18#include "mlir/IR/Operation.h"
20#include "clang/AST/Attrs.inc"
27#include "llvm/Support/Casting.h"
41 llvm::StringMap<mlir::Operation *> kernelHandles;
44 llvm::DenseMap<mlir::Operation *, mlir::Operation *> kernelStubs;
49 cir::CUDADeviceVarKind flags;
51 llvm::SmallVector<VarInfo, 16> deviceVars;
54 std::unique_ptr<MangleContext> deviceMC;
57 void emitDeviceStubBodyNew(CIRGenFunction &cgf, cir::FuncOp fn,
58 FunctionArgList &args);
59 mlir::Value prepareKernelArgs(CIRGenFunction &cgf, mlir::Location loc,
60 FunctionArgList &args);
61 mlir::Operation *getKernelHandle(cir::FuncOp fn, GlobalDecl gd)
override;
63 mlir::Operation *getKernelStub(mlir::Operation *handle)
override {
64 auto it = kernelStubs.find(handle);
65 assert(it != kernelStubs.end());
68 std::string addPrefixToName(StringRef funcName)
const;
69 std::string addUnderscoredPrefixToName(StringRef funcName)
const;
72 CIRGenNVCUDARuntime(CIRGenModule &cgm);
73 ~CIRGenNVCUDARuntime();
75 void emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn,
76 FunctionArgList &args)
override;
78 void handleVarRegistration(
const VarDecl *vd, cir::GlobalOp var)
override;
79 void finalizeModule()
override;
80 void handleGlobalReplace(cir::GlobalOp oldGV, cir::GlobalOp newGV)
override;
82 void internalizeDeviceSideVar(
const VarDecl *d,
83 cir::GlobalLinkageKind &linkage)
override;
85 std::string getDeviceSideName(
const NamedDecl *nd)
override;
87 void registerDeviceVar(
const VarDecl *vd, cir::GlobalOp &var,
bool isExtern,
90 auto &builder = cgm.getBuilder();
91 var->setAttr(cir::CUDAVarRegistrationInfoAttr::getMnemonic(),
92 cir::CUDAVarRegistrationInfoAttr::get(
95 cir::CUDADeviceVarKind::Variable, isExtern, isConstant,
96 vd->
hasAttr<HIPManagedAttr>()));
97 deviceVars.push_back({
100 cir::CUDADeviceVarKind::Variable,
107std::string CIRGenNVCUDARuntime::addPrefixToName(StringRef funcName)
const {
108 return (prefix + funcName).str();
112CIRGenNVCUDARuntime::addUnderscoredPrefixToName(StringRef funcName)
const {
113 return (
"__" + prefix + funcName).str();
116CIRGenNVCUDARuntime::CIRGenNVCUDARuntime(CIRGenModule &cgm)
117 : CIRGenCUDARuntime(cgm),
118 deviceMC(cgm.getASTContext().cudaNVInitDeviceMC()) {
120 cgm.
errorNYI(
"CIRGenNVCUDARuntime: Offload via LLVM");
127mlir::Value CIRGenNVCUDARuntime::prepareKernelArgs(
CIRGenFunction &cgf,
133 auto voidPtrArrayTy = cir::ArrayType::get(cgm.
voidPtrTy, args.size());
134 mlir::Value kernelArgs =
135 builder.
createAlloca(loc, cir::PointerType::get(voidPtrArrayTy),
138 mlir::Value kernelArgsDecayed =
139 builder.
createCast(cir::CastKind::array_to_ptrdecay, kernelArgs,
142 for (
const auto &[i, arg] : llvm::enumerate(args)) {
145 mlir::Value storePos =
150 builder.CIRBaseBuilderTy::createStore(loc, argAsVoid, storePos);
153 return kernelArgsDecayed;
158void CIRGenNVCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
160 FunctionArgList &args) {
164 cgm.
errorNYI(
"CIRGenNVCUDARuntime: Offload via LLVM");
167 mlir::Location loc = fn.getLoc();
172 mlir::Value kernelArgs = prepareKernelArgs(cgf, loc, args);
189 std::string kernelLaunchAPI =
"LaunchKernel";
191 LangOptions::GPUDefaultStreamKind::PerThread) {
193 kernelLaunchAPI +=
"_spt";
195 kernelLaunchAPI +=
"_ptsz";
198 std::string launchKernelName = addPrefixToName(kernelLaunchAPI);
199 const IdentifierInfo &launchII =
201 FunctionDecl *cudaLaunchKernelFD =
nullptr;
202 for (NamedDecl *result : dc->
lookup(&launchII)) {
203 if (FunctionDecl *fd = dyn_cast<FunctionDecl>(result))
204 cudaLaunchKernelFD = fd;
207 if (cudaLaunchKernelFD ==
nullptr) {
209 "Can't find declaration for " + launchKernelName);
226 builder.
createAlloca(loc, cir::PointerType::get(dim3Ty),
"grid_dim",
229 builder.
createAlloca(loc, cir::PointerType::get(dim3Ty),
"block_dim",
234 loc, cir::PointerType::get(streamTy),
"stream", cgm.
getPointerAlign());
238 sharedMem.getType(), stream.getType()},
240 addUnderscoredPrefixToName(
"PopCallConfiguration"));
251 mlir::Value kernel = [&]() -> mlir::Value {
252 if (cir::GlobalOp globalOp = llvm::dyn_cast_or_null<cir::GlobalOp>(
253 kernelHandles[fn.getSymName()])) {
254 cir::PointerType kernelTy = cir::PointerType::get(globalOp.getSymType());
255 mlir::Value kernelVal = cir::GetGlobalOp::create(builder, loc, kernelTy,
256 globalOp.getSymName());
260 if (cir::FuncOp funcOp = llvm::dyn_cast_or_null<cir::FuncOp>(
261 kernelHandles[fn.getSymName()])) {
262 cir::PointerType kernelTy =
263 cir::PointerType::get(funcOp.getFunctionType());
264 mlir::Value kernelVal =
265 cir::GetGlobalOp::create(builder, loc, kernelTy, funcOp.getSymName());
269 llvm_unreachable(
"Expected stub handle to be cir::GlobalOp or FuncOp");
272 CallArgList launchArgs;
284 RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, sharedMem)),
286 launchArgs.
add(
RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, stream)),
289 mlir::Type launchTy =
293 const CIRGenFunctionInfo &callInfo =
296 ReturnValueSlot(), launchArgs);
300 cgm.
errorNYI(
"MSVC CUDA stub handling");
303void CIRGenNVCUDARuntime::emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn,
304 FunctionArgList &args) {
307 llvm::dyn_cast<cir::GlobalOp>(kernelHandles[fn.getSymName()])) {
309 mlir::Type fnPtrTy = globalOp.getSymType();
310 auto sym = mlir::FlatSymbolRefAttr::get(fn.getSymNameAttr());
311 auto gv = cir::GlobalViewAttr::get(fnPtrTy, sym);
313 globalOp->setAttr(
"initial_value", gv);
314 globalOp->removeAttr(
"sym_visibility");
315 globalOp->setAttr(
"alignment", builder.getI64IntegerAttr(
321 CudaFeature::CUDA_USES_NEW_LAUNCH) ||
324 emitDeviceStubBodyNew(cgf, fn, args);
326 cgm.
errorNYI(
"Emit Stub Body Legacy");
330 return new CIRGenNVCUDARuntime(cgm);
333CIRGenNVCUDARuntime::~CIRGenNVCUDARuntime() {}
335mlir::Operation *CIRGenNVCUDARuntime::getKernelHandle(cir::FuncOp fn,
339 auto it = kernelHandles.find(fn.getSymName());
340 if (it != kernelHandles.end()) {
341 mlir::Operation *oldHandle = it->second;
343 if (kernelStubs[oldHandle] == fn)
351 kernelStubs[oldHandle] = fn;
356 kernelStubs.erase(oldHandle);
361 kernelHandles[fn.getSymName()] = fn;
362 kernelStubs[fn] = fn;
370 cir::PointerType fnPtrTy = builder.
getPointerTo(fn.getFunctionType());
371 cir::GlobalOp globalOp =
374 globalOp->setAttr(
"alignment", builder.getI64IntegerAttr(
378 kernelHandles[fn.getSymName()] = globalOp;
379 kernelStubs[globalOp] = fn;
384void CIRGenNVCUDARuntime::internalizeDeviceSideVar(
385 const VarDecl *d, cir::GlobalLinkageKind &linkage) {
388 "internalizeDeviceSideVar: GPU Relocatable Device Code (RDC)");
395 if (d->
hasAttr<CUDADeviceAttr>() || d->
hasAttr<CUDAConstantAttr>() ||
396 d->
hasAttr<CUDASharedAttr>()) {
397 linkage = cir::GlobalLinkageKind::InternalLinkage;
403 "internalizeDeviceSideVar: CUDA Surface/Texture support");
406std::string CIRGenNVCUDARuntime::getDeviceSideName(
const NamedDecl *nd) {
409 if (
auto *fd = dyn_cast<FunctionDecl>(nd))
410 gd = GlobalDecl(fd, KernelReferenceKind::Kernel);
413 std::string deviceSideName;
420 SmallString<256> buffer;
421 llvm::raw_svector_ostream
out(buffer);
423 deviceSideName = std::string(
out.str());
430 SmallString<256> buffer;
431 llvm::raw_svector_ostream
out(buffer);
432 out << deviceSideName;
434 deviceSideName = std::string(
out.str());
436 return deviceSideName;
439void CIRGenNVCUDARuntime::handleVarRegistration(
const VarDecl *vd,
441 if (vd->
hasAttr<CUDADeviceAttr>() || vd->
hasAttr<CUDAConstantAttr>()) {
457 vd->
hasAttr<HIPManagedAttr>()) {
459 vd->
hasAttr<CUDAConstantAttr>());
466 "handleVarRegistration: Surface and Texture registration");
470void CIRGenNVCUDARuntime::handleGlobalReplace(cir::GlobalOp oldGV,
471 cir::GlobalOp newGV) {
472 for (
auto &info : deviceVars) {
473 if (info.var == oldGV)
478void CIRGenNVCUDARuntime::finalizeModule() {
492 for (
auto &&info : deviceVars) {
493 auto kind = info.flags;
494 bool isDecl = info.var.isDeclaration();
496 bool isVarOrSurfaceOrTexture = (
kind == cir::CUDADeviceVarKind::Variable ||
497 kind == cir::CUDADeviceVarKind::Surface ||
498 kind == cir::CUDADeviceVarKind::Texture);
499 bool isUsed = info.d->isUsed();
500 bool hasUsedAttr = info.d->hasAttr<UsedAttr>();
501 if (!isDecl && !isLocalLinkage && isVarOrSurfaceOrTexture && isUsed &&
503 if (
auto globalValue = mlir::dyn_cast<cir::CIRGlobalValueInterface>(
504 info.var.getOperation())) {
Defines the clang::ASTContext interface.
Provides definitions for the various language-specific address spaces.
*collection of selector each with an associated kind and an ordered *collection of selectors A selector has a kind
__CUDA_BUILTIN_VAR __cuda_builtin_blockDim_t blockDim
__CUDA_BUILTIN_VAR __cuda_builtin_gridDim_t gridDim
mlir::Value createCast(mlir::Location loc, cir::CastKind kind, mlir::Value src, mlir::Type newTy)
cir::PtrStrideOp createPtrStride(mlir::Location loc, mlir::Value base, mlir::Value stride)
cir::PointerType getPointerTo(mlir::Type ty)
mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType, llvm::StringRef name, mlir::IntegerAttr alignment, mlir::Value dynAllocSize)
mlir::Value createBitcast(mlir::Value src, mlir::Type newTy)
TranslationUnitDecl * getTranslationUnitDecl() const
bool shouldExternalize(const Decl *D) const
Whether a C++ static variable or CUDA/HIP kernel should be externalized.
llvm::SetVector< const VarDecl * > CUDADeviceVarODRUsedByHost
Keep track of CUDA/HIP device-side variables ODR-used by host code.
const TargetInfo & getTargetInfo() const
mlir::Value getPointer() const
cir::ConstantOp getConstInt(mlir::Location loc, llvm::APSInt intVal)
clang::MangleContext & getMangleContext()
Gets the mangle context.
static CIRGenCallee forDirect(mlir::Operation *funcPtr, const CIRGenCalleeInfo &abstractInfo=CIRGenCalleeInfo())
CIRGenTypes & getTypes() const
const clang::LangOptions & getLangOpts() const
const clang::Decl * curFuncDecl
Address getAddrOfLocalVar(const clang::VarDecl *vd)
Return the address of a local variable.
RValue emitCall(const CIRGenFunctionInfo &funcInfo, const CIRGenCallee &callee, ReturnValueSlot returnValue, const CallArgList &args, cir::CIRCallOpInterface *callOp, mlir::Location loc)
mlir::Value emitRuntimeCall(mlir::Location loc, cir::FuncOp callee, llvm::ArrayRef< mlir::Value > args={}, mlir::NamedAttrList attrs={})
This class organizes the cross-function state that is used while generating CIR code.
llvm::StringRef getMangledName(clang::GlobalDecl gd)
DiagnosticBuilder errorNYI(SourceLocation, llvm::StringRef)
Helpers to emit "not yet implemented" error diagnostics.
clang::ASTContext & getASTContext() const
CIRGenBuilderTy & getBuilder()
const clang::TargetInfo & getTarget() const
void error(SourceLocation loc, llvm::StringRef error)
Emit a general error that something can't be done.
cir::FuncOp createRuntimeFunction(cir::FuncType ty, llvm::StringRef name, mlir::NamedAttrList extraAttrs={}, bool isLocal=false, bool assumeConvergent=false)
const clang::LangOptions & getLangOpts() const
void printPostfixForExternalizedDecl(llvm::raw_ostream &os, const Decl *d)
Print the postfix for externalized static variable or kernels for single source offloading languages ...
cir::GlobalOp createGlobalOp(mlir::Location loc, llvm::StringRef name, mlir::Type t, bool isConstant=false, mlir::ptr::MemorySpaceAttrInterface addrSpace={}, mlir::Operation *insertPoint=nullptr)
void addCompilerUsedGlobal(cir::CIRGlobalValueInterface gv)
Add a global value to the llvmCompilerUsed list.
CIRGenCXXABI & getCXXABI() const
const CIRGenFunctionInfo & arrangeFunctionDeclaration(const clang::FunctionDecl *fd)
Free functions are functions that are compatible with an ordinary C function pointer type.
mlir::Type convertType(clang::QualType type)
Convert a Clang type into a mlir::Type.
void add(RValue rvalue, clang::QualType type)
Type for representing both the decl and type of parameters to a function.
static RValue get(mlir::Value v)
static RValue getAggregate(Address addr, bool isVolatile=false)
Convert an Address to an RValue.
QuantityType getQuantity() const
getQuantity - Get the raw integer representation of this quantity.
static CharUnits fromQuantity(QuantityType Quantity)
fromQuantity - Construct a CharUnits quantity from a raw integer type.
lookup_result lookup(DeclarationName Name) const
lookup - Find the declarations (if any) with the given Name in this context.
SourceLocation getLocation() const
const ParmVarDecl * getParamDecl(unsigned i) const
GlobalDecl - represents a global declaration.
GlobalDecl getWithKernelReferenceKind(KernelReferenceKind Kind)
StringRef getName() const
Return the actual identifier string.
IdentifierInfo & get(StringRef Name)
Return the identifier token info for the specified named identifier.
GPUDefaultStreamKind GPUDefaultStream
The default stream kind used for HIP kernel launching.
bool shouldMangleDeclName(const NamedDecl *D)
void mangleName(GlobalDecl GD, raw_ostream &)
IdentifierInfo * getIdentifier() const
Get the identifier that names this declaration, if there is one.
bool isMicrosoft() const
Is this ABI an MSVC-compatible ABI?
TargetCXXABI getCXXABI() const
Get the C++ ABI currently in use.
const llvm::VersionTuple & getSDKVersion() const
static DeclContext * castToDeclContext(const TranslationUnitDecl *D)
bool isCUDADeviceBuiltinSurfaceType() const
Check if the type is the CUDA device builtin surface type.
bool isCUDADeviceBuiltinTextureType() const
Check if the type is the CUDA device builtin texture type.
SourceRange getSourceRange() const override LLVM_READONLY
Source range that this declaration covers.
bool isInline() const
Whether this variable is (C++1z) inline.
bool hasExternalStorage() const
Returns true if a variable has extern or private_extern storage.
DefinitionKind hasDefinition(ASTContext &) const
Check whether this variable is defined in this translation unit.
static bool isLocalLinkage(GlobalLinkageKind linkage)
CIRGenCUDARuntime * createNVCUDARuntime(CIRGenModule &cgm)
constexpr Variable var(Literal L)
Returns the variable of L.
@ Address
A pointer to a ValueDecl.
The JSON file list parser is used to communicate input to InstallAPI.
if(T->getSizeExpr()) TRY_TO(TraverseStmt(const_cast< Expr * >(T -> getSizeExpr())))
bool CudaFeatureEnabled(llvm::VersionTuple, CudaFeature)
U cast(CodeGen::Address addr)
clang::CharUnits getPointerAlign() const
clang::CharUnits getSizeAlign() const
cir::PointerType voidPtrTy
void* in address space 0