clang 18.0.0git
CGHLSLRuntime.cpp
Go to the documentation of this file.
1//===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This provides an abstract class for HLSL code generation. Concrete
10// subclasses of this implement code generation for specific HLSL
11// runtime libraries.
12//
13//===----------------------------------------------------------------------===//
14
15#include "CGHLSLRuntime.h"
16#include "CGDebugInfo.h"
17#include "CodeGenModule.h"
18#include "clang/AST/Decl.h"
20#include "llvm/IR/IntrinsicsDirectX.h"
21#include "llvm/IR/Metadata.h"
22#include "llvm/IR/Module.h"
23#include "llvm/Support/FormatVariadic.h"
24
25using namespace clang;
26using namespace CodeGen;
27using namespace clang::hlsl;
28using namespace llvm;
29
30namespace {
31
32void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
33 // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
34 // Assume ValVersionStr is legal here.
35 VersionTuple Version;
36 if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
37 Version.getSubminor() || !Version.getMinor()) {
38 return;
39 }
40
41 uint64_t Major = Version.getMajor();
42 uint64_t Minor = *Version.getMinor();
43
44 auto &Ctx = M.getContext();
45 IRBuilder<> B(M.getContext());
46 MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),
47 ConstantAsMetadata::get(B.getInt32(Minor))});
48 StringRef DXILValKey = "dx.valver";
49 auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);
50 DXILValMD->addOperand(Val);
51}
52void addDisableOptimizations(llvm::Module &M) {
53 StringRef Key = "dx.disable_optimizations";
54 M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);
55}
56// cbuffer will be translated into global variable in special address space.
57// If translate into C,
58// cbuffer A {
59// float a;
60// float b;
61// }
62// float foo() { return a + b; }
63//
64// will be translated into
65//
66// struct A {
67// float a;
68// float b;
69// } cbuffer_A __attribute__((address_space(4)));
70// float foo() { return cbuffer_A.a + cbuffer_A.b; }
71//
72// layoutBuffer will create the struct A type.
73// replaceBuffer will replace use of global variable a and b with cbuffer_A.a
74// and cbuffer_A.b.
75//
76void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) {
77 if (Buf.Constants.empty())
78 return;
79
80 std::vector<llvm::Type *> EltTys;
81 for (auto &Const : Buf.Constants) {
82 GlobalVariable *GV = Const.first;
83 Const.second = EltTys.size();
84 llvm::Type *Ty = GV->getValueType();
85 EltTys.emplace_back(Ty);
86 }
87 Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);
88}
89
90GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) {
91 // Create global variable for CB.
92 GlobalVariable *CBGV = new GlobalVariable(
93 Buf.LayoutStruct, /*isConstant*/ true,
94 GlobalValue::LinkageTypes::ExternalLinkage, nullptr,
95 llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."),
96 GlobalValue::NotThreadLocal);
97
98 IRBuilder<> B(CBGV->getContext());
99 Value *ZeroIdx = B.getInt32(0);
100 // Replace Const use with CB use.
101 for (auto &[GV, Offset] : Buf.Constants) {
102 Value *GEP =
103 B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});
104
105 assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&
106 "constant type mismatch");
107
108 // Replace.
109 GV->replaceAllUsesWith(GEP);
110 // Erase GV.
111 GV->removeDeadConstantUsers();
112 GV->eraseFromParent();
113 }
114 return CBGV;
115}
116
117} // namespace
118
119void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {
120 if (D->getStorageClass() == SC_Static) {
121 // For static inside cbuffer, take as global static.
122 // Don't add to cbuffer.
123 CGM.EmitGlobal(D);
124 return;
125 }
126
127 auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D));
128 // Add debug info for constVal.
130 if (CGM.getCodeGenOpts().getDebugInfo() >=
131 codegenoptions::DebugInfoKind::LimitedDebugInfo)
132 DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D);
133
134 // FIXME: support packoffset.
135 // See https://github.com/llvm/llvm-project/issues/57914.
136 uint32_t Offset = 0;
137 bool HasUserOffset = false;
138
139 unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;
140 CB.Constants.emplace_back(std::make_pair(GV, LowerBound));
141}
142
143void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {
144 for (Decl *it : DC->decls()) {
145 if (auto *ConstDecl = dyn_cast<VarDecl>(it)) {
146 addConstant(ConstDecl, CB);
147 } else if (isa<CXXRecordDecl, EmptyDecl>(it)) {
148 // Nothing to do for this declaration.
149 } else if (isa<FunctionDecl>(it)) {
150 // A function within an cbuffer is effectively a top-level function,
151 // as it only refers to globally scoped declarations.
153 }
154 }
155}
156
158 Buffers.emplace_back(Buffer(D));
159 addBufferDecls(D, Buffers.back());
160}
161
163 auto &TargetOpts = CGM.getTarget().getTargetOpts();
164 llvm::Module &M = CGM.getModule();
165 Triple T(M.getTargetTriple());
166 if (T.getArch() == Triple::ArchType::dxil)
167 addDxilValVersion(TargetOpts.DxilValidatorVersion, M);
168
170 if (CGM.getCodeGenOpts().OptimizationLevel == 0)
171 addDisableOptimizations(M);
172
173 const DataLayout &DL = M.getDataLayout();
174
175 for (auto &Buf : Buffers) {
176 layoutBuffer(Buf, DL);
177 GlobalVariable *GV = replaceBuffer(Buf);
178 M.insertGlobalVariable(GV);
179 llvm::hlsl::ResourceClass RC = Buf.IsCBuffer
180 ? llvm::hlsl::ResourceClass::CBuffer
181 : llvm::hlsl::ResourceClass::SRV;
182 llvm::hlsl::ResourceKind RK = Buf.IsCBuffer
183 ? llvm::hlsl::ResourceKind::CBuffer
184 : llvm::hlsl::ResourceKind::TBuffer;
185 std::string TyName =
186 Buf.Name.str() + (Buf.IsCBuffer ? ".cb." : ".tb.") + "ty";
187 addBufferResourceAnnotation(GV, TyName, RC, RK, Buf.Binding);
188 }
189}
190
192 : Name(D->getName()), IsCBuffer(D->isCBuffer()),
193 Binding(D->getAttr<HLSLResourceBindingAttr>()) {}
194
195void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
196 llvm::StringRef TyName,
197 llvm::hlsl::ResourceClass RC,
198 llvm::hlsl::ResourceKind RK,
199 BufferResBinding &Binding) {
200 llvm::Module &M = CGM.getModule();
201
202 NamedMDNode *ResourceMD = nullptr;
203 switch (RC) {
204 case llvm::hlsl::ResourceClass::UAV:
205 ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs");
206 break;
207 case llvm::hlsl::ResourceClass::SRV:
208 ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs");
209 break;
210 case llvm::hlsl::ResourceClass::CBuffer:
211 ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs");
212 break;
213 default:
214 assert(false && "Unsupported buffer type!");
215 return;
216 }
217
218 assert(ResourceMD != nullptr &&
219 "ResourceMD must have been set by the switch above.");
220
221 llvm::hlsl::FrontendResource Res(
222 GV, TyName, RK, Binding.Reg.value_or(UINT_MAX), Binding.Space);
223 ResourceMD->addOperand(Res.getMetadata());
224}
225
226void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {
227 const Type *Ty = D->getType()->getPointeeOrArrayElementType();
228 if (!Ty)
229 return;
230 const auto *RD = Ty->getAsCXXRecordDecl();
231 if (!RD)
232 return;
233 const auto *Attr = RD->getAttr<HLSLResourceAttr>();
234 if (!Attr)
235 return;
236
237 llvm::hlsl::ResourceClass RC = Attr->getResourceClass();
238 llvm::hlsl::ResourceKind RK = Attr->getResourceKind();
239
240 QualType QT(Ty, 0);
241 BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>());
242 addBufferResourceAnnotation(GV, QT.getAsString(), RC, RK, Binding);
243}
244
246 HLSLResourceBindingAttr *Binding) {
247 if (Binding) {
248 llvm::APInt RegInt(64, 0);
249 Binding->getSlot().substr(1).getAsInteger(10, RegInt);
250 Reg = RegInt.getLimitedValue();
251 llvm::APInt SpaceInt(64, 0);
252 Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);
253 Space = SpaceInt.getLimitedValue();
254 } else {
255 Space = 0;
256 }
257}
258
260 const FunctionDecl *FD, llvm::Function *Fn) {
261 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
262 assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
263 const StringRef ShaderAttrKindStr = "hlsl.shader";
264 Fn->addFnAttr(ShaderAttrKindStr,
265 ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType()));
266 if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
267 const StringRef NumThreadsKindStr = "hlsl.numthreads";
268 std::string NumThreadsStr =
269 formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),
270 NumThreadsAttr->getZ());
271 Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);
272 }
273}
274
275static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {
276 if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) {
277 Value *Result = PoisonValue::get(Ty);
278 for (unsigned I = 0; I < VT->getNumElements(); ++I) {
279 Value *Elt = B.CreateCall(F, {B.getInt32(I)});
280 Result = B.CreateInsertElement(Result, Elt, I);
281 }
282 return Result;
283 }
284 return B.CreateCall(F, {B.getInt32(0)});
285}
286
287llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
288 const ParmVarDecl &D,
289 llvm::Type *Ty) {
290 assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
291 if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
292 llvm::Function *DxGroupIndex =
293 CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group);
294 return B.CreateCall(FunctionCallee(DxGroupIndex));
295 }
296 if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
297 llvm::Function *DxThreadID = CGM.getIntrinsic(Intrinsic::dx_thread_id);
298 return buildVectorInput(B, DxThreadID, Ty);
299 }
300 assert(false && "Unhandled parameter attribute");
301 return nullptr;
302}
303
305 llvm::Function *Fn) {
306 llvm::Module &M = CGM.getModule();
307 llvm::LLVMContext &Ctx = M.getContext();
308 auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
309 Function *EntryFn =
310 Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M);
311
312 // Copy function attributes over, we have no argument or return attributes
313 // that can be valid on the real entry.
314 AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,
315 Fn->getAttributes().getFnAttrs());
316 EntryFn->setAttributes(NewAttrs);
317 setHLSLEntryAttributes(FD, EntryFn);
318
319 // Set the called function as internal linkage.
320 Fn->setLinkage(GlobalValue::InternalLinkage);
321
322 BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
323 IRBuilder<> B(BB);
325 // FIXME: support struct parameters where semantics are on members.
326 // See: https://github.com/llvm/llvm-project/issues/57874
327 unsigned SRetOffset = 0;
328 for (const auto &Param : Fn->args()) {
329 if (Param.hasStructRetAttr()) {
330 // FIXME: support output.
331 // See: https://github.com/llvm/llvm-project/issues/57874
332 SRetOffset = 1;
333 Args.emplace_back(PoisonValue::get(Param.getType()));
334 continue;
335 }
336 const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
337 Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
338 }
339
340 CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args);
341 (void)CI;
342 // FIXME: Handle codegen for return type semantics.
343 // See: https://github.com/llvm/llvm-project/issues/57875
344 B.CreateRetVoid();
345}
346
347static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M,
348 bool CtorOrDtor) {
349 const auto *GV =
350 M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
351 if (!GV)
352 return;
353 const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer());
354 if (!CA)
355 return;
356 // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
357 // HLSL neither supports priorities or COMDat values, so we will check those
358 // in an assert but not handle them.
359
361 for (const auto &Ctor : CA->operands()) {
362 if (isa<ConstantAggregateZero>(Ctor))
363 continue;
364 ConstantStruct *CS = cast<ConstantStruct>(Ctor);
365
366 assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&
367 "HLSL doesn't support setting priority for global ctors.");
368 assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&
369 "HLSL doesn't support COMDat for global ctors.");
370 Fns.push_back(cast<Function>(CS->getOperand(1)));
371 }
372}
373
375 llvm::Module &M = CGM.getModule();
378 gatherFunctions(CtorFns, M, true);
379 gatherFunctions(DtorFns, M, false);
380
381 // Insert a call to the global constructor at the beginning of the entry block
382 // to externally exported functions. This is a bit of a hack, but HLSL allows
383 // global constructors, but doesn't support driver initialization of globals.
384 for (auto &F : M.functions()) {
385 if (!F.hasFnAttribute("hlsl.shader"))
386 continue;
387 IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin());
388 for (auto *Fn : CtorFns)
389 B.CreateCall(FunctionCallee(Fn));
390
391 // Insert global dtors before the terminator of the last instruction
392 B.SetInsertPoint(F.back().getTerminator());
393 for (auto *Fn : DtorFns)
394 B.CreateCall(FunctionCallee(Fn));
395 }
396
397 // No need to keep global ctors/dtors for non-lib profile after call to
398 // ctors/dtors added for entry.
399 Triple T(M.getTargetTriple());
400 if (T.getEnvironment() != Triple::EnvironmentType::Library) {
401 if (auto *GV = M.getNamedGlobal("llvm.global_ctors"))
402 GV->eraseFromParent();
403 if (auto *GV = M.getNamedGlobal("llvm.global_dtors"))
404 GV->eraseFromParent();
405 }
406}
static void gatherFunctions(SmallVectorImpl< Function * > &Fns, llvm::Module &M, bool CtorOrDtor)
static Value * buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty)
static std::string getName(const CallEvent &Call)
Defines the clang::TargetOptions class.
Attr - This represents one attribute.
Definition: Attr.h:41
This class gathers all debug information during compilation and is responsible for emitting to llvm g...
Definition: CGDebugInfo.h:55
void setHLSLEntryAttributes(const FunctionDecl *FD, llvm::Function *Fn)
void emitEntryFunction(const FunctionDecl *FD, llvm::Function *Fn)
llvm::Value * emitInputSemantic(llvm::IRBuilder<> &B, const ParmVarDecl &D, llvm::Type *Ty)
void annotateHLSLResource(const VarDecl *D, llvm::GlobalVariable *GV)
void addBuffer(const HLSLBufferDecl *D)
llvm::Module & getModule() const
CGDebugInfo * getModuleDebugInfo()
const TargetInfo & getTarget() const
void EmitGlobal(GlobalDecl D)
Emit code for a single global function or var decl.
llvm::Constant * GetAddrOfGlobalVar(const VarDecl *D, llvm::Type *Ty=nullptr, ForDefinition_t IsForDefinition=NotForDefinition)
Return the llvm::Constant for the address of the given global variable.
const CodeGenOptions & getCodeGenOpts() const
llvm::Function * getIntrinsic(unsigned IID, ArrayRef< llvm::Type * > Tys=std::nullopt)
void EmitTopLevelDecl(Decl *D)
Emit code for a single top level declaration.
DeclContext - This is used only as base class of specific decl types that can act as declaration cont...
Definition: DeclBase.h:1435
decl_range decls() const
decls_begin/decls_end - Iterate over the declarations stored in this context.
Definition: DeclBase.h:2320
Decl - This represents one declaration (or definition), e.g.
Definition: DeclBase.h:85
T * getAttr() const
Definition: DeclBase.h:577
bool hasAttrs() const
Definition: DeclBase.h:523
bool hasAttr() const
Definition: DeclBase.h:581
Represents a function declaration or definition.
Definition: Decl.h:1957
const ParmVarDecl * getParamDecl(unsigned i) const
Definition: Decl.h:2664
HLSLBufferDecl - Represent a cbuffer or tbuffer declaration.
Definition: Decl.h:4909
StringRef getName() const
Get the name of identifier for this declaration as a StringRef.
Definition: Decl.h:275
Represents a parameter to a function.
Definition: Decl.h:1747
A (possibly-)qualified type.
Definition: Type.h:736
static std::string getAsString(SplitQualType split, const PrintingPolicy &Policy)
Definition: Type.h:1120
TargetOptions & getTargetOpts() const
Retrieve the target options.
Definition: TargetInfo.h:304
The base class of the type hierarchy.
Definition: Type.h:1602
CXXRecordDecl * getAsCXXRecordDecl() const
Retrieves the CXXRecordDecl that this type refers to, either because the type is a RecordType or beca...
Definition: Type.cpp:1819
const Type * getPointeeOrArrayElementType() const
If this is a pointer type, return the pointee type.
Definition: Type.h:7515
QualType getType() const
Definition: Decl.h:715
Represents a variable declaration or definition.
Definition: Decl.h:916
StorageClass getStorageClass() const
Returns the storage class as written in the source.
Definition: Decl.h:1150
#define UINT_MAX
Definition: limits.h:60
bool Const(InterpState &S, CodePtr OpPC, const T &Arg)
Definition: Interp.h:886
@ SC_Static
Definition: Specifiers.h:247
@ Result
The result type of a method or function.
unsigned long uint64_t
YAML serialization mapping.
Definition: Dominators.h:30
BufferResBinding(HLSLResourceBindingAttr *Attr)
std::vector< std::pair< llvm::GlobalVariable *, unsigned > > Constants
Definition: CGHLSLRuntime.h:66