clang 19.0.0git
CGHLSLRuntime.cpp
Go to the documentation of this file.
1//===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This provides an abstract class for HLSL code generation. Concrete
10// subclasses of this implement code generation for specific HLSL
11// runtime libraries.
12//
13//===----------------------------------------------------------------------===//
14
15#include "CGHLSLRuntime.h"
16#include "CGDebugInfo.h"
17#include "CodeGenModule.h"
18#include "clang/AST/Decl.h"
20#include "llvm/IR/IntrinsicsDirectX.h"
21#include "llvm/IR/IntrinsicsSPIRV.h"
22#include "llvm/IR/Metadata.h"
23#include "llvm/IR/Module.h"
24#include "llvm/Support/FormatVariadic.h"
25
26using namespace clang;
27using namespace CodeGen;
28using namespace clang::hlsl;
29using namespace llvm;
30
31namespace {
32
33void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
34 // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
35 // Assume ValVersionStr is legal here.
36 VersionTuple Version;
37 if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
38 Version.getSubminor() || !Version.getMinor()) {
39 return;
40 }
41
42 uint64_t Major = Version.getMajor();
43 uint64_t Minor = *Version.getMinor();
44
45 auto &Ctx = M.getContext();
46 IRBuilder<> B(M.getContext());
47 MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),
48 ConstantAsMetadata::get(B.getInt32(Minor))});
49 StringRef DXILValKey = "dx.valver";
50 auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);
51 DXILValMD->addOperand(Val);
52}
53void addDisableOptimizations(llvm::Module &M) {
54 StringRef Key = "dx.disable_optimizations";
55 M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);
56}
57// cbuffer will be translated into global variable in special address space.
58// If translate into C,
59// cbuffer A {
60// float a;
61// float b;
62// }
63// float foo() { return a + b; }
64//
65// will be translated into
66//
67// struct A {
68// float a;
69// float b;
70// } cbuffer_A __attribute__((address_space(4)));
71// float foo() { return cbuffer_A.a + cbuffer_A.b; }
72//
73// layoutBuffer will create the struct A type.
74// replaceBuffer will replace use of global variable a and b with cbuffer_A.a
75// and cbuffer_A.b.
76//
77void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) {
78 if (Buf.Constants.empty())
79 return;
80
81 std::vector<llvm::Type *> EltTys;
82 for (auto &Const : Buf.Constants) {
83 GlobalVariable *GV = Const.first;
84 Const.second = EltTys.size();
85 llvm::Type *Ty = GV->getValueType();
86 EltTys.emplace_back(Ty);
87 }
88 Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);
89}
90
91GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) {
92 // Create global variable for CB.
93 GlobalVariable *CBGV = new GlobalVariable(
94 Buf.LayoutStruct, /*isConstant*/ true,
95 GlobalValue::LinkageTypes::ExternalLinkage, nullptr,
96 llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."),
97 GlobalValue::NotThreadLocal);
98
99 IRBuilder<> B(CBGV->getContext());
100 Value *ZeroIdx = B.getInt32(0);
101 // Replace Const use with CB use.
102 for (auto &[GV, Offset] : Buf.Constants) {
103 Value *GEP =
104 B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});
105
106 assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&
107 "constant type mismatch");
108
109 // Replace.
110 GV->replaceAllUsesWith(GEP);
111 // Erase GV.
112 GV->removeDeadConstantUsers();
113 GV->eraseFromParent();
114 }
115 return CBGV;
116}
117
118} // namespace
119
120void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {
121 if (D->getStorageClass() == SC_Static) {
122 // For static inside cbuffer, take as global static.
123 // Don't add to cbuffer.
124 CGM.EmitGlobal(D);
125 return;
126 }
127
128 auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D));
129 // Add debug info for constVal.
131 if (CGM.getCodeGenOpts().getDebugInfo() >=
132 codegenoptions::DebugInfoKind::LimitedDebugInfo)
133 DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D);
134
135 // FIXME: support packoffset.
136 // See https://github.com/llvm/llvm-project/issues/57914.
137 uint32_t Offset = 0;
138 bool HasUserOffset = false;
139
140 unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;
141 CB.Constants.emplace_back(std::make_pair(GV, LowerBound));
142}
143
144void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {
145 for (Decl *it : DC->decls()) {
146 if (auto *ConstDecl = dyn_cast<VarDecl>(it)) {
147 addConstant(ConstDecl, CB);
148 } else if (isa<CXXRecordDecl, EmptyDecl>(it)) {
149 // Nothing to do for this declaration.
150 } else if (isa<FunctionDecl>(it)) {
151 // A function within an cbuffer is effectively a top-level function,
152 // as it only refers to globally scoped declarations.
154 }
155 }
156}
157
159 Buffers.emplace_back(Buffer(D));
160 addBufferDecls(D, Buffers.back());
161}
162
164 auto &TargetOpts = CGM.getTarget().getTargetOpts();
165 llvm::Module &M = CGM.getModule();
166 Triple T(M.getTargetTriple());
167 if (T.getArch() == Triple::ArchType::dxil)
168 addDxilValVersion(TargetOpts.DxilValidatorVersion, M);
169
171 if (CGM.getCodeGenOpts().OptimizationLevel == 0)
172 addDisableOptimizations(M);
173
174 const DataLayout &DL = M.getDataLayout();
175
176 for (auto &Buf : Buffers) {
177 layoutBuffer(Buf, DL);
178 GlobalVariable *GV = replaceBuffer(Buf);
179 M.insertGlobalVariable(GV);
180 llvm::hlsl::ResourceClass RC = Buf.IsCBuffer
181 ? llvm::hlsl::ResourceClass::CBuffer
182 : llvm::hlsl::ResourceClass::SRV;
183 llvm::hlsl::ResourceKind RK = Buf.IsCBuffer
184 ? llvm::hlsl::ResourceKind::CBuffer
185 : llvm::hlsl::ResourceKind::TBuffer;
186 addBufferResourceAnnotation(GV, RC, RK, /*IsROV=*/false,
187 llvm::hlsl::ElementType::Invalid, Buf.Binding);
188 }
189}
190
192 : Name(D->getName()), IsCBuffer(D->isCBuffer()),
193 Binding(D->getAttr<HLSLResourceBindingAttr>()) {}
194
195void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
196 llvm::hlsl::ResourceClass RC,
197 llvm::hlsl::ResourceKind RK,
198 bool IsROV,
199 llvm::hlsl::ElementType ET,
200 BufferResBinding &Binding) {
201 llvm::Module &M = CGM.getModule();
202
203 NamedMDNode *ResourceMD = nullptr;
204 switch (RC) {
205 case llvm::hlsl::ResourceClass::UAV:
206 ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs");
207 break;
208 case llvm::hlsl::ResourceClass::SRV:
209 ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs");
210 break;
211 case llvm::hlsl::ResourceClass::CBuffer:
212 ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs");
213 break;
214 default:
215 assert(false && "Unsupported buffer type!");
216 return;
217 }
218 assert(ResourceMD != nullptr &&
219 "ResourceMD must have been set by the switch above.");
220
221 llvm::hlsl::FrontendResource Res(
222 GV, RK, ET, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space);
223 ResourceMD->addOperand(Res.getMetadata());
224}
225
226static llvm::hlsl::ElementType
227calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy) {
228 using llvm::hlsl::ElementType;
229
230 // TODO: We may need to update this when we add things like ByteAddressBuffer
231 // that don't have a template parameter (or, indeed, an element type).
232 const auto *TST = ResourceTy->getAs<TemplateSpecializationType>();
233 assert(TST && "Resource types must be template specializations");
234 ArrayRef<TemplateArgument> Args = TST->template_arguments();
235 assert(!Args.empty() && "Resource has no element type");
236
237 // At this point we have a resource with an element type, so we can assume
238 // that it's valid or we would have diagnosed the error earlier.
239 QualType ElTy = Args[0].getAsType();
240
241 // We should either have a basic type or a vector of a basic type.
242 if (const auto *VecTy = ElTy->getAs<clang::VectorType>())
243 ElTy = VecTy->getElementType();
244
245 if (ElTy->isSignedIntegerType()) {
246 switch (Context.getTypeSize(ElTy)) {
247 case 16:
248 return ElementType::I16;
249 case 32:
250 return ElementType::I32;
251 case 64:
252 return ElementType::I64;
253 }
254 } else if (ElTy->isUnsignedIntegerType()) {
255 switch (Context.getTypeSize(ElTy)) {
256 case 16:
257 return ElementType::U16;
258 case 32:
259 return ElementType::U32;
260 case 64:
261 return ElementType::U64;
262 }
263 } else if (ElTy->isSpecificBuiltinType(BuiltinType::Half))
264 return ElementType::F16;
265 else if (ElTy->isSpecificBuiltinType(BuiltinType::Float))
266 return ElementType::F32;
267 else if (ElTy->isSpecificBuiltinType(BuiltinType::Double))
268 return ElementType::F64;
269
270 // TODO: We need to handle unorm/snorm float types here once we support them
271 llvm_unreachable("Invalid element type for resource");
272}
273
274void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {
275 const Type *Ty = D->getType()->getPointeeOrArrayElementType();
276 if (!Ty)
277 return;
278 const auto *RD = Ty->getAsCXXRecordDecl();
279 if (!RD)
280 return;
281 const auto *Attr = RD->getAttr<HLSLResourceAttr>();
282 if (!Attr)
283 return;
284
285 llvm::hlsl::ResourceClass RC = Attr->getResourceClass();
286 llvm::hlsl::ResourceKind RK = Attr->getResourceKind();
287 bool IsROV = Attr->getIsROV();
288 llvm::hlsl::ElementType ET = calculateElementType(CGM.getContext(), Ty);
289
290 BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>());
291 addBufferResourceAnnotation(GV, RC, RK, IsROV, ET, Binding);
292}
293
295 HLSLResourceBindingAttr *Binding) {
296 if (Binding) {
297 llvm::APInt RegInt(64, 0);
298 Binding->getSlot().substr(1).getAsInteger(10, RegInt);
299 Reg = RegInt.getLimitedValue();
300 llvm::APInt SpaceInt(64, 0);
301 Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);
302 Space = SpaceInt.getLimitedValue();
303 } else {
304 Space = 0;
305 }
306}
307
309 const FunctionDecl *FD, llvm::Function *Fn) {
310 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
311 assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
312 const StringRef ShaderAttrKindStr = "hlsl.shader";
313 Fn->addFnAttr(ShaderAttrKindStr,
314 ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType()));
315 if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
316 const StringRef NumThreadsKindStr = "hlsl.numthreads";
317 std::string NumThreadsStr =
318 formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),
319 NumThreadsAttr->getZ());
320 Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);
321 }
322}
323
324static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {
325 if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) {
326 Value *Result = PoisonValue::get(Ty);
327 for (unsigned I = 0; I < VT->getNumElements(); ++I) {
328 Value *Elt = B.CreateCall(F, {B.getInt32(I)});
329 Result = B.CreateInsertElement(Result, Elt, I);
330 }
331 return Result;
332 }
333 return B.CreateCall(F, {B.getInt32(0)});
334}
335
336llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
337 const ParmVarDecl &D,
338 llvm::Type *Ty) {
339 assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
340 if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
341 llvm::Function *DxGroupIndex =
342 CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group);
343 return B.CreateCall(FunctionCallee(DxGroupIndex));
344 }
345 if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
346 llvm::Function *ThreadIDIntrinsic;
347 switch (CGM.getTarget().getTriple().getArch()) {
348 case llvm::Triple::dxil:
349 ThreadIDIntrinsic = CGM.getIntrinsic(Intrinsic::dx_thread_id);
350 break;
351 case llvm::Triple::spirv:
352 ThreadIDIntrinsic = CGM.getIntrinsic(Intrinsic::spv_thread_id);
353 break;
354 default:
355 llvm_unreachable("Input semantic not supported by target");
356 break;
357 }
358 return buildVectorInput(B, ThreadIDIntrinsic, Ty);
359 }
360 assert(false && "Unhandled parameter attribute");
361 return nullptr;
362}
363
365 llvm::Function *Fn) {
366 llvm::Module &M = CGM.getModule();
367 llvm::LLVMContext &Ctx = M.getContext();
368 auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
369 Function *EntryFn =
370 Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M);
371
372 // Copy function attributes over, we have no argument or return attributes
373 // that can be valid on the real entry.
374 AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,
375 Fn->getAttributes().getFnAttrs());
376 EntryFn->setAttributes(NewAttrs);
377 setHLSLEntryAttributes(FD, EntryFn);
378
379 // Set the called function as internal linkage.
380 Fn->setLinkage(GlobalValue::InternalLinkage);
381
382 BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
383 IRBuilder<> B(BB);
385 // FIXME: support struct parameters where semantics are on members.
386 // See: https://github.com/llvm/llvm-project/issues/57874
387 unsigned SRetOffset = 0;
388 for (const auto &Param : Fn->args()) {
389 if (Param.hasStructRetAttr()) {
390 // FIXME: support output.
391 // See: https://github.com/llvm/llvm-project/issues/57874
392 SRetOffset = 1;
393 Args.emplace_back(PoisonValue::get(Param.getType()));
394 continue;
395 }
396 const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
397 Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
398 }
399
400 CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args);
401 (void)CI;
402 // FIXME: Handle codegen for return type semantics.
403 // See: https://github.com/llvm/llvm-project/issues/57875
404 B.CreateRetVoid();
405}
406
407static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M,
408 bool CtorOrDtor) {
409 const auto *GV =
410 M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
411 if (!GV)
412 return;
413 const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer());
414 if (!CA)
415 return;
416 // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
417 // HLSL neither supports priorities or COMDat values, so we will check those
418 // in an assert but not handle them.
419
421 for (const auto &Ctor : CA->operands()) {
422 if (isa<ConstantAggregateZero>(Ctor))
423 continue;
424 ConstantStruct *CS = cast<ConstantStruct>(Ctor);
425
426 assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&
427 "HLSL doesn't support setting priority for global ctors.");
428 assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&
429 "HLSL doesn't support COMDat for global ctors.");
430 Fns.push_back(cast<Function>(CS->getOperand(1)));
431 }
432}
433
435 llvm::Module &M = CGM.getModule();
438 gatherFunctions(CtorFns, M, true);
439 gatherFunctions(DtorFns, M, false);
440
441 // Insert a call to the global constructor at the beginning of the entry block
442 // to externally exported functions. This is a bit of a hack, but HLSL allows
443 // global constructors, but doesn't support driver initialization of globals.
444 for (auto &F : M.functions()) {
445 if (!F.hasFnAttribute("hlsl.shader"))
446 continue;
447 IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin());
448 for (auto *Fn : CtorFns)
449 B.CreateCall(FunctionCallee(Fn));
450
451 // Insert global dtors before the terminator of the last instruction
452 B.SetInsertPoint(F.back().getTerminator());
453 for (auto *Fn : DtorFns)
454 B.CreateCall(FunctionCallee(Fn));
455 }
456
457 // No need to keep global ctors/dtors for non-lib profile after call to
458 // ctors/dtors added for entry.
459 Triple T(M.getTargetTriple());
460 if (T.getEnvironment() != Triple::EnvironmentType::Library) {
461 if (auto *GV = M.getNamedGlobal("llvm.global_ctors"))
462 GV->eraseFromParent();
463 if (auto *GV = M.getNamedGlobal("llvm.global_dtors"))
464 GV->eraseFromParent();
465 }
466}
static llvm::hlsl::ElementType calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy)
static void gatherFunctions(SmallVectorImpl< Function * > &Fns, llvm::Module &M, bool CtorOrDtor)
static Value * buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty)
static std::string getName(const CallEvent &Call)
Defines the clang::TargetOptions class.
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
Definition: ASTContext.h:182
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
Definition: ASTContext.h:2322
Attr - This represents one attribute.
Definition: Attr.h:42
This class gathers all debug information during compilation and is responsible for emitting to llvm g...
Definition: CGDebugInfo.h:55
void setHLSLEntryAttributes(const FunctionDecl *FD, llvm::Function *Fn)
void emitEntryFunction(const FunctionDecl *FD, llvm::Function *Fn)
llvm::Value * emitInputSemantic(llvm::IRBuilder<> &B, const ParmVarDecl &D, llvm::Type *Ty)
void annotateHLSLResource(const VarDecl *D, llvm::GlobalVariable *GV)
void addBuffer(const HLSLBufferDecl *D)
llvm::Module & getModule() const
CGDebugInfo * getModuleDebugInfo()
const TargetInfo & getTarget() const
void EmitGlobal(GlobalDecl D)
Emit code for a single global function or var decl.
ASTContext & getContext() const
llvm::Constant * GetAddrOfGlobalVar(const VarDecl *D, llvm::Type *Ty=nullptr, ForDefinition_t IsForDefinition=NotForDefinition)
Return the llvm::Constant for the address of the given global variable.
const CodeGenOptions & getCodeGenOpts() const
llvm::Function * getIntrinsic(unsigned IID, ArrayRef< llvm::Type * > Tys=std::nullopt)
void EmitTopLevelDecl(Decl *D)
Emit code for a single top level declaration.
DeclContext - This is used only as base class of specific decl types that can act as declaration cont...
Definition: DeclBase.h:1446
decl_range decls() const
decls_begin/decls_end - Iterate over the declarations stored in this context.
Definition: DeclBase.h:2332
Decl - This represents one declaration (or definition), e.g.
Definition: DeclBase.h:85
T * getAttr() const
Definition: DeclBase.h:578
bool hasAttrs() const
Definition: DeclBase.h:523
bool hasAttr() const
Definition: DeclBase.h:582
Represents a function declaration or definition.
Definition: Decl.h:1959
const ParmVarDecl * getParamDecl(unsigned i) const
Definition: Decl.h:2674
HLSLBufferDecl - Represent a cbuffer or tbuffer declaration.
Definition: Decl.h:4905
StringRef getName() const
Get the name of identifier for this declaration as a StringRef.
Definition: Decl.h:276
Represents a parameter to a function.
Definition: Decl.h:1749
A (possibly-)qualified type.
Definition: Type.h:738
TargetOptions & getTargetOpts() const
Retrieve the target options.
Definition: TargetInfo.h:304
const llvm::Triple & getTriple() const
Returns the target triple of the primary target.
Definition: TargetInfo.h:1220
Represents a type template specialization; the template must be a class template, a type alias templa...
Definition: Type.h:5849
The base class of the type hierarchy.
Definition: Type.h:1607
CXXRecordDecl * getAsCXXRecordDecl() const
Retrieves the CXXRecordDecl that this type refers to, either because the type is a RecordType or beca...
Definition: Type.cpp:1862
const Type * getPointeeOrArrayElementType() const
If this is a pointer type, return the pointee type.
Definition: Type.h:7835
bool isSignedIntegerType() const
Return true if this is an integer type that is signed, according to C99 6.2.5p4 [char,...
Definition: Type.cpp:2126
bool isSpecificBuiltinType(unsigned K) const
Test for a particular builtin type.
Definition: Type.h:7629
bool isUnsignedIntegerType() const
Return true if this is an integer type that is unsigned, according to C99 6.2.5p6 [which returns true...
Definition: Type.cpp:2176
const T * getAs() const
Member-template getAs<specific type>'.
Definition: Type.h:7878
QualType getType() const
Definition: Decl.h:717
Represents a variable declaration or definition.
Definition: Decl.h:918
StorageClass getStorageClass() const
Returns the storage class as written in the source.
Definition: Decl.h:1152
Represents a GCC generic vector type.
Definition: Type.h:3729
#define UINT_MAX
Definition: limits.h:60
bool Const(InterpState &S, CodePtr OpPC, const T &Arg)
Definition: Interp.h:928
The JSON file list parser is used to communicate input to InstallAPI.
@ SC_Static
Definition: Specifiers.h:249
@ Result
The result type of a method or function.
unsigned long uint64_t
YAML serialization mapping.
Definition: Dominators.h:30
BufferResBinding(HLSLResourceBindingAttr *Attr)
std::vector< std::pair< llvm::GlobalVariable *, unsigned > > Constants
Definition: CGHLSLRuntime.h:66