clang 23.0.0git
SemaHLSL.cpp
Go to the documentation of this file.
1//===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This implements Semantic Analysis for HLSL constructs.
9//===----------------------------------------------------------------------===//
10
11#include "clang/Sema/SemaHLSL.h"
14#include "clang/AST/Attr.h"
15#include "clang/AST/Decl.h"
16#include "clang/AST/DeclBase.h"
17#include "clang/AST/DeclCXX.h"
20#include "clang/AST/Expr.h"
22#include "clang/AST/Type.h"
23#include "clang/AST/TypeBase.h"
24#include "clang/AST/TypeLoc.h"
28#include "clang/Basic/LLVM.h"
33#include "clang/Sema/Lookup.h"
35#include "clang/Sema/Sema.h"
36#include "clang/Sema/Template.h"
37#include "llvm/ADT/ArrayRef.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/StringExtras.h"
41#include "llvm/ADT/StringRef.h"
42#include "llvm/ADT/Twine.h"
43#include "llvm/Frontend/HLSL/HLSLBinding.h"
44#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
45#include "llvm/Support/Casting.h"
46#include "llvm/Support/DXILABI.h"
47#include "llvm/Support/ErrorHandling.h"
48#include "llvm/Support/FormatVariadic.h"
49#include "llvm/TargetParser/Triple.h"
50#include <cmath>
51#include <cstddef>
52#include <iterator>
53#include <utility>
54
55using namespace clang;
56using namespace clang::hlsl;
57using RegisterType = HLSLResourceBindingAttr::RegisterType;
58
60 CXXRecordDecl *StructDecl);
61
63 switch (RC) {
64 case ResourceClass::SRV:
65 return RegisterType::SRV;
66 case ResourceClass::UAV:
67 return RegisterType::UAV;
68 case ResourceClass::CBuffer:
69 return RegisterType::CBuffer;
70 case ResourceClass::Sampler:
71 return RegisterType::Sampler;
72 }
73 llvm_unreachable("unexpected ResourceClass value");
74}
75
76static RegisterType getRegisterType(const HLSLAttributedResourceType *ResTy) {
77 return getRegisterType(ResTy->getAttrs().ResourceClass);
78}
79
81 switch (RC) {
82 case ResourceClass::SRV:
83 case ResourceClass::UAV:
85 case ResourceClass::CBuffer:
87 case ResourceClass::Sampler:
89 }
90 llvm_unreachable("unexpected ResourceClass value");
91}
92
93// Converts the first letter of string Slot to RegisterType.
94// Returns false if the letter does not correspond to a valid register type.
95static bool convertToRegisterType(StringRef Slot, RegisterType *RT) {
96 assert(RT != nullptr);
97 switch (Slot[0]) {
98 case 't':
99 case 'T':
100 *RT = RegisterType::SRV;
101 return true;
102 case 'u':
103 case 'U':
104 *RT = RegisterType::UAV;
105 return true;
106 case 'b':
107 case 'B':
108 *RT = RegisterType::CBuffer;
109 return true;
110 case 's':
111 case 'S':
112 *RT = RegisterType::Sampler;
113 return true;
114 case 'c':
115 case 'C':
116 *RT = RegisterType::C;
117 return true;
118 case 'i':
119 case 'I':
120 *RT = RegisterType::I;
121 return true;
122 default:
123 return false;
124 }
125}
126
128 switch (RT) {
129 case RegisterType::SRV:
130 return 't';
131 case RegisterType::UAV:
132 return 'u';
133 case RegisterType::CBuffer:
134 return 'b';
135 case RegisterType::Sampler:
136 return 's';
137 case RegisterType::C:
138 return 'c';
139 case RegisterType::I:
140 return 'i';
141 }
142 llvm_unreachable("unexpected RegisterType value");
143}
144
146 switch (RT) {
147 case RegisterType::SRV:
148 return ResourceClass::SRV;
149 case RegisterType::UAV:
150 return ResourceClass::UAV;
151 case RegisterType::CBuffer:
152 return ResourceClass::CBuffer;
153 case RegisterType::Sampler:
154 return ResourceClass::Sampler;
155 case RegisterType::C:
156 case RegisterType::I:
157 // Deliberately falling through to the unreachable below.
158 break;
159 }
160 llvm_unreachable("unexpected RegisterType value");
161}
162
164 const auto *BT = dyn_cast<BuiltinType>(Type);
165 if (!BT) {
166 if (!Type->isEnumeralType())
167 return Builtin::NotBuiltin;
168 return Builtin::BI__builtin_get_spirv_spec_constant_int;
169 }
170
171 switch (BT->getKind()) {
172 case BuiltinType::Bool:
173 return Builtin::BI__builtin_get_spirv_spec_constant_bool;
174 case BuiltinType::Short:
175 return Builtin::BI__builtin_get_spirv_spec_constant_short;
176 case BuiltinType::Int:
177 return Builtin::BI__builtin_get_spirv_spec_constant_int;
178 case BuiltinType::LongLong:
179 return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
180 case BuiltinType::UShort:
181 return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
182 case BuiltinType::UInt:
183 return Builtin::BI__builtin_get_spirv_spec_constant_uint;
184 case BuiltinType::ULongLong:
185 return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
186 case BuiltinType::Half:
187 return Builtin::BI__builtin_get_spirv_spec_constant_half;
188 case BuiltinType::Float:
189 return Builtin::BI__builtin_get_spirv_spec_constant_float;
190 case BuiltinType::Double:
191 return Builtin::BI__builtin_get_spirv_spec_constant_double;
192 default:
193 return Builtin::NotBuiltin;
194 }
195}
196
197static StringRef createRegisterString(ASTContext &AST, RegisterType RegType,
198 unsigned N) {
200 llvm::raw_svector_ostream OS(Buffer);
201 OS << getRegisterTypeChar(RegType);
202 OS << N;
203 return AST.backupStr(OS.str());
204}
205
207 ResourceClass ResClass) {
208 assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
209 "DeclBindingInfo already added");
210 assert(!hasBindingInfoForDecl(VD) || BindingsList.back().Decl == VD);
211 // VarDecl may have multiple entries for different resource classes.
212 // DeclToBindingListIndex stores the index of the first binding we saw
213 // for this decl. If there are any additional ones then that index
214 // shouldn't be updated.
215 DeclToBindingListIndex.try_emplace(VD, BindingsList.size());
216 return &BindingsList.emplace_back(VD, ResClass);
217}
218
220 ResourceClass ResClass) {
221 auto Entry = DeclToBindingListIndex.find(VD);
222 if (Entry != DeclToBindingListIndex.end()) {
223 for (unsigned Index = Entry->getSecond();
224 Index < BindingsList.size() && BindingsList[Index].Decl == VD;
225 ++Index) {
226 if (BindingsList[Index].ResClass == ResClass)
227 return &BindingsList[Index];
228 }
229 }
230 return nullptr;
231}
232
234 return DeclToBindingListIndex.contains(VD);
235}
236
238
239Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
240 SourceLocation KwLoc, IdentifierInfo *Ident,
241 SourceLocation IdentLoc,
242 SourceLocation LBrace) {
243 // For anonymous namespace, take the location of the left brace.
244 DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
246 getASTContext(), LexicalParent, CBuffer, KwLoc, Ident, IdentLoc, LBrace);
247
248 // if CBuffer is false, then it's a TBuffer
249 auto RC = CBuffer ? llvm::hlsl::ResourceClass::CBuffer
250 : llvm::hlsl::ResourceClass::SRV;
251 Result->addAttr(HLSLResourceClassAttr::CreateImplicit(getASTContext(), RC));
252
253 SemaRef.PushOnScopeChains(Result, BufferScope);
254 SemaRef.PushDeclContext(BufferScope, Result);
255
256 return Result;
257}
258
259static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context,
260 QualType T) {
261 // Arrays, Matrices, and Structs are always aligned to new buffer rows
262 if (T->isArrayType() || T->isStructureType() || T->isConstantMatrixType())
263 return 16;
264
265 // Vectors are aligned to the type they contain
266 if (const VectorType *VT = T->getAs<VectorType>())
267 return calculateLegacyCbufferFieldAlign(Context, VT->getElementType());
268
269 assert(Context.getTypeSize(T) <= 64 &&
270 "Scalar bit widths larger than 64 not supported");
271
272 // Scalar types are aligned to their byte width
273 return Context.getTypeSize(T) / 8;
274}
275
276// Calculate the size of a legacy cbuffer type in bytes based on
277// https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules
278static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
279 QualType T) {
280 constexpr unsigned CBufferAlign = 16;
281 if (const auto *RD = T->getAsRecordDecl()) {
282 unsigned Size = 0;
283 for (const FieldDecl *Field : RD->fields()) {
284 QualType Ty = Field->getType();
285 unsigned FieldSize = calculateLegacyCbufferSize(Context, Ty);
286 unsigned FieldAlign = calculateLegacyCbufferFieldAlign(Context, Ty);
287
288 // If the field crosses the row boundary after alignment it drops to the
289 // next row
290 unsigned AlignSize = llvm::alignTo(Size, FieldAlign);
291 if ((AlignSize % CBufferAlign) + FieldSize > CBufferAlign) {
292 FieldAlign = CBufferAlign;
293 }
294
295 Size = llvm::alignTo(Size, FieldAlign);
296 Size += FieldSize;
297 }
298 return Size;
299 }
300
301 if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
302 unsigned ElementCount = AT->getSize().getZExtValue();
303 if (ElementCount == 0)
304 return 0;
305
306 unsigned ElementSize =
307 calculateLegacyCbufferSize(Context, AT->getElementType());
308 unsigned AlignedElementSize = llvm::alignTo(ElementSize, CBufferAlign);
309 return AlignedElementSize * (ElementCount - 1) + ElementSize;
310 }
311
312 if (const VectorType *VT = T->getAs<VectorType>()) {
313 unsigned ElementCount = VT->getNumElements();
314 unsigned ElementSize =
315 calculateLegacyCbufferSize(Context, VT->getElementType());
316 return ElementSize * ElementCount;
317 }
318
319 return Context.getTypeSize(T) / 8;
320}
321
322// Validate packoffset:
323// - if packoffset it used it must be set on all declarations inside the buffer
324// - packoffset ranges must not overlap
325static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) {
327
328 // Make sure the packoffset annotations are either on all declarations
329 // or on none.
330 bool HasPackOffset = false;
331 bool HasNonPackOffset = false;
332 for (auto *Field : BufDecl->buffer_decls()) {
333 VarDecl *Var = dyn_cast<VarDecl>(Field);
334 if (!Var)
335 continue;
336 if (Field->hasAttr<HLSLPackOffsetAttr>()) {
337 PackOffsetVec.emplace_back(Var, Field->getAttr<HLSLPackOffsetAttr>());
338 HasPackOffset = true;
339 } else {
340 HasNonPackOffset = true;
341 }
342 }
343
344 if (!HasPackOffset)
345 return;
346
347 if (HasNonPackOffset)
348 S.Diag(BufDecl->getLocation(), diag::warn_hlsl_packoffset_mix);
349
350 // Make sure there is no overlap in packoffset - sort PackOffsetVec by offset
351 // and compare adjacent values.
352 bool IsValid = true;
353 ASTContext &Context = S.getASTContext();
354 std::sort(PackOffsetVec.begin(), PackOffsetVec.end(),
355 [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,
356 const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {
357 return LHS.second->getOffsetInBytes() <
358 RHS.second->getOffsetInBytes();
359 });
360 for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {
361 VarDecl *Var = PackOffsetVec[i].first;
362 HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;
363 unsigned Size = calculateLegacyCbufferSize(Context, Var->getType());
364 unsigned Begin = Attr->getOffsetInBytes();
365 unsigned End = Begin + Size;
366 unsigned NextBegin = PackOffsetVec[i + 1].second->getOffsetInBytes();
367 if (End > NextBegin) {
368 VarDecl *NextVar = PackOffsetVec[i + 1].first;
369 S.Diag(NextVar->getLocation(), diag::err_hlsl_packoffset_overlap)
370 << NextVar << Var;
371 IsValid = false;
372 }
373 }
374 BufDecl->setHasValidPackoffset(IsValid);
375}
376
377// Returns true if the array has a zero size = if any of the dimensions is 0
378static bool isZeroSizedArray(const ConstantArrayType *CAT) {
379 while (CAT && !CAT->isZeroSize())
380 CAT = dyn_cast<ConstantArrayType>(
382 return CAT != nullptr;
383}
384
388
392
393static const HLSLAttributedResourceType *
395 assert(QT->isHLSLResourceRecordArray() &&
396 "expected array of resource records");
397 const Type *Ty = QT->getUnqualifiedDesugaredType();
398 while (const ArrayType *AT = dyn_cast<ArrayType>(Ty))
400 return HLSLAttributedResourceType::findHandleTypeOnResource(Ty);
401}
402
403static const HLSLAttributedResourceType *
407
408// Returns true if the type is a leaf element type that is not valid to be
409// included in HLSL Buffer, such as a resource class, empty struct, zero-sized
410// array, or a builtin intangible type. Returns false it is a valid leaf element
411// type or if it is a record type that needs to be inspected further.
415 return true;
416 if (const auto *RD = Ty->getAsCXXRecordDecl())
417 return RD->isEmpty();
418 if (Ty->isConstantArrayType() &&
420 return true;
422 return true;
423 return false;
424}
425
426// Returns true if the struct contains at least one element that prevents it
427// from being included inside HLSL Buffer as is, such as an intangible type,
428// empty struct, or zero-sized array. If it does, a new implicit layout struct
429// needs to be created for HLSL Buffer use that will exclude these unwanted
430// declarations (see createHostLayoutStruct function).
432 if (RD->isHLSLIntangible() || RD->isEmpty())
433 return true;
434 // check fields
435 for (const FieldDecl *Field : RD->fields()) {
436 QualType Ty = Field->getType();
438 return true;
439 if (const auto *RD = Ty->getAsCXXRecordDecl();
441 return true;
442 }
443 // check bases
444 for (const CXXBaseSpecifier &Base : RD->bases())
446 Base.getType()->castAsCXXRecordDecl()))
447 return true;
448 return false;
449}
450
452 DeclContext *DC) {
453 CXXRecordDecl *RD = nullptr;
454 for (NamedDecl *Decl :
456 if (CXXRecordDecl *FoundRD = dyn_cast<CXXRecordDecl>(Decl)) {
457 assert(RD == nullptr &&
458 "there should be at most 1 record by a given name in a scope");
459 RD = FoundRD;
460 }
461 }
462 return RD;
463}
464
465// Creates a name for buffer layout struct using the provide name base.
466// If the name must be unique (not previously defined), a suffix is added
467// until a unique name is found.
469 bool MustBeUnique) {
470 ASTContext &AST = S.getASTContext();
471
472 IdentifierInfo *NameBaseII = BaseDecl->getIdentifier();
473 llvm::SmallString<64> Name("__cblayout_");
474 if (NameBaseII) {
475 Name.append(NameBaseII->getName());
476 } else {
477 // anonymous struct
478 Name.append("anon");
479 MustBeUnique = true;
480 }
481
482 size_t NameLength = Name.size();
483 IdentifierInfo *II = &AST.Idents.get(Name, tok::TokenKind::identifier);
484 if (!MustBeUnique)
485 return II;
486
487 unsigned suffix = 0;
488 while (true) {
489 if (suffix != 0) {
490 Name.append("_");
491 Name.append(llvm::Twine(suffix).str());
492 II = &AST.Idents.get(Name, tok::TokenKind::identifier);
493 }
494 if (!findRecordDeclInContext(II, BaseDecl->getDeclContext()))
495 return II;
496 // declaration with that name already exists - increment suffix and try
497 // again until unique name is found
498 suffix++;
499 Name.truncate(NameLength);
500 };
501}
502
503static const Type *createHostLayoutType(Sema &S, const Type *Ty) {
504 ASTContext &AST = S.getASTContext();
505 if (auto *RD = Ty->getAsCXXRecordDecl()) {
507 return Ty;
508 RD = createHostLayoutStruct(S, RD);
509 if (!RD)
510 return nullptr;
511 return AST.getCanonicalTagType(RD)->getTypePtr();
512 }
513
514 if (const auto *CAT = dyn_cast<ConstantArrayType>(Ty)) {
515 const Type *ElementTy = createHostLayoutType(
516 S, CAT->getElementType()->getUnqualifiedDesugaredType());
517 if (!ElementTy)
518 return nullptr;
519 return AST
520 .getConstantArrayType(QualType(ElementTy, 0), CAT->getSize(), nullptr,
521 CAT->getSizeModifier(),
522 CAT->getIndexTypeCVRQualifiers())
523 .getTypePtr();
524 }
525 return Ty;
526}
527
528// Creates a field declaration of given name and type for HLSL buffer layout
529// struct. Returns nullptr if the type cannot be use in HLSL Buffer layout.
531 IdentifierInfo *II,
532 CXXRecordDecl *LayoutStruct) {
534 return nullptr;
535
536 Ty = createHostLayoutType(S, Ty);
537 if (!Ty)
538 return nullptr;
539
540 QualType QT = QualType(Ty, 0);
541 ASTContext &AST = S.getASTContext();
543 auto *Field = FieldDecl::Create(AST, LayoutStruct, SourceLocation(),
544 SourceLocation(), II, QT, TSI, nullptr, false,
546 Field->setAccess(AccessSpecifier::AS_public);
547 return Field;
548}
549
550// Creates host layout struct for a struct included in HLSL Buffer.
551// The layout struct will include only fields that are allowed in HLSL buffer.
552// These fields will be filtered out:
553// - resource classes
554// - empty structs
555// - zero-sized arrays
556// Returns nullptr if the resulting layout struct would be empty.
558 CXXRecordDecl *StructDecl) {
559 assert(requiresImplicitBufferLayoutStructure(StructDecl) &&
560 "struct is already HLSL buffer compatible");
561
562 ASTContext &AST = S.getASTContext();
563 DeclContext *DC = StructDecl->getDeclContext();
564 IdentifierInfo *II = getHostLayoutStructName(S, StructDecl, false);
565
566 // reuse existing if the layout struct if it already exists
567 if (CXXRecordDecl *RD = findRecordDeclInContext(II, DC))
568 return RD;
569
570 CXXRecordDecl *LS =
571 CXXRecordDecl::Create(AST, TagDecl::TagKind::Struct, DC, SourceLocation(),
572 SourceLocation(), II);
573 LS->setImplicit(true);
574 LS->addAttr(PackedAttr::CreateImplicit(AST));
575 LS->startDefinition();
576
577 // copy base struct, create HLSL Buffer compatible version if needed
578 if (unsigned NumBases = StructDecl->getNumBases()) {
579 assert(NumBases == 1 && "HLSL supports only one base type");
580 (void)NumBases;
581 CXXBaseSpecifier Base = *StructDecl->bases_begin();
582 CXXRecordDecl *BaseDecl = Base.getType()->castAsCXXRecordDecl();
584 BaseDecl = createHostLayoutStruct(S, BaseDecl);
585 if (BaseDecl) {
586 TypeSourceInfo *TSI =
588 Base = CXXBaseSpecifier(SourceRange(), false, StructDecl->isClass(),
589 AS_none, TSI, SourceLocation());
590 }
591 }
592 if (BaseDecl) {
593 const CXXBaseSpecifier *BasesArray[1] = {&Base};
594 LS->setBases(BasesArray, 1);
595 }
596 }
597
598 // filter struct fields
599 for (const FieldDecl *FD : StructDecl->fields()) {
600 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
601 if (FieldDecl *NewFD =
602 createFieldForHostLayoutStruct(S, Ty, FD->getIdentifier(), LS))
603 LS->addDecl(NewFD);
604 }
605 LS->completeDefinition();
606
607 if (LS->field_empty() && LS->getNumBases() == 0)
608 return nullptr;
609
610 DC->addDecl(LS);
611 return LS;
612}
613
614// Creates host layout struct for HLSL Buffer. The struct will include only
615// fields of types that are allowed in HLSL buffer and it will filter out:
616// - static or groupshared variable declarations
617// - resource classes
618// - empty structs
619// - zero-sized arrays
620// - non-variable declarations
621// The layout struct will be added to the HLSLBufferDecl declarations.
623 ASTContext &AST = S.getASTContext();
624 IdentifierInfo *II = getHostLayoutStructName(S, BufDecl, true);
625
626 CXXRecordDecl *LS =
627 CXXRecordDecl::Create(AST, TagDecl::TagKind::Struct, BufDecl,
629 LS->addAttr(PackedAttr::CreateImplicit(AST));
630 LS->setImplicit(true);
631 LS->startDefinition();
632
633 for (Decl *D : BufDecl->buffer_decls()) {
634 VarDecl *VD = dyn_cast<VarDecl>(D);
635 if (!VD || VD->getStorageClass() == SC_Static ||
637 continue;
638 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
639
640 FieldDecl *FD =
642 // Declarations collected for the default $Globals constant buffer have
643 // already been checked to have non-empty cbuffer layout, so
644 // createFieldForHostLayoutStruct should always succeed. These declarations
645 // already have their address space set to hlsl_constant.
646 // For declarations in a named cbuffer block
647 // createFieldForHostLayoutStruct can still return nullptr if the type
648 // is empty (does not have a cbuffer layout).
649 assert((FD || VD->getType().getAddressSpace() != LangAS::hlsl_constant) &&
650 "host layout field for $Globals decl failed to be created");
651 if (FD) {
652 // Add the field decl to the layout struct.
653 LS->addDecl(FD);
655 // Update address space of the original decl to hlsl_constant.
656 QualType NewTy =
658 VD->setType(NewTy);
659 }
660 }
661 }
662 LS->completeDefinition();
663 BufDecl->addLayoutStruct(LS);
664}
665
667 uint32_t ImplicitBindingOrderID) {
668 auto *Attr =
669 HLSLResourceBindingAttr::CreateImplicit(S.getASTContext(), "", "0", {});
670 Attr->setBinding(RT, std::nullopt, 0);
671 Attr->setImplicitBindingOrderID(ImplicitBindingOrderID);
672 D->addAttr(Attr);
673}
674
675// Handle end of cbuffer/tbuffer declaration
677 auto *BufDecl = cast<HLSLBufferDecl>(Dcl);
678 BufDecl->setRBraceLoc(RBrace);
679
680 validatePackoffset(SemaRef, BufDecl);
681
683
684 // Handle implicit binding if needed.
685 ResourceBindingAttrs ResourceAttrs(Dcl);
686 if (!ResourceAttrs.isExplicit()) {
687 SemaRef.Diag(Dcl->getLocation(), diag::warn_hlsl_implicit_binding);
688 // Use HLSLResourceBindingAttr to transfer implicit binding order_ID
689 // to codegen. If it does not exist, create an implicit attribute.
690 uint32_t OrderID = getNextImplicitBindingOrderID();
691 if (ResourceAttrs.hasBinding())
692 ResourceAttrs.setImplicitOrderID(OrderID);
693 else
695 BufDecl->isCBuffer() ? RegisterType::CBuffer
696 : RegisterType::SRV,
697 OrderID);
698 }
699
700 SemaRef.PopDeclContext();
701}
702
703HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
704 const AttributeCommonInfo &AL,
705 int X, int Y, int Z) {
706 if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
707 if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
708 Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
709 Diag(AL.getLoc(), diag::note_conflicting_attribute);
710 }
711 return nullptr;
712 }
713 return ::new (getASTContext())
714 HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
715}
716
718 const AttributeCommonInfo &AL,
719 int Min, int Max, int Preferred,
720 int SpelledArgsCount) {
721 if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) {
722 if (WS->getMin() != Min || WS->getMax() != Max ||
723 WS->getPreferred() != Preferred ||
724 WS->getSpelledArgsCount() != SpelledArgsCount) {
725 Diag(WS->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
726 Diag(AL.getLoc(), diag::note_conflicting_attribute);
727 }
728 return nullptr;
729 }
730 HLSLWaveSizeAttr *Result = ::new (getASTContext())
731 HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
732 Result->setSpelledArgsCount(SpelledArgsCount);
733 return Result;
734}
735
736HLSLVkConstantIdAttr *
738 int Id) {
739
741 if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
742 Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL;
743 return nullptr;
744 }
745
746 auto *VD = cast<VarDecl>(D);
747
748 if (getSpecConstBuiltinId(VD->getType()->getUnqualifiedDesugaredType()) ==
750 Diag(VD->getLocation(), diag::err_specialization_const);
751 return nullptr;
752 }
753
754 if (!VD->getType().isConstQualified()) {
755 Diag(VD->getLocation(), diag::err_specialization_const);
756 return nullptr;
757 }
758
759 if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
760 if (CI->getId() != Id) {
761 Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
762 Diag(AL.getLoc(), diag::note_conflicting_attribute);
763 }
764 return nullptr;
765 }
766
767 HLSLVkConstantIdAttr *Result =
768 ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
769 return Result;
770}
771
772HLSLShaderAttr *
774 llvm::Triple::EnvironmentType ShaderType) {
775 if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
776 if (NT->getType() != ShaderType) {
777 Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
778 Diag(AL.getLoc(), diag::note_conflicting_attribute);
779 }
780 return nullptr;
781 }
782 return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL);
783}
784
785HLSLParamModifierAttr *
787 HLSLParamModifierAttr::Spelling Spelling) {
788 // We can only merge an `in` attribute with an `out` attribute. All other
789 // combinations of duplicated attributes are ill-formed.
790 if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
791 if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
792 (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
793 D->dropAttr<HLSLParamModifierAttr>();
794 SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
795 return HLSLParamModifierAttr::Create(
796 getASTContext(), /*MergedSpelling=*/true, AdjustedRange,
797 HLSLParamModifierAttr::Keyword_inout);
798 }
799 Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
800 Diag(PA->getLocation(), diag::note_conflicting_attribute);
801 return nullptr;
802 }
803 return HLSLParamModifierAttr::Create(getASTContext(), AL);
804}
805
808
810 return;
811
812 // If we have specified a root signature to override the entry function then
813 // attach it now
814 HLSLRootSignatureDecl *SignatureDecl =
816 if (SignatureDecl) {
817 FD->dropAttr<RootSignatureAttr>();
818 // We could look up the SourceRange of the macro here as well
819 AttributeCommonInfo AL(RootSigOverrideIdent, AttributeScopeInfo(),
820 SourceRange(), ParsedAttr::Form::Microsoft());
821 FD->addAttr(::new (getASTContext()) RootSignatureAttr(
822 getASTContext(), AL, RootSigOverrideIdent, SignatureDecl));
823 }
824
825 llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
826 if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) {
827 if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
828 // The entry point is already annotated - check that it matches the
829 // triple.
830 if (Shader->getType() != Env) {
831 Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
832 << Shader;
833 FD->setInvalidDecl();
834 }
835 } else {
836 // Implicitly add the shader attribute if the entry function isn't
837 // explicitly annotated.
838 FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env,
839 FD->getBeginLoc()));
840 }
841 } else {
842 switch (Env) {
843 case llvm::Triple::UnknownEnvironment:
844 case llvm::Triple::Library:
845 break;
846 case llvm::Triple::RootSignature:
847 llvm_unreachable("rootsig environment has no functions");
848 default:
849 llvm_unreachable("Unhandled environment in triple");
850 }
851 }
852}
853
854static bool isVkPipelineBuiltin(const ASTContext &AstContext, FunctionDecl *FD,
855 HLSLAppliedSemanticAttr *Semantic,
856 bool IsInput) {
857 if (AstContext.getTargetInfo().getTriple().getOS() != llvm::Triple::Vulkan)
858 return false;
859
860 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
861 assert(ShaderAttr && "Entry point has no shader attribute");
862 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
863 auto SemanticName = Semantic->getSemanticName().upper();
864
865 // The SV_Position semantic is lowered to:
866 // - Position built-in for vertex output.
867 // - FragCoord built-in for fragment input.
868 if (SemanticName == "SV_POSITION") {
869 return (ST == llvm::Triple::Vertex && !IsInput) ||
870 (ST == llvm::Triple::Pixel && IsInput);
871 }
872 if (SemanticName == "SV_VERTEXID")
873 return true;
874
875 return false;
876}
877
878bool SemaHLSL::determineActiveSemanticOnScalar(FunctionDecl *FD,
879 DeclaratorDecl *OutputDecl,
881 SemanticInfo &ActiveSemantic,
882 SemaHLSL::SemanticContext &SC) {
883 if (ActiveSemantic.Semantic == nullptr) {
884 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
885 if (ActiveSemantic.Semantic)
886 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
887 }
888
889 if (!ActiveSemantic.Semantic) {
890 Diag(D->getLocation(), diag::err_hlsl_missing_semantic_annotation);
891 return false;
892 }
893
894 auto *A = ::new (getASTContext())
895 HLSLAppliedSemanticAttr(getASTContext(), *ActiveSemantic.Semantic,
896 ActiveSemantic.Semantic->getAttrName()->getName(),
897 ActiveSemantic.Index.value_or(0));
898 if (!A)
900
901 checkSemanticAnnotation(FD, D, A, SC);
902 OutputDecl->addAttr(A);
903
904 unsigned Location = ActiveSemantic.Index.value_or(0);
905
907 SC.CurrentIOType & IOType::In)) {
908 bool HasVkLocation = false;
909 if (auto *A = D->getAttr<HLSLVkLocationAttr>()) {
910 HasVkLocation = true;
911 Location = A->getLocation();
912 }
913
914 if (SC.UsesExplicitVkLocations.value_or(HasVkLocation) != HasVkLocation) {
915 Diag(D->getLocation(), diag::err_hlsl_semantic_partial_explicit_indexing);
916 return false;
917 }
918 SC.UsesExplicitVkLocations = HasVkLocation;
919 }
920
921 const ConstantArrayType *AT = dyn_cast<ConstantArrayType>(D->getType());
922 unsigned ElementCount = AT ? AT->getZExtSize() : 1;
923 ActiveSemantic.Index = Location + ElementCount;
924
925 Twine BaseName = Twine(ActiveSemantic.Semantic->getAttrName()->getName());
926 for (unsigned I = 0; I < ElementCount; ++I) {
927 Twine VariableName = BaseName.concat(Twine(Location + I));
928
929 auto [_, Inserted] = SC.ActiveSemantics.insert(VariableName.str());
930 if (!Inserted) {
931 Diag(D->getLocation(), diag::err_hlsl_semantic_index_overlap)
932 << VariableName.str();
933 return false;
934 }
935 }
936
937 return true;
938}
939
940bool SemaHLSL::determineActiveSemantic(FunctionDecl *FD,
941 DeclaratorDecl *OutputDecl,
943 SemanticInfo &ActiveSemantic,
944 SemaHLSL::SemanticContext &SC) {
945 if (ActiveSemantic.Semantic == nullptr) {
946 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
947 if (ActiveSemantic.Semantic)
948 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
949 }
950
951 const Type *T = D == FD ? &*FD->getReturnType() : &*D->getType();
953
954 const RecordType *RT = dyn_cast<RecordType>(T);
955 if (!RT)
956 return determineActiveSemanticOnScalar(FD, OutputDecl, D, ActiveSemantic,
957 SC);
958
959 const RecordDecl *RD = RT->getDecl();
960 for (FieldDecl *Field : RD->fields()) {
961 SemanticInfo Info = ActiveSemantic;
962 if (!determineActiveSemantic(FD, OutputDecl, Field, Info, SC)) {
963 Diag(Field->getLocation(), diag::note_hlsl_semantic_used_here) << Field;
964 return false;
965 }
966 if (ActiveSemantic.Semantic)
967 ActiveSemantic = Info;
968 }
969
970 return true;
971}
972
974 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
975 assert(ShaderAttr && "Entry point has no shader attribute");
976 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
978 VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
979 switch (ST) {
980 case llvm::Triple::Pixel:
981 case llvm::Triple::Vertex:
982 case llvm::Triple::Geometry:
983 case llvm::Triple::Hull:
984 case llvm::Triple::Domain:
985 case llvm::Triple::RayGeneration:
986 case llvm::Triple::Intersection:
987 case llvm::Triple::AnyHit:
988 case llvm::Triple::ClosestHit:
989 case llvm::Triple::Miss:
990 case llvm::Triple::Callable:
991 if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
992 diagnoseAttrStageMismatch(NT, ST,
993 {llvm::Triple::Compute,
994 llvm::Triple::Amplification,
995 llvm::Triple::Mesh});
996 FD->setInvalidDecl();
997 }
998 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
999 diagnoseAttrStageMismatch(WS, ST,
1000 {llvm::Triple::Compute,
1001 llvm::Triple::Amplification,
1002 llvm::Triple::Mesh});
1003 FD->setInvalidDecl();
1004 }
1005 break;
1006
1007 case llvm::Triple::Compute:
1008 case llvm::Triple::Amplification:
1009 case llvm::Triple::Mesh:
1010 if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
1011 Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
1012 << llvm::Triple::getEnvironmentTypeName(ST);
1013 FD->setInvalidDecl();
1014 }
1015 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
1016 if (Ver < VersionTuple(6, 6)) {
1017 Diag(WS->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model)
1018 << WS << "6.6";
1019 FD->setInvalidDecl();
1020 } else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
1021 Diag(
1022 WS->getLocation(),
1023 diag::err_hlsl_attribute_number_arguments_insufficient_shader_model)
1024 << WS << WS->getSpelledArgsCount() << "6.8";
1025 FD->setInvalidDecl();
1026 }
1027 }
1028 break;
1029 case llvm::Triple::RootSignature:
1030 llvm_unreachable("rootsig environment has no function entry point");
1031 default:
1032 llvm_unreachable("Unhandled environment in triple");
1033 }
1034
1035 SemaHLSL::SemanticContext InputSC = {};
1036 InputSC.CurrentIOType = IOType::In;
1037
1038 for (ParmVarDecl *Param : FD->parameters()) {
1039 SemanticInfo ActiveSemantic;
1040 ActiveSemantic.Semantic = Param->getAttr<HLSLParsedSemanticAttr>();
1041 if (ActiveSemantic.Semantic)
1042 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
1043
1044 // FIXME: Verify output semantics in parameters.
1045 if (!determineActiveSemantic(FD, Param, Param, ActiveSemantic, InputSC)) {
1046 Diag(Param->getLocation(), diag::note_previous_decl) << Param;
1047 FD->setInvalidDecl();
1048 }
1049 }
1050
1051 SemanticInfo ActiveSemantic;
1052 SemaHLSL::SemanticContext OutputSC = {};
1053 OutputSC.CurrentIOType = IOType::Out;
1054 ActiveSemantic.Semantic = FD->getAttr<HLSLParsedSemanticAttr>();
1055 if (ActiveSemantic.Semantic)
1056 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
1057 if (!FD->getReturnType()->isVoidType())
1058 determineActiveSemantic(FD, FD, FD, ActiveSemantic, OutputSC);
1059}
1060
1061void SemaHLSL::checkSemanticAnnotation(
1062 FunctionDecl *EntryPoint, const Decl *Param,
1063 const HLSLAppliedSemanticAttr *SemanticAttr, const SemanticContext &SC) {
1064 auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
1065 assert(ShaderAttr && "Entry point has no shader attribute");
1066 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
1067
1068 auto SemanticName = SemanticAttr->getSemanticName().upper();
1069 if (SemanticName == "SV_DISPATCHTHREADID" ||
1070 SemanticName == "SV_GROUPINDEX" || SemanticName == "SV_GROUPTHREADID" ||
1071 SemanticName == "SV_GROUPID") {
1072
1073 if (ST != llvm::Triple::Compute)
1074 diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
1075 {{llvm::Triple::Compute, IOType::In}});
1076
1077 if (SemanticAttr->getSemanticIndex() != 0) {
1078 std::string PrettyName =
1079 "'" + SemanticAttr->getSemanticName().str() + "'";
1080 Diag(SemanticAttr->getLoc(),
1081 diag::err_hlsl_semantic_indexing_not_supported)
1082 << PrettyName;
1083 }
1084 return;
1085 }
1086
1087 if (SemanticName == "SV_POSITION") {
1088 // SV_Position can be an input or output in vertex shaders,
1089 // but only an input in pixel shaders.
1090 diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
1091 {{llvm::Triple::Vertex, IOType::InOut},
1092 {llvm::Triple::Pixel, IOType::In}});
1093 return;
1094 }
1095 if (SemanticName == "SV_VERTEXID") {
1096 diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
1097 {{llvm::Triple::Vertex, IOType::In}});
1098 return;
1099 }
1100
1101 if (SemanticName == "SV_TARGET") {
1102 diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
1103 {{llvm::Triple::Pixel, IOType::Out}});
1104 return;
1105 }
1106
1107 // FIXME: catch-all for non-implemented system semantics reaching this
1108 // location.
1109 if (SemanticAttr->getAttrName()->getName().starts_with_insensitive("SV_"))
1110 llvm_unreachable("Unknown SemanticAttr");
1111}
1112
1113void SemaHLSL::diagnoseAttrStageMismatch(
1114 const Attr *A, llvm::Triple::EnvironmentType Stage,
1115 std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
1116 SmallVector<StringRef, 8> StageStrings;
1117 llvm::transform(AllowedStages, std::back_inserter(StageStrings),
1118 [](llvm::Triple::EnvironmentType ST) {
1119 return StringRef(
1120 HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST));
1121 });
1122 Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
1123 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Stage)
1124 << (AllowedStages.size() != 1) << join(StageStrings, ", ");
1125}
1126
1127void SemaHLSL::diagnoseSemanticStageMismatch(
1128 const Attr *A, llvm::Triple::EnvironmentType Stage, IOType CurrentIOType,
1129 std::initializer_list<SemanticStageInfo> Allowed) {
1130
1131 for (auto &Case : Allowed) {
1132 if (Case.Stage != Stage)
1133 continue;
1134
1135 if (CurrentIOType & Case.AllowedIOTypesMask)
1136 return;
1137
1138 SmallVector<std::string, 8> ValidCases;
1139 llvm::transform(
1140 Allowed, std::back_inserter(ValidCases), [](SemanticStageInfo Case) {
1141 SmallVector<std::string, 2> ValidType;
1142 if (Case.AllowedIOTypesMask & IOType::In)
1143 ValidType.push_back("input");
1144 if (Case.AllowedIOTypesMask & IOType::Out)
1145 ValidType.push_back("output");
1146 return std::string(
1147 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Case.Stage)) +
1148 " " + join(ValidType, "/");
1149 });
1150 Diag(A->getLoc(), diag::err_hlsl_semantic_unsupported_iotype_for_stage)
1151 << A->getAttrName() << (CurrentIOType & IOType::In ? "input" : "output")
1152 << llvm::Triple::getEnvironmentTypeName(Case.Stage)
1153 << join(ValidCases, ", ");
1154 return;
1155 }
1156
1157 SmallVector<StringRef, 8> StageStrings;
1158 llvm::transform(
1159 Allowed, std::back_inserter(StageStrings), [](SemanticStageInfo Case) {
1160 return StringRef(
1161 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Case.Stage));
1162 });
1163
1164 Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
1165 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Stage)
1166 << (Allowed.size() != 1) << join(StageStrings, ", ");
1167}
1168
1169template <CastKind Kind>
1170static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
1171 if (const auto *VTy = Ty->getAs<VectorType>())
1172 Ty = VTy->getElementType();
1173 Ty = S.getASTContext().getExtVectorType(Ty, Sz);
1174 E = S.ImpCastExprToType(E.get(), Ty, Kind);
1175}
1176
1177template <CastKind Kind>
1179 E = S.ImpCastExprToType(E.get(), Ty, Kind);
1180 return Ty;
1181}
1182
1184 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1185 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1186 bool LHSFloat = LElTy->isRealFloatingType();
1187 bool RHSFloat = RElTy->isRealFloatingType();
1188
1189 if (LHSFloat && RHSFloat) {
1190 if (IsCompAssign ||
1191 SemaRef.getASTContext().getFloatingTypeOrder(LElTy, RElTy) > 0)
1192 return castElement<CK_FloatingCast>(SemaRef, RHS, LHSType);
1193
1194 return castElement<CK_FloatingCast>(SemaRef, LHS, RHSType);
1195 }
1196
1197 if (LHSFloat)
1198 return castElement<CK_IntegralToFloating>(SemaRef, RHS, LHSType);
1199
1200 assert(RHSFloat);
1201 if (IsCompAssign)
1202 return castElement<clang::CK_FloatingToIntegral>(SemaRef, RHS, LHSType);
1203
1204 return castElement<CK_IntegralToFloating>(SemaRef, LHS, RHSType);
1205}
1206
1208 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1209 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1210
1211 int IntOrder = SemaRef.Context.getIntegerTypeOrder(LElTy, RElTy);
1212 bool LHSSigned = LElTy->hasSignedIntegerRepresentation();
1213 bool RHSSigned = RElTy->hasSignedIntegerRepresentation();
1214 auto &Ctx = SemaRef.getASTContext();
1215
1216 // If both types have the same signedness, use the higher ranked type.
1217 if (LHSSigned == RHSSigned) {
1218 if (IsCompAssign || IntOrder >= 0)
1219 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1220
1221 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
1222 }
1223
1224 // If the unsigned type has greater than or equal rank of the signed type, use
1225 // the unsigned type.
1226 if (IntOrder != (LHSSigned ? 1 : -1)) {
1227 if (IsCompAssign || RHSSigned)
1228 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1229 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
1230 }
1231
1232 // At this point the signed type has higher rank than the unsigned type, which
1233 // means it will be the same size or bigger. If the signed type is bigger, it
1234 // can represent all the values of the unsigned type, so select it.
1235 if (Ctx.getIntWidth(LElTy) != Ctx.getIntWidth(RElTy)) {
1236 if (IsCompAssign || LHSSigned)
1237 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1238 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
1239 }
1240
1241 // This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due
1242 // to C/C++ leaking through. The place this happens today is long vs long
1243 // long. When arguments are vector<unsigned long, N> and vector<long long, N>,
1244 // the long long has higher rank than long even though they are the same size.
1245
1246 // If this is a compound assignment cast the right hand side to the left hand
1247 // side's type.
1248 if (IsCompAssign)
1249 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1250
1251 // If this isn't a compound assignment we convert to unsigned long long.
1252 QualType ElTy = Ctx.getCorrespondingUnsignedType(LHSSigned ? LElTy : RElTy);
1253 QualType NewTy = Ctx.getExtVectorType(
1254 ElTy, RHSType->castAs<VectorType>()->getNumElements());
1255 (void)castElement<CK_IntegralCast>(SemaRef, RHS, NewTy);
1256
1257 return castElement<CK_IntegralCast>(SemaRef, LHS, NewTy);
1258}
1259
1261 QualType SrcTy) {
1262 if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType())
1263 return CK_FloatingCast;
1264 if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx))
1265 return CK_IntegralCast;
1266 if (DestTy->isRealFloatingType())
1267 return CK_IntegralToFloating;
1268 assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx));
1269 return CK_FloatingToIntegral;
1270}
1271
1273 QualType LHSType,
1274 QualType RHSType,
1275 bool IsCompAssign) {
1276 const auto *LVecTy = LHSType->getAs<VectorType>();
1277 const auto *RVecTy = RHSType->getAs<VectorType>();
1278 auto &Ctx = getASTContext();
1279
1280 // If the LHS is not a vector and this is a compound assignment, we truncate
1281 // the argument to a scalar then convert it to the LHS's type.
1282 if (!LVecTy && IsCompAssign) {
1283 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1284 RHS = SemaRef.ImpCastExprToType(RHS.get(), RElTy, CK_HLSLVectorTruncation);
1285 RHSType = RHS.get()->getType();
1286 if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
1287 return LHSType;
1288 RHS = SemaRef.ImpCastExprToType(RHS.get(), LHSType,
1289 getScalarCastKind(Ctx, LHSType, RHSType));
1290 return LHSType;
1291 }
1292
1293 unsigned EndSz = std::numeric_limits<unsigned>::max();
1294 unsigned LSz = 0;
1295 if (LVecTy)
1296 LSz = EndSz = LVecTy->getNumElements();
1297 if (RVecTy)
1298 EndSz = std::min(RVecTy->getNumElements(), EndSz);
1299 assert(EndSz != std::numeric_limits<unsigned>::max() &&
1300 "one of the above should have had a value");
1301
1302 // In a compound assignment, the left operand does not change type, the right
1303 // operand is converted to the type of the left operand.
1304 if (IsCompAssign && LSz != EndSz) {
1305 Diag(LHS.get()->getBeginLoc(),
1306 diag::err_hlsl_vector_compound_assignment_truncation)
1307 << LHSType << RHSType;
1308 return QualType();
1309 }
1310
1311 if (RVecTy && RVecTy->getNumElements() > EndSz)
1312 castVector<CK_HLSLVectorTruncation>(SemaRef, RHS, RHSType, EndSz);
1313 if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz)
1314 castVector<CK_HLSLVectorTruncation>(SemaRef, LHS, LHSType, EndSz);
1315
1316 if (!RVecTy)
1317 castVector<CK_VectorSplat>(SemaRef, RHS, RHSType, EndSz);
1318 if (!IsCompAssign && !LVecTy)
1319 castVector<CK_VectorSplat>(SemaRef, LHS, LHSType, EndSz);
1320
1321 // If we're at the same type after resizing we can stop here.
1322 if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
1323 return Ctx.getCommonSugaredType(LHSType, RHSType);
1324
1325 QualType LElTy = LHSType->castAs<VectorType>()->getElementType();
1326 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1327
1328 // Handle conversion for floating point vectors.
1329 if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType())
1330 return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1331 LElTy, RElTy, IsCompAssign);
1332
1333 assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) &&
1334 "HLSL Vectors can only contain integer or floating point types");
1335 return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1336 LElTy, RElTy, IsCompAssign);
1337}
1338
1340 BinaryOperatorKind Opc) {
1341 assert((Opc == BO_LOr || Opc == BO_LAnd) &&
1342 "Called with non-logical operator");
1344 llvm::raw_svector_ostream OS(Buff);
1345 PrintingPolicy PP(SemaRef.getLangOpts());
1346 StringRef NewFnName = Opc == BO_LOr ? "or" : "and";
1347 OS << NewFnName << "(";
1348 LHS->printPretty(OS, nullptr, PP);
1349 OS << ", ";
1350 RHS->printPretty(OS, nullptr, PP);
1351 OS << ")";
1352 SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc());
1353 SemaRef.Diag(LHS->getBeginLoc(), diag::note_function_suggestion)
1354 << NewFnName << FixItHint::CreateReplacement(FullRange, OS.str());
1355}
1356
1357std::pair<IdentifierInfo *, bool>
1359 llvm::hash_code Hash = llvm::hash_value(Signature);
1360 std::string IdStr = "__hlsl_rootsig_decl_" + std::to_string(Hash);
1361 IdentifierInfo *DeclIdent = &(getASTContext().Idents.get(IdStr));
1362
1363 // Check if we have already found a decl of the same name.
1364 LookupResult R(SemaRef, DeclIdent, SourceLocation(),
1366 bool Found = SemaRef.LookupQualifiedName(R, SemaRef.CurContext);
1367 return {DeclIdent, Found};
1368}
1369
1371 SourceLocation Loc, IdentifierInfo *DeclIdent,
1373
1374 if (handleRootSignatureElements(RootElements))
1375 return;
1376
1378 for (auto &RootSigElement : RootElements)
1379 Elements.push_back(RootSigElement.getElement());
1380
1381 auto *SignatureDecl = HLSLRootSignatureDecl::Create(
1382 SemaRef.getASTContext(), /*DeclContext=*/SemaRef.CurContext, Loc,
1383 DeclIdent, SemaRef.getLangOpts().HLSLRootSigVer, Elements);
1384
1385 SignatureDecl->setImplicit();
1386 SemaRef.PushOnScopeChains(SignatureDecl, SemaRef.getCurScope());
1387}
1388
1391 if (RootSigOverrideIdent) {
1392 LookupResult R(SemaRef, RootSigOverrideIdent, SourceLocation(),
1394 if (SemaRef.LookupQualifiedName(R, DC))
1395 return dyn_cast<HLSLRootSignatureDecl>(R.getFoundDecl());
1396 }
1397
1398 return nullptr;
1399}
1400
1401namespace {
1402
1403struct PerVisibilityBindingChecker {
1404 SemaHLSL *S;
1405 // We need one builder per `llvm::dxbc::ShaderVisibility` value.
1406 std::array<llvm::hlsl::BindingInfoBuilder, 8> Builders;
1407
1408 struct ElemInfo {
1409 const hlsl::RootSignatureElement *Elem;
1410 llvm::dxbc::ShaderVisibility Vis;
1411 bool Diagnosed;
1412 };
1413 llvm::SmallVector<ElemInfo> ElemInfoMap;
1414
1415 PerVisibilityBindingChecker(SemaHLSL *S) : S(S) {}
1416
1417 void trackBinding(llvm::dxbc::ShaderVisibility Visibility,
1418 llvm::dxil::ResourceClass RC, uint32_t Space,
1419 uint32_t LowerBound, uint32_t UpperBound,
1420 const hlsl::RootSignatureElement *Elem) {
1421 uint32_t BuilderIndex = llvm::to_underlying(Visibility);
1422 assert(BuilderIndex < Builders.size() &&
1423 "Not enough builders for visibility type");
1424 Builders[BuilderIndex].trackBinding(RC, Space, LowerBound, UpperBound,
1425 static_cast<const void *>(Elem));
1426
1427 static_assert(llvm::to_underlying(llvm::dxbc::ShaderVisibility::All) == 0,
1428 "'All' visibility must come first");
1429 if (Visibility == llvm::dxbc::ShaderVisibility::All)
1430 for (size_t I = 1, E = Builders.size(); I < E; ++I)
1431 Builders[I].trackBinding(RC, Space, LowerBound, UpperBound,
1432 static_cast<const void *>(Elem));
1433
1434 ElemInfoMap.push_back({Elem, Visibility, false});
1435 }
1436
1437 ElemInfo &getInfo(const hlsl::RootSignatureElement *Elem) {
1438 auto It = llvm::lower_bound(
1439 ElemInfoMap, Elem,
1440 [](const auto &LHS, const auto &RHS) { return LHS.Elem < RHS; });
1441 assert(It->Elem == Elem && "Element not in map");
1442 return *It;
1443 }
1444
1445 bool checkOverlap() {
1446 llvm::sort(ElemInfoMap, [](const auto &LHS, const auto &RHS) {
1447 return LHS.Elem < RHS.Elem;
1448 });
1449
1450 bool HadOverlap = false;
1451
1452 using llvm::hlsl::BindingInfoBuilder;
1453 auto ReportOverlap = [this,
1454 &HadOverlap](const BindingInfoBuilder &Builder,
1455 const llvm::hlsl::Binding &Reported) {
1456 HadOverlap = true;
1457
1458 const auto *Elem =
1459 static_cast<const hlsl::RootSignatureElement *>(Reported.Cookie);
1460 const llvm::hlsl::Binding &Previous = Builder.findOverlapping(Reported);
1461 const auto *PrevElem =
1462 static_cast<const hlsl::RootSignatureElement *>(Previous.Cookie);
1463
1464 ElemInfo &Info = getInfo(Elem);
1465 // We will have already diagnosed this binding if there's overlap in the
1466 // "All" visibility as well as any particular visibility.
1467 if (Info.Diagnosed)
1468 return;
1469 Info.Diagnosed = true;
1470
1471 ElemInfo &PrevInfo = getInfo(PrevElem);
1472 llvm::dxbc::ShaderVisibility CommonVis =
1473 Info.Vis == llvm::dxbc::ShaderVisibility::All ? PrevInfo.Vis
1474 : Info.Vis;
1475
1476 this->S->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap)
1477 << llvm::to_underlying(Reported.RC) << Reported.LowerBound
1478 << Reported.isUnbounded() << Reported.UpperBound
1479 << llvm::to_underlying(Previous.RC) << Previous.LowerBound
1480 << Previous.isUnbounded() << Previous.UpperBound << Reported.Space
1481 << CommonVis;
1482
1483 this->S->Diag(PrevElem->getLocation(),
1484 diag::note_hlsl_resource_range_here);
1485 };
1486
1487 for (BindingInfoBuilder &Builder : Builders)
1488 Builder.calculateBindingInfo(ReportOverlap);
1489
1490 return HadOverlap;
1491 }
1492};
1493
1494static CXXMethodDecl *lookupMethod(Sema &S, CXXRecordDecl *RecordDecl,
1495 StringRef Name, SourceLocation Loc) {
1496 DeclarationName DeclName(&S.getASTContext().Idents.get(Name));
1497 LookupResult Result(S, DeclName, Loc, Sema::LookupMemberName);
1498 if (!S.LookupQualifiedName(Result, static_cast<DeclContext *>(RecordDecl)))
1499 return nullptr;
1500 return cast<CXXMethodDecl>(Result.getFoundDecl());
1501}
1502
1503} // end anonymous namespace
1504
1507 // Define some common error handling functions
1508 bool HadError = false;
1509 auto ReportError = [this, &HadError](SourceLocation Loc, uint32_t LowerBound,
1510 uint32_t UpperBound) {
1511 HadError = true;
1512 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value)
1513 << LowerBound << UpperBound;
1514 };
1515
1516 auto ReportFloatError = [this, &HadError](SourceLocation Loc,
1517 float LowerBound,
1518 float UpperBound) {
1519 HadError = true;
1520 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value)
1521 << llvm::formatv("{0:f}", LowerBound).sstr<6>()
1522 << llvm::formatv("{0:f}", UpperBound).sstr<6>();
1523 };
1524
1525 auto VerifyRegister = [ReportError](SourceLocation Loc, uint32_t Register) {
1526 if (!llvm::hlsl::rootsig::verifyRegisterValue(Register))
1527 ReportError(Loc, 0, 0xfffffffe);
1528 };
1529
1530 auto VerifySpace = [ReportError](SourceLocation Loc, uint32_t Space) {
1531 if (!llvm::hlsl::rootsig::verifyRegisterSpace(Space))
1532 ReportError(Loc, 0, 0xffffffef);
1533 };
1534
1535 const uint32_t Version =
1536 llvm::to_underlying(SemaRef.getLangOpts().HLSLRootSigVer);
1537 const uint32_t VersionEnum = Version - 1;
1538 auto ReportFlagError = [this, &HadError, VersionEnum](SourceLocation Loc) {
1539 HadError = true;
1540 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_flag)
1541 << /*version minor*/ VersionEnum;
1542 };
1543
1544 // Iterate through the elements and do basic validations
1545 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1546 SourceLocation Loc = RootSigElem.getLocation();
1547 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1548 if (const auto *Descriptor =
1549 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
1550 VerifyRegister(Loc, Descriptor->Reg.Number);
1551 VerifySpace(Loc, Descriptor->Space);
1552
1553 if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(Version,
1554 Descriptor->Flags))
1555 ReportFlagError(Loc);
1556 } else if (const auto *Constants =
1557 std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
1558 VerifyRegister(Loc, Constants->Reg.Number);
1559 VerifySpace(Loc, Constants->Space);
1560 } else if (const auto *Sampler =
1561 std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
1562 VerifyRegister(Loc, Sampler->Reg.Number);
1563 VerifySpace(Loc, Sampler->Space);
1564
1565 assert(!std::isnan(Sampler->MaxLOD) && !std::isnan(Sampler->MinLOD) &&
1566 "By construction, parseFloatParam can't produce a NaN from a "
1567 "float_literal token");
1568
1569 if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler->MaxAnisotropy))
1570 ReportError(Loc, 0, 16);
1571 if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler->MipLODBias))
1572 ReportFloatError(Loc, -16.f, 15.99f);
1573 } else if (const auto *Clause =
1574 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1575 &Elem)) {
1576 VerifyRegister(Loc, Clause->Reg.Number);
1577 VerifySpace(Loc, Clause->Space);
1578
1579 if (!llvm::hlsl::rootsig::verifyNumDescriptors(Clause->NumDescriptors)) {
1580 // NumDescriptor could techincally be ~0u but that is reserved for
1581 // unbounded, so the diagnostic will not report that as a valid int
1582 // value
1583 ReportError(Loc, 1, 0xfffffffe);
1584 }
1585
1586 if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(Version, Clause->Type,
1587 Clause->Flags))
1588 ReportFlagError(Loc);
1589 }
1590 }
1591
1592 PerVisibilityBindingChecker BindingChecker(this);
1593 SmallVector<std::pair<const llvm::hlsl::rootsig::DescriptorTableClause *,
1595 UnboundClauses;
1596
1597 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1598 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1599 if (const auto *Descriptor =
1600 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
1601 uint32_t LowerBound(Descriptor->Reg.Number);
1602 uint32_t UpperBound(LowerBound); // inclusive range
1603
1604 BindingChecker.trackBinding(
1605 Descriptor->Visibility,
1606 static_cast<llvm::dxil::ResourceClass>(Descriptor->Type),
1607 Descriptor->Space, LowerBound, UpperBound, &RootSigElem);
1608 } else if (const auto *Constants =
1609 std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
1610 uint32_t LowerBound(Constants->Reg.Number);
1611 uint32_t UpperBound(LowerBound); // inclusive range
1612
1613 BindingChecker.trackBinding(
1614 Constants->Visibility, llvm::dxil::ResourceClass::CBuffer,
1615 Constants->Space, LowerBound, UpperBound, &RootSigElem);
1616 } else if (const auto *Sampler =
1617 std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
1618 uint32_t LowerBound(Sampler->Reg.Number);
1619 uint32_t UpperBound(LowerBound); // inclusive range
1620
1621 BindingChecker.trackBinding(
1622 Sampler->Visibility, llvm::dxil::ResourceClass::Sampler,
1623 Sampler->Space, LowerBound, UpperBound, &RootSigElem);
1624 } else if (const auto *Clause =
1625 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1626 &Elem)) {
1627 // We'll process these once we see the table element.
1628 UnboundClauses.emplace_back(Clause, &RootSigElem);
1629 } else if (const auto *Table =
1630 std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
1631 assert(UnboundClauses.size() == Table->NumClauses &&
1632 "Number of unbound elements must match the number of clauses");
1633 bool HasAnySampler = false;
1634 bool HasAnyNonSampler = false;
1635 uint64_t Offset = 0;
1636 bool IsPrevUnbound = false;
1637 for (const auto &[Clause, ClauseElem] : UnboundClauses) {
1638 SourceLocation Loc = ClauseElem->getLocation();
1639 if (Clause->Type == llvm::dxil::ResourceClass::Sampler)
1640 HasAnySampler = true;
1641 else
1642 HasAnyNonSampler = true;
1643
1644 if (HasAnySampler && HasAnyNonSampler)
1645 Diag(Loc, diag::err_hlsl_invalid_mixed_resources);
1646
1647 // Relevant error will have already been reported above and needs to be
1648 // fixed before we can conduct further analysis, so shortcut error
1649 // return
1650 if (Clause->NumDescriptors == 0)
1651 return true;
1652
1653 bool IsAppending =
1654 Clause->Offset == llvm::hlsl::rootsig::DescriptorTableOffsetAppend;
1655 if (!IsAppending)
1656 Offset = Clause->Offset;
1657
1658 uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound(
1659 Offset, Clause->NumDescriptors);
1660
1661 if (IsPrevUnbound && IsAppending)
1662 Diag(Loc, diag::err_hlsl_appending_onto_unbound);
1663 else if (!llvm::hlsl::rootsig::verifyNoOverflowedOffset(RangeBound))
1664 Diag(Loc, diag::err_hlsl_offset_overflow) << Offset << RangeBound;
1665
1666 // Update offset to be 1 past this range's bound
1667 Offset = RangeBound + 1;
1668 IsPrevUnbound = Clause->NumDescriptors ==
1669 llvm::hlsl::rootsig::NumDescriptorsUnbounded;
1670
1671 // Compute the register bounds and track resource binding
1672 uint32_t LowerBound(Clause->Reg.Number);
1673 uint32_t UpperBound = llvm::hlsl::rootsig::computeRangeBound(
1674 LowerBound, Clause->NumDescriptors);
1675
1676 BindingChecker.trackBinding(
1677 Table->Visibility,
1678 static_cast<llvm::dxil::ResourceClass>(Clause->Type), Clause->Space,
1679 LowerBound, UpperBound, ClauseElem);
1680 }
1681 UnboundClauses.clear();
1682 }
1683 }
1684
1685 return BindingChecker.checkOverlap();
1686}
1687
1689 if (AL.getNumArgs() != 1) {
1690 Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1;
1691 return;
1692 }
1693
1695 if (auto *RS = D->getAttr<RootSignatureAttr>()) {
1696 if (RS->getSignatureIdent() != Ident) {
1697 Diag(AL.getLoc(), diag::err_disallowed_duplicate_attribute) << RS;
1698 return;
1699 }
1700
1701 Diag(AL.getLoc(), diag::warn_duplicate_attribute_exact) << RS;
1702 return;
1703 }
1704
1706 if (SemaRef.LookupQualifiedName(R, D->getDeclContext()))
1707 if (auto *SignatureDecl =
1708 dyn_cast<HLSLRootSignatureDecl>(R.getFoundDecl())) {
1709 D->addAttr(::new (getASTContext()) RootSignatureAttr(
1710 getASTContext(), AL, Ident, SignatureDecl));
1711 }
1712}
1713
1715 llvm::VersionTuple SMVersion =
1716 getASTContext().getTargetInfo().getTriple().getOSVersion();
1717 bool IsDXIL = getASTContext().getTargetInfo().getTriple().getArch() ==
1718 llvm::Triple::dxil;
1719
1720 uint32_t ZMax = 1024;
1721 uint32_t ThreadMax = 1024;
1722 if (IsDXIL && SMVersion.getMajor() <= 4) {
1723 ZMax = 1;
1724 ThreadMax = 768;
1725 } else if (IsDXIL && SMVersion.getMajor() == 5) {
1726 ZMax = 64;
1727 ThreadMax = 1024;
1728 }
1729
1730 uint32_t X;
1731 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), X))
1732 return;
1733 if (X > 1024) {
1734 Diag(AL.getArgAsExpr(0)->getExprLoc(),
1735 diag::err_hlsl_numthreads_argument_oor)
1736 << 0 << 1024;
1737 return;
1738 }
1739 uint32_t Y;
1740 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Y))
1741 return;
1742 if (Y > 1024) {
1743 Diag(AL.getArgAsExpr(1)->getExprLoc(),
1744 diag::err_hlsl_numthreads_argument_oor)
1745 << 1 << 1024;
1746 return;
1747 }
1748 uint32_t Z;
1749 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Z))
1750 return;
1751 if (Z > ZMax) {
1752 SemaRef.Diag(AL.getArgAsExpr(2)->getExprLoc(),
1753 diag::err_hlsl_numthreads_argument_oor)
1754 << 2 << ZMax;
1755 return;
1756 }
1757
1758 if (X * Y * Z > ThreadMax) {
1759 Diag(AL.getLoc(), diag::err_hlsl_numthreads_invalid) << ThreadMax;
1760 return;
1761 }
1762
1763 HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z);
1764 if (NewAttr)
1765 D->addAttr(NewAttr);
1766}
1767
1768static bool isValidWaveSizeValue(unsigned Value) {
1769 return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128;
1770}
1771
1773 // validate that the wavesize argument is a power of 2 between 4 and 128
1774 // inclusive
1775 unsigned SpelledArgsCount = AL.getNumArgs();
1776 if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
1777 return;
1778
1779 uint32_t Min;
1780 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Min))
1781 return;
1782
1783 uint32_t Max = 0;
1784 if (SpelledArgsCount > 1 &&
1785 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Max))
1786 return;
1787
1788 uint32_t Preferred = 0;
1789 if (SpelledArgsCount > 2 &&
1790 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Preferred))
1791 return;
1792
1793 if (SpelledArgsCount > 2) {
1794 if (!isValidWaveSizeValue(Preferred)) {
1795 Diag(AL.getArgAsExpr(2)->getExprLoc(),
1796 diag::err_attribute_power_of_two_in_range)
1797 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize
1798 << Preferred;
1799 return;
1800 }
1801 // Preferred not in range.
1802 if (Preferred < Min || Preferred > Max) {
1803 Diag(AL.getArgAsExpr(2)->getExprLoc(),
1804 diag::err_attribute_power_of_two_in_range)
1805 << AL << Min << Max << Preferred;
1806 return;
1807 }
1808 } else if (SpelledArgsCount > 1) {
1809 if (!isValidWaveSizeValue(Max)) {
1810 Diag(AL.getArgAsExpr(1)->getExprLoc(),
1811 diag::err_attribute_power_of_two_in_range)
1812 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max;
1813 return;
1814 }
1815 if (Max < Min) {
1816 Diag(AL.getLoc(), diag::err_attribute_argument_invalid) << AL << 1;
1817 return;
1818 } else if (Max == Min) {
1819 Diag(AL.getLoc(), diag::warn_attr_min_eq_max) << AL;
1820 }
1821 } else {
1822 if (!isValidWaveSizeValue(Min)) {
1823 Diag(AL.getArgAsExpr(0)->getExprLoc(),
1824 diag::err_attribute_power_of_two_in_range)
1825 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min;
1826 return;
1827 }
1828 }
1829
1830 HLSLWaveSizeAttr *NewAttr =
1831 mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
1832 if (NewAttr)
1833 D->addAttr(NewAttr);
1834}
1835
1837 uint32_t ID;
1838 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), ID))
1839 return;
1840 D->addAttr(::new (getASTContext())
1841 HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
1842}
1843
1845 uint32_t ID;
1846 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), ID))
1847 return;
1848 D->addAttr(::new (getASTContext())
1849 HLSLVkExtBuiltinOutputAttr(getASTContext(), AL, ID));
1850}
1851
1853 D->addAttr(::new (getASTContext())
1854 HLSLVkPushConstantAttr(getASTContext(), AL));
1855}
1856
1858 uint32_t Id;
1859 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
1860 return;
1861 HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
1862 if (NewAttr)
1863 D->addAttr(NewAttr);
1864}
1865
1867 uint32_t Binding = 0;
1868 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Binding))
1869 return;
1870 uint32_t Set = 0;
1871 if (AL.getNumArgs() > 1 &&
1872 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Set))
1873 return;
1874
1875 D->addAttr(::new (getASTContext())
1876 HLSLVkBindingAttr(getASTContext(), AL, Binding, Set));
1877}
1878
1880 uint32_t Location;
1881 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Location))
1882 return;
1883
1884 D->addAttr(::new (getASTContext())
1885 HLSLVkLocationAttr(getASTContext(), AL, Location));
1886}
1887
1889 const auto *VT = T->getAs<VectorType>();
1890
1891 if (!T->hasUnsignedIntegerRepresentation() ||
1892 (VT && VT->getNumElements() > 3)) {
1893 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1894 << AL << "uint/uint2/uint3";
1895 return false;
1896 }
1897
1898 return true;
1899}
1900
1902 const auto *VT = T->getAs<VectorType>();
1903 if (!T->hasFloatingRepresentation() || (VT && VT->getNumElements() > 4)) {
1904 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1905 << AL << "float/float1/float2/float3/float4";
1906 return false;
1907 }
1908
1909 return true;
1910}
1911
1913 std::optional<unsigned> Index) {
1914 std::string SemanticName = AL.getAttrName()->getName().upper();
1915
1916 auto *VD = cast<ValueDecl>(D);
1917 QualType ValueType = VD->getType();
1918 if (auto *FD = dyn_cast<FunctionDecl>(D))
1919 ValueType = FD->getReturnType();
1920
1921 bool IsOutput = false;
1922 if (HLSLParamModifierAttr *MA = D->getAttr<HLSLParamModifierAttr>()) {
1923 if (MA->isOut()) {
1924 IsOutput = true;
1925 ValueType = cast<ReferenceType>(ValueType)->getPointeeType();
1926 }
1927 }
1928
1929 if (SemanticName == "SV_DISPATCHTHREADID") {
1930 diagnoseInputIDType(ValueType, AL);
1931 if (IsOutput)
1932 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1933 if (Index.has_value())
1934 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1936 return;
1937 }
1938
1939 if (SemanticName == "SV_GROUPINDEX") {
1940 if (IsOutput)
1941 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1942 if (Index.has_value())
1943 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1945 return;
1946 }
1947
1948 if (SemanticName == "SV_GROUPTHREADID") {
1949 diagnoseInputIDType(ValueType, AL);
1950 if (IsOutput)
1951 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1952 if (Index.has_value())
1953 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1955 return;
1956 }
1957
1958 if (SemanticName == "SV_GROUPID") {
1959 diagnoseInputIDType(ValueType, AL);
1960 if (IsOutput)
1961 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1962 if (Index.has_value())
1963 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1965 return;
1966 }
1967
1968 if (SemanticName == "SV_POSITION") {
1969 const auto *VT = ValueType->getAs<VectorType>();
1970 if (!ValueType->hasFloatingRepresentation() ||
1971 (VT && VT->getNumElements() > 4))
1972 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1973 << AL << "float/float1/float2/float3/float4";
1975 return;
1976 }
1977
1978 if (SemanticName == "SV_VERTEXID") {
1979 uint64_t SizeInBits = SemaRef.Context.getTypeSize(ValueType);
1980 if (!ValueType->isUnsignedIntegerType() || SizeInBits != 32)
1981 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type) << AL << "uint";
1983 return;
1984 }
1985
1986 if (SemanticName == "SV_TARGET") {
1987 const auto *VT = ValueType->getAs<VectorType>();
1988 if (!ValueType->hasFloatingRepresentation() ||
1989 (VT && VT->getNumElements() > 4))
1990 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1991 << AL << "float/float1/float2/float3/float4";
1993 return;
1994 }
1995
1996 Diag(AL.getLoc(), diag::err_hlsl_unknown_semantic) << AL;
1997}
1998
2000 uint32_t IndexValue(0), ExplicitIndex(0);
2001 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), IndexValue) ||
2002 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), ExplicitIndex)) {
2003 assert(0 && "HLSLUnparsedSemantic is expected to have 2 int arguments.");
2004 }
2005 assert(IndexValue > 0 ? ExplicitIndex : true);
2006 std::optional<unsigned> Index =
2007 ExplicitIndex ? std::optional<unsigned>(IndexValue) : std::nullopt;
2008
2009 if (AL.getAttrName()->getName().starts_with_insensitive("SV_"))
2010 diagnoseSystemSemanticAttr(D, AL, Index);
2011 else
2013}
2014
2017 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)
2018 << AL << "shader constant in a constant buffer";
2019 return;
2020 }
2021
2022 uint32_t SubComponent;
2023 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), SubComponent))
2024 return;
2025 uint32_t Component;
2026 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Component))
2027 return;
2028
2029 QualType T = cast<VarDecl>(D)->getType().getCanonicalType();
2030 // Check if T is an array or struct type.
2031 // TODO: mark matrix type as aggregate type.
2032 bool IsAggregateTy = (T->isArrayType() || T->isStructureType());
2033
2034 // Check Component is valid for T.
2035 if (Component) {
2036 unsigned Size = getASTContext().getTypeSize(T);
2037 if (IsAggregateTy) {
2038 Diag(AL.getLoc(), diag::err_hlsl_invalid_register_or_packoffset);
2039 return;
2040 } else {
2041 // Make sure Component + sizeof(T) <= 4.
2042 if ((Component * 32 + Size) > 128) {
2043 Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);
2044 return;
2045 }
2046 QualType EltTy = T;
2047 if (const auto *VT = T->getAs<VectorType>())
2048 EltTy = VT->getElementType();
2049 unsigned Align = getASTContext().getTypeAlign(EltTy);
2050 if (Align > 32 && Component == 1) {
2051 // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary.
2052 // So we only need to check Component 1 here.
2053 Diag(AL.getLoc(), diag::err_hlsl_packoffset_alignment_mismatch)
2054 << Align << EltTy;
2055 return;
2056 }
2057 }
2058 }
2059
2060 D->addAttr(::new (getASTContext()) HLSLPackOffsetAttr(
2061 getASTContext(), AL, SubComponent, Component));
2062}
2063
2065 StringRef Str;
2066 SourceLocation ArgLoc;
2067 if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))
2068 return;
2069
2070 llvm::Triple::EnvironmentType ShaderType;
2071 if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) {
2072 Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
2073 << AL << Str << ArgLoc;
2074 return;
2075 }
2076
2077 // FIXME: check function match the shader stage.
2078
2079 HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType);
2080 if (NewAttr)
2081 D->addAttr(NewAttr);
2082}
2083
2085 Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList,
2086 QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo) {
2087 assert(AttrList.size() && "expected list of resource attributes");
2088
2089 QualType ContainedTy = QualType();
2090 TypeSourceInfo *ContainedTyInfo = nullptr;
2091 SourceLocation LocBegin = AttrList[0]->getRange().getBegin();
2092 SourceLocation LocEnd = AttrList[0]->getRange().getEnd();
2093
2094 HLSLAttributedResourceType::Attributes ResAttrs;
2095
2096 bool HasResourceClass = false;
2097 bool HasResourceDimension = false;
2098 for (const Attr *A : AttrList) {
2099 if (!A)
2100 continue;
2101 LocEnd = A->getRange().getEnd();
2102 switch (A->getKind()) {
2103 case attr::HLSLResourceClass: {
2104 ResourceClass RC = cast<HLSLResourceClassAttr>(A)->getResourceClass();
2105 if (HasResourceClass) {
2106 S.Diag(A->getLocation(), ResAttrs.ResourceClass == RC
2107 ? diag::warn_duplicate_attribute_exact
2108 : diag::warn_duplicate_attribute)
2109 << A;
2110 return false;
2111 }
2112 ResAttrs.ResourceClass = RC;
2113 HasResourceClass = true;
2114 break;
2115 }
2116 case attr::HLSLResourceDimension: {
2117 llvm::dxil::ResourceDimension RD =
2118 cast<HLSLResourceDimensionAttr>(A)->getDimension();
2119 if (HasResourceDimension) {
2120 S.Diag(A->getLocation(), ResAttrs.ResourceDimension == RD
2121 ? diag::warn_duplicate_attribute_exact
2122 : diag::warn_duplicate_attribute)
2123 << A;
2124 return false;
2125 }
2126 ResAttrs.ResourceDimension = RD;
2127 HasResourceDimension = true;
2128 break;
2129 }
2130 case attr::HLSLROV:
2131 if (ResAttrs.IsROV) {
2132 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
2133 return false;
2134 }
2135 ResAttrs.IsROV = true;
2136 break;
2137 case attr::HLSLRawBuffer:
2138 if (ResAttrs.RawBuffer) {
2139 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
2140 return false;
2141 }
2142 ResAttrs.RawBuffer = true;
2143 break;
2144 case attr::HLSLIsCounter:
2145 if (ResAttrs.IsCounter) {
2146 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
2147 return false;
2148 }
2149 ResAttrs.IsCounter = true;
2150 break;
2151 case attr::HLSLContainedType: {
2152 const HLSLContainedTypeAttr *CTAttr = cast<HLSLContainedTypeAttr>(A);
2153 QualType Ty = CTAttr->getType();
2154 if (!ContainedTy.isNull()) {
2155 S.Diag(A->getLocation(), ContainedTy == Ty
2156 ? diag::warn_duplicate_attribute_exact
2157 : diag::warn_duplicate_attribute)
2158 << A;
2159 return false;
2160 }
2161 ContainedTy = Ty;
2162 ContainedTyInfo = CTAttr->getTypeLoc();
2163 break;
2164 }
2165 default:
2166 llvm_unreachable("unhandled resource attribute type");
2167 }
2168 }
2169
2170 if (!HasResourceClass) {
2171 S.Diag(AttrList.back()->getRange().getEnd(),
2172 diag::err_hlsl_missing_resource_class);
2173 return false;
2174 }
2175
2177 Wrapped, ContainedTy, ResAttrs);
2178
2179 if (LocInfo && ContainedTyInfo) {
2180 LocInfo->Range = SourceRange(LocBegin, LocEnd);
2181 LocInfo->ContainedTyInfo = ContainedTyInfo;
2182 }
2183 return true;
2184}
2185
2186// Validates and creates an HLSL attribute that is applied as type attribute on
2187// HLSL resource. The attributes are collected in HLSLResourcesTypeAttrs and at
2188// the end of the declaration they are applied to the declaration type by
2189// wrapping it in HLSLAttributedResourceType.
2191 // only allow resource type attributes on intangible types
2192 if (!T->isHLSLResourceType()) {
2193 Diag(AL.getLoc(), diag::err_hlsl_attribute_needs_intangible_type)
2194 << AL << getASTContext().HLSLResourceTy;
2195 return false;
2196 }
2197
2198 // validate number of arguments
2199 if (!AL.checkExactlyNumArgs(SemaRef, AL.getMinArgs()))
2200 return false;
2201
2202 Attr *A = nullptr;
2203
2207 {
2208 AttributeCommonInfo::AS_CXX11, 0, false /*IsAlignas*/,
2209 false /*IsRegularKeywordAttribute*/
2210 });
2211
2212 switch (AL.getKind()) {
2213 case ParsedAttr::AT_HLSLResourceClass: {
2214 if (!AL.isArgIdent(0)) {
2215 Diag(AL.getLoc(), diag::err_attribute_argument_type)
2216 << AL << AANT_ArgumentIdentifier;
2217 return false;
2218 }
2219
2220 IdentifierLoc *Loc = AL.getArgAsIdent(0);
2221 StringRef Identifier = Loc->getIdentifierInfo()->getName();
2222 SourceLocation ArgLoc = Loc->getLoc();
2223
2224 // Validate resource class value
2225 ResourceClass RC;
2226 if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Identifier, RC)) {
2227 Diag(ArgLoc, diag::warn_attribute_type_not_supported)
2228 << "ResourceClass" << Identifier;
2229 return false;
2230 }
2231 A = HLSLResourceClassAttr::Create(getASTContext(), RC, ACI);
2232 break;
2233 }
2234
2235 case ParsedAttr::AT_HLSLResourceDimension: {
2236 StringRef Identifier;
2237 SourceLocation ArgLoc;
2238 if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Identifier, &ArgLoc))
2239 return false;
2240
2241 // Validate resource dimension value
2242 llvm::dxil::ResourceDimension RD;
2243 if (!HLSLResourceDimensionAttr::ConvertStrToResourceDimension(Identifier,
2244 RD)) {
2245 Diag(ArgLoc, diag::warn_attribute_type_not_supported)
2246 << "ResourceDimension" << Identifier;
2247 return false;
2248 }
2249 A = HLSLResourceDimensionAttr::Create(getASTContext(), RD, ACI);
2250 break;
2251 }
2252
2253 case ParsedAttr::AT_HLSLROV:
2254 A = HLSLROVAttr::Create(getASTContext(), ACI);
2255 break;
2256
2257 case ParsedAttr::AT_HLSLRawBuffer:
2258 A = HLSLRawBufferAttr::Create(getASTContext(), ACI);
2259 break;
2260
2261 case ParsedAttr::AT_HLSLIsCounter:
2262 A = HLSLIsCounterAttr::Create(getASTContext(), ACI);
2263 break;
2264
2265 case ParsedAttr::AT_HLSLContainedType: {
2266 if (AL.getNumArgs() != 1 && !AL.hasParsedType()) {
2267 Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1;
2268 return false;
2269 }
2270
2271 TypeSourceInfo *TSI = nullptr;
2272 QualType QT = SemaRef.GetTypeFromParser(AL.getTypeArg(), &TSI);
2273 assert(TSI && "no type source info for attribute argument");
2274 if (SemaRef.RequireCompleteType(TSI->getTypeLoc().getBeginLoc(), QT,
2275 diag::err_incomplete_type))
2276 return false;
2277 A = HLSLContainedTypeAttr::Create(getASTContext(), TSI, ACI);
2278 break;
2279 }
2280
2281 default:
2282 llvm_unreachable("unhandled HLSL attribute");
2283 }
2284
2285 HLSLResourcesTypeAttrs.emplace_back(A);
2286 return true;
2287}
2288
2289// Combines all resource type attributes and creates HLSLAttributedResourceType.
2291 if (!HLSLResourcesTypeAttrs.size())
2292 return CurrentType;
2293
2294 QualType QT = CurrentType;
2297 HLSLResourcesTypeAttrs, QT, &LocInfo)) {
2298 const HLSLAttributedResourceType *RT =
2300
2301 // Temporarily store TypeLoc information for the new type.
2302 // It will be transferred to HLSLAttributesResourceTypeLoc
2303 // shortly after the type is created by TypeSpecLocFiller which
2304 // will call the TakeLocForHLSLAttribute method below.
2305 LocsForHLSLAttributedResources.insert(std::pair(RT, LocInfo));
2306 }
2307 HLSLResourcesTypeAttrs.clear();
2308 return QT;
2309}
2310
2311// Returns source location for the HLSLAttributedResourceType
2313SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
2314 HLSLAttributedResourceLocInfo LocInfo = {};
2315 auto I = LocsForHLSLAttributedResources.find(RT);
2316 if (I != LocsForHLSLAttributedResources.end()) {
2317 LocInfo = I->second;
2318 LocsForHLSLAttributedResources.erase(I);
2319 return LocInfo;
2320 }
2321 LocInfo.Range = SourceRange();
2322 return LocInfo;
2323}
2324
2325// Walks though the global variable declaration, collects all resource binding
2326// requirements and adds them to Bindings
2327void SemaHLSL::collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
2328 const RecordType *RT) {
2329 const RecordDecl *RD = RT->getDecl()->getDefinitionOrSelf();
2330 for (FieldDecl *FD : RD->fields()) {
2331 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
2332
2333 // Unwrap arrays
2334 // FIXME: Calculate array size while unwrapping
2335 assert(!Ty->isIncompleteArrayType() &&
2336 "incomplete arrays inside user defined types are not supported");
2337 while (Ty->isConstantArrayType()) {
2340 }
2341
2342 if (!Ty->isRecordType())
2343 continue;
2344
2345 if (const HLSLAttributedResourceType *AttrResType =
2346 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
2347 // Add a new DeclBindingInfo to Bindings if it does not already exist
2348 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
2349 DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, RC);
2350 if (!DBI)
2351 Bindings.addDeclBindingInfo(VD, RC);
2352 } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
2353 // Recursively scan embedded struct or class; it would be nice to do this
2354 // without recursion, but tricky to correctly calculate the size of the
2355 // binding, which is something we are probably going to need to do later
2356 // on. Hopefully nesting of structs in structs too many levels is
2357 // unlikely.
2358 collectResourceBindingsOnUserRecordDecl(VD, RT);
2359 }
2360 }
2361}
2362
2363// Diagnose localized register binding errors for a single binding; does not
2364// diagnose resource binding on user record types, that will be done later
2365// in processResourceBindingOnDecl based on the information collected in
2366// collectResourceBindingsOnVarDecl.
2367// Returns false if the register binding is not valid.
2369 Decl *D, RegisterType RegType,
2370 bool SpecifiedSpace) {
2371 int RegTypeNum = static_cast<int>(RegType);
2372
2373 // check if the decl type is groupshared
2374 if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
2375 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2376 return false;
2377 }
2378
2379 // Cbuffers and Tbuffers are HLSLBufferDecl types
2380 if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
2381 ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
2382 : ResourceClass::SRV;
2383 if (RegType == getRegisterType(RC))
2384 return true;
2385
2386 S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
2387 << RegTypeNum;
2388 return false;
2389 }
2390
2391 // Samplers, UAVs, and SRVs are VarDecl types
2392 assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
2393 VarDecl *VD = cast<VarDecl>(D);
2394
2395 // Resource
2396 if (const HLSLAttributedResourceType *AttrResType =
2397 HLSLAttributedResourceType::findHandleTypeOnResource(
2398 VD->getType().getTypePtr())) {
2399 if (RegType == getRegisterType(AttrResType))
2400 return true;
2401
2402 S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
2403 << RegTypeNum;
2404 return false;
2405 }
2406
2407 const clang::Type *Ty = VD->getType().getTypePtr();
2408 while (Ty->isArrayType())
2410
2411 // Basic types
2412 if (Ty->isArithmeticType() || Ty->isVectorType()) {
2413 bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
2414 if (SpecifiedSpace && !DeclaredInCOrTBuffer)
2415 S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
2416
2417 if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(S.getASTContext()) ||
2418 Ty->isFloatingType() || Ty->isVectorType())) {
2419 // Register annotation on default constant buffer declaration ($Globals)
2420 if (RegType == RegisterType::CBuffer)
2421 S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
2422 else if (RegType != RegisterType::C)
2423 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2424 else
2425 return true;
2426 } else {
2427 if (RegType == RegisterType::C)
2428 S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
2429 else
2430 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2431 }
2432 return false;
2433 }
2434 if (Ty->isRecordType())
2435 // RecordTypes will be diagnosed in processResourceBindingOnDecl
2436 // that is called from ActOnVariableDeclarator
2437 return true;
2438
2439 // Anything else is an error
2440 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2441 return false;
2442}
2443
2445 RegisterType regType) {
2446 // make sure that there are no two register annotations
2447 // applied to the decl with the same register type
2448 bool RegisterTypesDetected[5] = {false};
2449 RegisterTypesDetected[static_cast<int>(regType)] = true;
2450
2451 for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
2452 if (HLSLResourceBindingAttr *attr =
2453 dyn_cast<HLSLResourceBindingAttr>(*it)) {
2454
2455 RegisterType otherRegType = attr->getRegisterType();
2456 if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
2457 int otherRegTypeNum = static_cast<int>(otherRegType);
2458 S.Diag(TheDecl->getLocation(),
2459 diag::err_hlsl_duplicate_register_annotation)
2460 << otherRegTypeNum;
2461 return false;
2462 }
2463 RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
2464 }
2465 }
2466 return true;
2467}
2468
2470 Decl *D, RegisterType RegType,
2471 bool SpecifiedSpace) {
2472
2473 // exactly one of these two types should be set
2474 assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
2475 (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
2476 "expecting VarDecl or HLSLBufferDecl");
2477
2478 // check if the declaration contains resource matching the register type
2479 if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
2480 return false;
2481
2482 // next, if multiple register annotations exist, check that none conflict.
2483 return ValidateMultipleRegisterAnnotations(S, D, RegType);
2484}
2485
2486// return false if the slot count exceeds the limit, true otherwise
2487static bool AccumulateHLSLResourceSlots(QualType Ty, uint64_t &StartSlot,
2488 const uint64_t &Limit,
2489 const ResourceClass ResClass,
2490 ASTContext &Ctx,
2491 uint64_t ArrayCount = 1) {
2492 Ty = Ty.getCanonicalType();
2493 const Type *T = Ty.getTypePtr();
2494
2495 // Early exit if already overflowed
2496 if (StartSlot > Limit)
2497 return false;
2498
2499 // Case 1: array type
2500 if (const auto *AT = dyn_cast<ArrayType>(T)) {
2501 uint64_t Count = 1;
2502
2503 if (const auto *CAT = dyn_cast<ConstantArrayType>(AT))
2504 Count = CAT->getSize().getZExtValue();
2505
2506 QualType ElemTy = AT->getElementType();
2507 return AccumulateHLSLResourceSlots(ElemTy, StartSlot, Limit, ResClass, Ctx,
2508 ArrayCount * Count);
2509 }
2510
2511 // Case 2: resource leaf
2512 if (auto ResTy = dyn_cast<HLSLAttributedResourceType>(T)) {
2513 // First ensure this resource counts towards the corresponding
2514 // register type limit.
2515 if (ResTy->getAttrs().ResourceClass != ResClass)
2516 return true;
2517
2518 // Validate highest slot used
2519 uint64_t EndSlot = StartSlot + ArrayCount - 1;
2520 if (EndSlot > Limit)
2521 return false;
2522
2523 // Advance SlotCount past the consumed range
2524 StartSlot = EndSlot + 1;
2525 return true;
2526 }
2527
2528 // Case 3: struct / record
2529 if (const auto *RT = dyn_cast<RecordType>(T)) {
2530 const RecordDecl *RD = RT->getDecl();
2531
2532 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
2533 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
2534 if (!AccumulateHLSLResourceSlots(Base.getType(), StartSlot, Limit,
2535 ResClass, Ctx, ArrayCount))
2536 return false;
2537 }
2538 }
2539
2540 for (const FieldDecl *Field : RD->fields()) {
2541 if (!AccumulateHLSLResourceSlots(Field->getType(), StartSlot, Limit,
2542 ResClass, Ctx, ArrayCount))
2543 return false;
2544 }
2545
2546 return true;
2547 }
2548
2549 // Case 4: everything else
2550 return true;
2551}
2552
2553// return true if there is something invalid, false otherwise
2554static bool ValidateRegisterNumber(uint64_t SlotNum, Decl *TheDecl,
2555 ASTContext &Ctx, RegisterType RegTy) {
2556 const uint64_t Limit = UINT32_MAX;
2557 if (SlotNum > Limit)
2558 return true;
2559
2560 // after verifying the number doesn't exceed uint32max, we don't need
2561 // to look further into c or i register types
2562 if (RegTy == RegisterType::C || RegTy == RegisterType::I)
2563 return false;
2564
2565 if (VarDecl *VD = dyn_cast<VarDecl>(TheDecl)) {
2566 uint64_t BaseSlot = SlotNum;
2567
2568 if (!AccumulateHLSLResourceSlots(VD->getType(), SlotNum, Limit,
2569 getResourceClass(RegTy), Ctx))
2570 return true;
2571
2572 // After AccumulateHLSLResourceSlots runs, SlotNum is now
2573 // the first free slot; last used was SlotNum - 1
2574 return (BaseSlot > Limit);
2575 }
2576 // handle the cbuffer/tbuffer case
2577 if (isa<HLSLBufferDecl>(TheDecl))
2578 // resources cannot be put within a cbuffer, so no need
2579 // to analyze the structure since the register number
2580 // won't be pushed any higher.
2581 return (SlotNum > Limit);
2582
2583 // we don't expect any other decl type, so fail
2584 llvm_unreachable("unexpected decl type");
2585}
2586
2588 if (VarDecl *VD = dyn_cast<VarDecl>(TheDecl)) {
2589 QualType Ty = VD->getType();
2590 if (const auto *IAT = dyn_cast<IncompleteArrayType>(Ty))
2591 Ty = IAT->getElementType();
2592 if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(), Ty,
2593 diag::err_incomplete_type))
2594 return;
2595 }
2596
2597 StringRef Slot = "";
2598 StringRef Space = "";
2599 SourceLocation SlotLoc, SpaceLoc;
2600
2601 if (!AL.isArgIdent(0)) {
2602 Diag(AL.getLoc(), diag::err_attribute_argument_type)
2603 << AL << AANT_ArgumentIdentifier;
2604 return;
2605 }
2606 IdentifierLoc *Loc = AL.getArgAsIdent(0);
2607
2608 if (AL.getNumArgs() == 2) {
2609 Slot = Loc->getIdentifierInfo()->getName();
2610 SlotLoc = Loc->getLoc();
2611 if (!AL.isArgIdent(1)) {
2612 Diag(AL.getLoc(), diag::err_attribute_argument_type)
2613 << AL << AANT_ArgumentIdentifier;
2614 return;
2615 }
2616 Loc = AL.getArgAsIdent(1);
2617 Space = Loc->getIdentifierInfo()->getName();
2618 SpaceLoc = Loc->getLoc();
2619 } else {
2620 StringRef Str = Loc->getIdentifierInfo()->getName();
2621 if (Str.starts_with("space")) {
2622 Space = Str;
2623 SpaceLoc = Loc->getLoc();
2624 } else {
2625 Slot = Str;
2626 SlotLoc = Loc->getLoc();
2627 Space = "space0";
2628 }
2629 }
2630
2631 RegisterType RegType = RegisterType::SRV;
2632 std::optional<unsigned> SlotNum;
2633 unsigned SpaceNum = 0;
2634
2635 // Validate slot
2636 if (!Slot.empty()) {
2637 if (!convertToRegisterType(Slot, &RegType)) {
2638 Diag(SlotLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1);
2639 return;
2640 }
2641 if (RegType == RegisterType::I) {
2642 Diag(SlotLoc, diag::warn_hlsl_deprecated_register_type_i);
2643 return;
2644 }
2645 const StringRef SlotNumStr = Slot.substr(1);
2646
2647 uint64_t N;
2648
2649 // validate that the slot number is a non-empty number
2650 if (SlotNumStr.getAsInteger(10, N)) {
2651 Diag(SlotLoc, diag::err_hlsl_unsupported_register_number);
2652 return;
2653 }
2654
2655 // Validate register number. It should not exceed UINT32_MAX,
2656 // including if the resource type is an array that starts
2657 // before UINT32_MAX, but ends afterwards.
2658 if (ValidateRegisterNumber(N, TheDecl, getASTContext(), RegType)) {
2659 Diag(SlotLoc, diag::err_hlsl_register_number_too_large);
2660 return;
2661 }
2662
2663 // the slot number has been validated and does not exceed UINT32_MAX
2664 SlotNum = (unsigned)N;
2665 }
2666
2667 // Validate space
2668 if (!Space.starts_with("space")) {
2669 Diag(SpaceLoc, diag::err_hlsl_expected_space) << Space;
2670 return;
2671 }
2672 StringRef SpaceNumStr = Space.substr(5);
2673 if (SpaceNumStr.getAsInteger(10, SpaceNum)) {
2674 Diag(SpaceLoc, diag::err_hlsl_expected_space) << Space;
2675 return;
2676 }
2677
2678 // If we have slot, diagnose it is the right register type for the decl
2679 if (SlotNum.has_value())
2680 if (!DiagnoseHLSLRegisterAttribute(SemaRef, SlotLoc, TheDecl, RegType,
2681 !SpaceLoc.isInvalid()))
2682 return;
2683
2684 HLSLResourceBindingAttr *NewAttr =
2685 HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL);
2686 if (NewAttr) {
2687 NewAttr->setBinding(RegType, SlotNum, SpaceNum);
2688 TheDecl->addAttr(NewAttr);
2689 }
2690}
2691
2693 HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr(
2694 D, AL,
2695 static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
2696 if (NewAttr)
2697 D->addAttr(NewAttr);
2698}
2699
2700static bool isMatrixOrArrayOfMatrix(const ASTContext &Ctx, QualType QT) {
2701 const Type *Ty = QT->getUnqualifiedDesugaredType();
2702 while (isa<ArrayType>(Ty))
2704 return Ty->isDependentType() || Ty->isConstantMatrixType();
2705}
2706
2708 SourceLocation Loc,
2709 const IdentifierInfo *AttrName) {
2710 QualType Ty;
2711 if (auto *VD = dyn_cast<ValueDecl>(D))
2712 Ty = VD->getType();
2713 else if (auto *TD = dyn_cast<TypedefNameDecl>(D))
2714 Ty = TD->getUnderlyingType();
2715
2716 if (Ty.isNull() || Ty->isDependentType())
2717 return false;
2718
2719 // For functions, the qualifier can apply to the return type or any parameter.
2720 if (const auto *FPT = Ty->getAs<FunctionProtoType>()) {
2721 if (isMatrixOrArrayOfMatrix(SemaRef.getASTContext(), FPT->getReturnType()))
2722 return false;
2723 SemaRef.Diag(Loc, diag::err_hlsl_matrix_layout_non_matrix) << AttrName;
2724 return true;
2725 }
2726
2727 if (isMatrixOrArrayOfMatrix(SemaRef.getASTContext(), Ty))
2728 return false;
2729
2730 SemaRef.Diag(Loc, diag::err_hlsl_matrix_layout_non_matrix) << AttrName;
2731 return true;
2732}
2733
2735 // row_major and column_major are only valid on matrix types.
2737 AL.getAttrName()))
2738 return;
2739
2740 // Check for conflicting or duplicate matrix layout attributes.
2741 if (const auto *Existing = D->getAttr<HLSLMatrixLayoutAttr>()) {
2742 if (Existing->getSemanticSpelling() != AL.getSemanticSpelling()) {
2743 Diag(AL.getLoc(), diag::err_hlsl_matrix_layout_conflict)
2744 << AL.getAttrName() << Existing->getAttrName();
2745 Diag(Existing->getLoc(), diag::note_conflicting_attribute);
2746 } else {
2747 Diag(AL.getLoc(), diag::warn_duplicate_attribute_exact)
2748 << AL.getAttrName();
2749 Diag(Existing->getLoc(), diag::note_previous_attribute);
2750 }
2751 return;
2752 }
2753
2754 D->addAttr(::new (getASTContext()) HLSLMatrixLayoutAttr(getASTContext(), AL));
2755}
2756
2758 Decl *D, const HLSLMatrixLayoutAttr *Attr) {
2760 Attr->getAttrName());
2761}
2762
2763namespace {
2764
2765/// This class implements HLSL availability diagnostics for default
2766/// and relaxed mode
2767///
2768/// The goal of this diagnostic is to emit an error or warning when an
2769/// unavailable API is found in code that is reachable from the shader
2770/// entry function or from an exported function (when compiling a shader
2771/// library).
2772///
2773/// This is done by traversing the AST of all shader entry point functions
2774/// and of all exported functions, and any functions that are referenced
2775/// from this AST. In other words, any functions that are reachable from
2776/// the entry points.
2777class DiagnoseHLSLAvailability : public DynamicRecursiveASTVisitor {
2778 Sema &SemaRef;
2779
2780 // Stack of functions to be scaned
2782
2783 // Tracks which environments functions have been scanned in.
2784 //
2785 // Maps FunctionDecl to an unsigned number that represents the set of shader
2786 // environments the function has been scanned for.
2787 // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
2788 // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
2789 // (verified by static_asserts in Triple.cpp), we can use it to index
2790 // individual bits in the set, as long as we shift the values to start with 0
2791 // by subtracting the value of llvm::Triple::Pixel first.
2792 //
2793 // The N'th bit in the set will be set if the function has been scanned
2794 // in shader environment whose llvm::Triple::EnvironmentType integer value
2795 // equals (llvm::Triple::Pixel + N).
2796 //
2797 // For example, if a function has been scanned in compute and pixel stage
2798 // environment, the value will be 0x21 (100001 binary) because:
2799 //
2800 // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
2801 // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
2802 //
2803 // A FunctionDecl is mapped to 0 (or not included in the map) if it has not
2804 // been scanned in any environment.
2805 llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
2806
2807 // Do not access these directly, use the get/set methods below to make
2808 // sure the values are in sync
2809 llvm::Triple::EnvironmentType CurrentShaderEnvironment;
2810 unsigned CurrentShaderStageBit;
2811
2812 // True if scanning a function that was already scanned in a different
2813 // shader stage context, and therefore we should not report issues that
2814 // depend only on shader model version because they would be duplicate.
2815 bool ReportOnlyShaderStageIssues;
2816
2817 // Helper methods for dealing with current stage context / environment
2818 void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
2819 static_assert(sizeof(unsigned) >= 4);
2820 assert(HLSLShaderAttr::isValidShaderType(ShaderType));
2821 assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
2822 "ShaderType is too big for this bitmap"); // 31 is reserved for
2823 // "unknown"
2824
2825 unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
2826 CurrentShaderEnvironment = ShaderType;
2827 CurrentShaderStageBit = (1 << bitmapIndex);
2828 }
2829
2830 void SetUnknownShaderStageContext() {
2831 CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment;
2832 CurrentShaderStageBit = (1 << 31);
2833 }
2834
2835 llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const {
2836 return CurrentShaderEnvironment;
2837 }
2838
2839 bool InUnknownShaderStageContext() const {
2840 return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment;
2841 }
2842
2843 // Helper methods for dealing with shader stage bitmap
2844 void AddToScannedFunctions(const FunctionDecl *FD) {
2845 unsigned &ScannedStages = ScannedDecls[FD];
2846 ScannedStages |= CurrentShaderStageBit;
2847 }
2848
2849 unsigned GetScannedStages(const FunctionDecl *FD) { return ScannedDecls[FD]; }
2850
2851 bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) {
2852 return WasAlreadyScannedInCurrentStage(GetScannedStages(FD));
2853 }
2854
2855 bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) {
2856 return ScannerStages & CurrentShaderStageBit;
2857 }
2858
2859 static bool NeverBeenScanned(unsigned ScannedStages) {
2860 return ScannedStages == 0;
2861 }
2862
2863 // Scanning methods
2864 void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr);
2865 void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA,
2866 SourceRange Range);
2867 const AvailabilityAttr *FindAvailabilityAttr(const Decl *D);
2868 bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA);
2869
2870public:
2871 DiagnoseHLSLAvailability(Sema &SemaRef)
2872 : SemaRef(SemaRef),
2873 CurrentShaderEnvironment(llvm::Triple::UnknownEnvironment),
2874 CurrentShaderStageBit(0), ReportOnlyShaderStageIssues(false) {}
2875
2876 // AST traversal methods
2877 void RunOnTranslationUnit(const TranslationUnitDecl *TU);
2878 void RunOnFunction(const FunctionDecl *FD);
2879
2880 bool VisitDeclRefExpr(DeclRefExpr *DRE) override {
2881 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(DRE->getDecl());
2882 if (FD)
2883 HandleFunctionOrMethodRef(FD, DRE);
2884 return true;
2885 }
2886
2887 bool VisitMemberExpr(MemberExpr *ME) override {
2888 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(ME->getMemberDecl());
2889 if (FD)
2890 HandleFunctionOrMethodRef(FD, ME);
2891 return true;
2892 }
2893};
2894
2895void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD,
2896 Expr *RefExpr) {
2897 assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) &&
2898 "expected DeclRefExpr or MemberExpr");
2899
2900 // has a definition -> add to stack to be scanned
2901 const FunctionDecl *FDWithBody = nullptr;
2902 if (FD->hasBody(FDWithBody)) {
2903 if (!WasAlreadyScannedInCurrentStage(FDWithBody))
2904 DeclsToScan.push_back(FDWithBody);
2905 return;
2906 }
2907
2908 // no body -> diagnose availability
2909 const AvailabilityAttr *AA = FindAvailabilityAttr(FD);
2910 if (AA)
2911 CheckDeclAvailability(
2912 FD, AA, SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc()));
2913}
2914
2915void DiagnoseHLSLAvailability::RunOnTranslationUnit(
2916 const TranslationUnitDecl *TU) {
2917
2918 // Iterate over all shader entry functions and library exports, and for those
2919 // that have a body (definiton), run diag scan on each, setting appropriate
2920 // shader environment context based on whether it is a shader entry function
2921 // or an exported function. Exported functions can be in namespaces and in
2922 // export declarations so we need to scan those declaration contexts as well.
2924 DeclContextsToScan.push_back(TU);
2925
2926 while (!DeclContextsToScan.empty()) {
2927 const DeclContext *DC = DeclContextsToScan.pop_back_val();
2928 for (auto &D : DC->decls()) {
2929 // do not scan implicit declaration generated by the implementation
2930 if (D->isImplicit())
2931 continue;
2932
2933 // for namespace or export declaration add the context to the list to be
2934 // scanned later
2935 if (llvm::dyn_cast<NamespaceDecl>(D) || llvm::dyn_cast<ExportDecl>(D)) {
2936 DeclContextsToScan.push_back(llvm::dyn_cast<DeclContext>(D));
2937 continue;
2938 }
2939
2940 // skip over other decls or function decls without body
2941 const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(D);
2942 if (!FD || !FD->isThisDeclarationADefinition())
2943 continue;
2944
2945 // shader entry point
2946 if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) {
2947 SetShaderStageContext(ShaderAttr->getType());
2948 RunOnFunction(FD);
2949 continue;
2950 }
2951 // exported library function
2952 // FIXME: replace this loop with external linkage check once issue #92071
2953 // is resolved
2954 bool isExport = FD->isInExportDeclContext();
2955 if (!isExport) {
2956 for (const auto *Redecl : FD->redecls()) {
2957 if (Redecl->isInExportDeclContext()) {
2958 isExport = true;
2959 break;
2960 }
2961 }
2962 }
2963 if (isExport) {
2964 SetUnknownShaderStageContext();
2965 RunOnFunction(FD);
2966 continue;
2967 }
2968 }
2969 }
2970}
2971
2972void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) {
2973 assert(DeclsToScan.empty() && "DeclsToScan should be empty");
2974 DeclsToScan.push_back(FD);
2975
2976 while (!DeclsToScan.empty()) {
2977 // Take one decl from the stack and check it by traversing its AST.
2978 // For any CallExpr found during the traversal add it's callee to the top of
2979 // the stack to be processed next. Functions already processed are stored in
2980 // ScannedDecls.
2981 const FunctionDecl *FD = DeclsToScan.pop_back_val();
2982
2983 // Decl was already scanned
2984 const unsigned ScannedStages = GetScannedStages(FD);
2985 if (WasAlreadyScannedInCurrentStage(ScannedStages))
2986 continue;
2987
2988 ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages);
2989
2990 AddToScannedFunctions(FD);
2991 TraverseStmt(FD->getBody());
2992 }
2993}
2994
2995bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone(
2996 const AvailabilityAttr *AA) {
2997 const IdentifierInfo *IIEnvironment = AA->getEnvironment();
2998 if (!IIEnvironment)
2999 return true;
3000
3001 llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment();
3002 if (CurrentEnv == llvm::Triple::UnknownEnvironment)
3003 return false;
3004
3005 llvm::Triple::EnvironmentType AttrEnv =
3006 AvailabilityAttr::getEnvironmentType(IIEnvironment->getName());
3007
3008 return CurrentEnv == AttrEnv;
3009}
3010
3011const AvailabilityAttr *
3012DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) {
3013 AvailabilityAttr const *PartialMatch = nullptr;
3014 // Check each AvailabilityAttr to find the one for this platform.
3015 // For multiple attributes with the same platform try to find one for this
3016 // environment.
3017 for (const auto *A : D->attrs()) {
3018 if (const auto *Avail = dyn_cast<AvailabilityAttr>(A)) {
3019 const AvailabilityAttr *EffectiveAvail = Avail->getEffectiveAttr();
3020 StringRef AttrPlatform = EffectiveAvail->getPlatform()->getName();
3021 StringRef TargetPlatform =
3023
3024 // Match the platform name.
3025 if (AttrPlatform == TargetPlatform) {
3026 // Find the best matching attribute for this environment
3027 if (HasMatchingEnvironmentOrNone(EffectiveAvail))
3028 return Avail;
3029 PartialMatch = Avail;
3030 }
3031 }
3032 }
3033 return PartialMatch;
3034}
3035
3036// Check availability against target shader model version and current shader
3037// stage and emit diagnostic
3038void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,
3039 const AvailabilityAttr *AA,
3040 SourceRange Range) {
3041
3042 const IdentifierInfo *IIEnv = AA->getEnvironment();
3043
3044 if (!IIEnv) {
3045 // The availability attribute does not have environment -> it depends only
3046 // on shader model version and not on specific the shader stage.
3047
3048 // Skip emitting the diagnostics if the diagnostic mode is set to
3049 // strict (-fhlsl-strict-availability) because all relevant diagnostics
3050 // were already emitted in the DiagnoseUnguardedAvailability scan
3051 // (SemaAvailability.cpp).
3052 if (SemaRef.getLangOpts().HLSLStrictAvailability)
3053 return;
3054
3055 // Do not report shader-stage-independent issues if scanning a function
3056 // that was already scanned in a different shader stage context (they would
3057 // be duplicate)
3058 if (ReportOnlyShaderStageIssues)
3059 return;
3060
3061 } else {
3062 // The availability attribute has environment -> we need to know
3063 // the current stage context to property diagnose it.
3064 if (InUnknownShaderStageContext())
3065 return;
3066 }
3067
3068 // Check introduced version and if environment matches
3069 bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA);
3070 VersionTuple Introduced = AA->getIntroduced();
3071 VersionTuple TargetVersion =
3073
3074 if (TargetVersion >= Introduced && EnvironmentMatches)
3075 return;
3076
3077 // Emit diagnostic message
3078 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
3079 llvm::StringRef PlatformName(
3080 AvailabilityAttr::getPrettyPlatformName(TI.getPlatformName()));
3081
3082 llvm::StringRef CurrentEnvStr =
3083 llvm::Triple::getEnvironmentTypeName(GetCurrentShaderEnvironment());
3084
3085 llvm::StringRef AttrEnvStr =
3086 AA->getEnvironment() ? AA->getEnvironment()->getName() : "";
3087 bool UseEnvironment = !AttrEnvStr.empty();
3088
3089 if (EnvironmentMatches) {
3090 SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability)
3091 << Range << D << PlatformName << Introduced.getAsString()
3092 << UseEnvironment << CurrentEnvStr;
3093 } else {
3094 SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability_unavailable)
3095 << Range << D;
3096 }
3097
3098 SemaRef.Diag(D->getLocation(), diag::note_partial_availability_specified_here)
3099 << D << PlatformName << Introduced.getAsString()
3100 << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString()
3101 << UseEnvironment << AttrEnvStr << CurrentEnvStr;
3102}
3103
3104} // namespace
3105
3107 // process default CBuffer - create buffer layout struct and invoke codegenCGH
3108 if (!DefaultCBufferDecls.empty()) {
3110 SemaRef.getASTContext(), SemaRef.getCurLexicalContext(),
3111 DefaultCBufferDecls);
3112 addImplicitBindingAttrToDecl(SemaRef, DefaultCBuffer, RegisterType::CBuffer,
3114 SemaRef.getCurLexicalContext()->addDecl(DefaultCBuffer);
3116
3117 // Set HasValidPackoffset if any of the decls has a register(c#) annotation;
3118 for (const Decl *VD : DefaultCBufferDecls) {
3119 const HLSLResourceBindingAttr *RBA =
3120 VD->getAttr<HLSLResourceBindingAttr>();
3121 if (RBA && RBA->hasRegisterSlot() &&
3122 RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) {
3123 DefaultCBuffer->setHasValidPackoffset(true);
3124 break;
3125 }
3126 }
3127
3128 DeclGroupRef DG(DefaultCBuffer);
3129 SemaRef.Consumer.HandleTopLevelDecl(DG);
3130 }
3131 diagnoseAvailabilityViolations(TU);
3132}
3133
3134// For resource member access through a global struct array, verify that the
3135// array index selecting the struct element is a constant integer expression.
3136// Returns false if the member expression is invalid.
3138 assert((ME->getType()->isHLSLResourceRecord() ||
3140 "expected member expr to have resource record type or array of them");
3141
3142 // Walk the AST from MemberExpr to the VarDecl of the parent struct instance
3143 // and take note of any non-constant array indexing along the way. If the
3144 // VarDecl we find is a global variable, report error if there was any
3145 // non-constant array index in the resource member access along the way.
3146 const Expr *NonConstIndexExpr = nullptr;
3147 const Expr *E = ME->getBase();
3148 while (E) {
3149 if (const DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E)) {
3150 if (!NonConstIndexExpr)
3151 return true;
3152
3153 const VarDecl *VD = cast<VarDecl>(DRE->getDecl());
3154 if (!VD->hasGlobalStorage())
3155 return true;
3156
3157 SemaRef.Diag(NonConstIndexExpr->getExprLoc(),
3158 diag::err_hlsl_resource_member_array_access_not_constant);
3159 return false;
3160 }
3161
3162 if (const auto *ASE = dyn_cast<ArraySubscriptExpr>(E)) {
3163 const Expr *IdxExpr = ASE->getIdx();
3164 if (!IdxExpr->isIntegerConstantExpr(SemaRef.getASTContext()))
3165 NonConstIndexExpr = IdxExpr;
3166 E = ASE->getBase();
3167 } else if (const auto *SubME = dyn_cast<MemberExpr>(E)) {
3168 E = SubME->getBase();
3169 } else if (const auto *ICE = dyn_cast<ImplicitCastExpr>(E)) {
3170 E = ICE->getSubExpr();
3171 } else {
3172 llvm_unreachable("unexpected expr type in resource member access");
3173 }
3174 }
3175 return true;
3176}
3177
3179 CXXRecordDecl *RD) {
3180 QualType AddrSpaceType =
3181 SemaRef.Context.getCanonicalType(SemaRef.Context.getAddrSpaceQualType(
3182 Type.withConst(), LangAS::hlsl_constant));
3183 QualType ReturnTy = SemaRef.Context.getCanonicalType(
3184 SemaRef.Context.getLValueReferenceType(AddrSpaceType));
3185
3186 DeclarationName ConvName =
3187 SemaRef.Context.DeclarationNames.getCXXConversionFunctionName(
3188 CanQualType::CreateUnsafe(ReturnTy));
3189 LookupResult ConvR(SemaRef, ConvName, SourceLocation(),
3191 [[maybe_unused]] bool LookupSucceeded =
3192 SemaRef.LookupQualifiedName(ConvR, RD);
3193 assert(LookupSucceeded);
3194
3195 for (NamedDecl *D : ConvR) {
3197 return D;
3198 }
3199 return nullptr;
3200}
3201
3202std::optional<ExprResult>
3204 QualType BaseType = BaseExpr.get()->getType();
3205 const HLSLAttributedResourceType *ResTy =
3206 HLSLAttributedResourceType::findHandleTypeOnResource(
3207 BaseType.getTypePtr());
3208 if (!ResTy ||
3209 ResTy->getAttrs().ResourceClass != llvm::dxil::ResourceClass::CBuffer)
3210 return std::nullopt;
3211
3212 QualType TemplateType = ResTy->getContainedType();
3213
3214 NamedDecl *NamedConversionDecl = getConstantBufferConversionFunction(
3215 TemplateType, BaseType->getAsCXXRecordDecl());
3216 assert(NamedConversionDecl &&
3217 "Could not find conversion function for ConstantBuffer.");
3218 auto *ConversionDecl =
3219 cast<CXXConversionDecl>(NamedConversionDecl->getUnderlyingDecl());
3220
3221 return SemaRef.BuildCXXMemberCallExpr(BaseExpr.get(), NamedConversionDecl,
3222 ConversionDecl,
3223 /*HadMultipleCandidates=*/false);
3224}
3225
3226void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
3227 // Skip running the diagnostics scan if the diagnostic mode is
3228 // strict (-fhlsl-strict-availability) and the target shader stage is known
3229 // because all relevant diagnostics were already emitted in the
3230 // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp).
3232 if (SemaRef.getLangOpts().HLSLStrictAvailability &&
3233 TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library)
3234 return;
3235
3236 DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
3237}
3238
3239static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
3240 assert(TheCall->getNumArgs() > 1);
3241 QualType ArgTy0 = TheCall->getArg(0)->getType();
3242
3243 for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) {
3245 ArgTy0, TheCall->getArg(I)->getType())) {
3246 S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
3247 << TheCall->getDirectCallee() << /*useAllTerminology*/ true
3248 << SourceRange(TheCall->getArg(0)->getBeginLoc(),
3249 TheCall->getArg(N - 1)->getEndLoc());
3250 return true;
3251 }
3252 }
3253 return false;
3254}
3255
3257 QualType ArgType = Arg->getType();
3259 S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
3260 << ArgType << ExpectedType << 1 << 0 << 0;
3261 return true;
3262 }
3263 return false;
3264}
3265
3267 Sema *S, CallExpr *TheCall,
3268 llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
3269 clang::QualType PassedType)>
3270 Check) {
3271 for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
3272 Expr *Arg = TheCall->getArg(I);
3273 if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
3274 return true;
3275 }
3276 return false;
3277}
3278
3280 int ArgOrdinal,
3281 clang::QualType PassedType) {
3282 clang::QualType BaseType =
3283 PassedType->isVectorType()
3284 ? PassedType->castAs<clang::VectorType>()->getElementType()
3285 : PassedType;
3286 if (!BaseType->isFloat32Type())
3287 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3288 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3289 << /* float */ 1 << PassedType;
3290 return false;
3291}
3292
3294 int ArgOrdinal,
3295 clang::QualType PassedType) {
3296 clang::QualType BaseType =
3297 PassedType->isVectorType()
3298 ? PassedType->castAs<clang::VectorType>()->getElementType()
3299 : PassedType;
3300 if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
3301 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3302 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3303 << /* half or float */ 2 << PassedType;
3304 return false;
3305}
3306
3308 int ArgOrdinal,
3309 clang::QualType PassedType) {
3310 clang::QualType BaseType =
3311 PassedType->isVectorType()
3312 ? PassedType->castAs<clang::VectorType>()->getElementType()
3313 : PassedType->isMatrixType()
3314 ? PassedType->castAs<clang::MatrixType>()->getElementType()
3315 : PassedType;
3316 if (!BaseType->isDoubleType()) {
3317 // FIXME: adopt standard `err_builtin_invalid_arg_type` instead of using
3318 // this custom error.
3319 return S->Diag(Loc, diag::err_builtin_requires_double_type)
3320 << ArgOrdinal << PassedType;
3321 }
3322
3323 return false;
3324}
3325
3326static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
3327 unsigned ArgIndex) {
3328 auto *Arg = TheCall->getArg(ArgIndex);
3329 SourceLocation OrigLoc = Arg->getExprLoc();
3330 if (Arg->IgnoreCasts()->isModifiableLvalue(S->Context, &OrigLoc) ==
3332 return false;
3333 S->Diag(OrigLoc, diag::error_hlsl_inout_lvalue) << Arg << 0;
3334 return true;
3335}
3336
3337static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
3338 clang::QualType PassedType) {
3339 const auto *VecTy = PassedType->getAs<VectorType>();
3340 if (!VecTy)
3341 return false;
3342
3343 if (VecTy->getElementType()->isDoubleType())
3344 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3345 << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
3346 << PassedType;
3347 return false;
3348}
3349
3351 int ArgOrdinal,
3352 clang::QualType PassedType) {
3353 if (!PassedType->hasIntegerRepresentation() &&
3354 !PassedType->hasFloatingRepresentation())
3355 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3356 << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
3357 << /* fp */ 1 << PassedType;
3358 return false;
3359}
3360
3362 int ArgOrdinal,
3363 clang::QualType PassedType) {
3364 if (auto *VecTy = PassedType->getAs<VectorType>())
3365 if (VecTy->getElementType()->isUnsignedIntegerType())
3366 return false;
3367
3368 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3369 << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
3370 << PassedType;
3371}
3372
3373// checks for unsigned ints of all sizes
3375 int ArgOrdinal,
3376 clang::QualType PassedType) {
3377 if (!PassedType->hasUnsignedIntegerRepresentation())
3378 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3379 << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
3380 << /* no fp */ 0 << PassedType;
3381 return false;
3382}
3383
3384static bool CheckExpectedBitWidth(Sema *S, CallExpr *TheCall,
3385 unsigned ArgOrdinal, unsigned Width) {
3386 QualType ArgTy = TheCall->getArg(0)->getType();
3387 if (auto *VTy = ArgTy->getAs<VectorType>())
3388 ArgTy = VTy->getElementType();
3389 // ensure arg type has expected bit width
3390 uint64_t ElementBitCount =
3392 if (ElementBitCount != Width) {
3393 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3394 diag::err_integer_incorrect_bit_count)
3395 << Width << ElementBitCount;
3396 return true;
3397 }
3398 return false;
3399}
3400
3402 QualType ReturnType) {
3403 auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>();
3404 if (VecTyA)
3405 ReturnType =
3406 S->Context.getExtVectorType(ReturnType, VecTyA->getNumElements());
3407
3408 TheCall->setType(ReturnType);
3409}
3410
3411static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
3412 unsigned ArgIndex) {
3413 assert(TheCall->getNumArgs() >= ArgIndex);
3414 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3415 auto *VTy = ArgType->getAs<VectorType>();
3416 // not the scalar or vector<scalar>
3417 if (!(S->Context.hasSameUnqualifiedType(ArgType, Scalar) ||
3418 (VTy &&
3419 S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar)))) {
3420 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3421 diag::err_typecheck_expect_scalar_or_vector)
3422 << ArgType << Scalar;
3423 return true;
3424 }
3425 return false;
3426}
3427
3429 QualType Scalar, unsigned ArgIndex) {
3430 assert(TheCall->getNumArgs() > ArgIndex);
3431
3432 Expr *Arg = TheCall->getArg(ArgIndex);
3433 QualType ArgType = Arg->getType();
3434
3435 // Scalar: T
3436 if (S->Context.hasSameUnqualifiedType(ArgType, Scalar))
3437 return false;
3438
3439 // Vector: vector<T>
3440 if (const auto *VTy = ArgType->getAs<VectorType>()) {
3441 if (S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar))
3442 return false;
3443 }
3444
3445 // Matrix: ConstantMatrixType with element type T
3446 if (const auto *MTy = ArgType->getAs<ConstantMatrixType>()) {
3447 if (S->Context.hasSameUnqualifiedType(MTy->getElementType(), Scalar))
3448 return false;
3449 }
3450
3451 // Not a scalar/vector/matrix-of-scalar
3452 S->Diag(Arg->getBeginLoc(),
3453 diag::err_typecheck_expect_scalar_or_vector_or_matrix)
3454 << ArgType << Scalar;
3455 return true;
3456}
3457
3458static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
3459 unsigned ArgIndex) {
3460 assert(TheCall->getNumArgs() >= ArgIndex);
3461 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3462 auto *VTy = ArgType->getAs<VectorType>();
3463 // not the scalar or vector<scalar>
3464 if (!(ArgType->isScalarType() ||
3465 (VTy && VTy->getElementType()->isScalarType()))) {
3466 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3467 diag::err_typecheck_expect_any_scalar_or_vector)
3468 << ArgType << 1;
3469 return true;
3470 }
3471 return false;
3472}
3473
3474// Check that the argument is not a bool or vector<bool>
3475// Returns true on error
3477 unsigned ArgIndex) {
3478 QualType BoolType = S->getASTContext().BoolTy;
3479 assert(ArgIndex < TheCall->getNumArgs());
3480 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3481 auto *VTy = ArgType->getAs<VectorType>();
3482 // is the bool or vector<bool>
3483 if (S->Context.hasSameUnqualifiedType(ArgType, BoolType) ||
3484 (VTy &&
3485 S->Context.hasSameUnqualifiedType(VTy->getElementType(), BoolType))) {
3486 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3487 diag::err_typecheck_expect_any_scalar_or_vector)
3488 << ArgType << 0;
3489 return true;
3490 }
3491 return false;
3492}
3493
3494static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
3495 if (CheckNotBoolScalarOrVector(S, TheCall, 0))
3496 return true;
3497 return false;
3498}
3499
3500static bool CheckWavePrefix(Sema *S, CallExpr *TheCall) {
3501 if (CheckNotBoolScalarOrVector(S, TheCall, 0))
3502 return true;
3503 return false;
3504}
3505
3506static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
3507 assert(TheCall->getNumArgs() == 3);
3508 Expr *Arg1 = TheCall->getArg(1);
3509 Expr *Arg2 = TheCall->getArg(2);
3510 if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
3511 S->Diag(TheCall->getBeginLoc(),
3512 diag::err_typecheck_call_different_arg_types)
3513 << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
3514 << Arg2->getSourceRange();
3515 return true;
3516 }
3517
3518 TheCall->setType(Arg1->getType());
3519 return false;
3520}
3521
3522static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
3523 assert(TheCall->getNumArgs() == 3);
3524 Expr *Arg1 = TheCall->getArg(1);
3525 QualType Arg1Ty = Arg1->getType();
3526 Expr *Arg2 = TheCall->getArg(2);
3527 QualType Arg2Ty = Arg2->getType();
3528
3529 QualType Arg1ScalarTy = Arg1Ty;
3530 if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
3531 Arg1ScalarTy = VTy->getElementType();
3532
3533 QualType Arg2ScalarTy = Arg2Ty;
3534 if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
3535 Arg2ScalarTy = VTy->getElementType();
3536
3537 if (!S->Context.hasSameUnqualifiedType(Arg1ScalarTy, Arg2ScalarTy))
3538 S->Diag(Arg1->getBeginLoc(), diag::err_hlsl_builtin_scalar_vector_mismatch)
3539 << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
3540
3541 QualType Arg0Ty = TheCall->getArg(0)->getType();
3542 unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
3543 unsigned Arg1Length = Arg1Ty->isVectorType()
3544 ? Arg1Ty->getAs<VectorType>()->getNumElements()
3545 : 0;
3546 unsigned Arg2Length = Arg2Ty->isVectorType()
3547 ? Arg2Ty->getAs<VectorType>()->getNumElements()
3548 : 0;
3549 if (Arg1Length > 0 && Arg0Length != Arg1Length) {
3550 S->Diag(TheCall->getBeginLoc(),
3551 diag::err_typecheck_vector_lengths_not_equal)
3552 << Arg0Ty << Arg1Ty << TheCall->getArg(0)->getSourceRange()
3553 << Arg1->getSourceRange();
3554 return true;
3555 }
3556
3557 if (Arg2Length > 0 && Arg0Length != Arg2Length) {
3558 S->Diag(TheCall->getBeginLoc(),
3559 diag::err_typecheck_vector_lengths_not_equal)
3560 << Arg0Ty << Arg2Ty << TheCall->getArg(0)->getSourceRange()
3561 << Arg2->getSourceRange();
3562 return true;
3563 }
3564
3565 TheCall->setType(
3566 S->getASTContext().getExtVectorType(Arg1ScalarTy, Arg0Length));
3567 return false;
3568}
3569
3570static bool CheckIndexType(Sema *S, CallExpr *TheCall, unsigned IndexArgIndex) {
3571 assert(TheCall->getNumArgs() > IndexArgIndex && "Index argument missing");
3572 QualType ArgType = TheCall->getArg(IndexArgIndex)->getType();
3573 QualType IndexTy = ArgType;
3574 unsigned int ActualDim = 1;
3575 if (const auto *VTy = IndexTy->getAs<VectorType>()) {
3576 ActualDim = VTy->getNumElements();
3577 IndexTy = VTy->getElementType();
3578 }
3579 if (!IndexTy->isIntegerType()) {
3580 S->Diag(TheCall->getArg(IndexArgIndex)->getBeginLoc(),
3581 diag::err_typecheck_expect_int)
3582 << ArgType;
3583 return true;
3584 }
3585
3586 QualType ResourceArgTy = TheCall->getArg(0)->getType();
3587 const HLSLAttributedResourceType *ResTy =
3588 ResourceArgTy.getTypePtr()->getAs<HLSLAttributedResourceType>();
3589 assert(ResTy && "Resource argument must be a resource");
3590 HLSLAttributedResourceType::Attributes ResAttrs = ResTy->getAttrs();
3591
3592 unsigned int ExpectedDim = 1;
3593 if (ResAttrs.ResourceDimension != llvm::dxil::ResourceDimension::Unknown)
3594 ExpectedDim = getResourceDimensions(ResAttrs.ResourceDimension);
3595
3596 if (ActualDim != ExpectedDim) {
3597 S->Diag(TheCall->getArg(IndexArgIndex)->getBeginLoc(),
3598 diag::err_hlsl_builtin_resource_coordinate_dimension_mismatch)
3599 << cast<NamedDecl>(TheCall->getCalleeDecl()) << ExpectedDim
3600 << ActualDim;
3601 return true;
3602 }
3603
3604 return false;
3605}
3606
3608 Sema *S, CallExpr *TheCall, unsigned ArgIndex,
3609 llvm::function_ref<bool(const HLSLAttributedResourceType *ResType)> Check =
3610 nullptr) {
3611 assert(TheCall->getNumArgs() >= ArgIndex);
3612 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3613 const HLSLAttributedResourceType *ResTy =
3614 ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>();
3615 if (!ResTy) {
3616 S->Diag(TheCall->getArg(ArgIndex)->getBeginLoc(),
3617 diag::err_typecheck_expect_hlsl_resource)
3618 << ArgType;
3619 return true;
3620 }
3621 if (Check && Check(ResTy)) {
3622 S->Diag(TheCall->getArg(ArgIndex)->getExprLoc(),
3623 diag::err_invalid_hlsl_resource_type)
3624 << ArgType;
3625 return true;
3626 }
3627 return false;
3628}
3629
3630static bool CheckVectorElementCount(Sema *S, QualType PassedType,
3631 QualType BaseType, unsigned ExpectedCount,
3632 SourceLocation Loc) {
3633 unsigned PassedCount = 1;
3634 if (const auto *VecTy = PassedType->getAs<VectorType>())
3635 PassedCount = VecTy->getNumElements();
3636
3637 if (PassedCount != ExpectedCount) {
3639 S->Context.getExtVectorType(BaseType, ExpectedCount);
3640 S->Diag(Loc, diag::err_typecheck_convert_incompatible)
3641 << PassedType << ExpectedType << 1 << 0 << 0;
3642 return true;
3643 }
3644 return false;
3645}
3646
3647enum class SampleKind { Sample, Bias, Grad, Level, Cmp, CmpLevelZero };
3648
3650 // Check the texture handle.
3651 if (CheckResourceHandle(&S, TheCall, 0,
3652 [](const HLSLAttributedResourceType *ResType) {
3653 return ResType->getAttrs().ResourceDimension ==
3654 llvm::dxil::ResourceDimension::Unknown;
3655 }))
3656 return true;
3657
3658 // Check the sampler handle.
3659 if (CheckResourceHandle(&S, TheCall, 1,
3660 [](const HLSLAttributedResourceType *ResType) {
3661 return ResType->getAttrs().ResourceClass !=
3662 llvm::hlsl::ResourceClass::Sampler;
3663 }))
3664 return true;
3665
3666 auto *ResourceTy =
3667 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3668
3669 // Check the location.
3670 unsigned ExpectedDim =
3671 getResourceDimensions(ResourceTy->getAttrs().ResourceDimension);
3672 if (CheckVectorElementCount(&S, TheCall->getArg(2)->getType(),
3673 S.Context.FloatTy, ExpectedDim,
3674 TheCall->getBeginLoc()))
3675 return true;
3676
3677 return false;
3678}
3679
3680static bool CheckCalculateLodBuiltin(Sema &S, CallExpr *TheCall) {
3681 if (S.checkArgCount(TheCall, 3))
3682 return true;
3683
3684 if (CheckTextureSamplerAndLocation(S, TheCall))
3685 return true;
3686
3687 TheCall->setType(S.Context.FloatTy);
3688 return false;
3689}
3690
3691static bool CheckGatherBuiltin(Sema &S, CallExpr *TheCall, bool IsCmp) {
3692 if (S.checkArgCountRange(TheCall, IsCmp ? 5 : 4, IsCmp ? 6 : 5))
3693 return true;
3694
3695 if (CheckTextureSamplerAndLocation(S, TheCall))
3696 return true;
3697
3698 unsigned NextIdx = 3;
3699 if (IsCmp) {
3700 // Check the compare value.
3701 QualType CmpTy = TheCall->getArg(NextIdx)->getType();
3702 if (!CmpTy->isFloatingType() || CmpTy->isVectorType()) {
3703 S.Diag(TheCall->getArg(NextIdx)->getBeginLoc(),
3704 diag::err_typecheck_convert_incompatible)
3705 << CmpTy << S.Context.FloatTy << 1 << 0 << 0;
3706 return true;
3707 }
3708 NextIdx++;
3709 }
3710
3711 // Check the component operand.
3712 Expr *ComponentArg = TheCall->getArg(NextIdx);
3713 QualType ComponentTy = ComponentArg->getType();
3714 if (!ComponentTy->isIntegerType() || ComponentTy->isVectorType()) {
3715 S.Diag(ComponentArg->getBeginLoc(),
3716 diag::err_typecheck_convert_incompatible)
3717 << ComponentTy << S.Context.UnsignedIntTy << 1 << 0 << 0;
3718 return true;
3719 }
3720
3721 // GatherCmp operations on Vulkan target must use component 0 (Red).
3722 if (IsCmp && S.getASTContext().getTargetInfo().getTriple().isSPIRV()) {
3723 std::optional<llvm::APSInt> ComponentOpt =
3724 ComponentArg->getIntegerConstantExpr(S.getASTContext());
3725 if (ComponentOpt) {
3726 int64_t ComponentVal = ComponentOpt->getSExtValue();
3727 if (ComponentVal != 0) {
3728 // Issue an error if the component is not 0 (Red).
3729 // 0 -> Red, 1 -> Green, 2 -> Blue, 3 -> Alpha
3730 assert(ComponentVal >= 0 && ComponentVal <= 3 &&
3731 "The component is not in the expected range.");
3732 S.Diag(ComponentArg->getBeginLoc(),
3733 diag::err_hlsl_gathercmp_invalid_component)
3734 << ComponentVal;
3735 return true;
3736 }
3737 }
3738 }
3739
3740 NextIdx++;
3741
3742 // Check the offset operand.
3743 const HLSLAttributedResourceType *ResourceTy =
3744 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3745 if (TheCall->getNumArgs() > NextIdx) {
3746 unsigned ExpectedDim =
3747 getResourceDimensions(ResourceTy->getAttrs().ResourceDimension);
3748 if (CheckVectorElementCount(&S, TheCall->getArg(NextIdx)->getType(),
3749 S.Context.IntTy, ExpectedDim,
3750 TheCall->getArg(NextIdx)->getBeginLoc()))
3751 return true;
3752 NextIdx++;
3753 }
3754
3755 assert(ResourceTy->hasContainedType() &&
3756 "Expecting a contained type for resource with a dimension "
3757 "attribute.");
3758 QualType ReturnType = ResourceTy->getContainedType();
3759
3760 if (IsCmp) {
3761 if (!ReturnType->hasFloatingRepresentation()) {
3762 S.Diag(TheCall->getBeginLoc(), diag::err_hlsl_samplecmp_requires_float);
3763 return true;
3764 }
3765 }
3766
3767 if (const auto *VecTy = ReturnType->getAs<VectorType>())
3768 ReturnType = VecTy->getElementType();
3769 ReturnType = S.Context.getExtVectorType(ReturnType, 4);
3770
3771 TheCall->setType(ReturnType);
3772
3773 return false;
3774}
3775static bool CheckLoadLevelBuiltin(Sema &S, CallExpr *TheCall) {
3776 if (S.checkArgCountRange(TheCall, 2, 3))
3777 return true;
3778
3779 // Check the texture handle.
3780 if (CheckResourceHandle(&S, TheCall, 0,
3781 [](const HLSLAttributedResourceType *ResType) {
3782 return ResType->getAttrs().ResourceDimension ==
3783 llvm::dxil::ResourceDimension::Unknown;
3784 }))
3785 return true;
3786
3787 auto *ResourceTy =
3788 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3789
3790 // Check the location + lod (int3 for Texture2D).
3791 unsigned ExpectedDim =
3792 getResourceDimensions(ResourceTy->getAttrs().ResourceDimension);
3793 QualType CoordLODTy = TheCall->getArg(1)->getType();
3794 if (CheckVectorElementCount(&S, CoordLODTy, S.Context.IntTy, ExpectedDim + 1,
3795 TheCall->getArg(1)->getBeginLoc()))
3796 return true;
3797
3798 QualType EltTy = CoordLODTy;
3799 if (const auto *VTy = EltTy->getAs<VectorType>())
3800 EltTy = VTy->getElementType();
3801 if (!EltTy->isIntegerType()) {
3802 S.Diag(TheCall->getArg(1)->getBeginLoc(), diag::err_typecheck_expect_int)
3803 << CoordLODTy;
3804 return true;
3805 }
3806
3807 // Check the offset operand.
3808 if (TheCall->getNumArgs() > 2) {
3809 if (CheckVectorElementCount(&S, TheCall->getArg(2)->getType(),
3810 S.Context.IntTy, ExpectedDim,
3811 TheCall->getArg(2)->getBeginLoc()))
3812 return true;
3813 }
3814
3815 TheCall->setType(ResourceTy->getContainedType());
3816 return false;
3817}
3818
3819static bool CheckSamplingBuiltin(Sema &S, CallExpr *TheCall, SampleKind Kind) {
3820 unsigned MinArgs, MaxArgs;
3821 if (Kind == SampleKind::Sample) {
3822 MinArgs = 3;
3823 MaxArgs = 5;
3824 } else if (Kind == SampleKind::Bias) {
3825 MinArgs = 4;
3826 MaxArgs = 6;
3827 } else if (Kind == SampleKind::Grad) {
3828 MinArgs = 5;
3829 MaxArgs = 7;
3830 } else if (Kind == SampleKind::Level) {
3831 MinArgs = 4;
3832 MaxArgs = 5;
3833 } else if (Kind == SampleKind::Cmp) {
3834 MinArgs = 4;
3835 MaxArgs = 6;
3836 } else {
3837 assert(Kind == SampleKind::CmpLevelZero);
3838 MinArgs = 4;
3839 MaxArgs = 5;
3840 }
3841
3842 if (S.checkArgCountRange(TheCall, MinArgs, MaxArgs))
3843 return true;
3844
3845 if (CheckTextureSamplerAndLocation(S, TheCall))
3846 return true;
3847
3848 const HLSLAttributedResourceType *ResourceTy =
3849 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3850 unsigned ExpectedDim =
3851 getResourceDimensions(ResourceTy->getAttrs().ResourceDimension);
3852
3853 unsigned NextIdx = 3;
3854 if (Kind == SampleKind::Bias || Kind == SampleKind::Level ||
3855 Kind == SampleKind::Cmp || Kind == SampleKind::CmpLevelZero) {
3856 // Check the bias, lod level, or compare value, depending on the kind.
3857 // All of them must be a scalar float value.
3858 QualType BiasOrLODOrCmpTy = TheCall->getArg(NextIdx)->getType();
3859 if (!BiasOrLODOrCmpTy->isFloatingType() ||
3860 BiasOrLODOrCmpTy->isVectorType()) {
3861 S.Diag(TheCall->getArg(NextIdx)->getBeginLoc(),
3862 diag::err_typecheck_convert_incompatible)
3863 << BiasOrLODOrCmpTy << S.Context.FloatTy << 1 << 0 << 0;
3864 return true;
3865 }
3866 NextIdx++;
3867 } else if (Kind == SampleKind::Grad) {
3868 // Check the DDX operand.
3869 if (CheckVectorElementCount(&S, TheCall->getArg(NextIdx)->getType(),
3870 S.Context.FloatTy, ExpectedDim,
3871 TheCall->getArg(NextIdx)->getBeginLoc()))
3872 return true;
3873
3874 // Check the DDY operand.
3875 if (CheckVectorElementCount(&S, TheCall->getArg(NextIdx + 1)->getType(),
3876 S.Context.FloatTy, ExpectedDim,
3877 TheCall->getArg(NextIdx + 1)->getBeginLoc()))
3878 return true;
3879 NextIdx += 2;
3880 }
3881
3882 // Check the offset operand.
3883 if (TheCall->getNumArgs() > NextIdx) {
3884 if (CheckVectorElementCount(&S, TheCall->getArg(NextIdx)->getType(),
3885 S.Context.IntTy, ExpectedDim,
3886 TheCall->getArg(NextIdx)->getBeginLoc()))
3887 return true;
3888 NextIdx++;
3889 }
3890
3891 // Check the clamp operand.
3892 if (Kind != SampleKind::Level && Kind != SampleKind::CmpLevelZero &&
3893 TheCall->getNumArgs() > NextIdx) {
3894 QualType ClampTy = TheCall->getArg(NextIdx)->getType();
3895 if (!ClampTy->isFloatingType() || ClampTy->isVectorType()) {
3896 S.Diag(TheCall->getArg(NextIdx)->getBeginLoc(),
3897 diag::err_typecheck_convert_incompatible)
3898 << ClampTy << S.Context.FloatTy << 1 << 0 << 0;
3899 return true;
3900 }
3901 }
3902
3903 assert(ResourceTy->hasContainedType() &&
3904 "Expecting a contained type for resource with a dimension "
3905 "attribute.");
3906 QualType ReturnType = ResourceTy->getContainedType();
3907 if (Kind == SampleKind::Cmp || Kind == SampleKind::CmpLevelZero) {
3908 if (!ReturnType->hasFloatingRepresentation()) {
3909 S.Diag(TheCall->getBeginLoc(), diag::err_hlsl_samplecmp_requires_float);
3910 return true;
3911 }
3912 ReturnType = S.Context.FloatTy;
3913 }
3914 TheCall->setType(ReturnType);
3915
3916 return false;
3917}
3918
3919// Note: returning true in this case results in CheckBuiltinFunctionCall
3920// returning an ExprError
3921bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
3922 switch (BuiltinID) {
3923 case Builtin::BI__builtin_hlsl_adduint64: {
3924 if (SemaRef.checkArgCount(TheCall, 2))
3925 return true;
3926
3927 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3929 return true;
3930
3931 // ensure arg integers are 32-bits
3932 if (CheckExpectedBitWidth(&SemaRef, TheCall, 0, 32))
3933 return true;
3934
3935 // ensure both args are vectors of total bit size of a multiple of 64
3936 auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
3937 int NumElementsArg = VTy->getNumElements();
3938 if (NumElementsArg != 2 && NumElementsArg != 4) {
3939 SemaRef.Diag(TheCall->getBeginLoc(), diag::err_vector_incorrect_bit_count)
3940 << 1 /*a multiple of*/ << 64 << NumElementsArg * 32;
3941 return true;
3942 }
3943
3944 // ensure first arg and second arg have the same type
3945 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
3946 return true;
3947
3948 ExprResult A = TheCall->getArg(0);
3949 QualType ArgTyA = A.get()->getType();
3950 // return type is the same as the input type
3951 TheCall->setType(ArgTyA);
3952 break;
3953 }
3954 case Builtin::BI__builtin_hlsl_resource_getpointer: {
3955 if (SemaRef.checkArgCountRange(TheCall, 1, 2) ||
3956 CheckResourceHandle(&SemaRef, TheCall, 0) ||
3957 (TheCall->getNumArgs() == 2 && CheckIndexType(&SemaRef, TheCall, 1)))
3958 return true;
3959
3960 auto *ResourceTy =
3961 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3962 QualType ContainedTy = ResourceTy->getContainedType();
3963 auto ReturnType = SemaRef.Context.getAddrSpaceQualType(
3964 ContainedTy,
3965 getLangASFromResourceClass(ResourceTy->getAttrs().ResourceClass));
3966 ReturnType = SemaRef.Context.getPointerType(ReturnType);
3967 TheCall->setType(ReturnType);
3968
3969 break;
3970 }
3971 case Builtin::BI__builtin_hlsl_resource_getpointer_typed: {
3972 if (SemaRef.checkArgCount(TheCall, 3) ||
3973 CheckResourceHandle(&SemaRef, TheCall, 0) ||
3974 CheckIndexType(&SemaRef, TheCall, 1))
3975 return true;
3976
3977 QualType ElementTy = TheCall->getArg(2)->getType();
3978 assert(ElementTy->isPointerType() &&
3979 "expected pointer type for second argument");
3980 ElementTy = ElementTy->getPointeeType();
3981
3982 // Reject array types
3983 if (ElementTy->isArrayType())
3984 return SemaRef.Diag(
3985 cast<FunctionDecl>(SemaRef.CurContext)->getPointOfInstantiation(),
3986 diag::err_invalid_use_of_array_type);
3987
3988 auto *ResourceTy =
3989 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3990 auto ReturnType = SemaRef.Context.getAddrSpaceQualType(
3991 ElementTy,
3992 getLangASFromResourceClass(ResourceTy->getAttrs().ResourceClass));
3993 ReturnType = SemaRef.Context.getPointerType(ReturnType);
3994 TheCall->setType(ReturnType);
3995
3996 break;
3997 }
3998 case Builtin::BI__builtin_hlsl_resource_load_with_status: {
3999 if (SemaRef.checkArgCount(TheCall, 3) ||
4000 CheckResourceHandle(&SemaRef, TheCall, 0) ||
4001 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
4002 SemaRef.getASTContext().UnsignedIntTy) ||
4003 CheckArgTypeMatches(&SemaRef, TheCall->getArg(2),
4004 SemaRef.getASTContext().UnsignedIntTy) ||
4005 CheckModifiableLValue(&SemaRef, TheCall, 2))
4006 return true;
4007
4008 auto *ResourceTy =
4009 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
4010 QualType ReturnType = ResourceTy->getContainedType();
4011 TheCall->setType(ReturnType);
4012
4013 break;
4014 }
4015 case Builtin::BI__builtin_hlsl_resource_load_with_status_typed: {
4016 if (SemaRef.checkArgCount(TheCall, 4) ||
4017 CheckResourceHandle(&SemaRef, TheCall, 0) ||
4018 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
4019 SemaRef.getASTContext().UnsignedIntTy) ||
4020 CheckArgTypeMatches(&SemaRef, TheCall->getArg(2),
4021 SemaRef.getASTContext().UnsignedIntTy) ||
4022 CheckModifiableLValue(&SemaRef, TheCall, 2))
4023 return true;
4024
4025 QualType ReturnType = TheCall->getArg(3)->getType();
4026 assert(ReturnType->isPointerType() &&
4027 "expected pointer type for second argument");
4028 ReturnType = ReturnType->getPointeeType();
4029
4030 // Reject array types
4031 if (ReturnType->isArrayType())
4032 return SemaRef.Diag(
4033 cast<FunctionDecl>(SemaRef.CurContext)->getPointOfInstantiation(),
4034 diag::err_invalid_use_of_array_type);
4035
4036 TheCall->setType(ReturnType);
4037
4038 break;
4039 }
4040 case Builtin::BI__builtin_hlsl_resource_load_level:
4041 return CheckLoadLevelBuiltin(SemaRef, TheCall);
4042 case Builtin::BI__builtin_hlsl_resource_sample:
4044 case Builtin::BI__builtin_hlsl_resource_sample_bias:
4046 case Builtin::BI__builtin_hlsl_resource_sample_grad:
4048 case Builtin::BI__builtin_hlsl_resource_sample_level:
4050 case Builtin::BI__builtin_hlsl_resource_sample_cmp:
4052 case Builtin::BI__builtin_hlsl_resource_sample_cmp_level_zero:
4054 case Builtin::BI__builtin_hlsl_resource_calculate_lod:
4055 case Builtin::BI__builtin_hlsl_resource_calculate_lod_unclamped:
4056 return CheckCalculateLodBuiltin(SemaRef, TheCall);
4057 case Builtin::BI__builtin_hlsl_resource_gather:
4058 return CheckGatherBuiltin(SemaRef, TheCall, /*IsCmp=*/false);
4059 case Builtin::BI__builtin_hlsl_resource_gather_cmp:
4060 return CheckGatherBuiltin(SemaRef, TheCall, /*IsCmp=*/true);
4061 case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: {
4062 assert(TheCall->getNumArgs() == 1 && "expected 1 arg");
4063 // Update return type to be the attributed resource type from arg0.
4064 QualType ResourceTy = TheCall->getArg(0)->getType();
4065 TheCall->setType(ResourceTy);
4066 break;
4067 }
4068 case Builtin::BI__builtin_hlsl_resource_handlefrombinding: {
4069 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
4070 // Update return type to be the attributed resource type from arg0.
4071 QualType ResourceTy = TheCall->getArg(0)->getType();
4072 TheCall->setType(ResourceTy);
4073 break;
4074 }
4075 case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: {
4076 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
4077 // Update return type to be the attributed resource type from arg0.
4078 QualType ResourceTy = TheCall->getArg(0)->getType();
4079 TheCall->setType(ResourceTy);
4080 break;
4081 }
4082 case Builtin::BI__builtin_hlsl_resource_counterhandlefromimplicitbinding: {
4083 assert(TheCall->getNumArgs() == 3 && "expected 3 args");
4084 ASTContext &AST = SemaRef.getASTContext();
4085 QualType MainHandleTy = TheCall->getArg(0)->getType();
4086 auto *MainResType = MainHandleTy->getAs<HLSLAttributedResourceType>();
4087 auto MainAttrs = MainResType->getAttrs();
4088 assert(!MainAttrs.IsCounter && "cannot create a counter from a counter");
4089 MainAttrs.IsCounter = true;
4090 QualType CounterHandleTy = AST.getHLSLAttributedResourceType(
4091 MainResType->getWrappedType(), MainResType->getContainedType(),
4092 MainAttrs);
4093 // Update return type to be the attributed resource type from arg0
4094 // with added IsCounter flag.
4095 TheCall->setType(CounterHandleTy);
4096 break;
4097 }
4098 case Builtin::BI__builtin_hlsl_and:
4099 case Builtin::BI__builtin_hlsl_or: {
4100 if (SemaRef.checkArgCount(TheCall, 2))
4101 return true;
4102 if (CheckScalarOrVectorOrMatrix(&SemaRef, TheCall, getASTContext().BoolTy,
4103 0))
4104 return true;
4105 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
4106 return true;
4107
4108 ExprResult A = TheCall->getArg(0);
4109 QualType ArgTyA = A.get()->getType();
4110 // return type is the same as the input type
4111 TheCall->setType(ArgTyA);
4112 break;
4113 }
4114 case Builtin::BI__builtin_hlsl_all:
4115 case Builtin::BI__builtin_hlsl_any: {
4116 if (SemaRef.checkArgCount(TheCall, 1))
4117 return true;
4118 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4119 return true;
4120 break;
4121 }
4122 case Builtin::BI__builtin_hlsl_asdouble: {
4123 if (SemaRef.checkArgCount(TheCall, 2))
4124 return true;
4126 &SemaRef, TheCall,
4127 /*only check for uint*/ SemaRef.Context.UnsignedIntTy,
4128 /* arg index */ 0))
4129 return true;
4131 &SemaRef, TheCall,
4132 /*only check for uint*/ SemaRef.Context.UnsignedIntTy,
4133 /* arg index */ 1))
4134 return true;
4135 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
4136 return true;
4137
4138 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().DoubleTy);
4139 break;
4140 }
4141 case Builtin::BI__builtin_hlsl_elementwise_clamp: {
4142 if (SemaRef.BuiltinElementwiseTernaryMath(
4143 TheCall, /*ArgTyRestr=*/
4145 return true;
4146 break;
4147 }
4148 case Builtin::BI__builtin_hlsl_dot: {
4149 // arg count is checked by BuiltinVectorToScalarMath
4150 if (SemaRef.BuiltinVectorToScalarMath(TheCall))
4151 return true;
4153 return true;
4154 break;
4155 }
4156 case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
4157 case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: {
4158 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4159 return true;
4160
4161 const Expr *Arg = TheCall->getArg(0);
4162 QualType ArgTy = Arg->getType();
4163 QualType EltTy = ArgTy;
4164
4165 QualType ResTy = SemaRef.Context.UnsignedIntTy;
4166
4167 if (auto *VecTy = EltTy->getAs<VectorType>()) {
4168 EltTy = VecTy->getElementType();
4169 ResTy = SemaRef.Context.getExtVectorType(ResTy, VecTy->getNumElements());
4170 }
4171
4172 if (!EltTy->isIntegerType()) {
4173 Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
4174 << 1 << /* scalar or vector of */ 5 << /* integer ty */ 1
4175 << /* no fp */ 0 << ArgTy;
4176 return true;
4177 }
4178
4179 TheCall->setType(ResTy);
4180 break;
4181 }
4182 case Builtin::BI__builtin_hlsl_select: {
4183 if (SemaRef.checkArgCount(TheCall, 3))
4184 return true;
4185 if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
4186 return true;
4187 QualType ArgTy = TheCall->getArg(0)->getType();
4188 if (ArgTy->isBooleanType() && CheckBoolSelect(&SemaRef, TheCall))
4189 return true;
4190 auto *VTy = ArgTy->getAs<VectorType>();
4191 if (VTy && VTy->getElementType()->isBooleanType() &&
4192 CheckVectorSelect(&SemaRef, TheCall))
4193 return true;
4194 break;
4195 }
4196 case Builtin::BI__builtin_hlsl_elementwise_saturate:
4197 case Builtin::BI__builtin_hlsl_elementwise_rcp: {
4198 if (SemaRef.checkArgCount(TheCall, 1))
4199 return true;
4200 if (!TheCall->getArg(0)
4201 ->getType()
4202 ->hasFloatingRepresentation()) // half or float or double
4203 return SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
4204 diag::err_builtin_invalid_arg_type)
4205 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
4206 << /* fp */ 1 << TheCall->getArg(0)->getType();
4207 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4208 return true;
4209 break;
4210 }
4211 case Builtin::BI__builtin_hlsl_elementwise_degrees:
4212 case Builtin::BI__builtin_hlsl_elementwise_radians:
4213 case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
4214 case Builtin::BI__builtin_hlsl_elementwise_frac:
4215 case Builtin::BI__builtin_hlsl_elementwise_ddx_coarse:
4216 case Builtin::BI__builtin_hlsl_elementwise_ddy_coarse:
4217 case Builtin::BI__builtin_hlsl_elementwise_ddx_fine:
4218 case Builtin::BI__builtin_hlsl_elementwise_ddy_fine: {
4219 if (SemaRef.checkArgCount(TheCall, 1))
4220 return true;
4221 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4223 return true;
4224 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4225 return true;
4226 break;
4227 }
4228 case Builtin::BI__builtin_hlsl_elementwise_isinf:
4229 case Builtin::BI__builtin_hlsl_elementwise_isnan: {
4230 if (SemaRef.checkArgCount(TheCall, 1))
4231 return true;
4232 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4234 return true;
4235 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4236 return true;
4238 break;
4239 }
4240 case Builtin::BI__builtin_hlsl_lerp: {
4241 if (SemaRef.checkArgCount(TheCall, 3))
4242 return true;
4243 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4245 return true;
4246 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
4247 return true;
4248 if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
4249 return true;
4250 break;
4251 }
4252 case Builtin::BI__builtin_hlsl_mad: {
4253 if (SemaRef.BuiltinElementwiseTernaryMath(
4254 TheCall, /*ArgTyRestr=*/
4256 return true;
4257 break;
4258 }
4259 case Builtin::BI__builtin_hlsl_mul: {
4260 if (SemaRef.checkArgCount(TheCall, 2))
4261 return true;
4262
4263 Expr *Arg0 = TheCall->getArg(0);
4264 Expr *Arg1 = TheCall->getArg(1);
4265 QualType Ty0 = Arg0->getType();
4266 QualType Ty1 = Arg1->getType();
4267
4268 auto getElemType = [](QualType T) -> QualType {
4269 if (const auto *VTy = T->getAs<VectorType>())
4270 return VTy->getElementType();
4271 if (const auto *MTy = T->getAs<ConstantMatrixType>())
4272 return MTy->getElementType();
4273 return T;
4274 };
4275
4276 QualType EltTy0 = getElemType(Ty0);
4277
4278 bool IsVec0 = Ty0->isVectorType();
4279 bool IsMat0 = Ty0->isConstantMatrixType();
4280 bool IsVec1 = Ty1->isVectorType();
4281 bool IsMat1 = Ty1->isConstantMatrixType();
4282
4283 QualType RetTy;
4284
4285 if (IsVec0 && IsMat1) {
4286 auto *MatTy = Ty1->castAs<ConstantMatrixType>();
4287 RetTy = getASTContext().getExtVectorType(EltTy0, MatTy->getNumColumns());
4288 } else if (IsMat0 && IsVec1) {
4289 auto *MatTy = Ty0->castAs<ConstantMatrixType>();
4290 RetTy = getASTContext().getExtVectorType(EltTy0, MatTy->getNumRows());
4291 } else {
4292 assert(IsMat0 && IsMat1);
4293 auto *MatTy0 = Ty0->castAs<ConstantMatrixType>();
4294 auto *MatTy1 = Ty1->castAs<ConstantMatrixType>();
4296 EltTy0, MatTy0->getNumRows(), MatTy1->getNumColumns());
4297 }
4298
4299 TheCall->setType(RetTy);
4300 break;
4301 }
4302 case Builtin::BI__builtin_hlsl_normalize: {
4303 if (SemaRef.checkArgCount(TheCall, 1))
4304 return true;
4305 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4307 return true;
4308 ExprResult A = TheCall->getArg(0);
4309 QualType ArgTyA = A.get()->getType();
4310 // return type is the same as the input type
4311 TheCall->setType(ArgTyA);
4312 break;
4313 }
4314 case Builtin::BI__builtin_elementwise_fma: {
4315 if (SemaRef.checkArgCount(TheCall, 3) ||
4316 CheckAllArgsHaveSameType(&SemaRef, TheCall)) {
4317 return true;
4318 }
4319
4320 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4322 return true;
4323
4324 ExprResult A = TheCall->getArg(0);
4325 QualType ArgTyA = A.get()->getType();
4326 // return type is the same as input type
4327 TheCall->setType(ArgTyA);
4328 break;
4329 }
4330 case Builtin::BI__builtin_hlsl_transpose: {
4331 if (SemaRef.checkArgCount(TheCall, 1))
4332 return true;
4333
4334 Expr *Arg = TheCall->getArg(0);
4335 QualType ArgTy = Arg->getType();
4336
4337 const auto *MatTy = ArgTy->getAs<ConstantMatrixType>();
4338 if (!MatTy) {
4339 SemaRef.Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
4340 << 1 << /* matrix */ 3 << /* no int */ 0 << /* no fp */ 0 << ArgTy;
4341 return true;
4342 }
4343
4345 MatTy->getElementType(), MatTy->getNumColumns(), MatTy->getNumRows());
4346 TheCall->setType(RetTy);
4347 break;
4348 }
4349 case Builtin::BI__builtin_hlsl_elementwise_sign: {
4350 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4351 return true;
4352 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4354 return true;
4356 break;
4357 }
4358 case Builtin::BI__builtin_hlsl_step: {
4359 if (SemaRef.checkArgCount(TheCall, 2))
4360 return true;
4361 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4363 return true;
4364
4365 ExprResult A = TheCall->getArg(0);
4366 QualType ArgTyA = A.get()->getType();
4367 // return type is the same as the input type
4368 TheCall->setType(ArgTyA);
4369 break;
4370 }
4371 case Builtin::BI__builtin_hlsl_wave_active_all_equal: {
4372 if (SemaRef.checkArgCount(TheCall, 1))
4373 return true;
4374
4375 // Ensure input expr type is a scalar/vector
4376 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4377 return true;
4378
4379 QualType InputTy = TheCall->getArg(0)->getType();
4380 ASTContext &Ctx = getASTContext();
4381
4382 QualType RetTy;
4383
4384 // If vector, construct bool vector of same size
4385 if (const auto *VecTy = InputTy->getAs<ExtVectorType>()) {
4386 unsigned NumElts = VecTy->getNumElements();
4387 RetTy = Ctx.getExtVectorType(Ctx.BoolTy, NumElts);
4388 } else {
4389 // Scalar case
4390 RetTy = Ctx.BoolTy;
4391 }
4392
4393 TheCall->setType(RetTy);
4394 break;
4395 }
4396 case Builtin::BI__builtin_hlsl_wave_active_max:
4397 case Builtin::BI__builtin_hlsl_wave_active_min:
4398 case Builtin::BI__builtin_hlsl_wave_active_sum:
4399 case Builtin::BI__builtin_hlsl_wave_active_product: {
4400 if (SemaRef.checkArgCount(TheCall, 1))
4401 return true;
4402
4403 // Ensure input expr type is a scalar/vector and the same as the return type
4404 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4405 return true;
4406 if (CheckWaveActive(&SemaRef, TheCall))
4407 return true;
4408 ExprResult Expr = TheCall->getArg(0);
4409 QualType ArgTyExpr = Expr.get()->getType();
4410 TheCall->setType(ArgTyExpr);
4411 break;
4412 }
4413 case Builtin::BI__builtin_hlsl_wave_active_bit_or:
4414 case Builtin::BI__builtin_hlsl_wave_active_bit_xor:
4415 case Builtin::BI__builtin_hlsl_wave_active_bit_and: {
4416 if (SemaRef.checkArgCount(TheCall, 1))
4417 return true;
4418
4419 // Ensure input expr type is a scalar/vector
4420 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4421 return true;
4422
4423 if (CheckWaveActive(&SemaRef, TheCall))
4424 return true;
4425
4426 // Ensure the expr type is interpretable as a uint or vector<uint>
4427 ExprResult Expr = TheCall->getArg(0);
4428 QualType ArgTyExpr = Expr.get()->getType();
4429 auto *VTy = ArgTyExpr->getAs<VectorType>();
4430 if (!(ArgTyExpr->isIntegerType() ||
4431 (VTy && VTy->getElementType()->isIntegerType()))) {
4432 SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
4433 diag::err_builtin_invalid_arg_type)
4434 << ArgTyExpr << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
4435 return true;
4436 }
4437
4438 // Ensure input expr type is the same as the return type
4439 TheCall->setType(ArgTyExpr);
4440 break;
4441 }
4442 // Note these are llvm builtins that we want to catch invalid intrinsic
4443 // generation. Normal handling of these builtins will occur elsewhere.
4444 case Builtin::BI__builtin_elementwise_bitreverse: {
4445 // does not include a check for number of arguments
4446 // because that is done previously
4447 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4449 return true;
4450 break;
4451 }
4452 case Builtin::BI__builtin_hlsl_wave_prefix_count_bits: {
4453 if (SemaRef.checkArgCount(TheCall, 1))
4454 return true;
4455
4456 QualType ArgType = TheCall->getArg(0)->getType();
4457
4458 if (!(ArgType->isScalarType())) {
4459 SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
4460 diag::err_typecheck_expect_any_scalar_or_vector)
4461 << ArgType << 0;
4462 return true;
4463 }
4464
4465 if (!(ArgType->isBooleanType())) {
4466 SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
4467 diag::err_typecheck_expect_any_scalar_or_vector)
4468 << ArgType << 0;
4469 return true;
4470 }
4471
4472 break;
4473 }
4474 case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
4475 if (SemaRef.checkArgCount(TheCall, 2))
4476 return true;
4477
4478 // Ensure index parameter type can be interpreted as a uint
4479 ExprResult Index = TheCall->getArg(1);
4480 QualType ArgTyIndex = Index.get()->getType();
4481 if (!ArgTyIndex->isIntegerType()) {
4482 SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
4483 diag::err_typecheck_convert_incompatible)
4484 << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
4485 return true;
4486 }
4487
4488 // Ensure input expr type is a scalar/vector and the same as the return type
4489 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4490 return true;
4491
4492 ExprResult Expr = TheCall->getArg(0);
4493 QualType ArgTyExpr = Expr.get()->getType();
4494 TheCall->setType(ArgTyExpr);
4495 break;
4496 }
4497 case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
4498 if (SemaRef.checkArgCount(TheCall, 0))
4499 return true;
4500 break;
4501 }
4502 case Builtin::BI__builtin_hlsl_wave_prefix_sum:
4503 case Builtin::BI__builtin_hlsl_wave_prefix_product: {
4504 if (SemaRef.checkArgCount(TheCall, 1))
4505 return true;
4506
4507 // Ensure input expr type is a scalar/vector and the same as the return type
4508 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4509 return true;
4510 if (CheckWavePrefix(&SemaRef, TheCall))
4511 return true;
4512 ExprResult Expr = TheCall->getArg(0);
4513 QualType ArgTyExpr = Expr.get()->getType();
4514 TheCall->setType(ArgTyExpr);
4515 break;
4516 }
4517 case Builtin::BI__builtin_hlsl_quad_read_across_x:
4518 case Builtin::BI__builtin_hlsl_quad_read_across_y: {
4519 if (SemaRef.checkArgCount(TheCall, 1))
4520 return true;
4521
4522 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4523 return true;
4524 if (CheckNotBoolScalarOrVector(&SemaRef, TheCall, 0))
4525 return true;
4526 ExprResult Expr = TheCall->getArg(0);
4527 QualType ArgTyExpr = Expr.get()->getType();
4528 TheCall->setType(ArgTyExpr);
4529 break;
4530 }
4531 case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
4532 if (SemaRef.checkArgCount(TheCall, 3))
4533 return true;
4534
4535 if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.DoubleTy, 0) ||
4536 CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
4537 1) ||
4538 CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
4539 2))
4540 return true;
4541
4542 if (CheckModifiableLValue(&SemaRef, TheCall, 1) ||
4543 CheckModifiableLValue(&SemaRef, TheCall, 2))
4544 return true;
4545 break;
4546 }
4547 case Builtin::BI__builtin_hlsl_elementwise_clip: {
4548 if (SemaRef.checkArgCount(TheCall, 1))
4549 return true;
4550
4551 if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.FloatTy, 0))
4552 return true;
4553 break;
4554 }
4555 case Builtin::BI__builtin_elementwise_acos:
4556 case Builtin::BI__builtin_elementwise_asin:
4557 case Builtin::BI__builtin_elementwise_atan:
4558 case Builtin::BI__builtin_elementwise_atan2:
4559 case Builtin::BI__builtin_elementwise_ceil:
4560 case Builtin::BI__builtin_elementwise_cos:
4561 case Builtin::BI__builtin_elementwise_cosh:
4562 case Builtin::BI__builtin_elementwise_exp:
4563 case Builtin::BI__builtin_elementwise_exp2:
4564 case Builtin::BI__builtin_elementwise_exp10:
4565 case Builtin::BI__builtin_elementwise_floor:
4566 case Builtin::BI__builtin_elementwise_fmod:
4567 case Builtin::BI__builtin_elementwise_log:
4568 case Builtin::BI__builtin_elementwise_log2:
4569 case Builtin::BI__builtin_elementwise_log10:
4570 case Builtin::BI__builtin_elementwise_pow:
4571 case Builtin::BI__builtin_elementwise_roundeven:
4572 case Builtin::BI__builtin_elementwise_sin:
4573 case Builtin::BI__builtin_elementwise_sinh:
4574 case Builtin::BI__builtin_elementwise_sqrt:
4575 case Builtin::BI__builtin_elementwise_tan:
4576 case Builtin::BI__builtin_elementwise_tanh:
4577 case Builtin::BI__builtin_elementwise_trunc: {
4578 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4580 return true;
4581 break;
4582 }
4583 case Builtin::BI__builtin_hlsl_buffer_update_counter: {
4584 assert(TheCall->getNumArgs() == 2 && "expected 2 args");
4585 auto checkResTy = [](const HLSLAttributedResourceType *ResTy) -> bool {
4586 return !(ResTy->getAttrs().ResourceClass == ResourceClass::UAV &&
4587 ResTy->getAttrs().RawBuffer && ResTy->hasContainedType());
4588 };
4589 if (CheckResourceHandle(&SemaRef, TheCall, 0, checkResTy))
4590 return true;
4591 Expr *OffsetExpr = TheCall->getArg(1);
4592 std::optional<llvm::APSInt> Offset =
4593 OffsetExpr->getIntegerConstantExpr(SemaRef.getASTContext());
4594 if (!Offset.has_value() || std::abs(Offset->getExtValue()) != 1) {
4595 SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
4596 diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
4597 << 1;
4598 return true;
4599 }
4600 break;
4601 }
4602 case Builtin::BI__builtin_hlsl_elementwise_f16tof32: {
4603 if (SemaRef.checkArgCount(TheCall, 1))
4604 return true;
4605 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4607 return true;
4608 // ensure arg integers are 32 bits
4609 if (CheckExpectedBitWidth(&SemaRef, TheCall, 0, 32))
4610 return true;
4611 // check it wasn't a bool type
4612 QualType ArgTy = TheCall->getArg(0)->getType();
4613 if (auto *VTy = ArgTy->getAs<VectorType>())
4614 ArgTy = VTy->getElementType();
4615 if (ArgTy->isBooleanType()) {
4616 SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
4617 diag::err_builtin_invalid_arg_type)
4618 << 1 << /* scalar or vector of */ 5 << /* unsigned int */ 3
4619 << /* no fp */ 0 << TheCall->getArg(0)->getType();
4620 return true;
4621 }
4622
4623 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().FloatTy);
4624 break;
4625 }
4626 case Builtin::BI__builtin_hlsl_elementwise_f32tof16: {
4627 if (SemaRef.checkArgCount(TheCall, 1))
4628 return true;
4630 return true;
4632 getASTContext().UnsignedIntTy);
4633 break;
4634 }
4635 }
4636 return false;
4637}
4638
4642 WorkList.push_back(BaseTy);
4643 while (!WorkList.empty()) {
4644 QualType T = WorkList.pop_back_val();
4645 T = T.getCanonicalType().getUnqualifiedType();
4646 if (const auto *AT = dyn_cast<ConstantArrayType>(T)) {
4647 llvm::SmallVector<QualType, 16> ElementFields;
4648 // Generally I've avoided recursion in this algorithm, but arrays of
4649 // structs could be time-consuming to flatten and churn through on the
4650 // work list. Hopefully nesting arrays of structs containing arrays
4651 // of structs too many levels deep is unlikely.
4652 BuildFlattenedTypeList(AT->getElementType(), ElementFields);
4653 // Repeat the element's field list n times.
4654 for (uint64_t Ct = 0; Ct < AT->getZExtSize(); ++Ct)
4655 llvm::append_range(List, ElementFields);
4656 continue;
4657 }
4658 // Vectors can only have element types that are builtin types, so this can
4659 // add directly to the list instead of to the WorkList.
4660 if (const auto *VT = dyn_cast<VectorType>(T)) {
4661 List.insert(List.end(), VT->getNumElements(), VT->getElementType());
4662 continue;
4663 }
4664 if (const auto *MT = dyn_cast<ConstantMatrixType>(T)) {
4665 List.insert(List.end(), MT->getNumElementsFlattened(),
4666 MT->getElementType());
4667 continue;
4668 }
4669 if (const auto *RD = T->getAsCXXRecordDecl()) {
4670 if (RD->isStandardLayout())
4671 RD = RD->getStandardLayoutBaseWithFields();
4672
4673 // For types that we shouldn't decompose (unions and non-aggregates), just
4674 // add the type itself to the list.
4675 if (RD->isUnion() || !RD->isAggregate()) {
4676 List.push_back(T);
4677 continue;
4678 }
4679
4681 for (const auto *FD : RD->fields())
4682 if (!FD->isUnnamedBitField())
4683 FieldTypes.push_back(FD->getType());
4684 // Reverse the newly added sub-range.
4685 std::reverse(FieldTypes.begin(), FieldTypes.end());
4686 llvm::append_range(WorkList, FieldTypes);
4687
4688 // If this wasn't a standard layout type we may also have some base
4689 // classes to deal with.
4690 if (!RD->isStandardLayout()) {
4691 FieldTypes.clear();
4692 for (const auto &Base : RD->bases())
4693 FieldTypes.push_back(Base.getType());
4694 std::reverse(FieldTypes.begin(), FieldTypes.end());
4695 llvm::append_range(WorkList, FieldTypes);
4696 }
4697 continue;
4698 }
4699 List.push_back(T);
4700 }
4701}
4702
4704 // null and array types are not allowed.
4705 if (QT.isNull() || QT->isArrayType())
4706 return false;
4707
4708 // UDT types are not allowed
4709 if (QT->isRecordType())
4710 return false;
4711
4712 if (QT->isBooleanType() || QT->isEnumeralType())
4713 return false;
4714
4715 // the only other valid builtin types are scalars or vectors
4716 if (QT->isArithmeticType()) {
4717 if (SemaRef.Context.getTypeSize(QT) / 8 > 16)
4718 return false;
4719 return true;
4720 }
4721
4722 if (const VectorType *VT = QT->getAs<VectorType>()) {
4723 int ArraySize = VT->getNumElements();
4724
4725 if (ArraySize > 4)
4726 return false;
4727
4728 QualType ElTy = VT->getElementType();
4729 if (ElTy->isBooleanType())
4730 return false;
4731
4732 if (SemaRef.Context.getTypeSize(QT) / 8 > 16)
4733 return false;
4734 return true;
4735 }
4736
4737 return false;
4738}
4739
4741 if (T1.isNull() || T2.isNull())
4742 return false;
4743
4746
4747 // If both types are the same canonical type, they're obviously compatible.
4748 if (SemaRef.getASTContext().hasSameType(T1, T2))
4749 return true;
4750
4752 BuildFlattenedTypeList(T1, T1Types);
4754 BuildFlattenedTypeList(T2, T2Types);
4755
4756 // Check the flattened type list
4757 return llvm::equal(T1Types, T2Types,
4758 [this](QualType LHS, QualType RHS) -> bool {
4759 return SemaRef.IsLayoutCompatible(LHS, RHS);
4760 });
4761}
4762
4764 FunctionDecl *Old) {
4765 if (New->getNumParams() != Old->getNumParams())
4766 return true;
4767
4768 bool HadError = false;
4769
4770 for (unsigned i = 0, e = New->getNumParams(); i != e; ++i) {
4771 ParmVarDecl *NewParam = New->getParamDecl(i);
4772 ParmVarDecl *OldParam = Old->getParamDecl(i);
4773
4774 // HLSL parameter declarations for inout and out must match between
4775 // declarations. In HLSL inout and out are ambiguous at the call site,
4776 // but have different calling behavior, so you cannot overload a
4777 // method based on a difference between inout and out annotations.
4778 const auto *NDAttr = NewParam->getAttr<HLSLParamModifierAttr>();
4779 unsigned NSpellingIdx = (NDAttr ? NDAttr->getSpellingListIndex() : 0);
4780 const auto *ODAttr = OldParam->getAttr<HLSLParamModifierAttr>();
4781 unsigned OSpellingIdx = (ODAttr ? ODAttr->getSpellingListIndex() : 0);
4782
4783 if (NSpellingIdx != OSpellingIdx) {
4784 SemaRef.Diag(NewParam->getLocation(),
4785 diag::err_hlsl_param_qualifier_mismatch)
4786 << NDAttr << NewParam;
4787 SemaRef.Diag(OldParam->getLocation(), diag::note_previous_declaration_as)
4788 << ODAttr;
4789 HadError = true;
4790 }
4791 }
4792 return HadError;
4793}
4794
4795// Generally follows PerformScalarCast, with cases reordered for
4796// clarity of what types are supported
4798
4799 if (!SrcTy->isScalarType() || !DestTy->isScalarType())
4800 return false;
4801
4802 if (SemaRef.getASTContext().hasSameUnqualifiedType(SrcTy, DestTy))
4803 return true;
4804
4805 switch (SrcTy->getScalarTypeKind()) {
4806 case Type::STK_Bool: // casting from bool is like casting from an integer
4807 case Type::STK_Integral:
4808 switch (DestTy->getScalarTypeKind()) {
4809 case Type::STK_Bool:
4810 case Type::STK_Integral:
4811 case Type::STK_Floating:
4812 return true;
4813 case Type::STK_CPointer:
4817 llvm_unreachable("HLSL doesn't support pointers.");
4820 llvm_unreachable("HLSL doesn't support complex types.");
4822 llvm_unreachable("HLSL doesn't support fixed point types.");
4823 }
4824 llvm_unreachable("Should have returned before this");
4825
4826 case Type::STK_Floating:
4827 switch (DestTy->getScalarTypeKind()) {
4828 case Type::STK_Floating:
4829 case Type::STK_Bool:
4830 case Type::STK_Integral:
4831 return true;
4834 llvm_unreachable("HLSL doesn't support complex types.");
4836 llvm_unreachable("HLSL doesn't support fixed point types.");
4837 case Type::STK_CPointer:
4841 llvm_unreachable("HLSL doesn't support pointers.");
4842 }
4843 llvm_unreachable("Should have returned before this");
4844
4846 case Type::STK_CPointer:
4849 llvm_unreachable("HLSL doesn't support pointers.");
4850
4852 llvm_unreachable("HLSL doesn't support fixed point types.");
4853
4856 llvm_unreachable("HLSL doesn't support complex types.");
4857 }
4858
4859 llvm_unreachable("Unhandled scalar cast");
4860}
4861
4862// Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
4863// Src is a scalar, a vector of length 1, or a 1x1 matrix
4864// Or if Dest is a vector and Src is a vector of length 1 or a 1x1 matrix
4866
4867 QualType SrcTy = Src->getType();
4868 // Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is
4869 // going to be a vector splat from a scalar.
4870 if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
4871 DestTy->isScalarType())
4872 return false;
4873
4874 const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
4875 const ConstantMatrixType *SrcMatTy = SrcTy->getAs<ConstantMatrixType>();
4876
4877 // Src isn't a scalar, a vector of length 1, or a 1x1 matrix
4878 if (!SrcTy->isScalarType() &&
4879 !(SrcVecTy && SrcVecTy->getNumElements() == 1) &&
4880 !(SrcMatTy && SrcMatTy->getNumElementsFlattened() == 1))
4881 return false;
4882
4883 if (SrcVecTy)
4884 SrcTy = SrcVecTy->getElementType();
4885 else if (SrcMatTy)
4886 SrcTy = SrcMatTy->getElementType();
4887
4889 BuildFlattenedTypeList(DestTy, DestTypes);
4890
4891 for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) {
4892 if (DestTypes[I]->isUnionType())
4893 return false;
4894 if (!CanPerformScalarCast(SrcTy, DestTypes[I]))
4895 return false;
4896 }
4897 return true;
4898}
4899
4900// Can we perform an HLSL Elementwise cast?
4902
4903 // Don't handle casts where LHS and RHS are any combination of scalar/vector
4904 // There must be an aggregate somewhere
4905 QualType SrcTy = Src->getType();
4906 if (SrcTy->isScalarType()) // always a splat and this cast doesn't handle that
4907 return false;
4908
4909 if (SrcTy->isVectorType() &&
4910 (DestTy->isScalarType() || DestTy->isVectorType()))
4911 return false;
4912
4913 if (SrcTy->isConstantMatrixType() &&
4914 (DestTy->isScalarType() || DestTy->isConstantMatrixType()))
4915 return false;
4916
4918 BuildFlattenedTypeList(DestTy, DestTypes);
4920 BuildFlattenedTypeList(SrcTy, SrcTypes);
4921
4922 // Usually the size of SrcTypes must be greater than or equal to the size of
4923 // DestTypes.
4924 if (SrcTypes.size() < DestTypes.size())
4925 return false;
4926
4927 unsigned SrcSize = SrcTypes.size();
4928 unsigned DstSize = DestTypes.size();
4929 unsigned I;
4930 for (I = 0; I < DstSize && I < SrcSize; I++) {
4931 if (SrcTypes[I]->isUnionType() || DestTypes[I]->isUnionType())
4932 return false;
4933 if (!CanPerformScalarCast(SrcTypes[I], DestTypes[I])) {
4934 return false;
4935 }
4936 }
4937
4938 // check the rest of the source type for unions.
4939 for (; I < SrcSize; I++) {
4940 if (SrcTypes[I]->isUnionType())
4941 return false;
4942 }
4943 return true;
4944}
4945
4947 assert(Param->hasAttr<HLSLParamModifierAttr>() &&
4948 "We should not get here without a parameter modifier expression");
4949 const auto *Attr = Param->getAttr<HLSLParamModifierAttr>();
4950 if (Attr->getABI() == ParameterABI::Ordinary)
4951 return ExprResult(Arg);
4952
4953 bool IsInOut = Attr->getABI() == ParameterABI::HLSLInOut;
4954 if (!Arg->isLValue()) {
4955 SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_lvalue)
4956 << Arg << (IsInOut ? 1 : 0);
4957 return ExprError();
4958 }
4959
4960 ASTContext &Ctx = SemaRef.getASTContext();
4961
4962 QualType Ty = Param->getType().getNonLValueExprType(Ctx);
4963
4964 // HLSL allows implicit conversions from scalars to vectors, but not the
4965 // inverse, so we need to disallow `inout` with scalar->vector or
4966 // scalar->matrix conversions.
4967 if (Arg->getType()->isScalarType() != Ty->isScalarType()) {
4968 SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_scalar_extension)
4969 << Arg << (IsInOut ? 1 : 0);
4970 return ExprError();
4971 }
4972
4973 auto *ArgOpV = new (Ctx) OpaqueValueExpr(Param->getBeginLoc(), Arg->getType(),
4974 VK_LValue, OK_Ordinary, Arg);
4975
4976 // Parameters are initialized via copy initialization. This allows for
4977 // overload resolution of argument constructors.
4978 InitializedEntity Entity =
4980 ExprResult Res =
4981 SemaRef.PerformCopyInitialization(Entity, Param->getBeginLoc(), ArgOpV);
4982 if (Res.isInvalid())
4983 return ExprError();
4984 Expr *Base = Res.get();
4985 // After the cast, drop the reference type when creating the exprs.
4986 Ty = Ty.getNonLValueExprType(Ctx);
4987 auto *OpV = new (Ctx)
4988 OpaqueValueExpr(Param->getBeginLoc(), Ty, VK_LValue, OK_Ordinary, Base);
4989
4990 // Writebacks are performed with `=` binary operator, which allows for
4991 // overload resolution on writeback result expressions.
4992 Res = SemaRef.ActOnBinOp(SemaRef.getCurScope(), Param->getBeginLoc(),
4993 tok::equal, ArgOpV, OpV);
4994
4995 if (Res.isInvalid())
4996 return ExprError();
4997 Expr *Writeback = Res.get();
4998 auto *OutExpr =
4999 HLSLOutArgExpr::Create(Ctx, Ty, ArgOpV, OpV, Writeback, IsInOut);
5000
5001 return ExprResult(OutExpr);
5002}
5003
5005 // If HLSL gains support for references, all the cites that use this will need
5006 // to be updated with semantic checking to produce errors for
5007 // pointers/references.
5008 assert(!Ty->isReferenceType() &&
5009 "Pointer and reference types cannot be inout or out parameters");
5010 Ty = SemaRef.getASTContext().getLValueReferenceType(Ty);
5011 Ty.addRestrict();
5012 return Ty;
5013}
5014
5015// Returns true if the type has a non-empty constant buffer layout (if it is
5016// scalar, vector or matrix, or if it contains any of these.
5018 const Type *Ty = QT->getUnqualifiedDesugaredType();
5019 if (Ty->isScalarType() || Ty->isVectorType() || Ty->isMatrixType())
5020 return true;
5021
5023 return false;
5024
5025 if (const auto *RD = Ty->getAsCXXRecordDecl()) {
5026 for (const auto *FD : RD->fields()) {
5028 return true;
5029 }
5030 assert(RD->getNumBases() <= 1 &&
5031 "HLSL doesn't support multiple inheritance");
5032 return RD->getNumBases()
5033 ? hasConstantBufferLayout(RD->bases_begin()->getType())
5034 : false;
5035 }
5036
5037 if (const auto *AT = dyn_cast<ArrayType>(Ty)) {
5038 if (const auto *CAT = dyn_cast<ConstantArrayType>(AT))
5039 if (isZeroSizedArray(CAT))
5040 return false;
5042 }
5043
5044 return false;
5045}
5046
5047static bool IsDefaultBufferConstantDecl(const ASTContext &Ctx, VarDecl *VD) {
5048 bool IsVulkan =
5049 Ctx.getTargetInfo().getTriple().getOS() == llvm::Triple::Vulkan;
5050 bool IsVKPushConstant = IsVulkan && VD->hasAttr<HLSLVkPushConstantAttr>();
5051 QualType QT = VD->getType();
5052 return VD->getDeclContext()->isTranslationUnit() &&
5053 QT.getAddressSpace() == LangAS::Default &&
5054 VD->getStorageClass() != SC_Static &&
5055 !VD->hasAttr<HLSLVkConstantIdAttr>() && !IsVKPushConstant &&
5057}
5058
5060 // The variable already has an address space (groupshared for ex).
5061 if (Decl->getType().hasAddressSpace())
5062 return;
5063
5064 if (Decl->getType()->isDependentType())
5065 return;
5066
5067 QualType Type = Decl->getType();
5068
5069 if (Decl->hasAttr<HLSLVkExtBuiltinInputAttr>()) {
5070 LangAS ImplAS = LangAS::hlsl_input;
5071 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
5072 Decl->setType(Type);
5073 return;
5074 }
5075
5076 if (Decl->hasAttr<HLSLVkExtBuiltinOutputAttr>()) {
5077 LangAS ImplAS = LangAS::hlsl_output;
5078 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
5079 Decl->setType(Type);
5080
5081 // HLSL uses `static` differently than C++. For BuiltIn output, the static
5082 // does not imply private to the module scope.
5083 // Marking it as external to reflect the semantic this attribute brings.
5084 // See https://github.com/microsoft/hlsl-specs/issues/350
5085 Decl->setStorageClass(SC_Extern);
5086 return;
5087 }
5088
5089 bool IsVulkan = getASTContext().getTargetInfo().getTriple().getOS() ==
5090 llvm::Triple::Vulkan;
5091 if (IsVulkan && Decl->hasAttr<HLSLVkPushConstantAttr>()) {
5092 if (HasDeclaredAPushConstant)
5093 SemaRef.Diag(Decl->getLocation(), diag::err_hlsl_push_constant_unique);
5094
5096 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
5097 Decl->setType(Type);
5098 HasDeclaredAPushConstant = true;
5099 return;
5100 }
5101
5102 if (Type->isSamplerT() || Type->isVoidType())
5103 return;
5104
5105 // Resource handles.
5107 return;
5108
5109 // Only static globals belong to the Private address space.
5110 // Non-static globals belongs to the cbuffer.
5111 if (Decl->getStorageClass() != SC_Static && !Decl->isStaticDataMember())
5112 return;
5113
5115 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
5116 Decl->setType(Type);
5117}
5118
5119namespace {
5120
5121// Helper class for assigning bindings to resources declared within a struct.
5122// It keeps track of all binding attributes declared on a struct instance, and
5123// the offsets for each register type that have been assigned so far.
5124// Handles both explicit and implicit bindings.
5125class StructBindingContext {
5126 // Bindings and offsets per register type. We only need to support four
5127 // register types - SRV (u), UAV (t), CBuffer (c), and Sampler (s).
5128 HLSLResourceBindingAttr *RegBindingsAttrs[4];
5129 unsigned RegBindingOffset[4];
5130
5131 // Make sure the RegisterType values are what we expect
5132 static_assert(static_cast<unsigned>(RegisterType::SRV) == 0 &&
5133 static_cast<unsigned>(RegisterType::UAV) == 1 &&
5134 static_cast<unsigned>(RegisterType::CBuffer) == 2 &&
5135 static_cast<unsigned>(RegisterType::Sampler) == 3,
5136 "unexpected register type values");
5137
5138 // Vulkan binding attribute does not vary by register type.
5139 HLSLVkBindingAttr *VkBindingAttr;
5140 unsigned VkBindingOffset;
5141
5142public:
5143 // Constructor: gather all binding attributes on a struct instance and
5144 // initialize offsets.
5145 StructBindingContext(VarDecl *VD) {
5146 for (unsigned i = 0; i < 4; ++i) {
5147 RegBindingsAttrs[i] = nullptr;
5148 RegBindingOffset[i] = 0;
5149 }
5150 VkBindingAttr = nullptr;
5151 VkBindingOffset = 0;
5152
5153 ASTContext &AST = VD->getASTContext();
5154 bool IsSpirv = AST.getTargetInfo().getTriple().isSPIRV();
5155
5156 for (Attr *A : VD->attrs()) {
5157 if (auto *RBA = dyn_cast<HLSLResourceBindingAttr>(A)) {
5158 RegisterType RegType = RBA->getRegisterType();
5159 unsigned RegTypeIdx = static_cast<unsigned>(RegType);
5160 // Ignore unsupported register annotations, such as 'c' or 'i'.
5161 if (RegTypeIdx < 4)
5162 RegBindingsAttrs[RegTypeIdx] = RBA;
5163 continue;
5164 }
5165 // Gather the Vulkan binding attributes only if the target is SPIR-V.
5166 if (IsSpirv) {
5167 if (auto *VBA = dyn_cast<HLSLVkBindingAttr>(A))
5168 VkBindingAttr = VBA;
5169 }
5170 }
5171 }
5172
5173 // Creates a binding attribute for a resource based on the gathered attributes
5174 // and the required register type and range.
5175 Attr *createBindingAttr(SemaHLSL &S, ASTContext &AST, RegisterType RegType,
5176 unsigned Range, bool HasCounter) {
5177 assert(static_cast<unsigned>(RegType) < 4 && "unexpected register type");
5178
5179 if (VkBindingAttr) {
5180 unsigned Offset = VkBindingOffset;
5181 VkBindingOffset += Range;
5182 return HLSLVkBindingAttr::CreateImplicit(
5183 AST, VkBindingAttr->getBinding() + Offset, VkBindingAttr->getSet(),
5184 VkBindingAttr->getRange());
5185 }
5186
5187 HLSLResourceBindingAttr *RBA =
5188 RegBindingsAttrs[static_cast<unsigned>(RegType)];
5189 HLSLResourceBindingAttr *NewAttr = nullptr;
5190
5191 if (RBA && RBA->hasRegisterSlot()) {
5192 // Explicit binding - create a new attribute with offseted slot number
5193 // based on the required register type.
5194 unsigned Offset = RegBindingOffset[static_cast<unsigned>(RegType)];
5195 RegBindingOffset[static_cast<unsigned>(RegType)] += Range;
5196
5197 unsigned NewSlotNumber = RBA->getSlotNumber() + Offset;
5198 StringRef NewSlotNumberStr =
5199 createRegisterString(AST, RBA->getRegisterType(), NewSlotNumber);
5200 NewAttr = HLSLResourceBindingAttr::CreateImplicit(
5201 AST, NewSlotNumberStr, RBA->getSpace(), RBA->getRange());
5202 NewAttr->setBinding(RegType, NewSlotNumber, RBA->getSpaceNumber());
5203 } else {
5204 // No binding attribute or space-only binding - create a binding
5205 // attribute for implicit binding.
5206 NewAttr = HLSLResourceBindingAttr::CreateImplicit(AST, "", "0", {});
5207 NewAttr->setBinding(RegType, std::nullopt,
5208 RBA ? RBA->getSpaceNumber() : 0);
5209 NewAttr->setImplicitBindingOrderID(S.getNextImplicitBindingOrderID());
5210 }
5211 if (HasCounter)
5212 NewAttr->setImplicitCounterBindingOrderID(
5214 return NewAttr;
5215 }
5216};
5217
5218// Creates a global variable declaration for a resource field embedded in a
5219// struct, assigns it a binding, initializes it, and associates it with the
5220// struct declaration via an HLSLAssociatedResourceDeclAttr.
5221static void createGlobalResourceDeclForStruct(
5222 Sema &S, VarDecl *ParentVD, SourceLocation Loc, IdentifierInfo *Id,
5223 QualType ResTy, StructBindingContext &BindingCtx) {
5224 assert(isResourceRecordTypeOrArrayOf(ResTy) &&
5225 "expected resource type or array of resources");
5226
5227 DeclContext *DC = ParentVD->getNonTransparentDeclContext();
5228 assert(DC->isTranslationUnit() && "expected translation unit decl context");
5229
5230 ASTContext &AST = S.getASTContext();
5231 VarDecl *ResDecl =
5232 VarDecl::Create(AST, DC, Loc, Loc, Id, ResTy, nullptr, SC_None);
5233
5234 unsigned Range = 1;
5235 const Type *SingleResTy = ResTy.getTypePtr()->getUnqualifiedDesugaredType();
5236 while (const auto *AT = dyn_cast<ArrayType>(SingleResTy)) {
5237 const auto *CAT = dyn_cast<ConstantArrayType>(AT);
5238 Range = CAT ? (Range * CAT->getSize().getZExtValue()) : 0;
5239 SingleResTy =
5241 }
5242 const HLSLAttributedResourceType *ResHandleTy =
5243 HLSLAttributedResourceType::findHandleTypeOnResource(SingleResTy);
5244
5245 // Add a binding attribute to the global resource declaration.
5246 bool HasCounter = hasCounterHandle(SingleResTy->getAsCXXRecordDecl());
5247 Attr *BindingAttr = BindingCtx.createBindingAttr(
5248 S.HLSL(), AST, getRegisterType(ResHandleTy), Range, HasCounter);
5249 ResDecl->addAttr(BindingAttr);
5250 ResDecl->addAttr(InternalLinkageAttr::CreateImplicit(AST));
5251 ResDecl->setImplicit();
5252
5253 if (Range == 1)
5254 S.HLSL().initGlobalResourceDecl(ResDecl);
5255 else
5256 S.HLSL().initGlobalResourceArrayDecl(ResDecl);
5257
5258 ParentVD->addAttr(
5259 HLSLAssociatedResourceDeclAttr::CreateImplicit(AST, ResDecl));
5260 DC->addDecl(ResDecl);
5261
5262 DeclGroupRef DG(ResDecl);
5264}
5265
5266static void handleArrayOfStructWithResources(
5267 Sema &S, VarDecl *ParentVD, const ConstantArrayType *CAT,
5268 EmbeddedResourceNameBuilder &NameBuilder, StructBindingContext &BindingCtx);
5269
5270// Scans base and all fields of a struct/class type to find all embedded
5271// resources or resource arrays. Creates a global variable for each resource
5272// found.
5273static void handleStructWithResources(Sema &S, VarDecl *ParentVD,
5274 const CXXRecordDecl *RD,
5275 EmbeddedResourceNameBuilder &NameBuilder,
5276 StructBindingContext &BindingCtx) {
5277
5278 // Scan the base classes.
5279 assert(RD->getNumBases() <= 1 && "HLSL doesn't support multiple inheritance");
5280 const auto *BasesIt = RD->bases_begin();
5281 if (BasesIt != RD->bases_end()) {
5282 QualType QT = BasesIt->getType();
5283 if (QT->isHLSLIntangibleType()) {
5284 CXXRecordDecl *BaseRD = QT->getAsCXXRecordDecl();
5285 NameBuilder.pushBaseName(BaseRD->getName());
5286 handleStructWithResources(S, ParentVD, BaseRD, NameBuilder, BindingCtx);
5287 NameBuilder.pop();
5288 }
5289 }
5290 // Process this class fields.
5291 for (const FieldDecl *FD : RD->fields()) {
5292 QualType FDTy = FD->getType().getCanonicalType();
5293 if (!FDTy->isHLSLIntangibleType())
5294 continue;
5295
5296 NameBuilder.pushName(FD->getName());
5297
5299 IdentifierInfo *II = NameBuilder.getNameAsIdentifier(S.getASTContext());
5300 createGlobalResourceDeclForStruct(S, ParentVD, FD->getLocation(), II,
5301 FDTy, BindingCtx);
5302 } else if (const auto *RD = FDTy->getAsCXXRecordDecl()) {
5303 handleStructWithResources(S, ParentVD, RD, NameBuilder, BindingCtx);
5304
5305 } else if (const auto *ArrayTy = dyn_cast<ConstantArrayType>(FDTy)) {
5306 assert(!FDTy->isHLSLResourceRecordArray() &&
5307 "resource arrays should have been already handled");
5308 handleArrayOfStructWithResources(S, ParentVD, ArrayTy, NameBuilder,
5309 BindingCtx);
5310 }
5311 NameBuilder.pop();
5312 }
5313}
5314
5315// Processes array of structs with resources.
5316static void
5317handleArrayOfStructWithResources(Sema &S, VarDecl *ParentVD,
5318 const ConstantArrayType *CAT,
5319 EmbeddedResourceNameBuilder &NameBuilder,
5320 StructBindingContext &BindingCtx) {
5321
5322 QualType ElementTy = CAT->getElementType().getCanonicalType();
5323 assert(ElementTy->isHLSLIntangibleType() && "Expected HLSL intangible type");
5324
5325 const ConstantArrayType *SubCAT = dyn_cast<ConstantArrayType>(ElementTy);
5326 const CXXRecordDecl *ElementRD = ElementTy->getAsCXXRecordDecl();
5327
5328 if (!SubCAT && !ElementRD)
5329 return;
5330
5331 for (unsigned I = 0, E = CAT->getSize().getZExtValue(); I < E; ++I) {
5332 NameBuilder.pushArrayIndex(I);
5333 if (ElementRD)
5334 handleStructWithResources(S, ParentVD, ElementRD, NameBuilder,
5335 BindingCtx);
5336 else
5337 handleArrayOfStructWithResources(S, ParentVD, SubCAT, NameBuilder,
5338 BindingCtx);
5339 NameBuilder.pop();
5340 }
5341}
5342
5343} // namespace
5344
5345// Scans all fields of a user-defined struct (or array of structs)
5346// to find all embedded resources or resource arrays. For each resource
5347// a global variable of the resource type is created and associated
5348// with the parent declaration (VD) through a HLSLAssociatedResourceDeclAttr
5349// attribute.
5350void SemaHLSL::handleGlobalStructOrArrayOfWithResources(VarDecl *VD) {
5351 EmbeddedResourceNameBuilder NameBuilder(VD->getName());
5352 StructBindingContext BindingCtx(VD);
5353
5354 const Type *VDTy = VD->getType().getTypePtr();
5355 assert(VDTy->isHLSLIntangibleType() && !isResourceRecordTypeOrArrayOf(VD) &&
5356 "Expected non-resource struct or array type");
5357
5358 if (const CXXRecordDecl *RD = VDTy->getAsCXXRecordDecl()) {
5359 handleStructWithResources(SemaRef, VD, RD, NameBuilder, BindingCtx);
5360 return;
5361 }
5362
5363 if (const auto *CAT = dyn_cast<ConstantArrayType>(VDTy)) {
5364 handleArrayOfStructWithResources(SemaRef, VD, CAT, NameBuilder, BindingCtx);
5365 return;
5366 }
5367}
5368
5370 if (VD->hasGlobalStorage()) {
5371 // make sure the declaration has a complete type
5372 if (SemaRef.RequireCompleteType(
5373 VD->getLocation(),
5374 SemaRef.getASTContext().getBaseElementType(VD->getType()),
5375 diag::err_typecheck_decl_incomplete_type)) {
5376 VD->setInvalidDecl();
5378 return;
5379 }
5380
5381 // Global variables outside a cbuffer block that are not a resource, static,
5382 // groupshared, or an empty array or struct belong to the default constant
5383 // buffer $Globals (to be created at the end of the translation unit).
5385 // update address space to hlsl_constant
5388 VD->setType(NewTy);
5389 DefaultCBufferDecls.push_back(VD);
5390 }
5391
5392 // find all resources bindings on decl
5393 if (VD->getType()->isHLSLIntangibleType())
5394 collectResourceBindingsOnVarDecl(VD);
5395
5396 if (VD->hasAttr<HLSLVkConstantIdAttr>())
5398
5400 VD->getStorageClass() != SC_Static) {
5401 // Add internal linkage attribute to non-static resource variables. The
5402 // global externally visible storage is accessed through the handle, which
5403 // is a member. The variable itself is not externally visible.
5404 VD->addAttr(InternalLinkageAttr::CreateImplicit(getASTContext()));
5405 }
5406
5407 // process explicit bindings
5408 processExplicitBindingsOnDecl(VD);
5409
5410 // Add implicit binding attribute to non-static resource arrays.
5411 if (VD->getType()->isHLSLResourceRecordArray() &&
5412 VD->getStorageClass() != SC_Static) {
5413 // If the resource array does not have an explicit binding attribute,
5414 // create an implicit one. It will be used to transfer implicit binding
5415 // order_ID to codegen.
5416 ResourceBindingAttrs Binding(VD);
5417 if (!Binding.isExplicit()) {
5418 uint32_t OrderID = getNextImplicitBindingOrderID();
5419 if (Binding.hasBinding())
5420 Binding.setImplicitOrderID(OrderID);
5421 else {
5424 OrderID);
5425 // Re-create the binding object to pick up the new attribute.
5426 Binding = ResourceBindingAttrs(VD);
5427 }
5428 }
5429
5430 // Get to the base type of a potentially multi-dimensional array.
5432
5433 const CXXRecordDecl *RD = Ty->getAsCXXRecordDecl();
5434 if (hasCounterHandle(RD)) {
5435 if (!Binding.hasCounterImplicitOrderID()) {
5436 uint32_t OrderID = getNextImplicitBindingOrderID();
5437 Binding.setCounterImplicitOrderID(OrderID);
5438 }
5439 }
5440 }
5441
5442 // Process resources in user-defined structs, or arrays of such structs.
5443 const Type *VDTy = VD->getType().getTypePtr();
5444 if (VD->getStorageClass() != SC_Static && VDTy->isHLSLIntangibleType() &&
5446 handleGlobalStructOrArrayOfWithResources(VD);
5447
5448 // Mark groupshared variables as extern so they will have
5449 // external storage and won't be default initialized
5450 if (VD->hasAttr<HLSLGroupSharedAddressSpaceAttr>())
5452 }
5453
5455}
5456
5458 assert(VD->getType()->isHLSLResourceRecord() &&
5459 "expected resource record type");
5460
5461 ASTContext &AST = SemaRef.getASTContext();
5462 uint64_t UIntTySize = AST.getTypeSize(AST.UnsignedIntTy);
5463 uint64_t IntTySize = AST.getTypeSize(AST.IntTy);
5464
5465 // Gather resource binding attributes.
5466 ResourceBindingAttrs Binding(VD);
5467
5468 // Find correct initialization method and create its arguments.
5469 QualType ResourceTy = VD->getType();
5470 CXXRecordDecl *ResourceDecl = ResourceTy->getAsCXXRecordDecl();
5471 CXXMethodDecl *CreateMethod = nullptr;
5473
5474 bool HasCounter = hasCounterHandle(ResourceDecl);
5475 const char *CreateMethodName;
5476 if (Binding.isExplicit())
5477 CreateMethodName = HasCounter ? "__createFromBindingWithImplicitCounter"
5478 : "__createFromBinding";
5479 else
5480 CreateMethodName = HasCounter
5481 ? "__createFromImplicitBindingWithImplicitCounter"
5482 : "__createFromImplicitBinding";
5483
5484 CreateMethod =
5485 lookupMethod(SemaRef, ResourceDecl, CreateMethodName, VD->getLocation());
5486
5487 if (!CreateMethod) {
5488 // This can happen if someone creates a struct that looks like an HLSL
5489 // resource record but does not have the required static create method.
5490 // No binding will be generated for it.
5491 assert(!ResourceDecl->isImplicit() &&
5492 "create method lookup should always succeed for built-in resource "
5493 "records");
5494 return false;
5495 }
5496
5497 if (Binding.isExplicit()) {
5498 IntegerLiteral *RegSlot =
5499 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, Binding.getSlot()),
5501 Args.push_back(RegSlot);
5502 } else {
5503 uint32_t OrderID = (Binding.hasImplicitOrderID())
5504 ? Binding.getImplicitOrderID()
5506 IntegerLiteral *OrderId =
5507 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, OrderID),
5509 Args.push_back(OrderId);
5510 }
5511
5512 IntegerLiteral *Space =
5513 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, Binding.getSpace()),
5515 Args.push_back(Space);
5516
5518 AST, llvm::APInt(IntTySize, 1), AST.IntTy, SourceLocation());
5519 Args.push_back(RangeSize);
5520
5522 AST, llvm::APInt(UIntTySize, 0), AST.UnsignedIntTy, SourceLocation());
5523 Args.push_back(Index);
5524
5525 StringRef VarName = VD->getName();
5527 AST, VarName, StringLiteralKind::Ordinary, false,
5528 AST.getStringLiteralArrayType(AST.CharTy.withConst(), VarName.size()),
5529 SourceLocation());
5531 AST, AST.getPointerType(AST.CharTy.withConst()), CK_ArrayToPointerDecay,
5532 Name, nullptr, VK_PRValue, FPOptionsOverride());
5533 Args.push_back(NameCast);
5534
5535 if (HasCounter) {
5536 // Will this be in the correct order?
5537 uint32_t CounterOrderID = getNextImplicitBindingOrderID();
5538 IntegerLiteral *CounterId =
5539 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, CounterOrderID),
5541 Args.push_back(CounterId);
5542 }
5543
5544 // Make sure the create method template is instantiated and emitted.
5545 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
5546 SemaRef.InstantiateFunctionDefinition(VD->getLocation(), CreateMethod,
5547 true);
5548
5549 // Create CallExpr with a call to the static method and set it as the decl
5550 // initialization.
5552 AST, NestedNameSpecifierLoc(), SourceLocation(), CreateMethod, false,
5553 CreateMethod->getNameInfo(), CreateMethod->getType(), VK_PRValue);
5554
5555 auto *ImpCast = ImplicitCastExpr::Create(
5556 AST, AST.getPointerType(CreateMethod->getType()),
5557 CK_FunctionToPointerDecay, DRE, nullptr, VK_PRValue, FPOptionsOverride());
5558
5559 CallExpr *InitExpr =
5560 CallExpr::Create(AST, ImpCast, Args, ResourceTy, VK_PRValue,
5562 VD->setInit(InitExpr);
5564 SemaRef.CheckCompleteVariableDeclaration(VD);
5565 return true;
5566}
5567
5569 assert(VD->getType()->isHLSLResourceRecordArray() &&
5570 "expected array of resource records");
5571
5572 // Individual resources in a resource array are not initialized here. They
5573 // are initialized later on during codegen when the individual resources are
5574 // accessed. Codegen will emit a call to the resource initialization method
5575 // with the specified array index. We need to make sure though that the method
5576 // for the specific resource type is instantiated, so codegen can emit a call
5577 // to it when the array element is accessed.
5578
5579 // Find correct initialization method based on the resource binding
5580 // information.
5581 ASTContext &AST = SemaRef.getASTContext();
5582 QualType ResElementTy = AST.getBaseElementType(VD->getType());
5583 CXXRecordDecl *ResourceDecl = ResElementTy->getAsCXXRecordDecl();
5584 CXXMethodDecl *CreateMethod = nullptr;
5585
5586 bool HasCounter = hasCounterHandle(ResourceDecl);
5587 ResourceBindingAttrs ResourceAttrs(VD);
5588 if (ResourceAttrs.isExplicit())
5589 // Resource has explicit binding.
5590 CreateMethod =
5591 lookupMethod(SemaRef, ResourceDecl,
5592 HasCounter ? "__createFromBindingWithImplicitCounter"
5593 : "__createFromBinding",
5594 VD->getLocation());
5595 else
5596 // Resource has implicit binding.
5597 CreateMethod = lookupMethod(
5598 SemaRef, ResourceDecl,
5599 HasCounter ? "__createFromImplicitBindingWithImplicitCounter"
5600 : "__createFromImplicitBinding",
5601 VD->getLocation());
5602
5603 if (!CreateMethod)
5604 return false;
5605
5606 // Make sure the create method template is instantiated and emitted.
5607 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
5608 SemaRef.InstantiateFunctionDefinition(VD->getLocation(), CreateMethod,
5609 true);
5610 return true;
5611}
5612
5613// Returns true if the initialization has been handled.
5614// Returns false to use default initialization.
5616 // Objects in the hlsl_constant address space are initialized
5617 // externally, so don't synthesize an implicit initializer.
5619 return true;
5620
5621 // Initialize non-static resources at the global scope.
5622 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
5623 const Type *Ty = VD->getType().getTypePtr();
5624 if (Ty->isHLSLResourceRecord())
5625 return initGlobalResourceDecl(VD);
5626 if (Ty->isHLSLResourceRecordArray())
5627 return initGlobalResourceArrayDecl(VD);
5628 }
5629 return false;
5630}
5631
5632std::optional<const DeclBindingInfo *> SemaHLSL::inferGlobalBinding(Expr *E) {
5633 if (auto *Ternary = dyn_cast<ConditionalOperator>(E)) {
5634 auto TrueInfo = inferGlobalBinding(Ternary->getTrueExpr());
5635 auto FalseInfo = inferGlobalBinding(Ternary->getFalseExpr());
5636 if (!TrueInfo || !FalseInfo)
5637 return std::nullopt;
5638 if (*TrueInfo != *FalseInfo)
5639 return std::nullopt;
5640 return TrueInfo;
5641 }
5642
5643 if (auto *ASE = dyn_cast<ArraySubscriptExpr>(E))
5644 E = ASE->getBase()->IgnoreParenImpCasts();
5645
5646 if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E->IgnoreParens()))
5647 if (VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
5648 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
5649 if (Ty->isArrayType())
5651
5652 if (const auto *AttrResType =
5653 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
5654 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
5655 return Bindings.getDeclBindingInfo(VD, RC);
5656 }
5657 }
5658
5659 return nullptr;
5660}
5661
5662void SemaHLSL::trackLocalResource(VarDecl *VD, Expr *E) {
5663 std::optional<const DeclBindingInfo *> ExprBinding = inferGlobalBinding(E);
5664 if (!ExprBinding) {
5665 SemaRef.Diag(E->getBeginLoc(),
5666 diag::warn_hlsl_assigning_local_resource_is_not_unique)
5667 << E << VD;
5668 return; // Expr use multiple resources
5669 }
5670
5671 if (*ExprBinding == nullptr)
5672 return; // No binding could be inferred to track, return without error
5673
5674 auto PrevBinding = Assigns.find(VD);
5675 if (PrevBinding == Assigns.end()) {
5676 // No previous binding recorded, simply record the new assignment
5677 Assigns.insert({VD, *ExprBinding});
5678 return;
5679 }
5680
5681 // Otherwise, warn if the assignment implies different resource bindings
5682 if (*ExprBinding != PrevBinding->second) {
5683 SemaRef.Diag(E->getBeginLoc(),
5684 diag::warn_hlsl_assigning_local_resource_is_not_unique)
5685 << E << VD;
5686 SemaRef.Diag(VD->getLocation(), diag::note_var_declared_here) << VD;
5687 return;
5688 }
5689
5690 return;
5691}
5692
5694 Expr *RHSExpr, SourceLocation Loc) {
5695 assert((LHSExpr->getType()->isHLSLResourceRecord() ||
5696 LHSExpr->getType()->isHLSLResourceRecordArray()) &&
5697 "expected LHS to be a resource record or array of resource records");
5698 if (Opc != BO_Assign)
5699 return true;
5700
5701 // If LHS is an array subscript, get the underlying declaration.
5702 Expr *E = LHSExpr;
5703 while (auto *ASE = dyn_cast<ArraySubscriptExpr>(E))
5704 E = ASE->getBase()->IgnoreParenImpCasts();
5705
5706 // Report error if LHS is a non-static resource declared at a global scope.
5707 if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E->IgnoreParens())) {
5708 if (VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
5709 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
5710 // assignment to global resource is not allowed
5711 SemaRef.Diag(Loc, diag::err_hlsl_assign_to_global_resource) << VD;
5712 SemaRef.Diag(VD->getLocation(), diag::note_var_declared_here) << VD;
5713 return false;
5714 }
5715
5716 trackLocalResource(VD, RHSExpr);
5717 }
5718 }
5719 return true;
5720}
5721
5722// Walks though the global variable declaration, collects all resource binding
5723// requirements and adds them to Bindings
5724void SemaHLSL::collectResourceBindingsOnVarDecl(VarDecl *VD) {
5725 assert(VD->hasGlobalStorage() && VD->getType()->isHLSLIntangibleType() &&
5726 "expected global variable that contains HLSL resource");
5727
5728 // Cbuffers and Tbuffers are HLSLBufferDecl types
5729 if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(VD)) {
5730 Bindings.addDeclBindingInfo(VD, CBufferOrTBuffer->isCBuffer()
5731 ? ResourceClass::CBuffer
5732 : ResourceClass::SRV);
5733 return;
5734 }
5735
5736 // Unwrap arrays
5737 // FIXME: Calculate array size while unwrapping
5738 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
5739 while (Ty->isArrayType()) {
5740 const ArrayType *AT = cast<ArrayType>(Ty);
5742 }
5743
5744 // Resource (or array of resources)
5745 if (const HLSLAttributedResourceType *AttrResType =
5746 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
5747 Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass);
5748 return;
5749 }
5750
5751 // User defined record type
5752 if (const RecordType *RT = dyn_cast<RecordType>(Ty))
5753 collectResourceBindingsOnUserRecordDecl(VD, RT);
5754}
5755
5756// Walks though the explicit resource binding attributes on the declaration,
5757// and makes sure there is a resource that matched the binding and updates
5758// DeclBindingInfoLists
5759void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
5760 assert(VD->hasGlobalStorage() && "expected global variable");
5761
5762 bool HasBinding = false;
5763 for (Attr *A : VD->attrs()) {
5764 if (isa<HLSLVkBindingAttr>(A)) {
5765 HasBinding = true;
5766 if (auto PA = VD->getAttr<HLSLVkPushConstantAttr>())
5767 Diag(PA->getLoc(), diag::err_hlsl_attr_incompatible) << A << PA;
5768 }
5769
5770 HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
5771 if (!RBA || !RBA->hasRegisterSlot())
5772 continue;
5773 HasBinding = true;
5774
5775 RegisterType RT = RBA->getRegisterType();
5776 assert(RT != RegisterType::I && "invalid or obsolete register type should "
5777 "never have an attribute created");
5778
5779 if (RT == RegisterType::C) {
5780 if (Bindings.hasBindingInfoForDecl(VD))
5781 SemaRef.Diag(VD->getLocation(),
5782 diag::warn_hlsl_user_defined_type_missing_member)
5783 << static_cast<int>(RT);
5784 continue;
5785 }
5786
5787 // Find DeclBindingInfo for this binding and update it, or report error
5788 // if it does not exist (user type does to contain resources with the
5789 // expected resource class).
5791 if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, RC)) {
5792 // update binding info
5793 BI->setBindingAttribute(RBA, BindingType::Explicit);
5794 } else {
5795 SemaRef.Diag(VD->getLocation(),
5796 diag::warn_hlsl_user_defined_type_missing_member)
5797 << static_cast<int>(RT);
5798 }
5799 }
5800
5801 if (!HasBinding && isResourceRecordTypeOrArrayOf(VD))
5802 SemaRef.Diag(VD->getLocation(), diag::warn_hlsl_implicit_binding);
5803}
5804namespace {
5805class InitListTransformer {
5806 Sema &S;
5807 ASTContext &Ctx;
5808 QualType InitTy;
5809 QualType *DstIt = nullptr;
5810 Expr **ArgIt = nullptr;
5811 // Is wrapping the destination type iterator required? This is only used for
5812 // incomplete array types where we loop over the destination type since we
5813 // don't know the full number of elements from the declaration.
5814 bool Wrap;
5815
5816 bool castInitializer(Expr *E) {
5817 assert(DstIt && "This should always be something!");
5818 if (DstIt == DestTypes.end()) {
5819 if (!Wrap) {
5820 ArgExprs.push_back(E);
5821 // This is odd, but it isn't technically a failure due to conversion, we
5822 // handle mismatched counts of arguments differently.
5823 return true;
5824 }
5825 DstIt = DestTypes.begin();
5826 }
5827 InitializedEntity Entity = InitializedEntity::InitializeParameter(
5828 Ctx, *DstIt, /* Consumed (ObjC) */ false);
5829 ExprResult Res = S.PerformCopyInitialization(Entity, E->getBeginLoc(), E);
5830 if (Res.isInvalid())
5831 return false;
5832 Expr *Init = Res.get();
5833 ArgExprs.push_back(Init);
5834 DstIt++;
5835 return true;
5836 }
5837
5838 bool buildInitializerListImpl(Expr *E) {
5839 // If this is an initialization list, traverse the sub initializers.
5840 if (auto *Init = dyn_cast<InitListExpr>(E)) {
5841 for (auto *SubInit : Init->inits())
5842 if (!buildInitializerListImpl(SubInit))
5843 return false;
5844 return true;
5845 }
5846
5847 // If this is a scalar type, just enqueue the expression.
5848 QualType Ty = E->getType().getDesugaredType(Ctx);
5849
5850 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()) ||
5852 return castInitializer(E);
5853
5854 // If this is an aggregate type and a prvalue, create an xvalue temporary
5855 // so the member accesses will be xvalues. Wrap it in OpaqueExpr to make
5856 // sure codegen will not generate duplicate copies.
5857 if (E->isPRValue() && Ty->isAggregateType()) {
5859 if (TmpExpr.isInvalid())
5860 return false;
5861 E = TmpExpr.get();
5862 E = new (Ctx) OpaqueValueExpr(E->getBeginLoc(), E->getType(),
5863 E->getValueKind(), E->getObjectKind(), E);
5864 }
5865
5866 if (auto *VecTy = Ty->getAs<VectorType>()) {
5867 uint64_t Size = VecTy->getNumElements();
5868
5869 QualType SizeTy = Ctx.getSizeType();
5870 uint64_t SizeTySize = Ctx.getTypeSize(SizeTy);
5871 for (uint64_t I = 0; I < Size; ++I) {
5872 auto *Idx = IntegerLiteral::Create(Ctx, llvm::APInt(SizeTySize, I),
5873 SizeTy, SourceLocation());
5874
5876 E, E->getBeginLoc(), Idx, E->getEndLoc());
5877 if (ElExpr.isInvalid())
5878 return false;
5879 if (!castInitializer(ElExpr.get()))
5880 return false;
5881 }
5882 return true;
5883 }
5884 if (auto *MTy = Ty->getAs<ConstantMatrixType>()) {
5885 unsigned Rows = MTy->getNumRows();
5886 unsigned Cols = MTy->getNumColumns();
5887 QualType ElemTy = MTy->getElementType();
5888
5889 for (unsigned R = 0; R < Rows; ++R) {
5890 for (unsigned C = 0; C < Cols; ++C) {
5891 // row index literal
5892 Expr *RowIdx = IntegerLiteral::Create(
5893 Ctx, llvm::APInt(Ctx.getIntWidth(Ctx.IntTy), R), Ctx.IntTy,
5894 E->getBeginLoc());
5895 // column index literal
5896 Expr *ColIdx = IntegerLiteral::Create(
5897 Ctx, llvm::APInt(Ctx.getIntWidth(Ctx.IntTy), C), Ctx.IntTy,
5898 E->getBeginLoc());
5900 E, RowIdx, ColIdx, E->getEndLoc());
5901 if (ElExpr.isInvalid())
5902 return false;
5903 if (!castInitializer(ElExpr.get()))
5904 return false;
5905 ElExpr.get()->setType(ElemTy);
5906 }
5907 }
5908 return true;
5909 }
5910
5911 if (auto *ArrTy = dyn_cast<ConstantArrayType>(Ty.getTypePtr())) {
5912 uint64_t Size = ArrTy->getZExtSize();
5913 QualType SizeTy = Ctx.getSizeType();
5914 uint64_t SizeTySize = Ctx.getTypeSize(SizeTy);
5915 for (uint64_t I = 0; I < Size; ++I) {
5916 auto *Idx = IntegerLiteral::Create(Ctx, llvm::APInt(SizeTySize, I),
5917 SizeTy, SourceLocation());
5919 E, E->getBeginLoc(), Idx, E->getEndLoc());
5920 if (ElExpr.isInvalid())
5921 return false;
5922 if (!buildInitializerListImpl(ElExpr.get()))
5923 return false;
5924 }
5925 return true;
5926 }
5927
5928 if (auto *RD = Ty->getAsCXXRecordDecl()) {
5929 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
5930 RecordDecls.push_back(RD);
5931 while (RecordDecls.back()->getNumBases()) {
5932 CXXRecordDecl *D = RecordDecls.back();
5933 assert(D->getNumBases() == 1 &&
5934 "HLSL doesn't support multiple inheritance");
5935 RecordDecls.push_back(
5937 }
5938 while (!RecordDecls.empty()) {
5939 CXXRecordDecl *RD = RecordDecls.pop_back_val();
5940 for (auto *FD : RD->fields()) {
5941 if (FD->isUnnamedBitField())
5942 continue;
5943 DeclAccessPair Found = DeclAccessPair::make(FD, FD->getAccess());
5944 DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc());
5946 E, false, E->getBeginLoc(), CXXScopeSpec(), FD, Found, NameInfo);
5947 if (Res.isInvalid())
5948 return false;
5949 if (!buildInitializerListImpl(Res.get()))
5950 return false;
5951 }
5952 }
5953 }
5954 return true;
5955 }
5956
5957 Expr *generateInitListsImpl(QualType Ty) {
5958 Ty = Ty.getDesugaredType(Ctx);
5959 assert(ArgIt != ArgExprs.end() && "Something is off in iteration!");
5960 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()) ||
5962 return *(ArgIt++);
5963
5964 llvm::SmallVector<Expr *> Inits;
5965 if (Ty->isVectorType() || Ty->isConstantArrayType() ||
5966 Ty->isConstantMatrixType()) {
5967 QualType ElTy;
5968 uint64_t Size = 0;
5969 if (auto *ATy = Ty->getAs<VectorType>()) {
5970 ElTy = ATy->getElementType();
5971 Size = ATy->getNumElements();
5972 } else if (auto *CMTy = Ty->getAs<ConstantMatrixType>()) {
5973 ElTy = CMTy->getElementType();
5974 Size = CMTy->getNumElementsFlattened();
5975 } else {
5976 auto *VTy = cast<ConstantArrayType>(Ty.getTypePtr());
5977 ElTy = VTy->getElementType();
5978 Size = VTy->getZExtSize();
5979 }
5980 for (uint64_t I = 0; I < Size; ++I)
5981 Inits.push_back(generateInitListsImpl(ElTy));
5982 }
5983 if (auto *RD = Ty->getAsCXXRecordDecl()) {
5984 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
5985 RecordDecls.push_back(RD);
5986 while (RecordDecls.back()->getNumBases()) {
5987 CXXRecordDecl *D = RecordDecls.back();
5988 assert(D->getNumBases() == 1 &&
5989 "HLSL doesn't support multiple inheritance");
5990 RecordDecls.push_back(
5992 }
5993 while (!RecordDecls.empty()) {
5994 CXXRecordDecl *RD = RecordDecls.pop_back_val();
5995 for (auto *FD : RD->fields())
5996 if (!FD->isUnnamedBitField())
5997 Inits.push_back(generateInitListsImpl(FD->getType()));
5998 }
5999 }
6000 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
6001 Inits, Inits.back()->getEndLoc());
6002 NewInit->setType(Ty);
6003 return NewInit;
6004 }
6005
6006public:
6007 llvm::SmallVector<QualType, 16> DestTypes;
6008 llvm::SmallVector<Expr *, 16> ArgExprs;
6009 InitListTransformer(Sema &SemaRef, const InitializedEntity &Entity)
6010 : S(SemaRef), Ctx(SemaRef.getASTContext()),
6011 Wrap(Entity.getType()->isIncompleteArrayType()) {
6012 InitTy = Entity.getType().getNonReferenceType();
6013 // When we're generating initializer lists for incomplete array types we
6014 // need to wrap around both when building the initializers and when
6015 // generating the final initializer lists.
6016 if (Wrap) {
6017 assert(InitTy->isIncompleteArrayType());
6018 const IncompleteArrayType *IAT = Ctx.getAsIncompleteArrayType(InitTy);
6019 InitTy = IAT->getElementType();
6020 }
6021 BuildFlattenedTypeList(InitTy, DestTypes);
6022 DstIt = DestTypes.begin();
6023 }
6024
6025 bool buildInitializerList(Expr *E) { return buildInitializerListImpl(E); }
6026
6027 Expr *generateInitLists() {
6028 assert(!ArgExprs.empty() &&
6029 "Call buildInitializerList to generate argument expressions.");
6030 ArgIt = ArgExprs.begin();
6031 if (!Wrap)
6032 return generateInitListsImpl(InitTy);
6033 llvm::SmallVector<Expr *> Inits;
6034 while (ArgIt != ArgExprs.end())
6035 Inits.push_back(generateInitListsImpl(InitTy));
6036
6037 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
6038 Inits, Inits.back()->getEndLoc());
6039 llvm::APInt ArySize(64, Inits.size());
6040 NewInit->setType(Ctx.getConstantArrayType(InitTy, ArySize, nullptr,
6041 ArraySizeModifier::Normal, 0));
6042 return NewInit;
6043 }
6044};
6045} // namespace
6046
6047// Recursively detect any incomplete array anywhere in the type graph,
6048// including arrays, struct fields, and base classes.
6050 Ty = Ty.getCanonicalType();
6051
6052 // Array types
6053 if (const ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
6055 return true;
6057 }
6058
6059 // Record (struct/class) types
6060 if (const auto *RT = Ty->getAs<RecordType>()) {
6061 const RecordDecl *RD = RT->getDecl();
6062
6063 // Walk base classes (for C++ / HLSL structs with inheritance)
6064 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
6065 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
6066 if (containsIncompleteArrayType(Base.getType()))
6067 return true;
6068 }
6069 }
6070
6071 // Walk fields
6072 for (const FieldDecl *F : RD->fields()) {
6073 if (containsIncompleteArrayType(F->getType()))
6074 return true;
6075 }
6076 }
6077
6078 return false;
6079}
6080
6082 InitListExpr *Init) {
6083 // If the initializer is a scalar, just return it.
6084 if (Init->getType()->isScalarType())
6085 return true;
6086 ASTContext &Ctx = SemaRef.getASTContext();
6087 InitListTransformer ILT(SemaRef, Entity);
6088
6089 for (unsigned I = 0; I < Init->getNumInits(); ++I) {
6090 Expr *E = Init->getInit(I);
6091 if (E->HasSideEffects(Ctx)) {
6092 QualType Ty = E->getType();
6093 if (Ty->isRecordType())
6094 E = new (Ctx) MaterializeTemporaryExpr(Ty, E, E->isLValue());
6095 E = new (Ctx) OpaqueValueExpr(E->getBeginLoc(), Ty, E->getValueKind(),
6096 E->getObjectKind(), E);
6097 Init->setInit(I, E);
6098 }
6099 if (!ILT.buildInitializerList(E))
6100 return false;
6101 }
6102 size_t ExpectedSize = ILT.DestTypes.size();
6103 size_t ActualSize = ILT.ArgExprs.size();
6104 if (ExpectedSize == 0 && ActualSize == 0)
6105 return true;
6106
6107 // Reject empty initializer if *any* incomplete array exists structurally
6108 if (ActualSize == 0 && containsIncompleteArrayType(Entity.getType())) {
6109 QualType InitTy = Entity.getType().getNonReferenceType();
6110 if (InitTy.hasAddressSpace())
6111 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(InitTy);
6112
6113 SemaRef.Diag(Init->getBeginLoc(), diag::err_hlsl_incorrect_num_initializers)
6114 << /*TooManyOrFew=*/(int)(ExpectedSize < ActualSize) << InitTy
6115 << /*ExpectedSize=*/ExpectedSize << /*ActualSize=*/ActualSize;
6116 return false;
6117 }
6118
6119 // We infer size after validating legality.
6120 // For incomplete arrays it is completely arbitrary to choose whether we think
6121 // the user intended fewer or more elements. This implementation assumes that
6122 // the user intended more, and errors that there are too few initializers to
6123 // complete the final element.
6124 if (Entity.getType()->isIncompleteArrayType()) {
6125 assert(ExpectedSize > 0 &&
6126 "The expected size of an incomplete array type must be at least 1.");
6127 ExpectedSize =
6128 ((ActualSize + ExpectedSize - 1) / ExpectedSize) * ExpectedSize;
6129 }
6130
6131 // An initializer list might be attempting to initialize a reference or
6132 // rvalue-reference. When checking the initializer we should look through
6133 // the reference.
6134 QualType InitTy = Entity.getType().getNonReferenceType();
6135 if (InitTy.hasAddressSpace())
6136 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(InitTy);
6137 if (ExpectedSize != ActualSize) {
6138 int TooManyOrFew = ActualSize > ExpectedSize ? 1 : 0;
6139 SemaRef.Diag(Init->getBeginLoc(), diag::err_hlsl_incorrect_num_initializers)
6140 << TooManyOrFew << InitTy << ExpectedSize << ActualSize;
6141 return false;
6142 }
6143
6144 // generateInitListsImpl will always return an InitListExpr here, because the
6145 // scalar case is handled above.
6146 auto *NewInit = cast<InitListExpr>(ILT.generateInitLists());
6147 Init->resizeInits(Ctx, NewInit->getNumInits());
6148 for (unsigned I = 0; I < NewInit->getNumInits(); ++I)
6149 Init->updateInit(Ctx, I, NewInit->getInit(I));
6150 return true;
6151}
6152
6153static QualType ReportMatrixInvalidMember(Sema &S, StringRef Name,
6154 StringRef Expected,
6155 SourceLocation OpLoc,
6156 SourceLocation CompLoc) {
6157 S.Diag(OpLoc, diag::err_builtin_matrix_invalid_member)
6158 << Name << Expected << SourceRange(CompLoc);
6159 return QualType();
6160}
6161
6164 const IdentifierInfo *CompName,
6165 SourceLocation CompLoc) {
6166 const auto *MT = baseType->castAs<ConstantMatrixType>();
6167 StringRef AccessorName = CompName->getName();
6168 assert(!AccessorName.empty() && "Matrix Accessor must have a name");
6169
6170 unsigned Rows = MT->getNumRows();
6171 unsigned Cols = MT->getNumColumns();
6172 bool IsZeroBasedAccessor = false;
6173 unsigned ChunkLen = 0;
6174 if (AccessorName.size() < 2)
6175 return ReportMatrixInvalidMember(S, AccessorName,
6176 "length 4 for zero based: \'_mRC\' or "
6177 "length 3 for one-based: \'_RC\' accessor",
6178 OpLoc, CompLoc);
6179
6180 if (AccessorName[0] == '_') {
6181 if (AccessorName[1] == 'm') {
6182 IsZeroBasedAccessor = true;
6183 ChunkLen = 4; // zero-based: "_mRC"
6184 } else {
6185 ChunkLen = 3; // one-based: "_RC"
6186 }
6187 } else
6189 S, AccessorName, "zero based: \'_mRC\' or one-based: \'_RC\' accessor",
6190 OpLoc, CompLoc);
6191
6192 if (AccessorName.size() % ChunkLen != 0) {
6193 const llvm::StringRef Expected = IsZeroBasedAccessor
6194 ? "zero based: '_mRC' accessor"
6195 : "one-based: '_RC' accessor";
6196
6197 return ReportMatrixInvalidMember(S, AccessorName, Expected, OpLoc, CompLoc);
6198 }
6199
6200 auto isDigit = [](char c) { return c >= '0' && c <= '9'; };
6201 auto isZeroBasedIndex = [](unsigned i) { return i <= 3; };
6202 auto isOneBasedIndex = [](unsigned i) { return i >= 1 && i <= 4; };
6203
6204 bool HasRepeated = false;
6205 SmallVector<bool, 16> Seen(Rows * Cols, false);
6206 unsigned NumComponents = 0;
6207 const char *Begin = AccessorName.data();
6208
6209 for (unsigned I = 0, E = AccessorName.size(); I < E; I += ChunkLen) {
6210 const char *Chunk = Begin + I;
6211 char RowChar = 0, ColChar = 0;
6212 if (IsZeroBasedAccessor) {
6213 // Zero-based: "_mRC"
6214 if (Chunk[0] != '_' || Chunk[1] != 'm') {
6215 char Bad = (Chunk[0] != '_') ? Chunk[0] : Chunk[1];
6217 S, StringRef(&Bad, 1), "\'_m\' prefix",
6218 OpLoc.getLocWithOffset(I + (Bad == Chunk[0] ? 1 : 2)), CompLoc);
6219 }
6220 RowChar = Chunk[2];
6221 ColChar = Chunk[3];
6222 } else {
6223 // One-based: "_RC"
6224 if (Chunk[0] != '_')
6226 S, StringRef(&Chunk[0], 1), "\'_\' prefix",
6227 OpLoc.getLocWithOffset(I + 1), CompLoc);
6228 RowChar = Chunk[1];
6229 ColChar = Chunk[2];
6230 }
6231
6232 // Must be digits.
6233 bool IsDigitsError = false;
6234 if (!isDigit(RowChar)) {
6235 unsigned BadPos = IsZeroBasedAccessor ? 2 : 1;
6236 ReportMatrixInvalidMember(S, StringRef(&RowChar, 1), "row as integer",
6237 OpLoc.getLocWithOffset(I + BadPos + 1),
6238 CompLoc);
6239 IsDigitsError = true;
6240 }
6241
6242 if (!isDigit(ColChar)) {
6243 unsigned BadPos = IsZeroBasedAccessor ? 3 : 2;
6244 ReportMatrixInvalidMember(S, StringRef(&ColChar, 1), "column as integer",
6245 OpLoc.getLocWithOffset(I + BadPos + 1),
6246 CompLoc);
6247 IsDigitsError = true;
6248 }
6249 if (IsDigitsError)
6250 return QualType();
6251
6252 unsigned Row = RowChar - '0';
6253 unsigned Col = ColChar - '0';
6254
6255 bool HasIndexingError = false;
6256 if (IsZeroBasedAccessor) {
6257 // 0-based [0..3]
6258 if (!isZeroBasedIndex(Row)) {
6259 S.Diag(OpLoc, diag::err_hlsl_matrix_element_not_in_bounds)
6260 << /*row*/ 0 << /*zero-based*/ 0 << SourceRange(CompLoc);
6261 HasIndexingError = true;
6262 }
6263 if (!isZeroBasedIndex(Col)) {
6264 S.Diag(OpLoc, diag::err_hlsl_matrix_element_not_in_bounds)
6265 << /*col*/ 1 << /*zero-based*/ 0 << SourceRange(CompLoc);
6266 HasIndexingError = true;
6267 }
6268 } else {
6269 // 1-based [1..4]
6270 if (!isOneBasedIndex(Row)) {
6271 S.Diag(OpLoc, diag::err_hlsl_matrix_element_not_in_bounds)
6272 << /*row*/ 0 << /*one-based*/ 1 << SourceRange(CompLoc);
6273 HasIndexingError = true;
6274 }
6275 if (!isOneBasedIndex(Col)) {
6276 S.Diag(OpLoc, diag::err_hlsl_matrix_element_not_in_bounds)
6277 << /*col*/ 1 << /*one-based*/ 1 << SourceRange(CompLoc);
6278 HasIndexingError = true;
6279 }
6280 // Convert to 0-based after range checking.
6281 --Row;
6282 --Col;
6283 }
6284
6285 if (HasIndexingError)
6286 return QualType();
6287
6288 // Note: matrix swizzle index is hard coded. That means Row and Col can
6289 // potentially be larger than Rows and Cols if matrix size is less than
6290 // the max index size.
6291 bool HasBoundsError = false;
6292 if (Row >= Rows) {
6293 Diag(OpLoc, diag::err_hlsl_matrix_index_out_of_bounds)
6294 << /*Row*/ 0 << Row << Rows << SourceRange(CompLoc);
6295 HasBoundsError = true;
6296 }
6297 if (Col >= Cols) {
6298 Diag(OpLoc, diag::err_hlsl_matrix_index_out_of_bounds)
6299 << /*Col*/ 1 << Col << Cols << SourceRange(CompLoc);
6300 HasBoundsError = true;
6301 }
6302 if (HasBoundsError)
6303 return QualType();
6304
6305 unsigned FlatIndex = Row * Cols + Col;
6306 if (Seen[FlatIndex])
6307 HasRepeated = true;
6308 Seen[FlatIndex] = true;
6309 ++NumComponents;
6310 }
6311 if (NumComponents == 0 || NumComponents > 4) {
6312 S.Diag(OpLoc, diag::err_hlsl_matrix_swizzle_invalid_length)
6313 << NumComponents << SourceRange(CompLoc);
6314 return QualType();
6315 }
6316
6317 QualType ElemTy = MT->getElementType();
6318 if (NumComponents == 1)
6319 return ElemTy;
6320 QualType VT = S.Context.getExtVectorType(ElemTy, NumComponents);
6321 if (HasRepeated)
6322 VK = VK_PRValue;
6323
6324 for (Sema::ExtVectorDeclsType::iterator
6326 E = S.ExtVectorDecls.end();
6327 I != E; ++I) {
6328 if ((*I)->getUnderlyingType() == VT)
6330 /*Qualifier=*/std::nullopt, *I);
6331 }
6332
6333 return VT;
6334}
6335
6337 // If initializing a local resource, track the resource binding it is using
6338 if (VDecl->getType()->isHLSLResourceRecord() && !VDecl->hasGlobalStorage())
6339 trackLocalResource(VDecl, Init);
6340
6341 const HLSLVkConstantIdAttr *ConstIdAttr =
6342 VDecl->getAttr<HLSLVkConstantIdAttr>();
6343 if (!ConstIdAttr)
6344 return true;
6345
6346 ASTContext &Context = SemaRef.getASTContext();
6347
6348 APValue InitValue;
6349 if (!Init->isCXX11ConstantExpr(Context, &InitValue)) {
6350 Diag(VDecl->getLocation(), diag::err_specialization_const);
6351 VDecl->setInvalidDecl();
6352 return false;
6353 }
6354
6355 Builtin::ID BID =
6357
6358 // Argument 1: The ID from the attribute
6359 int ConstantID = ConstIdAttr->getId();
6360 llvm::APInt IDVal(Context.getIntWidth(Context.IntTy), ConstantID);
6361 Expr *IdExpr = IntegerLiteral::Create(Context, IDVal, Context.IntTy,
6362 ConstIdAttr->getLocation());
6363
6364 SmallVector<Expr *, 2> Args = {IdExpr, Init};
6365 Expr *C = SemaRef.BuildBuiltinCallExpr(Init->getExprLoc(), BID, Args);
6366 if (C->getType()->getCanonicalTypeUnqualified() !=
6368 C = SemaRef
6369 .BuildCStyleCastExpr(SourceLocation(),
6370 Context.getTrivialTypeSourceInfo(
6371 Init->getType(), Init->getExprLoc()),
6372 SourceLocation(), C)
6373 .get();
6374 }
6375 Init = C;
6376 return true;
6377}
6378
6380 SourceLocation NameLoc) {
6381 if (!Template)
6382 return QualType();
6383
6384 DeclContext *DC = Template->getDeclContext();
6385 if (!DC->isNamespace() || !cast<NamespaceDecl>(DC)->getIdentifier() ||
6386 cast<NamespaceDecl>(DC)->getName() != "hlsl")
6387 return QualType();
6388
6389 TemplateParameterList *Params = Template->getTemplateParameters();
6390 if (!Params || Params->size() != 1)
6391 return QualType();
6392
6393 if (!Template->isImplicit())
6394 return QualType();
6395
6396 // We manually extract default arguments here instead of letting
6397 // CheckTemplateIdType handle it. This ensures that for resource types that
6398 // lack a default argument (like Buffer), we return a null QualType, which
6399 // triggers the "requires template arguments" error rather than a less
6400 // descriptive "too few template arguments" error.
6401 TemplateArgumentListInfo TemplateArgs(NameLoc, NameLoc);
6402 for (NamedDecl *P : *Params) {
6403 if (auto *TTP = dyn_cast<TemplateTypeParmDecl>(P)) {
6404 if (TTP->hasDefaultArgument()) {
6405 TemplateArgs.addArgument(TTP->getDefaultArgument());
6406 continue;
6407 }
6408 } else if (auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(P)) {
6409 if (NTTP->hasDefaultArgument()) {
6410 TemplateArgs.addArgument(NTTP->getDefaultArgument());
6411 continue;
6412 }
6413 } else if (auto *TTPD = dyn_cast<TemplateTemplateParmDecl>(P)) {
6414 if (TTPD->hasDefaultArgument()) {
6415 TemplateArgs.addArgument(TTPD->getDefaultArgument());
6416 continue;
6417 }
6418 }
6419 return QualType();
6420 }
6421
6422 return SemaRef.CheckTemplateIdType(
6424 TemplateArgs, nullptr, /*ForNestedNameSpecifier=*/false);
6425}
Defines the clang::ASTContext interface.
Defines enum values for all the target-independent builtin functions.
llvm::dxil::ResourceClass ResourceClass
Defines the C++ Decl subclasses, other than those for templates (found in DeclTemplate....
TokenType getType() const
Returns the token's type, e.g.
FormatToken * Previous
The previous token in the unwrapped line.
Defines the clang::IdentifierInfo, clang::IdentifierTable, and clang::Selector interfaces.
#define X(type, name)
Definition Value.h:97
Forward-declares and imports various common LLVM datatypes that clang wants to use unqualified.
llvm::SmallVector< std::pair< const MemRegion *, SVal >, 4 > Bindings
static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType)
static void BuildFlattenedTypeList(QualType BaseTy, llvm::SmallVectorImpl< QualType > &List)
static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static bool containsIncompleteArrayType(QualType Ty)
static QualType handleIntegerVectorBinOpConversion(Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign)
static bool convertToRegisterType(StringRef Slot, RegisterType *RT)
Definition SemaHLSL.cpp:95
static StringRef createRegisterString(ASTContext &AST, RegisterType RegType, unsigned N)
Definition SemaHLSL.cpp:197
static bool CheckWaveActive(Sema *S, CallExpr *TheCall)
static void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl)
Definition SemaHLSL.cpp:622
static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz)
static QualType ReportMatrixInvalidMember(Sema &S, StringRef Name, StringRef Expected, SourceLocation OpLoc, SourceLocation CompLoc)
static bool CheckBoolSelect(Sema *S, CallExpr *TheCall)
static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context, QualType T)
Definition SemaHLSL.cpp:259
static bool isZeroSizedArray(const ConstantArrayType *CAT)
Definition SemaHLSL.cpp:378
static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, Decl *D, RegisterType RegType, bool SpecifiedSpace)
static bool hasConstantBufferLayout(QualType QT)
static FieldDecl * createFieldForHostLayoutStruct(Sema &S, const Type *Ty, IdentifierInfo *II, CXXRecordDecl *LayoutStruct)
Definition SemaHLSL.cpp:530
static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
SampleKind
static bool isInvalidConstantBufferLeafElementType(const Type *Ty)
Definition SemaHLSL.cpp:412
static bool CheckCalculateLodBuiltin(Sema &S, CallExpr *TheCall)
static Builtin::ID getSpecConstBuiltinId(const Type *Type)
Definition SemaHLSL.cpp:163
static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static const Type * createHostLayoutType(Sema &S, const Type *Ty)
Definition SemaHLSL.cpp:503
static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall, unsigned ArgIndex)
static const HLSLAttributedResourceType * getResourceArrayHandleType(QualType QT)
Definition SemaHLSL.cpp:394
static IdentifierInfo * getHostLayoutStructName(Sema &S, NamedDecl *BaseDecl, bool MustBeUnique)
Definition SemaHLSL.cpp:468
static void addImplicitBindingAttrToDecl(Sema &S, Decl *D, RegisterType RT, uint32_t ImplicitBindingOrderID)
Definition SemaHLSL.cpp:666
static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall, QualType ReturnType)
static unsigned calculateLegacyCbufferSize(const ASTContext &Context, QualType T)
Definition SemaHLSL.cpp:278
static bool CheckLoadLevelBuiltin(Sema &S, CallExpr *TheCall)
static RegisterType getRegisterType(ResourceClass RC)
Definition SemaHLSL.cpp:62
static bool ValidateRegisterNumber(uint64_t SlotNum, Decl *TheDecl, ASTContext &Ctx, RegisterType RegTy)
static bool isVkPipelineBuiltin(const ASTContext &AstContext, FunctionDecl *FD, HLSLAppliedSemanticAttr *Semantic, bool IsInput)
Definition SemaHLSL.cpp:854
static bool CheckVectorElementCount(Sema *S, QualType PassedType, QualType BaseType, unsigned ExpectedCount, SourceLocation Loc)
static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall, unsigned ArgIndex)
static QualType castElement(Sema &S, ExprResult &E, QualType Ty)
static char getRegisterTypeChar(RegisterType RT)
Definition SemaHLSL.cpp:127
static bool CheckNotBoolScalarOrVector(Sema *S, CallExpr *TheCall, unsigned ArgIndex)
static bool isMatrixOrArrayOfMatrix(const ASTContext &Ctx, QualType QT)
static CXXRecordDecl * findRecordDeclInContext(IdentifierInfo *II, DeclContext *DC)
Definition SemaHLSL.cpp:451
static bool CheckWavePrefix(Sema *S, CallExpr *TheCall)
static bool CheckExpectedBitWidth(Sema *S, CallExpr *TheCall, unsigned ArgOrdinal, unsigned Width)
static LangAS getLangASFromResourceClass(ResourceClass RC)
Definition SemaHLSL.cpp:80
static bool diagnoseMatrixLayoutOnNonMatrix(Sema &SemaRef, Decl *D, SourceLocation Loc, const IdentifierInfo *AttrName)
static bool CheckVectorSelect(Sema *S, CallExpr *TheCall)
static QualType handleFloatVectorBinOpConversion(Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign)
static ResourceClass getResourceClass(RegisterType RT)
Definition SemaHLSL.cpp:145
static CXXRecordDecl * createHostLayoutStruct(Sema &S, CXXRecordDecl *StructDecl)
Definition SemaHLSL.cpp:557
static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar, unsigned ArgIndex)
static bool CheckSamplingBuiltin(Sema &S, CallExpr *TheCall, SampleKind Kind)
static bool CheckScalarOrVectorOrMatrix(Sema *S, CallExpr *TheCall, QualType Scalar, unsigned ArgIndex)
static bool CheckFloatRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static bool CheckAnyDoubleRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD)
Definition SemaHLSL.cpp:431
static bool CheckResourceHandle(Sema *S, CallExpr *TheCall, unsigned ArgIndex, llvm::function_ref< bool(const HLSLAttributedResourceType *ResType)> Check=nullptr)
static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl)
Definition SemaHLSL.cpp:325
static bool IsDefaultBufferConstantDecl(const ASTContext &Ctx, VarDecl *VD)
HLSLResourceBindingAttr::RegisterType RegisterType
Definition SemaHLSL.cpp:57
static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy, QualType SrcTy)
static bool CheckGatherBuiltin(Sema &S, CallExpr *TheCall, bool IsCmp)
static bool isValidWaveSizeValue(unsigned Value)
static bool isResourceRecordTypeOrArrayOf(QualType Ty)
Definition SemaHLSL.cpp:385
static bool AccumulateHLSLResourceSlots(QualType Ty, uint64_t &StartSlot, const uint64_t &Limit, const ResourceClass ResClass, ASTContext &Ctx, uint64_t ArrayCount=1)
static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, RegisterType regType)
static bool CheckTextureSamplerAndLocation(Sema &S, CallExpr *TheCall)
static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, Decl *D, RegisterType RegType, bool SpecifiedSpace)
static bool CheckIndexType(Sema *S, CallExpr *TheCall, unsigned IndexArgIndex)
This file declares semantic analysis for HLSL constructs.
Defines the clang::SourceLocation class and associated facilities.
Defines various enumerations that describe declaration and type specifiers.
C Language Family Type Representation.
Defines the clang::TypeLoc interface and its subclasses.
C Language Family Type Representation.
static const TypeInfo & getInfo(unsigned id)
Definition Types.cpp:44
__device__ __2f16 float c
return(__x > > __y)|(__x<<(32 - __y))
APValue - This class implements a discriminated union of [uninitialized] [APSInt] [APFloat],...
Definition APValue.h:122
virtual bool HandleTopLevelDecl(DeclGroupRef D)
HandleTopLevelDecl - Handle the specified top-level declaration.
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
Definition ASTContext.h:227
unsigned getIntWidth(QualType T) const
int getIntegerTypeOrder(QualType LHS, QualType RHS) const
Return the highest ranked integer type, see C99 6.3.1.8p1.
CanQualType FloatTy
QualType getPointerType(QualType T) const
Return the uniqued reference to the type for a pointer to the specified type.
const IncompleteArrayType * getAsIncompleteArrayType(QualType T) const
IdentifierTable & Idents
Definition ASTContext.h:805
QualType getConstantArrayType(QualType EltTy, const llvm::APInt &ArySize, const Expr *SizeExpr, ArraySizeModifier ASM, unsigned IndexTypeQuals) const
Return the unique reference to the type for a constant array of the specified element type.
QualType getBaseElementType(const ArrayType *VAT) const
Return the innermost element type of an array type.
int getFloatingTypeOrder(QualType LHS, QualType RHS) const
Compare the rank of the two specified floating point types, ignoring the domain of the type (i....
CanQualType BoolTy
TypeSourceInfo * getTrivialTypeSourceInfo(QualType T, SourceLocation Loc=SourceLocation()) const
Allocate a TypeSourceInfo where all locations have been initialized to a given location,...
QualType getStringLiteralArrayType(QualType EltTy, unsigned Length) const
Return a type for a constant array for a string literal of the specified element type and length.
CanQualType CharTy
CanQualType IntTy
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
CharUnits getTypeSizeInChars(QualType T) const
Return the size of the specified (complete) type T, in characters.
CanQualType UnsignedIntTy
QualType getTypedefType(ElaboratedTypeKeyword Keyword, NestedNameSpecifier Qualifier, const TypedefNameDecl *Decl, QualType UnderlyingType=QualType(), std::optional< bool > TypeMatchesDeclOrNone=std::nullopt) const
Return the unique reference to the type for the specified typedef-name decl.
llvm::StringRef backupStr(llvm::StringRef S) const
Definition ASTContext.h:887
QualType getSizeType() const
Return the unique type for "size_t" (C99 7.17), defined in <stddef.h>.
QualType getExtVectorType(QualType VectorType, unsigned NumElts) const
Return the unique reference to an extended vector type of the specified element type and size.
const TargetInfo & getTargetInfo() const
Definition ASTContext.h:924
QualType getHLSLAttributedResourceType(QualType Wrapped, QualType Contained, const HLSLAttributedResourceType::Attributes &Attrs)
QualType getAddrSpaceQualType(QualType T, LangAS AddressSpace) const
Return the uniqued reference to the type for an address space qualified type with the specified type ...
CanQualType getCanonicalTagType(const TagDecl *TD) const
static bool hasSameUnqualifiedType(QualType T1, QualType T2)
Determine whether the given types are equivalent after cvr-qualifiers have been removed.
QualType getConstantMatrixType(QualType ElementType, unsigned NumRows, unsigned NumColumns) const
Return the unique reference to the matrix type of the specified element type and size.
unsigned getTypeAlign(QualType T) const
Return the ABI-specified alignment of a (complete) type T, in bits.
PtrTy get() const
Definition Ownership.h:171
bool isInvalid() const
Definition Ownership.h:167
Represents an array type, per C99 6.7.5.2 - Array Declarators.
Definition TypeBase.h:3777
QualType getElementType() const
Definition TypeBase.h:3789
Attr - This represents one attribute.
Definition Attr.h:46
attr::Kind getKind() const
Definition Attr.h:92
SourceLocation getLocation() const
Definition Attr.h:99
SourceLocation getScopeLoc() const
const IdentifierInfo * getScopeName() const
SourceLocation getLoc() const
const IdentifierInfo * getAttrName() const
Represents a base class of a C++ class.
Definition DeclCXX.h:146
QualType getType() const
Retrieves the type of the base class.
Definition DeclCXX.h:249
Represents a static or instance method of a struct/union/class.
Definition DeclCXX.h:2132
Represents a C++ struct/union/class.
Definition DeclCXX.h:258
bool isHLSLIntangible() const
Returns true if the class contains HLSL intangible type, either as a field or in base class.
Definition DeclCXX.h:1556
static CXXRecordDecl * Create(const ASTContext &C, TagKind TK, DeclContext *DC, SourceLocation StartLoc, SourceLocation IdLoc, IdentifierInfo *Id, CXXRecordDecl *PrevDecl=nullptr)
Definition DeclCXX.cpp:132
void setBases(CXXBaseSpecifier const *const *Bases, unsigned NumBases)
Sets the base classes of this struct or class.
Definition DeclCXX.cpp:184
base_class_iterator bases_end()
Definition DeclCXX.h:617
void completeDefinition() override
Indicates that the definition of this class is now complete.
Definition DeclCXX.cpp:2249
base_class_range bases()
Definition DeclCXX.h:608
unsigned getNumBases() const
Retrieves the number of base classes of this class.
Definition DeclCXX.h:602
base_class_iterator bases_begin()
Definition DeclCXX.h:615
bool isEmpty() const
Determine whether this is an empty class in the sense of (C++11 [meta.unary.prop]).
Definition DeclCXX.h:1186
CallExpr - Represents a function call (C99 6.5.2.2, C++ [expr.call]).
Definition Expr.h:2946
Expr * getArg(unsigned Arg)
getArg - Return the specified argument.
Definition Expr.h:3150
SourceLocation getBeginLoc() const
Definition Expr.h:3280
static CallExpr * Create(const ASTContext &Ctx, Expr *Fn, ArrayRef< Expr * > Args, QualType Ty, ExprValueKind VK, SourceLocation RParenLoc, FPOptionsOverride FPFeatures, unsigned MinNumArgs=0, ADLCallKind UsesADL=NotADL)
Create a call expression.
Definition Expr.cpp:1517
FunctionDecl * getDirectCallee()
If the callee is a FunctionDecl, return it. Otherwise return null.
Definition Expr.h:3129
Expr * getCallee()
Definition Expr.h:3093
unsigned getNumArgs() const
getNumArgs - Return the number of actual arguments to this call.
Definition Expr.h:3137
Decl * getCalleeDecl()
Definition Expr.h:3123
static CanQual< Type > CreateUnsafe(QualType Other)
QualType withConst() const
Retrieves a version of this type with const applied.
const T * getTypePtr() const
Retrieve the underlying type pointer, which refers to a canonical type.
QuantityType getQuantity() const
getQuantity - Get the raw integer representation of this quantity.
Definition CharUnits.h:185
Represents the canonical version of C arrays with a specified constant size.
Definition TypeBase.h:3815
bool isZeroSize() const
Return true if the size is zero.
Definition TypeBase.h:3885
llvm::APInt getSize() const
Return the constant array size as an APInt.
Definition TypeBase.h:3871
uint64_t getZExtSize() const
Return the size zero-extended as a uint64_t.
Definition TypeBase.h:3891
Represents a concrete matrix type with constant number of rows and columns.
Definition TypeBase.h:4442
unsigned getNumColumns() const
Returns the number of columns in the matrix.
Definition TypeBase.h:4461
static DeclAccessPair make(NamedDecl *D, AccessSpecifier AS)
DeclContext - This is used only as base class of specific decl types that can act as declaration cont...
Definition DeclBase.h:1462
bool isNamespace() const
Definition DeclBase.h:2211
lookup_result lookup(DeclarationName Name) const
lookup - Find the declarations (if any) with the given Name in this context.
bool isTranslationUnit() const
Definition DeclBase.h:2198
void addDecl(Decl *D)
Add the declaration D into this context.
decl_range decls() const
decls_begin/decls_end - Iterate over the declarations stored in this context.
Definition DeclBase.h:2386
DeclContext * getNonTransparentContext()
A reference to a declared variable, function, enum, etc.
Definition Expr.h:1273
static DeclRefExpr * Create(const ASTContext &Context, NestedNameSpecifierLoc QualifierLoc, SourceLocation TemplateKWLoc, ValueDecl *D, bool RefersToEnclosingVariableOrCapture, SourceLocation NameLoc, QualType T, ExprValueKind VK, NamedDecl *FoundD=nullptr, const TemplateArgumentListInfo *TemplateArgs=nullptr, NonOdrUseReason NOUR=NOUR_None)
Definition Expr.cpp:488
ValueDecl * getDecl()
Definition Expr.h:1341
Decl - This represents one declaration (or definition), e.g.
Definition DeclBase.h:86
T * getAttr() const
Definition DeclBase.h:581
ASTContext & getASTContext() const LLVM_READONLY
Definition DeclBase.cpp:547
void addAttr(Attr *A)
attr_iterator attr_end() const
Definition DeclBase.h:550
bool isImplicit() const
isImplicit - Indicates whether the declaration was implicitly generated by the implementation.
Definition DeclBase.h:601
void setInvalidDecl(bool Invalid=true)
setInvalidDecl - Indicates the Decl had a semantic error.
Definition DeclBase.cpp:178
bool isInExportDeclContext() const
Whether this declaration was exported in a lexical context.
attr_iterator attr_begin() const
Definition DeclBase.h:547
DeclContext * getNonTransparentDeclContext()
Return the non transparent context.
SourceLocation getLocation() const
Definition DeclBase.h:447
void setImplicit(bool I=true)
Definition DeclBase.h:602
DeclContext * getDeclContext()
Definition DeclBase.h:456
attr_range attrs() const
Definition DeclBase.h:543
AccessSpecifier getAccess() const
Definition DeclBase.h:515
SourceLocation getBeginLoc() const LLVM_READONLY
Definition DeclBase.h:439
void dropAttr()
Definition DeclBase.h:564
bool hasAttr() const
Definition DeclBase.h:585
The name of a declaration.
Represents a ValueDecl that came out of a declarator.
Definition Decl.h:780
SourceLocation getBeginLoc() const LLVM_READONLY
Definition Decl.h:831
This represents one expression.
Definition Expr.h:112
bool isIntegerConstantExpr(const ASTContext &Ctx) const
void setType(QualType t)
Definition Expr.h:145
ExprValueKind getValueKind() const
getValueKind - The value kind that this expression produces.
Definition Expr.h:447
Expr * IgnoreParenImpCasts() LLVM_READONLY
Skip past any parentheses and implicit casts which might surround this expression until reaching a fi...
Definition Expr.cpp:3090
Expr * IgnoreParens() LLVM_READONLY
Skip past any parentheses which might surround this expression until reaching a fixed point.
Definition Expr.cpp:3086
std::optional< llvm::APSInt > getIntegerConstantExpr(const ASTContext &Ctx) const
isIntegerConstantExpr - Return the value if this expression is a valid integer constant expression.
bool isPRValue() const
Definition Expr.h:285
bool isLValue() const
isLValue - True if this expression is an "l-value" according to the rules of the current language.
Definition Expr.h:284
ExprObjectKind getObjectKind() const
getObjectKind - The object kind that this expression produces.
Definition Expr.h:454
bool HasSideEffects(const ASTContext &Ctx, bool IncludePossibleEffects=true) const
HasSideEffects - This routine returns true for all those expressions which have any effect other than...
Definition Expr.cpp:3688
SourceLocation getExprLoc() const LLVM_READONLY
getExprLoc - Return the preferred location for the arrow when diagnosing a problem with a generic exp...
Definition Expr.cpp:277
@ MLV_Valid
Definition Expr.h:306
QualType getType() const
Definition Expr.h:144
ExtVectorType - Extended vector type.
Definition TypeBase.h:4322
Represents difference between two FPOptions values.
Represents a member of a struct/union/class.
Definition Decl.h:3178
static FieldDecl * Create(const ASTContext &C, DeclContext *DC, SourceLocation StartLoc, SourceLocation IdLoc, const IdentifierInfo *Id, QualType T, TypeSourceInfo *TInfo, Expr *BW, bool Mutable, InClassInitStyle InitStyle)
Definition Decl.cpp:4695
static FixItHint CreateReplacement(CharSourceRange RemoveRange, StringRef Code)
Create a code modification hint that replaces the given source range with the given code string.
Definition Diagnostic.h:141
Represents a function declaration or definition.
Definition Decl.h:2018
const ParmVarDecl * getParamDecl(unsigned i) const
Definition Decl.h:2815
Stmt * getBody(const FunctionDecl *&Definition) const
Retrieve the body (definition) of the function.
Definition Decl.cpp:3274
bool isThisDeclarationADefinition() const
Returns whether this specific declaration of the function is also a definition that does not contain ...
Definition Decl.h:2332
QualType getReturnType() const
Definition Decl.h:2863
ArrayRef< ParmVarDecl * > parameters() const
Definition Decl.h:2792
bool isTemplateInstantiation() const
Determines if the given function was instantiated from a function template.
Definition Decl.cpp:4252
redecl_range redecls() const
Returns an iterator range for all the redeclarations of the same decl.
unsigned getNumParams() const
Return the number of parameters this function must have based on its FunctionType.
Definition Decl.cpp:3821
DeclarationNameInfo getNameInfo() const
Definition Decl.h:2229
bool hasBody(const FunctionDecl *&Definition) const
Returns true if the function has a body.
Definition Decl.cpp:3194
bool isDefined(const FunctionDecl *&Definition, bool CheckForPendingFriendDefinition=false) const
Returns true if the function has a definition that does not need to be instantiated.
Definition Decl.cpp:3241
Represents a prototype with parameter type info, e.g.
Definition TypeBase.h:5362
HLSLBufferDecl - Represent a cbuffer or tbuffer declaration.
Definition Decl.h:5212
static HLSLBufferDecl * Create(ASTContext &C, DeclContext *LexicalParent, bool CBuffer, SourceLocation KwLoc, IdentifierInfo *ID, SourceLocation IDLoc, SourceLocation LBrace)
Definition Decl.cpp:5905
void addLayoutStruct(CXXRecordDecl *LS)
Definition Decl.cpp:5945
void setHasValidPackoffset(bool PO)
Definition Decl.h:5257
static HLSLBufferDecl * CreateDefaultCBuffer(ASTContext &C, DeclContext *LexicalParent, ArrayRef< Decl * > DefaultCBufferDecls)
Definition Decl.cpp:5928
buffer_decl_range buffer_decls() const
Definition Decl.h:5287
static HLSLOutArgExpr * Create(const ASTContext &C, QualType Ty, OpaqueValueExpr *Base, OpaqueValueExpr *OpV, Expr *WB, bool IsInOut)
Definition Expr.cpp:5645
static HLSLRootSignatureDecl * Create(ASTContext &C, DeclContext *DC, SourceLocation Loc, IdentifierInfo *ID, llvm::dxbc::RootSignatureVersion Version, ArrayRef< llvm::hlsl::rootsig::RootElement > RootElements)
Definition Decl.cpp:5991
One of these records is kept for each identifier that is lexed.
StringRef getName() const
Return the actual identifier string.
A simple pair of identifier info and location.
SourceLocation getLoc() const
IdentifierInfo * getIdentifierInfo() const
IdentifierInfo & get(StringRef Name)
Return the identifier token info for the specified named identifier.
ImplicitCastExpr - Allows us to explicitly represent implicit type conversions, which have no direct ...
Definition Expr.h:3856
static ImplicitCastExpr * Create(const ASTContext &Context, QualType T, CastKind Kind, Expr *Operand, const CXXCastPath *BasePath, ExprValueKind Cat, FPOptionsOverride FPO)
Definition Expr.cpp:2073
Describes an C or C++ initializer list.
Definition Expr.h:5302
Describes an entity that is being initialized.
QualType getType() const
Retrieve type being initialized.
static InitializedEntity InitializeParameter(ASTContext &Context, ParmVarDecl *Parm)
Create the initialization entity for a parameter.
static IntegerLiteral * Create(const ASTContext &C, const llvm::APInt &V, QualType type, SourceLocation l)
Returns a new integer literal with value 'V' and type 'type'.
Definition Expr.cpp:975
iterator begin(Source *source, bool LocalOnly=false)
Represents the results of name lookup.
Definition Lookup.h:147
Represents a prvalue temporary that is written into memory so that a reference can bind to it.
Definition ExprCXX.h:4920
Represents a matrix type, as defined in the Matrix Types clang extensions.
Definition TypeBase.h:4392
MemberExpr - [C99 6.5.2.3] Structure and Union Members.
Definition Expr.h:3367
ValueDecl * getMemberDecl() const
Retrieve the member declaration to which this expression refers.
Definition Expr.h:3450
Expr * getBase() const
Definition Expr.h:3444
This represents a decl that may have a name.
Definition Decl.h:274
NamedDecl * getUnderlyingDecl()
Looks through UsingDecls and ObjCCompatibleAliasDecls for the underlying named decl.
Definition Decl.h:487
IdentifierInfo * getIdentifier() const
Get the identifier that names this declaration, if there is one.
Definition Decl.h:295
StringRef getName() const
Get the name of identifier for this declaration as a StringRef.
Definition Decl.h:301
DeclarationName getDeclName() const
Get the actual, stored name of the declaration, which may be a special name.
Definition Decl.h:340
A C++ nested-name-specifier augmented with source location information.
OpaqueValueExpr - An expression referring to an opaque object of a fixed type and value class.
Definition Expr.h:1181
Represents a parameter to a function.
Definition Decl.h:1808
ParsedAttr - Represents a syntactic attribute.
Definition ParsedAttr.h:119
unsigned getSemanticSpelling() const
If the parsed attribute has a semantic equivalent, and it would have a semantic Spelling enumeration ...
unsigned getMinArgs() const
bool checkExactlyNumArgs(class Sema &S, unsigned Num) const
Check if the attribute has exactly as many args as Num.
IdentifierLoc * getArgAsIdent(unsigned Arg) const
Definition ParsedAttr.h:389
bool hasParsedType() const
Definition ParsedAttr.h:337
const ParsedType & getTypeArg() const
Definition ParsedAttr.h:459
unsigned getNumArgs() const
getNumArgs - Return the number of actual arguments to this attribute.
Definition ParsedAttr.h:371
bool isArgIdent(unsigned Arg) const
Definition ParsedAttr.h:385
Expr * getArgAsExpr(unsigned Arg) const
Definition ParsedAttr.h:383
AttributeCommonInfo::Kind getKind() const
Definition ParsedAttr.h:610
A (possibly-)qualified type.
Definition TypeBase.h:937
void addRestrict()
Add the restrict qualifier to this QualType.
Definition TypeBase.h:1183
QualType getNonLValueExprType(const ASTContext &Context) const
Determine the type of a (typically non-lvalue) expression with the specified result type.
Definition Type.cpp:3677
QualType getDesugaredType(const ASTContext &Context) const
Return the specified type with any "sugar" removed from the type.
Definition TypeBase.h:1307
bool isNull() const
Return true if this QualType doesn't point to a type yet.
Definition TypeBase.h:1004
const Type * getTypePtr() const
Retrieves a pointer to the underlying (unqualified) type.
Definition TypeBase.h:8436
LangAS getAddressSpace() const
Return the address space of this type.
Definition TypeBase.h:8562
QualType getNonReferenceType() const
If Type is a reference type (e.g., const int&), returns the type that the reference refers to ("const...
Definition TypeBase.h:8621
QualType getCanonicalType() const
Definition TypeBase.h:8488
QualType getUnqualifiedType() const
Retrieve the unqualified variant of the given type, removing as little sugar as possible.
Definition TypeBase.h:8530
bool hasAddressSpace() const
Check if this type has any address space qualifier.
Definition TypeBase.h:8557
Represents a struct/union/class.
Definition Decl.h:4343
field_range fields() const
Definition Decl.h:4546
bool field_empty() const
Definition Decl.h:4554
bool hasBindingInfoForDecl(const VarDecl *VD) const
Definition SemaHLSL.cpp:233
DeclBindingInfo * getDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass)
Definition SemaHLSL.cpp:219
DeclBindingInfo * addDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass)
Definition SemaHLSL.cpp:206
Scope - A scope is a transient data structure that is used while parsing the program.
Definition Scope.h:41
SemaBase(Sema &S)
Definition SemaBase.cpp:7
ASTContext & getASTContext() const
Definition SemaBase.cpp:9
Sema & SemaRef
Definition SemaBase.h:40
SemaDiagnosticBuilder Diag(SourceLocation Loc, unsigned DiagID)
Emit a diagnostic.
Definition SemaBase.cpp:61
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg)
HLSLRootSignatureDecl * lookupRootSignatureOverrideDecl(DeclContext *DC) const
bool CanPerformElementwiseCast(Expr *Src, QualType DestType)
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL)
void handleVkLocationAttr(Decl *D, const ParsedAttr &AL)
HLSLAttributedResourceLocInfo TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT)
void handleSemanticAttr(Decl *D, const ParsedAttr &AL)
bool CanPerformScalarCast(QualType SrcTy, QualType DestTy)
QualType ProcessResourceTypeAttributes(QualType Wrapped)
void handleShaderAttr(Decl *D, const ParsedAttr &AL)
uint32_t getNextImplicitBindingOrderID()
Definition SemaHLSL.h:244
void CheckEntryPoint(FunctionDecl *FD)
Definition SemaHLSL.cpp:973
void handleVkExtBuiltinOutputAttr(Decl *D, const ParsedAttr &AL)
void emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS, BinaryOperatorKind Opc)
T * createSemanticAttr(const AttributeCommonInfo &ACI, std::optional< unsigned > Location)
Definition SemaHLSL.h:196
bool initGlobalResourceDecl(VarDecl *VD)
void ActOnEndOfTranslationUnit(TranslationUnitDecl *TU)
bool initGlobalResourceArrayDecl(VarDecl *VD)
HLSLVkConstantIdAttr * mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id)
Definition SemaHLSL.cpp:737
HLSLNumThreadsAttr * mergeNumThreadsAttr(Decl *D, const AttributeCommonInfo &AL, int X, int Y, int Z)
Definition SemaHLSL.cpp:703
void deduceAddressSpace(VarDecl *Decl)
std::pair< IdentifierInfo *, bool > ActOnStartRootSignatureDecl(StringRef Signature)
Computes the unique Root Signature identifier from the given signature, then lookup if there is a pre...
void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL)
std::optional< ExprResult > tryPerformConstantBufferConversion(ExprResult &BaseExpr)
bool diagnosePositionType(QualType T, const ParsedAttr &AL)
bool handleInitialization(VarDecl *VDecl, Expr *&Init)
bool diagnoseInputIDType(QualType T, const ParsedAttr &AL)
void handleParamModifierAttr(Decl *D, const ParsedAttr &AL)
bool CheckResourceBinOp(BinaryOperatorKind Opc, Expr *LHSExpr, Expr *RHSExpr, SourceLocation Loc)
bool CanPerformAggregateSplatCast(Expr *Src, QualType DestType)
bool ActOnResourceMemberAccessExpr(MemberExpr *ME)
bool IsScalarizedLayoutCompatible(QualType T1, QualType T2) const
QualType ActOnTemplateShorthand(TemplateDecl *Template, SourceLocation NameLoc)
void diagnoseSystemSemanticAttr(Decl *D, const ParsedAttr &AL, std::optional< unsigned > Index)
void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL)
bool CheckCompatibleParameterABI(FunctionDecl *New, FunctionDecl *Old)
QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, bool IsCompAssign)
QualType checkMatrixComponent(Sema &S, QualType baseType, ExprValueKind &VK, SourceLocation OpLoc, const IdentifierInfo *CompName, SourceLocation CompLoc)
void handleResourceBindingAttr(Decl *D, const ParsedAttr &AL)
bool IsTypedResourceElementCompatible(QualType T1)
bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init)
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL)
bool ActOnUninitializedVarDecl(VarDecl *D)
void handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL)
bool diagnoseInstantiatedMatrixLayoutAttr(Decl *D, const HLSLMatrixLayoutAttr *Attr)
void ActOnTopLevelFunction(FunctionDecl *FD)
Definition SemaHLSL.cpp:806
bool handleResourceTypeAttr(QualType T, const ParsedAttr &AL)
void handleVkPushConstantAttr(Decl *D, const ParsedAttr &AL)
HLSLShaderAttr * mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL, llvm::Triple::EnvironmentType ShaderType)
Definition SemaHLSL.cpp:773
NamedDecl * getConstantBufferConversionFunction(QualType Type, CXXRecordDecl *RD)
void ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace)
Definition SemaHLSL.cpp:676
void handleVkBindingAttr(Decl *D, const ParsedAttr &AL)
HLSLParamModifierAttr * mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL, HLSLParamModifierAttr::Spelling Spelling)
Definition SemaHLSL.cpp:786
QualType getInoutParameterType(QualType Ty)
void handleMatrixLayoutAttr(Decl *D, const ParsedAttr &AL)
SemaHLSL(Sema &S)
Definition SemaHLSL.cpp:237
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL)
Decl * ActOnStartBuffer(Scope *BufferScope, bool CBuffer, SourceLocation KwLoc, IdentifierInfo *Ident, SourceLocation IdentLoc, SourceLocation LBrace)
Definition SemaHLSL.cpp:239
HLSLWaveSizeAttr * mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL, int Min, int Max, int Preferred, int SpelledArgsCount)
Definition SemaHLSL.cpp:717
bool handleRootSignatureElements(ArrayRef< hlsl::RootSignatureElement > Elements)
void ActOnFinishRootSignatureDecl(SourceLocation Loc, IdentifierInfo *DeclIdent, ArrayRef< hlsl::RootSignatureElement > Elements)
Creates the Root Signature decl of the parsed Root Signature elements onto the AST and push it onto c...
void ActOnVariableDeclarator(VarDecl *VD)
bool CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall)
Sema - This implements semantic analysis and AST building for C.
Definition Sema.h:868
@ LookupOrdinaryName
Ordinary name lookup, which finds ordinary names (functions, variables, typedefs, etc....
Definition Sema.h:9415
@ LookupMemberName
Member name lookup, which finds the names of class/struct/union members.
Definition Sema.h:9423
ExtVectorDeclsType ExtVectorDecls
ExtVectorDecls - This is a list all the extended vector types.
Definition Sema.h:4955
ASTContext & Context
Definition Sema.h:1308
ASTContext & getASTContext() const
Definition Sema.h:939
ExprResult ImpCastExprToType(Expr *E, QualType Type, CastKind CK, ExprValueKind VK=VK_PRValue, const CXXCastPath *BasePath=nullptr, CheckedConversionKind CCK=CheckedConversionKind::Implicit)
ImpCastExprToType - If Expr is not of type 'Type', insert an implicit cast.
Definition Sema.cpp:762
const LangOptions & getLangOpts() const
Definition Sema.h:932
ExprResult TemporaryMaterializationConversion(Expr *E)
If E is a prvalue denoting an unmaterialized temporary, materialize it as an xvalue.
SemaHLSL & HLSL()
Definition Sema.h:1483
ExprResult BuildFieldReferenceExpr(Expr *BaseExpr, bool IsArrow, SourceLocation OpLoc, const CXXScopeSpec &SS, FieldDecl *Field, DeclAccessPair FoundDecl, const DeclarationNameInfo &MemberNameInfo)
bool checkArgCountRange(CallExpr *Call, unsigned MinArgCount, unsigned MaxArgCount)
Checks that a call expression's argument count is in the desired range.
ExternalSemaSource * getExternalSource() const
Definition Sema.h:942
ASTConsumer & Consumer
Definition Sema.h:1309
bool checkArgCount(CallExpr *Call, unsigned DesiredArgCount)
Checks that a call expression's argument count is the desired number.
ExprResult CreateBuiltinArraySubscriptExpr(Expr *Base, SourceLocation LLoc, Expr *Idx, SourceLocation RLoc)
bool LookupQualifiedName(LookupResult &R, DeclContext *LookupCtx, bool InUnqualifiedLookup=false)
Perform qualified name lookup into a given context.
ExprResult PerformCopyInitialization(const InitializedEntity &Entity, SourceLocation EqualLoc, ExprResult Init, bool TopLevelOfInitList=false, bool AllowExplicit=false)
ExprResult CreateBuiltinMatrixSubscriptExpr(Expr *Base, Expr *RowIdx, Expr *ColumnIdx, SourceLocation RBLoc)
Encodes a location in the source.
SourceLocation getLocWithOffset(IntTy Offset) const
Return a source location with the specified offset from this SourceLocation.
A trivial tuple used to represent a source range.
SourceLocation getEnd() const
SourceLocation getEndLoc() const LLVM_READONLY
Definition Stmt.cpp:367
void printPretty(raw_ostream &OS, PrinterHelper *Helper, const PrintingPolicy &Policy, unsigned Indentation=0, StringRef NewlineSymbol="\n", const ASTContext *Context=nullptr) const
SourceRange getSourceRange() const LLVM_READONLY
SourceLocation tokens are not useful in isolation - they are low level value objects created/interpre...
Definition Stmt.cpp:343
SourceLocation getBeginLoc() const LLVM_READONLY
Definition Stmt.cpp:355
StringLiteral - This represents a string literal expression, e.g.
Definition Expr.h:1802
static StringLiteral * Create(const ASTContext &Ctx, StringRef Str, StringLiteralKind Kind, bool Pascal, QualType Ty, ArrayRef< SourceLocation > Locs)
This is the "fully general" constructor that allows representation of strings formed from one or more...
Definition Expr.cpp:1188
void startDefinition()
Starts the definition of this tag declaration.
Definition Decl.cpp:4901
bool isUnion() const
Definition Decl.h:3946
bool isClass() const
Definition Decl.h:3945
Exposes information about the current target.
Definition TargetInfo.h:227
TargetOptions & getTargetOpts() const
Retrieve the target options.
Definition TargetInfo.h:327
const llvm::Triple & getTriple() const
Returns the target triple of the primary target.
StringRef getPlatformName() const
Retrieve the name of the platform as it is used in the availability attribute.
VersionTuple getPlatformMinVersion() const
Retrieve the minimum desired version of the platform, to which the program should be compiled.
std::string HLSLEntry
The entry point name for HLSL shader being compiled as specified by -E.
A convenient class for passing around template argument information.
void addArgument(const TemplateArgumentLoc &Loc)
The base class of all kinds of template declarations (e.g., class, function, etc.).
Stores a list of template parameters for a TemplateDecl and its derived classes.
The top declaration context.
Definition Decl.h:105
SourceLocation getBeginLoc() const
Get the begin source location.
Definition TypeLoc.cpp:193
A container of type source information.
Definition TypeBase.h:8407
TypeLoc getTypeLoc() const
Return the TypeLoc wrapper for the type source info.
Definition TypeLoc.h:267
The base class of the type hierarchy.
Definition TypeBase.h:1871
bool isVoidType() const
Definition TypeBase.h:9039
bool isBooleanType() const
Definition TypeBase.h:9176
bool isIncompleteArrayType() const
Definition TypeBase.h:8780
CXXRecordDecl * getAsCXXRecordDecl() const
Retrieves the CXXRecordDecl that this type refers to, either because the type is a RecordType or beca...
Definition Type.h:26
bool isConstantArrayType() const
Definition TypeBase.h:8776
bool hasIntegerRepresentation() const
Determine whether this type has an integer representation of some sort, e.g., it is an integer type o...
Definition Type.cpp:2119
bool isArrayType() const
Definition TypeBase.h:8772
CXXRecordDecl * castAsCXXRecordDecl() const
Definition Type.h:36
bool isArithmeticType() const
Definition Type.cpp:2422
bool isConstantMatrixType() const
Definition TypeBase.h:8840
bool isHLSLBuiltinIntangibleType() const
Definition TypeBase.h:8984
bool isPointerType() const
Definition TypeBase.h:8673
CanQualType getCanonicalTypeUnqualified() const
bool isIntegerType() const
isIntegerType() does not include complex integers (a GCC extension).
Definition TypeBase.h:9083
const T * castAs() const
Member-template castAs<specific type>.
Definition TypeBase.h:9333
bool isReferenceType() const
Definition TypeBase.h:8697
bool isHLSLIntangibleType() const
Definition Type.cpp:5508
bool isEnumeralType() const
Definition TypeBase.h:8804
bool isScalarType() const
Definition TypeBase.h:9145
bool isIntegralType(const ASTContext &Ctx) const
Determine whether this type is an integral type.
Definition Type.cpp:2156
const Type * getArrayElementTypeNoTypeQual() const
If this is an array type, return the element type of the array, potentially with type qualifiers miss...
Definition Type.cpp:508
QualType getPointeeType() const
If this is a pointer, ObjC object pointer, or block pointer, this returns the respective pointee.
Definition Type.cpp:789
bool hasUnsignedIntegerRepresentation() const
Determine whether this type has an unsigned integer representation of some sort, e....
Definition Type.cpp:2376
bool isDependentType() const
Whether this type is a dependent type, meaning that its definition somehow depends on a template para...
Definition TypeBase.h:2837
bool isAggregateType() const
Determines whether the type is a C++ aggregate type or C aggregate or union type.
Definition Type.cpp:2503
ScalarTypeKind getScalarTypeKind() const
Given that this is a scalar type, classify it.
Definition Type.cpp:2454
bool hasSignedIntegerRepresentation() const
Determine whether this type has an signed integer representation of some sort, e.g....
Definition Type.cpp:2310
bool isMatrixType() const
Definition TypeBase.h:8836
bool isHLSLResourceRecord() const
Definition Type.cpp:5495
bool hasFloatingRepresentation() const
Determine whether this type has a floating-point representation of some sort, e.g....
Definition Type.cpp:2397
bool isVectorType() const
Definition TypeBase.h:8812
bool isRealFloatingType() const
Floating point categories.
Definition Type.cpp:2405
bool isHLSLAttributedResourceType() const
Definition TypeBase.h:8996
@ STK_FloatingComplex
Definition TypeBase.h:2819
@ STK_ObjCObjectPointer
Definition TypeBase.h:2813
@ STK_IntegralComplex
Definition TypeBase.h:2818
@ STK_MemberPointer
Definition TypeBase.h:2814
bool isFloatingType() const
Definition Type.cpp:2389
bool isSamplerT() const
Definition TypeBase.h:8917
const T * getAs() const
Member-template getAs<specific type>'.
Definition TypeBase.h:9266
const Type * getUnqualifiedDesugaredType() const
Return the specified type with any "sugar" removed from the type, removing any typedefs,...
Definition Type.cpp:690
bool isRecordType() const
Definition TypeBase.h:8800
bool isHLSLResourceRecordArray() const
Definition Type.cpp:5499
void setType(QualType newType)
Definition Decl.h:724
QualType getType() const
Definition Decl.h:723
Represents a variable declaration or definition.
Definition Decl.h:924
static VarDecl * Create(ASTContext &C, DeclContext *DC, SourceLocation StartLoc, SourceLocation IdLoc, const IdentifierInfo *Id, QualType T, TypeSourceInfo *TInfo, StorageClass S)
Definition Decl.cpp:2130
void setInitStyle(InitializationStyle Style)
Definition Decl.h:1465
@ CallInit
Call-style initialization (C++98)
Definition Decl.h:932
void setStorageClass(StorageClass SC)
Definition Decl.cpp:2142
bool hasGlobalStorage() const
Returns true for all variables that do not have local storage.
Definition Decl.h:1239
void setInit(Expr *I)
Definition Decl.cpp:2456
StorageClass getStorageClass() const
Returns the storage class as written in the source.
Definition Decl.h:1166
Represents a GCC generic vector type.
Definition TypeBase.h:4230
unsigned getNumElements() const
Definition TypeBase.h:4245
QualType getElementType() const
Definition TypeBase.h:4244
IdentifierInfo * getNameAsIdentifier(ASTContext &AST) const
Defines the clang::TargetInfo interface.
Definition SPIR.cpp:47
uint32_t getResourceDimensions(llvm::dxil::ResourceDimension Dim)
bool hasCounterHandle(const CXXRecordDecl *RD)
The JSON file list parser is used to communicate input to InstallAPI.
bool isa(CodeGen::Address addr)
Definition Address.h:330
if(T->getSizeExpr()) TRY_TO(TraverseStmt(const_cast< Expr * >(T -> getSizeExpr())))
static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
Definition SemaSPIRV.cpp:66
@ ICIS_NoInit
No in-class initializer.
Definition Specifiers.h:273
@ TemplateName
The identifier is a template name. FIXME: Add an annotation for that.
Definition Parser.h:61
@ OK_Ordinary
An ordinary object is located at an address in memory.
Definition Specifiers.h:152
static bool CheckAllArgTypesAreCorrect(Sema *S, CallExpr *TheCall, llvm::ArrayRef< llvm::function_ref< bool(Sema *, SourceLocation, int, QualType)> > Checks)
Definition SemaSPIRV.cpp:49
@ AS_public
Definition Specifiers.h:125
@ AS_none
Definition Specifiers.h:128
@ SC_Extern
Definition Specifiers.h:252
@ SC_Static
Definition Specifiers.h:253
@ SC_None
Definition Specifiers.h:251
@ AANT_ArgumentIdentifier
@ Result
The result type of a method or function.
Definition TypeBase.h:905
@ Ordinary
This parameter uses ordinary ABI rules for its type.
Definition Specifiers.h:383
llvm::Expected< QualType > ExpectedType
@ Template
We are parsing a template declaration.
Definition Parser.h:81
LLVM_READONLY bool isDigit(unsigned char c)
Return true if this character is an ASCII digit: [0-9].
Definition CharInfo.h:114
static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall)
Definition SemaSPIRV.cpp:32
ExprResult ExprError()
Definition Ownership.h:265
@ Type
The name was classified as a type.
Definition Sema.h:564
LangAS
Defines the address space values used by the address space qualifier of QualType.
bool CreateHLSLAttributedResourceType(Sema &S, QualType Wrapped, ArrayRef< const Attr * > AttrList, QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo=nullptr)
CastKind
CastKind - The kind of operation required for a conversion.
ExprValueKind
The categorization of expression values, currently following the C++11 scheme.
Definition Specifiers.h:133
@ VK_PRValue
A pr-value expression (in the C++11 taxonomy) produces a temporary value.
Definition Specifiers.h:136
@ VK_LValue
An l-value expression is a reference to an object with independent storage.
Definition Specifiers.h:140
DynamicRecursiveASTVisitorBase< false > DynamicRecursiveASTVisitor
U cast(CodeGen::Address addr)
Definition Address.h:327
@ None
No keyword precedes the qualified type name.
Definition TypeBase.h:5982
ActionResult< Expr * > ExprResult
Definition Ownership.h:249
Visibility
Describes the different kinds of visibility that a declaration may have.
Definition Visibility.h:34
unsigned long uint64_t
unsigned int uint32_t
hash_code hash_value(const clang::dependencies::ModuleID &ID)
__DEVICE__ bool isnan(float __x)
__DEVICE__ _Tp abs(const std::complex< _Tp > &__c)
#define false
Definition stdbool.h:26
Describes how types, statements, expressions, and declarations should be printed.
void setCounterImplicitOrderID(unsigned Value) const
void setImplicitOrderID(unsigned Value) const
const SourceLocation & getLocation() const
Definition SemaHLSL.h:48
const llvm::hlsl::rootsig::RootElement & getElement() const
Definition SemaHLSL.h:47