clang 23.0.0git
SemaHLSL.cpp
Go to the documentation of this file.
1//===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This implements Semantic Analysis for HLSL constructs.
9//===----------------------------------------------------------------------===//
10
11#include "clang/Sema/SemaHLSL.h"
14#include "clang/AST/Attr.h"
15#include "clang/AST/Decl.h"
16#include "clang/AST/DeclBase.h"
17#include "clang/AST/DeclCXX.h"
20#include "clang/AST/Expr.h"
22#include "clang/AST/Type.h"
23#include "clang/AST/TypeBase.h"
24#include "clang/AST/TypeLoc.h"
28#include "clang/Basic/LLVM.h"
33#include "clang/Sema/Lookup.h"
35#include "clang/Sema/Sema.h"
36#include "clang/Sema/Template.h"
37#include "llvm/ADT/ArrayRef.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/StringExtras.h"
41#include "llvm/ADT/StringRef.h"
42#include "llvm/ADT/Twine.h"
43#include "llvm/Frontend/HLSL/HLSLBinding.h"
44#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
45#include "llvm/Support/Casting.h"
46#include "llvm/Support/DXILABI.h"
47#include "llvm/Support/ErrorHandling.h"
48#include "llvm/Support/FormatVariadic.h"
49#include "llvm/TargetParser/Triple.h"
50#include <cmath>
51#include <cstddef>
52#include <iterator>
53#include <utility>
54
55using namespace clang;
56using namespace clang::hlsl;
57using RegisterType = HLSLResourceBindingAttr::RegisterType;
58
60 CXXRecordDecl *StructDecl);
61
63 switch (RC) {
64 case ResourceClass::SRV:
65 return RegisterType::SRV;
66 case ResourceClass::UAV:
67 return RegisterType::UAV;
68 case ResourceClass::CBuffer:
69 return RegisterType::CBuffer;
70 case ResourceClass::Sampler:
71 return RegisterType::Sampler;
72 }
73 llvm_unreachable("unexpected ResourceClass value");
74}
75
76static RegisterType getRegisterType(const HLSLAttributedResourceType *ResTy) {
77 return getRegisterType(ResTy->getAttrs().ResourceClass);
78}
79
80// Converts the first letter of string Slot to RegisterType.
81// Returns false if the letter does not correspond to a valid register type.
82static bool convertToRegisterType(StringRef Slot, RegisterType *RT) {
83 assert(RT != nullptr);
84 switch (Slot[0]) {
85 case 't':
86 case 'T':
87 *RT = RegisterType::SRV;
88 return true;
89 case 'u':
90 case 'U':
91 *RT = RegisterType::UAV;
92 return true;
93 case 'b':
94 case 'B':
95 *RT = RegisterType::CBuffer;
96 return true;
97 case 's':
98 case 'S':
99 *RT = RegisterType::Sampler;
100 return true;
101 case 'c':
102 case 'C':
103 *RT = RegisterType::C;
104 return true;
105 case 'i':
106 case 'I':
107 *RT = RegisterType::I;
108 return true;
109 default:
110 return false;
111 }
112}
113
115 switch (RT) {
116 case RegisterType::SRV:
117 return 't';
118 case RegisterType::UAV:
119 return 'u';
120 case RegisterType::CBuffer:
121 return 'b';
122 case RegisterType::Sampler:
123 return 's';
124 case RegisterType::C:
125 return 'c';
126 case RegisterType::I:
127 return 'i';
128 }
129 llvm_unreachable("unexpected RegisterType value");
130}
131
133 switch (RT) {
134 case RegisterType::SRV:
135 return ResourceClass::SRV;
136 case RegisterType::UAV:
137 return ResourceClass::UAV;
138 case RegisterType::CBuffer:
139 return ResourceClass::CBuffer;
140 case RegisterType::Sampler:
141 return ResourceClass::Sampler;
142 case RegisterType::C:
143 case RegisterType::I:
144 // Deliberately falling through to the unreachable below.
145 break;
146 }
147 llvm_unreachable("unexpected RegisterType value");
148}
149
151 const auto *BT = dyn_cast<BuiltinType>(Type);
152 if (!BT) {
153 if (!Type->isEnumeralType())
154 return Builtin::NotBuiltin;
155 return Builtin::BI__builtin_get_spirv_spec_constant_int;
156 }
157
158 switch (BT->getKind()) {
159 case BuiltinType::Bool:
160 return Builtin::BI__builtin_get_spirv_spec_constant_bool;
161 case BuiltinType::Short:
162 return Builtin::BI__builtin_get_spirv_spec_constant_short;
163 case BuiltinType::Int:
164 return Builtin::BI__builtin_get_spirv_spec_constant_int;
165 case BuiltinType::LongLong:
166 return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
167 case BuiltinType::UShort:
168 return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
169 case BuiltinType::UInt:
170 return Builtin::BI__builtin_get_spirv_spec_constant_uint;
171 case BuiltinType::ULongLong:
172 return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
173 case BuiltinType::Half:
174 return Builtin::BI__builtin_get_spirv_spec_constant_half;
175 case BuiltinType::Float:
176 return Builtin::BI__builtin_get_spirv_spec_constant_float;
177 case BuiltinType::Double:
178 return Builtin::BI__builtin_get_spirv_spec_constant_double;
179 default:
180 return Builtin::NotBuiltin;
181 }
182}
183
184static StringRef createRegisterString(ASTContext &AST, RegisterType RegType,
185 unsigned N) {
187 llvm::raw_svector_ostream OS(Buffer);
188 OS << getRegisterTypeChar(RegType);
189 OS << N;
190 return AST.backupStr(OS.str());
191}
192
194 ResourceClass ResClass) {
195 assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
196 "DeclBindingInfo already added");
197 assert(!hasBindingInfoForDecl(VD) || BindingsList.back().Decl == VD);
198 // VarDecl may have multiple entries for different resource classes.
199 // DeclToBindingListIndex stores the index of the first binding we saw
200 // for this decl. If there are any additional ones then that index
201 // shouldn't be updated.
202 DeclToBindingListIndex.try_emplace(VD, BindingsList.size());
203 return &BindingsList.emplace_back(VD, ResClass);
204}
205
207 ResourceClass ResClass) {
208 auto Entry = DeclToBindingListIndex.find(VD);
209 if (Entry != DeclToBindingListIndex.end()) {
210 for (unsigned Index = Entry->getSecond();
211 Index < BindingsList.size() && BindingsList[Index].Decl == VD;
212 ++Index) {
213 if (BindingsList[Index].ResClass == ResClass)
214 return &BindingsList[Index];
215 }
216 }
217 return nullptr;
218}
219
221 return DeclToBindingListIndex.contains(VD);
222}
223
225
226Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
227 SourceLocation KwLoc, IdentifierInfo *Ident,
228 SourceLocation IdentLoc,
229 SourceLocation LBrace) {
230 // For anonymous namespace, take the location of the left brace.
231 DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
233 getASTContext(), LexicalParent, CBuffer, KwLoc, Ident, IdentLoc, LBrace);
234
235 // if CBuffer is false, then it's a TBuffer
236 auto RC = CBuffer ? llvm::hlsl::ResourceClass::CBuffer
237 : llvm::hlsl::ResourceClass::SRV;
238 Result->addAttr(HLSLResourceClassAttr::CreateImplicit(getASTContext(), RC));
239
240 SemaRef.PushOnScopeChains(Result, BufferScope);
241 SemaRef.PushDeclContext(BufferScope, Result);
242
243 return Result;
244}
245
246static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context,
247 QualType T) {
248 // Arrays, Matrices, and Structs are always aligned to new buffer rows
249 if (T->isArrayType() || T->isStructureType() || T->isConstantMatrixType())
250 return 16;
251
252 // Vectors are aligned to the type they contain
253 if (const VectorType *VT = T->getAs<VectorType>())
254 return calculateLegacyCbufferFieldAlign(Context, VT->getElementType());
255
256 assert(Context.getTypeSize(T) <= 64 &&
257 "Scalar bit widths larger than 64 not supported");
258
259 // Scalar types are aligned to their byte width
260 return Context.getTypeSize(T) / 8;
261}
262
263// Calculate the size of a legacy cbuffer type in bytes based on
264// https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules
265static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
266 QualType T) {
267 constexpr unsigned CBufferAlign = 16;
268 if (const auto *RD = T->getAsRecordDecl()) {
269 unsigned Size = 0;
270 for (const FieldDecl *Field : RD->fields()) {
271 QualType Ty = Field->getType();
272 unsigned FieldSize = calculateLegacyCbufferSize(Context, Ty);
273 unsigned FieldAlign = calculateLegacyCbufferFieldAlign(Context, Ty);
274
275 // If the field crosses the row boundary after alignment it drops to the
276 // next row
277 unsigned AlignSize = llvm::alignTo(Size, FieldAlign);
278 if ((AlignSize % CBufferAlign) + FieldSize > CBufferAlign) {
279 FieldAlign = CBufferAlign;
280 }
281
282 Size = llvm::alignTo(Size, FieldAlign);
283 Size += FieldSize;
284 }
285 return Size;
286 }
287
288 if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
289 unsigned ElementCount = AT->getSize().getZExtValue();
290 if (ElementCount == 0)
291 return 0;
292
293 unsigned ElementSize =
294 calculateLegacyCbufferSize(Context, AT->getElementType());
295 unsigned AlignedElementSize = llvm::alignTo(ElementSize, CBufferAlign);
296 return AlignedElementSize * (ElementCount - 1) + ElementSize;
297 }
298
299 if (const VectorType *VT = T->getAs<VectorType>()) {
300 unsigned ElementCount = VT->getNumElements();
301 unsigned ElementSize =
302 calculateLegacyCbufferSize(Context, VT->getElementType());
303 return ElementSize * ElementCount;
304 }
305
306 return Context.getTypeSize(T) / 8;
307}
308
309// Validate packoffset:
310// - if packoffset it used it must be set on all declarations inside the buffer
311// - packoffset ranges must not overlap
312static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) {
314
315 // Make sure the packoffset annotations are either on all declarations
316 // or on none.
317 bool HasPackOffset = false;
318 bool HasNonPackOffset = false;
319 for (auto *Field : BufDecl->buffer_decls()) {
320 VarDecl *Var = dyn_cast<VarDecl>(Field);
321 if (!Var)
322 continue;
323 if (Field->hasAttr<HLSLPackOffsetAttr>()) {
324 PackOffsetVec.emplace_back(Var, Field->getAttr<HLSLPackOffsetAttr>());
325 HasPackOffset = true;
326 } else {
327 HasNonPackOffset = true;
328 }
329 }
330
331 if (!HasPackOffset)
332 return;
333
334 if (HasNonPackOffset)
335 S.Diag(BufDecl->getLocation(), diag::warn_hlsl_packoffset_mix);
336
337 // Make sure there is no overlap in packoffset - sort PackOffsetVec by offset
338 // and compare adjacent values.
339 bool IsValid = true;
340 ASTContext &Context = S.getASTContext();
341 std::sort(PackOffsetVec.begin(), PackOffsetVec.end(),
342 [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,
343 const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {
344 return LHS.second->getOffsetInBytes() <
345 RHS.second->getOffsetInBytes();
346 });
347 for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {
348 VarDecl *Var = PackOffsetVec[i].first;
349 HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;
350 unsigned Size = calculateLegacyCbufferSize(Context, Var->getType());
351 unsigned Begin = Attr->getOffsetInBytes();
352 unsigned End = Begin + Size;
353 unsigned NextBegin = PackOffsetVec[i + 1].second->getOffsetInBytes();
354 if (End > NextBegin) {
355 VarDecl *NextVar = PackOffsetVec[i + 1].first;
356 S.Diag(NextVar->getLocation(), diag::err_hlsl_packoffset_overlap)
357 << NextVar << Var;
358 IsValid = false;
359 }
360 }
361 BufDecl->setHasValidPackoffset(IsValid);
362}
363
364// Returns true if the array has a zero size = if any of the dimensions is 0
365static bool isZeroSizedArray(const ConstantArrayType *CAT) {
366 while (CAT && !CAT->isZeroSize())
367 CAT = dyn_cast<ConstantArrayType>(
369 return CAT != nullptr;
370}
371
375
379
380static const HLSLAttributedResourceType *
382 assert(QT->isHLSLResourceRecordArray() &&
383 "expected array of resource records");
384 const Type *Ty = QT->getUnqualifiedDesugaredType();
385 while (const ArrayType *AT = dyn_cast<ArrayType>(Ty))
387 return HLSLAttributedResourceType::findHandleTypeOnResource(Ty);
388}
389
390static const HLSLAttributedResourceType *
394
395// Returns true if the type is a leaf element type that is not valid to be
396// included in HLSL Buffer, such as a resource class, empty struct, zero-sized
397// array, or a builtin intangible type. Returns false it is a valid leaf element
398// type or if it is a record type that needs to be inspected further.
402 return true;
403 if (const auto *RD = Ty->getAsCXXRecordDecl())
404 return RD->isEmpty();
405 if (Ty->isConstantArrayType() &&
407 return true;
409 return true;
410 return false;
411}
412
413// Returns true if the struct contains at least one element that prevents it
414// from being included inside HLSL Buffer as is, such as an intangible type,
415// empty struct, or zero-sized array. If it does, a new implicit layout struct
416// needs to be created for HLSL Buffer use that will exclude these unwanted
417// declarations (see createHostLayoutStruct function).
419 if (RD->isHLSLIntangible() || RD->isEmpty())
420 return true;
421 // check fields
422 for (const FieldDecl *Field : RD->fields()) {
423 QualType Ty = Field->getType();
425 return true;
426 if (const auto *RD = Ty->getAsCXXRecordDecl();
428 return true;
429 }
430 // check bases
431 for (const CXXBaseSpecifier &Base : RD->bases())
433 Base.getType()->castAsCXXRecordDecl()))
434 return true;
435 return false;
436}
437
439 DeclContext *DC) {
440 CXXRecordDecl *RD = nullptr;
441 for (NamedDecl *Decl :
443 if (CXXRecordDecl *FoundRD = dyn_cast<CXXRecordDecl>(Decl)) {
444 assert(RD == nullptr &&
445 "there should be at most 1 record by a given name in a scope");
446 RD = FoundRD;
447 }
448 }
449 return RD;
450}
451
452// Creates a name for buffer layout struct using the provide name base.
453// If the name must be unique (not previously defined), a suffix is added
454// until a unique name is found.
456 bool MustBeUnique) {
457 ASTContext &AST = S.getASTContext();
458
459 IdentifierInfo *NameBaseII = BaseDecl->getIdentifier();
460 llvm::SmallString<64> Name("__cblayout_");
461 if (NameBaseII) {
462 Name.append(NameBaseII->getName());
463 } else {
464 // anonymous struct
465 Name.append("anon");
466 MustBeUnique = true;
467 }
468
469 size_t NameLength = Name.size();
470 IdentifierInfo *II = &AST.Idents.get(Name, tok::TokenKind::identifier);
471 if (!MustBeUnique)
472 return II;
473
474 unsigned suffix = 0;
475 while (true) {
476 if (suffix != 0) {
477 Name.append("_");
478 Name.append(llvm::Twine(suffix).str());
479 II = &AST.Idents.get(Name, tok::TokenKind::identifier);
480 }
481 if (!findRecordDeclInContext(II, BaseDecl->getDeclContext()))
482 return II;
483 // declaration with that name already exists - increment suffix and try
484 // again until unique name is found
485 suffix++;
486 Name.truncate(NameLength);
487 };
488}
489
490static const Type *createHostLayoutType(Sema &S, const Type *Ty) {
491 ASTContext &AST = S.getASTContext();
492 if (auto *RD = Ty->getAsCXXRecordDecl()) {
494 return Ty;
495 RD = createHostLayoutStruct(S, RD);
496 if (!RD)
497 return nullptr;
498 return AST.getCanonicalTagType(RD)->getTypePtr();
499 }
500
501 if (const auto *CAT = dyn_cast<ConstantArrayType>(Ty)) {
502 const Type *ElementTy = createHostLayoutType(
503 S, CAT->getElementType()->getUnqualifiedDesugaredType());
504 if (!ElementTy)
505 return nullptr;
506 return AST
507 .getConstantArrayType(QualType(ElementTy, 0), CAT->getSize(), nullptr,
508 CAT->getSizeModifier(),
509 CAT->getIndexTypeCVRQualifiers())
510 .getTypePtr();
511 }
512 return Ty;
513}
514
515// Creates a field declaration of given name and type for HLSL buffer layout
516// struct. Returns nullptr if the type cannot be use in HLSL Buffer layout.
518 IdentifierInfo *II,
519 CXXRecordDecl *LayoutStruct) {
521 return nullptr;
522
523 Ty = createHostLayoutType(S, Ty);
524 if (!Ty)
525 return nullptr;
526
527 QualType QT = QualType(Ty, 0);
528 ASTContext &AST = S.getASTContext();
530 auto *Field = FieldDecl::Create(AST, LayoutStruct, SourceLocation(),
531 SourceLocation(), II, QT, TSI, nullptr, false,
533 Field->setAccess(AccessSpecifier::AS_public);
534 return Field;
535}
536
537// Creates host layout struct for a struct included in HLSL Buffer.
538// The layout struct will include only fields that are allowed in HLSL buffer.
539// These fields will be filtered out:
540// - resource classes
541// - empty structs
542// - zero-sized arrays
543// Returns nullptr if the resulting layout struct would be empty.
545 CXXRecordDecl *StructDecl) {
546 assert(requiresImplicitBufferLayoutStructure(StructDecl) &&
547 "struct is already HLSL buffer compatible");
548
549 ASTContext &AST = S.getASTContext();
550 DeclContext *DC = StructDecl->getDeclContext();
551 IdentifierInfo *II = getHostLayoutStructName(S, StructDecl, false);
552
553 // reuse existing if the layout struct if it already exists
554 if (CXXRecordDecl *RD = findRecordDeclInContext(II, DC))
555 return RD;
556
557 CXXRecordDecl *LS =
558 CXXRecordDecl::Create(AST, TagDecl::TagKind::Struct, DC, SourceLocation(),
559 SourceLocation(), II);
560 LS->setImplicit(true);
561 LS->addAttr(PackedAttr::CreateImplicit(AST));
562 LS->startDefinition();
563
564 // copy base struct, create HLSL Buffer compatible version if needed
565 if (unsigned NumBases = StructDecl->getNumBases()) {
566 assert(NumBases == 1 && "HLSL supports only one base type");
567 (void)NumBases;
568 CXXBaseSpecifier Base = *StructDecl->bases_begin();
569 CXXRecordDecl *BaseDecl = Base.getType()->castAsCXXRecordDecl();
571 BaseDecl = createHostLayoutStruct(S, BaseDecl);
572 if (BaseDecl) {
573 TypeSourceInfo *TSI =
575 Base = CXXBaseSpecifier(SourceRange(), false, StructDecl->isClass(),
576 AS_none, TSI, SourceLocation());
577 }
578 }
579 if (BaseDecl) {
580 const CXXBaseSpecifier *BasesArray[1] = {&Base};
581 LS->setBases(BasesArray, 1);
582 }
583 }
584
585 // filter struct fields
586 for (const FieldDecl *FD : StructDecl->fields()) {
587 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
588 if (FieldDecl *NewFD =
589 createFieldForHostLayoutStruct(S, Ty, FD->getIdentifier(), LS))
590 LS->addDecl(NewFD);
591 }
592 LS->completeDefinition();
593
594 if (LS->field_empty() && LS->getNumBases() == 0)
595 return nullptr;
596
597 DC->addDecl(LS);
598 return LS;
599}
600
601// Creates host layout struct for HLSL Buffer. The struct will include only
602// fields of types that are allowed in HLSL buffer and it will filter out:
603// - static or groupshared variable declarations
604// - resource classes
605// - empty structs
606// - zero-sized arrays
607// - non-variable declarations
608// The layout struct will be added to the HLSLBufferDecl declarations.
610 ASTContext &AST = S.getASTContext();
611 IdentifierInfo *II = getHostLayoutStructName(S, BufDecl, true);
612
613 CXXRecordDecl *LS =
614 CXXRecordDecl::Create(AST, TagDecl::TagKind::Struct, BufDecl,
616 LS->addAttr(PackedAttr::CreateImplicit(AST));
617 LS->setImplicit(true);
618 LS->startDefinition();
619
620 for (Decl *D : BufDecl->buffer_decls()) {
621 VarDecl *VD = dyn_cast<VarDecl>(D);
622 if (!VD || VD->getStorageClass() == SC_Static ||
624 continue;
625 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
626
627 FieldDecl *FD =
629 // Declarations collected for the default $Globals constant buffer have
630 // already been checked to have non-empty cbuffer layout, so
631 // createFieldForHostLayoutStruct should always succeed. These declarations
632 // already have their address space set to hlsl_constant.
633 // For declarations in a named cbuffer block
634 // createFieldForHostLayoutStruct can still return nullptr if the type
635 // is empty (does not have a cbuffer layout).
636 assert((FD || VD->getType().getAddressSpace() != LangAS::hlsl_constant) &&
637 "host layout field for $Globals decl failed to be created");
638 if (FD) {
639 // Add the field decl to the layout struct.
640 LS->addDecl(FD);
642 // Update address space of the original decl to hlsl_constant.
643 QualType NewTy =
645 VD->setType(NewTy);
646 }
647 }
648 }
649 LS->completeDefinition();
650 BufDecl->addLayoutStruct(LS);
651}
652
654 uint32_t ImplicitBindingOrderID) {
655 auto *Attr =
656 HLSLResourceBindingAttr::CreateImplicit(S.getASTContext(), "", "0", {});
657 Attr->setBinding(RT, std::nullopt, 0);
658 Attr->setImplicitBindingOrderID(ImplicitBindingOrderID);
659 D->addAttr(Attr);
660}
661
662// Handle end of cbuffer/tbuffer declaration
664 auto *BufDecl = cast<HLSLBufferDecl>(Dcl);
665 BufDecl->setRBraceLoc(RBrace);
666
667 validatePackoffset(SemaRef, BufDecl);
668
670
671 // Handle implicit binding if needed.
672 ResourceBindingAttrs ResourceAttrs(Dcl);
673 if (!ResourceAttrs.isExplicit()) {
674 SemaRef.Diag(Dcl->getLocation(), diag::warn_hlsl_implicit_binding);
675 // Use HLSLResourceBindingAttr to transfer implicit binding order_ID
676 // to codegen. If it does not exist, create an implicit attribute.
677 uint32_t OrderID = getNextImplicitBindingOrderID();
678 if (ResourceAttrs.hasBinding())
679 ResourceAttrs.setImplicitOrderID(OrderID);
680 else
682 BufDecl->isCBuffer() ? RegisterType::CBuffer
683 : RegisterType::SRV,
684 OrderID);
685 }
686
687 SemaRef.PopDeclContext();
688}
689
690HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
691 const AttributeCommonInfo &AL,
692 int X, int Y, int Z) {
693 if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
694 if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
695 Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
696 Diag(AL.getLoc(), diag::note_conflicting_attribute);
697 }
698 return nullptr;
699 }
700 return ::new (getASTContext())
701 HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
702}
703
705 const AttributeCommonInfo &AL,
706 int Min, int Max, int Preferred,
707 int SpelledArgsCount) {
708 if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) {
709 if (WS->getMin() != Min || WS->getMax() != Max ||
710 WS->getPreferred() != Preferred ||
711 WS->getSpelledArgsCount() != SpelledArgsCount) {
712 Diag(WS->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
713 Diag(AL.getLoc(), diag::note_conflicting_attribute);
714 }
715 return nullptr;
716 }
717 HLSLWaveSizeAttr *Result = ::new (getASTContext())
718 HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
719 Result->setSpelledArgsCount(SpelledArgsCount);
720 return Result;
721}
722
723HLSLVkConstantIdAttr *
725 int Id) {
726
728 if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
729 Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL;
730 return nullptr;
731 }
732
733 auto *VD = cast<VarDecl>(D);
734
735 if (getSpecConstBuiltinId(VD->getType()->getUnqualifiedDesugaredType()) ==
737 Diag(VD->getLocation(), diag::err_specialization_const);
738 return nullptr;
739 }
740
741 if (!VD->getType().isConstQualified()) {
742 Diag(VD->getLocation(), diag::err_specialization_const);
743 return nullptr;
744 }
745
746 if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
747 if (CI->getId() != Id) {
748 Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
749 Diag(AL.getLoc(), diag::note_conflicting_attribute);
750 }
751 return nullptr;
752 }
753
754 HLSLVkConstantIdAttr *Result =
755 ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
756 return Result;
757}
758
759HLSLShaderAttr *
761 llvm::Triple::EnvironmentType ShaderType) {
762 if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
763 if (NT->getType() != ShaderType) {
764 Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
765 Diag(AL.getLoc(), diag::note_conflicting_attribute);
766 }
767 return nullptr;
768 }
769 return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL);
770}
771
772HLSLParamModifierAttr *
774 HLSLParamModifierAttr::Spelling Spelling) {
775 // We can only merge an `in` attribute with an `out` attribute. All other
776 // combinations of duplicated attributes are ill-formed.
777 if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
778 if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
779 (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
780 D->dropAttr<HLSLParamModifierAttr>();
781 SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
782 return HLSLParamModifierAttr::Create(
783 getASTContext(), /*MergedSpelling=*/true, AdjustedRange,
784 HLSLParamModifierAttr::Keyword_inout);
785 }
786 Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
787 Diag(PA->getLocation(), diag::note_conflicting_attribute);
788 return nullptr;
789 }
790 return HLSLParamModifierAttr::Create(getASTContext(), AL);
791}
792
795
797 return;
798
799 // If we have specified a root signature to override the entry function then
800 // attach it now
801 HLSLRootSignatureDecl *SignatureDecl =
803 if (SignatureDecl) {
804 FD->dropAttr<RootSignatureAttr>();
805 // We could look up the SourceRange of the macro here as well
806 AttributeCommonInfo AL(RootSigOverrideIdent, AttributeScopeInfo(),
807 SourceRange(), ParsedAttr::Form::Microsoft());
808 FD->addAttr(::new (getASTContext()) RootSignatureAttr(
809 getASTContext(), AL, RootSigOverrideIdent, SignatureDecl));
810 }
811
812 llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
813 if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) {
814 if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
815 // The entry point is already annotated - check that it matches the
816 // triple.
817 if (Shader->getType() != Env) {
818 Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
819 << Shader;
820 FD->setInvalidDecl();
821 }
822 } else {
823 // Implicitly add the shader attribute if the entry function isn't
824 // explicitly annotated.
825 FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env,
826 FD->getBeginLoc()));
827 }
828 } else {
829 switch (Env) {
830 case llvm::Triple::UnknownEnvironment:
831 case llvm::Triple::Library:
832 break;
833 case llvm::Triple::RootSignature:
834 llvm_unreachable("rootsig environment has no functions");
835 default:
836 llvm_unreachable("Unhandled environment in triple");
837 }
838 }
839}
840
841static bool isVkPipelineBuiltin(const ASTContext &AstContext, FunctionDecl *FD,
842 HLSLAppliedSemanticAttr *Semantic,
843 bool IsInput) {
844 if (AstContext.getTargetInfo().getTriple().getOS() != llvm::Triple::Vulkan)
845 return false;
846
847 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
848 assert(ShaderAttr && "Entry point has no shader attribute");
849 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
850 auto SemanticName = Semantic->getSemanticName().upper();
851
852 // The SV_Position semantic is lowered to:
853 // - Position built-in for vertex output.
854 // - FragCoord built-in for fragment input.
855 if (SemanticName == "SV_POSITION") {
856 return (ST == llvm::Triple::Vertex && !IsInput) ||
857 (ST == llvm::Triple::Pixel && IsInput);
858 }
859 if (SemanticName == "SV_VERTEXID")
860 return true;
861
862 return false;
863}
864
865bool SemaHLSL::determineActiveSemanticOnScalar(FunctionDecl *FD,
866 DeclaratorDecl *OutputDecl,
868 SemanticInfo &ActiveSemantic,
869 SemaHLSL::SemanticContext &SC) {
870 if (ActiveSemantic.Semantic == nullptr) {
871 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
872 if (ActiveSemantic.Semantic)
873 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
874 }
875
876 if (!ActiveSemantic.Semantic) {
877 Diag(D->getLocation(), diag::err_hlsl_missing_semantic_annotation);
878 return false;
879 }
880
881 auto *A = ::new (getASTContext())
882 HLSLAppliedSemanticAttr(getASTContext(), *ActiveSemantic.Semantic,
883 ActiveSemantic.Semantic->getAttrName()->getName(),
884 ActiveSemantic.Index.value_or(0));
885 if (!A)
887
888 checkSemanticAnnotation(FD, D, A, SC);
889 OutputDecl->addAttr(A);
890
891 unsigned Location = ActiveSemantic.Index.value_or(0);
892
894 SC.CurrentIOType & IOType::In)) {
895 bool HasVkLocation = false;
896 if (auto *A = D->getAttr<HLSLVkLocationAttr>()) {
897 HasVkLocation = true;
898 Location = A->getLocation();
899 }
900
901 if (SC.UsesExplicitVkLocations.value_or(HasVkLocation) != HasVkLocation) {
902 Diag(D->getLocation(), diag::err_hlsl_semantic_partial_explicit_indexing);
903 return false;
904 }
905 SC.UsesExplicitVkLocations = HasVkLocation;
906 }
907
908 const ConstantArrayType *AT = dyn_cast<ConstantArrayType>(D->getType());
909 unsigned ElementCount = AT ? AT->getZExtSize() : 1;
910 ActiveSemantic.Index = Location + ElementCount;
911
912 Twine BaseName = Twine(ActiveSemantic.Semantic->getAttrName()->getName());
913 for (unsigned I = 0; I < ElementCount; ++I) {
914 Twine VariableName = BaseName.concat(Twine(Location + I));
915
916 auto [_, Inserted] = SC.ActiveSemantics.insert(VariableName.str());
917 if (!Inserted) {
918 Diag(D->getLocation(), diag::err_hlsl_semantic_index_overlap)
919 << VariableName.str();
920 return false;
921 }
922 }
923
924 return true;
925}
926
927bool SemaHLSL::determineActiveSemantic(FunctionDecl *FD,
928 DeclaratorDecl *OutputDecl,
930 SemanticInfo &ActiveSemantic,
931 SemaHLSL::SemanticContext &SC) {
932 if (ActiveSemantic.Semantic == nullptr) {
933 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
934 if (ActiveSemantic.Semantic)
935 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
936 }
937
938 const Type *T = D == FD ? &*FD->getReturnType() : &*D->getType();
940
941 const RecordType *RT = dyn_cast<RecordType>(T);
942 if (!RT)
943 return determineActiveSemanticOnScalar(FD, OutputDecl, D, ActiveSemantic,
944 SC);
945
946 const RecordDecl *RD = RT->getDecl();
947 for (FieldDecl *Field : RD->fields()) {
948 SemanticInfo Info = ActiveSemantic;
949 if (!determineActiveSemantic(FD, OutputDecl, Field, Info, SC)) {
950 Diag(Field->getLocation(), diag::note_hlsl_semantic_used_here) << Field;
951 return false;
952 }
953 if (ActiveSemantic.Semantic)
954 ActiveSemantic = Info;
955 }
956
957 return true;
958}
959
961 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
962 assert(ShaderAttr && "Entry point has no shader attribute");
963 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
965 VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
966 switch (ST) {
967 case llvm::Triple::Pixel:
968 case llvm::Triple::Vertex:
969 case llvm::Triple::Geometry:
970 case llvm::Triple::Hull:
971 case llvm::Triple::Domain:
972 case llvm::Triple::RayGeneration:
973 case llvm::Triple::Intersection:
974 case llvm::Triple::AnyHit:
975 case llvm::Triple::ClosestHit:
976 case llvm::Triple::Miss:
977 case llvm::Triple::Callable:
978 if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
979 diagnoseAttrStageMismatch(NT, ST,
980 {llvm::Triple::Compute,
981 llvm::Triple::Amplification,
982 llvm::Triple::Mesh});
983 FD->setInvalidDecl();
984 }
985 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
986 diagnoseAttrStageMismatch(WS, ST,
987 {llvm::Triple::Compute,
988 llvm::Triple::Amplification,
989 llvm::Triple::Mesh});
990 FD->setInvalidDecl();
991 }
992 break;
993
994 case llvm::Triple::Compute:
995 case llvm::Triple::Amplification:
996 case llvm::Triple::Mesh:
997 if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
998 Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
999 << llvm::Triple::getEnvironmentTypeName(ST);
1000 FD->setInvalidDecl();
1001 }
1002 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
1003 if (Ver < VersionTuple(6, 6)) {
1004 Diag(WS->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model)
1005 << WS << "6.6";
1006 FD->setInvalidDecl();
1007 } else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
1008 Diag(
1009 WS->getLocation(),
1010 diag::err_hlsl_attribute_number_arguments_insufficient_shader_model)
1011 << WS << WS->getSpelledArgsCount() << "6.8";
1012 FD->setInvalidDecl();
1013 }
1014 }
1015 break;
1016 case llvm::Triple::RootSignature:
1017 llvm_unreachable("rootsig environment has no function entry point");
1018 default:
1019 llvm_unreachable("Unhandled environment in triple");
1020 }
1021
1022 SemaHLSL::SemanticContext InputSC = {};
1023 InputSC.CurrentIOType = IOType::In;
1024
1025 for (ParmVarDecl *Param : FD->parameters()) {
1026 SemanticInfo ActiveSemantic;
1027 ActiveSemantic.Semantic = Param->getAttr<HLSLParsedSemanticAttr>();
1028 if (ActiveSemantic.Semantic)
1029 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
1030
1031 // FIXME: Verify output semantics in parameters.
1032 if (!determineActiveSemantic(FD, Param, Param, ActiveSemantic, InputSC)) {
1033 Diag(Param->getLocation(), diag::note_previous_decl) << Param;
1034 FD->setInvalidDecl();
1035 }
1036 }
1037
1038 SemanticInfo ActiveSemantic;
1039 SemaHLSL::SemanticContext OutputSC = {};
1040 OutputSC.CurrentIOType = IOType::Out;
1041 ActiveSemantic.Semantic = FD->getAttr<HLSLParsedSemanticAttr>();
1042 if (ActiveSemantic.Semantic)
1043 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
1044 if (!FD->getReturnType()->isVoidType())
1045 determineActiveSemantic(FD, FD, FD, ActiveSemantic, OutputSC);
1046}
1047
1048void SemaHLSL::checkSemanticAnnotation(
1049 FunctionDecl *EntryPoint, const Decl *Param,
1050 const HLSLAppliedSemanticAttr *SemanticAttr, const SemanticContext &SC) {
1051 auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
1052 assert(ShaderAttr && "Entry point has no shader attribute");
1053 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
1054
1055 auto SemanticName = SemanticAttr->getSemanticName().upper();
1056 if (SemanticName == "SV_DISPATCHTHREADID" ||
1057 SemanticName == "SV_GROUPINDEX" || SemanticName == "SV_GROUPTHREADID" ||
1058 SemanticName == "SV_GROUPID") {
1059
1060 if (ST != llvm::Triple::Compute)
1061 diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
1062 {{llvm::Triple::Compute, IOType::In}});
1063
1064 if (SemanticAttr->getSemanticIndex() != 0) {
1065 std::string PrettyName =
1066 "'" + SemanticAttr->getSemanticName().str() + "'";
1067 Diag(SemanticAttr->getLoc(),
1068 diag::err_hlsl_semantic_indexing_not_supported)
1069 << PrettyName;
1070 }
1071 return;
1072 }
1073
1074 if (SemanticName == "SV_POSITION") {
1075 // SV_Position can be an input or output in vertex shaders,
1076 // but only an input in pixel shaders.
1077 diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
1078 {{llvm::Triple::Vertex, IOType::InOut},
1079 {llvm::Triple::Pixel, IOType::In}});
1080 return;
1081 }
1082 if (SemanticName == "SV_VERTEXID") {
1083 diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
1084 {{llvm::Triple::Vertex, IOType::In}});
1085 return;
1086 }
1087
1088 if (SemanticName == "SV_TARGET") {
1089 diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
1090 {{llvm::Triple::Pixel, IOType::Out}});
1091 return;
1092 }
1093
1094 // FIXME: catch-all for non-implemented system semantics reaching this
1095 // location.
1096 if (SemanticAttr->getAttrName()->getName().starts_with_insensitive("SV_"))
1097 llvm_unreachable("Unknown SemanticAttr");
1098}
1099
1100void SemaHLSL::diagnoseAttrStageMismatch(
1101 const Attr *A, llvm::Triple::EnvironmentType Stage,
1102 std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
1103 SmallVector<StringRef, 8> StageStrings;
1104 llvm::transform(AllowedStages, std::back_inserter(StageStrings),
1105 [](llvm::Triple::EnvironmentType ST) {
1106 return StringRef(
1107 HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST));
1108 });
1109 Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
1110 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Stage)
1111 << (AllowedStages.size() != 1) << join(StageStrings, ", ");
1112}
1113
1114void SemaHLSL::diagnoseSemanticStageMismatch(
1115 const Attr *A, llvm::Triple::EnvironmentType Stage, IOType CurrentIOType,
1116 std::initializer_list<SemanticStageInfo> Allowed) {
1117
1118 for (auto &Case : Allowed) {
1119 if (Case.Stage != Stage)
1120 continue;
1121
1122 if (CurrentIOType & Case.AllowedIOTypesMask)
1123 return;
1124
1125 SmallVector<std::string, 8> ValidCases;
1126 llvm::transform(
1127 Allowed, std::back_inserter(ValidCases), [](SemanticStageInfo Case) {
1128 SmallVector<std::string, 2> ValidType;
1129 if (Case.AllowedIOTypesMask & IOType::In)
1130 ValidType.push_back("input");
1131 if (Case.AllowedIOTypesMask & IOType::Out)
1132 ValidType.push_back("output");
1133 return std::string(
1134 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Case.Stage)) +
1135 " " + join(ValidType, "/");
1136 });
1137 Diag(A->getLoc(), diag::err_hlsl_semantic_unsupported_iotype_for_stage)
1138 << A->getAttrName() << (CurrentIOType & IOType::In ? "input" : "output")
1139 << llvm::Triple::getEnvironmentTypeName(Case.Stage)
1140 << join(ValidCases, ", ");
1141 return;
1142 }
1143
1144 SmallVector<StringRef, 8> StageStrings;
1145 llvm::transform(
1146 Allowed, std::back_inserter(StageStrings), [](SemanticStageInfo Case) {
1147 return StringRef(
1148 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Case.Stage));
1149 });
1150
1151 Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
1152 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Stage)
1153 << (Allowed.size() != 1) << join(StageStrings, ", ");
1154}
1155
1156template <CastKind Kind>
1157static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
1158 if (const auto *VTy = Ty->getAs<VectorType>())
1159 Ty = VTy->getElementType();
1160 Ty = S.getASTContext().getExtVectorType(Ty, Sz);
1161 E = S.ImpCastExprToType(E.get(), Ty, Kind);
1162}
1163
1164template <CastKind Kind>
1166 E = S.ImpCastExprToType(E.get(), Ty, Kind);
1167 return Ty;
1168}
1169
1171 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1172 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1173 bool LHSFloat = LElTy->isRealFloatingType();
1174 bool RHSFloat = RElTy->isRealFloatingType();
1175
1176 if (LHSFloat && RHSFloat) {
1177 if (IsCompAssign ||
1178 SemaRef.getASTContext().getFloatingTypeOrder(LElTy, RElTy) > 0)
1179 return castElement<CK_FloatingCast>(SemaRef, RHS, LHSType);
1180
1181 return castElement<CK_FloatingCast>(SemaRef, LHS, RHSType);
1182 }
1183
1184 if (LHSFloat)
1185 return castElement<CK_IntegralToFloating>(SemaRef, RHS, LHSType);
1186
1187 assert(RHSFloat);
1188 if (IsCompAssign)
1189 return castElement<clang::CK_FloatingToIntegral>(SemaRef, RHS, LHSType);
1190
1191 return castElement<CK_IntegralToFloating>(SemaRef, LHS, RHSType);
1192}
1193
1195 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1196 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1197
1198 int IntOrder = SemaRef.Context.getIntegerTypeOrder(LElTy, RElTy);
1199 bool LHSSigned = LElTy->hasSignedIntegerRepresentation();
1200 bool RHSSigned = RElTy->hasSignedIntegerRepresentation();
1201 auto &Ctx = SemaRef.getASTContext();
1202
1203 // If both types have the same signedness, use the higher ranked type.
1204 if (LHSSigned == RHSSigned) {
1205 if (IsCompAssign || IntOrder >= 0)
1206 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1207
1208 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
1209 }
1210
1211 // If the unsigned type has greater than or equal rank of the signed type, use
1212 // the unsigned type.
1213 if (IntOrder != (LHSSigned ? 1 : -1)) {
1214 if (IsCompAssign || RHSSigned)
1215 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1216 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
1217 }
1218
1219 // At this point the signed type has higher rank than the unsigned type, which
1220 // means it will be the same size or bigger. If the signed type is bigger, it
1221 // can represent all the values of the unsigned type, so select it.
1222 if (Ctx.getIntWidth(LElTy) != Ctx.getIntWidth(RElTy)) {
1223 if (IsCompAssign || LHSSigned)
1224 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1225 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
1226 }
1227
1228 // This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due
1229 // to C/C++ leaking through. The place this happens today is long vs long
1230 // long. When arguments are vector<unsigned long, N> and vector<long long, N>,
1231 // the long long has higher rank than long even though they are the same size.
1232
1233 // If this is a compound assignment cast the right hand side to the left hand
1234 // side's type.
1235 if (IsCompAssign)
1236 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1237
1238 // If this isn't a compound assignment we convert to unsigned long long.
1239 QualType ElTy = Ctx.getCorrespondingUnsignedType(LHSSigned ? LElTy : RElTy);
1240 QualType NewTy = Ctx.getExtVectorType(
1241 ElTy, RHSType->castAs<VectorType>()->getNumElements());
1242 (void)castElement<CK_IntegralCast>(SemaRef, RHS, NewTy);
1243
1244 return castElement<CK_IntegralCast>(SemaRef, LHS, NewTy);
1245}
1246
1248 QualType SrcTy) {
1249 if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType())
1250 return CK_FloatingCast;
1251 if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx))
1252 return CK_IntegralCast;
1253 if (DestTy->isRealFloatingType())
1254 return CK_IntegralToFloating;
1255 assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx));
1256 return CK_FloatingToIntegral;
1257}
1258
1260 QualType LHSType,
1261 QualType RHSType,
1262 bool IsCompAssign) {
1263 const auto *LVecTy = LHSType->getAs<VectorType>();
1264 const auto *RVecTy = RHSType->getAs<VectorType>();
1265 auto &Ctx = getASTContext();
1266
1267 // If the LHS is not a vector and this is a compound assignment, we truncate
1268 // the argument to a scalar then convert it to the LHS's type.
1269 if (!LVecTy && IsCompAssign) {
1270 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1271 RHS = SemaRef.ImpCastExprToType(RHS.get(), RElTy, CK_HLSLVectorTruncation);
1272 RHSType = RHS.get()->getType();
1273 if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
1274 return LHSType;
1275 RHS = SemaRef.ImpCastExprToType(RHS.get(), LHSType,
1276 getScalarCastKind(Ctx, LHSType, RHSType));
1277 return LHSType;
1278 }
1279
1280 unsigned EndSz = std::numeric_limits<unsigned>::max();
1281 unsigned LSz = 0;
1282 if (LVecTy)
1283 LSz = EndSz = LVecTy->getNumElements();
1284 if (RVecTy)
1285 EndSz = std::min(RVecTy->getNumElements(), EndSz);
1286 assert(EndSz != std::numeric_limits<unsigned>::max() &&
1287 "one of the above should have had a value");
1288
1289 // In a compound assignment, the left operand does not change type, the right
1290 // operand is converted to the type of the left operand.
1291 if (IsCompAssign && LSz != EndSz) {
1292 Diag(LHS.get()->getBeginLoc(),
1293 diag::err_hlsl_vector_compound_assignment_truncation)
1294 << LHSType << RHSType;
1295 return QualType();
1296 }
1297
1298 if (RVecTy && RVecTy->getNumElements() > EndSz)
1299 castVector<CK_HLSLVectorTruncation>(SemaRef, RHS, RHSType, EndSz);
1300 if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz)
1301 castVector<CK_HLSLVectorTruncation>(SemaRef, LHS, LHSType, EndSz);
1302
1303 if (!RVecTy)
1304 castVector<CK_VectorSplat>(SemaRef, RHS, RHSType, EndSz);
1305 if (!IsCompAssign && !LVecTy)
1306 castVector<CK_VectorSplat>(SemaRef, LHS, LHSType, EndSz);
1307
1308 // If we're at the same type after resizing we can stop here.
1309 if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
1310 return Ctx.getCommonSugaredType(LHSType, RHSType);
1311
1312 QualType LElTy = LHSType->castAs<VectorType>()->getElementType();
1313 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1314
1315 // Handle conversion for floating point vectors.
1316 if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType())
1317 return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1318 LElTy, RElTy, IsCompAssign);
1319
1320 assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) &&
1321 "HLSL Vectors can only contain integer or floating point types");
1322 return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1323 LElTy, RElTy, IsCompAssign);
1324}
1325
1327 BinaryOperatorKind Opc) {
1328 assert((Opc == BO_LOr || Opc == BO_LAnd) &&
1329 "Called with non-logical operator");
1331 llvm::raw_svector_ostream OS(Buff);
1332 PrintingPolicy PP(SemaRef.getLangOpts());
1333 StringRef NewFnName = Opc == BO_LOr ? "or" : "and";
1334 OS << NewFnName << "(";
1335 LHS->printPretty(OS, nullptr, PP);
1336 OS << ", ";
1337 RHS->printPretty(OS, nullptr, PP);
1338 OS << ")";
1339 SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc());
1340 SemaRef.Diag(LHS->getBeginLoc(), diag::note_function_suggestion)
1341 << NewFnName << FixItHint::CreateReplacement(FullRange, OS.str());
1342}
1343
1344std::pair<IdentifierInfo *, bool>
1346 llvm::hash_code Hash = llvm::hash_value(Signature);
1347 std::string IdStr = "__hlsl_rootsig_decl_" + std::to_string(Hash);
1348 IdentifierInfo *DeclIdent = &(getASTContext().Idents.get(IdStr));
1349
1350 // Check if we have already found a decl of the same name.
1351 LookupResult R(SemaRef, DeclIdent, SourceLocation(),
1353 bool Found = SemaRef.LookupQualifiedName(R, SemaRef.CurContext);
1354 return {DeclIdent, Found};
1355}
1356
1358 SourceLocation Loc, IdentifierInfo *DeclIdent,
1360
1361 if (handleRootSignatureElements(RootElements))
1362 return;
1363
1365 for (auto &RootSigElement : RootElements)
1366 Elements.push_back(RootSigElement.getElement());
1367
1368 auto *SignatureDecl = HLSLRootSignatureDecl::Create(
1369 SemaRef.getASTContext(), /*DeclContext=*/SemaRef.CurContext, Loc,
1370 DeclIdent, SemaRef.getLangOpts().HLSLRootSigVer, Elements);
1371
1372 SignatureDecl->setImplicit();
1373 SemaRef.PushOnScopeChains(SignatureDecl, SemaRef.getCurScope());
1374}
1375
1378 if (RootSigOverrideIdent) {
1379 LookupResult R(SemaRef, RootSigOverrideIdent, SourceLocation(),
1381 if (SemaRef.LookupQualifiedName(R, DC))
1382 return dyn_cast<HLSLRootSignatureDecl>(R.getFoundDecl());
1383 }
1384
1385 return nullptr;
1386}
1387
1388namespace {
1389
1390struct PerVisibilityBindingChecker {
1391 SemaHLSL *S;
1392 // We need one builder per `llvm::dxbc::ShaderVisibility` value.
1393 std::array<llvm::hlsl::BindingInfoBuilder, 8> Builders;
1394
1395 struct ElemInfo {
1396 const hlsl::RootSignatureElement *Elem;
1397 llvm::dxbc::ShaderVisibility Vis;
1398 bool Diagnosed;
1399 };
1400 llvm::SmallVector<ElemInfo> ElemInfoMap;
1401
1402 PerVisibilityBindingChecker(SemaHLSL *S) : S(S) {}
1403
1404 void trackBinding(llvm::dxbc::ShaderVisibility Visibility,
1405 llvm::dxil::ResourceClass RC, uint32_t Space,
1406 uint32_t LowerBound, uint32_t UpperBound,
1407 const hlsl::RootSignatureElement *Elem) {
1408 uint32_t BuilderIndex = llvm::to_underlying(Visibility);
1409 assert(BuilderIndex < Builders.size() &&
1410 "Not enough builders for visibility type");
1411 Builders[BuilderIndex].trackBinding(RC, Space, LowerBound, UpperBound,
1412 static_cast<const void *>(Elem));
1413
1414 static_assert(llvm::to_underlying(llvm::dxbc::ShaderVisibility::All) == 0,
1415 "'All' visibility must come first");
1416 if (Visibility == llvm::dxbc::ShaderVisibility::All)
1417 for (size_t I = 1, E = Builders.size(); I < E; ++I)
1418 Builders[I].trackBinding(RC, Space, LowerBound, UpperBound,
1419 static_cast<const void *>(Elem));
1420
1421 ElemInfoMap.push_back({Elem, Visibility, false});
1422 }
1423
1424 ElemInfo &getInfo(const hlsl::RootSignatureElement *Elem) {
1425 auto It = llvm::lower_bound(
1426 ElemInfoMap, Elem,
1427 [](const auto &LHS, const auto &RHS) { return LHS.Elem < RHS; });
1428 assert(It->Elem == Elem && "Element not in map");
1429 return *It;
1430 }
1431
1432 bool checkOverlap() {
1433 llvm::sort(ElemInfoMap, [](const auto &LHS, const auto &RHS) {
1434 return LHS.Elem < RHS.Elem;
1435 });
1436
1437 bool HadOverlap = false;
1438
1439 using llvm::hlsl::BindingInfoBuilder;
1440 auto ReportOverlap = [this,
1441 &HadOverlap](const BindingInfoBuilder &Builder,
1442 const llvm::hlsl::Binding &Reported) {
1443 HadOverlap = true;
1444
1445 const auto *Elem =
1446 static_cast<const hlsl::RootSignatureElement *>(Reported.Cookie);
1447 const llvm::hlsl::Binding &Previous = Builder.findOverlapping(Reported);
1448 const auto *PrevElem =
1449 static_cast<const hlsl::RootSignatureElement *>(Previous.Cookie);
1450
1451 ElemInfo &Info = getInfo(Elem);
1452 // We will have already diagnosed this binding if there's overlap in the
1453 // "All" visibility as well as any particular visibility.
1454 if (Info.Diagnosed)
1455 return;
1456 Info.Diagnosed = true;
1457
1458 ElemInfo &PrevInfo = getInfo(PrevElem);
1459 llvm::dxbc::ShaderVisibility CommonVis =
1460 Info.Vis == llvm::dxbc::ShaderVisibility::All ? PrevInfo.Vis
1461 : Info.Vis;
1462
1463 this->S->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap)
1464 << llvm::to_underlying(Reported.RC) << Reported.LowerBound
1465 << Reported.isUnbounded() << Reported.UpperBound
1466 << llvm::to_underlying(Previous.RC) << Previous.LowerBound
1467 << Previous.isUnbounded() << Previous.UpperBound << Reported.Space
1468 << CommonVis;
1469
1470 this->S->Diag(PrevElem->getLocation(),
1471 diag::note_hlsl_resource_range_here);
1472 };
1473
1474 for (BindingInfoBuilder &Builder : Builders)
1475 Builder.calculateBindingInfo(ReportOverlap);
1476
1477 return HadOverlap;
1478 }
1479};
1480
1481static CXXMethodDecl *lookupMethod(Sema &S, CXXRecordDecl *RecordDecl,
1482 StringRef Name, SourceLocation Loc) {
1483 DeclarationName DeclName(&S.getASTContext().Idents.get(Name));
1484 LookupResult Result(S, DeclName, Loc, Sema::LookupMemberName);
1485 if (!S.LookupQualifiedName(Result, static_cast<DeclContext *>(RecordDecl)))
1486 return nullptr;
1487 return cast<CXXMethodDecl>(Result.getFoundDecl());
1488}
1489
1490} // end anonymous namespace
1491
1492static bool hasCounterHandle(const CXXRecordDecl *RD) {
1493 if (RD->field_empty())
1494 return false;
1495 auto It = std::next(RD->field_begin());
1496 if (It == RD->field_end())
1497 return false;
1498 const FieldDecl *SecondField = *It;
1499 if (const auto *ResTy =
1500 SecondField->getType()->getAs<HLSLAttributedResourceType>()) {
1501 return ResTy->getAttrs().IsCounter;
1502 }
1503 return false;
1504}
1505
1508 // Define some common error handling functions
1509 bool HadError = false;
1510 auto ReportError = [this, &HadError](SourceLocation Loc, uint32_t LowerBound,
1511 uint32_t UpperBound) {
1512 HadError = true;
1513 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value)
1514 << LowerBound << UpperBound;
1515 };
1516
1517 auto ReportFloatError = [this, &HadError](SourceLocation Loc,
1518 float LowerBound,
1519 float UpperBound) {
1520 HadError = true;
1521 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value)
1522 << llvm::formatv("{0:f}", LowerBound).sstr<6>()
1523 << llvm::formatv("{0:f}", UpperBound).sstr<6>();
1524 };
1525
1526 auto VerifyRegister = [ReportError](SourceLocation Loc, uint32_t Register) {
1527 if (!llvm::hlsl::rootsig::verifyRegisterValue(Register))
1528 ReportError(Loc, 0, 0xfffffffe);
1529 };
1530
1531 auto VerifySpace = [ReportError](SourceLocation Loc, uint32_t Space) {
1532 if (!llvm::hlsl::rootsig::verifyRegisterSpace(Space))
1533 ReportError(Loc, 0, 0xffffffef);
1534 };
1535
1536 const uint32_t Version =
1537 llvm::to_underlying(SemaRef.getLangOpts().HLSLRootSigVer);
1538 const uint32_t VersionEnum = Version - 1;
1539 auto ReportFlagError = [this, &HadError, VersionEnum](SourceLocation Loc) {
1540 HadError = true;
1541 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_flag)
1542 << /*version minor*/ VersionEnum;
1543 };
1544
1545 // Iterate through the elements and do basic validations
1546 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1547 SourceLocation Loc = RootSigElem.getLocation();
1548 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1549 if (const auto *Descriptor =
1550 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
1551 VerifyRegister(Loc, Descriptor->Reg.Number);
1552 VerifySpace(Loc, Descriptor->Space);
1553
1554 if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(Version,
1555 Descriptor->Flags))
1556 ReportFlagError(Loc);
1557 } else if (const auto *Constants =
1558 std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
1559 VerifyRegister(Loc, Constants->Reg.Number);
1560 VerifySpace(Loc, Constants->Space);
1561 } else if (const auto *Sampler =
1562 std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
1563 VerifyRegister(Loc, Sampler->Reg.Number);
1564 VerifySpace(Loc, Sampler->Space);
1565
1566 assert(!std::isnan(Sampler->MaxLOD) && !std::isnan(Sampler->MinLOD) &&
1567 "By construction, parseFloatParam can't produce a NaN from a "
1568 "float_literal token");
1569
1570 if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler->MaxAnisotropy))
1571 ReportError(Loc, 0, 16);
1572 if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler->MipLODBias))
1573 ReportFloatError(Loc, -16.f, 15.99f);
1574 } else if (const auto *Clause =
1575 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1576 &Elem)) {
1577 VerifyRegister(Loc, Clause->Reg.Number);
1578 VerifySpace(Loc, Clause->Space);
1579
1580 if (!llvm::hlsl::rootsig::verifyNumDescriptors(Clause->NumDescriptors)) {
1581 // NumDescriptor could techincally be ~0u but that is reserved for
1582 // unbounded, so the diagnostic will not report that as a valid int
1583 // value
1584 ReportError(Loc, 1, 0xfffffffe);
1585 }
1586
1587 if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(Version, Clause->Type,
1588 Clause->Flags))
1589 ReportFlagError(Loc);
1590 }
1591 }
1592
1593 PerVisibilityBindingChecker BindingChecker(this);
1594 SmallVector<std::pair<const llvm::hlsl::rootsig::DescriptorTableClause *,
1596 UnboundClauses;
1597
1598 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1599 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1600 if (const auto *Descriptor =
1601 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
1602 uint32_t LowerBound(Descriptor->Reg.Number);
1603 uint32_t UpperBound(LowerBound); // inclusive range
1604
1605 BindingChecker.trackBinding(
1606 Descriptor->Visibility,
1607 static_cast<llvm::dxil::ResourceClass>(Descriptor->Type),
1608 Descriptor->Space, LowerBound, UpperBound, &RootSigElem);
1609 } else if (const auto *Constants =
1610 std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
1611 uint32_t LowerBound(Constants->Reg.Number);
1612 uint32_t UpperBound(LowerBound); // inclusive range
1613
1614 BindingChecker.trackBinding(
1615 Constants->Visibility, llvm::dxil::ResourceClass::CBuffer,
1616 Constants->Space, LowerBound, UpperBound, &RootSigElem);
1617 } else if (const auto *Sampler =
1618 std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
1619 uint32_t LowerBound(Sampler->Reg.Number);
1620 uint32_t UpperBound(LowerBound); // inclusive range
1621
1622 BindingChecker.trackBinding(
1623 Sampler->Visibility, llvm::dxil::ResourceClass::Sampler,
1624 Sampler->Space, LowerBound, UpperBound, &RootSigElem);
1625 } else if (const auto *Clause =
1626 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1627 &Elem)) {
1628 // We'll process these once we see the table element.
1629 UnboundClauses.emplace_back(Clause, &RootSigElem);
1630 } else if (const auto *Table =
1631 std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
1632 assert(UnboundClauses.size() == Table->NumClauses &&
1633 "Number of unbound elements must match the number of clauses");
1634 bool HasAnySampler = false;
1635 bool HasAnyNonSampler = false;
1636 uint64_t Offset = 0;
1637 bool IsPrevUnbound = false;
1638 for (const auto &[Clause, ClauseElem] : UnboundClauses) {
1639 SourceLocation Loc = ClauseElem->getLocation();
1640 if (Clause->Type == llvm::dxil::ResourceClass::Sampler)
1641 HasAnySampler = true;
1642 else
1643 HasAnyNonSampler = true;
1644
1645 if (HasAnySampler && HasAnyNonSampler)
1646 Diag(Loc, diag::err_hlsl_invalid_mixed_resources);
1647
1648 // Relevant error will have already been reported above and needs to be
1649 // fixed before we can conduct further analysis, so shortcut error
1650 // return
1651 if (Clause->NumDescriptors == 0)
1652 return true;
1653
1654 bool IsAppending =
1655 Clause->Offset == llvm::hlsl::rootsig::DescriptorTableOffsetAppend;
1656 if (!IsAppending)
1657 Offset = Clause->Offset;
1658
1659 uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound(
1660 Offset, Clause->NumDescriptors);
1661
1662 if (IsPrevUnbound && IsAppending)
1663 Diag(Loc, diag::err_hlsl_appending_onto_unbound);
1664 else if (!llvm::hlsl::rootsig::verifyNoOverflowedOffset(RangeBound))
1665 Diag(Loc, diag::err_hlsl_offset_overflow) << Offset << RangeBound;
1666
1667 // Update offset to be 1 past this range's bound
1668 Offset = RangeBound + 1;
1669 IsPrevUnbound = Clause->NumDescriptors ==
1670 llvm::hlsl::rootsig::NumDescriptorsUnbounded;
1671
1672 // Compute the register bounds and track resource binding
1673 uint32_t LowerBound(Clause->Reg.Number);
1674 uint32_t UpperBound = llvm::hlsl::rootsig::computeRangeBound(
1675 LowerBound, Clause->NumDescriptors);
1676
1677 BindingChecker.trackBinding(
1678 Table->Visibility,
1679 static_cast<llvm::dxil::ResourceClass>(Clause->Type), Clause->Space,
1680 LowerBound, UpperBound, ClauseElem);
1681 }
1682 UnboundClauses.clear();
1683 }
1684 }
1685
1686 return BindingChecker.checkOverlap();
1687}
1688
1690 if (AL.getNumArgs() != 1) {
1691 Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1;
1692 return;
1693 }
1694
1696 if (auto *RS = D->getAttr<RootSignatureAttr>()) {
1697 if (RS->getSignatureIdent() != Ident) {
1698 Diag(AL.getLoc(), diag::err_disallowed_duplicate_attribute) << RS;
1699 return;
1700 }
1701
1702 Diag(AL.getLoc(), diag::warn_duplicate_attribute_exact) << RS;
1703 return;
1704 }
1705
1707 if (SemaRef.LookupQualifiedName(R, D->getDeclContext()))
1708 if (auto *SignatureDecl =
1709 dyn_cast<HLSLRootSignatureDecl>(R.getFoundDecl())) {
1710 D->addAttr(::new (getASTContext()) RootSignatureAttr(
1711 getASTContext(), AL, Ident, SignatureDecl));
1712 }
1713}
1714
1716 llvm::VersionTuple SMVersion =
1717 getASTContext().getTargetInfo().getTriple().getOSVersion();
1718 bool IsDXIL = getASTContext().getTargetInfo().getTriple().getArch() ==
1719 llvm::Triple::dxil;
1720
1721 uint32_t ZMax = 1024;
1722 uint32_t ThreadMax = 1024;
1723 if (IsDXIL && SMVersion.getMajor() <= 4) {
1724 ZMax = 1;
1725 ThreadMax = 768;
1726 } else if (IsDXIL && SMVersion.getMajor() == 5) {
1727 ZMax = 64;
1728 ThreadMax = 1024;
1729 }
1730
1731 uint32_t X;
1732 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), X))
1733 return;
1734 if (X > 1024) {
1735 Diag(AL.getArgAsExpr(0)->getExprLoc(),
1736 diag::err_hlsl_numthreads_argument_oor)
1737 << 0 << 1024;
1738 return;
1739 }
1740 uint32_t Y;
1741 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Y))
1742 return;
1743 if (Y > 1024) {
1744 Diag(AL.getArgAsExpr(1)->getExprLoc(),
1745 diag::err_hlsl_numthreads_argument_oor)
1746 << 1 << 1024;
1747 return;
1748 }
1749 uint32_t Z;
1750 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Z))
1751 return;
1752 if (Z > ZMax) {
1753 SemaRef.Diag(AL.getArgAsExpr(2)->getExprLoc(),
1754 diag::err_hlsl_numthreads_argument_oor)
1755 << 2 << ZMax;
1756 return;
1757 }
1758
1759 if (X * Y * Z > ThreadMax) {
1760 Diag(AL.getLoc(), diag::err_hlsl_numthreads_invalid) << ThreadMax;
1761 return;
1762 }
1763
1764 HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z);
1765 if (NewAttr)
1766 D->addAttr(NewAttr);
1767}
1768
1769static bool isValidWaveSizeValue(unsigned Value) {
1770 return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128;
1771}
1772
1774 // validate that the wavesize argument is a power of 2 between 4 and 128
1775 // inclusive
1776 unsigned SpelledArgsCount = AL.getNumArgs();
1777 if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
1778 return;
1779
1780 uint32_t Min;
1781 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Min))
1782 return;
1783
1784 uint32_t Max = 0;
1785 if (SpelledArgsCount > 1 &&
1786 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Max))
1787 return;
1788
1789 uint32_t Preferred = 0;
1790 if (SpelledArgsCount > 2 &&
1791 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Preferred))
1792 return;
1793
1794 if (SpelledArgsCount > 2) {
1795 if (!isValidWaveSizeValue(Preferred)) {
1796 Diag(AL.getArgAsExpr(2)->getExprLoc(),
1797 diag::err_attribute_power_of_two_in_range)
1798 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize
1799 << Preferred;
1800 return;
1801 }
1802 // Preferred not in range.
1803 if (Preferred < Min || Preferred > Max) {
1804 Diag(AL.getArgAsExpr(2)->getExprLoc(),
1805 diag::err_attribute_power_of_two_in_range)
1806 << AL << Min << Max << Preferred;
1807 return;
1808 }
1809 } else if (SpelledArgsCount > 1) {
1810 if (!isValidWaveSizeValue(Max)) {
1811 Diag(AL.getArgAsExpr(1)->getExprLoc(),
1812 diag::err_attribute_power_of_two_in_range)
1813 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max;
1814 return;
1815 }
1816 if (Max < Min) {
1817 Diag(AL.getLoc(), diag::err_attribute_argument_invalid) << AL << 1;
1818 return;
1819 } else if (Max == Min) {
1820 Diag(AL.getLoc(), diag::warn_attr_min_eq_max) << AL;
1821 }
1822 } else {
1823 if (!isValidWaveSizeValue(Min)) {
1824 Diag(AL.getArgAsExpr(0)->getExprLoc(),
1825 diag::err_attribute_power_of_two_in_range)
1826 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min;
1827 return;
1828 }
1829 }
1830
1831 HLSLWaveSizeAttr *NewAttr =
1832 mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
1833 if (NewAttr)
1834 D->addAttr(NewAttr);
1835}
1836
1838 uint32_t ID;
1839 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), ID))
1840 return;
1841 D->addAttr(::new (getASTContext())
1842 HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
1843}
1844
1846 uint32_t ID;
1847 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), ID))
1848 return;
1849 D->addAttr(::new (getASTContext())
1850 HLSLVkExtBuiltinOutputAttr(getASTContext(), AL, ID));
1851}
1852
1854 D->addAttr(::new (getASTContext())
1855 HLSLVkPushConstantAttr(getASTContext(), AL));
1856}
1857
1859 uint32_t Id;
1860 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
1861 return;
1862 HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
1863 if (NewAttr)
1864 D->addAttr(NewAttr);
1865}
1866
1868 uint32_t Binding = 0;
1869 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Binding))
1870 return;
1871 uint32_t Set = 0;
1872 if (AL.getNumArgs() > 1 &&
1873 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Set))
1874 return;
1875
1876 D->addAttr(::new (getASTContext())
1877 HLSLVkBindingAttr(getASTContext(), AL, Binding, Set));
1878}
1879
1881 uint32_t Location;
1882 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Location))
1883 return;
1884
1885 D->addAttr(::new (getASTContext())
1886 HLSLVkLocationAttr(getASTContext(), AL, Location));
1887}
1888
1890 const auto *VT = T->getAs<VectorType>();
1891
1892 if (!T->hasUnsignedIntegerRepresentation() ||
1893 (VT && VT->getNumElements() > 3)) {
1894 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1895 << AL << "uint/uint2/uint3";
1896 return false;
1897 }
1898
1899 return true;
1900}
1901
1903 const auto *VT = T->getAs<VectorType>();
1904 if (!T->hasFloatingRepresentation() || (VT && VT->getNumElements() > 4)) {
1905 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1906 << AL << "float/float1/float2/float3/float4";
1907 return false;
1908 }
1909
1910 return true;
1911}
1912
1914 std::optional<unsigned> Index) {
1915 std::string SemanticName = AL.getAttrName()->getName().upper();
1916
1917 auto *VD = cast<ValueDecl>(D);
1918 QualType ValueType = VD->getType();
1919 if (auto *FD = dyn_cast<FunctionDecl>(D))
1920 ValueType = FD->getReturnType();
1921
1922 bool IsOutput = false;
1923 if (HLSLParamModifierAttr *MA = D->getAttr<HLSLParamModifierAttr>()) {
1924 if (MA->isOut()) {
1925 IsOutput = true;
1926 ValueType = cast<ReferenceType>(ValueType)->getPointeeType();
1927 }
1928 }
1929
1930 if (SemanticName == "SV_DISPATCHTHREADID") {
1931 diagnoseInputIDType(ValueType, AL);
1932 if (IsOutput)
1933 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1934 if (Index.has_value())
1935 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1937 return;
1938 }
1939
1940 if (SemanticName == "SV_GROUPINDEX") {
1941 if (IsOutput)
1942 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1943 if (Index.has_value())
1944 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1946 return;
1947 }
1948
1949 if (SemanticName == "SV_GROUPTHREADID") {
1950 diagnoseInputIDType(ValueType, AL);
1951 if (IsOutput)
1952 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1953 if (Index.has_value())
1954 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1956 return;
1957 }
1958
1959 if (SemanticName == "SV_GROUPID") {
1960 diagnoseInputIDType(ValueType, AL);
1961 if (IsOutput)
1962 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1963 if (Index.has_value())
1964 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1966 return;
1967 }
1968
1969 if (SemanticName == "SV_POSITION") {
1970 const auto *VT = ValueType->getAs<VectorType>();
1971 if (!ValueType->hasFloatingRepresentation() ||
1972 (VT && VT->getNumElements() > 4))
1973 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1974 << AL << "float/float1/float2/float3/float4";
1976 return;
1977 }
1978
1979 if (SemanticName == "SV_VERTEXID") {
1980 uint64_t SizeInBits = SemaRef.Context.getTypeSize(ValueType);
1981 if (!ValueType->isUnsignedIntegerType() || SizeInBits != 32)
1982 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type) << AL << "uint";
1984 return;
1985 }
1986
1987 if (SemanticName == "SV_TARGET") {
1988 const auto *VT = ValueType->getAs<VectorType>();
1989 if (!ValueType->hasFloatingRepresentation() ||
1990 (VT && VT->getNumElements() > 4))
1991 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1992 << AL << "float/float1/float2/float3/float4";
1994 return;
1995 }
1996
1997 Diag(AL.getLoc(), diag::err_hlsl_unknown_semantic) << AL;
1998}
1999
2001 uint32_t IndexValue(0), ExplicitIndex(0);
2002 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), IndexValue) ||
2003 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), ExplicitIndex)) {
2004 assert(0 && "HLSLUnparsedSemantic is expected to have 2 int arguments.");
2005 }
2006 assert(IndexValue > 0 ? ExplicitIndex : true);
2007 std::optional<unsigned> Index =
2008 ExplicitIndex ? std::optional<unsigned>(IndexValue) : std::nullopt;
2009
2010 if (AL.getAttrName()->getName().starts_with_insensitive("SV_"))
2011 diagnoseSystemSemanticAttr(D, AL, Index);
2012 else
2014}
2015
2018 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)
2019 << AL << "shader constant in a constant buffer";
2020 return;
2021 }
2022
2023 uint32_t SubComponent;
2024 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), SubComponent))
2025 return;
2026 uint32_t Component;
2027 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Component))
2028 return;
2029
2030 QualType T = cast<VarDecl>(D)->getType().getCanonicalType();
2031 // Check if T is an array or struct type.
2032 // TODO: mark matrix type as aggregate type.
2033 bool IsAggregateTy = (T->isArrayType() || T->isStructureType());
2034
2035 // Check Component is valid for T.
2036 if (Component) {
2037 unsigned Size = getASTContext().getTypeSize(T);
2038 if (IsAggregateTy) {
2039 Diag(AL.getLoc(), diag::err_hlsl_invalid_register_or_packoffset);
2040 return;
2041 } else {
2042 // Make sure Component + sizeof(T) <= 4.
2043 if ((Component * 32 + Size) > 128) {
2044 Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);
2045 return;
2046 }
2047 QualType EltTy = T;
2048 if (const auto *VT = T->getAs<VectorType>())
2049 EltTy = VT->getElementType();
2050 unsigned Align = getASTContext().getTypeAlign(EltTy);
2051 if (Align > 32 && Component == 1) {
2052 // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary.
2053 // So we only need to check Component 1 here.
2054 Diag(AL.getLoc(), diag::err_hlsl_packoffset_alignment_mismatch)
2055 << Align << EltTy;
2056 return;
2057 }
2058 }
2059 }
2060
2061 D->addAttr(::new (getASTContext()) HLSLPackOffsetAttr(
2062 getASTContext(), AL, SubComponent, Component));
2063}
2064
2066 StringRef Str;
2067 SourceLocation ArgLoc;
2068 if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))
2069 return;
2070
2071 llvm::Triple::EnvironmentType ShaderType;
2072 if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) {
2073 Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
2074 << AL << Str << ArgLoc;
2075 return;
2076 }
2077
2078 // FIXME: check function match the shader stage.
2079
2080 HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType);
2081 if (NewAttr)
2082 D->addAttr(NewAttr);
2083}
2084
2086 Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList,
2087 QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo) {
2088 assert(AttrList.size() && "expected list of resource attributes");
2089
2090 QualType ContainedTy = QualType();
2091 TypeSourceInfo *ContainedTyInfo = nullptr;
2092 SourceLocation LocBegin = AttrList[0]->getRange().getBegin();
2093 SourceLocation LocEnd = AttrList[0]->getRange().getEnd();
2094
2095 HLSLAttributedResourceType::Attributes ResAttrs;
2096
2097 bool HasResourceClass = false;
2098 bool HasResourceDimension = false;
2099 for (const Attr *A : AttrList) {
2100 if (!A)
2101 continue;
2102 LocEnd = A->getRange().getEnd();
2103 switch (A->getKind()) {
2104 case attr::HLSLResourceClass: {
2105 ResourceClass RC = cast<HLSLResourceClassAttr>(A)->getResourceClass();
2106 if (HasResourceClass) {
2107 S.Diag(A->getLocation(), ResAttrs.ResourceClass == RC
2108 ? diag::warn_duplicate_attribute_exact
2109 : diag::warn_duplicate_attribute)
2110 << A;
2111 return false;
2112 }
2113 ResAttrs.ResourceClass = RC;
2114 HasResourceClass = true;
2115 break;
2116 }
2117 case attr::HLSLResourceDimension: {
2118 llvm::dxil::ResourceDimension RD =
2119 cast<HLSLResourceDimensionAttr>(A)->getDimension();
2120 if (HasResourceDimension) {
2121 S.Diag(A->getLocation(), ResAttrs.ResourceDimension == RD
2122 ? diag::warn_duplicate_attribute_exact
2123 : diag::warn_duplicate_attribute)
2124 << A;
2125 return false;
2126 }
2127 ResAttrs.ResourceDimension = RD;
2128 HasResourceDimension = true;
2129 break;
2130 }
2131 case attr::HLSLROV:
2132 if (ResAttrs.IsROV) {
2133 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
2134 return false;
2135 }
2136 ResAttrs.IsROV = true;
2137 break;
2138 case attr::HLSLRawBuffer:
2139 if (ResAttrs.RawBuffer) {
2140 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
2141 return false;
2142 }
2143 ResAttrs.RawBuffer = true;
2144 break;
2145 case attr::HLSLIsCounter:
2146 if (ResAttrs.IsCounter) {
2147 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
2148 return false;
2149 }
2150 ResAttrs.IsCounter = true;
2151 break;
2152 case attr::HLSLContainedType: {
2153 const HLSLContainedTypeAttr *CTAttr = cast<HLSLContainedTypeAttr>(A);
2154 QualType Ty = CTAttr->getType();
2155 if (!ContainedTy.isNull()) {
2156 S.Diag(A->getLocation(), ContainedTy == Ty
2157 ? diag::warn_duplicate_attribute_exact
2158 : diag::warn_duplicate_attribute)
2159 << A;
2160 return false;
2161 }
2162 ContainedTy = Ty;
2163 ContainedTyInfo = CTAttr->getTypeLoc();
2164 break;
2165 }
2166 default:
2167 llvm_unreachable("unhandled resource attribute type");
2168 }
2169 }
2170
2171 if (!HasResourceClass) {
2172 S.Diag(AttrList.back()->getRange().getEnd(),
2173 diag::err_hlsl_missing_resource_class);
2174 return false;
2175 }
2176
2178 Wrapped, ContainedTy, ResAttrs);
2179
2180 if (LocInfo && ContainedTyInfo) {
2181 LocInfo->Range = SourceRange(LocBegin, LocEnd);
2182 LocInfo->ContainedTyInfo = ContainedTyInfo;
2183 }
2184 return true;
2185}
2186
2187// Validates and creates an HLSL attribute that is applied as type attribute on
2188// HLSL resource. The attributes are collected in HLSLResourcesTypeAttrs and at
2189// the end of the declaration they are applied to the declaration type by
2190// wrapping it in HLSLAttributedResourceType.
2192 // only allow resource type attributes on intangible types
2193 if (!T->isHLSLResourceType()) {
2194 Diag(AL.getLoc(), diag::err_hlsl_attribute_needs_intangible_type)
2195 << AL << getASTContext().HLSLResourceTy;
2196 return false;
2197 }
2198
2199 // validate number of arguments
2200 if (!AL.checkExactlyNumArgs(SemaRef, AL.getMinArgs()))
2201 return false;
2202
2203 Attr *A = nullptr;
2204
2208 {
2209 AttributeCommonInfo::AS_CXX11, 0, false /*IsAlignas*/,
2210 false /*IsRegularKeywordAttribute*/
2211 });
2212
2213 switch (AL.getKind()) {
2214 case ParsedAttr::AT_HLSLResourceClass: {
2215 if (!AL.isArgIdent(0)) {
2216 Diag(AL.getLoc(), diag::err_attribute_argument_type)
2217 << AL << AANT_ArgumentIdentifier;
2218 return false;
2219 }
2220
2221 IdentifierLoc *Loc = AL.getArgAsIdent(0);
2222 StringRef Identifier = Loc->getIdentifierInfo()->getName();
2223 SourceLocation ArgLoc = Loc->getLoc();
2224
2225 // Validate resource class value
2226 ResourceClass RC;
2227 if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Identifier, RC)) {
2228 Diag(ArgLoc, diag::warn_attribute_type_not_supported)
2229 << "ResourceClass" << Identifier;
2230 return false;
2231 }
2232 A = HLSLResourceClassAttr::Create(getASTContext(), RC, ACI);
2233 break;
2234 }
2235
2236 case ParsedAttr::AT_HLSLResourceDimension: {
2237 StringRef Identifier;
2238 SourceLocation ArgLoc;
2239 if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Identifier, &ArgLoc))
2240 return false;
2241
2242 // Validate resource dimension value
2243 llvm::dxil::ResourceDimension RD;
2244 if (!HLSLResourceDimensionAttr::ConvertStrToResourceDimension(Identifier,
2245 RD)) {
2246 Diag(ArgLoc, diag::warn_attribute_type_not_supported)
2247 << "ResourceDimension" << Identifier;
2248 return false;
2249 }
2250 A = HLSLResourceDimensionAttr::Create(getASTContext(), RD, ACI);
2251 break;
2252 }
2253
2254 case ParsedAttr::AT_HLSLROV:
2255 A = HLSLROVAttr::Create(getASTContext(), ACI);
2256 break;
2257
2258 case ParsedAttr::AT_HLSLRawBuffer:
2259 A = HLSLRawBufferAttr::Create(getASTContext(), ACI);
2260 break;
2261
2262 case ParsedAttr::AT_HLSLIsCounter:
2263 A = HLSLIsCounterAttr::Create(getASTContext(), ACI);
2264 break;
2265
2266 case ParsedAttr::AT_HLSLContainedType: {
2267 if (AL.getNumArgs() != 1 && !AL.hasParsedType()) {
2268 Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1;
2269 return false;
2270 }
2271
2272 TypeSourceInfo *TSI = nullptr;
2273 QualType QT = SemaRef.GetTypeFromParser(AL.getTypeArg(), &TSI);
2274 assert(TSI && "no type source info for attribute argument");
2275 if (SemaRef.RequireCompleteType(TSI->getTypeLoc().getBeginLoc(), QT,
2276 diag::err_incomplete_type))
2277 return false;
2278 A = HLSLContainedTypeAttr::Create(getASTContext(), TSI, ACI);
2279 break;
2280 }
2281
2282 default:
2283 llvm_unreachable("unhandled HLSL attribute");
2284 }
2285
2286 HLSLResourcesTypeAttrs.emplace_back(A);
2287 return true;
2288}
2289
2290// Combines all resource type attributes and creates HLSLAttributedResourceType.
2292 if (!HLSLResourcesTypeAttrs.size())
2293 return CurrentType;
2294
2295 QualType QT = CurrentType;
2298 HLSLResourcesTypeAttrs, QT, &LocInfo)) {
2299 const HLSLAttributedResourceType *RT =
2301
2302 // Temporarily store TypeLoc information for the new type.
2303 // It will be transferred to HLSLAttributesResourceTypeLoc
2304 // shortly after the type is created by TypeSpecLocFiller which
2305 // will call the TakeLocForHLSLAttribute method below.
2306 LocsForHLSLAttributedResources.insert(std::pair(RT, LocInfo));
2307 }
2308 HLSLResourcesTypeAttrs.clear();
2309 return QT;
2310}
2311
2312// Returns source location for the HLSLAttributedResourceType
2314SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
2315 HLSLAttributedResourceLocInfo LocInfo = {};
2316 auto I = LocsForHLSLAttributedResources.find(RT);
2317 if (I != LocsForHLSLAttributedResources.end()) {
2318 LocInfo = I->second;
2319 LocsForHLSLAttributedResources.erase(I);
2320 return LocInfo;
2321 }
2322 LocInfo.Range = SourceRange();
2323 return LocInfo;
2324}
2325
2326// Walks though the global variable declaration, collects all resource binding
2327// requirements and adds them to Bindings
2328void SemaHLSL::collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
2329 const RecordType *RT) {
2330 const RecordDecl *RD = RT->getDecl()->getDefinitionOrSelf();
2331 for (FieldDecl *FD : RD->fields()) {
2332 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
2333
2334 // Unwrap arrays
2335 // FIXME: Calculate array size while unwrapping
2336 assert(!Ty->isIncompleteArrayType() &&
2337 "incomplete arrays inside user defined types are not supported");
2338 while (Ty->isConstantArrayType()) {
2341 }
2342
2343 if (!Ty->isRecordType())
2344 continue;
2345
2346 if (const HLSLAttributedResourceType *AttrResType =
2347 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
2348 // Add a new DeclBindingInfo to Bindings if it does not already exist
2349 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
2350 DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, RC);
2351 if (!DBI)
2352 Bindings.addDeclBindingInfo(VD, RC);
2353 } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
2354 // Recursively scan embedded struct or class; it would be nice to do this
2355 // without recursion, but tricky to correctly calculate the size of the
2356 // binding, which is something we are probably going to need to do later
2357 // on. Hopefully nesting of structs in structs too many levels is
2358 // unlikely.
2359 collectResourceBindingsOnUserRecordDecl(VD, RT);
2360 }
2361 }
2362}
2363
2364// Diagnose localized register binding errors for a single binding; does not
2365// diagnose resource binding on user record types, that will be done later
2366// in processResourceBindingOnDecl based on the information collected in
2367// collectResourceBindingsOnVarDecl.
2368// Returns false if the register binding is not valid.
2370 Decl *D, RegisterType RegType,
2371 bool SpecifiedSpace) {
2372 int RegTypeNum = static_cast<int>(RegType);
2373
2374 // check if the decl type is groupshared
2375 if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
2376 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2377 return false;
2378 }
2379
2380 // Cbuffers and Tbuffers are HLSLBufferDecl types
2381 if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
2382 ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
2383 : ResourceClass::SRV;
2384 if (RegType == getRegisterType(RC))
2385 return true;
2386
2387 S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
2388 << RegTypeNum;
2389 return false;
2390 }
2391
2392 // Samplers, UAVs, and SRVs are VarDecl types
2393 assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
2394 VarDecl *VD = cast<VarDecl>(D);
2395
2396 // Resource
2397 if (const HLSLAttributedResourceType *AttrResType =
2398 HLSLAttributedResourceType::findHandleTypeOnResource(
2399 VD->getType().getTypePtr())) {
2400 if (RegType == getRegisterType(AttrResType))
2401 return true;
2402
2403 S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
2404 << RegTypeNum;
2405 return false;
2406 }
2407
2408 const clang::Type *Ty = VD->getType().getTypePtr();
2409 while (Ty->isArrayType())
2411
2412 // Basic types
2413 if (Ty->isArithmeticType() || Ty->isVectorType()) {
2414 bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
2415 if (SpecifiedSpace && !DeclaredInCOrTBuffer)
2416 S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
2417
2418 if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(S.getASTContext()) ||
2419 Ty->isFloatingType() || Ty->isVectorType())) {
2420 // Register annotation on default constant buffer declaration ($Globals)
2421 if (RegType == RegisterType::CBuffer)
2422 S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
2423 else if (RegType != RegisterType::C)
2424 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2425 else
2426 return true;
2427 } else {
2428 if (RegType == RegisterType::C)
2429 S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
2430 else
2431 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2432 }
2433 return false;
2434 }
2435 if (Ty->isRecordType())
2436 // RecordTypes will be diagnosed in processResourceBindingOnDecl
2437 // that is called from ActOnVariableDeclarator
2438 return true;
2439
2440 // Anything else is an error
2441 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2442 return false;
2443}
2444
2446 RegisterType regType) {
2447 // make sure that there are no two register annotations
2448 // applied to the decl with the same register type
2449 bool RegisterTypesDetected[5] = {false};
2450 RegisterTypesDetected[static_cast<int>(regType)] = true;
2451
2452 for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
2453 if (HLSLResourceBindingAttr *attr =
2454 dyn_cast<HLSLResourceBindingAttr>(*it)) {
2455
2456 RegisterType otherRegType = attr->getRegisterType();
2457 if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
2458 int otherRegTypeNum = static_cast<int>(otherRegType);
2459 S.Diag(TheDecl->getLocation(),
2460 diag::err_hlsl_duplicate_register_annotation)
2461 << otherRegTypeNum;
2462 return false;
2463 }
2464 RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
2465 }
2466 }
2467 return true;
2468}
2469
2471 Decl *D, RegisterType RegType,
2472 bool SpecifiedSpace) {
2473
2474 // exactly one of these two types should be set
2475 assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
2476 (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
2477 "expecting VarDecl or HLSLBufferDecl");
2478
2479 // check if the declaration contains resource matching the register type
2480 if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
2481 return false;
2482
2483 // next, if multiple register annotations exist, check that none conflict.
2484 return ValidateMultipleRegisterAnnotations(S, D, RegType);
2485}
2486
2487// return false if the slot count exceeds the limit, true otherwise
2488static bool AccumulateHLSLResourceSlots(QualType Ty, uint64_t &StartSlot,
2489 const uint64_t &Limit,
2490 const ResourceClass ResClass,
2491 ASTContext &Ctx,
2492 uint64_t ArrayCount = 1) {
2493 Ty = Ty.getCanonicalType();
2494 const Type *T = Ty.getTypePtr();
2495
2496 // Early exit if already overflowed
2497 if (StartSlot > Limit)
2498 return false;
2499
2500 // Case 1: array type
2501 if (const auto *AT = dyn_cast<ArrayType>(T)) {
2502 uint64_t Count = 1;
2503
2504 if (const auto *CAT = dyn_cast<ConstantArrayType>(AT))
2505 Count = CAT->getSize().getZExtValue();
2506
2507 QualType ElemTy = AT->getElementType();
2508 return AccumulateHLSLResourceSlots(ElemTy, StartSlot, Limit, ResClass, Ctx,
2509 ArrayCount * Count);
2510 }
2511
2512 // Case 2: resource leaf
2513 if (auto ResTy = dyn_cast<HLSLAttributedResourceType>(T)) {
2514 // First ensure this resource counts towards the corresponding
2515 // register type limit.
2516 if (ResTy->getAttrs().ResourceClass != ResClass)
2517 return true;
2518
2519 // Validate highest slot used
2520 uint64_t EndSlot = StartSlot + ArrayCount - 1;
2521 if (EndSlot > Limit)
2522 return false;
2523
2524 // Advance SlotCount past the consumed range
2525 StartSlot = EndSlot + 1;
2526 return true;
2527 }
2528
2529 // Case 3: struct / record
2530 if (const auto *RT = dyn_cast<RecordType>(T)) {
2531 const RecordDecl *RD = RT->getDecl();
2532
2533 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
2534 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
2535 if (!AccumulateHLSLResourceSlots(Base.getType(), StartSlot, Limit,
2536 ResClass, Ctx, ArrayCount))
2537 return false;
2538 }
2539 }
2540
2541 for (const FieldDecl *Field : RD->fields()) {
2542 if (!AccumulateHLSLResourceSlots(Field->getType(), StartSlot, Limit,
2543 ResClass, Ctx, ArrayCount))
2544 return false;
2545 }
2546
2547 return true;
2548 }
2549
2550 // Case 4: everything else
2551 return true;
2552}
2553
2554// return true if there is something invalid, false otherwise
2555static bool ValidateRegisterNumber(uint64_t SlotNum, Decl *TheDecl,
2556 ASTContext &Ctx, RegisterType RegTy) {
2557 const uint64_t Limit = UINT32_MAX;
2558 if (SlotNum > Limit)
2559 return true;
2560
2561 // after verifying the number doesn't exceed uint32max, we don't need
2562 // to look further into c or i register types
2563 if (RegTy == RegisterType::C || RegTy == RegisterType::I)
2564 return false;
2565
2566 if (VarDecl *VD = dyn_cast<VarDecl>(TheDecl)) {
2567 uint64_t BaseSlot = SlotNum;
2568
2569 if (!AccumulateHLSLResourceSlots(VD->getType(), SlotNum, Limit,
2570 getResourceClass(RegTy), Ctx))
2571 return true;
2572
2573 // After AccumulateHLSLResourceSlots runs, SlotNum is now
2574 // the first free slot; last used was SlotNum - 1
2575 return (BaseSlot > Limit);
2576 }
2577 // handle the cbuffer/tbuffer case
2578 if (isa<HLSLBufferDecl>(TheDecl))
2579 // resources cannot be put within a cbuffer, so no need
2580 // to analyze the structure since the register number
2581 // won't be pushed any higher.
2582 return (SlotNum > Limit);
2583
2584 // we don't expect any other decl type, so fail
2585 llvm_unreachable("unexpected decl type");
2586}
2587
2589 if (VarDecl *VD = dyn_cast<VarDecl>(TheDecl)) {
2590 QualType Ty = VD->getType();
2591 if (const auto *IAT = dyn_cast<IncompleteArrayType>(Ty))
2592 Ty = IAT->getElementType();
2593 if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(), Ty,
2594 diag::err_incomplete_type))
2595 return;
2596 }
2597
2598 StringRef Slot = "";
2599 StringRef Space = "";
2600 SourceLocation SlotLoc, SpaceLoc;
2601
2602 if (!AL.isArgIdent(0)) {
2603 Diag(AL.getLoc(), diag::err_attribute_argument_type)
2604 << AL << AANT_ArgumentIdentifier;
2605 return;
2606 }
2607 IdentifierLoc *Loc = AL.getArgAsIdent(0);
2608
2609 if (AL.getNumArgs() == 2) {
2610 Slot = Loc->getIdentifierInfo()->getName();
2611 SlotLoc = Loc->getLoc();
2612 if (!AL.isArgIdent(1)) {
2613 Diag(AL.getLoc(), diag::err_attribute_argument_type)
2614 << AL << AANT_ArgumentIdentifier;
2615 return;
2616 }
2617 Loc = AL.getArgAsIdent(1);
2618 Space = Loc->getIdentifierInfo()->getName();
2619 SpaceLoc = Loc->getLoc();
2620 } else {
2621 StringRef Str = Loc->getIdentifierInfo()->getName();
2622 if (Str.starts_with("space")) {
2623 Space = Str;
2624 SpaceLoc = Loc->getLoc();
2625 } else {
2626 Slot = Str;
2627 SlotLoc = Loc->getLoc();
2628 Space = "space0";
2629 }
2630 }
2631
2632 RegisterType RegType = RegisterType::SRV;
2633 std::optional<unsigned> SlotNum;
2634 unsigned SpaceNum = 0;
2635
2636 // Validate slot
2637 if (!Slot.empty()) {
2638 if (!convertToRegisterType(Slot, &RegType)) {
2639 Diag(SlotLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1);
2640 return;
2641 }
2642 if (RegType == RegisterType::I) {
2643 Diag(SlotLoc, diag::warn_hlsl_deprecated_register_type_i);
2644 return;
2645 }
2646 const StringRef SlotNumStr = Slot.substr(1);
2647
2648 uint64_t N;
2649
2650 // validate that the slot number is a non-empty number
2651 if (SlotNumStr.getAsInteger(10, N)) {
2652 Diag(SlotLoc, diag::err_hlsl_unsupported_register_number);
2653 return;
2654 }
2655
2656 // Validate register number. It should not exceed UINT32_MAX,
2657 // including if the resource type is an array that starts
2658 // before UINT32_MAX, but ends afterwards.
2659 if (ValidateRegisterNumber(N, TheDecl, getASTContext(), RegType)) {
2660 Diag(SlotLoc, diag::err_hlsl_register_number_too_large);
2661 return;
2662 }
2663
2664 // the slot number has been validated and does not exceed UINT32_MAX
2665 SlotNum = (unsigned)N;
2666 }
2667
2668 // Validate space
2669 if (!Space.starts_with("space")) {
2670 Diag(SpaceLoc, diag::err_hlsl_expected_space) << Space;
2671 return;
2672 }
2673 StringRef SpaceNumStr = Space.substr(5);
2674 if (SpaceNumStr.getAsInteger(10, SpaceNum)) {
2675 Diag(SpaceLoc, diag::err_hlsl_expected_space) << Space;
2676 return;
2677 }
2678
2679 // If we have slot, diagnose it is the right register type for the decl
2680 if (SlotNum.has_value())
2681 if (!DiagnoseHLSLRegisterAttribute(SemaRef, SlotLoc, TheDecl, RegType,
2682 !SpaceLoc.isInvalid()))
2683 return;
2684
2685 HLSLResourceBindingAttr *NewAttr =
2686 HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL);
2687 if (NewAttr) {
2688 NewAttr->setBinding(RegType, SlotNum, SpaceNum);
2689 TheDecl->addAttr(NewAttr);
2690 }
2691}
2692
2694 HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr(
2695 D, AL,
2696 static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
2697 if (NewAttr)
2698 D->addAttr(NewAttr);
2699}
2700
2701namespace {
2702
2703/// This class implements HLSL availability diagnostics for default
2704/// and relaxed mode
2705///
2706/// The goal of this diagnostic is to emit an error or warning when an
2707/// unavailable API is found in code that is reachable from the shader
2708/// entry function or from an exported function (when compiling a shader
2709/// library).
2710///
2711/// This is done by traversing the AST of all shader entry point functions
2712/// and of all exported functions, and any functions that are referenced
2713/// from this AST. In other words, any functions that are reachable from
2714/// the entry points.
2715class DiagnoseHLSLAvailability : public DynamicRecursiveASTVisitor {
2716 Sema &SemaRef;
2717
2718 // Stack of functions to be scaned
2720
2721 // Tracks which environments functions have been scanned in.
2722 //
2723 // Maps FunctionDecl to an unsigned number that represents the set of shader
2724 // environments the function has been scanned for.
2725 // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
2726 // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
2727 // (verified by static_asserts in Triple.cpp), we can use it to index
2728 // individual bits in the set, as long as we shift the values to start with 0
2729 // by subtracting the value of llvm::Triple::Pixel first.
2730 //
2731 // The N'th bit in the set will be set if the function has been scanned
2732 // in shader environment whose llvm::Triple::EnvironmentType integer value
2733 // equals (llvm::Triple::Pixel + N).
2734 //
2735 // For example, if a function has been scanned in compute and pixel stage
2736 // environment, the value will be 0x21 (100001 binary) because:
2737 //
2738 // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
2739 // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
2740 //
2741 // A FunctionDecl is mapped to 0 (or not included in the map) if it has not
2742 // been scanned in any environment.
2743 llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
2744
2745 // Do not access these directly, use the get/set methods below to make
2746 // sure the values are in sync
2747 llvm::Triple::EnvironmentType CurrentShaderEnvironment;
2748 unsigned CurrentShaderStageBit;
2749
2750 // True if scanning a function that was already scanned in a different
2751 // shader stage context, and therefore we should not report issues that
2752 // depend only on shader model version because they would be duplicate.
2753 bool ReportOnlyShaderStageIssues;
2754
2755 // Helper methods for dealing with current stage context / environment
2756 void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
2757 static_assert(sizeof(unsigned) >= 4);
2758 assert(HLSLShaderAttr::isValidShaderType(ShaderType));
2759 assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
2760 "ShaderType is too big for this bitmap"); // 31 is reserved for
2761 // "unknown"
2762
2763 unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
2764 CurrentShaderEnvironment = ShaderType;
2765 CurrentShaderStageBit = (1 << bitmapIndex);
2766 }
2767
2768 void SetUnknownShaderStageContext() {
2769 CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment;
2770 CurrentShaderStageBit = (1 << 31);
2771 }
2772
2773 llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const {
2774 return CurrentShaderEnvironment;
2775 }
2776
2777 bool InUnknownShaderStageContext() const {
2778 return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment;
2779 }
2780
2781 // Helper methods for dealing with shader stage bitmap
2782 void AddToScannedFunctions(const FunctionDecl *FD) {
2783 unsigned &ScannedStages = ScannedDecls[FD];
2784 ScannedStages |= CurrentShaderStageBit;
2785 }
2786
2787 unsigned GetScannedStages(const FunctionDecl *FD) { return ScannedDecls[FD]; }
2788
2789 bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) {
2790 return WasAlreadyScannedInCurrentStage(GetScannedStages(FD));
2791 }
2792
2793 bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) {
2794 return ScannerStages & CurrentShaderStageBit;
2795 }
2796
2797 static bool NeverBeenScanned(unsigned ScannedStages) {
2798 return ScannedStages == 0;
2799 }
2800
2801 // Scanning methods
2802 void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr);
2803 void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA,
2804 SourceRange Range);
2805 const AvailabilityAttr *FindAvailabilityAttr(const Decl *D);
2806 bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA);
2807
2808public:
2809 DiagnoseHLSLAvailability(Sema &SemaRef)
2810 : SemaRef(SemaRef),
2811 CurrentShaderEnvironment(llvm::Triple::UnknownEnvironment),
2812 CurrentShaderStageBit(0), ReportOnlyShaderStageIssues(false) {}
2813
2814 // AST traversal methods
2815 void RunOnTranslationUnit(const TranslationUnitDecl *TU);
2816 void RunOnFunction(const FunctionDecl *FD);
2817
2818 bool VisitDeclRefExpr(DeclRefExpr *DRE) override {
2819 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(DRE->getDecl());
2820 if (FD)
2821 HandleFunctionOrMethodRef(FD, DRE);
2822 return true;
2823 }
2824
2825 bool VisitMemberExpr(MemberExpr *ME) override {
2826 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(ME->getMemberDecl());
2827 if (FD)
2828 HandleFunctionOrMethodRef(FD, ME);
2829 return true;
2830 }
2831};
2832
2833void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD,
2834 Expr *RefExpr) {
2835 assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) &&
2836 "expected DeclRefExpr or MemberExpr");
2837
2838 // has a definition -> add to stack to be scanned
2839 const FunctionDecl *FDWithBody = nullptr;
2840 if (FD->hasBody(FDWithBody)) {
2841 if (!WasAlreadyScannedInCurrentStage(FDWithBody))
2842 DeclsToScan.push_back(FDWithBody);
2843 return;
2844 }
2845
2846 // no body -> diagnose availability
2847 const AvailabilityAttr *AA = FindAvailabilityAttr(FD);
2848 if (AA)
2849 CheckDeclAvailability(
2850 FD, AA, SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc()));
2851}
2852
2853void DiagnoseHLSLAvailability::RunOnTranslationUnit(
2854 const TranslationUnitDecl *TU) {
2855
2856 // Iterate over all shader entry functions and library exports, and for those
2857 // that have a body (definiton), run diag scan on each, setting appropriate
2858 // shader environment context based on whether it is a shader entry function
2859 // or an exported function. Exported functions can be in namespaces and in
2860 // export declarations so we need to scan those declaration contexts as well.
2862 DeclContextsToScan.push_back(TU);
2863
2864 while (!DeclContextsToScan.empty()) {
2865 const DeclContext *DC = DeclContextsToScan.pop_back_val();
2866 for (auto &D : DC->decls()) {
2867 // do not scan implicit declaration generated by the implementation
2868 if (D->isImplicit())
2869 continue;
2870
2871 // for namespace or export declaration add the context to the list to be
2872 // scanned later
2873 if (llvm::dyn_cast<NamespaceDecl>(D) || llvm::dyn_cast<ExportDecl>(D)) {
2874 DeclContextsToScan.push_back(llvm::dyn_cast<DeclContext>(D));
2875 continue;
2876 }
2877
2878 // skip over other decls or function decls without body
2879 const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(D);
2880 if (!FD || !FD->isThisDeclarationADefinition())
2881 continue;
2882
2883 // shader entry point
2884 if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) {
2885 SetShaderStageContext(ShaderAttr->getType());
2886 RunOnFunction(FD);
2887 continue;
2888 }
2889 // exported library function
2890 // FIXME: replace this loop with external linkage check once issue #92071
2891 // is resolved
2892 bool isExport = FD->isInExportDeclContext();
2893 if (!isExport) {
2894 for (const auto *Redecl : FD->redecls()) {
2895 if (Redecl->isInExportDeclContext()) {
2896 isExport = true;
2897 break;
2898 }
2899 }
2900 }
2901 if (isExport) {
2902 SetUnknownShaderStageContext();
2903 RunOnFunction(FD);
2904 continue;
2905 }
2906 }
2907 }
2908}
2909
2910void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) {
2911 assert(DeclsToScan.empty() && "DeclsToScan should be empty");
2912 DeclsToScan.push_back(FD);
2913
2914 while (!DeclsToScan.empty()) {
2915 // Take one decl from the stack and check it by traversing its AST.
2916 // For any CallExpr found during the traversal add it's callee to the top of
2917 // the stack to be processed next. Functions already processed are stored in
2918 // ScannedDecls.
2919 const FunctionDecl *FD = DeclsToScan.pop_back_val();
2920
2921 // Decl was already scanned
2922 const unsigned ScannedStages = GetScannedStages(FD);
2923 if (WasAlreadyScannedInCurrentStage(ScannedStages))
2924 continue;
2925
2926 ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages);
2927
2928 AddToScannedFunctions(FD);
2929 TraverseStmt(FD->getBody());
2930 }
2931}
2932
2933bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone(
2934 const AvailabilityAttr *AA) {
2935 const IdentifierInfo *IIEnvironment = AA->getEnvironment();
2936 if (!IIEnvironment)
2937 return true;
2938
2939 llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment();
2940 if (CurrentEnv == llvm::Triple::UnknownEnvironment)
2941 return false;
2942
2943 llvm::Triple::EnvironmentType AttrEnv =
2944 AvailabilityAttr::getEnvironmentType(IIEnvironment->getName());
2945
2946 return CurrentEnv == AttrEnv;
2947}
2948
2949const AvailabilityAttr *
2950DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) {
2951 AvailabilityAttr const *PartialMatch = nullptr;
2952 // Check each AvailabilityAttr to find the one for this platform.
2953 // For multiple attributes with the same platform try to find one for this
2954 // environment.
2955 for (const auto *A : D->attrs()) {
2956 if (const auto *Avail = dyn_cast<AvailabilityAttr>(A)) {
2957 const AvailabilityAttr *EffectiveAvail = Avail->getEffectiveAttr();
2958 StringRef AttrPlatform = EffectiveAvail->getPlatform()->getName();
2959 StringRef TargetPlatform =
2961
2962 // Match the platform name.
2963 if (AttrPlatform == TargetPlatform) {
2964 // Find the best matching attribute for this environment
2965 if (HasMatchingEnvironmentOrNone(EffectiveAvail))
2966 return Avail;
2967 PartialMatch = Avail;
2968 }
2969 }
2970 }
2971 return PartialMatch;
2972}
2973
2974// Check availability against target shader model version and current shader
2975// stage and emit diagnostic
2976void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,
2977 const AvailabilityAttr *AA,
2978 SourceRange Range) {
2979
2980 const IdentifierInfo *IIEnv = AA->getEnvironment();
2981
2982 if (!IIEnv) {
2983 // The availability attribute does not have environment -> it depends only
2984 // on shader model version and not on specific the shader stage.
2985
2986 // Skip emitting the diagnostics if the diagnostic mode is set to
2987 // strict (-fhlsl-strict-availability) because all relevant diagnostics
2988 // were already emitted in the DiagnoseUnguardedAvailability scan
2989 // (SemaAvailability.cpp).
2990 if (SemaRef.getLangOpts().HLSLStrictAvailability)
2991 return;
2992
2993 // Do not report shader-stage-independent issues if scanning a function
2994 // that was already scanned in a different shader stage context (they would
2995 // be duplicate)
2996 if (ReportOnlyShaderStageIssues)
2997 return;
2998
2999 } else {
3000 // The availability attribute has environment -> we need to know
3001 // the current stage context to property diagnose it.
3002 if (InUnknownShaderStageContext())
3003 return;
3004 }
3005
3006 // Check introduced version and if environment matches
3007 bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA);
3008 VersionTuple Introduced = AA->getIntroduced();
3009 VersionTuple TargetVersion =
3011
3012 if (TargetVersion >= Introduced && EnvironmentMatches)
3013 return;
3014
3015 // Emit diagnostic message
3016 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
3017 llvm::StringRef PlatformName(
3018 AvailabilityAttr::getPrettyPlatformName(TI.getPlatformName()));
3019
3020 llvm::StringRef CurrentEnvStr =
3021 llvm::Triple::getEnvironmentTypeName(GetCurrentShaderEnvironment());
3022
3023 llvm::StringRef AttrEnvStr =
3024 AA->getEnvironment() ? AA->getEnvironment()->getName() : "";
3025 bool UseEnvironment = !AttrEnvStr.empty();
3026
3027 if (EnvironmentMatches) {
3028 SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability)
3029 << Range << D << PlatformName << Introduced.getAsString()
3030 << UseEnvironment << CurrentEnvStr;
3031 } else {
3032 SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability_unavailable)
3033 << Range << D;
3034 }
3035
3036 SemaRef.Diag(D->getLocation(), diag::note_partial_availability_specified_here)
3037 << D << PlatformName << Introduced.getAsString()
3038 << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString()
3039 << UseEnvironment << AttrEnvStr << CurrentEnvStr;
3040}
3041
3042} // namespace
3043
3045 // process default CBuffer - create buffer layout struct and invoke codegenCGH
3046 if (!DefaultCBufferDecls.empty()) {
3048 SemaRef.getASTContext(), SemaRef.getCurLexicalContext(),
3049 DefaultCBufferDecls);
3050 addImplicitBindingAttrToDecl(SemaRef, DefaultCBuffer, RegisterType::CBuffer,
3052 SemaRef.getCurLexicalContext()->addDecl(DefaultCBuffer);
3054
3055 // Set HasValidPackoffset if any of the decls has a register(c#) annotation;
3056 for (const Decl *VD : DefaultCBufferDecls) {
3057 const HLSLResourceBindingAttr *RBA =
3058 VD->getAttr<HLSLResourceBindingAttr>();
3059 if (RBA && RBA->hasRegisterSlot() &&
3060 RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) {
3061 DefaultCBuffer->setHasValidPackoffset(true);
3062 break;
3063 }
3064 }
3065
3066 DeclGroupRef DG(DefaultCBuffer);
3067 SemaRef.Consumer.HandleTopLevelDecl(DG);
3068 }
3069 diagnoseAvailabilityViolations(TU);
3070}
3071
3072// For resource member access through a global struct array, verify that the
3073// array index selecting the struct element is a constant integer expression.
3074// Returns false if the member expression is invalid.
3076 assert((ME->getType()->isHLSLResourceRecord() ||
3078 "expected member expr to have resource record type or array of them");
3079
3080 // Walk the AST from MemberExpr to the VarDecl of the parent struct instance
3081 // and take note of any non-constant array indexing along the way. If the
3082 // VarDecl we find is a global variable, report error if there was any
3083 // non-constant array index in the resource member access along the way.
3084 const Expr *NonConstIndexExpr = nullptr;
3085 const Expr *E = ME->getBase();
3086 while (E) {
3087 if (const DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E)) {
3088 if (!NonConstIndexExpr)
3089 return true;
3090
3091 const VarDecl *VD = cast<VarDecl>(DRE->getDecl());
3092 if (!VD->hasGlobalStorage())
3093 return true;
3094
3095 SemaRef.Diag(NonConstIndexExpr->getExprLoc(),
3096 diag::err_hlsl_resource_member_array_access_not_constant);
3097 return false;
3098 }
3099
3100 if (const auto *ASE = dyn_cast<ArraySubscriptExpr>(E)) {
3101 const Expr *IdxExpr = ASE->getIdx();
3102 if (!IdxExpr->isIntegerConstantExpr(SemaRef.getASTContext()))
3103 NonConstIndexExpr = IdxExpr;
3104 E = ASE->getBase();
3105 } else if (const auto *SubME = dyn_cast<MemberExpr>(E)) {
3106 E = SubME->getBase();
3107 } else if (const auto *ICE = dyn_cast<ImplicitCastExpr>(E)) {
3108 E = ICE->getSubExpr();
3109 } else {
3110 llvm_unreachable("unexpected expr type in resource member access");
3111 }
3112 }
3113 return true;
3114}
3115
3116void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
3117 // Skip running the diagnostics scan if the diagnostic mode is
3118 // strict (-fhlsl-strict-availability) and the target shader stage is known
3119 // because all relevant diagnostics were already emitted in the
3120 // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp).
3122 if (SemaRef.getLangOpts().HLSLStrictAvailability &&
3123 TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library)
3124 return;
3125
3126 DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
3127}
3128
3129static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
3130 assert(TheCall->getNumArgs() > 1);
3131 QualType ArgTy0 = TheCall->getArg(0)->getType();
3132
3133 for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) {
3135 ArgTy0, TheCall->getArg(I)->getType())) {
3136 S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
3137 << TheCall->getDirectCallee() << /*useAllTerminology*/ true
3138 << SourceRange(TheCall->getArg(0)->getBeginLoc(),
3139 TheCall->getArg(N - 1)->getEndLoc());
3140 return true;
3141 }
3142 }
3143 return false;
3144}
3145
3147 QualType ArgType = Arg->getType();
3149 S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
3150 << ArgType << ExpectedType << 1 << 0 << 0;
3151 return true;
3152 }
3153 return false;
3154}
3155
3157 Sema *S, CallExpr *TheCall,
3158 llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
3159 clang::QualType PassedType)>
3160 Check) {
3161 for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
3162 Expr *Arg = TheCall->getArg(I);
3163 if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
3164 return true;
3165 }
3166 return false;
3167}
3168
3170 int ArgOrdinal,
3171 clang::QualType PassedType) {
3172 clang::QualType BaseType =
3173 PassedType->isVectorType()
3174 ? PassedType->castAs<clang::VectorType>()->getElementType()
3175 : PassedType;
3176 if (!BaseType->isFloat32Type())
3177 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3178 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3179 << /* float */ 1 << PassedType;
3180 return false;
3181}
3182
3184 int ArgOrdinal,
3185 clang::QualType PassedType) {
3186 clang::QualType BaseType =
3187 PassedType->isVectorType()
3188 ? PassedType->castAs<clang::VectorType>()->getElementType()
3189 : PassedType;
3190 if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
3191 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3192 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3193 << /* half or float */ 2 << PassedType;
3194 return false;
3195}
3196
3198 int ArgOrdinal,
3199 clang::QualType PassedType) {
3200 clang::QualType BaseType =
3201 PassedType->isVectorType()
3202 ? PassedType->castAs<clang::VectorType>()->getElementType()
3203 : PassedType->isMatrixType()
3204 ? PassedType->castAs<clang::MatrixType>()->getElementType()
3205 : PassedType;
3206 if (!BaseType->isDoubleType()) {
3207 // FIXME: adopt standard `err_builtin_invalid_arg_type` instead of using
3208 // this custom error.
3209 return S->Diag(Loc, diag::err_builtin_requires_double_type)
3210 << ArgOrdinal << PassedType;
3211 }
3212
3213 return false;
3214}
3215
3216static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
3217 unsigned ArgIndex) {
3218 auto *Arg = TheCall->getArg(ArgIndex);
3219 SourceLocation OrigLoc = Arg->getExprLoc();
3220 if (Arg->IgnoreCasts()->isModifiableLvalue(S->Context, &OrigLoc) ==
3222 return false;
3223 S->Diag(OrigLoc, diag::error_hlsl_inout_lvalue) << Arg << 0;
3224 return true;
3225}
3226
3227static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
3228 clang::QualType PassedType) {
3229 const auto *VecTy = PassedType->getAs<VectorType>();
3230 if (!VecTy)
3231 return false;
3232
3233 if (VecTy->getElementType()->isDoubleType())
3234 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3235 << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
3236 << PassedType;
3237 return false;
3238}
3239
3241 int ArgOrdinal,
3242 clang::QualType PassedType) {
3243 if (!PassedType->hasIntegerRepresentation() &&
3244 !PassedType->hasFloatingRepresentation())
3245 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3246 << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
3247 << /* fp */ 1 << PassedType;
3248 return false;
3249}
3250
3252 int ArgOrdinal,
3253 clang::QualType PassedType) {
3254 if (auto *VecTy = PassedType->getAs<VectorType>())
3255 if (VecTy->getElementType()->isUnsignedIntegerType())
3256 return false;
3257
3258 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3259 << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
3260 << PassedType;
3261}
3262
3263// checks for unsigned ints of all sizes
3265 int ArgOrdinal,
3266 clang::QualType PassedType) {
3267 if (!PassedType->hasUnsignedIntegerRepresentation())
3268 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3269 << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
3270 << /* no fp */ 0 << PassedType;
3271 return false;
3272}
3273
3274static bool CheckExpectedBitWidth(Sema *S, CallExpr *TheCall,
3275 unsigned ArgOrdinal, unsigned Width) {
3276 QualType ArgTy = TheCall->getArg(0)->getType();
3277 if (auto *VTy = ArgTy->getAs<VectorType>())
3278 ArgTy = VTy->getElementType();
3279 // ensure arg type has expected bit width
3280 uint64_t ElementBitCount =
3282 if (ElementBitCount != Width) {
3283 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3284 diag::err_integer_incorrect_bit_count)
3285 << Width << ElementBitCount;
3286 return true;
3287 }
3288 return false;
3289}
3290
3292 QualType ReturnType) {
3293 auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>();
3294 if (VecTyA)
3295 ReturnType =
3296 S->Context.getExtVectorType(ReturnType, VecTyA->getNumElements());
3297
3298 TheCall->setType(ReturnType);
3299}
3300
3301static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
3302 unsigned ArgIndex) {
3303 assert(TheCall->getNumArgs() >= ArgIndex);
3304 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3305 auto *VTy = ArgType->getAs<VectorType>();
3306 // not the scalar or vector<scalar>
3307 if (!(S->Context.hasSameUnqualifiedType(ArgType, Scalar) ||
3308 (VTy &&
3309 S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar)))) {
3310 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3311 diag::err_typecheck_expect_scalar_or_vector)
3312 << ArgType << Scalar;
3313 return true;
3314 }
3315 return false;
3316}
3317
3319 QualType Scalar, unsigned ArgIndex) {
3320 assert(TheCall->getNumArgs() > ArgIndex);
3321
3322 Expr *Arg = TheCall->getArg(ArgIndex);
3323 QualType ArgType = Arg->getType();
3324
3325 // Scalar: T
3326 if (S->Context.hasSameUnqualifiedType(ArgType, Scalar))
3327 return false;
3328
3329 // Vector: vector<T>
3330 if (const auto *VTy = ArgType->getAs<VectorType>()) {
3331 if (S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar))
3332 return false;
3333 }
3334
3335 // Matrix: ConstantMatrixType with element type T
3336 if (const auto *MTy = ArgType->getAs<ConstantMatrixType>()) {
3337 if (S->Context.hasSameUnqualifiedType(MTy->getElementType(), Scalar))
3338 return false;
3339 }
3340
3341 // Not a scalar/vector/matrix-of-scalar
3342 S->Diag(Arg->getBeginLoc(),
3343 diag::err_typecheck_expect_scalar_or_vector_or_matrix)
3344 << ArgType << Scalar;
3345 return true;
3346}
3347
3348static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
3349 unsigned ArgIndex) {
3350 assert(TheCall->getNumArgs() >= ArgIndex);
3351 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3352 auto *VTy = ArgType->getAs<VectorType>();
3353 // not the scalar or vector<scalar>
3354 if (!(ArgType->isScalarType() ||
3355 (VTy && VTy->getElementType()->isScalarType()))) {
3356 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3357 diag::err_typecheck_expect_any_scalar_or_vector)
3358 << ArgType << 1;
3359 return true;
3360 }
3361 return false;
3362}
3363
3364// Check that the argument is not a bool or vector<bool>
3365// Returns true on error
3367 unsigned ArgIndex) {
3368 QualType BoolType = S->getASTContext().BoolTy;
3369 assert(ArgIndex < TheCall->getNumArgs());
3370 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3371 auto *VTy = ArgType->getAs<VectorType>();
3372 // is the bool or vector<bool>
3373 if (S->Context.hasSameUnqualifiedType(ArgType, BoolType) ||
3374 (VTy &&
3375 S->Context.hasSameUnqualifiedType(VTy->getElementType(), BoolType))) {
3376 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3377 diag::err_typecheck_expect_any_scalar_or_vector)
3378 << ArgType << 0;
3379 return true;
3380 }
3381 return false;
3382}
3383
3384static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
3385 if (CheckNotBoolScalarOrVector(S, TheCall, 0))
3386 return true;
3387 return false;
3388}
3389
3390static bool CheckWavePrefix(Sema *S, CallExpr *TheCall) {
3391 if (CheckNotBoolScalarOrVector(S, TheCall, 0))
3392 return true;
3393 return false;
3394}
3395
3396static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
3397 assert(TheCall->getNumArgs() == 3);
3398 Expr *Arg1 = TheCall->getArg(1);
3399 Expr *Arg2 = TheCall->getArg(2);
3400 if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
3401 S->Diag(TheCall->getBeginLoc(),
3402 diag::err_typecheck_call_different_arg_types)
3403 << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
3404 << Arg2->getSourceRange();
3405 return true;
3406 }
3407
3408 TheCall->setType(Arg1->getType());
3409 return false;
3410}
3411
3412static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
3413 assert(TheCall->getNumArgs() == 3);
3414 Expr *Arg1 = TheCall->getArg(1);
3415 QualType Arg1Ty = Arg1->getType();
3416 Expr *Arg2 = TheCall->getArg(2);
3417 QualType Arg2Ty = Arg2->getType();
3418
3419 QualType Arg1ScalarTy = Arg1Ty;
3420 if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
3421 Arg1ScalarTy = VTy->getElementType();
3422
3423 QualType Arg2ScalarTy = Arg2Ty;
3424 if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
3425 Arg2ScalarTy = VTy->getElementType();
3426
3427 if (!S->Context.hasSameUnqualifiedType(Arg1ScalarTy, Arg2ScalarTy))
3428 S->Diag(Arg1->getBeginLoc(), diag::err_hlsl_builtin_scalar_vector_mismatch)
3429 << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
3430
3431 QualType Arg0Ty = TheCall->getArg(0)->getType();
3432 unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
3433 unsigned Arg1Length = Arg1Ty->isVectorType()
3434 ? Arg1Ty->getAs<VectorType>()->getNumElements()
3435 : 0;
3436 unsigned Arg2Length = Arg2Ty->isVectorType()
3437 ? Arg2Ty->getAs<VectorType>()->getNumElements()
3438 : 0;
3439 if (Arg1Length > 0 && Arg0Length != Arg1Length) {
3440 S->Diag(TheCall->getBeginLoc(),
3441 diag::err_typecheck_vector_lengths_not_equal)
3442 << Arg0Ty << Arg1Ty << TheCall->getArg(0)->getSourceRange()
3443 << Arg1->getSourceRange();
3444 return true;
3445 }
3446
3447 if (Arg2Length > 0 && Arg0Length != Arg2Length) {
3448 S->Diag(TheCall->getBeginLoc(),
3449 diag::err_typecheck_vector_lengths_not_equal)
3450 << Arg0Ty << Arg2Ty << TheCall->getArg(0)->getSourceRange()
3451 << Arg2->getSourceRange();
3452 return true;
3453 }
3454
3455 TheCall->setType(
3456 S->getASTContext().getExtVectorType(Arg1ScalarTy, Arg0Length));
3457 return false;
3458}
3459
3460static bool CheckIndexType(Sema *S, CallExpr *TheCall, unsigned IndexArgIndex) {
3461 assert(TheCall->getNumArgs() > IndexArgIndex && "Index argument missing");
3462 QualType ArgType = TheCall->getArg(IndexArgIndex)->getType();
3463 QualType IndexTy = ArgType;
3464 unsigned int ActualDim = 1;
3465 if (const auto *VTy = IndexTy->getAs<VectorType>()) {
3466 ActualDim = VTy->getNumElements();
3467 IndexTy = VTy->getElementType();
3468 }
3469 if (!IndexTy->isIntegerType()) {
3470 S->Diag(TheCall->getArg(IndexArgIndex)->getBeginLoc(),
3471 diag::err_typecheck_expect_int)
3472 << ArgType;
3473 return true;
3474 }
3475
3476 QualType ResourceArgTy = TheCall->getArg(0)->getType();
3477 const HLSLAttributedResourceType *ResTy =
3478 ResourceArgTy.getTypePtr()->getAs<HLSLAttributedResourceType>();
3479 assert(ResTy && "Resource argument must be a resource");
3480 HLSLAttributedResourceType::Attributes ResAttrs = ResTy->getAttrs();
3481
3482 unsigned int ExpectedDim = 1;
3483 if (ResAttrs.ResourceDimension != llvm::dxil::ResourceDimension::Unknown)
3484 ExpectedDim = getResourceDimensions(ResAttrs.ResourceDimension);
3485
3486 if (ActualDim != ExpectedDim) {
3487 S->Diag(TheCall->getArg(IndexArgIndex)->getBeginLoc(),
3488 diag::err_hlsl_builtin_resource_coordinate_dimension_mismatch)
3489 << cast<NamedDecl>(TheCall->getCalleeDecl()) << ExpectedDim
3490 << ActualDim;
3491 return true;
3492 }
3493
3494 return false;
3495}
3496
3498 Sema *S, CallExpr *TheCall, unsigned ArgIndex,
3499 llvm::function_ref<bool(const HLSLAttributedResourceType *ResType)> Check =
3500 nullptr) {
3501 assert(TheCall->getNumArgs() >= ArgIndex);
3502 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3503 const HLSLAttributedResourceType *ResTy =
3504 ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>();
3505 if (!ResTy) {
3506 S->Diag(TheCall->getArg(ArgIndex)->getBeginLoc(),
3507 diag::err_typecheck_expect_hlsl_resource)
3508 << ArgType;
3509 return true;
3510 }
3511 if (Check && Check(ResTy)) {
3512 S->Diag(TheCall->getArg(ArgIndex)->getExprLoc(),
3513 diag::err_invalid_hlsl_resource_type)
3514 << ArgType;
3515 return true;
3516 }
3517 return false;
3518}
3519
3520static bool CheckVectorElementCount(Sema *S, QualType PassedType,
3521 QualType BaseType, unsigned ExpectedCount,
3522 SourceLocation Loc) {
3523 unsigned PassedCount = 1;
3524 if (const auto *VecTy = PassedType->getAs<VectorType>())
3525 PassedCount = VecTy->getNumElements();
3526
3527 if (PassedCount != ExpectedCount) {
3529 S->Context.getExtVectorType(BaseType, ExpectedCount);
3530 S->Diag(Loc, diag::err_typecheck_convert_incompatible)
3531 << PassedType << ExpectedType << 1 << 0 << 0;
3532 return true;
3533 }
3534 return false;
3535}
3536
3537enum class SampleKind { Sample, Bias, Grad, Level, Cmp, CmpLevelZero };
3538
3540 // Check the texture handle.
3541 if (CheckResourceHandle(&S, TheCall, 0,
3542 [](const HLSLAttributedResourceType *ResType) {
3543 return ResType->getAttrs().ResourceDimension ==
3544 llvm::dxil::ResourceDimension::Unknown;
3545 }))
3546 return true;
3547
3548 // Check the sampler handle.
3549 if (CheckResourceHandle(&S, TheCall, 1,
3550 [](const HLSLAttributedResourceType *ResType) {
3551 return ResType->getAttrs().ResourceClass !=
3552 llvm::hlsl::ResourceClass::Sampler;
3553 }))
3554 return true;
3555
3556 auto *ResourceTy =
3557 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3558
3559 // Check the location.
3560 unsigned ExpectedDim =
3561 getResourceDimensions(ResourceTy->getAttrs().ResourceDimension);
3562 if (CheckVectorElementCount(&S, TheCall->getArg(2)->getType(),
3563 S.Context.FloatTy, ExpectedDim,
3564 TheCall->getBeginLoc()))
3565 return true;
3566
3567 return false;
3568}
3569
3570static bool CheckCalculateLodBuiltin(Sema &S, CallExpr *TheCall) {
3571 if (S.checkArgCount(TheCall, 3))
3572 return true;
3573
3574 if (CheckTextureSamplerAndLocation(S, TheCall))
3575 return true;
3576
3577 TheCall->setType(S.Context.FloatTy);
3578 return false;
3579}
3580
3581static bool CheckGatherBuiltin(Sema &S, CallExpr *TheCall, bool IsCmp) {
3582 if (S.checkArgCountRange(TheCall, IsCmp ? 5 : 4, IsCmp ? 6 : 5))
3583 return true;
3584
3585 if (CheckTextureSamplerAndLocation(S, TheCall))
3586 return true;
3587
3588 unsigned NextIdx = 3;
3589 if (IsCmp) {
3590 // Check the compare value.
3591 QualType CmpTy = TheCall->getArg(NextIdx)->getType();
3592 if (!CmpTy->isFloatingType() || CmpTy->isVectorType()) {
3593 S.Diag(TheCall->getArg(NextIdx)->getBeginLoc(),
3594 diag::err_typecheck_convert_incompatible)
3595 << CmpTy << S.Context.FloatTy << 1 << 0 << 0;
3596 return true;
3597 }
3598 NextIdx++;
3599 }
3600
3601 // Check the component operand.
3602 Expr *ComponentArg = TheCall->getArg(NextIdx);
3603 QualType ComponentTy = ComponentArg->getType();
3604 if (!ComponentTy->isIntegerType() || ComponentTy->isVectorType()) {
3605 S.Diag(ComponentArg->getBeginLoc(),
3606 diag::err_typecheck_convert_incompatible)
3607 << ComponentTy << S.Context.UnsignedIntTy << 1 << 0 << 0;
3608 return true;
3609 }
3610
3611 // GatherCmp operations on Vulkan target must use component 0 (Red).
3612 if (IsCmp && S.getASTContext().getTargetInfo().getTriple().isSPIRV()) {
3613 std::optional<llvm::APSInt> ComponentOpt =
3614 ComponentArg->getIntegerConstantExpr(S.getASTContext());
3615 if (ComponentOpt) {
3616 int64_t ComponentVal = ComponentOpt->getSExtValue();
3617 if (ComponentVal != 0) {
3618 // Issue an error if the component is not 0 (Red).
3619 // 0 -> Red, 1 -> Green, 2 -> Blue, 3 -> Alpha
3620 assert(ComponentVal >= 0 && ComponentVal <= 3 &&
3621 "The component is not in the expected range.");
3622 S.Diag(ComponentArg->getBeginLoc(),
3623 diag::err_hlsl_gathercmp_invalid_component)
3624 << ComponentVal;
3625 return true;
3626 }
3627 }
3628 }
3629
3630 NextIdx++;
3631
3632 // Check the offset operand.
3633 const HLSLAttributedResourceType *ResourceTy =
3634 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3635 if (TheCall->getNumArgs() > NextIdx) {
3636 unsigned ExpectedDim =
3637 getResourceDimensions(ResourceTy->getAttrs().ResourceDimension);
3638 if (CheckVectorElementCount(&S, TheCall->getArg(NextIdx)->getType(),
3639 S.Context.IntTy, ExpectedDim,
3640 TheCall->getArg(NextIdx)->getBeginLoc()))
3641 return true;
3642 NextIdx++;
3643 }
3644
3645 assert(ResourceTy->hasContainedType() &&
3646 "Expecting a contained type for resource with a dimension "
3647 "attribute.");
3648 QualType ReturnType = ResourceTy->getContainedType();
3649
3650 if (IsCmp) {
3651 if (!ReturnType->hasFloatingRepresentation()) {
3652 S.Diag(TheCall->getBeginLoc(), diag::err_hlsl_samplecmp_requires_float);
3653 return true;
3654 }
3655 }
3656
3657 if (const auto *VecTy = ReturnType->getAs<VectorType>())
3658 ReturnType = VecTy->getElementType();
3659 ReturnType = S.Context.getExtVectorType(ReturnType, 4);
3660
3661 TheCall->setType(ReturnType);
3662
3663 return false;
3664}
3665static bool CheckLoadLevelBuiltin(Sema &S, CallExpr *TheCall) {
3666 if (S.checkArgCountRange(TheCall, 2, 3))
3667 return true;
3668
3669 // Check the texture handle.
3670 if (CheckResourceHandle(&S, TheCall, 0,
3671 [](const HLSLAttributedResourceType *ResType) {
3672 return ResType->getAttrs().ResourceDimension ==
3673 llvm::dxil::ResourceDimension::Unknown;
3674 }))
3675 return true;
3676
3677 auto *ResourceTy =
3678 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3679
3680 // Check the location + lod (int3 for Texture2D).
3681 unsigned ExpectedDim =
3682 getResourceDimensions(ResourceTy->getAttrs().ResourceDimension);
3683 QualType CoordLODTy = TheCall->getArg(1)->getType();
3684 if (CheckVectorElementCount(&S, CoordLODTy, S.Context.IntTy, ExpectedDim + 1,
3685 TheCall->getArg(1)->getBeginLoc()))
3686 return true;
3687
3688 QualType EltTy = CoordLODTy;
3689 if (const auto *VTy = EltTy->getAs<VectorType>())
3690 EltTy = VTy->getElementType();
3691 if (!EltTy->isIntegerType()) {
3692 S.Diag(TheCall->getArg(1)->getBeginLoc(), diag::err_typecheck_expect_int)
3693 << CoordLODTy;
3694 return true;
3695 }
3696
3697 // Check the offset operand.
3698 if (TheCall->getNumArgs() > 2) {
3699 if (CheckVectorElementCount(&S, TheCall->getArg(2)->getType(),
3700 S.Context.IntTy, ExpectedDim,
3701 TheCall->getArg(2)->getBeginLoc()))
3702 return true;
3703 }
3704
3705 TheCall->setType(ResourceTy->getContainedType());
3706 return false;
3707}
3708
3709static bool CheckSamplingBuiltin(Sema &S, CallExpr *TheCall, SampleKind Kind) {
3710 unsigned MinArgs, MaxArgs;
3711 if (Kind == SampleKind::Sample) {
3712 MinArgs = 3;
3713 MaxArgs = 5;
3714 } else if (Kind == SampleKind::Bias) {
3715 MinArgs = 4;
3716 MaxArgs = 6;
3717 } else if (Kind == SampleKind::Grad) {
3718 MinArgs = 5;
3719 MaxArgs = 7;
3720 } else if (Kind == SampleKind::Level) {
3721 MinArgs = 4;
3722 MaxArgs = 5;
3723 } else if (Kind == SampleKind::Cmp) {
3724 MinArgs = 4;
3725 MaxArgs = 6;
3726 } else {
3727 assert(Kind == SampleKind::CmpLevelZero);
3728 MinArgs = 4;
3729 MaxArgs = 5;
3730 }
3731
3732 if (S.checkArgCountRange(TheCall, MinArgs, MaxArgs))
3733 return true;
3734
3735 if (CheckTextureSamplerAndLocation(S, TheCall))
3736 return true;
3737
3738 const HLSLAttributedResourceType *ResourceTy =
3739 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3740 unsigned ExpectedDim =
3741 getResourceDimensions(ResourceTy->getAttrs().ResourceDimension);
3742
3743 unsigned NextIdx = 3;
3744 if (Kind == SampleKind::Bias || Kind == SampleKind::Level ||
3745 Kind == SampleKind::Cmp || Kind == SampleKind::CmpLevelZero) {
3746 // Check the bias, lod level, or compare value, depending on the kind.
3747 // All of them must be a scalar float value.
3748 QualType BiasOrLODOrCmpTy = TheCall->getArg(NextIdx)->getType();
3749 if (!BiasOrLODOrCmpTy->isFloatingType() ||
3750 BiasOrLODOrCmpTy->isVectorType()) {
3751 S.Diag(TheCall->getArg(NextIdx)->getBeginLoc(),
3752 diag::err_typecheck_convert_incompatible)
3753 << BiasOrLODOrCmpTy << S.Context.FloatTy << 1 << 0 << 0;
3754 return true;
3755 }
3756 NextIdx++;
3757 } else if (Kind == SampleKind::Grad) {
3758 // Check the DDX operand.
3759 if (CheckVectorElementCount(&S, TheCall->getArg(NextIdx)->getType(),
3760 S.Context.FloatTy, ExpectedDim,
3761 TheCall->getArg(NextIdx)->getBeginLoc()))
3762 return true;
3763
3764 // Check the DDY operand.
3765 if (CheckVectorElementCount(&S, TheCall->getArg(NextIdx + 1)->getType(),
3766 S.Context.FloatTy, ExpectedDim,
3767 TheCall->getArg(NextIdx + 1)->getBeginLoc()))
3768 return true;
3769 NextIdx += 2;
3770 }
3771
3772 // Check the offset operand.
3773 if (TheCall->getNumArgs() > NextIdx) {
3774 if (CheckVectorElementCount(&S, TheCall->getArg(NextIdx)->getType(),
3775 S.Context.IntTy, ExpectedDim,
3776 TheCall->getArg(NextIdx)->getBeginLoc()))
3777 return true;
3778 NextIdx++;
3779 }
3780
3781 // Check the clamp operand.
3782 if (Kind != SampleKind::Level && Kind != SampleKind::CmpLevelZero &&
3783 TheCall->getNumArgs() > NextIdx) {
3784 QualType ClampTy = TheCall->getArg(NextIdx)->getType();
3785 if (!ClampTy->isFloatingType() || ClampTy->isVectorType()) {
3786 S.Diag(TheCall->getArg(NextIdx)->getBeginLoc(),
3787 diag::err_typecheck_convert_incompatible)
3788 << ClampTy << S.Context.FloatTy << 1 << 0 << 0;
3789 return true;
3790 }
3791 }
3792
3793 assert(ResourceTy->hasContainedType() &&
3794 "Expecting a contained type for resource with a dimension "
3795 "attribute.");
3796 QualType ReturnType = ResourceTy->getContainedType();
3797 if (Kind == SampleKind::Cmp || Kind == SampleKind::CmpLevelZero) {
3798 if (!ReturnType->hasFloatingRepresentation()) {
3799 S.Diag(TheCall->getBeginLoc(), diag::err_hlsl_samplecmp_requires_float);
3800 return true;
3801 }
3802 ReturnType = S.Context.FloatTy;
3803 }
3804 TheCall->setType(ReturnType);
3805
3806 return false;
3807}
3808
3809// Note: returning true in this case results in CheckBuiltinFunctionCall
3810// returning an ExprError
3811bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
3812 switch (BuiltinID) {
3813 case Builtin::BI__builtin_hlsl_adduint64: {
3814 if (SemaRef.checkArgCount(TheCall, 2))
3815 return true;
3816
3817 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3819 return true;
3820
3821 // ensure arg integers are 32-bits
3822 if (CheckExpectedBitWidth(&SemaRef, TheCall, 0, 32))
3823 return true;
3824
3825 // ensure both args are vectors of total bit size of a multiple of 64
3826 auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
3827 int NumElementsArg = VTy->getNumElements();
3828 if (NumElementsArg != 2 && NumElementsArg != 4) {
3829 SemaRef.Diag(TheCall->getBeginLoc(), diag::err_vector_incorrect_bit_count)
3830 << 1 /*a multiple of*/ << 64 << NumElementsArg * 32;
3831 return true;
3832 }
3833
3834 // ensure first arg and second arg have the same type
3835 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
3836 return true;
3837
3838 ExprResult A = TheCall->getArg(0);
3839 QualType ArgTyA = A.get()->getType();
3840 // return type is the same as the input type
3841 TheCall->setType(ArgTyA);
3842 break;
3843 }
3844 case Builtin::BI__builtin_hlsl_resource_getpointer: {
3845 if (SemaRef.checkArgCount(TheCall, 2) ||
3846 CheckResourceHandle(&SemaRef, TheCall, 0) ||
3847 CheckIndexType(&SemaRef, TheCall, 1))
3848 return true;
3849
3850 auto *ResourceTy =
3851 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3852 QualType ContainedTy = ResourceTy->getContainedType();
3853 auto ReturnType =
3854 SemaRef.Context.getAddrSpaceQualType(ContainedTy, LangAS::hlsl_device);
3855 ReturnType = SemaRef.Context.getPointerType(ReturnType);
3856 TheCall->setType(ReturnType);
3857 TheCall->setValueKind(VK_LValue);
3858
3859 break;
3860 }
3861 case Builtin::BI__builtin_hlsl_resource_getpointer_typed: {
3862 if (SemaRef.checkArgCount(TheCall, 3) ||
3863 CheckResourceHandle(&SemaRef, TheCall, 0) ||
3864 CheckIndexType(&SemaRef, TheCall, 1))
3865 return true;
3866
3867 QualType ElementTy = TheCall->getArg(2)->getType();
3868 assert(ElementTy->isPointerType() &&
3869 "expected pointer type for second argument");
3870 ElementTy = ElementTy->getPointeeType();
3871
3872 // Reject array types
3873 if (ElementTy->isArrayType())
3874 return SemaRef.Diag(
3875 cast<FunctionDecl>(SemaRef.CurContext)->getPointOfInstantiation(),
3876 diag::err_invalid_use_of_array_type);
3877
3878 auto ReturnType =
3879 SemaRef.Context.getAddrSpaceQualType(ElementTy, LangAS::hlsl_device);
3880 ReturnType = SemaRef.Context.getPointerType(ReturnType);
3881 TheCall->setType(ReturnType);
3882
3883 break;
3884 }
3885 case Builtin::BI__builtin_hlsl_resource_load_with_status: {
3886 if (SemaRef.checkArgCount(TheCall, 3) ||
3887 CheckResourceHandle(&SemaRef, TheCall, 0) ||
3888 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
3889 SemaRef.getASTContext().UnsignedIntTy) ||
3890 CheckArgTypeMatches(&SemaRef, TheCall->getArg(2),
3891 SemaRef.getASTContext().UnsignedIntTy) ||
3892 CheckModifiableLValue(&SemaRef, TheCall, 2))
3893 return true;
3894
3895 auto *ResourceTy =
3896 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3897 QualType ReturnType = ResourceTy->getContainedType();
3898 TheCall->setType(ReturnType);
3899
3900 break;
3901 }
3902 case Builtin::BI__builtin_hlsl_resource_load_with_status_typed: {
3903 if (SemaRef.checkArgCount(TheCall, 4) ||
3904 CheckResourceHandle(&SemaRef, TheCall, 0) ||
3905 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
3906 SemaRef.getASTContext().UnsignedIntTy) ||
3907 CheckArgTypeMatches(&SemaRef, TheCall->getArg(2),
3908 SemaRef.getASTContext().UnsignedIntTy) ||
3909 CheckModifiableLValue(&SemaRef, TheCall, 2))
3910 return true;
3911
3912 QualType ReturnType = TheCall->getArg(3)->getType();
3913 assert(ReturnType->isPointerType() &&
3914 "expected pointer type for second argument");
3915 ReturnType = ReturnType->getPointeeType();
3916
3917 // Reject array types
3918 if (ReturnType->isArrayType())
3919 return SemaRef.Diag(
3920 cast<FunctionDecl>(SemaRef.CurContext)->getPointOfInstantiation(),
3921 diag::err_invalid_use_of_array_type);
3922
3923 TheCall->setType(ReturnType);
3924
3925 break;
3926 }
3927 case Builtin::BI__builtin_hlsl_resource_load_level:
3928 return CheckLoadLevelBuiltin(SemaRef, TheCall);
3929 case Builtin::BI__builtin_hlsl_resource_sample:
3931 case Builtin::BI__builtin_hlsl_resource_sample_bias:
3933 case Builtin::BI__builtin_hlsl_resource_sample_grad:
3935 case Builtin::BI__builtin_hlsl_resource_sample_level:
3937 case Builtin::BI__builtin_hlsl_resource_sample_cmp:
3939 case Builtin::BI__builtin_hlsl_resource_sample_cmp_level_zero:
3941 case Builtin::BI__builtin_hlsl_resource_calculate_lod:
3942 case Builtin::BI__builtin_hlsl_resource_calculate_lod_unclamped:
3943 return CheckCalculateLodBuiltin(SemaRef, TheCall);
3944 case Builtin::BI__builtin_hlsl_resource_gather:
3945 return CheckGatherBuiltin(SemaRef, TheCall, /*IsCmp=*/false);
3946 case Builtin::BI__builtin_hlsl_resource_gather_cmp:
3947 return CheckGatherBuiltin(SemaRef, TheCall, /*IsCmp=*/true);
3948 case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: {
3949 assert(TheCall->getNumArgs() == 1 && "expected 1 arg");
3950 // Update return type to be the attributed resource type from arg0.
3951 QualType ResourceTy = TheCall->getArg(0)->getType();
3952 TheCall->setType(ResourceTy);
3953 break;
3954 }
3955 case Builtin::BI__builtin_hlsl_resource_handlefrombinding: {
3956 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
3957 // Update return type to be the attributed resource type from arg0.
3958 QualType ResourceTy = TheCall->getArg(0)->getType();
3959 TheCall->setType(ResourceTy);
3960 break;
3961 }
3962 case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: {
3963 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
3964 // Update return type to be the attributed resource type from arg0.
3965 QualType ResourceTy = TheCall->getArg(0)->getType();
3966 TheCall->setType(ResourceTy);
3967 break;
3968 }
3969 case Builtin::BI__builtin_hlsl_resource_counterhandlefromimplicitbinding: {
3970 assert(TheCall->getNumArgs() == 3 && "expected 3 args");
3971 ASTContext &AST = SemaRef.getASTContext();
3972 QualType MainHandleTy = TheCall->getArg(0)->getType();
3973 auto *MainResType = MainHandleTy->getAs<HLSLAttributedResourceType>();
3974 auto MainAttrs = MainResType->getAttrs();
3975 assert(!MainAttrs.IsCounter && "cannot create a counter from a counter");
3976 MainAttrs.IsCounter = true;
3977 QualType CounterHandleTy = AST.getHLSLAttributedResourceType(
3978 MainResType->getWrappedType(), MainResType->getContainedType(),
3979 MainAttrs);
3980 // Update return type to be the attributed resource type from arg0
3981 // with added IsCounter flag.
3982 TheCall->setType(CounterHandleTy);
3983 break;
3984 }
3985 case Builtin::BI__builtin_hlsl_and:
3986 case Builtin::BI__builtin_hlsl_or: {
3987 if (SemaRef.checkArgCount(TheCall, 2))
3988 return true;
3989 if (CheckScalarOrVectorOrMatrix(&SemaRef, TheCall, getASTContext().BoolTy,
3990 0))
3991 return true;
3992 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
3993 return true;
3994
3995 ExprResult A = TheCall->getArg(0);
3996 QualType ArgTyA = A.get()->getType();
3997 // return type is the same as the input type
3998 TheCall->setType(ArgTyA);
3999 break;
4000 }
4001 case Builtin::BI__builtin_hlsl_all:
4002 case Builtin::BI__builtin_hlsl_any: {
4003 if (SemaRef.checkArgCount(TheCall, 1))
4004 return true;
4005 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4006 return true;
4007 break;
4008 }
4009 case Builtin::BI__builtin_hlsl_asdouble: {
4010 if (SemaRef.checkArgCount(TheCall, 2))
4011 return true;
4013 &SemaRef, TheCall,
4014 /*only check for uint*/ SemaRef.Context.UnsignedIntTy,
4015 /* arg index */ 0))
4016 return true;
4018 &SemaRef, TheCall,
4019 /*only check for uint*/ SemaRef.Context.UnsignedIntTy,
4020 /* arg index */ 1))
4021 return true;
4022 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
4023 return true;
4024
4025 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().DoubleTy);
4026 break;
4027 }
4028 case Builtin::BI__builtin_hlsl_elementwise_clamp: {
4029 if (SemaRef.BuiltinElementwiseTernaryMath(
4030 TheCall, /*ArgTyRestr=*/
4032 return true;
4033 break;
4034 }
4035 case Builtin::BI__builtin_hlsl_dot: {
4036 // arg count is checked by BuiltinVectorToScalarMath
4037 if (SemaRef.BuiltinVectorToScalarMath(TheCall))
4038 return true;
4040 return true;
4041 break;
4042 }
4043 case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
4044 case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: {
4045 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4046 return true;
4047
4048 const Expr *Arg = TheCall->getArg(0);
4049 QualType ArgTy = Arg->getType();
4050 QualType EltTy = ArgTy;
4051
4052 QualType ResTy = SemaRef.Context.UnsignedIntTy;
4053
4054 if (auto *VecTy = EltTy->getAs<VectorType>()) {
4055 EltTy = VecTy->getElementType();
4056 ResTy = SemaRef.Context.getExtVectorType(ResTy, VecTy->getNumElements());
4057 }
4058
4059 if (!EltTy->isIntegerType()) {
4060 Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
4061 << 1 << /* scalar or vector of */ 5 << /* integer ty */ 1
4062 << /* no fp */ 0 << ArgTy;
4063 return true;
4064 }
4065
4066 TheCall->setType(ResTy);
4067 break;
4068 }
4069 case Builtin::BI__builtin_hlsl_select: {
4070 if (SemaRef.checkArgCount(TheCall, 3))
4071 return true;
4072 if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
4073 return true;
4074 QualType ArgTy = TheCall->getArg(0)->getType();
4075 if (ArgTy->isBooleanType() && CheckBoolSelect(&SemaRef, TheCall))
4076 return true;
4077 auto *VTy = ArgTy->getAs<VectorType>();
4078 if (VTy && VTy->getElementType()->isBooleanType() &&
4079 CheckVectorSelect(&SemaRef, TheCall))
4080 return true;
4081 break;
4082 }
4083 case Builtin::BI__builtin_hlsl_elementwise_saturate:
4084 case Builtin::BI__builtin_hlsl_elementwise_rcp: {
4085 if (SemaRef.checkArgCount(TheCall, 1))
4086 return true;
4087 if (!TheCall->getArg(0)
4088 ->getType()
4089 ->hasFloatingRepresentation()) // half or float or double
4090 return SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
4091 diag::err_builtin_invalid_arg_type)
4092 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
4093 << /* fp */ 1 << TheCall->getArg(0)->getType();
4094 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4095 return true;
4096 break;
4097 }
4098 case Builtin::BI__builtin_hlsl_elementwise_degrees:
4099 case Builtin::BI__builtin_hlsl_elementwise_radians:
4100 case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
4101 case Builtin::BI__builtin_hlsl_elementwise_frac:
4102 case Builtin::BI__builtin_hlsl_elementwise_ddx_coarse:
4103 case Builtin::BI__builtin_hlsl_elementwise_ddy_coarse:
4104 case Builtin::BI__builtin_hlsl_elementwise_ddx_fine:
4105 case Builtin::BI__builtin_hlsl_elementwise_ddy_fine: {
4106 if (SemaRef.checkArgCount(TheCall, 1))
4107 return true;
4108 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4110 return true;
4111 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4112 return true;
4113 break;
4114 }
4115 case Builtin::BI__builtin_hlsl_elementwise_isinf:
4116 case Builtin::BI__builtin_hlsl_elementwise_isnan: {
4117 if (SemaRef.checkArgCount(TheCall, 1))
4118 return true;
4119 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4121 return true;
4122 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4123 return true;
4125 break;
4126 }
4127 case Builtin::BI__builtin_hlsl_lerp: {
4128 if (SemaRef.checkArgCount(TheCall, 3))
4129 return true;
4130 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4132 return true;
4133 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
4134 return true;
4135 if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
4136 return true;
4137 break;
4138 }
4139 case Builtin::BI__builtin_hlsl_mad: {
4140 if (SemaRef.BuiltinElementwiseTernaryMath(
4141 TheCall, /*ArgTyRestr=*/
4143 return true;
4144 break;
4145 }
4146 case Builtin::BI__builtin_hlsl_mul: {
4147 if (SemaRef.checkArgCount(TheCall, 2))
4148 return true;
4149
4150 Expr *Arg0 = TheCall->getArg(0);
4151 Expr *Arg1 = TheCall->getArg(1);
4152 QualType Ty0 = Arg0->getType();
4153 QualType Ty1 = Arg1->getType();
4154
4155 auto getElemType = [](QualType T) -> QualType {
4156 if (const auto *VTy = T->getAs<VectorType>())
4157 return VTy->getElementType();
4158 if (const auto *MTy = T->getAs<ConstantMatrixType>())
4159 return MTy->getElementType();
4160 return T;
4161 };
4162
4163 QualType EltTy0 = getElemType(Ty0);
4164
4165 bool IsVec0 = Ty0->isVectorType();
4166 bool IsMat0 = Ty0->isConstantMatrixType();
4167 bool IsVec1 = Ty1->isVectorType();
4168 bool IsMat1 = Ty1->isConstantMatrixType();
4169
4170 QualType RetTy;
4171
4172 if (IsVec0 && IsMat1) {
4173 auto *MatTy = Ty1->castAs<ConstantMatrixType>();
4174 RetTy = getASTContext().getExtVectorType(EltTy0, MatTy->getNumColumns());
4175 } else if (IsMat0 && IsVec1) {
4176 auto *MatTy = Ty0->castAs<ConstantMatrixType>();
4177 RetTy = getASTContext().getExtVectorType(EltTy0, MatTy->getNumRows());
4178 } else {
4179 assert(IsMat0 && IsMat1);
4180 auto *MatTy0 = Ty0->castAs<ConstantMatrixType>();
4181 auto *MatTy1 = Ty1->castAs<ConstantMatrixType>();
4183 EltTy0, MatTy0->getNumRows(), MatTy1->getNumColumns());
4184 }
4185
4186 TheCall->setType(RetTy);
4187 break;
4188 }
4189 case Builtin::BI__builtin_hlsl_normalize: {
4190 if (SemaRef.checkArgCount(TheCall, 1))
4191 return true;
4192 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4194 return true;
4195 ExprResult A = TheCall->getArg(0);
4196 QualType ArgTyA = A.get()->getType();
4197 // return type is the same as the input type
4198 TheCall->setType(ArgTyA);
4199 break;
4200 }
4201 case Builtin::BI__builtin_elementwise_fma: {
4202 if (SemaRef.checkArgCount(TheCall, 3) ||
4203 CheckAllArgsHaveSameType(&SemaRef, TheCall)) {
4204 return true;
4205 }
4206
4207 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4209 return true;
4210
4211 ExprResult A = TheCall->getArg(0);
4212 QualType ArgTyA = A.get()->getType();
4213 // return type is the same as input type
4214 TheCall->setType(ArgTyA);
4215 break;
4216 }
4217 case Builtin::BI__builtin_hlsl_transpose: {
4218 if (SemaRef.checkArgCount(TheCall, 1))
4219 return true;
4220
4221 Expr *Arg = TheCall->getArg(0);
4222 QualType ArgTy = Arg->getType();
4223
4224 const auto *MatTy = ArgTy->getAs<ConstantMatrixType>();
4225 if (!MatTy) {
4226 SemaRef.Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
4227 << 1 << /* matrix */ 3 << /* no int */ 0 << /* no fp */ 0 << ArgTy;
4228 return true;
4229 }
4230
4232 MatTy->getElementType(), MatTy->getNumColumns(), MatTy->getNumRows());
4233 TheCall->setType(RetTy);
4234 break;
4235 }
4236 case Builtin::BI__builtin_hlsl_elementwise_sign: {
4237 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4238 return true;
4239 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4241 return true;
4243 break;
4244 }
4245 case Builtin::BI__builtin_hlsl_step: {
4246 if (SemaRef.checkArgCount(TheCall, 2))
4247 return true;
4248 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4250 return true;
4251
4252 ExprResult A = TheCall->getArg(0);
4253 QualType ArgTyA = A.get()->getType();
4254 // return type is the same as the input type
4255 TheCall->setType(ArgTyA);
4256 break;
4257 }
4258 case Builtin::BI__builtin_hlsl_wave_active_all_equal: {
4259 if (SemaRef.checkArgCount(TheCall, 1))
4260 return true;
4261
4262 // Ensure input expr type is a scalar/vector
4263 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4264 return true;
4265
4266 QualType InputTy = TheCall->getArg(0)->getType();
4267 ASTContext &Ctx = getASTContext();
4268
4269 QualType RetTy;
4270
4271 // If vector, construct bool vector of same size
4272 if (const auto *VecTy = InputTy->getAs<ExtVectorType>()) {
4273 unsigned NumElts = VecTy->getNumElements();
4274 RetTy = Ctx.getExtVectorType(Ctx.BoolTy, NumElts);
4275 } else {
4276 // Scalar case
4277 RetTy = Ctx.BoolTy;
4278 }
4279
4280 TheCall->setType(RetTy);
4281 break;
4282 }
4283 case Builtin::BI__builtin_hlsl_wave_active_max:
4284 case Builtin::BI__builtin_hlsl_wave_active_min:
4285 case Builtin::BI__builtin_hlsl_wave_active_sum:
4286 case Builtin::BI__builtin_hlsl_wave_active_product: {
4287 if (SemaRef.checkArgCount(TheCall, 1))
4288 return true;
4289
4290 // Ensure input expr type is a scalar/vector and the same as the return type
4291 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4292 return true;
4293 if (CheckWaveActive(&SemaRef, TheCall))
4294 return true;
4295 ExprResult Expr = TheCall->getArg(0);
4296 QualType ArgTyExpr = Expr.get()->getType();
4297 TheCall->setType(ArgTyExpr);
4298 break;
4299 }
4300 case Builtin::BI__builtin_hlsl_wave_active_bit_or:
4301 case Builtin::BI__builtin_hlsl_wave_active_bit_xor:
4302 case Builtin::BI__builtin_hlsl_wave_active_bit_and: {
4303 if (SemaRef.checkArgCount(TheCall, 1))
4304 return true;
4305
4306 // Ensure input expr type is a scalar/vector
4307 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4308 return true;
4309
4310 if (CheckWaveActive(&SemaRef, TheCall))
4311 return true;
4312
4313 // Ensure the expr type is interpretable as a uint or vector<uint>
4314 ExprResult Expr = TheCall->getArg(0);
4315 QualType ArgTyExpr = Expr.get()->getType();
4316 auto *VTy = ArgTyExpr->getAs<VectorType>();
4317 if (!(ArgTyExpr->isIntegerType() ||
4318 (VTy && VTy->getElementType()->isIntegerType()))) {
4319 SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
4320 diag::err_builtin_invalid_arg_type)
4321 << ArgTyExpr << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
4322 return true;
4323 }
4324
4325 // Ensure input expr type is the same as the return type
4326 TheCall->setType(ArgTyExpr);
4327 break;
4328 }
4329 // Note these are llvm builtins that we want to catch invalid intrinsic
4330 // generation. Normal handling of these builtins will occur elsewhere.
4331 case Builtin::BI__builtin_elementwise_bitreverse: {
4332 // does not include a check for number of arguments
4333 // because that is done previously
4334 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4336 return true;
4337 break;
4338 }
4339 case Builtin::BI__builtin_hlsl_wave_prefix_count_bits: {
4340 if (SemaRef.checkArgCount(TheCall, 1))
4341 return true;
4342
4343 QualType ArgType = TheCall->getArg(0)->getType();
4344
4345 if (!(ArgType->isScalarType())) {
4346 SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
4347 diag::err_typecheck_expect_any_scalar_or_vector)
4348 << ArgType << 0;
4349 return true;
4350 }
4351
4352 if (!(ArgType->isBooleanType())) {
4353 SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
4354 diag::err_typecheck_expect_any_scalar_or_vector)
4355 << ArgType << 0;
4356 return true;
4357 }
4358
4359 break;
4360 }
4361 case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
4362 if (SemaRef.checkArgCount(TheCall, 2))
4363 return true;
4364
4365 // Ensure index parameter type can be interpreted as a uint
4366 ExprResult Index = TheCall->getArg(1);
4367 QualType ArgTyIndex = Index.get()->getType();
4368 if (!ArgTyIndex->isIntegerType()) {
4369 SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
4370 diag::err_typecheck_convert_incompatible)
4371 << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
4372 return true;
4373 }
4374
4375 // Ensure input expr type is a scalar/vector and the same as the return type
4376 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4377 return true;
4378
4379 ExprResult Expr = TheCall->getArg(0);
4380 QualType ArgTyExpr = Expr.get()->getType();
4381 TheCall->setType(ArgTyExpr);
4382 break;
4383 }
4384 case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
4385 if (SemaRef.checkArgCount(TheCall, 0))
4386 return true;
4387 break;
4388 }
4389 case Builtin::BI__builtin_hlsl_wave_prefix_sum:
4390 case Builtin::BI__builtin_hlsl_wave_prefix_product: {
4391 if (SemaRef.checkArgCount(TheCall, 1))
4392 return true;
4393
4394 // Ensure input expr type is a scalar/vector and the same as the return type
4395 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4396 return true;
4397 if (CheckWavePrefix(&SemaRef, TheCall))
4398 return true;
4399 ExprResult Expr = TheCall->getArg(0);
4400 QualType ArgTyExpr = Expr.get()->getType();
4401 TheCall->setType(ArgTyExpr);
4402 break;
4403 }
4404 case Builtin::BI__builtin_hlsl_quad_read_across_x:
4405 case Builtin::BI__builtin_hlsl_quad_read_across_y: {
4406 if (SemaRef.checkArgCount(TheCall, 1))
4407 return true;
4408
4409 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4410 return true;
4411 if (CheckNotBoolScalarOrVector(&SemaRef, TheCall, 0))
4412 return true;
4413 ExprResult Expr = TheCall->getArg(0);
4414 QualType ArgTyExpr = Expr.get()->getType();
4415 TheCall->setType(ArgTyExpr);
4416 break;
4417 }
4418 case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
4419 if (SemaRef.checkArgCount(TheCall, 3))
4420 return true;
4421
4422 if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.DoubleTy, 0) ||
4423 CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
4424 1) ||
4425 CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
4426 2))
4427 return true;
4428
4429 if (CheckModifiableLValue(&SemaRef, TheCall, 1) ||
4430 CheckModifiableLValue(&SemaRef, TheCall, 2))
4431 return true;
4432 break;
4433 }
4434 case Builtin::BI__builtin_hlsl_elementwise_clip: {
4435 if (SemaRef.checkArgCount(TheCall, 1))
4436 return true;
4437
4438 if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.FloatTy, 0))
4439 return true;
4440 break;
4441 }
4442 case Builtin::BI__builtin_elementwise_acos:
4443 case Builtin::BI__builtin_elementwise_asin:
4444 case Builtin::BI__builtin_elementwise_atan:
4445 case Builtin::BI__builtin_elementwise_atan2:
4446 case Builtin::BI__builtin_elementwise_ceil:
4447 case Builtin::BI__builtin_elementwise_cos:
4448 case Builtin::BI__builtin_elementwise_cosh:
4449 case Builtin::BI__builtin_elementwise_exp:
4450 case Builtin::BI__builtin_elementwise_exp2:
4451 case Builtin::BI__builtin_elementwise_exp10:
4452 case Builtin::BI__builtin_elementwise_floor:
4453 case Builtin::BI__builtin_elementwise_fmod:
4454 case Builtin::BI__builtin_elementwise_log:
4455 case Builtin::BI__builtin_elementwise_log2:
4456 case Builtin::BI__builtin_elementwise_log10:
4457 case Builtin::BI__builtin_elementwise_pow:
4458 case Builtin::BI__builtin_elementwise_roundeven:
4459 case Builtin::BI__builtin_elementwise_sin:
4460 case Builtin::BI__builtin_elementwise_sinh:
4461 case Builtin::BI__builtin_elementwise_sqrt:
4462 case Builtin::BI__builtin_elementwise_tan:
4463 case Builtin::BI__builtin_elementwise_tanh:
4464 case Builtin::BI__builtin_elementwise_trunc: {
4465 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4467 return true;
4468 break;
4469 }
4470 case Builtin::BI__builtin_hlsl_buffer_update_counter: {
4471 assert(TheCall->getNumArgs() == 2 && "expected 2 args");
4472 auto checkResTy = [](const HLSLAttributedResourceType *ResTy) -> bool {
4473 return !(ResTy->getAttrs().ResourceClass == ResourceClass::UAV &&
4474 ResTy->getAttrs().RawBuffer && ResTy->hasContainedType());
4475 };
4476 if (CheckResourceHandle(&SemaRef, TheCall, 0, checkResTy))
4477 return true;
4478 Expr *OffsetExpr = TheCall->getArg(1);
4479 std::optional<llvm::APSInt> Offset =
4480 OffsetExpr->getIntegerConstantExpr(SemaRef.getASTContext());
4481 if (!Offset.has_value() || std::abs(Offset->getExtValue()) != 1) {
4482 SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
4483 diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
4484 << 1;
4485 return true;
4486 }
4487 break;
4488 }
4489 case Builtin::BI__builtin_hlsl_elementwise_f16tof32: {
4490 if (SemaRef.checkArgCount(TheCall, 1))
4491 return true;
4492 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4494 return true;
4495 // ensure arg integers are 32 bits
4496 if (CheckExpectedBitWidth(&SemaRef, TheCall, 0, 32))
4497 return true;
4498 // check it wasn't a bool type
4499 QualType ArgTy = TheCall->getArg(0)->getType();
4500 if (auto *VTy = ArgTy->getAs<VectorType>())
4501 ArgTy = VTy->getElementType();
4502 if (ArgTy->isBooleanType()) {
4503 SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
4504 diag::err_builtin_invalid_arg_type)
4505 << 1 << /* scalar or vector of */ 5 << /* unsigned int */ 3
4506 << /* no fp */ 0 << TheCall->getArg(0)->getType();
4507 return true;
4508 }
4509
4510 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().FloatTy);
4511 break;
4512 }
4513 case Builtin::BI__builtin_hlsl_elementwise_f32tof16: {
4514 if (SemaRef.checkArgCount(TheCall, 1))
4515 return true;
4517 return true;
4519 getASTContext().UnsignedIntTy);
4520 break;
4521 }
4522 }
4523 return false;
4524}
4525
4529 WorkList.push_back(BaseTy);
4530 while (!WorkList.empty()) {
4531 QualType T = WorkList.pop_back_val();
4532 T = T.getCanonicalType().getUnqualifiedType();
4533 if (const auto *AT = dyn_cast<ConstantArrayType>(T)) {
4534 llvm::SmallVector<QualType, 16> ElementFields;
4535 // Generally I've avoided recursion in this algorithm, but arrays of
4536 // structs could be time-consuming to flatten and churn through on the
4537 // work list. Hopefully nesting arrays of structs containing arrays
4538 // of structs too many levels deep is unlikely.
4539 BuildFlattenedTypeList(AT->getElementType(), ElementFields);
4540 // Repeat the element's field list n times.
4541 for (uint64_t Ct = 0; Ct < AT->getZExtSize(); ++Ct)
4542 llvm::append_range(List, ElementFields);
4543 continue;
4544 }
4545 // Vectors can only have element types that are builtin types, so this can
4546 // add directly to the list instead of to the WorkList.
4547 if (const auto *VT = dyn_cast<VectorType>(T)) {
4548 List.insert(List.end(), VT->getNumElements(), VT->getElementType());
4549 continue;
4550 }
4551 if (const auto *MT = dyn_cast<ConstantMatrixType>(T)) {
4552 List.insert(List.end(), MT->getNumElementsFlattened(),
4553 MT->getElementType());
4554 continue;
4555 }
4556 if (const auto *RD = T->getAsCXXRecordDecl()) {
4557 if (RD->isStandardLayout())
4558 RD = RD->getStandardLayoutBaseWithFields();
4559
4560 // For types that we shouldn't decompose (unions and non-aggregates), just
4561 // add the type itself to the list.
4562 if (RD->isUnion() || !RD->isAggregate()) {
4563 List.push_back(T);
4564 continue;
4565 }
4566
4568 for (const auto *FD : RD->fields())
4569 if (!FD->isUnnamedBitField())
4570 FieldTypes.push_back(FD->getType());
4571 // Reverse the newly added sub-range.
4572 std::reverse(FieldTypes.begin(), FieldTypes.end());
4573 llvm::append_range(WorkList, FieldTypes);
4574
4575 // If this wasn't a standard layout type we may also have some base
4576 // classes to deal with.
4577 if (!RD->isStandardLayout()) {
4578 FieldTypes.clear();
4579 for (const auto &Base : RD->bases())
4580 FieldTypes.push_back(Base.getType());
4581 std::reverse(FieldTypes.begin(), FieldTypes.end());
4582 llvm::append_range(WorkList, FieldTypes);
4583 }
4584 continue;
4585 }
4586 List.push_back(T);
4587 }
4588}
4589
4591 // null and array types are not allowed.
4592 if (QT.isNull() || QT->isArrayType())
4593 return false;
4594
4595 // UDT types are not allowed
4596 if (QT->isRecordType())
4597 return false;
4598
4599 if (QT->isBooleanType() || QT->isEnumeralType())
4600 return false;
4601
4602 // the only other valid builtin types are scalars or vectors
4603 if (QT->isArithmeticType()) {
4604 if (SemaRef.Context.getTypeSize(QT) / 8 > 16)
4605 return false;
4606 return true;
4607 }
4608
4609 if (const VectorType *VT = QT->getAs<VectorType>()) {
4610 int ArraySize = VT->getNumElements();
4611
4612 if (ArraySize > 4)
4613 return false;
4614
4615 QualType ElTy = VT->getElementType();
4616 if (ElTy->isBooleanType())
4617 return false;
4618
4619 if (SemaRef.Context.getTypeSize(QT) / 8 > 16)
4620 return false;
4621 return true;
4622 }
4623
4624 return false;
4625}
4626
4628 if (T1.isNull() || T2.isNull())
4629 return false;
4630
4633
4634 // If both types are the same canonical type, they're obviously compatible.
4635 if (SemaRef.getASTContext().hasSameType(T1, T2))
4636 return true;
4637
4639 BuildFlattenedTypeList(T1, T1Types);
4641 BuildFlattenedTypeList(T2, T2Types);
4642
4643 // Check the flattened type list
4644 return llvm::equal(T1Types, T2Types,
4645 [this](QualType LHS, QualType RHS) -> bool {
4646 return SemaRef.IsLayoutCompatible(LHS, RHS);
4647 });
4648}
4649
4651 FunctionDecl *Old) {
4652 if (New->getNumParams() != Old->getNumParams())
4653 return true;
4654
4655 bool HadError = false;
4656
4657 for (unsigned i = 0, e = New->getNumParams(); i != e; ++i) {
4658 ParmVarDecl *NewParam = New->getParamDecl(i);
4659 ParmVarDecl *OldParam = Old->getParamDecl(i);
4660
4661 // HLSL parameter declarations for inout and out must match between
4662 // declarations. In HLSL inout and out are ambiguous at the call site,
4663 // but have different calling behavior, so you cannot overload a
4664 // method based on a difference between inout and out annotations.
4665 const auto *NDAttr = NewParam->getAttr<HLSLParamModifierAttr>();
4666 unsigned NSpellingIdx = (NDAttr ? NDAttr->getSpellingListIndex() : 0);
4667 const auto *ODAttr = OldParam->getAttr<HLSLParamModifierAttr>();
4668 unsigned OSpellingIdx = (ODAttr ? ODAttr->getSpellingListIndex() : 0);
4669
4670 if (NSpellingIdx != OSpellingIdx) {
4671 SemaRef.Diag(NewParam->getLocation(),
4672 diag::err_hlsl_param_qualifier_mismatch)
4673 << NDAttr << NewParam;
4674 SemaRef.Diag(OldParam->getLocation(), diag::note_previous_declaration_as)
4675 << ODAttr;
4676 HadError = true;
4677 }
4678 }
4679 return HadError;
4680}
4681
4682// Generally follows PerformScalarCast, with cases reordered for
4683// clarity of what types are supported
4685
4686 if (!SrcTy->isScalarType() || !DestTy->isScalarType())
4687 return false;
4688
4689 if (SemaRef.getASTContext().hasSameUnqualifiedType(SrcTy, DestTy))
4690 return true;
4691
4692 switch (SrcTy->getScalarTypeKind()) {
4693 case Type::STK_Bool: // casting from bool is like casting from an integer
4694 case Type::STK_Integral:
4695 switch (DestTy->getScalarTypeKind()) {
4696 case Type::STK_Bool:
4697 case Type::STK_Integral:
4698 case Type::STK_Floating:
4699 return true;
4700 case Type::STK_CPointer:
4704 llvm_unreachable("HLSL doesn't support pointers.");
4707 llvm_unreachable("HLSL doesn't support complex types.");
4709 llvm_unreachable("HLSL doesn't support fixed point types.");
4710 }
4711 llvm_unreachable("Should have returned before this");
4712
4713 case Type::STK_Floating:
4714 switch (DestTy->getScalarTypeKind()) {
4715 case Type::STK_Floating:
4716 case Type::STK_Bool:
4717 case Type::STK_Integral:
4718 return true;
4721 llvm_unreachable("HLSL doesn't support complex types.");
4723 llvm_unreachable("HLSL doesn't support fixed point types.");
4724 case Type::STK_CPointer:
4728 llvm_unreachable("HLSL doesn't support pointers.");
4729 }
4730 llvm_unreachable("Should have returned before this");
4731
4733 case Type::STK_CPointer:
4736 llvm_unreachable("HLSL doesn't support pointers.");
4737
4739 llvm_unreachable("HLSL doesn't support fixed point types.");
4740
4743 llvm_unreachable("HLSL doesn't support complex types.");
4744 }
4745
4746 llvm_unreachable("Unhandled scalar cast");
4747}
4748
4749// Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
4750// Src is a scalar, a vector of length 1, or a 1x1 matrix
4751// Or if Dest is a vector and Src is a vector of length 1 or a 1x1 matrix
4753
4754 QualType SrcTy = Src->getType();
4755 // Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is
4756 // going to be a vector splat from a scalar.
4757 if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
4758 DestTy->isScalarType())
4759 return false;
4760
4761 const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
4762 const ConstantMatrixType *SrcMatTy = SrcTy->getAs<ConstantMatrixType>();
4763
4764 // Src isn't a scalar, a vector of length 1, or a 1x1 matrix
4765 if (!SrcTy->isScalarType() &&
4766 !(SrcVecTy && SrcVecTy->getNumElements() == 1) &&
4767 !(SrcMatTy && SrcMatTy->getNumElementsFlattened() == 1))
4768 return false;
4769
4770 if (SrcVecTy)
4771 SrcTy = SrcVecTy->getElementType();
4772 else if (SrcMatTy)
4773 SrcTy = SrcMatTy->getElementType();
4774
4776 BuildFlattenedTypeList(DestTy, DestTypes);
4777
4778 for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) {
4779 if (DestTypes[I]->isUnionType())
4780 return false;
4781 if (!CanPerformScalarCast(SrcTy, DestTypes[I]))
4782 return false;
4783 }
4784 return true;
4785}
4786
4787// Can we perform an HLSL Elementwise cast?
4789
4790 // Don't handle casts where LHS and RHS are any combination of scalar/vector
4791 // There must be an aggregate somewhere
4792 QualType SrcTy = Src->getType();
4793 if (SrcTy->isScalarType()) // always a splat and this cast doesn't handle that
4794 return false;
4795
4796 if (SrcTy->isVectorType() &&
4797 (DestTy->isScalarType() || DestTy->isVectorType()))
4798 return false;
4799
4800 if (SrcTy->isConstantMatrixType() &&
4801 (DestTy->isScalarType() || DestTy->isConstantMatrixType()))
4802 return false;
4803
4805 BuildFlattenedTypeList(DestTy, DestTypes);
4807 BuildFlattenedTypeList(SrcTy, SrcTypes);
4808
4809 // Usually the size of SrcTypes must be greater than or equal to the size of
4810 // DestTypes.
4811 if (SrcTypes.size() < DestTypes.size())
4812 return false;
4813
4814 unsigned SrcSize = SrcTypes.size();
4815 unsigned DstSize = DestTypes.size();
4816 unsigned I;
4817 for (I = 0; I < DstSize && I < SrcSize; I++) {
4818 if (SrcTypes[I]->isUnionType() || DestTypes[I]->isUnionType())
4819 return false;
4820 if (!CanPerformScalarCast(SrcTypes[I], DestTypes[I])) {
4821 return false;
4822 }
4823 }
4824
4825 // check the rest of the source type for unions.
4826 for (; I < SrcSize; I++) {
4827 if (SrcTypes[I]->isUnionType())
4828 return false;
4829 }
4830 return true;
4831}
4832
4834 assert(Param->hasAttr<HLSLParamModifierAttr>() &&
4835 "We should not get here without a parameter modifier expression");
4836 const auto *Attr = Param->getAttr<HLSLParamModifierAttr>();
4837 if (Attr->getABI() == ParameterABI::Ordinary)
4838 return ExprResult(Arg);
4839
4840 bool IsInOut = Attr->getABI() == ParameterABI::HLSLInOut;
4841 if (!Arg->isLValue()) {
4842 SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_lvalue)
4843 << Arg << (IsInOut ? 1 : 0);
4844 return ExprError();
4845 }
4846
4847 ASTContext &Ctx = SemaRef.getASTContext();
4848
4849 QualType Ty = Param->getType().getNonLValueExprType(Ctx);
4850
4851 // HLSL allows implicit conversions from scalars to vectors, but not the
4852 // inverse, so we need to disallow `inout` with scalar->vector or
4853 // scalar->matrix conversions.
4854 if (Arg->getType()->isScalarType() != Ty->isScalarType()) {
4855 SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_scalar_extension)
4856 << Arg << (IsInOut ? 1 : 0);
4857 return ExprError();
4858 }
4859
4860 auto *ArgOpV = new (Ctx) OpaqueValueExpr(Param->getBeginLoc(), Arg->getType(),
4861 VK_LValue, OK_Ordinary, Arg);
4862
4863 // Parameters are initialized via copy initialization. This allows for
4864 // overload resolution of argument constructors.
4865 InitializedEntity Entity =
4867 ExprResult Res =
4868 SemaRef.PerformCopyInitialization(Entity, Param->getBeginLoc(), ArgOpV);
4869 if (Res.isInvalid())
4870 return ExprError();
4871 Expr *Base = Res.get();
4872 // After the cast, drop the reference type when creating the exprs.
4873 Ty = Ty.getNonLValueExprType(Ctx);
4874 auto *OpV = new (Ctx)
4875 OpaqueValueExpr(Param->getBeginLoc(), Ty, VK_LValue, OK_Ordinary, Base);
4876
4877 // Writebacks are performed with `=` binary operator, which allows for
4878 // overload resolution on writeback result expressions.
4879 Res = SemaRef.ActOnBinOp(SemaRef.getCurScope(), Param->getBeginLoc(),
4880 tok::equal, ArgOpV, OpV);
4881
4882 if (Res.isInvalid())
4883 return ExprError();
4884 Expr *Writeback = Res.get();
4885 auto *OutExpr =
4886 HLSLOutArgExpr::Create(Ctx, Ty, ArgOpV, OpV, Writeback, IsInOut);
4887
4888 return ExprResult(OutExpr);
4889}
4890
4892 // If HLSL gains support for references, all the cites that use this will need
4893 // to be updated with semantic checking to produce errors for
4894 // pointers/references.
4895 assert(!Ty->isReferenceType() &&
4896 "Pointer and reference types cannot be inout or out parameters");
4897 Ty = SemaRef.getASTContext().getLValueReferenceType(Ty);
4898 Ty.addRestrict();
4899 return Ty;
4900}
4901
4902// Returns true if the type has a non-empty constant buffer layout (if it is
4903// scalar, vector or matrix, or if it contains any of these.
4905 const Type *Ty = QT->getUnqualifiedDesugaredType();
4906 if (Ty->isScalarType() || Ty->isVectorType() || Ty->isMatrixType())
4907 return true;
4908
4910 return false;
4911
4912 if (const auto *RD = Ty->getAsCXXRecordDecl()) {
4913 for (const auto *FD : RD->fields()) {
4915 return true;
4916 }
4917 assert(RD->getNumBases() <= 1 &&
4918 "HLSL doesn't support multiple inheritance");
4919 return RD->getNumBases()
4920 ? hasConstantBufferLayout(RD->bases_begin()->getType())
4921 : false;
4922 }
4923
4924 if (const auto *AT = dyn_cast<ArrayType>(Ty)) {
4925 if (const auto *CAT = dyn_cast<ConstantArrayType>(AT))
4926 if (isZeroSizedArray(CAT))
4927 return false;
4929 }
4930
4931 return false;
4932}
4933
4934static bool IsDefaultBufferConstantDecl(const ASTContext &Ctx, VarDecl *VD) {
4935 bool IsVulkan =
4936 Ctx.getTargetInfo().getTriple().getOS() == llvm::Triple::Vulkan;
4937 bool IsVKPushConstant = IsVulkan && VD->hasAttr<HLSLVkPushConstantAttr>();
4938 QualType QT = VD->getType();
4939 return VD->getDeclContext()->isTranslationUnit() &&
4940 QT.getAddressSpace() == LangAS::Default &&
4941 VD->getStorageClass() != SC_Static &&
4942 !VD->hasAttr<HLSLVkConstantIdAttr>() && !IsVKPushConstant &&
4944}
4945
4947 // The variable already has an address space (groupshared for ex).
4948 if (Decl->getType().hasAddressSpace())
4949 return;
4950
4951 if (Decl->getType()->isDependentType())
4952 return;
4953
4954 QualType Type = Decl->getType();
4955
4956 if (Decl->hasAttr<HLSLVkExtBuiltinInputAttr>()) {
4957 LangAS ImplAS = LangAS::hlsl_input;
4958 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
4959 Decl->setType(Type);
4960 return;
4961 }
4962
4963 if (Decl->hasAttr<HLSLVkExtBuiltinOutputAttr>()) {
4964 LangAS ImplAS = LangAS::hlsl_output;
4965 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
4966 Decl->setType(Type);
4967
4968 // HLSL uses `static` differently than C++. For BuiltIn output, the static
4969 // does not imply private to the module scope.
4970 // Marking it as external to reflect the semantic this attribute brings.
4971 // See https://github.com/microsoft/hlsl-specs/issues/350
4972 Decl->setStorageClass(SC_Extern);
4973 return;
4974 }
4975
4976 bool IsVulkan = getASTContext().getTargetInfo().getTriple().getOS() ==
4977 llvm::Triple::Vulkan;
4978 if (IsVulkan && Decl->hasAttr<HLSLVkPushConstantAttr>()) {
4979 if (HasDeclaredAPushConstant)
4980 SemaRef.Diag(Decl->getLocation(), diag::err_hlsl_push_constant_unique);
4981
4983 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
4984 Decl->setType(Type);
4985 HasDeclaredAPushConstant = true;
4986 return;
4987 }
4988
4989 if (Type->isSamplerT() || Type->isVoidType())
4990 return;
4991
4992 // Resource handles.
4994 return;
4995
4996 // Only static globals belong to the Private address space.
4997 // Non-static globals belongs to the cbuffer.
4998 if (Decl->getStorageClass() != SC_Static && !Decl->isStaticDataMember())
4999 return;
5000
5002 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
5003 Decl->setType(Type);
5004}
5005
5006namespace {
5007
5008// Helper class for assigning bindings to resources declared within a struct.
5009// It keeps track of all binding attributes declared on a struct instance, and
5010// the offsets for each register type that have been assigned so far.
5011// Handles both explicit and implicit bindings.
5012class StructBindingContext {
5013 // Bindings and offsets per register type. We only need to support four
5014 // register types - SRV (u), UAV (t), CBuffer (c), and Sampler (s).
5015 HLSLResourceBindingAttr *RegBindingsAttrs[4];
5016 unsigned RegBindingOffset[4];
5017
5018 // Make sure the RegisterType values are what we expect
5019 static_assert(static_cast<unsigned>(RegisterType::SRV) == 0 &&
5020 static_cast<unsigned>(RegisterType::UAV) == 1 &&
5021 static_cast<unsigned>(RegisterType::CBuffer) == 2 &&
5022 static_cast<unsigned>(RegisterType::Sampler) == 3,
5023 "unexpected register type values");
5024
5025 // Vulkan binding attribute does not vary by register type.
5026 HLSLVkBindingAttr *VkBindingAttr;
5027 unsigned VkBindingOffset;
5028
5029public:
5030 // Constructor: gather all binding attributes on a struct instance and
5031 // initialize offsets.
5032 StructBindingContext(VarDecl *VD) {
5033 for (unsigned i = 0; i < 4; ++i) {
5034 RegBindingsAttrs[i] = nullptr;
5035 RegBindingOffset[i] = 0;
5036 }
5037 VkBindingAttr = nullptr;
5038 VkBindingOffset = 0;
5039
5040 ASTContext &AST = VD->getASTContext();
5041 bool IsSpirv = AST.getTargetInfo().getTriple().isSPIRV();
5042
5043 for (Attr *A : VD->attrs()) {
5044 if (auto *RBA = dyn_cast<HLSLResourceBindingAttr>(A)) {
5045 RegisterType RegType = RBA->getRegisterType();
5046 unsigned RegTypeIdx = static_cast<unsigned>(RegType);
5047 // Ignore unsupported register annotations, such as 'c' or 'i'.
5048 if (RegTypeIdx < 4)
5049 RegBindingsAttrs[RegTypeIdx] = RBA;
5050 continue;
5051 }
5052 // Gather the Vulkan binding attributes only if the target is SPIR-V.
5053 if (IsSpirv) {
5054 if (auto *VBA = dyn_cast<HLSLVkBindingAttr>(A))
5055 VkBindingAttr = VBA;
5056 }
5057 }
5058 }
5059
5060 // Creates a binding attribute for a resource based on the gathered attributes
5061 // and the required register type and range.
5062 Attr *createBindingAttr(SemaHLSL &S, ASTContext &AST, RegisterType RegType,
5063 unsigned Range) {
5064 assert(static_cast<unsigned>(RegType) < 4 && "unexpected register type");
5065
5066 if (VkBindingAttr) {
5067 unsigned Offset = VkBindingOffset;
5068 VkBindingOffset += Range;
5069 return HLSLVkBindingAttr::CreateImplicit(
5070 AST, VkBindingAttr->getBinding() + Offset, VkBindingAttr->getSet(),
5071 VkBindingAttr->getRange());
5072 }
5073
5074 HLSLResourceBindingAttr *RBA =
5075 RegBindingsAttrs[static_cast<unsigned>(RegType)];
5076 HLSLResourceBindingAttr *NewAttr = nullptr;
5077
5078 if (RBA && RBA->hasRegisterSlot()) {
5079 // Explicit binding - create a new attribute with offseted slot number
5080 // based on the required register type.
5081 unsigned Offset = RegBindingOffset[static_cast<unsigned>(RegType)];
5082 RegBindingOffset[static_cast<unsigned>(RegType)] += Range;
5083
5084 unsigned NewSlotNumber = RBA->getSlotNumber() + Offset;
5085 StringRef NewSlotNumberStr =
5086 createRegisterString(AST, RBA->getRegisterType(), NewSlotNumber);
5087 NewAttr = HLSLResourceBindingAttr::CreateImplicit(
5088 AST, NewSlotNumberStr, RBA->getSpace(), RBA->getRange());
5089 NewAttr->setBinding(RegType, NewSlotNumber, RBA->getSpaceNumber());
5090 } else {
5091 // No binding attribute or space-only binding - create a binding
5092 // attribute for implicit binding.
5093 NewAttr = HLSLResourceBindingAttr::CreateImplicit(AST, "", "0", {});
5094 NewAttr->setBinding(RegType, std::nullopt,
5095 RBA ? RBA->getSpaceNumber() : 0);
5096 NewAttr->setImplicitBindingOrderID(S.getNextImplicitBindingOrderID());
5097 }
5098 return NewAttr;
5099 }
5100};
5101
5102// Creates a global variable declaration for a resource field embedded in a
5103// struct, assigns it a binding, initializes it, and associates it with the
5104// struct declaration via an HLSLAssociatedResourceDeclAttr.
5105static void createGlobalResourceDeclForStruct(
5106 Sema &S, VarDecl *ParentVD, SourceLocation Loc, IdentifierInfo *Id,
5107 QualType ResTy, StructBindingContext &BindingCtx) {
5108 assert(isResourceRecordTypeOrArrayOf(ResTy) &&
5109 "expected resource type or array of resources");
5110
5111 DeclContext *DC = ParentVD->getNonTransparentDeclContext();
5112 assert(DC->isTranslationUnit() && "expected translation unit decl context");
5113
5114 ASTContext &AST = S.getASTContext();
5115 VarDecl *ResDecl =
5116 VarDecl::Create(AST, DC, Loc, Loc, Id, ResTy, nullptr, SC_None);
5117
5118 unsigned Range = 1;
5119 const HLSLAttributedResourceType *ResHandleTy = nullptr;
5120 if (const auto *AT = dyn_cast<ArrayType>(ResTy.getTypePtr())) {
5121 const auto *CAT = dyn_cast<ConstantArrayType>(AT);
5122 Range = CAT ? CAT->getSize().getZExtValue() : 0;
5123 ResHandleTy = getResourceArrayHandleType(ResTy);
5124 } else {
5125 ResHandleTy = HLSLAttributedResourceType::findHandleTypeOnResource(
5126 ResTy.getTypePtr());
5127 }
5128 // Add a binding attribute to the global resource declaration.
5129 Attr *BindingAttr = BindingCtx.createBindingAttr(
5130 S.HLSL(), AST, getRegisterType(ResHandleTy), Range);
5131 ResDecl->addAttr(BindingAttr);
5132 ResDecl->addAttr(InternalLinkageAttr::CreateImplicit(AST));
5133 ResDecl->setImplicit();
5134
5135 if (Range == 1)
5136 S.HLSL().initGlobalResourceDecl(ResDecl);
5137 else
5138 S.HLSL().initGlobalResourceArrayDecl(ResDecl);
5139
5140 ParentVD->addAttr(
5141 HLSLAssociatedResourceDeclAttr::CreateImplicit(AST, ResDecl));
5142 DC->addDecl(ResDecl);
5143
5144 DeclGroupRef DG(ResDecl);
5146}
5147
5148static void handleArrayOfStructWithResources(
5149 Sema &S, VarDecl *ParentVD, const ConstantArrayType *CAT,
5150 EmbeddedResourceNameBuilder &NameBuilder, StructBindingContext &BindingCtx);
5151
5152// Scans base and all fields of a struct/class type to find all embedded
5153// resources or resource arrays. Creates a global variable for each resource
5154// found.
5155static void handleStructWithResources(Sema &S, VarDecl *ParentVD,
5156 const CXXRecordDecl *RD,
5157 EmbeddedResourceNameBuilder &NameBuilder,
5158 StructBindingContext &BindingCtx) {
5159
5160 // Scan the base classes.
5161 assert(RD->getNumBases() <= 1 && "HLSL doesn't support multiple inheritance");
5162 const auto *BasesIt = RD->bases_begin();
5163 if (BasesIt != RD->bases_end()) {
5164 QualType QT = BasesIt->getType();
5165 if (QT->isHLSLIntangibleType()) {
5166 CXXRecordDecl *BaseRD = QT->getAsCXXRecordDecl();
5167 NameBuilder.pushBaseName(BaseRD->getName());
5168 handleStructWithResources(S, ParentVD, BaseRD, NameBuilder, BindingCtx);
5169 NameBuilder.pop();
5170 }
5171 }
5172 // Process this class fields.
5173 for (const FieldDecl *FD : RD->fields()) {
5174 QualType FDTy = FD->getType().getCanonicalType();
5175 if (!FDTy->isHLSLIntangibleType())
5176 continue;
5177
5178 NameBuilder.pushName(FD->getName());
5179
5181 IdentifierInfo *II = NameBuilder.getNameAsIdentifier(S.getASTContext());
5182 createGlobalResourceDeclForStruct(S, ParentVD, FD->getLocation(), II,
5183 FDTy, BindingCtx);
5184 } else if (const auto *RD = FDTy->getAsCXXRecordDecl()) {
5185 handleStructWithResources(S, ParentVD, RD, NameBuilder, BindingCtx);
5186
5187 } else if (const auto *ArrayTy = dyn_cast<ConstantArrayType>(FDTy)) {
5188 assert(!FDTy->isHLSLResourceRecordArray() &&
5189 "resource arrays should have been already handled");
5190 handleArrayOfStructWithResources(S, ParentVD, ArrayTy, NameBuilder,
5191 BindingCtx);
5192 }
5193 NameBuilder.pop();
5194 }
5195}
5196
5197// Processes array of structs with resources.
5198static void
5199handleArrayOfStructWithResources(Sema &S, VarDecl *ParentVD,
5200 const ConstantArrayType *CAT,
5201 EmbeddedResourceNameBuilder &NameBuilder,
5202 StructBindingContext &BindingCtx) {
5203
5204 QualType ElementTy = CAT->getElementType().getCanonicalType();
5205 assert(ElementTy->isHLSLIntangibleType() && "Expected HLSL intangible type");
5206
5207 const ConstantArrayType *SubCAT = dyn_cast<ConstantArrayType>(ElementTy);
5208 const CXXRecordDecl *ElementRD = ElementTy->getAsCXXRecordDecl();
5209
5210 if (!SubCAT && !ElementRD)
5211 return;
5212
5213 for (unsigned I = 0, E = CAT->getSize().getZExtValue(); I < E; ++I) {
5214 NameBuilder.pushArrayIndex(I);
5215 if (ElementRD)
5216 handleStructWithResources(S, ParentVD, ElementRD, NameBuilder,
5217 BindingCtx);
5218 else
5219 handleArrayOfStructWithResources(S, ParentVD, SubCAT, NameBuilder,
5220 BindingCtx);
5221 NameBuilder.pop();
5222 }
5223}
5224
5225} // namespace
5226
5227// Scans all fields of a user-defined struct (or array of structs)
5228// to find all embedded resources or resource arrays. For each resource
5229// a global variable of the resource type is created and associated
5230// with the parent declaration (VD) through a HLSLAssociatedResourceDeclAttr
5231// attribute.
5232void SemaHLSL::handleGlobalStructOrArrayOfWithResources(VarDecl *VD) {
5233 EmbeddedResourceNameBuilder NameBuilder(VD->getName());
5234 StructBindingContext BindingCtx(VD);
5235
5236 const Type *VDTy = VD->getType().getTypePtr();
5237 assert(VDTy->isHLSLIntangibleType() && !isResourceRecordTypeOrArrayOf(VD) &&
5238 "Expected non-resource struct or array type");
5239
5240 if (const CXXRecordDecl *RD = VDTy->getAsCXXRecordDecl()) {
5241 handleStructWithResources(SemaRef, VD, RD, NameBuilder, BindingCtx);
5242 return;
5243 }
5244
5245 if (const auto *CAT = dyn_cast<ConstantArrayType>(VDTy)) {
5246 handleArrayOfStructWithResources(SemaRef, VD, CAT, NameBuilder, BindingCtx);
5247 return;
5248 }
5249}
5250
5252 if (VD->hasGlobalStorage()) {
5253 // make sure the declaration has a complete type
5254 if (SemaRef.RequireCompleteType(
5255 VD->getLocation(),
5256 SemaRef.getASTContext().getBaseElementType(VD->getType()),
5257 diag::err_typecheck_decl_incomplete_type)) {
5258 VD->setInvalidDecl();
5260 return;
5261 }
5262
5263 // Global variables outside a cbuffer block that are not a resource, static,
5264 // groupshared, or an empty array or struct belong to the default constant
5265 // buffer $Globals (to be created at the end of the translation unit).
5267 // update address space to hlsl_constant
5270 VD->setType(NewTy);
5271 DefaultCBufferDecls.push_back(VD);
5272 }
5273
5274 // find all resources bindings on decl
5275 if (VD->getType()->isHLSLIntangibleType())
5276 collectResourceBindingsOnVarDecl(VD);
5277
5278 if (VD->hasAttr<HLSLVkConstantIdAttr>())
5280
5282 VD->getStorageClass() != SC_Static) {
5283 // Add internal linkage attribute to non-static resource variables. The
5284 // global externally visible storage is accessed through the handle, which
5285 // is a member. The variable itself is not externally visible.
5286 VD->addAttr(InternalLinkageAttr::CreateImplicit(getASTContext()));
5287 }
5288
5289 // process explicit bindings
5290 processExplicitBindingsOnDecl(VD);
5291
5292 // Add implicit binding attribute to non-static resource arrays.
5293 if (VD->getType()->isHLSLResourceRecordArray() &&
5294 VD->getStorageClass() != SC_Static) {
5295 // If the resource array does not have an explicit binding attribute,
5296 // create an implicit one. It will be used to transfer implicit binding
5297 // order_ID to codegen.
5298 ResourceBindingAttrs Binding(VD);
5299 if (!Binding.isExplicit()) {
5300 uint32_t OrderID = getNextImplicitBindingOrderID();
5301 if (Binding.hasBinding())
5302 Binding.setImplicitOrderID(OrderID);
5303 else {
5306 OrderID);
5307 // Re-create the binding object to pick up the new attribute.
5308 Binding = ResourceBindingAttrs(VD);
5309 }
5310 }
5311
5312 // Get to the base type of a potentially multi-dimensional array.
5314
5315 const CXXRecordDecl *RD = Ty->getAsCXXRecordDecl();
5316 if (hasCounterHandle(RD)) {
5317 if (!Binding.hasCounterImplicitOrderID()) {
5318 uint32_t OrderID = getNextImplicitBindingOrderID();
5319 Binding.setCounterImplicitOrderID(OrderID);
5320 }
5321 }
5322 }
5323
5324 // Process resources in user-defined structs, or arrays of such structs.
5325 const Type *VDTy = VD->getType().getTypePtr();
5326 if (VD->getStorageClass() != SC_Static && VDTy->isHLSLIntangibleType() &&
5328 handleGlobalStructOrArrayOfWithResources(VD);
5329
5330 // Mark groupshared variables as extern so they will have
5331 // external storage and won't be default initialized
5332 if (VD->hasAttr<HLSLGroupSharedAddressSpaceAttr>())
5334 }
5335
5337}
5338
5340 assert(VD->getType()->isHLSLResourceRecord() &&
5341 "expected resource record type");
5342
5343 ASTContext &AST = SemaRef.getASTContext();
5344 uint64_t UIntTySize = AST.getTypeSize(AST.UnsignedIntTy);
5345 uint64_t IntTySize = AST.getTypeSize(AST.IntTy);
5346
5347 // Gather resource binding attributes.
5348 ResourceBindingAttrs Binding(VD);
5349
5350 // Find correct initialization method and create its arguments.
5351 QualType ResourceTy = VD->getType();
5352 CXXRecordDecl *ResourceDecl = ResourceTy->getAsCXXRecordDecl();
5353 CXXMethodDecl *CreateMethod = nullptr;
5355
5356 bool HasCounter = hasCounterHandle(ResourceDecl);
5357 const char *CreateMethodName;
5358 if (Binding.isExplicit())
5359 CreateMethodName = HasCounter ? "__createFromBindingWithImplicitCounter"
5360 : "__createFromBinding";
5361 else
5362 CreateMethodName = HasCounter
5363 ? "__createFromImplicitBindingWithImplicitCounter"
5364 : "__createFromImplicitBinding";
5365
5366 CreateMethod =
5367 lookupMethod(SemaRef, ResourceDecl, CreateMethodName, VD->getLocation());
5368
5369 if (!CreateMethod)
5370 // This can happen if someone creates a struct that looks like an HLSL
5371 // resource record but does not have the required static create method.
5372 // No binding will be generated for it.
5373 return false;
5374
5375 if (Binding.isExplicit()) {
5376 IntegerLiteral *RegSlot =
5377 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, Binding.getSlot()),
5379 Args.push_back(RegSlot);
5380 } else {
5381 uint32_t OrderID = (Binding.hasImplicitOrderID())
5382 ? Binding.getImplicitOrderID()
5384 IntegerLiteral *OrderId =
5385 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, OrderID),
5387 Args.push_back(OrderId);
5388 }
5389
5390 IntegerLiteral *Space =
5391 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, Binding.getSpace()),
5393 Args.push_back(Space);
5394
5396 AST, llvm::APInt(IntTySize, 1), AST.IntTy, SourceLocation());
5397 Args.push_back(RangeSize);
5398
5400 AST, llvm::APInt(UIntTySize, 0), AST.UnsignedIntTy, SourceLocation());
5401 Args.push_back(Index);
5402
5403 StringRef VarName = VD->getName();
5405 AST, VarName, StringLiteralKind::Ordinary, false,
5406 AST.getStringLiteralArrayType(AST.CharTy.withConst(), VarName.size()),
5407 SourceLocation());
5409 AST, AST.getPointerType(AST.CharTy.withConst()), CK_ArrayToPointerDecay,
5410 Name, nullptr, VK_PRValue, FPOptionsOverride());
5411 Args.push_back(NameCast);
5412
5413 if (HasCounter) {
5414 // Will this be in the correct order?
5415 uint32_t CounterOrderID = getNextImplicitBindingOrderID();
5416 IntegerLiteral *CounterId =
5417 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, CounterOrderID),
5419 Args.push_back(CounterId);
5420 }
5421
5422 // Make sure the create method template is instantiated and emitted.
5423 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
5424 SemaRef.InstantiateFunctionDefinition(VD->getLocation(), CreateMethod,
5425 true);
5426
5427 // Create CallExpr with a call to the static method and set it as the decl
5428 // initialization.
5430 AST, NestedNameSpecifierLoc(), SourceLocation(), CreateMethod, false,
5431 CreateMethod->getNameInfo(), CreateMethod->getType(), VK_PRValue);
5432
5433 auto *ImpCast = ImplicitCastExpr::Create(
5434 AST, AST.getPointerType(CreateMethod->getType()),
5435 CK_FunctionToPointerDecay, DRE, nullptr, VK_PRValue, FPOptionsOverride());
5436
5437 CallExpr *InitExpr =
5438 CallExpr::Create(AST, ImpCast, Args, ResourceTy, VK_PRValue,
5440 VD->setInit(InitExpr);
5442 SemaRef.CheckCompleteVariableDeclaration(VD);
5443 return true;
5444}
5445
5447 assert(VD->getType()->isHLSLResourceRecordArray() &&
5448 "expected array of resource records");
5449
5450 // Individual resources in a resource array are not initialized here. They
5451 // are initialized later on during codegen when the individual resources are
5452 // accessed. Codegen will emit a call to the resource initialization method
5453 // with the specified array index. We need to make sure though that the method
5454 // for the specific resource type is instantiated, so codegen can emit a call
5455 // to it when the array element is accessed.
5456
5457 // Find correct initialization method based on the resource binding
5458 // information.
5459 ASTContext &AST = SemaRef.getASTContext();
5460 QualType ResElementTy = AST.getBaseElementType(VD->getType());
5461 CXXRecordDecl *ResourceDecl = ResElementTy->getAsCXXRecordDecl();
5462 CXXMethodDecl *CreateMethod = nullptr;
5463
5464 bool HasCounter = hasCounterHandle(ResourceDecl);
5465 ResourceBindingAttrs ResourceAttrs(VD);
5466 if (ResourceAttrs.isExplicit())
5467 // Resource has explicit binding.
5468 CreateMethod =
5469 lookupMethod(SemaRef, ResourceDecl,
5470 HasCounter ? "__createFromBindingWithImplicitCounter"
5471 : "__createFromBinding",
5472 VD->getLocation());
5473 else
5474 // Resource has implicit binding.
5475 CreateMethod = lookupMethod(
5476 SemaRef, ResourceDecl,
5477 HasCounter ? "__createFromImplicitBindingWithImplicitCounter"
5478 : "__createFromImplicitBinding",
5479 VD->getLocation());
5480
5481 if (!CreateMethod)
5482 return false;
5483
5484 // Make sure the create method template is instantiated and emitted.
5485 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
5486 SemaRef.InstantiateFunctionDefinition(VD->getLocation(), CreateMethod,
5487 true);
5488 return true;
5489}
5490
5491// Returns true if the initialization has been handled.
5492// Returns false to use default initialization.
5494 // Objects in the hlsl_constant address space are initialized
5495 // externally, so don't synthesize an implicit initializer.
5497 return true;
5498
5499 // Initialize non-static resources at the global scope.
5500 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
5501 const Type *Ty = VD->getType().getTypePtr();
5502 if (Ty->isHLSLResourceRecord())
5503 return initGlobalResourceDecl(VD);
5504 if (Ty->isHLSLResourceRecordArray())
5505 return initGlobalResourceArrayDecl(VD);
5506 }
5507 return false;
5508}
5509
5510std::optional<const DeclBindingInfo *> SemaHLSL::inferGlobalBinding(Expr *E) {
5511 if (auto *Ternary = dyn_cast<ConditionalOperator>(E)) {
5512 auto TrueInfo = inferGlobalBinding(Ternary->getTrueExpr());
5513 auto FalseInfo = inferGlobalBinding(Ternary->getFalseExpr());
5514 if (!TrueInfo || !FalseInfo)
5515 return std::nullopt;
5516 if (*TrueInfo != *FalseInfo)
5517 return std::nullopt;
5518 return TrueInfo;
5519 }
5520
5521 if (auto *ASE = dyn_cast<ArraySubscriptExpr>(E))
5522 E = ASE->getBase()->IgnoreParenImpCasts();
5523
5524 if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E->IgnoreParens()))
5525 if (VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
5526 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
5527 if (Ty->isArrayType())
5529
5530 if (const auto *AttrResType =
5531 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
5532 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
5533 return Bindings.getDeclBindingInfo(VD, RC);
5534 }
5535 }
5536
5537 return nullptr;
5538}
5539
5540void SemaHLSL::trackLocalResource(VarDecl *VD, Expr *E) {
5541 std::optional<const DeclBindingInfo *> ExprBinding = inferGlobalBinding(E);
5542 if (!ExprBinding) {
5543 SemaRef.Diag(E->getBeginLoc(),
5544 diag::warn_hlsl_assigning_local_resource_is_not_unique)
5545 << E << VD;
5546 return; // Expr use multiple resources
5547 }
5548
5549 if (*ExprBinding == nullptr)
5550 return; // No binding could be inferred to track, return without error
5551
5552 auto PrevBinding = Assigns.find(VD);
5553 if (PrevBinding == Assigns.end()) {
5554 // No previous binding recorded, simply record the new assignment
5555 Assigns.insert({VD, *ExprBinding});
5556 return;
5557 }
5558
5559 // Otherwise, warn if the assignment implies different resource bindings
5560 if (*ExprBinding != PrevBinding->second) {
5561 SemaRef.Diag(E->getBeginLoc(),
5562 diag::warn_hlsl_assigning_local_resource_is_not_unique)
5563 << E << VD;
5564 SemaRef.Diag(VD->getLocation(), diag::note_var_declared_here) << VD;
5565 return;
5566 }
5567
5568 return;
5569}
5570
5572 Expr *RHSExpr, SourceLocation Loc) {
5573 assert((LHSExpr->getType()->isHLSLResourceRecord() ||
5574 LHSExpr->getType()->isHLSLResourceRecordArray()) &&
5575 "expected LHS to be a resource record or array of resource records");
5576 if (Opc != BO_Assign)
5577 return true;
5578
5579 // If LHS is an array subscript, get the underlying declaration.
5580 Expr *E = LHSExpr;
5581 while (auto *ASE = dyn_cast<ArraySubscriptExpr>(E))
5582 E = ASE->getBase()->IgnoreParenImpCasts();
5583
5584 // Report error if LHS is a non-static resource declared at a global scope.
5585 if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E->IgnoreParens())) {
5586 if (VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
5587 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
5588 // assignment to global resource is not allowed
5589 SemaRef.Diag(Loc, diag::err_hlsl_assign_to_global_resource) << VD;
5590 SemaRef.Diag(VD->getLocation(), diag::note_var_declared_here) << VD;
5591 return false;
5592 }
5593
5594 trackLocalResource(VD, RHSExpr);
5595 }
5596 }
5597 return true;
5598}
5599
5600// Walks though the global variable declaration, collects all resource binding
5601// requirements and adds them to Bindings
5602void SemaHLSL::collectResourceBindingsOnVarDecl(VarDecl *VD) {
5603 assert(VD->hasGlobalStorage() && VD->getType()->isHLSLIntangibleType() &&
5604 "expected global variable that contains HLSL resource");
5605
5606 // Cbuffers and Tbuffers are HLSLBufferDecl types
5607 if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(VD)) {
5608 Bindings.addDeclBindingInfo(VD, CBufferOrTBuffer->isCBuffer()
5609 ? ResourceClass::CBuffer
5610 : ResourceClass::SRV);
5611 return;
5612 }
5613
5614 // Unwrap arrays
5615 // FIXME: Calculate array size while unwrapping
5616 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
5617 while (Ty->isArrayType()) {
5618 const ArrayType *AT = cast<ArrayType>(Ty);
5620 }
5621
5622 // Resource (or array of resources)
5623 if (const HLSLAttributedResourceType *AttrResType =
5624 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
5625 Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass);
5626 return;
5627 }
5628
5629 // User defined record type
5630 if (const RecordType *RT = dyn_cast<RecordType>(Ty))
5631 collectResourceBindingsOnUserRecordDecl(VD, RT);
5632}
5633
5634// Walks though the explicit resource binding attributes on the declaration,
5635// and makes sure there is a resource that matched the binding and updates
5636// DeclBindingInfoLists
5637void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
5638 assert(VD->hasGlobalStorage() && "expected global variable");
5639
5640 bool HasBinding = false;
5641 for (Attr *A : VD->attrs()) {
5642 if (isa<HLSLVkBindingAttr>(A)) {
5643 HasBinding = true;
5644 if (auto PA = VD->getAttr<HLSLVkPushConstantAttr>())
5645 Diag(PA->getLoc(), diag::err_hlsl_attr_incompatible) << A << PA;
5646 }
5647
5648 HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
5649 if (!RBA || !RBA->hasRegisterSlot())
5650 continue;
5651 HasBinding = true;
5652
5653 RegisterType RT = RBA->getRegisterType();
5654 assert(RT != RegisterType::I && "invalid or obsolete register type should "
5655 "never have an attribute created");
5656
5657 if (RT == RegisterType::C) {
5658 if (Bindings.hasBindingInfoForDecl(VD))
5659 SemaRef.Diag(VD->getLocation(),
5660 diag::warn_hlsl_user_defined_type_missing_member)
5661 << static_cast<int>(RT);
5662 continue;
5663 }
5664
5665 // Find DeclBindingInfo for this binding and update it, or report error
5666 // if it does not exist (user type does to contain resources with the
5667 // expected resource class).
5669 if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, RC)) {
5670 // update binding info
5671 BI->setBindingAttribute(RBA, BindingType::Explicit);
5672 } else {
5673 SemaRef.Diag(VD->getLocation(),
5674 diag::warn_hlsl_user_defined_type_missing_member)
5675 << static_cast<int>(RT);
5676 }
5677 }
5678
5679 if (!HasBinding && isResourceRecordTypeOrArrayOf(VD))
5680 SemaRef.Diag(VD->getLocation(), diag::warn_hlsl_implicit_binding);
5681}
5682namespace {
5683class InitListTransformer {
5684 Sema &S;
5685 ASTContext &Ctx;
5686 QualType InitTy;
5687 QualType *DstIt = nullptr;
5688 Expr **ArgIt = nullptr;
5689 // Is wrapping the destination type iterator required? This is only used for
5690 // incomplete array types where we loop over the destination type since we
5691 // don't know the full number of elements from the declaration.
5692 bool Wrap;
5693
5694 bool castInitializer(Expr *E) {
5695 assert(DstIt && "This should always be something!");
5696 if (DstIt == DestTypes.end()) {
5697 if (!Wrap) {
5698 ArgExprs.push_back(E);
5699 // This is odd, but it isn't technically a failure due to conversion, we
5700 // handle mismatched counts of arguments differently.
5701 return true;
5702 }
5703 DstIt = DestTypes.begin();
5704 }
5705 InitializedEntity Entity = InitializedEntity::InitializeParameter(
5706 Ctx, *DstIt, /* Consumed (ObjC) */ false);
5707 ExprResult Res = S.PerformCopyInitialization(Entity, E->getBeginLoc(), E);
5708 if (Res.isInvalid())
5709 return false;
5710 Expr *Init = Res.get();
5711 ArgExprs.push_back(Init);
5712 DstIt++;
5713 return true;
5714 }
5715
5716 bool buildInitializerListImpl(Expr *E) {
5717 // If this is an initialization list, traverse the sub initializers.
5718 if (auto *Init = dyn_cast<InitListExpr>(E)) {
5719 for (auto *SubInit : Init->inits())
5720 if (!buildInitializerListImpl(SubInit))
5721 return false;
5722 return true;
5723 }
5724
5725 // If this is a scalar type, just enqueue the expression.
5726 QualType Ty = E->getType().getDesugaredType(Ctx);
5727
5728 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()) ||
5730 return castInitializer(E);
5731
5732 if (auto *VecTy = Ty->getAs<VectorType>()) {
5733 uint64_t Size = VecTy->getNumElements();
5734
5735 QualType SizeTy = Ctx.getSizeType();
5736 uint64_t SizeTySize = Ctx.getTypeSize(SizeTy);
5737 for (uint64_t I = 0; I < Size; ++I) {
5738 auto *Idx = IntegerLiteral::Create(Ctx, llvm::APInt(SizeTySize, I),
5739 SizeTy, SourceLocation());
5740
5742 E, E->getBeginLoc(), Idx, E->getEndLoc());
5743 if (ElExpr.isInvalid())
5744 return false;
5745 if (!castInitializer(ElExpr.get()))
5746 return false;
5747 }
5748 return true;
5749 }
5750 if (auto *MTy = Ty->getAs<ConstantMatrixType>()) {
5751 unsigned Rows = MTy->getNumRows();
5752 unsigned Cols = MTy->getNumColumns();
5753 QualType ElemTy = MTy->getElementType();
5754
5755 for (unsigned R = 0; R < Rows; ++R) {
5756 for (unsigned C = 0; C < Cols; ++C) {
5757 // row index literal
5758 Expr *RowIdx = IntegerLiteral::Create(
5759 Ctx, llvm::APInt(Ctx.getIntWidth(Ctx.IntTy), R), Ctx.IntTy,
5760 E->getBeginLoc());
5761 // column index literal
5762 Expr *ColIdx = IntegerLiteral::Create(
5763 Ctx, llvm::APInt(Ctx.getIntWidth(Ctx.IntTy), C), Ctx.IntTy,
5764 E->getBeginLoc());
5766 E, RowIdx, ColIdx, E->getEndLoc());
5767 if (ElExpr.isInvalid())
5768 return false;
5769 if (!castInitializer(ElExpr.get()))
5770 return false;
5771 ElExpr.get()->setType(ElemTy);
5772 }
5773 }
5774 return true;
5775 }
5776
5777 if (auto *ArrTy = dyn_cast<ConstantArrayType>(Ty.getTypePtr())) {
5778 uint64_t Size = ArrTy->getZExtSize();
5779 QualType SizeTy = Ctx.getSizeType();
5780 uint64_t SizeTySize = Ctx.getTypeSize(SizeTy);
5781 for (uint64_t I = 0; I < Size; ++I) {
5782 auto *Idx = IntegerLiteral::Create(Ctx, llvm::APInt(SizeTySize, I),
5783 SizeTy, SourceLocation());
5785 E, E->getBeginLoc(), Idx, E->getEndLoc());
5786 if (ElExpr.isInvalid())
5787 return false;
5788 if (!buildInitializerListImpl(ElExpr.get()))
5789 return false;
5790 }
5791 return true;
5792 }
5793
5794 if (auto *RD = Ty->getAsCXXRecordDecl()) {
5795 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
5796 RecordDecls.push_back(RD);
5797 // If this is a prvalue create an xvalue so the member accesses
5798 // will be xvalues.
5799 if (E->isPRValue())
5800 E = new (Ctx)
5801 MaterializeTemporaryExpr(Ty, E, /*BoundToLvalueReference=*/false);
5802 while (RecordDecls.back()->getNumBases()) {
5803 CXXRecordDecl *D = RecordDecls.back();
5804 assert(D->getNumBases() == 1 &&
5805 "HLSL doesn't support multiple inheritance");
5806 RecordDecls.push_back(
5808 }
5809 while (!RecordDecls.empty()) {
5810 CXXRecordDecl *RD = RecordDecls.pop_back_val();
5811 for (auto *FD : RD->fields()) {
5812 if (FD->isUnnamedBitField())
5813 continue;
5814 DeclAccessPair Found = DeclAccessPair::make(FD, FD->getAccess());
5815 DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc());
5817 E, false, E->getBeginLoc(), CXXScopeSpec(), FD, Found, NameInfo);
5818 if (Res.isInvalid())
5819 return false;
5820 if (!buildInitializerListImpl(Res.get()))
5821 return false;
5822 }
5823 }
5824 }
5825 return true;
5826 }
5827
5828 Expr *generateInitListsImpl(QualType Ty) {
5829 Ty = Ty.getDesugaredType(Ctx);
5830 assert(ArgIt != ArgExprs.end() && "Something is off in iteration!");
5831 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()) ||
5833 return *(ArgIt++);
5834
5835 llvm::SmallVector<Expr *> Inits;
5836 if (Ty->isVectorType() || Ty->isConstantArrayType() ||
5837 Ty->isConstantMatrixType()) {
5838 QualType ElTy;
5839 uint64_t Size = 0;
5840 if (auto *ATy = Ty->getAs<VectorType>()) {
5841 ElTy = ATy->getElementType();
5842 Size = ATy->getNumElements();
5843 } else if (auto *CMTy = Ty->getAs<ConstantMatrixType>()) {
5844 ElTy = CMTy->getElementType();
5845 Size = CMTy->getNumElementsFlattened();
5846 } else {
5847 auto *VTy = cast<ConstantArrayType>(Ty.getTypePtr());
5848 ElTy = VTy->getElementType();
5849 Size = VTy->getZExtSize();
5850 }
5851 for (uint64_t I = 0; I < Size; ++I)
5852 Inits.push_back(generateInitListsImpl(ElTy));
5853 }
5854 if (auto *RD = Ty->getAsCXXRecordDecl()) {
5855 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
5856 RecordDecls.push_back(RD);
5857 while (RecordDecls.back()->getNumBases()) {
5858 CXXRecordDecl *D = RecordDecls.back();
5859 assert(D->getNumBases() == 1 &&
5860 "HLSL doesn't support multiple inheritance");
5861 RecordDecls.push_back(
5863 }
5864 while (!RecordDecls.empty()) {
5865 CXXRecordDecl *RD = RecordDecls.pop_back_val();
5866 for (auto *FD : RD->fields())
5867 if (!FD->isUnnamedBitField())
5868 Inits.push_back(generateInitListsImpl(FD->getType()));
5869 }
5870 }
5871 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
5872 Inits, Inits.back()->getEndLoc());
5873 NewInit->setType(Ty);
5874 return NewInit;
5875 }
5876
5877public:
5878 llvm::SmallVector<QualType, 16> DestTypes;
5879 llvm::SmallVector<Expr *, 16> ArgExprs;
5880 InitListTransformer(Sema &SemaRef, const InitializedEntity &Entity)
5881 : S(SemaRef), Ctx(SemaRef.getASTContext()),
5882 Wrap(Entity.getType()->isIncompleteArrayType()) {
5883 InitTy = Entity.getType().getNonReferenceType();
5884 // When we're generating initializer lists for incomplete array types we
5885 // need to wrap around both when building the initializers and when
5886 // generating the final initializer lists.
5887 if (Wrap) {
5888 assert(InitTy->isIncompleteArrayType());
5889 const IncompleteArrayType *IAT = Ctx.getAsIncompleteArrayType(InitTy);
5890 InitTy = IAT->getElementType();
5891 }
5892 BuildFlattenedTypeList(InitTy, DestTypes);
5893 DstIt = DestTypes.begin();
5894 }
5895
5896 bool buildInitializerList(Expr *E) { return buildInitializerListImpl(E); }
5897
5898 Expr *generateInitLists() {
5899 assert(!ArgExprs.empty() &&
5900 "Call buildInitializerList to generate argument expressions.");
5901 ArgIt = ArgExprs.begin();
5902 if (!Wrap)
5903 return generateInitListsImpl(InitTy);
5904 llvm::SmallVector<Expr *> Inits;
5905 while (ArgIt != ArgExprs.end())
5906 Inits.push_back(generateInitListsImpl(InitTy));
5907
5908 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
5909 Inits, Inits.back()->getEndLoc());
5910 llvm::APInt ArySize(64, Inits.size());
5911 NewInit->setType(Ctx.getConstantArrayType(InitTy, ArySize, nullptr,
5912 ArraySizeModifier::Normal, 0));
5913 return NewInit;
5914 }
5915};
5916} // namespace
5917
5918// Recursively detect any incomplete array anywhere in the type graph,
5919// including arrays, struct fields, and base classes.
5921 Ty = Ty.getCanonicalType();
5922
5923 // Array types
5924 if (const ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
5926 return true;
5928 }
5929
5930 // Record (struct/class) types
5931 if (const auto *RT = Ty->getAs<RecordType>()) {
5932 const RecordDecl *RD = RT->getDecl();
5933
5934 // Walk base classes (for C++ / HLSL structs with inheritance)
5935 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
5936 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
5937 if (containsIncompleteArrayType(Base.getType()))
5938 return true;
5939 }
5940 }
5941
5942 // Walk fields
5943 for (const FieldDecl *F : RD->fields()) {
5944 if (containsIncompleteArrayType(F->getType()))
5945 return true;
5946 }
5947 }
5948
5949 return false;
5950}
5951
5953 InitListExpr *Init) {
5954 // If the initializer is a scalar, just return it.
5955 if (Init->getType()->isScalarType())
5956 return true;
5957 ASTContext &Ctx = SemaRef.getASTContext();
5958 InitListTransformer ILT(SemaRef, Entity);
5959
5960 for (unsigned I = 0; I < Init->getNumInits(); ++I) {
5961 Expr *E = Init->getInit(I);
5962 if (E->HasSideEffects(Ctx)) {
5963 QualType Ty = E->getType();
5964 if (Ty->isRecordType())
5965 E = new (Ctx) MaterializeTemporaryExpr(Ty, E, E->isLValue());
5966 E = new (Ctx) OpaqueValueExpr(E->getBeginLoc(), Ty, E->getValueKind(),
5967 E->getObjectKind(), E);
5968 Init->setInit(I, E);
5969 }
5970 if (!ILT.buildInitializerList(E))
5971 return false;
5972 }
5973 size_t ExpectedSize = ILT.DestTypes.size();
5974 size_t ActualSize = ILT.ArgExprs.size();
5975 if (ExpectedSize == 0 && ActualSize == 0)
5976 return true;
5977
5978 // Reject empty initializer if *any* incomplete array exists structurally
5979 if (ActualSize == 0 && containsIncompleteArrayType(Entity.getType())) {
5980 QualType InitTy = Entity.getType().getNonReferenceType();
5981 if (InitTy.hasAddressSpace())
5982 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(InitTy);
5983
5984 SemaRef.Diag(Init->getBeginLoc(), diag::err_hlsl_incorrect_num_initializers)
5985 << /*TooManyOrFew=*/(int)(ExpectedSize < ActualSize) << InitTy
5986 << /*ExpectedSize=*/ExpectedSize << /*ActualSize=*/ActualSize;
5987 return false;
5988 }
5989
5990 // We infer size after validating legality.
5991 // For incomplete arrays it is completely arbitrary to choose whether we think
5992 // the user intended fewer or more elements. This implementation assumes that
5993 // the user intended more, and errors that there are too few initializers to
5994 // complete the final element.
5995 if (Entity.getType()->isIncompleteArrayType()) {
5996 assert(ExpectedSize > 0 &&
5997 "The expected size of an incomplete array type must be at least 1.");
5998 ExpectedSize =
5999 ((ActualSize + ExpectedSize - 1) / ExpectedSize) * ExpectedSize;
6000 }
6001
6002 // An initializer list might be attempting to initialize a reference or
6003 // rvalue-reference. When checking the initializer we should look through
6004 // the reference.
6005 QualType InitTy = Entity.getType().getNonReferenceType();
6006 if (InitTy.hasAddressSpace())
6007 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(InitTy);
6008 if (ExpectedSize != ActualSize) {
6009 int TooManyOrFew = ActualSize > ExpectedSize ? 1 : 0;
6010 SemaRef.Diag(Init->getBeginLoc(), diag::err_hlsl_incorrect_num_initializers)
6011 << TooManyOrFew << InitTy << ExpectedSize << ActualSize;
6012 return false;
6013 }
6014
6015 // generateInitListsImpl will always return an InitListExpr here, because the
6016 // scalar case is handled above.
6017 auto *NewInit = cast<InitListExpr>(ILT.generateInitLists());
6018 Init->resizeInits(Ctx, NewInit->getNumInits());
6019 for (unsigned I = 0; I < NewInit->getNumInits(); ++I)
6020 Init->updateInit(Ctx, I, NewInit->getInit(I));
6021 return true;
6022}
6023
6024static QualType ReportMatrixInvalidMember(Sema &S, StringRef Name,
6025 StringRef Expected,
6026 SourceLocation OpLoc,
6027 SourceLocation CompLoc) {
6028 S.Diag(OpLoc, diag::err_builtin_matrix_invalid_member)
6029 << Name << Expected << SourceRange(CompLoc);
6030 return QualType();
6031}
6032
6035 const IdentifierInfo *CompName,
6036 SourceLocation CompLoc) {
6037 const auto *MT = baseType->castAs<ConstantMatrixType>();
6038 StringRef AccessorName = CompName->getName();
6039 assert(!AccessorName.empty() && "Matrix Accessor must have a name");
6040
6041 unsigned Rows = MT->getNumRows();
6042 unsigned Cols = MT->getNumColumns();
6043 bool IsZeroBasedAccessor = false;
6044 unsigned ChunkLen = 0;
6045 if (AccessorName.size() < 2)
6046 return ReportMatrixInvalidMember(S, AccessorName,
6047 "length 4 for zero based: \'_mRC\' or "
6048 "length 3 for one-based: \'_RC\' accessor",
6049 OpLoc, CompLoc);
6050
6051 if (AccessorName[0] == '_') {
6052 if (AccessorName[1] == 'm') {
6053 IsZeroBasedAccessor = true;
6054 ChunkLen = 4; // zero-based: "_mRC"
6055 } else {
6056 ChunkLen = 3; // one-based: "_RC"
6057 }
6058 } else
6060 S, AccessorName, "zero based: \'_mRC\' or one-based: \'_RC\' accessor",
6061 OpLoc, CompLoc);
6062
6063 if (AccessorName.size() % ChunkLen != 0) {
6064 const llvm::StringRef Expected = IsZeroBasedAccessor
6065 ? "zero based: '_mRC' accessor"
6066 : "one-based: '_RC' accessor";
6067
6068 return ReportMatrixInvalidMember(S, AccessorName, Expected, OpLoc, CompLoc);
6069 }
6070
6071 auto isDigit = [](char c) { return c >= '0' && c <= '9'; };
6072 auto isZeroBasedIndex = [](unsigned i) { return i <= 3; };
6073 auto isOneBasedIndex = [](unsigned i) { return i >= 1 && i <= 4; };
6074
6075 bool HasRepeated = false;
6076 SmallVector<bool, 16> Seen(Rows * Cols, false);
6077 unsigned NumComponents = 0;
6078 const char *Begin = AccessorName.data();
6079
6080 for (unsigned I = 0, E = AccessorName.size(); I < E; I += ChunkLen) {
6081 const char *Chunk = Begin + I;
6082 char RowChar = 0, ColChar = 0;
6083 if (IsZeroBasedAccessor) {
6084 // Zero-based: "_mRC"
6085 if (Chunk[0] != '_' || Chunk[1] != 'm') {
6086 char Bad = (Chunk[0] != '_') ? Chunk[0] : Chunk[1];
6088 S, StringRef(&Bad, 1), "\'_m\' prefix",
6089 OpLoc.getLocWithOffset(I + (Bad == Chunk[0] ? 1 : 2)), CompLoc);
6090 }
6091 RowChar = Chunk[2];
6092 ColChar = Chunk[3];
6093 } else {
6094 // One-based: "_RC"
6095 if (Chunk[0] != '_')
6097 S, StringRef(&Chunk[0], 1), "\'_\' prefix",
6098 OpLoc.getLocWithOffset(I + 1), CompLoc);
6099 RowChar = Chunk[1];
6100 ColChar = Chunk[2];
6101 }
6102
6103 // Must be digits.
6104 bool IsDigitsError = false;
6105 if (!isDigit(RowChar)) {
6106 unsigned BadPos = IsZeroBasedAccessor ? 2 : 1;
6107 ReportMatrixInvalidMember(S, StringRef(&RowChar, 1), "row as integer",
6108 OpLoc.getLocWithOffset(I + BadPos + 1),
6109 CompLoc);
6110 IsDigitsError = true;
6111 }
6112
6113 if (!isDigit(ColChar)) {
6114 unsigned BadPos = IsZeroBasedAccessor ? 3 : 2;
6115 ReportMatrixInvalidMember(S, StringRef(&ColChar, 1), "column as integer",
6116 OpLoc.getLocWithOffset(I + BadPos + 1),
6117 CompLoc);
6118 IsDigitsError = true;
6119 }
6120 if (IsDigitsError)
6121 return QualType();
6122
6123 unsigned Row = RowChar - '0';
6124 unsigned Col = ColChar - '0';
6125
6126 bool HasIndexingError = false;
6127 if (IsZeroBasedAccessor) {
6128 // 0-based [0..3]
6129 if (!isZeroBasedIndex(Row)) {
6130 S.Diag(OpLoc, diag::err_hlsl_matrix_element_not_in_bounds)
6131 << /*row*/ 0 << /*zero-based*/ 0 << SourceRange(CompLoc);
6132 HasIndexingError = true;
6133 }
6134 if (!isZeroBasedIndex(Col)) {
6135 S.Diag(OpLoc, diag::err_hlsl_matrix_element_not_in_bounds)
6136 << /*col*/ 1 << /*zero-based*/ 0 << SourceRange(CompLoc);
6137 HasIndexingError = true;
6138 }
6139 } else {
6140 // 1-based [1..4]
6141 if (!isOneBasedIndex(Row)) {
6142 S.Diag(OpLoc, diag::err_hlsl_matrix_element_not_in_bounds)
6143 << /*row*/ 0 << /*one-based*/ 1 << SourceRange(CompLoc);
6144 HasIndexingError = true;
6145 }
6146 if (!isOneBasedIndex(Col)) {
6147 S.Diag(OpLoc, diag::err_hlsl_matrix_element_not_in_bounds)
6148 << /*col*/ 1 << /*one-based*/ 1 << SourceRange(CompLoc);
6149 HasIndexingError = true;
6150 }
6151 // Convert to 0-based after range checking.
6152 --Row;
6153 --Col;
6154 }
6155
6156 if (HasIndexingError)
6157 return QualType();
6158
6159 // Note: matrix swizzle index is hard coded. That means Row and Col can
6160 // potentially be larger than Rows and Cols if matrix size is less than
6161 // the max index size.
6162 bool HasBoundsError = false;
6163 if (Row >= Rows) {
6164 Diag(OpLoc, diag::err_hlsl_matrix_index_out_of_bounds)
6165 << /*Row*/ 0 << Row << Rows << SourceRange(CompLoc);
6166 HasBoundsError = true;
6167 }
6168 if (Col >= Cols) {
6169 Diag(OpLoc, diag::err_hlsl_matrix_index_out_of_bounds)
6170 << /*Col*/ 1 << Col << Cols << SourceRange(CompLoc);
6171 HasBoundsError = true;
6172 }
6173 if (HasBoundsError)
6174 return QualType();
6175
6176 unsigned FlatIndex = Row * Cols + Col;
6177 if (Seen[FlatIndex])
6178 HasRepeated = true;
6179 Seen[FlatIndex] = true;
6180 ++NumComponents;
6181 }
6182 if (NumComponents == 0 || NumComponents > 4) {
6183 S.Diag(OpLoc, diag::err_hlsl_matrix_swizzle_invalid_length)
6184 << NumComponents << SourceRange(CompLoc);
6185 return QualType();
6186 }
6187
6188 QualType ElemTy = MT->getElementType();
6189 if (NumComponents == 1)
6190 return ElemTy;
6191 QualType VT = S.Context.getExtVectorType(ElemTy, NumComponents);
6192 if (HasRepeated)
6193 VK = VK_PRValue;
6194
6195 for (Sema::ExtVectorDeclsType::iterator
6197 E = S.ExtVectorDecls.end();
6198 I != E; ++I) {
6199 if ((*I)->getUnderlyingType() == VT)
6201 /*Qualifier=*/std::nullopt, *I);
6202 }
6203
6204 return VT;
6205}
6206
6208 // If initializing a local resource, track the resource binding it is using
6209 if (VDecl->getType()->isHLSLResourceRecord() && !VDecl->hasGlobalStorage())
6210 trackLocalResource(VDecl, Init);
6211
6212 const HLSLVkConstantIdAttr *ConstIdAttr =
6213 VDecl->getAttr<HLSLVkConstantIdAttr>();
6214 if (!ConstIdAttr)
6215 return true;
6216
6217 ASTContext &Context = SemaRef.getASTContext();
6218
6219 APValue InitValue;
6220 if (!Init->isCXX11ConstantExpr(Context, &InitValue)) {
6221 Diag(VDecl->getLocation(), diag::err_specialization_const);
6222 VDecl->setInvalidDecl();
6223 return false;
6224 }
6225
6226 Builtin::ID BID =
6228
6229 // Argument 1: The ID from the attribute
6230 int ConstantID = ConstIdAttr->getId();
6231 llvm::APInt IDVal(Context.getIntWidth(Context.IntTy), ConstantID);
6232 Expr *IdExpr = IntegerLiteral::Create(Context, IDVal, Context.IntTy,
6233 ConstIdAttr->getLocation());
6234
6235 SmallVector<Expr *, 2> Args = {IdExpr, Init};
6236 Expr *C = SemaRef.BuildBuiltinCallExpr(Init->getExprLoc(), BID, Args);
6237 if (C->getType()->getCanonicalTypeUnqualified() !=
6239 C = SemaRef
6240 .BuildCStyleCastExpr(SourceLocation(),
6241 Context.getTrivialTypeSourceInfo(
6242 Init->getType(), Init->getExprLoc()),
6243 SourceLocation(), C)
6244 .get();
6245 }
6246 Init = C;
6247 return true;
6248}
6249
6251 SourceLocation NameLoc) {
6252 if (!Template)
6253 return QualType();
6254
6255 DeclContext *DC = Template->getDeclContext();
6256 if (!DC->isNamespace() || !cast<NamespaceDecl>(DC)->getIdentifier() ||
6257 cast<NamespaceDecl>(DC)->getName() != "hlsl")
6258 return QualType();
6259
6260 TemplateParameterList *Params = Template->getTemplateParameters();
6261 if (!Params || Params->size() != 1)
6262 return QualType();
6263
6264 if (!Template->isImplicit())
6265 return QualType();
6266
6267 // We manually extract default arguments here instead of letting
6268 // CheckTemplateIdType handle it. This ensures that for resource types that
6269 // lack a default argument (like Buffer), we return a null QualType, which
6270 // triggers the "requires template arguments" error rather than a less
6271 // descriptive "too few template arguments" error.
6272 TemplateArgumentListInfo TemplateArgs(NameLoc, NameLoc);
6273 for (NamedDecl *P : *Params) {
6274 if (auto *TTP = dyn_cast<TemplateTypeParmDecl>(P)) {
6275 if (TTP->hasDefaultArgument()) {
6276 TemplateArgs.addArgument(TTP->getDefaultArgument());
6277 continue;
6278 }
6279 } else if (auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(P)) {
6280 if (NTTP->hasDefaultArgument()) {
6281 TemplateArgs.addArgument(NTTP->getDefaultArgument());
6282 continue;
6283 }
6284 } else if (auto *TTPD = dyn_cast<TemplateTemplateParmDecl>(P)) {
6285 if (TTPD->hasDefaultArgument()) {
6286 TemplateArgs.addArgument(TTPD->getDefaultArgument());
6287 continue;
6288 }
6289 }
6290 return QualType();
6291 }
6292
6293 return SemaRef.CheckTemplateIdType(
6295 TemplateArgs, nullptr, /*ForNestedNameSpecifier=*/false);
6296}
Defines the clang::ASTContext interface.
Defines enum values for all the target-independent builtin functions.
llvm::dxil::ResourceClass ResourceClass
Defines the C++ Decl subclasses, other than those for templates (found in DeclTemplate....
TokenType getType() const
Returns the token's type, e.g.
FormatToken * Previous
The previous token in the unwrapped line.
Defines the clang::IdentifierInfo, clang::IdentifierTable, and clang::Selector interfaces.
#define X(type, name)
Definition Value.h:97
Forward-declares and imports various common LLVM datatypes that clang wants to use unqualified.
llvm::SmallVector< std::pair< const MemRegion *, SVal >, 4 > Bindings
static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType)
static void BuildFlattenedTypeList(QualType BaseTy, llvm::SmallVectorImpl< QualType > &List)
static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static bool containsIncompleteArrayType(QualType Ty)
static QualType handleIntegerVectorBinOpConversion(Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign)
static bool convertToRegisterType(StringRef Slot, RegisterType *RT)
Definition SemaHLSL.cpp:82
static StringRef createRegisterString(ASTContext &AST, RegisterType RegType, unsigned N)
Definition SemaHLSL.cpp:184
static bool CheckWaveActive(Sema *S, CallExpr *TheCall)
static void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl)
Definition SemaHLSL.cpp:609
static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz)
static QualType ReportMatrixInvalidMember(Sema &S, StringRef Name, StringRef Expected, SourceLocation OpLoc, SourceLocation CompLoc)
static bool CheckBoolSelect(Sema *S, CallExpr *TheCall)
static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context, QualType T)
Definition SemaHLSL.cpp:246
static bool isZeroSizedArray(const ConstantArrayType *CAT)
Definition SemaHLSL.cpp:365
static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, Decl *D, RegisterType RegType, bool SpecifiedSpace)
static bool hasConstantBufferLayout(QualType QT)
static FieldDecl * createFieldForHostLayoutStruct(Sema &S, const Type *Ty, IdentifierInfo *II, CXXRecordDecl *LayoutStruct)
Definition SemaHLSL.cpp:517
static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
SampleKind
static bool isInvalidConstantBufferLeafElementType(const Type *Ty)
Definition SemaHLSL.cpp:399
static bool CheckCalculateLodBuiltin(Sema &S, CallExpr *TheCall)
static Builtin::ID getSpecConstBuiltinId(const Type *Type)
Definition SemaHLSL.cpp:150
static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static const Type * createHostLayoutType(Sema &S, const Type *Ty)
Definition SemaHLSL.cpp:490
static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall, unsigned ArgIndex)
static const HLSLAttributedResourceType * getResourceArrayHandleType(QualType QT)
Definition SemaHLSL.cpp:381
static IdentifierInfo * getHostLayoutStructName(Sema &S, NamedDecl *BaseDecl, bool MustBeUnique)
Definition SemaHLSL.cpp:455
static void addImplicitBindingAttrToDecl(Sema &S, Decl *D, RegisterType RT, uint32_t ImplicitBindingOrderID)
Definition SemaHLSL.cpp:653
static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall, QualType ReturnType)
static unsigned calculateLegacyCbufferSize(const ASTContext &Context, QualType T)
Definition SemaHLSL.cpp:265
static bool CheckLoadLevelBuiltin(Sema &S, CallExpr *TheCall)
static RegisterType getRegisterType(ResourceClass RC)
Definition SemaHLSL.cpp:62
static bool ValidateRegisterNumber(uint64_t SlotNum, Decl *TheDecl, ASTContext &Ctx, RegisterType RegTy)
static bool isVkPipelineBuiltin(const ASTContext &AstContext, FunctionDecl *FD, HLSLAppliedSemanticAttr *Semantic, bool IsInput)
Definition SemaHLSL.cpp:841
static bool CheckVectorElementCount(Sema *S, QualType PassedType, QualType BaseType, unsigned ExpectedCount, SourceLocation Loc)
static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall, unsigned ArgIndex)
static QualType castElement(Sema &S, ExprResult &E, QualType Ty)
static char getRegisterTypeChar(RegisterType RT)
Definition SemaHLSL.cpp:114
static bool CheckNotBoolScalarOrVector(Sema *S, CallExpr *TheCall, unsigned ArgIndex)
static CXXRecordDecl * findRecordDeclInContext(IdentifierInfo *II, DeclContext *DC)
Definition SemaHLSL.cpp:438
static bool CheckWavePrefix(Sema *S, CallExpr *TheCall)
static bool CheckExpectedBitWidth(Sema *S, CallExpr *TheCall, unsigned ArgOrdinal, unsigned Width)
static bool hasCounterHandle(const CXXRecordDecl *RD)
static bool CheckVectorSelect(Sema *S, CallExpr *TheCall)
static QualType handleFloatVectorBinOpConversion(Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign)
static ResourceClass getResourceClass(RegisterType RT)
Definition SemaHLSL.cpp:132
static CXXRecordDecl * createHostLayoutStruct(Sema &S, CXXRecordDecl *StructDecl)
Definition SemaHLSL.cpp:544
static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar, unsigned ArgIndex)
static bool CheckSamplingBuiltin(Sema &S, CallExpr *TheCall, SampleKind Kind)
static bool CheckScalarOrVectorOrMatrix(Sema *S, CallExpr *TheCall, QualType Scalar, unsigned ArgIndex)
static bool CheckFloatRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static bool CheckAnyDoubleRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD)
Definition SemaHLSL.cpp:418
static bool CheckResourceHandle(Sema *S, CallExpr *TheCall, unsigned ArgIndex, llvm::function_ref< bool(const HLSLAttributedResourceType *ResType)> Check=nullptr)
static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl)
Definition SemaHLSL.cpp:312
static bool IsDefaultBufferConstantDecl(const ASTContext &Ctx, VarDecl *VD)
HLSLResourceBindingAttr::RegisterType RegisterType
Definition SemaHLSL.cpp:57
static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy, QualType SrcTy)
static bool CheckGatherBuiltin(Sema &S, CallExpr *TheCall, bool IsCmp)
static bool isValidWaveSizeValue(unsigned Value)
static bool isResourceRecordTypeOrArrayOf(QualType Ty)
Definition SemaHLSL.cpp:372
static bool AccumulateHLSLResourceSlots(QualType Ty, uint64_t &StartSlot, const uint64_t &Limit, const ResourceClass ResClass, ASTContext &Ctx, uint64_t ArrayCount=1)
static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, RegisterType regType)
static bool CheckTextureSamplerAndLocation(Sema &S, CallExpr *TheCall)
static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, Decl *D, RegisterType RegType, bool SpecifiedSpace)
static bool CheckIndexType(Sema *S, CallExpr *TheCall, unsigned IndexArgIndex)
This file declares semantic analysis for HLSL constructs.
Defines the clang::SourceLocation class and associated facilities.
Defines various enumerations that describe declaration and type specifiers.
C Language Family Type Representation.
Defines the clang::TypeLoc interface and its subclasses.
C Language Family Type Representation.
static const TypeInfo & getInfo(unsigned id)
Definition Types.cpp:44
__device__ __2f16 float c
return(__x > > __y)|(__x<<(32 - __y))
APValue - This class implements a discriminated union of [uninitialized] [APSInt] [APFloat],...
Definition APValue.h:122
virtual bool HandleTopLevelDecl(DeclGroupRef D)
HandleTopLevelDecl - Handle the specified top-level declaration.
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
Definition ASTContext.h:226
unsigned getIntWidth(QualType T) const
int getIntegerTypeOrder(QualType LHS, QualType RHS) const
Return the highest ranked integer type, see C99 6.3.1.8p1.
CanQualType FloatTy
QualType getPointerType(QualType T) const
Return the uniqued reference to the type for a pointer to the specified type.
const IncompleteArrayType * getAsIncompleteArrayType(QualType T) const
IdentifierTable & Idents
Definition ASTContext.h:798
QualType getConstantArrayType(QualType EltTy, const llvm::APInt &ArySize, const Expr *SizeExpr, ArraySizeModifier ASM, unsigned IndexTypeQuals) const
Return the unique reference to the type for a constant array of the specified element type.
QualType getBaseElementType(const ArrayType *VAT) const
Return the innermost element type of an array type.
int getFloatingTypeOrder(QualType LHS, QualType RHS) const
Compare the rank of the two specified floating point types, ignoring the domain of the type (i....
CanQualType BoolTy
TypeSourceInfo * getTrivialTypeSourceInfo(QualType T, SourceLocation Loc=SourceLocation()) const
Allocate a TypeSourceInfo where all locations have been initialized to a given location,...
QualType getStringLiteralArrayType(QualType EltTy, unsigned Length) const
Return a type for a constant array for a string literal of the specified element type and length.
CanQualType CharTy
CanQualType IntTy
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
CharUnits getTypeSizeInChars(QualType T) const
Return the size of the specified (complete) type T, in characters.
CanQualType UnsignedIntTy
QualType getTypedefType(ElaboratedTypeKeyword Keyword, NestedNameSpecifier Qualifier, const TypedefNameDecl *Decl, QualType UnderlyingType=QualType(), std::optional< bool > TypeMatchesDeclOrNone=std::nullopt) const
Return the unique reference to the type for the specified typedef-name decl.
llvm::StringRef backupStr(llvm::StringRef S) const
Definition ASTContext.h:880
QualType getSizeType() const
Return the unique type for "size_t" (C99 7.17), defined in <stddef.h>.
QualType getExtVectorType(QualType VectorType, unsigned NumElts) const
Return the unique reference to an extended vector type of the specified element type and size.
const TargetInfo & getTargetInfo() const
Definition ASTContext.h:917
QualType getHLSLAttributedResourceType(QualType Wrapped, QualType Contained, const HLSLAttributedResourceType::Attributes &Attrs)
QualType getAddrSpaceQualType(QualType T, LangAS AddressSpace) const
Return the uniqued reference to the type for an address space qualified type with the specified type ...
CanQualType getCanonicalTagType(const TagDecl *TD) const
static bool hasSameUnqualifiedType(QualType T1, QualType T2)
Determine whether the given types are equivalent after cvr-qualifiers have been removed.
QualType getConstantMatrixType(QualType ElementType, unsigned NumRows, unsigned NumColumns) const
Return the unique reference to the matrix type of the specified element type and size.
unsigned getTypeAlign(QualType T) const
Return the ABI-specified alignment of a (complete) type T, in bits.
PtrTy get() const
Definition Ownership.h:171
bool isInvalid() const
Definition Ownership.h:167
Represents an array type, per C99 6.7.5.2 - Array Declarators.
Definition TypeBase.h:3772
QualType getElementType() const
Definition TypeBase.h:3784
Attr - This represents one attribute.
Definition Attr.h:46
attr::Kind getKind() const
Definition Attr.h:92
SourceLocation getLocation() const
Definition Attr.h:99
SourceLocation getScopeLoc() const
const IdentifierInfo * getScopeName() const
SourceLocation getLoc() const
const IdentifierInfo * getAttrName() const
Represents a base class of a C++ class.
Definition DeclCXX.h:146
QualType getType() const
Retrieves the type of the base class.
Definition DeclCXX.h:249
Represents a static or instance method of a struct/union/class.
Definition DeclCXX.h:2136
Represents a C++ struct/union/class.
Definition DeclCXX.h:258
bool isHLSLIntangible() const
Returns true if the class contains HLSL intangible type, either as a field or in base class.
Definition DeclCXX.h:1556
static CXXRecordDecl * Create(const ASTContext &C, TagKind TK, DeclContext *DC, SourceLocation StartLoc, SourceLocation IdLoc, IdentifierInfo *Id, CXXRecordDecl *PrevDecl=nullptr)
Definition DeclCXX.cpp:132
void setBases(CXXBaseSpecifier const *const *Bases, unsigned NumBases)
Sets the base classes of this struct or class.
Definition DeclCXX.cpp:184
base_class_iterator bases_end()
Definition DeclCXX.h:617
void completeDefinition() override
Indicates that the definition of this class is now complete.
Definition DeclCXX.cpp:2249
base_class_range bases()
Definition DeclCXX.h:608
unsigned getNumBases() const
Retrieves the number of base classes of this class.
Definition DeclCXX.h:602
base_class_iterator bases_begin()
Definition DeclCXX.h:615
bool isEmpty() const
Determine whether this is an empty class in the sense of (C++11 [meta.unary.prop]).
Definition DeclCXX.h:1186
CallExpr - Represents a function call (C99 6.5.2.2, C++ [expr.call]).
Definition Expr.h:2946
Expr * getArg(unsigned Arg)
getArg - Return the specified argument.
Definition Expr.h:3150
SourceLocation getBeginLoc() const
Definition Expr.h:3280
static CallExpr * Create(const ASTContext &Ctx, Expr *Fn, ArrayRef< Expr * > Args, QualType Ty, ExprValueKind VK, SourceLocation RParenLoc, FPOptionsOverride FPFeatures, unsigned MinNumArgs=0, ADLCallKind UsesADL=NotADL)
Create a call expression.
Definition Expr.cpp:1517
FunctionDecl * getDirectCallee()
If the callee is a FunctionDecl, return it. Otherwise return null.
Definition Expr.h:3129
Expr * getCallee()
Definition Expr.h:3093
unsigned getNumArgs() const
getNumArgs - Return the number of actual arguments to this call.
Definition Expr.h:3137
Decl * getCalleeDecl()
Definition Expr.h:3123
QualType withConst() const
Retrieves a version of this type with const applied.
const T * getTypePtr() const
Retrieve the underlying type pointer, which refers to a canonical type.
QuantityType getQuantity() const
getQuantity - Get the raw integer representation of this quantity.
Definition CharUnits.h:185
Represents the canonical version of C arrays with a specified constant size.
Definition TypeBase.h:3810
bool isZeroSize() const
Return true if the size is zero.
Definition TypeBase.h:3880
llvm::APInt getSize() const
Return the constant array size as an APInt.
Definition TypeBase.h:3866
uint64_t getZExtSize() const
Return the size zero-extended as a uint64_t.
Definition TypeBase.h:3886
Represents a concrete matrix type with constant number of rows and columns.
Definition TypeBase.h:4437
unsigned getNumColumns() const
Returns the number of columns in the matrix.
Definition TypeBase.h:4456
static DeclAccessPair make(NamedDecl *D, AccessSpecifier AS)
DeclContext - This is used only as base class of specific decl types that can act as declaration cont...
Definition DeclBase.h:1462
bool isNamespace() const
Definition DeclBase.h:2211
lookup_result lookup(DeclarationName Name) const
lookup - Find the declarations (if any) with the given Name in this context.
bool isTranslationUnit() const
Definition DeclBase.h:2198
void addDecl(Decl *D)
Add the declaration D into this context.
decl_range decls() const
decls_begin/decls_end - Iterate over the declarations stored in this context.
Definition DeclBase.h:2386
DeclContext * getNonTransparentContext()
A reference to a declared variable, function, enum, etc.
Definition Expr.h:1273
static DeclRefExpr * Create(const ASTContext &Context, NestedNameSpecifierLoc QualifierLoc, SourceLocation TemplateKWLoc, ValueDecl *D, bool RefersToEnclosingVariableOrCapture, SourceLocation NameLoc, QualType T, ExprValueKind VK, NamedDecl *FoundD=nullptr, const TemplateArgumentListInfo *TemplateArgs=nullptr, NonOdrUseReason NOUR=NOUR_None)
Definition Expr.cpp:488
ValueDecl * getDecl()
Definition Expr.h:1341
Decl - This represents one declaration (or definition), e.g.
Definition DeclBase.h:86
T * getAttr() const
Definition DeclBase.h:581
ASTContext & getASTContext() const LLVM_READONLY
Definition DeclBase.cpp:547
void addAttr(Attr *A)
attr_iterator attr_end() const
Definition DeclBase.h:550
bool isImplicit() const
isImplicit - Indicates whether the declaration was implicitly generated by the implementation.
Definition DeclBase.h:601
void setInvalidDecl(bool Invalid=true)
setInvalidDecl - Indicates the Decl had a semantic error.
Definition DeclBase.cpp:178
bool isInExportDeclContext() const
Whether this declaration was exported in a lexical context.
attr_iterator attr_begin() const
Definition DeclBase.h:547
DeclContext * getNonTransparentDeclContext()
Return the non transparent context.
SourceLocation getLocation() const
Definition DeclBase.h:447
void setImplicit(bool I=true)
Definition DeclBase.h:602
DeclContext * getDeclContext()
Definition DeclBase.h:456
attr_range attrs() const
Definition DeclBase.h:543
AccessSpecifier getAccess() const
Definition DeclBase.h:515
SourceLocation getBeginLoc() const LLVM_READONLY
Definition DeclBase.h:439
void dropAttr()
Definition DeclBase.h:564
bool hasAttr() const
Definition DeclBase.h:585
The name of a declaration.
Represents a ValueDecl that came out of a declarator.
Definition Decl.h:780
SourceLocation getBeginLoc() const LLVM_READONLY
Definition Decl.h:831
This represents one expression.
Definition Expr.h:112
bool isIntegerConstantExpr(const ASTContext &Ctx) const
void setType(QualType t)
Definition Expr.h:145
ExprValueKind getValueKind() const
getValueKind - The value kind that this expression produces.
Definition Expr.h:447
Expr * IgnoreParenImpCasts() LLVM_READONLY
Skip past any parentheses and implicit casts which might surround this expression until reaching a fi...
Definition Expr.cpp:3090
Expr * IgnoreParens() LLVM_READONLY
Skip past any parentheses which might surround this expression until reaching a fixed point.
Definition Expr.cpp:3086
std::optional< llvm::APSInt > getIntegerConstantExpr(const ASTContext &Ctx) const
isIntegerConstantExpr - Return the value if this expression is a valid integer constant expression.
bool isPRValue() const
Definition Expr.h:285
bool isLValue() const
isLValue - True if this expression is an "l-value" according to the rules of the current language.
Definition Expr.h:284
ExprObjectKind getObjectKind() const
getObjectKind - The object kind that this expression produces.
Definition Expr.h:454
bool HasSideEffects(const ASTContext &Ctx, bool IncludePossibleEffects=true) const
HasSideEffects - This routine returns true for all those expressions which have any effect other than...
Definition Expr.cpp:3688
void setValueKind(ExprValueKind Cat)
setValueKind - Set the value kind produced by this expression.
Definition Expr.h:464
SourceLocation getExprLoc() const LLVM_READONLY
getExprLoc - Return the preferred location for the arrow when diagnosing a problem with a generic exp...
Definition Expr.cpp:277
@ MLV_Valid
Definition Expr.h:306
QualType getType() const
Definition Expr.h:144
ExtVectorType - Extended vector type.
Definition TypeBase.h:4317
Represents difference between two FPOptions values.
Represents a member of a struct/union/class.
Definition Decl.h:3175
static FieldDecl * Create(const ASTContext &C, DeclContext *DC, SourceLocation StartLoc, SourceLocation IdLoc, const IdentifierInfo *Id, QualType T, TypeSourceInfo *TInfo, Expr *BW, bool Mutable, InClassInitStyle InitStyle)
Definition Decl.cpp:4702
static FixItHint CreateReplacement(CharSourceRange RemoveRange, StringRef Code)
Create a code modification hint that replaces the given source range with the given code string.
Definition Diagnostic.h:141
Represents a function declaration or definition.
Definition Decl.h:2015
const ParmVarDecl * getParamDecl(unsigned i) const
Definition Decl.h:2812
Stmt * getBody(const FunctionDecl *&Definition) const
Retrieve the body (definition) of the function.
Definition Decl.cpp:3281
bool isThisDeclarationADefinition() const
Returns whether this specific declaration of the function is also a definition that does not contain ...
Definition Decl.h:2329
QualType getReturnType() const
Definition Decl.h:2860
ArrayRef< ParmVarDecl * > parameters() const
Definition Decl.h:2789
bool isTemplateInstantiation() const
Determines if the given function was instantiated from a function template.
Definition Decl.cpp:4259
redecl_range redecls() const
Returns an iterator range for all the redeclarations of the same decl.
unsigned getNumParams() const
Return the number of parameters this function must have based on its FunctionType.
Definition Decl.cpp:3828
DeclarationNameInfo getNameInfo() const
Definition Decl.h:2226
bool hasBody(const FunctionDecl *&Definition) const
Returns true if the function has a body.
Definition Decl.cpp:3201
bool isDefined(const FunctionDecl *&Definition, bool CheckForPendingFriendDefinition=false) const
Returns true if the function has a definition that does not need to be instantiated.
Definition Decl.cpp:3248
HLSLBufferDecl - Represent a cbuffer or tbuffer declaration.
Definition Decl.h:5211
static HLSLBufferDecl * Create(ASTContext &C, DeclContext *LexicalParent, bool CBuffer, SourceLocation KwLoc, IdentifierInfo *ID, SourceLocation IDLoc, SourceLocation LBrace)
Definition Decl.cpp:5907
void addLayoutStruct(CXXRecordDecl *LS)
Definition Decl.cpp:5947
void setHasValidPackoffset(bool PO)
Definition Decl.h:5256
static HLSLBufferDecl * CreateDefaultCBuffer(ASTContext &C, DeclContext *LexicalParent, ArrayRef< Decl * > DefaultCBufferDecls)
Definition Decl.cpp:5930
buffer_decl_range buffer_decls() const
Definition Decl.h:5286
static HLSLOutArgExpr * Create(const ASTContext &C, QualType Ty, OpaqueValueExpr *Base, OpaqueValueExpr *OpV, Expr *WB, bool IsInOut)
Definition Expr.cpp:5645
static HLSLRootSignatureDecl * Create(ASTContext &C, DeclContext *DC, SourceLocation Loc, IdentifierInfo *ID, llvm::dxbc::RootSignatureVersion Version, ArrayRef< llvm::hlsl::rootsig::RootElement > RootElements)
Definition Decl.cpp:5993
One of these records is kept for each identifier that is lexed.
StringRef getName() const
Return the actual identifier string.
A simple pair of identifier info and location.
SourceLocation getLoc() const
IdentifierInfo * getIdentifierInfo() const
IdentifierInfo & get(StringRef Name)
Return the identifier token info for the specified named identifier.
ImplicitCastExpr - Allows us to explicitly represent implicit type conversions, which have no direct ...
Definition Expr.h:3856
static ImplicitCastExpr * Create(const ASTContext &Context, QualType T, CastKind Kind, Expr *Operand, const CXXCastPath *BasePath, ExprValueKind Cat, FPOptionsOverride FPO)
Definition Expr.cpp:2073
Describes an C or C++ initializer list.
Definition Expr.h:5302
Describes an entity that is being initialized.
QualType getType() const
Retrieve type being initialized.
static InitializedEntity InitializeParameter(ASTContext &Context, ParmVarDecl *Parm)
Create the initialization entity for a parameter.
static IntegerLiteral * Create(const ASTContext &C, const llvm::APInt &V, QualType type, SourceLocation l)
Returns a new integer literal with value 'V' and type 'type'.
Definition Expr.cpp:975
iterator begin(Source *source, bool LocalOnly=false)
Represents the results of name lookup.
Definition Lookup.h:147
Represents a prvalue temporary that is written into memory so that a reference can bind to it.
Definition ExprCXX.h:4920
Represents a matrix type, as defined in the Matrix Types clang extensions.
Definition TypeBase.h:4387
MemberExpr - [C99 6.5.2.3] Structure and Union Members.
Definition Expr.h:3367
ValueDecl * getMemberDecl() const
Retrieve the member declaration to which this expression refers.
Definition Expr.h:3450
Expr * getBase() const
Definition Expr.h:3444
This represents a decl that may have a name.
Definition Decl.h:274
IdentifierInfo * getIdentifier() const
Get the identifier that names this declaration, if there is one.
Definition Decl.h:295
StringRef getName() const
Get the name of identifier for this declaration as a StringRef.
Definition Decl.h:301
DeclarationName getDeclName() const
Get the actual, stored name of the declaration, which may be a special name.
Definition Decl.h:340
A C++ nested-name-specifier augmented with source location information.
OpaqueValueExpr - An expression referring to an opaque object of a fixed type and value class.
Definition Expr.h:1181
Represents a parameter to a function.
Definition Decl.h:1805
ParsedAttr - Represents a syntactic attribute.
Definition ParsedAttr.h:119
unsigned getSemanticSpelling() const
If the parsed attribute has a semantic equivalent, and it would have a semantic Spelling enumeration ...
unsigned getMinArgs() const
bool checkExactlyNumArgs(class Sema &S, unsigned Num) const
Check if the attribute has exactly as many args as Num.
IdentifierLoc * getArgAsIdent(unsigned Arg) const
Definition ParsedAttr.h:389
bool hasParsedType() const
Definition ParsedAttr.h:337
const ParsedType & getTypeArg() const
Definition ParsedAttr.h:459
unsigned getNumArgs() const
getNumArgs - Return the number of actual arguments to this attribute.
Definition ParsedAttr.h:371
bool isArgIdent(unsigned Arg) const
Definition ParsedAttr.h:385
Expr * getArgAsExpr(unsigned Arg) const
Definition ParsedAttr.h:383
AttributeCommonInfo::Kind getKind() const
Definition ParsedAttr.h:610
A (possibly-)qualified type.
Definition TypeBase.h:937
void addRestrict()
Add the restrict qualifier to this QualType.
Definition TypeBase.h:1178
QualType getNonLValueExprType(const ASTContext &Context) const
Determine the type of a (typically non-lvalue) expression with the specified result type.
Definition Type.cpp:3630
QualType getDesugaredType(const ASTContext &Context) const
Return the specified type with any "sugar" removed from the type.
Definition TypeBase.h:1302
bool isNull() const
Return true if this QualType doesn't point to a type yet.
Definition TypeBase.h:1004
const Type * getTypePtr() const
Retrieves a pointer to the underlying (unqualified) type.
Definition TypeBase.h:8431
LangAS getAddressSpace() const
Return the address space of this type.
Definition TypeBase.h:8557
QualType getNonReferenceType() const
If Type is a reference type (e.g., const int&), returns the type that the reference refers to ("const...
Definition TypeBase.h:8616
QualType getCanonicalType() const
Definition TypeBase.h:8483
QualType getUnqualifiedType() const
Retrieve the unqualified variant of the given type, removing as little sugar as possible.
Definition TypeBase.h:8525
bool hasAddressSpace() const
Check if this type has any address space qualifier.
Definition TypeBase.h:8552
Represents a struct/union/class.
Definition Decl.h:4342
field_iterator field_end() const
Definition Decl.h:4548
field_range fields() const
Definition Decl.h:4545
bool field_empty() const
Definition Decl.h:4553
field_iterator field_begin() const
Definition Decl.cpp:5277
bool hasBindingInfoForDecl(const VarDecl *VD) const
Definition SemaHLSL.cpp:220
DeclBindingInfo * getDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass)
Definition SemaHLSL.cpp:206
DeclBindingInfo * addDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass)
Definition SemaHLSL.cpp:193
Scope - A scope is a transient data structure that is used while parsing the program.
Definition Scope.h:41
SemaBase(Sema &S)
Definition SemaBase.cpp:7
ASTContext & getASTContext() const
Definition SemaBase.cpp:9
Sema & SemaRef
Definition SemaBase.h:40
SemaDiagnosticBuilder Diag(SourceLocation Loc, unsigned DiagID)
Emit a diagnostic.
Definition SemaBase.cpp:61
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg)
HLSLRootSignatureDecl * lookupRootSignatureOverrideDecl(DeclContext *DC) const
bool CanPerformElementwiseCast(Expr *Src, QualType DestType)
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL)
void handleVkLocationAttr(Decl *D, const ParsedAttr &AL)
HLSLAttributedResourceLocInfo TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT)
void handleSemanticAttr(Decl *D, const ParsedAttr &AL)
bool CanPerformScalarCast(QualType SrcTy, QualType DestTy)
QualType ProcessResourceTypeAttributes(QualType Wrapped)
void handleShaderAttr(Decl *D, const ParsedAttr &AL)
uint32_t getNextImplicitBindingOrderID()
Definition SemaHLSL.h:230
void CheckEntryPoint(FunctionDecl *FD)
Definition SemaHLSL.cpp:960
void handleVkExtBuiltinOutputAttr(Decl *D, const ParsedAttr &AL)
void emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS, BinaryOperatorKind Opc)
T * createSemanticAttr(const AttributeCommonInfo &ACI, std::optional< unsigned > Location)
Definition SemaHLSL.h:182
bool initGlobalResourceDecl(VarDecl *VD)
void ActOnEndOfTranslationUnit(TranslationUnitDecl *TU)
bool initGlobalResourceArrayDecl(VarDecl *VD)
HLSLVkConstantIdAttr * mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id)
Definition SemaHLSL.cpp:724
HLSLNumThreadsAttr * mergeNumThreadsAttr(Decl *D, const AttributeCommonInfo &AL, int X, int Y, int Z)
Definition SemaHLSL.cpp:690
void deduceAddressSpace(VarDecl *Decl)
std::pair< IdentifierInfo *, bool > ActOnStartRootSignatureDecl(StringRef Signature)
Computes the unique Root Signature identifier from the given signature, then lookup if there is a pre...
void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL)
bool diagnosePositionType(QualType T, const ParsedAttr &AL)
bool handleInitialization(VarDecl *VDecl, Expr *&Init)
bool diagnoseInputIDType(QualType T, const ParsedAttr &AL)
void handleParamModifierAttr(Decl *D, const ParsedAttr &AL)
bool CheckResourceBinOp(BinaryOperatorKind Opc, Expr *LHSExpr, Expr *RHSExpr, SourceLocation Loc)
bool CanPerformAggregateSplatCast(Expr *Src, QualType DestType)
bool ActOnResourceMemberAccessExpr(MemberExpr *ME)
bool IsScalarizedLayoutCompatible(QualType T1, QualType T2) const
QualType ActOnTemplateShorthand(TemplateDecl *Template, SourceLocation NameLoc)
void diagnoseSystemSemanticAttr(Decl *D, const ParsedAttr &AL, std::optional< unsigned > Index)
void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL)
bool CheckCompatibleParameterABI(FunctionDecl *New, FunctionDecl *Old)
QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, bool IsCompAssign)
QualType checkMatrixComponent(Sema &S, QualType baseType, ExprValueKind &VK, SourceLocation OpLoc, const IdentifierInfo *CompName, SourceLocation CompLoc)
void handleResourceBindingAttr(Decl *D, const ParsedAttr &AL)
bool IsTypedResourceElementCompatible(QualType T1)
bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init)
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL)
bool ActOnUninitializedVarDecl(VarDecl *D)
void handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL)
void ActOnTopLevelFunction(FunctionDecl *FD)
Definition SemaHLSL.cpp:793
bool handleResourceTypeAttr(QualType T, const ParsedAttr &AL)
void handleVkPushConstantAttr(Decl *D, const ParsedAttr &AL)
HLSLShaderAttr * mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL, llvm::Triple::EnvironmentType ShaderType)
Definition SemaHLSL.cpp:760
void ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace)
Definition SemaHLSL.cpp:663
void handleVkBindingAttr(Decl *D, const ParsedAttr &AL)
HLSLParamModifierAttr * mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL, HLSLParamModifierAttr::Spelling Spelling)
Definition SemaHLSL.cpp:773
QualType getInoutParameterType(QualType Ty)
SemaHLSL(Sema &S)
Definition SemaHLSL.cpp:224
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL)
Decl * ActOnStartBuffer(Scope *BufferScope, bool CBuffer, SourceLocation KwLoc, IdentifierInfo *Ident, SourceLocation IdentLoc, SourceLocation LBrace)
Definition SemaHLSL.cpp:226
HLSLWaveSizeAttr * mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL, int Min, int Max, int Preferred, int SpelledArgsCount)
Definition SemaHLSL.cpp:704
bool handleRootSignatureElements(ArrayRef< hlsl::RootSignatureElement > Elements)
void ActOnFinishRootSignatureDecl(SourceLocation Loc, IdentifierInfo *DeclIdent, ArrayRef< hlsl::RootSignatureElement > Elements)
Creates the Root Signature decl of the parsed Root Signature elements onto the AST and push it onto c...
void ActOnVariableDeclarator(VarDecl *VD)
bool CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall)
Sema - This implements semantic analysis and AST building for C.
Definition Sema.h:868
@ LookupOrdinaryName
Ordinary name lookup, which finds ordinary names (functions, variables, typedefs, etc....
Definition Sema.h:9415
@ LookupMemberName
Member name lookup, which finds the names of class/struct/union members.
Definition Sema.h:9423
ExtVectorDeclsType ExtVectorDecls
ExtVectorDecls - This is a list all the extended vector types.
Definition Sema.h:4954
ASTContext & Context
Definition Sema.h:1308
ASTContext & getASTContext() const
Definition Sema.h:939
ExprResult ImpCastExprToType(Expr *E, QualType Type, CastKind CK, ExprValueKind VK=VK_PRValue, const CXXCastPath *BasePath=nullptr, CheckedConversionKind CCK=CheckedConversionKind::Implicit)
ImpCastExprToType - If Expr is not of type 'Type', insert an implicit cast.
Definition Sema.cpp:762
const LangOptions & getLangOpts() const
Definition Sema.h:932
SemaHLSL & HLSL()
Definition Sema.h:1483
ExprResult BuildFieldReferenceExpr(Expr *BaseExpr, bool IsArrow, SourceLocation OpLoc, const CXXScopeSpec &SS, FieldDecl *Field, DeclAccessPair FoundDecl, const DeclarationNameInfo &MemberNameInfo)
bool checkArgCountRange(CallExpr *Call, unsigned MinArgCount, unsigned MaxArgCount)
Checks that a call expression's argument count is in the desired range.
ExternalSemaSource * getExternalSource() const
Definition Sema.h:942
ASTConsumer & Consumer
Definition Sema.h:1309
bool checkArgCount(CallExpr *Call, unsigned DesiredArgCount)
Checks that a call expression's argument count is the desired number.
ExprResult CreateBuiltinArraySubscriptExpr(Expr *Base, SourceLocation LLoc, Expr *Idx, SourceLocation RLoc)
bool LookupQualifiedName(LookupResult &R, DeclContext *LookupCtx, bool InUnqualifiedLookup=false)
Perform qualified name lookup into a given context.
ExprResult PerformCopyInitialization(const InitializedEntity &Entity, SourceLocation EqualLoc, ExprResult Init, bool TopLevelOfInitList=false, bool AllowExplicit=false)
ExprResult CreateBuiltinMatrixSubscriptExpr(Expr *Base, Expr *RowIdx, Expr *ColumnIdx, SourceLocation RBLoc)
Encodes a location in the source.
SourceLocation getLocWithOffset(IntTy Offset) const
Return a source location with the specified offset from this SourceLocation.
A trivial tuple used to represent a source range.
SourceLocation getEnd() const
SourceLocation getEndLoc() const LLVM_READONLY
Definition Stmt.cpp:367
void printPretty(raw_ostream &OS, PrinterHelper *Helper, const PrintingPolicy &Policy, unsigned Indentation=0, StringRef NewlineSymbol="\n", const ASTContext *Context=nullptr) const
SourceRange getSourceRange() const LLVM_READONLY
SourceLocation tokens are not useful in isolation - they are low level value objects created/interpre...
Definition Stmt.cpp:343
SourceLocation getBeginLoc() const LLVM_READONLY
Definition Stmt.cpp:355
StringLiteral - This represents a string literal expression, e.g.
Definition Expr.h:1802
static StringLiteral * Create(const ASTContext &Ctx, StringRef Str, StringLiteralKind Kind, bool Pascal, QualType Ty, ArrayRef< SourceLocation > Locs)
This is the "fully general" constructor that allows representation of strings formed from one or more...
Definition Expr.cpp:1188
void startDefinition()
Starts the definition of this tag declaration.
Definition Decl.cpp:4908
bool isUnion() const
Definition Decl.h:3943
bool isClass() const
Definition Decl.h:3942
Exposes information about the current target.
Definition TargetInfo.h:227
TargetOptions & getTargetOpts() const
Retrieve the target options.
Definition TargetInfo.h:327
const llvm::Triple & getTriple() const
Returns the target triple of the primary target.
StringRef getPlatformName() const
Retrieve the name of the platform as it is used in the availability attribute.
VersionTuple getPlatformMinVersion() const
Retrieve the minimum desired version of the platform, to which the program should be compiled.
std::string HLSLEntry
The entry point name for HLSL shader being compiled as specified by -E.
A convenient class for passing around template argument information.
void addArgument(const TemplateArgumentLoc &Loc)
The base class of all kinds of template declarations (e.g., class, function, etc.).
Stores a list of template parameters for a TemplateDecl and its derived classes.
The top declaration context.
Definition Decl.h:105
SourceLocation getBeginLoc() const
Get the begin source location.
Definition TypeLoc.cpp:193
A container of type source information.
Definition TypeBase.h:8402
TypeLoc getTypeLoc() const
Return the TypeLoc wrapper for the type source info.
Definition TypeLoc.h:267
The base class of the type hierarchy.
Definition TypeBase.h:1866
bool isVoidType() const
Definition TypeBase.h:9034
bool isBooleanType() const
Definition TypeBase.h:9171
bool isIncompleteArrayType() const
Definition TypeBase.h:8775
CXXRecordDecl * getAsCXXRecordDecl() const
Retrieves the CXXRecordDecl that this type refers to, either because the type is a RecordType or beca...
Definition Type.h:26
bool isConstantArrayType() const
Definition TypeBase.h:8771
bool hasIntegerRepresentation() const
Determine whether this type has an integer representation of some sort, e.g., it is an integer type o...
Definition Type.cpp:2084
bool isArrayType() const
Definition TypeBase.h:8767
CXXRecordDecl * castAsCXXRecordDecl() const
Definition Type.h:36
bool isArithmeticType() const
Definition Type.cpp:2375
bool isConstantMatrixType() const
Definition TypeBase.h:8835
bool isHLSLBuiltinIntangibleType() const
Definition TypeBase.h:8979
bool isPointerType() const
Definition TypeBase.h:8668
CanQualType getCanonicalTypeUnqualified() const
bool isIntegerType() const
isIntegerType() does not include complex integers (a GCC extension).
Definition TypeBase.h:9078
const T * castAs() const
Member-template castAs<specific type>.
Definition TypeBase.h:9328
bool isReferenceType() const
Definition TypeBase.h:8692
bool isHLSLIntangibleType() const
Definition Type.cpp:5461
bool isEnumeralType() const
Definition TypeBase.h:8799
bool isScalarType() const
Definition TypeBase.h:9140
bool isIntegralType(const ASTContext &Ctx) const
Determine whether this type is an integral type.
Definition Type.cpp:2121
const Type * getArrayElementTypeNoTypeQual() const
If this is an array type, return the element type of the array, potentially with type qualifiers miss...
Definition Type.cpp:473
QualType getPointeeType() const
If this is a pointer, ObjC object pointer, or block pointer, this returns the respective pointee.
Definition Type.cpp:754
bool hasUnsignedIntegerRepresentation() const
Determine whether this type has an unsigned integer representation of some sort, e....
Definition Type.cpp:2329
bool isAggregateType() const
Determines whether the type is a C++ aggregate type or C aggregate or union type.
Definition Type.cpp:2456
ScalarTypeKind getScalarTypeKind() const
Given that this is a scalar type, classify it.
Definition Type.cpp:2407
bool hasSignedIntegerRepresentation() const
Determine whether this type has an signed integer representation of some sort, e.g....
Definition Type.cpp:2275
bool isMatrixType() const
Definition TypeBase.h:8831
bool isHLSLResourceRecord() const
Definition Type.cpp:5448
bool hasFloatingRepresentation() const
Determine whether this type has a floating-point representation of some sort, e.g....
Definition Type.cpp:2350
bool isVectorType() const
Definition TypeBase.h:8807
bool isRealFloatingType() const
Floating point categories.
Definition Type.cpp:2358
bool isHLSLAttributedResourceType() const
Definition TypeBase.h:8991
@ STK_FloatingComplex
Definition TypeBase.h:2814
@ STK_ObjCObjectPointer
Definition TypeBase.h:2808
@ STK_IntegralComplex
Definition TypeBase.h:2813
@ STK_MemberPointer
Definition TypeBase.h:2809
bool isFloatingType() const
Definition Type.cpp:2342
bool isSamplerT() const
Definition TypeBase.h:8912
const T * getAs() const
Member-template getAs<specific type>'.
Definition TypeBase.h:9261
const Type * getUnqualifiedDesugaredType() const
Return the specified type with any "sugar" removed from the type, removing any typedefs,...
Definition Type.cpp:655
bool isRecordType() const
Definition TypeBase.h:8795
bool isHLSLResourceRecordArray() const
Definition Type.cpp:5452
void setType(QualType newType)
Definition Decl.h:724
QualType getType() const
Definition Decl.h:723
Represents a variable declaration or definition.
Definition Decl.h:926
static VarDecl * Create(ASTContext &C, DeclContext *DC, SourceLocation StartLoc, SourceLocation IdLoc, const IdentifierInfo *Id, QualType T, TypeSourceInfo *TInfo, StorageClass S)
Definition Decl.cpp:2164
void setInitStyle(InitializationStyle Style)
Definition Decl.h:1467
@ CallInit
Call-style initialization (C++98)
Definition Decl.h:934
void setStorageClass(StorageClass SC)
Definition Decl.cpp:2176
bool hasGlobalStorage() const
Returns true for all variables that do not have local storage.
Definition Decl.h:1241
void setInit(Expr *I)
Definition Decl.cpp:2490
StorageClass getStorageClass() const
Returns the storage class as written in the source.
Definition Decl.h:1168
Represents a GCC generic vector type.
Definition TypeBase.h:4225
unsigned getNumElements() const
Definition TypeBase.h:4240
QualType getElementType() const
Definition TypeBase.h:4239
IdentifierInfo * getNameAsIdentifier(ASTContext &AST) const
Defines the clang::TargetInfo interface.
Definition SPIR.cpp:47
uint32_t getResourceDimensions(llvm::dxil::ResourceDimension Dim)
The JSON file list parser is used to communicate input to InstallAPI.
bool isa(CodeGen::Address addr)
Definition Address.h:330
if(T->getSizeExpr()) TRY_TO(TraverseStmt(const_cast< Expr * >(T -> getSizeExpr())))
static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
Definition SemaSPIRV.cpp:66
@ ICIS_NoInit
No in-class initializer.
Definition Specifiers.h:273
@ TemplateName
The identifier is a template name. FIXME: Add an annotation for that.
Definition Parser.h:61
@ OK_Ordinary
An ordinary object is located at an address in memory.
Definition Specifiers.h:152
static bool CheckAllArgTypesAreCorrect(Sema *S, CallExpr *TheCall, llvm::ArrayRef< llvm::function_ref< bool(Sema *, SourceLocation, int, QualType)> > Checks)
Definition SemaSPIRV.cpp:49
@ AS_public
Definition Specifiers.h:125
@ AS_none
Definition Specifiers.h:128
@ SC_Extern
Definition Specifiers.h:252
@ SC_Static
Definition Specifiers.h:253
@ SC_None
Definition Specifiers.h:251
@ AANT_ArgumentIdentifier
@ Result
The result type of a method or function.
Definition TypeBase.h:905
@ Ordinary
This parameter uses ordinary ABI rules for its type.
Definition Specifiers.h:383
llvm::Expected< QualType > ExpectedType
@ Template
We are parsing a template declaration.
Definition Parser.h:81
LLVM_READONLY bool isDigit(unsigned char c)
Return true if this character is an ASCII digit: [0-9].
Definition CharInfo.h:114
static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall)
Definition SemaSPIRV.cpp:32
ExprResult ExprError()
Definition Ownership.h:265
@ Type
The name was classified as a type.
Definition Sema.h:564
LangAS
Defines the address space values used by the address space qualifier of QualType.
bool CreateHLSLAttributedResourceType(Sema &S, QualType Wrapped, ArrayRef< const Attr * > AttrList, QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo=nullptr)
CastKind
CastKind - The kind of operation required for a conversion.
ExprValueKind
The categorization of expression values, currently following the C++11 scheme.
Definition Specifiers.h:133
@ VK_PRValue
A pr-value expression (in the C++11 taxonomy) produces a temporary value.
Definition Specifiers.h:136
@ VK_LValue
An l-value expression is a reference to an object with independent storage.
Definition Specifiers.h:140
DynamicRecursiveASTVisitorBase< false > DynamicRecursiveASTVisitor
U cast(CodeGen::Address addr)
Definition Address.h:327
@ None
No keyword precedes the qualified type name.
Definition TypeBase.h:5977
ActionResult< Expr * > ExprResult
Definition Ownership.h:249
Visibility
Describes the different kinds of visibility that a declaration may have.
Definition Visibility.h:34
unsigned long uint64_t
unsigned int uint32_t
hash_code hash_value(const clang::dependencies::ModuleID &ID)
__DEVICE__ bool isnan(float __x)
__DEVICE__ _Tp abs(const std::complex< _Tp > &__c)
#define false
Definition stdbool.h:26
Describes how types, statements, expressions, and declarations should be printed.
void setCounterImplicitOrderID(unsigned Value) const
void setImplicitOrderID(unsigned Value) const
const SourceLocation & getLocation() const
Definition SemaHLSL.h:48
const llvm::hlsl::rootsig::RootElement & getElement() const
Definition SemaHLSL.h:47