18#include "mlir/IR/Operation.h"
20#include "clang/AST/Attrs.inc"
27#include "llvm/Support/Casting.h"
41 llvm::StringMap<mlir::Operation *> kernelHandles;
44 llvm::DenseMap<mlir::Operation *, mlir::Operation *> kernelStubs;
49 cir::CUDADeviceVarKind flags;
51 llvm::SmallVector<VarInfo, 16> deviceVars;
54 std::unique_ptr<MangleContext> deviceMC;
57 void emitDeviceStubBodyNew(CIRGenFunction &cgf, cir::FuncOp fn,
58 FunctionArgList &args);
59 mlir::Value prepareKernelArgs(CIRGenFunction &cgf, mlir::Location loc,
60 FunctionArgList &args);
61 mlir::Operation *getKernelHandle(cir::FuncOp fn, GlobalDecl gd)
override;
63 mlir::Operation *getKernelStub(mlir::Operation *handle)
override {
64 auto it = kernelStubs.find(handle);
65 assert(it != kernelStubs.end());
68 std::string addPrefixToName(StringRef funcName)
const;
69 std::string addUnderscoredPrefixToName(StringRef funcName)
const;
72 CIRGenNVCUDARuntime(CIRGenModule &cgm);
73 ~CIRGenNVCUDARuntime();
75 void emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn,
76 FunctionArgList &args)
override;
78 void handleVarRegistration(
const VarDecl *vd, cir::GlobalOp var)
override;
79 void finalizeModule()
override;
80 void handleGlobalReplace(cir::GlobalOp oldGV, cir::GlobalOp newGV)
override;
82 void internalizeDeviceSideVar(
const VarDecl *d,
83 cir::GlobalLinkageKind &linkage)
override;
85 std::string getDeviceSideName(
const NamedDecl *nd)
override;
87 void registerDeviceVar(
const VarDecl *vd, cir::GlobalOp &var,
bool isExtern,
90 auto &builder = cgm.getBuilder();
91 var->setAttr(cir::CUDAVarRegistrationInfoAttr::getMnemonic(),
92 cir::CUDAVarRegistrationInfoAttr::get(
95 cir::CUDADeviceVarKind::Variable, isExtern, isConstant,
96 vd->
hasAttr<HIPManagedAttr>()));
97 deviceVars.push_back({
100 cir::CUDADeviceVarKind::Variable,
107std::string CIRGenNVCUDARuntime::addPrefixToName(StringRef funcName)
const {
108 return (prefix + funcName).str();
112CIRGenNVCUDARuntime::addUnderscoredPrefixToName(StringRef funcName)
const {
113 return (
"__" + prefix + funcName).str();
116CIRGenNVCUDARuntime::CIRGenNVCUDARuntime(CIRGenModule &cgm)
117 : CIRGenCUDARuntime(cgm),
118 deviceMC(cgm.getASTContext().cudaNVInitDeviceMC()) {
120 cgm.
errorNYI(
"CIRGenNVCUDARuntime: Offload via LLVM");
127mlir::Value CIRGenNVCUDARuntime::prepareKernelArgs(
CIRGenFunction &cgf,
133 auto voidPtrArrayTy = cir::ArrayType::get(cgm.
voidPtrTy, args.size());
135 loc, cir::PointerType::get(voidPtrArrayTy), voidPtrArrayTy,
"kernel_args",
138 mlir::Value kernelArgsDecayed =
139 builder.
createCast(cir::CastKind::array_to_ptrdecay, kernelArgs,
142 for (
const auto &[i, arg] : llvm::enumerate(args)) {
145 mlir::Value storePos =
150 builder.CIRBaseBuilderTy::createStore(loc, argAsVoid, storePos);
153 return kernelArgsDecayed;
158void CIRGenNVCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
160 FunctionArgList &args) {
164 cgm.
errorNYI(
"CIRGenNVCUDARuntime: Offload via LLVM");
167 mlir::Location loc = fn.getLoc();
172 mlir::Value kernelArgs = prepareKernelArgs(cgf, loc, args);
189 std::string kernelLaunchAPI =
"LaunchKernel";
191 LangOptions::GPUDefaultStreamKind::PerThread) {
193 kernelLaunchAPI +=
"_spt";
195 kernelLaunchAPI +=
"_ptsz";
198 std::string launchKernelName = addPrefixToName(kernelLaunchAPI);
199 const IdentifierInfo &launchII =
201 FunctionDecl *cudaLaunchKernelFD =
nullptr;
202 for (NamedDecl *result : dc->
lookup(&launchII)) {
203 if (FunctionDecl *fd = dyn_cast<FunctionDecl>(result))
204 cudaLaunchKernelFD = fd;
207 if (cudaLaunchKernelFD ==
nullptr) {
209 "Can't find declaration for " + launchKernelName);
226 builder.
createAlloca(loc, cir::PointerType::get(dim3Ty), dim3Ty,
229 builder.
createAlloca(loc, cir::PointerType::get(dim3Ty), dim3Ty,
231 mlir::Value sharedMem =
235 builder.
createAlloca(loc, cir::PointerType::get(streamTy), streamTy,
240 sharedMem.getType(), stream.getType()},
242 addUnderscoredPrefixToName(
"PopCallConfiguration"));
253 mlir::Value kernel = [&]() -> mlir::Value {
254 if (cir::GlobalOp globalOp = llvm::dyn_cast_or_null<cir::GlobalOp>(
255 kernelHandles[fn.getSymName()])) {
256 cir::PointerType kernelTy = cir::PointerType::get(globalOp.getSymType());
257 mlir::Value kernelVal = cir::GetGlobalOp::create(builder, loc, kernelTy,
258 globalOp.getSymName());
262 if (cir::FuncOp funcOp = llvm::dyn_cast_or_null<cir::FuncOp>(
263 kernelHandles[fn.getSymName()])) {
264 cir::PointerType kernelTy =
265 cir::PointerType::get(funcOp.getFunctionType());
266 mlir::Value kernelVal =
267 cir::GetGlobalOp::create(builder, loc, kernelTy, funcOp.getSymName());
271 llvm_unreachable(
"Expected stub handle to be cir::GlobalOp or FuncOp");
274 CallArgList launchArgs;
286 RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, sharedMem)),
288 launchArgs.
add(
RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, stream)),
291 mlir::Type launchTy =
295 const CIRGenFunctionInfo &callInfo =
298 ReturnValueSlot(), launchArgs);
302 cgm.
errorNYI(
"MSVC CUDA stub handling");
305void CIRGenNVCUDARuntime::emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn,
306 FunctionArgList &args) {
309 llvm::dyn_cast<cir::GlobalOp>(kernelHandles[fn.getSymName()])) {
311 mlir::Type fnPtrTy = globalOp.getSymType();
312 auto sym = mlir::FlatSymbolRefAttr::get(fn.getSymNameAttr());
313 auto gv = cir::GlobalViewAttr::get(fnPtrTy, sym);
315 globalOp->setAttr(
"initial_value", gv);
316 globalOp->removeAttr(
"sym_visibility");
317 globalOp->setAttr(
"alignment", builder.getI64IntegerAttr(
323 CudaFeature::CUDA_USES_NEW_LAUNCH) ||
326 emitDeviceStubBodyNew(cgf, fn, args);
328 cgm.
errorNYI(
"Emit Stub Body Legacy");
332 return new CIRGenNVCUDARuntime(cgm);
335CIRGenNVCUDARuntime::~CIRGenNVCUDARuntime() {}
337mlir::Operation *CIRGenNVCUDARuntime::getKernelHandle(cir::FuncOp fn,
341 auto it = kernelHandles.find(fn.getSymName());
342 if (it != kernelHandles.end()) {
343 mlir::Operation *oldHandle = it->second;
345 if (kernelStubs[oldHandle] == fn)
353 kernelStubs[oldHandle] = fn;
358 kernelStubs.erase(oldHandle);
363 kernelHandles[fn.getSymName()] = fn;
364 kernelStubs[fn] = fn;
372 cir::PointerType fnPtrTy = builder.
getPointerTo(fn.getFunctionType());
373 cir::GlobalOp globalOp =
376 globalOp->setAttr(
"alignment", builder.getI64IntegerAttr(
380 kernelHandles[fn.getSymName()] = globalOp;
381 kernelStubs[globalOp] = fn;
386void CIRGenNVCUDARuntime::internalizeDeviceSideVar(
387 const VarDecl *d, cir::GlobalLinkageKind &linkage) {
390 "internalizeDeviceSideVar: GPU Relocatable Device Code (RDC)");
397 if (d->
hasAttr<CUDADeviceAttr>() || d->
hasAttr<CUDAConstantAttr>() ||
398 d->
hasAttr<CUDASharedAttr>()) {
399 linkage = cir::GlobalLinkageKind::InternalLinkage;
405 "internalizeDeviceSideVar: CUDA Surface/Texture support");
408std::string CIRGenNVCUDARuntime::getDeviceSideName(
const NamedDecl *nd) {
411 if (
auto *fd = dyn_cast<FunctionDecl>(nd))
412 gd = GlobalDecl(fd, KernelReferenceKind::Kernel);
415 std::string deviceSideName;
422 SmallString<256> buffer;
423 llvm::raw_svector_ostream
out(buffer);
425 deviceSideName = std::string(
out.str());
432 SmallString<256> buffer;
433 llvm::raw_svector_ostream
out(buffer);
434 out << deviceSideName;
436 deviceSideName = std::string(
out.str());
438 return deviceSideName;
441void CIRGenNVCUDARuntime::handleVarRegistration(
const VarDecl *vd,
443 if (vd->
hasAttr<CUDADeviceAttr>() || vd->
hasAttr<CUDAConstantAttr>()) {
459 vd->
hasAttr<HIPManagedAttr>()) {
461 vd->
hasAttr<CUDAConstantAttr>());
468 "handleVarRegistration: Surface and Texture registration");
472void CIRGenNVCUDARuntime::handleGlobalReplace(cir::GlobalOp oldGV,
473 cir::GlobalOp newGV) {
474 for (
auto &info : deviceVars) {
475 if (info.var == oldGV)
480void CIRGenNVCUDARuntime::finalizeModule() {
494 for (
auto &&info : deviceVars) {
495 auto kind = info.flags;
496 bool isDecl = info.var.isDeclaration();
498 bool isVarOrSurfaceOrTexture = (
kind == cir::CUDADeviceVarKind::Variable ||
499 kind == cir::CUDADeviceVarKind::Surface ||
500 kind == cir::CUDADeviceVarKind::Texture);
501 bool isUsed = info.d->isUsed();
502 bool hasUsedAttr = info.d->hasAttr<UsedAttr>();
503 if (!isDecl && !isLocalLinkage && isVarOrSurfaceOrTexture && isUsed &&
505 if (
auto globalValue = mlir::dyn_cast<cir::CIRGlobalValueInterface>(
506 info.var.getOperation())) {
Defines the clang::ASTContext interface.
Provides definitions for the various language-specific address spaces.
*collection of selector each with an associated kind and an ordered *collection of selectors A selector has a kind
__CUDA_BUILTIN_VAR __cuda_builtin_blockDim_t blockDim
__CUDA_BUILTIN_VAR __cuda_builtin_gridDim_t gridDim
mlir::Value createCast(mlir::Location loc, cir::CastKind kind, mlir::Value src, mlir::Type newTy)
cir::PtrStrideOp createPtrStride(mlir::Location loc, mlir::Value base, mlir::Value stride)
cir::PointerType getPointerTo(mlir::Type ty)
mlir::Value createBitcast(mlir::Value src, mlir::Type newTy)
mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType, mlir::Type type, llvm::StringRef name, mlir::IntegerAttr alignment, mlir::Value dynAllocSize)
TranslationUnitDecl * getTranslationUnitDecl() const
bool shouldExternalize(const Decl *D) const
Whether a C++ static variable or CUDA/HIP kernel should be externalized.
llvm::SetVector< const VarDecl * > CUDADeviceVarODRUsedByHost
Keep track of CUDA/HIP device-side variables ODR-used by host code.
const TargetInfo & getTargetInfo() const
mlir::Value getPointer() const
cir::ConstantOp getConstInt(mlir::Location loc, llvm::APSInt intVal)
clang::MangleContext & getMangleContext()
Gets the mangle context.
static CIRGenCallee forDirect(mlir::Operation *funcPtr, const CIRGenCalleeInfo &abstractInfo=CIRGenCalleeInfo())
CIRGenTypes & getTypes() const
const clang::LangOptions & getLangOpts() const
const clang::Decl * curFuncDecl
Address getAddrOfLocalVar(const clang::VarDecl *vd)
Return the address of a local variable.
RValue emitCall(const CIRGenFunctionInfo &funcInfo, const CIRGenCallee &callee, ReturnValueSlot returnValue, const CallArgList &args, cir::CIRCallOpInterface *callOp, mlir::Location loc)
mlir::Value emitRuntimeCall(mlir::Location loc, cir::FuncOp callee, llvm::ArrayRef< mlir::Value > args={}, mlir::NamedAttrList attrs={})
This class organizes the cross-function state that is used while generating CIR code.
llvm::StringRef getMangledName(clang::GlobalDecl gd)
DiagnosticBuilder errorNYI(SourceLocation, llvm::StringRef)
Helpers to emit "not yet implemented" error diagnostics.
clang::ASTContext & getASTContext() const
CIRGenBuilderTy & getBuilder()
const clang::TargetInfo & getTarget() const
void error(SourceLocation loc, llvm::StringRef error)
Emit a general error that something can't be done.
cir::FuncOp createRuntimeFunction(cir::FuncType ty, llvm::StringRef name, mlir::NamedAttrList extraAttrs={}, bool isLocal=false, bool assumeConvergent=false)
const clang::LangOptions & getLangOpts() const
void printPostfixForExternalizedDecl(llvm::raw_ostream &os, const Decl *d)
Print the postfix for externalized static variable or kernels for single source offloading languages ...
cir::GlobalOp createGlobalOp(mlir::Location loc, llvm::StringRef name, mlir::Type t, bool isConstant=false, mlir::ptr::MemorySpaceAttrInterface addrSpace={}, mlir::Operation *insertPoint=nullptr)
void addCompilerUsedGlobal(cir::CIRGlobalValueInterface gv)
Add a global value to the llvmCompilerUsed list.
CIRGenCXXABI & getCXXABI() const
const CIRGenFunctionInfo & arrangeFunctionDeclaration(const clang::FunctionDecl *fd)
Free functions are functions that are compatible with an ordinary C function pointer type.
mlir::Type convertType(clang::QualType type)
Convert a Clang type into a mlir::Type.
void add(RValue rvalue, clang::QualType type)
Type for representing both the decl and type of parameters to a function.
static RValue get(mlir::Value v)
static RValue getAggregate(Address addr, bool isVolatile=false)
Convert an Address to an RValue.
QuantityType getQuantity() const
getQuantity - Get the raw integer representation of this quantity.
static CharUnits fromQuantity(QuantityType Quantity)
fromQuantity - Construct a CharUnits quantity from a raw integer type.
lookup_result lookup(DeclarationName Name) const
lookup - Find the declarations (if any) with the given Name in this context.
SourceLocation getLocation() const
const ParmVarDecl * getParamDecl(unsigned i) const
GlobalDecl - represents a global declaration.
GlobalDecl getWithKernelReferenceKind(KernelReferenceKind Kind)
StringRef getName() const
Return the actual identifier string.
IdentifierInfo & get(StringRef Name)
Return the identifier token info for the specified named identifier.
GPUDefaultStreamKind GPUDefaultStream
The default stream kind used for HIP kernel launching.
bool shouldMangleDeclName(const NamedDecl *D)
void mangleName(GlobalDecl GD, raw_ostream &)
IdentifierInfo * getIdentifier() const
Get the identifier that names this declaration, if there is one.
bool isMicrosoft() const
Is this ABI an MSVC-compatible ABI?
TargetCXXABI getCXXABI() const
Get the C++ ABI currently in use.
const llvm::VersionTuple & getSDKVersion() const
static DeclContext * castToDeclContext(const TranslationUnitDecl *D)
bool isCUDADeviceBuiltinSurfaceType() const
Check if the type is the CUDA device builtin surface type.
bool isCUDADeviceBuiltinTextureType() const
Check if the type is the CUDA device builtin texture type.
SourceRange getSourceRange() const override LLVM_READONLY
Source range that this declaration covers.
bool isInline() const
Whether this variable is (C++1z) inline.
bool hasExternalStorage() const
Returns true if a variable has extern or private_extern storage.
DefinitionKind hasDefinition(ASTContext &) const
Check whether this variable is defined in this translation unit.
static bool isLocalLinkage(GlobalLinkageKind linkage)
CIRGenCUDARuntime * createNVCUDARuntime(CIRGenModule &cgm)
constexpr Variable var(Literal L)
Returns the variable of L.
@ Address
A pointer to a ValueDecl.
The JSON file list parser is used to communicate input to InstallAPI.
if(T->getSizeExpr()) TRY_TO(TraverseStmt(const_cast< Expr * >(T -> getSizeExpr())))
bool CudaFeatureEnabled(llvm::VersionTuple, CudaFeature)
U cast(CodeGen::Address addr)
clang::CharUnits getPointerAlign() const
clang::CharUnits getSizeAlign() const
cir::PointerType voidPtrTy
void* in address space 0