clang 23.0.0git
SemaHLSL.cpp
Go to the documentation of this file.
1//===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This implements Semantic Analysis for HLSL constructs.
9//===----------------------------------------------------------------------===//
10
11#include "clang/Sema/SemaHLSL.h"
14#include "clang/AST/Attr.h"
15#include "clang/AST/Decl.h"
16#include "clang/AST/DeclBase.h"
17#include "clang/AST/DeclCXX.h"
20#include "clang/AST/Expr.h"
22#include "clang/AST/Type.h"
23#include "clang/AST/TypeBase.h"
24#include "clang/AST/TypeLoc.h"
28#include "clang/Basic/LLVM.h"
33#include "clang/Sema/Lookup.h"
35#include "clang/Sema/Sema.h"
36#include "clang/Sema/Template.h"
37#include "llvm/ADT/ArrayRef.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/StringExtras.h"
41#include "llvm/ADT/StringRef.h"
42#include "llvm/ADT/Twine.h"
43#include "llvm/Frontend/HLSL/HLSLBinding.h"
44#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
45#include "llvm/Support/Casting.h"
46#include "llvm/Support/DXILABI.h"
47#include "llvm/Support/ErrorHandling.h"
48#include "llvm/Support/FormatVariadic.h"
49#include "llvm/TargetParser/Triple.h"
50#include <cmath>
51#include <cstddef>
52#include <iterator>
53#include <utility>
54
55using namespace clang;
56using namespace clang::hlsl;
57using RegisterType = HLSLResourceBindingAttr::RegisterType;
58
60 CXXRecordDecl *StructDecl);
61
63 switch (RC) {
64 case ResourceClass::SRV:
65 return RegisterType::SRV;
66 case ResourceClass::UAV:
67 return RegisterType::UAV;
68 case ResourceClass::CBuffer:
69 return RegisterType::CBuffer;
70 case ResourceClass::Sampler:
71 return RegisterType::Sampler;
72 }
73 llvm_unreachable("unexpected ResourceClass value");
74}
75
76static RegisterType getRegisterType(const HLSLAttributedResourceType *ResTy) {
77 return getRegisterType(ResTy->getAttrs().ResourceClass);
78}
79
80// Converts the first letter of string Slot to RegisterType.
81// Returns false if the letter does not correspond to a valid register type.
82static bool convertToRegisterType(StringRef Slot, RegisterType *RT) {
83 assert(RT != nullptr);
84 switch (Slot[0]) {
85 case 't':
86 case 'T':
87 *RT = RegisterType::SRV;
88 return true;
89 case 'u':
90 case 'U':
91 *RT = RegisterType::UAV;
92 return true;
93 case 'b':
94 case 'B':
95 *RT = RegisterType::CBuffer;
96 return true;
97 case 's':
98 case 'S':
99 *RT = RegisterType::Sampler;
100 return true;
101 case 'c':
102 case 'C':
103 *RT = RegisterType::C;
104 return true;
105 case 'i':
106 case 'I':
107 *RT = RegisterType::I;
108 return true;
109 default:
110 return false;
111 }
112}
113
115 switch (RT) {
116 case RegisterType::SRV:
117 return 't';
118 case RegisterType::UAV:
119 return 'u';
120 case RegisterType::CBuffer:
121 return 'b';
122 case RegisterType::Sampler:
123 return 's';
124 case RegisterType::C:
125 return 'c';
126 case RegisterType::I:
127 return 'i';
128 }
129 llvm_unreachable("unexpected RegisterType value");
130}
131
133 switch (RT) {
134 case RegisterType::SRV:
135 return ResourceClass::SRV;
136 case RegisterType::UAV:
137 return ResourceClass::UAV;
138 case RegisterType::CBuffer:
139 return ResourceClass::CBuffer;
140 case RegisterType::Sampler:
141 return ResourceClass::Sampler;
142 case RegisterType::C:
143 case RegisterType::I:
144 // Deliberately falling through to the unreachable below.
145 break;
146 }
147 llvm_unreachable("unexpected RegisterType value");
148}
149
151 const auto *BT = dyn_cast<BuiltinType>(Type);
152 if (!BT) {
153 if (!Type->isEnumeralType())
154 return Builtin::NotBuiltin;
155 return Builtin::BI__builtin_get_spirv_spec_constant_int;
156 }
157
158 switch (BT->getKind()) {
159 case BuiltinType::Bool:
160 return Builtin::BI__builtin_get_spirv_spec_constant_bool;
161 case BuiltinType::Short:
162 return Builtin::BI__builtin_get_spirv_spec_constant_short;
163 case BuiltinType::Int:
164 return Builtin::BI__builtin_get_spirv_spec_constant_int;
165 case BuiltinType::LongLong:
166 return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
167 case BuiltinType::UShort:
168 return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
169 case BuiltinType::UInt:
170 return Builtin::BI__builtin_get_spirv_spec_constant_uint;
171 case BuiltinType::ULongLong:
172 return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
173 case BuiltinType::Half:
174 return Builtin::BI__builtin_get_spirv_spec_constant_half;
175 case BuiltinType::Float:
176 return Builtin::BI__builtin_get_spirv_spec_constant_float;
177 case BuiltinType::Double:
178 return Builtin::BI__builtin_get_spirv_spec_constant_double;
179 default:
180 return Builtin::NotBuiltin;
181 }
182}
183
184static StringRef createRegisterString(ASTContext &AST, RegisterType RegType,
185 unsigned N) {
187 llvm::raw_svector_ostream OS(Buffer);
188 OS << getRegisterTypeChar(RegType);
189 OS << N;
190 return AST.backupStr(OS.str());
191}
192
194 ResourceClass ResClass) {
195 assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
196 "DeclBindingInfo already added");
197 assert(!hasBindingInfoForDecl(VD) || BindingsList.back().Decl == VD);
198 // VarDecl may have multiple entries for different resource classes.
199 // DeclToBindingListIndex stores the index of the first binding we saw
200 // for this decl. If there are any additional ones then that index
201 // shouldn't be updated.
202 DeclToBindingListIndex.try_emplace(VD, BindingsList.size());
203 return &BindingsList.emplace_back(VD, ResClass);
204}
205
207 ResourceClass ResClass) {
208 auto Entry = DeclToBindingListIndex.find(VD);
209 if (Entry != DeclToBindingListIndex.end()) {
210 for (unsigned Index = Entry->getSecond();
211 Index < BindingsList.size() && BindingsList[Index].Decl == VD;
212 ++Index) {
213 if (BindingsList[Index].ResClass == ResClass)
214 return &BindingsList[Index];
215 }
216 }
217 return nullptr;
218}
219
221 return DeclToBindingListIndex.contains(VD);
222}
223
225
226Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
227 SourceLocation KwLoc, IdentifierInfo *Ident,
228 SourceLocation IdentLoc,
229 SourceLocation LBrace) {
230 // For anonymous namespace, take the location of the left brace.
231 DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
233 getASTContext(), LexicalParent, CBuffer, KwLoc, Ident, IdentLoc, LBrace);
234
235 // if CBuffer is false, then it's a TBuffer
236 auto RC = CBuffer ? llvm::hlsl::ResourceClass::CBuffer
237 : llvm::hlsl::ResourceClass::SRV;
238 Result->addAttr(HLSLResourceClassAttr::CreateImplicit(getASTContext(), RC));
239
240 SemaRef.PushOnScopeChains(Result, BufferScope);
241 SemaRef.PushDeclContext(BufferScope, Result);
242
243 return Result;
244}
245
246static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context,
247 QualType T) {
248 // Arrays, Matrices, and Structs are always aligned to new buffer rows
249 if (T->isArrayType() || T->isStructureType() || T->isConstantMatrixType())
250 return 16;
251
252 // Vectors are aligned to the type they contain
253 if (const VectorType *VT = T->getAs<VectorType>())
254 return calculateLegacyCbufferFieldAlign(Context, VT->getElementType());
255
256 assert(Context.getTypeSize(T) <= 64 &&
257 "Scalar bit widths larger than 64 not supported");
258
259 // Scalar types are aligned to their byte width
260 return Context.getTypeSize(T) / 8;
261}
262
263// Calculate the size of a legacy cbuffer type in bytes based on
264// https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules
265static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
266 QualType T) {
267 constexpr unsigned CBufferAlign = 16;
268 if (const auto *RD = T->getAsRecordDecl()) {
269 unsigned Size = 0;
270 for (const FieldDecl *Field : RD->fields()) {
271 QualType Ty = Field->getType();
272 unsigned FieldSize = calculateLegacyCbufferSize(Context, Ty);
273 unsigned FieldAlign = calculateLegacyCbufferFieldAlign(Context, Ty);
274
275 // If the field crosses the row boundary after alignment it drops to the
276 // next row
277 unsigned AlignSize = llvm::alignTo(Size, FieldAlign);
278 if ((AlignSize % CBufferAlign) + FieldSize > CBufferAlign) {
279 FieldAlign = CBufferAlign;
280 }
281
282 Size = llvm::alignTo(Size, FieldAlign);
283 Size += FieldSize;
284 }
285 return Size;
286 }
287
288 if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
289 unsigned ElementCount = AT->getSize().getZExtValue();
290 if (ElementCount == 0)
291 return 0;
292
293 unsigned ElementSize =
294 calculateLegacyCbufferSize(Context, AT->getElementType());
295 unsigned AlignedElementSize = llvm::alignTo(ElementSize, CBufferAlign);
296 return AlignedElementSize * (ElementCount - 1) + ElementSize;
297 }
298
299 if (const VectorType *VT = T->getAs<VectorType>()) {
300 unsigned ElementCount = VT->getNumElements();
301 unsigned ElementSize =
302 calculateLegacyCbufferSize(Context, VT->getElementType());
303 return ElementSize * ElementCount;
304 }
305
306 return Context.getTypeSize(T) / 8;
307}
308
309// Validate packoffset:
310// - if packoffset it used it must be set on all declarations inside the buffer
311// - packoffset ranges must not overlap
312static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) {
314
315 // Make sure the packoffset annotations are either on all declarations
316 // or on none.
317 bool HasPackOffset = false;
318 bool HasNonPackOffset = false;
319 for (auto *Field : BufDecl->buffer_decls()) {
320 VarDecl *Var = dyn_cast<VarDecl>(Field);
321 if (!Var)
322 continue;
323 if (Field->hasAttr<HLSLPackOffsetAttr>()) {
324 PackOffsetVec.emplace_back(Var, Field->getAttr<HLSLPackOffsetAttr>());
325 HasPackOffset = true;
326 } else {
327 HasNonPackOffset = true;
328 }
329 }
330
331 if (!HasPackOffset)
332 return;
333
334 if (HasNonPackOffset)
335 S.Diag(BufDecl->getLocation(), diag::warn_hlsl_packoffset_mix);
336
337 // Make sure there is no overlap in packoffset - sort PackOffsetVec by offset
338 // and compare adjacent values.
339 bool IsValid = true;
340 ASTContext &Context = S.getASTContext();
341 std::sort(PackOffsetVec.begin(), PackOffsetVec.end(),
342 [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,
343 const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {
344 return LHS.second->getOffsetInBytes() <
345 RHS.second->getOffsetInBytes();
346 });
347 for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {
348 VarDecl *Var = PackOffsetVec[i].first;
349 HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;
350 unsigned Size = calculateLegacyCbufferSize(Context, Var->getType());
351 unsigned Begin = Attr->getOffsetInBytes();
352 unsigned End = Begin + Size;
353 unsigned NextBegin = PackOffsetVec[i + 1].second->getOffsetInBytes();
354 if (End > NextBegin) {
355 VarDecl *NextVar = PackOffsetVec[i + 1].first;
356 S.Diag(NextVar->getLocation(), diag::err_hlsl_packoffset_overlap)
357 << NextVar << Var;
358 IsValid = false;
359 }
360 }
361 BufDecl->setHasValidPackoffset(IsValid);
362}
363
364// Returns true if the array has a zero size = if any of the dimensions is 0
365static bool isZeroSizedArray(const ConstantArrayType *CAT) {
366 while (CAT && !CAT->isZeroSize())
367 CAT = dyn_cast<ConstantArrayType>(
369 return CAT != nullptr;
370}
371
375
379
380static const HLSLAttributedResourceType *
382 assert(QT->isHLSLResourceRecordArray() &&
383 "expected array of resource records");
384 const Type *Ty = QT->getUnqualifiedDesugaredType();
385 while (const ArrayType *AT = dyn_cast<ArrayType>(Ty))
387 return HLSLAttributedResourceType::findHandleTypeOnResource(Ty);
388}
389
390static const HLSLAttributedResourceType *
394
395// Returns true if the type is a leaf element type that is not valid to be
396// included in HLSL Buffer, such as a resource class, empty struct, zero-sized
397// array, or a builtin intangible type. Returns false it is a valid leaf element
398// type or if it is a record type that needs to be inspected further.
402 return true;
403 if (const auto *RD = Ty->getAsCXXRecordDecl())
404 return RD->isEmpty();
405 if (Ty->isConstantArrayType() &&
407 return true;
409 return true;
410 return false;
411}
412
413// Returns true if the struct contains at least one element that prevents it
414// from being included inside HLSL Buffer as is, such as an intangible type,
415// empty struct, or zero-sized array. If it does, a new implicit layout struct
416// needs to be created for HLSL Buffer use that will exclude these unwanted
417// declarations (see createHostLayoutStruct function).
419 if (RD->isHLSLIntangible() || RD->isEmpty())
420 return true;
421 // check fields
422 for (const FieldDecl *Field : RD->fields()) {
423 QualType Ty = Field->getType();
425 return true;
426 if (const auto *RD = Ty->getAsCXXRecordDecl();
428 return true;
429 }
430 // check bases
431 for (const CXXBaseSpecifier &Base : RD->bases())
433 Base.getType()->castAsCXXRecordDecl()))
434 return true;
435 return false;
436}
437
439 DeclContext *DC) {
440 CXXRecordDecl *RD = nullptr;
441 for (NamedDecl *Decl :
443 if (CXXRecordDecl *FoundRD = dyn_cast<CXXRecordDecl>(Decl)) {
444 assert(RD == nullptr &&
445 "there should be at most 1 record by a given name in a scope");
446 RD = FoundRD;
447 }
448 }
449 return RD;
450}
451
452// Creates a name for buffer layout struct using the provide name base.
453// If the name must be unique (not previously defined), a suffix is added
454// until a unique name is found.
456 bool MustBeUnique) {
457 ASTContext &AST = S.getASTContext();
458
459 IdentifierInfo *NameBaseII = BaseDecl->getIdentifier();
460 llvm::SmallString<64> Name("__cblayout_");
461 if (NameBaseII) {
462 Name.append(NameBaseII->getName());
463 } else {
464 // anonymous struct
465 Name.append("anon");
466 MustBeUnique = true;
467 }
468
469 size_t NameLength = Name.size();
470 IdentifierInfo *II = &AST.Idents.get(Name, tok::TokenKind::identifier);
471 if (!MustBeUnique)
472 return II;
473
474 unsigned suffix = 0;
475 while (true) {
476 if (suffix != 0) {
477 Name.append("_");
478 Name.append(llvm::Twine(suffix).str());
479 II = &AST.Idents.get(Name, tok::TokenKind::identifier);
480 }
481 if (!findRecordDeclInContext(II, BaseDecl->getDeclContext()))
482 return II;
483 // declaration with that name already exists - increment suffix and try
484 // again until unique name is found
485 suffix++;
486 Name.truncate(NameLength);
487 };
488}
489
490static const Type *createHostLayoutType(Sema &S, const Type *Ty) {
491 ASTContext &AST = S.getASTContext();
492 if (auto *RD = Ty->getAsCXXRecordDecl()) {
494 return Ty;
495 RD = createHostLayoutStruct(S, RD);
496 if (!RD)
497 return nullptr;
498 return AST.getCanonicalTagType(RD)->getTypePtr();
499 }
500
501 if (const auto *CAT = dyn_cast<ConstantArrayType>(Ty)) {
502 const Type *ElementTy = createHostLayoutType(
503 S, CAT->getElementType()->getUnqualifiedDesugaredType());
504 if (!ElementTy)
505 return nullptr;
506 return AST
507 .getConstantArrayType(QualType(ElementTy, 0), CAT->getSize(), nullptr,
508 CAT->getSizeModifier(),
509 CAT->getIndexTypeCVRQualifiers())
510 .getTypePtr();
511 }
512 return Ty;
513}
514
515// Creates a field declaration of given name and type for HLSL buffer layout
516// struct. Returns nullptr if the type cannot be use in HLSL Buffer layout.
518 IdentifierInfo *II,
519 CXXRecordDecl *LayoutStruct) {
521 return nullptr;
522
523 Ty = createHostLayoutType(S, Ty);
524 if (!Ty)
525 return nullptr;
526
527 QualType QT = QualType(Ty, 0);
528 ASTContext &AST = S.getASTContext();
530 auto *Field = FieldDecl::Create(AST, LayoutStruct, SourceLocation(),
531 SourceLocation(), II, QT, TSI, nullptr, false,
533 Field->setAccess(AccessSpecifier::AS_public);
534 return Field;
535}
536
537// Creates host layout struct for a struct included in HLSL Buffer.
538// The layout struct will include only fields that are allowed in HLSL buffer.
539// These fields will be filtered out:
540// - resource classes
541// - empty structs
542// - zero-sized arrays
543// Returns nullptr if the resulting layout struct would be empty.
545 CXXRecordDecl *StructDecl) {
546 assert(requiresImplicitBufferLayoutStructure(StructDecl) &&
547 "struct is already HLSL buffer compatible");
548
549 ASTContext &AST = S.getASTContext();
550 DeclContext *DC = StructDecl->getDeclContext();
551 IdentifierInfo *II = getHostLayoutStructName(S, StructDecl, false);
552
553 // reuse existing if the layout struct if it already exists
554 if (CXXRecordDecl *RD = findRecordDeclInContext(II, DC))
555 return RD;
556
557 CXXRecordDecl *LS =
558 CXXRecordDecl::Create(AST, TagDecl::TagKind::Struct, DC, SourceLocation(),
559 SourceLocation(), II);
560 LS->setImplicit(true);
561 LS->addAttr(PackedAttr::CreateImplicit(AST));
562 LS->startDefinition();
563
564 // copy base struct, create HLSL Buffer compatible version if needed
565 if (unsigned NumBases = StructDecl->getNumBases()) {
566 assert(NumBases == 1 && "HLSL supports only one base type");
567 (void)NumBases;
568 CXXBaseSpecifier Base = *StructDecl->bases_begin();
569 CXXRecordDecl *BaseDecl = Base.getType()->castAsCXXRecordDecl();
571 BaseDecl = createHostLayoutStruct(S, BaseDecl);
572 if (BaseDecl) {
573 TypeSourceInfo *TSI =
575 Base = CXXBaseSpecifier(SourceRange(), false, StructDecl->isClass(),
576 AS_none, TSI, SourceLocation());
577 }
578 }
579 if (BaseDecl) {
580 const CXXBaseSpecifier *BasesArray[1] = {&Base};
581 LS->setBases(BasesArray, 1);
582 }
583 }
584
585 // filter struct fields
586 for (const FieldDecl *FD : StructDecl->fields()) {
587 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
588 if (FieldDecl *NewFD =
589 createFieldForHostLayoutStruct(S, Ty, FD->getIdentifier(), LS))
590 LS->addDecl(NewFD);
591 }
592 LS->completeDefinition();
593
594 if (LS->field_empty() && LS->getNumBases() == 0)
595 return nullptr;
596
597 DC->addDecl(LS);
598 return LS;
599}
600
601// Creates host layout struct for HLSL Buffer. The struct will include only
602// fields of types that are allowed in HLSL buffer and it will filter out:
603// - static or groupshared variable declarations
604// - resource classes
605// - empty structs
606// - zero-sized arrays
607// - non-variable declarations
608// The layout struct will be added to the HLSLBufferDecl declarations.
610 ASTContext &AST = S.getASTContext();
611 IdentifierInfo *II = getHostLayoutStructName(S, BufDecl, true);
612
613 CXXRecordDecl *LS =
614 CXXRecordDecl::Create(AST, TagDecl::TagKind::Struct, BufDecl,
616 LS->addAttr(PackedAttr::CreateImplicit(AST));
617 LS->setImplicit(true);
618 LS->startDefinition();
619
620 for (Decl *D : BufDecl->buffer_decls()) {
621 VarDecl *VD = dyn_cast<VarDecl>(D);
622 if (!VD || VD->getStorageClass() == SC_Static ||
624 continue;
625 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
626
627 FieldDecl *FD =
629 // Declarations collected for the default $Globals constant buffer have
630 // already been checked to have non-empty cbuffer layout, so
631 // createFieldForHostLayoutStruct should always succeed. These declarations
632 // already have their address space set to hlsl_constant.
633 // For declarations in a named cbuffer block
634 // createFieldForHostLayoutStruct can still return nullptr if the type
635 // is empty (does not have a cbuffer layout).
636 assert((FD || VD->getType().getAddressSpace() != LangAS::hlsl_constant) &&
637 "host layout field for $Globals decl failed to be created");
638 if (FD) {
639 // Add the field decl to the layout struct.
640 LS->addDecl(FD);
642 // Update address space of the original decl to hlsl_constant.
643 QualType NewTy =
645 VD->setType(NewTy);
646 }
647 }
648 }
649 LS->completeDefinition();
650 BufDecl->addLayoutStruct(LS);
651}
652
654 uint32_t ImplicitBindingOrderID) {
655 auto *Attr =
656 HLSLResourceBindingAttr::CreateImplicit(S.getASTContext(), "", "0", {});
657 Attr->setBinding(RT, std::nullopt, 0);
658 Attr->setImplicitBindingOrderID(ImplicitBindingOrderID);
659 D->addAttr(Attr);
660}
661
662// Handle end of cbuffer/tbuffer declaration
664 auto *BufDecl = cast<HLSLBufferDecl>(Dcl);
665 BufDecl->setRBraceLoc(RBrace);
666
667 validatePackoffset(SemaRef, BufDecl);
668
670
671 // Handle implicit binding if needed.
672 ResourceBindingAttrs ResourceAttrs(Dcl);
673 if (!ResourceAttrs.isExplicit()) {
674 SemaRef.Diag(Dcl->getLocation(), diag::warn_hlsl_implicit_binding);
675 // Use HLSLResourceBindingAttr to transfer implicit binding order_ID
676 // to codegen. If it does not exist, create an implicit attribute.
677 uint32_t OrderID = getNextImplicitBindingOrderID();
678 if (ResourceAttrs.hasBinding())
679 ResourceAttrs.setImplicitOrderID(OrderID);
680 else
682 BufDecl->isCBuffer() ? RegisterType::CBuffer
683 : RegisterType::SRV,
684 OrderID);
685 }
686
687 SemaRef.PopDeclContext();
688}
689
690HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
691 const AttributeCommonInfo &AL,
692 int X, int Y, int Z) {
693 if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
694 if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
695 Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
696 Diag(AL.getLoc(), diag::note_conflicting_attribute);
697 }
698 return nullptr;
699 }
700 return ::new (getASTContext())
701 HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
702}
703
705 const AttributeCommonInfo &AL,
706 int Min, int Max, int Preferred,
707 int SpelledArgsCount) {
708 if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) {
709 if (WS->getMin() != Min || WS->getMax() != Max ||
710 WS->getPreferred() != Preferred ||
711 WS->getSpelledArgsCount() != SpelledArgsCount) {
712 Diag(WS->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
713 Diag(AL.getLoc(), diag::note_conflicting_attribute);
714 }
715 return nullptr;
716 }
717 HLSLWaveSizeAttr *Result = ::new (getASTContext())
718 HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
719 Result->setSpelledArgsCount(SpelledArgsCount);
720 return Result;
721}
722
723HLSLVkConstantIdAttr *
725 int Id) {
726
728 if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
729 Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL;
730 return nullptr;
731 }
732
733 auto *VD = cast<VarDecl>(D);
734
735 if (getSpecConstBuiltinId(VD->getType()->getUnqualifiedDesugaredType()) ==
737 Diag(VD->getLocation(), diag::err_specialization_const);
738 return nullptr;
739 }
740
741 if (!VD->getType().isConstQualified()) {
742 Diag(VD->getLocation(), diag::err_specialization_const);
743 return nullptr;
744 }
745
746 if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
747 if (CI->getId() != Id) {
748 Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
749 Diag(AL.getLoc(), diag::note_conflicting_attribute);
750 }
751 return nullptr;
752 }
753
754 HLSLVkConstantIdAttr *Result =
755 ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
756 return Result;
757}
758
759HLSLShaderAttr *
761 llvm::Triple::EnvironmentType ShaderType) {
762 if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
763 if (NT->getType() != ShaderType) {
764 Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
765 Diag(AL.getLoc(), diag::note_conflicting_attribute);
766 }
767 return nullptr;
768 }
769 return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL);
770}
771
772HLSLParamModifierAttr *
774 HLSLParamModifierAttr::Spelling Spelling) {
775 // We can only merge an `in` attribute with an `out` attribute. All other
776 // combinations of duplicated attributes are ill-formed.
777 if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
778 if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
779 (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
780 D->dropAttr<HLSLParamModifierAttr>();
781 SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
782 return HLSLParamModifierAttr::Create(
783 getASTContext(), /*MergedSpelling=*/true, AdjustedRange,
784 HLSLParamModifierAttr::Keyword_inout);
785 }
786 Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
787 Diag(PA->getLocation(), diag::note_conflicting_attribute);
788 return nullptr;
789 }
790 return HLSLParamModifierAttr::Create(getASTContext(), AL);
791}
792
795
797 return;
798
799 // If we have specified a root signature to override the entry function then
800 // attach it now
801 HLSLRootSignatureDecl *SignatureDecl =
803 if (SignatureDecl) {
804 FD->dropAttr<RootSignatureAttr>();
805 // We could look up the SourceRange of the macro here as well
806 AttributeCommonInfo AL(RootSigOverrideIdent, AttributeScopeInfo(),
807 SourceRange(), ParsedAttr::Form::Microsoft());
808 FD->addAttr(::new (getASTContext()) RootSignatureAttr(
809 getASTContext(), AL, RootSigOverrideIdent, SignatureDecl));
810 }
811
812 llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
813 if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) {
814 if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
815 // The entry point is already annotated - check that it matches the
816 // triple.
817 if (Shader->getType() != Env) {
818 Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
819 << Shader;
820 FD->setInvalidDecl();
821 }
822 } else {
823 // Implicitly add the shader attribute if the entry function isn't
824 // explicitly annotated.
825 FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env,
826 FD->getBeginLoc()));
827 }
828 } else {
829 switch (Env) {
830 case llvm::Triple::UnknownEnvironment:
831 case llvm::Triple::Library:
832 break;
833 case llvm::Triple::RootSignature:
834 llvm_unreachable("rootsig environment has no functions");
835 default:
836 llvm_unreachable("Unhandled environment in triple");
837 }
838 }
839}
840
841static bool isVkPipelineBuiltin(const ASTContext &AstContext, FunctionDecl *FD,
842 HLSLAppliedSemanticAttr *Semantic,
843 bool IsInput) {
844 if (AstContext.getTargetInfo().getTriple().getOS() != llvm::Triple::Vulkan)
845 return false;
846
847 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
848 assert(ShaderAttr && "Entry point has no shader attribute");
849 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
850 auto SemanticName = Semantic->getSemanticName().upper();
851
852 // The SV_Position semantic is lowered to:
853 // - Position built-in for vertex output.
854 // - FragCoord built-in for fragment input.
855 if (SemanticName == "SV_POSITION") {
856 return (ST == llvm::Triple::Vertex && !IsInput) ||
857 (ST == llvm::Triple::Pixel && IsInput);
858 }
859 if (SemanticName == "SV_VERTEXID")
860 return true;
861
862 return false;
863}
864
865bool SemaHLSL::determineActiveSemanticOnScalar(FunctionDecl *FD,
866 DeclaratorDecl *OutputDecl,
868 SemanticInfo &ActiveSemantic,
869 SemaHLSL::SemanticContext &SC) {
870 if (ActiveSemantic.Semantic == nullptr) {
871 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
872 if (ActiveSemantic.Semantic)
873 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
874 }
875
876 if (!ActiveSemantic.Semantic) {
877 Diag(D->getLocation(), diag::err_hlsl_missing_semantic_annotation);
878 return false;
879 }
880
881 auto *A = ::new (getASTContext())
882 HLSLAppliedSemanticAttr(getASTContext(), *ActiveSemantic.Semantic,
883 ActiveSemantic.Semantic->getAttrName()->getName(),
884 ActiveSemantic.Index.value_or(0));
885 if (!A)
887
888 checkSemanticAnnotation(FD, D, A, SC);
889 OutputDecl->addAttr(A);
890
891 unsigned Location = ActiveSemantic.Index.value_or(0);
892
894 SC.CurrentIOType & IOType::In)) {
895 bool HasVkLocation = false;
896 if (auto *A = D->getAttr<HLSLVkLocationAttr>()) {
897 HasVkLocation = true;
898 Location = A->getLocation();
899 }
900
901 if (SC.UsesExplicitVkLocations.value_or(HasVkLocation) != HasVkLocation) {
902 Diag(D->getLocation(), diag::err_hlsl_semantic_partial_explicit_indexing);
903 return false;
904 }
905 SC.UsesExplicitVkLocations = HasVkLocation;
906 }
907
908 const ConstantArrayType *AT = dyn_cast<ConstantArrayType>(D->getType());
909 unsigned ElementCount = AT ? AT->getZExtSize() : 1;
910 ActiveSemantic.Index = Location + ElementCount;
911
912 Twine BaseName = Twine(ActiveSemantic.Semantic->getAttrName()->getName());
913 for (unsigned I = 0; I < ElementCount; ++I) {
914 Twine VariableName = BaseName.concat(Twine(Location + I));
915
916 auto [_, Inserted] = SC.ActiveSemantics.insert(VariableName.str());
917 if (!Inserted) {
918 Diag(D->getLocation(), diag::err_hlsl_semantic_index_overlap)
919 << VariableName.str();
920 return false;
921 }
922 }
923
924 return true;
925}
926
927bool SemaHLSL::determineActiveSemantic(FunctionDecl *FD,
928 DeclaratorDecl *OutputDecl,
930 SemanticInfo &ActiveSemantic,
931 SemaHLSL::SemanticContext &SC) {
932 if (ActiveSemantic.Semantic == nullptr) {
933 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
934 if (ActiveSemantic.Semantic)
935 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
936 }
937
938 const Type *T = D == FD ? &*FD->getReturnType() : &*D->getType();
940
941 const RecordType *RT = dyn_cast<RecordType>(T);
942 if (!RT)
943 return determineActiveSemanticOnScalar(FD, OutputDecl, D, ActiveSemantic,
944 SC);
945
946 const RecordDecl *RD = RT->getDecl();
947 for (FieldDecl *Field : RD->fields()) {
948 SemanticInfo Info = ActiveSemantic;
949 if (!determineActiveSemantic(FD, OutputDecl, Field, Info, SC)) {
950 Diag(Field->getLocation(), diag::note_hlsl_semantic_used_here) << Field;
951 return false;
952 }
953 if (ActiveSemantic.Semantic)
954 ActiveSemantic = Info;
955 }
956
957 return true;
958}
959
961 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
962 assert(ShaderAttr && "Entry point has no shader attribute");
963 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
965 VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
966 switch (ST) {
967 case llvm::Triple::Pixel:
968 case llvm::Triple::Vertex:
969 case llvm::Triple::Geometry:
970 case llvm::Triple::Hull:
971 case llvm::Triple::Domain:
972 case llvm::Triple::RayGeneration:
973 case llvm::Triple::Intersection:
974 case llvm::Triple::AnyHit:
975 case llvm::Triple::ClosestHit:
976 case llvm::Triple::Miss:
977 case llvm::Triple::Callable:
978 if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
979 diagnoseAttrStageMismatch(NT, ST,
980 {llvm::Triple::Compute,
981 llvm::Triple::Amplification,
982 llvm::Triple::Mesh});
983 FD->setInvalidDecl();
984 }
985 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
986 diagnoseAttrStageMismatch(WS, ST,
987 {llvm::Triple::Compute,
988 llvm::Triple::Amplification,
989 llvm::Triple::Mesh});
990 FD->setInvalidDecl();
991 }
992 break;
993
994 case llvm::Triple::Compute:
995 case llvm::Triple::Amplification:
996 case llvm::Triple::Mesh:
997 if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
998 Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
999 << llvm::Triple::getEnvironmentTypeName(ST);
1000 FD->setInvalidDecl();
1001 }
1002 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
1003 if (Ver < VersionTuple(6, 6)) {
1004 Diag(WS->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model)
1005 << WS << "6.6";
1006 FD->setInvalidDecl();
1007 } else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
1008 Diag(
1009 WS->getLocation(),
1010 diag::err_hlsl_attribute_number_arguments_insufficient_shader_model)
1011 << WS << WS->getSpelledArgsCount() << "6.8";
1012 FD->setInvalidDecl();
1013 }
1014 }
1015 break;
1016 case llvm::Triple::RootSignature:
1017 llvm_unreachable("rootsig environment has no function entry point");
1018 default:
1019 llvm_unreachable("Unhandled environment in triple");
1020 }
1021
1022 SemaHLSL::SemanticContext InputSC = {};
1023 InputSC.CurrentIOType = IOType::In;
1024
1025 for (ParmVarDecl *Param : FD->parameters()) {
1026 SemanticInfo ActiveSemantic;
1027 ActiveSemantic.Semantic = Param->getAttr<HLSLParsedSemanticAttr>();
1028 if (ActiveSemantic.Semantic)
1029 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
1030
1031 // FIXME: Verify output semantics in parameters.
1032 if (!determineActiveSemantic(FD, Param, Param, ActiveSemantic, InputSC)) {
1033 Diag(Param->getLocation(), diag::note_previous_decl) << Param;
1034 FD->setInvalidDecl();
1035 }
1036 }
1037
1038 SemanticInfo ActiveSemantic;
1039 SemaHLSL::SemanticContext OutputSC = {};
1040 OutputSC.CurrentIOType = IOType::Out;
1041 ActiveSemantic.Semantic = FD->getAttr<HLSLParsedSemanticAttr>();
1042 if (ActiveSemantic.Semantic)
1043 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
1044 if (!FD->getReturnType()->isVoidType())
1045 determineActiveSemantic(FD, FD, FD, ActiveSemantic, OutputSC);
1046}
1047
1048void SemaHLSL::checkSemanticAnnotation(
1049 FunctionDecl *EntryPoint, const Decl *Param,
1050 const HLSLAppliedSemanticAttr *SemanticAttr, const SemanticContext &SC) {
1051 auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
1052 assert(ShaderAttr && "Entry point has no shader attribute");
1053 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
1054
1055 auto SemanticName = SemanticAttr->getSemanticName().upper();
1056 if (SemanticName == "SV_DISPATCHTHREADID" ||
1057 SemanticName == "SV_GROUPINDEX" || SemanticName == "SV_GROUPTHREADID" ||
1058 SemanticName == "SV_GROUPID") {
1059
1060 if (ST != llvm::Triple::Compute)
1061 diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
1062 {{llvm::Triple::Compute, IOType::In}});
1063
1064 if (SemanticAttr->getSemanticIndex() != 0) {
1065 std::string PrettyName =
1066 "'" + SemanticAttr->getSemanticName().str() + "'";
1067 Diag(SemanticAttr->getLoc(),
1068 diag::err_hlsl_semantic_indexing_not_supported)
1069 << PrettyName;
1070 }
1071 return;
1072 }
1073
1074 if (SemanticName == "SV_POSITION") {
1075 // SV_Position can be an input or output in vertex shaders,
1076 // but only an input in pixel shaders.
1077 diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
1078 {{llvm::Triple::Vertex, IOType::InOut},
1079 {llvm::Triple::Pixel, IOType::In}});
1080 return;
1081 }
1082 if (SemanticName == "SV_VERTEXID") {
1083 diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
1084 {{llvm::Triple::Vertex, IOType::In}});
1085 return;
1086 }
1087
1088 if (SemanticName == "SV_TARGET") {
1089 diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
1090 {{llvm::Triple::Pixel, IOType::Out}});
1091 return;
1092 }
1093
1094 // FIXME: catch-all for non-implemented system semantics reaching this
1095 // location.
1096 if (SemanticAttr->getAttrName()->getName().starts_with_insensitive("SV_"))
1097 llvm_unreachable("Unknown SemanticAttr");
1098}
1099
1100void SemaHLSL::diagnoseAttrStageMismatch(
1101 const Attr *A, llvm::Triple::EnvironmentType Stage,
1102 std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
1103 SmallVector<StringRef, 8> StageStrings;
1104 llvm::transform(AllowedStages, std::back_inserter(StageStrings),
1105 [](llvm::Triple::EnvironmentType ST) {
1106 return StringRef(
1107 HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST));
1108 });
1109 Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
1110 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Stage)
1111 << (AllowedStages.size() != 1) << join(StageStrings, ", ");
1112}
1113
1114void SemaHLSL::diagnoseSemanticStageMismatch(
1115 const Attr *A, llvm::Triple::EnvironmentType Stage, IOType CurrentIOType,
1116 std::initializer_list<SemanticStageInfo> Allowed) {
1117
1118 for (auto &Case : Allowed) {
1119 if (Case.Stage != Stage)
1120 continue;
1121
1122 if (CurrentIOType & Case.AllowedIOTypesMask)
1123 return;
1124
1125 SmallVector<std::string, 8> ValidCases;
1126 llvm::transform(
1127 Allowed, std::back_inserter(ValidCases), [](SemanticStageInfo Case) {
1128 SmallVector<std::string, 2> ValidType;
1129 if (Case.AllowedIOTypesMask & IOType::In)
1130 ValidType.push_back("input");
1131 if (Case.AllowedIOTypesMask & IOType::Out)
1132 ValidType.push_back("output");
1133 return std::string(
1134 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Case.Stage)) +
1135 " " + join(ValidType, "/");
1136 });
1137 Diag(A->getLoc(), diag::err_hlsl_semantic_unsupported_iotype_for_stage)
1138 << A->getAttrName() << (CurrentIOType & IOType::In ? "input" : "output")
1139 << llvm::Triple::getEnvironmentTypeName(Case.Stage)
1140 << join(ValidCases, ", ");
1141 return;
1142 }
1143
1144 SmallVector<StringRef, 8> StageStrings;
1145 llvm::transform(
1146 Allowed, std::back_inserter(StageStrings), [](SemanticStageInfo Case) {
1147 return StringRef(
1148 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Case.Stage));
1149 });
1150
1151 Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
1152 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Stage)
1153 << (Allowed.size() != 1) << join(StageStrings, ", ");
1154}
1155
1156template <CastKind Kind>
1157static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
1158 if (const auto *VTy = Ty->getAs<VectorType>())
1159 Ty = VTy->getElementType();
1160 Ty = S.getASTContext().getExtVectorType(Ty, Sz);
1161 E = S.ImpCastExprToType(E.get(), Ty, Kind);
1162}
1163
1164template <CastKind Kind>
1166 E = S.ImpCastExprToType(E.get(), Ty, Kind);
1167 return Ty;
1168}
1169
1171 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1172 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1173 bool LHSFloat = LElTy->isRealFloatingType();
1174 bool RHSFloat = RElTy->isRealFloatingType();
1175
1176 if (LHSFloat && RHSFloat) {
1177 if (IsCompAssign ||
1178 SemaRef.getASTContext().getFloatingTypeOrder(LElTy, RElTy) > 0)
1179 return castElement<CK_FloatingCast>(SemaRef, RHS, LHSType);
1180
1181 return castElement<CK_FloatingCast>(SemaRef, LHS, RHSType);
1182 }
1183
1184 if (LHSFloat)
1185 return castElement<CK_IntegralToFloating>(SemaRef, RHS, LHSType);
1186
1187 assert(RHSFloat);
1188 if (IsCompAssign)
1189 return castElement<clang::CK_FloatingToIntegral>(SemaRef, RHS, LHSType);
1190
1191 return castElement<CK_IntegralToFloating>(SemaRef, LHS, RHSType);
1192}
1193
1195 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1196 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1197
1198 int IntOrder = SemaRef.Context.getIntegerTypeOrder(LElTy, RElTy);
1199 bool LHSSigned = LElTy->hasSignedIntegerRepresentation();
1200 bool RHSSigned = RElTy->hasSignedIntegerRepresentation();
1201 auto &Ctx = SemaRef.getASTContext();
1202
1203 // If both types have the same signedness, use the higher ranked type.
1204 if (LHSSigned == RHSSigned) {
1205 if (IsCompAssign || IntOrder >= 0)
1206 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1207
1208 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
1209 }
1210
1211 // If the unsigned type has greater than or equal rank of the signed type, use
1212 // the unsigned type.
1213 if (IntOrder != (LHSSigned ? 1 : -1)) {
1214 if (IsCompAssign || RHSSigned)
1215 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1216 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
1217 }
1218
1219 // At this point the signed type has higher rank than the unsigned type, which
1220 // means it will be the same size or bigger. If the signed type is bigger, it
1221 // can represent all the values of the unsigned type, so select it.
1222 if (Ctx.getIntWidth(LElTy) != Ctx.getIntWidth(RElTy)) {
1223 if (IsCompAssign || LHSSigned)
1224 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1225 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
1226 }
1227
1228 // This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due
1229 // to C/C++ leaking through. The place this happens today is long vs long
1230 // long. When arguments are vector<unsigned long, N> and vector<long long, N>,
1231 // the long long has higher rank than long even though they are the same size.
1232
1233 // If this is a compound assignment cast the right hand side to the left hand
1234 // side's type.
1235 if (IsCompAssign)
1236 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1237
1238 // If this isn't a compound assignment we convert to unsigned long long.
1239 QualType ElTy = Ctx.getCorrespondingUnsignedType(LHSSigned ? LElTy : RElTy);
1240 QualType NewTy = Ctx.getExtVectorType(
1241 ElTy, RHSType->castAs<VectorType>()->getNumElements());
1242 (void)castElement<CK_IntegralCast>(SemaRef, RHS, NewTy);
1243
1244 return castElement<CK_IntegralCast>(SemaRef, LHS, NewTy);
1245}
1246
1248 QualType SrcTy) {
1249 if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType())
1250 return CK_FloatingCast;
1251 if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx))
1252 return CK_IntegralCast;
1253 if (DestTy->isRealFloatingType())
1254 return CK_IntegralToFloating;
1255 assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx));
1256 return CK_FloatingToIntegral;
1257}
1258
1260 QualType LHSType,
1261 QualType RHSType,
1262 bool IsCompAssign) {
1263 const auto *LVecTy = LHSType->getAs<VectorType>();
1264 const auto *RVecTy = RHSType->getAs<VectorType>();
1265 auto &Ctx = getASTContext();
1266
1267 // If the LHS is not a vector and this is a compound assignment, we truncate
1268 // the argument to a scalar then convert it to the LHS's type.
1269 if (!LVecTy && IsCompAssign) {
1270 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1271 RHS = SemaRef.ImpCastExprToType(RHS.get(), RElTy, CK_HLSLVectorTruncation);
1272 RHSType = RHS.get()->getType();
1273 if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
1274 return LHSType;
1275 RHS = SemaRef.ImpCastExprToType(RHS.get(), LHSType,
1276 getScalarCastKind(Ctx, LHSType, RHSType));
1277 return LHSType;
1278 }
1279
1280 unsigned EndSz = std::numeric_limits<unsigned>::max();
1281 unsigned LSz = 0;
1282 if (LVecTy)
1283 LSz = EndSz = LVecTy->getNumElements();
1284 if (RVecTy)
1285 EndSz = std::min(RVecTy->getNumElements(), EndSz);
1286 assert(EndSz != std::numeric_limits<unsigned>::max() &&
1287 "one of the above should have had a value");
1288
1289 // In a compound assignment, the left operand does not change type, the right
1290 // operand is converted to the type of the left operand.
1291 if (IsCompAssign && LSz != EndSz) {
1292 Diag(LHS.get()->getBeginLoc(),
1293 diag::err_hlsl_vector_compound_assignment_truncation)
1294 << LHSType << RHSType;
1295 return QualType();
1296 }
1297
1298 if (RVecTy && RVecTy->getNumElements() > EndSz)
1299 castVector<CK_HLSLVectorTruncation>(SemaRef, RHS, RHSType, EndSz);
1300 if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz)
1301 castVector<CK_HLSLVectorTruncation>(SemaRef, LHS, LHSType, EndSz);
1302
1303 if (!RVecTy)
1304 castVector<CK_VectorSplat>(SemaRef, RHS, RHSType, EndSz);
1305 if (!IsCompAssign && !LVecTy)
1306 castVector<CK_VectorSplat>(SemaRef, LHS, LHSType, EndSz);
1307
1308 // If we're at the same type after resizing we can stop here.
1309 if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
1310 return Ctx.getCommonSugaredType(LHSType, RHSType);
1311
1312 QualType LElTy = LHSType->castAs<VectorType>()->getElementType();
1313 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1314
1315 // Handle conversion for floating point vectors.
1316 if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType())
1317 return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1318 LElTy, RElTy, IsCompAssign);
1319
1320 assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) &&
1321 "HLSL Vectors can only contain integer or floating point types");
1322 return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1323 LElTy, RElTy, IsCompAssign);
1324}
1325
1327 BinaryOperatorKind Opc) {
1328 assert((Opc == BO_LOr || Opc == BO_LAnd) &&
1329 "Called with non-logical operator");
1331 llvm::raw_svector_ostream OS(Buff);
1332 PrintingPolicy PP(SemaRef.getLangOpts());
1333 StringRef NewFnName = Opc == BO_LOr ? "or" : "and";
1334 OS << NewFnName << "(";
1335 LHS->printPretty(OS, nullptr, PP);
1336 OS << ", ";
1337 RHS->printPretty(OS, nullptr, PP);
1338 OS << ")";
1339 SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc());
1340 SemaRef.Diag(LHS->getBeginLoc(), diag::note_function_suggestion)
1341 << NewFnName << FixItHint::CreateReplacement(FullRange, OS.str());
1342}
1343
1344std::pair<IdentifierInfo *, bool>
1346 llvm::hash_code Hash = llvm::hash_value(Signature);
1347 std::string IdStr = "__hlsl_rootsig_decl_" + std::to_string(Hash);
1348 IdentifierInfo *DeclIdent = &(getASTContext().Idents.get(IdStr));
1349
1350 // Check if we have already found a decl of the same name.
1351 LookupResult R(SemaRef, DeclIdent, SourceLocation(),
1353 bool Found = SemaRef.LookupQualifiedName(R, SemaRef.CurContext);
1354 return {DeclIdent, Found};
1355}
1356
1358 SourceLocation Loc, IdentifierInfo *DeclIdent,
1360
1361 if (handleRootSignatureElements(RootElements))
1362 return;
1363
1365 for (auto &RootSigElement : RootElements)
1366 Elements.push_back(RootSigElement.getElement());
1367
1368 auto *SignatureDecl = HLSLRootSignatureDecl::Create(
1369 SemaRef.getASTContext(), /*DeclContext=*/SemaRef.CurContext, Loc,
1370 DeclIdent, SemaRef.getLangOpts().HLSLRootSigVer, Elements);
1371
1372 SignatureDecl->setImplicit();
1373 SemaRef.PushOnScopeChains(SignatureDecl, SemaRef.getCurScope());
1374}
1375
1378 if (RootSigOverrideIdent) {
1379 LookupResult R(SemaRef, RootSigOverrideIdent, SourceLocation(),
1381 if (SemaRef.LookupQualifiedName(R, DC))
1382 return dyn_cast<HLSLRootSignatureDecl>(R.getFoundDecl());
1383 }
1384
1385 return nullptr;
1386}
1387
1388namespace {
1389
1390struct PerVisibilityBindingChecker {
1391 SemaHLSL *S;
1392 // We need one builder per `llvm::dxbc::ShaderVisibility` value.
1393 std::array<llvm::hlsl::BindingInfoBuilder, 8> Builders;
1394
1395 struct ElemInfo {
1396 const hlsl::RootSignatureElement *Elem;
1397 llvm::dxbc::ShaderVisibility Vis;
1398 bool Diagnosed;
1399 };
1400 llvm::SmallVector<ElemInfo> ElemInfoMap;
1401
1402 PerVisibilityBindingChecker(SemaHLSL *S) : S(S) {}
1403
1404 void trackBinding(llvm::dxbc::ShaderVisibility Visibility,
1405 llvm::dxil::ResourceClass RC, uint32_t Space,
1406 uint32_t LowerBound, uint32_t UpperBound,
1407 const hlsl::RootSignatureElement *Elem) {
1408 uint32_t BuilderIndex = llvm::to_underlying(Visibility);
1409 assert(BuilderIndex < Builders.size() &&
1410 "Not enough builders for visibility type");
1411 Builders[BuilderIndex].trackBinding(RC, Space, LowerBound, UpperBound,
1412 static_cast<const void *>(Elem));
1413
1414 static_assert(llvm::to_underlying(llvm::dxbc::ShaderVisibility::All) == 0,
1415 "'All' visibility must come first");
1416 if (Visibility == llvm::dxbc::ShaderVisibility::All)
1417 for (size_t I = 1, E = Builders.size(); I < E; ++I)
1418 Builders[I].trackBinding(RC, Space, LowerBound, UpperBound,
1419 static_cast<const void *>(Elem));
1420
1421 ElemInfoMap.push_back({Elem, Visibility, false});
1422 }
1423
1424 ElemInfo &getInfo(const hlsl::RootSignatureElement *Elem) {
1425 auto It = llvm::lower_bound(
1426 ElemInfoMap, Elem,
1427 [](const auto &LHS, const auto &RHS) { return LHS.Elem < RHS; });
1428 assert(It->Elem == Elem && "Element not in map");
1429 return *It;
1430 }
1431
1432 bool checkOverlap() {
1433 llvm::sort(ElemInfoMap, [](const auto &LHS, const auto &RHS) {
1434 return LHS.Elem < RHS.Elem;
1435 });
1436
1437 bool HadOverlap = false;
1438
1439 using llvm::hlsl::BindingInfoBuilder;
1440 auto ReportOverlap = [this,
1441 &HadOverlap](const BindingInfoBuilder &Builder,
1442 const llvm::hlsl::Binding &Reported) {
1443 HadOverlap = true;
1444
1445 const auto *Elem =
1446 static_cast<const hlsl::RootSignatureElement *>(Reported.Cookie);
1447 const llvm::hlsl::Binding &Previous = Builder.findOverlapping(Reported);
1448 const auto *PrevElem =
1449 static_cast<const hlsl::RootSignatureElement *>(Previous.Cookie);
1450
1451 ElemInfo &Info = getInfo(Elem);
1452 // We will have already diagnosed this binding if there's overlap in the
1453 // "All" visibility as well as any particular visibility.
1454 if (Info.Diagnosed)
1455 return;
1456 Info.Diagnosed = true;
1457
1458 ElemInfo &PrevInfo = getInfo(PrevElem);
1459 llvm::dxbc::ShaderVisibility CommonVis =
1460 Info.Vis == llvm::dxbc::ShaderVisibility::All ? PrevInfo.Vis
1461 : Info.Vis;
1462
1463 this->S->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap)
1464 << llvm::to_underlying(Reported.RC) << Reported.LowerBound
1465 << Reported.isUnbounded() << Reported.UpperBound
1466 << llvm::to_underlying(Previous.RC) << Previous.LowerBound
1467 << Previous.isUnbounded() << Previous.UpperBound << Reported.Space
1468 << CommonVis;
1469
1470 this->S->Diag(PrevElem->getLocation(),
1471 diag::note_hlsl_resource_range_here);
1472 };
1473
1474 for (BindingInfoBuilder &Builder : Builders)
1475 Builder.calculateBindingInfo(ReportOverlap);
1476
1477 return HadOverlap;
1478 }
1479};
1480
1481static CXXMethodDecl *lookupMethod(Sema &S, CXXRecordDecl *RecordDecl,
1482 StringRef Name, SourceLocation Loc) {
1483 DeclarationName DeclName(&S.getASTContext().Idents.get(Name));
1484 LookupResult Result(S, DeclName, Loc, Sema::LookupMemberName);
1485 if (!S.LookupQualifiedName(Result, static_cast<DeclContext *>(RecordDecl)))
1486 return nullptr;
1487 return cast<CXXMethodDecl>(Result.getFoundDecl());
1488}
1489
1490} // end anonymous namespace
1491
1494 // Define some common error handling functions
1495 bool HadError = false;
1496 auto ReportError = [this, &HadError](SourceLocation Loc, uint32_t LowerBound,
1497 uint32_t UpperBound) {
1498 HadError = true;
1499 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value)
1500 << LowerBound << UpperBound;
1501 };
1502
1503 auto ReportFloatError = [this, &HadError](SourceLocation Loc,
1504 float LowerBound,
1505 float UpperBound) {
1506 HadError = true;
1507 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value)
1508 << llvm::formatv("{0:f}", LowerBound).sstr<6>()
1509 << llvm::formatv("{0:f}", UpperBound).sstr<6>();
1510 };
1511
1512 auto VerifyRegister = [ReportError](SourceLocation Loc, uint32_t Register) {
1513 if (!llvm::hlsl::rootsig::verifyRegisterValue(Register))
1514 ReportError(Loc, 0, 0xfffffffe);
1515 };
1516
1517 auto VerifySpace = [ReportError](SourceLocation Loc, uint32_t Space) {
1518 if (!llvm::hlsl::rootsig::verifyRegisterSpace(Space))
1519 ReportError(Loc, 0, 0xffffffef);
1520 };
1521
1522 const uint32_t Version =
1523 llvm::to_underlying(SemaRef.getLangOpts().HLSLRootSigVer);
1524 const uint32_t VersionEnum = Version - 1;
1525 auto ReportFlagError = [this, &HadError, VersionEnum](SourceLocation Loc) {
1526 HadError = true;
1527 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_flag)
1528 << /*version minor*/ VersionEnum;
1529 };
1530
1531 // Iterate through the elements and do basic validations
1532 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1533 SourceLocation Loc = RootSigElem.getLocation();
1534 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1535 if (const auto *Descriptor =
1536 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
1537 VerifyRegister(Loc, Descriptor->Reg.Number);
1538 VerifySpace(Loc, Descriptor->Space);
1539
1540 if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(Version,
1541 Descriptor->Flags))
1542 ReportFlagError(Loc);
1543 } else if (const auto *Constants =
1544 std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
1545 VerifyRegister(Loc, Constants->Reg.Number);
1546 VerifySpace(Loc, Constants->Space);
1547 } else if (const auto *Sampler =
1548 std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
1549 VerifyRegister(Loc, Sampler->Reg.Number);
1550 VerifySpace(Loc, Sampler->Space);
1551
1552 assert(!std::isnan(Sampler->MaxLOD) && !std::isnan(Sampler->MinLOD) &&
1553 "By construction, parseFloatParam can't produce a NaN from a "
1554 "float_literal token");
1555
1556 if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler->MaxAnisotropy))
1557 ReportError(Loc, 0, 16);
1558 if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler->MipLODBias))
1559 ReportFloatError(Loc, -16.f, 15.99f);
1560 } else if (const auto *Clause =
1561 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1562 &Elem)) {
1563 VerifyRegister(Loc, Clause->Reg.Number);
1564 VerifySpace(Loc, Clause->Space);
1565
1566 if (!llvm::hlsl::rootsig::verifyNumDescriptors(Clause->NumDescriptors)) {
1567 // NumDescriptor could techincally be ~0u but that is reserved for
1568 // unbounded, so the diagnostic will not report that as a valid int
1569 // value
1570 ReportError(Loc, 1, 0xfffffffe);
1571 }
1572
1573 if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(Version, Clause->Type,
1574 Clause->Flags))
1575 ReportFlagError(Loc);
1576 }
1577 }
1578
1579 PerVisibilityBindingChecker BindingChecker(this);
1580 SmallVector<std::pair<const llvm::hlsl::rootsig::DescriptorTableClause *,
1582 UnboundClauses;
1583
1584 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1585 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1586 if (const auto *Descriptor =
1587 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
1588 uint32_t LowerBound(Descriptor->Reg.Number);
1589 uint32_t UpperBound(LowerBound); // inclusive range
1590
1591 BindingChecker.trackBinding(
1592 Descriptor->Visibility,
1593 static_cast<llvm::dxil::ResourceClass>(Descriptor->Type),
1594 Descriptor->Space, LowerBound, UpperBound, &RootSigElem);
1595 } else if (const auto *Constants =
1596 std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
1597 uint32_t LowerBound(Constants->Reg.Number);
1598 uint32_t UpperBound(LowerBound); // inclusive range
1599
1600 BindingChecker.trackBinding(
1601 Constants->Visibility, llvm::dxil::ResourceClass::CBuffer,
1602 Constants->Space, LowerBound, UpperBound, &RootSigElem);
1603 } else if (const auto *Sampler =
1604 std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
1605 uint32_t LowerBound(Sampler->Reg.Number);
1606 uint32_t UpperBound(LowerBound); // inclusive range
1607
1608 BindingChecker.trackBinding(
1609 Sampler->Visibility, llvm::dxil::ResourceClass::Sampler,
1610 Sampler->Space, LowerBound, UpperBound, &RootSigElem);
1611 } else if (const auto *Clause =
1612 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1613 &Elem)) {
1614 // We'll process these once we see the table element.
1615 UnboundClauses.emplace_back(Clause, &RootSigElem);
1616 } else if (const auto *Table =
1617 std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
1618 assert(UnboundClauses.size() == Table->NumClauses &&
1619 "Number of unbound elements must match the number of clauses");
1620 bool HasAnySampler = false;
1621 bool HasAnyNonSampler = false;
1622 uint64_t Offset = 0;
1623 bool IsPrevUnbound = false;
1624 for (const auto &[Clause, ClauseElem] : UnboundClauses) {
1625 SourceLocation Loc = ClauseElem->getLocation();
1626 if (Clause->Type == llvm::dxil::ResourceClass::Sampler)
1627 HasAnySampler = true;
1628 else
1629 HasAnyNonSampler = true;
1630
1631 if (HasAnySampler && HasAnyNonSampler)
1632 Diag(Loc, diag::err_hlsl_invalid_mixed_resources);
1633
1634 // Relevant error will have already been reported above and needs to be
1635 // fixed before we can conduct further analysis, so shortcut error
1636 // return
1637 if (Clause->NumDescriptors == 0)
1638 return true;
1639
1640 bool IsAppending =
1641 Clause->Offset == llvm::hlsl::rootsig::DescriptorTableOffsetAppend;
1642 if (!IsAppending)
1643 Offset = Clause->Offset;
1644
1645 uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound(
1646 Offset, Clause->NumDescriptors);
1647
1648 if (IsPrevUnbound && IsAppending)
1649 Diag(Loc, diag::err_hlsl_appending_onto_unbound);
1650 else if (!llvm::hlsl::rootsig::verifyNoOverflowedOffset(RangeBound))
1651 Diag(Loc, diag::err_hlsl_offset_overflow) << Offset << RangeBound;
1652
1653 // Update offset to be 1 past this range's bound
1654 Offset = RangeBound + 1;
1655 IsPrevUnbound = Clause->NumDescriptors ==
1656 llvm::hlsl::rootsig::NumDescriptorsUnbounded;
1657
1658 // Compute the register bounds and track resource binding
1659 uint32_t LowerBound(Clause->Reg.Number);
1660 uint32_t UpperBound = llvm::hlsl::rootsig::computeRangeBound(
1661 LowerBound, Clause->NumDescriptors);
1662
1663 BindingChecker.trackBinding(
1664 Table->Visibility,
1665 static_cast<llvm::dxil::ResourceClass>(Clause->Type), Clause->Space,
1666 LowerBound, UpperBound, ClauseElem);
1667 }
1668 UnboundClauses.clear();
1669 }
1670 }
1671
1672 return BindingChecker.checkOverlap();
1673}
1674
1676 if (AL.getNumArgs() != 1) {
1677 Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1;
1678 return;
1679 }
1680
1682 if (auto *RS = D->getAttr<RootSignatureAttr>()) {
1683 if (RS->getSignatureIdent() != Ident) {
1684 Diag(AL.getLoc(), diag::err_disallowed_duplicate_attribute) << RS;
1685 return;
1686 }
1687
1688 Diag(AL.getLoc(), diag::warn_duplicate_attribute_exact) << RS;
1689 return;
1690 }
1691
1693 if (SemaRef.LookupQualifiedName(R, D->getDeclContext()))
1694 if (auto *SignatureDecl =
1695 dyn_cast<HLSLRootSignatureDecl>(R.getFoundDecl())) {
1696 D->addAttr(::new (getASTContext()) RootSignatureAttr(
1697 getASTContext(), AL, Ident, SignatureDecl));
1698 }
1699}
1700
1702 llvm::VersionTuple SMVersion =
1703 getASTContext().getTargetInfo().getTriple().getOSVersion();
1704 bool IsDXIL = getASTContext().getTargetInfo().getTriple().getArch() ==
1705 llvm::Triple::dxil;
1706
1707 uint32_t ZMax = 1024;
1708 uint32_t ThreadMax = 1024;
1709 if (IsDXIL && SMVersion.getMajor() <= 4) {
1710 ZMax = 1;
1711 ThreadMax = 768;
1712 } else if (IsDXIL && SMVersion.getMajor() == 5) {
1713 ZMax = 64;
1714 ThreadMax = 1024;
1715 }
1716
1717 uint32_t X;
1718 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), X))
1719 return;
1720 if (X > 1024) {
1721 Diag(AL.getArgAsExpr(0)->getExprLoc(),
1722 diag::err_hlsl_numthreads_argument_oor)
1723 << 0 << 1024;
1724 return;
1725 }
1726 uint32_t Y;
1727 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Y))
1728 return;
1729 if (Y > 1024) {
1730 Diag(AL.getArgAsExpr(1)->getExprLoc(),
1731 diag::err_hlsl_numthreads_argument_oor)
1732 << 1 << 1024;
1733 return;
1734 }
1735 uint32_t Z;
1736 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Z))
1737 return;
1738 if (Z > ZMax) {
1739 SemaRef.Diag(AL.getArgAsExpr(2)->getExprLoc(),
1740 diag::err_hlsl_numthreads_argument_oor)
1741 << 2 << ZMax;
1742 return;
1743 }
1744
1745 if (X * Y * Z > ThreadMax) {
1746 Diag(AL.getLoc(), diag::err_hlsl_numthreads_invalid) << ThreadMax;
1747 return;
1748 }
1749
1750 HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z);
1751 if (NewAttr)
1752 D->addAttr(NewAttr);
1753}
1754
1755static bool isValidWaveSizeValue(unsigned Value) {
1756 return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128;
1757}
1758
1760 // validate that the wavesize argument is a power of 2 between 4 and 128
1761 // inclusive
1762 unsigned SpelledArgsCount = AL.getNumArgs();
1763 if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
1764 return;
1765
1766 uint32_t Min;
1767 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Min))
1768 return;
1769
1770 uint32_t Max = 0;
1771 if (SpelledArgsCount > 1 &&
1772 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Max))
1773 return;
1774
1775 uint32_t Preferred = 0;
1776 if (SpelledArgsCount > 2 &&
1777 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Preferred))
1778 return;
1779
1780 if (SpelledArgsCount > 2) {
1781 if (!isValidWaveSizeValue(Preferred)) {
1782 Diag(AL.getArgAsExpr(2)->getExprLoc(),
1783 diag::err_attribute_power_of_two_in_range)
1784 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize
1785 << Preferred;
1786 return;
1787 }
1788 // Preferred not in range.
1789 if (Preferred < Min || Preferred > Max) {
1790 Diag(AL.getArgAsExpr(2)->getExprLoc(),
1791 diag::err_attribute_power_of_two_in_range)
1792 << AL << Min << Max << Preferred;
1793 return;
1794 }
1795 } else if (SpelledArgsCount > 1) {
1796 if (!isValidWaveSizeValue(Max)) {
1797 Diag(AL.getArgAsExpr(1)->getExprLoc(),
1798 diag::err_attribute_power_of_two_in_range)
1799 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max;
1800 return;
1801 }
1802 if (Max < Min) {
1803 Diag(AL.getLoc(), diag::err_attribute_argument_invalid) << AL << 1;
1804 return;
1805 } else if (Max == Min) {
1806 Diag(AL.getLoc(), diag::warn_attr_min_eq_max) << AL;
1807 }
1808 } else {
1809 if (!isValidWaveSizeValue(Min)) {
1810 Diag(AL.getArgAsExpr(0)->getExprLoc(),
1811 diag::err_attribute_power_of_two_in_range)
1812 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min;
1813 return;
1814 }
1815 }
1816
1817 HLSLWaveSizeAttr *NewAttr =
1818 mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
1819 if (NewAttr)
1820 D->addAttr(NewAttr);
1821}
1822
1824 uint32_t ID;
1825 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), ID))
1826 return;
1827 D->addAttr(::new (getASTContext())
1828 HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
1829}
1830
1832 uint32_t ID;
1833 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), ID))
1834 return;
1835 D->addAttr(::new (getASTContext())
1836 HLSLVkExtBuiltinOutputAttr(getASTContext(), AL, ID));
1837}
1838
1840 D->addAttr(::new (getASTContext())
1841 HLSLVkPushConstantAttr(getASTContext(), AL));
1842}
1843
1845 uint32_t Id;
1846 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
1847 return;
1848 HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
1849 if (NewAttr)
1850 D->addAttr(NewAttr);
1851}
1852
1854 uint32_t Binding = 0;
1855 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Binding))
1856 return;
1857 uint32_t Set = 0;
1858 if (AL.getNumArgs() > 1 &&
1859 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Set))
1860 return;
1861
1862 D->addAttr(::new (getASTContext())
1863 HLSLVkBindingAttr(getASTContext(), AL, Binding, Set));
1864}
1865
1867 uint32_t Location;
1868 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Location))
1869 return;
1870
1871 D->addAttr(::new (getASTContext())
1872 HLSLVkLocationAttr(getASTContext(), AL, Location));
1873}
1874
1876 const auto *VT = T->getAs<VectorType>();
1877
1878 if (!T->hasUnsignedIntegerRepresentation() ||
1879 (VT && VT->getNumElements() > 3)) {
1880 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1881 << AL << "uint/uint2/uint3";
1882 return false;
1883 }
1884
1885 return true;
1886}
1887
1889 const auto *VT = T->getAs<VectorType>();
1890 if (!T->hasFloatingRepresentation() || (VT && VT->getNumElements() > 4)) {
1891 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1892 << AL << "float/float1/float2/float3/float4";
1893 return false;
1894 }
1895
1896 return true;
1897}
1898
1900 std::optional<unsigned> Index) {
1901 std::string SemanticName = AL.getAttrName()->getName().upper();
1902
1903 auto *VD = cast<ValueDecl>(D);
1904 QualType ValueType = VD->getType();
1905 if (auto *FD = dyn_cast<FunctionDecl>(D))
1906 ValueType = FD->getReturnType();
1907
1908 bool IsOutput = false;
1909 if (HLSLParamModifierAttr *MA = D->getAttr<HLSLParamModifierAttr>()) {
1910 if (MA->isOut()) {
1911 IsOutput = true;
1912 ValueType = cast<ReferenceType>(ValueType)->getPointeeType();
1913 }
1914 }
1915
1916 if (SemanticName == "SV_DISPATCHTHREADID") {
1917 diagnoseInputIDType(ValueType, AL);
1918 if (IsOutput)
1919 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1920 if (Index.has_value())
1921 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1923 return;
1924 }
1925
1926 if (SemanticName == "SV_GROUPINDEX") {
1927 if (IsOutput)
1928 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1929 if (Index.has_value())
1930 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1932 return;
1933 }
1934
1935 if (SemanticName == "SV_GROUPTHREADID") {
1936 diagnoseInputIDType(ValueType, AL);
1937 if (IsOutput)
1938 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1939 if (Index.has_value())
1940 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1942 return;
1943 }
1944
1945 if (SemanticName == "SV_GROUPID") {
1946 diagnoseInputIDType(ValueType, AL);
1947 if (IsOutput)
1948 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1949 if (Index.has_value())
1950 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1952 return;
1953 }
1954
1955 if (SemanticName == "SV_POSITION") {
1956 const auto *VT = ValueType->getAs<VectorType>();
1957 if (!ValueType->hasFloatingRepresentation() ||
1958 (VT && VT->getNumElements() > 4))
1959 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1960 << AL << "float/float1/float2/float3/float4";
1962 return;
1963 }
1964
1965 if (SemanticName == "SV_VERTEXID") {
1966 uint64_t SizeInBits = SemaRef.Context.getTypeSize(ValueType);
1967 if (!ValueType->isUnsignedIntegerType() || SizeInBits != 32)
1968 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type) << AL << "uint";
1970 return;
1971 }
1972
1973 if (SemanticName == "SV_TARGET") {
1974 const auto *VT = ValueType->getAs<VectorType>();
1975 if (!ValueType->hasFloatingRepresentation() ||
1976 (VT && VT->getNumElements() > 4))
1977 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1978 << AL << "float/float1/float2/float3/float4";
1980 return;
1981 }
1982
1983 Diag(AL.getLoc(), diag::err_hlsl_unknown_semantic) << AL;
1984}
1985
1987 uint32_t IndexValue(0), ExplicitIndex(0);
1988 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), IndexValue) ||
1989 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), ExplicitIndex)) {
1990 assert(0 && "HLSLUnparsedSemantic is expected to have 2 int arguments.");
1991 }
1992 assert(IndexValue > 0 ? ExplicitIndex : true);
1993 std::optional<unsigned> Index =
1994 ExplicitIndex ? std::optional<unsigned>(IndexValue) : std::nullopt;
1995
1996 if (AL.getAttrName()->getName().starts_with_insensitive("SV_"))
1997 diagnoseSystemSemanticAttr(D, AL, Index);
1998 else
2000}
2001
2004 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)
2005 << AL << "shader constant in a constant buffer";
2006 return;
2007 }
2008
2009 uint32_t SubComponent;
2010 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), SubComponent))
2011 return;
2012 uint32_t Component;
2013 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Component))
2014 return;
2015
2016 QualType T = cast<VarDecl>(D)->getType().getCanonicalType();
2017 // Check if T is an array or struct type.
2018 // TODO: mark matrix type as aggregate type.
2019 bool IsAggregateTy = (T->isArrayType() || T->isStructureType());
2020
2021 // Check Component is valid for T.
2022 if (Component) {
2023 unsigned Size = getASTContext().getTypeSize(T);
2024 if (IsAggregateTy) {
2025 Diag(AL.getLoc(), diag::err_hlsl_invalid_register_or_packoffset);
2026 return;
2027 } else {
2028 // Make sure Component + sizeof(T) <= 4.
2029 if ((Component * 32 + Size) > 128) {
2030 Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);
2031 return;
2032 }
2033 QualType EltTy = T;
2034 if (const auto *VT = T->getAs<VectorType>())
2035 EltTy = VT->getElementType();
2036 unsigned Align = getASTContext().getTypeAlign(EltTy);
2037 if (Align > 32 && Component == 1) {
2038 // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary.
2039 // So we only need to check Component 1 here.
2040 Diag(AL.getLoc(), diag::err_hlsl_packoffset_alignment_mismatch)
2041 << Align << EltTy;
2042 return;
2043 }
2044 }
2045 }
2046
2047 D->addAttr(::new (getASTContext()) HLSLPackOffsetAttr(
2048 getASTContext(), AL, SubComponent, Component));
2049}
2050
2052 StringRef Str;
2053 SourceLocation ArgLoc;
2054 if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))
2055 return;
2056
2057 llvm::Triple::EnvironmentType ShaderType;
2058 if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) {
2059 Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
2060 << AL << Str << ArgLoc;
2061 return;
2062 }
2063
2064 // FIXME: check function match the shader stage.
2065
2066 HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType);
2067 if (NewAttr)
2068 D->addAttr(NewAttr);
2069}
2070
2072 Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList,
2073 QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo) {
2074 assert(AttrList.size() && "expected list of resource attributes");
2075
2076 QualType ContainedTy = QualType();
2077 TypeSourceInfo *ContainedTyInfo = nullptr;
2078 SourceLocation LocBegin = AttrList[0]->getRange().getBegin();
2079 SourceLocation LocEnd = AttrList[0]->getRange().getEnd();
2080
2081 HLSLAttributedResourceType::Attributes ResAttrs;
2082
2083 bool HasResourceClass = false;
2084 bool HasResourceDimension = false;
2085 for (const Attr *A : AttrList) {
2086 if (!A)
2087 continue;
2088 LocEnd = A->getRange().getEnd();
2089 switch (A->getKind()) {
2090 case attr::HLSLResourceClass: {
2091 ResourceClass RC = cast<HLSLResourceClassAttr>(A)->getResourceClass();
2092 if (HasResourceClass) {
2093 S.Diag(A->getLocation(), ResAttrs.ResourceClass == RC
2094 ? diag::warn_duplicate_attribute_exact
2095 : diag::warn_duplicate_attribute)
2096 << A;
2097 return false;
2098 }
2099 ResAttrs.ResourceClass = RC;
2100 HasResourceClass = true;
2101 break;
2102 }
2103 case attr::HLSLResourceDimension: {
2104 llvm::dxil::ResourceDimension RD =
2105 cast<HLSLResourceDimensionAttr>(A)->getDimension();
2106 if (HasResourceDimension) {
2107 S.Diag(A->getLocation(), ResAttrs.ResourceDimension == RD
2108 ? diag::warn_duplicate_attribute_exact
2109 : diag::warn_duplicate_attribute)
2110 << A;
2111 return false;
2112 }
2113 ResAttrs.ResourceDimension = RD;
2114 HasResourceDimension = true;
2115 break;
2116 }
2117 case attr::HLSLROV:
2118 if (ResAttrs.IsROV) {
2119 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
2120 return false;
2121 }
2122 ResAttrs.IsROV = true;
2123 break;
2124 case attr::HLSLRawBuffer:
2125 if (ResAttrs.RawBuffer) {
2126 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
2127 return false;
2128 }
2129 ResAttrs.RawBuffer = true;
2130 break;
2131 case attr::HLSLIsCounter:
2132 if (ResAttrs.IsCounter) {
2133 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
2134 return false;
2135 }
2136 ResAttrs.IsCounter = true;
2137 break;
2138 case attr::HLSLContainedType: {
2139 const HLSLContainedTypeAttr *CTAttr = cast<HLSLContainedTypeAttr>(A);
2140 QualType Ty = CTAttr->getType();
2141 if (!ContainedTy.isNull()) {
2142 S.Diag(A->getLocation(), ContainedTy == Ty
2143 ? diag::warn_duplicate_attribute_exact
2144 : diag::warn_duplicate_attribute)
2145 << A;
2146 return false;
2147 }
2148 ContainedTy = Ty;
2149 ContainedTyInfo = CTAttr->getTypeLoc();
2150 break;
2151 }
2152 default:
2153 llvm_unreachable("unhandled resource attribute type");
2154 }
2155 }
2156
2157 if (!HasResourceClass) {
2158 S.Diag(AttrList.back()->getRange().getEnd(),
2159 diag::err_hlsl_missing_resource_class);
2160 return false;
2161 }
2162
2164 Wrapped, ContainedTy, ResAttrs);
2165
2166 if (LocInfo && ContainedTyInfo) {
2167 LocInfo->Range = SourceRange(LocBegin, LocEnd);
2168 LocInfo->ContainedTyInfo = ContainedTyInfo;
2169 }
2170 return true;
2171}
2172
2173// Validates and creates an HLSL attribute that is applied as type attribute on
2174// HLSL resource. The attributes are collected in HLSLResourcesTypeAttrs and at
2175// the end of the declaration they are applied to the declaration type by
2176// wrapping it in HLSLAttributedResourceType.
2178 // only allow resource type attributes on intangible types
2179 if (!T->isHLSLResourceType()) {
2180 Diag(AL.getLoc(), diag::err_hlsl_attribute_needs_intangible_type)
2181 << AL << getASTContext().HLSLResourceTy;
2182 return false;
2183 }
2184
2185 // validate number of arguments
2186 if (!AL.checkExactlyNumArgs(SemaRef, AL.getMinArgs()))
2187 return false;
2188
2189 Attr *A = nullptr;
2190
2194 {
2195 AttributeCommonInfo::AS_CXX11, 0, false /*IsAlignas*/,
2196 false /*IsRegularKeywordAttribute*/
2197 });
2198
2199 switch (AL.getKind()) {
2200 case ParsedAttr::AT_HLSLResourceClass: {
2201 if (!AL.isArgIdent(0)) {
2202 Diag(AL.getLoc(), diag::err_attribute_argument_type)
2203 << AL << AANT_ArgumentIdentifier;
2204 return false;
2205 }
2206
2207 IdentifierLoc *Loc = AL.getArgAsIdent(0);
2208 StringRef Identifier = Loc->getIdentifierInfo()->getName();
2209 SourceLocation ArgLoc = Loc->getLoc();
2210
2211 // Validate resource class value
2212 ResourceClass RC;
2213 if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Identifier, RC)) {
2214 Diag(ArgLoc, diag::warn_attribute_type_not_supported)
2215 << "ResourceClass" << Identifier;
2216 return false;
2217 }
2218 A = HLSLResourceClassAttr::Create(getASTContext(), RC, ACI);
2219 break;
2220 }
2221
2222 case ParsedAttr::AT_HLSLResourceDimension: {
2223 StringRef Identifier;
2224 SourceLocation ArgLoc;
2225 if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Identifier, &ArgLoc))
2226 return false;
2227
2228 // Validate resource dimension value
2229 llvm::dxil::ResourceDimension RD;
2230 if (!HLSLResourceDimensionAttr::ConvertStrToResourceDimension(Identifier,
2231 RD)) {
2232 Diag(ArgLoc, diag::warn_attribute_type_not_supported)
2233 << "ResourceDimension" << Identifier;
2234 return false;
2235 }
2236 A = HLSLResourceDimensionAttr::Create(getASTContext(), RD, ACI);
2237 break;
2238 }
2239
2240 case ParsedAttr::AT_HLSLROV:
2241 A = HLSLROVAttr::Create(getASTContext(), ACI);
2242 break;
2243
2244 case ParsedAttr::AT_HLSLRawBuffer:
2245 A = HLSLRawBufferAttr::Create(getASTContext(), ACI);
2246 break;
2247
2248 case ParsedAttr::AT_HLSLIsCounter:
2249 A = HLSLIsCounterAttr::Create(getASTContext(), ACI);
2250 break;
2251
2252 case ParsedAttr::AT_HLSLContainedType: {
2253 if (AL.getNumArgs() != 1 && !AL.hasParsedType()) {
2254 Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1;
2255 return false;
2256 }
2257
2258 TypeSourceInfo *TSI = nullptr;
2259 QualType QT = SemaRef.GetTypeFromParser(AL.getTypeArg(), &TSI);
2260 assert(TSI && "no type source info for attribute argument");
2261 if (SemaRef.RequireCompleteType(TSI->getTypeLoc().getBeginLoc(), QT,
2262 diag::err_incomplete_type))
2263 return false;
2264 A = HLSLContainedTypeAttr::Create(getASTContext(), TSI, ACI);
2265 break;
2266 }
2267
2268 default:
2269 llvm_unreachable("unhandled HLSL attribute");
2270 }
2271
2272 HLSLResourcesTypeAttrs.emplace_back(A);
2273 return true;
2274}
2275
2276// Combines all resource type attributes and creates HLSLAttributedResourceType.
2278 if (!HLSLResourcesTypeAttrs.size())
2279 return CurrentType;
2280
2281 QualType QT = CurrentType;
2284 HLSLResourcesTypeAttrs, QT, &LocInfo)) {
2285 const HLSLAttributedResourceType *RT =
2287
2288 // Temporarily store TypeLoc information for the new type.
2289 // It will be transferred to HLSLAttributesResourceTypeLoc
2290 // shortly after the type is created by TypeSpecLocFiller which
2291 // will call the TakeLocForHLSLAttribute method below.
2292 LocsForHLSLAttributedResources.insert(std::pair(RT, LocInfo));
2293 }
2294 HLSLResourcesTypeAttrs.clear();
2295 return QT;
2296}
2297
2298// Returns source location for the HLSLAttributedResourceType
2300SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
2301 HLSLAttributedResourceLocInfo LocInfo = {};
2302 auto I = LocsForHLSLAttributedResources.find(RT);
2303 if (I != LocsForHLSLAttributedResources.end()) {
2304 LocInfo = I->second;
2305 LocsForHLSLAttributedResources.erase(I);
2306 return LocInfo;
2307 }
2308 LocInfo.Range = SourceRange();
2309 return LocInfo;
2310}
2311
2312// Walks though the global variable declaration, collects all resource binding
2313// requirements and adds them to Bindings
2314void SemaHLSL::collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
2315 const RecordType *RT) {
2316 const RecordDecl *RD = RT->getDecl()->getDefinitionOrSelf();
2317 for (FieldDecl *FD : RD->fields()) {
2318 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
2319
2320 // Unwrap arrays
2321 // FIXME: Calculate array size while unwrapping
2322 assert(!Ty->isIncompleteArrayType() &&
2323 "incomplete arrays inside user defined types are not supported");
2324 while (Ty->isConstantArrayType()) {
2327 }
2328
2329 if (!Ty->isRecordType())
2330 continue;
2331
2332 if (const HLSLAttributedResourceType *AttrResType =
2333 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
2334 // Add a new DeclBindingInfo to Bindings if it does not already exist
2335 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
2336 DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, RC);
2337 if (!DBI)
2338 Bindings.addDeclBindingInfo(VD, RC);
2339 } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
2340 // Recursively scan embedded struct or class; it would be nice to do this
2341 // without recursion, but tricky to correctly calculate the size of the
2342 // binding, which is something we are probably going to need to do later
2343 // on. Hopefully nesting of structs in structs too many levels is
2344 // unlikely.
2345 collectResourceBindingsOnUserRecordDecl(VD, RT);
2346 }
2347 }
2348}
2349
2350// Diagnose localized register binding errors for a single binding; does not
2351// diagnose resource binding on user record types, that will be done later
2352// in processResourceBindingOnDecl based on the information collected in
2353// collectResourceBindingsOnVarDecl.
2354// Returns false if the register binding is not valid.
2356 Decl *D, RegisterType RegType,
2357 bool SpecifiedSpace) {
2358 int RegTypeNum = static_cast<int>(RegType);
2359
2360 // check if the decl type is groupshared
2361 if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
2362 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2363 return false;
2364 }
2365
2366 // Cbuffers and Tbuffers are HLSLBufferDecl types
2367 if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
2368 ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
2369 : ResourceClass::SRV;
2370 if (RegType == getRegisterType(RC))
2371 return true;
2372
2373 S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
2374 << RegTypeNum;
2375 return false;
2376 }
2377
2378 // Samplers, UAVs, and SRVs are VarDecl types
2379 assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
2380 VarDecl *VD = cast<VarDecl>(D);
2381
2382 // Resource
2383 if (const HLSLAttributedResourceType *AttrResType =
2384 HLSLAttributedResourceType::findHandleTypeOnResource(
2385 VD->getType().getTypePtr())) {
2386 if (RegType == getRegisterType(AttrResType))
2387 return true;
2388
2389 S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
2390 << RegTypeNum;
2391 return false;
2392 }
2393
2394 const clang::Type *Ty = VD->getType().getTypePtr();
2395 while (Ty->isArrayType())
2397
2398 // Basic types
2399 if (Ty->isArithmeticType() || Ty->isVectorType()) {
2400 bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
2401 if (SpecifiedSpace && !DeclaredInCOrTBuffer)
2402 S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
2403
2404 if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(S.getASTContext()) ||
2405 Ty->isFloatingType() || Ty->isVectorType())) {
2406 // Register annotation on default constant buffer declaration ($Globals)
2407 if (RegType == RegisterType::CBuffer)
2408 S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
2409 else if (RegType != RegisterType::C)
2410 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2411 else
2412 return true;
2413 } else {
2414 if (RegType == RegisterType::C)
2415 S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
2416 else
2417 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2418 }
2419 return false;
2420 }
2421 if (Ty->isRecordType())
2422 // RecordTypes will be diagnosed in processResourceBindingOnDecl
2423 // that is called from ActOnVariableDeclarator
2424 return true;
2425
2426 // Anything else is an error
2427 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2428 return false;
2429}
2430
2432 RegisterType regType) {
2433 // make sure that there are no two register annotations
2434 // applied to the decl with the same register type
2435 bool RegisterTypesDetected[5] = {false};
2436 RegisterTypesDetected[static_cast<int>(regType)] = true;
2437
2438 for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
2439 if (HLSLResourceBindingAttr *attr =
2440 dyn_cast<HLSLResourceBindingAttr>(*it)) {
2441
2442 RegisterType otherRegType = attr->getRegisterType();
2443 if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
2444 int otherRegTypeNum = static_cast<int>(otherRegType);
2445 S.Diag(TheDecl->getLocation(),
2446 diag::err_hlsl_duplicate_register_annotation)
2447 << otherRegTypeNum;
2448 return false;
2449 }
2450 RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
2451 }
2452 }
2453 return true;
2454}
2455
2457 Decl *D, RegisterType RegType,
2458 bool SpecifiedSpace) {
2459
2460 // exactly one of these two types should be set
2461 assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
2462 (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
2463 "expecting VarDecl or HLSLBufferDecl");
2464
2465 // check if the declaration contains resource matching the register type
2466 if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
2467 return false;
2468
2469 // next, if multiple register annotations exist, check that none conflict.
2470 return ValidateMultipleRegisterAnnotations(S, D, RegType);
2471}
2472
2473// return false if the slot count exceeds the limit, true otherwise
2474static bool AccumulateHLSLResourceSlots(QualType Ty, uint64_t &StartSlot,
2475 const uint64_t &Limit,
2476 const ResourceClass ResClass,
2477 ASTContext &Ctx,
2478 uint64_t ArrayCount = 1) {
2479 Ty = Ty.getCanonicalType();
2480 const Type *T = Ty.getTypePtr();
2481
2482 // Early exit if already overflowed
2483 if (StartSlot > Limit)
2484 return false;
2485
2486 // Case 1: array type
2487 if (const auto *AT = dyn_cast<ArrayType>(T)) {
2488 uint64_t Count = 1;
2489
2490 if (const auto *CAT = dyn_cast<ConstantArrayType>(AT))
2491 Count = CAT->getSize().getZExtValue();
2492
2493 QualType ElemTy = AT->getElementType();
2494 return AccumulateHLSLResourceSlots(ElemTy, StartSlot, Limit, ResClass, Ctx,
2495 ArrayCount * Count);
2496 }
2497
2498 // Case 2: resource leaf
2499 if (auto ResTy = dyn_cast<HLSLAttributedResourceType>(T)) {
2500 // First ensure this resource counts towards the corresponding
2501 // register type limit.
2502 if (ResTy->getAttrs().ResourceClass != ResClass)
2503 return true;
2504
2505 // Validate highest slot used
2506 uint64_t EndSlot = StartSlot + ArrayCount - 1;
2507 if (EndSlot > Limit)
2508 return false;
2509
2510 // Advance SlotCount past the consumed range
2511 StartSlot = EndSlot + 1;
2512 return true;
2513 }
2514
2515 // Case 3: struct / record
2516 if (const auto *RT = dyn_cast<RecordType>(T)) {
2517 const RecordDecl *RD = RT->getDecl();
2518
2519 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
2520 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
2521 if (!AccumulateHLSLResourceSlots(Base.getType(), StartSlot, Limit,
2522 ResClass, Ctx, ArrayCount))
2523 return false;
2524 }
2525 }
2526
2527 for (const FieldDecl *Field : RD->fields()) {
2528 if (!AccumulateHLSLResourceSlots(Field->getType(), StartSlot, Limit,
2529 ResClass, Ctx, ArrayCount))
2530 return false;
2531 }
2532
2533 return true;
2534 }
2535
2536 // Case 4: everything else
2537 return true;
2538}
2539
2540// return true if there is something invalid, false otherwise
2541static bool ValidateRegisterNumber(uint64_t SlotNum, Decl *TheDecl,
2542 ASTContext &Ctx, RegisterType RegTy) {
2543 const uint64_t Limit = UINT32_MAX;
2544 if (SlotNum > Limit)
2545 return true;
2546
2547 // after verifying the number doesn't exceed uint32max, we don't need
2548 // to look further into c or i register types
2549 if (RegTy == RegisterType::C || RegTy == RegisterType::I)
2550 return false;
2551
2552 if (VarDecl *VD = dyn_cast<VarDecl>(TheDecl)) {
2553 uint64_t BaseSlot = SlotNum;
2554
2555 if (!AccumulateHLSLResourceSlots(VD->getType(), SlotNum, Limit,
2556 getResourceClass(RegTy), Ctx))
2557 return true;
2558
2559 // After AccumulateHLSLResourceSlots runs, SlotNum is now
2560 // the first free slot; last used was SlotNum - 1
2561 return (BaseSlot > Limit);
2562 }
2563 // handle the cbuffer/tbuffer case
2564 if (isa<HLSLBufferDecl>(TheDecl))
2565 // resources cannot be put within a cbuffer, so no need
2566 // to analyze the structure since the register number
2567 // won't be pushed any higher.
2568 return (SlotNum > Limit);
2569
2570 // we don't expect any other decl type, so fail
2571 llvm_unreachable("unexpected decl type");
2572}
2573
2575 if (VarDecl *VD = dyn_cast<VarDecl>(TheDecl)) {
2576 QualType Ty = VD->getType();
2577 if (const auto *IAT = dyn_cast<IncompleteArrayType>(Ty))
2578 Ty = IAT->getElementType();
2579 if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(), Ty,
2580 diag::err_incomplete_type))
2581 return;
2582 }
2583
2584 StringRef Slot = "";
2585 StringRef Space = "";
2586 SourceLocation SlotLoc, SpaceLoc;
2587
2588 if (!AL.isArgIdent(0)) {
2589 Diag(AL.getLoc(), diag::err_attribute_argument_type)
2590 << AL << AANT_ArgumentIdentifier;
2591 return;
2592 }
2593 IdentifierLoc *Loc = AL.getArgAsIdent(0);
2594
2595 if (AL.getNumArgs() == 2) {
2596 Slot = Loc->getIdentifierInfo()->getName();
2597 SlotLoc = Loc->getLoc();
2598 if (!AL.isArgIdent(1)) {
2599 Diag(AL.getLoc(), diag::err_attribute_argument_type)
2600 << AL << AANT_ArgumentIdentifier;
2601 return;
2602 }
2603 Loc = AL.getArgAsIdent(1);
2604 Space = Loc->getIdentifierInfo()->getName();
2605 SpaceLoc = Loc->getLoc();
2606 } else {
2607 StringRef Str = Loc->getIdentifierInfo()->getName();
2608 if (Str.starts_with("space")) {
2609 Space = Str;
2610 SpaceLoc = Loc->getLoc();
2611 } else {
2612 Slot = Str;
2613 SlotLoc = Loc->getLoc();
2614 Space = "space0";
2615 }
2616 }
2617
2618 RegisterType RegType = RegisterType::SRV;
2619 std::optional<unsigned> SlotNum;
2620 unsigned SpaceNum = 0;
2621
2622 // Validate slot
2623 if (!Slot.empty()) {
2624 if (!convertToRegisterType(Slot, &RegType)) {
2625 Diag(SlotLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1);
2626 return;
2627 }
2628 if (RegType == RegisterType::I) {
2629 Diag(SlotLoc, diag::warn_hlsl_deprecated_register_type_i);
2630 return;
2631 }
2632 const StringRef SlotNumStr = Slot.substr(1);
2633
2634 uint64_t N;
2635
2636 // validate that the slot number is a non-empty number
2637 if (SlotNumStr.getAsInteger(10, N)) {
2638 Diag(SlotLoc, diag::err_hlsl_unsupported_register_number);
2639 return;
2640 }
2641
2642 // Validate register number. It should not exceed UINT32_MAX,
2643 // including if the resource type is an array that starts
2644 // before UINT32_MAX, but ends afterwards.
2645 if (ValidateRegisterNumber(N, TheDecl, getASTContext(), RegType)) {
2646 Diag(SlotLoc, diag::err_hlsl_register_number_too_large);
2647 return;
2648 }
2649
2650 // the slot number has been validated and does not exceed UINT32_MAX
2651 SlotNum = (unsigned)N;
2652 }
2653
2654 // Validate space
2655 if (!Space.starts_with("space")) {
2656 Diag(SpaceLoc, diag::err_hlsl_expected_space) << Space;
2657 return;
2658 }
2659 StringRef SpaceNumStr = Space.substr(5);
2660 if (SpaceNumStr.getAsInteger(10, SpaceNum)) {
2661 Diag(SpaceLoc, diag::err_hlsl_expected_space) << Space;
2662 return;
2663 }
2664
2665 // If we have slot, diagnose it is the right register type for the decl
2666 if (SlotNum.has_value())
2667 if (!DiagnoseHLSLRegisterAttribute(SemaRef, SlotLoc, TheDecl, RegType,
2668 !SpaceLoc.isInvalid()))
2669 return;
2670
2671 HLSLResourceBindingAttr *NewAttr =
2672 HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL);
2673 if (NewAttr) {
2674 NewAttr->setBinding(RegType, SlotNum, SpaceNum);
2675 TheDecl->addAttr(NewAttr);
2676 }
2677}
2678
2680 HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr(
2681 D, AL,
2682 static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
2683 if (NewAttr)
2684 D->addAttr(NewAttr);
2685}
2686
2687namespace {
2688
2689/// This class implements HLSL availability diagnostics for default
2690/// and relaxed mode
2691///
2692/// The goal of this diagnostic is to emit an error or warning when an
2693/// unavailable API is found in code that is reachable from the shader
2694/// entry function or from an exported function (when compiling a shader
2695/// library).
2696///
2697/// This is done by traversing the AST of all shader entry point functions
2698/// and of all exported functions, and any functions that are referenced
2699/// from this AST. In other words, any functions that are reachable from
2700/// the entry points.
2701class DiagnoseHLSLAvailability : public DynamicRecursiveASTVisitor {
2702 Sema &SemaRef;
2703
2704 // Stack of functions to be scaned
2706
2707 // Tracks which environments functions have been scanned in.
2708 //
2709 // Maps FunctionDecl to an unsigned number that represents the set of shader
2710 // environments the function has been scanned for.
2711 // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
2712 // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
2713 // (verified by static_asserts in Triple.cpp), we can use it to index
2714 // individual bits in the set, as long as we shift the values to start with 0
2715 // by subtracting the value of llvm::Triple::Pixel first.
2716 //
2717 // The N'th bit in the set will be set if the function has been scanned
2718 // in shader environment whose llvm::Triple::EnvironmentType integer value
2719 // equals (llvm::Triple::Pixel + N).
2720 //
2721 // For example, if a function has been scanned in compute and pixel stage
2722 // environment, the value will be 0x21 (100001 binary) because:
2723 //
2724 // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
2725 // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
2726 //
2727 // A FunctionDecl is mapped to 0 (or not included in the map) if it has not
2728 // been scanned in any environment.
2729 llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
2730
2731 // Do not access these directly, use the get/set methods below to make
2732 // sure the values are in sync
2733 llvm::Triple::EnvironmentType CurrentShaderEnvironment;
2734 unsigned CurrentShaderStageBit;
2735
2736 // True if scanning a function that was already scanned in a different
2737 // shader stage context, and therefore we should not report issues that
2738 // depend only on shader model version because they would be duplicate.
2739 bool ReportOnlyShaderStageIssues;
2740
2741 // Helper methods for dealing with current stage context / environment
2742 void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
2743 static_assert(sizeof(unsigned) >= 4);
2744 assert(HLSLShaderAttr::isValidShaderType(ShaderType));
2745 assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
2746 "ShaderType is too big for this bitmap"); // 31 is reserved for
2747 // "unknown"
2748
2749 unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
2750 CurrentShaderEnvironment = ShaderType;
2751 CurrentShaderStageBit = (1 << bitmapIndex);
2752 }
2753
2754 void SetUnknownShaderStageContext() {
2755 CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment;
2756 CurrentShaderStageBit = (1 << 31);
2757 }
2758
2759 llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const {
2760 return CurrentShaderEnvironment;
2761 }
2762
2763 bool InUnknownShaderStageContext() const {
2764 return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment;
2765 }
2766
2767 // Helper methods for dealing with shader stage bitmap
2768 void AddToScannedFunctions(const FunctionDecl *FD) {
2769 unsigned &ScannedStages = ScannedDecls[FD];
2770 ScannedStages |= CurrentShaderStageBit;
2771 }
2772
2773 unsigned GetScannedStages(const FunctionDecl *FD) { return ScannedDecls[FD]; }
2774
2775 bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) {
2776 return WasAlreadyScannedInCurrentStage(GetScannedStages(FD));
2777 }
2778
2779 bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) {
2780 return ScannerStages & CurrentShaderStageBit;
2781 }
2782
2783 static bool NeverBeenScanned(unsigned ScannedStages) {
2784 return ScannedStages == 0;
2785 }
2786
2787 // Scanning methods
2788 void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr);
2789 void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA,
2790 SourceRange Range);
2791 const AvailabilityAttr *FindAvailabilityAttr(const Decl *D);
2792 bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA);
2793
2794public:
2795 DiagnoseHLSLAvailability(Sema &SemaRef)
2796 : SemaRef(SemaRef),
2797 CurrentShaderEnvironment(llvm::Triple::UnknownEnvironment),
2798 CurrentShaderStageBit(0), ReportOnlyShaderStageIssues(false) {}
2799
2800 // AST traversal methods
2801 void RunOnTranslationUnit(const TranslationUnitDecl *TU);
2802 void RunOnFunction(const FunctionDecl *FD);
2803
2804 bool VisitDeclRefExpr(DeclRefExpr *DRE) override {
2805 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(DRE->getDecl());
2806 if (FD)
2807 HandleFunctionOrMethodRef(FD, DRE);
2808 return true;
2809 }
2810
2811 bool VisitMemberExpr(MemberExpr *ME) override {
2812 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(ME->getMemberDecl());
2813 if (FD)
2814 HandleFunctionOrMethodRef(FD, ME);
2815 return true;
2816 }
2817};
2818
2819void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD,
2820 Expr *RefExpr) {
2821 assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) &&
2822 "expected DeclRefExpr or MemberExpr");
2823
2824 // has a definition -> add to stack to be scanned
2825 const FunctionDecl *FDWithBody = nullptr;
2826 if (FD->hasBody(FDWithBody)) {
2827 if (!WasAlreadyScannedInCurrentStage(FDWithBody))
2828 DeclsToScan.push_back(FDWithBody);
2829 return;
2830 }
2831
2832 // no body -> diagnose availability
2833 const AvailabilityAttr *AA = FindAvailabilityAttr(FD);
2834 if (AA)
2835 CheckDeclAvailability(
2836 FD, AA, SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc()));
2837}
2838
2839void DiagnoseHLSLAvailability::RunOnTranslationUnit(
2840 const TranslationUnitDecl *TU) {
2841
2842 // Iterate over all shader entry functions and library exports, and for those
2843 // that have a body (definiton), run diag scan on each, setting appropriate
2844 // shader environment context based on whether it is a shader entry function
2845 // or an exported function. Exported functions can be in namespaces and in
2846 // export declarations so we need to scan those declaration contexts as well.
2848 DeclContextsToScan.push_back(TU);
2849
2850 while (!DeclContextsToScan.empty()) {
2851 const DeclContext *DC = DeclContextsToScan.pop_back_val();
2852 for (auto &D : DC->decls()) {
2853 // do not scan implicit declaration generated by the implementation
2854 if (D->isImplicit())
2855 continue;
2856
2857 // for namespace or export declaration add the context to the list to be
2858 // scanned later
2859 if (llvm::dyn_cast<NamespaceDecl>(D) || llvm::dyn_cast<ExportDecl>(D)) {
2860 DeclContextsToScan.push_back(llvm::dyn_cast<DeclContext>(D));
2861 continue;
2862 }
2863
2864 // skip over other decls or function decls without body
2865 const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(D);
2866 if (!FD || !FD->isThisDeclarationADefinition())
2867 continue;
2868
2869 // shader entry point
2870 if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) {
2871 SetShaderStageContext(ShaderAttr->getType());
2872 RunOnFunction(FD);
2873 continue;
2874 }
2875 // exported library function
2876 // FIXME: replace this loop with external linkage check once issue #92071
2877 // is resolved
2878 bool isExport = FD->isInExportDeclContext();
2879 if (!isExport) {
2880 for (const auto *Redecl : FD->redecls()) {
2881 if (Redecl->isInExportDeclContext()) {
2882 isExport = true;
2883 break;
2884 }
2885 }
2886 }
2887 if (isExport) {
2888 SetUnknownShaderStageContext();
2889 RunOnFunction(FD);
2890 continue;
2891 }
2892 }
2893 }
2894}
2895
2896void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) {
2897 assert(DeclsToScan.empty() && "DeclsToScan should be empty");
2898 DeclsToScan.push_back(FD);
2899
2900 while (!DeclsToScan.empty()) {
2901 // Take one decl from the stack and check it by traversing its AST.
2902 // For any CallExpr found during the traversal add it's callee to the top of
2903 // the stack to be processed next. Functions already processed are stored in
2904 // ScannedDecls.
2905 const FunctionDecl *FD = DeclsToScan.pop_back_val();
2906
2907 // Decl was already scanned
2908 const unsigned ScannedStages = GetScannedStages(FD);
2909 if (WasAlreadyScannedInCurrentStage(ScannedStages))
2910 continue;
2911
2912 ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages);
2913
2914 AddToScannedFunctions(FD);
2915 TraverseStmt(FD->getBody());
2916 }
2917}
2918
2919bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone(
2920 const AvailabilityAttr *AA) {
2921 const IdentifierInfo *IIEnvironment = AA->getEnvironment();
2922 if (!IIEnvironment)
2923 return true;
2924
2925 llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment();
2926 if (CurrentEnv == llvm::Triple::UnknownEnvironment)
2927 return false;
2928
2929 llvm::Triple::EnvironmentType AttrEnv =
2930 AvailabilityAttr::getEnvironmentType(IIEnvironment->getName());
2931
2932 return CurrentEnv == AttrEnv;
2933}
2934
2935const AvailabilityAttr *
2936DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) {
2937 AvailabilityAttr const *PartialMatch = nullptr;
2938 // Check each AvailabilityAttr to find the one for this platform.
2939 // For multiple attributes with the same platform try to find one for this
2940 // environment.
2941 for (const auto *A : D->attrs()) {
2942 if (const auto *Avail = dyn_cast<AvailabilityAttr>(A)) {
2943 const AvailabilityAttr *EffectiveAvail = Avail->getEffectiveAttr();
2944 StringRef AttrPlatform = EffectiveAvail->getPlatform()->getName();
2945 StringRef TargetPlatform =
2947
2948 // Match the platform name.
2949 if (AttrPlatform == TargetPlatform) {
2950 // Find the best matching attribute for this environment
2951 if (HasMatchingEnvironmentOrNone(EffectiveAvail))
2952 return Avail;
2953 PartialMatch = Avail;
2954 }
2955 }
2956 }
2957 return PartialMatch;
2958}
2959
2960// Check availability against target shader model version and current shader
2961// stage and emit diagnostic
2962void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,
2963 const AvailabilityAttr *AA,
2964 SourceRange Range) {
2965
2966 const IdentifierInfo *IIEnv = AA->getEnvironment();
2967
2968 if (!IIEnv) {
2969 // The availability attribute does not have environment -> it depends only
2970 // on shader model version and not on specific the shader stage.
2971
2972 // Skip emitting the diagnostics if the diagnostic mode is set to
2973 // strict (-fhlsl-strict-availability) because all relevant diagnostics
2974 // were already emitted in the DiagnoseUnguardedAvailability scan
2975 // (SemaAvailability.cpp).
2976 if (SemaRef.getLangOpts().HLSLStrictAvailability)
2977 return;
2978
2979 // Do not report shader-stage-independent issues if scanning a function
2980 // that was already scanned in a different shader stage context (they would
2981 // be duplicate)
2982 if (ReportOnlyShaderStageIssues)
2983 return;
2984
2985 } else {
2986 // The availability attribute has environment -> we need to know
2987 // the current stage context to property diagnose it.
2988 if (InUnknownShaderStageContext())
2989 return;
2990 }
2991
2992 // Check introduced version and if environment matches
2993 bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA);
2994 VersionTuple Introduced = AA->getIntroduced();
2995 VersionTuple TargetVersion =
2997
2998 if (TargetVersion >= Introduced && EnvironmentMatches)
2999 return;
3000
3001 // Emit diagnostic message
3002 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
3003 llvm::StringRef PlatformName(
3004 AvailabilityAttr::getPrettyPlatformName(TI.getPlatformName()));
3005
3006 llvm::StringRef CurrentEnvStr =
3007 llvm::Triple::getEnvironmentTypeName(GetCurrentShaderEnvironment());
3008
3009 llvm::StringRef AttrEnvStr =
3010 AA->getEnvironment() ? AA->getEnvironment()->getName() : "";
3011 bool UseEnvironment = !AttrEnvStr.empty();
3012
3013 if (EnvironmentMatches) {
3014 SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability)
3015 << Range << D << PlatformName << Introduced.getAsString()
3016 << UseEnvironment << CurrentEnvStr;
3017 } else {
3018 SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability_unavailable)
3019 << Range << D;
3020 }
3021
3022 SemaRef.Diag(D->getLocation(), diag::note_partial_availability_specified_here)
3023 << D << PlatformName << Introduced.getAsString()
3024 << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString()
3025 << UseEnvironment << AttrEnvStr << CurrentEnvStr;
3026}
3027
3028} // namespace
3029
3031 // process default CBuffer - create buffer layout struct and invoke codegenCGH
3032 if (!DefaultCBufferDecls.empty()) {
3034 SemaRef.getASTContext(), SemaRef.getCurLexicalContext(),
3035 DefaultCBufferDecls);
3036 addImplicitBindingAttrToDecl(SemaRef, DefaultCBuffer, RegisterType::CBuffer,
3038 SemaRef.getCurLexicalContext()->addDecl(DefaultCBuffer);
3040
3041 // Set HasValidPackoffset if any of the decls has a register(c#) annotation;
3042 for (const Decl *VD : DefaultCBufferDecls) {
3043 const HLSLResourceBindingAttr *RBA =
3044 VD->getAttr<HLSLResourceBindingAttr>();
3045 if (RBA && RBA->hasRegisterSlot() &&
3046 RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) {
3047 DefaultCBuffer->setHasValidPackoffset(true);
3048 break;
3049 }
3050 }
3051
3052 DeclGroupRef DG(DefaultCBuffer);
3053 SemaRef.Consumer.HandleTopLevelDecl(DG);
3054 }
3055 diagnoseAvailabilityViolations(TU);
3056}
3057
3058// For resource member access through a global struct array, verify that the
3059// array index selecting the struct element is a constant integer expression.
3060// Returns false if the member expression is invalid.
3062 assert((ME->getType()->isHLSLResourceRecord() ||
3064 "expected member expr to have resource record type or array of them");
3065
3066 // Walk the AST from MemberExpr to the VarDecl of the parent struct instance
3067 // and take note of any non-constant array indexing along the way. If the
3068 // VarDecl we find is a global variable, report error if there was any
3069 // non-constant array index in the resource member access along the way.
3070 const Expr *NonConstIndexExpr = nullptr;
3071 const Expr *E = ME->getBase();
3072 while (E) {
3073 if (const DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E)) {
3074 if (!NonConstIndexExpr)
3075 return true;
3076
3077 const VarDecl *VD = cast<VarDecl>(DRE->getDecl());
3078 if (!VD->hasGlobalStorage())
3079 return true;
3080
3081 SemaRef.Diag(NonConstIndexExpr->getExprLoc(),
3082 diag::err_hlsl_resource_member_array_access_not_constant);
3083 return false;
3084 }
3085
3086 if (const auto *ASE = dyn_cast<ArraySubscriptExpr>(E)) {
3087 const Expr *IdxExpr = ASE->getIdx();
3088 if (!IdxExpr->isIntegerConstantExpr(SemaRef.getASTContext()))
3089 NonConstIndexExpr = IdxExpr;
3090 E = ASE->getBase();
3091 } else if (const auto *SubME = dyn_cast<MemberExpr>(E)) {
3092 E = SubME->getBase();
3093 } else if (const auto *ICE = dyn_cast<ImplicitCastExpr>(E)) {
3094 E = ICE->getSubExpr();
3095 } else {
3096 llvm_unreachable("unexpected expr type in resource member access");
3097 }
3098 }
3099 return true;
3100}
3101
3102void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
3103 // Skip running the diagnostics scan if the diagnostic mode is
3104 // strict (-fhlsl-strict-availability) and the target shader stage is known
3105 // because all relevant diagnostics were already emitted in the
3106 // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp).
3108 if (SemaRef.getLangOpts().HLSLStrictAvailability &&
3109 TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library)
3110 return;
3111
3112 DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
3113}
3114
3115static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
3116 assert(TheCall->getNumArgs() > 1);
3117 QualType ArgTy0 = TheCall->getArg(0)->getType();
3118
3119 for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) {
3121 ArgTy0, TheCall->getArg(I)->getType())) {
3122 S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
3123 << TheCall->getDirectCallee() << /*useAllTerminology*/ true
3124 << SourceRange(TheCall->getArg(0)->getBeginLoc(),
3125 TheCall->getArg(N - 1)->getEndLoc());
3126 return true;
3127 }
3128 }
3129 return false;
3130}
3131
3133 QualType ArgType = Arg->getType();
3135 S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
3136 << ArgType << ExpectedType << 1 << 0 << 0;
3137 return true;
3138 }
3139 return false;
3140}
3141
3143 Sema *S, CallExpr *TheCall,
3144 llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
3145 clang::QualType PassedType)>
3146 Check) {
3147 for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
3148 Expr *Arg = TheCall->getArg(I);
3149 if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
3150 return true;
3151 }
3152 return false;
3153}
3154
3156 int ArgOrdinal,
3157 clang::QualType PassedType) {
3158 clang::QualType BaseType =
3159 PassedType->isVectorType()
3160 ? PassedType->castAs<clang::VectorType>()->getElementType()
3161 : PassedType;
3162 if (!BaseType->isFloat32Type())
3163 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3164 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3165 << /* float */ 1 << PassedType;
3166 return false;
3167}
3168
3170 int ArgOrdinal,
3171 clang::QualType PassedType) {
3172 clang::QualType BaseType =
3173 PassedType->isVectorType()
3174 ? PassedType->castAs<clang::VectorType>()->getElementType()
3175 : PassedType;
3176 if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
3177 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3178 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3179 << /* half or float */ 2 << PassedType;
3180 return false;
3181}
3182
3184 int ArgOrdinal,
3185 clang::QualType PassedType) {
3186 clang::QualType BaseType =
3187 PassedType->isVectorType()
3188 ? PassedType->castAs<clang::VectorType>()->getElementType()
3189 : PassedType->isMatrixType()
3190 ? PassedType->castAs<clang::MatrixType>()->getElementType()
3191 : PassedType;
3192 if (!BaseType->isDoubleType()) {
3193 // FIXME: adopt standard `err_builtin_invalid_arg_type` instead of using
3194 // this custom error.
3195 return S->Diag(Loc, diag::err_builtin_requires_double_type)
3196 << ArgOrdinal << PassedType;
3197 }
3198
3199 return false;
3200}
3201
3202static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
3203 unsigned ArgIndex) {
3204 auto *Arg = TheCall->getArg(ArgIndex);
3205 SourceLocation OrigLoc = Arg->getExprLoc();
3206 if (Arg->IgnoreCasts()->isModifiableLvalue(S->Context, &OrigLoc) ==
3208 return false;
3209 S->Diag(OrigLoc, diag::error_hlsl_inout_lvalue) << Arg << 0;
3210 return true;
3211}
3212
3213static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
3214 clang::QualType PassedType) {
3215 const auto *VecTy = PassedType->getAs<VectorType>();
3216 if (!VecTy)
3217 return false;
3218
3219 if (VecTy->getElementType()->isDoubleType())
3220 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3221 << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
3222 << PassedType;
3223 return false;
3224}
3225
3227 int ArgOrdinal,
3228 clang::QualType PassedType) {
3229 if (!PassedType->hasIntegerRepresentation() &&
3230 !PassedType->hasFloatingRepresentation())
3231 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3232 << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
3233 << /* fp */ 1 << PassedType;
3234 return false;
3235}
3236
3238 int ArgOrdinal,
3239 clang::QualType PassedType) {
3240 if (auto *VecTy = PassedType->getAs<VectorType>())
3241 if (VecTy->getElementType()->isUnsignedIntegerType())
3242 return false;
3243
3244 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3245 << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
3246 << PassedType;
3247}
3248
3249// checks for unsigned ints of all sizes
3251 int ArgOrdinal,
3252 clang::QualType PassedType) {
3253 if (!PassedType->hasUnsignedIntegerRepresentation())
3254 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3255 << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
3256 << /* no fp */ 0 << PassedType;
3257 return false;
3258}
3259
3260static bool CheckExpectedBitWidth(Sema *S, CallExpr *TheCall,
3261 unsigned ArgOrdinal, unsigned Width) {
3262 QualType ArgTy = TheCall->getArg(0)->getType();
3263 if (auto *VTy = ArgTy->getAs<VectorType>())
3264 ArgTy = VTy->getElementType();
3265 // ensure arg type has expected bit width
3266 uint64_t ElementBitCount =
3268 if (ElementBitCount != Width) {
3269 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3270 diag::err_integer_incorrect_bit_count)
3271 << Width << ElementBitCount;
3272 return true;
3273 }
3274 return false;
3275}
3276
3278 QualType ReturnType) {
3279 auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>();
3280 if (VecTyA)
3281 ReturnType =
3282 S->Context.getExtVectorType(ReturnType, VecTyA->getNumElements());
3283
3284 TheCall->setType(ReturnType);
3285}
3286
3287static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
3288 unsigned ArgIndex) {
3289 assert(TheCall->getNumArgs() >= ArgIndex);
3290 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3291 auto *VTy = ArgType->getAs<VectorType>();
3292 // not the scalar or vector<scalar>
3293 if (!(S->Context.hasSameUnqualifiedType(ArgType, Scalar) ||
3294 (VTy &&
3295 S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar)))) {
3296 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3297 diag::err_typecheck_expect_scalar_or_vector)
3298 << ArgType << Scalar;
3299 return true;
3300 }
3301 return false;
3302}
3303
3305 QualType Scalar, unsigned ArgIndex) {
3306 assert(TheCall->getNumArgs() > ArgIndex);
3307
3308 Expr *Arg = TheCall->getArg(ArgIndex);
3309 QualType ArgType = Arg->getType();
3310
3311 // Scalar: T
3312 if (S->Context.hasSameUnqualifiedType(ArgType, Scalar))
3313 return false;
3314
3315 // Vector: vector<T>
3316 if (const auto *VTy = ArgType->getAs<VectorType>()) {
3317 if (S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar))
3318 return false;
3319 }
3320
3321 // Matrix: ConstantMatrixType with element type T
3322 if (const auto *MTy = ArgType->getAs<ConstantMatrixType>()) {
3323 if (S->Context.hasSameUnqualifiedType(MTy->getElementType(), Scalar))
3324 return false;
3325 }
3326
3327 // Not a scalar/vector/matrix-of-scalar
3328 S->Diag(Arg->getBeginLoc(),
3329 diag::err_typecheck_expect_scalar_or_vector_or_matrix)
3330 << ArgType << Scalar;
3331 return true;
3332}
3333
3334static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
3335 unsigned ArgIndex) {
3336 assert(TheCall->getNumArgs() >= ArgIndex);
3337 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3338 auto *VTy = ArgType->getAs<VectorType>();
3339 // not the scalar or vector<scalar>
3340 if (!(ArgType->isScalarType() ||
3341 (VTy && VTy->getElementType()->isScalarType()))) {
3342 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3343 diag::err_typecheck_expect_any_scalar_or_vector)
3344 << ArgType << 1;
3345 return true;
3346 }
3347 return false;
3348}
3349
3350// Check that the argument is not a bool or vector<bool>
3351// Returns true on error
3353 unsigned ArgIndex) {
3354 QualType BoolType = S->getASTContext().BoolTy;
3355 assert(ArgIndex < TheCall->getNumArgs());
3356 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3357 auto *VTy = ArgType->getAs<VectorType>();
3358 // is the bool or vector<bool>
3359 if (S->Context.hasSameUnqualifiedType(ArgType, BoolType) ||
3360 (VTy &&
3361 S->Context.hasSameUnqualifiedType(VTy->getElementType(), BoolType))) {
3362 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3363 diag::err_typecheck_expect_any_scalar_or_vector)
3364 << ArgType << 0;
3365 return true;
3366 }
3367 return false;
3368}
3369
3370static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
3371 if (CheckNotBoolScalarOrVector(S, TheCall, 0))
3372 return true;
3373 return false;
3374}
3375
3376static bool CheckWavePrefix(Sema *S, CallExpr *TheCall) {
3377 if (CheckNotBoolScalarOrVector(S, TheCall, 0))
3378 return true;
3379 return false;
3380}
3381
3382static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
3383 assert(TheCall->getNumArgs() == 3);
3384 Expr *Arg1 = TheCall->getArg(1);
3385 Expr *Arg2 = TheCall->getArg(2);
3386 if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
3387 S->Diag(TheCall->getBeginLoc(),
3388 diag::err_typecheck_call_different_arg_types)
3389 << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
3390 << Arg2->getSourceRange();
3391 return true;
3392 }
3393
3394 TheCall->setType(Arg1->getType());
3395 return false;
3396}
3397
3398static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
3399 assert(TheCall->getNumArgs() == 3);
3400 Expr *Arg1 = TheCall->getArg(1);
3401 QualType Arg1Ty = Arg1->getType();
3402 Expr *Arg2 = TheCall->getArg(2);
3403 QualType Arg2Ty = Arg2->getType();
3404
3405 QualType Arg1ScalarTy = Arg1Ty;
3406 if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
3407 Arg1ScalarTy = VTy->getElementType();
3408
3409 QualType Arg2ScalarTy = Arg2Ty;
3410 if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
3411 Arg2ScalarTy = VTy->getElementType();
3412
3413 if (!S->Context.hasSameUnqualifiedType(Arg1ScalarTy, Arg2ScalarTy))
3414 S->Diag(Arg1->getBeginLoc(), diag::err_hlsl_builtin_scalar_vector_mismatch)
3415 << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
3416
3417 QualType Arg0Ty = TheCall->getArg(0)->getType();
3418 unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
3419 unsigned Arg1Length = Arg1Ty->isVectorType()
3420 ? Arg1Ty->getAs<VectorType>()->getNumElements()
3421 : 0;
3422 unsigned Arg2Length = Arg2Ty->isVectorType()
3423 ? Arg2Ty->getAs<VectorType>()->getNumElements()
3424 : 0;
3425 if (Arg1Length > 0 && Arg0Length != Arg1Length) {
3426 S->Diag(TheCall->getBeginLoc(),
3427 diag::err_typecheck_vector_lengths_not_equal)
3428 << Arg0Ty << Arg1Ty << TheCall->getArg(0)->getSourceRange()
3429 << Arg1->getSourceRange();
3430 return true;
3431 }
3432
3433 if (Arg2Length > 0 && Arg0Length != Arg2Length) {
3434 S->Diag(TheCall->getBeginLoc(),
3435 diag::err_typecheck_vector_lengths_not_equal)
3436 << Arg0Ty << Arg2Ty << TheCall->getArg(0)->getSourceRange()
3437 << Arg2->getSourceRange();
3438 return true;
3439 }
3440
3441 TheCall->setType(
3442 S->getASTContext().getExtVectorType(Arg1ScalarTy, Arg0Length));
3443 return false;
3444}
3445
3446static bool CheckIndexType(Sema *S, CallExpr *TheCall, unsigned IndexArgIndex) {
3447 assert(TheCall->getNumArgs() > IndexArgIndex && "Index argument missing");
3448 QualType ArgType = TheCall->getArg(IndexArgIndex)->getType();
3449 QualType IndexTy = ArgType;
3450 unsigned int ActualDim = 1;
3451 if (const auto *VTy = IndexTy->getAs<VectorType>()) {
3452 ActualDim = VTy->getNumElements();
3453 IndexTy = VTy->getElementType();
3454 }
3455 if (!IndexTy->isIntegerType()) {
3456 S->Diag(TheCall->getArg(IndexArgIndex)->getBeginLoc(),
3457 diag::err_typecheck_expect_int)
3458 << ArgType;
3459 return true;
3460 }
3461
3462 QualType ResourceArgTy = TheCall->getArg(0)->getType();
3463 const HLSLAttributedResourceType *ResTy =
3464 ResourceArgTy.getTypePtr()->getAs<HLSLAttributedResourceType>();
3465 assert(ResTy && "Resource argument must be a resource");
3466 HLSLAttributedResourceType::Attributes ResAttrs = ResTy->getAttrs();
3467
3468 unsigned int ExpectedDim = 1;
3469 if (ResAttrs.ResourceDimension != llvm::dxil::ResourceDimension::Unknown)
3470 ExpectedDim = getResourceDimensions(ResAttrs.ResourceDimension);
3471
3472 if (ActualDim != ExpectedDim) {
3473 S->Diag(TheCall->getArg(IndexArgIndex)->getBeginLoc(),
3474 diag::err_hlsl_builtin_resource_coordinate_dimension_mismatch)
3475 << cast<NamedDecl>(TheCall->getCalleeDecl()) << ExpectedDim
3476 << ActualDim;
3477 return true;
3478 }
3479
3480 return false;
3481}
3482
3484 Sema *S, CallExpr *TheCall, unsigned ArgIndex,
3485 llvm::function_ref<bool(const HLSLAttributedResourceType *ResType)> Check =
3486 nullptr) {
3487 assert(TheCall->getNumArgs() >= ArgIndex);
3488 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3489 const HLSLAttributedResourceType *ResTy =
3490 ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>();
3491 if (!ResTy) {
3492 S->Diag(TheCall->getArg(ArgIndex)->getBeginLoc(),
3493 diag::err_typecheck_expect_hlsl_resource)
3494 << ArgType;
3495 return true;
3496 }
3497 if (Check && Check(ResTy)) {
3498 S->Diag(TheCall->getArg(ArgIndex)->getExprLoc(),
3499 diag::err_invalid_hlsl_resource_type)
3500 << ArgType;
3501 return true;
3502 }
3503 return false;
3504}
3505
3506static bool CheckVectorElementCount(Sema *S, QualType PassedType,
3507 QualType BaseType, unsigned ExpectedCount,
3508 SourceLocation Loc) {
3509 unsigned PassedCount = 1;
3510 if (const auto *VecTy = PassedType->getAs<VectorType>())
3511 PassedCount = VecTy->getNumElements();
3512
3513 if (PassedCount != ExpectedCount) {
3515 S->Context.getExtVectorType(BaseType, ExpectedCount);
3516 S->Diag(Loc, diag::err_typecheck_convert_incompatible)
3517 << PassedType << ExpectedType << 1 << 0 << 0;
3518 return true;
3519 }
3520 return false;
3521}
3522
3523enum class SampleKind { Sample, Bias, Grad, Level, Cmp, CmpLevelZero };
3524
3526 // Check the texture handle.
3527 if (CheckResourceHandle(&S, TheCall, 0,
3528 [](const HLSLAttributedResourceType *ResType) {
3529 return ResType->getAttrs().ResourceDimension ==
3530 llvm::dxil::ResourceDimension::Unknown;
3531 }))
3532 return true;
3533
3534 // Check the sampler handle.
3535 if (CheckResourceHandle(&S, TheCall, 1,
3536 [](const HLSLAttributedResourceType *ResType) {
3537 return ResType->getAttrs().ResourceClass !=
3538 llvm::hlsl::ResourceClass::Sampler;
3539 }))
3540 return true;
3541
3542 auto *ResourceTy =
3543 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3544
3545 // Check the location.
3546 unsigned ExpectedDim =
3547 getResourceDimensions(ResourceTy->getAttrs().ResourceDimension);
3548 if (CheckVectorElementCount(&S, TheCall->getArg(2)->getType(),
3549 S.Context.FloatTy, ExpectedDim,
3550 TheCall->getBeginLoc()))
3551 return true;
3552
3553 return false;
3554}
3555
3556static bool CheckCalculateLodBuiltin(Sema &S, CallExpr *TheCall) {
3557 if (S.checkArgCount(TheCall, 3))
3558 return true;
3559
3560 if (CheckTextureSamplerAndLocation(S, TheCall))
3561 return true;
3562
3563 TheCall->setType(S.Context.FloatTy);
3564 return false;
3565}
3566
3567static bool CheckGatherBuiltin(Sema &S, CallExpr *TheCall, bool IsCmp) {
3568 if (S.checkArgCountRange(TheCall, IsCmp ? 5 : 4, IsCmp ? 6 : 5))
3569 return true;
3570
3571 if (CheckTextureSamplerAndLocation(S, TheCall))
3572 return true;
3573
3574 unsigned NextIdx = 3;
3575 if (IsCmp) {
3576 // Check the compare value.
3577 QualType CmpTy = TheCall->getArg(NextIdx)->getType();
3578 if (!CmpTy->isFloatingType() || CmpTy->isVectorType()) {
3579 S.Diag(TheCall->getArg(NextIdx)->getBeginLoc(),
3580 diag::err_typecheck_convert_incompatible)
3581 << CmpTy << S.Context.FloatTy << 1 << 0 << 0;
3582 return true;
3583 }
3584 NextIdx++;
3585 }
3586
3587 // Check the component operand.
3588 Expr *ComponentArg = TheCall->getArg(NextIdx);
3589 QualType ComponentTy = ComponentArg->getType();
3590 if (!ComponentTy->isIntegerType() || ComponentTy->isVectorType()) {
3591 S.Diag(ComponentArg->getBeginLoc(),
3592 diag::err_typecheck_convert_incompatible)
3593 << ComponentTy << S.Context.UnsignedIntTy << 1 << 0 << 0;
3594 return true;
3595 }
3596
3597 // GatherCmp operations on Vulkan target must use component 0 (Red).
3598 if (IsCmp && S.getASTContext().getTargetInfo().getTriple().isSPIRV()) {
3599 std::optional<llvm::APSInt> ComponentOpt =
3600 ComponentArg->getIntegerConstantExpr(S.getASTContext());
3601 if (ComponentOpt) {
3602 int64_t ComponentVal = ComponentOpt->getSExtValue();
3603 if (ComponentVal != 0) {
3604 // Issue an error if the component is not 0 (Red).
3605 // 0 -> Red, 1 -> Green, 2 -> Blue, 3 -> Alpha
3606 assert(ComponentVal >= 0 && ComponentVal <= 3 &&
3607 "The component is not in the expected range.");
3608 S.Diag(ComponentArg->getBeginLoc(),
3609 diag::err_hlsl_gathercmp_invalid_component)
3610 << ComponentVal;
3611 return true;
3612 }
3613 }
3614 }
3615
3616 NextIdx++;
3617
3618 // Check the offset operand.
3619 const HLSLAttributedResourceType *ResourceTy =
3620 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3621 if (TheCall->getNumArgs() > NextIdx) {
3622 unsigned ExpectedDim =
3623 getResourceDimensions(ResourceTy->getAttrs().ResourceDimension);
3624 if (CheckVectorElementCount(&S, TheCall->getArg(NextIdx)->getType(),
3625 S.Context.IntTy, ExpectedDim,
3626 TheCall->getArg(NextIdx)->getBeginLoc()))
3627 return true;
3628 NextIdx++;
3629 }
3630
3631 assert(ResourceTy->hasContainedType() &&
3632 "Expecting a contained type for resource with a dimension "
3633 "attribute.");
3634 QualType ReturnType = ResourceTy->getContainedType();
3635
3636 if (IsCmp) {
3637 if (!ReturnType->hasFloatingRepresentation()) {
3638 S.Diag(TheCall->getBeginLoc(), diag::err_hlsl_samplecmp_requires_float);
3639 return true;
3640 }
3641 }
3642
3643 if (const auto *VecTy = ReturnType->getAs<VectorType>())
3644 ReturnType = VecTy->getElementType();
3645 ReturnType = S.Context.getExtVectorType(ReturnType, 4);
3646
3647 TheCall->setType(ReturnType);
3648
3649 return false;
3650}
3651static bool CheckLoadLevelBuiltin(Sema &S, CallExpr *TheCall) {
3652 if (S.checkArgCountRange(TheCall, 2, 3))
3653 return true;
3654
3655 // Check the texture handle.
3656 if (CheckResourceHandle(&S, TheCall, 0,
3657 [](const HLSLAttributedResourceType *ResType) {
3658 return ResType->getAttrs().ResourceDimension ==
3659 llvm::dxil::ResourceDimension::Unknown;
3660 }))
3661 return true;
3662
3663 auto *ResourceTy =
3664 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3665
3666 // Check the location + lod (int3 for Texture2D).
3667 unsigned ExpectedDim =
3668 getResourceDimensions(ResourceTy->getAttrs().ResourceDimension);
3669 QualType CoordLODTy = TheCall->getArg(1)->getType();
3670 if (CheckVectorElementCount(&S, CoordLODTy, S.Context.IntTy, ExpectedDim + 1,
3671 TheCall->getArg(1)->getBeginLoc()))
3672 return true;
3673
3674 QualType EltTy = CoordLODTy;
3675 if (const auto *VTy = EltTy->getAs<VectorType>())
3676 EltTy = VTy->getElementType();
3677 if (!EltTy->isIntegerType()) {
3678 S.Diag(TheCall->getArg(1)->getBeginLoc(), diag::err_typecheck_expect_int)
3679 << CoordLODTy;
3680 return true;
3681 }
3682
3683 // Check the offset operand.
3684 if (TheCall->getNumArgs() > 2) {
3685 if (CheckVectorElementCount(&S, TheCall->getArg(2)->getType(),
3686 S.Context.IntTy, ExpectedDim,
3687 TheCall->getArg(2)->getBeginLoc()))
3688 return true;
3689 }
3690
3691 TheCall->setType(ResourceTy->getContainedType());
3692 return false;
3693}
3694
3695static bool CheckSamplingBuiltin(Sema &S, CallExpr *TheCall, SampleKind Kind) {
3696 unsigned MinArgs, MaxArgs;
3697 if (Kind == SampleKind::Sample) {
3698 MinArgs = 3;
3699 MaxArgs = 5;
3700 } else if (Kind == SampleKind::Bias) {
3701 MinArgs = 4;
3702 MaxArgs = 6;
3703 } else if (Kind == SampleKind::Grad) {
3704 MinArgs = 5;
3705 MaxArgs = 7;
3706 } else if (Kind == SampleKind::Level) {
3707 MinArgs = 4;
3708 MaxArgs = 5;
3709 } else if (Kind == SampleKind::Cmp) {
3710 MinArgs = 4;
3711 MaxArgs = 6;
3712 } else {
3713 assert(Kind == SampleKind::CmpLevelZero);
3714 MinArgs = 4;
3715 MaxArgs = 5;
3716 }
3717
3718 if (S.checkArgCountRange(TheCall, MinArgs, MaxArgs))
3719 return true;
3720
3721 if (CheckTextureSamplerAndLocation(S, TheCall))
3722 return true;
3723
3724 const HLSLAttributedResourceType *ResourceTy =
3725 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3726 unsigned ExpectedDim =
3727 getResourceDimensions(ResourceTy->getAttrs().ResourceDimension);
3728
3729 unsigned NextIdx = 3;
3730 if (Kind == SampleKind::Bias || Kind == SampleKind::Level ||
3731 Kind == SampleKind::Cmp || Kind == SampleKind::CmpLevelZero) {
3732 // Check the bias, lod level, or compare value, depending on the kind.
3733 // All of them must be a scalar float value.
3734 QualType BiasOrLODOrCmpTy = TheCall->getArg(NextIdx)->getType();
3735 if (!BiasOrLODOrCmpTy->isFloatingType() ||
3736 BiasOrLODOrCmpTy->isVectorType()) {
3737 S.Diag(TheCall->getArg(NextIdx)->getBeginLoc(),
3738 diag::err_typecheck_convert_incompatible)
3739 << BiasOrLODOrCmpTy << S.Context.FloatTy << 1 << 0 << 0;
3740 return true;
3741 }
3742 NextIdx++;
3743 } else if (Kind == SampleKind::Grad) {
3744 // Check the DDX operand.
3745 if (CheckVectorElementCount(&S, TheCall->getArg(NextIdx)->getType(),
3746 S.Context.FloatTy, ExpectedDim,
3747 TheCall->getArg(NextIdx)->getBeginLoc()))
3748 return true;
3749
3750 // Check the DDY operand.
3751 if (CheckVectorElementCount(&S, TheCall->getArg(NextIdx + 1)->getType(),
3752 S.Context.FloatTy, ExpectedDim,
3753 TheCall->getArg(NextIdx + 1)->getBeginLoc()))
3754 return true;
3755 NextIdx += 2;
3756 }
3757
3758 // Check the offset operand.
3759 if (TheCall->getNumArgs() > NextIdx) {
3760 if (CheckVectorElementCount(&S, TheCall->getArg(NextIdx)->getType(),
3761 S.Context.IntTy, ExpectedDim,
3762 TheCall->getArg(NextIdx)->getBeginLoc()))
3763 return true;
3764 NextIdx++;
3765 }
3766
3767 // Check the clamp operand.
3768 if (Kind != SampleKind::Level && Kind != SampleKind::CmpLevelZero &&
3769 TheCall->getNumArgs() > NextIdx) {
3770 QualType ClampTy = TheCall->getArg(NextIdx)->getType();
3771 if (!ClampTy->isFloatingType() || ClampTy->isVectorType()) {
3772 S.Diag(TheCall->getArg(NextIdx)->getBeginLoc(),
3773 diag::err_typecheck_convert_incompatible)
3774 << ClampTy << S.Context.FloatTy << 1 << 0 << 0;
3775 return true;
3776 }
3777 }
3778
3779 assert(ResourceTy->hasContainedType() &&
3780 "Expecting a contained type for resource with a dimension "
3781 "attribute.");
3782 QualType ReturnType = ResourceTy->getContainedType();
3783 if (Kind == SampleKind::Cmp || Kind == SampleKind::CmpLevelZero) {
3784 if (!ReturnType->hasFloatingRepresentation()) {
3785 S.Diag(TheCall->getBeginLoc(), diag::err_hlsl_samplecmp_requires_float);
3786 return true;
3787 }
3788 ReturnType = S.Context.FloatTy;
3789 }
3790 TheCall->setType(ReturnType);
3791
3792 return false;
3793}
3794
3795// Note: returning true in this case results in CheckBuiltinFunctionCall
3796// returning an ExprError
3797bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
3798 switch (BuiltinID) {
3799 case Builtin::BI__builtin_hlsl_adduint64: {
3800 if (SemaRef.checkArgCount(TheCall, 2))
3801 return true;
3802
3803 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3805 return true;
3806
3807 // ensure arg integers are 32-bits
3808 if (CheckExpectedBitWidth(&SemaRef, TheCall, 0, 32))
3809 return true;
3810
3811 // ensure both args are vectors of total bit size of a multiple of 64
3812 auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
3813 int NumElementsArg = VTy->getNumElements();
3814 if (NumElementsArg != 2 && NumElementsArg != 4) {
3815 SemaRef.Diag(TheCall->getBeginLoc(), diag::err_vector_incorrect_bit_count)
3816 << 1 /*a multiple of*/ << 64 << NumElementsArg * 32;
3817 return true;
3818 }
3819
3820 // ensure first arg and second arg have the same type
3821 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
3822 return true;
3823
3824 ExprResult A = TheCall->getArg(0);
3825 QualType ArgTyA = A.get()->getType();
3826 // return type is the same as the input type
3827 TheCall->setType(ArgTyA);
3828 break;
3829 }
3830 case Builtin::BI__builtin_hlsl_resource_getpointer: {
3831 if (SemaRef.checkArgCount(TheCall, 2) ||
3832 CheckResourceHandle(&SemaRef, TheCall, 0) ||
3833 CheckIndexType(&SemaRef, TheCall, 1))
3834 return true;
3835
3836 auto *ResourceTy =
3837 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3838 QualType ContainedTy = ResourceTy->getContainedType();
3839 auto ReturnType =
3840 SemaRef.Context.getAddrSpaceQualType(ContainedTy, LangAS::hlsl_device);
3841 ReturnType = SemaRef.Context.getPointerType(ReturnType);
3842 TheCall->setType(ReturnType);
3843 TheCall->setValueKind(VK_LValue);
3844
3845 break;
3846 }
3847 case Builtin::BI__builtin_hlsl_resource_getpointer_typed: {
3848 if (SemaRef.checkArgCount(TheCall, 3) ||
3849 CheckResourceHandle(&SemaRef, TheCall, 0) ||
3850 CheckIndexType(&SemaRef, TheCall, 1))
3851 return true;
3852
3853 QualType ElementTy = TheCall->getArg(2)->getType();
3854 assert(ElementTy->isPointerType() &&
3855 "expected pointer type for second argument");
3856 ElementTy = ElementTy->getPointeeType();
3857
3858 // Reject array types
3859 if (ElementTy->isArrayType())
3860 return SemaRef.Diag(
3861 cast<FunctionDecl>(SemaRef.CurContext)->getPointOfInstantiation(),
3862 diag::err_invalid_use_of_array_type);
3863
3864 auto ReturnType =
3865 SemaRef.Context.getAddrSpaceQualType(ElementTy, LangAS::hlsl_device);
3866 ReturnType = SemaRef.Context.getPointerType(ReturnType);
3867 TheCall->setType(ReturnType);
3868
3869 break;
3870 }
3871 case Builtin::BI__builtin_hlsl_resource_load_with_status: {
3872 if (SemaRef.checkArgCount(TheCall, 3) ||
3873 CheckResourceHandle(&SemaRef, TheCall, 0) ||
3874 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
3875 SemaRef.getASTContext().UnsignedIntTy) ||
3876 CheckArgTypeMatches(&SemaRef, TheCall->getArg(2),
3877 SemaRef.getASTContext().UnsignedIntTy) ||
3878 CheckModifiableLValue(&SemaRef, TheCall, 2))
3879 return true;
3880
3881 auto *ResourceTy =
3882 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3883 QualType ReturnType = ResourceTy->getContainedType();
3884 TheCall->setType(ReturnType);
3885
3886 break;
3887 }
3888 case Builtin::BI__builtin_hlsl_resource_load_with_status_typed: {
3889 if (SemaRef.checkArgCount(TheCall, 4) ||
3890 CheckResourceHandle(&SemaRef, TheCall, 0) ||
3891 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
3892 SemaRef.getASTContext().UnsignedIntTy) ||
3893 CheckArgTypeMatches(&SemaRef, TheCall->getArg(2),
3894 SemaRef.getASTContext().UnsignedIntTy) ||
3895 CheckModifiableLValue(&SemaRef, TheCall, 2))
3896 return true;
3897
3898 QualType ReturnType = TheCall->getArg(3)->getType();
3899 assert(ReturnType->isPointerType() &&
3900 "expected pointer type for second argument");
3901 ReturnType = ReturnType->getPointeeType();
3902
3903 // Reject array types
3904 if (ReturnType->isArrayType())
3905 return SemaRef.Diag(
3906 cast<FunctionDecl>(SemaRef.CurContext)->getPointOfInstantiation(),
3907 diag::err_invalid_use_of_array_type);
3908
3909 TheCall->setType(ReturnType);
3910
3911 break;
3912 }
3913 case Builtin::BI__builtin_hlsl_resource_load_level:
3914 return CheckLoadLevelBuiltin(SemaRef, TheCall);
3915 case Builtin::BI__builtin_hlsl_resource_sample:
3917 case Builtin::BI__builtin_hlsl_resource_sample_bias:
3919 case Builtin::BI__builtin_hlsl_resource_sample_grad:
3921 case Builtin::BI__builtin_hlsl_resource_sample_level:
3923 case Builtin::BI__builtin_hlsl_resource_sample_cmp:
3925 case Builtin::BI__builtin_hlsl_resource_sample_cmp_level_zero:
3927 case Builtin::BI__builtin_hlsl_resource_calculate_lod:
3928 case Builtin::BI__builtin_hlsl_resource_calculate_lod_unclamped:
3929 return CheckCalculateLodBuiltin(SemaRef, TheCall);
3930 case Builtin::BI__builtin_hlsl_resource_gather:
3931 return CheckGatherBuiltin(SemaRef, TheCall, /*IsCmp=*/false);
3932 case Builtin::BI__builtin_hlsl_resource_gather_cmp:
3933 return CheckGatherBuiltin(SemaRef, TheCall, /*IsCmp=*/true);
3934 case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: {
3935 assert(TheCall->getNumArgs() == 1 && "expected 1 arg");
3936 // Update return type to be the attributed resource type from arg0.
3937 QualType ResourceTy = TheCall->getArg(0)->getType();
3938 TheCall->setType(ResourceTy);
3939 break;
3940 }
3941 case Builtin::BI__builtin_hlsl_resource_handlefrombinding: {
3942 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
3943 // Update return type to be the attributed resource type from arg0.
3944 QualType ResourceTy = TheCall->getArg(0)->getType();
3945 TheCall->setType(ResourceTy);
3946 break;
3947 }
3948 case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: {
3949 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
3950 // Update return type to be the attributed resource type from arg0.
3951 QualType ResourceTy = TheCall->getArg(0)->getType();
3952 TheCall->setType(ResourceTy);
3953 break;
3954 }
3955 case Builtin::BI__builtin_hlsl_resource_counterhandlefromimplicitbinding: {
3956 assert(TheCall->getNumArgs() == 3 && "expected 3 args");
3957 ASTContext &AST = SemaRef.getASTContext();
3958 QualType MainHandleTy = TheCall->getArg(0)->getType();
3959 auto *MainResType = MainHandleTy->getAs<HLSLAttributedResourceType>();
3960 auto MainAttrs = MainResType->getAttrs();
3961 assert(!MainAttrs.IsCounter && "cannot create a counter from a counter");
3962 MainAttrs.IsCounter = true;
3963 QualType CounterHandleTy = AST.getHLSLAttributedResourceType(
3964 MainResType->getWrappedType(), MainResType->getContainedType(),
3965 MainAttrs);
3966 // Update return type to be the attributed resource type from arg0
3967 // with added IsCounter flag.
3968 TheCall->setType(CounterHandleTy);
3969 break;
3970 }
3971 case Builtin::BI__builtin_hlsl_and:
3972 case Builtin::BI__builtin_hlsl_or: {
3973 if (SemaRef.checkArgCount(TheCall, 2))
3974 return true;
3975 if (CheckScalarOrVectorOrMatrix(&SemaRef, TheCall, getASTContext().BoolTy,
3976 0))
3977 return true;
3978 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
3979 return true;
3980
3981 ExprResult A = TheCall->getArg(0);
3982 QualType ArgTyA = A.get()->getType();
3983 // return type is the same as the input type
3984 TheCall->setType(ArgTyA);
3985 break;
3986 }
3987 case Builtin::BI__builtin_hlsl_all:
3988 case Builtin::BI__builtin_hlsl_any: {
3989 if (SemaRef.checkArgCount(TheCall, 1))
3990 return true;
3991 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
3992 return true;
3993 break;
3994 }
3995 case Builtin::BI__builtin_hlsl_asdouble: {
3996 if (SemaRef.checkArgCount(TheCall, 2))
3997 return true;
3999 &SemaRef, TheCall,
4000 /*only check for uint*/ SemaRef.Context.UnsignedIntTy,
4001 /* arg index */ 0))
4002 return true;
4004 &SemaRef, TheCall,
4005 /*only check for uint*/ SemaRef.Context.UnsignedIntTy,
4006 /* arg index */ 1))
4007 return true;
4008 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
4009 return true;
4010
4011 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().DoubleTy);
4012 break;
4013 }
4014 case Builtin::BI__builtin_hlsl_elementwise_clamp: {
4015 if (SemaRef.BuiltinElementwiseTernaryMath(
4016 TheCall, /*ArgTyRestr=*/
4018 return true;
4019 break;
4020 }
4021 case Builtin::BI__builtin_hlsl_dot: {
4022 // arg count is checked by BuiltinVectorToScalarMath
4023 if (SemaRef.BuiltinVectorToScalarMath(TheCall))
4024 return true;
4026 return true;
4027 break;
4028 }
4029 case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
4030 case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: {
4031 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4032 return true;
4033
4034 const Expr *Arg = TheCall->getArg(0);
4035 QualType ArgTy = Arg->getType();
4036 QualType EltTy = ArgTy;
4037
4038 QualType ResTy = SemaRef.Context.UnsignedIntTy;
4039
4040 if (auto *VecTy = EltTy->getAs<VectorType>()) {
4041 EltTy = VecTy->getElementType();
4042 ResTy = SemaRef.Context.getExtVectorType(ResTy, VecTy->getNumElements());
4043 }
4044
4045 if (!EltTy->isIntegerType()) {
4046 Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
4047 << 1 << /* scalar or vector of */ 5 << /* integer ty */ 1
4048 << /* no fp */ 0 << ArgTy;
4049 return true;
4050 }
4051
4052 TheCall->setType(ResTy);
4053 break;
4054 }
4055 case Builtin::BI__builtin_hlsl_select: {
4056 if (SemaRef.checkArgCount(TheCall, 3))
4057 return true;
4058 if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
4059 return true;
4060 QualType ArgTy = TheCall->getArg(0)->getType();
4061 if (ArgTy->isBooleanType() && CheckBoolSelect(&SemaRef, TheCall))
4062 return true;
4063 auto *VTy = ArgTy->getAs<VectorType>();
4064 if (VTy && VTy->getElementType()->isBooleanType() &&
4065 CheckVectorSelect(&SemaRef, TheCall))
4066 return true;
4067 break;
4068 }
4069 case Builtin::BI__builtin_hlsl_elementwise_saturate:
4070 case Builtin::BI__builtin_hlsl_elementwise_rcp: {
4071 if (SemaRef.checkArgCount(TheCall, 1))
4072 return true;
4073 if (!TheCall->getArg(0)
4074 ->getType()
4075 ->hasFloatingRepresentation()) // half or float or double
4076 return SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
4077 diag::err_builtin_invalid_arg_type)
4078 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
4079 << /* fp */ 1 << TheCall->getArg(0)->getType();
4080 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4081 return true;
4082 break;
4083 }
4084 case Builtin::BI__builtin_hlsl_elementwise_degrees:
4085 case Builtin::BI__builtin_hlsl_elementwise_radians:
4086 case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
4087 case Builtin::BI__builtin_hlsl_elementwise_frac:
4088 case Builtin::BI__builtin_hlsl_elementwise_ddx_coarse:
4089 case Builtin::BI__builtin_hlsl_elementwise_ddy_coarse:
4090 case Builtin::BI__builtin_hlsl_elementwise_ddx_fine:
4091 case Builtin::BI__builtin_hlsl_elementwise_ddy_fine: {
4092 if (SemaRef.checkArgCount(TheCall, 1))
4093 return true;
4094 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4096 return true;
4097 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4098 return true;
4099 break;
4100 }
4101 case Builtin::BI__builtin_hlsl_elementwise_isinf:
4102 case Builtin::BI__builtin_hlsl_elementwise_isnan: {
4103 if (SemaRef.checkArgCount(TheCall, 1))
4104 return true;
4105 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4107 return true;
4108 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4109 return true;
4111 break;
4112 }
4113 case Builtin::BI__builtin_hlsl_lerp: {
4114 if (SemaRef.checkArgCount(TheCall, 3))
4115 return true;
4116 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4118 return true;
4119 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
4120 return true;
4121 if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
4122 return true;
4123 break;
4124 }
4125 case Builtin::BI__builtin_hlsl_mad: {
4126 if (SemaRef.BuiltinElementwiseTernaryMath(
4127 TheCall, /*ArgTyRestr=*/
4129 return true;
4130 break;
4131 }
4132 case Builtin::BI__builtin_hlsl_mul: {
4133 if (SemaRef.checkArgCount(TheCall, 2))
4134 return true;
4135
4136 Expr *Arg0 = TheCall->getArg(0);
4137 Expr *Arg1 = TheCall->getArg(1);
4138 QualType Ty0 = Arg0->getType();
4139 QualType Ty1 = Arg1->getType();
4140
4141 auto getElemType = [](QualType T) -> QualType {
4142 if (const auto *VTy = T->getAs<VectorType>())
4143 return VTy->getElementType();
4144 if (const auto *MTy = T->getAs<ConstantMatrixType>())
4145 return MTy->getElementType();
4146 return T;
4147 };
4148
4149 QualType EltTy0 = getElemType(Ty0);
4150
4151 bool IsVec0 = Ty0->isVectorType();
4152 bool IsMat0 = Ty0->isConstantMatrixType();
4153 bool IsVec1 = Ty1->isVectorType();
4154 bool IsMat1 = Ty1->isConstantMatrixType();
4155
4156 QualType RetTy;
4157
4158 if (IsVec0 && IsMat1) {
4159 auto *MatTy = Ty1->castAs<ConstantMatrixType>();
4160 RetTy = getASTContext().getExtVectorType(EltTy0, MatTy->getNumColumns());
4161 } else if (IsMat0 && IsVec1) {
4162 auto *MatTy = Ty0->castAs<ConstantMatrixType>();
4163 RetTy = getASTContext().getExtVectorType(EltTy0, MatTy->getNumRows());
4164 } else {
4165 assert(IsMat0 && IsMat1);
4166 auto *MatTy0 = Ty0->castAs<ConstantMatrixType>();
4167 auto *MatTy1 = Ty1->castAs<ConstantMatrixType>();
4169 EltTy0, MatTy0->getNumRows(), MatTy1->getNumColumns());
4170 }
4171
4172 TheCall->setType(RetTy);
4173 break;
4174 }
4175 case Builtin::BI__builtin_hlsl_normalize: {
4176 if (SemaRef.checkArgCount(TheCall, 1))
4177 return true;
4178 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4180 return true;
4181 ExprResult A = TheCall->getArg(0);
4182 QualType ArgTyA = A.get()->getType();
4183 // return type is the same as the input type
4184 TheCall->setType(ArgTyA);
4185 break;
4186 }
4187 case Builtin::BI__builtin_elementwise_fma: {
4188 if (SemaRef.checkArgCount(TheCall, 3) ||
4189 CheckAllArgsHaveSameType(&SemaRef, TheCall)) {
4190 return true;
4191 }
4192
4193 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4195 return true;
4196
4197 ExprResult A = TheCall->getArg(0);
4198 QualType ArgTyA = A.get()->getType();
4199 // return type is the same as input type
4200 TheCall->setType(ArgTyA);
4201 break;
4202 }
4203 case Builtin::BI__builtin_hlsl_transpose: {
4204 if (SemaRef.checkArgCount(TheCall, 1))
4205 return true;
4206
4207 Expr *Arg = TheCall->getArg(0);
4208 QualType ArgTy = Arg->getType();
4209
4210 const auto *MatTy = ArgTy->getAs<ConstantMatrixType>();
4211 if (!MatTy) {
4212 SemaRef.Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
4213 << 1 << /* matrix */ 3 << /* no int */ 0 << /* no fp */ 0 << ArgTy;
4214 return true;
4215 }
4216
4218 MatTy->getElementType(), MatTy->getNumColumns(), MatTy->getNumRows());
4219 TheCall->setType(RetTy);
4220 break;
4221 }
4222 case Builtin::BI__builtin_hlsl_elementwise_sign: {
4223 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4224 return true;
4225 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4227 return true;
4229 break;
4230 }
4231 case Builtin::BI__builtin_hlsl_step: {
4232 if (SemaRef.checkArgCount(TheCall, 2))
4233 return true;
4234 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4236 return true;
4237
4238 ExprResult A = TheCall->getArg(0);
4239 QualType ArgTyA = A.get()->getType();
4240 // return type is the same as the input type
4241 TheCall->setType(ArgTyA);
4242 break;
4243 }
4244 case Builtin::BI__builtin_hlsl_wave_active_all_equal: {
4245 if (SemaRef.checkArgCount(TheCall, 1))
4246 return true;
4247
4248 // Ensure input expr type is a scalar/vector
4249 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4250 return true;
4251
4252 QualType InputTy = TheCall->getArg(0)->getType();
4253 ASTContext &Ctx = getASTContext();
4254
4255 QualType RetTy;
4256
4257 // If vector, construct bool vector of same size
4258 if (const auto *VecTy = InputTy->getAs<ExtVectorType>()) {
4259 unsigned NumElts = VecTy->getNumElements();
4260 RetTy = Ctx.getExtVectorType(Ctx.BoolTy, NumElts);
4261 } else {
4262 // Scalar case
4263 RetTy = Ctx.BoolTy;
4264 }
4265
4266 TheCall->setType(RetTy);
4267 break;
4268 }
4269 case Builtin::BI__builtin_hlsl_wave_active_max:
4270 case Builtin::BI__builtin_hlsl_wave_active_min:
4271 case Builtin::BI__builtin_hlsl_wave_active_sum:
4272 case Builtin::BI__builtin_hlsl_wave_active_product: {
4273 if (SemaRef.checkArgCount(TheCall, 1))
4274 return true;
4275
4276 // Ensure input expr type is a scalar/vector and the same as the return type
4277 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4278 return true;
4279 if (CheckWaveActive(&SemaRef, TheCall))
4280 return true;
4281 ExprResult Expr = TheCall->getArg(0);
4282 QualType ArgTyExpr = Expr.get()->getType();
4283 TheCall->setType(ArgTyExpr);
4284 break;
4285 }
4286 case Builtin::BI__builtin_hlsl_wave_active_bit_or:
4287 case Builtin::BI__builtin_hlsl_wave_active_bit_xor:
4288 case Builtin::BI__builtin_hlsl_wave_active_bit_and: {
4289 if (SemaRef.checkArgCount(TheCall, 1))
4290 return true;
4291
4292 // Ensure input expr type is a scalar/vector
4293 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4294 return true;
4295
4296 if (CheckWaveActive(&SemaRef, TheCall))
4297 return true;
4298
4299 // Ensure the expr type is interpretable as a uint or vector<uint>
4300 ExprResult Expr = TheCall->getArg(0);
4301 QualType ArgTyExpr = Expr.get()->getType();
4302 auto *VTy = ArgTyExpr->getAs<VectorType>();
4303 if (!(ArgTyExpr->isIntegerType() ||
4304 (VTy && VTy->getElementType()->isIntegerType()))) {
4305 SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
4306 diag::err_builtin_invalid_arg_type)
4307 << ArgTyExpr << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
4308 return true;
4309 }
4310
4311 // Ensure input expr type is the same as the return type
4312 TheCall->setType(ArgTyExpr);
4313 break;
4314 }
4315 // Note these are llvm builtins that we want to catch invalid intrinsic
4316 // generation. Normal handling of these builtins will occur elsewhere.
4317 case Builtin::BI__builtin_elementwise_bitreverse: {
4318 // does not include a check for number of arguments
4319 // because that is done previously
4320 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4322 return true;
4323 break;
4324 }
4325 case Builtin::BI__builtin_hlsl_wave_prefix_count_bits: {
4326 if (SemaRef.checkArgCount(TheCall, 1))
4327 return true;
4328
4329 QualType ArgType = TheCall->getArg(0)->getType();
4330
4331 if (!(ArgType->isScalarType())) {
4332 SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
4333 diag::err_typecheck_expect_any_scalar_or_vector)
4334 << ArgType << 0;
4335 return true;
4336 }
4337
4338 if (!(ArgType->isBooleanType())) {
4339 SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
4340 diag::err_typecheck_expect_any_scalar_or_vector)
4341 << ArgType << 0;
4342 return true;
4343 }
4344
4345 break;
4346 }
4347 case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
4348 if (SemaRef.checkArgCount(TheCall, 2))
4349 return true;
4350
4351 // Ensure index parameter type can be interpreted as a uint
4352 ExprResult Index = TheCall->getArg(1);
4353 QualType ArgTyIndex = Index.get()->getType();
4354 if (!ArgTyIndex->isIntegerType()) {
4355 SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
4356 diag::err_typecheck_convert_incompatible)
4357 << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
4358 return true;
4359 }
4360
4361 // Ensure input expr type is a scalar/vector and the same as the return type
4362 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4363 return true;
4364
4365 ExprResult Expr = TheCall->getArg(0);
4366 QualType ArgTyExpr = Expr.get()->getType();
4367 TheCall->setType(ArgTyExpr);
4368 break;
4369 }
4370 case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
4371 if (SemaRef.checkArgCount(TheCall, 0))
4372 return true;
4373 break;
4374 }
4375 case Builtin::BI__builtin_hlsl_wave_prefix_sum:
4376 case Builtin::BI__builtin_hlsl_wave_prefix_product: {
4377 if (SemaRef.checkArgCount(TheCall, 1))
4378 return true;
4379
4380 // Ensure input expr type is a scalar/vector and the same as the return type
4381 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4382 return true;
4383 if (CheckWavePrefix(&SemaRef, TheCall))
4384 return true;
4385 ExprResult Expr = TheCall->getArg(0);
4386 QualType ArgTyExpr = Expr.get()->getType();
4387 TheCall->setType(ArgTyExpr);
4388 break;
4389 }
4390 case Builtin::BI__builtin_hlsl_quad_read_across_x:
4391 case Builtin::BI__builtin_hlsl_quad_read_across_y: {
4392 if (SemaRef.checkArgCount(TheCall, 1))
4393 return true;
4394
4395 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
4396 return true;
4397 if (CheckNotBoolScalarOrVector(&SemaRef, TheCall, 0))
4398 return true;
4399 ExprResult Expr = TheCall->getArg(0);
4400 QualType ArgTyExpr = Expr.get()->getType();
4401 TheCall->setType(ArgTyExpr);
4402 break;
4403 }
4404 case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
4405 if (SemaRef.checkArgCount(TheCall, 3))
4406 return true;
4407
4408 if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.DoubleTy, 0) ||
4409 CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
4410 1) ||
4411 CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
4412 2))
4413 return true;
4414
4415 if (CheckModifiableLValue(&SemaRef, TheCall, 1) ||
4416 CheckModifiableLValue(&SemaRef, TheCall, 2))
4417 return true;
4418 break;
4419 }
4420 case Builtin::BI__builtin_hlsl_elementwise_clip: {
4421 if (SemaRef.checkArgCount(TheCall, 1))
4422 return true;
4423
4424 if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.FloatTy, 0))
4425 return true;
4426 break;
4427 }
4428 case Builtin::BI__builtin_elementwise_acos:
4429 case Builtin::BI__builtin_elementwise_asin:
4430 case Builtin::BI__builtin_elementwise_atan:
4431 case Builtin::BI__builtin_elementwise_atan2:
4432 case Builtin::BI__builtin_elementwise_ceil:
4433 case Builtin::BI__builtin_elementwise_cos:
4434 case Builtin::BI__builtin_elementwise_cosh:
4435 case Builtin::BI__builtin_elementwise_exp:
4436 case Builtin::BI__builtin_elementwise_exp2:
4437 case Builtin::BI__builtin_elementwise_exp10:
4438 case Builtin::BI__builtin_elementwise_floor:
4439 case Builtin::BI__builtin_elementwise_fmod:
4440 case Builtin::BI__builtin_elementwise_log:
4441 case Builtin::BI__builtin_elementwise_log2:
4442 case Builtin::BI__builtin_elementwise_log10:
4443 case Builtin::BI__builtin_elementwise_pow:
4444 case Builtin::BI__builtin_elementwise_roundeven:
4445 case Builtin::BI__builtin_elementwise_sin:
4446 case Builtin::BI__builtin_elementwise_sinh:
4447 case Builtin::BI__builtin_elementwise_sqrt:
4448 case Builtin::BI__builtin_elementwise_tan:
4449 case Builtin::BI__builtin_elementwise_tanh:
4450 case Builtin::BI__builtin_elementwise_trunc: {
4451 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4453 return true;
4454 break;
4455 }
4456 case Builtin::BI__builtin_hlsl_buffer_update_counter: {
4457 assert(TheCall->getNumArgs() == 2 && "expected 2 args");
4458 auto checkResTy = [](const HLSLAttributedResourceType *ResTy) -> bool {
4459 return !(ResTy->getAttrs().ResourceClass == ResourceClass::UAV &&
4460 ResTy->getAttrs().RawBuffer && ResTy->hasContainedType());
4461 };
4462 if (CheckResourceHandle(&SemaRef, TheCall, 0, checkResTy))
4463 return true;
4464 Expr *OffsetExpr = TheCall->getArg(1);
4465 std::optional<llvm::APSInt> Offset =
4466 OffsetExpr->getIntegerConstantExpr(SemaRef.getASTContext());
4467 if (!Offset.has_value() || std::abs(Offset->getExtValue()) != 1) {
4468 SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
4469 diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
4470 << 1;
4471 return true;
4472 }
4473 break;
4474 }
4475 case Builtin::BI__builtin_hlsl_elementwise_f16tof32: {
4476 if (SemaRef.checkArgCount(TheCall, 1))
4477 return true;
4478 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4480 return true;
4481 // ensure arg integers are 32 bits
4482 if (CheckExpectedBitWidth(&SemaRef, TheCall, 0, 32))
4483 return true;
4484 // check it wasn't a bool type
4485 QualType ArgTy = TheCall->getArg(0)->getType();
4486 if (auto *VTy = ArgTy->getAs<VectorType>())
4487 ArgTy = VTy->getElementType();
4488 if (ArgTy->isBooleanType()) {
4489 SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
4490 diag::err_builtin_invalid_arg_type)
4491 << 1 << /* scalar or vector of */ 5 << /* unsigned int */ 3
4492 << /* no fp */ 0 << TheCall->getArg(0)->getType();
4493 return true;
4494 }
4495
4496 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().FloatTy);
4497 break;
4498 }
4499 case Builtin::BI__builtin_hlsl_elementwise_f32tof16: {
4500 if (SemaRef.checkArgCount(TheCall, 1))
4501 return true;
4503 return true;
4505 getASTContext().UnsignedIntTy);
4506 break;
4507 }
4508 }
4509 return false;
4510}
4511
4515 WorkList.push_back(BaseTy);
4516 while (!WorkList.empty()) {
4517 QualType T = WorkList.pop_back_val();
4518 T = T.getCanonicalType().getUnqualifiedType();
4519 if (const auto *AT = dyn_cast<ConstantArrayType>(T)) {
4520 llvm::SmallVector<QualType, 16> ElementFields;
4521 // Generally I've avoided recursion in this algorithm, but arrays of
4522 // structs could be time-consuming to flatten and churn through on the
4523 // work list. Hopefully nesting arrays of structs containing arrays
4524 // of structs too many levels deep is unlikely.
4525 BuildFlattenedTypeList(AT->getElementType(), ElementFields);
4526 // Repeat the element's field list n times.
4527 for (uint64_t Ct = 0; Ct < AT->getZExtSize(); ++Ct)
4528 llvm::append_range(List, ElementFields);
4529 continue;
4530 }
4531 // Vectors can only have element types that are builtin types, so this can
4532 // add directly to the list instead of to the WorkList.
4533 if (const auto *VT = dyn_cast<VectorType>(T)) {
4534 List.insert(List.end(), VT->getNumElements(), VT->getElementType());
4535 continue;
4536 }
4537 if (const auto *MT = dyn_cast<ConstantMatrixType>(T)) {
4538 List.insert(List.end(), MT->getNumElementsFlattened(),
4539 MT->getElementType());
4540 continue;
4541 }
4542 if (const auto *RD = T->getAsCXXRecordDecl()) {
4543 if (RD->isStandardLayout())
4544 RD = RD->getStandardLayoutBaseWithFields();
4545
4546 // For types that we shouldn't decompose (unions and non-aggregates), just
4547 // add the type itself to the list.
4548 if (RD->isUnion() || !RD->isAggregate()) {
4549 List.push_back(T);
4550 continue;
4551 }
4552
4554 for (const auto *FD : RD->fields())
4555 if (!FD->isUnnamedBitField())
4556 FieldTypes.push_back(FD->getType());
4557 // Reverse the newly added sub-range.
4558 std::reverse(FieldTypes.begin(), FieldTypes.end());
4559 llvm::append_range(WorkList, FieldTypes);
4560
4561 // If this wasn't a standard layout type we may also have some base
4562 // classes to deal with.
4563 if (!RD->isStandardLayout()) {
4564 FieldTypes.clear();
4565 for (const auto &Base : RD->bases())
4566 FieldTypes.push_back(Base.getType());
4567 std::reverse(FieldTypes.begin(), FieldTypes.end());
4568 llvm::append_range(WorkList, FieldTypes);
4569 }
4570 continue;
4571 }
4572 List.push_back(T);
4573 }
4574}
4575
4577 // null and array types are not allowed.
4578 if (QT.isNull() || QT->isArrayType())
4579 return false;
4580
4581 // UDT types are not allowed
4582 if (QT->isRecordType())
4583 return false;
4584
4585 if (QT->isBooleanType() || QT->isEnumeralType())
4586 return false;
4587
4588 // the only other valid builtin types are scalars or vectors
4589 if (QT->isArithmeticType()) {
4590 if (SemaRef.Context.getTypeSize(QT) / 8 > 16)
4591 return false;
4592 return true;
4593 }
4594
4595 if (const VectorType *VT = QT->getAs<VectorType>()) {
4596 int ArraySize = VT->getNumElements();
4597
4598 if (ArraySize > 4)
4599 return false;
4600
4601 QualType ElTy = VT->getElementType();
4602 if (ElTy->isBooleanType())
4603 return false;
4604
4605 if (SemaRef.Context.getTypeSize(QT) / 8 > 16)
4606 return false;
4607 return true;
4608 }
4609
4610 return false;
4611}
4612
4614 if (T1.isNull() || T2.isNull())
4615 return false;
4616
4619
4620 // If both types are the same canonical type, they're obviously compatible.
4621 if (SemaRef.getASTContext().hasSameType(T1, T2))
4622 return true;
4623
4625 BuildFlattenedTypeList(T1, T1Types);
4627 BuildFlattenedTypeList(T2, T2Types);
4628
4629 // Check the flattened type list
4630 return llvm::equal(T1Types, T2Types,
4631 [this](QualType LHS, QualType RHS) -> bool {
4632 return SemaRef.IsLayoutCompatible(LHS, RHS);
4633 });
4634}
4635
4637 FunctionDecl *Old) {
4638 if (New->getNumParams() != Old->getNumParams())
4639 return true;
4640
4641 bool HadError = false;
4642
4643 for (unsigned i = 0, e = New->getNumParams(); i != e; ++i) {
4644 ParmVarDecl *NewParam = New->getParamDecl(i);
4645 ParmVarDecl *OldParam = Old->getParamDecl(i);
4646
4647 // HLSL parameter declarations for inout and out must match between
4648 // declarations. In HLSL inout and out are ambiguous at the call site,
4649 // but have different calling behavior, so you cannot overload a
4650 // method based on a difference between inout and out annotations.
4651 const auto *NDAttr = NewParam->getAttr<HLSLParamModifierAttr>();
4652 unsigned NSpellingIdx = (NDAttr ? NDAttr->getSpellingListIndex() : 0);
4653 const auto *ODAttr = OldParam->getAttr<HLSLParamModifierAttr>();
4654 unsigned OSpellingIdx = (ODAttr ? ODAttr->getSpellingListIndex() : 0);
4655
4656 if (NSpellingIdx != OSpellingIdx) {
4657 SemaRef.Diag(NewParam->getLocation(),
4658 diag::err_hlsl_param_qualifier_mismatch)
4659 << NDAttr << NewParam;
4660 SemaRef.Diag(OldParam->getLocation(), diag::note_previous_declaration_as)
4661 << ODAttr;
4662 HadError = true;
4663 }
4664 }
4665 return HadError;
4666}
4667
4668// Generally follows PerformScalarCast, with cases reordered for
4669// clarity of what types are supported
4671
4672 if (!SrcTy->isScalarType() || !DestTy->isScalarType())
4673 return false;
4674
4675 if (SemaRef.getASTContext().hasSameUnqualifiedType(SrcTy, DestTy))
4676 return true;
4677
4678 switch (SrcTy->getScalarTypeKind()) {
4679 case Type::STK_Bool: // casting from bool is like casting from an integer
4680 case Type::STK_Integral:
4681 switch (DestTy->getScalarTypeKind()) {
4682 case Type::STK_Bool:
4683 case Type::STK_Integral:
4684 case Type::STK_Floating:
4685 return true;
4686 case Type::STK_CPointer:
4690 llvm_unreachable("HLSL doesn't support pointers.");
4693 llvm_unreachable("HLSL doesn't support complex types.");
4695 llvm_unreachable("HLSL doesn't support fixed point types.");
4696 }
4697 llvm_unreachable("Should have returned before this");
4698
4699 case Type::STK_Floating:
4700 switch (DestTy->getScalarTypeKind()) {
4701 case Type::STK_Floating:
4702 case Type::STK_Bool:
4703 case Type::STK_Integral:
4704 return true;
4707 llvm_unreachable("HLSL doesn't support complex types.");
4709 llvm_unreachable("HLSL doesn't support fixed point types.");
4710 case Type::STK_CPointer:
4714 llvm_unreachable("HLSL doesn't support pointers.");
4715 }
4716 llvm_unreachable("Should have returned before this");
4717
4719 case Type::STK_CPointer:
4722 llvm_unreachable("HLSL doesn't support pointers.");
4723
4725 llvm_unreachable("HLSL doesn't support fixed point types.");
4726
4729 llvm_unreachable("HLSL doesn't support complex types.");
4730 }
4731
4732 llvm_unreachable("Unhandled scalar cast");
4733}
4734
4735// Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
4736// Src is a scalar, a vector of length 1, or a 1x1 matrix
4737// Or if Dest is a vector and Src is a vector of length 1 or a 1x1 matrix
4739
4740 QualType SrcTy = Src->getType();
4741 // Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is
4742 // going to be a vector splat from a scalar.
4743 if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
4744 DestTy->isScalarType())
4745 return false;
4746
4747 const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
4748 const ConstantMatrixType *SrcMatTy = SrcTy->getAs<ConstantMatrixType>();
4749
4750 // Src isn't a scalar, a vector of length 1, or a 1x1 matrix
4751 if (!SrcTy->isScalarType() &&
4752 !(SrcVecTy && SrcVecTy->getNumElements() == 1) &&
4753 !(SrcMatTy && SrcMatTy->getNumElementsFlattened() == 1))
4754 return false;
4755
4756 if (SrcVecTy)
4757 SrcTy = SrcVecTy->getElementType();
4758 else if (SrcMatTy)
4759 SrcTy = SrcMatTy->getElementType();
4760
4762 BuildFlattenedTypeList(DestTy, DestTypes);
4763
4764 for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) {
4765 if (DestTypes[I]->isUnionType())
4766 return false;
4767 if (!CanPerformScalarCast(SrcTy, DestTypes[I]))
4768 return false;
4769 }
4770 return true;
4771}
4772
4773// Can we perform an HLSL Elementwise cast?
4775
4776 // Don't handle casts where LHS and RHS are any combination of scalar/vector
4777 // There must be an aggregate somewhere
4778 QualType SrcTy = Src->getType();
4779 if (SrcTy->isScalarType()) // always a splat and this cast doesn't handle that
4780 return false;
4781
4782 if (SrcTy->isVectorType() &&
4783 (DestTy->isScalarType() || DestTy->isVectorType()))
4784 return false;
4785
4786 if (SrcTy->isConstantMatrixType() &&
4787 (DestTy->isScalarType() || DestTy->isConstantMatrixType()))
4788 return false;
4789
4791 BuildFlattenedTypeList(DestTy, DestTypes);
4793 BuildFlattenedTypeList(SrcTy, SrcTypes);
4794
4795 // Usually the size of SrcTypes must be greater than or equal to the size of
4796 // DestTypes.
4797 if (SrcTypes.size() < DestTypes.size())
4798 return false;
4799
4800 unsigned SrcSize = SrcTypes.size();
4801 unsigned DstSize = DestTypes.size();
4802 unsigned I;
4803 for (I = 0; I < DstSize && I < SrcSize; I++) {
4804 if (SrcTypes[I]->isUnionType() || DestTypes[I]->isUnionType())
4805 return false;
4806 if (!CanPerformScalarCast(SrcTypes[I], DestTypes[I])) {
4807 return false;
4808 }
4809 }
4810
4811 // check the rest of the source type for unions.
4812 for (; I < SrcSize; I++) {
4813 if (SrcTypes[I]->isUnionType())
4814 return false;
4815 }
4816 return true;
4817}
4818
4820 assert(Param->hasAttr<HLSLParamModifierAttr>() &&
4821 "We should not get here without a parameter modifier expression");
4822 const auto *Attr = Param->getAttr<HLSLParamModifierAttr>();
4823 if (Attr->getABI() == ParameterABI::Ordinary)
4824 return ExprResult(Arg);
4825
4826 bool IsInOut = Attr->getABI() == ParameterABI::HLSLInOut;
4827 if (!Arg->isLValue()) {
4828 SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_lvalue)
4829 << Arg << (IsInOut ? 1 : 0);
4830 return ExprError();
4831 }
4832
4833 ASTContext &Ctx = SemaRef.getASTContext();
4834
4835 QualType Ty = Param->getType().getNonLValueExprType(Ctx);
4836
4837 // HLSL allows implicit conversions from scalars to vectors, but not the
4838 // inverse, so we need to disallow `inout` with scalar->vector or
4839 // scalar->matrix conversions.
4840 if (Arg->getType()->isScalarType() != Ty->isScalarType()) {
4841 SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_scalar_extension)
4842 << Arg << (IsInOut ? 1 : 0);
4843 return ExprError();
4844 }
4845
4846 auto *ArgOpV = new (Ctx) OpaqueValueExpr(Param->getBeginLoc(), Arg->getType(),
4847 VK_LValue, OK_Ordinary, Arg);
4848
4849 // Parameters are initialized via copy initialization. This allows for
4850 // overload resolution of argument constructors.
4851 InitializedEntity Entity =
4853 ExprResult Res =
4854 SemaRef.PerformCopyInitialization(Entity, Param->getBeginLoc(), ArgOpV);
4855 if (Res.isInvalid())
4856 return ExprError();
4857 Expr *Base = Res.get();
4858 // After the cast, drop the reference type when creating the exprs.
4859 Ty = Ty.getNonLValueExprType(Ctx);
4860 auto *OpV = new (Ctx)
4861 OpaqueValueExpr(Param->getBeginLoc(), Ty, VK_LValue, OK_Ordinary, Base);
4862
4863 // Writebacks are performed with `=` binary operator, which allows for
4864 // overload resolution on writeback result expressions.
4865 Res = SemaRef.ActOnBinOp(SemaRef.getCurScope(), Param->getBeginLoc(),
4866 tok::equal, ArgOpV, OpV);
4867
4868 if (Res.isInvalid())
4869 return ExprError();
4870 Expr *Writeback = Res.get();
4871 auto *OutExpr =
4872 HLSLOutArgExpr::Create(Ctx, Ty, ArgOpV, OpV, Writeback, IsInOut);
4873
4874 return ExprResult(OutExpr);
4875}
4876
4878 // If HLSL gains support for references, all the cites that use this will need
4879 // to be updated with semantic checking to produce errors for
4880 // pointers/references.
4881 assert(!Ty->isReferenceType() &&
4882 "Pointer and reference types cannot be inout or out parameters");
4883 Ty = SemaRef.getASTContext().getLValueReferenceType(Ty);
4884 Ty.addRestrict();
4885 return Ty;
4886}
4887
4888// Returns true if the type has a non-empty constant buffer layout (if it is
4889// scalar, vector or matrix, or if it contains any of these.
4891 const Type *Ty = QT->getUnqualifiedDesugaredType();
4892 if (Ty->isScalarType() || Ty->isVectorType() || Ty->isMatrixType())
4893 return true;
4894
4896 return false;
4897
4898 if (const auto *RD = Ty->getAsCXXRecordDecl()) {
4899 for (const auto *FD : RD->fields()) {
4901 return true;
4902 }
4903 assert(RD->getNumBases() <= 1 &&
4904 "HLSL doesn't support multiple inheritance");
4905 return RD->getNumBases()
4906 ? hasConstantBufferLayout(RD->bases_begin()->getType())
4907 : false;
4908 }
4909
4910 if (const auto *AT = dyn_cast<ArrayType>(Ty)) {
4911 if (const auto *CAT = dyn_cast<ConstantArrayType>(AT))
4912 if (isZeroSizedArray(CAT))
4913 return false;
4915 }
4916
4917 return false;
4918}
4919
4920static bool IsDefaultBufferConstantDecl(const ASTContext &Ctx, VarDecl *VD) {
4921 bool IsVulkan =
4922 Ctx.getTargetInfo().getTriple().getOS() == llvm::Triple::Vulkan;
4923 bool IsVKPushConstant = IsVulkan && VD->hasAttr<HLSLVkPushConstantAttr>();
4924 QualType QT = VD->getType();
4925 return VD->getDeclContext()->isTranslationUnit() &&
4926 QT.getAddressSpace() == LangAS::Default &&
4927 VD->getStorageClass() != SC_Static &&
4928 !VD->hasAttr<HLSLVkConstantIdAttr>() && !IsVKPushConstant &&
4930}
4931
4933 // The variable already has an address space (groupshared for ex).
4934 if (Decl->getType().hasAddressSpace())
4935 return;
4936
4937 if (Decl->getType()->isDependentType())
4938 return;
4939
4940 QualType Type = Decl->getType();
4941
4942 if (Decl->hasAttr<HLSLVkExtBuiltinInputAttr>()) {
4943 LangAS ImplAS = LangAS::hlsl_input;
4944 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
4945 Decl->setType(Type);
4946 return;
4947 }
4948
4949 if (Decl->hasAttr<HLSLVkExtBuiltinOutputAttr>()) {
4950 LangAS ImplAS = LangAS::hlsl_output;
4951 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
4952 Decl->setType(Type);
4953
4954 // HLSL uses `static` differently than C++. For BuiltIn output, the static
4955 // does not imply private to the module scope.
4956 // Marking it as external to reflect the semantic this attribute brings.
4957 // See https://github.com/microsoft/hlsl-specs/issues/350
4958 Decl->setStorageClass(SC_Extern);
4959 return;
4960 }
4961
4962 bool IsVulkan = getASTContext().getTargetInfo().getTriple().getOS() ==
4963 llvm::Triple::Vulkan;
4964 if (IsVulkan && Decl->hasAttr<HLSLVkPushConstantAttr>()) {
4965 if (HasDeclaredAPushConstant)
4966 SemaRef.Diag(Decl->getLocation(), diag::err_hlsl_push_constant_unique);
4967
4969 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
4970 Decl->setType(Type);
4971 HasDeclaredAPushConstant = true;
4972 return;
4973 }
4974
4975 if (Type->isSamplerT() || Type->isVoidType())
4976 return;
4977
4978 // Resource handles.
4980 return;
4981
4982 // Only static globals belong to the Private address space.
4983 // Non-static globals belongs to the cbuffer.
4984 if (Decl->getStorageClass() != SC_Static && !Decl->isStaticDataMember())
4985 return;
4986
4988 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
4989 Decl->setType(Type);
4990}
4991
4992namespace {
4993
4994// Helper class for assigning bindings to resources declared within a struct.
4995// It keeps track of all binding attributes declared on a struct instance, and
4996// the offsets for each register type that have been assigned so far.
4997// Handles both explicit and implicit bindings.
4998class StructBindingContext {
4999 // Bindings and offsets per register type. We only need to support four
5000 // register types - SRV (u), UAV (t), CBuffer (c), and Sampler (s).
5001 HLSLResourceBindingAttr *RegBindingsAttrs[4];
5002 unsigned RegBindingOffset[4];
5003
5004 // Make sure the RegisterType values are what we expect
5005 static_assert(static_cast<unsigned>(RegisterType::SRV) == 0 &&
5006 static_cast<unsigned>(RegisterType::UAV) == 1 &&
5007 static_cast<unsigned>(RegisterType::CBuffer) == 2 &&
5008 static_cast<unsigned>(RegisterType::Sampler) == 3,
5009 "unexpected register type values");
5010
5011 // Vulkan binding attribute does not vary by register type.
5012 HLSLVkBindingAttr *VkBindingAttr;
5013 unsigned VkBindingOffset;
5014
5015public:
5016 // Constructor: gather all binding attributes on a struct instance and
5017 // initialize offsets.
5018 StructBindingContext(VarDecl *VD) {
5019 for (unsigned i = 0; i < 4; ++i) {
5020 RegBindingsAttrs[i] = nullptr;
5021 RegBindingOffset[i] = 0;
5022 }
5023 VkBindingAttr = nullptr;
5024 VkBindingOffset = 0;
5025
5026 ASTContext &AST = VD->getASTContext();
5027 bool IsSpirv = AST.getTargetInfo().getTriple().isSPIRV();
5028
5029 for (Attr *A : VD->attrs()) {
5030 if (auto *RBA = dyn_cast<HLSLResourceBindingAttr>(A)) {
5031 RegisterType RegType = RBA->getRegisterType();
5032 unsigned RegTypeIdx = static_cast<unsigned>(RegType);
5033 // Ignore unsupported register annotations, such as 'c' or 'i'.
5034 if (RegTypeIdx < 4)
5035 RegBindingsAttrs[RegTypeIdx] = RBA;
5036 continue;
5037 }
5038 // Gather the Vulkan binding attributes only if the target is SPIR-V.
5039 if (IsSpirv) {
5040 if (auto *VBA = dyn_cast<HLSLVkBindingAttr>(A))
5041 VkBindingAttr = VBA;
5042 }
5043 }
5044 }
5045
5046 // Creates a binding attribute for a resource based on the gathered attributes
5047 // and the required register type and range.
5048 Attr *createBindingAttr(SemaHLSL &S, ASTContext &AST, RegisterType RegType,
5049 unsigned Range, bool HasCounter) {
5050 assert(static_cast<unsigned>(RegType) < 4 && "unexpected register type");
5051
5052 if (VkBindingAttr) {
5053 unsigned Offset = VkBindingOffset;
5054 VkBindingOffset += Range;
5055 return HLSLVkBindingAttr::CreateImplicit(
5056 AST, VkBindingAttr->getBinding() + Offset, VkBindingAttr->getSet(),
5057 VkBindingAttr->getRange());
5058 }
5059
5060 HLSLResourceBindingAttr *RBA =
5061 RegBindingsAttrs[static_cast<unsigned>(RegType)];
5062 HLSLResourceBindingAttr *NewAttr = nullptr;
5063
5064 if (RBA && RBA->hasRegisterSlot()) {
5065 // Explicit binding - create a new attribute with offseted slot number
5066 // based on the required register type.
5067 unsigned Offset = RegBindingOffset[static_cast<unsigned>(RegType)];
5068 RegBindingOffset[static_cast<unsigned>(RegType)] += Range;
5069
5070 unsigned NewSlotNumber = RBA->getSlotNumber() + Offset;
5071 StringRef NewSlotNumberStr =
5072 createRegisterString(AST, RBA->getRegisterType(), NewSlotNumber);
5073 NewAttr = HLSLResourceBindingAttr::CreateImplicit(
5074 AST, NewSlotNumberStr, RBA->getSpace(), RBA->getRange());
5075 NewAttr->setBinding(RegType, NewSlotNumber, RBA->getSpaceNumber());
5076 } else {
5077 // No binding attribute or space-only binding - create a binding
5078 // attribute for implicit binding.
5079 NewAttr = HLSLResourceBindingAttr::CreateImplicit(AST, "", "0", {});
5080 NewAttr->setBinding(RegType, std::nullopt,
5081 RBA ? RBA->getSpaceNumber() : 0);
5082 NewAttr->setImplicitBindingOrderID(S.getNextImplicitBindingOrderID());
5083 }
5084 if (HasCounter)
5085 NewAttr->setImplicitCounterBindingOrderID(
5087 return NewAttr;
5088 }
5089};
5090
5091// Creates a global variable declaration for a resource field embedded in a
5092// struct, assigns it a binding, initializes it, and associates it with the
5093// struct declaration via an HLSLAssociatedResourceDeclAttr.
5094static void createGlobalResourceDeclForStruct(
5095 Sema &S, VarDecl *ParentVD, SourceLocation Loc, IdentifierInfo *Id,
5096 QualType ResTy, StructBindingContext &BindingCtx) {
5097 assert(isResourceRecordTypeOrArrayOf(ResTy) &&
5098 "expected resource type or array of resources");
5099
5100 DeclContext *DC = ParentVD->getNonTransparentDeclContext();
5101 assert(DC->isTranslationUnit() && "expected translation unit decl context");
5102
5103 ASTContext &AST = S.getASTContext();
5104 VarDecl *ResDecl =
5105 VarDecl::Create(AST, DC, Loc, Loc, Id, ResTy, nullptr, SC_None);
5106
5107 unsigned Range = 1;
5108 const Type *SingleResTy = ResTy.getTypePtr()->getUnqualifiedDesugaredType();
5109 while (const auto *AT = dyn_cast<ArrayType>(SingleResTy)) {
5110 const auto *CAT = dyn_cast<ConstantArrayType>(AT);
5111 Range = CAT ? (Range * CAT->getSize().getZExtValue()) : 0;
5112 SingleResTy =
5114 }
5115 const HLSLAttributedResourceType *ResHandleTy =
5116 HLSLAttributedResourceType::findHandleTypeOnResource(SingleResTy);
5117
5118 // Add a binding attribute to the global resource declaration.
5119 bool HasCounter = hasCounterHandle(SingleResTy->getAsCXXRecordDecl());
5120 Attr *BindingAttr = BindingCtx.createBindingAttr(
5121 S.HLSL(), AST, getRegisterType(ResHandleTy), Range, HasCounter);
5122 ResDecl->addAttr(BindingAttr);
5123 ResDecl->addAttr(InternalLinkageAttr::CreateImplicit(AST));
5124 ResDecl->setImplicit();
5125
5126 if (Range == 1)
5127 S.HLSL().initGlobalResourceDecl(ResDecl);
5128 else
5129 S.HLSL().initGlobalResourceArrayDecl(ResDecl);
5130
5131 ParentVD->addAttr(
5132 HLSLAssociatedResourceDeclAttr::CreateImplicit(AST, ResDecl));
5133 DC->addDecl(ResDecl);
5134
5135 DeclGroupRef DG(ResDecl);
5137}
5138
5139static void handleArrayOfStructWithResources(
5140 Sema &S, VarDecl *ParentVD, const ConstantArrayType *CAT,
5141 EmbeddedResourceNameBuilder &NameBuilder, StructBindingContext &BindingCtx);
5142
5143// Scans base and all fields of a struct/class type to find all embedded
5144// resources or resource arrays. Creates a global variable for each resource
5145// found.
5146static void handleStructWithResources(Sema &S, VarDecl *ParentVD,
5147 const CXXRecordDecl *RD,
5148 EmbeddedResourceNameBuilder &NameBuilder,
5149 StructBindingContext &BindingCtx) {
5150
5151 // Scan the base classes.
5152 assert(RD->getNumBases() <= 1 && "HLSL doesn't support multiple inheritance");
5153 const auto *BasesIt = RD->bases_begin();
5154 if (BasesIt != RD->bases_end()) {
5155 QualType QT = BasesIt->getType();
5156 if (QT->isHLSLIntangibleType()) {
5157 CXXRecordDecl *BaseRD = QT->getAsCXXRecordDecl();
5158 NameBuilder.pushBaseName(BaseRD->getName());
5159 handleStructWithResources(S, ParentVD, BaseRD, NameBuilder, BindingCtx);
5160 NameBuilder.pop();
5161 }
5162 }
5163 // Process this class fields.
5164 for (const FieldDecl *FD : RD->fields()) {
5165 QualType FDTy = FD->getType().getCanonicalType();
5166 if (!FDTy->isHLSLIntangibleType())
5167 continue;
5168
5169 NameBuilder.pushName(FD->getName());
5170
5172 IdentifierInfo *II = NameBuilder.getNameAsIdentifier(S.getASTContext());
5173 createGlobalResourceDeclForStruct(S, ParentVD, FD->getLocation(), II,
5174 FDTy, BindingCtx);
5175 } else if (const auto *RD = FDTy->getAsCXXRecordDecl()) {
5176 handleStructWithResources(S, ParentVD, RD, NameBuilder, BindingCtx);
5177
5178 } else if (const auto *ArrayTy = dyn_cast<ConstantArrayType>(FDTy)) {
5179 assert(!FDTy->isHLSLResourceRecordArray() &&
5180 "resource arrays should have been already handled");
5181 handleArrayOfStructWithResources(S, ParentVD, ArrayTy, NameBuilder,
5182 BindingCtx);
5183 }
5184 NameBuilder.pop();
5185 }
5186}
5187
5188// Processes array of structs with resources.
5189static void
5190handleArrayOfStructWithResources(Sema &S, VarDecl *ParentVD,
5191 const ConstantArrayType *CAT,
5192 EmbeddedResourceNameBuilder &NameBuilder,
5193 StructBindingContext &BindingCtx) {
5194
5195 QualType ElementTy = CAT->getElementType().getCanonicalType();
5196 assert(ElementTy->isHLSLIntangibleType() && "Expected HLSL intangible type");
5197
5198 const ConstantArrayType *SubCAT = dyn_cast<ConstantArrayType>(ElementTy);
5199 const CXXRecordDecl *ElementRD = ElementTy->getAsCXXRecordDecl();
5200
5201 if (!SubCAT && !ElementRD)
5202 return;
5203
5204 for (unsigned I = 0, E = CAT->getSize().getZExtValue(); I < E; ++I) {
5205 NameBuilder.pushArrayIndex(I);
5206 if (ElementRD)
5207 handleStructWithResources(S, ParentVD, ElementRD, NameBuilder,
5208 BindingCtx);
5209 else
5210 handleArrayOfStructWithResources(S, ParentVD, SubCAT, NameBuilder,
5211 BindingCtx);
5212 NameBuilder.pop();
5213 }
5214}
5215
5216} // namespace
5217
5218// Scans all fields of a user-defined struct (or array of structs)
5219// to find all embedded resources or resource arrays. For each resource
5220// a global variable of the resource type is created and associated
5221// with the parent declaration (VD) through a HLSLAssociatedResourceDeclAttr
5222// attribute.
5223void SemaHLSL::handleGlobalStructOrArrayOfWithResources(VarDecl *VD) {
5224 EmbeddedResourceNameBuilder NameBuilder(VD->getName());
5225 StructBindingContext BindingCtx(VD);
5226
5227 const Type *VDTy = VD->getType().getTypePtr();
5228 assert(VDTy->isHLSLIntangibleType() && !isResourceRecordTypeOrArrayOf(VD) &&
5229 "Expected non-resource struct or array type");
5230
5231 if (const CXXRecordDecl *RD = VDTy->getAsCXXRecordDecl()) {
5232 handleStructWithResources(SemaRef, VD, RD, NameBuilder, BindingCtx);
5233 return;
5234 }
5235
5236 if (const auto *CAT = dyn_cast<ConstantArrayType>(VDTy)) {
5237 handleArrayOfStructWithResources(SemaRef, VD, CAT, NameBuilder, BindingCtx);
5238 return;
5239 }
5240}
5241
5243 if (VD->hasGlobalStorage()) {
5244 // make sure the declaration has a complete type
5245 if (SemaRef.RequireCompleteType(
5246 VD->getLocation(),
5247 SemaRef.getASTContext().getBaseElementType(VD->getType()),
5248 diag::err_typecheck_decl_incomplete_type)) {
5249 VD->setInvalidDecl();
5251 return;
5252 }
5253
5254 // Global variables outside a cbuffer block that are not a resource, static,
5255 // groupshared, or an empty array or struct belong to the default constant
5256 // buffer $Globals (to be created at the end of the translation unit).
5258 // update address space to hlsl_constant
5261 VD->setType(NewTy);
5262 DefaultCBufferDecls.push_back(VD);
5263 }
5264
5265 // find all resources bindings on decl
5266 if (VD->getType()->isHLSLIntangibleType())
5267 collectResourceBindingsOnVarDecl(VD);
5268
5269 if (VD->hasAttr<HLSLVkConstantIdAttr>())
5271
5273 VD->getStorageClass() != SC_Static) {
5274 // Add internal linkage attribute to non-static resource variables. The
5275 // global externally visible storage is accessed through the handle, which
5276 // is a member. The variable itself is not externally visible.
5277 VD->addAttr(InternalLinkageAttr::CreateImplicit(getASTContext()));
5278 }
5279
5280 // process explicit bindings
5281 processExplicitBindingsOnDecl(VD);
5282
5283 // Add implicit binding attribute to non-static resource arrays.
5284 if (VD->getType()->isHLSLResourceRecordArray() &&
5285 VD->getStorageClass() != SC_Static) {
5286 // If the resource array does not have an explicit binding attribute,
5287 // create an implicit one. It will be used to transfer implicit binding
5288 // order_ID to codegen.
5289 ResourceBindingAttrs Binding(VD);
5290 if (!Binding.isExplicit()) {
5291 uint32_t OrderID = getNextImplicitBindingOrderID();
5292 if (Binding.hasBinding())
5293 Binding.setImplicitOrderID(OrderID);
5294 else {
5297 OrderID);
5298 // Re-create the binding object to pick up the new attribute.
5299 Binding = ResourceBindingAttrs(VD);
5300 }
5301 }
5302
5303 // Get to the base type of a potentially multi-dimensional array.
5305
5306 const CXXRecordDecl *RD = Ty->getAsCXXRecordDecl();
5307 if (hasCounterHandle(RD)) {
5308 if (!Binding.hasCounterImplicitOrderID()) {
5309 uint32_t OrderID = getNextImplicitBindingOrderID();
5310 Binding.setCounterImplicitOrderID(OrderID);
5311 }
5312 }
5313 }
5314
5315 // Process resources in user-defined structs, or arrays of such structs.
5316 const Type *VDTy = VD->getType().getTypePtr();
5317 if (VD->getStorageClass() != SC_Static && VDTy->isHLSLIntangibleType() &&
5319 handleGlobalStructOrArrayOfWithResources(VD);
5320
5321 // Mark groupshared variables as extern so they will have
5322 // external storage and won't be default initialized
5323 if (VD->hasAttr<HLSLGroupSharedAddressSpaceAttr>())
5325 }
5326
5328}
5329
5331 assert(VD->getType()->isHLSLResourceRecord() &&
5332 "expected resource record type");
5333
5334 ASTContext &AST = SemaRef.getASTContext();
5335 uint64_t UIntTySize = AST.getTypeSize(AST.UnsignedIntTy);
5336 uint64_t IntTySize = AST.getTypeSize(AST.IntTy);
5337
5338 // Gather resource binding attributes.
5339 ResourceBindingAttrs Binding(VD);
5340
5341 // Find correct initialization method and create its arguments.
5342 QualType ResourceTy = VD->getType();
5343 CXXRecordDecl *ResourceDecl = ResourceTy->getAsCXXRecordDecl();
5344 CXXMethodDecl *CreateMethod = nullptr;
5346
5347 bool HasCounter = hasCounterHandle(ResourceDecl);
5348 const char *CreateMethodName;
5349 if (Binding.isExplicit())
5350 CreateMethodName = HasCounter ? "__createFromBindingWithImplicitCounter"
5351 : "__createFromBinding";
5352 else
5353 CreateMethodName = HasCounter
5354 ? "__createFromImplicitBindingWithImplicitCounter"
5355 : "__createFromImplicitBinding";
5356
5357 CreateMethod =
5358 lookupMethod(SemaRef, ResourceDecl, CreateMethodName, VD->getLocation());
5359
5360 if (!CreateMethod) {
5361 // This can happen if someone creates a struct that looks like an HLSL
5362 // resource record but does not have the required static create method.
5363 // No binding will be generated for it.
5364 assert(!ResourceDecl->isImplicit() &&
5365 "create method lookup should always succeed for built-in resource "
5366 "records");
5367 return false;
5368 }
5369
5370 if (Binding.isExplicit()) {
5371 IntegerLiteral *RegSlot =
5372 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, Binding.getSlot()),
5374 Args.push_back(RegSlot);
5375 } else {
5376 uint32_t OrderID = (Binding.hasImplicitOrderID())
5377 ? Binding.getImplicitOrderID()
5379 IntegerLiteral *OrderId =
5380 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, OrderID),
5382 Args.push_back(OrderId);
5383 }
5384
5385 IntegerLiteral *Space =
5386 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, Binding.getSpace()),
5388 Args.push_back(Space);
5389
5391 AST, llvm::APInt(IntTySize, 1), AST.IntTy, SourceLocation());
5392 Args.push_back(RangeSize);
5393
5395 AST, llvm::APInt(UIntTySize, 0), AST.UnsignedIntTy, SourceLocation());
5396 Args.push_back(Index);
5397
5398 StringRef VarName = VD->getName();
5400 AST, VarName, StringLiteralKind::Ordinary, false,
5401 AST.getStringLiteralArrayType(AST.CharTy.withConst(), VarName.size()),
5402 SourceLocation());
5404 AST, AST.getPointerType(AST.CharTy.withConst()), CK_ArrayToPointerDecay,
5405 Name, nullptr, VK_PRValue, FPOptionsOverride());
5406 Args.push_back(NameCast);
5407
5408 if (HasCounter) {
5409 // Will this be in the correct order?
5410 uint32_t CounterOrderID = getNextImplicitBindingOrderID();
5411 IntegerLiteral *CounterId =
5412 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, CounterOrderID),
5414 Args.push_back(CounterId);
5415 }
5416
5417 // Make sure the create method template is instantiated and emitted.
5418 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
5419 SemaRef.InstantiateFunctionDefinition(VD->getLocation(), CreateMethod,
5420 true);
5421
5422 // Create CallExpr with a call to the static method and set it as the decl
5423 // initialization.
5425 AST, NestedNameSpecifierLoc(), SourceLocation(), CreateMethod, false,
5426 CreateMethod->getNameInfo(), CreateMethod->getType(), VK_PRValue);
5427
5428 auto *ImpCast = ImplicitCastExpr::Create(
5429 AST, AST.getPointerType(CreateMethod->getType()),
5430 CK_FunctionToPointerDecay, DRE, nullptr, VK_PRValue, FPOptionsOverride());
5431
5432 CallExpr *InitExpr =
5433 CallExpr::Create(AST, ImpCast, Args, ResourceTy, VK_PRValue,
5435 VD->setInit(InitExpr);
5437 SemaRef.CheckCompleteVariableDeclaration(VD);
5438 return true;
5439}
5440
5442 assert(VD->getType()->isHLSLResourceRecordArray() &&
5443 "expected array of resource records");
5444
5445 // Individual resources in a resource array are not initialized here. They
5446 // are initialized later on during codegen when the individual resources are
5447 // accessed. Codegen will emit a call to the resource initialization method
5448 // with the specified array index. We need to make sure though that the method
5449 // for the specific resource type is instantiated, so codegen can emit a call
5450 // to it when the array element is accessed.
5451
5452 // Find correct initialization method based on the resource binding
5453 // information.
5454 ASTContext &AST = SemaRef.getASTContext();
5455 QualType ResElementTy = AST.getBaseElementType(VD->getType());
5456 CXXRecordDecl *ResourceDecl = ResElementTy->getAsCXXRecordDecl();
5457 CXXMethodDecl *CreateMethod = nullptr;
5458
5459 bool HasCounter = hasCounterHandle(ResourceDecl);
5460 ResourceBindingAttrs ResourceAttrs(VD);
5461 if (ResourceAttrs.isExplicit())
5462 // Resource has explicit binding.
5463 CreateMethod =
5464 lookupMethod(SemaRef, ResourceDecl,
5465 HasCounter ? "__createFromBindingWithImplicitCounter"
5466 : "__createFromBinding",
5467 VD->getLocation());
5468 else
5469 // Resource has implicit binding.
5470 CreateMethod = lookupMethod(
5471 SemaRef, ResourceDecl,
5472 HasCounter ? "__createFromImplicitBindingWithImplicitCounter"
5473 : "__createFromImplicitBinding",
5474 VD->getLocation());
5475
5476 if (!CreateMethod)
5477 return false;
5478
5479 // Make sure the create method template is instantiated and emitted.
5480 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
5481 SemaRef.InstantiateFunctionDefinition(VD->getLocation(), CreateMethod,
5482 true);
5483 return true;
5484}
5485
5486// Returns true if the initialization has been handled.
5487// Returns false to use default initialization.
5489 // Objects in the hlsl_constant address space are initialized
5490 // externally, so don't synthesize an implicit initializer.
5492 return true;
5493
5494 // Initialize non-static resources at the global scope.
5495 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
5496 const Type *Ty = VD->getType().getTypePtr();
5497 if (Ty->isHLSLResourceRecord())
5498 return initGlobalResourceDecl(VD);
5499 if (Ty->isHLSLResourceRecordArray())
5500 return initGlobalResourceArrayDecl(VD);
5501 }
5502 return false;
5503}
5504
5505std::optional<const DeclBindingInfo *> SemaHLSL::inferGlobalBinding(Expr *E) {
5506 if (auto *Ternary = dyn_cast<ConditionalOperator>(E)) {
5507 auto TrueInfo = inferGlobalBinding(Ternary->getTrueExpr());
5508 auto FalseInfo = inferGlobalBinding(Ternary->getFalseExpr());
5509 if (!TrueInfo || !FalseInfo)
5510 return std::nullopt;
5511 if (*TrueInfo != *FalseInfo)
5512 return std::nullopt;
5513 return TrueInfo;
5514 }
5515
5516 if (auto *ASE = dyn_cast<ArraySubscriptExpr>(E))
5517 E = ASE->getBase()->IgnoreParenImpCasts();
5518
5519 if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E->IgnoreParens()))
5520 if (VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
5521 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
5522 if (Ty->isArrayType())
5524
5525 if (const auto *AttrResType =
5526 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
5527 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
5528 return Bindings.getDeclBindingInfo(VD, RC);
5529 }
5530 }
5531
5532 return nullptr;
5533}
5534
5535void SemaHLSL::trackLocalResource(VarDecl *VD, Expr *E) {
5536 std::optional<const DeclBindingInfo *> ExprBinding = inferGlobalBinding(E);
5537 if (!ExprBinding) {
5538 SemaRef.Diag(E->getBeginLoc(),
5539 diag::warn_hlsl_assigning_local_resource_is_not_unique)
5540 << E << VD;
5541 return; // Expr use multiple resources
5542 }
5543
5544 if (*ExprBinding == nullptr)
5545 return; // No binding could be inferred to track, return without error
5546
5547 auto PrevBinding = Assigns.find(VD);
5548 if (PrevBinding == Assigns.end()) {
5549 // No previous binding recorded, simply record the new assignment
5550 Assigns.insert({VD, *ExprBinding});
5551 return;
5552 }
5553
5554 // Otherwise, warn if the assignment implies different resource bindings
5555 if (*ExprBinding != PrevBinding->second) {
5556 SemaRef.Diag(E->getBeginLoc(),
5557 diag::warn_hlsl_assigning_local_resource_is_not_unique)
5558 << E << VD;
5559 SemaRef.Diag(VD->getLocation(), diag::note_var_declared_here) << VD;
5560 return;
5561 }
5562
5563 return;
5564}
5565
5567 Expr *RHSExpr, SourceLocation Loc) {
5568 assert((LHSExpr->getType()->isHLSLResourceRecord() ||
5569 LHSExpr->getType()->isHLSLResourceRecordArray()) &&
5570 "expected LHS to be a resource record or array of resource records");
5571 if (Opc != BO_Assign)
5572 return true;
5573
5574 // If LHS is an array subscript, get the underlying declaration.
5575 Expr *E = LHSExpr;
5576 while (auto *ASE = dyn_cast<ArraySubscriptExpr>(E))
5577 E = ASE->getBase()->IgnoreParenImpCasts();
5578
5579 // Report error if LHS is a non-static resource declared at a global scope.
5580 if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E->IgnoreParens())) {
5581 if (VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
5582 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
5583 // assignment to global resource is not allowed
5584 SemaRef.Diag(Loc, diag::err_hlsl_assign_to_global_resource) << VD;
5585 SemaRef.Diag(VD->getLocation(), diag::note_var_declared_here) << VD;
5586 return false;
5587 }
5588
5589 trackLocalResource(VD, RHSExpr);
5590 }
5591 }
5592 return true;
5593}
5594
5595// Walks though the global variable declaration, collects all resource binding
5596// requirements and adds them to Bindings
5597void SemaHLSL::collectResourceBindingsOnVarDecl(VarDecl *VD) {
5598 assert(VD->hasGlobalStorage() && VD->getType()->isHLSLIntangibleType() &&
5599 "expected global variable that contains HLSL resource");
5600
5601 // Cbuffers and Tbuffers are HLSLBufferDecl types
5602 if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(VD)) {
5603 Bindings.addDeclBindingInfo(VD, CBufferOrTBuffer->isCBuffer()
5604 ? ResourceClass::CBuffer
5605 : ResourceClass::SRV);
5606 return;
5607 }
5608
5609 // Unwrap arrays
5610 // FIXME: Calculate array size while unwrapping
5611 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
5612 while (Ty->isArrayType()) {
5613 const ArrayType *AT = cast<ArrayType>(Ty);
5615 }
5616
5617 // Resource (or array of resources)
5618 if (const HLSLAttributedResourceType *AttrResType =
5619 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
5620 Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass);
5621 return;
5622 }
5623
5624 // User defined record type
5625 if (const RecordType *RT = dyn_cast<RecordType>(Ty))
5626 collectResourceBindingsOnUserRecordDecl(VD, RT);
5627}
5628
5629// Walks though the explicit resource binding attributes on the declaration,
5630// and makes sure there is a resource that matched the binding and updates
5631// DeclBindingInfoLists
5632void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
5633 assert(VD->hasGlobalStorage() && "expected global variable");
5634
5635 bool HasBinding = false;
5636 for (Attr *A : VD->attrs()) {
5637 if (isa<HLSLVkBindingAttr>(A)) {
5638 HasBinding = true;
5639 if (auto PA = VD->getAttr<HLSLVkPushConstantAttr>())
5640 Diag(PA->getLoc(), diag::err_hlsl_attr_incompatible) << A << PA;
5641 }
5642
5643 HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
5644 if (!RBA || !RBA->hasRegisterSlot())
5645 continue;
5646 HasBinding = true;
5647
5648 RegisterType RT = RBA->getRegisterType();
5649 assert(RT != RegisterType::I && "invalid or obsolete register type should "
5650 "never have an attribute created");
5651
5652 if (RT == RegisterType::C) {
5653 if (Bindings.hasBindingInfoForDecl(VD))
5654 SemaRef.Diag(VD->getLocation(),
5655 diag::warn_hlsl_user_defined_type_missing_member)
5656 << static_cast<int>(RT);
5657 continue;
5658 }
5659
5660 // Find DeclBindingInfo for this binding and update it, or report error
5661 // if it does not exist (user type does to contain resources with the
5662 // expected resource class).
5664 if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, RC)) {
5665 // update binding info
5666 BI->setBindingAttribute(RBA, BindingType::Explicit);
5667 } else {
5668 SemaRef.Diag(VD->getLocation(),
5669 diag::warn_hlsl_user_defined_type_missing_member)
5670 << static_cast<int>(RT);
5671 }
5672 }
5673
5674 if (!HasBinding && isResourceRecordTypeOrArrayOf(VD))
5675 SemaRef.Diag(VD->getLocation(), diag::warn_hlsl_implicit_binding);
5676}
5677namespace {
5678class InitListTransformer {
5679 Sema &S;
5680 ASTContext &Ctx;
5681 QualType InitTy;
5682 QualType *DstIt = nullptr;
5683 Expr **ArgIt = nullptr;
5684 // Is wrapping the destination type iterator required? This is only used for
5685 // incomplete array types where we loop over the destination type since we
5686 // don't know the full number of elements from the declaration.
5687 bool Wrap;
5688
5689 bool castInitializer(Expr *E) {
5690 assert(DstIt && "This should always be something!");
5691 if (DstIt == DestTypes.end()) {
5692 if (!Wrap) {
5693 ArgExprs.push_back(E);
5694 // This is odd, but it isn't technically a failure due to conversion, we
5695 // handle mismatched counts of arguments differently.
5696 return true;
5697 }
5698 DstIt = DestTypes.begin();
5699 }
5700 InitializedEntity Entity = InitializedEntity::InitializeParameter(
5701 Ctx, *DstIt, /* Consumed (ObjC) */ false);
5702 ExprResult Res = S.PerformCopyInitialization(Entity, E->getBeginLoc(), E);
5703 if (Res.isInvalid())
5704 return false;
5705 Expr *Init = Res.get();
5706 ArgExprs.push_back(Init);
5707 DstIt++;
5708 return true;
5709 }
5710
5711 bool buildInitializerListImpl(Expr *E) {
5712 // If this is an initialization list, traverse the sub initializers.
5713 if (auto *Init = dyn_cast<InitListExpr>(E)) {
5714 for (auto *SubInit : Init->inits())
5715 if (!buildInitializerListImpl(SubInit))
5716 return false;
5717 return true;
5718 }
5719
5720 // If this is a scalar type, just enqueue the expression.
5721 QualType Ty = E->getType().getDesugaredType(Ctx);
5722
5723 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()) ||
5725 return castInitializer(E);
5726
5727 // If this is an aggregate type and a prvalue, create an xvalue temporary
5728 // so the member accesses will be xvalues. Wrap it in OpaqueExpr to make
5729 // sure codegen will not generate duplicate copies.
5730 if (E->isPRValue() && Ty->isAggregateType()) {
5732 if (TmpExpr.isInvalid())
5733 return false;
5734 E = TmpExpr.get();
5735 E = new (Ctx) OpaqueValueExpr(E->getBeginLoc(), E->getType(),
5736 E->getValueKind(), E->getObjectKind(), E);
5737 }
5738
5739 if (auto *VecTy = Ty->getAs<VectorType>()) {
5740 uint64_t Size = VecTy->getNumElements();
5741
5742 QualType SizeTy = Ctx.getSizeType();
5743 uint64_t SizeTySize = Ctx.getTypeSize(SizeTy);
5744 for (uint64_t I = 0; I < Size; ++I) {
5745 auto *Idx = IntegerLiteral::Create(Ctx, llvm::APInt(SizeTySize, I),
5746 SizeTy, SourceLocation());
5747
5749 E, E->getBeginLoc(), Idx, E->getEndLoc());
5750 if (ElExpr.isInvalid())
5751 return false;
5752 if (!castInitializer(ElExpr.get()))
5753 return false;
5754 }
5755 return true;
5756 }
5757 if (auto *MTy = Ty->getAs<ConstantMatrixType>()) {
5758 unsigned Rows = MTy->getNumRows();
5759 unsigned Cols = MTy->getNumColumns();
5760 QualType ElemTy = MTy->getElementType();
5761
5762 for (unsigned R = 0; R < Rows; ++R) {
5763 for (unsigned C = 0; C < Cols; ++C) {
5764 // row index literal
5765 Expr *RowIdx = IntegerLiteral::Create(
5766 Ctx, llvm::APInt(Ctx.getIntWidth(Ctx.IntTy), R), Ctx.IntTy,
5767 E->getBeginLoc());
5768 // column index literal
5769 Expr *ColIdx = IntegerLiteral::Create(
5770 Ctx, llvm::APInt(Ctx.getIntWidth(Ctx.IntTy), C), Ctx.IntTy,
5771 E->getBeginLoc());
5773 E, RowIdx, ColIdx, E->getEndLoc());
5774 if (ElExpr.isInvalid())
5775 return false;
5776 if (!castInitializer(ElExpr.get()))
5777 return false;
5778 ElExpr.get()->setType(ElemTy);
5779 }
5780 }
5781 return true;
5782 }
5783
5784 if (auto *ArrTy = dyn_cast<ConstantArrayType>(Ty.getTypePtr())) {
5785 uint64_t Size = ArrTy->getZExtSize();
5786 QualType SizeTy = Ctx.getSizeType();
5787 uint64_t SizeTySize = Ctx.getTypeSize(SizeTy);
5788 for (uint64_t I = 0; I < Size; ++I) {
5789 auto *Idx = IntegerLiteral::Create(Ctx, llvm::APInt(SizeTySize, I),
5790 SizeTy, SourceLocation());
5792 E, E->getBeginLoc(), Idx, E->getEndLoc());
5793 if (ElExpr.isInvalid())
5794 return false;
5795 if (!buildInitializerListImpl(ElExpr.get()))
5796 return false;
5797 }
5798 return true;
5799 }
5800
5801 if (auto *RD = Ty->getAsCXXRecordDecl()) {
5802 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
5803 RecordDecls.push_back(RD);
5804 while (RecordDecls.back()->getNumBases()) {
5805 CXXRecordDecl *D = RecordDecls.back();
5806 assert(D->getNumBases() == 1 &&
5807 "HLSL doesn't support multiple inheritance");
5808 RecordDecls.push_back(
5810 }
5811 while (!RecordDecls.empty()) {
5812 CXXRecordDecl *RD = RecordDecls.pop_back_val();
5813 for (auto *FD : RD->fields()) {
5814 if (FD->isUnnamedBitField())
5815 continue;
5816 DeclAccessPair Found = DeclAccessPair::make(FD, FD->getAccess());
5817 DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc());
5819 E, false, E->getBeginLoc(), CXXScopeSpec(), FD, Found, NameInfo);
5820 if (Res.isInvalid())
5821 return false;
5822 if (!buildInitializerListImpl(Res.get()))
5823 return false;
5824 }
5825 }
5826 }
5827 return true;
5828 }
5829
5830 Expr *generateInitListsImpl(QualType Ty) {
5831 Ty = Ty.getDesugaredType(Ctx);
5832 assert(ArgIt != ArgExprs.end() && "Something is off in iteration!");
5833 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()) ||
5835 return *(ArgIt++);
5836
5837 llvm::SmallVector<Expr *> Inits;
5838 if (Ty->isVectorType() || Ty->isConstantArrayType() ||
5839 Ty->isConstantMatrixType()) {
5840 QualType ElTy;
5841 uint64_t Size = 0;
5842 if (auto *ATy = Ty->getAs<VectorType>()) {
5843 ElTy = ATy->getElementType();
5844 Size = ATy->getNumElements();
5845 } else if (auto *CMTy = Ty->getAs<ConstantMatrixType>()) {
5846 ElTy = CMTy->getElementType();
5847 Size = CMTy->getNumElementsFlattened();
5848 } else {
5849 auto *VTy = cast<ConstantArrayType>(Ty.getTypePtr());
5850 ElTy = VTy->getElementType();
5851 Size = VTy->getZExtSize();
5852 }
5853 for (uint64_t I = 0; I < Size; ++I)
5854 Inits.push_back(generateInitListsImpl(ElTy));
5855 }
5856 if (auto *RD = Ty->getAsCXXRecordDecl()) {
5857 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
5858 RecordDecls.push_back(RD);
5859 while (RecordDecls.back()->getNumBases()) {
5860 CXXRecordDecl *D = RecordDecls.back();
5861 assert(D->getNumBases() == 1 &&
5862 "HLSL doesn't support multiple inheritance");
5863 RecordDecls.push_back(
5865 }
5866 while (!RecordDecls.empty()) {
5867 CXXRecordDecl *RD = RecordDecls.pop_back_val();
5868 for (auto *FD : RD->fields())
5869 if (!FD->isUnnamedBitField())
5870 Inits.push_back(generateInitListsImpl(FD->getType()));
5871 }
5872 }
5873 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
5874 Inits, Inits.back()->getEndLoc());
5875 NewInit->setType(Ty);
5876 return NewInit;
5877 }
5878
5879public:
5880 llvm::SmallVector<QualType, 16> DestTypes;
5881 llvm::SmallVector<Expr *, 16> ArgExprs;
5882 InitListTransformer(Sema &SemaRef, const InitializedEntity &Entity)
5883 : S(SemaRef), Ctx(SemaRef.getASTContext()),
5884 Wrap(Entity.getType()->isIncompleteArrayType()) {
5885 InitTy = Entity.getType().getNonReferenceType();
5886 // When we're generating initializer lists for incomplete array types we
5887 // need to wrap around both when building the initializers and when
5888 // generating the final initializer lists.
5889 if (Wrap) {
5890 assert(InitTy->isIncompleteArrayType());
5891 const IncompleteArrayType *IAT = Ctx.getAsIncompleteArrayType(InitTy);
5892 InitTy = IAT->getElementType();
5893 }
5894 BuildFlattenedTypeList(InitTy, DestTypes);
5895 DstIt = DestTypes.begin();
5896 }
5897
5898 bool buildInitializerList(Expr *E) { return buildInitializerListImpl(E); }
5899
5900 Expr *generateInitLists() {
5901 assert(!ArgExprs.empty() &&
5902 "Call buildInitializerList to generate argument expressions.");
5903 ArgIt = ArgExprs.begin();
5904 if (!Wrap)
5905 return generateInitListsImpl(InitTy);
5906 llvm::SmallVector<Expr *> Inits;
5907 while (ArgIt != ArgExprs.end())
5908 Inits.push_back(generateInitListsImpl(InitTy));
5909
5910 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
5911 Inits, Inits.back()->getEndLoc());
5912 llvm::APInt ArySize(64, Inits.size());
5913 NewInit->setType(Ctx.getConstantArrayType(InitTy, ArySize, nullptr,
5914 ArraySizeModifier::Normal, 0));
5915 return NewInit;
5916 }
5917};
5918} // namespace
5919
5920// Recursively detect any incomplete array anywhere in the type graph,
5921// including arrays, struct fields, and base classes.
5923 Ty = Ty.getCanonicalType();
5924
5925 // Array types
5926 if (const ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
5928 return true;
5930 }
5931
5932 // Record (struct/class) types
5933 if (const auto *RT = Ty->getAs<RecordType>()) {
5934 const RecordDecl *RD = RT->getDecl();
5935
5936 // Walk base classes (for C++ / HLSL structs with inheritance)
5937 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
5938 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
5939 if (containsIncompleteArrayType(Base.getType()))
5940 return true;
5941 }
5942 }
5943
5944 // Walk fields
5945 for (const FieldDecl *F : RD->fields()) {
5946 if (containsIncompleteArrayType(F->getType()))
5947 return true;
5948 }
5949 }
5950
5951 return false;
5952}
5953
5955 InitListExpr *Init) {
5956 // If the initializer is a scalar, just return it.
5957 if (Init->getType()->isScalarType())
5958 return true;
5959 ASTContext &Ctx = SemaRef.getASTContext();
5960 InitListTransformer ILT(SemaRef, Entity);
5961
5962 for (unsigned I = 0; I < Init->getNumInits(); ++I) {
5963 Expr *E = Init->getInit(I);
5964 if (E->HasSideEffects(Ctx)) {
5965 QualType Ty = E->getType();
5966 if (Ty->isRecordType())
5967 E = new (Ctx) MaterializeTemporaryExpr(Ty, E, E->isLValue());
5968 E = new (Ctx) OpaqueValueExpr(E->getBeginLoc(), Ty, E->getValueKind(),
5969 E->getObjectKind(), E);
5970 Init->setInit(I, E);
5971 }
5972 if (!ILT.buildInitializerList(E))
5973 return false;
5974 }
5975 size_t ExpectedSize = ILT.DestTypes.size();
5976 size_t ActualSize = ILT.ArgExprs.size();
5977 if (ExpectedSize == 0 && ActualSize == 0)
5978 return true;
5979
5980 // Reject empty initializer if *any* incomplete array exists structurally
5981 if (ActualSize == 0 && containsIncompleteArrayType(Entity.getType())) {
5982 QualType InitTy = Entity.getType().getNonReferenceType();
5983 if (InitTy.hasAddressSpace())
5984 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(InitTy);
5985
5986 SemaRef.Diag(Init->getBeginLoc(), diag::err_hlsl_incorrect_num_initializers)
5987 << /*TooManyOrFew=*/(int)(ExpectedSize < ActualSize) << InitTy
5988 << /*ExpectedSize=*/ExpectedSize << /*ActualSize=*/ActualSize;
5989 return false;
5990 }
5991
5992 // We infer size after validating legality.
5993 // For incomplete arrays it is completely arbitrary to choose whether we think
5994 // the user intended fewer or more elements. This implementation assumes that
5995 // the user intended more, and errors that there are too few initializers to
5996 // complete the final element.
5997 if (Entity.getType()->isIncompleteArrayType()) {
5998 assert(ExpectedSize > 0 &&
5999 "The expected size of an incomplete array type must be at least 1.");
6000 ExpectedSize =
6001 ((ActualSize + ExpectedSize - 1) / ExpectedSize) * ExpectedSize;
6002 }
6003
6004 // An initializer list might be attempting to initialize a reference or
6005 // rvalue-reference. When checking the initializer we should look through
6006 // the reference.
6007 QualType InitTy = Entity.getType().getNonReferenceType();
6008 if (InitTy.hasAddressSpace())
6009 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(InitTy);
6010 if (ExpectedSize != ActualSize) {
6011 int TooManyOrFew = ActualSize > ExpectedSize ? 1 : 0;
6012 SemaRef.Diag(Init->getBeginLoc(), diag::err_hlsl_incorrect_num_initializers)
6013 << TooManyOrFew << InitTy << ExpectedSize << ActualSize;
6014 return false;
6015 }
6016
6017 // generateInitListsImpl will always return an InitListExpr here, because the
6018 // scalar case is handled above.
6019 auto *NewInit = cast<InitListExpr>(ILT.generateInitLists());
6020 Init->resizeInits(Ctx, NewInit->getNumInits());
6021 for (unsigned I = 0; I < NewInit->getNumInits(); ++I)
6022 Init->updateInit(Ctx, I, NewInit->getInit(I));
6023 return true;
6024}
6025
6026static QualType ReportMatrixInvalidMember(Sema &S, StringRef Name,
6027 StringRef Expected,
6028 SourceLocation OpLoc,
6029 SourceLocation CompLoc) {
6030 S.Diag(OpLoc, diag::err_builtin_matrix_invalid_member)
6031 << Name << Expected << SourceRange(CompLoc);
6032 return QualType();
6033}
6034
6037 const IdentifierInfo *CompName,
6038 SourceLocation CompLoc) {
6039 const auto *MT = baseType->castAs<ConstantMatrixType>();
6040 StringRef AccessorName = CompName->getName();
6041 assert(!AccessorName.empty() && "Matrix Accessor must have a name");
6042
6043 unsigned Rows = MT->getNumRows();
6044 unsigned Cols = MT->getNumColumns();
6045 bool IsZeroBasedAccessor = false;
6046 unsigned ChunkLen = 0;
6047 if (AccessorName.size() < 2)
6048 return ReportMatrixInvalidMember(S, AccessorName,
6049 "length 4 for zero based: \'_mRC\' or "
6050 "length 3 for one-based: \'_RC\' accessor",
6051 OpLoc, CompLoc);
6052
6053 if (AccessorName[0] == '_') {
6054 if (AccessorName[1] == 'm') {
6055 IsZeroBasedAccessor = true;
6056 ChunkLen = 4; // zero-based: "_mRC"
6057 } else {
6058 ChunkLen = 3; // one-based: "_RC"
6059 }
6060 } else
6062 S, AccessorName, "zero based: \'_mRC\' or one-based: \'_RC\' accessor",
6063 OpLoc, CompLoc);
6064
6065 if (AccessorName.size() % ChunkLen != 0) {
6066 const llvm::StringRef Expected = IsZeroBasedAccessor
6067 ? "zero based: '_mRC' accessor"
6068 : "one-based: '_RC' accessor";
6069
6070 return ReportMatrixInvalidMember(S, AccessorName, Expected, OpLoc, CompLoc);
6071 }
6072
6073 auto isDigit = [](char c) { return c >= '0' && c <= '9'; };
6074 auto isZeroBasedIndex = [](unsigned i) { return i <= 3; };
6075 auto isOneBasedIndex = [](unsigned i) { return i >= 1 && i <= 4; };
6076
6077 bool HasRepeated = false;
6078 SmallVector<bool, 16> Seen(Rows * Cols, false);
6079 unsigned NumComponents = 0;
6080 const char *Begin = AccessorName.data();
6081
6082 for (unsigned I = 0, E = AccessorName.size(); I < E; I += ChunkLen) {
6083 const char *Chunk = Begin + I;
6084 char RowChar = 0, ColChar = 0;
6085 if (IsZeroBasedAccessor) {
6086 // Zero-based: "_mRC"
6087 if (Chunk[0] != '_' || Chunk[1] != 'm') {
6088 char Bad = (Chunk[0] != '_') ? Chunk[0] : Chunk[1];
6090 S, StringRef(&Bad, 1), "\'_m\' prefix",
6091 OpLoc.getLocWithOffset(I + (Bad == Chunk[0] ? 1 : 2)), CompLoc);
6092 }
6093 RowChar = Chunk[2];
6094 ColChar = Chunk[3];
6095 } else {
6096 // One-based: "_RC"
6097 if (Chunk[0] != '_')
6099 S, StringRef(&Chunk[0], 1), "\'_\' prefix",
6100 OpLoc.getLocWithOffset(I + 1), CompLoc);
6101 RowChar = Chunk[1];
6102 ColChar = Chunk[2];
6103 }
6104
6105 // Must be digits.
6106 bool IsDigitsError = false;
6107 if (!isDigit(RowChar)) {
6108 unsigned BadPos = IsZeroBasedAccessor ? 2 : 1;
6109 ReportMatrixInvalidMember(S, StringRef(&RowChar, 1), "row as integer",
6110 OpLoc.getLocWithOffset(I + BadPos + 1),
6111 CompLoc);
6112 IsDigitsError = true;
6113 }
6114
6115 if (!isDigit(ColChar)) {
6116 unsigned BadPos = IsZeroBasedAccessor ? 3 : 2;
6117 ReportMatrixInvalidMember(S, StringRef(&ColChar, 1), "column as integer",
6118 OpLoc.getLocWithOffset(I + BadPos + 1),
6119 CompLoc);
6120 IsDigitsError = true;
6121 }
6122 if (IsDigitsError)
6123 return QualType();
6124
6125 unsigned Row = RowChar - '0';
6126 unsigned Col = ColChar - '0';
6127
6128 bool HasIndexingError = false;
6129 if (IsZeroBasedAccessor) {
6130 // 0-based [0..3]
6131 if (!isZeroBasedIndex(Row)) {
6132 S.Diag(OpLoc, diag::err_hlsl_matrix_element_not_in_bounds)
6133 << /*row*/ 0 << /*zero-based*/ 0 << SourceRange(CompLoc);
6134 HasIndexingError = true;
6135 }
6136 if (!isZeroBasedIndex(Col)) {
6137 S.Diag(OpLoc, diag::err_hlsl_matrix_element_not_in_bounds)
6138 << /*col*/ 1 << /*zero-based*/ 0 << SourceRange(CompLoc);
6139 HasIndexingError = true;
6140 }
6141 } else {
6142 // 1-based [1..4]
6143 if (!isOneBasedIndex(Row)) {
6144 S.Diag(OpLoc, diag::err_hlsl_matrix_element_not_in_bounds)
6145 << /*row*/ 0 << /*one-based*/ 1 << SourceRange(CompLoc);
6146 HasIndexingError = true;
6147 }
6148 if (!isOneBasedIndex(Col)) {
6149 S.Diag(OpLoc, diag::err_hlsl_matrix_element_not_in_bounds)
6150 << /*col*/ 1 << /*one-based*/ 1 << SourceRange(CompLoc);
6151 HasIndexingError = true;
6152 }
6153 // Convert to 0-based after range checking.
6154 --Row;
6155 --Col;
6156 }
6157
6158 if (HasIndexingError)
6159 return QualType();
6160
6161 // Note: matrix swizzle index is hard coded. That means Row and Col can
6162 // potentially be larger than Rows and Cols if matrix size is less than
6163 // the max index size.
6164 bool HasBoundsError = false;
6165 if (Row >= Rows) {
6166 Diag(OpLoc, diag::err_hlsl_matrix_index_out_of_bounds)
6167 << /*Row*/ 0 << Row << Rows << SourceRange(CompLoc);
6168 HasBoundsError = true;
6169 }
6170 if (Col >= Cols) {
6171 Diag(OpLoc, diag::err_hlsl_matrix_index_out_of_bounds)
6172 << /*Col*/ 1 << Col << Cols << SourceRange(CompLoc);
6173 HasBoundsError = true;
6174 }
6175 if (HasBoundsError)
6176 return QualType();
6177
6178 unsigned FlatIndex = Row * Cols + Col;
6179 if (Seen[FlatIndex])
6180 HasRepeated = true;
6181 Seen[FlatIndex] = true;
6182 ++NumComponents;
6183 }
6184 if (NumComponents == 0 || NumComponents > 4) {
6185 S.Diag(OpLoc, diag::err_hlsl_matrix_swizzle_invalid_length)
6186 << NumComponents << SourceRange(CompLoc);
6187 return QualType();
6188 }
6189
6190 QualType ElemTy = MT->getElementType();
6191 if (NumComponents == 1)
6192 return ElemTy;
6193 QualType VT = S.Context.getExtVectorType(ElemTy, NumComponents);
6194 if (HasRepeated)
6195 VK = VK_PRValue;
6196
6197 for (Sema::ExtVectorDeclsType::iterator
6199 E = S.ExtVectorDecls.end();
6200 I != E; ++I) {
6201 if ((*I)->getUnderlyingType() == VT)
6203 /*Qualifier=*/std::nullopt, *I);
6204 }
6205
6206 return VT;
6207}
6208
6210 // If initializing a local resource, track the resource binding it is using
6211 if (VDecl->getType()->isHLSLResourceRecord() && !VDecl->hasGlobalStorage())
6212 trackLocalResource(VDecl, Init);
6213
6214 const HLSLVkConstantIdAttr *ConstIdAttr =
6215 VDecl->getAttr<HLSLVkConstantIdAttr>();
6216 if (!ConstIdAttr)
6217 return true;
6218
6219 ASTContext &Context = SemaRef.getASTContext();
6220
6221 APValue InitValue;
6222 if (!Init->isCXX11ConstantExpr(Context, &InitValue)) {
6223 Diag(VDecl->getLocation(), diag::err_specialization_const);
6224 VDecl->setInvalidDecl();
6225 return false;
6226 }
6227
6228 Builtin::ID BID =
6230
6231 // Argument 1: The ID from the attribute
6232 int ConstantID = ConstIdAttr->getId();
6233 llvm::APInt IDVal(Context.getIntWidth(Context.IntTy), ConstantID);
6234 Expr *IdExpr = IntegerLiteral::Create(Context, IDVal, Context.IntTy,
6235 ConstIdAttr->getLocation());
6236
6237 SmallVector<Expr *, 2> Args = {IdExpr, Init};
6238 Expr *C = SemaRef.BuildBuiltinCallExpr(Init->getExprLoc(), BID, Args);
6239 if (C->getType()->getCanonicalTypeUnqualified() !=
6241 C = SemaRef
6242 .BuildCStyleCastExpr(SourceLocation(),
6243 Context.getTrivialTypeSourceInfo(
6244 Init->getType(), Init->getExprLoc()),
6245 SourceLocation(), C)
6246 .get();
6247 }
6248 Init = C;
6249 return true;
6250}
6251
6253 SourceLocation NameLoc) {
6254 if (!Template)
6255 return QualType();
6256
6257 DeclContext *DC = Template->getDeclContext();
6258 if (!DC->isNamespace() || !cast<NamespaceDecl>(DC)->getIdentifier() ||
6259 cast<NamespaceDecl>(DC)->getName() != "hlsl")
6260 return QualType();
6261
6262 TemplateParameterList *Params = Template->getTemplateParameters();
6263 if (!Params || Params->size() != 1)
6264 return QualType();
6265
6266 if (!Template->isImplicit())
6267 return QualType();
6268
6269 // We manually extract default arguments here instead of letting
6270 // CheckTemplateIdType handle it. This ensures that for resource types that
6271 // lack a default argument (like Buffer), we return a null QualType, which
6272 // triggers the "requires template arguments" error rather than a less
6273 // descriptive "too few template arguments" error.
6274 TemplateArgumentListInfo TemplateArgs(NameLoc, NameLoc);
6275 for (NamedDecl *P : *Params) {
6276 if (auto *TTP = dyn_cast<TemplateTypeParmDecl>(P)) {
6277 if (TTP->hasDefaultArgument()) {
6278 TemplateArgs.addArgument(TTP->getDefaultArgument());
6279 continue;
6280 }
6281 } else if (auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(P)) {
6282 if (NTTP->hasDefaultArgument()) {
6283 TemplateArgs.addArgument(NTTP->getDefaultArgument());
6284 continue;
6285 }
6286 } else if (auto *TTPD = dyn_cast<TemplateTemplateParmDecl>(P)) {
6287 if (TTPD->hasDefaultArgument()) {
6288 TemplateArgs.addArgument(TTPD->getDefaultArgument());
6289 continue;
6290 }
6291 }
6292 return QualType();
6293 }
6294
6295 return SemaRef.CheckTemplateIdType(
6297 TemplateArgs, nullptr, /*ForNestedNameSpecifier=*/false);
6298}
Defines the clang::ASTContext interface.
Defines enum values for all the target-independent builtin functions.
llvm::dxil::ResourceClass ResourceClass
Defines the C++ Decl subclasses, other than those for templates (found in DeclTemplate....
TokenType getType() const
Returns the token's type, e.g.
FormatToken * Previous
The previous token in the unwrapped line.
Defines the clang::IdentifierInfo, clang::IdentifierTable, and clang::Selector interfaces.
#define X(type, name)
Definition Value.h:97
Forward-declares and imports various common LLVM datatypes that clang wants to use unqualified.
llvm::SmallVector< std::pair< const MemRegion *, SVal >, 4 > Bindings
static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType)
static void BuildFlattenedTypeList(QualType BaseTy, llvm::SmallVectorImpl< QualType > &List)
static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static bool containsIncompleteArrayType(QualType Ty)
static QualType handleIntegerVectorBinOpConversion(Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign)
static bool convertToRegisterType(StringRef Slot, RegisterType *RT)
Definition SemaHLSL.cpp:82
static StringRef createRegisterString(ASTContext &AST, RegisterType RegType, unsigned N)
Definition SemaHLSL.cpp:184
static bool CheckWaveActive(Sema *S, CallExpr *TheCall)
static void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl)
Definition SemaHLSL.cpp:609
static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz)
static QualType ReportMatrixInvalidMember(Sema &S, StringRef Name, StringRef Expected, SourceLocation OpLoc, SourceLocation CompLoc)
static bool CheckBoolSelect(Sema *S, CallExpr *TheCall)
static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context, QualType T)
Definition SemaHLSL.cpp:246
static bool isZeroSizedArray(const ConstantArrayType *CAT)
Definition SemaHLSL.cpp:365
static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, Decl *D, RegisterType RegType, bool SpecifiedSpace)
static bool hasConstantBufferLayout(QualType QT)
static FieldDecl * createFieldForHostLayoutStruct(Sema &S, const Type *Ty, IdentifierInfo *II, CXXRecordDecl *LayoutStruct)
Definition SemaHLSL.cpp:517
static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
SampleKind
static bool isInvalidConstantBufferLeafElementType(const Type *Ty)
Definition SemaHLSL.cpp:399
static bool CheckCalculateLodBuiltin(Sema &S, CallExpr *TheCall)
static Builtin::ID getSpecConstBuiltinId(const Type *Type)
Definition SemaHLSL.cpp:150
static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static const Type * createHostLayoutType(Sema &S, const Type *Ty)
Definition SemaHLSL.cpp:490
static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall, unsigned ArgIndex)
static const HLSLAttributedResourceType * getResourceArrayHandleType(QualType QT)
Definition SemaHLSL.cpp:381
static IdentifierInfo * getHostLayoutStructName(Sema &S, NamedDecl *BaseDecl, bool MustBeUnique)
Definition SemaHLSL.cpp:455
static void addImplicitBindingAttrToDecl(Sema &S, Decl *D, RegisterType RT, uint32_t ImplicitBindingOrderID)
Definition SemaHLSL.cpp:653
static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall, QualType ReturnType)
static unsigned calculateLegacyCbufferSize(const ASTContext &Context, QualType T)
Definition SemaHLSL.cpp:265
static bool CheckLoadLevelBuiltin(Sema &S, CallExpr *TheCall)
static RegisterType getRegisterType(ResourceClass RC)
Definition SemaHLSL.cpp:62
static bool ValidateRegisterNumber(uint64_t SlotNum, Decl *TheDecl, ASTContext &Ctx, RegisterType RegTy)
static bool isVkPipelineBuiltin(const ASTContext &AstContext, FunctionDecl *FD, HLSLAppliedSemanticAttr *Semantic, bool IsInput)
Definition SemaHLSL.cpp:841
static bool CheckVectorElementCount(Sema *S, QualType PassedType, QualType BaseType, unsigned ExpectedCount, SourceLocation Loc)
static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall, unsigned ArgIndex)
static QualType castElement(Sema &S, ExprResult &E, QualType Ty)
static char getRegisterTypeChar(RegisterType RT)
Definition SemaHLSL.cpp:114
static bool CheckNotBoolScalarOrVector(Sema *S, CallExpr *TheCall, unsigned ArgIndex)
static CXXRecordDecl * findRecordDeclInContext(IdentifierInfo *II, DeclContext *DC)
Definition SemaHLSL.cpp:438
static bool CheckWavePrefix(Sema *S, CallExpr *TheCall)
static bool CheckExpectedBitWidth(Sema *S, CallExpr *TheCall, unsigned ArgOrdinal, unsigned Width)
static bool CheckVectorSelect(Sema *S, CallExpr *TheCall)
static QualType handleFloatVectorBinOpConversion(Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign)
static ResourceClass getResourceClass(RegisterType RT)
Definition SemaHLSL.cpp:132
static CXXRecordDecl * createHostLayoutStruct(Sema &S, CXXRecordDecl *StructDecl)
Definition SemaHLSL.cpp:544
static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar, unsigned ArgIndex)
static bool CheckSamplingBuiltin(Sema &S, CallExpr *TheCall, SampleKind Kind)
static bool CheckScalarOrVectorOrMatrix(Sema *S, CallExpr *TheCall, QualType Scalar, unsigned ArgIndex)
static bool CheckFloatRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static bool CheckAnyDoubleRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD)
Definition SemaHLSL.cpp:418
static bool CheckResourceHandle(Sema *S, CallExpr *TheCall, unsigned ArgIndex, llvm::function_ref< bool(const HLSLAttributedResourceType *ResType)> Check=nullptr)
static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl)
Definition SemaHLSL.cpp:312
static bool IsDefaultBufferConstantDecl(const ASTContext &Ctx, VarDecl *VD)
HLSLResourceBindingAttr::RegisterType RegisterType
Definition SemaHLSL.cpp:57
static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy, QualType SrcTy)
static bool CheckGatherBuiltin(Sema &S, CallExpr *TheCall, bool IsCmp)
static bool isValidWaveSizeValue(unsigned Value)
static bool isResourceRecordTypeOrArrayOf(QualType Ty)
Definition SemaHLSL.cpp:372
static bool AccumulateHLSLResourceSlots(QualType Ty, uint64_t &StartSlot, const uint64_t &Limit, const ResourceClass ResClass, ASTContext &Ctx, uint64_t ArrayCount=1)
static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, RegisterType regType)
static bool CheckTextureSamplerAndLocation(Sema &S, CallExpr *TheCall)
static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, Decl *D, RegisterType RegType, bool SpecifiedSpace)
static bool CheckIndexType(Sema *S, CallExpr *TheCall, unsigned IndexArgIndex)
This file declares semantic analysis for HLSL constructs.
Defines the clang::SourceLocation class and associated facilities.
Defines various enumerations that describe declaration and type specifiers.
C Language Family Type Representation.
Defines the clang::TypeLoc interface and its subclasses.
C Language Family Type Representation.
static const TypeInfo & getInfo(unsigned id)
Definition Types.cpp:44
__device__ __2f16 float c
return(__x > > __y)|(__x<<(32 - __y))
APValue - This class implements a discriminated union of [uninitialized] [APSInt] [APFloat],...
Definition APValue.h:122
virtual bool HandleTopLevelDecl(DeclGroupRef D)
HandleTopLevelDecl - Handle the specified top-level declaration.
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
Definition ASTContext.h:227
unsigned getIntWidth(QualType T) const
int getIntegerTypeOrder(QualType LHS, QualType RHS) const
Return the highest ranked integer type, see C99 6.3.1.8p1.
CanQualType FloatTy
QualType getPointerType(QualType T) const
Return the uniqued reference to the type for a pointer to the specified type.
const IncompleteArrayType * getAsIncompleteArrayType(QualType T) const
IdentifierTable & Idents
Definition ASTContext.h:805
QualType getConstantArrayType(QualType EltTy, const llvm::APInt &ArySize, const Expr *SizeExpr, ArraySizeModifier ASM, unsigned IndexTypeQuals) const
Return the unique reference to the type for a constant array of the specified element type.
QualType getBaseElementType(const ArrayType *VAT) const
Return the innermost element type of an array type.
int getFloatingTypeOrder(QualType LHS, QualType RHS) const
Compare the rank of the two specified floating point types, ignoring the domain of the type (i....
CanQualType BoolTy
TypeSourceInfo * getTrivialTypeSourceInfo(QualType T, SourceLocation Loc=SourceLocation()) const
Allocate a TypeSourceInfo where all locations have been initialized to a given location,...
QualType getStringLiteralArrayType(QualType EltTy, unsigned Length) const
Return a type for a constant array for a string literal of the specified element type and length.
CanQualType CharTy
CanQualType IntTy
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
CharUnits getTypeSizeInChars(QualType T) const
Return the size of the specified (complete) type T, in characters.
CanQualType UnsignedIntTy
QualType getTypedefType(ElaboratedTypeKeyword Keyword, NestedNameSpecifier Qualifier, const TypedefNameDecl *Decl, QualType UnderlyingType=QualType(), std::optional< bool > TypeMatchesDeclOrNone=std::nullopt) const
Return the unique reference to the type for the specified typedef-name decl.
llvm::StringRef backupStr(llvm::StringRef S) const
Definition ASTContext.h:887
QualType getSizeType() const
Return the unique type for "size_t" (C99 7.17), defined in <stddef.h>.
QualType getExtVectorType(QualType VectorType, unsigned NumElts) const
Return the unique reference to an extended vector type of the specified element type and size.
const TargetInfo & getTargetInfo() const
Definition ASTContext.h:924
QualType getHLSLAttributedResourceType(QualType Wrapped, QualType Contained, const HLSLAttributedResourceType::Attributes &Attrs)
QualType getAddrSpaceQualType(QualType T, LangAS AddressSpace) const
Return the uniqued reference to the type for an address space qualified type with the specified type ...
CanQualType getCanonicalTagType(const TagDecl *TD) const
static bool hasSameUnqualifiedType(QualType T1, QualType T2)
Determine whether the given types are equivalent after cvr-qualifiers have been removed.
QualType getConstantMatrixType(QualType ElementType, unsigned NumRows, unsigned NumColumns) const
Return the unique reference to the matrix type of the specified element type and size.
unsigned getTypeAlign(QualType T) const
Return the ABI-specified alignment of a (complete) type T, in bits.
PtrTy get() const
Definition Ownership.h:171
bool isInvalid() const
Definition Ownership.h:167
Represents an array type, per C99 6.7.5.2 - Array Declarators.
Definition TypeBase.h:3777
QualType getElementType() const
Definition TypeBase.h:3789
Attr - This represents one attribute.
Definition Attr.h:46
attr::Kind getKind() const
Definition Attr.h:92
SourceLocation getLocation() const
Definition Attr.h:99
SourceLocation getScopeLoc() const
const IdentifierInfo * getScopeName() const
SourceLocation getLoc() const
const IdentifierInfo * getAttrName() const
Represents a base class of a C++ class.
Definition DeclCXX.h:146
QualType getType() const
Retrieves the type of the base class.
Definition DeclCXX.h:249
Represents a static or instance method of a struct/union/class.
Definition DeclCXX.h:2136
Represents a C++ struct/union/class.
Definition DeclCXX.h:258
bool isHLSLIntangible() const
Returns true if the class contains HLSL intangible type, either as a field or in base class.
Definition DeclCXX.h:1556
static CXXRecordDecl * Create(const ASTContext &C, TagKind TK, DeclContext *DC, SourceLocation StartLoc, SourceLocation IdLoc, IdentifierInfo *Id, CXXRecordDecl *PrevDecl=nullptr)
Definition DeclCXX.cpp:132
void setBases(CXXBaseSpecifier const *const *Bases, unsigned NumBases)
Sets the base classes of this struct or class.
Definition DeclCXX.cpp:184
base_class_iterator bases_end()
Definition DeclCXX.h:617
void completeDefinition() override
Indicates that the definition of this class is now complete.
Definition DeclCXX.cpp:2249
base_class_range bases()
Definition DeclCXX.h:608
unsigned getNumBases() const
Retrieves the number of base classes of this class.
Definition DeclCXX.h:602
base_class_iterator bases_begin()
Definition DeclCXX.h:615
bool isEmpty() const
Determine whether this is an empty class in the sense of (C++11 [meta.unary.prop]).
Definition DeclCXX.h:1186
CallExpr - Represents a function call (C99 6.5.2.2, C++ [expr.call]).
Definition Expr.h:2946
Expr * getArg(unsigned Arg)
getArg - Return the specified argument.
Definition Expr.h:3150
SourceLocation getBeginLoc() const
Definition Expr.h:3280
static CallExpr * Create(const ASTContext &Ctx, Expr *Fn, ArrayRef< Expr * > Args, QualType Ty, ExprValueKind VK, SourceLocation RParenLoc, FPOptionsOverride FPFeatures, unsigned MinNumArgs=0, ADLCallKind UsesADL=NotADL)
Create a call expression.
Definition Expr.cpp:1517
FunctionDecl * getDirectCallee()
If the callee is a FunctionDecl, return it. Otherwise return null.
Definition Expr.h:3129
Expr * getCallee()
Definition Expr.h:3093
unsigned getNumArgs() const
getNumArgs - Return the number of actual arguments to this call.
Definition Expr.h:3137
Decl * getCalleeDecl()
Definition Expr.h:3123
QualType withConst() const
Retrieves a version of this type with const applied.
const T * getTypePtr() const
Retrieve the underlying type pointer, which refers to a canonical type.
QuantityType getQuantity() const
getQuantity - Get the raw integer representation of this quantity.
Definition CharUnits.h:185
Represents the canonical version of C arrays with a specified constant size.
Definition TypeBase.h:3815
bool isZeroSize() const
Return true if the size is zero.
Definition TypeBase.h:3885
llvm::APInt getSize() const
Return the constant array size as an APInt.
Definition TypeBase.h:3871
uint64_t getZExtSize() const
Return the size zero-extended as a uint64_t.
Definition TypeBase.h:3891
Represents a concrete matrix type with constant number of rows and columns.
Definition TypeBase.h:4442
unsigned getNumColumns() const
Returns the number of columns in the matrix.
Definition TypeBase.h:4461
static DeclAccessPair make(NamedDecl *D, AccessSpecifier AS)
DeclContext - This is used only as base class of specific decl types that can act as declaration cont...
Definition DeclBase.h:1462
bool isNamespace() const
Definition DeclBase.h:2211
lookup_result lookup(DeclarationName Name) const
lookup - Find the declarations (if any) with the given Name in this context.
bool isTranslationUnit() const
Definition DeclBase.h:2198
void addDecl(Decl *D)
Add the declaration D into this context.
decl_range decls() const
decls_begin/decls_end - Iterate over the declarations stored in this context.
Definition DeclBase.h:2386
DeclContext * getNonTransparentContext()
A reference to a declared variable, function, enum, etc.
Definition Expr.h:1273
static DeclRefExpr * Create(const ASTContext &Context, NestedNameSpecifierLoc QualifierLoc, SourceLocation TemplateKWLoc, ValueDecl *D, bool RefersToEnclosingVariableOrCapture, SourceLocation NameLoc, QualType T, ExprValueKind VK, NamedDecl *FoundD=nullptr, const TemplateArgumentListInfo *TemplateArgs=nullptr, NonOdrUseReason NOUR=NOUR_None)
Definition Expr.cpp:488
ValueDecl * getDecl()
Definition Expr.h:1341
Decl - This represents one declaration (or definition), e.g.
Definition DeclBase.h:86
T * getAttr() const
Definition DeclBase.h:581
ASTContext & getASTContext() const LLVM_READONLY
Definition DeclBase.cpp:547
void addAttr(Attr *A)
attr_iterator attr_end() const
Definition DeclBase.h:550
bool isImplicit() const
isImplicit - Indicates whether the declaration was implicitly generated by the implementation.
Definition DeclBase.h:601
void setInvalidDecl(bool Invalid=true)
setInvalidDecl - Indicates the Decl had a semantic error.
Definition DeclBase.cpp:178
bool isInExportDeclContext() const
Whether this declaration was exported in a lexical context.
attr_iterator attr_begin() const
Definition DeclBase.h:547
DeclContext * getNonTransparentDeclContext()
Return the non transparent context.
SourceLocation getLocation() const
Definition DeclBase.h:447
void setImplicit(bool I=true)
Definition DeclBase.h:602
DeclContext * getDeclContext()
Definition DeclBase.h:456
attr_range attrs() const
Definition DeclBase.h:543
AccessSpecifier getAccess() const
Definition DeclBase.h:515
SourceLocation getBeginLoc() const LLVM_READONLY
Definition DeclBase.h:439
void dropAttr()
Definition DeclBase.h:564
bool hasAttr() const
Definition DeclBase.h:585
The name of a declaration.
Represents a ValueDecl that came out of a declarator.
Definition Decl.h:780
SourceLocation getBeginLoc() const LLVM_READONLY
Definition Decl.h:831
This represents one expression.
Definition Expr.h:112
bool isIntegerConstantExpr(const ASTContext &Ctx) const
void setType(QualType t)
Definition Expr.h:145
ExprValueKind getValueKind() const
getValueKind - The value kind that this expression produces.
Definition Expr.h:447
Expr * IgnoreParenImpCasts() LLVM_READONLY
Skip past any parentheses and implicit casts which might surround this expression until reaching a fi...
Definition Expr.cpp:3090
Expr * IgnoreParens() LLVM_READONLY
Skip past any parentheses which might surround this expression until reaching a fixed point.
Definition Expr.cpp:3086
std::optional< llvm::APSInt > getIntegerConstantExpr(const ASTContext &Ctx) const
isIntegerConstantExpr - Return the value if this expression is a valid integer constant expression.
bool isPRValue() const
Definition Expr.h:285
bool isLValue() const
isLValue - True if this expression is an "l-value" according to the rules of the current language.
Definition Expr.h:284
ExprObjectKind getObjectKind() const
getObjectKind - The object kind that this expression produces.
Definition Expr.h:454
bool HasSideEffects(const ASTContext &Ctx, bool IncludePossibleEffects=true) const
HasSideEffects - This routine returns true for all those expressions which have any effect other than...
Definition Expr.cpp:3688
void setValueKind(ExprValueKind Cat)
setValueKind - Set the value kind produced by this expression.
Definition Expr.h:464
SourceLocation getExprLoc() const LLVM_READONLY
getExprLoc - Return the preferred location for the arrow when diagnosing a problem with a generic exp...
Definition Expr.cpp:277
@ MLV_Valid
Definition Expr.h:306
QualType getType() const
Definition Expr.h:144
ExtVectorType - Extended vector type.
Definition TypeBase.h:4322
Represents difference between two FPOptions values.
Represents a member of a struct/union/class.
Definition Decl.h:3175
static FieldDecl * Create(const ASTContext &C, DeclContext *DC, SourceLocation StartLoc, SourceLocation IdLoc, const IdentifierInfo *Id, QualType T, TypeSourceInfo *TInfo, Expr *BW, bool Mutable, InClassInitStyle InitStyle)
Definition Decl.cpp:4666
static FixItHint CreateReplacement(CharSourceRange RemoveRange, StringRef Code)
Create a code modification hint that replaces the given source range with the given code string.
Definition Diagnostic.h:141
Represents a function declaration or definition.
Definition Decl.h:2015
const ParmVarDecl * getParamDecl(unsigned i) const
Definition Decl.h:2812
Stmt * getBody(const FunctionDecl *&Definition) const
Retrieve the body (definition) of the function.
Definition Decl.cpp:3245
bool isThisDeclarationADefinition() const
Returns whether this specific declaration of the function is also a definition that does not contain ...
Definition Decl.h:2329
QualType getReturnType() const
Definition Decl.h:2860
ArrayRef< ParmVarDecl * > parameters() const
Definition Decl.h:2789
bool isTemplateInstantiation() const
Determines if the given function was instantiated from a function template.
Definition Decl.cpp:4223
redecl_range redecls() const
Returns an iterator range for all the redeclarations of the same decl.
unsigned getNumParams() const
Return the number of parameters this function must have based on its FunctionType.
Definition Decl.cpp:3792
DeclarationNameInfo getNameInfo() const
Definition Decl.h:2226
bool hasBody(const FunctionDecl *&Definition) const
Returns true if the function has a body.
Definition Decl.cpp:3165
bool isDefined(const FunctionDecl *&Definition, bool CheckForPendingFriendDefinition=false) const
Returns true if the function has a definition that does not need to be instantiated.
Definition Decl.cpp:3212
HLSLBufferDecl - Represent a cbuffer or tbuffer declaration.
Definition Decl.h:5211
static HLSLBufferDecl * Create(ASTContext &C, DeclContext *LexicalParent, bool CBuffer, SourceLocation KwLoc, IdentifierInfo *ID, SourceLocation IDLoc, SourceLocation LBrace)
Definition Decl.cpp:5871
void addLayoutStruct(CXXRecordDecl *LS)
Definition Decl.cpp:5911
void setHasValidPackoffset(bool PO)
Definition Decl.h:5256
static HLSLBufferDecl * CreateDefaultCBuffer(ASTContext &C, DeclContext *LexicalParent, ArrayRef< Decl * > DefaultCBufferDecls)
Definition Decl.cpp:5894
buffer_decl_range buffer_decls() const
Definition Decl.h:5286
static HLSLOutArgExpr * Create(const ASTContext &C, QualType Ty, OpaqueValueExpr *Base, OpaqueValueExpr *OpV, Expr *WB, bool IsInOut)
Definition Expr.cpp:5645
static HLSLRootSignatureDecl * Create(ASTContext &C, DeclContext *DC, SourceLocation Loc, IdentifierInfo *ID, llvm::dxbc::RootSignatureVersion Version, ArrayRef< llvm::hlsl::rootsig::RootElement > RootElements)
Definition Decl.cpp:5957
One of these records is kept for each identifier that is lexed.
StringRef getName() const
Return the actual identifier string.
A simple pair of identifier info and location.
SourceLocation getLoc() const
IdentifierInfo * getIdentifierInfo() const
IdentifierInfo & get(StringRef Name)
Return the identifier token info for the specified named identifier.
ImplicitCastExpr - Allows us to explicitly represent implicit type conversions, which have no direct ...
Definition Expr.h:3856
static ImplicitCastExpr * Create(const ASTContext &Context, QualType T, CastKind Kind, Expr *Operand, const CXXCastPath *BasePath, ExprValueKind Cat, FPOptionsOverride FPO)
Definition Expr.cpp:2073
Describes an C or C++ initializer list.
Definition Expr.h:5302
Describes an entity that is being initialized.
QualType getType() const
Retrieve type being initialized.
static InitializedEntity InitializeParameter(ASTContext &Context, ParmVarDecl *Parm)
Create the initialization entity for a parameter.
static IntegerLiteral * Create(const ASTContext &C, const llvm::APInt &V, QualType type, SourceLocation l)
Returns a new integer literal with value 'V' and type 'type'.
Definition Expr.cpp:975
iterator begin(Source *source, bool LocalOnly=false)
Represents the results of name lookup.
Definition Lookup.h:147
Represents a prvalue temporary that is written into memory so that a reference can bind to it.
Definition ExprCXX.h:4920
Represents a matrix type, as defined in the Matrix Types clang extensions.
Definition TypeBase.h:4392
MemberExpr - [C99 6.5.2.3] Structure and Union Members.
Definition Expr.h:3367
ValueDecl * getMemberDecl() const
Retrieve the member declaration to which this expression refers.
Definition Expr.h:3450
Expr * getBase() const
Definition Expr.h:3444
This represents a decl that may have a name.
Definition Decl.h:274
IdentifierInfo * getIdentifier() const
Get the identifier that names this declaration, if there is one.
Definition Decl.h:295
StringRef getName() const
Get the name of identifier for this declaration as a StringRef.
Definition Decl.h:301
DeclarationName getDeclName() const
Get the actual, stored name of the declaration, which may be a special name.
Definition Decl.h:340
A C++ nested-name-specifier augmented with source location information.
OpaqueValueExpr - An expression referring to an opaque object of a fixed type and value class.
Definition Expr.h:1181
Represents a parameter to a function.
Definition Decl.h:1805
ParsedAttr - Represents a syntactic attribute.
Definition ParsedAttr.h:119
unsigned getSemanticSpelling() const
If the parsed attribute has a semantic equivalent, and it would have a semantic Spelling enumeration ...
unsigned getMinArgs() const
bool checkExactlyNumArgs(class Sema &S, unsigned Num) const
Check if the attribute has exactly as many args as Num.
IdentifierLoc * getArgAsIdent(unsigned Arg) const
Definition ParsedAttr.h:389
bool hasParsedType() const
Definition ParsedAttr.h:337
const ParsedType & getTypeArg() const
Definition ParsedAttr.h:459
unsigned getNumArgs() const
getNumArgs - Return the number of actual arguments to this attribute.
Definition ParsedAttr.h:371
bool isArgIdent(unsigned Arg) const
Definition ParsedAttr.h:385
Expr * getArgAsExpr(unsigned Arg) const
Definition ParsedAttr.h:383
AttributeCommonInfo::Kind getKind() const
Definition ParsedAttr.h:610
A (possibly-)qualified type.
Definition TypeBase.h:937
void addRestrict()
Add the restrict qualifier to this QualType.
Definition TypeBase.h:1183
QualType getNonLValueExprType(const ASTContext &Context) const
Determine the type of a (typically non-lvalue) expression with the specified result type.
Definition Type.cpp:3665
QualType getDesugaredType(const ASTContext &Context) const
Return the specified type with any "sugar" removed from the type.
Definition TypeBase.h:1307
bool isNull() const
Return true if this QualType doesn't point to a type yet.
Definition TypeBase.h:1004
const Type * getTypePtr() const
Retrieves a pointer to the underlying (unqualified) type.
Definition TypeBase.h:8436
LangAS getAddressSpace() const
Return the address space of this type.
Definition TypeBase.h:8562
QualType getNonReferenceType() const
If Type is a reference type (e.g., const int&), returns the type that the reference refers to ("const...
Definition TypeBase.h:8621
QualType getCanonicalType() const
Definition TypeBase.h:8488
QualType getUnqualifiedType() const
Retrieve the unqualified variant of the given type, removing as little sugar as possible.
Definition TypeBase.h:8530
bool hasAddressSpace() const
Check if this type has any address space qualifier.
Definition TypeBase.h:8557
Represents a struct/union/class.
Definition Decl.h:4342
field_range fields() const
Definition Decl.h:4545
bool field_empty() const
Definition Decl.h:4553
bool hasBindingInfoForDecl(const VarDecl *VD) const
Definition SemaHLSL.cpp:220
DeclBindingInfo * getDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass)
Definition SemaHLSL.cpp:206
DeclBindingInfo * addDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass)
Definition SemaHLSL.cpp:193
Scope - A scope is a transient data structure that is used while parsing the program.
Definition Scope.h:41
SemaBase(Sema &S)
Definition SemaBase.cpp:7
ASTContext & getASTContext() const
Definition SemaBase.cpp:9
Sema & SemaRef
Definition SemaBase.h:40
SemaDiagnosticBuilder Diag(SourceLocation Loc, unsigned DiagID)
Emit a diagnostic.
Definition SemaBase.cpp:61
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg)
HLSLRootSignatureDecl * lookupRootSignatureOverrideDecl(DeclContext *DC) const
bool CanPerformElementwiseCast(Expr *Src, QualType DestType)
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL)
void handleVkLocationAttr(Decl *D, const ParsedAttr &AL)
HLSLAttributedResourceLocInfo TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT)
void handleSemanticAttr(Decl *D, const ParsedAttr &AL)
bool CanPerformScalarCast(QualType SrcTy, QualType DestTy)
QualType ProcessResourceTypeAttributes(QualType Wrapped)
void handleShaderAttr(Decl *D, const ParsedAttr &AL)
uint32_t getNextImplicitBindingOrderID()
Definition SemaHLSL.h:230
void CheckEntryPoint(FunctionDecl *FD)
Definition SemaHLSL.cpp:960
void handleVkExtBuiltinOutputAttr(Decl *D, const ParsedAttr &AL)
void emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS, BinaryOperatorKind Opc)
T * createSemanticAttr(const AttributeCommonInfo &ACI, std::optional< unsigned > Location)
Definition SemaHLSL.h:182
bool initGlobalResourceDecl(VarDecl *VD)
void ActOnEndOfTranslationUnit(TranslationUnitDecl *TU)
bool initGlobalResourceArrayDecl(VarDecl *VD)
HLSLVkConstantIdAttr * mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id)
Definition SemaHLSL.cpp:724
HLSLNumThreadsAttr * mergeNumThreadsAttr(Decl *D, const AttributeCommonInfo &AL, int X, int Y, int Z)
Definition SemaHLSL.cpp:690
void deduceAddressSpace(VarDecl *Decl)
std::pair< IdentifierInfo *, bool > ActOnStartRootSignatureDecl(StringRef Signature)
Computes the unique Root Signature identifier from the given signature, then lookup if there is a pre...
void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL)
bool diagnosePositionType(QualType T, const ParsedAttr &AL)
bool handleInitialization(VarDecl *VDecl, Expr *&Init)
bool diagnoseInputIDType(QualType T, const ParsedAttr &AL)
void handleParamModifierAttr(Decl *D, const ParsedAttr &AL)
bool CheckResourceBinOp(BinaryOperatorKind Opc, Expr *LHSExpr, Expr *RHSExpr, SourceLocation Loc)
bool CanPerformAggregateSplatCast(Expr *Src, QualType DestType)
bool ActOnResourceMemberAccessExpr(MemberExpr *ME)
bool IsScalarizedLayoutCompatible(QualType T1, QualType T2) const
QualType ActOnTemplateShorthand(TemplateDecl *Template, SourceLocation NameLoc)
void diagnoseSystemSemanticAttr(Decl *D, const ParsedAttr &AL, std::optional< unsigned > Index)
void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL)
bool CheckCompatibleParameterABI(FunctionDecl *New, FunctionDecl *Old)
QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, bool IsCompAssign)
QualType checkMatrixComponent(Sema &S, QualType baseType, ExprValueKind &VK, SourceLocation OpLoc, const IdentifierInfo *CompName, SourceLocation CompLoc)
void handleResourceBindingAttr(Decl *D, const ParsedAttr &AL)
bool IsTypedResourceElementCompatible(QualType T1)
bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init)
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL)
bool ActOnUninitializedVarDecl(VarDecl *D)
void handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL)
void ActOnTopLevelFunction(FunctionDecl *FD)
Definition SemaHLSL.cpp:793
bool handleResourceTypeAttr(QualType T, const ParsedAttr &AL)
void handleVkPushConstantAttr(Decl *D, const ParsedAttr &AL)
HLSLShaderAttr * mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL, llvm::Triple::EnvironmentType ShaderType)
Definition SemaHLSL.cpp:760
void ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace)
Definition SemaHLSL.cpp:663
void handleVkBindingAttr(Decl *D, const ParsedAttr &AL)
HLSLParamModifierAttr * mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL, HLSLParamModifierAttr::Spelling Spelling)
Definition SemaHLSL.cpp:773
QualType getInoutParameterType(QualType Ty)
SemaHLSL(Sema &S)
Definition SemaHLSL.cpp:224
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL)
Decl * ActOnStartBuffer(Scope *BufferScope, bool CBuffer, SourceLocation KwLoc, IdentifierInfo *Ident, SourceLocation IdentLoc, SourceLocation LBrace)
Definition SemaHLSL.cpp:226
HLSLWaveSizeAttr * mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL, int Min, int Max, int Preferred, int SpelledArgsCount)
Definition SemaHLSL.cpp:704
bool handleRootSignatureElements(ArrayRef< hlsl::RootSignatureElement > Elements)
void ActOnFinishRootSignatureDecl(SourceLocation Loc, IdentifierInfo *DeclIdent, ArrayRef< hlsl::RootSignatureElement > Elements)
Creates the Root Signature decl of the parsed Root Signature elements onto the AST and push it onto c...
void ActOnVariableDeclarator(VarDecl *VD)
bool CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall)
Sema - This implements semantic analysis and AST building for C.
Definition Sema.h:868
@ LookupOrdinaryName
Ordinary name lookup, which finds ordinary names (functions, variables, typedefs, etc....
Definition Sema.h:9414
@ LookupMemberName
Member name lookup, which finds the names of class/struct/union members.
Definition Sema.h:9422
ExtVectorDeclsType ExtVectorDecls
ExtVectorDecls - This is a list all the extended vector types.
Definition Sema.h:4954
ASTContext & Context
Definition Sema.h:1308
ASTContext & getASTContext() const
Definition Sema.h:939
ExprResult ImpCastExprToType(Expr *E, QualType Type, CastKind CK, ExprValueKind VK=VK_PRValue, const CXXCastPath *BasePath=nullptr, CheckedConversionKind CCK=CheckedConversionKind::Implicit)
ImpCastExprToType - If Expr is not of type 'Type', insert an implicit cast.
Definition Sema.cpp:762
const LangOptions & getLangOpts() const
Definition Sema.h:932
ExprResult TemporaryMaterializationConversion(Expr *E)
If E is a prvalue denoting an unmaterialized temporary, materialize it as an xvalue.
SemaHLSL & HLSL()
Definition Sema.h:1483
ExprResult BuildFieldReferenceExpr(Expr *BaseExpr, bool IsArrow, SourceLocation OpLoc, const CXXScopeSpec &SS, FieldDecl *Field, DeclAccessPair FoundDecl, const DeclarationNameInfo &MemberNameInfo)
bool checkArgCountRange(CallExpr *Call, unsigned MinArgCount, unsigned MaxArgCount)
Checks that a call expression's argument count is in the desired range.
ExternalSemaSource * getExternalSource() const
Definition Sema.h:942
ASTConsumer & Consumer
Definition Sema.h:1309
bool checkArgCount(CallExpr *Call, unsigned DesiredArgCount)
Checks that a call expression's argument count is the desired number.
ExprResult CreateBuiltinArraySubscriptExpr(Expr *Base, SourceLocation LLoc, Expr *Idx, SourceLocation RLoc)
bool LookupQualifiedName(LookupResult &R, DeclContext *LookupCtx, bool InUnqualifiedLookup=false)
Perform qualified name lookup into a given context.
ExprResult PerformCopyInitialization(const InitializedEntity &Entity, SourceLocation EqualLoc, ExprResult Init, bool TopLevelOfInitList=false, bool AllowExplicit=false)
ExprResult CreateBuiltinMatrixSubscriptExpr(Expr *Base, Expr *RowIdx, Expr *ColumnIdx, SourceLocation RBLoc)
Encodes a location in the source.
SourceLocation getLocWithOffset(IntTy Offset) const
Return a source location with the specified offset from this SourceLocation.
A trivial tuple used to represent a source range.
SourceLocation getEnd() const
SourceLocation getEndLoc() const LLVM_READONLY
Definition Stmt.cpp:367
void printPretty(raw_ostream &OS, PrinterHelper *Helper, const PrintingPolicy &Policy, unsigned Indentation=0, StringRef NewlineSymbol="\n", const ASTContext *Context=nullptr) const
SourceRange getSourceRange() const LLVM_READONLY
SourceLocation tokens are not useful in isolation - they are low level value objects created/interpre...
Definition Stmt.cpp:343
SourceLocation getBeginLoc() const LLVM_READONLY
Definition Stmt.cpp:355
StringLiteral - This represents a string literal expression, e.g.
Definition Expr.h:1802
static StringLiteral * Create(const ASTContext &Ctx, StringRef Str, StringLiteralKind Kind, bool Pascal, QualType Ty, ArrayRef< SourceLocation > Locs)
This is the "fully general" constructor that allows representation of strings formed from one or more...
Definition Expr.cpp:1188
void startDefinition()
Starts the definition of this tag declaration.
Definition Decl.cpp:4872
bool isUnion() const
Definition Decl.h:3943
bool isClass() const
Definition Decl.h:3942
Exposes information about the current target.
Definition TargetInfo.h:227
TargetOptions & getTargetOpts() const
Retrieve the target options.
Definition TargetInfo.h:327
const llvm::Triple & getTriple() const
Returns the target triple of the primary target.
StringRef getPlatformName() const
Retrieve the name of the platform as it is used in the availability attribute.
VersionTuple getPlatformMinVersion() const
Retrieve the minimum desired version of the platform, to which the program should be compiled.
std::string HLSLEntry
The entry point name for HLSL shader being compiled as specified by -E.
A convenient class for passing around template argument information.
void addArgument(const TemplateArgumentLoc &Loc)
The base class of all kinds of template declarations (e.g., class, function, etc.).
Stores a list of template parameters for a TemplateDecl and its derived classes.
The top declaration context.
Definition Decl.h:105
SourceLocation getBeginLoc() const
Get the begin source location.
Definition TypeLoc.cpp:193
A container of type source information.
Definition TypeBase.h:8407
TypeLoc getTypeLoc() const
Return the TypeLoc wrapper for the type source info.
Definition TypeLoc.h:267
The base class of the type hierarchy.
Definition TypeBase.h:1871
bool isVoidType() const
Definition TypeBase.h:9039
bool isBooleanType() const
Definition TypeBase.h:9176
bool isIncompleteArrayType() const
Definition TypeBase.h:8780
CXXRecordDecl * getAsCXXRecordDecl() const
Retrieves the CXXRecordDecl that this type refers to, either because the type is a RecordType or beca...
Definition Type.h:26
bool isConstantArrayType() const
Definition TypeBase.h:8776
bool hasIntegerRepresentation() const
Determine whether this type has an integer representation of some sort, e.g., it is an integer type o...
Definition Type.cpp:2119
bool isArrayType() const
Definition TypeBase.h:8772
CXXRecordDecl * castAsCXXRecordDecl() const
Definition Type.h:36
bool isArithmeticType() const
Definition Type.cpp:2410
bool isConstantMatrixType() const
Definition TypeBase.h:8840
bool isHLSLBuiltinIntangibleType() const
Definition TypeBase.h:8984
bool isPointerType() const
Definition TypeBase.h:8673
CanQualType getCanonicalTypeUnqualified() const
bool isIntegerType() const
isIntegerType() does not include complex integers (a GCC extension).
Definition TypeBase.h:9083
const T * castAs() const
Member-template castAs<specific type>.
Definition TypeBase.h:9333
bool isReferenceType() const
Definition TypeBase.h:8697
bool isHLSLIntangibleType() const
Definition Type.cpp:5496
bool isEnumeralType() const
Definition TypeBase.h:8804
bool isScalarType() const
Definition TypeBase.h:9145
bool isIntegralType(const ASTContext &Ctx) const
Determine whether this type is an integral type.
Definition Type.cpp:2156
const Type * getArrayElementTypeNoTypeQual() const
If this is an array type, return the element type of the array, potentially with type qualifiers miss...
Definition Type.cpp:508
QualType getPointeeType() const
If this is a pointer, ObjC object pointer, or block pointer, this returns the respective pointee.
Definition Type.cpp:789
bool hasUnsignedIntegerRepresentation() const
Determine whether this type has an unsigned integer representation of some sort, e....
Definition Type.cpp:2364
bool isAggregateType() const
Determines whether the type is a C++ aggregate type or C aggregate or union type.
Definition Type.cpp:2491
ScalarTypeKind getScalarTypeKind() const
Given that this is a scalar type, classify it.
Definition Type.cpp:2442
bool hasSignedIntegerRepresentation() const
Determine whether this type has an signed integer representation of some sort, e.g....
Definition Type.cpp:2310
bool isMatrixType() const
Definition TypeBase.h:8836
bool isHLSLResourceRecord() const
Definition Type.cpp:5483
bool hasFloatingRepresentation() const
Determine whether this type has a floating-point representation of some sort, e.g....
Definition Type.cpp:2385
bool isVectorType() const
Definition TypeBase.h:8812
bool isRealFloatingType() const
Floating point categories.
Definition Type.cpp:2393
bool isHLSLAttributedResourceType() const
Definition TypeBase.h:8996
@ STK_FloatingComplex
Definition TypeBase.h:2819
@ STK_ObjCObjectPointer
Definition TypeBase.h:2813
@ STK_IntegralComplex
Definition TypeBase.h:2818
@ STK_MemberPointer
Definition TypeBase.h:2814
bool isFloatingType() const
Definition Type.cpp:2377
bool isSamplerT() const
Definition TypeBase.h:8917
const T * getAs() const
Member-template getAs<specific type>'.
Definition TypeBase.h:9266
const Type * getUnqualifiedDesugaredType() const
Return the specified type with any "sugar" removed from the type, removing any typedefs,...
Definition Type.cpp:690
bool isRecordType() const
Definition TypeBase.h:8800
bool isHLSLResourceRecordArray() const
Definition Type.cpp:5487
void setType(QualType newType)
Definition Decl.h:724
QualType getType() const
Definition Decl.h:723
Represents a variable declaration or definition.
Definition Decl.h:926
static VarDecl * Create(ASTContext &C, DeclContext *DC, SourceLocation StartLoc, SourceLocation IdLoc, const IdentifierInfo *Id, QualType T, TypeSourceInfo *TInfo, StorageClass S)
Definition Decl.cpp:2128
void setInitStyle(InitializationStyle Style)
Definition Decl.h:1467
@ CallInit
Call-style initialization (C++98)
Definition Decl.h:934
void setStorageClass(StorageClass SC)
Definition Decl.cpp:2140
bool hasGlobalStorage() const
Returns true for all variables that do not have local storage.
Definition Decl.h:1241
void setInit(Expr *I)
Definition Decl.cpp:2454
StorageClass getStorageClass() const
Returns the storage class as written in the source.
Definition Decl.h:1168
Represents a GCC generic vector type.
Definition TypeBase.h:4230
unsigned getNumElements() const
Definition TypeBase.h:4245
QualType getElementType() const
Definition TypeBase.h:4244
IdentifierInfo * getNameAsIdentifier(ASTContext &AST) const
Defines the clang::TargetInfo interface.
Definition SPIR.cpp:47
uint32_t getResourceDimensions(llvm::dxil::ResourceDimension Dim)
bool hasCounterHandle(const CXXRecordDecl *RD)
The JSON file list parser is used to communicate input to InstallAPI.
bool isa(CodeGen::Address addr)
Definition Address.h:330
if(T->getSizeExpr()) TRY_TO(TraverseStmt(const_cast< Expr * >(T -> getSizeExpr())))
static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
Definition SemaSPIRV.cpp:66
@ ICIS_NoInit
No in-class initializer.
Definition Specifiers.h:273
@ TemplateName
The identifier is a template name. FIXME: Add an annotation for that.
Definition Parser.h:61
@ OK_Ordinary
An ordinary object is located at an address in memory.
Definition Specifiers.h:152
static bool CheckAllArgTypesAreCorrect(Sema *S, CallExpr *TheCall, llvm::ArrayRef< llvm::function_ref< bool(Sema *, SourceLocation, int, QualType)> > Checks)
Definition SemaSPIRV.cpp:49
@ AS_public
Definition Specifiers.h:125
@ AS_none
Definition Specifiers.h:128
@ SC_Extern
Definition Specifiers.h:252
@ SC_Static
Definition Specifiers.h:253
@ SC_None
Definition Specifiers.h:251
@ AANT_ArgumentIdentifier
@ Result
The result type of a method or function.
Definition TypeBase.h:905
@ Ordinary
This parameter uses ordinary ABI rules for its type.
Definition Specifiers.h:383
llvm::Expected< QualType > ExpectedType
@ Template
We are parsing a template declaration.
Definition Parser.h:81
LLVM_READONLY bool isDigit(unsigned char c)
Return true if this character is an ASCII digit: [0-9].
Definition CharInfo.h:114
static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall)
Definition SemaSPIRV.cpp:32
ExprResult ExprError()
Definition Ownership.h:265
@ Type
The name was classified as a type.
Definition Sema.h:564
LangAS
Defines the address space values used by the address space qualifier of QualType.
bool CreateHLSLAttributedResourceType(Sema &S, QualType Wrapped, ArrayRef< const Attr * > AttrList, QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo=nullptr)
CastKind
CastKind - The kind of operation required for a conversion.
ExprValueKind
The categorization of expression values, currently following the C++11 scheme.
Definition Specifiers.h:133
@ VK_PRValue
A pr-value expression (in the C++11 taxonomy) produces a temporary value.
Definition Specifiers.h:136
@ VK_LValue
An l-value expression is a reference to an object with independent storage.
Definition Specifiers.h:140
DynamicRecursiveASTVisitorBase< false > DynamicRecursiveASTVisitor
U cast(CodeGen::Address addr)
Definition Address.h:327
@ None
No keyword precedes the qualified type name.
Definition TypeBase.h:5982
ActionResult< Expr * > ExprResult
Definition Ownership.h:249
Visibility
Describes the different kinds of visibility that a declaration may have.
Definition Visibility.h:34
unsigned long uint64_t
unsigned int uint32_t
hash_code hash_value(const clang::dependencies::ModuleID &ID)
__DEVICE__ bool isnan(float __x)
__DEVICE__ _Tp abs(const std::complex< _Tp > &__c)
#define false
Definition stdbool.h:26
Describes how types, statements, expressions, and declarations should be printed.
void setCounterImplicitOrderID(unsigned Value) const
void setImplicitOrderID(unsigned Value) const
const SourceLocation & getLocation() const
Definition SemaHLSL.h:48
const llvm::hlsl::rootsig::RootElement & getElement() const
Definition SemaHLSL.h:47