17#include "mlir/IR/Operation.h"
25#include "llvm/Support/Casting.h"
39 llvm::StringMap<mlir::Operation *> kernelHandles;
42 llvm::DenseMap<mlir::Operation *, mlir::Operation *> kernelStubs;
44 std::unique_ptr<MangleContext> deviceMC;
47 void emitDeviceStubBodyNew(CIRGenFunction &cgf, cir::FuncOp fn,
48 FunctionArgList &args);
49 mlir::Value prepareKernelArgs(CIRGenFunction &cgf, mlir::Location loc,
50 FunctionArgList &args);
51 mlir::Operation *getKernelHandle(cir::FuncOp fn, GlobalDecl gd)
override;
53 mlir::Operation *getKernelStub(mlir::Operation *handle)
override {
54 auto it = kernelStubs.find(handle);
55 assert(it != kernelStubs.end());
58 std::string addPrefixToName(StringRef funcName)
const;
59 std::string addUnderscoredPrefixToName(StringRef funcName)
const;
62 CIRGenNVCUDARuntime(CIRGenModule &cgm);
63 ~CIRGenNVCUDARuntime();
65 void emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn,
66 FunctionArgList &args)
override;
71std::string CIRGenNVCUDARuntime::addPrefixToName(StringRef funcName)
const {
72 return (prefix + funcName).str();
76CIRGenNVCUDARuntime::addUnderscoredPrefixToName(StringRef funcName)
const {
77 return (
"__" + prefix + funcName).str();
80CIRGenNVCUDARuntime::CIRGenNVCUDARuntime(CIRGenModule &cgm)
81 : CIRGenCUDARuntime(cgm),
82 deviceMC(cgm.getASTContext().cudaNVInitDeviceMC()) {
84 cgm.
errorNYI(
"CIRGenNVCUDARuntime: Offload via LLVM");
91mlir::Value CIRGenNVCUDARuntime::prepareKernelArgs(
CIRGenFunction &cgf,
97 auto voidPtrArrayTy = cir::ArrayType::get(cgm.
voidPtrTy, args.size());
99 loc, cir::PointerType::get(voidPtrArrayTy), voidPtrArrayTy,
"kernel_args",
102 mlir::Value kernelArgsDecayed =
103 builder.
createCast(cir::CastKind::array_to_ptrdecay, kernelArgs,
106 for (
const auto &[i, arg] : llvm::enumerate(args)) {
109 mlir::Value storePos =
114 builder.CIRBaseBuilderTy::createStore(loc, argAsVoid, storePos);
117 return kernelArgsDecayed;
122void CIRGenNVCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
124 FunctionArgList &args) {
128 cgm.
errorNYI(
"CIRGenNVCUDARuntime: Offload via LLVM");
131 mlir::Location loc = fn.getLoc();
136 mlir::Value kernelArgs = prepareKernelArgs(cgf, loc, args);
153 StringRef kernelLaunchAPI =
"LaunchKernel";
155 LangOptions::GPUDefaultStreamKind::PerThread)
156 cgm.
errorNYI(
"CUDA/HIP Stream per thread");
158 std::string launchKernelName = addPrefixToName(kernelLaunchAPI);
159 const IdentifierInfo &launchII =
161 FunctionDecl *cudaLaunchKernelFD =
nullptr;
162 for (NamedDecl *result : dc->
lookup(&launchII)) {
163 if (FunctionDecl *fd = dyn_cast<FunctionDecl>(result))
164 cudaLaunchKernelFD = fd;
167 if (cudaLaunchKernelFD ==
nullptr) {
169 "Can't find declaration for " + launchKernelName);
186 builder.
createAlloca(loc, cir::PointerType::get(dim3Ty), dim3Ty,
189 builder.
createAlloca(loc, cir::PointerType::get(dim3Ty), dim3Ty,
191 mlir::Value sharedMem =
195 builder.
createAlloca(loc, cir::PointerType::get(streamTy), streamTy,
200 sharedMem.getType(), stream.getType()},
202 addUnderscoredPrefixToName(
"PopCallConfiguration"));
213 mlir::Value kernel = [&]() -> mlir::Value {
214 if (cir::GlobalOp globalOp = llvm::dyn_cast_or_null<cir::GlobalOp>(
215 kernelHandles[fn.getSymName()])) {
216 cir::PointerType kernelTy = cir::PointerType::get(globalOp.getSymType());
217 mlir::Value kernelVal = cir::GetGlobalOp::create(builder, loc, kernelTy,
218 globalOp.getSymName());
222 if (cir::FuncOp funcOp = llvm::dyn_cast_or_null<cir::FuncOp>(
223 kernelHandles[fn.getSymName()])) {
224 cir::PointerType kernelTy =
225 cir::PointerType::get(funcOp.getFunctionType());
226 mlir::Value kernelVal =
227 cir::GetGlobalOp::create(builder, loc, kernelTy, funcOp.getSymName());
231 llvm_unreachable(
"Expected stub handle to be cir::GlobalOp or FuncOp");
234 CallArgList launchArgs;
246 RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, sharedMem)),
248 launchArgs.
add(
RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, stream)),
251 mlir::Type launchTy =
255 const CIRGenFunctionInfo &callInfo =
258 ReturnValueSlot(), launchArgs);
262 cgm.
errorNYI(
"MSVC CUDA stub handling");
265void CIRGenNVCUDARuntime::emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn,
266 FunctionArgList &args) {
269 llvm::dyn_cast<cir::GlobalOp>(kernelHandles[fn.getSymName()])) {
271 mlir::Type fnPtrTy = globalOp.getSymType();
272 auto sym = mlir::FlatSymbolRefAttr::get(fn.getSymNameAttr());
273 auto gv = cir::GlobalViewAttr::get(fnPtrTy, sym);
275 globalOp->setAttr(
"initial_value", gv);
276 globalOp->removeAttr(
"sym_visibility");
277 globalOp->setAttr(
"alignment", builder.getI64IntegerAttr(
283 CudaFeature::CUDA_USES_NEW_LAUNCH) ||
286 emitDeviceStubBodyNew(cgf, fn, args);
288 cgm.
errorNYI(
"Emit Stub Body Legacy");
292 return new CIRGenNVCUDARuntime(cgm);
295CIRGenNVCUDARuntime::~CIRGenNVCUDARuntime() {}
297mlir::Operation *CIRGenNVCUDARuntime::getKernelHandle(cir::FuncOp fn,
301 auto it = kernelHandles.find(fn.getSymName());
302 if (it != kernelHandles.end()) {
303 mlir::Operation *oldHandle = it->second;
305 if (kernelStubs[oldHandle] == fn)
313 kernelStubs[oldHandle] = fn;
318 kernelStubs.erase(oldHandle);
323 kernelHandles[fn.getSymName()] = fn;
324 kernelStubs[fn] = fn;
333 cgm, fn.getLoc(), globalName, fn.getFunctionType(),
336 globalOp->setAttr(
"alignment", builder.getI64IntegerAttr(
340 kernelHandles[fn.getSymName()] = globalOp;
341 kernelStubs[globalOp] = fn;
Defines the clang::ASTContext interface.
Provides definitions for the various language-specific address spaces.
__CUDA_BUILTIN_VAR __cuda_builtin_blockDim_t blockDim
__CUDA_BUILTIN_VAR __cuda_builtin_gridDim_t gridDim
mlir::Value createCast(mlir::Location loc, cir::CastKind kind, mlir::Value src, mlir::Type newTy)
cir::PtrStrideOp createPtrStride(mlir::Location loc, mlir::Value base, mlir::Value stride)
mlir::Value createBitcast(mlir::Value src, mlir::Type newTy)
mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType, mlir::Type type, llvm::StringRef name, mlir::IntegerAttr alignment, mlir::Value dynAllocSize)
TranslationUnitDecl * getTranslationUnitDecl() const
const TargetInfo & getTargetInfo() const
mlir::Value getPointer() const
cir::ConstantOp getConstInt(mlir::Location loc, llvm::APSInt intVal)
static CIRGenCallee forDirect(mlir::Operation *funcPtr, const CIRGenCalleeInfo &abstractInfo=CIRGenCalleeInfo())
CIRGenTypes & getTypes() const
const clang::LangOptions & getLangOpts() const
const clang::Decl * curFuncDecl
Address getAddrOfLocalVar(const clang::VarDecl *vd)
Return the address of a local variable.
RValue emitCall(const CIRGenFunctionInfo &funcInfo, const CIRGenCallee &callee, ReturnValueSlot returnValue, const CallArgList &args, cir::CIRCallOpInterface *callOp, mlir::Location loc)
mlir::Value emitRuntimeCall(mlir::Location loc, cir::FuncOp callee, llvm::ArrayRef< mlir::Value > args={})
This class organizes the cross-function state that is used while generating CIR code.
llvm::StringRef getMangledName(clang::GlobalDecl gd)
DiagnosticBuilder errorNYI(SourceLocation, llvm::StringRef)
Helpers to emit "not yet implemented" error diagnostics.
clang::ASTContext & getASTContext() const
cir::FuncOp createRuntimeFunction(cir::FuncType ty, llvm::StringRef name, mlir::ArrayAttr={}, bool isLocal=false, bool assumeConvergent=false)
CIRGenBuilderTy & getBuilder()
const clang::TargetInfo & getTarget() const
void error(SourceLocation loc, llvm::StringRef error)
Emit a general error that something can't be done.
const clang::LangOptions & getLangOpts() const
static cir::GlobalOp createGlobalOp(CIRGenModule &cgm, mlir::Location loc, llvm::StringRef name, mlir::Type t, bool isConstant=false, mlir::Operation *insertPoint=nullptr)
const CIRGenFunctionInfo & arrangeFunctionDeclaration(const clang::FunctionDecl *fd)
Free functions are functions that are compatible with an ordinary C function pointer type.
mlir::Type convertType(clang::QualType type)
Convert a Clang type into a mlir::Type.
void add(RValue rvalue, clang::QualType type)
Type for representing both the decl and type of parameters to a function.
static RValue get(mlir::Value v)
static RValue getAggregate(Address addr, bool isVolatile=false)
Convert an Address to an RValue.
QuantityType getQuantity() const
getQuantity - Get the raw integer representation of this quantity.
static CharUnits fromQuantity(QuantityType Quantity)
fromQuantity - Construct a CharUnits quantity from a raw integer type.
lookup_result lookup(DeclarationName Name) const
lookup - Find the declarations (if any) with the given Name in this context.
SourceLocation getLocation() const
const ParmVarDecl * getParamDecl(unsigned i) const
GlobalDecl - represents a global declaration.
GlobalDecl getWithKernelReferenceKind(KernelReferenceKind Kind)
IdentifierInfo & get(StringRef Name)
Return the identifier token info for the specified named identifier.
GPUDefaultStreamKind GPUDefaultStream
The default stream kind used for HIP kernel launching.
bool isMicrosoft() const
Is this ABI an MSVC-compatible ABI?
TargetCXXABI getCXXABI() const
Get the C++ ABI currently in use.
const llvm::VersionTuple & getSDKVersion() const
static DeclContext * castToDeclContext(const TranslationUnitDecl *D)
CIRGenCUDARuntime * createNVCUDARuntime(CIRGenModule &cgm)
The JSON file list parser is used to communicate input to InstallAPI.
if(T->getSizeExpr()) TRY_TO(TraverseStmt(const_cast< Expr * >(T -> getSizeExpr())))
bool CudaFeatureEnabled(llvm::VersionTuple, CudaFeature)
U cast(CodeGen::Address addr)
clang::CharUnits getPointerAlign() const
clang::CharUnits getSizeAlign() const
cir::PointerType voidPtrTy
void* in address space 0