clang 23.0.0git
SemaHLSL.cpp
Go to the documentation of this file.
1//===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This implements Semantic Analysis for HLSL constructs.
9//===----------------------------------------------------------------------===//
10
11#include "clang/Sema/SemaHLSL.h"
14#include "clang/AST/Attr.h"
15#include "clang/AST/Decl.h"
16#include "clang/AST/DeclBase.h"
17#include "clang/AST/DeclCXX.h"
20#include "clang/AST/Expr.h"
22#include "clang/AST/Type.h"
23#include "clang/AST/TypeBase.h"
24#include "clang/AST/TypeLoc.h"
28#include "clang/Basic/LLVM.h"
33#include "clang/Sema/Lookup.h"
35#include "clang/Sema/Sema.h"
36#include "clang/Sema/Template.h"
37#include "llvm/ADT/ArrayRef.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/StringExtras.h"
41#include "llvm/ADT/StringRef.h"
42#include "llvm/ADT/Twine.h"
43#include "llvm/Frontend/HLSL/HLSLBinding.h"
44#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
45#include "llvm/Support/Casting.h"
46#include "llvm/Support/DXILABI.h"
47#include "llvm/Support/ErrorHandling.h"
48#include "llvm/Support/FormatVariadic.h"
49#include "llvm/TargetParser/Triple.h"
50#include <cmath>
51#include <cstddef>
52#include <iterator>
53#include <utility>
54
55using namespace clang;
56using namespace clang::hlsl;
57using RegisterType = HLSLResourceBindingAttr::RegisterType;
58
60 CXXRecordDecl *StructDecl);
61
63 switch (RC) {
64 case ResourceClass::SRV:
65 return RegisterType::SRV;
66 case ResourceClass::UAV:
67 return RegisterType::UAV;
68 case ResourceClass::CBuffer:
69 return RegisterType::CBuffer;
70 case ResourceClass::Sampler:
71 return RegisterType::Sampler;
72 }
73 llvm_unreachable("unexpected ResourceClass value");
74}
75
76static RegisterType getRegisterType(const HLSLAttributedResourceType *ResTy) {
77 return getRegisterType(ResTy->getAttrs().ResourceClass);
78}
79
80// Converts the first letter of string Slot to RegisterType.
81// Returns false if the letter does not correspond to a valid register type.
82static bool convertToRegisterType(StringRef Slot, RegisterType *RT) {
83 assert(RT != nullptr);
84 switch (Slot[0]) {
85 case 't':
86 case 'T':
87 *RT = RegisterType::SRV;
88 return true;
89 case 'u':
90 case 'U':
91 *RT = RegisterType::UAV;
92 return true;
93 case 'b':
94 case 'B':
95 *RT = RegisterType::CBuffer;
96 return true;
97 case 's':
98 case 'S':
99 *RT = RegisterType::Sampler;
100 return true;
101 case 'c':
102 case 'C':
103 *RT = RegisterType::C;
104 return true;
105 case 'i':
106 case 'I':
107 *RT = RegisterType::I;
108 return true;
109 default:
110 return false;
111 }
112}
113
115 switch (RT) {
116 case RegisterType::SRV:
117 return ResourceClass::SRV;
118 case RegisterType::UAV:
119 return ResourceClass::UAV;
120 case RegisterType::CBuffer:
121 return ResourceClass::CBuffer;
122 case RegisterType::Sampler:
123 return ResourceClass::Sampler;
124 case RegisterType::C:
125 case RegisterType::I:
126 // Deliberately falling through to the unreachable below.
127 break;
128 }
129 llvm_unreachable("unexpected RegisterType value");
130}
131
133 const auto *BT = dyn_cast<BuiltinType>(Type);
134 if (!BT) {
135 if (!Type->isEnumeralType())
136 return Builtin::NotBuiltin;
137 return Builtin::BI__builtin_get_spirv_spec_constant_int;
138 }
139
140 switch (BT->getKind()) {
141 case BuiltinType::Bool:
142 return Builtin::BI__builtin_get_spirv_spec_constant_bool;
143 case BuiltinType::Short:
144 return Builtin::BI__builtin_get_spirv_spec_constant_short;
145 case BuiltinType::Int:
146 return Builtin::BI__builtin_get_spirv_spec_constant_int;
147 case BuiltinType::LongLong:
148 return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
149 case BuiltinType::UShort:
150 return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
151 case BuiltinType::UInt:
152 return Builtin::BI__builtin_get_spirv_spec_constant_uint;
153 case BuiltinType::ULongLong:
154 return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
155 case BuiltinType::Half:
156 return Builtin::BI__builtin_get_spirv_spec_constant_half;
157 case BuiltinType::Float:
158 return Builtin::BI__builtin_get_spirv_spec_constant_float;
159 case BuiltinType::Double:
160 return Builtin::BI__builtin_get_spirv_spec_constant_double;
161 default:
162 return Builtin::NotBuiltin;
163 }
164}
165
167 ResourceClass ResClass) {
168 assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
169 "DeclBindingInfo already added");
170 assert(!hasBindingInfoForDecl(VD) || BindingsList.back().Decl == VD);
171 // VarDecl may have multiple entries for different resource classes.
172 // DeclToBindingListIndex stores the index of the first binding we saw
173 // for this decl. If there are any additional ones then that index
174 // shouldn't be updated.
175 DeclToBindingListIndex.try_emplace(VD, BindingsList.size());
176 return &BindingsList.emplace_back(VD, ResClass);
177}
178
180 ResourceClass ResClass) {
181 auto Entry = DeclToBindingListIndex.find(VD);
182 if (Entry != DeclToBindingListIndex.end()) {
183 for (unsigned Index = Entry->getSecond();
184 Index < BindingsList.size() && BindingsList[Index].Decl == VD;
185 ++Index) {
186 if (BindingsList[Index].ResClass == ResClass)
187 return &BindingsList[Index];
188 }
189 }
190 return nullptr;
191}
192
194 return DeclToBindingListIndex.contains(VD);
195}
196
198
199Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
200 SourceLocation KwLoc, IdentifierInfo *Ident,
201 SourceLocation IdentLoc,
202 SourceLocation LBrace) {
203 // For anonymous namespace, take the location of the left brace.
204 DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
206 getASTContext(), LexicalParent, CBuffer, KwLoc, Ident, IdentLoc, LBrace);
207
208 // if CBuffer is false, then it's a TBuffer
209 auto RC = CBuffer ? llvm::hlsl::ResourceClass::CBuffer
210 : llvm::hlsl::ResourceClass::SRV;
211 Result->addAttr(HLSLResourceClassAttr::CreateImplicit(getASTContext(), RC));
212
213 SemaRef.PushOnScopeChains(Result, BufferScope);
214 SemaRef.PushDeclContext(BufferScope, Result);
215
216 return Result;
217}
218
219static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context,
220 QualType T) {
221 // Arrays, Matrices, and Structs are always aligned to new buffer rows
222 if (T->isArrayType() || T->isStructureType() || T->isConstantMatrixType())
223 return 16;
224
225 // Vectors are aligned to the type they contain
226 if (const VectorType *VT = T->getAs<VectorType>())
227 return calculateLegacyCbufferFieldAlign(Context, VT->getElementType());
228
229 assert(Context.getTypeSize(T) <= 64 &&
230 "Scalar bit widths larger than 64 not supported");
231
232 // Scalar types are aligned to their byte width
233 return Context.getTypeSize(T) / 8;
234}
235
236// Calculate the size of a legacy cbuffer type in bytes based on
237// https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules
238static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
239 QualType T) {
240 constexpr unsigned CBufferAlign = 16;
241 if (const auto *RD = T->getAsRecordDecl()) {
242 unsigned Size = 0;
243 for (const FieldDecl *Field : RD->fields()) {
244 QualType Ty = Field->getType();
245 unsigned FieldSize = calculateLegacyCbufferSize(Context, Ty);
246 unsigned FieldAlign = calculateLegacyCbufferFieldAlign(Context, Ty);
247
248 // If the field crosses the row boundary after alignment it drops to the
249 // next row
250 unsigned AlignSize = llvm::alignTo(Size, FieldAlign);
251 if ((AlignSize % CBufferAlign) + FieldSize > CBufferAlign) {
252 FieldAlign = CBufferAlign;
253 }
254
255 Size = llvm::alignTo(Size, FieldAlign);
256 Size += FieldSize;
257 }
258 return Size;
259 }
260
261 if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
262 unsigned ElementCount = AT->getSize().getZExtValue();
263 if (ElementCount == 0)
264 return 0;
265
266 unsigned ElementSize =
267 calculateLegacyCbufferSize(Context, AT->getElementType());
268 unsigned AlignedElementSize = llvm::alignTo(ElementSize, CBufferAlign);
269 return AlignedElementSize * (ElementCount - 1) + ElementSize;
270 }
271
272 if (const VectorType *VT = T->getAs<VectorType>()) {
273 unsigned ElementCount = VT->getNumElements();
274 unsigned ElementSize =
275 calculateLegacyCbufferSize(Context, VT->getElementType());
276 return ElementSize * ElementCount;
277 }
278
279 return Context.getTypeSize(T) / 8;
280}
281
282// Validate packoffset:
283// - if packoffset it used it must be set on all declarations inside the buffer
284// - packoffset ranges must not overlap
285static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) {
287
288 // Make sure the packoffset annotations are either on all declarations
289 // or on none.
290 bool HasPackOffset = false;
291 bool HasNonPackOffset = false;
292 for (auto *Field : BufDecl->buffer_decls()) {
293 VarDecl *Var = dyn_cast<VarDecl>(Field);
294 if (!Var)
295 continue;
296 if (Field->hasAttr<HLSLPackOffsetAttr>()) {
297 PackOffsetVec.emplace_back(Var, Field->getAttr<HLSLPackOffsetAttr>());
298 HasPackOffset = true;
299 } else {
300 HasNonPackOffset = true;
301 }
302 }
303
304 if (!HasPackOffset)
305 return;
306
307 if (HasNonPackOffset)
308 S.Diag(BufDecl->getLocation(), diag::warn_hlsl_packoffset_mix);
309
310 // Make sure there is no overlap in packoffset - sort PackOffsetVec by offset
311 // and compare adjacent values.
312 bool IsValid = true;
313 ASTContext &Context = S.getASTContext();
314 std::sort(PackOffsetVec.begin(), PackOffsetVec.end(),
315 [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,
316 const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {
317 return LHS.second->getOffsetInBytes() <
318 RHS.second->getOffsetInBytes();
319 });
320 for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {
321 VarDecl *Var = PackOffsetVec[i].first;
322 HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;
323 unsigned Size = calculateLegacyCbufferSize(Context, Var->getType());
324 unsigned Begin = Attr->getOffsetInBytes();
325 unsigned End = Begin + Size;
326 unsigned NextBegin = PackOffsetVec[i + 1].second->getOffsetInBytes();
327 if (End > NextBegin) {
328 VarDecl *NextVar = PackOffsetVec[i + 1].first;
329 S.Diag(NextVar->getLocation(), diag::err_hlsl_packoffset_overlap)
330 << NextVar << Var;
331 IsValid = false;
332 }
333 }
334 BufDecl->setHasValidPackoffset(IsValid);
335}
336
337// Returns true if the array has a zero size = if any of the dimensions is 0
338static bool isZeroSizedArray(const ConstantArrayType *CAT) {
339 while (CAT && !CAT->isZeroSize())
340 CAT = dyn_cast<ConstantArrayType>(
342 return CAT != nullptr;
343}
344
346 const Type *Ty = VD->getType().getTypePtr();
348}
349
350static const HLSLAttributedResourceType *
352 assert(VD->getType()->isHLSLResourceRecordArray() &&
353 "expected array of resource records");
354 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
355 while (const ArrayType *AT = dyn_cast<ArrayType>(Ty))
357 return HLSLAttributedResourceType::findHandleTypeOnResource(Ty);
358}
359
360// Returns true if the type is a leaf element type that is not valid to be
361// included in HLSL Buffer, such as a resource class, empty struct, zero-sized
362// array, or a builtin intangible type. Returns false it is a valid leaf element
363// type or if it is a record type that needs to be inspected further.
367 return true;
368 if (const auto *RD = Ty->getAsCXXRecordDecl())
369 return RD->isEmpty();
370 if (Ty->isConstantArrayType() &&
372 return true;
374 return true;
375 return false;
376}
377
378// Returns true if the struct contains at least one element that prevents it
379// from being included inside HLSL Buffer as is, such as an intangible type,
380// empty struct, or zero-sized array. If it does, a new implicit layout struct
381// needs to be created for HLSL Buffer use that will exclude these unwanted
382// declarations (see createHostLayoutStruct function).
384 if (RD->isHLSLIntangible() || RD->isEmpty())
385 return true;
386 // check fields
387 for (const FieldDecl *Field : RD->fields()) {
388 QualType Ty = Field->getType();
390 return true;
391 if (const auto *RD = Ty->getAsCXXRecordDecl();
393 return true;
394 }
395 // check bases
396 for (const CXXBaseSpecifier &Base : RD->bases())
398 Base.getType()->castAsCXXRecordDecl()))
399 return true;
400 return false;
401}
402
404 DeclContext *DC) {
405 CXXRecordDecl *RD = nullptr;
406 for (NamedDecl *Decl :
408 if (CXXRecordDecl *FoundRD = dyn_cast<CXXRecordDecl>(Decl)) {
409 assert(RD == nullptr &&
410 "there should be at most 1 record by a given name in a scope");
411 RD = FoundRD;
412 }
413 }
414 return RD;
415}
416
417// Creates a name for buffer layout struct using the provide name base.
418// If the name must be unique (not previously defined), a suffix is added
419// until a unique name is found.
421 bool MustBeUnique) {
422 ASTContext &AST = S.getASTContext();
423
424 IdentifierInfo *NameBaseII = BaseDecl->getIdentifier();
425 llvm::SmallString<64> Name("__cblayout_");
426 if (NameBaseII) {
427 Name.append(NameBaseII->getName());
428 } else {
429 // anonymous struct
430 Name.append("anon");
431 MustBeUnique = true;
432 }
433
434 size_t NameLength = Name.size();
435 IdentifierInfo *II = &AST.Idents.get(Name, tok::TokenKind::identifier);
436 if (!MustBeUnique)
437 return II;
438
439 unsigned suffix = 0;
440 while (true) {
441 if (suffix != 0) {
442 Name.append("_");
443 Name.append(llvm::Twine(suffix).str());
444 II = &AST.Idents.get(Name, tok::TokenKind::identifier);
445 }
446 if (!findRecordDeclInContext(II, BaseDecl->getDeclContext()))
447 return II;
448 // declaration with that name already exists - increment suffix and try
449 // again until unique name is found
450 suffix++;
451 Name.truncate(NameLength);
452 };
453}
454
455// Creates a field declaration of given name and type for HLSL buffer layout
456// struct. Returns nullptr if the type cannot be use in HLSL Buffer layout.
458 IdentifierInfo *II,
459 CXXRecordDecl *LayoutStruct) {
461 return nullptr;
462
463 if (auto *RD = Ty->getAsCXXRecordDecl()) {
465 RD = createHostLayoutStruct(S, RD);
466 if (!RD)
467 return nullptr;
469 }
470 }
471
472 QualType QT = QualType(Ty, 0);
473 ASTContext &AST = S.getASTContext();
475 auto *Field = FieldDecl::Create(AST, LayoutStruct, SourceLocation(),
476 SourceLocation(), II, QT, TSI, nullptr, false,
478 Field->setAccess(AccessSpecifier::AS_public);
479 return Field;
480}
481
482// Creates host layout struct for a struct included in HLSL Buffer.
483// The layout struct will include only fields that are allowed in HLSL buffer.
484// These fields will be filtered out:
485// - resource classes
486// - empty structs
487// - zero-sized arrays
488// Returns nullptr if the resulting layout struct would be empty.
490 CXXRecordDecl *StructDecl) {
491 assert(requiresImplicitBufferLayoutStructure(StructDecl) &&
492 "struct is already HLSL buffer compatible");
493
494 ASTContext &AST = S.getASTContext();
495 DeclContext *DC = StructDecl->getDeclContext();
496 IdentifierInfo *II = getHostLayoutStructName(S, StructDecl, false);
497
498 // reuse existing if the layout struct if it already exists
499 if (CXXRecordDecl *RD = findRecordDeclInContext(II, DC))
500 return RD;
501
502 CXXRecordDecl *LS =
503 CXXRecordDecl::Create(AST, TagDecl::TagKind::Struct, DC, SourceLocation(),
504 SourceLocation(), II);
505 LS->setImplicit(true);
506 LS->addAttr(PackedAttr::CreateImplicit(AST));
507 LS->startDefinition();
508
509 // copy base struct, create HLSL Buffer compatible version if needed
510 if (unsigned NumBases = StructDecl->getNumBases()) {
511 assert(NumBases == 1 && "HLSL supports only one base type");
512 (void)NumBases;
513 CXXBaseSpecifier Base = *StructDecl->bases_begin();
514 CXXRecordDecl *BaseDecl = Base.getType()->castAsCXXRecordDecl();
516 BaseDecl = createHostLayoutStruct(S, BaseDecl);
517 if (BaseDecl) {
518 TypeSourceInfo *TSI =
520 Base = CXXBaseSpecifier(SourceRange(), false, StructDecl->isClass(),
521 AS_none, TSI, SourceLocation());
522 }
523 }
524 if (BaseDecl) {
525 const CXXBaseSpecifier *BasesArray[1] = {&Base};
526 LS->setBases(BasesArray, 1);
527 }
528 }
529
530 // filter struct fields
531 for (const FieldDecl *FD : StructDecl->fields()) {
532 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
533 if (FieldDecl *NewFD =
534 createFieldForHostLayoutStruct(S, Ty, FD->getIdentifier(), LS))
535 LS->addDecl(NewFD);
536 }
537 LS->completeDefinition();
538
539 if (LS->field_empty() && LS->getNumBases() == 0)
540 return nullptr;
541
542 DC->addDecl(LS);
543 return LS;
544}
545
546// Creates host layout struct for HLSL Buffer. The struct will include only
547// fields of types that are allowed in HLSL buffer and it will filter out:
548// - static or groupshared variable declarations
549// - resource classes
550// - empty structs
551// - zero-sized arrays
552// - non-variable declarations
553// The layout struct will be added to the HLSLBufferDecl declarations.
555 ASTContext &AST = S.getASTContext();
556 IdentifierInfo *II = getHostLayoutStructName(S, BufDecl, true);
557
558 CXXRecordDecl *LS =
559 CXXRecordDecl::Create(AST, TagDecl::TagKind::Struct, BufDecl,
561 LS->addAttr(PackedAttr::CreateImplicit(AST));
562 LS->setImplicit(true);
563 LS->startDefinition();
564
565 for (Decl *D : BufDecl->buffer_decls()) {
566 VarDecl *VD = dyn_cast<VarDecl>(D);
567 if (!VD || VD->getStorageClass() == SC_Static ||
569 continue;
570 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
571 if (FieldDecl *FD =
573 // add the field decl to the layout struct
574 LS->addDecl(FD);
575 // update address space of the original decl to hlsl_constant
576 QualType NewTy =
578 VD->setType(NewTy);
579 }
580 }
581 LS->completeDefinition();
582 BufDecl->addLayoutStruct(LS);
583}
584
586 uint32_t ImplicitBindingOrderID) {
587 auto *Attr =
588 HLSLResourceBindingAttr::CreateImplicit(S.getASTContext(), "", "0", {});
589 Attr->setBinding(RT, std::nullopt, 0);
590 Attr->setImplicitBindingOrderID(ImplicitBindingOrderID);
591 D->addAttr(Attr);
592}
593
594// Handle end of cbuffer/tbuffer declaration
596 auto *BufDecl = cast<HLSLBufferDecl>(Dcl);
597 BufDecl->setRBraceLoc(RBrace);
598
599 validatePackoffset(SemaRef, BufDecl);
600
602
603 // Handle implicit binding if needed.
604 ResourceBindingAttrs ResourceAttrs(Dcl);
605 if (!ResourceAttrs.isExplicit()) {
606 SemaRef.Diag(Dcl->getLocation(), diag::warn_hlsl_implicit_binding);
607 // Use HLSLResourceBindingAttr to transfer implicit binding order_ID
608 // to codegen. If it does not exist, create an implicit attribute.
609 uint32_t OrderID = getNextImplicitBindingOrderID();
610 if (ResourceAttrs.hasBinding())
611 ResourceAttrs.setImplicitOrderID(OrderID);
612 else
614 BufDecl->isCBuffer() ? RegisterType::CBuffer
615 : RegisterType::SRV,
616 OrderID);
617 }
618
619 SemaRef.PopDeclContext();
620}
621
622HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
623 const AttributeCommonInfo &AL,
624 int X, int Y, int Z) {
625 if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
626 if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
627 Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
628 Diag(AL.getLoc(), diag::note_conflicting_attribute);
629 }
630 return nullptr;
631 }
632 return ::new (getASTContext())
633 HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
634}
635
637 const AttributeCommonInfo &AL,
638 int Min, int Max, int Preferred,
639 int SpelledArgsCount) {
640 if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) {
641 if (WS->getMin() != Min || WS->getMax() != Max ||
642 WS->getPreferred() != Preferred ||
643 WS->getSpelledArgsCount() != SpelledArgsCount) {
644 Diag(WS->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
645 Diag(AL.getLoc(), diag::note_conflicting_attribute);
646 }
647 return nullptr;
648 }
649 HLSLWaveSizeAttr *Result = ::new (getASTContext())
650 HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
651 Result->setSpelledArgsCount(SpelledArgsCount);
652 return Result;
653}
654
655HLSLVkConstantIdAttr *
657 int Id) {
658
660 if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
661 Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL;
662 return nullptr;
663 }
664
665 auto *VD = cast<VarDecl>(D);
666
667 if (getSpecConstBuiltinId(VD->getType()->getUnqualifiedDesugaredType()) ==
669 Diag(VD->getLocation(), diag::err_specialization_const);
670 return nullptr;
671 }
672
673 if (!VD->getType().isConstQualified()) {
674 Diag(VD->getLocation(), diag::err_specialization_const);
675 return nullptr;
676 }
677
678 if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
679 if (CI->getId() != Id) {
680 Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
681 Diag(AL.getLoc(), diag::note_conflicting_attribute);
682 }
683 return nullptr;
684 }
685
686 HLSLVkConstantIdAttr *Result =
687 ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
688 return Result;
689}
690
691HLSLShaderAttr *
693 llvm::Triple::EnvironmentType ShaderType) {
694 if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
695 if (NT->getType() != ShaderType) {
696 Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
697 Diag(AL.getLoc(), diag::note_conflicting_attribute);
698 }
699 return nullptr;
700 }
701 return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL);
702}
703
704HLSLParamModifierAttr *
706 HLSLParamModifierAttr::Spelling Spelling) {
707 // We can only merge an `in` attribute with an `out` attribute. All other
708 // combinations of duplicated attributes are ill-formed.
709 if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
710 if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
711 (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
712 D->dropAttr<HLSLParamModifierAttr>();
713 SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
714 return HLSLParamModifierAttr::Create(
715 getASTContext(), /*MergedSpelling=*/true, AdjustedRange,
716 HLSLParamModifierAttr::Keyword_inout);
717 }
718 Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
719 Diag(PA->getLocation(), diag::note_conflicting_attribute);
720 return nullptr;
721 }
722 return HLSLParamModifierAttr::Create(getASTContext(), AL);
723}
724
727
729 return;
730
731 // If we have specified a root signature to override the entry function then
732 // attach it now
733 HLSLRootSignatureDecl *SignatureDecl =
735 if (SignatureDecl) {
736 FD->dropAttr<RootSignatureAttr>();
737 // We could look up the SourceRange of the macro here as well
738 AttributeCommonInfo AL(RootSigOverrideIdent, AttributeScopeInfo(),
739 SourceRange(), ParsedAttr::Form::Microsoft());
740 FD->addAttr(::new (getASTContext()) RootSignatureAttr(
741 getASTContext(), AL, RootSigOverrideIdent, SignatureDecl));
742 }
743
744 llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
745 if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) {
746 if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
747 // The entry point is already annotated - check that it matches the
748 // triple.
749 if (Shader->getType() != Env) {
750 Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
751 << Shader;
752 FD->setInvalidDecl();
753 }
754 } else {
755 // Implicitly add the shader attribute if the entry function isn't
756 // explicitly annotated.
757 FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env,
758 FD->getBeginLoc()));
759 }
760 } else {
761 switch (Env) {
762 case llvm::Triple::UnknownEnvironment:
763 case llvm::Triple::Library:
764 break;
765 case llvm::Triple::RootSignature:
766 llvm_unreachable("rootsig environment has no functions");
767 default:
768 llvm_unreachable("Unhandled environment in triple");
769 }
770 }
771}
772
773static bool isVkPipelineBuiltin(const ASTContext &AstContext, FunctionDecl *FD,
774 HLSLAppliedSemanticAttr *Semantic,
775 bool IsInput) {
776 if (AstContext.getTargetInfo().getTriple().getOS() != llvm::Triple::Vulkan)
777 return false;
778
779 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
780 assert(ShaderAttr && "Entry point has no shader attribute");
781 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
782 auto SemanticName = Semantic->getSemanticName().upper();
783
784 // The SV_Position semantic is lowered to:
785 // - Position built-in for vertex output.
786 // - FragCoord built-in for fragment input.
787 if (SemanticName == "SV_POSITION") {
788 return (ST == llvm::Triple::Vertex && !IsInput) ||
789 (ST == llvm::Triple::Pixel && IsInput);
790 }
791
792 return false;
793}
794
795bool SemaHLSL::determineActiveSemanticOnScalar(FunctionDecl *FD,
796 DeclaratorDecl *OutputDecl,
798 SemanticInfo &ActiveSemantic,
799 SemaHLSL::SemanticContext &SC) {
800 if (ActiveSemantic.Semantic == nullptr) {
801 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
802 if (ActiveSemantic.Semantic)
803 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
804 }
805
806 if (!ActiveSemantic.Semantic) {
807 Diag(D->getLocation(), diag::err_hlsl_missing_semantic_annotation);
808 return false;
809 }
810
811 auto *A = ::new (getASTContext())
812 HLSLAppliedSemanticAttr(getASTContext(), *ActiveSemantic.Semantic,
813 ActiveSemantic.Semantic->getAttrName()->getName(),
814 ActiveSemantic.Index.value_or(0));
815 if (!A)
817
818 checkSemanticAnnotation(FD, D, A, SC);
819 OutputDecl->addAttr(A);
820
821 unsigned Location = ActiveSemantic.Index.value_or(0);
822
824 SC.CurrentIOType & IOType::In)) {
825 bool HasVkLocation = false;
826 if (auto *A = D->getAttr<HLSLVkLocationAttr>()) {
827 HasVkLocation = true;
828 Location = A->getLocation();
829 }
830
831 if (SC.UsesExplicitVkLocations.value_or(HasVkLocation) != HasVkLocation) {
832 Diag(D->getLocation(), diag::err_hlsl_semantic_partial_explicit_indexing);
833 return false;
834 }
835 SC.UsesExplicitVkLocations = HasVkLocation;
836 }
837
838 const ConstantArrayType *AT = dyn_cast<ConstantArrayType>(D->getType());
839 unsigned ElementCount = AT ? AT->getZExtSize() : 1;
840 ActiveSemantic.Index = Location + ElementCount;
841
842 Twine BaseName = Twine(ActiveSemantic.Semantic->getAttrName()->getName());
843 for (unsigned I = 0; I < ElementCount; ++I) {
844 Twine VariableName = BaseName.concat(Twine(Location + I));
845
846 auto [_, Inserted] = SC.ActiveSemantics.insert(VariableName.str());
847 if (!Inserted) {
848 Diag(D->getLocation(), diag::err_hlsl_semantic_index_overlap)
849 << VariableName.str();
850 return false;
851 }
852 }
853
854 return true;
855}
856
857bool SemaHLSL::determineActiveSemantic(FunctionDecl *FD,
858 DeclaratorDecl *OutputDecl,
860 SemanticInfo &ActiveSemantic,
861 SemaHLSL::SemanticContext &SC) {
862 if (ActiveSemantic.Semantic == nullptr) {
863 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
864 if (ActiveSemantic.Semantic)
865 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
866 }
867
868 const Type *T = D == FD ? &*FD->getReturnType() : &*D->getType();
870
871 const RecordType *RT = dyn_cast<RecordType>(T);
872 if (!RT)
873 return determineActiveSemanticOnScalar(FD, OutputDecl, D, ActiveSemantic,
874 SC);
875
876 const RecordDecl *RD = RT->getDecl();
877 for (FieldDecl *Field : RD->fields()) {
878 SemanticInfo Info = ActiveSemantic;
879 if (!determineActiveSemantic(FD, OutputDecl, Field, Info, SC)) {
880 Diag(Field->getLocation(), diag::note_hlsl_semantic_used_here) << Field;
881 return false;
882 }
883 if (ActiveSemantic.Semantic)
884 ActiveSemantic = Info;
885 }
886
887 return true;
888}
889
891 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
892 assert(ShaderAttr && "Entry point has no shader attribute");
893 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
895 VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
896 switch (ST) {
897 case llvm::Triple::Pixel:
898 case llvm::Triple::Vertex:
899 case llvm::Triple::Geometry:
900 case llvm::Triple::Hull:
901 case llvm::Triple::Domain:
902 case llvm::Triple::RayGeneration:
903 case llvm::Triple::Intersection:
904 case llvm::Triple::AnyHit:
905 case llvm::Triple::ClosestHit:
906 case llvm::Triple::Miss:
907 case llvm::Triple::Callable:
908 if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
909 diagnoseAttrStageMismatch(NT, ST,
910 {llvm::Triple::Compute,
911 llvm::Triple::Amplification,
912 llvm::Triple::Mesh});
913 FD->setInvalidDecl();
914 }
915 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
916 diagnoseAttrStageMismatch(WS, ST,
917 {llvm::Triple::Compute,
918 llvm::Triple::Amplification,
919 llvm::Triple::Mesh});
920 FD->setInvalidDecl();
921 }
922 break;
923
924 case llvm::Triple::Compute:
925 case llvm::Triple::Amplification:
926 case llvm::Triple::Mesh:
927 if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
928 Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
929 << llvm::Triple::getEnvironmentTypeName(ST);
930 FD->setInvalidDecl();
931 }
932 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
933 if (Ver < VersionTuple(6, 6)) {
934 Diag(WS->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model)
935 << WS << "6.6";
936 FD->setInvalidDecl();
937 } else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
938 Diag(
939 WS->getLocation(),
940 diag::err_hlsl_attribute_number_arguments_insufficient_shader_model)
941 << WS << WS->getSpelledArgsCount() << "6.8";
942 FD->setInvalidDecl();
943 }
944 }
945 break;
946 case llvm::Triple::RootSignature:
947 llvm_unreachable("rootsig environment has no function entry point");
948 default:
949 llvm_unreachable("Unhandled environment in triple");
950 }
951
952 SemaHLSL::SemanticContext InputSC = {};
953 InputSC.CurrentIOType = IOType::In;
954
955 for (ParmVarDecl *Param : FD->parameters()) {
956 SemanticInfo ActiveSemantic;
957 ActiveSemantic.Semantic = Param->getAttr<HLSLParsedSemanticAttr>();
958 if (ActiveSemantic.Semantic)
959 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
960
961 // FIXME: Verify output semantics in parameters.
962 if (!determineActiveSemantic(FD, Param, Param, ActiveSemantic, InputSC)) {
963 Diag(Param->getLocation(), diag::note_previous_decl) << Param;
964 FD->setInvalidDecl();
965 }
966 }
967
968 SemanticInfo ActiveSemantic;
969 SemaHLSL::SemanticContext OutputSC = {};
970 OutputSC.CurrentIOType = IOType::Out;
971 ActiveSemantic.Semantic = FD->getAttr<HLSLParsedSemanticAttr>();
972 if (ActiveSemantic.Semantic)
973 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
974 if (!FD->getReturnType()->isVoidType())
975 determineActiveSemantic(FD, FD, FD, ActiveSemantic, OutputSC);
976}
977
978void SemaHLSL::checkSemanticAnnotation(
979 FunctionDecl *EntryPoint, const Decl *Param,
980 const HLSLAppliedSemanticAttr *SemanticAttr, const SemanticContext &SC) {
981 auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
982 assert(ShaderAttr && "Entry point has no shader attribute");
983 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
984
985 auto SemanticName = SemanticAttr->getSemanticName().upper();
986 if (SemanticName == "SV_DISPATCHTHREADID" ||
987 SemanticName == "SV_GROUPINDEX" || SemanticName == "SV_GROUPTHREADID" ||
988 SemanticName == "SV_GROUPID") {
989
990 if (ST != llvm::Triple::Compute)
991 diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
992 {{llvm::Triple::Compute, IOType::In}});
993
994 if (SemanticAttr->getSemanticIndex() != 0) {
995 std::string PrettyName =
996 "'" + SemanticAttr->getSemanticName().str() + "'";
997 Diag(SemanticAttr->getLoc(),
998 diag::err_hlsl_semantic_indexing_not_supported)
999 << PrettyName;
1000 }
1001 return;
1002 }
1003
1004 if (SemanticName == "SV_POSITION") {
1005 // SV_Position can be an input or output in vertex shaders,
1006 // but only an input in pixel shaders.
1007 diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
1008 {{llvm::Triple::Vertex, IOType::InOut},
1009 {llvm::Triple::Pixel, IOType::In}});
1010 return;
1011 }
1012
1013 if (SemanticName == "SV_TARGET") {
1014 diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
1015 {{llvm::Triple::Pixel, IOType::Out}});
1016 return;
1017 }
1018
1019 // FIXME: catch-all for non-implemented system semantics reaching this
1020 // location.
1021 if (SemanticAttr->getAttrName()->getName().starts_with_insensitive("SV_"))
1022 llvm_unreachable("Unknown SemanticAttr");
1023}
1024
1025void SemaHLSL::diagnoseAttrStageMismatch(
1026 const Attr *A, llvm::Triple::EnvironmentType Stage,
1027 std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
1028 SmallVector<StringRef, 8> StageStrings;
1029 llvm::transform(AllowedStages, std::back_inserter(StageStrings),
1030 [](llvm::Triple::EnvironmentType ST) {
1031 return StringRef(
1032 HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST));
1033 });
1034 Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
1035 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Stage)
1036 << (AllowedStages.size() != 1) << join(StageStrings, ", ");
1037}
1038
1039void SemaHLSL::diagnoseSemanticStageMismatch(
1040 const Attr *A, llvm::Triple::EnvironmentType Stage, IOType CurrentIOType,
1041 std::initializer_list<SemanticStageInfo> Allowed) {
1042
1043 for (auto &Case : Allowed) {
1044 if (Case.Stage != Stage)
1045 continue;
1046
1047 if (CurrentIOType & Case.AllowedIOTypesMask)
1048 return;
1049
1050 SmallVector<std::string, 8> ValidCases;
1051 llvm::transform(
1052 Allowed, std::back_inserter(ValidCases), [](SemanticStageInfo Case) {
1053 SmallVector<std::string, 2> ValidType;
1054 if (Case.AllowedIOTypesMask & IOType::In)
1055 ValidType.push_back("input");
1056 if (Case.AllowedIOTypesMask & IOType::Out)
1057 ValidType.push_back("output");
1058 return std::string(
1059 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Case.Stage)) +
1060 " " + join(ValidType, "/");
1061 });
1062 Diag(A->getLoc(), diag::err_hlsl_semantic_unsupported_iotype_for_stage)
1063 << A->getAttrName() << (CurrentIOType & IOType::In ? "input" : "output")
1064 << llvm::Triple::getEnvironmentTypeName(Case.Stage)
1065 << join(ValidCases, ", ");
1066 return;
1067 }
1068
1069 SmallVector<StringRef, 8> StageStrings;
1070 llvm::transform(
1071 Allowed, std::back_inserter(StageStrings), [](SemanticStageInfo Case) {
1072 return StringRef(
1073 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Case.Stage));
1074 });
1075
1076 Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
1077 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Stage)
1078 << (Allowed.size() != 1) << join(StageStrings, ", ");
1079}
1080
1081template <CastKind Kind>
1082static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
1083 if (const auto *VTy = Ty->getAs<VectorType>())
1084 Ty = VTy->getElementType();
1085 Ty = S.getASTContext().getExtVectorType(Ty, Sz);
1086 E = S.ImpCastExprToType(E.get(), Ty, Kind);
1087}
1088
1089template <CastKind Kind>
1091 E = S.ImpCastExprToType(E.get(), Ty, Kind);
1092 return Ty;
1093}
1094
1096 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1097 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1098 bool LHSFloat = LElTy->isRealFloatingType();
1099 bool RHSFloat = RElTy->isRealFloatingType();
1100
1101 if (LHSFloat && RHSFloat) {
1102 if (IsCompAssign ||
1103 SemaRef.getASTContext().getFloatingTypeOrder(LElTy, RElTy) > 0)
1104 return castElement<CK_FloatingCast>(SemaRef, RHS, LHSType);
1105
1106 return castElement<CK_FloatingCast>(SemaRef, LHS, RHSType);
1107 }
1108
1109 if (LHSFloat)
1110 return castElement<CK_IntegralToFloating>(SemaRef, RHS, LHSType);
1111
1112 assert(RHSFloat);
1113 if (IsCompAssign)
1114 return castElement<clang::CK_FloatingToIntegral>(SemaRef, RHS, LHSType);
1115
1116 return castElement<CK_IntegralToFloating>(SemaRef, LHS, RHSType);
1117}
1118
1120 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1121 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1122
1123 int IntOrder = SemaRef.Context.getIntegerTypeOrder(LElTy, RElTy);
1124 bool LHSSigned = LElTy->hasSignedIntegerRepresentation();
1125 bool RHSSigned = RElTy->hasSignedIntegerRepresentation();
1126 auto &Ctx = SemaRef.getASTContext();
1127
1128 // If both types have the same signedness, use the higher ranked type.
1129 if (LHSSigned == RHSSigned) {
1130 if (IsCompAssign || IntOrder >= 0)
1131 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1132
1133 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
1134 }
1135
1136 // If the unsigned type has greater than or equal rank of the signed type, use
1137 // the unsigned type.
1138 if (IntOrder != (LHSSigned ? 1 : -1)) {
1139 if (IsCompAssign || RHSSigned)
1140 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1141 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
1142 }
1143
1144 // At this point the signed type has higher rank than the unsigned type, which
1145 // means it will be the same size or bigger. If the signed type is bigger, it
1146 // can represent all the values of the unsigned type, so select it.
1147 if (Ctx.getIntWidth(LElTy) != Ctx.getIntWidth(RElTy)) {
1148 if (IsCompAssign || LHSSigned)
1149 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1150 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
1151 }
1152
1153 // This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due
1154 // to C/C++ leaking through. The place this happens today is long vs long
1155 // long. When arguments are vector<unsigned long, N> and vector<long long, N>,
1156 // the long long has higher rank than long even though they are the same size.
1157
1158 // If this is a compound assignment cast the right hand side to the left hand
1159 // side's type.
1160 if (IsCompAssign)
1161 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1162
1163 // If this isn't a compound assignment we convert to unsigned long long.
1164 QualType ElTy = Ctx.getCorrespondingUnsignedType(LHSSigned ? LElTy : RElTy);
1165 QualType NewTy = Ctx.getExtVectorType(
1166 ElTy, RHSType->castAs<VectorType>()->getNumElements());
1167 (void)castElement<CK_IntegralCast>(SemaRef, RHS, NewTy);
1168
1169 return castElement<CK_IntegralCast>(SemaRef, LHS, NewTy);
1170}
1171
1173 QualType SrcTy) {
1174 if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType())
1175 return CK_FloatingCast;
1176 if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx))
1177 return CK_IntegralCast;
1178 if (DestTy->isRealFloatingType())
1179 return CK_IntegralToFloating;
1180 assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx));
1181 return CK_FloatingToIntegral;
1182}
1183
1185 QualType LHSType,
1186 QualType RHSType,
1187 bool IsCompAssign) {
1188 const auto *LVecTy = LHSType->getAs<VectorType>();
1189 const auto *RVecTy = RHSType->getAs<VectorType>();
1190 auto &Ctx = getASTContext();
1191
1192 // If the LHS is not a vector and this is a compound assignment, we truncate
1193 // the argument to a scalar then convert it to the LHS's type.
1194 if (!LVecTy && IsCompAssign) {
1195 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1196 RHS = SemaRef.ImpCastExprToType(RHS.get(), RElTy, CK_HLSLVectorTruncation);
1197 RHSType = RHS.get()->getType();
1198 if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
1199 return LHSType;
1200 RHS = SemaRef.ImpCastExprToType(RHS.get(), LHSType,
1201 getScalarCastKind(Ctx, LHSType, RHSType));
1202 return LHSType;
1203 }
1204
1205 unsigned EndSz = std::numeric_limits<unsigned>::max();
1206 unsigned LSz = 0;
1207 if (LVecTy)
1208 LSz = EndSz = LVecTy->getNumElements();
1209 if (RVecTy)
1210 EndSz = std::min(RVecTy->getNumElements(), EndSz);
1211 assert(EndSz != std::numeric_limits<unsigned>::max() &&
1212 "one of the above should have had a value");
1213
1214 // In a compound assignment, the left operand does not change type, the right
1215 // operand is converted to the type of the left operand.
1216 if (IsCompAssign && LSz != EndSz) {
1217 Diag(LHS.get()->getBeginLoc(),
1218 diag::err_hlsl_vector_compound_assignment_truncation)
1219 << LHSType << RHSType;
1220 return QualType();
1221 }
1222
1223 if (RVecTy && RVecTy->getNumElements() > EndSz)
1224 castVector<CK_HLSLVectorTruncation>(SemaRef, RHS, RHSType, EndSz);
1225 if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz)
1226 castVector<CK_HLSLVectorTruncation>(SemaRef, LHS, LHSType, EndSz);
1227
1228 if (!RVecTy)
1229 castVector<CK_VectorSplat>(SemaRef, RHS, RHSType, EndSz);
1230 if (!IsCompAssign && !LVecTy)
1231 castVector<CK_VectorSplat>(SemaRef, LHS, LHSType, EndSz);
1232
1233 // If we're at the same type after resizing we can stop here.
1234 if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
1235 return Ctx.getCommonSugaredType(LHSType, RHSType);
1236
1237 QualType LElTy = LHSType->castAs<VectorType>()->getElementType();
1238 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1239
1240 // Handle conversion for floating point vectors.
1241 if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType())
1242 return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1243 LElTy, RElTy, IsCompAssign);
1244
1245 assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) &&
1246 "HLSL Vectors can only contain integer or floating point types");
1247 return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1248 LElTy, RElTy, IsCompAssign);
1249}
1250
1252 BinaryOperatorKind Opc) {
1253 assert((Opc == BO_LOr || Opc == BO_LAnd) &&
1254 "Called with non-logical operator");
1256 llvm::raw_svector_ostream OS(Buff);
1257 PrintingPolicy PP(SemaRef.getLangOpts());
1258 StringRef NewFnName = Opc == BO_LOr ? "or" : "and";
1259 OS << NewFnName << "(";
1260 LHS->printPretty(OS, nullptr, PP);
1261 OS << ", ";
1262 RHS->printPretty(OS, nullptr, PP);
1263 OS << ")";
1264 SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc());
1265 SemaRef.Diag(LHS->getBeginLoc(), diag::note_function_suggestion)
1266 << NewFnName << FixItHint::CreateReplacement(FullRange, OS.str());
1267}
1268
1269std::pair<IdentifierInfo *, bool>
1271 llvm::hash_code Hash = llvm::hash_value(Signature);
1272 std::string IdStr = "__hlsl_rootsig_decl_" + std::to_string(Hash);
1273 IdentifierInfo *DeclIdent = &(getASTContext().Idents.get(IdStr));
1274
1275 // Check if we have already found a decl of the same name.
1276 LookupResult R(SemaRef, DeclIdent, SourceLocation(),
1278 bool Found = SemaRef.LookupQualifiedName(R, SemaRef.CurContext);
1279 return {DeclIdent, Found};
1280}
1281
1283 SourceLocation Loc, IdentifierInfo *DeclIdent,
1285
1286 if (handleRootSignatureElements(RootElements))
1287 return;
1288
1290 for (auto &RootSigElement : RootElements)
1291 Elements.push_back(RootSigElement.getElement());
1292
1293 auto *SignatureDecl = HLSLRootSignatureDecl::Create(
1294 SemaRef.getASTContext(), /*DeclContext=*/SemaRef.CurContext, Loc,
1295 DeclIdent, SemaRef.getLangOpts().HLSLRootSigVer, Elements);
1296
1297 SignatureDecl->setImplicit();
1298 SemaRef.PushOnScopeChains(SignatureDecl, SemaRef.getCurScope());
1299}
1300
1303 if (RootSigOverrideIdent) {
1304 LookupResult R(SemaRef, RootSigOverrideIdent, SourceLocation(),
1306 if (SemaRef.LookupQualifiedName(R, DC))
1307 return dyn_cast<HLSLRootSignatureDecl>(R.getFoundDecl());
1308 }
1309
1310 return nullptr;
1311}
1312
1313namespace {
1314
1315struct PerVisibilityBindingChecker {
1316 SemaHLSL *S;
1317 // We need one builder per `llvm::dxbc::ShaderVisibility` value.
1318 std::array<llvm::hlsl::BindingInfoBuilder, 8> Builders;
1319
1320 struct ElemInfo {
1321 const hlsl::RootSignatureElement *Elem;
1322 llvm::dxbc::ShaderVisibility Vis;
1323 bool Diagnosed;
1324 };
1325 llvm::SmallVector<ElemInfo> ElemInfoMap;
1326
1327 PerVisibilityBindingChecker(SemaHLSL *S) : S(S) {}
1328
1329 void trackBinding(llvm::dxbc::ShaderVisibility Visibility,
1330 llvm::dxil::ResourceClass RC, uint32_t Space,
1331 uint32_t LowerBound, uint32_t UpperBound,
1332 const hlsl::RootSignatureElement *Elem) {
1333 uint32_t BuilderIndex = llvm::to_underlying(Visibility);
1334 assert(BuilderIndex < Builders.size() &&
1335 "Not enough builders for visibility type");
1336 Builders[BuilderIndex].trackBinding(RC, Space, LowerBound, UpperBound,
1337 static_cast<const void *>(Elem));
1338
1339 static_assert(llvm::to_underlying(llvm::dxbc::ShaderVisibility::All) == 0,
1340 "'All' visibility must come first");
1341 if (Visibility == llvm::dxbc::ShaderVisibility::All)
1342 for (size_t I = 1, E = Builders.size(); I < E; ++I)
1343 Builders[I].trackBinding(RC, Space, LowerBound, UpperBound,
1344 static_cast<const void *>(Elem));
1345
1346 ElemInfoMap.push_back({Elem, Visibility, false});
1347 }
1348
1349 ElemInfo &getInfo(const hlsl::RootSignatureElement *Elem) {
1350 auto It = llvm::lower_bound(
1351 ElemInfoMap, Elem,
1352 [](const auto &LHS, const auto &RHS) { return LHS.Elem < RHS; });
1353 assert(It->Elem == Elem && "Element not in map");
1354 return *It;
1355 }
1356
1357 bool checkOverlap() {
1358 llvm::sort(ElemInfoMap, [](const auto &LHS, const auto &RHS) {
1359 return LHS.Elem < RHS.Elem;
1360 });
1361
1362 bool HadOverlap = false;
1363
1364 using llvm::hlsl::BindingInfoBuilder;
1365 auto ReportOverlap = [this,
1366 &HadOverlap](const BindingInfoBuilder &Builder,
1367 const llvm::hlsl::Binding &Reported) {
1368 HadOverlap = true;
1369
1370 const auto *Elem =
1371 static_cast<const hlsl::RootSignatureElement *>(Reported.Cookie);
1372 const llvm::hlsl::Binding &Previous = Builder.findOverlapping(Reported);
1373 const auto *PrevElem =
1374 static_cast<const hlsl::RootSignatureElement *>(Previous.Cookie);
1375
1376 ElemInfo &Info = getInfo(Elem);
1377 // We will have already diagnosed this binding if there's overlap in the
1378 // "All" visibility as well as any particular visibility.
1379 if (Info.Diagnosed)
1380 return;
1381 Info.Diagnosed = true;
1382
1383 ElemInfo &PrevInfo = getInfo(PrevElem);
1384 llvm::dxbc::ShaderVisibility CommonVis =
1385 Info.Vis == llvm::dxbc::ShaderVisibility::All ? PrevInfo.Vis
1386 : Info.Vis;
1387
1388 this->S->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap)
1389 << llvm::to_underlying(Reported.RC) << Reported.LowerBound
1390 << Reported.isUnbounded() << Reported.UpperBound
1391 << llvm::to_underlying(Previous.RC) << Previous.LowerBound
1392 << Previous.isUnbounded() << Previous.UpperBound << Reported.Space
1393 << CommonVis;
1394
1395 this->S->Diag(PrevElem->getLocation(),
1396 diag::note_hlsl_resource_range_here);
1397 };
1398
1399 for (BindingInfoBuilder &Builder : Builders)
1400 Builder.calculateBindingInfo(ReportOverlap);
1401
1402 return HadOverlap;
1403 }
1404};
1405
1406static CXXMethodDecl *lookupMethod(Sema &S, CXXRecordDecl *RecordDecl,
1407 StringRef Name, SourceLocation Loc) {
1408 DeclarationName DeclName(&S.getASTContext().Idents.get(Name));
1409 LookupResult Result(S, DeclName, Loc, Sema::LookupMemberName);
1410 if (!S.LookupQualifiedName(Result, static_cast<DeclContext *>(RecordDecl)))
1411 return nullptr;
1412 return cast<CXXMethodDecl>(Result.getFoundDecl());
1413}
1414
1415} // end anonymous namespace
1416
1417static bool hasCounterHandle(const CXXRecordDecl *RD) {
1418 if (RD->field_empty())
1419 return false;
1420 auto It = std::next(RD->field_begin());
1421 if (It == RD->field_end())
1422 return false;
1423 const FieldDecl *SecondField = *It;
1424 if (const auto *ResTy =
1425 SecondField->getType()->getAs<HLSLAttributedResourceType>()) {
1426 return ResTy->getAttrs().IsCounter;
1427 }
1428 return false;
1429}
1430
1433 // Define some common error handling functions
1434 bool HadError = false;
1435 auto ReportError = [this, &HadError](SourceLocation Loc, uint32_t LowerBound,
1436 uint32_t UpperBound) {
1437 HadError = true;
1438 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value)
1439 << LowerBound << UpperBound;
1440 };
1441
1442 auto ReportFloatError = [this, &HadError](SourceLocation Loc,
1443 float LowerBound,
1444 float UpperBound) {
1445 HadError = true;
1446 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value)
1447 << llvm::formatv("{0:f}", LowerBound).sstr<6>()
1448 << llvm::formatv("{0:f}", UpperBound).sstr<6>();
1449 };
1450
1451 auto VerifyRegister = [ReportError](SourceLocation Loc, uint32_t Register) {
1452 if (!llvm::hlsl::rootsig::verifyRegisterValue(Register))
1453 ReportError(Loc, 0, 0xfffffffe);
1454 };
1455
1456 auto VerifySpace = [ReportError](SourceLocation Loc, uint32_t Space) {
1457 if (!llvm::hlsl::rootsig::verifyRegisterSpace(Space))
1458 ReportError(Loc, 0, 0xffffffef);
1459 };
1460
1461 const uint32_t Version =
1462 llvm::to_underlying(SemaRef.getLangOpts().HLSLRootSigVer);
1463 const uint32_t VersionEnum = Version - 1;
1464 auto ReportFlagError = [this, &HadError, VersionEnum](SourceLocation Loc) {
1465 HadError = true;
1466 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_flag)
1467 << /*version minor*/ VersionEnum;
1468 };
1469
1470 // Iterate through the elements and do basic validations
1471 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1472 SourceLocation Loc = RootSigElem.getLocation();
1473 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1474 if (const auto *Descriptor =
1475 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
1476 VerifyRegister(Loc, Descriptor->Reg.Number);
1477 VerifySpace(Loc, Descriptor->Space);
1478
1479 if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(Version,
1480 Descriptor->Flags))
1481 ReportFlagError(Loc);
1482 } else if (const auto *Constants =
1483 std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
1484 VerifyRegister(Loc, Constants->Reg.Number);
1485 VerifySpace(Loc, Constants->Space);
1486 } else if (const auto *Sampler =
1487 std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
1488 VerifyRegister(Loc, Sampler->Reg.Number);
1489 VerifySpace(Loc, Sampler->Space);
1490
1491 assert(!std::isnan(Sampler->MaxLOD) && !std::isnan(Sampler->MinLOD) &&
1492 "By construction, parseFloatParam can't produce a NaN from a "
1493 "float_literal token");
1494
1495 if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler->MaxAnisotropy))
1496 ReportError(Loc, 0, 16);
1497 if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler->MipLODBias))
1498 ReportFloatError(Loc, -16.f, 15.99f);
1499 } else if (const auto *Clause =
1500 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1501 &Elem)) {
1502 VerifyRegister(Loc, Clause->Reg.Number);
1503 VerifySpace(Loc, Clause->Space);
1504
1505 if (!llvm::hlsl::rootsig::verifyNumDescriptors(Clause->NumDescriptors)) {
1506 // NumDescriptor could techincally be ~0u but that is reserved for
1507 // unbounded, so the diagnostic will not report that as a valid int
1508 // value
1509 ReportError(Loc, 1, 0xfffffffe);
1510 }
1511
1512 if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(Version, Clause->Type,
1513 Clause->Flags))
1514 ReportFlagError(Loc);
1515 }
1516 }
1517
1518 PerVisibilityBindingChecker BindingChecker(this);
1519 SmallVector<std::pair<const llvm::hlsl::rootsig::DescriptorTableClause *,
1521 UnboundClauses;
1522
1523 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1524 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1525 if (const auto *Descriptor =
1526 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
1527 uint32_t LowerBound(Descriptor->Reg.Number);
1528 uint32_t UpperBound(LowerBound); // inclusive range
1529
1530 BindingChecker.trackBinding(
1531 Descriptor->Visibility,
1532 static_cast<llvm::dxil::ResourceClass>(Descriptor->Type),
1533 Descriptor->Space, LowerBound, UpperBound, &RootSigElem);
1534 } else if (const auto *Constants =
1535 std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
1536 uint32_t LowerBound(Constants->Reg.Number);
1537 uint32_t UpperBound(LowerBound); // inclusive range
1538
1539 BindingChecker.trackBinding(
1540 Constants->Visibility, llvm::dxil::ResourceClass::CBuffer,
1541 Constants->Space, LowerBound, UpperBound, &RootSigElem);
1542 } else if (const auto *Sampler =
1543 std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
1544 uint32_t LowerBound(Sampler->Reg.Number);
1545 uint32_t UpperBound(LowerBound); // inclusive range
1546
1547 BindingChecker.trackBinding(
1548 Sampler->Visibility, llvm::dxil::ResourceClass::Sampler,
1549 Sampler->Space, LowerBound, UpperBound, &RootSigElem);
1550 } else if (const auto *Clause =
1551 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1552 &Elem)) {
1553 // We'll process these once we see the table element.
1554 UnboundClauses.emplace_back(Clause, &RootSigElem);
1555 } else if (const auto *Table =
1556 std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
1557 assert(UnboundClauses.size() == Table->NumClauses &&
1558 "Number of unbound elements must match the number of clauses");
1559 bool HasAnySampler = false;
1560 bool HasAnyNonSampler = false;
1561 uint64_t Offset = 0;
1562 bool IsPrevUnbound = false;
1563 for (const auto &[Clause, ClauseElem] : UnboundClauses) {
1564 SourceLocation Loc = ClauseElem->getLocation();
1565 if (Clause->Type == llvm::dxil::ResourceClass::Sampler)
1566 HasAnySampler = true;
1567 else
1568 HasAnyNonSampler = true;
1569
1570 if (HasAnySampler && HasAnyNonSampler)
1571 Diag(Loc, diag::err_hlsl_invalid_mixed_resources);
1572
1573 // Relevant error will have already been reported above and needs to be
1574 // fixed before we can conduct further analysis, so shortcut error
1575 // return
1576 if (Clause->NumDescriptors == 0)
1577 return true;
1578
1579 bool IsAppending =
1580 Clause->Offset == llvm::hlsl::rootsig::DescriptorTableOffsetAppend;
1581 if (!IsAppending)
1582 Offset = Clause->Offset;
1583
1584 uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound(
1585 Offset, Clause->NumDescriptors);
1586
1587 if (IsPrevUnbound && IsAppending)
1588 Diag(Loc, diag::err_hlsl_appending_onto_unbound);
1589 else if (!llvm::hlsl::rootsig::verifyNoOverflowedOffset(RangeBound))
1590 Diag(Loc, diag::err_hlsl_offset_overflow) << Offset << RangeBound;
1591
1592 // Update offset to be 1 past this range's bound
1593 Offset = RangeBound + 1;
1594 IsPrevUnbound = Clause->NumDescriptors ==
1595 llvm::hlsl::rootsig::NumDescriptorsUnbounded;
1596
1597 // Compute the register bounds and track resource binding
1598 uint32_t LowerBound(Clause->Reg.Number);
1599 uint32_t UpperBound = llvm::hlsl::rootsig::computeRangeBound(
1600 LowerBound, Clause->NumDescriptors);
1601
1602 BindingChecker.trackBinding(
1603 Table->Visibility,
1604 static_cast<llvm::dxil::ResourceClass>(Clause->Type), Clause->Space,
1605 LowerBound, UpperBound, ClauseElem);
1606 }
1607 UnboundClauses.clear();
1608 }
1609 }
1610
1611 return BindingChecker.checkOverlap();
1612}
1613
1615 if (AL.getNumArgs() != 1) {
1616 Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1;
1617 return;
1618 }
1619
1621 if (auto *RS = D->getAttr<RootSignatureAttr>()) {
1622 if (RS->getSignatureIdent() != Ident) {
1623 Diag(AL.getLoc(), diag::err_disallowed_duplicate_attribute) << RS;
1624 return;
1625 }
1626
1627 Diag(AL.getLoc(), diag::warn_duplicate_attribute_exact) << RS;
1628 return;
1629 }
1630
1632 if (SemaRef.LookupQualifiedName(R, D->getDeclContext()))
1633 if (auto *SignatureDecl =
1634 dyn_cast<HLSLRootSignatureDecl>(R.getFoundDecl())) {
1635 D->addAttr(::new (getASTContext()) RootSignatureAttr(
1636 getASTContext(), AL, Ident, SignatureDecl));
1637 }
1638}
1639
1641 llvm::VersionTuple SMVersion =
1642 getASTContext().getTargetInfo().getTriple().getOSVersion();
1643 bool IsDXIL = getASTContext().getTargetInfo().getTriple().getArch() ==
1644 llvm::Triple::dxil;
1645
1646 uint32_t ZMax = 1024;
1647 uint32_t ThreadMax = 1024;
1648 if (IsDXIL && SMVersion.getMajor() <= 4) {
1649 ZMax = 1;
1650 ThreadMax = 768;
1651 } else if (IsDXIL && SMVersion.getMajor() == 5) {
1652 ZMax = 64;
1653 ThreadMax = 1024;
1654 }
1655
1656 uint32_t X;
1657 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), X))
1658 return;
1659 if (X > 1024) {
1660 Diag(AL.getArgAsExpr(0)->getExprLoc(),
1661 diag::err_hlsl_numthreads_argument_oor)
1662 << 0 << 1024;
1663 return;
1664 }
1665 uint32_t Y;
1666 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Y))
1667 return;
1668 if (Y > 1024) {
1669 Diag(AL.getArgAsExpr(1)->getExprLoc(),
1670 diag::err_hlsl_numthreads_argument_oor)
1671 << 1 << 1024;
1672 return;
1673 }
1674 uint32_t Z;
1675 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Z))
1676 return;
1677 if (Z > ZMax) {
1678 SemaRef.Diag(AL.getArgAsExpr(2)->getExprLoc(),
1679 diag::err_hlsl_numthreads_argument_oor)
1680 << 2 << ZMax;
1681 return;
1682 }
1683
1684 if (X * Y * Z > ThreadMax) {
1685 Diag(AL.getLoc(), diag::err_hlsl_numthreads_invalid) << ThreadMax;
1686 return;
1687 }
1688
1689 HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z);
1690 if (NewAttr)
1691 D->addAttr(NewAttr);
1692}
1693
1694static bool isValidWaveSizeValue(unsigned Value) {
1695 return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128;
1696}
1697
1699 // validate that the wavesize argument is a power of 2 between 4 and 128
1700 // inclusive
1701 unsigned SpelledArgsCount = AL.getNumArgs();
1702 if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
1703 return;
1704
1705 uint32_t Min;
1706 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Min))
1707 return;
1708
1709 uint32_t Max = 0;
1710 if (SpelledArgsCount > 1 &&
1711 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Max))
1712 return;
1713
1714 uint32_t Preferred = 0;
1715 if (SpelledArgsCount > 2 &&
1716 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Preferred))
1717 return;
1718
1719 if (SpelledArgsCount > 2) {
1720 if (!isValidWaveSizeValue(Preferred)) {
1721 Diag(AL.getArgAsExpr(2)->getExprLoc(),
1722 diag::err_attribute_power_of_two_in_range)
1723 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize
1724 << Preferred;
1725 return;
1726 }
1727 // Preferred not in range.
1728 if (Preferred < Min || Preferred > Max) {
1729 Diag(AL.getArgAsExpr(2)->getExprLoc(),
1730 diag::err_attribute_power_of_two_in_range)
1731 << AL << Min << Max << Preferred;
1732 return;
1733 }
1734 } else if (SpelledArgsCount > 1) {
1735 if (!isValidWaveSizeValue(Max)) {
1736 Diag(AL.getArgAsExpr(1)->getExprLoc(),
1737 diag::err_attribute_power_of_two_in_range)
1738 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max;
1739 return;
1740 }
1741 if (Max < Min) {
1742 Diag(AL.getLoc(), diag::err_attribute_argument_invalid) << AL << 1;
1743 return;
1744 } else if (Max == Min) {
1745 Diag(AL.getLoc(), diag::warn_attr_min_eq_max) << AL;
1746 }
1747 } else {
1748 if (!isValidWaveSizeValue(Min)) {
1749 Diag(AL.getArgAsExpr(0)->getExprLoc(),
1750 diag::err_attribute_power_of_two_in_range)
1751 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min;
1752 return;
1753 }
1754 }
1755
1756 HLSLWaveSizeAttr *NewAttr =
1757 mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
1758 if (NewAttr)
1759 D->addAttr(NewAttr);
1760}
1761
1763 uint32_t ID;
1764 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), ID))
1765 return;
1766 D->addAttr(::new (getASTContext())
1767 HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
1768}
1769
1771 D->addAttr(::new (getASTContext())
1772 HLSLVkPushConstantAttr(getASTContext(), AL));
1773}
1774
1776 uint32_t Id;
1777 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
1778 return;
1779 HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
1780 if (NewAttr)
1781 D->addAttr(NewAttr);
1782}
1783
1785 uint32_t Binding = 0;
1786 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Binding))
1787 return;
1788 uint32_t Set = 0;
1789 if (AL.getNumArgs() > 1 &&
1790 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Set))
1791 return;
1792
1793 D->addAttr(::new (getASTContext())
1794 HLSLVkBindingAttr(getASTContext(), AL, Binding, Set));
1795}
1796
1798 uint32_t Location;
1799 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Location))
1800 return;
1801
1802 D->addAttr(::new (getASTContext())
1803 HLSLVkLocationAttr(getASTContext(), AL, Location));
1804}
1805
1807 const auto *VT = T->getAs<VectorType>();
1808
1809 if (!T->hasUnsignedIntegerRepresentation() ||
1810 (VT && VT->getNumElements() > 3)) {
1811 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1812 << AL << "uint/uint2/uint3";
1813 return false;
1814 }
1815
1816 return true;
1817}
1818
1820 const auto *VT = T->getAs<VectorType>();
1821 if (!T->hasFloatingRepresentation() || (VT && VT->getNumElements() > 4)) {
1822 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1823 << AL << "float/float1/float2/float3/float4";
1824 return false;
1825 }
1826
1827 return true;
1828}
1829
1831 std::optional<unsigned> Index) {
1832 std::string SemanticName = AL.getAttrName()->getName().upper();
1833
1834 auto *VD = cast<ValueDecl>(D);
1835 QualType ValueType = VD->getType();
1836 if (auto *FD = dyn_cast<FunctionDecl>(D))
1837 ValueType = FD->getReturnType();
1838
1839 bool IsOutput = false;
1840 if (HLSLParamModifierAttr *MA = D->getAttr<HLSLParamModifierAttr>()) {
1841 if (MA->isOut()) {
1842 IsOutput = true;
1843 ValueType = cast<ReferenceType>(ValueType)->getPointeeType();
1844 }
1845 }
1846
1847 if (SemanticName == "SV_DISPATCHTHREADID") {
1848 diagnoseInputIDType(ValueType, AL);
1849 if (IsOutput)
1850 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1851 if (Index.has_value())
1852 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1854 return;
1855 }
1856
1857 if (SemanticName == "SV_GROUPINDEX") {
1858 if (IsOutput)
1859 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1860 if (Index.has_value())
1861 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1863 return;
1864 }
1865
1866 if (SemanticName == "SV_GROUPTHREADID") {
1867 diagnoseInputIDType(ValueType, AL);
1868 if (IsOutput)
1869 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1870 if (Index.has_value())
1871 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1873 return;
1874 }
1875
1876 if (SemanticName == "SV_GROUPID") {
1877 diagnoseInputIDType(ValueType, AL);
1878 if (IsOutput)
1879 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1880 if (Index.has_value())
1881 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1883 return;
1884 }
1885
1886 if (SemanticName == "SV_POSITION") {
1887 const auto *VT = ValueType->getAs<VectorType>();
1888 if (!ValueType->hasFloatingRepresentation() ||
1889 (VT && VT->getNumElements() > 4))
1890 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1891 << AL << "float/float1/float2/float3/float4";
1893 return;
1894 }
1895
1896 if (SemanticName == "SV_TARGET") {
1897 const auto *VT = ValueType->getAs<VectorType>();
1898 if (!ValueType->hasFloatingRepresentation() ||
1899 (VT && VT->getNumElements() > 4))
1900 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1901 << AL << "float/float1/float2/float3/float4";
1903 return;
1904 }
1905
1906 Diag(AL.getLoc(), diag::err_hlsl_unknown_semantic) << AL;
1907}
1908
1910 uint32_t IndexValue(0), ExplicitIndex(0);
1911 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), IndexValue) ||
1912 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), ExplicitIndex)) {
1913 assert(0 && "HLSLUnparsedSemantic is expected to have 2 int arguments.");
1914 }
1915 assert(IndexValue > 0 ? ExplicitIndex : true);
1916 std::optional<unsigned> Index =
1917 ExplicitIndex ? std::optional<unsigned>(IndexValue) : std::nullopt;
1918
1919 if (AL.getAttrName()->getName().starts_with_insensitive("SV_"))
1920 diagnoseSystemSemanticAttr(D, AL, Index);
1921 else
1923}
1924
1927 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)
1928 << AL << "shader constant in a constant buffer";
1929 return;
1930 }
1931
1932 uint32_t SubComponent;
1933 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), SubComponent))
1934 return;
1935 uint32_t Component;
1936 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Component))
1937 return;
1938
1939 QualType T = cast<VarDecl>(D)->getType().getCanonicalType();
1940 // Check if T is an array or struct type.
1941 // TODO: mark matrix type as aggregate type.
1942 bool IsAggregateTy = (T->isArrayType() || T->isStructureType());
1943
1944 // Check Component is valid for T.
1945 if (Component) {
1946 unsigned Size = getASTContext().getTypeSize(T);
1947 if (IsAggregateTy) {
1948 Diag(AL.getLoc(), diag::err_hlsl_invalid_register_or_packoffset);
1949 return;
1950 } else {
1951 // Make sure Component + sizeof(T) <= 4.
1952 if ((Component * 32 + Size) > 128) {
1953 Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);
1954 return;
1955 }
1956 QualType EltTy = T;
1957 if (const auto *VT = T->getAs<VectorType>())
1958 EltTy = VT->getElementType();
1959 unsigned Align = getASTContext().getTypeAlign(EltTy);
1960 if (Align > 32 && Component == 1) {
1961 // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary.
1962 // So we only need to check Component 1 here.
1963 Diag(AL.getLoc(), diag::err_hlsl_packoffset_alignment_mismatch)
1964 << Align << EltTy;
1965 return;
1966 }
1967 }
1968 }
1969
1970 D->addAttr(::new (getASTContext()) HLSLPackOffsetAttr(
1971 getASTContext(), AL, SubComponent, Component));
1972}
1973
1975 StringRef Str;
1976 SourceLocation ArgLoc;
1977 if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))
1978 return;
1979
1980 llvm::Triple::EnvironmentType ShaderType;
1981 if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) {
1982 Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
1983 << AL << Str << ArgLoc;
1984 return;
1985 }
1986
1987 // FIXME: check function match the shader stage.
1988
1989 HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType);
1990 if (NewAttr)
1991 D->addAttr(NewAttr);
1992}
1993
1995 Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList,
1996 QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo) {
1997 assert(AttrList.size() && "expected list of resource attributes");
1998
1999 QualType ContainedTy = QualType();
2000 TypeSourceInfo *ContainedTyInfo = nullptr;
2001 SourceLocation LocBegin = AttrList[0]->getRange().getBegin();
2002 SourceLocation LocEnd = AttrList[0]->getRange().getEnd();
2003
2004 HLSLAttributedResourceType::Attributes ResAttrs;
2005
2006 bool HasResourceClass = false;
2007 bool HasResourceDimension = false;
2008 for (const Attr *A : AttrList) {
2009 if (!A)
2010 continue;
2011 LocEnd = A->getRange().getEnd();
2012 switch (A->getKind()) {
2013 case attr::HLSLResourceClass: {
2014 ResourceClass RC = cast<HLSLResourceClassAttr>(A)->getResourceClass();
2015 if (HasResourceClass) {
2016 S.Diag(A->getLocation(), ResAttrs.ResourceClass == RC
2017 ? diag::warn_duplicate_attribute_exact
2018 : diag::warn_duplicate_attribute)
2019 << A;
2020 return false;
2021 }
2022 ResAttrs.ResourceClass = RC;
2023 HasResourceClass = true;
2024 break;
2025 }
2026 case attr::HLSLResourceDimension: {
2027 llvm::dxil::ResourceDimension RD =
2028 cast<HLSLResourceDimensionAttr>(A)->getDimension();
2029 if (HasResourceDimension) {
2030 S.Diag(A->getLocation(), ResAttrs.ResourceDimension == RD
2031 ? diag::warn_duplicate_attribute_exact
2032 : diag::warn_duplicate_attribute)
2033 << A;
2034 return false;
2035 }
2036 ResAttrs.ResourceDimension = RD;
2037 HasResourceDimension = true;
2038 break;
2039 }
2040 case attr::HLSLROV:
2041 if (ResAttrs.IsROV) {
2042 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
2043 return false;
2044 }
2045 ResAttrs.IsROV = true;
2046 break;
2047 case attr::HLSLRawBuffer:
2048 if (ResAttrs.RawBuffer) {
2049 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
2050 return false;
2051 }
2052 ResAttrs.RawBuffer = true;
2053 break;
2054 case attr::HLSLIsCounter:
2055 if (ResAttrs.IsCounter) {
2056 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
2057 return false;
2058 }
2059 ResAttrs.IsCounter = true;
2060 break;
2061 case attr::HLSLContainedType: {
2062 const HLSLContainedTypeAttr *CTAttr = cast<HLSLContainedTypeAttr>(A);
2063 QualType Ty = CTAttr->getType();
2064 if (!ContainedTy.isNull()) {
2065 S.Diag(A->getLocation(), ContainedTy == Ty
2066 ? diag::warn_duplicate_attribute_exact
2067 : diag::warn_duplicate_attribute)
2068 << A;
2069 return false;
2070 }
2071 ContainedTy = Ty;
2072 ContainedTyInfo = CTAttr->getTypeLoc();
2073 break;
2074 }
2075 default:
2076 llvm_unreachable("unhandled resource attribute type");
2077 }
2078 }
2079
2080 if (!HasResourceClass) {
2081 S.Diag(AttrList.back()->getRange().getEnd(),
2082 diag::err_hlsl_missing_resource_class);
2083 return false;
2084 }
2085
2087 Wrapped, ContainedTy, ResAttrs);
2088
2089 if (LocInfo && ContainedTyInfo) {
2090 LocInfo->Range = SourceRange(LocBegin, LocEnd);
2091 LocInfo->ContainedTyInfo = ContainedTyInfo;
2092 }
2093 return true;
2094}
2095
2096// Validates and creates an HLSL attribute that is applied as type attribute on
2097// HLSL resource. The attributes are collected in HLSLResourcesTypeAttrs and at
2098// the end of the declaration they are applied to the declaration type by
2099// wrapping it in HLSLAttributedResourceType.
2101 // only allow resource type attributes on intangible types
2102 if (!T->isHLSLResourceType()) {
2103 Diag(AL.getLoc(), diag::err_hlsl_attribute_needs_intangible_type)
2104 << AL << getASTContext().HLSLResourceTy;
2105 return false;
2106 }
2107
2108 // validate number of arguments
2109 if (!AL.checkExactlyNumArgs(SemaRef, AL.getMinArgs()))
2110 return false;
2111
2112 Attr *A = nullptr;
2113
2117 {
2118 AttributeCommonInfo::AS_CXX11, 0, false /*IsAlignas*/,
2119 false /*IsRegularKeywordAttribute*/
2120 });
2121
2122 switch (AL.getKind()) {
2123 case ParsedAttr::AT_HLSLResourceClass: {
2124 if (!AL.isArgIdent(0)) {
2125 Diag(AL.getLoc(), diag::err_attribute_argument_type)
2126 << AL << AANT_ArgumentIdentifier;
2127 return false;
2128 }
2129
2130 IdentifierLoc *Loc = AL.getArgAsIdent(0);
2131 StringRef Identifier = Loc->getIdentifierInfo()->getName();
2132 SourceLocation ArgLoc = Loc->getLoc();
2133
2134 // Validate resource class value
2135 ResourceClass RC;
2136 if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Identifier, RC)) {
2137 Diag(ArgLoc, diag::warn_attribute_type_not_supported)
2138 << "ResourceClass" << Identifier;
2139 return false;
2140 }
2141 A = HLSLResourceClassAttr::Create(getASTContext(), RC, ACI);
2142 break;
2143 }
2144
2145 case ParsedAttr::AT_HLSLROV:
2146 A = HLSLROVAttr::Create(getASTContext(), ACI);
2147 break;
2148
2149 case ParsedAttr::AT_HLSLRawBuffer:
2150 A = HLSLRawBufferAttr::Create(getASTContext(), ACI);
2151 break;
2152
2153 case ParsedAttr::AT_HLSLIsCounter:
2154 A = HLSLIsCounterAttr::Create(getASTContext(), ACI);
2155 break;
2156
2157 case ParsedAttr::AT_HLSLContainedType: {
2158 if (AL.getNumArgs() != 1 && !AL.hasParsedType()) {
2159 Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1;
2160 return false;
2161 }
2162
2163 TypeSourceInfo *TSI = nullptr;
2164 QualType QT = SemaRef.GetTypeFromParser(AL.getTypeArg(), &TSI);
2165 assert(TSI && "no type source info for attribute argument");
2166 if (SemaRef.RequireCompleteType(TSI->getTypeLoc().getBeginLoc(), QT,
2167 diag::err_incomplete_type))
2168 return false;
2169 A = HLSLContainedTypeAttr::Create(getASTContext(), TSI, ACI);
2170 break;
2171 }
2172
2173 default:
2174 llvm_unreachable("unhandled HLSL attribute");
2175 }
2176
2177 HLSLResourcesTypeAttrs.emplace_back(A);
2178 return true;
2179}
2180
2181// Combines all resource type attributes and creates HLSLAttributedResourceType.
2183 if (!HLSLResourcesTypeAttrs.size())
2184 return CurrentType;
2185
2186 QualType QT = CurrentType;
2189 HLSLResourcesTypeAttrs, QT, &LocInfo)) {
2190 const HLSLAttributedResourceType *RT =
2192
2193 // Temporarily store TypeLoc information for the new type.
2194 // It will be transferred to HLSLAttributesResourceTypeLoc
2195 // shortly after the type is created by TypeSpecLocFiller which
2196 // will call the TakeLocForHLSLAttribute method below.
2197 LocsForHLSLAttributedResources.insert(std::pair(RT, LocInfo));
2198 }
2199 HLSLResourcesTypeAttrs.clear();
2200 return QT;
2201}
2202
2203// Returns source location for the HLSLAttributedResourceType
2205SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
2206 HLSLAttributedResourceLocInfo LocInfo = {};
2207 auto I = LocsForHLSLAttributedResources.find(RT);
2208 if (I != LocsForHLSLAttributedResources.end()) {
2209 LocInfo = I->second;
2210 LocsForHLSLAttributedResources.erase(I);
2211 return LocInfo;
2212 }
2213 LocInfo.Range = SourceRange();
2214 return LocInfo;
2215}
2216
2217// Walks though the global variable declaration, collects all resource binding
2218// requirements and adds them to Bindings
2219void SemaHLSL::collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
2220 const RecordType *RT) {
2221 const RecordDecl *RD = RT->getDecl()->getDefinitionOrSelf();
2222 for (FieldDecl *FD : RD->fields()) {
2223 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
2224
2225 // Unwrap arrays
2226 // FIXME: Calculate array size while unwrapping
2227 assert(!Ty->isIncompleteArrayType() &&
2228 "incomplete arrays inside user defined types are not supported");
2229 while (Ty->isConstantArrayType()) {
2232 }
2233
2234 if (!Ty->isRecordType())
2235 continue;
2236
2237 if (const HLSLAttributedResourceType *AttrResType =
2238 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
2239 // Add a new DeclBindingInfo to Bindings if it does not already exist
2240 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
2241 DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, RC);
2242 if (!DBI)
2243 Bindings.addDeclBindingInfo(VD, RC);
2244 } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
2245 // Recursively scan embedded struct or class; it would be nice to do this
2246 // without recursion, but tricky to correctly calculate the size of the
2247 // binding, which is something we are probably going to need to do later
2248 // on. Hopefully nesting of structs in structs too many levels is
2249 // unlikely.
2250 collectResourceBindingsOnUserRecordDecl(VD, RT);
2251 }
2252 }
2253}
2254
2255// Diagnose localized register binding errors for a single binding; does not
2256// diagnose resource binding on user record types, that will be done later
2257// in processResourceBindingOnDecl based on the information collected in
2258// collectResourceBindingsOnVarDecl.
2259// Returns false if the register binding is not valid.
2261 Decl *D, RegisterType RegType,
2262 bool SpecifiedSpace) {
2263 int RegTypeNum = static_cast<int>(RegType);
2264
2265 // check if the decl type is groupshared
2266 if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
2267 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2268 return false;
2269 }
2270
2271 // Cbuffers and Tbuffers are HLSLBufferDecl types
2272 if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
2273 ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
2274 : ResourceClass::SRV;
2275 if (RegType == getRegisterType(RC))
2276 return true;
2277
2278 S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
2279 << RegTypeNum;
2280 return false;
2281 }
2282
2283 // Samplers, UAVs, and SRVs are VarDecl types
2284 assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
2285 VarDecl *VD = cast<VarDecl>(D);
2286
2287 // Resource
2288 if (const HLSLAttributedResourceType *AttrResType =
2289 HLSLAttributedResourceType::findHandleTypeOnResource(
2290 VD->getType().getTypePtr())) {
2291 if (RegType == getRegisterType(AttrResType))
2292 return true;
2293
2294 S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
2295 << RegTypeNum;
2296 return false;
2297 }
2298
2299 const clang::Type *Ty = VD->getType().getTypePtr();
2300 while (Ty->isArrayType())
2302
2303 // Basic types
2304 if (Ty->isArithmeticType() || Ty->isVectorType()) {
2305 bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
2306 if (SpecifiedSpace && !DeclaredInCOrTBuffer)
2307 S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
2308
2309 if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(S.getASTContext()) ||
2310 Ty->isFloatingType() || Ty->isVectorType())) {
2311 // Register annotation on default constant buffer declaration ($Globals)
2312 if (RegType == RegisterType::CBuffer)
2313 S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
2314 else if (RegType != RegisterType::C)
2315 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2316 else
2317 return true;
2318 } else {
2319 if (RegType == RegisterType::C)
2320 S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
2321 else
2322 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2323 }
2324 return false;
2325 }
2326 if (Ty->isRecordType())
2327 // RecordTypes will be diagnosed in processResourceBindingOnDecl
2328 // that is called from ActOnVariableDeclarator
2329 return true;
2330
2331 // Anything else is an error
2332 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2333 return false;
2334}
2335
2337 RegisterType regType) {
2338 // make sure that there are no two register annotations
2339 // applied to the decl with the same register type
2340 bool RegisterTypesDetected[5] = {false};
2341 RegisterTypesDetected[static_cast<int>(regType)] = true;
2342
2343 for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
2344 if (HLSLResourceBindingAttr *attr =
2345 dyn_cast<HLSLResourceBindingAttr>(*it)) {
2346
2347 RegisterType otherRegType = attr->getRegisterType();
2348 if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
2349 int otherRegTypeNum = static_cast<int>(otherRegType);
2350 S.Diag(TheDecl->getLocation(),
2351 diag::err_hlsl_duplicate_register_annotation)
2352 << otherRegTypeNum;
2353 return false;
2354 }
2355 RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
2356 }
2357 }
2358 return true;
2359}
2360
2362 Decl *D, RegisterType RegType,
2363 bool SpecifiedSpace) {
2364
2365 // exactly one of these two types should be set
2366 assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
2367 (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
2368 "expecting VarDecl or HLSLBufferDecl");
2369
2370 // check if the declaration contains resource matching the register type
2371 if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
2372 return false;
2373
2374 // next, if multiple register annotations exist, check that none conflict.
2375 return ValidateMultipleRegisterAnnotations(S, D, RegType);
2376}
2377
2378// return false if the slot count exceeds the limit, true otherwise
2379static bool AccumulateHLSLResourceSlots(QualType Ty, uint64_t &StartSlot,
2380 const uint64_t &Limit,
2381 const ResourceClass ResClass,
2382 ASTContext &Ctx,
2383 uint64_t ArrayCount = 1) {
2384 Ty = Ty.getCanonicalType();
2385 const Type *T = Ty.getTypePtr();
2386
2387 // Early exit if already overflowed
2388 if (StartSlot > Limit)
2389 return false;
2390
2391 // Case 1: array type
2392 if (const auto *AT = dyn_cast<ArrayType>(T)) {
2393 uint64_t Count = 1;
2394
2395 if (const auto *CAT = dyn_cast<ConstantArrayType>(AT))
2396 Count = CAT->getSize().getZExtValue();
2397
2398 QualType ElemTy = AT->getElementType();
2399 return AccumulateHLSLResourceSlots(ElemTy, StartSlot, Limit, ResClass, Ctx,
2400 ArrayCount * Count);
2401 }
2402
2403 // Case 2: resource leaf
2404 if (auto ResTy = dyn_cast<HLSLAttributedResourceType>(T)) {
2405 // First ensure this resource counts towards the corresponding
2406 // register type limit.
2407 if (ResTy->getAttrs().ResourceClass != ResClass)
2408 return true;
2409
2410 // Validate highest slot used
2411 uint64_t EndSlot = StartSlot + ArrayCount - 1;
2412 if (EndSlot > Limit)
2413 return false;
2414
2415 // Advance SlotCount past the consumed range
2416 StartSlot = EndSlot + 1;
2417 return true;
2418 }
2419
2420 // Case 3: struct / record
2421 if (const auto *RT = dyn_cast<RecordType>(T)) {
2422 const RecordDecl *RD = RT->getDecl();
2423
2424 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
2425 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
2426 if (!AccumulateHLSLResourceSlots(Base.getType(), StartSlot, Limit,
2427 ResClass, Ctx, ArrayCount))
2428 return false;
2429 }
2430 }
2431
2432 for (const FieldDecl *Field : RD->fields()) {
2433 if (!AccumulateHLSLResourceSlots(Field->getType(), StartSlot, Limit,
2434 ResClass, Ctx, ArrayCount))
2435 return false;
2436 }
2437
2438 return true;
2439 }
2440
2441 // Case 4: everything else
2442 return true;
2443}
2444
2445// return true if there is something invalid, false otherwise
2446static bool ValidateRegisterNumber(uint64_t SlotNum, Decl *TheDecl,
2447 ASTContext &Ctx, RegisterType RegTy) {
2448 const uint64_t Limit = UINT32_MAX;
2449 if (SlotNum > Limit)
2450 return true;
2451
2452 // after verifying the number doesn't exceed uint32max, we don't need
2453 // to look further into c or i register types
2454 if (RegTy == RegisterType::C || RegTy == RegisterType::I)
2455 return false;
2456
2457 if (VarDecl *VD = dyn_cast<VarDecl>(TheDecl)) {
2458 uint64_t BaseSlot = SlotNum;
2459
2460 if (!AccumulateHLSLResourceSlots(VD->getType(), SlotNum, Limit,
2461 getResourceClass(RegTy), Ctx))
2462 return true;
2463
2464 // After AccumulateHLSLResourceSlots runs, SlotNum is now
2465 // the first free slot; last used was SlotNum - 1
2466 return (BaseSlot > Limit);
2467 }
2468 // handle the cbuffer/tbuffer case
2469 if (isa<HLSLBufferDecl>(TheDecl))
2470 // resources cannot be put within a cbuffer, so no need
2471 // to analyze the structure since the register number
2472 // won't be pushed any higher.
2473 return (SlotNum > Limit);
2474
2475 // we don't expect any other decl type, so fail
2476 llvm_unreachable("unexpected decl type");
2477}
2478
2480 if (VarDecl *VD = dyn_cast<VarDecl>(TheDecl)) {
2481 QualType Ty = VD->getType();
2482 if (const auto *IAT = dyn_cast<IncompleteArrayType>(Ty))
2483 Ty = IAT->getElementType();
2484 if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(), Ty,
2485 diag::err_incomplete_type))
2486 return;
2487 }
2488
2489 StringRef Slot = "";
2490 StringRef Space = "";
2491 SourceLocation SlotLoc, SpaceLoc;
2492
2493 if (!AL.isArgIdent(0)) {
2494 Diag(AL.getLoc(), diag::err_attribute_argument_type)
2495 << AL << AANT_ArgumentIdentifier;
2496 return;
2497 }
2498 IdentifierLoc *Loc = AL.getArgAsIdent(0);
2499
2500 if (AL.getNumArgs() == 2) {
2501 Slot = Loc->getIdentifierInfo()->getName();
2502 SlotLoc = Loc->getLoc();
2503 if (!AL.isArgIdent(1)) {
2504 Diag(AL.getLoc(), diag::err_attribute_argument_type)
2505 << AL << AANT_ArgumentIdentifier;
2506 return;
2507 }
2508 Loc = AL.getArgAsIdent(1);
2509 Space = Loc->getIdentifierInfo()->getName();
2510 SpaceLoc = Loc->getLoc();
2511 } else {
2512 StringRef Str = Loc->getIdentifierInfo()->getName();
2513 if (Str.starts_with("space")) {
2514 Space = Str;
2515 SpaceLoc = Loc->getLoc();
2516 } else {
2517 Slot = Str;
2518 SlotLoc = Loc->getLoc();
2519 Space = "space0";
2520 }
2521 }
2522
2523 RegisterType RegType = RegisterType::SRV;
2524 std::optional<unsigned> SlotNum;
2525 unsigned SpaceNum = 0;
2526
2527 // Validate slot
2528 if (!Slot.empty()) {
2529 if (!convertToRegisterType(Slot, &RegType)) {
2530 Diag(SlotLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1);
2531 return;
2532 }
2533 if (RegType == RegisterType::I) {
2534 Diag(SlotLoc, diag::warn_hlsl_deprecated_register_type_i);
2535 return;
2536 }
2537 const StringRef SlotNumStr = Slot.substr(1);
2538
2539 uint64_t N;
2540
2541 // validate that the slot number is a non-empty number
2542 if (SlotNumStr.getAsInteger(10, N)) {
2543 Diag(SlotLoc, diag::err_hlsl_unsupported_register_number);
2544 return;
2545 }
2546
2547 // Validate register number. It should not exceed UINT32_MAX,
2548 // including if the resource type is an array that starts
2549 // before UINT32_MAX, but ends afterwards.
2550 if (ValidateRegisterNumber(N, TheDecl, getASTContext(), RegType)) {
2551 Diag(SlotLoc, diag::err_hlsl_register_number_too_large);
2552 return;
2553 }
2554
2555 // the slot number has been validated and does not exceed UINT32_MAX
2556 SlotNum = (unsigned)N;
2557 }
2558
2559 // Validate space
2560 if (!Space.starts_with("space")) {
2561 Diag(SpaceLoc, diag::err_hlsl_expected_space) << Space;
2562 return;
2563 }
2564 StringRef SpaceNumStr = Space.substr(5);
2565 if (SpaceNumStr.getAsInteger(10, SpaceNum)) {
2566 Diag(SpaceLoc, diag::err_hlsl_expected_space) << Space;
2567 return;
2568 }
2569
2570 // If we have slot, diagnose it is the right register type for the decl
2571 if (SlotNum.has_value())
2572 if (!DiagnoseHLSLRegisterAttribute(SemaRef, SlotLoc, TheDecl, RegType,
2573 !SpaceLoc.isInvalid()))
2574 return;
2575
2576 HLSLResourceBindingAttr *NewAttr =
2577 HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL);
2578 if (NewAttr) {
2579 NewAttr->setBinding(RegType, SlotNum, SpaceNum);
2580 TheDecl->addAttr(NewAttr);
2581 }
2582}
2583
2585 HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr(
2586 D, AL,
2587 static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
2588 if (NewAttr)
2589 D->addAttr(NewAttr);
2590}
2591
2592namespace {
2593
2594/// This class implements HLSL availability diagnostics for default
2595/// and relaxed mode
2596///
2597/// The goal of this diagnostic is to emit an error or warning when an
2598/// unavailable API is found in code that is reachable from the shader
2599/// entry function or from an exported function (when compiling a shader
2600/// library).
2601///
2602/// This is done by traversing the AST of all shader entry point functions
2603/// and of all exported functions, and any functions that are referenced
2604/// from this AST. In other words, any functions that are reachable from
2605/// the entry points.
2606class DiagnoseHLSLAvailability : public DynamicRecursiveASTVisitor {
2607 Sema &SemaRef;
2608
2609 // Stack of functions to be scaned
2611
2612 // Tracks which environments functions have been scanned in.
2613 //
2614 // Maps FunctionDecl to an unsigned number that represents the set of shader
2615 // environments the function has been scanned for.
2616 // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
2617 // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
2618 // (verified by static_asserts in Triple.cpp), we can use it to index
2619 // individual bits in the set, as long as we shift the values to start with 0
2620 // by subtracting the value of llvm::Triple::Pixel first.
2621 //
2622 // The N'th bit in the set will be set if the function has been scanned
2623 // in shader environment whose llvm::Triple::EnvironmentType integer value
2624 // equals (llvm::Triple::Pixel + N).
2625 //
2626 // For example, if a function has been scanned in compute and pixel stage
2627 // environment, the value will be 0x21 (100001 binary) because:
2628 //
2629 // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
2630 // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
2631 //
2632 // A FunctionDecl is mapped to 0 (or not included in the map) if it has not
2633 // been scanned in any environment.
2634 llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
2635
2636 // Do not access these directly, use the get/set methods below to make
2637 // sure the values are in sync
2638 llvm::Triple::EnvironmentType CurrentShaderEnvironment;
2639 unsigned CurrentShaderStageBit;
2640
2641 // True if scanning a function that was already scanned in a different
2642 // shader stage context, and therefore we should not report issues that
2643 // depend only on shader model version because they would be duplicate.
2644 bool ReportOnlyShaderStageIssues;
2645
2646 // Helper methods for dealing with current stage context / environment
2647 void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
2648 static_assert(sizeof(unsigned) >= 4);
2649 assert(HLSLShaderAttr::isValidShaderType(ShaderType));
2650 assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
2651 "ShaderType is too big for this bitmap"); // 31 is reserved for
2652 // "unknown"
2653
2654 unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
2655 CurrentShaderEnvironment = ShaderType;
2656 CurrentShaderStageBit = (1 << bitmapIndex);
2657 }
2658
2659 void SetUnknownShaderStageContext() {
2660 CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment;
2661 CurrentShaderStageBit = (1 << 31);
2662 }
2663
2664 llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const {
2665 return CurrentShaderEnvironment;
2666 }
2667
2668 bool InUnknownShaderStageContext() const {
2669 return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment;
2670 }
2671
2672 // Helper methods for dealing with shader stage bitmap
2673 void AddToScannedFunctions(const FunctionDecl *FD) {
2674 unsigned &ScannedStages = ScannedDecls[FD];
2675 ScannedStages |= CurrentShaderStageBit;
2676 }
2677
2678 unsigned GetScannedStages(const FunctionDecl *FD) { return ScannedDecls[FD]; }
2679
2680 bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) {
2681 return WasAlreadyScannedInCurrentStage(GetScannedStages(FD));
2682 }
2683
2684 bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) {
2685 return ScannerStages & CurrentShaderStageBit;
2686 }
2687
2688 static bool NeverBeenScanned(unsigned ScannedStages) {
2689 return ScannedStages == 0;
2690 }
2691
2692 // Scanning methods
2693 void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr);
2694 void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA,
2695 SourceRange Range);
2696 const AvailabilityAttr *FindAvailabilityAttr(const Decl *D);
2697 bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA);
2698
2699public:
2700 DiagnoseHLSLAvailability(Sema &SemaRef)
2701 : SemaRef(SemaRef),
2702 CurrentShaderEnvironment(llvm::Triple::UnknownEnvironment),
2703 CurrentShaderStageBit(0), ReportOnlyShaderStageIssues(false) {}
2704
2705 // AST traversal methods
2706 void RunOnTranslationUnit(const TranslationUnitDecl *TU);
2707 void RunOnFunction(const FunctionDecl *FD);
2708
2709 bool VisitDeclRefExpr(DeclRefExpr *DRE) override {
2710 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(DRE->getDecl());
2711 if (FD)
2712 HandleFunctionOrMethodRef(FD, DRE);
2713 return true;
2714 }
2715
2716 bool VisitMemberExpr(MemberExpr *ME) override {
2717 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(ME->getMemberDecl());
2718 if (FD)
2719 HandleFunctionOrMethodRef(FD, ME);
2720 return true;
2721 }
2722};
2723
2724void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD,
2725 Expr *RefExpr) {
2726 assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) &&
2727 "expected DeclRefExpr or MemberExpr");
2728
2729 // has a definition -> add to stack to be scanned
2730 const FunctionDecl *FDWithBody = nullptr;
2731 if (FD->hasBody(FDWithBody)) {
2732 if (!WasAlreadyScannedInCurrentStage(FDWithBody))
2733 DeclsToScan.push_back(FDWithBody);
2734 return;
2735 }
2736
2737 // no body -> diagnose availability
2738 const AvailabilityAttr *AA = FindAvailabilityAttr(FD);
2739 if (AA)
2740 CheckDeclAvailability(
2741 FD, AA, SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc()));
2742}
2743
2744void DiagnoseHLSLAvailability::RunOnTranslationUnit(
2745 const TranslationUnitDecl *TU) {
2746
2747 // Iterate over all shader entry functions and library exports, and for those
2748 // that have a body (definiton), run diag scan on each, setting appropriate
2749 // shader environment context based on whether it is a shader entry function
2750 // or an exported function. Exported functions can be in namespaces and in
2751 // export declarations so we need to scan those declaration contexts as well.
2753 DeclContextsToScan.push_back(TU);
2754
2755 while (!DeclContextsToScan.empty()) {
2756 const DeclContext *DC = DeclContextsToScan.pop_back_val();
2757 for (auto &D : DC->decls()) {
2758 // do not scan implicit declaration generated by the implementation
2759 if (D->isImplicit())
2760 continue;
2761
2762 // for namespace or export declaration add the context to the list to be
2763 // scanned later
2764 if (llvm::dyn_cast<NamespaceDecl>(D) || llvm::dyn_cast<ExportDecl>(D)) {
2765 DeclContextsToScan.push_back(llvm::dyn_cast<DeclContext>(D));
2766 continue;
2767 }
2768
2769 // skip over other decls or function decls without body
2770 const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(D);
2771 if (!FD || !FD->isThisDeclarationADefinition())
2772 continue;
2773
2774 // shader entry point
2775 if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) {
2776 SetShaderStageContext(ShaderAttr->getType());
2777 RunOnFunction(FD);
2778 continue;
2779 }
2780 // exported library function
2781 // FIXME: replace this loop with external linkage check once issue #92071
2782 // is resolved
2783 bool isExport = FD->isInExportDeclContext();
2784 if (!isExport) {
2785 for (const auto *Redecl : FD->redecls()) {
2786 if (Redecl->isInExportDeclContext()) {
2787 isExport = true;
2788 break;
2789 }
2790 }
2791 }
2792 if (isExport) {
2793 SetUnknownShaderStageContext();
2794 RunOnFunction(FD);
2795 continue;
2796 }
2797 }
2798 }
2799}
2800
2801void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) {
2802 assert(DeclsToScan.empty() && "DeclsToScan should be empty");
2803 DeclsToScan.push_back(FD);
2804
2805 while (!DeclsToScan.empty()) {
2806 // Take one decl from the stack and check it by traversing its AST.
2807 // For any CallExpr found during the traversal add it's callee to the top of
2808 // the stack to be processed next. Functions already processed are stored in
2809 // ScannedDecls.
2810 const FunctionDecl *FD = DeclsToScan.pop_back_val();
2811
2812 // Decl was already scanned
2813 const unsigned ScannedStages = GetScannedStages(FD);
2814 if (WasAlreadyScannedInCurrentStage(ScannedStages))
2815 continue;
2816
2817 ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages);
2818
2819 AddToScannedFunctions(FD);
2820 TraverseStmt(FD->getBody());
2821 }
2822}
2823
2824bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone(
2825 const AvailabilityAttr *AA) {
2826 const IdentifierInfo *IIEnvironment = AA->getEnvironment();
2827 if (!IIEnvironment)
2828 return true;
2829
2830 llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment();
2831 if (CurrentEnv == llvm::Triple::UnknownEnvironment)
2832 return false;
2833
2834 llvm::Triple::EnvironmentType AttrEnv =
2835 AvailabilityAttr::getEnvironmentType(IIEnvironment->getName());
2836
2837 return CurrentEnv == AttrEnv;
2838}
2839
2840const AvailabilityAttr *
2841DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) {
2842 AvailabilityAttr const *PartialMatch = nullptr;
2843 // Check each AvailabilityAttr to find the one for this platform.
2844 // For multiple attributes with the same platform try to find one for this
2845 // environment.
2846 for (const auto *A : D->attrs()) {
2847 if (const auto *Avail = dyn_cast<AvailabilityAttr>(A)) {
2848 StringRef AttrPlatform = Avail->getPlatform()->getName();
2849 StringRef TargetPlatform =
2851
2852 // Match the platform name.
2853 if (AttrPlatform == TargetPlatform) {
2854 // Find the best matching attribute for this environment
2855 if (HasMatchingEnvironmentOrNone(Avail))
2856 return Avail;
2857 PartialMatch = Avail;
2858 }
2859 }
2860 }
2861 return PartialMatch;
2862}
2863
2864// Check availability against target shader model version and current shader
2865// stage and emit diagnostic
2866void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,
2867 const AvailabilityAttr *AA,
2868 SourceRange Range) {
2869
2870 const IdentifierInfo *IIEnv = AA->getEnvironment();
2871
2872 if (!IIEnv) {
2873 // The availability attribute does not have environment -> it depends only
2874 // on shader model version and not on specific the shader stage.
2875
2876 // Skip emitting the diagnostics if the diagnostic mode is set to
2877 // strict (-fhlsl-strict-availability) because all relevant diagnostics
2878 // were already emitted in the DiagnoseUnguardedAvailability scan
2879 // (SemaAvailability.cpp).
2880 if (SemaRef.getLangOpts().HLSLStrictAvailability)
2881 return;
2882
2883 // Do not report shader-stage-independent issues if scanning a function
2884 // that was already scanned in a different shader stage context (they would
2885 // be duplicate)
2886 if (ReportOnlyShaderStageIssues)
2887 return;
2888
2889 } else {
2890 // The availability attribute has environment -> we need to know
2891 // the current stage context to property diagnose it.
2892 if (InUnknownShaderStageContext())
2893 return;
2894 }
2895
2896 // Check introduced version and if environment matches
2897 bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA);
2898 VersionTuple Introduced = AA->getIntroduced();
2899 VersionTuple TargetVersion =
2901
2902 if (TargetVersion >= Introduced && EnvironmentMatches)
2903 return;
2904
2905 // Emit diagnostic message
2906 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
2907 llvm::StringRef PlatformName(
2908 AvailabilityAttr::getPrettyPlatformName(TI.getPlatformName()));
2909
2910 llvm::StringRef CurrentEnvStr =
2911 llvm::Triple::getEnvironmentTypeName(GetCurrentShaderEnvironment());
2912
2913 llvm::StringRef AttrEnvStr =
2914 AA->getEnvironment() ? AA->getEnvironment()->getName() : "";
2915 bool UseEnvironment = !AttrEnvStr.empty();
2916
2917 if (EnvironmentMatches) {
2918 SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability)
2919 << Range << D << PlatformName << Introduced.getAsString()
2920 << UseEnvironment << CurrentEnvStr;
2921 } else {
2922 SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability_unavailable)
2923 << Range << D;
2924 }
2925
2926 SemaRef.Diag(D->getLocation(), diag::note_partial_availability_specified_here)
2927 << D << PlatformName << Introduced.getAsString()
2928 << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString()
2929 << UseEnvironment << AttrEnvStr << CurrentEnvStr;
2930}
2931
2932} // namespace
2933
2935 // process default CBuffer - create buffer layout struct and invoke codegenCGH
2936 if (!DefaultCBufferDecls.empty()) {
2938 SemaRef.getASTContext(), SemaRef.getCurLexicalContext(),
2939 DefaultCBufferDecls);
2940 addImplicitBindingAttrToDecl(SemaRef, DefaultCBuffer, RegisterType::CBuffer,
2941 getNextImplicitBindingOrderID());
2942 SemaRef.getCurLexicalContext()->addDecl(DefaultCBuffer);
2944
2945 // Set HasValidPackoffset if any of the decls has a register(c#) annotation;
2946 for (const Decl *VD : DefaultCBufferDecls) {
2947 const HLSLResourceBindingAttr *RBA =
2948 VD->getAttr<HLSLResourceBindingAttr>();
2949 if (RBA && RBA->hasRegisterSlot() &&
2950 RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) {
2951 DefaultCBuffer->setHasValidPackoffset(true);
2952 break;
2953 }
2954 }
2955
2956 DeclGroupRef DG(DefaultCBuffer);
2957 SemaRef.Consumer.HandleTopLevelDecl(DG);
2958 }
2959 diagnoseAvailabilityViolations(TU);
2960}
2961
2962void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
2963 // Skip running the diagnostics scan if the diagnostic mode is
2964 // strict (-fhlsl-strict-availability) and the target shader stage is known
2965 // because all relevant diagnostics were already emitted in the
2966 // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp).
2968 if (SemaRef.getLangOpts().HLSLStrictAvailability &&
2969 TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library)
2970 return;
2971
2972 DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
2973}
2974
2975static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
2976 assert(TheCall->getNumArgs() > 1);
2977 QualType ArgTy0 = TheCall->getArg(0)->getType();
2978
2979 for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) {
2981 ArgTy0, TheCall->getArg(I)->getType())) {
2982 S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
2983 << TheCall->getDirectCallee() << /*useAllTerminology*/ true
2984 << SourceRange(TheCall->getArg(0)->getBeginLoc(),
2985 TheCall->getArg(N - 1)->getEndLoc());
2986 return true;
2987 }
2988 }
2989 return false;
2990}
2991
2993 QualType ArgType = Arg->getType();
2995 S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
2996 << ArgType << ExpectedType << 1 << 0 << 0;
2997 return true;
2998 }
2999 return false;
3000}
3001
3003 Sema *S, CallExpr *TheCall,
3004 llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
3005 clang::QualType PassedType)>
3006 Check) {
3007 for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
3008 Expr *Arg = TheCall->getArg(I);
3009 if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
3010 return true;
3011 }
3012 return false;
3013}
3014
3016 int ArgOrdinal,
3017 clang::QualType PassedType) {
3018 clang::QualType BaseType =
3019 PassedType->isVectorType()
3020 ? PassedType->castAs<clang::VectorType>()->getElementType()
3021 : PassedType;
3022 if (!BaseType->isFloat32Type())
3023 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3024 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3025 << /* float */ 1 << PassedType;
3026 return false;
3027}
3028
3030 int ArgOrdinal,
3031 clang::QualType PassedType) {
3032 clang::QualType BaseType =
3033 PassedType->isVectorType()
3034 ? PassedType->castAs<clang::VectorType>()->getElementType()
3035 : PassedType;
3036 if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
3037 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3038 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3039 << /* half or float */ 2 << PassedType;
3040 return false;
3041}
3042
3043static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
3044 unsigned ArgIndex) {
3045 auto *Arg = TheCall->getArg(ArgIndex);
3046 SourceLocation OrigLoc = Arg->getExprLoc();
3047 if (Arg->IgnoreCasts()->isModifiableLvalue(S->Context, &OrigLoc) ==
3049 return false;
3050 S->Diag(OrigLoc, diag::error_hlsl_inout_lvalue) << Arg << 0;
3051 return true;
3052}
3053
3054static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
3055 clang::QualType PassedType) {
3056 const auto *VecTy = PassedType->getAs<VectorType>();
3057 if (!VecTy)
3058 return false;
3059
3060 if (VecTy->getElementType()->isDoubleType())
3061 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3062 << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
3063 << PassedType;
3064 return false;
3065}
3066
3068 int ArgOrdinal,
3069 clang::QualType PassedType) {
3070 if (!PassedType->hasIntegerRepresentation() &&
3071 !PassedType->hasFloatingRepresentation())
3072 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3073 << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
3074 << /* fp */ 1 << PassedType;
3075 return false;
3076}
3077
3079 int ArgOrdinal,
3080 clang::QualType PassedType) {
3081 if (auto *VecTy = PassedType->getAs<VectorType>())
3082 if (VecTy->getElementType()->isUnsignedIntegerType())
3083 return false;
3084
3085 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3086 << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
3087 << PassedType;
3088}
3089
3090// checks for unsigned ints of all sizes
3092 int ArgOrdinal,
3093 clang::QualType PassedType) {
3094 if (!PassedType->hasUnsignedIntegerRepresentation())
3095 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3096 << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
3097 << /* no fp */ 0 << PassedType;
3098 return false;
3099}
3100
3101static bool CheckExpectedBitWidth(Sema *S, CallExpr *TheCall,
3102 unsigned ArgOrdinal, unsigned Width) {
3103 QualType ArgTy = TheCall->getArg(0)->getType();
3104 if (auto *VTy = ArgTy->getAs<VectorType>())
3105 ArgTy = VTy->getElementType();
3106 // ensure arg type has expected bit width
3107 uint64_t ElementBitCount =
3109 if (ElementBitCount != Width) {
3110 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3111 diag::err_integer_incorrect_bit_count)
3112 << Width << ElementBitCount;
3113 return true;
3114 }
3115 return false;
3116}
3117
3119 QualType ReturnType) {
3120 auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>();
3121 if (VecTyA)
3122 ReturnType =
3123 S->Context.getExtVectorType(ReturnType, VecTyA->getNumElements());
3124
3125 TheCall->setType(ReturnType);
3126}
3127
3128static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
3129 unsigned ArgIndex) {
3130 assert(TheCall->getNumArgs() >= ArgIndex);
3131 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3132 auto *VTy = ArgType->getAs<VectorType>();
3133 // not the scalar or vector<scalar>
3134 if (!(S->Context.hasSameUnqualifiedType(ArgType, Scalar) ||
3135 (VTy &&
3136 S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar)))) {
3137 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3138 diag::err_typecheck_expect_scalar_or_vector)
3139 << ArgType << Scalar;
3140 return true;
3141 }
3142 return false;
3143}
3144
3146 QualType Scalar, unsigned ArgIndex) {
3147 assert(TheCall->getNumArgs() > ArgIndex);
3148
3149 Expr *Arg = TheCall->getArg(ArgIndex);
3150 QualType ArgType = Arg->getType();
3151
3152 // Scalar: T
3153 if (S->Context.hasSameUnqualifiedType(ArgType, Scalar))
3154 return false;
3155
3156 // Vector: vector<T>
3157 if (const auto *VTy = ArgType->getAs<VectorType>()) {
3158 if (S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar))
3159 return false;
3160 }
3161
3162 // Matrix: ConstantMatrixType with element type T
3163 if (const auto *MTy = ArgType->getAs<ConstantMatrixType>()) {
3164 if (S->Context.hasSameUnqualifiedType(MTy->getElementType(), Scalar))
3165 return false;
3166 }
3167
3168 // Not a scalar/vector/matrix-of-scalar
3169 S->Diag(Arg->getBeginLoc(),
3170 diag::err_typecheck_expect_scalar_or_vector_or_matrix)
3171 << ArgType << Scalar;
3172 return true;
3173}
3174
3175static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
3176 unsigned ArgIndex) {
3177 assert(TheCall->getNumArgs() >= ArgIndex);
3178 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3179 auto *VTy = ArgType->getAs<VectorType>();
3180 // not the scalar or vector<scalar>
3181 if (!(ArgType->isScalarType() ||
3182 (VTy && VTy->getElementType()->isScalarType()))) {
3183 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3184 diag::err_typecheck_expect_any_scalar_or_vector)
3185 << ArgType << 1;
3186 return true;
3187 }
3188 return false;
3189}
3190
3191// Check that the argument is not a bool or vector<bool>
3192// Returns true on error
3194 unsigned ArgIndex) {
3195 QualType BoolType = S->getASTContext().BoolTy;
3196 assert(ArgIndex < TheCall->getNumArgs());
3197 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3198 auto *VTy = ArgType->getAs<VectorType>();
3199 // is the bool or vector<bool>
3200 if (S->Context.hasSameUnqualifiedType(ArgType, BoolType) ||
3201 (VTy &&
3202 S->Context.hasSameUnqualifiedType(VTy->getElementType(), BoolType))) {
3203 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3204 diag::err_typecheck_expect_any_scalar_or_vector)
3205 << ArgType << 0;
3206 return true;
3207 }
3208 return false;
3209}
3210
3211static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
3212 if (CheckNotBoolScalarOrVector(S, TheCall, 0))
3213 return true;
3214 return false;
3215}
3216
3217static bool CheckWavePrefix(Sema *S, CallExpr *TheCall) {
3218 if (CheckNotBoolScalarOrVector(S, TheCall, 0))
3219 return true;
3220 return false;
3221}
3222
3223static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
3224 assert(TheCall->getNumArgs() == 3);
3225 Expr *Arg1 = TheCall->getArg(1);
3226 Expr *Arg2 = TheCall->getArg(2);
3227 if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
3228 S->Diag(TheCall->getBeginLoc(),
3229 diag::err_typecheck_call_different_arg_types)
3230 << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
3231 << Arg2->getSourceRange();
3232 return true;
3233 }
3234
3235 TheCall->setType(Arg1->getType());
3236 return false;
3237}
3238
3239static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
3240 assert(TheCall->getNumArgs() == 3);
3241 Expr *Arg1 = TheCall->getArg(1);
3242 QualType Arg1Ty = Arg1->getType();
3243 Expr *Arg2 = TheCall->getArg(2);
3244 QualType Arg2Ty = Arg2->getType();
3245
3246 QualType Arg1ScalarTy = Arg1Ty;
3247 if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
3248 Arg1ScalarTy = VTy->getElementType();
3249
3250 QualType Arg2ScalarTy = Arg2Ty;
3251 if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
3252 Arg2ScalarTy = VTy->getElementType();
3253
3254 if (!S->Context.hasSameUnqualifiedType(Arg1ScalarTy, Arg2ScalarTy))
3255 S->Diag(Arg1->getBeginLoc(), diag::err_hlsl_builtin_scalar_vector_mismatch)
3256 << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
3257
3258 QualType Arg0Ty = TheCall->getArg(0)->getType();
3259 unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
3260 unsigned Arg1Length = Arg1Ty->isVectorType()
3261 ? Arg1Ty->getAs<VectorType>()->getNumElements()
3262 : 0;
3263 unsigned Arg2Length = Arg2Ty->isVectorType()
3264 ? Arg2Ty->getAs<VectorType>()->getNumElements()
3265 : 0;
3266 if (Arg1Length > 0 && Arg0Length != Arg1Length) {
3267 S->Diag(TheCall->getBeginLoc(),
3268 diag::err_typecheck_vector_lengths_not_equal)
3269 << Arg0Ty << Arg1Ty << TheCall->getArg(0)->getSourceRange()
3270 << Arg1->getSourceRange();
3271 return true;
3272 }
3273
3274 if (Arg2Length > 0 && Arg0Length != Arg2Length) {
3275 S->Diag(TheCall->getBeginLoc(),
3276 diag::err_typecheck_vector_lengths_not_equal)
3277 << Arg0Ty << Arg2Ty << TheCall->getArg(0)->getSourceRange()
3278 << Arg2->getSourceRange();
3279 return true;
3280 }
3281
3282 TheCall->setType(
3283 S->getASTContext().getExtVectorType(Arg1ScalarTy, Arg0Length));
3284 return false;
3285}
3286
3288 Sema *S, CallExpr *TheCall, unsigned ArgIndex,
3289 llvm::function_ref<bool(const HLSLAttributedResourceType *ResType)> Check =
3290 nullptr) {
3291 assert(TheCall->getNumArgs() >= ArgIndex);
3292 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3293 const HLSLAttributedResourceType *ResTy =
3294 ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>();
3295 if (!ResTy) {
3296 S->Diag(TheCall->getArg(ArgIndex)->getBeginLoc(),
3297 diag::err_typecheck_expect_hlsl_resource)
3298 << ArgType;
3299 return true;
3300 }
3301 if (Check && Check(ResTy)) {
3302 S->Diag(TheCall->getArg(ArgIndex)->getExprLoc(),
3303 diag::err_invalid_hlsl_resource_type)
3304 << ArgType;
3305 return true;
3306 }
3307 return false;
3308}
3309
3310static bool CheckVectorElementCount(Sema *S, QualType PassedType,
3311 QualType BaseType, unsigned ExpectedCount,
3312 SourceLocation Loc) {
3313 unsigned PassedCount = 1;
3314 if (const auto *VecTy = PassedType->getAs<VectorType>())
3315 PassedCount = VecTy->getNumElements();
3316
3317 if (PassedCount != ExpectedCount) {
3319 S->Context.getExtVectorType(BaseType, ExpectedCount);
3320 S->Diag(Loc, diag::err_typecheck_convert_incompatible)
3321 << PassedType << ExpectedType << 1 << 0 << 0;
3322 return true;
3323 }
3324 return false;
3325}
3326
3327enum class SampleKind { Sample, Bias, Grad, Level, Cmp, CmpLevelZero };
3328
3329static bool CheckSamplingBuiltin(Sema &S, CallExpr *TheCall, SampleKind Kind) {
3330 unsigned MinArgs, MaxArgs;
3331 if (Kind == SampleKind::Sample) {
3332 MinArgs = 3;
3333 MaxArgs = 5;
3334 } else if (Kind == SampleKind::Bias) {
3335 MinArgs = 4;
3336 MaxArgs = 6;
3337 } else if (Kind == SampleKind::Grad) {
3338 MinArgs = 5;
3339 MaxArgs = 7;
3340 } else if (Kind == SampleKind::Level) {
3341 MinArgs = 4;
3342 MaxArgs = 5;
3343 } else if (Kind == SampleKind::Cmp) {
3344 MinArgs = 4;
3345 MaxArgs = 6;
3346 } else {
3347 assert(Kind == SampleKind::CmpLevelZero);
3348 MinArgs = 4;
3349 MaxArgs = 5;
3350 }
3351
3352 if (S.checkArgCountRange(TheCall, MinArgs, MaxArgs))
3353 return true;
3354
3355 // Check the texture handle.
3356 if (CheckResourceHandle(&S, TheCall, 0,
3357 [](const HLSLAttributedResourceType *ResType) {
3358 return ResType->getAttrs().ResourceDimension ==
3359 llvm::dxil::ResourceDimension::Unknown;
3360 }))
3361 return true;
3362
3363 // Check the sampler handle.
3364 if (CheckResourceHandle(&S, TheCall, 1,
3365 [](const HLSLAttributedResourceType *ResType) {
3366 return ResType->getAttrs().ResourceClass !=
3367 llvm::hlsl::ResourceClass::Sampler;
3368 }))
3369 return true;
3370
3371 auto *ResourceTy =
3372 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3373
3374 // Check the location.
3375 unsigned ExpectedDim =
3376 getResourceDimensions(ResourceTy->getAttrs().ResourceDimension);
3377 if (CheckVectorElementCount(&S, TheCall->getArg(2)->getType(),
3378 S.Context.FloatTy, ExpectedDim,
3379 TheCall->getBeginLoc()))
3380 return true;
3381
3382 unsigned NextIdx = 3;
3383 if (Kind == SampleKind::Bias || Kind == SampleKind::Level ||
3384 Kind == SampleKind::Cmp || Kind == SampleKind::CmpLevelZero) {
3385 // Check the bias, lod level, or compare value, depending on the kind.
3386 // All of them must be a scalar float value.
3387 QualType BiasOrLODOrCmpTy = TheCall->getArg(NextIdx)->getType();
3388 if (!BiasOrLODOrCmpTy->isFloatingType() ||
3389 BiasOrLODOrCmpTy->isVectorType()) {
3390 S.Diag(TheCall->getArg(NextIdx)->getBeginLoc(),
3391 diag::err_typecheck_convert_incompatible)
3392 << BiasOrLODOrCmpTy << S.Context.FloatTy << 1 << 0 << 0;
3393 return true;
3394 }
3395 NextIdx++;
3396 } else if (Kind == SampleKind::Grad) {
3397 // Check the DDX operand.
3398 if (CheckVectorElementCount(&S, TheCall->getArg(NextIdx)->getType(),
3399 S.Context.FloatTy, ExpectedDim,
3400 TheCall->getArg(NextIdx)->getBeginLoc()))
3401 return true;
3402
3403 // Check the DDY operand.
3404 if (CheckVectorElementCount(&S, TheCall->getArg(NextIdx + 1)->getType(),
3405 S.Context.FloatTy, ExpectedDim,
3406 TheCall->getArg(NextIdx + 1)->getBeginLoc()))
3407 return true;
3408 NextIdx += 2;
3409 }
3410
3411 // Check the offset operand.
3412 if (TheCall->getNumArgs() > NextIdx) {
3413 if (CheckVectorElementCount(&S, TheCall->getArg(NextIdx)->getType(),
3414 S.Context.IntTy, ExpectedDim,
3415 TheCall->getArg(NextIdx)->getBeginLoc()))
3416 return true;
3417 NextIdx++;
3418 }
3419
3420 // Check the clamp operand.
3421 if (Kind != SampleKind::Level && Kind != SampleKind::CmpLevelZero &&
3422 TheCall->getNumArgs() > NextIdx) {
3423 QualType ClampTy = TheCall->getArg(NextIdx)->getType();
3424 if (!ClampTy->isFloatingType() || ClampTy->isVectorType()) {
3425 S.Diag(TheCall->getArg(NextIdx)->getBeginLoc(),
3426 diag::err_typecheck_convert_incompatible)
3427 << ClampTy << S.Context.FloatTy << 1 << 0 << 0;
3428 return true;
3429 }
3430 }
3431
3432 assert(ResourceTy->hasContainedType() &&
3433 "Expecting a contained type for resource with a dimension "
3434 "attribute.");
3435 QualType ReturnType = ResourceTy->getContainedType();
3436 if (Kind == SampleKind::Cmp || Kind == SampleKind::CmpLevelZero) {
3437 if (!ReturnType->hasFloatingRepresentation()) {
3438 S.Diag(TheCall->getBeginLoc(), diag::err_hlsl_samplecmp_requires_float);
3439 return true;
3440 }
3441 ReturnType = S.Context.FloatTy;
3442 }
3443 TheCall->setType(ReturnType);
3444
3445 return false;
3446}
3447
3448// Note: returning true in this case results in CheckBuiltinFunctionCall
3449// returning an ExprError
3450bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
3451 switch (BuiltinID) {
3452 case Builtin::BI__builtin_hlsl_adduint64: {
3453 if (SemaRef.checkArgCount(TheCall, 2))
3454 return true;
3455
3456 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3458 return true;
3459
3460 // ensure arg integers are 32-bits
3461 if (CheckExpectedBitWidth(&SemaRef, TheCall, 0, 32))
3462 return true;
3463
3464 // ensure both args are vectors of total bit size of a multiple of 64
3465 auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
3466 int NumElementsArg = VTy->getNumElements();
3467 if (NumElementsArg != 2 && NumElementsArg != 4) {
3468 SemaRef.Diag(TheCall->getBeginLoc(), diag::err_vector_incorrect_bit_count)
3469 << 1 /*a multiple of*/ << 64 << NumElementsArg * 32;
3470 return true;
3471 }
3472
3473 // ensure first arg and second arg have the same type
3474 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
3475 return true;
3476
3477 ExprResult A = TheCall->getArg(0);
3478 QualType ArgTyA = A.get()->getType();
3479 // return type is the same as the input type
3480 TheCall->setType(ArgTyA);
3481 break;
3482 }
3483 case Builtin::BI__builtin_hlsl_resource_getpointer: {
3484 if (SemaRef.checkArgCount(TheCall, 2) ||
3485 CheckResourceHandle(&SemaRef, TheCall, 0) ||
3486 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
3487 SemaRef.getASTContext().UnsignedIntTy))
3488 return true;
3489
3490 auto *ResourceTy =
3491 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3492 QualType ContainedTy = ResourceTy->getContainedType();
3493 auto ReturnType =
3494 SemaRef.Context.getAddrSpaceQualType(ContainedTy, LangAS::hlsl_device);
3495 ReturnType = SemaRef.Context.getPointerType(ReturnType);
3496 TheCall->setType(ReturnType);
3497 TheCall->setValueKind(VK_LValue);
3498
3499 break;
3500 }
3501 case Builtin::BI__builtin_hlsl_resource_getpointer_typed: {
3502 if (SemaRef.checkArgCount(TheCall, 3) ||
3503 CheckResourceHandle(&SemaRef, TheCall, 0) ||
3504 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
3505 SemaRef.getASTContext().UnsignedIntTy))
3506 return true;
3507
3508 QualType ElementTy = TheCall->getArg(2)->getType();
3509 assert(ElementTy->isPointerType() &&
3510 "expected pointer type for second argument");
3511 ElementTy = ElementTy->getPointeeType();
3512
3513 // Reject array types
3514 if (ElementTy->isArrayType())
3515 return SemaRef.Diag(
3516 cast<FunctionDecl>(SemaRef.CurContext)->getPointOfInstantiation(),
3517 diag::err_invalid_use_of_array_type);
3518
3519 auto ReturnType =
3520 SemaRef.Context.getAddrSpaceQualType(ElementTy, LangAS::hlsl_device);
3521 ReturnType = SemaRef.Context.getPointerType(ReturnType);
3522 TheCall->setType(ReturnType);
3523
3524 break;
3525 }
3526 case Builtin::BI__builtin_hlsl_resource_load_with_status: {
3527 if (SemaRef.checkArgCount(TheCall, 3) ||
3528 CheckResourceHandle(&SemaRef, TheCall, 0) ||
3529 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
3530 SemaRef.getASTContext().UnsignedIntTy) ||
3531 CheckArgTypeMatches(&SemaRef, TheCall->getArg(2),
3532 SemaRef.getASTContext().UnsignedIntTy) ||
3533 CheckModifiableLValue(&SemaRef, TheCall, 2))
3534 return true;
3535
3536 auto *ResourceTy =
3537 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3538 QualType ReturnType = ResourceTy->getContainedType();
3539 TheCall->setType(ReturnType);
3540
3541 break;
3542 }
3543 case Builtin::BI__builtin_hlsl_resource_load_with_status_typed: {
3544 if (SemaRef.checkArgCount(TheCall, 4) ||
3545 CheckResourceHandle(&SemaRef, TheCall, 0) ||
3546 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
3547 SemaRef.getASTContext().UnsignedIntTy) ||
3548 CheckArgTypeMatches(&SemaRef, TheCall->getArg(2),
3549 SemaRef.getASTContext().UnsignedIntTy) ||
3550 CheckModifiableLValue(&SemaRef, TheCall, 2))
3551 return true;
3552
3553 QualType ReturnType = TheCall->getArg(3)->getType();
3554 assert(ReturnType->isPointerType() &&
3555 "expected pointer type for second argument");
3556 ReturnType = ReturnType->getPointeeType();
3557
3558 // Reject array types
3559 if (ReturnType->isArrayType())
3560 return SemaRef.Diag(
3561 cast<FunctionDecl>(SemaRef.CurContext)->getPointOfInstantiation(),
3562 diag::err_invalid_use_of_array_type);
3563
3564 TheCall->setType(ReturnType);
3565
3566 break;
3567 }
3568 case Builtin::BI__builtin_hlsl_resource_sample:
3570 case Builtin::BI__builtin_hlsl_resource_sample_bias:
3572 case Builtin::BI__builtin_hlsl_resource_sample_grad:
3574 case Builtin::BI__builtin_hlsl_resource_sample_level:
3576 case Builtin::BI__builtin_hlsl_resource_sample_cmp:
3578 case Builtin::BI__builtin_hlsl_resource_sample_cmp_level_zero:
3580 case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: {
3581 assert(TheCall->getNumArgs() == 1 && "expected 1 arg");
3582 // Update return type to be the attributed resource type from arg0.
3583 QualType ResourceTy = TheCall->getArg(0)->getType();
3584 TheCall->setType(ResourceTy);
3585 break;
3586 }
3587 case Builtin::BI__builtin_hlsl_resource_handlefrombinding: {
3588 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
3589 // Update return type to be the attributed resource type from arg0.
3590 QualType ResourceTy = TheCall->getArg(0)->getType();
3591 TheCall->setType(ResourceTy);
3592 break;
3593 }
3594 case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: {
3595 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
3596 // Update return type to be the attributed resource type from arg0.
3597 QualType ResourceTy = TheCall->getArg(0)->getType();
3598 TheCall->setType(ResourceTy);
3599 break;
3600 }
3601 case Builtin::BI__builtin_hlsl_resource_counterhandlefromimplicitbinding: {
3602 assert(TheCall->getNumArgs() == 3 && "expected 3 args");
3603 ASTContext &AST = SemaRef.getASTContext();
3604 QualType MainHandleTy = TheCall->getArg(0)->getType();
3605 auto *MainResType = MainHandleTy->getAs<HLSLAttributedResourceType>();
3606 auto MainAttrs = MainResType->getAttrs();
3607 assert(!MainAttrs.IsCounter && "cannot create a counter from a counter");
3608 MainAttrs.IsCounter = true;
3609 QualType CounterHandleTy = AST.getHLSLAttributedResourceType(
3610 MainResType->getWrappedType(), MainResType->getContainedType(),
3611 MainAttrs);
3612 // Update return type to be the attributed resource type from arg0
3613 // with added IsCounter flag.
3614 TheCall->setType(CounterHandleTy);
3615 break;
3616 }
3617 case Builtin::BI__builtin_hlsl_and:
3618 case Builtin::BI__builtin_hlsl_or: {
3619 if (SemaRef.checkArgCount(TheCall, 2))
3620 return true;
3621 if (CheckScalarOrVectorOrMatrix(&SemaRef, TheCall, getASTContext().BoolTy,
3622 0))
3623 return true;
3624 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
3625 return true;
3626
3627 ExprResult A = TheCall->getArg(0);
3628 QualType ArgTyA = A.get()->getType();
3629 // return type is the same as the input type
3630 TheCall->setType(ArgTyA);
3631 break;
3632 }
3633 case Builtin::BI__builtin_hlsl_all:
3634 case Builtin::BI__builtin_hlsl_any: {
3635 if (SemaRef.checkArgCount(TheCall, 1))
3636 return true;
3637 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
3638 return true;
3639 break;
3640 }
3641 case Builtin::BI__builtin_hlsl_asdouble: {
3642 if (SemaRef.checkArgCount(TheCall, 2))
3643 return true;
3645 &SemaRef, TheCall,
3646 /*only check for uint*/ SemaRef.Context.UnsignedIntTy,
3647 /* arg index */ 0))
3648 return true;
3650 &SemaRef, TheCall,
3651 /*only check for uint*/ SemaRef.Context.UnsignedIntTy,
3652 /* arg index */ 1))
3653 return true;
3654 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
3655 return true;
3656
3657 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().DoubleTy);
3658 break;
3659 }
3660 case Builtin::BI__builtin_hlsl_elementwise_clamp: {
3661 if (SemaRef.BuiltinElementwiseTernaryMath(
3662 TheCall, /*ArgTyRestr=*/
3664 return true;
3665 break;
3666 }
3667 case Builtin::BI__builtin_hlsl_dot: {
3668 // arg count is checked by BuiltinVectorToScalarMath
3669 if (SemaRef.BuiltinVectorToScalarMath(TheCall))
3670 return true;
3672 return true;
3673 break;
3674 }
3675 case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
3676 case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: {
3677 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3678 return true;
3679
3680 const Expr *Arg = TheCall->getArg(0);
3681 QualType ArgTy = Arg->getType();
3682 QualType EltTy = ArgTy;
3683
3684 QualType ResTy = SemaRef.Context.UnsignedIntTy;
3685
3686 if (auto *VecTy = EltTy->getAs<VectorType>()) {
3687 EltTy = VecTy->getElementType();
3688 ResTy = SemaRef.Context.getExtVectorType(ResTy, VecTy->getNumElements());
3689 }
3690
3691 if (!EltTy->isIntegerType()) {
3692 Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
3693 << 1 << /* scalar or vector of */ 5 << /* integer ty */ 1
3694 << /* no fp */ 0 << ArgTy;
3695 return true;
3696 }
3697
3698 TheCall->setType(ResTy);
3699 break;
3700 }
3701 case Builtin::BI__builtin_hlsl_select: {
3702 if (SemaRef.checkArgCount(TheCall, 3))
3703 return true;
3704 if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
3705 return true;
3706 QualType ArgTy = TheCall->getArg(0)->getType();
3707 if (ArgTy->isBooleanType() && CheckBoolSelect(&SemaRef, TheCall))
3708 return true;
3709 auto *VTy = ArgTy->getAs<VectorType>();
3710 if (VTy && VTy->getElementType()->isBooleanType() &&
3711 CheckVectorSelect(&SemaRef, TheCall))
3712 return true;
3713 break;
3714 }
3715 case Builtin::BI__builtin_hlsl_elementwise_saturate:
3716 case Builtin::BI__builtin_hlsl_elementwise_rcp: {
3717 if (SemaRef.checkArgCount(TheCall, 1))
3718 return true;
3719 if (!TheCall->getArg(0)
3720 ->getType()
3721 ->hasFloatingRepresentation()) // half or float or double
3722 return SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
3723 diag::err_builtin_invalid_arg_type)
3724 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
3725 << /* fp */ 1 << TheCall->getArg(0)->getType();
3726 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3727 return true;
3728 break;
3729 }
3730 case Builtin::BI__builtin_hlsl_elementwise_degrees:
3731 case Builtin::BI__builtin_hlsl_elementwise_radians:
3732 case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
3733 case Builtin::BI__builtin_hlsl_elementwise_frac:
3734 case Builtin::BI__builtin_hlsl_elementwise_ddx_coarse:
3735 case Builtin::BI__builtin_hlsl_elementwise_ddy_coarse:
3736 case Builtin::BI__builtin_hlsl_elementwise_ddx_fine:
3737 case Builtin::BI__builtin_hlsl_elementwise_ddy_fine: {
3738 if (SemaRef.checkArgCount(TheCall, 1))
3739 return true;
3740 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3742 return true;
3743 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3744 return true;
3745 break;
3746 }
3747 case Builtin::BI__builtin_hlsl_elementwise_isinf:
3748 case Builtin::BI__builtin_hlsl_elementwise_isnan: {
3749 if (SemaRef.checkArgCount(TheCall, 1))
3750 return true;
3751 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3753 return true;
3754 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3755 return true;
3757 break;
3758 }
3759 case Builtin::BI__builtin_hlsl_lerp: {
3760 if (SemaRef.checkArgCount(TheCall, 3))
3761 return true;
3762 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3764 return true;
3765 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
3766 return true;
3767 if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
3768 return true;
3769 break;
3770 }
3771 case Builtin::BI__builtin_hlsl_mad: {
3772 if (SemaRef.BuiltinElementwiseTernaryMath(
3773 TheCall, /*ArgTyRestr=*/
3775 return true;
3776 break;
3777 }
3778 case Builtin::BI__builtin_hlsl_normalize: {
3779 if (SemaRef.checkArgCount(TheCall, 1))
3780 return true;
3781 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3783 return true;
3784 ExprResult A = TheCall->getArg(0);
3785 QualType ArgTyA = A.get()->getType();
3786 // return type is the same as the input type
3787 TheCall->setType(ArgTyA);
3788 break;
3789 }
3790 case Builtin::BI__builtin_hlsl_elementwise_sign: {
3791 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3792 return true;
3793 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3795 return true;
3797 break;
3798 }
3799 case Builtin::BI__builtin_hlsl_step: {
3800 if (SemaRef.checkArgCount(TheCall, 2))
3801 return true;
3802 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3804 return true;
3805
3806 ExprResult A = TheCall->getArg(0);
3807 QualType ArgTyA = A.get()->getType();
3808 // return type is the same as the input type
3809 TheCall->setType(ArgTyA);
3810 break;
3811 }
3812 case Builtin::BI__builtin_hlsl_wave_active_max:
3813 case Builtin::BI__builtin_hlsl_wave_active_min:
3814 case Builtin::BI__builtin_hlsl_wave_active_sum: {
3815 if (SemaRef.checkArgCount(TheCall, 1))
3816 return true;
3817
3818 // Ensure input expr type is a scalar/vector and the same as the return type
3819 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
3820 return true;
3821 if (CheckWaveActive(&SemaRef, TheCall))
3822 return true;
3823 ExprResult Expr = TheCall->getArg(0);
3824 QualType ArgTyExpr = Expr.get()->getType();
3825 TheCall->setType(ArgTyExpr);
3826 break;
3827 }
3828 // Note these are llvm builtins that we want to catch invalid intrinsic
3829 // generation. Normal handling of these builtins will occur elsewhere.
3830 case Builtin::BI__builtin_elementwise_bitreverse: {
3831 // does not include a check for number of arguments
3832 // because that is done previously
3833 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3835 return true;
3836 break;
3837 }
3838 case Builtin::BI__builtin_hlsl_wave_prefix_count_bits: {
3839 if (SemaRef.checkArgCount(TheCall, 1))
3840 return true;
3841
3842 QualType ArgType = TheCall->getArg(0)->getType();
3843
3844 if (!(ArgType->isScalarType())) {
3845 SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
3846 diag::err_typecheck_expect_any_scalar_or_vector)
3847 << ArgType << 0;
3848 return true;
3849 }
3850
3851 if (!(ArgType->isBooleanType())) {
3852 SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
3853 diag::err_typecheck_expect_any_scalar_or_vector)
3854 << ArgType << 0;
3855 return true;
3856 }
3857
3858 break;
3859 }
3860 case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
3861 if (SemaRef.checkArgCount(TheCall, 2))
3862 return true;
3863
3864 // Ensure index parameter type can be interpreted as a uint
3865 ExprResult Index = TheCall->getArg(1);
3866 QualType ArgTyIndex = Index.get()->getType();
3867 if (!ArgTyIndex->isIntegerType()) {
3868 SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
3869 diag::err_typecheck_convert_incompatible)
3870 << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
3871 return true;
3872 }
3873
3874 // Ensure input expr type is a scalar/vector and the same as the return type
3875 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
3876 return true;
3877
3878 ExprResult Expr = TheCall->getArg(0);
3879 QualType ArgTyExpr = Expr.get()->getType();
3880 TheCall->setType(ArgTyExpr);
3881 break;
3882 }
3883 case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
3884 if (SemaRef.checkArgCount(TheCall, 0))
3885 return true;
3886 break;
3887 }
3888 case Builtin::BI__builtin_hlsl_wave_prefix_sum:
3889 case Builtin::BI__builtin_hlsl_wave_prefix_product: {
3890 if (SemaRef.checkArgCount(TheCall, 1))
3891 return true;
3892
3893 // Ensure input expr type is a scalar/vector and the same as the return type
3894 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
3895 return true;
3896 if (CheckWavePrefix(&SemaRef, TheCall))
3897 return true;
3898 ExprResult Expr = TheCall->getArg(0);
3899 QualType ArgTyExpr = Expr.get()->getType();
3900 TheCall->setType(ArgTyExpr);
3901 break;
3902 }
3903 case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
3904 if (SemaRef.checkArgCount(TheCall, 3))
3905 return true;
3906
3907 if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.DoubleTy, 0) ||
3908 CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
3909 1) ||
3910 CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
3911 2))
3912 return true;
3913
3914 if (CheckModifiableLValue(&SemaRef, TheCall, 1) ||
3915 CheckModifiableLValue(&SemaRef, TheCall, 2))
3916 return true;
3917 break;
3918 }
3919 case Builtin::BI__builtin_hlsl_elementwise_clip: {
3920 if (SemaRef.checkArgCount(TheCall, 1))
3921 return true;
3922
3923 if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.FloatTy, 0))
3924 return true;
3925 break;
3926 }
3927 case Builtin::BI__builtin_elementwise_acos:
3928 case Builtin::BI__builtin_elementwise_asin:
3929 case Builtin::BI__builtin_elementwise_atan:
3930 case Builtin::BI__builtin_elementwise_atan2:
3931 case Builtin::BI__builtin_elementwise_ceil:
3932 case Builtin::BI__builtin_elementwise_cos:
3933 case Builtin::BI__builtin_elementwise_cosh:
3934 case Builtin::BI__builtin_elementwise_exp:
3935 case Builtin::BI__builtin_elementwise_exp2:
3936 case Builtin::BI__builtin_elementwise_exp10:
3937 case Builtin::BI__builtin_elementwise_floor:
3938 case Builtin::BI__builtin_elementwise_fmod:
3939 case Builtin::BI__builtin_elementwise_log:
3940 case Builtin::BI__builtin_elementwise_log2:
3941 case Builtin::BI__builtin_elementwise_log10:
3942 case Builtin::BI__builtin_elementwise_pow:
3943 case Builtin::BI__builtin_elementwise_roundeven:
3944 case Builtin::BI__builtin_elementwise_sin:
3945 case Builtin::BI__builtin_elementwise_sinh:
3946 case Builtin::BI__builtin_elementwise_sqrt:
3947 case Builtin::BI__builtin_elementwise_tan:
3948 case Builtin::BI__builtin_elementwise_tanh:
3949 case Builtin::BI__builtin_elementwise_trunc: {
3950 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3952 return true;
3953 break;
3954 }
3955 case Builtin::BI__builtin_hlsl_buffer_update_counter: {
3956 assert(TheCall->getNumArgs() == 2 && "expected 2 args");
3957 auto checkResTy = [](const HLSLAttributedResourceType *ResTy) -> bool {
3958 return !(ResTy->getAttrs().ResourceClass == ResourceClass::UAV &&
3959 ResTy->getAttrs().RawBuffer && ResTy->hasContainedType());
3960 };
3961 if (CheckResourceHandle(&SemaRef, TheCall, 0, checkResTy))
3962 return true;
3963 Expr *OffsetExpr = TheCall->getArg(1);
3964 std::optional<llvm::APSInt> Offset =
3965 OffsetExpr->getIntegerConstantExpr(SemaRef.getASTContext());
3966 if (!Offset.has_value() || std::abs(Offset->getExtValue()) != 1) {
3967 SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
3968 diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
3969 << 1;
3970 return true;
3971 }
3972 break;
3973 }
3974 case Builtin::BI__builtin_hlsl_elementwise_f16tof32: {
3975 if (SemaRef.checkArgCount(TheCall, 1))
3976 return true;
3977 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3979 return true;
3980 // ensure arg integers are 32 bits
3981 if (CheckExpectedBitWidth(&SemaRef, TheCall, 0, 32))
3982 return true;
3983 // check it wasn't a bool type
3984 QualType ArgTy = TheCall->getArg(0)->getType();
3985 if (auto *VTy = ArgTy->getAs<VectorType>())
3986 ArgTy = VTy->getElementType();
3987 if (ArgTy->isBooleanType()) {
3988 SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
3989 diag::err_builtin_invalid_arg_type)
3990 << 1 << /* scalar or vector of */ 5 << /* unsigned int */ 3
3991 << /* no fp */ 0 << TheCall->getArg(0)->getType();
3992 return true;
3993 }
3994
3995 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().FloatTy);
3996 break;
3997 }
3998 case Builtin::BI__builtin_hlsl_elementwise_f32tof16: {
3999 if (SemaRef.checkArgCount(TheCall, 1))
4000 return true;
4002 return true;
4004 getASTContext().UnsignedIntTy);
4005 break;
4006 }
4007 }
4008 return false;
4009}
4010
4014 WorkList.push_back(BaseTy);
4015 while (!WorkList.empty()) {
4016 QualType T = WorkList.pop_back_val();
4017 T = T.getCanonicalType().getUnqualifiedType();
4018 if (const auto *AT = dyn_cast<ConstantArrayType>(T)) {
4019 llvm::SmallVector<QualType, 16> ElementFields;
4020 // Generally I've avoided recursion in this algorithm, but arrays of
4021 // structs could be time-consuming to flatten and churn through on the
4022 // work list. Hopefully nesting arrays of structs containing arrays
4023 // of structs too many levels deep is unlikely.
4024 BuildFlattenedTypeList(AT->getElementType(), ElementFields);
4025 // Repeat the element's field list n times.
4026 for (uint64_t Ct = 0; Ct < AT->getZExtSize(); ++Ct)
4027 llvm::append_range(List, ElementFields);
4028 continue;
4029 }
4030 // Vectors can only have element types that are builtin types, so this can
4031 // add directly to the list instead of to the WorkList.
4032 if (const auto *VT = dyn_cast<VectorType>(T)) {
4033 List.insert(List.end(), VT->getNumElements(), VT->getElementType());
4034 continue;
4035 }
4036 if (const auto *MT = dyn_cast<ConstantMatrixType>(T)) {
4037 List.insert(List.end(), MT->getNumElementsFlattened(),
4038 MT->getElementType());
4039 continue;
4040 }
4041 if (const auto *RD = T->getAsCXXRecordDecl()) {
4042 if (RD->isStandardLayout())
4043 RD = RD->getStandardLayoutBaseWithFields();
4044
4045 // For types that we shouldn't decompose (unions and non-aggregates), just
4046 // add the type itself to the list.
4047 if (RD->isUnion() || !RD->isAggregate()) {
4048 List.push_back(T);
4049 continue;
4050 }
4051
4053 for (const auto *FD : RD->fields())
4054 if (!FD->isUnnamedBitField())
4055 FieldTypes.push_back(FD->getType());
4056 // Reverse the newly added sub-range.
4057 std::reverse(FieldTypes.begin(), FieldTypes.end());
4058 llvm::append_range(WorkList, FieldTypes);
4059
4060 // If this wasn't a standard layout type we may also have some base
4061 // classes to deal with.
4062 if (!RD->isStandardLayout()) {
4063 FieldTypes.clear();
4064 for (const auto &Base : RD->bases())
4065 FieldTypes.push_back(Base.getType());
4066 std::reverse(FieldTypes.begin(), FieldTypes.end());
4067 llvm::append_range(WorkList, FieldTypes);
4068 }
4069 continue;
4070 }
4071 List.push_back(T);
4072 }
4073}
4074
4076 // null and array types are not allowed.
4077 if (QT.isNull() || QT->isArrayType())
4078 return false;
4079
4080 // UDT types are not allowed
4081 if (QT->isRecordType())
4082 return false;
4083
4084 if (QT->isBooleanType() || QT->isEnumeralType())
4085 return false;
4086
4087 // the only other valid builtin types are scalars or vectors
4088 if (QT->isArithmeticType()) {
4089 if (SemaRef.Context.getTypeSize(QT) / 8 > 16)
4090 return false;
4091 return true;
4092 }
4093
4094 if (const VectorType *VT = QT->getAs<VectorType>()) {
4095 int ArraySize = VT->getNumElements();
4096
4097 if (ArraySize > 4)
4098 return false;
4099
4100 QualType ElTy = VT->getElementType();
4101 if (ElTy->isBooleanType())
4102 return false;
4103
4104 if (SemaRef.Context.getTypeSize(QT) / 8 > 16)
4105 return false;
4106 return true;
4107 }
4108
4109 return false;
4110}
4111
4113 if (T1.isNull() || T2.isNull())
4114 return false;
4115
4118
4119 // If both types are the same canonical type, they're obviously compatible.
4120 if (SemaRef.getASTContext().hasSameType(T1, T2))
4121 return true;
4122
4124 BuildFlattenedTypeList(T1, T1Types);
4126 BuildFlattenedTypeList(T2, T2Types);
4127
4128 // Check the flattened type list
4129 return llvm::equal(T1Types, T2Types,
4130 [this](QualType LHS, QualType RHS) -> bool {
4131 return SemaRef.IsLayoutCompatible(LHS, RHS);
4132 });
4133}
4134
4136 FunctionDecl *Old) {
4137 if (New->getNumParams() != Old->getNumParams())
4138 return true;
4139
4140 bool HadError = false;
4141
4142 for (unsigned i = 0, e = New->getNumParams(); i != e; ++i) {
4143 ParmVarDecl *NewParam = New->getParamDecl(i);
4144 ParmVarDecl *OldParam = Old->getParamDecl(i);
4145
4146 // HLSL parameter declarations for inout and out must match between
4147 // declarations. In HLSL inout and out are ambiguous at the call site,
4148 // but have different calling behavior, so you cannot overload a
4149 // method based on a difference between inout and out annotations.
4150 const auto *NDAttr = NewParam->getAttr<HLSLParamModifierAttr>();
4151 unsigned NSpellingIdx = (NDAttr ? NDAttr->getSpellingListIndex() : 0);
4152 const auto *ODAttr = OldParam->getAttr<HLSLParamModifierAttr>();
4153 unsigned OSpellingIdx = (ODAttr ? ODAttr->getSpellingListIndex() : 0);
4154
4155 if (NSpellingIdx != OSpellingIdx) {
4156 SemaRef.Diag(NewParam->getLocation(),
4157 diag::err_hlsl_param_qualifier_mismatch)
4158 << NDAttr << NewParam;
4159 SemaRef.Diag(OldParam->getLocation(), diag::note_previous_declaration_as)
4160 << ODAttr;
4161 HadError = true;
4162 }
4163 }
4164 return HadError;
4165}
4166
4167// Generally follows PerformScalarCast, with cases reordered for
4168// clarity of what types are supported
4170
4171 if (!SrcTy->isScalarType() || !DestTy->isScalarType())
4172 return false;
4173
4174 if (SemaRef.getASTContext().hasSameUnqualifiedType(SrcTy, DestTy))
4175 return true;
4176
4177 switch (SrcTy->getScalarTypeKind()) {
4178 case Type::STK_Bool: // casting from bool is like casting from an integer
4179 case Type::STK_Integral:
4180 switch (DestTy->getScalarTypeKind()) {
4181 case Type::STK_Bool:
4182 case Type::STK_Integral:
4183 case Type::STK_Floating:
4184 return true;
4185 case Type::STK_CPointer:
4189 llvm_unreachable("HLSL doesn't support pointers.");
4192 llvm_unreachable("HLSL doesn't support complex types.");
4194 llvm_unreachable("HLSL doesn't support fixed point types.");
4195 }
4196 llvm_unreachable("Should have returned before this");
4197
4198 case Type::STK_Floating:
4199 switch (DestTy->getScalarTypeKind()) {
4200 case Type::STK_Floating:
4201 case Type::STK_Bool:
4202 case Type::STK_Integral:
4203 return true;
4206 llvm_unreachable("HLSL doesn't support complex types.");
4208 llvm_unreachable("HLSL doesn't support fixed point types.");
4209 case Type::STK_CPointer:
4213 llvm_unreachable("HLSL doesn't support pointers.");
4214 }
4215 llvm_unreachable("Should have returned before this");
4216
4218 case Type::STK_CPointer:
4221 llvm_unreachable("HLSL doesn't support pointers.");
4222
4224 llvm_unreachable("HLSL doesn't support fixed point types.");
4225
4228 llvm_unreachable("HLSL doesn't support complex types.");
4229 }
4230
4231 llvm_unreachable("Unhandled scalar cast");
4232}
4233
4234// Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
4235// Src is a scalar or a vector of length 1
4236// Or if Dest is a vector and Src is a vector of length 1
4238
4239 QualType SrcTy = Src->getType();
4240 // Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is
4241 // going to be a vector splat from a scalar.
4242 if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
4243 DestTy->isScalarType())
4244 return false;
4245
4246 const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
4247
4248 // Src isn't a scalar or a vector of length 1
4249 if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
4250 return false;
4251
4252 if (SrcVecTy)
4253 SrcTy = SrcVecTy->getElementType();
4254
4256 BuildFlattenedTypeList(DestTy, DestTypes);
4257
4258 for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) {
4259 if (DestTypes[I]->isUnionType())
4260 return false;
4261 if (!CanPerformScalarCast(SrcTy, DestTypes[I]))
4262 return false;
4263 }
4264 return true;
4265}
4266
4267// Can we perform an HLSL Elementwise cast?
4269
4270 // Don't handle casts where LHS and RHS are any combination of scalar/vector
4271 // There must be an aggregate somewhere
4272 QualType SrcTy = Src->getType();
4273 if (SrcTy->isScalarType()) // always a splat and this cast doesn't handle that
4274 return false;
4275
4276 if (SrcTy->isVectorType() &&
4277 (DestTy->isScalarType() || DestTy->isVectorType()))
4278 return false;
4279
4280 if (SrcTy->isConstantMatrixType() &&
4281 (DestTy->isScalarType() || DestTy->isConstantMatrixType()))
4282 return false;
4283
4285 BuildFlattenedTypeList(DestTy, DestTypes);
4287 BuildFlattenedTypeList(SrcTy, SrcTypes);
4288
4289 // Usually the size of SrcTypes must be greater than or equal to the size of
4290 // DestTypes.
4291 if (SrcTypes.size() < DestTypes.size())
4292 return false;
4293
4294 unsigned SrcSize = SrcTypes.size();
4295 unsigned DstSize = DestTypes.size();
4296 unsigned I;
4297 for (I = 0; I < DstSize && I < SrcSize; I++) {
4298 if (SrcTypes[I]->isUnionType() || DestTypes[I]->isUnionType())
4299 return false;
4300 if (!CanPerformScalarCast(SrcTypes[I], DestTypes[I])) {
4301 return false;
4302 }
4303 }
4304
4305 // check the rest of the source type for unions.
4306 for (; I < SrcSize; I++) {
4307 if (SrcTypes[I]->isUnionType())
4308 return false;
4309 }
4310 return true;
4311}
4312
4314 assert(Param->hasAttr<HLSLParamModifierAttr>() &&
4315 "We should not get here without a parameter modifier expression");
4316 const auto *Attr = Param->getAttr<HLSLParamModifierAttr>();
4317 if (Attr->getABI() == ParameterABI::Ordinary)
4318 return ExprResult(Arg);
4319
4320 bool IsInOut = Attr->getABI() == ParameterABI::HLSLInOut;
4321 if (!Arg->isLValue()) {
4322 SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_lvalue)
4323 << Arg << (IsInOut ? 1 : 0);
4324 return ExprError();
4325 }
4326
4327 ASTContext &Ctx = SemaRef.getASTContext();
4328
4329 QualType Ty = Param->getType().getNonLValueExprType(Ctx);
4330
4331 // HLSL allows implicit conversions from scalars to vectors, but not the
4332 // inverse, so we need to disallow `inout` with scalar->vector or
4333 // scalar->matrix conversions.
4334 if (Arg->getType()->isScalarType() != Ty->isScalarType()) {
4335 SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_scalar_extension)
4336 << Arg << (IsInOut ? 1 : 0);
4337 return ExprError();
4338 }
4339
4340 auto *ArgOpV = new (Ctx) OpaqueValueExpr(Param->getBeginLoc(), Arg->getType(),
4341 VK_LValue, OK_Ordinary, Arg);
4342
4343 // Parameters are initialized via copy initialization. This allows for
4344 // overload resolution of argument constructors.
4345 InitializedEntity Entity =
4347 ExprResult Res =
4348 SemaRef.PerformCopyInitialization(Entity, Param->getBeginLoc(), ArgOpV);
4349 if (Res.isInvalid())
4350 return ExprError();
4351 Expr *Base = Res.get();
4352 // After the cast, drop the reference type when creating the exprs.
4353 Ty = Ty.getNonLValueExprType(Ctx);
4354 auto *OpV = new (Ctx)
4355 OpaqueValueExpr(Param->getBeginLoc(), Ty, VK_LValue, OK_Ordinary, Base);
4356
4357 // Writebacks are performed with `=` binary operator, which allows for
4358 // overload resolution on writeback result expressions.
4359 Res = SemaRef.ActOnBinOp(SemaRef.getCurScope(), Param->getBeginLoc(),
4360 tok::equal, ArgOpV, OpV);
4361
4362 if (Res.isInvalid())
4363 return ExprError();
4364 Expr *Writeback = Res.get();
4365 auto *OutExpr =
4366 HLSLOutArgExpr::Create(Ctx, Ty, ArgOpV, OpV, Writeback, IsInOut);
4367
4368 return ExprResult(OutExpr);
4369}
4370
4372 // If HLSL gains support for references, all the cites that use this will need
4373 // to be updated with semantic checking to produce errors for
4374 // pointers/references.
4375 assert(!Ty->isReferenceType() &&
4376 "Pointer and reference types cannot be inout or out parameters");
4377 Ty = SemaRef.getASTContext().getLValueReferenceType(Ty);
4378 Ty.addRestrict();
4379 return Ty;
4380}
4381
4382static bool IsDefaultBufferConstantDecl(const ASTContext &Ctx, VarDecl *VD) {
4383 bool IsVulkan =
4384 Ctx.getTargetInfo().getTriple().getOS() == llvm::Triple::Vulkan;
4385 bool IsVKPushConstant = IsVulkan && VD->hasAttr<HLSLVkPushConstantAttr>();
4386 QualType QT = VD->getType();
4387 return VD->getDeclContext()->isTranslationUnit() &&
4388 QT.getAddressSpace() == LangAS::Default &&
4389 VD->getStorageClass() != SC_Static &&
4390 !VD->hasAttr<HLSLVkConstantIdAttr>() && !IsVKPushConstant &&
4392}
4393
4395 // The variable already has an address space (groupshared for ex).
4396 if (Decl->getType().hasAddressSpace())
4397 return;
4398
4399 if (Decl->getType()->isDependentType())
4400 return;
4401
4402 QualType Type = Decl->getType();
4403
4404 if (Decl->hasAttr<HLSLVkExtBuiltinInputAttr>()) {
4405 LangAS ImplAS = LangAS::hlsl_input;
4406 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
4407 Decl->setType(Type);
4408 return;
4409 }
4410
4411 bool IsVulkan = getASTContext().getTargetInfo().getTriple().getOS() ==
4412 llvm::Triple::Vulkan;
4413 if (IsVulkan && Decl->hasAttr<HLSLVkPushConstantAttr>()) {
4414 if (HasDeclaredAPushConstant)
4415 SemaRef.Diag(Decl->getLocation(), diag::err_hlsl_push_constant_unique);
4416
4418 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
4419 Decl->setType(Type);
4420 HasDeclaredAPushConstant = true;
4421 return;
4422 }
4423
4424 if (Type->isSamplerT() || Type->isVoidType())
4425 return;
4426
4427 // Resource handles.
4429 return;
4430
4431 // Only static globals belong to the Private address space.
4432 // Non-static globals belongs to the cbuffer.
4433 if (Decl->getStorageClass() != SC_Static && !Decl->isStaticDataMember())
4434 return;
4435
4437 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
4438 Decl->setType(Type);
4439}
4440
4442 if (VD->hasGlobalStorage()) {
4443 // make sure the declaration has a complete type
4444 if (SemaRef.RequireCompleteType(
4445 VD->getLocation(),
4446 SemaRef.getASTContext().getBaseElementType(VD->getType()),
4447 diag::err_typecheck_decl_incomplete_type)) {
4448 VD->setInvalidDecl();
4450 return;
4451 }
4452
4453 // Global variables outside a cbuffer block that are not a resource, static,
4454 // groupshared, or an empty array or struct belong to the default constant
4455 // buffer $Globals (to be created at the end of the translation unit).
4457 // update address space to hlsl_constant
4460 VD->setType(NewTy);
4461 DefaultCBufferDecls.push_back(VD);
4462 }
4463
4464 // find all resources bindings on decl
4465 if (VD->getType()->isHLSLIntangibleType())
4466 collectResourceBindingsOnVarDecl(VD);
4467
4468 if (VD->hasAttr<HLSLVkConstantIdAttr>())
4470
4472 VD->getStorageClass() != SC_Static) {
4473 // Add internal linkage attribute to non-static resource variables. The
4474 // global externally visible storage is accessed through the handle, which
4475 // is a member. The variable itself is not externally visible.
4476 VD->addAttr(InternalLinkageAttr::CreateImplicit(getASTContext()));
4477 }
4478
4479 // process explicit bindings
4480 processExplicitBindingsOnDecl(VD);
4481
4482 // Add implicit binding attribute to non-static resource arrays.
4483 if (VD->getType()->isHLSLResourceRecordArray() &&
4484 VD->getStorageClass() != SC_Static) {
4485 // If the resource array does not have an explicit binding attribute,
4486 // create an implicit one. It will be used to transfer implicit binding
4487 // order_ID to codegen.
4488 ResourceBindingAttrs Binding(VD);
4489 if (!Binding.isExplicit()) {
4490 uint32_t OrderID = getNextImplicitBindingOrderID();
4491 if (Binding.hasBinding())
4492 Binding.setImplicitOrderID(OrderID);
4493 else {
4496 OrderID);
4497 // Re-create the binding object to pick up the new attribute.
4498 Binding = ResourceBindingAttrs(VD);
4499 }
4500 }
4501
4502 // Get to the base type of a potentially multi-dimensional array.
4504
4505 const CXXRecordDecl *RD = Ty->getAsCXXRecordDecl();
4506 if (hasCounterHandle(RD)) {
4507 if (!Binding.hasCounterImplicitOrderID()) {
4508 uint32_t OrderID = getNextImplicitBindingOrderID();
4509 Binding.setCounterImplicitOrderID(OrderID);
4510 }
4511 }
4512 }
4513 }
4514
4516}
4517
4518bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) {
4519 assert(VD->getType()->isHLSLResourceRecord() &&
4520 "expected resource record type");
4521
4523 uint64_t UIntTySize = AST.getTypeSize(AST.UnsignedIntTy);
4524 uint64_t IntTySize = AST.getTypeSize(AST.IntTy);
4525
4526 // Gather resource binding attributes.
4527 ResourceBindingAttrs Binding(VD);
4528
4529 // Find correct initialization method and create its arguments.
4530 QualType ResourceTy = VD->getType();
4531 CXXRecordDecl *ResourceDecl = ResourceTy->getAsCXXRecordDecl();
4532 CXXMethodDecl *CreateMethod = nullptr;
4534
4535 bool HasCounter = hasCounterHandle(ResourceDecl);
4536 const char *CreateMethodName;
4537 if (Binding.isExplicit())
4538 CreateMethodName = HasCounter ? "__createFromBindingWithImplicitCounter"
4539 : "__createFromBinding";
4540 else
4541 CreateMethodName = HasCounter
4542 ? "__createFromImplicitBindingWithImplicitCounter"
4543 : "__createFromImplicitBinding";
4544
4545 CreateMethod =
4546 lookupMethod(SemaRef, ResourceDecl, CreateMethodName, VD->getLocation());
4547
4548 if (!CreateMethod)
4549 // This can happen if someone creates a struct that looks like an HLSL
4550 // resource record but does not have the required static create method.
4551 // No binding will be generated for it.
4552 return false;
4553
4554 if (Binding.isExplicit()) {
4555 IntegerLiteral *RegSlot =
4556 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, Binding.getSlot()),
4558 Args.push_back(RegSlot);
4559 } else {
4560 uint32_t OrderID = (Binding.hasImplicitOrderID())
4561 ? Binding.getImplicitOrderID()
4562 : getNextImplicitBindingOrderID();
4563 IntegerLiteral *OrderId =
4564 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, OrderID),
4566 Args.push_back(OrderId);
4567 }
4568
4569 IntegerLiteral *Space =
4570 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, Binding.getSpace()),
4571 AST.UnsignedIntTy, SourceLocation());
4572 Args.push_back(Space);
4573
4574 IntegerLiteral *RangeSize = IntegerLiteral::Create(
4575 AST, llvm::APInt(IntTySize, 1), AST.IntTy, SourceLocation());
4576 Args.push_back(RangeSize);
4577
4578 IntegerLiteral *Index = IntegerLiteral::Create(
4579 AST, llvm::APInt(UIntTySize, 0), AST.UnsignedIntTy, SourceLocation());
4580 Args.push_back(Index);
4581
4582 StringRef VarName = VD->getName();
4583 StringLiteral *Name = StringLiteral::Create(
4584 AST, VarName, StringLiteralKind::Ordinary, false,
4585 AST.getStringLiteralArrayType(AST.CharTy.withConst(), VarName.size()),
4586 SourceLocation());
4587 ImplicitCastExpr *NameCast = ImplicitCastExpr::Create(
4588 AST, AST.getPointerType(AST.CharTy.withConst()), CK_ArrayToPointerDecay,
4589 Name, nullptr, VK_PRValue, FPOptionsOverride());
4590 Args.push_back(NameCast);
4591
4592 if (HasCounter) {
4593 // Will this be in the correct order?
4594 uint32_t CounterOrderID = getNextImplicitBindingOrderID();
4595 IntegerLiteral *CounterId =
4596 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, CounterOrderID),
4597 AST.UnsignedIntTy, SourceLocation());
4598 Args.push_back(CounterId);
4599 }
4600
4601 // Make sure the create method template is instantiated and emitted.
4602 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
4603 SemaRef.InstantiateFunctionDefinition(VD->getLocation(), CreateMethod,
4604 true);
4605
4606 // Create CallExpr with a call to the static method and set it as the decl
4607 // initialization.
4608 DeclRefExpr *DRE = DeclRefExpr::Create(
4609 AST, NestedNameSpecifierLoc(), SourceLocation(), CreateMethod, false,
4610 CreateMethod->getNameInfo(), CreateMethod->getType(), VK_PRValue);
4611
4612 auto *ImpCast = ImplicitCastExpr::Create(
4613 AST, AST.getPointerType(CreateMethod->getType()),
4614 CK_FunctionToPointerDecay, DRE, nullptr, VK_PRValue, FPOptionsOverride());
4615
4616 CallExpr *InitExpr =
4617 CallExpr::Create(AST, ImpCast, Args, ResourceTy, VK_PRValue,
4618 SourceLocation(), FPOptionsOverride());
4619 VD->setInit(InitExpr);
4621 SemaRef.CheckCompleteVariableDeclaration(VD);
4622 return true;
4623}
4624
4625bool SemaHLSL::initGlobalResourceArrayDecl(VarDecl *VD) {
4626 assert(VD->getType()->isHLSLResourceRecordArray() &&
4627 "expected array of resource records");
4628
4629 // Individual resources in a resource array are not initialized here. They
4630 // are initialized later on during codegen when the individual resources are
4631 // accessed. Codegen will emit a call to the resource initialization method
4632 // with the specified array index. We need to make sure though that the method
4633 // for the specific resource type is instantiated, so codegen can emit a call
4634 // to it when the array element is accessed.
4635
4636 // Find correct initialization method based on the resource binding
4637 // information.
4638 ASTContext &AST = SemaRef.getASTContext();
4639 QualType ResElementTy = AST.getBaseElementType(VD->getType());
4640 CXXRecordDecl *ResourceDecl = ResElementTy->getAsCXXRecordDecl();
4641 CXXMethodDecl *CreateMethod = nullptr;
4642
4643 bool HasCounter = hasCounterHandle(ResourceDecl);
4644 ResourceBindingAttrs ResourceAttrs(VD);
4645 if (ResourceAttrs.isExplicit())
4646 // Resource has explicit binding.
4647 CreateMethod =
4648 lookupMethod(SemaRef, ResourceDecl,
4649 HasCounter ? "__createFromBindingWithImplicitCounter"
4650 : "__createFromBinding",
4651 VD->getLocation());
4652 else
4653 // Resource has implicit binding.
4654 CreateMethod = lookupMethod(
4655 SemaRef, ResourceDecl,
4656 HasCounter ? "__createFromImplicitBindingWithImplicitCounter"
4657 : "__createFromImplicitBinding",
4658 VD->getLocation());
4659
4660 if (!CreateMethod)
4661 return false;
4662
4663 // Make sure the create method template is instantiated and emitted.
4664 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
4665 SemaRef.InstantiateFunctionDefinition(VD->getLocation(), CreateMethod,
4666 true);
4667 return true;
4668}
4669
4670// Returns true if the initialization has been handled.
4671// Returns false to use default initialization.
4673 // Objects in the hlsl_constant address space are initialized
4674 // externally, so don't synthesize an implicit initializer.
4676 return true;
4677
4678 // Initialize non-static resources at the global scope.
4679 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
4680 const Type *Ty = VD->getType().getTypePtr();
4681 if (Ty->isHLSLResourceRecord())
4682 return initGlobalResourceDecl(VD);
4683 if (Ty->isHLSLResourceRecordArray())
4684 return initGlobalResourceArrayDecl(VD);
4685 }
4686 return false;
4687}
4688
4689// Return true if everything is ok; returns false if there was an error.
4691 Expr *RHSExpr, SourceLocation Loc) {
4692 assert((LHSExpr->getType()->isHLSLResourceRecord() ||
4693 LHSExpr->getType()->isHLSLResourceRecordArray()) &&
4694 "expected LHS to be a resource record or array of resource records");
4695 if (Opc != BO_Assign)
4696 return true;
4697
4698 // If LHS is an array subscript, get the underlying declaration.
4699 Expr *E = LHSExpr;
4700 while (auto *ASE = dyn_cast<ArraySubscriptExpr>(E))
4701 E = ASE->getBase()->IgnoreParenImpCasts();
4702
4703 // Report error if LHS is a non-static resource declared at a global scope.
4704 if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E->IgnoreParens())) {
4705 if (VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
4706 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
4707 // assignment to global resource is not allowed
4708 SemaRef.Diag(Loc, diag::err_hlsl_assign_to_global_resource) << VD;
4709 SemaRef.Diag(VD->getLocation(), diag::note_var_declared_here) << VD;
4710 return false;
4711 }
4712 }
4713 }
4714 return true;
4715}
4716
4717// Walks though the global variable declaration, collects all resource binding
4718// requirements and adds them to Bindings
4719void SemaHLSL::collectResourceBindingsOnVarDecl(VarDecl *VD) {
4720 assert(VD->hasGlobalStorage() && VD->getType()->isHLSLIntangibleType() &&
4721 "expected global variable that contains HLSL resource");
4722
4723 // Cbuffers and Tbuffers are HLSLBufferDecl types
4724 if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(VD)) {
4725 Bindings.addDeclBindingInfo(VD, CBufferOrTBuffer->isCBuffer()
4726 ? ResourceClass::CBuffer
4727 : ResourceClass::SRV);
4728 return;
4729 }
4730
4731 // Unwrap arrays
4732 // FIXME: Calculate array size while unwrapping
4733 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
4734 while (Ty->isArrayType()) {
4735 const ArrayType *AT = cast<ArrayType>(Ty);
4737 }
4738
4739 // Resource (or array of resources)
4740 if (const HLSLAttributedResourceType *AttrResType =
4741 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
4742 Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass);
4743 return;
4744 }
4745
4746 // User defined record type
4747 if (const RecordType *RT = dyn_cast<RecordType>(Ty))
4748 collectResourceBindingsOnUserRecordDecl(VD, RT);
4749}
4750
4751// Walks though the explicit resource binding attributes on the declaration,
4752// and makes sure there is a resource that matched the binding and updates
4753// DeclBindingInfoLists
4754void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
4755 assert(VD->hasGlobalStorage() && "expected global variable");
4756
4757 bool HasBinding = false;
4758 for (Attr *A : VD->attrs()) {
4759 if (isa<HLSLVkBindingAttr>(A)) {
4760 HasBinding = true;
4761 if (auto PA = VD->getAttr<HLSLVkPushConstantAttr>())
4762 Diag(PA->getLoc(), diag::err_hlsl_attr_incompatible) << A << PA;
4763 }
4764
4765 HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
4766 if (!RBA || !RBA->hasRegisterSlot())
4767 continue;
4768 HasBinding = true;
4769
4770 RegisterType RT = RBA->getRegisterType();
4771 assert(RT != RegisterType::I && "invalid or obsolete register type should "
4772 "never have an attribute created");
4773
4774 if (RT == RegisterType::C) {
4775 if (Bindings.hasBindingInfoForDecl(VD))
4776 SemaRef.Diag(VD->getLocation(),
4777 diag::warn_hlsl_user_defined_type_missing_member)
4778 << static_cast<int>(RT);
4779 continue;
4780 }
4781
4782 // Find DeclBindingInfo for this binding and update it, or report error
4783 // if it does not exist (user type does to contain resources with the
4784 // expected resource class).
4786 if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, RC)) {
4787 // update binding info
4788 BI->setBindingAttribute(RBA, BindingType::Explicit);
4789 } else {
4790 SemaRef.Diag(VD->getLocation(),
4791 diag::warn_hlsl_user_defined_type_missing_member)
4792 << static_cast<int>(RT);
4793 }
4794 }
4795
4796 if (!HasBinding && isResourceRecordTypeOrArrayOf(VD))
4797 SemaRef.Diag(VD->getLocation(), diag::warn_hlsl_implicit_binding);
4798}
4799namespace {
4800class InitListTransformer {
4801 Sema &S;
4802 ASTContext &Ctx;
4803 QualType InitTy;
4804 QualType *DstIt = nullptr;
4805 Expr **ArgIt = nullptr;
4806 // Is wrapping the destination type iterator required? This is only used for
4807 // incomplete array types where we loop over the destination type since we
4808 // don't know the full number of elements from the declaration.
4809 bool Wrap;
4810
4811 bool castInitializer(Expr *E) {
4812 assert(DstIt && "This should always be something!");
4813 if (DstIt == DestTypes.end()) {
4814 if (!Wrap) {
4815 ArgExprs.push_back(E);
4816 // This is odd, but it isn't technically a failure due to conversion, we
4817 // handle mismatched counts of arguments differently.
4818 return true;
4819 }
4820 DstIt = DestTypes.begin();
4821 }
4822 InitializedEntity Entity = InitializedEntity::InitializeParameter(
4823 Ctx, *DstIt, /* Consumed (ObjC) */ false);
4824 ExprResult Res = S.PerformCopyInitialization(Entity, E->getBeginLoc(), E);
4825 if (Res.isInvalid())
4826 return false;
4827 Expr *Init = Res.get();
4828 ArgExprs.push_back(Init);
4829 DstIt++;
4830 return true;
4831 }
4832
4833 bool buildInitializerListImpl(Expr *E) {
4834 // If this is an initialization list, traverse the sub initializers.
4835 if (auto *Init = dyn_cast<InitListExpr>(E)) {
4836 for (auto *SubInit : Init->inits())
4837 if (!buildInitializerListImpl(SubInit))
4838 return false;
4839 return true;
4840 }
4841
4842 // If this is a scalar type, just enqueue the expression.
4843 QualType Ty = E->getType();
4844
4845 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()))
4846 return castInitializer(E);
4847
4848 if (auto *VecTy = Ty->getAs<VectorType>()) {
4849 uint64_t Size = VecTy->getNumElements();
4850
4851 QualType SizeTy = Ctx.getSizeType();
4852 uint64_t SizeTySize = Ctx.getTypeSize(SizeTy);
4853 for (uint64_t I = 0; I < Size; ++I) {
4854 auto *Idx = IntegerLiteral::Create(Ctx, llvm::APInt(SizeTySize, I),
4855 SizeTy, SourceLocation());
4856
4858 E, E->getBeginLoc(), Idx, E->getEndLoc());
4859 if (ElExpr.isInvalid())
4860 return false;
4861 if (!castInitializer(ElExpr.get()))
4862 return false;
4863 }
4864 return true;
4865 }
4866 if (auto *MTy = Ty->getAs<ConstantMatrixType>()) {
4867 unsigned Rows = MTy->getNumRows();
4868 unsigned Cols = MTy->getNumColumns();
4869 QualType ElemTy = MTy->getElementType();
4870
4871 for (unsigned R = 0; R < Rows; ++R) {
4872 for (unsigned C = 0; C < Cols; ++C) {
4873 // row index literal
4874 Expr *RowIdx = IntegerLiteral::Create(
4875 Ctx, llvm::APInt(Ctx.getIntWidth(Ctx.IntTy), R), Ctx.IntTy,
4876 E->getBeginLoc());
4877 // column index literal
4878 Expr *ColIdx = IntegerLiteral::Create(
4879 Ctx, llvm::APInt(Ctx.getIntWidth(Ctx.IntTy), C), Ctx.IntTy,
4880 E->getBeginLoc());
4882 E, RowIdx, ColIdx, E->getEndLoc());
4883 if (ElExpr.isInvalid())
4884 return false;
4885 if (!castInitializer(ElExpr.get()))
4886 return false;
4887 ElExpr.get()->setType(ElemTy);
4888 }
4889 }
4890 return true;
4891 }
4892
4893 if (auto *ArrTy = dyn_cast<ConstantArrayType>(Ty.getTypePtr())) {
4894 uint64_t Size = ArrTy->getZExtSize();
4895 QualType SizeTy = Ctx.getSizeType();
4896 uint64_t SizeTySize = Ctx.getTypeSize(SizeTy);
4897 for (uint64_t I = 0; I < Size; ++I) {
4898 auto *Idx = IntegerLiteral::Create(Ctx, llvm::APInt(SizeTySize, I),
4899 SizeTy, SourceLocation());
4901 E, E->getBeginLoc(), Idx, E->getEndLoc());
4902 if (ElExpr.isInvalid())
4903 return false;
4904 if (!buildInitializerListImpl(ElExpr.get()))
4905 return false;
4906 }
4907 return true;
4908 }
4909
4910 if (auto *RD = Ty->getAsCXXRecordDecl()) {
4911 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
4912 RecordDecls.push_back(RD);
4913 while (RecordDecls.back()->getNumBases()) {
4914 CXXRecordDecl *D = RecordDecls.back();
4915 assert(D->getNumBases() == 1 &&
4916 "HLSL doesn't support multiple inheritance");
4917 RecordDecls.push_back(
4919 }
4920 while (!RecordDecls.empty()) {
4921 CXXRecordDecl *RD = RecordDecls.pop_back_val();
4922 for (auto *FD : RD->fields()) {
4923 if (FD->isUnnamedBitField())
4924 continue;
4925 DeclAccessPair Found = DeclAccessPair::make(FD, FD->getAccess());
4926 DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc());
4928 E, false, E->getBeginLoc(), CXXScopeSpec(), FD, Found, NameInfo);
4929 if (Res.isInvalid())
4930 return false;
4931 if (!buildInitializerListImpl(Res.get()))
4932 return false;
4933 }
4934 }
4935 }
4936 return true;
4937 }
4938
4939 Expr *generateInitListsImpl(QualType Ty) {
4940 assert(ArgIt != ArgExprs.end() && "Something is off in iteration!");
4941 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()))
4942 return *(ArgIt++);
4943
4944 llvm::SmallVector<Expr *> Inits;
4945 Ty = Ty.getDesugaredType(Ctx);
4946 if (Ty->isVectorType() || Ty->isConstantArrayType() ||
4947 Ty->isConstantMatrixType()) {
4948 QualType ElTy;
4949 uint64_t Size = 0;
4950 if (auto *ATy = Ty->getAs<VectorType>()) {
4951 ElTy = ATy->getElementType();
4952 Size = ATy->getNumElements();
4953 } else if (auto *CMTy = Ty->getAs<ConstantMatrixType>()) {
4954 ElTy = CMTy->getElementType();
4955 Size = CMTy->getNumElementsFlattened();
4956 } else {
4957 auto *VTy = cast<ConstantArrayType>(Ty.getTypePtr());
4958 ElTy = VTy->getElementType();
4959 Size = VTy->getZExtSize();
4960 }
4961 for (uint64_t I = 0; I < Size; ++I)
4962 Inits.push_back(generateInitListsImpl(ElTy));
4963 }
4964 if (auto *RD = Ty->getAsCXXRecordDecl()) {
4965 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
4966 RecordDecls.push_back(RD);
4967 while (RecordDecls.back()->getNumBases()) {
4968 CXXRecordDecl *D = RecordDecls.back();
4969 assert(D->getNumBases() == 1 &&
4970 "HLSL doesn't support multiple inheritance");
4971 RecordDecls.push_back(
4973 }
4974 while (!RecordDecls.empty()) {
4975 CXXRecordDecl *RD = RecordDecls.pop_back_val();
4976 for (auto *FD : RD->fields())
4977 if (!FD->isUnnamedBitField())
4978 Inits.push_back(generateInitListsImpl(FD->getType()));
4979 }
4980 }
4981 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
4982 Inits, Inits.back()->getEndLoc());
4983 NewInit->setType(Ty);
4984 return NewInit;
4985 }
4986
4987public:
4988 llvm::SmallVector<QualType, 16> DestTypes;
4989 llvm::SmallVector<Expr *, 16> ArgExprs;
4990 InitListTransformer(Sema &SemaRef, const InitializedEntity &Entity)
4991 : S(SemaRef), Ctx(SemaRef.getASTContext()),
4992 Wrap(Entity.getType()->isIncompleteArrayType()) {
4993 InitTy = Entity.getType().getNonReferenceType();
4994 // When we're generating initializer lists for incomplete array types we
4995 // need to wrap around both when building the initializers and when
4996 // generating the final initializer lists.
4997 if (Wrap) {
4998 assert(InitTy->isIncompleteArrayType());
4999 const IncompleteArrayType *IAT = Ctx.getAsIncompleteArrayType(InitTy);
5000 InitTy = IAT->getElementType();
5001 }
5002 BuildFlattenedTypeList(InitTy, DestTypes);
5003 DstIt = DestTypes.begin();
5004 }
5005
5006 bool buildInitializerList(Expr *E) { return buildInitializerListImpl(E); }
5007
5008 Expr *generateInitLists() {
5009 assert(!ArgExprs.empty() &&
5010 "Call buildInitializerList to generate argument expressions.");
5011 ArgIt = ArgExprs.begin();
5012 if (!Wrap)
5013 return generateInitListsImpl(InitTy);
5014 llvm::SmallVector<Expr *> Inits;
5015 while (ArgIt != ArgExprs.end())
5016 Inits.push_back(generateInitListsImpl(InitTy));
5017
5018 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
5019 Inits, Inits.back()->getEndLoc());
5020 llvm::APInt ArySize(64, Inits.size());
5021 NewInit->setType(Ctx.getConstantArrayType(InitTy, ArySize, nullptr,
5022 ArraySizeModifier::Normal, 0));
5023 return NewInit;
5024 }
5025};
5026} // namespace
5027
5028// Recursively detect any incomplete array anywhere in the type graph,
5029// including arrays, struct fields, and base classes.
5031 Ty = Ty.getCanonicalType();
5032
5033 // Array types
5034 if (const ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
5036 return true;
5038 }
5039
5040 // Record (struct/class) types
5041 if (const auto *RT = Ty->getAs<RecordType>()) {
5042 const RecordDecl *RD = RT->getDecl();
5043
5044 // Walk base classes (for C++ / HLSL structs with inheritance)
5045 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
5046 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
5047 if (containsIncompleteArrayType(Base.getType()))
5048 return true;
5049 }
5050 }
5051
5052 // Walk fields
5053 for (const FieldDecl *F : RD->fields()) {
5054 if (containsIncompleteArrayType(F->getType()))
5055 return true;
5056 }
5057 }
5058
5059 return false;
5060}
5061
5063 InitListExpr *Init) {
5064 // If the initializer is a scalar, just return it.
5065 if (Init->getType()->isScalarType())
5066 return true;
5067 ASTContext &Ctx = SemaRef.getASTContext();
5068 InitListTransformer ILT(SemaRef, Entity);
5069
5070 for (unsigned I = 0; I < Init->getNumInits(); ++I) {
5071 Expr *E = Init->getInit(I);
5072 if (E->HasSideEffects(Ctx)) {
5073 QualType Ty = E->getType();
5074 if (Ty->isRecordType())
5075 E = new (Ctx) MaterializeTemporaryExpr(Ty, E, E->isLValue());
5076 E = new (Ctx) OpaqueValueExpr(E->getBeginLoc(), Ty, E->getValueKind(),
5077 E->getObjectKind(), E);
5078 Init->setInit(I, E);
5079 }
5080 if (!ILT.buildInitializerList(E))
5081 return false;
5082 }
5083 size_t ExpectedSize = ILT.DestTypes.size();
5084 size_t ActualSize = ILT.ArgExprs.size();
5085 if (ExpectedSize == 0 && ActualSize == 0)
5086 return true;
5087
5088 // Reject empty initializer if *any* incomplete array exists structurally
5089 if (ActualSize == 0 && containsIncompleteArrayType(Entity.getType())) {
5090 QualType InitTy = Entity.getType().getNonReferenceType();
5091 if (InitTy.hasAddressSpace())
5092 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(InitTy);
5093
5094 SemaRef.Diag(Init->getBeginLoc(), diag::err_hlsl_incorrect_num_initializers)
5095 << /*TooManyOrFew=*/(int)(ExpectedSize < ActualSize) << InitTy
5096 << /*ExpectedSize=*/ExpectedSize << /*ActualSize=*/ActualSize;
5097 return false;
5098 }
5099
5100 // We infer size after validating legality.
5101 // For incomplete arrays it is completely arbitrary to choose whether we think
5102 // the user intended fewer or more elements. This implementation assumes that
5103 // the user intended more, and errors that there are too few initializers to
5104 // complete the final element.
5105 if (Entity.getType()->isIncompleteArrayType()) {
5106 assert(ExpectedSize > 0 &&
5107 "The expected size of an incomplete array type must be at least 1.");
5108 ExpectedSize =
5109 ((ActualSize + ExpectedSize - 1) / ExpectedSize) * ExpectedSize;
5110 }
5111
5112 // An initializer list might be attempting to initialize a reference or
5113 // rvalue-reference. When checking the initializer we should look through
5114 // the reference.
5115 QualType InitTy = Entity.getType().getNonReferenceType();
5116 if (InitTy.hasAddressSpace())
5117 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(InitTy);
5118 if (ExpectedSize != ActualSize) {
5119 int TooManyOrFew = ActualSize > ExpectedSize ? 1 : 0;
5120 SemaRef.Diag(Init->getBeginLoc(), diag::err_hlsl_incorrect_num_initializers)
5121 << TooManyOrFew << InitTy << ExpectedSize << ActualSize;
5122 return false;
5123 }
5124
5125 // generateInitListsImpl will always return an InitListExpr here, because the
5126 // scalar case is handled above.
5127 auto *NewInit = cast<InitListExpr>(ILT.generateInitLists());
5128 Init->resizeInits(Ctx, NewInit->getNumInits());
5129 for (unsigned I = 0; I < NewInit->getNumInits(); ++I)
5130 Init->updateInit(Ctx, I, NewInit->getInit(I));
5131 return true;
5132}
5133
5134static QualType ReportMatrixInvalidMember(Sema &S, StringRef Name,
5135 StringRef Expected,
5136 SourceLocation OpLoc,
5137 SourceLocation CompLoc) {
5138 S.Diag(OpLoc, diag::err_builtin_matrix_invalid_member)
5139 << Name << Expected << SourceRange(CompLoc);
5140 return QualType();
5141}
5142
5145 const IdentifierInfo *CompName,
5146 SourceLocation CompLoc) {
5147 const auto *MT = baseType->castAs<ConstantMatrixType>();
5148 StringRef AccessorName = CompName->getName();
5149 assert(!AccessorName.empty() && "Matrix Accessor must have a name");
5150
5151 unsigned Rows = MT->getNumRows();
5152 unsigned Cols = MT->getNumColumns();
5153 bool IsZeroBasedAccessor = false;
5154 unsigned ChunkLen = 0;
5155 if (AccessorName.size() < 2)
5156 return ReportMatrixInvalidMember(S, AccessorName,
5157 "length 4 for zero based: \'_mRC\' or "
5158 "length 3 for one-based: \'_RC\' accessor",
5159 OpLoc, CompLoc);
5160
5161 if (AccessorName[0] == '_') {
5162 if (AccessorName[1] == 'm') {
5163 IsZeroBasedAccessor = true;
5164 ChunkLen = 4; // zero-based: "_mRC"
5165 } else {
5166 ChunkLen = 3; // one-based: "_RC"
5167 }
5168 } else
5170 S, AccessorName, "zero based: \'_mRC\' or one-based: \'_RC\' accessor",
5171 OpLoc, CompLoc);
5172
5173 if (AccessorName.size() % ChunkLen != 0) {
5174 const llvm::StringRef Expected = IsZeroBasedAccessor
5175 ? "zero based: '_mRC' accessor"
5176 : "one-based: '_RC' accessor";
5177
5178 return ReportMatrixInvalidMember(S, AccessorName, Expected, OpLoc, CompLoc);
5179 }
5180
5181 auto isDigit = [](char c) { return c >= '0' && c <= '9'; };
5182 auto isZeroBasedIndex = [](unsigned i) { return i <= 3; };
5183 auto isOneBasedIndex = [](unsigned i) { return i >= 1 && i <= 4; };
5184
5185 bool HasRepeated = false;
5186 SmallVector<bool, 16> Seen(Rows * Cols, false);
5187 unsigned NumComponents = 0;
5188 const char *Begin = AccessorName.data();
5189
5190 for (unsigned I = 0, E = AccessorName.size(); I < E; I += ChunkLen) {
5191 const char *Chunk = Begin + I;
5192 char RowChar = 0, ColChar = 0;
5193 if (IsZeroBasedAccessor) {
5194 // Zero-based: "_mRC"
5195 if (Chunk[0] != '_' || Chunk[1] != 'm') {
5196 char Bad = (Chunk[0] != '_') ? Chunk[0] : Chunk[1];
5198 S, StringRef(&Bad, 1), "\'_m\' prefix",
5199 OpLoc.getLocWithOffset(I + (Bad == Chunk[0] ? 1 : 2)), CompLoc);
5200 }
5201 RowChar = Chunk[2];
5202 ColChar = Chunk[3];
5203 } else {
5204 // One-based: "_RC"
5205 if (Chunk[0] != '_')
5207 S, StringRef(&Chunk[0], 1), "\'_\' prefix",
5208 OpLoc.getLocWithOffset(I + 1), CompLoc);
5209 RowChar = Chunk[1];
5210 ColChar = Chunk[2];
5211 }
5212
5213 // Must be digits.
5214 bool IsDigitsError = false;
5215 if (!isDigit(RowChar)) {
5216 unsigned BadPos = IsZeroBasedAccessor ? 2 : 1;
5217 ReportMatrixInvalidMember(S, StringRef(&RowChar, 1), "row as integer",
5218 OpLoc.getLocWithOffset(I + BadPos + 1),
5219 CompLoc);
5220 IsDigitsError = true;
5221 }
5222
5223 if (!isDigit(ColChar)) {
5224 unsigned BadPos = IsZeroBasedAccessor ? 3 : 2;
5225 ReportMatrixInvalidMember(S, StringRef(&ColChar, 1), "column as integer",
5226 OpLoc.getLocWithOffset(I + BadPos + 1),
5227 CompLoc);
5228 IsDigitsError = true;
5229 }
5230 if (IsDigitsError)
5231 return QualType();
5232
5233 unsigned Row = RowChar - '0';
5234 unsigned Col = ColChar - '0';
5235
5236 bool HasIndexingError = false;
5237 if (IsZeroBasedAccessor) {
5238 // 0-based [0..3]
5239 if (!isZeroBasedIndex(Row)) {
5240 S.Diag(OpLoc, diag::err_hlsl_matrix_element_not_in_bounds)
5241 << /*row*/ 0 << /*zero-based*/ 0 << SourceRange(CompLoc);
5242 HasIndexingError = true;
5243 }
5244 if (!isZeroBasedIndex(Col)) {
5245 S.Diag(OpLoc, diag::err_hlsl_matrix_element_not_in_bounds)
5246 << /*col*/ 1 << /*zero-based*/ 0 << SourceRange(CompLoc);
5247 HasIndexingError = true;
5248 }
5249 } else {
5250 // 1-based [1..4]
5251 if (!isOneBasedIndex(Row)) {
5252 S.Diag(OpLoc, diag::err_hlsl_matrix_element_not_in_bounds)
5253 << /*row*/ 0 << /*one-based*/ 1 << SourceRange(CompLoc);
5254 HasIndexingError = true;
5255 }
5256 if (!isOneBasedIndex(Col)) {
5257 S.Diag(OpLoc, diag::err_hlsl_matrix_element_not_in_bounds)
5258 << /*col*/ 1 << /*one-based*/ 1 << SourceRange(CompLoc);
5259 HasIndexingError = true;
5260 }
5261 // Convert to 0-based after range checking.
5262 --Row;
5263 --Col;
5264 }
5265
5266 if (HasIndexingError)
5267 return QualType();
5268
5269 // Note: matrix swizzle index is hard coded. That means Row and Col can
5270 // potentially be larger than Rows and Cols if matrix size is less than
5271 // the max index size.
5272 bool HasBoundsError = false;
5273 if (Row >= Rows) {
5274 Diag(OpLoc, diag::err_hlsl_matrix_index_out_of_bounds)
5275 << /*Row*/ 0 << Row << Rows << SourceRange(CompLoc);
5276 HasBoundsError = true;
5277 }
5278 if (Col >= Cols) {
5279 Diag(OpLoc, diag::err_hlsl_matrix_index_out_of_bounds)
5280 << /*Col*/ 1 << Col << Cols << SourceRange(CompLoc);
5281 HasBoundsError = true;
5282 }
5283 if (HasBoundsError)
5284 return QualType();
5285
5286 unsigned FlatIndex = Row * Cols + Col;
5287 if (Seen[FlatIndex])
5288 HasRepeated = true;
5289 Seen[FlatIndex] = true;
5290 ++NumComponents;
5291 }
5292 if (NumComponents == 0 || NumComponents > 4) {
5293 S.Diag(OpLoc, diag::err_hlsl_matrix_swizzle_invalid_length)
5294 << NumComponents << SourceRange(CompLoc);
5295 return QualType();
5296 }
5297
5298 QualType ElemTy = MT->getElementType();
5299 if (NumComponents == 1)
5300 return ElemTy;
5301 QualType VT = S.Context.getExtVectorType(ElemTy, NumComponents);
5302 if (HasRepeated)
5303 VK = VK_PRValue;
5304
5305 for (Sema::ExtVectorDeclsType::iterator
5307 E = S.ExtVectorDecls.end();
5308 I != E; ++I) {
5309 if ((*I)->getUnderlyingType() == VT)
5311 /*Qualifier=*/std::nullopt, *I);
5312 }
5313
5314 return VT;
5315}
5316
5318 const HLSLVkConstantIdAttr *ConstIdAttr =
5319 VDecl->getAttr<HLSLVkConstantIdAttr>();
5320 if (!ConstIdAttr)
5321 return true;
5322
5323 ASTContext &Context = SemaRef.getASTContext();
5324
5325 APValue InitValue;
5326 if (!Init->isCXX11ConstantExpr(Context, &InitValue)) {
5327 Diag(VDecl->getLocation(), diag::err_specialization_const);
5328 VDecl->setInvalidDecl();
5329 return false;
5330 }
5331
5332 Builtin::ID BID =
5334
5335 // Argument 1: The ID from the attribute
5336 int ConstantID = ConstIdAttr->getId();
5337 llvm::APInt IDVal(Context.getIntWidth(Context.IntTy), ConstantID);
5338 Expr *IdExpr = IntegerLiteral::Create(Context, IDVal, Context.IntTy,
5339 ConstIdAttr->getLocation());
5340
5341 SmallVector<Expr *, 2> Args = {IdExpr, Init};
5342 Expr *C = SemaRef.BuildBuiltinCallExpr(Init->getExprLoc(), BID, Args);
5343 if (C->getType()->getCanonicalTypeUnqualified() !=
5345 C = SemaRef
5346 .BuildCStyleCastExpr(SourceLocation(),
5347 Context.getTrivialTypeSourceInfo(
5348 Init->getType(), Init->getExprLoc()),
5349 SourceLocation(), C)
5350 .get();
5351 }
5352 Init = C;
5353 return true;
5354}
Defines the clang::ASTContext interface.
Defines enum values for all the target-independent builtin functions.
llvm::dxil::ResourceClass ResourceClass
Defines the C++ Decl subclasses, other than those for templates (found in DeclTemplate....
TokenType getType() const
Returns the token's type, e.g.
FormatToken * Previous
The previous token in the unwrapped line.
Defines the clang::IdentifierInfo, clang::IdentifierTable, and clang::Selector interfaces.
#define X(type, name)
Definition Value.h:97
Forward-declares and imports various common LLVM datatypes that clang wants to use unqualified.
llvm::SmallVector< std::pair< const MemRegion *, SVal >, 4 > Bindings
static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType)
static void BuildFlattenedTypeList(QualType BaseTy, llvm::SmallVectorImpl< QualType > &List)
static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static bool containsIncompleteArrayType(QualType Ty)
static QualType handleIntegerVectorBinOpConversion(Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign)
static bool convertToRegisterType(StringRef Slot, RegisterType *RT)
Definition SemaHLSL.cpp:82
static bool CheckWaveActive(Sema *S, CallExpr *TheCall)
static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz)
static QualType ReportMatrixInvalidMember(Sema &S, StringRef Name, StringRef Expected, SourceLocation OpLoc, SourceLocation CompLoc)
static bool CheckBoolSelect(Sema *S, CallExpr *TheCall)
static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context, QualType T)
Definition SemaHLSL.cpp:219
static bool isZeroSizedArray(const ConstantArrayType *CAT)
Definition SemaHLSL.cpp:338
static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, Decl *D, RegisterType RegType, bool SpecifiedSpace)
static FieldDecl * createFieldForHostLayoutStruct(Sema &S, const Type *Ty, IdentifierInfo *II, CXXRecordDecl *LayoutStruct)
Definition SemaHLSL.cpp:457
static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
SampleKind
static bool isInvalidConstantBufferLeafElementType(const Type *Ty)
Definition SemaHLSL.cpp:364
static Builtin::ID getSpecConstBuiltinId(const Type *Type)
Definition SemaHLSL.cpp:132
static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall, unsigned ArgIndex)
static IdentifierInfo * getHostLayoutStructName(Sema &S, NamedDecl *BaseDecl, bool MustBeUnique)
Definition SemaHLSL.cpp:420
static void addImplicitBindingAttrToDecl(Sema &S, Decl *D, RegisterType RT, uint32_t ImplicitBindingOrderID)
Definition SemaHLSL.cpp:585
static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall, QualType ReturnType)
static bool isResourceRecordTypeOrArrayOf(VarDecl *VD)
Definition SemaHLSL.cpp:345
static unsigned calculateLegacyCbufferSize(const ASTContext &Context, QualType T)
Definition SemaHLSL.cpp:238
static const HLSLAttributedResourceType * getResourceArrayHandleType(VarDecl *VD)
Definition SemaHLSL.cpp:351
static RegisterType getRegisterType(ResourceClass RC)
Definition SemaHLSL.cpp:62
static bool ValidateRegisterNumber(uint64_t SlotNum, Decl *TheDecl, ASTContext &Ctx, RegisterType RegTy)
static bool isVkPipelineBuiltin(const ASTContext &AstContext, FunctionDecl *FD, HLSLAppliedSemanticAttr *Semantic, bool IsInput)
Definition SemaHLSL.cpp:773
static bool CheckVectorElementCount(Sema *S, QualType PassedType, QualType BaseType, unsigned ExpectedCount, SourceLocation Loc)
static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall, unsigned ArgIndex)
static QualType castElement(Sema &S, ExprResult &E, QualType Ty)
static bool CheckNotBoolScalarOrVector(Sema *S, CallExpr *TheCall, unsigned ArgIndex)
static CXXRecordDecl * findRecordDeclInContext(IdentifierInfo *II, DeclContext *DC)
Definition SemaHLSL.cpp:403
static bool CheckWavePrefix(Sema *S, CallExpr *TheCall)
static bool CheckExpectedBitWidth(Sema *S, CallExpr *TheCall, unsigned ArgOrdinal, unsigned Width)
static bool hasCounterHandle(const CXXRecordDecl *RD)
static bool CheckVectorSelect(Sema *S, CallExpr *TheCall)
static QualType handleFloatVectorBinOpConversion(Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign)
static ResourceClass getResourceClass(RegisterType RT)
Definition SemaHLSL.cpp:114
static CXXRecordDecl * createHostLayoutStruct(Sema &S, CXXRecordDecl *StructDecl)
Definition SemaHLSL.cpp:489
static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar, unsigned ArgIndex)
static bool CheckSamplingBuiltin(Sema &S, CallExpr *TheCall, SampleKind Kind)
static bool CheckScalarOrVectorOrMatrix(Sema *S, CallExpr *TheCall, QualType Scalar, unsigned ArgIndex)
static bool CheckFloatRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl)
Definition SemaHLSL.cpp:554
static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD)
Definition SemaHLSL.cpp:383
static bool CheckResourceHandle(Sema *S, CallExpr *TheCall, unsigned ArgIndex, llvm::function_ref< bool(const HLSLAttributedResourceType *ResType)> Check=nullptr)
static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl)
Definition SemaHLSL.cpp:285
static bool IsDefaultBufferConstantDecl(const ASTContext &Ctx, VarDecl *VD)
HLSLResourceBindingAttr::RegisterType RegisterType
Definition SemaHLSL.cpp:57
static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy, QualType SrcTy)
static bool isValidWaveSizeValue(unsigned Value)
static bool AccumulateHLSLResourceSlots(QualType Ty, uint64_t &StartSlot, const uint64_t &Limit, const ResourceClass ResClass, ASTContext &Ctx, uint64_t ArrayCount=1)
static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, RegisterType regType)
static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, Decl *D, RegisterType RegType, bool SpecifiedSpace)
This file declares semantic analysis for HLSL constructs.
Defines the clang::SourceLocation class and associated facilities.
Defines various enumerations that describe declaration and type specifiers.
C Language Family Type Representation.
Defines the clang::TypeLoc interface and its subclasses.
C Language Family Type Representation.
static const TypeInfo & getInfo(unsigned id)
Definition Types.cpp:44
__device__ __2f16 float c
return(__x > > __y)|(__x<<(32 - __y))
APValue - This class implements a discriminated union of [uninitialized] [APSInt] [APFloat],...
Definition APValue.h:122
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
Definition ASTContext.h:226
unsigned getIntWidth(QualType T) const
int getIntegerTypeOrder(QualType LHS, QualType RHS) const
Return the highest ranked integer type, see C99 6.3.1.8p1.
CanQualType FloatTy
QualType getPointerType(QualType T) const
Return the uniqued reference to the type for a pointer to the specified type.
const IncompleteArrayType * getAsIncompleteArrayType(QualType T) const
IdentifierTable & Idents
Definition ASTContext.h:797
QualType getConstantArrayType(QualType EltTy, const llvm::APInt &ArySize, const Expr *SizeExpr, ArraySizeModifier ASM, unsigned IndexTypeQuals) const
Return the unique reference to the type for a constant array of the specified element type.
QualType getBaseElementType(const ArrayType *VAT) const
Return the innermost element type of an array type.
int getFloatingTypeOrder(QualType LHS, QualType RHS) const
Compare the rank of the two specified floating point types, ignoring the domain of the type (i....
CanQualType BoolTy
TypeSourceInfo * getTrivialTypeSourceInfo(QualType T, SourceLocation Loc=SourceLocation()) const
Allocate a TypeSourceInfo where all locations have been initialized to a given location,...
QualType getStringLiteralArrayType(QualType EltTy, unsigned Length) const
Return a type for a constant array for a string literal of the specified element type and length.
CanQualType CharTy
CanQualType IntTy
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
CharUnits getTypeSizeInChars(QualType T) const
Return the size of the specified (complete) type T, in characters.
CanQualType UnsignedIntTy
QualType getTypedefType(ElaboratedTypeKeyword Keyword, NestedNameSpecifier Qualifier, const TypedefNameDecl *Decl, QualType UnderlyingType=QualType(), std::optional< bool > TypeMatchesDeclOrNone=std::nullopt) const
Return the unique reference to the type for the specified typedef-name decl.
QualType getSizeType() const
Return the unique type for "size_t" (C99 7.17), defined in <stddef.h>.
QualType getExtVectorType(QualType VectorType, unsigned NumElts) const
Return the unique reference to an extended vector type of the specified element type and size.
const TargetInfo & getTargetInfo() const
Definition ASTContext.h:916
QualType getHLSLAttributedResourceType(QualType Wrapped, QualType Contained, const HLSLAttributedResourceType::Attributes &Attrs)
QualType getAddrSpaceQualType(QualType T, LangAS AddressSpace) const
Return the uniqued reference to the type for an address space qualified type with the specified type ...
CanQualType getCanonicalTagType(const TagDecl *TD) const
static bool hasSameUnqualifiedType(QualType T1, QualType T2)
Determine whether the given types are equivalent after cvr-qualifiers have been removed.
unsigned getTypeAlign(QualType T) const
Return the ABI-specified alignment of a (complete) type T, in bits.
PtrTy get() const
Definition Ownership.h:171
bool isInvalid() const
Definition Ownership.h:167
Represents an array type, per C99 6.7.5.2 - Array Declarators.
Definition TypeBase.h:3730
QualType getElementType() const
Definition TypeBase.h:3742
Attr - This represents one attribute.
Definition Attr.h:46
attr::Kind getKind() const
Definition Attr.h:92
SourceLocation getLocation() const
Definition Attr.h:99
SourceLocation getScopeLoc() const
const IdentifierInfo * getScopeName() const
SourceLocation getLoc() const
const IdentifierInfo * getAttrName() const
Represents a base class of a C++ class.
Definition DeclCXX.h:146
QualType getType() const
Retrieves the type of the base class.
Definition DeclCXX.h:249
Represents a static or instance method of a struct/union/class.
Definition DeclCXX.h:2136
Represents a C++ struct/union/class.
Definition DeclCXX.h:258
bool isHLSLIntangible() const
Returns true if the class contains HLSL intangible type, either as a field or in base class.
Definition DeclCXX.h:1556
static CXXRecordDecl * Create(const ASTContext &C, TagKind TK, DeclContext *DC, SourceLocation StartLoc, SourceLocation IdLoc, IdentifierInfo *Id, CXXRecordDecl *PrevDecl=nullptr)
Definition DeclCXX.cpp:132
void setBases(CXXBaseSpecifier const *const *Bases, unsigned NumBases)
Sets the base classes of this struct or class.
Definition DeclCXX.cpp:184
void completeDefinition() override
Indicates that the definition of this class is now complete.
Definition DeclCXX.cpp:2249
base_class_range bases()
Definition DeclCXX.h:608
unsigned getNumBases() const
Retrieves the number of base classes of this class.
Definition DeclCXX.h:602
base_class_iterator bases_begin()
Definition DeclCXX.h:615
bool isEmpty() const
Determine whether this is an empty class in the sense of (C++11 [meta.unary.prop]).
Definition DeclCXX.h:1186
CallExpr - Represents a function call (C99 6.5.2.2, C++ [expr.call]).
Definition Expr.h:2946
Expr * getArg(unsigned Arg)
getArg - Return the specified argument.
Definition Expr.h:3150
SourceLocation getBeginLoc() const
Definition Expr.h:3280
static CallExpr * Create(const ASTContext &Ctx, Expr *Fn, ArrayRef< Expr * > Args, QualType Ty, ExprValueKind VK, SourceLocation RParenLoc, FPOptionsOverride FPFeatures, unsigned MinNumArgs=0, ADLCallKind UsesADL=NotADL)
Create a call expression.
Definition Expr.cpp:1517
FunctionDecl * getDirectCallee()
If the callee is a FunctionDecl, return it. Otherwise return null.
Definition Expr.h:3129
Expr * getCallee()
Definition Expr.h:3093
unsigned getNumArgs() const
getNumArgs - Return the number of actual arguments to this call.
Definition Expr.h:3137
QualType withConst() const
Retrieves a version of this type with const applied.
const T * getTypePtr() const
Retrieve the underlying type pointer, which refers to a canonical type.
QuantityType getQuantity() const
getQuantity - Get the raw integer representation of this quantity.
Definition CharUnits.h:185
Represents the canonical version of C arrays with a specified constant size.
Definition TypeBase.h:3768
bool isZeroSize() const
Return true if the size is zero.
Definition TypeBase.h:3838
llvm::APInt getSize() const
Return the constant array size as an APInt.
Definition TypeBase.h:3824
uint64_t getZExtSize() const
Return the size zero-extended as a uint64_t.
Definition TypeBase.h:3844
Represents a concrete matrix type with constant number of rows and columns.
Definition TypeBase.h:4395
static DeclAccessPair make(NamedDecl *D, AccessSpecifier AS)
DeclContext - This is used only as base class of specific decl types that can act as declaration cont...
Definition DeclBase.h:1449
lookup_result lookup(DeclarationName Name) const
lookup - Find the declarations (if any) with the given Name in this context.
bool isTranslationUnit() const
Definition DeclBase.h:2185
void addDecl(Decl *D)
Add the declaration D into this context.
decl_range decls() const
decls_begin/decls_end - Iterate over the declarations stored in this context.
Definition DeclBase.h:2373
DeclContext * getNonTransparentContext()
A reference to a declared variable, function, enum, etc.
Definition Expr.h:1273
static DeclRefExpr * Create(const ASTContext &Context, NestedNameSpecifierLoc QualifierLoc, SourceLocation TemplateKWLoc, ValueDecl *D, bool RefersToEnclosingVariableOrCapture, SourceLocation NameLoc, QualType T, ExprValueKind VK, NamedDecl *FoundD=nullptr, const TemplateArgumentListInfo *TemplateArgs=nullptr, NonOdrUseReason NOUR=NOUR_None)
Definition Expr.cpp:488
ValueDecl * getDecl()
Definition Expr.h:1341
Decl - This represents one declaration (or definition), e.g.
Definition DeclBase.h:86
T * getAttr() const
Definition DeclBase.h:573
void addAttr(Attr *A)
attr_iterator attr_end() const
Definition DeclBase.h:542
bool isImplicit() const
isImplicit - Indicates whether the declaration was implicitly generated by the implementation.
Definition DeclBase.h:593
void setInvalidDecl(bool Invalid=true)
setInvalidDecl - Indicates the Decl had a semantic error.
Definition DeclBase.cpp:178
bool isInExportDeclContext() const
Whether this declaration was exported in a lexical context.
attr_iterator attr_begin() const
Definition DeclBase.h:539
SourceLocation getLocation() const
Definition DeclBase.h:439
void setImplicit(bool I=true)
Definition DeclBase.h:594
DeclContext * getDeclContext()
Definition DeclBase.h:448
attr_range attrs() const
Definition DeclBase.h:535
AccessSpecifier getAccess() const
Definition DeclBase.h:507
SourceLocation getBeginLoc() const LLVM_READONLY
Definition DeclBase.h:431
void dropAttr()
Definition DeclBase.h:556
bool hasAttr() const
Definition DeclBase.h:577
The name of a declaration.
Represents a ValueDecl that came out of a declarator.
Definition Decl.h:780
SourceLocation getBeginLoc() const LLVM_READONLY
Definition Decl.h:831
This represents one expression.
Definition Expr.h:112
void setType(QualType t)
Definition Expr.h:145
ExprValueKind getValueKind() const
getValueKind - The value kind that this expression produces.
Definition Expr.h:447
Expr * IgnoreParenImpCasts() LLVM_READONLY
Skip past any parentheses and implicit casts which might surround this expression until reaching a fi...
Definition Expr.cpp:3090
Expr * IgnoreParens() LLVM_READONLY
Skip past any parentheses which might surround this expression until reaching a fixed point.
Definition Expr.cpp:3086
std::optional< llvm::APSInt > getIntegerConstantExpr(const ASTContext &Ctx) const
isIntegerConstantExpr - Return the value if this expression is a valid integer constant expression.
bool isLValue() const
isLValue - True if this expression is an "l-value" according to the rules of the current language.
Definition Expr.h:284
ExprObjectKind getObjectKind() const
getObjectKind - The object kind that this expression produces.
Definition Expr.h:454
bool HasSideEffects(const ASTContext &Ctx, bool IncludePossibleEffects=true) const
HasSideEffects - This routine returns true for all those expressions which have any effect other than...
Definition Expr.cpp:3670
void setValueKind(ExprValueKind Cat)
setValueKind - Set the value kind produced by this expression.
Definition Expr.h:464
SourceLocation getExprLoc() const LLVM_READONLY
getExprLoc - Return the preferred location for the arrow when diagnosing a problem with a generic exp...
Definition Expr.cpp:277
@ MLV_Valid
Definition Expr.h:306
QualType getType() const
Definition Expr.h:144
Represents a member of a struct/union/class.
Definition Decl.h:3160
static FieldDecl * Create(const ASTContext &C, DeclContext *DC, SourceLocation StartLoc, SourceLocation IdLoc, const IdentifierInfo *Id, QualType T, TypeSourceInfo *TInfo, Expr *BW, bool Mutable, InClassInitStyle InitStyle)
Definition Decl.cpp:4701
static FixItHint CreateReplacement(CharSourceRange RemoveRange, StringRef Code)
Create a code modification hint that replaces the given source range with the given code string.
Definition Diagnostic.h:140
Represents a function declaration or definition.
Definition Decl.h:2000
const ParmVarDecl * getParamDecl(unsigned i) const
Definition Decl.h:2797
Stmt * getBody(const FunctionDecl *&Definition) const
Retrieve the body (definition) of the function.
Definition Decl.cpp:3280
bool isThisDeclarationADefinition() const
Returns whether this specific declaration of the function is also a definition that does not contain ...
Definition Decl.h:2314
QualType getReturnType() const
Definition Decl.h:2845
ArrayRef< ParmVarDecl * > parameters() const
Definition Decl.h:2774
bool isTemplateInstantiation() const
Determines if the given function was instantiated from a function template.
Definition Decl.cpp:4258
redecl_range redecls() const
Returns an iterator range for all the redeclarations of the same decl.
unsigned getNumParams() const
Return the number of parameters this function must have based on its FunctionType.
Definition Decl.cpp:3827
DeclarationNameInfo getNameInfo() const
Definition Decl.h:2211
bool hasBody(const FunctionDecl *&Definition) const
Returns true if the function has a body.
Definition Decl.cpp:3200
bool isDefined(const FunctionDecl *&Definition, bool CheckForPendingFriendDefinition=false) const
Returns true if the function has a definition that does not need to be instantiated.
Definition Decl.cpp:3247
HLSLBufferDecl - Represent a cbuffer or tbuffer declaration.
Definition Decl.h:5196
static HLSLBufferDecl * Create(ASTContext &C, DeclContext *LexicalParent, bool CBuffer, SourceLocation KwLoc, IdentifierInfo *ID, SourceLocation IDLoc, SourceLocation LBrace)
Definition Decl.cpp:5906
void addLayoutStruct(CXXRecordDecl *LS)
Definition Decl.cpp:5946
void setHasValidPackoffset(bool PO)
Definition Decl.h:5241
static HLSLBufferDecl * CreateDefaultCBuffer(ASTContext &C, DeclContext *LexicalParent, ArrayRef< Decl * > DefaultCBufferDecls)
Definition Decl.cpp:5929
buffer_decl_range buffer_decls() const
Definition Decl.h:5271
static HLSLOutArgExpr * Create(const ASTContext &C, QualType Ty, OpaqueValueExpr *Base, OpaqueValueExpr *OpV, Expr *WB, bool IsInOut)
Definition Expr.cpp:5627
static HLSLRootSignatureDecl * Create(ASTContext &C, DeclContext *DC, SourceLocation Loc, IdentifierInfo *ID, llvm::dxbc::RootSignatureVersion Version, ArrayRef< llvm::hlsl::rootsig::RootElement > RootElements)
Definition Decl.cpp:5992
One of these records is kept for each identifier that is lexed.
StringRef getName() const
Return the actual identifier string.
A simple pair of identifier info and location.
SourceLocation getLoc() const
IdentifierInfo * getIdentifierInfo() const
IdentifierInfo & get(StringRef Name)
Return the identifier token info for the specified named identifier.
static ImplicitCastExpr * Create(const ASTContext &Context, QualType T, CastKind Kind, Expr *Operand, const CXXCastPath *BasePath, ExprValueKind Cat, FPOptionsOverride FPO)
Definition Expr.cpp:2073
Describes an C or C++ initializer list.
Definition Expr.h:5302
Describes an entity that is being initialized.
QualType getType() const
Retrieve type being initialized.
static InitializedEntity InitializeParameter(ASTContext &Context, ParmVarDecl *Parm)
Create the initialization entity for a parameter.
static IntegerLiteral * Create(const ASTContext &C, const llvm::APInt &V, QualType type, SourceLocation l)
Returns a new integer literal with value 'V' and type 'type'.
Definition Expr.cpp:975
iterator begin(Source *source, bool LocalOnly=false)
Represents the results of name lookup.
Definition Lookup.h:147
NamedDecl * getFoundDecl() const
Fetch the unique decl found by this lookup.
Definition Lookup.h:569
Represents a prvalue temporary that is written into memory so that a reference can bind to it.
Definition ExprCXX.h:4921
ValueDecl * getMemberDecl() const
Retrieve the member declaration to which this expression refers.
Definition Expr.h:3450
This represents a decl that may have a name.
Definition Decl.h:274
IdentifierInfo * getIdentifier() const
Get the identifier that names this declaration, if there is one.
Definition Decl.h:295
StringRef getName() const
Get the name of identifier for this declaration as a StringRef.
Definition Decl.h:301
DeclarationName getDeclName() const
Get the actual, stored name of the declaration, which may be a special name.
Definition Decl.h:340
OpaqueValueExpr - An expression referring to an opaque object of a fixed type and value class.
Definition Expr.h:1181
Represents a parameter to a function.
Definition Decl.h:1790
ParsedAttr - Represents a syntactic attribute.
Definition ParsedAttr.h:119
unsigned getSemanticSpelling() const
If the parsed attribute has a semantic equivalent, and it would have a semantic Spelling enumeration ...
unsigned getMinArgs() const
bool checkExactlyNumArgs(class Sema &S, unsigned Num) const
Check if the attribute has exactly as many args as Num.
IdentifierLoc * getArgAsIdent(unsigned Arg) const
Definition ParsedAttr.h:389
bool hasParsedType() const
Definition ParsedAttr.h:337
const ParsedType & getTypeArg() const
Definition ParsedAttr.h:459
unsigned getNumArgs() const
getNumArgs - Return the number of actual arguments to this attribute.
Definition ParsedAttr.h:371
bool isArgIdent(unsigned Arg) const
Definition ParsedAttr.h:385
Expr * getArgAsExpr(unsigned Arg) const
Definition ParsedAttr.h:383
AttributeCommonInfo::Kind getKind() const
Definition ParsedAttr.h:610
A (possibly-)qualified type.
Definition TypeBase.h:937
void addRestrict()
Add the restrict qualifier to this QualType.
Definition TypeBase.h:1178
QualType getNonLValueExprType(const ASTContext &Context) const
Determine the type of a (typically non-lvalue) expression with the specified result type.
Definition Type.cpp:3620
QualType getDesugaredType(const ASTContext &Context) const
Return the specified type with any "sugar" removed from the type.
Definition TypeBase.h:1302
bool isNull() const
Return true if this QualType doesn't point to a type yet.
Definition TypeBase.h:1004
const Type * getTypePtr() const
Retrieves a pointer to the underlying (unqualified) type.
Definition TypeBase.h:8349
LangAS getAddressSpace() const
Return the address space of this type.
Definition TypeBase.h:8475
QualType getNonReferenceType() const
If Type is a reference type (e.g., const int&), returns the type that the reference refers to ("const...
Definition TypeBase.h:8534
QualType getCanonicalType() const
Definition TypeBase.h:8401
QualType getUnqualifiedType() const
Retrieve the unqualified variant of the given type, removing as little sugar as possible.
Definition TypeBase.h:8443
bool hasAddressSpace() const
Check if this type has any address space qualifier.
Definition TypeBase.h:8470
Represents a struct/union/class.
Definition Decl.h:4327
field_iterator field_end() const
Definition Decl.h:4533
field_range fields() const
Definition Decl.h:4530
bool field_empty() const
Definition Decl.h:4538
field_iterator field_begin() const
Definition Decl.cpp:5276
bool hasBindingInfoForDecl(const VarDecl *VD) const
Definition SemaHLSL.cpp:193
DeclBindingInfo * getDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass)
Definition SemaHLSL.cpp:179
DeclBindingInfo * addDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass)
Definition SemaHLSL.cpp:166
Scope - A scope is a transient data structure that is used while parsing the program.
Definition Scope.h:41
SemaBase(Sema &S)
Definition SemaBase.cpp:7
ASTContext & getASTContext() const
Definition SemaBase.cpp:9
Sema & SemaRef
Definition SemaBase.h:40
SemaDiagnosticBuilder Diag(SourceLocation Loc, unsigned DiagID)
Emit a diagnostic.
Definition SemaBase.cpp:61
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg)
HLSLRootSignatureDecl * lookupRootSignatureOverrideDecl(DeclContext *DC) const
bool CanPerformElementwiseCast(Expr *Src, QualType DestType)
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL)
void handleVkLocationAttr(Decl *D, const ParsedAttr &AL)
HLSLAttributedResourceLocInfo TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT)
void handleSemanticAttr(Decl *D, const ParsedAttr &AL)
bool CanPerformScalarCast(QualType SrcTy, QualType DestTy)
QualType ProcessResourceTypeAttributes(QualType Wrapped)
void handleShaderAttr(Decl *D, const ParsedAttr &AL)
void CheckEntryPoint(FunctionDecl *FD)
Definition SemaHLSL.cpp:890
void emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS, BinaryOperatorKind Opc)
T * createSemanticAttr(const AttributeCommonInfo &ACI, std::optional< unsigned > Location)
Definition SemaHLSL.h:179
void ActOnEndOfTranslationUnit(TranslationUnitDecl *TU)
HLSLVkConstantIdAttr * mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id)
Definition SemaHLSL.cpp:656
HLSLNumThreadsAttr * mergeNumThreadsAttr(Decl *D, const AttributeCommonInfo &AL, int X, int Y, int Z)
Definition SemaHLSL.cpp:622
void deduceAddressSpace(VarDecl *Decl)
std::pair< IdentifierInfo *, bool > ActOnStartRootSignatureDecl(StringRef Signature)
Computes the unique Root Signature identifier from the given signature, then lookup if there is a pre...
void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL)
bool diagnosePositionType(QualType T, const ParsedAttr &AL)
bool handleInitialization(VarDecl *VDecl, Expr *&Init)
bool diagnoseInputIDType(QualType T, const ParsedAttr &AL)
void handleParamModifierAttr(Decl *D, const ParsedAttr &AL)
bool CheckResourceBinOp(BinaryOperatorKind Opc, Expr *LHSExpr, Expr *RHSExpr, SourceLocation Loc)
bool CanPerformAggregateSplatCast(Expr *Src, QualType DestType)
bool IsScalarizedLayoutCompatible(QualType T1, QualType T2) const
void diagnoseSystemSemanticAttr(Decl *D, const ParsedAttr &AL, std::optional< unsigned > Index)
void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL)
bool CheckCompatibleParameterABI(FunctionDecl *New, FunctionDecl *Old)
QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, bool IsCompAssign)
QualType checkMatrixComponent(Sema &S, QualType baseType, ExprValueKind &VK, SourceLocation OpLoc, const IdentifierInfo *CompName, SourceLocation CompLoc)
void handleResourceBindingAttr(Decl *D, const ParsedAttr &AL)
bool IsTypedResourceElementCompatible(QualType T1)
bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init)
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL)
bool ActOnUninitializedVarDecl(VarDecl *D)
void handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL)
void ActOnTopLevelFunction(FunctionDecl *FD)
Definition SemaHLSL.cpp:725
bool handleResourceTypeAttr(QualType T, const ParsedAttr &AL)
void handleVkPushConstantAttr(Decl *D, const ParsedAttr &AL)
HLSLShaderAttr * mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL, llvm::Triple::EnvironmentType ShaderType)
Definition SemaHLSL.cpp:692
void ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace)
Definition SemaHLSL.cpp:595
void handleVkBindingAttr(Decl *D, const ParsedAttr &AL)
HLSLParamModifierAttr * mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL, HLSLParamModifierAttr::Spelling Spelling)
Definition SemaHLSL.cpp:705
QualType getInoutParameterType(QualType Ty)
SemaHLSL(Sema &S)
Definition SemaHLSL.cpp:197
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL)
Decl * ActOnStartBuffer(Scope *BufferScope, bool CBuffer, SourceLocation KwLoc, IdentifierInfo *Ident, SourceLocation IdentLoc, SourceLocation LBrace)
Definition SemaHLSL.cpp:199
HLSLWaveSizeAttr * mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL, int Min, int Max, int Preferred, int SpelledArgsCount)
Definition SemaHLSL.cpp:636
bool handleRootSignatureElements(ArrayRef< hlsl::RootSignatureElement > Elements)
void ActOnFinishRootSignatureDecl(SourceLocation Loc, IdentifierInfo *DeclIdent, ArrayRef< hlsl::RootSignatureElement > Elements)
Creates the Root Signature decl of the parsed Root Signature elements onto the AST and push it onto c...
void ActOnVariableDeclarator(VarDecl *VD)
bool CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall)
Sema - This implements semantic analysis and AST building for C.
Definition Sema.h:868
@ LookupOrdinaryName
Ordinary name lookup, which finds ordinary names (functions, variables, typedefs, etc....
Definition Sema.h:9381
@ LookupMemberName
Member name lookup, which finds the names of class/struct/union members.
Definition Sema.h:9389
ExtVectorDeclsType ExtVectorDecls
ExtVectorDecls - This is a list all the extended vector types.
Definition Sema.h:4936
ASTContext & Context
Definition Sema.h:1300
ASTContext & getASTContext() const
Definition Sema.h:939
ExprResult ImpCastExprToType(Expr *E, QualType Type, CastKind CK, ExprValueKind VK=VK_PRValue, const CXXCastPath *BasePath=nullptr, CheckedConversionKind CCK=CheckedConversionKind::Implicit)
ImpCastExprToType - If Expr is not of type 'Type', insert an implicit cast.
Definition Sema.cpp:756
const LangOptions & getLangOpts() const
Definition Sema.h:932
ExprResult BuildFieldReferenceExpr(Expr *BaseExpr, bool IsArrow, SourceLocation OpLoc, const CXXScopeSpec &SS, FieldDecl *Field, DeclAccessPair FoundDecl, const DeclarationNameInfo &MemberNameInfo)
bool checkArgCountRange(CallExpr *Call, unsigned MinArgCount, unsigned MaxArgCount)
Checks that a call expression's argument count is in the desired range.
ExternalSemaSource * getExternalSource() const
Definition Sema.h:942
ExprResult CreateBuiltinArraySubscriptExpr(Expr *Base, SourceLocation LLoc, Expr *Idx, SourceLocation RLoc)
bool LookupQualifiedName(LookupResult &R, DeclContext *LookupCtx, bool InUnqualifiedLookup=false)
Perform qualified name lookup into a given context.
ExprResult PerformCopyInitialization(const InitializedEntity &Entity, SourceLocation EqualLoc, ExprResult Init, bool TopLevelOfInitList=false, bool AllowExplicit=false)
ExprResult CreateBuiltinMatrixSubscriptExpr(Expr *Base, Expr *RowIdx, Expr *ColumnIdx, SourceLocation RBLoc)
Encodes a location in the source.
SourceLocation getLocWithOffset(IntTy Offset) const
Return a source location with the specified offset from this SourceLocation.
A trivial tuple used to represent a source range.
SourceLocation getEnd() const
SourceLocation getEndLoc() const LLVM_READONLY
Definition Stmt.cpp:367
void printPretty(raw_ostream &OS, PrinterHelper *Helper, const PrintingPolicy &Policy, unsigned Indentation=0, StringRef NewlineSymbol="\n", const ASTContext *Context=nullptr) const
SourceRange getSourceRange() const LLVM_READONLY
SourceLocation tokens are not useful in isolation - they are low level value objects created/interpre...
Definition Stmt.cpp:343
SourceLocation getBeginLoc() const LLVM_READONLY
Definition Stmt.cpp:355
static StringLiteral * Create(const ASTContext &Ctx, StringRef Str, StringLiteralKind Kind, bool Pascal, QualType Ty, ArrayRef< SourceLocation > Locs)
This is the "fully general" constructor that allows representation of strings formed from one or more...
Definition Expr.cpp:1188
void startDefinition()
Starts the definition of this tag declaration.
Definition Decl.cpp:4907
bool isUnion() const
Definition Decl.h:3928
bool isClass() const
Definition Decl.h:3927
Exposes information about the current target.
Definition TargetInfo.h:226
TargetOptions & getTargetOpts() const
Retrieve the target options.
Definition TargetInfo.h:326
const llvm::Triple & getTriple() const
Returns the target triple of the primary target.
StringRef getPlatformName() const
Retrieve the name of the platform as it is used in the availability attribute.
VersionTuple getPlatformMinVersion() const
Retrieve the minimum desired version of the platform, to which the program should be compiled.
std::string HLSLEntry
The entry point name for HLSL shader being compiled as specified by -E.
The top declaration context.
Definition Decl.h:105
SourceLocation getBeginLoc() const
Get the begin source location.
Definition TypeLoc.cpp:193
A container of type source information.
Definition TypeBase.h:8320
TypeLoc getTypeLoc() const
Return the TypeLoc wrapper for the type source info.
Definition TypeLoc.h:267
The base class of the type hierarchy.
Definition TypeBase.h:1839
bool isVoidType() const
Definition TypeBase.h:8952
bool isBooleanType() const
Definition TypeBase.h:9089
bool isIncompleteArrayType() const
Definition TypeBase.h:8693
CXXRecordDecl * getAsCXXRecordDecl() const
Retrieves the CXXRecordDecl that this type refers to, either because the type is a RecordType or beca...
Definition Type.h:26
bool isConstantArrayType() const
Definition TypeBase.h:8689
bool hasIntegerRepresentation() const
Determine whether this type has an integer representation of some sort, e.g., it is an integer type o...
Definition Type.cpp:2083
bool isArrayType() const
Definition TypeBase.h:8685
CXXRecordDecl * castAsCXXRecordDecl() const
Definition Type.h:36
bool isArithmeticType() const
Definition Type.cpp:2374
bool isConstantMatrixType() const
Definition TypeBase.h:8753
bool isHLSLBuiltinIntangibleType() const
Definition TypeBase.h:8897
bool isPointerType() const
Definition TypeBase.h:8586
CanQualType getCanonicalTypeUnqualified() const
bool isIntegerType() const
isIntegerType() does not include complex integers (a GCC extension).
Definition TypeBase.h:8996
const T * castAs() const
Member-template castAs<specific type>.
Definition TypeBase.h:9246
bool isReferenceType() const
Definition TypeBase.h:8610
bool isHLSLIntangibleType() const
Definition Type.cpp:5452
bool isEnumeralType() const
Definition TypeBase.h:8717
bool isScalarType() const
Definition TypeBase.h:9058
bool isIntegralType(const ASTContext &Ctx) const
Determine whether this type is an integral type.
Definition Type.cpp:2120
const Type * getArrayElementTypeNoTypeQual() const
If this is an array type, return the element type of the array, potentially with type qualifiers miss...
Definition Type.cpp:472
QualType getPointeeType() const
If this is a pointer, ObjC object pointer, or block pointer, this returns the respective pointee.
Definition Type.cpp:753
bool hasUnsignedIntegerRepresentation() const
Determine whether this type has an unsigned integer representation of some sort, e....
Definition Type.cpp:2328
bool isAggregateType() const
Determines whether the type is a C++ aggregate type or C aggregate or union type.
Definition Type.cpp:2455
ScalarTypeKind getScalarTypeKind() const
Given that this is a scalar type, classify it.
Definition Type.cpp:2406
bool hasSignedIntegerRepresentation() const
Determine whether this type has an signed integer representation of some sort, e.g....
Definition Type.cpp:2274
bool isHLSLResourceRecord() const
Definition Type.cpp:5439
bool hasFloatingRepresentation() const
Determine whether this type has a floating-point representation of some sort, e.g....
Definition Type.cpp:2349
bool isVectorType() const
Definition TypeBase.h:8725
bool isRealFloatingType() const
Floating point categories.
Definition Type.cpp:2357
bool isHLSLAttributedResourceType() const
Definition TypeBase.h:8909
@ STK_FloatingComplex
Definition TypeBase.h:2772
@ STK_ObjCObjectPointer
Definition TypeBase.h:2766
@ STK_IntegralComplex
Definition TypeBase.h:2771
@ STK_MemberPointer
Definition TypeBase.h:2767
bool isFloatingType() const
Definition Type.cpp:2341
bool isSamplerT() const
Definition TypeBase.h:8830
const T * getAs() const
Member-template getAs<specific type>'.
Definition TypeBase.h:9179
const Type * getUnqualifiedDesugaredType() const
Return the specified type with any "sugar" removed from the type, removing any typedefs,...
Definition Type.cpp:654
bool isRecordType() const
Definition TypeBase.h:8713
bool isHLSLResourceRecordArray() const
Definition Type.cpp:5443
void setType(QualType newType)
Definition Decl.h:724
QualType getType() const
Definition Decl.h:723
Represents a variable declaration or definition.
Definition Decl.h:926
void setInitStyle(InitializationStyle Style)
Definition Decl.h:1452
@ CallInit
Call-style initialization (C++98)
Definition Decl.h:934
void setStorageClass(StorageClass SC)
Definition Decl.cpp:2175
bool hasGlobalStorage() const
Returns true for all variables that do not have local storage.
Definition Decl.h:1226
void setInit(Expr *I)
Definition Decl.cpp:2489
StorageClass getStorageClass() const
Returns the storage class as written in the source.
Definition Decl.h:1168
Represents a GCC generic vector type.
Definition TypeBase.h:4183
unsigned getNumElements() const
Definition TypeBase.h:4198
QualType getElementType() const
Definition TypeBase.h:4197
Defines the clang::TargetInfo interface.
Definition SPIR.cpp:47
uint32_t getResourceDimensions(llvm::dxil::ResourceDimension Dim)
The JSON file list parser is used to communicate input to InstallAPI.
bool isa(CodeGen::Address addr)
Definition Address.h:330
if(T->getSizeExpr()) TRY_TO(TraverseStmt(const_cast< Expr * >(T -> getSizeExpr())))
static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
Definition SemaSPIRV.cpp:66
@ ICIS_NoInit
No in-class initializer.
Definition Specifiers.h:272
@ OK_Ordinary
An ordinary object is located at an address in memory.
Definition Specifiers.h:151
static bool CheckAllArgTypesAreCorrect(Sema *S, CallExpr *TheCall, llvm::ArrayRef< llvm::function_ref< bool(Sema *, SourceLocation, int, QualType)> > Checks)
Definition SemaSPIRV.cpp:49
@ AS_public
Definition Specifiers.h:124
@ AS_none
Definition Specifiers.h:127
@ SC_Static
Definition Specifiers.h:252
@ AANT_ArgumentIdentifier
@ Result
The result type of a method or function.
Definition TypeBase.h:905
@ Ordinary
This parameter uses ordinary ABI rules for its type.
Definition Specifiers.h:380
llvm::Expected< QualType > ExpectedType
LLVM_READONLY bool isDigit(unsigned char c)
Return true if this character is an ASCII digit: [0-9].
Definition CharInfo.h:114
static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall)
Definition SemaSPIRV.cpp:32
ExprResult ExprError()
Definition Ownership.h:265
@ Type
The name was classified as a type.
Definition Sema.h:564
LangAS
Defines the address space values used by the address space qualifier of QualType.
bool CreateHLSLAttributedResourceType(Sema &S, QualType Wrapped, ArrayRef< const Attr * > AttrList, QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo=nullptr)
CastKind
CastKind - The kind of operation required for a conversion.
ExprValueKind
The categorization of expression values, currently following the C++11 scheme.
Definition Specifiers.h:132
@ VK_PRValue
A pr-value expression (in the C++11 taxonomy) produces a temporary value.
Definition Specifiers.h:135
@ VK_LValue
An l-value expression is a reference to an object with independent storage.
Definition Specifiers.h:139
DynamicRecursiveASTVisitorBase< false > DynamicRecursiveASTVisitor
U cast(CodeGen::Address addr)
Definition Address.h:327
@ None
No keyword precedes the qualified type name.
Definition TypeBase.h:5896
ActionResult< Expr * > ExprResult
Definition Ownership.h:249
Visibility
Describes the different kinds of visibility that a declaration may have.
Definition Visibility.h:34
unsigned long uint64_t
unsigned int uint32_t
hash_code hash_value(const clang::dependencies::ModuleID &ID)
__DEVICE__ bool isnan(float __x)
__DEVICE__ _Tp abs(const std::complex< _Tp > &__c)
#define false
Definition stdbool.h:26
Describes how types, statements, expressions, and declarations should be printed.
void setCounterImplicitOrderID(unsigned Value) const
void setImplicitOrderID(unsigned Value) const
const SourceLocation & getLocation() const
Definition SemaHLSL.h:48
const llvm::hlsl::rootsig::RootElement & getElement() const
Definition SemaHLSL.h:47