clang 23.0.0git
SemaHLSL.cpp
Go to the documentation of this file.
1//===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This implements Semantic Analysis for HLSL constructs.
9//===----------------------------------------------------------------------===//
10
11#include "clang/Sema/SemaHLSL.h"
14#include "clang/AST/Attr.h"
15#include "clang/AST/Attrs.inc"
16#include "clang/AST/Decl.h"
17#include "clang/AST/DeclBase.h"
18#include "clang/AST/DeclCXX.h"
21#include "clang/AST/Expr.h"
23#include "clang/AST/Type.h"
24#include "clang/AST/TypeBase.h"
25#include "clang/AST/TypeLoc.h"
29#include "clang/Basic/LLVM.h"
34#include "clang/Sema/Lookup.h"
36#include "clang/Sema/Sema.h"
37#include "clang/Sema/Template.h"
38#include "llvm/ADT/ArrayRef.h"
39#include "llvm/ADT/STLExtras.h"
40#include "llvm/ADT/SmallVector.h"
41#include "llvm/ADT/StringExtras.h"
42#include "llvm/ADT/StringRef.h"
43#include "llvm/ADT/Twine.h"
44#include "llvm/Frontend/HLSL/HLSLBinding.h"
45#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
46#include "llvm/Support/Casting.h"
47#include "llvm/Support/DXILABI.h"
48#include "llvm/Support/ErrorHandling.h"
49#include "llvm/Support/FormatVariadic.h"
50#include "llvm/TargetParser/Triple.h"
51#include <cmath>
52#include <cstddef>
53#include <iterator>
54#include <utility>
55
56using namespace clang;
57using namespace clang::hlsl;
58using RegisterType = HLSLResourceBindingAttr::RegisterType;
59
61 CXXRecordDecl *StructDecl);
62
64 switch (RC) {
65 case ResourceClass::SRV:
66 return RegisterType::SRV;
67 case ResourceClass::UAV:
68 return RegisterType::UAV;
69 case ResourceClass::CBuffer:
70 return RegisterType::CBuffer;
71 case ResourceClass::Sampler:
72 return RegisterType::Sampler;
73 }
74 llvm_unreachable("unexpected ResourceClass value");
75}
76
77static RegisterType getRegisterType(const HLSLAttributedResourceType *ResTy) {
78 return getRegisterType(ResTy->getAttrs().ResourceClass);
79}
80
81// Converts the first letter of string Slot to RegisterType.
82// Returns false if the letter does not correspond to a valid register type.
83static bool convertToRegisterType(StringRef Slot, RegisterType *RT) {
84 assert(RT != nullptr);
85 switch (Slot[0]) {
86 case 't':
87 case 'T':
88 *RT = RegisterType::SRV;
89 return true;
90 case 'u':
91 case 'U':
92 *RT = RegisterType::UAV;
93 return true;
94 case 'b':
95 case 'B':
96 *RT = RegisterType::CBuffer;
97 return true;
98 case 's':
99 case 'S':
100 *RT = RegisterType::Sampler;
101 return true;
102 case 'c':
103 case 'C':
104 *RT = RegisterType::C;
105 return true;
106 case 'i':
107 case 'I':
108 *RT = RegisterType::I;
109 return true;
110 default:
111 return false;
112 }
113}
114
116 switch (RT) {
117 case RegisterType::SRV:
118 return ResourceClass::SRV;
119 case RegisterType::UAV:
120 return ResourceClass::UAV;
121 case RegisterType::CBuffer:
122 return ResourceClass::CBuffer;
123 case RegisterType::Sampler:
124 return ResourceClass::Sampler;
125 case RegisterType::C:
126 case RegisterType::I:
127 // Deliberately falling through to the unreachable below.
128 break;
129 }
130 llvm_unreachable("unexpected RegisterType value");
131}
132
134 const auto *BT = dyn_cast<BuiltinType>(Type);
135 if (!BT) {
136 if (!Type->isEnumeralType())
137 return Builtin::NotBuiltin;
138 return Builtin::BI__builtin_get_spirv_spec_constant_int;
139 }
140
141 switch (BT->getKind()) {
142 case BuiltinType::Bool:
143 return Builtin::BI__builtin_get_spirv_spec_constant_bool;
144 case BuiltinType::Short:
145 return Builtin::BI__builtin_get_spirv_spec_constant_short;
146 case BuiltinType::Int:
147 return Builtin::BI__builtin_get_spirv_spec_constant_int;
148 case BuiltinType::LongLong:
149 return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
150 case BuiltinType::UShort:
151 return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
152 case BuiltinType::UInt:
153 return Builtin::BI__builtin_get_spirv_spec_constant_uint;
154 case BuiltinType::ULongLong:
155 return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
156 case BuiltinType::Half:
157 return Builtin::BI__builtin_get_spirv_spec_constant_half;
158 case BuiltinType::Float:
159 return Builtin::BI__builtin_get_spirv_spec_constant_float;
160 case BuiltinType::Double:
161 return Builtin::BI__builtin_get_spirv_spec_constant_double;
162 default:
163 return Builtin::NotBuiltin;
164 }
165}
166
168 ResourceClass ResClass) {
169 assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
170 "DeclBindingInfo already added");
171 assert(!hasBindingInfoForDecl(VD) || BindingsList.back().Decl == VD);
172 // VarDecl may have multiple entries for different resource classes.
173 // DeclToBindingListIndex stores the index of the first binding we saw
174 // for this decl. If there are any additional ones then that index
175 // shouldn't be updated.
176 DeclToBindingListIndex.try_emplace(VD, BindingsList.size());
177 return &BindingsList.emplace_back(VD, ResClass);
178}
179
181 ResourceClass ResClass) {
182 auto Entry = DeclToBindingListIndex.find(VD);
183 if (Entry != DeclToBindingListIndex.end()) {
184 for (unsigned Index = Entry->getSecond();
185 Index < BindingsList.size() && BindingsList[Index].Decl == VD;
186 ++Index) {
187 if (BindingsList[Index].ResClass == ResClass)
188 return &BindingsList[Index];
189 }
190 }
191 return nullptr;
192}
193
195 return DeclToBindingListIndex.contains(VD);
196}
197
199
200Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
201 SourceLocation KwLoc, IdentifierInfo *Ident,
202 SourceLocation IdentLoc,
203 SourceLocation LBrace) {
204 // For anonymous namespace, take the location of the left brace.
205 DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
207 getASTContext(), LexicalParent, CBuffer, KwLoc, Ident, IdentLoc, LBrace);
208
209 // if CBuffer is false, then it's a TBuffer
210 auto RC = CBuffer ? llvm::hlsl::ResourceClass::CBuffer
211 : llvm::hlsl::ResourceClass::SRV;
212 Result->addAttr(HLSLResourceClassAttr::CreateImplicit(getASTContext(), RC));
213
214 SemaRef.PushOnScopeChains(Result, BufferScope);
215 SemaRef.PushDeclContext(BufferScope, Result);
216
217 return Result;
218}
219
220static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context,
221 QualType T) {
222 // Arrays and Structs are always aligned to new buffer rows
223 if (T->isArrayType() || T->isStructureType())
224 return 16;
225
226 // Vectors are aligned to the type they contain
227 if (const VectorType *VT = T->getAs<VectorType>())
228 return calculateLegacyCbufferFieldAlign(Context, VT->getElementType());
229
230 assert(Context.getTypeSize(T) <= 64 &&
231 "Scalar bit widths larger than 64 not supported");
232
233 // Scalar types are aligned to their byte width
234 return Context.getTypeSize(T) / 8;
235}
236
237// Calculate the size of a legacy cbuffer type in bytes based on
238// https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules
239static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
240 QualType T) {
241 constexpr unsigned CBufferAlign = 16;
242 if (const auto *RD = T->getAsRecordDecl()) {
243 unsigned Size = 0;
244 for (const FieldDecl *Field : RD->fields()) {
245 QualType Ty = Field->getType();
246 unsigned FieldSize = calculateLegacyCbufferSize(Context, Ty);
247 unsigned FieldAlign = calculateLegacyCbufferFieldAlign(Context, Ty);
248
249 // If the field crosses the row boundary after alignment it drops to the
250 // next row
251 unsigned AlignSize = llvm::alignTo(Size, FieldAlign);
252 if ((AlignSize % CBufferAlign) + FieldSize > CBufferAlign) {
253 FieldAlign = CBufferAlign;
254 }
255
256 Size = llvm::alignTo(Size, FieldAlign);
257 Size += FieldSize;
258 }
259 return Size;
260 }
261
262 if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
263 unsigned ElementCount = AT->getSize().getZExtValue();
264 if (ElementCount == 0)
265 return 0;
266
267 unsigned ElementSize =
268 calculateLegacyCbufferSize(Context, AT->getElementType());
269 unsigned AlignedElementSize = llvm::alignTo(ElementSize, CBufferAlign);
270 return AlignedElementSize * (ElementCount - 1) + ElementSize;
271 }
272
273 if (const VectorType *VT = T->getAs<VectorType>()) {
274 unsigned ElementCount = VT->getNumElements();
275 unsigned ElementSize =
276 calculateLegacyCbufferSize(Context, VT->getElementType());
277 return ElementSize * ElementCount;
278 }
279
280 return Context.getTypeSize(T) / 8;
281}
282
283// Validate packoffset:
284// - if packoffset it used it must be set on all declarations inside the buffer
285// - packoffset ranges must not overlap
286static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) {
288
289 // Make sure the packoffset annotations are either on all declarations
290 // or on none.
291 bool HasPackOffset = false;
292 bool HasNonPackOffset = false;
293 for (auto *Field : BufDecl->buffer_decls()) {
294 VarDecl *Var = dyn_cast<VarDecl>(Field);
295 if (!Var)
296 continue;
297 if (Field->hasAttr<HLSLPackOffsetAttr>()) {
298 PackOffsetVec.emplace_back(Var, Field->getAttr<HLSLPackOffsetAttr>());
299 HasPackOffset = true;
300 } else {
301 HasNonPackOffset = true;
302 }
303 }
304
305 if (!HasPackOffset)
306 return;
307
308 if (HasNonPackOffset)
309 S.Diag(BufDecl->getLocation(), diag::warn_hlsl_packoffset_mix);
310
311 // Make sure there is no overlap in packoffset - sort PackOffsetVec by offset
312 // and compare adjacent values.
313 bool IsValid = true;
314 ASTContext &Context = S.getASTContext();
315 std::sort(PackOffsetVec.begin(), PackOffsetVec.end(),
316 [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,
317 const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {
318 return LHS.second->getOffsetInBytes() <
319 RHS.second->getOffsetInBytes();
320 });
321 for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {
322 VarDecl *Var = PackOffsetVec[i].first;
323 HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;
324 unsigned Size = calculateLegacyCbufferSize(Context, Var->getType());
325 unsigned Begin = Attr->getOffsetInBytes();
326 unsigned End = Begin + Size;
327 unsigned NextBegin = PackOffsetVec[i + 1].second->getOffsetInBytes();
328 if (End > NextBegin) {
329 VarDecl *NextVar = PackOffsetVec[i + 1].first;
330 S.Diag(NextVar->getLocation(), diag::err_hlsl_packoffset_overlap)
331 << NextVar << Var;
332 IsValid = false;
333 }
334 }
335 BufDecl->setHasValidPackoffset(IsValid);
336}
337
338// Returns true if the array has a zero size = if any of the dimensions is 0
339static bool isZeroSizedArray(const ConstantArrayType *CAT) {
340 while (CAT && !CAT->isZeroSize())
341 CAT = dyn_cast<ConstantArrayType>(
343 return CAT != nullptr;
344}
345
347 const Type *Ty = VD->getType().getTypePtr();
349}
350
351static const HLSLAttributedResourceType *
353 assert(VD->getType()->isHLSLResourceRecordArray() &&
354 "expected array of resource records");
355 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
356 while (const ArrayType *AT = dyn_cast<ArrayType>(Ty))
358 return HLSLAttributedResourceType::findHandleTypeOnResource(Ty);
359}
360
361// Returns true if the type is a leaf element type that is not valid to be
362// included in HLSL Buffer, such as a resource class, empty struct, zero-sized
363// array, or a builtin intangible type. Returns false it is a valid leaf element
364// type or if it is a record type that needs to be inspected further.
368 return true;
369 if (const auto *RD = Ty->getAsCXXRecordDecl())
370 return RD->isEmpty();
371 if (Ty->isConstantArrayType() &&
373 return true;
375 return true;
376 return false;
377}
378
379// Returns true if the struct contains at least one element that prevents it
380// from being included inside HLSL Buffer as is, such as an intangible type,
381// empty struct, or zero-sized array. If it does, a new implicit layout struct
382// needs to be created for HLSL Buffer use that will exclude these unwanted
383// declarations (see createHostLayoutStruct function).
385 if (RD->isHLSLIntangible() || RD->isEmpty())
386 return true;
387 // check fields
388 for (const FieldDecl *Field : RD->fields()) {
389 QualType Ty = Field->getType();
391 return true;
392 if (const auto *RD = Ty->getAsCXXRecordDecl();
394 return true;
395 }
396 // check bases
397 for (const CXXBaseSpecifier &Base : RD->bases())
399 Base.getType()->castAsCXXRecordDecl()))
400 return true;
401 return false;
402}
403
405 DeclContext *DC) {
406 CXXRecordDecl *RD = nullptr;
407 for (NamedDecl *Decl :
409 if (CXXRecordDecl *FoundRD = dyn_cast<CXXRecordDecl>(Decl)) {
410 assert(RD == nullptr &&
411 "there should be at most 1 record by a given name in a scope");
412 RD = FoundRD;
413 }
414 }
415 return RD;
416}
417
418// Creates a name for buffer layout struct using the provide name base.
419// If the name must be unique (not previously defined), a suffix is added
420// until a unique name is found.
422 bool MustBeUnique) {
423 ASTContext &AST = S.getASTContext();
424
425 IdentifierInfo *NameBaseII = BaseDecl->getIdentifier();
426 llvm::SmallString<64> Name("__cblayout_");
427 if (NameBaseII) {
428 Name.append(NameBaseII->getName());
429 } else {
430 // anonymous struct
431 Name.append("anon");
432 MustBeUnique = true;
433 }
434
435 size_t NameLength = Name.size();
436 IdentifierInfo *II = &AST.Idents.get(Name, tok::TokenKind::identifier);
437 if (!MustBeUnique)
438 return II;
439
440 unsigned suffix = 0;
441 while (true) {
442 if (suffix != 0) {
443 Name.append("_");
444 Name.append(llvm::Twine(suffix).str());
445 II = &AST.Idents.get(Name, tok::TokenKind::identifier);
446 }
447 if (!findRecordDeclInContext(II, BaseDecl->getDeclContext()))
448 return II;
449 // declaration with that name already exists - increment suffix and try
450 // again until unique name is found
451 suffix++;
452 Name.truncate(NameLength);
453 };
454}
455
456// Creates a field declaration of given name and type for HLSL buffer layout
457// struct. Returns nullptr if the type cannot be use in HLSL Buffer layout.
459 IdentifierInfo *II,
460 CXXRecordDecl *LayoutStruct) {
462 return nullptr;
463
464 if (auto *RD = Ty->getAsCXXRecordDecl()) {
466 RD = createHostLayoutStruct(S, RD);
467 if (!RD)
468 return nullptr;
470 }
471 }
472
473 QualType QT = QualType(Ty, 0);
474 ASTContext &AST = S.getASTContext();
476 auto *Field = FieldDecl::Create(AST, LayoutStruct, SourceLocation(),
477 SourceLocation(), II, QT, TSI, nullptr, false,
479 Field->setAccess(AccessSpecifier::AS_public);
480 return Field;
481}
482
483// Creates host layout struct for a struct included in HLSL Buffer.
484// The layout struct will include only fields that are allowed in HLSL buffer.
485// These fields will be filtered out:
486// - resource classes
487// - empty structs
488// - zero-sized arrays
489// Returns nullptr if the resulting layout struct would be empty.
491 CXXRecordDecl *StructDecl) {
492 assert(requiresImplicitBufferLayoutStructure(StructDecl) &&
493 "struct is already HLSL buffer compatible");
494
495 ASTContext &AST = S.getASTContext();
496 DeclContext *DC = StructDecl->getDeclContext();
497 IdentifierInfo *II = getHostLayoutStructName(S, StructDecl, false);
498
499 // reuse existing if the layout struct if it already exists
500 if (CXXRecordDecl *RD = findRecordDeclInContext(II, DC))
501 return RD;
502
503 CXXRecordDecl *LS =
504 CXXRecordDecl::Create(AST, TagDecl::TagKind::Struct, DC, SourceLocation(),
505 SourceLocation(), II);
506 LS->setImplicit(true);
507 LS->addAttr(PackedAttr::CreateImplicit(AST));
508 LS->startDefinition();
509
510 // copy base struct, create HLSL Buffer compatible version if needed
511 if (unsigned NumBases = StructDecl->getNumBases()) {
512 assert(NumBases == 1 && "HLSL supports only one base type");
513 (void)NumBases;
514 CXXBaseSpecifier Base = *StructDecl->bases_begin();
515 CXXRecordDecl *BaseDecl = Base.getType()->castAsCXXRecordDecl();
517 BaseDecl = createHostLayoutStruct(S, BaseDecl);
518 if (BaseDecl) {
519 TypeSourceInfo *TSI =
521 Base = CXXBaseSpecifier(SourceRange(), false, StructDecl->isClass(),
522 AS_none, TSI, SourceLocation());
523 }
524 }
525 if (BaseDecl) {
526 const CXXBaseSpecifier *BasesArray[1] = {&Base};
527 LS->setBases(BasesArray, 1);
528 }
529 }
530
531 // filter struct fields
532 for (const FieldDecl *FD : StructDecl->fields()) {
533 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
534 if (FieldDecl *NewFD =
535 createFieldForHostLayoutStruct(S, Ty, FD->getIdentifier(), LS))
536 LS->addDecl(NewFD);
537 }
538 LS->completeDefinition();
539
540 if (LS->field_empty() && LS->getNumBases() == 0)
541 return nullptr;
542
543 DC->addDecl(LS);
544 return LS;
545}
546
547// Creates host layout struct for HLSL Buffer. The struct will include only
548// fields of types that are allowed in HLSL buffer and it will filter out:
549// - static or groupshared variable declarations
550// - resource classes
551// - empty structs
552// - zero-sized arrays
553// - non-variable declarations
554// The layout struct will be added to the HLSLBufferDecl declarations.
556 ASTContext &AST = S.getASTContext();
557 IdentifierInfo *II = getHostLayoutStructName(S, BufDecl, true);
558
559 CXXRecordDecl *LS =
560 CXXRecordDecl::Create(AST, TagDecl::TagKind::Struct, BufDecl,
562 LS->addAttr(PackedAttr::CreateImplicit(AST));
563 LS->setImplicit(true);
564 LS->startDefinition();
565
566 for (Decl *D : BufDecl->buffer_decls()) {
567 VarDecl *VD = dyn_cast<VarDecl>(D);
568 if (!VD || VD->getStorageClass() == SC_Static ||
570 continue;
571 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
572 if (FieldDecl *FD =
574 // add the field decl to the layout struct
575 LS->addDecl(FD);
576 // update address space of the original decl to hlsl_constant
577 QualType NewTy =
579 VD->setType(NewTy);
580 }
581 }
582 LS->completeDefinition();
583 BufDecl->addLayoutStruct(LS);
584}
585
587 uint32_t ImplicitBindingOrderID) {
588 auto *Attr =
589 HLSLResourceBindingAttr::CreateImplicit(S.getASTContext(), "", "0", {});
590 Attr->setBinding(RT, std::nullopt, 0);
591 Attr->setImplicitBindingOrderID(ImplicitBindingOrderID);
592 D->addAttr(Attr);
593}
594
595// Handle end of cbuffer/tbuffer declaration
597 auto *BufDecl = cast<HLSLBufferDecl>(Dcl);
598 BufDecl->setRBraceLoc(RBrace);
599
600 validatePackoffset(SemaRef, BufDecl);
601
603
604 // Handle implicit binding if needed.
605 ResourceBindingAttrs ResourceAttrs(Dcl);
606 if (!ResourceAttrs.isExplicit()) {
607 SemaRef.Diag(Dcl->getLocation(), diag::warn_hlsl_implicit_binding);
608 // Use HLSLResourceBindingAttr to transfer implicit binding order_ID
609 // to codegen. If it does not exist, create an implicit attribute.
610 uint32_t OrderID = getNextImplicitBindingOrderID();
611 if (ResourceAttrs.hasBinding())
612 ResourceAttrs.setImplicitOrderID(OrderID);
613 else
615 BufDecl->isCBuffer() ? RegisterType::CBuffer
616 : RegisterType::SRV,
617 OrderID);
618 }
619
620 SemaRef.PopDeclContext();
621}
622
623HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
624 const AttributeCommonInfo &AL,
625 int X, int Y, int Z) {
626 if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
627 if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
628 Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
629 Diag(AL.getLoc(), diag::note_conflicting_attribute);
630 }
631 return nullptr;
632 }
633 return ::new (getASTContext())
634 HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
635}
636
638 const AttributeCommonInfo &AL,
639 int Min, int Max, int Preferred,
640 int SpelledArgsCount) {
641 if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) {
642 if (WS->getMin() != Min || WS->getMax() != Max ||
643 WS->getPreferred() != Preferred ||
644 WS->getSpelledArgsCount() != SpelledArgsCount) {
645 Diag(WS->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
646 Diag(AL.getLoc(), diag::note_conflicting_attribute);
647 }
648 return nullptr;
649 }
650 HLSLWaveSizeAttr *Result = ::new (getASTContext())
651 HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
652 Result->setSpelledArgsCount(SpelledArgsCount);
653 return Result;
654}
655
656HLSLVkConstantIdAttr *
658 int Id) {
659
661 if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
662 Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL;
663 return nullptr;
664 }
665
666 auto *VD = cast<VarDecl>(D);
667
668 if (getSpecConstBuiltinId(VD->getType()->getUnqualifiedDesugaredType()) ==
670 Diag(VD->getLocation(), diag::err_specialization_const);
671 return nullptr;
672 }
673
674 if (!VD->getType().isConstQualified()) {
675 Diag(VD->getLocation(), diag::err_specialization_const);
676 return nullptr;
677 }
678
679 if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
680 if (CI->getId() != Id) {
681 Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
682 Diag(AL.getLoc(), diag::note_conflicting_attribute);
683 }
684 return nullptr;
685 }
686
687 HLSLVkConstantIdAttr *Result =
688 ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
689 return Result;
690}
691
692HLSLShaderAttr *
694 llvm::Triple::EnvironmentType ShaderType) {
695 if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
696 if (NT->getType() != ShaderType) {
697 Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
698 Diag(AL.getLoc(), diag::note_conflicting_attribute);
699 }
700 return nullptr;
701 }
702 return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL);
703}
704
705HLSLParamModifierAttr *
707 HLSLParamModifierAttr::Spelling Spelling) {
708 // We can only merge an `in` attribute with an `out` attribute. All other
709 // combinations of duplicated attributes are ill-formed.
710 if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
711 if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
712 (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
713 D->dropAttr<HLSLParamModifierAttr>();
714 SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
715 return HLSLParamModifierAttr::Create(
716 getASTContext(), /*MergedSpelling=*/true, AdjustedRange,
717 HLSLParamModifierAttr::Keyword_inout);
718 }
719 Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
720 Diag(PA->getLocation(), diag::note_conflicting_attribute);
721 return nullptr;
722 }
723 return HLSLParamModifierAttr::Create(getASTContext(), AL);
724}
725
728
730 return;
731
732 // If we have specified a root signature to override the entry function then
733 // attach it now
734 HLSLRootSignatureDecl *SignatureDecl =
736 if (SignatureDecl) {
737 FD->dropAttr<RootSignatureAttr>();
738 // We could look up the SourceRange of the macro here as well
739 AttributeCommonInfo AL(RootSigOverrideIdent, AttributeScopeInfo(),
740 SourceRange(), ParsedAttr::Form::Microsoft());
741 FD->addAttr(::new (getASTContext()) RootSignatureAttr(
742 getASTContext(), AL, RootSigOverrideIdent, SignatureDecl));
743 }
744
745 llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
746 if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) {
747 if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
748 // The entry point is already annotated - check that it matches the
749 // triple.
750 if (Shader->getType() != Env) {
751 Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
752 << Shader;
753 FD->setInvalidDecl();
754 }
755 } else {
756 // Implicitly add the shader attribute if the entry function isn't
757 // explicitly annotated.
758 FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env,
759 FD->getBeginLoc()));
760 }
761 } else {
762 switch (Env) {
763 case llvm::Triple::UnknownEnvironment:
764 case llvm::Triple::Library:
765 break;
766 case llvm::Triple::RootSignature:
767 llvm_unreachable("rootsig environment has no functions");
768 default:
769 llvm_unreachable("Unhandled environment in triple");
770 }
771 }
772}
773
774static bool isVkPipelineBuiltin(const ASTContext &AstContext, FunctionDecl *FD,
775 HLSLAppliedSemanticAttr *Semantic,
776 bool IsInput) {
777 if (AstContext.getTargetInfo().getTriple().getOS() != llvm::Triple::Vulkan)
778 return false;
779
780 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
781 assert(ShaderAttr && "Entry point has no shader attribute");
782 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
783 auto SemanticName = Semantic->getSemanticName().upper();
784
785 // The SV_Position semantic is lowered to:
786 // - Position built-in for vertex output.
787 // - FragCoord built-in for fragment input.
788 if (SemanticName == "SV_POSITION") {
789 return (ST == llvm::Triple::Vertex && !IsInput) ||
790 (ST == llvm::Triple::Pixel && IsInput);
791 }
792
793 return false;
794}
795
796bool SemaHLSL::determineActiveSemanticOnScalar(FunctionDecl *FD,
797 DeclaratorDecl *OutputDecl,
799 SemanticInfo &ActiveSemantic,
800 SemaHLSL::SemanticContext &SC) {
801 if (ActiveSemantic.Semantic == nullptr) {
802 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
803 if (ActiveSemantic.Semantic)
804 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
805 }
806
807 if (!ActiveSemantic.Semantic) {
808 Diag(D->getLocation(), diag::err_hlsl_missing_semantic_annotation);
809 return false;
810 }
811
812 auto *A = ::new (getASTContext())
813 HLSLAppliedSemanticAttr(getASTContext(), *ActiveSemantic.Semantic,
814 ActiveSemantic.Semantic->getAttrName()->getName(),
815 ActiveSemantic.Index.value_or(0));
816 if (!A)
818
819 checkSemanticAnnotation(FD, D, A, SC);
820 OutputDecl->addAttr(A);
821
822 unsigned Location = ActiveSemantic.Index.value_or(0);
823
825 SC.CurrentIOType & IOType::In)) {
826 bool HasVkLocation = false;
827 if (auto *A = D->getAttr<HLSLVkLocationAttr>()) {
828 HasVkLocation = true;
829 Location = A->getLocation();
830 }
831
832 if (SC.UsesExplicitVkLocations.value_or(HasVkLocation) != HasVkLocation) {
833 Diag(D->getLocation(), diag::err_hlsl_semantic_partial_explicit_indexing);
834 return false;
835 }
836 SC.UsesExplicitVkLocations = HasVkLocation;
837 }
838
839 const ConstantArrayType *AT = dyn_cast<ConstantArrayType>(D->getType());
840 unsigned ElementCount = AT ? AT->getZExtSize() : 1;
841 ActiveSemantic.Index = Location + ElementCount;
842
843 Twine BaseName = Twine(ActiveSemantic.Semantic->getAttrName()->getName());
844 for (unsigned I = 0; I < ElementCount; ++I) {
845 Twine VariableName = BaseName.concat(Twine(Location + I));
846
847 auto [_, Inserted] = SC.ActiveSemantics.insert(VariableName.str());
848 if (!Inserted) {
849 Diag(D->getLocation(), diag::err_hlsl_semantic_index_overlap)
850 << VariableName.str();
851 return false;
852 }
853 }
854
855 return true;
856}
857
858bool SemaHLSL::determineActiveSemantic(FunctionDecl *FD,
859 DeclaratorDecl *OutputDecl,
861 SemanticInfo &ActiveSemantic,
862 SemaHLSL::SemanticContext &SC) {
863 if (ActiveSemantic.Semantic == nullptr) {
864 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
865 if (ActiveSemantic.Semantic)
866 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
867 }
868
869 const Type *T = D == FD ? &*FD->getReturnType() : &*D->getType();
871
872 const RecordType *RT = dyn_cast<RecordType>(T);
873 if (!RT)
874 return determineActiveSemanticOnScalar(FD, OutputDecl, D, ActiveSemantic,
875 SC);
876
877 const RecordDecl *RD = RT->getDecl();
878 for (FieldDecl *Field : RD->fields()) {
879 SemanticInfo Info = ActiveSemantic;
880 if (!determineActiveSemantic(FD, OutputDecl, Field, Info, SC)) {
881 Diag(Field->getLocation(), diag::note_hlsl_semantic_used_here) << Field;
882 return false;
883 }
884 if (ActiveSemantic.Semantic)
885 ActiveSemantic = Info;
886 }
887
888 return true;
889}
890
892 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
893 assert(ShaderAttr && "Entry point has no shader attribute");
894 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
896 VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
897 switch (ST) {
898 case llvm::Triple::Pixel:
899 case llvm::Triple::Vertex:
900 case llvm::Triple::Geometry:
901 case llvm::Triple::Hull:
902 case llvm::Triple::Domain:
903 case llvm::Triple::RayGeneration:
904 case llvm::Triple::Intersection:
905 case llvm::Triple::AnyHit:
906 case llvm::Triple::ClosestHit:
907 case llvm::Triple::Miss:
908 case llvm::Triple::Callable:
909 if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
910 diagnoseAttrStageMismatch(NT, ST,
911 {llvm::Triple::Compute,
912 llvm::Triple::Amplification,
913 llvm::Triple::Mesh});
914 FD->setInvalidDecl();
915 }
916 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
917 diagnoseAttrStageMismatch(WS, ST,
918 {llvm::Triple::Compute,
919 llvm::Triple::Amplification,
920 llvm::Triple::Mesh});
921 FD->setInvalidDecl();
922 }
923 break;
924
925 case llvm::Triple::Compute:
926 case llvm::Triple::Amplification:
927 case llvm::Triple::Mesh:
928 if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
929 Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
930 << llvm::Triple::getEnvironmentTypeName(ST);
931 FD->setInvalidDecl();
932 }
933 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
934 if (Ver < VersionTuple(6, 6)) {
935 Diag(WS->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model)
936 << WS << "6.6";
937 FD->setInvalidDecl();
938 } else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
939 Diag(
940 WS->getLocation(),
941 diag::err_hlsl_attribute_number_arguments_insufficient_shader_model)
942 << WS << WS->getSpelledArgsCount() << "6.8";
943 FD->setInvalidDecl();
944 }
945 }
946 break;
947 case llvm::Triple::RootSignature:
948 llvm_unreachable("rootsig environment has no function entry point");
949 default:
950 llvm_unreachable("Unhandled environment in triple");
951 }
952
953 SemaHLSL::SemanticContext InputSC = {};
954 InputSC.CurrentIOType = IOType::In;
955
956 for (ParmVarDecl *Param : FD->parameters()) {
957 SemanticInfo ActiveSemantic;
958 ActiveSemantic.Semantic = Param->getAttr<HLSLParsedSemanticAttr>();
959 if (ActiveSemantic.Semantic)
960 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
961
962 // FIXME: Verify output semantics in parameters.
963 if (!determineActiveSemantic(FD, Param, Param, ActiveSemantic, InputSC)) {
964 Diag(Param->getLocation(), diag::note_previous_decl) << Param;
965 FD->setInvalidDecl();
966 }
967 }
968
969 SemanticInfo ActiveSemantic;
970 SemaHLSL::SemanticContext OutputSC = {};
971 OutputSC.CurrentIOType = IOType::Out;
972 ActiveSemantic.Semantic = FD->getAttr<HLSLParsedSemanticAttr>();
973 if (ActiveSemantic.Semantic)
974 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
975 if (!FD->getReturnType()->isVoidType())
976 determineActiveSemantic(FD, FD, FD, ActiveSemantic, OutputSC);
977}
978
979void SemaHLSL::checkSemanticAnnotation(
980 FunctionDecl *EntryPoint, const Decl *Param,
981 const HLSLAppliedSemanticAttr *SemanticAttr, const SemanticContext &SC) {
982 auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
983 assert(ShaderAttr && "Entry point has no shader attribute");
984 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
985
986 auto SemanticName = SemanticAttr->getSemanticName().upper();
987 if (SemanticName == "SV_DISPATCHTHREADID" ||
988 SemanticName == "SV_GROUPINDEX" || SemanticName == "SV_GROUPTHREADID" ||
989 SemanticName == "SV_GROUPID") {
990
991 if (ST != llvm::Triple::Compute)
992 diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
993 {{llvm::Triple::Compute, IOType::In}});
994
995 if (SemanticAttr->getSemanticIndex() != 0) {
996 std::string PrettyName =
997 "'" + SemanticAttr->getSemanticName().str() + "'";
998 Diag(SemanticAttr->getLoc(),
999 diag::err_hlsl_semantic_indexing_not_supported)
1000 << PrettyName;
1001 }
1002 return;
1003 }
1004
1005 if (SemanticName == "SV_POSITION") {
1006 // SV_Position can be an input or output in vertex shaders,
1007 // but only an input in pixel shaders.
1008 diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
1009 {{llvm::Triple::Vertex, IOType::InOut},
1010 {llvm::Triple::Pixel, IOType::In}});
1011 return;
1012 }
1013
1014 if (SemanticName == "SV_TARGET") {
1015 diagnoseSemanticStageMismatch(SemanticAttr, ST, SC.CurrentIOType,
1016 {{llvm::Triple::Pixel, IOType::Out}});
1017 return;
1018 }
1019
1020 // FIXME: catch-all for non-implemented system semantics reaching this
1021 // location.
1022 if (SemanticAttr->getAttrName()->getName().starts_with_insensitive("SV_"))
1023 llvm_unreachable("Unknown SemanticAttr");
1024}
1025
1026void SemaHLSL::diagnoseAttrStageMismatch(
1027 const Attr *A, llvm::Triple::EnvironmentType Stage,
1028 std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
1029 SmallVector<StringRef, 8> StageStrings;
1030 llvm::transform(AllowedStages, std::back_inserter(StageStrings),
1031 [](llvm::Triple::EnvironmentType ST) {
1032 return StringRef(
1033 HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST));
1034 });
1035 Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
1036 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Stage)
1037 << (AllowedStages.size() != 1) << join(StageStrings, ", ");
1038}
1039
1040void SemaHLSL::diagnoseSemanticStageMismatch(
1041 const Attr *A, llvm::Triple::EnvironmentType Stage, IOType CurrentIOType,
1042 std::initializer_list<SemanticStageInfo> Allowed) {
1043
1044 for (auto &Case : Allowed) {
1045 if (Case.Stage != Stage)
1046 continue;
1047
1048 if (CurrentIOType & Case.AllowedIOTypesMask)
1049 return;
1050
1051 SmallVector<std::string, 8> ValidCases;
1052 llvm::transform(
1053 Allowed, std::back_inserter(ValidCases), [](SemanticStageInfo Case) {
1054 SmallVector<std::string, 2> ValidType;
1055 if (Case.AllowedIOTypesMask & IOType::In)
1056 ValidType.push_back("input");
1057 if (Case.AllowedIOTypesMask & IOType::Out)
1058 ValidType.push_back("output");
1059 return std::string(
1060 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Case.Stage)) +
1061 " " + join(ValidType, "/");
1062 });
1063 Diag(A->getLoc(), diag::err_hlsl_semantic_unsupported_iotype_for_stage)
1064 << A->getAttrName() << (CurrentIOType & IOType::In ? "input" : "output")
1065 << llvm::Triple::getEnvironmentTypeName(Case.Stage)
1066 << join(ValidCases, ", ");
1067 return;
1068 }
1069
1070 SmallVector<StringRef, 8> StageStrings;
1071 llvm::transform(
1072 Allowed, std::back_inserter(StageStrings), [](SemanticStageInfo Case) {
1073 return StringRef(
1074 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Case.Stage));
1075 });
1076
1077 Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
1078 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Stage)
1079 << (Allowed.size() != 1) << join(StageStrings, ", ");
1080}
1081
1082template <CastKind Kind>
1083static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
1084 if (const auto *VTy = Ty->getAs<VectorType>())
1085 Ty = VTy->getElementType();
1086 Ty = S.getASTContext().getExtVectorType(Ty, Sz);
1087 E = S.ImpCastExprToType(E.get(), Ty, Kind);
1088}
1089
1090template <CastKind Kind>
1092 E = S.ImpCastExprToType(E.get(), Ty, Kind);
1093 return Ty;
1094}
1095
1097 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1098 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1099 bool LHSFloat = LElTy->isRealFloatingType();
1100 bool RHSFloat = RElTy->isRealFloatingType();
1101
1102 if (LHSFloat && RHSFloat) {
1103 if (IsCompAssign ||
1104 SemaRef.getASTContext().getFloatingTypeOrder(LElTy, RElTy) > 0)
1105 return castElement<CK_FloatingCast>(SemaRef, RHS, LHSType);
1106
1107 return castElement<CK_FloatingCast>(SemaRef, LHS, RHSType);
1108 }
1109
1110 if (LHSFloat)
1111 return castElement<CK_IntegralToFloating>(SemaRef, RHS, LHSType);
1112
1113 assert(RHSFloat);
1114 if (IsCompAssign)
1115 return castElement<clang::CK_FloatingToIntegral>(SemaRef, RHS, LHSType);
1116
1117 return castElement<CK_IntegralToFloating>(SemaRef, LHS, RHSType);
1118}
1119
1121 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1122 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1123
1124 int IntOrder = SemaRef.Context.getIntegerTypeOrder(LElTy, RElTy);
1125 bool LHSSigned = LElTy->hasSignedIntegerRepresentation();
1126 bool RHSSigned = RElTy->hasSignedIntegerRepresentation();
1127 auto &Ctx = SemaRef.getASTContext();
1128
1129 // If both types have the same signedness, use the higher ranked type.
1130 if (LHSSigned == RHSSigned) {
1131 if (IsCompAssign || IntOrder >= 0)
1132 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1133
1134 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
1135 }
1136
1137 // If the unsigned type has greater than or equal rank of the signed type, use
1138 // the unsigned type.
1139 if (IntOrder != (LHSSigned ? 1 : -1)) {
1140 if (IsCompAssign || RHSSigned)
1141 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1142 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
1143 }
1144
1145 // At this point the signed type has higher rank than the unsigned type, which
1146 // means it will be the same size or bigger. If the signed type is bigger, it
1147 // can represent all the values of the unsigned type, so select it.
1148 if (Ctx.getIntWidth(LElTy) != Ctx.getIntWidth(RElTy)) {
1149 if (IsCompAssign || LHSSigned)
1150 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1151 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
1152 }
1153
1154 // This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due
1155 // to C/C++ leaking through. The place this happens today is long vs long
1156 // long. When arguments are vector<unsigned long, N> and vector<long long, N>,
1157 // the long long has higher rank than long even though they are the same size.
1158
1159 // If this is a compound assignment cast the right hand side to the left hand
1160 // side's type.
1161 if (IsCompAssign)
1162 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
1163
1164 // If this isn't a compound assignment we convert to unsigned long long.
1165 QualType ElTy = Ctx.getCorrespondingUnsignedType(LHSSigned ? LElTy : RElTy);
1166 QualType NewTy = Ctx.getExtVectorType(
1167 ElTy, RHSType->castAs<VectorType>()->getNumElements());
1168 (void)castElement<CK_IntegralCast>(SemaRef, RHS, NewTy);
1169
1170 return castElement<CK_IntegralCast>(SemaRef, LHS, NewTy);
1171}
1172
1174 QualType SrcTy) {
1175 if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType())
1176 return CK_FloatingCast;
1177 if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx))
1178 return CK_IntegralCast;
1179 if (DestTy->isRealFloatingType())
1180 return CK_IntegralToFloating;
1181 assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx));
1182 return CK_FloatingToIntegral;
1183}
1184
1186 QualType LHSType,
1187 QualType RHSType,
1188 bool IsCompAssign) {
1189 const auto *LVecTy = LHSType->getAs<VectorType>();
1190 const auto *RVecTy = RHSType->getAs<VectorType>();
1191 auto &Ctx = getASTContext();
1192
1193 // If the LHS is not a vector and this is a compound assignment, we truncate
1194 // the argument to a scalar then convert it to the LHS's type.
1195 if (!LVecTy && IsCompAssign) {
1196 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1197 RHS = SemaRef.ImpCastExprToType(RHS.get(), RElTy, CK_HLSLVectorTruncation);
1198 RHSType = RHS.get()->getType();
1199 if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
1200 return LHSType;
1201 RHS = SemaRef.ImpCastExprToType(RHS.get(), LHSType,
1202 getScalarCastKind(Ctx, LHSType, RHSType));
1203 return LHSType;
1204 }
1205
1206 unsigned EndSz = std::numeric_limits<unsigned>::max();
1207 unsigned LSz = 0;
1208 if (LVecTy)
1209 LSz = EndSz = LVecTy->getNumElements();
1210 if (RVecTy)
1211 EndSz = std::min(RVecTy->getNumElements(), EndSz);
1212 assert(EndSz != std::numeric_limits<unsigned>::max() &&
1213 "one of the above should have had a value");
1214
1215 // In a compound assignment, the left operand does not change type, the right
1216 // operand is converted to the type of the left operand.
1217 if (IsCompAssign && LSz != EndSz) {
1218 Diag(LHS.get()->getBeginLoc(),
1219 diag::err_hlsl_vector_compound_assignment_truncation)
1220 << LHSType << RHSType;
1221 return QualType();
1222 }
1223
1224 if (RVecTy && RVecTy->getNumElements() > EndSz)
1225 castVector<CK_HLSLVectorTruncation>(SemaRef, RHS, RHSType, EndSz);
1226 if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz)
1227 castVector<CK_HLSLVectorTruncation>(SemaRef, LHS, LHSType, EndSz);
1228
1229 if (!RVecTy)
1230 castVector<CK_VectorSplat>(SemaRef, RHS, RHSType, EndSz);
1231 if (!IsCompAssign && !LVecTy)
1232 castVector<CK_VectorSplat>(SemaRef, LHS, LHSType, EndSz);
1233
1234 // If we're at the same type after resizing we can stop here.
1235 if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
1236 return Ctx.getCommonSugaredType(LHSType, RHSType);
1237
1238 QualType LElTy = LHSType->castAs<VectorType>()->getElementType();
1239 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1240
1241 // Handle conversion for floating point vectors.
1242 if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType())
1243 return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1244 LElTy, RElTy, IsCompAssign);
1245
1246 assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) &&
1247 "HLSL Vectors can only contain integer or floating point types");
1248 return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1249 LElTy, RElTy, IsCompAssign);
1250}
1251
1253 BinaryOperatorKind Opc) {
1254 assert((Opc == BO_LOr || Opc == BO_LAnd) &&
1255 "Called with non-logical operator");
1257 llvm::raw_svector_ostream OS(Buff);
1258 PrintingPolicy PP(SemaRef.getLangOpts());
1259 StringRef NewFnName = Opc == BO_LOr ? "or" : "and";
1260 OS << NewFnName << "(";
1261 LHS->printPretty(OS, nullptr, PP);
1262 OS << ", ";
1263 RHS->printPretty(OS, nullptr, PP);
1264 OS << ")";
1265 SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc());
1266 SemaRef.Diag(LHS->getBeginLoc(), diag::note_function_suggestion)
1267 << NewFnName << FixItHint::CreateReplacement(FullRange, OS.str());
1268}
1269
1270std::pair<IdentifierInfo *, bool>
1272 llvm::hash_code Hash = llvm::hash_value(Signature);
1273 std::string IdStr = "__hlsl_rootsig_decl_" + std::to_string(Hash);
1274 IdentifierInfo *DeclIdent = &(getASTContext().Idents.get(IdStr));
1275
1276 // Check if we have already found a decl of the same name.
1277 LookupResult R(SemaRef, DeclIdent, SourceLocation(),
1279 bool Found = SemaRef.LookupQualifiedName(R, SemaRef.CurContext);
1280 return {DeclIdent, Found};
1281}
1282
1284 SourceLocation Loc, IdentifierInfo *DeclIdent,
1286
1287 if (handleRootSignatureElements(RootElements))
1288 return;
1289
1291 for (auto &RootSigElement : RootElements)
1292 Elements.push_back(RootSigElement.getElement());
1293
1294 auto *SignatureDecl = HLSLRootSignatureDecl::Create(
1295 SemaRef.getASTContext(), /*DeclContext=*/SemaRef.CurContext, Loc,
1296 DeclIdent, SemaRef.getLangOpts().HLSLRootSigVer, Elements);
1297
1298 SignatureDecl->setImplicit();
1299 SemaRef.PushOnScopeChains(SignatureDecl, SemaRef.getCurScope());
1300}
1301
1304 if (RootSigOverrideIdent) {
1305 LookupResult R(SemaRef, RootSigOverrideIdent, SourceLocation(),
1307 if (SemaRef.LookupQualifiedName(R, DC))
1308 return dyn_cast<HLSLRootSignatureDecl>(R.getFoundDecl());
1309 }
1310
1311 return nullptr;
1312}
1313
1314namespace {
1315
1316struct PerVisibilityBindingChecker {
1317 SemaHLSL *S;
1318 // We need one builder per `llvm::dxbc::ShaderVisibility` value.
1319 std::array<llvm::hlsl::BindingInfoBuilder, 8> Builders;
1320
1321 struct ElemInfo {
1322 const hlsl::RootSignatureElement *Elem;
1323 llvm::dxbc::ShaderVisibility Vis;
1324 bool Diagnosed;
1325 };
1326 llvm::SmallVector<ElemInfo> ElemInfoMap;
1327
1328 PerVisibilityBindingChecker(SemaHLSL *S) : S(S) {}
1329
1330 void trackBinding(llvm::dxbc::ShaderVisibility Visibility,
1331 llvm::dxil::ResourceClass RC, uint32_t Space,
1332 uint32_t LowerBound, uint32_t UpperBound,
1333 const hlsl::RootSignatureElement *Elem) {
1334 uint32_t BuilderIndex = llvm::to_underlying(Visibility);
1335 assert(BuilderIndex < Builders.size() &&
1336 "Not enough builders for visibility type");
1337 Builders[BuilderIndex].trackBinding(RC, Space, LowerBound, UpperBound,
1338 static_cast<const void *>(Elem));
1339
1340 static_assert(llvm::to_underlying(llvm::dxbc::ShaderVisibility::All) == 0,
1341 "'All' visibility must come first");
1342 if (Visibility == llvm::dxbc::ShaderVisibility::All)
1343 for (size_t I = 1, E = Builders.size(); I < E; ++I)
1344 Builders[I].trackBinding(RC, Space, LowerBound, UpperBound,
1345 static_cast<const void *>(Elem));
1346
1347 ElemInfoMap.push_back({Elem, Visibility, false});
1348 }
1349
1350 ElemInfo &getInfo(const hlsl::RootSignatureElement *Elem) {
1351 auto It = llvm::lower_bound(
1352 ElemInfoMap, Elem,
1353 [](const auto &LHS, const auto &RHS) { return LHS.Elem < RHS; });
1354 assert(It->Elem == Elem && "Element not in map");
1355 return *It;
1356 }
1357
1358 bool checkOverlap() {
1359 llvm::sort(ElemInfoMap, [](const auto &LHS, const auto &RHS) {
1360 return LHS.Elem < RHS.Elem;
1361 });
1362
1363 bool HadOverlap = false;
1364
1365 using llvm::hlsl::BindingInfoBuilder;
1366 auto ReportOverlap = [this,
1367 &HadOverlap](const BindingInfoBuilder &Builder,
1368 const llvm::hlsl::Binding &Reported) {
1369 HadOverlap = true;
1370
1371 const auto *Elem =
1372 static_cast<const hlsl::RootSignatureElement *>(Reported.Cookie);
1373 const llvm::hlsl::Binding &Previous = Builder.findOverlapping(Reported);
1374 const auto *PrevElem =
1375 static_cast<const hlsl::RootSignatureElement *>(Previous.Cookie);
1376
1377 ElemInfo &Info = getInfo(Elem);
1378 // We will have already diagnosed this binding if there's overlap in the
1379 // "All" visibility as well as any particular visibility.
1380 if (Info.Diagnosed)
1381 return;
1382 Info.Diagnosed = true;
1383
1384 ElemInfo &PrevInfo = getInfo(PrevElem);
1385 llvm::dxbc::ShaderVisibility CommonVis =
1386 Info.Vis == llvm::dxbc::ShaderVisibility::All ? PrevInfo.Vis
1387 : Info.Vis;
1388
1389 this->S->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap)
1390 << llvm::to_underlying(Reported.RC) << Reported.LowerBound
1391 << Reported.isUnbounded() << Reported.UpperBound
1392 << llvm::to_underlying(Previous.RC) << Previous.LowerBound
1393 << Previous.isUnbounded() << Previous.UpperBound << Reported.Space
1394 << CommonVis;
1395
1396 this->S->Diag(PrevElem->getLocation(),
1397 diag::note_hlsl_resource_range_here);
1398 };
1399
1400 for (BindingInfoBuilder &Builder : Builders)
1401 Builder.calculateBindingInfo(ReportOverlap);
1402
1403 return HadOverlap;
1404 }
1405};
1406
1407static CXXMethodDecl *lookupMethod(Sema &S, CXXRecordDecl *RecordDecl,
1408 StringRef Name, SourceLocation Loc) {
1409 DeclarationName DeclName(&S.getASTContext().Idents.get(Name));
1410 LookupResult Result(S, DeclName, Loc, Sema::LookupMemberName);
1411 if (!S.LookupQualifiedName(Result, static_cast<DeclContext *>(RecordDecl)))
1412 return nullptr;
1413 return cast<CXXMethodDecl>(Result.getFoundDecl());
1414}
1415
1416} // end anonymous namespace
1417
1418static bool hasCounterHandle(const CXXRecordDecl *RD) {
1419 if (RD->field_empty())
1420 return false;
1421 auto It = std::next(RD->field_begin());
1422 if (It == RD->field_end())
1423 return false;
1424 const FieldDecl *SecondField = *It;
1425 if (const auto *ResTy =
1426 SecondField->getType()->getAs<HLSLAttributedResourceType>()) {
1427 return ResTy->getAttrs().IsCounter;
1428 }
1429 return false;
1430}
1431
1434 // Define some common error handling functions
1435 bool HadError = false;
1436 auto ReportError = [this, &HadError](SourceLocation Loc, uint32_t LowerBound,
1437 uint32_t UpperBound) {
1438 HadError = true;
1439 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value)
1440 << LowerBound << UpperBound;
1441 };
1442
1443 auto ReportFloatError = [this, &HadError](SourceLocation Loc,
1444 float LowerBound,
1445 float UpperBound) {
1446 HadError = true;
1447 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value)
1448 << llvm::formatv("{0:f}", LowerBound).sstr<6>()
1449 << llvm::formatv("{0:f}", UpperBound).sstr<6>();
1450 };
1451
1452 auto VerifyRegister = [ReportError](SourceLocation Loc, uint32_t Register) {
1453 if (!llvm::hlsl::rootsig::verifyRegisterValue(Register))
1454 ReportError(Loc, 0, 0xfffffffe);
1455 };
1456
1457 auto VerifySpace = [ReportError](SourceLocation Loc, uint32_t Space) {
1458 if (!llvm::hlsl::rootsig::verifyRegisterSpace(Space))
1459 ReportError(Loc, 0, 0xffffffef);
1460 };
1461
1462 const uint32_t Version =
1463 llvm::to_underlying(SemaRef.getLangOpts().HLSLRootSigVer);
1464 const uint32_t VersionEnum = Version - 1;
1465 auto ReportFlagError = [this, &HadError, VersionEnum](SourceLocation Loc) {
1466 HadError = true;
1467 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_flag)
1468 << /*version minor*/ VersionEnum;
1469 };
1470
1471 // Iterate through the elements and do basic validations
1472 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1473 SourceLocation Loc = RootSigElem.getLocation();
1474 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1475 if (const auto *Descriptor =
1476 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
1477 VerifyRegister(Loc, Descriptor->Reg.Number);
1478 VerifySpace(Loc, Descriptor->Space);
1479
1480 if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(Version,
1481 Descriptor->Flags))
1482 ReportFlagError(Loc);
1483 } else if (const auto *Constants =
1484 std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
1485 VerifyRegister(Loc, Constants->Reg.Number);
1486 VerifySpace(Loc, Constants->Space);
1487 } else if (const auto *Sampler =
1488 std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
1489 VerifyRegister(Loc, Sampler->Reg.Number);
1490 VerifySpace(Loc, Sampler->Space);
1491
1492 assert(!std::isnan(Sampler->MaxLOD) && !std::isnan(Sampler->MinLOD) &&
1493 "By construction, parseFloatParam can't produce a NaN from a "
1494 "float_literal token");
1495
1496 if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler->MaxAnisotropy))
1497 ReportError(Loc, 0, 16);
1498 if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler->MipLODBias))
1499 ReportFloatError(Loc, -16.f, 15.99f);
1500 } else if (const auto *Clause =
1501 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1502 &Elem)) {
1503 VerifyRegister(Loc, Clause->Reg.Number);
1504 VerifySpace(Loc, Clause->Space);
1505
1506 if (!llvm::hlsl::rootsig::verifyNumDescriptors(Clause->NumDescriptors)) {
1507 // NumDescriptor could techincally be ~0u but that is reserved for
1508 // unbounded, so the diagnostic will not report that as a valid int
1509 // value
1510 ReportError(Loc, 1, 0xfffffffe);
1511 }
1512
1513 if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(Version, Clause->Type,
1514 Clause->Flags))
1515 ReportFlagError(Loc);
1516 }
1517 }
1518
1519 PerVisibilityBindingChecker BindingChecker(this);
1520 SmallVector<std::pair<const llvm::hlsl::rootsig::DescriptorTableClause *,
1522 UnboundClauses;
1523
1524 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1525 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1526 if (const auto *Descriptor =
1527 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
1528 uint32_t LowerBound(Descriptor->Reg.Number);
1529 uint32_t UpperBound(LowerBound); // inclusive range
1530
1531 BindingChecker.trackBinding(
1532 Descriptor->Visibility,
1533 static_cast<llvm::dxil::ResourceClass>(Descriptor->Type),
1534 Descriptor->Space, LowerBound, UpperBound, &RootSigElem);
1535 } else if (const auto *Constants =
1536 std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
1537 uint32_t LowerBound(Constants->Reg.Number);
1538 uint32_t UpperBound(LowerBound); // inclusive range
1539
1540 BindingChecker.trackBinding(
1541 Constants->Visibility, llvm::dxil::ResourceClass::CBuffer,
1542 Constants->Space, LowerBound, UpperBound, &RootSigElem);
1543 } else if (const auto *Sampler =
1544 std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
1545 uint32_t LowerBound(Sampler->Reg.Number);
1546 uint32_t UpperBound(LowerBound); // inclusive range
1547
1548 BindingChecker.trackBinding(
1549 Sampler->Visibility, llvm::dxil::ResourceClass::Sampler,
1550 Sampler->Space, LowerBound, UpperBound, &RootSigElem);
1551 } else if (const auto *Clause =
1552 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1553 &Elem)) {
1554 // We'll process these once we see the table element.
1555 UnboundClauses.emplace_back(Clause, &RootSigElem);
1556 } else if (const auto *Table =
1557 std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
1558 assert(UnboundClauses.size() == Table->NumClauses &&
1559 "Number of unbound elements must match the number of clauses");
1560 bool HasAnySampler = false;
1561 bool HasAnyNonSampler = false;
1562 uint64_t Offset = 0;
1563 bool IsPrevUnbound = false;
1564 for (const auto &[Clause, ClauseElem] : UnboundClauses) {
1565 SourceLocation Loc = ClauseElem->getLocation();
1566 if (Clause->Type == llvm::dxil::ResourceClass::Sampler)
1567 HasAnySampler = true;
1568 else
1569 HasAnyNonSampler = true;
1570
1571 if (HasAnySampler && HasAnyNonSampler)
1572 Diag(Loc, diag::err_hlsl_invalid_mixed_resources);
1573
1574 // Relevant error will have already been reported above and needs to be
1575 // fixed before we can conduct further analysis, so shortcut error
1576 // return
1577 if (Clause->NumDescriptors == 0)
1578 return true;
1579
1580 bool IsAppending =
1581 Clause->Offset == llvm::hlsl::rootsig::DescriptorTableOffsetAppend;
1582 if (!IsAppending)
1583 Offset = Clause->Offset;
1584
1585 uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound(
1586 Offset, Clause->NumDescriptors);
1587
1588 if (IsPrevUnbound && IsAppending)
1589 Diag(Loc, diag::err_hlsl_appending_onto_unbound);
1590 else if (!llvm::hlsl::rootsig::verifyNoOverflowedOffset(RangeBound))
1591 Diag(Loc, diag::err_hlsl_offset_overflow) << Offset << RangeBound;
1592
1593 // Update offset to be 1 past this range's bound
1594 Offset = RangeBound + 1;
1595 IsPrevUnbound = Clause->NumDescriptors ==
1596 llvm::hlsl::rootsig::NumDescriptorsUnbounded;
1597
1598 // Compute the register bounds and track resource binding
1599 uint32_t LowerBound(Clause->Reg.Number);
1600 uint32_t UpperBound = llvm::hlsl::rootsig::computeRangeBound(
1601 LowerBound, Clause->NumDescriptors);
1602
1603 BindingChecker.trackBinding(
1604 Table->Visibility,
1605 static_cast<llvm::dxil::ResourceClass>(Clause->Type), Clause->Space,
1606 LowerBound, UpperBound, ClauseElem);
1607 }
1608 UnboundClauses.clear();
1609 }
1610 }
1611
1612 return BindingChecker.checkOverlap();
1613}
1614
1616 if (AL.getNumArgs() != 1) {
1617 Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1;
1618 return;
1619 }
1620
1622 if (auto *RS = D->getAttr<RootSignatureAttr>()) {
1623 if (RS->getSignatureIdent() != Ident) {
1624 Diag(AL.getLoc(), diag::err_disallowed_duplicate_attribute) << RS;
1625 return;
1626 }
1627
1628 Diag(AL.getLoc(), diag::warn_duplicate_attribute_exact) << RS;
1629 return;
1630 }
1631
1633 if (SemaRef.LookupQualifiedName(R, D->getDeclContext()))
1634 if (auto *SignatureDecl =
1635 dyn_cast<HLSLRootSignatureDecl>(R.getFoundDecl())) {
1636 D->addAttr(::new (getASTContext()) RootSignatureAttr(
1637 getASTContext(), AL, Ident, SignatureDecl));
1638 }
1639}
1640
1642 llvm::VersionTuple SMVersion =
1643 getASTContext().getTargetInfo().getTriple().getOSVersion();
1644 bool IsDXIL = getASTContext().getTargetInfo().getTriple().getArch() ==
1645 llvm::Triple::dxil;
1646
1647 uint32_t ZMax = 1024;
1648 uint32_t ThreadMax = 1024;
1649 if (IsDXIL && SMVersion.getMajor() <= 4) {
1650 ZMax = 1;
1651 ThreadMax = 768;
1652 } else if (IsDXIL && SMVersion.getMajor() == 5) {
1653 ZMax = 64;
1654 ThreadMax = 1024;
1655 }
1656
1657 uint32_t X;
1658 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), X))
1659 return;
1660 if (X > 1024) {
1661 Diag(AL.getArgAsExpr(0)->getExprLoc(),
1662 diag::err_hlsl_numthreads_argument_oor)
1663 << 0 << 1024;
1664 return;
1665 }
1666 uint32_t Y;
1667 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Y))
1668 return;
1669 if (Y > 1024) {
1670 Diag(AL.getArgAsExpr(1)->getExprLoc(),
1671 diag::err_hlsl_numthreads_argument_oor)
1672 << 1 << 1024;
1673 return;
1674 }
1675 uint32_t Z;
1676 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Z))
1677 return;
1678 if (Z > ZMax) {
1679 SemaRef.Diag(AL.getArgAsExpr(2)->getExprLoc(),
1680 diag::err_hlsl_numthreads_argument_oor)
1681 << 2 << ZMax;
1682 return;
1683 }
1684
1685 if (X * Y * Z > ThreadMax) {
1686 Diag(AL.getLoc(), diag::err_hlsl_numthreads_invalid) << ThreadMax;
1687 return;
1688 }
1689
1690 HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z);
1691 if (NewAttr)
1692 D->addAttr(NewAttr);
1693}
1694
1695static bool isValidWaveSizeValue(unsigned Value) {
1696 return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128;
1697}
1698
1700 // validate that the wavesize argument is a power of 2 between 4 and 128
1701 // inclusive
1702 unsigned SpelledArgsCount = AL.getNumArgs();
1703 if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
1704 return;
1705
1706 uint32_t Min;
1707 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Min))
1708 return;
1709
1710 uint32_t Max = 0;
1711 if (SpelledArgsCount > 1 &&
1712 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Max))
1713 return;
1714
1715 uint32_t Preferred = 0;
1716 if (SpelledArgsCount > 2 &&
1717 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Preferred))
1718 return;
1719
1720 if (SpelledArgsCount > 2) {
1721 if (!isValidWaveSizeValue(Preferred)) {
1722 Diag(AL.getArgAsExpr(2)->getExprLoc(),
1723 diag::err_attribute_power_of_two_in_range)
1724 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize
1725 << Preferred;
1726 return;
1727 }
1728 // Preferred not in range.
1729 if (Preferred < Min || Preferred > Max) {
1730 Diag(AL.getArgAsExpr(2)->getExprLoc(),
1731 diag::err_attribute_power_of_two_in_range)
1732 << AL << Min << Max << Preferred;
1733 return;
1734 }
1735 } else if (SpelledArgsCount > 1) {
1736 if (!isValidWaveSizeValue(Max)) {
1737 Diag(AL.getArgAsExpr(1)->getExprLoc(),
1738 diag::err_attribute_power_of_two_in_range)
1739 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max;
1740 return;
1741 }
1742 if (Max < Min) {
1743 Diag(AL.getLoc(), diag::err_attribute_argument_invalid) << AL << 1;
1744 return;
1745 } else if (Max == Min) {
1746 Diag(AL.getLoc(), diag::warn_attr_min_eq_max) << AL;
1747 }
1748 } else {
1749 if (!isValidWaveSizeValue(Min)) {
1750 Diag(AL.getArgAsExpr(0)->getExprLoc(),
1751 diag::err_attribute_power_of_two_in_range)
1752 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min;
1753 return;
1754 }
1755 }
1756
1757 HLSLWaveSizeAttr *NewAttr =
1758 mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
1759 if (NewAttr)
1760 D->addAttr(NewAttr);
1761}
1762
1764 uint32_t ID;
1765 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), ID))
1766 return;
1767 D->addAttr(::new (getASTContext())
1768 HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
1769}
1770
1772 D->addAttr(::new (getASTContext())
1773 HLSLVkPushConstantAttr(getASTContext(), AL));
1774}
1775
1777 uint32_t Id;
1778 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
1779 return;
1780 HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
1781 if (NewAttr)
1782 D->addAttr(NewAttr);
1783}
1784
1786 uint32_t Binding = 0;
1787 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Binding))
1788 return;
1789 uint32_t Set = 0;
1790 if (AL.getNumArgs() > 1 &&
1791 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Set))
1792 return;
1793
1794 D->addAttr(::new (getASTContext())
1795 HLSLVkBindingAttr(getASTContext(), AL, Binding, Set));
1796}
1797
1799 uint32_t Location;
1800 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Location))
1801 return;
1802
1803 D->addAttr(::new (getASTContext())
1804 HLSLVkLocationAttr(getASTContext(), AL, Location));
1805}
1806
1808 const auto *VT = T->getAs<VectorType>();
1809
1810 if (!T->hasUnsignedIntegerRepresentation() ||
1811 (VT && VT->getNumElements() > 3)) {
1812 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1813 << AL << "uint/uint2/uint3";
1814 return false;
1815 }
1816
1817 return true;
1818}
1819
1821 const auto *VT = T->getAs<VectorType>();
1822 if (!T->hasFloatingRepresentation() || (VT && VT->getNumElements() > 4)) {
1823 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1824 << AL << "float/float1/float2/float3/float4";
1825 return false;
1826 }
1827
1828 return true;
1829}
1830
1832 std::optional<unsigned> Index) {
1833 std::string SemanticName = AL.getAttrName()->getName().upper();
1834
1835 auto *VD = cast<ValueDecl>(D);
1836 QualType ValueType = VD->getType();
1837 if (auto *FD = dyn_cast<FunctionDecl>(D))
1838 ValueType = FD->getReturnType();
1839
1840 bool IsOutput = false;
1841 if (HLSLParamModifierAttr *MA = D->getAttr<HLSLParamModifierAttr>()) {
1842 if (MA->isOut()) {
1843 IsOutput = true;
1844 ValueType = cast<ReferenceType>(ValueType)->getPointeeType();
1845 }
1846 }
1847
1848 if (SemanticName == "SV_DISPATCHTHREADID") {
1849 diagnoseInputIDType(ValueType, AL);
1850 if (IsOutput)
1851 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1852 if (Index.has_value())
1853 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1855 return;
1856 }
1857
1858 if (SemanticName == "SV_GROUPINDEX") {
1859 if (IsOutput)
1860 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1861 if (Index.has_value())
1862 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1864 return;
1865 }
1866
1867 if (SemanticName == "SV_GROUPTHREADID") {
1868 diagnoseInputIDType(ValueType, AL);
1869 if (IsOutput)
1870 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1871 if (Index.has_value())
1872 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1874 return;
1875 }
1876
1877 if (SemanticName == "SV_GROUPID") {
1878 diagnoseInputIDType(ValueType, AL);
1879 if (IsOutput)
1880 Diag(AL.getLoc(), diag::err_hlsl_semantic_output_not_supported) << AL;
1881 if (Index.has_value())
1882 Diag(AL.getLoc(), diag::err_hlsl_semantic_indexing_not_supported) << AL;
1884 return;
1885 }
1886
1887 if (SemanticName == "SV_POSITION") {
1888 const auto *VT = ValueType->getAs<VectorType>();
1889 if (!ValueType->hasFloatingRepresentation() ||
1890 (VT && VT->getNumElements() > 4))
1891 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1892 << AL << "float/float1/float2/float3/float4";
1894 return;
1895 }
1896
1897 if (SemanticName == "SV_TARGET") {
1898 const auto *VT = ValueType->getAs<VectorType>();
1899 if (!ValueType->hasFloatingRepresentation() ||
1900 (VT && VT->getNumElements() > 4))
1901 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1902 << AL << "float/float1/float2/float3/float4";
1904 return;
1905 }
1906
1907 Diag(AL.getLoc(), diag::err_hlsl_unknown_semantic) << AL;
1908}
1909
1911 uint32_t IndexValue(0), ExplicitIndex(0);
1912 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), IndexValue) ||
1913 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), ExplicitIndex)) {
1914 assert(0 && "HLSLUnparsedSemantic is expected to have 2 int arguments.");
1915 }
1916 assert(IndexValue > 0 ? ExplicitIndex : true);
1917 std::optional<unsigned> Index =
1918 ExplicitIndex ? std::optional<unsigned>(IndexValue) : std::nullopt;
1919
1920 if (AL.getAttrName()->getName().starts_with_insensitive("SV_"))
1921 diagnoseSystemSemanticAttr(D, AL, Index);
1922 else
1924}
1925
1928 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)
1929 << AL << "shader constant in a constant buffer";
1930 return;
1931 }
1932
1933 uint32_t SubComponent;
1934 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), SubComponent))
1935 return;
1936 uint32_t Component;
1937 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Component))
1938 return;
1939
1940 QualType T = cast<VarDecl>(D)->getType().getCanonicalType();
1941 // Check if T is an array or struct type.
1942 // TODO: mark matrix type as aggregate type.
1943 bool IsAggregateTy = (T->isArrayType() || T->isStructureType());
1944
1945 // Check Component is valid for T.
1946 if (Component) {
1947 unsigned Size = getASTContext().getTypeSize(T);
1948 if (IsAggregateTy) {
1949 Diag(AL.getLoc(), diag::err_hlsl_invalid_register_or_packoffset);
1950 return;
1951 } else {
1952 // Make sure Component + sizeof(T) <= 4.
1953 if ((Component * 32 + Size) > 128) {
1954 Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);
1955 return;
1956 }
1957 QualType EltTy = T;
1958 if (const auto *VT = T->getAs<VectorType>())
1959 EltTy = VT->getElementType();
1960 unsigned Align = getASTContext().getTypeAlign(EltTy);
1961 if (Align > 32 && Component == 1) {
1962 // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary.
1963 // So we only need to check Component 1 here.
1964 Diag(AL.getLoc(), diag::err_hlsl_packoffset_alignment_mismatch)
1965 << Align << EltTy;
1966 return;
1967 }
1968 }
1969 }
1970
1971 D->addAttr(::new (getASTContext()) HLSLPackOffsetAttr(
1972 getASTContext(), AL, SubComponent, Component));
1973}
1974
1976 StringRef Str;
1977 SourceLocation ArgLoc;
1978 if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))
1979 return;
1980
1981 llvm::Triple::EnvironmentType ShaderType;
1982 if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) {
1983 Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
1984 << AL << Str << ArgLoc;
1985 return;
1986 }
1987
1988 // FIXME: check function match the shader stage.
1989
1990 HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType);
1991 if (NewAttr)
1992 D->addAttr(NewAttr);
1993}
1994
1996 Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList,
1997 QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo) {
1998 assert(AttrList.size() && "expected list of resource attributes");
1999
2000 QualType ContainedTy = QualType();
2001 TypeSourceInfo *ContainedTyInfo = nullptr;
2002 SourceLocation LocBegin = AttrList[0]->getRange().getBegin();
2003 SourceLocation LocEnd = AttrList[0]->getRange().getEnd();
2004
2005 HLSLAttributedResourceType::Attributes ResAttrs;
2006
2007 bool HasResourceClass = false;
2008 for (const Attr *A : AttrList) {
2009 if (!A)
2010 continue;
2011 LocEnd = A->getRange().getEnd();
2012 switch (A->getKind()) {
2013 case attr::HLSLResourceClass: {
2014 ResourceClass RC = cast<HLSLResourceClassAttr>(A)->getResourceClass();
2015 if (HasResourceClass) {
2016 S.Diag(A->getLocation(), ResAttrs.ResourceClass == RC
2017 ? diag::warn_duplicate_attribute_exact
2018 : diag::warn_duplicate_attribute)
2019 << A;
2020 return false;
2021 }
2022 ResAttrs.ResourceClass = RC;
2023 HasResourceClass = true;
2024 break;
2025 }
2026 case attr::HLSLROV:
2027 if (ResAttrs.IsROV) {
2028 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
2029 return false;
2030 }
2031 ResAttrs.IsROV = true;
2032 break;
2033 case attr::HLSLRawBuffer:
2034 if (ResAttrs.RawBuffer) {
2035 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
2036 return false;
2037 }
2038 ResAttrs.RawBuffer = true;
2039 break;
2040 case attr::HLSLIsCounter:
2041 if (ResAttrs.IsCounter) {
2042 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
2043 return false;
2044 }
2045 ResAttrs.IsCounter = true;
2046 break;
2047 case attr::HLSLContainedType: {
2048 const HLSLContainedTypeAttr *CTAttr = cast<HLSLContainedTypeAttr>(A);
2049 QualType Ty = CTAttr->getType();
2050 if (!ContainedTy.isNull()) {
2051 S.Diag(A->getLocation(), ContainedTy == Ty
2052 ? diag::warn_duplicate_attribute_exact
2053 : diag::warn_duplicate_attribute)
2054 << A;
2055 return false;
2056 }
2057 ContainedTy = Ty;
2058 ContainedTyInfo = CTAttr->getTypeLoc();
2059 break;
2060 }
2061 default:
2062 llvm_unreachable("unhandled resource attribute type");
2063 }
2064 }
2065
2066 if (!HasResourceClass) {
2067 S.Diag(AttrList.back()->getRange().getEnd(),
2068 diag::err_hlsl_missing_resource_class);
2069 return false;
2070 }
2071
2073 Wrapped, ContainedTy, ResAttrs);
2074
2075 if (LocInfo && ContainedTyInfo) {
2076 LocInfo->Range = SourceRange(LocBegin, LocEnd);
2077 LocInfo->ContainedTyInfo = ContainedTyInfo;
2078 }
2079 return true;
2080}
2081
2082// Validates and creates an HLSL attribute that is applied as type attribute on
2083// HLSL resource. The attributes are collected in HLSLResourcesTypeAttrs and at
2084// the end of the declaration they are applied to the declaration type by
2085// wrapping it in HLSLAttributedResourceType.
2087 // only allow resource type attributes on intangible types
2088 if (!T->isHLSLResourceType()) {
2089 Diag(AL.getLoc(), diag::err_hlsl_attribute_needs_intangible_type)
2090 << AL << getASTContext().HLSLResourceTy;
2091 return false;
2092 }
2093
2094 // validate number of arguments
2095 if (!AL.checkExactlyNumArgs(SemaRef, AL.getMinArgs()))
2096 return false;
2097
2098 Attr *A = nullptr;
2099
2103 {
2104 AttributeCommonInfo::AS_CXX11, 0, false /*IsAlignas*/,
2105 false /*IsRegularKeywordAttribute*/
2106 });
2107
2108 switch (AL.getKind()) {
2109 case ParsedAttr::AT_HLSLResourceClass: {
2110 if (!AL.isArgIdent(0)) {
2111 Diag(AL.getLoc(), diag::err_attribute_argument_type)
2112 << AL << AANT_ArgumentIdentifier;
2113 return false;
2114 }
2115
2116 IdentifierLoc *Loc = AL.getArgAsIdent(0);
2117 StringRef Identifier = Loc->getIdentifierInfo()->getName();
2118 SourceLocation ArgLoc = Loc->getLoc();
2119
2120 // Validate resource class value
2121 ResourceClass RC;
2122 if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Identifier, RC)) {
2123 Diag(ArgLoc, diag::warn_attribute_type_not_supported)
2124 << "ResourceClass" << Identifier;
2125 return false;
2126 }
2127 A = HLSLResourceClassAttr::Create(getASTContext(), RC, ACI);
2128 break;
2129 }
2130
2131 case ParsedAttr::AT_HLSLROV:
2132 A = HLSLROVAttr::Create(getASTContext(), ACI);
2133 break;
2134
2135 case ParsedAttr::AT_HLSLRawBuffer:
2136 A = HLSLRawBufferAttr::Create(getASTContext(), ACI);
2137 break;
2138
2139 case ParsedAttr::AT_HLSLIsCounter:
2140 A = HLSLIsCounterAttr::Create(getASTContext(), ACI);
2141 break;
2142
2143 case ParsedAttr::AT_HLSLContainedType: {
2144 if (AL.getNumArgs() != 1 && !AL.hasParsedType()) {
2145 Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1;
2146 return false;
2147 }
2148
2149 TypeSourceInfo *TSI = nullptr;
2150 QualType QT = SemaRef.GetTypeFromParser(AL.getTypeArg(), &TSI);
2151 assert(TSI && "no type source info for attribute argument");
2152 if (SemaRef.RequireCompleteType(TSI->getTypeLoc().getBeginLoc(), QT,
2153 diag::err_incomplete_type))
2154 return false;
2155 A = HLSLContainedTypeAttr::Create(getASTContext(), TSI, ACI);
2156 break;
2157 }
2158
2159 default:
2160 llvm_unreachable("unhandled HLSL attribute");
2161 }
2162
2163 HLSLResourcesTypeAttrs.emplace_back(A);
2164 return true;
2165}
2166
2167// Combines all resource type attributes and creates HLSLAttributedResourceType.
2169 if (!HLSLResourcesTypeAttrs.size())
2170 return CurrentType;
2171
2172 QualType QT = CurrentType;
2175 HLSLResourcesTypeAttrs, QT, &LocInfo)) {
2176 const HLSLAttributedResourceType *RT =
2178
2179 // Temporarily store TypeLoc information for the new type.
2180 // It will be transferred to HLSLAttributesResourceTypeLoc
2181 // shortly after the type is created by TypeSpecLocFiller which
2182 // will call the TakeLocForHLSLAttribute method below.
2183 LocsForHLSLAttributedResources.insert(std::pair(RT, LocInfo));
2184 }
2185 HLSLResourcesTypeAttrs.clear();
2186 return QT;
2187}
2188
2189// Returns source location for the HLSLAttributedResourceType
2191SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
2192 HLSLAttributedResourceLocInfo LocInfo = {};
2193 auto I = LocsForHLSLAttributedResources.find(RT);
2194 if (I != LocsForHLSLAttributedResources.end()) {
2195 LocInfo = I->second;
2196 LocsForHLSLAttributedResources.erase(I);
2197 return LocInfo;
2198 }
2199 LocInfo.Range = SourceRange();
2200 return LocInfo;
2201}
2202
2203// Walks though the global variable declaration, collects all resource binding
2204// requirements and adds them to Bindings
2205void SemaHLSL::collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
2206 const RecordType *RT) {
2207 const RecordDecl *RD = RT->getDecl()->getDefinitionOrSelf();
2208 for (FieldDecl *FD : RD->fields()) {
2209 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
2210
2211 // Unwrap arrays
2212 // FIXME: Calculate array size while unwrapping
2213 assert(!Ty->isIncompleteArrayType() &&
2214 "incomplete arrays inside user defined types are not supported");
2215 while (Ty->isConstantArrayType()) {
2218 }
2219
2220 if (!Ty->isRecordType())
2221 continue;
2222
2223 if (const HLSLAttributedResourceType *AttrResType =
2224 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
2225 // Add a new DeclBindingInfo to Bindings if it does not already exist
2226 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
2227 DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, RC);
2228 if (!DBI)
2229 Bindings.addDeclBindingInfo(VD, RC);
2230 } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
2231 // Recursively scan embedded struct or class; it would be nice to do this
2232 // without recursion, but tricky to correctly calculate the size of the
2233 // binding, which is something we are probably going to need to do later
2234 // on. Hopefully nesting of structs in structs too many levels is
2235 // unlikely.
2236 collectResourceBindingsOnUserRecordDecl(VD, RT);
2237 }
2238 }
2239}
2240
2241// Diagnose localized register binding errors for a single binding; does not
2242// diagnose resource binding on user record types, that will be done later
2243// in processResourceBindingOnDecl based on the information collected in
2244// collectResourceBindingsOnVarDecl.
2245// Returns false if the register binding is not valid.
2247 Decl *D, RegisterType RegType,
2248 bool SpecifiedSpace) {
2249 int RegTypeNum = static_cast<int>(RegType);
2250
2251 // check if the decl type is groupshared
2252 if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
2253 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2254 return false;
2255 }
2256
2257 // Cbuffers and Tbuffers are HLSLBufferDecl types
2258 if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
2259 ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
2260 : ResourceClass::SRV;
2261 if (RegType == getRegisterType(RC))
2262 return true;
2263
2264 S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
2265 << RegTypeNum;
2266 return false;
2267 }
2268
2269 // Samplers, UAVs, and SRVs are VarDecl types
2270 assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
2271 VarDecl *VD = cast<VarDecl>(D);
2272
2273 // Resource
2274 if (const HLSLAttributedResourceType *AttrResType =
2275 HLSLAttributedResourceType::findHandleTypeOnResource(
2276 VD->getType().getTypePtr())) {
2277 if (RegType == getRegisterType(AttrResType))
2278 return true;
2279
2280 S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
2281 << RegTypeNum;
2282 return false;
2283 }
2284
2285 const clang::Type *Ty = VD->getType().getTypePtr();
2286 while (Ty->isArrayType())
2288
2289 // Basic types
2290 if (Ty->isArithmeticType() || Ty->isVectorType()) {
2291 bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
2292 if (SpecifiedSpace && !DeclaredInCOrTBuffer)
2293 S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
2294
2295 if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(S.getASTContext()) ||
2296 Ty->isFloatingType() || Ty->isVectorType())) {
2297 // Register annotation on default constant buffer declaration ($Globals)
2298 if (RegType == RegisterType::CBuffer)
2299 S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
2300 else if (RegType != RegisterType::C)
2301 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2302 else
2303 return true;
2304 } else {
2305 if (RegType == RegisterType::C)
2306 S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
2307 else
2308 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2309 }
2310 return false;
2311 }
2312 if (Ty->isRecordType())
2313 // RecordTypes will be diagnosed in processResourceBindingOnDecl
2314 // that is called from ActOnVariableDeclarator
2315 return true;
2316
2317 // Anything else is an error
2318 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2319 return false;
2320}
2321
2323 RegisterType regType) {
2324 // make sure that there are no two register annotations
2325 // applied to the decl with the same register type
2326 bool RegisterTypesDetected[5] = {false};
2327 RegisterTypesDetected[static_cast<int>(regType)] = true;
2328
2329 for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
2330 if (HLSLResourceBindingAttr *attr =
2331 dyn_cast<HLSLResourceBindingAttr>(*it)) {
2332
2333 RegisterType otherRegType = attr->getRegisterType();
2334 if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
2335 int otherRegTypeNum = static_cast<int>(otherRegType);
2336 S.Diag(TheDecl->getLocation(),
2337 diag::err_hlsl_duplicate_register_annotation)
2338 << otherRegTypeNum;
2339 return false;
2340 }
2341 RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
2342 }
2343 }
2344 return true;
2345}
2346
2348 Decl *D, RegisterType RegType,
2349 bool SpecifiedSpace) {
2350
2351 // exactly one of these two types should be set
2352 assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
2353 (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
2354 "expecting VarDecl or HLSLBufferDecl");
2355
2356 // check if the declaration contains resource matching the register type
2357 if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
2358 return false;
2359
2360 // next, if multiple register annotations exist, check that none conflict.
2361 return ValidateMultipleRegisterAnnotations(S, D, RegType);
2362}
2363
2364// return false if the slot count exceeds the limit, true otherwise
2365static bool AccumulateHLSLResourceSlots(QualType Ty, uint64_t &StartSlot,
2366 const uint64_t &Limit,
2367 const ResourceClass ResClass,
2368 ASTContext &Ctx,
2369 uint64_t ArrayCount = 1) {
2370 Ty = Ty.getCanonicalType();
2371 const Type *T = Ty.getTypePtr();
2372
2373 // Early exit if already overflowed
2374 if (StartSlot > Limit)
2375 return false;
2376
2377 // Case 1: array type
2378 if (const auto *AT = dyn_cast<ArrayType>(T)) {
2379 uint64_t Count = 1;
2380
2381 if (const auto *CAT = dyn_cast<ConstantArrayType>(AT))
2382 Count = CAT->getSize().getZExtValue();
2383
2384 QualType ElemTy = AT->getElementType();
2385 return AccumulateHLSLResourceSlots(ElemTy, StartSlot, Limit, ResClass, Ctx,
2386 ArrayCount * Count);
2387 }
2388
2389 // Case 2: resource leaf
2390 if (auto ResTy = dyn_cast<HLSLAttributedResourceType>(T)) {
2391 // First ensure this resource counts towards the corresponding
2392 // register type limit.
2393 if (ResTy->getAttrs().ResourceClass != ResClass)
2394 return true;
2395
2396 // Validate highest slot used
2397 uint64_t EndSlot = StartSlot + ArrayCount - 1;
2398 if (EndSlot > Limit)
2399 return false;
2400
2401 // Advance SlotCount past the consumed range
2402 StartSlot = EndSlot + 1;
2403 return true;
2404 }
2405
2406 // Case 3: struct / record
2407 if (const auto *RT = dyn_cast<RecordType>(T)) {
2408 const RecordDecl *RD = RT->getDecl();
2409
2410 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
2411 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
2412 if (!AccumulateHLSLResourceSlots(Base.getType(), StartSlot, Limit,
2413 ResClass, Ctx, ArrayCount))
2414 return false;
2415 }
2416 }
2417
2418 for (const FieldDecl *Field : RD->fields()) {
2419 if (!AccumulateHLSLResourceSlots(Field->getType(), StartSlot, Limit,
2420 ResClass, Ctx, ArrayCount))
2421 return false;
2422 }
2423
2424 return true;
2425 }
2426
2427 // Case 4: everything else
2428 return true;
2429}
2430
2431// return true if there is something invalid, false otherwise
2432static bool ValidateRegisterNumber(uint64_t SlotNum, Decl *TheDecl,
2433 ASTContext &Ctx, RegisterType RegTy) {
2434 const uint64_t Limit = UINT32_MAX;
2435 if (SlotNum > Limit)
2436 return true;
2437
2438 // after verifying the number doesn't exceed uint32max, we don't need
2439 // to look further into c or i register types
2440 if (RegTy == RegisterType::C || RegTy == RegisterType::I)
2441 return false;
2442
2443 if (VarDecl *VD = dyn_cast<VarDecl>(TheDecl)) {
2444 uint64_t BaseSlot = SlotNum;
2445
2446 if (!AccumulateHLSLResourceSlots(VD->getType(), SlotNum, Limit,
2447 getResourceClass(RegTy), Ctx))
2448 return true;
2449
2450 // After AccumulateHLSLResourceSlots runs, SlotNum is now
2451 // the first free slot; last used was SlotNum - 1
2452 return (BaseSlot > Limit);
2453 }
2454 // handle the cbuffer/tbuffer case
2455 if (isa<HLSLBufferDecl>(TheDecl))
2456 // resources cannot be put within a cbuffer, so no need
2457 // to analyze the structure since the register number
2458 // won't be pushed any higher.
2459 return (SlotNum > Limit);
2460
2461 // we don't expect any other decl type, so fail
2462 llvm_unreachable("unexpected decl type");
2463}
2464
2466 if (VarDecl *VD = dyn_cast<VarDecl>(TheDecl)) {
2467 QualType Ty = VD->getType();
2468 if (const auto *IAT = dyn_cast<IncompleteArrayType>(Ty))
2469 Ty = IAT->getElementType();
2470 if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(), Ty,
2471 diag::err_incomplete_type))
2472 return;
2473 }
2474
2475 StringRef Slot = "";
2476 StringRef Space = "";
2477 SourceLocation SlotLoc, SpaceLoc;
2478
2479 if (!AL.isArgIdent(0)) {
2480 Diag(AL.getLoc(), diag::err_attribute_argument_type)
2481 << AL << AANT_ArgumentIdentifier;
2482 return;
2483 }
2484 IdentifierLoc *Loc = AL.getArgAsIdent(0);
2485
2486 if (AL.getNumArgs() == 2) {
2487 Slot = Loc->getIdentifierInfo()->getName();
2488 SlotLoc = Loc->getLoc();
2489 if (!AL.isArgIdent(1)) {
2490 Diag(AL.getLoc(), diag::err_attribute_argument_type)
2491 << AL << AANT_ArgumentIdentifier;
2492 return;
2493 }
2494 Loc = AL.getArgAsIdent(1);
2495 Space = Loc->getIdentifierInfo()->getName();
2496 SpaceLoc = Loc->getLoc();
2497 } else {
2498 StringRef Str = Loc->getIdentifierInfo()->getName();
2499 if (Str.starts_with("space")) {
2500 Space = Str;
2501 SpaceLoc = Loc->getLoc();
2502 } else {
2503 Slot = Str;
2504 SlotLoc = Loc->getLoc();
2505 Space = "space0";
2506 }
2507 }
2508
2509 RegisterType RegType = RegisterType::SRV;
2510 std::optional<unsigned> SlotNum;
2511 unsigned SpaceNum = 0;
2512
2513 // Validate slot
2514 if (!Slot.empty()) {
2515 if (!convertToRegisterType(Slot, &RegType)) {
2516 Diag(SlotLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1);
2517 return;
2518 }
2519 if (RegType == RegisterType::I) {
2520 Diag(SlotLoc, diag::warn_hlsl_deprecated_register_type_i);
2521 return;
2522 }
2523 const StringRef SlotNumStr = Slot.substr(1);
2524
2525 uint64_t N;
2526
2527 // validate that the slot number is a non-empty number
2528 if (SlotNumStr.getAsInteger(10, N)) {
2529 Diag(SlotLoc, diag::err_hlsl_unsupported_register_number);
2530 return;
2531 }
2532
2533 // Validate register number. It should not exceed UINT32_MAX,
2534 // including if the resource type is an array that starts
2535 // before UINT32_MAX, but ends afterwards.
2536 if (ValidateRegisterNumber(N, TheDecl, getASTContext(), RegType)) {
2537 Diag(SlotLoc, diag::err_hlsl_register_number_too_large);
2538 return;
2539 }
2540
2541 // the slot number has been validated and does not exceed UINT32_MAX
2542 SlotNum = (unsigned)N;
2543 }
2544
2545 // Validate space
2546 if (!Space.starts_with("space")) {
2547 Diag(SpaceLoc, diag::err_hlsl_expected_space) << Space;
2548 return;
2549 }
2550 StringRef SpaceNumStr = Space.substr(5);
2551 if (SpaceNumStr.getAsInteger(10, SpaceNum)) {
2552 Diag(SpaceLoc, diag::err_hlsl_expected_space) << Space;
2553 return;
2554 }
2555
2556 // If we have slot, diagnose it is the right register type for the decl
2557 if (SlotNum.has_value())
2558 if (!DiagnoseHLSLRegisterAttribute(SemaRef, SlotLoc, TheDecl, RegType,
2559 !SpaceLoc.isInvalid()))
2560 return;
2561
2562 HLSLResourceBindingAttr *NewAttr =
2563 HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL);
2564 if (NewAttr) {
2565 NewAttr->setBinding(RegType, SlotNum, SpaceNum);
2566 TheDecl->addAttr(NewAttr);
2567 }
2568}
2569
2571 HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr(
2572 D, AL,
2573 static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
2574 if (NewAttr)
2575 D->addAttr(NewAttr);
2576}
2577
2578namespace {
2579
2580/// This class implements HLSL availability diagnostics for default
2581/// and relaxed mode
2582///
2583/// The goal of this diagnostic is to emit an error or warning when an
2584/// unavailable API is found in code that is reachable from the shader
2585/// entry function or from an exported function (when compiling a shader
2586/// library).
2587///
2588/// This is done by traversing the AST of all shader entry point functions
2589/// and of all exported functions, and any functions that are referenced
2590/// from this AST. In other words, any functions that are reachable from
2591/// the entry points.
2592class DiagnoseHLSLAvailability : public DynamicRecursiveASTVisitor {
2593 Sema &SemaRef;
2594
2595 // Stack of functions to be scaned
2597
2598 // Tracks which environments functions have been scanned in.
2599 //
2600 // Maps FunctionDecl to an unsigned number that represents the set of shader
2601 // environments the function has been scanned for.
2602 // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
2603 // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
2604 // (verified by static_asserts in Triple.cpp), we can use it to index
2605 // individual bits in the set, as long as we shift the values to start with 0
2606 // by subtracting the value of llvm::Triple::Pixel first.
2607 //
2608 // The N'th bit in the set will be set if the function has been scanned
2609 // in shader environment whose llvm::Triple::EnvironmentType integer value
2610 // equals (llvm::Triple::Pixel + N).
2611 //
2612 // For example, if a function has been scanned in compute and pixel stage
2613 // environment, the value will be 0x21 (100001 binary) because:
2614 //
2615 // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
2616 // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
2617 //
2618 // A FunctionDecl is mapped to 0 (or not included in the map) if it has not
2619 // been scanned in any environment.
2620 llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
2621
2622 // Do not access these directly, use the get/set methods below to make
2623 // sure the values are in sync
2624 llvm::Triple::EnvironmentType CurrentShaderEnvironment;
2625 unsigned CurrentShaderStageBit;
2626
2627 // True if scanning a function that was already scanned in a different
2628 // shader stage context, and therefore we should not report issues that
2629 // depend only on shader model version because they would be duplicate.
2630 bool ReportOnlyShaderStageIssues;
2631
2632 // Helper methods for dealing with current stage context / environment
2633 void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
2634 static_assert(sizeof(unsigned) >= 4);
2635 assert(HLSLShaderAttr::isValidShaderType(ShaderType));
2636 assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
2637 "ShaderType is too big for this bitmap"); // 31 is reserved for
2638 // "unknown"
2639
2640 unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
2641 CurrentShaderEnvironment = ShaderType;
2642 CurrentShaderStageBit = (1 << bitmapIndex);
2643 }
2644
2645 void SetUnknownShaderStageContext() {
2646 CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment;
2647 CurrentShaderStageBit = (1 << 31);
2648 }
2649
2650 llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const {
2651 return CurrentShaderEnvironment;
2652 }
2653
2654 bool InUnknownShaderStageContext() const {
2655 return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment;
2656 }
2657
2658 // Helper methods for dealing with shader stage bitmap
2659 void AddToScannedFunctions(const FunctionDecl *FD) {
2660 unsigned &ScannedStages = ScannedDecls[FD];
2661 ScannedStages |= CurrentShaderStageBit;
2662 }
2663
2664 unsigned GetScannedStages(const FunctionDecl *FD) { return ScannedDecls[FD]; }
2665
2666 bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) {
2667 return WasAlreadyScannedInCurrentStage(GetScannedStages(FD));
2668 }
2669
2670 bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) {
2671 return ScannerStages & CurrentShaderStageBit;
2672 }
2673
2674 static bool NeverBeenScanned(unsigned ScannedStages) {
2675 return ScannedStages == 0;
2676 }
2677
2678 // Scanning methods
2679 void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr);
2680 void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA,
2681 SourceRange Range);
2682 const AvailabilityAttr *FindAvailabilityAttr(const Decl *D);
2683 bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA);
2684
2685public:
2686 DiagnoseHLSLAvailability(Sema &SemaRef)
2687 : SemaRef(SemaRef),
2688 CurrentShaderEnvironment(llvm::Triple::UnknownEnvironment),
2689 CurrentShaderStageBit(0), ReportOnlyShaderStageIssues(false) {}
2690
2691 // AST traversal methods
2692 void RunOnTranslationUnit(const TranslationUnitDecl *TU);
2693 void RunOnFunction(const FunctionDecl *FD);
2694
2695 bool VisitDeclRefExpr(DeclRefExpr *DRE) override {
2696 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(DRE->getDecl());
2697 if (FD)
2698 HandleFunctionOrMethodRef(FD, DRE);
2699 return true;
2700 }
2701
2702 bool VisitMemberExpr(MemberExpr *ME) override {
2703 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(ME->getMemberDecl());
2704 if (FD)
2705 HandleFunctionOrMethodRef(FD, ME);
2706 return true;
2707 }
2708};
2709
2710void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD,
2711 Expr *RefExpr) {
2712 assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) &&
2713 "expected DeclRefExpr or MemberExpr");
2714
2715 // has a definition -> add to stack to be scanned
2716 const FunctionDecl *FDWithBody = nullptr;
2717 if (FD->hasBody(FDWithBody)) {
2718 if (!WasAlreadyScannedInCurrentStage(FDWithBody))
2719 DeclsToScan.push_back(FDWithBody);
2720 return;
2721 }
2722
2723 // no body -> diagnose availability
2724 const AvailabilityAttr *AA = FindAvailabilityAttr(FD);
2725 if (AA)
2726 CheckDeclAvailability(
2727 FD, AA, SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc()));
2728}
2729
2730void DiagnoseHLSLAvailability::RunOnTranslationUnit(
2731 const TranslationUnitDecl *TU) {
2732
2733 // Iterate over all shader entry functions and library exports, and for those
2734 // that have a body (definiton), run diag scan on each, setting appropriate
2735 // shader environment context based on whether it is a shader entry function
2736 // or an exported function. Exported functions can be in namespaces and in
2737 // export declarations so we need to scan those declaration contexts as well.
2739 DeclContextsToScan.push_back(TU);
2740
2741 while (!DeclContextsToScan.empty()) {
2742 const DeclContext *DC = DeclContextsToScan.pop_back_val();
2743 for (auto &D : DC->decls()) {
2744 // do not scan implicit declaration generated by the implementation
2745 if (D->isImplicit())
2746 continue;
2747
2748 // for namespace or export declaration add the context to the list to be
2749 // scanned later
2750 if (llvm::dyn_cast<NamespaceDecl>(D) || llvm::dyn_cast<ExportDecl>(D)) {
2751 DeclContextsToScan.push_back(llvm::dyn_cast<DeclContext>(D));
2752 continue;
2753 }
2754
2755 // skip over other decls or function decls without body
2756 const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(D);
2757 if (!FD || !FD->isThisDeclarationADefinition())
2758 continue;
2759
2760 // shader entry point
2761 if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) {
2762 SetShaderStageContext(ShaderAttr->getType());
2763 RunOnFunction(FD);
2764 continue;
2765 }
2766 // exported library function
2767 // FIXME: replace this loop with external linkage check once issue #92071
2768 // is resolved
2769 bool isExport = FD->isInExportDeclContext();
2770 if (!isExport) {
2771 for (const auto *Redecl : FD->redecls()) {
2772 if (Redecl->isInExportDeclContext()) {
2773 isExport = true;
2774 break;
2775 }
2776 }
2777 }
2778 if (isExport) {
2779 SetUnknownShaderStageContext();
2780 RunOnFunction(FD);
2781 continue;
2782 }
2783 }
2784 }
2785}
2786
2787void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) {
2788 assert(DeclsToScan.empty() && "DeclsToScan should be empty");
2789 DeclsToScan.push_back(FD);
2790
2791 while (!DeclsToScan.empty()) {
2792 // Take one decl from the stack and check it by traversing its AST.
2793 // For any CallExpr found during the traversal add it's callee to the top of
2794 // the stack to be processed next. Functions already processed are stored in
2795 // ScannedDecls.
2796 const FunctionDecl *FD = DeclsToScan.pop_back_val();
2797
2798 // Decl was already scanned
2799 const unsigned ScannedStages = GetScannedStages(FD);
2800 if (WasAlreadyScannedInCurrentStage(ScannedStages))
2801 continue;
2802
2803 ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages);
2804
2805 AddToScannedFunctions(FD);
2806 TraverseStmt(FD->getBody());
2807 }
2808}
2809
2810bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone(
2811 const AvailabilityAttr *AA) {
2812 const IdentifierInfo *IIEnvironment = AA->getEnvironment();
2813 if (!IIEnvironment)
2814 return true;
2815
2816 llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment();
2817 if (CurrentEnv == llvm::Triple::UnknownEnvironment)
2818 return false;
2819
2820 llvm::Triple::EnvironmentType AttrEnv =
2821 AvailabilityAttr::getEnvironmentType(IIEnvironment->getName());
2822
2823 return CurrentEnv == AttrEnv;
2824}
2825
2826const AvailabilityAttr *
2827DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) {
2828 AvailabilityAttr const *PartialMatch = nullptr;
2829 // Check each AvailabilityAttr to find the one for this platform.
2830 // For multiple attributes with the same platform try to find one for this
2831 // environment.
2832 for (const auto *A : D->attrs()) {
2833 if (const auto *Avail = dyn_cast<AvailabilityAttr>(A)) {
2834 StringRef AttrPlatform = Avail->getPlatform()->getName();
2835 StringRef TargetPlatform =
2837
2838 // Match the platform name.
2839 if (AttrPlatform == TargetPlatform) {
2840 // Find the best matching attribute for this environment
2841 if (HasMatchingEnvironmentOrNone(Avail))
2842 return Avail;
2843 PartialMatch = Avail;
2844 }
2845 }
2846 }
2847 return PartialMatch;
2848}
2849
2850// Check availability against target shader model version and current shader
2851// stage and emit diagnostic
2852void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,
2853 const AvailabilityAttr *AA,
2854 SourceRange Range) {
2855
2856 const IdentifierInfo *IIEnv = AA->getEnvironment();
2857
2858 if (!IIEnv) {
2859 // The availability attribute does not have environment -> it depends only
2860 // on shader model version and not on specific the shader stage.
2861
2862 // Skip emitting the diagnostics if the diagnostic mode is set to
2863 // strict (-fhlsl-strict-availability) because all relevant diagnostics
2864 // were already emitted in the DiagnoseUnguardedAvailability scan
2865 // (SemaAvailability.cpp).
2866 if (SemaRef.getLangOpts().HLSLStrictAvailability)
2867 return;
2868
2869 // Do not report shader-stage-independent issues if scanning a function
2870 // that was already scanned in a different shader stage context (they would
2871 // be duplicate)
2872 if (ReportOnlyShaderStageIssues)
2873 return;
2874
2875 } else {
2876 // The availability attribute has environment -> we need to know
2877 // the current stage context to property diagnose it.
2878 if (InUnknownShaderStageContext())
2879 return;
2880 }
2881
2882 // Check introduced version and if environment matches
2883 bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA);
2884 VersionTuple Introduced = AA->getIntroduced();
2885 VersionTuple TargetVersion =
2887
2888 if (TargetVersion >= Introduced && EnvironmentMatches)
2889 return;
2890
2891 // Emit diagnostic message
2892 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
2893 llvm::StringRef PlatformName(
2894 AvailabilityAttr::getPrettyPlatformName(TI.getPlatformName()));
2895
2896 llvm::StringRef CurrentEnvStr =
2897 llvm::Triple::getEnvironmentTypeName(GetCurrentShaderEnvironment());
2898
2899 llvm::StringRef AttrEnvStr =
2900 AA->getEnvironment() ? AA->getEnvironment()->getName() : "";
2901 bool UseEnvironment = !AttrEnvStr.empty();
2902
2903 if (EnvironmentMatches) {
2904 SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability)
2905 << Range << D << PlatformName << Introduced.getAsString()
2906 << UseEnvironment << CurrentEnvStr;
2907 } else {
2908 SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability_unavailable)
2909 << Range << D;
2910 }
2911
2912 SemaRef.Diag(D->getLocation(), diag::note_partial_availability_specified_here)
2913 << D << PlatformName << Introduced.getAsString()
2914 << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString()
2915 << UseEnvironment << AttrEnvStr << CurrentEnvStr;
2916}
2917
2918} // namespace
2919
2921 // process default CBuffer - create buffer layout struct and invoke codegenCGH
2922 if (!DefaultCBufferDecls.empty()) {
2924 SemaRef.getASTContext(), SemaRef.getCurLexicalContext(),
2925 DefaultCBufferDecls);
2926 addImplicitBindingAttrToDecl(SemaRef, DefaultCBuffer, RegisterType::CBuffer,
2927 getNextImplicitBindingOrderID());
2928 SemaRef.getCurLexicalContext()->addDecl(DefaultCBuffer);
2930
2931 // Set HasValidPackoffset if any of the decls has a register(c#) annotation;
2932 for (const Decl *VD : DefaultCBufferDecls) {
2933 const HLSLResourceBindingAttr *RBA =
2934 VD->getAttr<HLSLResourceBindingAttr>();
2935 if (RBA && RBA->hasRegisterSlot() &&
2936 RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) {
2937 DefaultCBuffer->setHasValidPackoffset(true);
2938 break;
2939 }
2940 }
2941
2942 DeclGroupRef DG(DefaultCBuffer);
2943 SemaRef.Consumer.HandleTopLevelDecl(DG);
2944 }
2945 diagnoseAvailabilityViolations(TU);
2946}
2947
2948void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
2949 // Skip running the diagnostics scan if the diagnostic mode is
2950 // strict (-fhlsl-strict-availability) and the target shader stage is known
2951 // because all relevant diagnostics were already emitted in the
2952 // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp).
2954 if (SemaRef.getLangOpts().HLSLStrictAvailability &&
2955 TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library)
2956 return;
2957
2958 DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
2959}
2960
2961static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
2962 assert(TheCall->getNumArgs() > 1);
2963 QualType ArgTy0 = TheCall->getArg(0)->getType();
2964
2965 for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) {
2967 ArgTy0, TheCall->getArg(I)->getType())) {
2968 S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
2969 << TheCall->getDirectCallee() << /*useAllTerminology*/ true
2970 << SourceRange(TheCall->getArg(0)->getBeginLoc(),
2971 TheCall->getArg(N - 1)->getEndLoc());
2972 return true;
2973 }
2974 }
2975 return false;
2976}
2977
2979 QualType ArgType = Arg->getType();
2981 S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
2982 << ArgType << ExpectedType << 1 << 0 << 0;
2983 return true;
2984 }
2985 return false;
2986}
2987
2989 Sema *S, CallExpr *TheCall,
2990 llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
2991 clang::QualType PassedType)>
2992 Check) {
2993 for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
2994 Expr *Arg = TheCall->getArg(I);
2995 if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
2996 return true;
2997 }
2998 return false;
2999}
3000
3002 int ArgOrdinal,
3003 clang::QualType PassedType) {
3004 clang::QualType BaseType =
3005 PassedType->isVectorType()
3006 ? PassedType->castAs<clang::VectorType>()->getElementType()
3007 : PassedType;
3008 if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
3009 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3010 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3011 << /* half or float */ 2 << PassedType;
3012 return false;
3013}
3014
3015static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
3016 unsigned ArgIndex) {
3017 auto *Arg = TheCall->getArg(ArgIndex);
3018 SourceLocation OrigLoc = Arg->getExprLoc();
3019 if (Arg->IgnoreCasts()->isModifiableLvalue(S->Context, &OrigLoc) ==
3021 return false;
3022 S->Diag(OrigLoc, diag::error_hlsl_inout_lvalue) << Arg << 0;
3023 return true;
3024}
3025
3026static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
3027 clang::QualType PassedType) {
3028 const auto *VecTy = PassedType->getAs<VectorType>();
3029 if (!VecTy)
3030 return false;
3031
3032 if (VecTy->getElementType()->isDoubleType())
3033 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3034 << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
3035 << PassedType;
3036 return false;
3037}
3038
3040 int ArgOrdinal,
3041 clang::QualType PassedType) {
3042 if (!PassedType->hasIntegerRepresentation() &&
3043 !PassedType->hasFloatingRepresentation())
3044 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3045 << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
3046 << /* fp */ 1 << PassedType;
3047 return false;
3048}
3049
3051 int ArgOrdinal,
3052 clang::QualType PassedType) {
3053 if (auto *VecTy = PassedType->getAs<VectorType>())
3054 if (VecTy->getElementType()->isUnsignedIntegerType())
3055 return false;
3056
3057 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3058 << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
3059 << PassedType;
3060}
3061
3062// checks for unsigned ints of all sizes
3064 int ArgOrdinal,
3065 clang::QualType PassedType) {
3066 if (!PassedType->hasUnsignedIntegerRepresentation())
3067 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
3068 << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
3069 << /* no fp */ 0 << PassedType;
3070 return false;
3071}
3072
3073static bool CheckExpectedBitWidth(Sema *S, CallExpr *TheCall,
3074 unsigned ArgOrdinal, unsigned Width) {
3075 QualType ArgTy = TheCall->getArg(0)->getType();
3076 if (auto *VTy = ArgTy->getAs<VectorType>())
3077 ArgTy = VTy->getElementType();
3078 // ensure arg type has expected bit width
3079 uint64_t ElementBitCount =
3081 if (ElementBitCount != Width) {
3082 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3083 diag::err_integer_incorrect_bit_count)
3084 << Width << ElementBitCount;
3085 return true;
3086 }
3087 return false;
3088}
3089
3091 QualType ReturnType) {
3092 auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>();
3093 if (VecTyA)
3094 ReturnType =
3095 S->Context.getExtVectorType(ReturnType, VecTyA->getNumElements());
3096
3097 TheCall->setType(ReturnType);
3098}
3099
3100static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
3101 unsigned ArgIndex) {
3102 assert(TheCall->getNumArgs() >= ArgIndex);
3103 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3104 auto *VTy = ArgType->getAs<VectorType>();
3105 // not the scalar or vector<scalar>
3106 if (!(S->Context.hasSameUnqualifiedType(ArgType, Scalar) ||
3107 (VTy &&
3108 S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar)))) {
3109 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3110 diag::err_typecheck_expect_scalar_or_vector)
3111 << ArgType << Scalar;
3112 return true;
3113 }
3114 return false;
3115}
3116
3118 QualType Scalar, unsigned ArgIndex) {
3119 assert(TheCall->getNumArgs() > ArgIndex);
3120
3121 Expr *Arg = TheCall->getArg(ArgIndex);
3122 QualType ArgType = Arg->getType();
3123
3124 // Scalar: T
3125 if (S->Context.hasSameUnqualifiedType(ArgType, Scalar))
3126 return false;
3127
3128 // Vector: vector<T>
3129 if (const auto *VTy = ArgType->getAs<VectorType>()) {
3130 if (S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar))
3131 return false;
3132 }
3133
3134 // Matrix: ConstantMatrixType with element type T
3135 if (const auto *MTy = ArgType->getAs<ConstantMatrixType>()) {
3136 if (S->Context.hasSameUnqualifiedType(MTy->getElementType(), Scalar))
3137 return false;
3138 }
3139
3140 // Not a scalar/vector/matrix-of-scalar
3141 S->Diag(Arg->getBeginLoc(),
3142 diag::err_typecheck_expect_scalar_or_vector_or_matrix)
3143 << ArgType << Scalar;
3144 return true;
3145}
3146
3147static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
3148 unsigned ArgIndex) {
3149 assert(TheCall->getNumArgs() >= ArgIndex);
3150 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3151 auto *VTy = ArgType->getAs<VectorType>();
3152 // not the scalar or vector<scalar>
3153 if (!(ArgType->isScalarType() ||
3154 (VTy && VTy->getElementType()->isScalarType()))) {
3155 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3156 diag::err_typecheck_expect_any_scalar_or_vector)
3157 << ArgType << 1;
3158 return true;
3159 }
3160 return false;
3161}
3162
3163static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
3164 QualType BoolType = S->getASTContext().BoolTy;
3165 assert(TheCall->getNumArgs() >= 1);
3166 QualType ArgType = TheCall->getArg(0)->getType();
3167 auto *VTy = ArgType->getAs<VectorType>();
3168 // is the bool or vector<bool>
3169 if (S->Context.hasSameUnqualifiedType(ArgType, BoolType) ||
3170 (VTy &&
3171 S->Context.hasSameUnqualifiedType(VTy->getElementType(), BoolType))) {
3172 S->Diag(TheCall->getArg(0)->getBeginLoc(),
3173 diag::err_typecheck_expect_any_scalar_or_vector)
3174 << ArgType << 0;
3175 return true;
3176 }
3177 return false;
3178}
3179
3180static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
3181 assert(TheCall->getNumArgs() == 3);
3182 Expr *Arg1 = TheCall->getArg(1);
3183 Expr *Arg2 = TheCall->getArg(2);
3184 if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
3185 S->Diag(TheCall->getBeginLoc(),
3186 diag::err_typecheck_call_different_arg_types)
3187 << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
3188 << Arg2->getSourceRange();
3189 return true;
3190 }
3191
3192 TheCall->setType(Arg1->getType());
3193 return false;
3194}
3195
3196static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
3197 assert(TheCall->getNumArgs() == 3);
3198 Expr *Arg1 = TheCall->getArg(1);
3199 QualType Arg1Ty = Arg1->getType();
3200 Expr *Arg2 = TheCall->getArg(2);
3201 QualType Arg2Ty = Arg2->getType();
3202
3203 QualType Arg1ScalarTy = Arg1Ty;
3204 if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
3205 Arg1ScalarTy = VTy->getElementType();
3206
3207 QualType Arg2ScalarTy = Arg2Ty;
3208 if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
3209 Arg2ScalarTy = VTy->getElementType();
3210
3211 if (!S->Context.hasSameUnqualifiedType(Arg1ScalarTy, Arg2ScalarTy))
3212 S->Diag(Arg1->getBeginLoc(), diag::err_hlsl_builtin_scalar_vector_mismatch)
3213 << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
3214
3215 QualType Arg0Ty = TheCall->getArg(0)->getType();
3216 unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
3217 unsigned Arg1Length = Arg1Ty->isVectorType()
3218 ? Arg1Ty->getAs<VectorType>()->getNumElements()
3219 : 0;
3220 unsigned Arg2Length = Arg2Ty->isVectorType()
3221 ? Arg2Ty->getAs<VectorType>()->getNumElements()
3222 : 0;
3223 if (Arg1Length > 0 && Arg0Length != Arg1Length) {
3224 S->Diag(TheCall->getBeginLoc(),
3225 diag::err_typecheck_vector_lengths_not_equal)
3226 << Arg0Ty << Arg1Ty << TheCall->getArg(0)->getSourceRange()
3227 << Arg1->getSourceRange();
3228 return true;
3229 }
3230
3231 if (Arg2Length > 0 && Arg0Length != Arg2Length) {
3232 S->Diag(TheCall->getBeginLoc(),
3233 diag::err_typecheck_vector_lengths_not_equal)
3234 << Arg0Ty << Arg2Ty << TheCall->getArg(0)->getSourceRange()
3235 << Arg2->getSourceRange();
3236 return true;
3237 }
3238
3239 TheCall->setType(
3240 S->getASTContext().getExtVectorType(Arg1ScalarTy, Arg0Length));
3241 return false;
3242}
3243
3245 Sema *S, CallExpr *TheCall, unsigned ArgIndex,
3246 llvm::function_ref<bool(const HLSLAttributedResourceType *ResType)> Check =
3247 nullptr) {
3248 assert(TheCall->getNumArgs() >= ArgIndex);
3249 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
3250 const HLSLAttributedResourceType *ResTy =
3251 ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>();
3252 if (!ResTy) {
3253 S->Diag(TheCall->getArg(ArgIndex)->getBeginLoc(),
3254 diag::err_typecheck_expect_hlsl_resource)
3255 << ArgType;
3256 return true;
3257 }
3258 if (Check && Check(ResTy)) {
3259 S->Diag(TheCall->getArg(ArgIndex)->getExprLoc(),
3260 diag::err_invalid_hlsl_resource_type)
3261 << ArgType;
3262 return true;
3263 }
3264 return false;
3265}
3266
3267// Note: returning true in this case results in CheckBuiltinFunctionCall
3268// returning an ExprError
3269bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
3270 switch (BuiltinID) {
3271 case Builtin::BI__builtin_hlsl_adduint64: {
3272 if (SemaRef.checkArgCount(TheCall, 2))
3273 return true;
3274
3275 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3277 return true;
3278
3279 // ensure arg integers are 32-bits
3280 if (CheckExpectedBitWidth(&SemaRef, TheCall, 0, 32))
3281 return true;
3282
3283 // ensure both args are vectors of total bit size of a multiple of 64
3284 auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
3285 int NumElementsArg = VTy->getNumElements();
3286 if (NumElementsArg != 2 && NumElementsArg != 4) {
3287 SemaRef.Diag(TheCall->getBeginLoc(), diag::err_vector_incorrect_bit_count)
3288 << 1 /*a multiple of*/ << 64 << NumElementsArg * 32;
3289 return true;
3290 }
3291
3292 // ensure first arg and second arg have the same type
3293 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
3294 return true;
3295
3296 ExprResult A = TheCall->getArg(0);
3297 QualType ArgTyA = A.get()->getType();
3298 // return type is the same as the input type
3299 TheCall->setType(ArgTyA);
3300 break;
3301 }
3302 case Builtin::BI__builtin_hlsl_resource_getpointer: {
3303 if (SemaRef.checkArgCount(TheCall, 2) ||
3304 CheckResourceHandle(&SemaRef, TheCall, 0) ||
3305 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
3306 SemaRef.getASTContext().UnsignedIntTy))
3307 return true;
3308
3309 auto *ResourceTy =
3310 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3311 QualType ContainedTy = ResourceTy->getContainedType();
3312 auto ReturnType =
3313 SemaRef.Context.getAddrSpaceQualType(ContainedTy, LangAS::hlsl_device);
3314 ReturnType = SemaRef.Context.getPointerType(ReturnType);
3315 TheCall->setType(ReturnType);
3316 TheCall->setValueKind(VK_LValue);
3317
3318 break;
3319 }
3320 case Builtin::BI__builtin_hlsl_resource_load_with_status: {
3321 if (SemaRef.checkArgCount(TheCall, 3) ||
3322 CheckResourceHandle(&SemaRef, TheCall, 0) ||
3323 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
3324 SemaRef.getASTContext().UnsignedIntTy) ||
3325 CheckArgTypeMatches(&SemaRef, TheCall->getArg(2),
3326 SemaRef.getASTContext().UnsignedIntTy) ||
3327 CheckModifiableLValue(&SemaRef, TheCall, 2))
3328 return true;
3329
3330 auto *ResourceTy =
3331 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
3332 QualType ReturnType = ResourceTy->getContainedType();
3333 TheCall->setType(ReturnType);
3334
3335 break;
3336 }
3337
3338 case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: {
3339 assert(TheCall->getNumArgs() == 1 && "expected 1 arg");
3340 // Update return type to be the attributed resource type from arg0.
3341 QualType ResourceTy = TheCall->getArg(0)->getType();
3342 TheCall->setType(ResourceTy);
3343 break;
3344 }
3345 case Builtin::BI__builtin_hlsl_resource_handlefrombinding: {
3346 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
3347 // Update return type to be the attributed resource type from arg0.
3348 QualType ResourceTy = TheCall->getArg(0)->getType();
3349 TheCall->setType(ResourceTy);
3350 break;
3351 }
3352 case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: {
3353 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
3354 // Update return type to be the attributed resource type from arg0.
3355 QualType ResourceTy = TheCall->getArg(0)->getType();
3356 TheCall->setType(ResourceTy);
3357 break;
3358 }
3359 case Builtin::BI__builtin_hlsl_resource_counterhandlefromimplicitbinding: {
3360 assert(TheCall->getNumArgs() == 3 && "expected 3 args");
3361 ASTContext &AST = SemaRef.getASTContext();
3362 QualType MainHandleTy = TheCall->getArg(0)->getType();
3363 auto *MainResType = MainHandleTy->getAs<HLSLAttributedResourceType>();
3364 auto MainAttrs = MainResType->getAttrs();
3365 assert(!MainAttrs.IsCounter && "cannot create a counter from a counter");
3366 MainAttrs.IsCounter = true;
3367 QualType CounterHandleTy = AST.getHLSLAttributedResourceType(
3368 MainResType->getWrappedType(), MainResType->getContainedType(),
3369 MainAttrs);
3370 // Update return type to be the attributed resource type from arg0
3371 // with added IsCounter flag.
3372 TheCall->setType(CounterHandleTy);
3373 break;
3374 }
3375 case Builtin::BI__builtin_hlsl_and:
3376 case Builtin::BI__builtin_hlsl_or: {
3377 if (SemaRef.checkArgCount(TheCall, 2))
3378 return true;
3379 if (CheckScalarOrVectorOrMatrix(&SemaRef, TheCall, getASTContext().BoolTy,
3380 0))
3381 return true;
3382 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
3383 return true;
3384
3385 ExprResult A = TheCall->getArg(0);
3386 QualType ArgTyA = A.get()->getType();
3387 // return type is the same as the input type
3388 TheCall->setType(ArgTyA);
3389 break;
3390 }
3391 case Builtin::BI__builtin_hlsl_all:
3392 case Builtin::BI__builtin_hlsl_any: {
3393 if (SemaRef.checkArgCount(TheCall, 1))
3394 return true;
3395 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
3396 return true;
3397 break;
3398 }
3399 case Builtin::BI__builtin_hlsl_asdouble: {
3400 if (SemaRef.checkArgCount(TheCall, 2))
3401 return true;
3403 &SemaRef, TheCall,
3404 /*only check for uint*/ SemaRef.Context.UnsignedIntTy,
3405 /* arg index */ 0))
3406 return true;
3408 &SemaRef, TheCall,
3409 /*only check for uint*/ SemaRef.Context.UnsignedIntTy,
3410 /* arg index */ 1))
3411 return true;
3412 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
3413 return true;
3414
3415 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().DoubleTy);
3416 break;
3417 }
3418 case Builtin::BI__builtin_hlsl_elementwise_clamp: {
3419 if (SemaRef.BuiltinElementwiseTernaryMath(
3420 TheCall, /*ArgTyRestr=*/
3422 return true;
3423 break;
3424 }
3425 case Builtin::BI__builtin_hlsl_dot: {
3426 // arg count is checked by BuiltinVectorToScalarMath
3427 if (SemaRef.BuiltinVectorToScalarMath(TheCall))
3428 return true;
3430 return true;
3431 break;
3432 }
3433 case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
3434 case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: {
3435 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3436 return true;
3437
3438 const Expr *Arg = TheCall->getArg(0);
3439 QualType ArgTy = Arg->getType();
3440 QualType EltTy = ArgTy;
3441
3442 QualType ResTy = SemaRef.Context.UnsignedIntTy;
3443
3444 if (auto *VecTy = EltTy->getAs<VectorType>()) {
3445 EltTy = VecTy->getElementType();
3446 ResTy = SemaRef.Context.getExtVectorType(ResTy, VecTy->getNumElements());
3447 }
3448
3449 if (!EltTy->isIntegerType()) {
3450 Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
3451 << 1 << /* scalar or vector of */ 5 << /* integer ty */ 1
3452 << /* no fp */ 0 << ArgTy;
3453 return true;
3454 }
3455
3456 TheCall->setType(ResTy);
3457 break;
3458 }
3459 case Builtin::BI__builtin_hlsl_select: {
3460 if (SemaRef.checkArgCount(TheCall, 3))
3461 return true;
3462 if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
3463 return true;
3464 QualType ArgTy = TheCall->getArg(0)->getType();
3465 if (ArgTy->isBooleanType() && CheckBoolSelect(&SemaRef, TheCall))
3466 return true;
3467 auto *VTy = ArgTy->getAs<VectorType>();
3468 if (VTy && VTy->getElementType()->isBooleanType() &&
3469 CheckVectorSelect(&SemaRef, TheCall))
3470 return true;
3471 break;
3472 }
3473 case Builtin::BI__builtin_hlsl_elementwise_saturate:
3474 case Builtin::BI__builtin_hlsl_elementwise_rcp: {
3475 if (SemaRef.checkArgCount(TheCall, 1))
3476 return true;
3477 if (!TheCall->getArg(0)
3478 ->getType()
3479 ->hasFloatingRepresentation()) // half or float or double
3480 return SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
3481 diag::err_builtin_invalid_arg_type)
3482 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
3483 << /* fp */ 1 << TheCall->getArg(0)->getType();
3484 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3485 return true;
3486 break;
3487 }
3488 case Builtin::BI__builtin_hlsl_elementwise_degrees:
3489 case Builtin::BI__builtin_hlsl_elementwise_radians:
3490 case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
3491 case Builtin::BI__builtin_hlsl_elementwise_frac:
3492 case Builtin::BI__builtin_hlsl_elementwise_ddx_coarse:
3493 case Builtin::BI__builtin_hlsl_elementwise_ddy_coarse:
3494 case Builtin::BI__builtin_hlsl_elementwise_ddx_fine:
3495 case Builtin::BI__builtin_hlsl_elementwise_ddy_fine: {
3496 if (SemaRef.checkArgCount(TheCall, 1))
3497 return true;
3498 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3500 return true;
3501 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3502 return true;
3503 break;
3504 }
3505 case Builtin::BI__builtin_hlsl_elementwise_isinf:
3506 case Builtin::BI__builtin_hlsl_elementwise_isnan: {
3507 if (SemaRef.checkArgCount(TheCall, 1))
3508 return true;
3509 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3511 return true;
3512 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3513 return true;
3515 break;
3516 }
3517 case Builtin::BI__builtin_hlsl_lerp: {
3518 if (SemaRef.checkArgCount(TheCall, 3))
3519 return true;
3520 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3522 return true;
3523 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
3524 return true;
3525 if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
3526 return true;
3527 break;
3528 }
3529 case Builtin::BI__builtin_hlsl_mad: {
3530 if (SemaRef.BuiltinElementwiseTernaryMath(
3531 TheCall, /*ArgTyRestr=*/
3533 return true;
3534 break;
3535 }
3536 case Builtin::BI__builtin_hlsl_normalize: {
3537 if (SemaRef.checkArgCount(TheCall, 1))
3538 return true;
3539 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3541 return true;
3542 ExprResult A = TheCall->getArg(0);
3543 QualType ArgTyA = A.get()->getType();
3544 // return type is the same as the input type
3545 TheCall->setType(ArgTyA);
3546 break;
3547 }
3548 case Builtin::BI__builtin_hlsl_elementwise_sign: {
3549 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3550 return true;
3551 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3553 return true;
3555 break;
3556 }
3557 case Builtin::BI__builtin_hlsl_step: {
3558 if (SemaRef.checkArgCount(TheCall, 2))
3559 return true;
3560 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3562 return true;
3563
3564 ExprResult A = TheCall->getArg(0);
3565 QualType ArgTyA = A.get()->getType();
3566 // return type is the same as the input type
3567 TheCall->setType(ArgTyA);
3568 break;
3569 }
3570 case Builtin::BI__builtin_hlsl_wave_active_max:
3571 case Builtin::BI__builtin_hlsl_wave_active_min:
3572 case Builtin::BI__builtin_hlsl_wave_active_sum: {
3573 if (SemaRef.checkArgCount(TheCall, 1))
3574 return true;
3575
3576 // Ensure input expr type is a scalar/vector and the same as the return type
3577 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
3578 return true;
3579 if (CheckWaveActive(&SemaRef, TheCall))
3580 return true;
3581 ExprResult Expr = TheCall->getArg(0);
3582 QualType ArgTyExpr = Expr.get()->getType();
3583 TheCall->setType(ArgTyExpr);
3584 break;
3585 }
3586 // Note these are llvm builtins that we want to catch invalid intrinsic
3587 // generation. Normal handling of these builtins will occur elsewhere.
3588 case Builtin::BI__builtin_elementwise_bitreverse: {
3589 // does not include a check for number of arguments
3590 // because that is done previously
3591 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3593 return true;
3594 break;
3595 }
3596 case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
3597 if (SemaRef.checkArgCount(TheCall, 2))
3598 return true;
3599
3600 // Ensure index parameter type can be interpreted as a uint
3601 ExprResult Index = TheCall->getArg(1);
3602 QualType ArgTyIndex = Index.get()->getType();
3603 if (!ArgTyIndex->isIntegerType()) {
3604 SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
3605 diag::err_typecheck_convert_incompatible)
3606 << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
3607 return true;
3608 }
3609
3610 // Ensure input expr type is a scalar/vector and the same as the return type
3611 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
3612 return true;
3613
3614 ExprResult Expr = TheCall->getArg(0);
3615 QualType ArgTyExpr = Expr.get()->getType();
3616 TheCall->setType(ArgTyExpr);
3617 break;
3618 }
3619 case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
3620 if (SemaRef.checkArgCount(TheCall, 0))
3621 return true;
3622 break;
3623 }
3624 case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
3625 if (SemaRef.checkArgCount(TheCall, 3))
3626 return true;
3627
3628 if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.DoubleTy, 0) ||
3629 CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
3630 1) ||
3631 CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
3632 2))
3633 return true;
3634
3635 if (CheckModifiableLValue(&SemaRef, TheCall, 1) ||
3636 CheckModifiableLValue(&SemaRef, TheCall, 2))
3637 return true;
3638 break;
3639 }
3640 case Builtin::BI__builtin_hlsl_elementwise_clip: {
3641 if (SemaRef.checkArgCount(TheCall, 1))
3642 return true;
3643
3644 if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.FloatTy, 0))
3645 return true;
3646 break;
3647 }
3648 case Builtin::BI__builtin_elementwise_acos:
3649 case Builtin::BI__builtin_elementwise_asin:
3650 case Builtin::BI__builtin_elementwise_atan:
3651 case Builtin::BI__builtin_elementwise_atan2:
3652 case Builtin::BI__builtin_elementwise_ceil:
3653 case Builtin::BI__builtin_elementwise_cos:
3654 case Builtin::BI__builtin_elementwise_cosh:
3655 case Builtin::BI__builtin_elementwise_exp:
3656 case Builtin::BI__builtin_elementwise_exp2:
3657 case Builtin::BI__builtin_elementwise_exp10:
3658 case Builtin::BI__builtin_elementwise_floor:
3659 case Builtin::BI__builtin_elementwise_fmod:
3660 case Builtin::BI__builtin_elementwise_log:
3661 case Builtin::BI__builtin_elementwise_log2:
3662 case Builtin::BI__builtin_elementwise_log10:
3663 case Builtin::BI__builtin_elementwise_pow:
3664 case Builtin::BI__builtin_elementwise_roundeven:
3665 case Builtin::BI__builtin_elementwise_sin:
3666 case Builtin::BI__builtin_elementwise_sinh:
3667 case Builtin::BI__builtin_elementwise_sqrt:
3668 case Builtin::BI__builtin_elementwise_tan:
3669 case Builtin::BI__builtin_elementwise_tanh:
3670 case Builtin::BI__builtin_elementwise_trunc: {
3671 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3673 return true;
3674 break;
3675 }
3676 case Builtin::BI__builtin_hlsl_buffer_update_counter: {
3677 assert(TheCall->getNumArgs() == 2 && "expected 2 args");
3678 auto checkResTy = [](const HLSLAttributedResourceType *ResTy) -> bool {
3679 return !(ResTy->getAttrs().ResourceClass == ResourceClass::UAV &&
3680 ResTy->getAttrs().RawBuffer && ResTy->hasContainedType());
3681 };
3682 if (CheckResourceHandle(&SemaRef, TheCall, 0, checkResTy))
3683 return true;
3684 Expr *OffsetExpr = TheCall->getArg(1);
3685 std::optional<llvm::APSInt> Offset =
3686 OffsetExpr->getIntegerConstantExpr(SemaRef.getASTContext());
3687 if (!Offset.has_value() || std::abs(Offset->getExtValue()) != 1) {
3688 SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
3689 diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
3690 << 1;
3691 return true;
3692 }
3693 break;
3694 }
3695 case Builtin::BI__builtin_hlsl_elementwise_f16tof32: {
3696 if (SemaRef.checkArgCount(TheCall, 1))
3697 return true;
3698 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3700 return true;
3701 // ensure arg integers are 32 bits
3702 if (CheckExpectedBitWidth(&SemaRef, TheCall, 0, 32))
3703 return true;
3704 // check it wasn't a bool type
3705 QualType ArgTy = TheCall->getArg(0)->getType();
3706 if (auto *VTy = ArgTy->getAs<VectorType>())
3707 ArgTy = VTy->getElementType();
3708 if (ArgTy->isBooleanType()) {
3709 SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
3710 diag::err_builtin_invalid_arg_type)
3711 << 1 << /* scalar or vector of */ 5 << /* unsigned int */ 3
3712 << /* no fp */ 0 << TheCall->getArg(0)->getType();
3713 return true;
3714 }
3715
3716 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().FloatTy);
3717 break;
3718 }
3719 }
3720 return false;
3721}
3722
3726 WorkList.push_back(BaseTy);
3727 while (!WorkList.empty()) {
3728 QualType T = WorkList.pop_back_val();
3729 T = T.getCanonicalType().getUnqualifiedType();
3730 if (const auto *AT = dyn_cast<ConstantArrayType>(T)) {
3731 llvm::SmallVector<QualType, 16> ElementFields;
3732 // Generally I've avoided recursion in this algorithm, but arrays of
3733 // structs could be time-consuming to flatten and churn through on the
3734 // work list. Hopefully nesting arrays of structs containing arrays
3735 // of structs too many levels deep is unlikely.
3736 BuildFlattenedTypeList(AT->getElementType(), ElementFields);
3737 // Repeat the element's field list n times.
3738 for (uint64_t Ct = 0; Ct < AT->getZExtSize(); ++Ct)
3739 llvm::append_range(List, ElementFields);
3740 continue;
3741 }
3742 // Vectors can only have element types that are builtin types, so this can
3743 // add directly to the list instead of to the WorkList.
3744 if (const auto *VT = dyn_cast<VectorType>(T)) {
3745 List.insert(List.end(), VT->getNumElements(), VT->getElementType());
3746 continue;
3747 }
3748 if (const auto *MT = dyn_cast<ConstantMatrixType>(T)) {
3749 List.insert(List.end(), MT->getNumElementsFlattened(),
3750 MT->getElementType());
3751 continue;
3752 }
3753 if (const auto *RD = T->getAsCXXRecordDecl()) {
3754 if (RD->isStandardLayout())
3755 RD = RD->getStandardLayoutBaseWithFields();
3756
3757 // For types that we shouldn't decompose (unions and non-aggregates), just
3758 // add the type itself to the list.
3759 if (RD->isUnion() || !RD->isAggregate()) {
3760 List.push_back(T);
3761 continue;
3762 }
3763
3765 for (const auto *FD : RD->fields())
3766 if (!FD->isUnnamedBitField())
3767 FieldTypes.push_back(FD->getType());
3768 // Reverse the newly added sub-range.
3769 std::reverse(FieldTypes.begin(), FieldTypes.end());
3770 llvm::append_range(WorkList, FieldTypes);
3771
3772 // If this wasn't a standard layout type we may also have some base
3773 // classes to deal with.
3774 if (!RD->isStandardLayout()) {
3775 FieldTypes.clear();
3776 for (const auto &Base : RD->bases())
3777 FieldTypes.push_back(Base.getType());
3778 std::reverse(FieldTypes.begin(), FieldTypes.end());
3779 llvm::append_range(WorkList, FieldTypes);
3780 }
3781 continue;
3782 }
3783 List.push_back(T);
3784 }
3785}
3786
3788 // null and array types are not allowed.
3789 if (QT.isNull() || QT->isArrayType())
3790 return false;
3791
3792 // UDT types are not allowed
3793 if (QT->isRecordType())
3794 return false;
3795
3796 if (QT->isBooleanType() || QT->isEnumeralType())
3797 return false;
3798
3799 // the only other valid builtin types are scalars or vectors
3800 if (QT->isArithmeticType()) {
3801 if (SemaRef.Context.getTypeSize(QT) / 8 > 16)
3802 return false;
3803 return true;
3804 }
3805
3806 if (const VectorType *VT = QT->getAs<VectorType>()) {
3807 int ArraySize = VT->getNumElements();
3808
3809 if (ArraySize > 4)
3810 return false;
3811
3812 QualType ElTy = VT->getElementType();
3813 if (ElTy->isBooleanType())
3814 return false;
3815
3816 if (SemaRef.Context.getTypeSize(QT) / 8 > 16)
3817 return false;
3818 return true;
3819 }
3820
3821 return false;
3822}
3823
3825 if (T1.isNull() || T2.isNull())
3826 return false;
3827
3830
3831 // If both types are the same canonical type, they're obviously compatible.
3832 if (SemaRef.getASTContext().hasSameType(T1, T2))
3833 return true;
3834
3836 BuildFlattenedTypeList(T1, T1Types);
3838 BuildFlattenedTypeList(T2, T2Types);
3839
3840 // Check the flattened type list
3841 return llvm::equal(T1Types, T2Types,
3842 [this](QualType LHS, QualType RHS) -> bool {
3843 return SemaRef.IsLayoutCompatible(LHS, RHS);
3844 });
3845}
3846
3848 FunctionDecl *Old) {
3849 if (New->getNumParams() != Old->getNumParams())
3850 return true;
3851
3852 bool HadError = false;
3853
3854 for (unsigned i = 0, e = New->getNumParams(); i != e; ++i) {
3855 ParmVarDecl *NewParam = New->getParamDecl(i);
3856 ParmVarDecl *OldParam = Old->getParamDecl(i);
3857
3858 // HLSL parameter declarations for inout and out must match between
3859 // declarations. In HLSL inout and out are ambiguous at the call site,
3860 // but have different calling behavior, so you cannot overload a
3861 // method based on a difference between inout and out annotations.
3862 const auto *NDAttr = NewParam->getAttr<HLSLParamModifierAttr>();
3863 unsigned NSpellingIdx = (NDAttr ? NDAttr->getSpellingListIndex() : 0);
3864 const auto *ODAttr = OldParam->getAttr<HLSLParamModifierAttr>();
3865 unsigned OSpellingIdx = (ODAttr ? ODAttr->getSpellingListIndex() : 0);
3866
3867 if (NSpellingIdx != OSpellingIdx) {
3868 SemaRef.Diag(NewParam->getLocation(),
3869 diag::err_hlsl_param_qualifier_mismatch)
3870 << NDAttr << NewParam;
3871 SemaRef.Diag(OldParam->getLocation(), diag::note_previous_declaration_as)
3872 << ODAttr;
3873 HadError = true;
3874 }
3875 }
3876 return HadError;
3877}
3878
3879// Generally follows PerformScalarCast, with cases reordered for
3880// clarity of what types are supported
3882
3883 if (!SrcTy->isScalarType() || !DestTy->isScalarType())
3884 return false;
3885
3886 if (SemaRef.getASTContext().hasSameUnqualifiedType(SrcTy, DestTy))
3887 return true;
3888
3889 switch (SrcTy->getScalarTypeKind()) {
3890 case Type::STK_Bool: // casting from bool is like casting from an integer
3891 case Type::STK_Integral:
3892 switch (DestTy->getScalarTypeKind()) {
3893 case Type::STK_Bool:
3894 case Type::STK_Integral:
3895 case Type::STK_Floating:
3896 return true;
3897 case Type::STK_CPointer:
3901 llvm_unreachable("HLSL doesn't support pointers.");
3904 llvm_unreachable("HLSL doesn't support complex types.");
3906 llvm_unreachable("HLSL doesn't support fixed point types.");
3907 }
3908 llvm_unreachable("Should have returned before this");
3909
3910 case Type::STK_Floating:
3911 switch (DestTy->getScalarTypeKind()) {
3912 case Type::STK_Floating:
3913 case Type::STK_Bool:
3914 case Type::STK_Integral:
3915 return true;
3918 llvm_unreachable("HLSL doesn't support complex types.");
3920 llvm_unreachable("HLSL doesn't support fixed point types.");
3921 case Type::STK_CPointer:
3925 llvm_unreachable("HLSL doesn't support pointers.");
3926 }
3927 llvm_unreachable("Should have returned before this");
3928
3930 case Type::STK_CPointer:
3933 llvm_unreachable("HLSL doesn't support pointers.");
3934
3936 llvm_unreachable("HLSL doesn't support fixed point types.");
3937
3940 llvm_unreachable("HLSL doesn't support complex types.");
3941 }
3942
3943 llvm_unreachable("Unhandled scalar cast");
3944}
3945
3946// Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
3947// Src is a scalar or a vector of length 1
3948// Or if Dest is a vector and Src is a vector of length 1
3950
3951 QualType SrcTy = Src->getType();
3952 // Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is
3953 // going to be a vector splat from a scalar.
3954 if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
3955 DestTy->isScalarType())
3956 return false;
3957
3958 const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
3959
3960 // Src isn't a scalar or a vector of length 1
3961 if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
3962 return false;
3963
3964 if (SrcVecTy)
3965 SrcTy = SrcVecTy->getElementType();
3966
3968 BuildFlattenedTypeList(DestTy, DestTypes);
3969
3970 for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) {
3971 if (DestTypes[I]->isUnionType())
3972 return false;
3973 if (!CanPerformScalarCast(SrcTy, DestTypes[I]))
3974 return false;
3975 }
3976 return true;
3977}
3978
3979// Can we perform an HLSL Elementwise cast?
3981
3982 // Don't handle casts where LHS and RHS are any combination of scalar/vector
3983 // There must be an aggregate somewhere
3984 QualType SrcTy = Src->getType();
3985 if (SrcTy->isScalarType()) // always a splat and this cast doesn't handle that
3986 return false;
3987
3988 if (SrcTy->isVectorType() &&
3989 (DestTy->isScalarType() || DestTy->isVectorType()))
3990 return false;
3991
3992 if (SrcTy->isConstantMatrixType() &&
3993 (DestTy->isScalarType() || DestTy->isConstantMatrixType()))
3994 return false;
3995
3997 BuildFlattenedTypeList(DestTy, DestTypes);
3999 BuildFlattenedTypeList(SrcTy, SrcTypes);
4000
4001 // Usually the size of SrcTypes must be greater than or equal to the size of
4002 // DestTypes.
4003 if (SrcTypes.size() < DestTypes.size())
4004 return false;
4005
4006 unsigned SrcSize = SrcTypes.size();
4007 unsigned DstSize = DestTypes.size();
4008 unsigned I;
4009 for (I = 0; I < DstSize && I < SrcSize; I++) {
4010 if (SrcTypes[I]->isUnionType() || DestTypes[I]->isUnionType())
4011 return false;
4012 if (!CanPerformScalarCast(SrcTypes[I], DestTypes[I])) {
4013 return false;
4014 }
4015 }
4016
4017 // check the rest of the source type for unions.
4018 for (; I < SrcSize; I++) {
4019 if (SrcTypes[I]->isUnionType())
4020 return false;
4021 }
4022 return true;
4023}
4024
4026 assert(Param->hasAttr<HLSLParamModifierAttr>() &&
4027 "We should not get here without a parameter modifier expression");
4028 const auto *Attr = Param->getAttr<HLSLParamModifierAttr>();
4029 if (Attr->getABI() == ParameterABI::Ordinary)
4030 return ExprResult(Arg);
4031
4032 bool IsInOut = Attr->getABI() == ParameterABI::HLSLInOut;
4033 if (!Arg->isLValue()) {
4034 SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_lvalue)
4035 << Arg << (IsInOut ? 1 : 0);
4036 return ExprError();
4037 }
4038
4039 ASTContext &Ctx = SemaRef.getASTContext();
4040
4041 QualType Ty = Param->getType().getNonLValueExprType(Ctx);
4042
4043 // HLSL allows implicit conversions from scalars to vectors, but not the
4044 // inverse, so we need to disallow `inout` with scalar->vector or
4045 // scalar->matrix conversions.
4046 if (Arg->getType()->isScalarType() != Ty->isScalarType()) {
4047 SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_scalar_extension)
4048 << Arg << (IsInOut ? 1 : 0);
4049 return ExprError();
4050 }
4051
4052 auto *ArgOpV = new (Ctx) OpaqueValueExpr(Param->getBeginLoc(), Arg->getType(),
4053 VK_LValue, OK_Ordinary, Arg);
4054
4055 // Parameters are initialized via copy initialization. This allows for
4056 // overload resolution of argument constructors.
4057 InitializedEntity Entity =
4059 ExprResult Res =
4060 SemaRef.PerformCopyInitialization(Entity, Param->getBeginLoc(), ArgOpV);
4061 if (Res.isInvalid())
4062 return ExprError();
4063 Expr *Base = Res.get();
4064 // After the cast, drop the reference type when creating the exprs.
4065 Ty = Ty.getNonLValueExprType(Ctx);
4066 auto *OpV = new (Ctx)
4067 OpaqueValueExpr(Param->getBeginLoc(), Ty, VK_LValue, OK_Ordinary, Base);
4068
4069 // Writebacks are performed with `=` binary operator, which allows for
4070 // overload resolution on writeback result expressions.
4071 Res = SemaRef.ActOnBinOp(SemaRef.getCurScope(), Param->getBeginLoc(),
4072 tok::equal, ArgOpV, OpV);
4073
4074 if (Res.isInvalid())
4075 return ExprError();
4076 Expr *Writeback = Res.get();
4077 auto *OutExpr =
4078 HLSLOutArgExpr::Create(Ctx, Ty, ArgOpV, OpV, Writeback, IsInOut);
4079
4080 return ExprResult(OutExpr);
4081}
4082
4084 // If HLSL gains support for references, all the cites that use this will need
4085 // to be updated with semantic checking to produce errors for
4086 // pointers/references.
4087 assert(!Ty->isReferenceType() &&
4088 "Pointer and reference types cannot be inout or out parameters");
4089 Ty = SemaRef.getASTContext().getLValueReferenceType(Ty);
4090 Ty.addRestrict();
4091 return Ty;
4092}
4093
4094static bool IsDefaultBufferConstantDecl(const ASTContext &Ctx, VarDecl *VD) {
4095 bool IsVulkan =
4096 Ctx.getTargetInfo().getTriple().getOS() == llvm::Triple::Vulkan;
4097 bool IsVKPushConstant = IsVulkan && VD->hasAttr<HLSLVkPushConstantAttr>();
4098 QualType QT = VD->getType();
4099 return VD->getDeclContext()->isTranslationUnit() &&
4100 QT.getAddressSpace() == LangAS::Default &&
4101 VD->getStorageClass() != SC_Static &&
4102 !VD->hasAttr<HLSLVkConstantIdAttr>() && !IsVKPushConstant &&
4104}
4105
4107 // The variable already has an address space (groupshared for ex).
4108 if (Decl->getType().hasAddressSpace())
4109 return;
4110
4111 if (Decl->getType()->isDependentType())
4112 return;
4113
4114 QualType Type = Decl->getType();
4115
4116 if (Decl->hasAttr<HLSLVkExtBuiltinInputAttr>()) {
4117 LangAS ImplAS = LangAS::hlsl_input;
4118 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
4119 Decl->setType(Type);
4120 return;
4121 }
4122
4123 bool IsVulkan = getASTContext().getTargetInfo().getTriple().getOS() ==
4124 llvm::Triple::Vulkan;
4125 if (IsVulkan && Decl->hasAttr<HLSLVkPushConstantAttr>()) {
4126 if (HasDeclaredAPushConstant)
4127 SemaRef.Diag(Decl->getLocation(), diag::err_hlsl_push_constant_unique);
4128
4130 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
4131 Decl->setType(Type);
4132 HasDeclaredAPushConstant = true;
4133 return;
4134 }
4135
4136 if (Type->isSamplerT() || Type->isVoidType())
4137 return;
4138
4139 // Resource handles.
4141 return;
4142
4143 // Only static globals belong to the Private address space.
4144 // Non-static globals belongs to the cbuffer.
4145 if (Decl->getStorageClass() != SC_Static && !Decl->isStaticDataMember())
4146 return;
4147
4149 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
4150 Decl->setType(Type);
4151}
4152
4154 if (VD->hasGlobalStorage()) {
4155 // make sure the declaration has a complete type
4156 if (SemaRef.RequireCompleteType(
4157 VD->getLocation(),
4158 SemaRef.getASTContext().getBaseElementType(VD->getType()),
4159 diag::err_typecheck_decl_incomplete_type)) {
4160 VD->setInvalidDecl();
4162 return;
4163 }
4164
4165 // Global variables outside a cbuffer block that are not a resource, static,
4166 // groupshared, or an empty array or struct belong to the default constant
4167 // buffer $Globals (to be created at the end of the translation unit).
4169 // update address space to hlsl_constant
4172 VD->setType(NewTy);
4173 DefaultCBufferDecls.push_back(VD);
4174 }
4175
4176 // find all resources bindings on decl
4177 if (VD->getType()->isHLSLIntangibleType())
4178 collectResourceBindingsOnVarDecl(VD);
4179
4180 if (VD->hasAttr<HLSLVkConstantIdAttr>())
4182
4184 VD->getStorageClass() != SC_Static) {
4185 // Add internal linkage attribute to non-static resource variables. The
4186 // global externally visible storage is accessed through the handle, which
4187 // is a member. The variable itself is not externally visible.
4188 VD->addAttr(InternalLinkageAttr::CreateImplicit(getASTContext()));
4189 }
4190
4191 // process explicit bindings
4192 processExplicitBindingsOnDecl(VD);
4193
4194 // Add implicit binding attribute to non-static resource arrays.
4195 if (VD->getType()->isHLSLResourceRecordArray() &&
4196 VD->getStorageClass() != SC_Static) {
4197 // If the resource array does not have an explicit binding attribute,
4198 // create an implicit one. It will be used to transfer implicit binding
4199 // order_ID to codegen.
4200 ResourceBindingAttrs Binding(VD);
4201 if (!Binding.isExplicit()) {
4202 uint32_t OrderID = getNextImplicitBindingOrderID();
4203 if (Binding.hasBinding())
4204 Binding.setImplicitOrderID(OrderID);
4205 else {
4208 OrderID);
4209 // Re-create the binding object to pick up the new attribute.
4210 Binding = ResourceBindingAttrs(VD);
4211 }
4212 }
4213
4214 // Get to the base type of a potentially multi-dimensional array.
4216
4217 const CXXRecordDecl *RD = Ty->getAsCXXRecordDecl();
4218 if (hasCounterHandle(RD)) {
4219 if (!Binding.hasCounterImplicitOrderID()) {
4220 uint32_t OrderID = getNextImplicitBindingOrderID();
4221 Binding.setCounterImplicitOrderID(OrderID);
4222 }
4223 }
4224 }
4225 }
4226
4228}
4229
4230bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) {
4231 assert(VD->getType()->isHLSLResourceRecord() &&
4232 "expected resource record type");
4233
4235 uint64_t UIntTySize = AST.getTypeSize(AST.UnsignedIntTy);
4236 uint64_t IntTySize = AST.getTypeSize(AST.IntTy);
4237
4238 // Gather resource binding attributes.
4239 ResourceBindingAttrs Binding(VD);
4240
4241 // Find correct initialization method and create its arguments.
4242 QualType ResourceTy = VD->getType();
4243 CXXRecordDecl *ResourceDecl = ResourceTy->getAsCXXRecordDecl();
4244 CXXMethodDecl *CreateMethod = nullptr;
4246
4247 bool HasCounter = hasCounterHandle(ResourceDecl);
4248 const char *CreateMethodName;
4249 if (Binding.isExplicit())
4250 CreateMethodName = HasCounter ? "__createFromBindingWithImplicitCounter"
4251 : "__createFromBinding";
4252 else
4253 CreateMethodName = HasCounter
4254 ? "__createFromImplicitBindingWithImplicitCounter"
4255 : "__createFromImplicitBinding";
4256
4257 CreateMethod =
4258 lookupMethod(SemaRef, ResourceDecl, CreateMethodName, VD->getLocation());
4259
4260 if (!CreateMethod)
4261 // This can happen if someone creates a struct that looks like an HLSL
4262 // resource record but does not have the required static create method.
4263 // No binding will be generated for it.
4264 return false;
4265
4266 if (Binding.isExplicit()) {
4267 IntegerLiteral *RegSlot =
4268 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, Binding.getSlot()),
4270 Args.push_back(RegSlot);
4271 } else {
4272 uint32_t OrderID = (Binding.hasImplicitOrderID())
4273 ? Binding.getImplicitOrderID()
4274 : getNextImplicitBindingOrderID();
4275 IntegerLiteral *OrderId =
4276 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, OrderID),
4278 Args.push_back(OrderId);
4279 }
4280
4281 IntegerLiteral *Space =
4282 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, Binding.getSpace()),
4283 AST.UnsignedIntTy, SourceLocation());
4284 Args.push_back(Space);
4285
4286 IntegerLiteral *RangeSize = IntegerLiteral::Create(
4287 AST, llvm::APInt(IntTySize, 1), AST.IntTy, SourceLocation());
4288 Args.push_back(RangeSize);
4289
4290 IntegerLiteral *Index = IntegerLiteral::Create(
4291 AST, llvm::APInt(UIntTySize, 0), AST.UnsignedIntTy, SourceLocation());
4292 Args.push_back(Index);
4293
4294 StringRef VarName = VD->getName();
4295 StringLiteral *Name = StringLiteral::Create(
4296 AST, VarName, StringLiteralKind::Ordinary, false,
4297 AST.getStringLiteralArrayType(AST.CharTy.withConst(), VarName.size()),
4298 SourceLocation());
4299 ImplicitCastExpr *NameCast = ImplicitCastExpr::Create(
4300 AST, AST.getPointerType(AST.CharTy.withConst()), CK_ArrayToPointerDecay,
4301 Name, nullptr, VK_PRValue, FPOptionsOverride());
4302 Args.push_back(NameCast);
4303
4304 if (HasCounter) {
4305 // Will this be in the correct order?
4306 uint32_t CounterOrderID = getNextImplicitBindingOrderID();
4307 IntegerLiteral *CounterId =
4308 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, CounterOrderID),
4309 AST.UnsignedIntTy, SourceLocation());
4310 Args.push_back(CounterId);
4311 }
4312
4313 // Make sure the create method template is instantiated and emitted.
4314 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
4315 SemaRef.InstantiateFunctionDefinition(VD->getLocation(), CreateMethod,
4316 true);
4317
4318 // Create CallExpr with a call to the static method and set it as the decl
4319 // initialization.
4320 DeclRefExpr *DRE = DeclRefExpr::Create(
4321 AST, NestedNameSpecifierLoc(), SourceLocation(), CreateMethod, false,
4322 CreateMethod->getNameInfo(), CreateMethod->getType(), VK_PRValue);
4323
4324 auto *ImpCast = ImplicitCastExpr::Create(
4325 AST, AST.getPointerType(CreateMethod->getType()),
4326 CK_FunctionToPointerDecay, DRE, nullptr, VK_PRValue, FPOptionsOverride());
4327
4328 CallExpr *InitExpr =
4329 CallExpr::Create(AST, ImpCast, Args, ResourceTy, VK_PRValue,
4330 SourceLocation(), FPOptionsOverride());
4331 VD->setInit(InitExpr);
4333 SemaRef.CheckCompleteVariableDeclaration(VD);
4334 return true;
4335}
4336
4337bool SemaHLSL::initGlobalResourceArrayDecl(VarDecl *VD) {
4338 assert(VD->getType()->isHLSLResourceRecordArray() &&
4339 "expected array of resource records");
4340
4341 // Individual resources in a resource array are not initialized here. They
4342 // are initialized later on during codegen when the individual resources are
4343 // accessed. Codegen will emit a call to the resource initialization method
4344 // with the specified array index. We need to make sure though that the method
4345 // for the specific resource type is instantiated, so codegen can emit a call
4346 // to it when the array element is accessed.
4347
4348 // Find correct initialization method based on the resource binding
4349 // information.
4350 ASTContext &AST = SemaRef.getASTContext();
4351 QualType ResElementTy = AST.getBaseElementType(VD->getType());
4352 CXXRecordDecl *ResourceDecl = ResElementTy->getAsCXXRecordDecl();
4353 CXXMethodDecl *CreateMethod = nullptr;
4354
4355 bool HasCounter = hasCounterHandle(ResourceDecl);
4356 ResourceBindingAttrs ResourceAttrs(VD);
4357 if (ResourceAttrs.isExplicit())
4358 // Resource has explicit binding.
4359 CreateMethod =
4360 lookupMethod(SemaRef, ResourceDecl,
4361 HasCounter ? "__createFromBindingWithImplicitCounter"
4362 : "__createFromBinding",
4363 VD->getLocation());
4364 else
4365 // Resource has implicit binding.
4366 CreateMethod = lookupMethod(
4367 SemaRef, ResourceDecl,
4368 HasCounter ? "__createFromImplicitBindingWithImplicitCounter"
4369 : "__createFromImplicitBinding",
4370 VD->getLocation());
4371
4372 if (!CreateMethod)
4373 return false;
4374
4375 // Make sure the create method template is instantiated and emitted.
4376 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
4377 SemaRef.InstantiateFunctionDefinition(VD->getLocation(), CreateMethod,
4378 true);
4379 return true;
4380}
4381
4382// Returns true if the initialization has been handled.
4383// Returns false to use default initialization.
4385 // Objects in the hlsl_constant address space are initialized
4386 // externally, so don't synthesize an implicit initializer.
4388 return true;
4389
4390 // Initialize non-static resources at the global scope.
4391 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
4392 const Type *Ty = VD->getType().getTypePtr();
4393 if (Ty->isHLSLResourceRecord())
4394 return initGlobalResourceDecl(VD);
4395 if (Ty->isHLSLResourceRecordArray())
4396 return initGlobalResourceArrayDecl(VD);
4397 }
4398 return false;
4399}
4400
4401// Return true if everything is ok; returns false if there was an error.
4403 Expr *RHSExpr, SourceLocation Loc) {
4404 assert((LHSExpr->getType()->isHLSLResourceRecord() ||
4405 LHSExpr->getType()->isHLSLResourceRecordArray()) &&
4406 "expected LHS to be a resource record or array of resource records");
4407 if (Opc != BO_Assign)
4408 return true;
4409
4410 // If LHS is an array subscript, get the underlying declaration.
4411 Expr *E = LHSExpr;
4412 while (auto *ASE = dyn_cast<ArraySubscriptExpr>(E))
4413 E = ASE->getBase()->IgnoreParenImpCasts();
4414
4415 // Report error if LHS is a non-static resource declared at a global scope.
4416 if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E->IgnoreParens())) {
4417 if (VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
4418 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
4419 // assignment to global resource is not allowed
4420 SemaRef.Diag(Loc, diag::err_hlsl_assign_to_global_resource) << VD;
4421 SemaRef.Diag(VD->getLocation(), diag::note_var_declared_here) << VD;
4422 return false;
4423 }
4424 }
4425 }
4426 return true;
4427}
4428
4429// Walks though the global variable declaration, collects all resource binding
4430// requirements and adds them to Bindings
4431void SemaHLSL::collectResourceBindingsOnVarDecl(VarDecl *VD) {
4432 assert(VD->hasGlobalStorage() && VD->getType()->isHLSLIntangibleType() &&
4433 "expected global variable that contains HLSL resource");
4434
4435 // Cbuffers and Tbuffers are HLSLBufferDecl types
4436 if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(VD)) {
4437 Bindings.addDeclBindingInfo(VD, CBufferOrTBuffer->isCBuffer()
4438 ? ResourceClass::CBuffer
4439 : ResourceClass::SRV);
4440 return;
4441 }
4442
4443 // Unwrap arrays
4444 // FIXME: Calculate array size while unwrapping
4445 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
4446 while (Ty->isArrayType()) {
4447 const ArrayType *AT = cast<ArrayType>(Ty);
4449 }
4450
4451 // Resource (or array of resources)
4452 if (const HLSLAttributedResourceType *AttrResType =
4453 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
4454 Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass);
4455 return;
4456 }
4457
4458 // User defined record type
4459 if (const RecordType *RT = dyn_cast<RecordType>(Ty))
4460 collectResourceBindingsOnUserRecordDecl(VD, RT);
4461}
4462
4463// Walks though the explicit resource binding attributes on the declaration,
4464// and makes sure there is a resource that matched the binding and updates
4465// DeclBindingInfoLists
4466void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
4467 assert(VD->hasGlobalStorage() && "expected global variable");
4468
4469 bool HasBinding = false;
4470 for (Attr *A : VD->attrs()) {
4471 if (isa<HLSLVkBindingAttr>(A)) {
4472 HasBinding = true;
4473 if (auto PA = VD->getAttr<HLSLVkPushConstantAttr>())
4474 Diag(PA->getLoc(), diag::err_hlsl_attr_incompatible) << A << PA;
4475 }
4476
4477 HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
4478 if (!RBA || !RBA->hasRegisterSlot())
4479 continue;
4480 HasBinding = true;
4481
4482 RegisterType RT = RBA->getRegisterType();
4483 assert(RT != RegisterType::I && "invalid or obsolete register type should "
4484 "never have an attribute created");
4485
4486 if (RT == RegisterType::C) {
4487 if (Bindings.hasBindingInfoForDecl(VD))
4488 SemaRef.Diag(VD->getLocation(),
4489 diag::warn_hlsl_user_defined_type_missing_member)
4490 << static_cast<int>(RT);
4491 continue;
4492 }
4493
4494 // Find DeclBindingInfo for this binding and update it, or report error
4495 // if it does not exist (user type does to contain resources with the
4496 // expected resource class).
4498 if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, RC)) {
4499 // update binding info
4500 BI->setBindingAttribute(RBA, BindingType::Explicit);
4501 } else {
4502 SemaRef.Diag(VD->getLocation(),
4503 diag::warn_hlsl_user_defined_type_missing_member)
4504 << static_cast<int>(RT);
4505 }
4506 }
4507
4508 if (!HasBinding && isResourceRecordTypeOrArrayOf(VD))
4509 SemaRef.Diag(VD->getLocation(), diag::warn_hlsl_implicit_binding);
4510}
4511namespace {
4512class InitListTransformer {
4513 Sema &S;
4514 ASTContext &Ctx;
4515 QualType InitTy;
4516 QualType *DstIt = nullptr;
4517 Expr **ArgIt = nullptr;
4518 // Is wrapping the destination type iterator required? This is only used for
4519 // incomplete array types where we loop over the destination type since we
4520 // don't know the full number of elements from the declaration.
4521 bool Wrap;
4522
4523 bool castInitializer(Expr *E) {
4524 assert(DstIt && "This should always be something!");
4525 if (DstIt == DestTypes.end()) {
4526 if (!Wrap) {
4527 ArgExprs.push_back(E);
4528 // This is odd, but it isn't technically a failure due to conversion, we
4529 // handle mismatched counts of arguments differently.
4530 return true;
4531 }
4532 DstIt = DestTypes.begin();
4533 }
4534 InitializedEntity Entity = InitializedEntity::InitializeParameter(
4535 Ctx, *DstIt, /* Consumed (ObjC) */ false);
4536 ExprResult Res = S.PerformCopyInitialization(Entity, E->getBeginLoc(), E);
4537 if (Res.isInvalid())
4538 return false;
4539 Expr *Init = Res.get();
4540 ArgExprs.push_back(Init);
4541 DstIt++;
4542 return true;
4543 }
4544
4545 bool buildInitializerListImpl(Expr *E) {
4546 // If this is an initialization list, traverse the sub initializers.
4547 if (auto *Init = dyn_cast<InitListExpr>(E)) {
4548 for (auto *SubInit : Init->inits())
4549 if (!buildInitializerListImpl(SubInit))
4550 return false;
4551 return true;
4552 }
4553
4554 // If this is a scalar type, just enqueue the expression.
4555 QualType Ty = E->getType();
4556
4557 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()))
4558 return castInitializer(E);
4559
4560 if (auto *VecTy = Ty->getAs<VectorType>()) {
4561 uint64_t Size = VecTy->getNumElements();
4562
4563 QualType SizeTy = Ctx.getSizeType();
4564 uint64_t SizeTySize = Ctx.getTypeSize(SizeTy);
4565 for (uint64_t I = 0; I < Size; ++I) {
4566 auto *Idx = IntegerLiteral::Create(Ctx, llvm::APInt(SizeTySize, I),
4567 SizeTy, SourceLocation());
4568
4570 E, E->getBeginLoc(), Idx, E->getEndLoc());
4571 if (ElExpr.isInvalid())
4572 return false;
4573 if (!castInitializer(ElExpr.get()))
4574 return false;
4575 }
4576 return true;
4577 }
4578 if (auto *MTy = Ty->getAs<ConstantMatrixType>()) {
4579 unsigned Rows = MTy->getNumRows();
4580 unsigned Cols = MTy->getNumColumns();
4581 QualType ElemTy = MTy->getElementType();
4582
4583 for (unsigned C = 0; C < Cols; ++C) {
4584 for (unsigned R = 0; R < Rows; ++R) {
4585 // row index literal
4586 Expr *RowIdx = IntegerLiteral::Create(
4587 Ctx, llvm::APInt(Ctx.getIntWidth(Ctx.IntTy), R), Ctx.IntTy,
4588 E->getBeginLoc());
4589 // column index literal
4590 Expr *ColIdx = IntegerLiteral::Create(
4591 Ctx, llvm::APInt(Ctx.getIntWidth(Ctx.IntTy), C), Ctx.IntTy,
4592 E->getBeginLoc());
4594 E, RowIdx, ColIdx, E->getEndLoc());
4595 if (ElExpr.isInvalid())
4596 return false;
4597 if (!castInitializer(ElExpr.get()))
4598 return false;
4599 ElExpr.get()->setType(ElemTy);
4600 }
4601 }
4602 return true;
4603 }
4604
4605 if (auto *ArrTy = dyn_cast<ConstantArrayType>(Ty.getTypePtr())) {
4606 uint64_t Size = ArrTy->getZExtSize();
4607 QualType SizeTy = Ctx.getSizeType();
4608 uint64_t SizeTySize = Ctx.getTypeSize(SizeTy);
4609 for (uint64_t I = 0; I < Size; ++I) {
4610 auto *Idx = IntegerLiteral::Create(Ctx, llvm::APInt(SizeTySize, I),
4611 SizeTy, SourceLocation());
4613 E, E->getBeginLoc(), Idx, E->getEndLoc());
4614 if (ElExpr.isInvalid())
4615 return false;
4616 if (!buildInitializerListImpl(ElExpr.get()))
4617 return false;
4618 }
4619 return true;
4620 }
4621
4622 if (auto *RD = Ty->getAsCXXRecordDecl()) {
4623 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
4624 RecordDecls.push_back(RD);
4625 while (RecordDecls.back()->getNumBases()) {
4626 CXXRecordDecl *D = RecordDecls.back();
4627 assert(D->getNumBases() == 1 &&
4628 "HLSL doesn't support multiple inheritance");
4629 RecordDecls.push_back(
4631 }
4632 while (!RecordDecls.empty()) {
4633 CXXRecordDecl *RD = RecordDecls.pop_back_val();
4634 for (auto *FD : RD->fields()) {
4635 if (FD->isUnnamedBitField())
4636 continue;
4637 DeclAccessPair Found = DeclAccessPair::make(FD, FD->getAccess());
4638 DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc());
4640 E, false, E->getBeginLoc(), CXXScopeSpec(), FD, Found, NameInfo);
4641 if (Res.isInvalid())
4642 return false;
4643 if (!buildInitializerListImpl(Res.get()))
4644 return false;
4645 }
4646 }
4647 }
4648 return true;
4649 }
4650
4651 Expr *generateInitListsImpl(QualType Ty) {
4652 assert(ArgIt != ArgExprs.end() && "Something is off in iteration!");
4653 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()))
4654 return *(ArgIt++);
4655
4656 llvm::SmallVector<Expr *> Inits;
4657 Ty = Ty.getDesugaredType(Ctx);
4658 if (Ty->isVectorType() || Ty->isConstantArrayType() ||
4659 Ty->isConstantMatrixType()) {
4660 QualType ElTy;
4661 uint64_t Size = 0;
4662 if (auto *ATy = Ty->getAs<VectorType>()) {
4663 ElTy = ATy->getElementType();
4664 Size = ATy->getNumElements();
4665 } else if (auto *CMTy = Ty->getAs<ConstantMatrixType>()) {
4666 ElTy = CMTy->getElementType();
4667 Size = CMTy->getNumElementsFlattened();
4668 } else {
4669 auto *VTy = cast<ConstantArrayType>(Ty.getTypePtr());
4670 ElTy = VTy->getElementType();
4671 Size = VTy->getZExtSize();
4672 }
4673 for (uint64_t I = 0; I < Size; ++I)
4674 Inits.push_back(generateInitListsImpl(ElTy));
4675 }
4676 if (auto *RD = Ty->getAsCXXRecordDecl()) {
4677 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
4678 RecordDecls.push_back(RD);
4679 while (RecordDecls.back()->getNumBases()) {
4680 CXXRecordDecl *D = RecordDecls.back();
4681 assert(D->getNumBases() == 1 &&
4682 "HLSL doesn't support multiple inheritance");
4683 RecordDecls.push_back(
4685 }
4686 while (!RecordDecls.empty()) {
4687 CXXRecordDecl *RD = RecordDecls.pop_back_val();
4688 for (auto *FD : RD->fields())
4689 if (!FD->isUnnamedBitField())
4690 Inits.push_back(generateInitListsImpl(FD->getType()));
4691 }
4692 }
4693 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
4694 Inits, Inits.back()->getEndLoc());
4695 NewInit->setType(Ty);
4696 return NewInit;
4697 }
4698
4699public:
4700 llvm::SmallVector<QualType, 16> DestTypes;
4701 llvm::SmallVector<Expr *, 16> ArgExprs;
4702 InitListTransformer(Sema &SemaRef, const InitializedEntity &Entity)
4703 : S(SemaRef), Ctx(SemaRef.getASTContext()),
4704 Wrap(Entity.getType()->isIncompleteArrayType()) {
4705 InitTy = Entity.getType().getNonReferenceType();
4706 // When we're generating initializer lists for incomplete array types we
4707 // need to wrap around both when building the initializers and when
4708 // generating the final initializer lists.
4709 if (Wrap) {
4710 assert(InitTy->isIncompleteArrayType());
4711 const IncompleteArrayType *IAT = Ctx.getAsIncompleteArrayType(InitTy);
4712 InitTy = IAT->getElementType();
4713 }
4714 BuildFlattenedTypeList(InitTy, DestTypes);
4715 DstIt = DestTypes.begin();
4716 }
4717
4718 bool buildInitializerList(Expr *E) { return buildInitializerListImpl(E); }
4719
4720 Expr *generateInitLists() {
4721 assert(!ArgExprs.empty() &&
4722 "Call buildInitializerList to generate argument expressions.");
4723 ArgIt = ArgExprs.begin();
4724 if (!Wrap)
4725 return generateInitListsImpl(InitTy);
4726 llvm::SmallVector<Expr *> Inits;
4727 while (ArgIt != ArgExprs.end())
4728 Inits.push_back(generateInitListsImpl(InitTy));
4729
4730 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
4731 Inits, Inits.back()->getEndLoc());
4732 llvm::APInt ArySize(64, Inits.size());
4733 NewInit->setType(Ctx.getConstantArrayType(InitTy, ArySize, nullptr,
4734 ArraySizeModifier::Normal, 0));
4735 return NewInit;
4736 }
4737};
4738} // namespace
4739
4741 InitListExpr *Init) {
4742 // If the initializer is a scalar, just return it.
4743 if (Init->getType()->isScalarType())
4744 return true;
4745 ASTContext &Ctx = SemaRef.getASTContext();
4746 InitListTransformer ILT(SemaRef, Entity);
4747
4748 for (unsigned I = 0; I < Init->getNumInits(); ++I) {
4749 Expr *E = Init->getInit(I);
4750 if (E->HasSideEffects(Ctx)) {
4751 QualType Ty = E->getType();
4752 if (Ty->isRecordType())
4753 E = new (Ctx) MaterializeTemporaryExpr(Ty, E, E->isLValue());
4754 E = new (Ctx) OpaqueValueExpr(E->getBeginLoc(), Ty, E->getValueKind(),
4755 E->getObjectKind(), E);
4756 Init->setInit(I, E);
4757 }
4758 if (!ILT.buildInitializerList(E))
4759 return false;
4760 }
4761 size_t ExpectedSize = ILT.DestTypes.size();
4762 size_t ActualSize = ILT.ArgExprs.size();
4763 if (ExpectedSize == 0 && ActualSize == 0)
4764 return true;
4765
4766 // For incomplete arrays it is completely arbitrary to choose whether we think
4767 // the user intended fewer or more elements. This implementation assumes that
4768 // the user intended more, and errors that there are too few initializers to
4769 // complete the final element.
4770 if (Entity.getType()->isIncompleteArrayType()) {
4771 assert(ExpectedSize > 0 &&
4772 "The expected size of an incomplete array type must be at least 1.");
4773 ExpectedSize =
4774 ((ActualSize + ExpectedSize - 1) / ExpectedSize) * ExpectedSize;
4775 }
4776
4777 // An initializer list might be attempting to initialize a reference or
4778 // rvalue-reference. When checking the initializer we should look through
4779 // the reference.
4780 QualType InitTy = Entity.getType().getNonReferenceType();
4781 if (InitTy.hasAddressSpace())
4782 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(InitTy);
4783 if (ExpectedSize != ActualSize) {
4784 int TooManyOrFew = ActualSize > ExpectedSize ? 1 : 0;
4785 SemaRef.Diag(Init->getBeginLoc(), diag::err_hlsl_incorrect_num_initializers)
4786 << TooManyOrFew << InitTy << ExpectedSize << ActualSize;
4787 return false;
4788 }
4789
4790 // generateInitListsImpl will always return an InitListExpr here, because the
4791 // scalar case is handled above.
4792 auto *NewInit = cast<InitListExpr>(ILT.generateInitLists());
4793 Init->resizeInits(Ctx, NewInit->getNumInits());
4794 for (unsigned I = 0; I < NewInit->getNumInits(); ++I)
4795 Init->updateInit(Ctx, I, NewInit->getInit(I));
4796 return true;
4797}
4798
4800 const HLSLVkConstantIdAttr *ConstIdAttr =
4801 VDecl->getAttr<HLSLVkConstantIdAttr>();
4802 if (!ConstIdAttr)
4803 return true;
4804
4805 ASTContext &Context = SemaRef.getASTContext();
4806
4807 APValue InitValue;
4808 if (!Init->isCXX11ConstantExpr(Context, &InitValue)) {
4809 Diag(VDecl->getLocation(), diag::err_specialization_const);
4810 VDecl->setInvalidDecl();
4811 return false;
4812 }
4813
4814 Builtin::ID BID =
4816
4817 // Argument 1: The ID from the attribute
4818 int ConstantID = ConstIdAttr->getId();
4819 llvm::APInt IDVal(Context.getIntWidth(Context.IntTy), ConstantID);
4820 Expr *IdExpr = IntegerLiteral::Create(Context, IDVal, Context.IntTy,
4821 ConstIdAttr->getLocation());
4822
4823 SmallVector<Expr *, 2> Args = {IdExpr, Init};
4824 Expr *C = SemaRef.BuildBuiltinCallExpr(Init->getExprLoc(), BID, Args);
4825 if (C->getType()->getCanonicalTypeUnqualified() !=
4827 C = SemaRef
4828 .BuildCStyleCastExpr(SourceLocation(),
4829 Context.getTrivialTypeSourceInfo(
4830 Init->getType(), Init->getExprLoc()),
4831 SourceLocation(), C)
4832 .get();
4833 }
4834 Init = C;
4835 return true;
4836}
Defines the clang::ASTContext interface.
Defines enum values for all the target-independent builtin functions.
llvm::dxil::ResourceClass ResourceClass
Defines the C++ Decl subclasses, other than those for templates (found in DeclTemplate....
TokenType getType() const
Returns the token's type, e.g.
FormatToken * Previous
The previous token in the unwrapped line.
Defines the clang::IdentifierInfo, clang::IdentifierTable, and clang::Selector interfaces.
#define X(type, name)
Definition Value.h:97
Forward-declares and imports various common LLVM datatypes that clang wants to use unqualified.
llvm::SmallVector< std::pair< const MemRegion *, SVal >, 4 > Bindings
static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType)
static void BuildFlattenedTypeList(QualType BaseTy, llvm::SmallVectorImpl< QualType > &List)
static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static QualType handleIntegerVectorBinOpConversion(Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign)
static bool convertToRegisterType(StringRef Slot, RegisterType *RT)
Definition SemaHLSL.cpp:83
static bool CheckWaveActive(Sema *S, CallExpr *TheCall)
static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz)
static bool CheckBoolSelect(Sema *S, CallExpr *TheCall)
static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context, QualType T)
Definition SemaHLSL.cpp:220
static bool isZeroSizedArray(const ConstantArrayType *CAT)
Definition SemaHLSL.cpp:339
static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, Decl *D, RegisterType RegType, bool SpecifiedSpace)
static FieldDecl * createFieldForHostLayoutStruct(Sema &S, const Type *Ty, IdentifierInfo *II, CXXRecordDecl *LayoutStruct)
Definition SemaHLSL.cpp:458
static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static bool isInvalidConstantBufferLeafElementType(const Type *Ty)
Definition SemaHLSL.cpp:365
static Builtin::ID getSpecConstBuiltinId(const Type *Type)
Definition SemaHLSL.cpp:133
static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall, unsigned ArgIndex)
static IdentifierInfo * getHostLayoutStructName(Sema &S, NamedDecl *BaseDecl, bool MustBeUnique)
Definition SemaHLSL.cpp:421
static void addImplicitBindingAttrToDecl(Sema &S, Decl *D, RegisterType RT, uint32_t ImplicitBindingOrderID)
Definition SemaHLSL.cpp:586
static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall, QualType ReturnType)
static bool isResourceRecordTypeOrArrayOf(VarDecl *VD)
Definition SemaHLSL.cpp:346
static unsigned calculateLegacyCbufferSize(const ASTContext &Context, QualType T)
Definition SemaHLSL.cpp:239
static const HLSLAttributedResourceType * getResourceArrayHandleType(VarDecl *VD)
Definition SemaHLSL.cpp:352
static RegisterType getRegisterType(ResourceClass RC)
Definition SemaHLSL.cpp:63
static bool ValidateRegisterNumber(uint64_t SlotNum, Decl *TheDecl, ASTContext &Ctx, RegisterType RegTy)
static bool isVkPipelineBuiltin(const ASTContext &AstContext, FunctionDecl *FD, HLSLAppliedSemanticAttr *Semantic, bool IsInput)
Definition SemaHLSL.cpp:774
static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall, unsigned ArgIndex)
static QualType castElement(Sema &S, ExprResult &E, QualType Ty)
static CXXRecordDecl * findRecordDeclInContext(IdentifierInfo *II, DeclContext *DC)
Definition SemaHLSL.cpp:404
static bool CheckExpectedBitWidth(Sema *S, CallExpr *TheCall, unsigned ArgOrdinal, unsigned Width)
static bool hasCounterHandle(const CXXRecordDecl *RD)
static bool CheckVectorSelect(Sema *S, CallExpr *TheCall)
static QualType handleFloatVectorBinOpConversion(Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign)
static ResourceClass getResourceClass(RegisterType RT)
Definition SemaHLSL.cpp:115
static CXXRecordDecl * createHostLayoutStruct(Sema &S, CXXRecordDecl *StructDecl)
Definition SemaHLSL.cpp:490
static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar, unsigned ArgIndex)
static bool CheckScalarOrVectorOrMatrix(Sema *S, CallExpr *TheCall, QualType Scalar, unsigned ArgIndex)
void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl)
Definition SemaHLSL.cpp:555
static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD)
Definition SemaHLSL.cpp:384
static bool CheckResourceHandle(Sema *S, CallExpr *TheCall, unsigned ArgIndex, llvm::function_ref< bool(const HLSLAttributedResourceType *ResType)> Check=nullptr)
static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl)
Definition SemaHLSL.cpp:286
static bool IsDefaultBufferConstantDecl(const ASTContext &Ctx, VarDecl *VD)
HLSLResourceBindingAttr::RegisterType RegisterType
Definition SemaHLSL.cpp:58
static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy, QualType SrcTy)
static bool isValidWaveSizeValue(unsigned Value)
static bool AccumulateHLSLResourceSlots(QualType Ty, uint64_t &StartSlot, const uint64_t &Limit, const ResourceClass ResClass, ASTContext &Ctx, uint64_t ArrayCount=1)
static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, RegisterType regType)
static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, Decl *D, RegisterType RegType, bool SpecifiedSpace)
This file declares semantic analysis for HLSL constructs.
Defines the clang::SourceLocation class and associated facilities.
Defines various enumerations that describe declaration and type specifiers.
C Language Family Type Representation.
Defines the clang::TypeLoc interface and its subclasses.
C Language Family Type Representation.
static const TypeInfo & getInfo(unsigned id)
Definition Types.cpp:44
return(__x > > __y)|(__x<<(32 - __y))
APValue - This class implements a discriminated union of [uninitialized] [APSInt] [APFloat],...
Definition APValue.h:122
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
Definition ASTContext.h:220
unsigned getIntWidth(QualType T) const
int getIntegerTypeOrder(QualType LHS, QualType RHS) const
Return the highest ranked integer type, see C99 6.3.1.8p1.
QualType getPointerType(QualType T) const
Return the uniqued reference to the type for a pointer to the specified type.
const IncompleteArrayType * getAsIncompleteArrayType(QualType T) const
IdentifierTable & Idents
Definition ASTContext.h:790
QualType getConstantArrayType(QualType EltTy, const llvm::APInt &ArySize, const Expr *SizeExpr, ArraySizeModifier ASM, unsigned IndexTypeQuals) const
Return the unique reference to the type for a constant array of the specified element type.
QualType getBaseElementType(const ArrayType *VAT) const
Return the innermost element type of an array type.
int getFloatingTypeOrder(QualType LHS, QualType RHS) const
Compare the rank of the two specified floating point types, ignoring the domain of the type (i....
CanQualType BoolTy
TypeSourceInfo * getTrivialTypeSourceInfo(QualType T, SourceLocation Loc=SourceLocation()) const
Allocate a TypeSourceInfo where all locations have been initialized to a given location,...
QualType getStringLiteralArrayType(QualType EltTy, unsigned Length) const
Return a type for a constant array for a string literal of the specified element type and length.
CanQualType CharTy
CanQualType IntTy
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
CharUnits getTypeSizeInChars(QualType T) const
Return the size of the specified (complete) type T, in characters.
CanQualType UnsignedIntTy
QualType getSizeType() const
Return the unique type for "size_t" (C99 7.17), defined in <stddef.h>.
QualType getExtVectorType(QualType VectorType, unsigned NumElts) const
Return the unique reference to an extended vector type of the specified element type and size.
const TargetInfo & getTargetInfo() const
Definition ASTContext.h:909
QualType getHLSLAttributedResourceType(QualType Wrapped, QualType Contained, const HLSLAttributedResourceType::Attributes &Attrs)
QualType getAddrSpaceQualType(QualType T, LangAS AddressSpace) const
Return the uniqued reference to the type for an address space qualified type with the specified type ...
CanQualType getCanonicalTagType(const TagDecl *TD) const
static bool hasSameUnqualifiedType(QualType T1, QualType T2)
Determine whether the given types are equivalent after cvr-qualifiers have been removed.
unsigned getTypeAlign(QualType T) const
Return the ABI-specified alignment of a (complete) type T, in bits.
PtrTy get() const
Definition Ownership.h:171
bool isInvalid() const
Definition Ownership.h:167
Represents an array type, per C99 6.7.5.2 - Array Declarators.
Definition TypeBase.h:3723
QualType getElementType() const
Definition TypeBase.h:3735
Attr - This represents one attribute.
Definition Attr.h:45
attr::Kind getKind() const
Definition Attr.h:91
SourceLocation getLocation() const
Definition Attr.h:98
SourceLocation getScopeLoc() const
const IdentifierInfo * getScopeName() const
SourceLocation getLoc() const
const IdentifierInfo * getAttrName() const
Represents a base class of a C++ class.
Definition DeclCXX.h:146
QualType getType() const
Retrieves the type of the base class.
Definition DeclCXX.h:249
Represents a static or instance method of a struct/union/class.
Definition DeclCXX.h:2129
Represents a C++ struct/union/class.
Definition DeclCXX.h:258
bool isHLSLIntangible() const
Returns true if the class contains HLSL intangible type, either as a field or in base class.
Definition DeclCXX.h:1550
static CXXRecordDecl * Create(const ASTContext &C, TagKind TK, DeclContext *DC, SourceLocation StartLoc, SourceLocation IdLoc, IdentifierInfo *Id, CXXRecordDecl *PrevDecl=nullptr)
Definition DeclCXX.cpp:132
void setBases(CXXBaseSpecifier const *const *Bases, unsigned NumBases)
Sets the base classes of this struct or class.
Definition DeclCXX.cpp:184
void completeDefinition() override
Indicates that the definition of this class is now complete.
Definition DeclCXX.cpp:2239
base_class_range bases()
Definition DeclCXX.h:608
unsigned getNumBases() const
Retrieves the number of base classes of this class.
Definition DeclCXX.h:602
base_class_iterator bases_begin()
Definition DeclCXX.h:615
bool isEmpty() const
Determine whether this is an empty class in the sense of (C++11 [meta.unary.prop]).
Definition DeclCXX.h:1186
CallExpr - Represents a function call (C99 6.5.2.2, C++ [expr.call]).
Definition Expr.h:2943
Expr * getArg(unsigned Arg)
getArg - Return the specified argument.
Definition Expr.h:3147
SourceLocation getBeginLoc() const
Definition Expr.h:3277
static CallExpr * Create(const ASTContext &Ctx, Expr *Fn, ArrayRef< Expr * > Args, QualType Ty, ExprValueKind VK, SourceLocation RParenLoc, FPOptionsOverride FPFeatures, unsigned MinNumArgs=0, ADLCallKind UsesADL=NotADL)
Create a call expression.
Definition Expr.cpp:1516
FunctionDecl * getDirectCallee()
If the callee is a FunctionDecl, return it. Otherwise return null.
Definition Expr.h:3126
Expr * getCallee()
Definition Expr.h:3090
unsigned getNumArgs() const
getNumArgs - Return the number of actual arguments to this call.
Definition Expr.h:3134
QualType withConst() const
Retrieves a version of this type with const applied.
const T * getTypePtr() const
Retrieve the underlying type pointer, which refers to a canonical type.
QuantityType getQuantity() const
getQuantity - Get the raw integer representation of this quantity.
Definition CharUnits.h:185
Represents the canonical version of C arrays with a specified constant size.
Definition TypeBase.h:3761
bool isZeroSize() const
Return true if the size is zero.
Definition TypeBase.h:3831
llvm::APInt getSize() const
Return the constant array size as an APInt.
Definition TypeBase.h:3817
uint64_t getZExtSize() const
Return the size zero-extended as a uint64_t.
Definition TypeBase.h:3837
Represents a concrete matrix type with constant number of rows and columns.
Definition TypeBase.h:4388
static DeclAccessPair make(NamedDecl *D, AccessSpecifier AS)
DeclContext - This is used only as base class of specific decl types that can act as declaration cont...
Definition DeclBase.h:1449
lookup_result lookup(DeclarationName Name) const
lookup - Find the declarations (if any) with the given Name in this context.
bool isTranslationUnit() const
Definition DeclBase.h:2185
void addDecl(Decl *D)
Add the declaration D into this context.
decl_range decls() const
decls_begin/decls_end - Iterate over the declarations stored in this context.
Definition DeclBase.h:2373
DeclContext * getNonTransparentContext()
A reference to a declared variable, function, enum, etc.
Definition Expr.h:1270
static DeclRefExpr * Create(const ASTContext &Context, NestedNameSpecifierLoc QualifierLoc, SourceLocation TemplateKWLoc, ValueDecl *D, bool RefersToEnclosingVariableOrCapture, SourceLocation NameLoc, QualType T, ExprValueKind VK, NamedDecl *FoundD=nullptr, const TemplateArgumentListInfo *TemplateArgs=nullptr, NonOdrUseReason NOUR=NOUR_None)
Definition Expr.cpp:487
ValueDecl * getDecl()
Definition Expr.h:1338
Decl - This represents one declaration (or definition), e.g.
Definition DeclBase.h:86
T * getAttr() const
Definition DeclBase.h:573
void addAttr(Attr *A)
attr_iterator attr_end() const
Definition DeclBase.h:542
bool isImplicit() const
isImplicit - Indicates whether the declaration was implicitly generated by the implementation.
Definition DeclBase.h:593
void setInvalidDecl(bool Invalid=true)
setInvalidDecl - Indicates the Decl had a semantic error.
Definition DeclBase.cpp:178
bool isInExportDeclContext() const
Whether this declaration was exported in a lexical context.
attr_iterator attr_begin() const
Definition DeclBase.h:539
SourceLocation getLocation() const
Definition DeclBase.h:439
void setImplicit(bool I=true)
Definition DeclBase.h:594
DeclContext * getDeclContext()
Definition DeclBase.h:448
attr_range attrs() const
Definition DeclBase.h:535
AccessSpecifier getAccess() const
Definition DeclBase.h:507
SourceLocation getBeginLoc() const LLVM_READONLY
Definition DeclBase.h:431
void dropAttr()
Definition DeclBase.h:556
bool hasAttr() const
Definition DeclBase.h:577
The name of a declaration.
Represents a ValueDecl that came out of a declarator.
Definition Decl.h:780
SourceLocation getBeginLoc() const LLVM_READONLY
Definition Decl.h:831
This represents one expression.
Definition Expr.h:112
void setType(QualType t)
Definition Expr.h:145
ExprValueKind getValueKind() const
getValueKind - The value kind that this expression produces.
Definition Expr.h:444
Expr * IgnoreParenImpCasts() LLVM_READONLY
Skip past any parentheses and implicit casts which might surround this expression until reaching a fi...
Definition Expr.cpp:3089
Expr * IgnoreParens() LLVM_READONLY
Skip past any parentheses which might surround this expression until reaching a fixed point.
Definition Expr.cpp:3085
std::optional< llvm::APSInt > getIntegerConstantExpr(const ASTContext &Ctx) const
isIntegerConstantExpr - Return the value if this expression is a valid integer constant expression.
bool isLValue() const
isLValue - True if this expression is an "l-value" according to the rules of the current language.
Definition Expr.h:284
ExprObjectKind getObjectKind() const
getObjectKind - The object kind that this expression produces.
Definition Expr.h:451
bool HasSideEffects(const ASTContext &Ctx, bool IncludePossibleEffects=true) const
HasSideEffects - This routine returns true for all those expressions which have any effect other than...
Definition Expr.cpp:3669
void setValueKind(ExprValueKind Cat)
setValueKind - Set the value kind produced by this expression.
Definition Expr.h:461
SourceLocation getExprLoc() const LLVM_READONLY
getExprLoc - Return the preferred location for the arrow when diagnosing a problem with a generic exp...
Definition Expr.cpp:276
@ MLV_Valid
Definition Expr.h:305
QualType getType() const
Definition Expr.h:144
Represents a member of a struct/union/class.
Definition Decl.h:3160
static FieldDecl * Create(const ASTContext &C, DeclContext *DC, SourceLocation StartLoc, SourceLocation IdLoc, const IdentifierInfo *Id, QualType T, TypeSourceInfo *TInfo, Expr *BW, bool Mutable, InClassInitStyle InitStyle)
Definition Decl.cpp:4700
static FixItHint CreateReplacement(CharSourceRange RemoveRange, StringRef Code)
Create a code modification hint that replaces the given source range with the given code string.
Definition Diagnostic.h:140
Represents a function declaration or definition.
Definition Decl.h:2000
const ParmVarDecl * getParamDecl(unsigned i) const
Definition Decl.h:2797
Stmt * getBody(const FunctionDecl *&Definition) const
Retrieve the body (definition) of the function.
Definition Decl.cpp:3279
bool isThisDeclarationADefinition() const
Returns whether this specific declaration of the function is also a definition that does not contain ...
Definition Decl.h:2314
QualType getReturnType() const
Definition Decl.h:2845
ArrayRef< ParmVarDecl * > parameters() const
Definition Decl.h:2774
bool isTemplateInstantiation() const
Determines if the given function was instantiated from a function template.
Definition Decl.cpp:4257
redecl_range redecls() const
Returns an iterator range for all the redeclarations of the same decl.
unsigned getNumParams() const
Return the number of parameters this function must have based on its FunctionType.
Definition Decl.cpp:3826
DeclarationNameInfo getNameInfo() const
Definition Decl.h:2211
bool hasBody(const FunctionDecl *&Definition) const
Returns true if the function has a body.
Definition Decl.cpp:3199
bool isDefined(const FunctionDecl *&Definition, bool CheckForPendingFriendDefinition=false) const
Returns true if the function has a definition that does not need to be instantiated.
Definition Decl.cpp:3246
HLSLBufferDecl - Represent a cbuffer or tbuffer declaration.
Definition Decl.h:5193
static HLSLBufferDecl * Create(ASTContext &C, DeclContext *LexicalParent, bool CBuffer, SourceLocation KwLoc, IdentifierInfo *ID, SourceLocation IDLoc, SourceLocation LBrace)
Definition Decl.cpp:5900
void addLayoutStruct(CXXRecordDecl *LS)
Definition Decl.cpp:5940
void setHasValidPackoffset(bool PO)
Definition Decl.h:5238
static HLSLBufferDecl * CreateDefaultCBuffer(ASTContext &C, DeclContext *LexicalParent, ArrayRef< Decl * > DefaultCBufferDecls)
Definition Decl.cpp:5923
buffer_decl_range buffer_decls() const
Definition Decl.h:5268
static HLSLOutArgExpr * Create(const ASTContext &C, QualType Ty, OpaqueValueExpr *Base, OpaqueValueExpr *OpV, Expr *WB, bool IsInOut)
Definition Expr.cpp:5531
static HLSLRootSignatureDecl * Create(ASTContext &C, DeclContext *DC, SourceLocation Loc, IdentifierInfo *ID, llvm::dxbc::RootSignatureVersion Version, ArrayRef< llvm::hlsl::rootsig::RootElement > RootElements)
Definition Decl.cpp:5986
One of these records is kept for each identifier that is lexed.
StringRef getName() const
Return the actual identifier string.
A simple pair of identifier info and location.
SourceLocation getLoc() const
IdentifierInfo * getIdentifierInfo() const
IdentifierInfo & get(StringRef Name)
Return the identifier token info for the specified named identifier.
static ImplicitCastExpr * Create(const ASTContext &Context, QualType T, CastKind Kind, Expr *Operand, const CXXCastPath *BasePath, ExprValueKind Cat, FPOptionsOverride FPO)
Definition Expr.cpp:2072
Describes an C or C++ initializer list.
Definition Expr.h:5299
Describes an entity that is being initialized.
QualType getType() const
Retrieve type being initialized.
static InitializedEntity InitializeParameter(ASTContext &Context, ParmVarDecl *Parm)
Create the initialization entity for a parameter.
static IntegerLiteral * Create(const ASTContext &C, const llvm::APInt &V, QualType type, SourceLocation l)
Returns a new integer literal with value 'V' and type 'type'.
Definition Expr.cpp:974
Represents the results of name lookup.
Definition Lookup.h:147
NamedDecl * getFoundDecl() const
Fetch the unique decl found by this lookup.
Definition Lookup.h:569
Represents a prvalue temporary that is written into memory so that a reference can bind to it.
Definition ExprCXX.h:4920
ValueDecl * getMemberDecl() const
Retrieve the member declaration to which this expression refers.
Definition Expr.h:3447
This represents a decl that may have a name.
Definition Decl.h:274
IdentifierInfo * getIdentifier() const
Get the identifier that names this declaration, if there is one.
Definition Decl.h:295
StringRef getName() const
Get the name of identifier for this declaration as a StringRef.
Definition Decl.h:301
DeclarationName getDeclName() const
Get the actual, stored name of the declaration, which may be a special name.
Definition Decl.h:340
OpaqueValueExpr - An expression referring to an opaque object of a fixed type and value class.
Definition Expr.h:1178
Represents a parameter to a function.
Definition Decl.h:1790
ParsedAttr - Represents a syntactic attribute.
Definition ParsedAttr.h:119
unsigned getSemanticSpelling() const
If the parsed attribute has a semantic equivalent, and it would have a semantic Spelling enumeration ...
unsigned getMinArgs() const
bool checkExactlyNumArgs(class Sema &S, unsigned Num) const
Check if the attribute has exactly as many args as Num.
IdentifierLoc * getArgAsIdent(unsigned Arg) const
Definition ParsedAttr.h:389
bool hasParsedType() const
Definition ParsedAttr.h:337
const ParsedType & getTypeArg() const
Definition ParsedAttr.h:459
unsigned getNumArgs() const
getNumArgs - Return the number of actual arguments to this attribute.
Definition ParsedAttr.h:371
bool isArgIdent(unsigned Arg) const
Definition ParsedAttr.h:385
Expr * getArgAsExpr(unsigned Arg) const
Definition ParsedAttr.h:383
AttributeCommonInfo::Kind getKind() const
Definition ParsedAttr.h:610
A (possibly-)qualified type.
Definition TypeBase.h:937
void addRestrict()
Add the restrict qualifier to this QualType.
Definition TypeBase.h:1172
QualType getNonLValueExprType(const ASTContext &Context) const
Determine the type of a (typically non-lvalue) expression with the specified result type.
Definition Type.cpp:3556
QualType getDesugaredType(const ASTContext &Context) const
Return the specified type with any "sugar" removed from the type.
Definition TypeBase.h:1296
bool isNull() const
Return true if this QualType doesn't point to a type yet.
Definition TypeBase.h:1004
const Type * getTypePtr() const
Retrieves a pointer to the underlying (unqualified) type.
Definition TypeBase.h:8292
LangAS getAddressSpace() const
Return the address space of this type.
Definition TypeBase.h:8418
QualType getNonReferenceType() const
If Type is a reference type (e.g., const int&), returns the type that the reference refers to ("const...
Definition TypeBase.h:8477
QualType getCanonicalType() const
Definition TypeBase.h:8344
QualType getUnqualifiedType() const
Retrieve the unqualified variant of the given type, removing as little sugar as possible.
Definition TypeBase.h:8386
bool hasAddressSpace() const
Check if this type has any address space qualifier.
Definition TypeBase.h:8413
Represents a struct/union/class.
Definition Decl.h:4324
field_iterator field_end() const
Definition Decl.h:4530
field_range fields() const
Definition Decl.h:4527
bool field_empty() const
Definition Decl.h:4535
field_iterator field_begin() const
Definition Decl.cpp:5270
bool hasBindingInfoForDecl(const VarDecl *VD) const
Definition SemaHLSL.cpp:194
DeclBindingInfo * getDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass)
Definition SemaHLSL.cpp:180
DeclBindingInfo * addDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass)
Definition SemaHLSL.cpp:167
Scope - A scope is a transient data structure that is used while parsing the program.
Definition Scope.h:41
SemaBase(Sema &S)
Definition SemaBase.cpp:7
ASTContext & getASTContext() const
Definition SemaBase.cpp:9
Sema & SemaRef
Definition SemaBase.h:40
SemaDiagnosticBuilder Diag(SourceLocation Loc, unsigned DiagID)
Emit a diagnostic.
Definition SemaBase.cpp:61
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg)
HLSLRootSignatureDecl * lookupRootSignatureOverrideDecl(DeclContext *DC) const
bool CanPerformElementwiseCast(Expr *Src, QualType DestType)
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL)
void handleVkLocationAttr(Decl *D, const ParsedAttr &AL)
HLSLAttributedResourceLocInfo TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT)
void handleSemanticAttr(Decl *D, const ParsedAttr &AL)
bool CanPerformScalarCast(QualType SrcTy, QualType DestTy)
QualType ProcessResourceTypeAttributes(QualType Wrapped)
void handleShaderAttr(Decl *D, const ParsedAttr &AL)
void CheckEntryPoint(FunctionDecl *FD)
Definition SemaHLSL.cpp:891
void emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS, BinaryOperatorKind Opc)
T * createSemanticAttr(const AttributeCommonInfo &ACI, std::optional< unsigned > Location)
Definition SemaHLSL.h:179
void ActOnEndOfTranslationUnit(TranslationUnitDecl *TU)
HLSLVkConstantIdAttr * mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id)
Definition SemaHLSL.cpp:657
HLSLNumThreadsAttr * mergeNumThreadsAttr(Decl *D, const AttributeCommonInfo &AL, int X, int Y, int Z)
Definition SemaHLSL.cpp:623
void deduceAddressSpace(VarDecl *Decl)
std::pair< IdentifierInfo *, bool > ActOnStartRootSignatureDecl(StringRef Signature)
Computes the unique Root Signature identifier from the given signature, then lookup if there is a pre...
void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL)
bool diagnosePositionType(QualType T, const ParsedAttr &AL)
bool handleInitialization(VarDecl *VDecl, Expr *&Init)
bool diagnoseInputIDType(QualType T, const ParsedAttr &AL)
void handleParamModifierAttr(Decl *D, const ParsedAttr &AL)
bool CheckResourceBinOp(BinaryOperatorKind Opc, Expr *LHSExpr, Expr *RHSExpr, SourceLocation Loc)
bool CanPerformAggregateSplatCast(Expr *Src, QualType DestType)
bool IsScalarizedLayoutCompatible(QualType T1, QualType T2) const
void diagnoseSystemSemanticAttr(Decl *D, const ParsedAttr &AL, std::optional< unsigned > Index)
void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL)
bool CheckCompatibleParameterABI(FunctionDecl *New, FunctionDecl *Old)
QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, bool IsCompAssign)
void handleResourceBindingAttr(Decl *D, const ParsedAttr &AL)
bool IsTypedResourceElementCompatible(QualType T1)
bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init)
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL)
bool ActOnUninitializedVarDecl(VarDecl *D)
void handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL)
void ActOnTopLevelFunction(FunctionDecl *FD)
Definition SemaHLSL.cpp:726
bool handleResourceTypeAttr(QualType T, const ParsedAttr &AL)
void handleVkPushConstantAttr(Decl *D, const ParsedAttr &AL)
HLSLShaderAttr * mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL, llvm::Triple::EnvironmentType ShaderType)
Definition SemaHLSL.cpp:693
void ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace)
Definition SemaHLSL.cpp:596
void handleVkBindingAttr(Decl *D, const ParsedAttr &AL)
HLSLParamModifierAttr * mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL, HLSLParamModifierAttr::Spelling Spelling)
Definition SemaHLSL.cpp:706
QualType getInoutParameterType(QualType Ty)
SemaHLSL(Sema &S)
Definition SemaHLSL.cpp:198
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL)
Decl * ActOnStartBuffer(Scope *BufferScope, bool CBuffer, SourceLocation KwLoc, IdentifierInfo *Ident, SourceLocation IdentLoc, SourceLocation LBrace)
Definition SemaHLSL.cpp:200
HLSLWaveSizeAttr * mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL, int Min, int Max, int Preferred, int SpelledArgsCount)
Definition SemaHLSL.cpp:637
bool handleRootSignatureElements(ArrayRef< hlsl::RootSignatureElement > Elements)
void ActOnFinishRootSignatureDecl(SourceLocation Loc, IdentifierInfo *DeclIdent, ArrayRef< hlsl::RootSignatureElement > Elements)
Creates the Root Signature decl of the parsed Root Signature elements onto the AST and push it onto c...
void ActOnVariableDeclarator(VarDecl *VD)
bool CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall)
Sema - This implements semantic analysis and AST building for C.
Definition Sema.h:855
@ LookupOrdinaryName
Ordinary name lookup, which finds ordinary names (functions, variables, typedefs, etc....
Definition Sema.h:9327
@ LookupMemberName
Member name lookup, which finds the names of class/struct/union members.
Definition Sema.h:9335
ASTContext & Context
Definition Sema.h:1283
ASTContext & getASTContext() const
Definition Sema.h:926
ExprResult ImpCastExprToType(Expr *E, QualType Type, CastKind CK, ExprValueKind VK=VK_PRValue, const CXXCastPath *BasePath=nullptr, CheckedConversionKind CCK=CheckedConversionKind::Implicit)
ImpCastExprToType - If Expr is not of type 'Type', insert an implicit cast.
Definition Sema.cpp:756
const LangOptions & getLangOpts() const
Definition Sema.h:919
ExprResult BuildFieldReferenceExpr(Expr *BaseExpr, bool IsArrow, SourceLocation OpLoc, const CXXScopeSpec &SS, FieldDecl *Field, DeclAccessPair FoundDecl, const DeclarationNameInfo &MemberNameInfo)
ExprResult CreateBuiltinArraySubscriptExpr(Expr *Base, SourceLocation LLoc, Expr *Idx, SourceLocation RLoc)
bool LookupQualifiedName(LookupResult &R, DeclContext *LookupCtx, bool InUnqualifiedLookup=false)
Perform qualified name lookup into a given context.
ExprResult PerformCopyInitialization(const InitializedEntity &Entity, SourceLocation EqualLoc, ExprResult Init, bool TopLevelOfInitList=false, bool AllowExplicit=false)
ExprResult CreateBuiltinMatrixSubscriptExpr(Expr *Base, Expr *RowIdx, Expr *ColumnIdx, SourceLocation RBLoc)
Encodes a location in the source.
A trivial tuple used to represent a source range.
SourceLocation getEnd() const
SourceLocation getEndLoc() const LLVM_READONLY
Definition Stmt.cpp:362
void printPretty(raw_ostream &OS, PrinterHelper *Helper, const PrintingPolicy &Policy, unsigned Indentation=0, StringRef NewlineSymbol="\n", const ASTContext *Context=nullptr) const
SourceRange getSourceRange() const LLVM_READONLY
SourceLocation tokens are not useful in isolation - they are low level value objects created/interpre...
Definition Stmt.cpp:338
SourceLocation getBeginLoc() const LLVM_READONLY
Definition Stmt.cpp:350
static StringLiteral * Create(const ASTContext &Ctx, StringRef Str, StringLiteralKind Kind, bool Pascal, QualType Ty, ArrayRef< SourceLocation > Locs)
This is the "fully general" constructor that allows representation of strings formed from one or more...
Definition Expr.cpp:1187
void startDefinition()
Starts the definition of this tag declaration.
Definition Decl.cpp:4906
bool isUnion() const
Definition Decl.h:3925
bool isClass() const
Definition Decl.h:3924
Exposes information about the current target.
Definition TargetInfo.h:226
TargetOptions & getTargetOpts() const
Retrieve the target options.
Definition TargetInfo.h:326
const llvm::Triple & getTriple() const
Returns the target triple of the primary target.
StringRef getPlatformName() const
Retrieve the name of the platform as it is used in the availability attribute.
VersionTuple getPlatformMinVersion() const
Retrieve the minimum desired version of the platform, to which the program should be compiled.
std::string HLSLEntry
The entry point name for HLSL shader being compiled as specified by -E.
The top declaration context.
Definition Decl.h:105
SourceLocation getBeginLoc() const
Get the begin source location.
Definition TypeLoc.cpp:193
A container of type source information.
Definition TypeBase.h:8263
TypeLoc getTypeLoc() const
Return the TypeLoc wrapper for the type source info.
Definition TypeLoc.h:267
The base class of the type hierarchy.
Definition TypeBase.h:1833
bool isVoidType() const
Definition TypeBase.h:8891
bool isBooleanType() const
Definition TypeBase.h:9021
bool isIncompleteArrayType() const
Definition TypeBase.h:8636
CXXRecordDecl * getAsCXXRecordDecl() const
Retrieves the CXXRecordDecl that this type refers to, either because the type is a RecordType or beca...
Definition Type.h:26
bool isConstantArrayType() const
Definition TypeBase.h:8632
bool hasIntegerRepresentation() const
Determine whether this type has an integer representation of some sort, e.g., it is an integer type o...
Definition Type.cpp:2067
bool isArrayType() const
Definition TypeBase.h:8628
CXXRecordDecl * castAsCXXRecordDecl() const
Definition Type.h:36
bool isArithmeticType() const
Definition Type.cpp:2338
bool isConstantMatrixType() const
Definition TypeBase.h:8696
bool isHLSLBuiltinIntangibleType() const
Definition TypeBase.h:8836
CanQualType getCanonicalTypeUnqualified() const
bool isIntegerType() const
isIntegerType() does not include complex integers (a GCC extension).
Definition TypeBase.h:8935
const T * castAs() const
Member-template castAs<specific type>.
Definition TypeBase.h:9178
bool isReferenceType() const
Definition TypeBase.h:8553
bool isHLSLIntangibleType() const
Definition Type.cpp:5376
bool isEnumeralType() const
Definition TypeBase.h:8660
bool isScalarType() const
Definition TypeBase.h:8993
bool isIntegralType(const ASTContext &Ctx) const
Determine whether this type is an integral type.
Definition Type.cpp:2104
const Type * getArrayElementTypeNoTypeQual() const
If this is an array type, return the element type of the array, potentially with type qualifiers miss...
Definition Type.cpp:472
bool hasUnsignedIntegerRepresentation() const
Determine whether this type has an unsigned integer representation of some sort, e....
Definition Type.cpp:2292
bool isAggregateType() const
Determines whether the type is a C++ aggregate type or C aggregate or union type.
Definition Type.cpp:2412
ScalarTypeKind getScalarTypeKind() const
Given that this is a scalar type, classify it.
Definition Type.cpp:2365
bool hasSignedIntegerRepresentation() const
Determine whether this type has an signed integer representation of some sort, e.g....
Definition Type.cpp:2244
bool isHLSLResourceRecord() const
Definition Type.cpp:5363
bool hasFloatingRepresentation() const
Determine whether this type has a floating-point representation of some sort, e.g....
Definition Type.cpp:2313
bool isVectorType() const
Definition TypeBase.h:8668
bool isRealFloatingType() const
Floating point categories.
Definition Type.cpp:2321
bool isHLSLAttributedResourceType() const
Definition TypeBase.h:8848
@ STK_FloatingComplex
Definition TypeBase.h:2765
@ STK_ObjCObjectPointer
Definition TypeBase.h:2759
@ STK_IntegralComplex
Definition TypeBase.h:2764
@ STK_MemberPointer
Definition TypeBase.h:2760
bool isFloatingType() const
Definition Type.cpp:2305
bool isSamplerT() const
Definition TypeBase.h:8769
const T * getAs() const
Member-template getAs<specific type>'.
Definition TypeBase.h:9111
const Type * getUnqualifiedDesugaredType() const
Return the specified type with any "sugar" removed from the type, removing any typedefs,...
Definition Type.cpp:654
bool isRecordType() const
Definition TypeBase.h:8656
bool isHLSLResourceRecordArray() const
Definition Type.cpp:5367
void setType(QualType newType)
Definition Decl.h:724
QualType getType() const
Definition Decl.h:723
Represents a variable declaration or definition.
Definition Decl.h:926
void setInitStyle(InitializationStyle Style)
Definition Decl.h:1452
@ CallInit
Call-style initialization (C++98)
Definition Decl.h:934
void setStorageClass(StorageClass SC)
Definition Decl.cpp:2174
bool hasGlobalStorage() const
Returns true for all variables that do not have local storage.
Definition Decl.h:1226
void setInit(Expr *I)
Definition Decl.cpp:2488
StorageClass getStorageClass() const
Returns the storage class as written in the source.
Definition Decl.h:1168
Represents a GCC generic vector type.
Definition TypeBase.h:4176
unsigned getNumElements() const
Definition TypeBase.h:4191
QualType getElementType() const
Definition TypeBase.h:4190
Defines the clang::TargetInfo interface.
The JSON file list parser is used to communicate input to InstallAPI.
bool isa(CodeGen::Address addr)
Definition Address.h:330
if(T->getSizeExpr()) TRY_TO(TraverseStmt(const_cast< Expr * >(T -> getSizeExpr())))
static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
Definition SemaSPIRV.cpp:66
@ ICIS_NoInit
No in-class initializer.
Definition Specifiers.h:272
@ OK_Ordinary
An ordinary object is located at an address in memory.
Definition Specifiers.h:151
static bool CheckAllArgTypesAreCorrect(Sema *S, CallExpr *TheCall, llvm::ArrayRef< llvm::function_ref< bool(Sema *, SourceLocation, int, QualType)> > Checks)
Definition SemaSPIRV.cpp:49
@ AS_public
Definition Specifiers.h:124
@ AS_none
Definition Specifiers.h:127
@ SC_Static
Definition Specifiers.h:252
@ AANT_ArgumentIdentifier
@ Result
The result type of a method or function.
Definition TypeBase.h:905
@ Ordinary
This parameter uses ordinary ABI rules for its type.
Definition Specifiers.h:380
const FunctionProtoType * T
llvm::Expected< QualType > ExpectedType
static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall)
Definition SemaSPIRV.cpp:32
ExprResult ExprError()
Definition Ownership.h:265
@ Type
The name was classified as a type.
Definition Sema.h:563
LangAS
Defines the address space values used by the address space qualifier of QualType.
bool CreateHLSLAttributedResourceType(Sema &S, QualType Wrapped, ArrayRef< const Attr * > AttrList, QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo=nullptr)
CastKind
CastKind - The kind of operation required for a conversion.
@ VK_PRValue
A pr-value expression (in the C++11 taxonomy) produces a temporary value.
Definition Specifiers.h:135
@ VK_LValue
An l-value expression is a reference to an object with independent storage.
Definition Specifiers.h:139
DynamicRecursiveASTVisitorBase< false > DynamicRecursiveASTVisitor
U cast(CodeGen::Address addr)
Definition Address.h:327
ActionResult< Expr * > ExprResult
Definition Ownership.h:249
Visibility
Describes the different kinds of visibility that a declaration may have.
Definition Visibility.h:34
unsigned long uint64_t
unsigned int uint32_t
hash_code hash_value(const clang::dependencies::ModuleID &ID)
__DEVICE__ bool isnan(float __x)
__DEVICE__ _Tp abs(const std::complex< _Tp > &__c)
#define false
Definition stdbool.h:26
Describes how types, statements, expressions, and declarations should be printed.
void setCounterImplicitOrderID(unsigned Value) const
void setImplicitOrderID(unsigned Value) const
const SourceLocation & getLocation() const
Definition SemaHLSL.h:48
const llvm::hlsl::rootsig::RootElement & getElement() const
Definition SemaHLSL.h:47