clang 17.0.0git
RISCVVIntrinsicUtils.cpp
Go to the documentation of this file.
1//===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "llvm/ADT/ArrayRef.h"
11#include "llvm/ADT/SmallSet.h"
12#include "llvm/ADT/StringExtras.h"
13#include "llvm/ADT/StringMap.h"
14#include "llvm/ADT/StringSet.h"
15#include "llvm/ADT/Twine.h"
16#include "llvm/Support/ErrorHandling.h"
17#include "llvm/Support/raw_ostream.h"
18#include <numeric>
19#include <optional>
20
21using namespace llvm;
22
23namespace clang {
24namespace RISCV {
25
26const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor(
28const PrototypeDescriptor PrototypeDescriptor::VL =
29 PrototypeDescriptor(BaseTypeModifier::SizeT);
30const PrototypeDescriptor PrototypeDescriptor::Vector =
31 PrototypeDescriptor(BaseTypeModifier::Vector);
32
33//===----------------------------------------------------------------------===//
34// Type implementation
35//===----------------------------------------------------------------------===//
36
37LMULType::LMULType(int NewLog2LMUL) {
38 // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
39 assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
40 Log2LMUL = NewLog2LMUL;
41}
42
43std::string LMULType::str() const {
44 if (Log2LMUL < 0)
45 return "mf" + utostr(1ULL << (-Log2LMUL));
46 return "m" + utostr(1ULL << Log2LMUL);
47}
48
49VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
50 int Log2ScaleResult = 0;
51 switch (ElementBitwidth) {
52 default:
53 break;
54 case 8:
55 Log2ScaleResult = Log2LMUL + 3;
56 break;
57 case 16:
58 Log2ScaleResult = Log2LMUL + 2;
59 break;
60 case 32:
61 Log2ScaleResult = Log2LMUL + 1;
62 break;
63 case 64:
64 Log2ScaleResult = Log2LMUL;
65 break;
66 }
67 // Illegal vscale result would be less than 1
68 if (Log2ScaleResult < 0)
69 return std::nullopt;
70 return 1 << Log2ScaleResult;
71}
72
73void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
74
75RVVType::RVVType(BasicType BT, int Log2LMUL,
76 const PrototypeDescriptor &prototype)
77 : BT(BT), LMUL(LMULType(Log2LMUL)) {
78 applyBasicType();
79 applyModifier(prototype);
80 Valid = verifyType();
81 if (Valid) {
82 initBuiltinStr();
83 initTypeStr();
84 if (isVector()) {
85 initClangBuiltinStr();
86 }
87 }
88}
89
90// clang-format off
91// boolean type are encoded the ratio of n (SEW/LMUL)
92// SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64
93// c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t
94// IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1
95
96// type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8
97// -------- |------ | -------- | ------- | ------- | -------- | -------- | --------
98// i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64
99// i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32
100// i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16
101// i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8
102// double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64
103// float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32
104// half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16
105// clang-format on
106
107bool RVVType::verifyType() const {
108 if (ScalarType == Invalid)
109 return false;
110 if (isScalar())
111 return true;
112 if (!Scale)
113 return false;
114 if (isFloat() && ElementBitwidth == 8)
115 return false;
116 unsigned V = *Scale;
117 switch (ElementBitwidth) {
118 case 1:
119 case 8:
120 // Check Scale is 1,2,4,8,16,32,64
121 return (V <= 64 && isPowerOf2_32(V));
122 case 16:
123 // Check Scale is 1,2,4,8,16,32
124 return (V <= 32 && isPowerOf2_32(V));
125 case 32:
126 // Check Scale is 1,2,4,8,16
127 return (V <= 16 && isPowerOf2_32(V));
128 case 64:
129 // Check Scale is 1,2,4,8
130 return (V <= 8 && isPowerOf2_32(V));
131 }
132 return false;
133}
134
135void RVVType::initBuiltinStr() {
136 assert(isValid() && "RVVType is invalid");
137 switch (ScalarType) {
139 BuiltinStr = "v";
140 return;
142 BuiltinStr = "z";
143 if (IsImmediate)
144 BuiltinStr = "I" + BuiltinStr;
145 if (IsPointer)
146 BuiltinStr += "*";
147 return;
149 BuiltinStr = "Y";
150 return;
152 BuiltinStr = "ULi";
153 return;
155 BuiltinStr = "Li";
156 return;
158 assert(ElementBitwidth == 1);
159 BuiltinStr += "b";
160 break;
163 switch (ElementBitwidth) {
164 case 8:
165 BuiltinStr += "c";
166 break;
167 case 16:
168 BuiltinStr += "s";
169 break;
170 case 32:
171 BuiltinStr += "i";
172 break;
173 case 64:
174 BuiltinStr += "Wi";
175 break;
176 default:
177 llvm_unreachable("Unhandled ElementBitwidth!");
178 }
179 if (isSignedInteger())
180 BuiltinStr = "S" + BuiltinStr;
181 else
182 BuiltinStr = "U" + BuiltinStr;
183 break;
185 switch (ElementBitwidth) {
186 case 16:
187 BuiltinStr += "x";
188 break;
189 case 32:
190 BuiltinStr += "f";
191 break;
192 case 64:
193 BuiltinStr += "d";
194 break;
195 default:
196 llvm_unreachable("Unhandled ElementBitwidth!");
197 }
198 break;
199 default:
200 llvm_unreachable("ScalarType is invalid!");
201 }
202 if (IsImmediate)
203 BuiltinStr = "I" + BuiltinStr;
204 if (isScalar()) {
205 if (IsConstant)
206 BuiltinStr += "C";
207 if (IsPointer)
208 BuiltinStr += "*";
209 return;
210 }
211 BuiltinStr = "q" + utostr(*Scale) + BuiltinStr;
212 // Pointer to vector types. Defined for segment load intrinsics.
213 // segment load intrinsics have pointer type arguments to store the loaded
214 // vector values.
215 if (IsPointer)
216 BuiltinStr += "*";
217}
218
219void RVVType::initClangBuiltinStr() {
220 assert(isValid() && "RVVType is invalid");
221 assert(isVector() && "Handle Vector type only");
222
223 ClangBuiltinStr = "__rvv_";
224 switch (ScalarType) {
226 ClangBuiltinStr += "bool" + utostr(64 / *Scale) + "_t";
227 return;
229 ClangBuiltinStr += "float";
230 break;
232 ClangBuiltinStr += "int";
233 break;
235 ClangBuiltinStr += "uint";
236 break;
237 default:
238 llvm_unreachable("ScalarTypeKind is invalid");
239 }
240 ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t";
241}
242
243void RVVType::initTypeStr() {
244 assert(isValid() && "RVVType is invalid");
245
246 if (IsConstant)
247 Str += "const ";
248
249 auto getTypeString = [&](StringRef TypeStr) {
250 if (isScalar())
251 return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
252 return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t")
253 .str();
254 };
255
256 switch (ScalarType) {
258 Str = "void";
259 return;
261 Str = "size_t";
262 if (IsPointer)
263 Str += " *";
264 return;
266 Str = "ptrdiff_t";
267 return;
269 Str = "unsigned long";
270 return;
272 Str = "long";
273 return;
275 if (isScalar())
276 Str += "bool";
277 else
278 // Vector bool is special case, the formulate is
279 // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
280 Str += "vbool" + utostr(64 / *Scale) + "_t";
281 break;
283 if (isScalar()) {
284 if (ElementBitwidth == 64)
285 Str += "double";
286 else if (ElementBitwidth == 32)
287 Str += "float";
288 else if (ElementBitwidth == 16)
289 Str += "_Float16";
290 else
291 llvm_unreachable("Unhandled floating type.");
292 } else
293 Str += getTypeString("float");
294 break;
296 Str += getTypeString("int");
297 break;
299 Str += getTypeString("uint");
300 break;
301 default:
302 llvm_unreachable("ScalarType is invalid!");
303 }
304 if (IsPointer)
305 Str += " *";
306}
307
308void RVVType::initShortStr() {
309 switch (ScalarType) {
311 assert(isVector());
312 ShortStr = "b" + utostr(64 / *Scale);
313 return;
315 ShortStr = "f" + utostr(ElementBitwidth);
316 break;
318 ShortStr = "i" + utostr(ElementBitwidth);
319 break;
321 ShortStr = "u" + utostr(ElementBitwidth);
322 break;
323 default:
324 llvm_unreachable("Unhandled case!");
325 }
326 if (isVector())
327 ShortStr += LMUL.str();
328}
329
330void RVVType::applyBasicType() {
331 switch (BT) {
332 case BasicType::Int8:
333 ElementBitwidth = 8;
335 break;
336 case BasicType::Int16:
337 ElementBitwidth = 16;
339 break;
340 case BasicType::Int32:
341 ElementBitwidth = 32;
343 break;
344 case BasicType::Int64:
345 ElementBitwidth = 64;
347 break;
349 ElementBitwidth = 16;
350 ScalarType = ScalarTypeKind::Float;
351 break;
353 ElementBitwidth = 32;
354 ScalarType = ScalarTypeKind::Float;
355 break;
357 ElementBitwidth = 64;
358 ScalarType = ScalarTypeKind::Float;
359 break;
360 default:
361 llvm_unreachable("Unhandled type code!");
362 }
363 assert(ElementBitwidth != 0 && "Bad element bitwidth!");
364}
365
366std::optional<PrototypeDescriptor>
368 llvm::StringRef PrototypeDescriptorStr) {
372
373 if (PrototypeDescriptorStr.empty())
374 return PD;
375
376 // Handle base type modifier
377 auto PType = PrototypeDescriptorStr.back();
378 switch (PType) {
379 case 'e':
381 break;
382 case 'v':
384 break;
385 case 'w':
388 break;
389 case 'q':
392 break;
393 case 'o':
396 break;
397 case 'm':
400 break;
401 case '0':
403 break;
404 case 'z':
406 break;
407 case 't':
409 break;
410 case 'u':
412 break;
413 case 'l':
415 break;
416 default:
417 llvm_unreachable("Illegal primitive type transformers!");
418 }
419 PD.PT = static_cast<uint8_t>(PT);
420 PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back();
421
422 // Compute the vector type transformers, it can only appear one time.
423 if (PrototypeDescriptorStr.startswith("(")) {
425 "VectorTypeModifier should only have one modifier");
426 size_t Idx = PrototypeDescriptorStr.find(')');
427 assert(Idx != StringRef::npos);
428 StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx);
429 PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1);
430 assert(!PrototypeDescriptorStr.contains('(') &&
431 "Only allow one vector type modifier");
432
433 auto ComplexTT = ComplexType.split(":");
434 if (ComplexTT.first == "Log2EEW") {
435 uint32_t Log2EEW;
436 if (ComplexTT.second.getAsInteger(10, Log2EEW)) {
437 llvm_unreachable("Invalid Log2EEW value!");
438 return std::nullopt;
439 }
440 switch (Log2EEW) {
441 case 3:
443 break;
444 case 4:
446 break;
447 case 5:
449 break;
450 case 6:
452 break;
453 default:
454 llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
455 return std::nullopt;
456 }
457 } else if (ComplexTT.first == "FixedSEW") {
458 uint32_t NewSEW;
459 if (ComplexTT.second.getAsInteger(10, NewSEW)) {
460 llvm_unreachable("Invalid FixedSEW value!");
461 return std::nullopt;
462 }
463 switch (NewSEW) {
464 case 8:
466 break;
467 case 16:
469 break;
470 case 32:
472 break;
473 case 64:
475 break;
476 default:
477 llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
478 return std::nullopt;
479 }
480 } else if (ComplexTT.first == "LFixedLog2LMUL") {
481 int32_t Log2LMUL;
482 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
483 llvm_unreachable("Invalid LFixedLog2LMUL value!");
484 return std::nullopt;
485 }
486 switch (Log2LMUL) {
487 case -3:
489 break;
490 case -2:
492 break;
493 case -1:
495 break;
496 case 0:
498 break;
499 case 1:
501 break;
502 case 2:
504 break;
505 case 3:
507 break;
508 default:
509 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
510 return std::nullopt;
511 }
512 } else if (ComplexTT.first == "SFixedLog2LMUL") {
513 int32_t Log2LMUL;
514 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
515 llvm_unreachable("Invalid SFixedLog2LMUL value!");
516 return std::nullopt;
517 }
518 switch (Log2LMUL) {
519 case -3:
521 break;
522 case -2:
524 break;
525 case -1:
527 break;
528 case 0:
530 break;
531 case 1:
533 break;
534 case 2:
536 break;
537 case 3:
539 break;
540 default:
541 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
542 return std::nullopt;
543 }
544
545 } else {
546 llvm_unreachable("Illegal complex type transformers!");
547 }
548 }
549 PD.VTM = static_cast<uint8_t>(VTM);
550
551 // Compute the remain type transformers
553 for (char I : PrototypeDescriptorStr) {
554 switch (I) {
555 case 'P':
557 llvm_unreachable("'P' transformer cannot be used after 'C'");
559 llvm_unreachable("'P' transformer cannot be used twice");
561 break;
562 case 'C':
564 break;
565 case 'K':
567 break;
568 case 'U':
570 break;
571 case 'I':
573 break;
574 case 'F':
576 break;
577 case 'S':
579 break;
580 default:
581 llvm_unreachable("Illegal non-primitive type transformer!");
582 }
583 }
584 PD.TM = static_cast<uint8_t>(TM);
585
586 return PD;
587}
588
589void RVVType::applyModifier(const PrototypeDescriptor &Transformer) {
590 // Handle primitive type transformer
591 switch (static_cast<BaseTypeModifier>(Transformer.PT)) {
593 Scale = 0;
594 break;
596 Scale = LMUL.getScale(ElementBitwidth);
597 break;
599 ScalarType = ScalarTypeKind::Void;
600 break;
602 ScalarType = ScalarTypeKind::Size_t;
603 break;
605 ScalarType = ScalarTypeKind::Ptrdiff_t;
606 break;
608 ScalarType = ScalarTypeKind::UnsignedLong;
609 break;
611 ScalarType = ScalarTypeKind::SignedLong;
612 break;
614 ScalarType = ScalarTypeKind::Invalid;
615 return;
616 }
617
618 switch (static_cast<VectorTypeModifier>(Transformer.VTM)) {
620 ElementBitwidth *= 2;
621 LMUL.MulLog2LMUL(1);
622 Scale = LMUL.getScale(ElementBitwidth);
623 break;
625 ElementBitwidth *= 4;
626 LMUL.MulLog2LMUL(2);
627 Scale = LMUL.getScale(ElementBitwidth);
628 break;
630 ElementBitwidth *= 8;
631 LMUL.MulLog2LMUL(3);
632 Scale = LMUL.getScale(ElementBitwidth);
633 break;
635 ScalarType = ScalarTypeKind::Boolean;
636 Scale = LMUL.getScale(ElementBitwidth);
637 ElementBitwidth = 1;
638 break;
640 applyLog2EEW(3);
641 break;
643 applyLog2EEW(4);
644 break;
646 applyLog2EEW(5);
647 break;
649 applyLog2EEW(6);
650 break;
652 applyFixedSEW(8);
653 break;
655 applyFixedSEW(16);
656 break;
658 applyFixedSEW(32);
659 break;
661 applyFixedSEW(64);
662 break;
664 applyFixedLog2LMUL(-3, FixedLMULType::LargerThan);
665 break;
667 applyFixedLog2LMUL(-2, FixedLMULType::LargerThan);
668 break;
670 applyFixedLog2LMUL(-1, FixedLMULType::LargerThan);
671 break;
673 applyFixedLog2LMUL(0, FixedLMULType::LargerThan);
674 break;
676 applyFixedLog2LMUL(1, FixedLMULType::LargerThan);
677 break;
679 applyFixedLog2LMUL(2, FixedLMULType::LargerThan);
680 break;
682 applyFixedLog2LMUL(3, FixedLMULType::LargerThan);
683 break;
685 applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan);
686 break;
688 applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan);
689 break;
691 applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan);
692 break;
694 applyFixedLog2LMUL(0, FixedLMULType::SmallerThan);
695 break;
697 applyFixedLog2LMUL(1, FixedLMULType::SmallerThan);
698 break;
700 applyFixedLog2LMUL(2, FixedLMULType::SmallerThan);
701 break;
703 applyFixedLog2LMUL(3, FixedLMULType::SmallerThan);
704 break;
706 break;
707 }
708
709 for (unsigned TypeModifierMaskShift = 0;
710 TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset);
711 ++TypeModifierMaskShift) {
712 unsigned TypeModifierMask = 1 << TypeModifierMaskShift;
713 if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) !=
714 TypeModifierMask)
715 continue;
716 switch (static_cast<TypeModifier>(TypeModifierMask)) {
718 IsPointer = true;
719 break;
721 IsConstant = true;
722 break;
724 IsImmediate = true;
725 IsConstant = true;
726 break;
729 break;
732 break;
734 ScalarType = ScalarTypeKind::Float;
735 break;
737 LMUL = LMULType(0);
738 // Update ElementBitwidth need to update Scale too.
739 Scale = LMUL.getScale(ElementBitwidth);
740 break;
741 default:
742 llvm_unreachable("Unknown type modifier mask!");
743 }
744 }
745}
746
747void RVVType::applyLog2EEW(unsigned Log2EEW) {
748 // update new elmul = (eew/sew) * lmul
749 LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
750 // update new eew
751 ElementBitwidth = 1 << Log2EEW;
753 Scale = LMUL.getScale(ElementBitwidth);
754}
755
756void RVVType::applyFixedSEW(unsigned NewSEW) {
757 // Set invalid type if src and dst SEW are same.
758 if (ElementBitwidth == NewSEW) {
759 ScalarType = ScalarTypeKind::Invalid;
760 return;
761 }
762 // Update new SEW
763 ElementBitwidth = NewSEW;
764 Scale = LMUL.getScale(ElementBitwidth);
765}
766
767void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) {
768 switch (Type) {
769 case FixedLMULType::LargerThan:
770 if (Log2LMUL < LMUL.Log2LMUL) {
771 ScalarType = ScalarTypeKind::Invalid;
772 return;
773 }
774 break;
775 case FixedLMULType::SmallerThan:
776 if (Log2LMUL > LMUL.Log2LMUL) {
777 ScalarType = ScalarTypeKind::Invalid;
778 return;
779 }
780 break;
781 }
782
783 // Update new LMUL
784 LMUL = LMULType(Log2LMUL);
785 Scale = LMUL.getScale(ElementBitwidth);
786}
787
788std::optional<RVVTypes>
789RVVTypeCache::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
791 // LMUL x NF must be less than or equal to 8.
792 if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
793 return std::nullopt;
794
795 RVVTypes Types;
796 for (const PrototypeDescriptor &Proto : Prototype) {
797 auto T = computeType(BT, Log2LMUL, Proto);
798 if (!T)
799 return std::nullopt;
800 // Record legal type index
801 Types.push_back(*T);
802 }
803 return Types;
804}
805
806// Compute the hash value of RVVType, used for cache the result of computeType.
807static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL,
808 PrototypeDescriptor Proto) {
809 // Layout of hash value:
810 // 0 8 16 24 32 40
811 // | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM |
812 assert(Log2LMUL >= -3 && Log2LMUL <= 3);
813 return (Log2LMUL + 3) | (static_cast<uint64_t>(BT) & 0xff) << 8 |
814 ((uint64_t)(Proto.PT & 0xff) << 16) |
815 ((uint64_t)(Proto.TM & 0xff) << 24) |
816 ((uint64_t)(Proto.VTM & 0xff) << 32);
817}
818
819std::optional<RVVTypePtr> RVVTypeCache::computeType(BasicType BT, int Log2LMUL,
820 PrototypeDescriptor Proto) {
821 uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto);
822 // Search first
823 auto It = LegalTypes.find(Idx);
824 if (It != LegalTypes.end())
825 return &(It->second);
826
827 if (IllegalTypes.count(Idx))
828 return std::nullopt;
829
830 // Compute type and record the result.
831 RVVType T(BT, Log2LMUL, Proto);
832 if (T.isValid()) {
833 // Record legal type index and value.
834 std::pair<std::unordered_map<uint64_t, RVVType>::iterator, bool>
835 InsertResult = LegalTypes.insert({Idx, T});
836 return &(InsertResult.first->second);
837 }
838 // Record illegal type index.
839 IllegalTypes.insert(Idx);
840 return std::nullopt;
841}
842
843//===----------------------------------------------------------------------===//
844// RVVIntrinsic implementation
845//===----------------------------------------------------------------------===//
846RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix,
847 StringRef NewOverloadedName,
848 StringRef OverloadedSuffix, StringRef IRName,
849 bool IsMasked, bool HasMaskedOffOperand, bool HasVL,
850 PolicyScheme Scheme, bool SupportOverloading,
851 bool HasBuiltinAlias, StringRef ManualCodegen,
852 const RVVTypes &OutInTypes,
853 const std::vector<int64_t> &NewIntrinsicTypes,
854 const std::vector<StringRef> &RequiredFeatures,
855 unsigned NF, Policy NewPolicyAttrs)
856 : IRName(IRName), IsMasked(IsMasked),
857 HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), Scheme(Scheme),
858 SupportOverloading(SupportOverloading), HasBuiltinAlias(HasBuiltinAlias),
859 ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs) {
860
861 // Init BuiltinName, Name and OverloadedName
862 BuiltinName = NewName.str();
863 Name = BuiltinName;
864 if (NewOverloadedName.empty())
865 OverloadedName = NewName.split("_").first.str();
866 else
867 OverloadedName = NewOverloadedName.str();
868 if (!Suffix.empty())
869 Name += "_" + Suffix.str();
870 if (!OverloadedSuffix.empty())
871 OverloadedName += "_" + OverloadedSuffix.str();
872
873 updateNamesAndPolicy(IsMasked, hasPolicy(), Name, BuiltinName, OverloadedName,
874 PolicyAttrs);
875
876 // Init OutputType and InputTypes
877 OutputType = OutInTypes[0];
878 InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
879
880 // IntrinsicTypes is unmasked TA version index. Need to update it
881 // if there is merge operand (It is always in first operand).
882 IntrinsicTypes = NewIntrinsicTypes;
883 if ((IsMasked && hasMaskedOffOperand()) ||
884 (!IsMasked && hasPassthruOperand())) {
885 for (auto &I : IntrinsicTypes) {
886 if (I >= 0)
887 I += NF;
888 }
889 }
890}
891
893 std::string S;
894 S += OutputType->getBuiltinStr();
895 for (const auto &T : InputTypes) {
896 S += T->getBuiltinStr();
897 }
898 return S;
899}
900
902 RVVTypeCache &TypeCache, BasicType Type, int Log2LMUL,
903 llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) {
904 SmallVector<std::string> SuffixStrs;
905 for (auto PD : PrototypeDescriptors) {
906 auto T = TypeCache.computeType(Type, Log2LMUL, PD);
907 SuffixStrs.push_back((*T)->getShortStr());
908 }
909 return join(SuffixStrs, "_");
910}
911
914 bool HasMaskedOffOperand, bool HasVL, unsigned NF,
915 PolicyScheme DefaultScheme, Policy PolicyAttrs) {
916 SmallVector<PrototypeDescriptor> NewPrototype(Prototype.begin(),
917 Prototype.end());
918 bool HasPassthruOp = DefaultScheme == PolicyScheme::HasPassthruOperand;
919 if (IsMasked) {
920 // If HasMaskedOffOperand, insert result type as first input operand if
921 // need.
922 if (HasMaskedOffOperand && !PolicyAttrs.isTAMAPolicy()) {
923 if (NF == 1) {
924 NewPrototype.insert(NewPrototype.begin() + 1, NewPrototype[0]);
925 } else if (NF > 1) {
926 // Convert
927 // (void, op0 address, op1 address, ...)
928 // to
929 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
930 PrototypeDescriptor MaskoffType = NewPrototype[1];
931 MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
932 NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
933 }
934 }
935 if (HasMaskedOffOperand && NF > 1) {
936 // Convert
937 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
938 // to
939 // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
940 // ...)
941 NewPrototype.insert(NewPrototype.begin() + NF + 1,
943 } else {
944 // If IsMasked, insert PrototypeDescriptor:Mask as first input operand.
945 NewPrototype.insert(NewPrototype.begin() + 1, PrototypeDescriptor::Mask);
946 }
947 } else {
948 if (NF == 1) {
949 if (PolicyAttrs.isTUPolicy() && HasPassthruOp)
950 NewPrototype.insert(NewPrototype.begin(), NewPrototype[0]);
951 } else if (PolicyAttrs.isTUPolicy() && HasPassthruOp) {
952 // NF > 1 cases for segment load operations.
953 // Convert
954 // (void, op0 address, op1 address, ...)
955 // to
956 // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...)
957 PrototypeDescriptor MaskoffType = Prototype[1];
958 MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
959 NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
960 }
961 }
962
963 // If HasVL, append PrototypeDescriptor:VL to last operand
964 if (HasVL)
965 NewPrototype.push_back(PrototypeDescriptor::VL);
966 return NewPrototype;
967}
968
971}
972
975 bool HasMaskPolicy) {
976 if (HasTailPolicy && HasMaskPolicy)
983 if (HasTailPolicy && !HasMaskPolicy)
986 if (!HasTailPolicy && HasMaskPolicy)
989 llvm_unreachable("An RVV instruction should not be without both tail policy "
990 "and mask policy");
991}
992
993void RVVIntrinsic::updateNamesAndPolicy(bool IsMasked, bool HasPolicy,
994 std::string &Name,
995 std::string &BuiltinName,
996 std::string &OverloadedName,
997 Policy &PolicyAttrs) {
998
999 auto appendPolicySuffix = [&](const std::string &suffix) {
1000 Name += suffix;
1001 BuiltinName += suffix;
1002 OverloadedName += suffix;
1003 };
1004
1005 // This follows the naming guideline under riscv-c-api-doc to add the
1006 // `__riscv_` suffix for all RVV intrinsics.
1007 Name = "__riscv_" + Name;
1008 OverloadedName = "__riscv_" + OverloadedName;
1009
1010 if (IsMasked) {
1011 if (PolicyAttrs.isTUMUPolicy())
1012 appendPolicySuffix("_tumu");
1013 else if (PolicyAttrs.isTUMAPolicy())
1014 appendPolicySuffix("_tum");
1015 else if (PolicyAttrs.isTAMUPolicy())
1016 appendPolicySuffix("_mu");
1017 else if (PolicyAttrs.isTAMAPolicy()) {
1018 Name += "_m";
1019 if (HasPolicy)
1020 BuiltinName += "_tama";
1021 else
1022 BuiltinName += "_m";
1023 } else
1024 llvm_unreachable("Unhandled policy condition");
1025 } else {
1026 if (PolicyAttrs.isTUPolicy())
1027 appendPolicySuffix("_tu");
1028 else if (PolicyAttrs.isTAPolicy()) {
1029 if (HasPolicy)
1030 BuiltinName += "_ta";
1031 } else
1032 llvm_unreachable("Unhandled policy condition");
1033 }
1034}
1035
1037 SmallVector<PrototypeDescriptor> PrototypeDescriptors;
1038 const StringRef Primaries("evwqom0ztul");
1039 while (!Prototypes.empty()) {
1040 size_t Idx = 0;
1041 // Skip over complex prototype because it could contain primitive type
1042 // character.
1043 if (Prototypes[0] == '(')
1044 Idx = Prototypes.find_first_of(')');
1045 Idx = Prototypes.find_first_of(Primaries, Idx);
1046 assert(Idx != StringRef::npos);
1048 Prototypes.slice(0, Idx + 1));
1049 if (!PD)
1050 llvm_unreachable("Error during parsing prototype.");
1051 PrototypeDescriptors.push_back(*PD);
1052 Prototypes = Prototypes.drop_front(Idx + 1);
1053 }
1054 return PrototypeDescriptors;
1055}
1056
1057raw_ostream &operator<<(raw_ostream &OS, const RVVIntrinsicRecord &Record) {
1058 OS << "{";
1059 OS << "\"" << Record.Name << "\",";
1060 if (Record.OverloadedName == nullptr ||
1061 StringRef(Record.OverloadedName).empty())
1062 OS << "nullptr,";
1063 else
1064 OS << "\"" << Record.OverloadedName << "\",";
1065 OS << Record.PrototypeIndex << ",";
1066 OS << Record.SuffixIndex << ",";
1067 OS << Record.OverloadedSuffixIndex << ",";
1068 OS << (int)Record.PrototypeLength << ",";
1069 OS << (int)Record.SuffixLength << ",";
1070 OS << (int)Record.OverloadedSuffixSize << ",";
1071 OS << (int)Record.RequiredExtensions << ",";
1072 OS << (int)Record.TypeRangeMask << ",";
1073 OS << (int)Record.Log2LMULMask << ",";
1074 OS << (int)Record.NF << ",";
1075 OS << (int)Record.HasMasked << ",";
1076 OS << (int)Record.HasVL << ",";
1077 OS << (int)Record.HasMaskedOffOperand << ",";
1078 OS << (int)Record.HasTailPolicy << ",";
1079 OS << (int)Record.HasMaskPolicy << ",";
1080 OS << (int)Record.UnMaskedPolicyScheme << ",";
1081 OS << (int)Record.MaskedPolicyScheme << ",";
1082 OS << "},\n";
1083 return OS;
1084}
1085
1086} // end namespace RISCV
1087} // end namespace clang
#define V(N, I)
Definition: ASTContext.h:3217
static bool getTypeString(SmallStringEnc &Enc, const Decl *D, const CodeGen::CodeGenModule &CGM, TypeStringCache &TSC)
The XCore ABI includes a type information section that communicates symbol type information to the li...
static bool isVector(QualType QT, QualType ElementType)
This helper function returns true if QT is a vector type that has element type ElementType.
Definition: SemaExpr.cpp:9751
__device__ int
Complex values, per C99 6.2.5p11.
Definition: Type.h:2735
RVVIntrinsic(llvm::StringRef Name, llvm::StringRef Suffix, llvm::StringRef OverloadedName, llvm::StringRef OverloadedSuffix, llvm::StringRef IRName, bool IsMasked, bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme, bool SupportOverloading, bool HasBuiltinAlias, llvm::StringRef ManualCodegen, const RVVTypes &Types, const std::vector< int64_t > &IntrinsicTypes, const std::vector< llvm::StringRef > &RequiredFeatures, unsigned NF, Policy PolicyAttrs)
static llvm::SmallVector< Policy > getSupportedMaskedPolicies(bool HasTailPolicy, bool HasMaskPolicy)
static std::string getSuffixStr(RVVTypeCache &TypeCache, BasicType Type, int Log2LMUL, llvm::ArrayRef< PrototypeDescriptor > PrototypeDescriptors)
static void updateNamesAndPolicy(bool IsMasked, bool HasPolicy, std::string &Name, std::string &BuiltinName, std::string &OverloadedName, Policy &PolicyAttrs)
static llvm::SmallVector< PrototypeDescriptor > computeBuiltinTypes(llvm::ArrayRef< PrototypeDescriptor > Prototype, bool IsMasked, bool HasMaskedOffOperand, bool HasVL, unsigned NF, PolicyScheme DefaultScheme, Policy PolicyAttrs)
static llvm::SmallVector< Policy > getSupportedUnMaskedPolicies()
std::string getBuiltinTypeStr() const
std::optional< RVVTypePtr > computeType(BasicType BT, int Log2LMUL, PrototypeDescriptor Proto)
std::optional< RVVTypes > computeTypes(BasicType BT, int Log2LMUL, unsigned NF, llvm::ArrayRef< PrototypeDescriptor > Prototype)
Compute output and input types by applying different config (basic type and LMUL with type transforme...
const std::string & getBuiltinStr() const
The base class of the type hierarchy.
Definition: Type.h:1566
llvm::raw_ostream & operator<<(llvm::raw_ostream &OS, const RVVIntrinsicRecord &RVVInstrRecord)
llvm::SmallVector< PrototypeDescriptor > parsePrototypes(llvm::StringRef Prototypes)
static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL, PrototypeDescriptor Proto)
std::optional< unsigned > VScaleVal
std::vector< RVVTypePtr > RVVTypes
YAML serialization mapping.
Definition: Dominators.h:30
std::optional< unsigned > getScale(unsigned ElementBitwidth) const
void MulLog2LMUL(int Log2LMUL)
static std::optional< PrototypeDescriptor > parsePrototypeDescriptor(llvm::StringRef PrototypeStr)
static const PrototypeDescriptor VL
static const PrototypeDescriptor Mask
static const PrototypeDescriptor Vector