10#include "mlir/IR/Builders.h"
11#include "mlir/IR/Dominance.h"
17using namespace mlir::abi;
31bool needsRewrite(
const FunctionClassification &fc) {
36 if ((fc.returnInfo.kind != ArgKind::Direct) || fc.returnInfo.coercedType)
38 for (
const ArgClassification &ac : fc.argInfos)
39 if ((ac.kind != ArgKind::Direct) || ac.coercedType)
51 const FunctionClassification &fc,
52 SmallVectorImpl<mlir::Type> &newArgTypes,
53 function_ref<mlir::InFlightDiagnostic()> emitError) {
54 assert(newArgTypes.empty() &&
"expected an empty output vector");
55 newArgTypes.reserve(oldArgTypes.size());
56 for (
auto [idx, ac] : llvm::enumerate(fc.argInfos)) {
57 mlir::Type origTy = oldArgTypes[idx];
64 newArgTypes.push_back(ac.coercedType ? ac.coercedType : origTy);
69 emitError() <<
"Expand at arg " << idx
70 <<
" not yet implemented in CallConvLowering";
71 return mlir::failure();
79 newArgTypes.push_back(origTy);
81 case ArgKind::Indirect:
82 emitError() <<
"Indirect at arg " << idx
83 <<
" not yet implemented in CallConvLowering";
84 return mlir::failure();
87 return mlir::success();
95computeNewReturnType(mlir::Type origRetTy,
const ArgClassification &retInfo,
96 mlir::MLIRContext *ctx,
97 function_ref<mlir::InFlightDiagnostic()> emitError) {
98 switch (retInfo.kind) {
102 return retInfo.coercedType ? retInfo.coercedType : origRetTy;
103 case ArgKind::Ignore:
104 return cir::VoidType::get(ctx);
105 case ArgKind::Expand:
106 emitError() <<
"Expand return is not allowed (classic codegen rejects "
107 <<
"it in EmitFunctionEpilog)";
109 case ArgKind::Extend:
114 case ArgKind::Indirect:
119 return cir::VoidType::get(ctx);
121 llvm_unreachable(
"all ArgKind cases handled");
129mlir::Value createIgnoredValue(mlir::OpBuilder &builder, mlir::Location loc,
131 return cir::ConstantOp::create(builder, loc, ty, cir::PoisonAttr::get(ty));
137mlir::ArrayAttr updateArgAttrs(mlir::MLIRContext *ctx,
138 mlir::ArrayAttr existingArgAttrs,
139 const FunctionClassification &fc) {
141 newArgAttrs.reserve(fc.argInfos.size());
142 for (
auto [oldIdx, ac] : llvm::enumerate(fc.argInfos)) {
143 if (ac.kind == ArgKind::Ignore)
145 mlir::DictionaryAttr existing = mlir::DictionaryAttr::get(ctx);
146 if (existingArgAttrs && oldIdx < existingArgAttrs.size())
147 existing = mlir::cast<mlir::DictionaryAttr>(existingArgAttrs[oldIdx]);
148 if (ac.kind == ArgKind::Extend) {
149 StringRef attrName = ac.signExtend ?
"llvm.signext" :
"llvm.zeroext";
150 mlir::NamedAttribute extAttr(mlir::StringAttr::get(ctx, attrName),
151 mlir::UnitAttr::get(ctx));
152 if (existing.empty()) {
153 newArgAttrs.push_back(mlir::DictionaryAttr::get(ctx, {extAttr}));
157 attrs.push_back(extAttr);
158 newArgAttrs.push_back(mlir::DictionaryAttr::get(ctx, attrs));
161 newArgAttrs.push_back(existing);
164 return mlir::ArrayAttr::get(ctx, newArgAttrs);
170mlir::ArrayAttr updateResAttrs(mlir::MLIRContext *ctx,
171 mlir::ArrayAttr existingResAttrs,
172 const ArgClassification &retInfo) {
173 if (retInfo.kind != ArgKind::Extend)
174 return existingResAttrs;
177 if (existingResAttrs && !existingResAttrs.empty())
178 for (mlir::NamedAttribute na :
179 mlir::cast<mlir::DictionaryAttr>(existingResAttrs[0]))
181 StringRef attrName = retInfo.signExtend ?
"llvm.signext" :
"llvm.zeroext";
182 attrs.push_back(mlir::NamedAttribute(mlir::StringAttr::get(ctx, attrName),
183 mlir::UnitAttr::get(ctx)));
184 return mlir::ArrayAttr::get(ctx, {mlir::DictionaryAttr::get(ctx, attrs)});
207mlir::Value emitCoercion(mlir::OpBuilder &builder, mlir::Location loc,
208 mlir::Type dstTy, mlir::Value src,
209 mlir::FunctionOpInterface funcOp,
210 const mlir::DataLayout &dl,
211 SmallPtrSetImpl<mlir::Operation *> &createdOps) {
212 mlir::Type srcTy = src.getType();
213 assert(srcTy != dstTy &&
214 "emitCoercion callers must pre-check that the types differ");
216 uint64_t srcAlign = dl.getTypeABIAlignment(srcTy);
217 uint64_t dstAlign = dl.getTypeABIAlignment(dstTy);
218 uint64_t allocaAlign = std::max(srcAlign, dstAlign);
220 dl.getTypeSize(srcTy) >= dl.getTypeSize(dstTy) ? srcTy : dstTy;
222 auto slotPtrTy = cir::PointerType::get(slotTy);
223 auto srcPtrTy = cir::PointerType::get(srcTy);
224 auto dstPtrTy = cir::PointerType::get(dstTy);
226 cir::AllocaOp alloca;
228 mlir::OpBuilder::InsertionGuard guard(builder);
229 mlir::Block &entry = funcOp->getRegion(0).front();
230 builder.setInsertionPointToStart(&entry);
231 alloca = cir::AllocaOp::create(builder, loc, slotPtrTy,
232 builder.getStringAttr(
"coerce"),
233 builder.getI64IntegerAttr(allocaAlign));
235 createdOps.insert(alloca);
238 mlir::Value srcSlot = alloca;
239 if (slotTy != srcTy) {
240 auto srcCast = cir::CastOp::create(builder, loc, srcPtrTy,
241 cir::CastKind::bitcast, alloca);
242 createdOps.insert(srcCast);
245 auto store = cir::StoreOp::create(builder, loc, src, srcSlot);
246 createdOps.insert(store);
249 mlir::Value dstSlot = alloca;
250 if (slotTy != dstTy) {
251 auto dstCast = cir::CastOp::create(builder, loc, dstPtrTy,
252 cir::CastKind::bitcast, alloca);
253 createdOps.insert(dstCast);
256 auto load = cir::LoadOp::create(builder, loc, dstSlot);
257 createdOps.insert(load);
263mlir::Value emitCoercion(mlir::OpBuilder &builder, mlir::Location loc,
264 mlir::Type dstTy, mlir::Value src,
265 mlir::FunctionOpInterface funcOp,
266 const mlir::DataLayout &dl) {
267 SmallPtrSet<mlir::Operation *, 4> ignored;
268 return emitCoercion(builder, loc, dstTy, src, funcOp, dl, ignored);
273void insertReturnCoercion(mlir::FunctionOpInterface funcOp,
274 mlir::Type origRetTy, mlir::Type coercedRetTy,
275 mlir::OpBuilder &builder,
276 const mlir::DataLayout &dl) {
278 funcOp.walk([&](cir::ReturnOp r) { returns.push_back(r); });
279 for (cir::ReturnOp r : returns) {
280 if (r.getInput().empty())
282 mlir::Value origVal = r.getInput()[0];
283 if (origVal.getType() == coercedRetTy)
285 builder.setInsertionPoint(r);
286 mlir::Value coerced =
287 emitCoercion(builder, r.getLoc(), coercedRetTy, origVal, funcOp, dl);
288 r->setOperand(0, coerced);
301void insertArgCoercion(mlir::FunctionOpInterface funcOp,
302 const FunctionClassification &fc,
303 mlir::OpBuilder &builder,
const mlir::DataLayout &dl,
305 mlir::Region &body = funcOp->getRegion(0);
308 mlir::Block &entry = body.front();
310 for (
auto [idx, ac] : llvm::enumerate(fc.argInfos)) {
311 if (ac.kind != ArgKind::Direct || !ac.coercedType)
313 unsigned blockIdx = idx + hasSRetArg;
314 if (
blockIdx >= entry.getNumArguments())
317 mlir::BlockArgument blockArg = entry.getArgument(
blockIdx);
318 mlir::Type oldArgTy = blockArg.getType();
319 mlir::Type newArgTy = ac.coercedType;
320 if (oldArgTy == newArgTy)
323 blockArg.setType(newArgTy);
325 builder.setInsertionPointToStart(&entry);
326 SmallPtrSet<mlir::Operation *, 4> coercionOps;
327 mlir::Value adapted = emitCoercion(builder, funcOp.getLoc(), oldArgTy,
328 blockArg, funcOp, dl, coercionOps);
334 blockArg.replaceAllUsesExcept(adapted, coercionOps);
366void insertSRetStores(mlir::FunctionOpInterface funcOp, mlir::Type origRetTy,
367 mlir::OpBuilder &builder) {
368 mlir::Value sretPtr = funcOp.getArguments()[0];
371 funcOp->walk([&](cir::ReturnOp retOp) { returnOps.push_back(retOp); });
373 cir::AllocaOp retAlloca =
nullptr;
374 for (cir::ReturnOp retOp : returnOps) {
377 assert(!retOp.getInput().empty() &&
378 "cir.return in sret function must have an operand");
380 cir::LoadOp retLoad =
381 mlir::cast<cir::LoadOp>(retOp.getInput()[0].getDefiningOp());
389 retAlloca = mlir::cast<cir::AllocaOp>(retLoad.getAddr().getDefiningOp());
390 retAlloca.getResult().replaceAllUsesWith(sretPtr);
396 builder.setInsertionPoint(retOp);
397 cir::ReturnOp::create(builder, retOp.getLoc());
399 if (retLoad.use_empty())
418 builder.getNamedAttr(
"llvm.sret", mlir::TypeAttr::get(retTy)));
420 builder.getNamedAttr(
"llvm.align", builder.getI64IntegerAttr(align)));
423 builder.getNamedAttr(
"llvm.noalias", builder.getUnitAttr()));
424 attrs.push_back(builder.getNamedAttr(
"llvm.writable", builder.getUnitAttr()));
426 builder.getNamedAttr(
"llvm.dead_on_unwind", builder.getUnitAttr()));
436void applySretSlotAttrs(cir::CallOp newCall, mlir::ArrayAttr argAttrs,
437 mlir::Type retTy, uint64_t align,
438 mlir::OpBuilder &builder) {
439 mlir::MLIRContext *ctx = newCall->getContext();
441 buildSretSlotAttrs(builder, retTy, align,
false);
444 newArgAttrs.reserve(newCall.getArgOperands().size());
445 newArgAttrs.push_back(mlir::DictionaryAttr::get(ctx, sretAttrs));
447 llvm::append_range(newArgAttrs, argAttrs);
448 assert(newArgAttrs.size() <= newCall.getArgOperands().size() &&
449 "arg_attrs wider than the rewritten call's operand list");
450 newArgAttrs.resize(newCall.getArgOperands().size(),
451 mlir::DictionaryAttr::get(ctx));
452 newCall->setAttr(
"arg_attrs", mlir::ArrayAttr::get(ctx, newArgAttrs));
462void rewriteIndirectReturnCall(cir::CallOp call,
463 const FunctionClassification &fc,
465 mlir::Type origRetTy, mlir::OpBuilder &builder) {
466 mlir::MLIRContext *ctx = call->getContext();
467 auto ptrTy = cir::PointerType::get(origRetTy);
468 builder.setInsertionPoint(call);
469 uint64_t sretAlign = fc.returnInfo.indirectAlign.value();
482 mlir::Value sretSlot =
nullptr;
483 cir::StoreOp reuseStore =
nullptr;
484 if (call.getResult().hasOneUse()) {
485 mlir::Operation *user = *call.getResult().getUsers().begin();
486 if (
auto store = mlir::dyn_cast<cir::StoreOp>(user))
487 if (store.getValue() == call.getResult() &&
488 store.getAddr().getType() == ptrTy &&
489 mlir::DominanceInfo().properlyDominates(store.getAddr(), call)) {
490 sretSlot = store.getAddr();
495 auto alloca = cir::AllocaOp::create(
496 builder, call.getLoc(), ptrTy,
497 builder.getStringAttr(
"sret"),
498 builder.getI64IntegerAttr(sretAlign));
503 sretArgs.push_back(sretSlot);
504 sretArgs.append(newArgs.begin(), newArgs.end());
506 mlir::Type sretVoidTy = cir::VoidType::get(ctx);
507 auto newCall = cir::CallOp::create(
508 builder, call.getLoc(), call.getCalleeAttr(), sretVoidTy, sretArgs);
509 for (mlir::NamedAttribute attr : call->getAttrs())
510 if (!newCall->hasAttr(
attr.getName()))
511 newCall->setAttr(
attr.getName(),
attr.getValue());
516 mlir::ArrayAttr argAttrs = call->getAttrOfType<mlir::ArrayAttr>(
"arg_attrs");
517 bool needsArgAttrUpdate =
518 llvm::any_of(fc.argInfos, [](
const ArgClassification &ac) {
519 return ac.kind == ArgKind::Ignore || ac.kind == ArgKind::Extend;
521 if (needsArgAttrUpdate)
522 argAttrs = updateArgAttrs(ctx, argAttrs, fc);
523 applySretSlotAttrs(newCall, argAttrs, origRetTy, sretAlign, builder);
531 builder.setInsertionPointAfter(newCall);
532 auto load = cir::LoadOp::create(builder, call.getLoc(), origRetTy, sretSlot,
536 cir::SyncScopeKindAttr(),
537 cir::MemOrderAttr());
538 call.getResult().replaceAllUsesWith(load);
546 mlir::FunctionOpInterface funcOpInterface,
const FunctionClassification &fc,
547 mlir::OpBuilder &builder) {
553 cir::FuncOp funcOp = mlir::cast<cir::FuncOp>(funcOpInterface);
555 if (!needsRewrite(fc))
556 return mlir::success();
560 mlir::MLIRContext *ctx = funcOp->getContext();
565 assert(oldResultTypes.size() <= 1 &&
566 "CIR functions return zero or one value");
569 if (mlir::failed(buildNewArgTypes(oldArgTypes, fc, newArgTypes,
570 [&]() {
return funcOp.emitOpError(); })))
571 return mlir::failure();
573 mlir::Type voidTy = cir::VoidType::get(ctx);
574 mlir::Type origRetTy = oldResultTypes.empty() ? voidTy : oldResultTypes[0];
575 mlir::Type newRetTy = computeNewReturnType(
576 origRetTy, fc.returnInfo, ctx, [&]() { return funcOp.emitOpError(); });
578 return mlir::failure();
588 fc.returnInfo.kind == ArgKind::Indirect && !oldResultTypes.empty();
590 newArgTypes.insert(newArgTypes.begin(), cir::PointerType::get(origRetTy));
592 if (funcOp.isDefinition()) {
593 mlir::Region &body = funcOp->getRegion(0);
599 body.front().insertArgument(0u, cir::PointerType::get(origRetTy),
601 insertSRetStores(funcOp, origRetTy, builder);
611 insertArgCoercion(funcOp, fc, builder, dl, hasSRet);
616 if (fc.returnInfo.kind == ArgKind::Direct && fc.returnInfo.coercedType &&
617 !oldResultTypes.empty() && fc.returnInfo.coercedType != origRetTy)
618 insertReturnCoercion(funcOp, origRetTy, fc.returnInfo.coercedType,
621 mlir::Block &entry = body.front();
629 for (
int argInfoIdx =
static_cast<int>(fc.argInfos.size()) - 1;
630 argInfoIdx >= 0; --argInfoIdx) {
631 if (fc.argInfos[argInfoIdx].kind != ArgKind::Ignore)
633 unsigned blockIdx =
static_cast<unsigned>(argInfoIdx) + hasSRet;
634 if (
blockIdx >= entry.getNumArguments())
636 mlir::BlockArgument arg = entry.getArgument(
blockIdx);
637 if (!arg.use_empty()) {
638 builder.setInsertionPointToStart(&entry);
640 createIgnoredValue(builder, funcOp.getLoc(), arg.getType());
641 arg.replaceAllUsesWith(poison);
653 if (fc.returnInfo.kind == ArgKind::Ignore && !oldResultTypes.empty()) {
654 assert(mlir::isa<cir::VoidType>(newRetTy) &&
655 "Ignore-return path requires the new return type to be void");
657 funcOp.walk([&](cir::ReturnOp r) { returns.push_back(r); });
658 for (cir::ReturnOp r : returns) {
659 if (r.getNumOperands() == 0)
661 builder.setInsertionPoint(r);
662 cir::ReturnOp::create(builder, r.getLoc());
668 mlir::Type newFnTy = funcOp.cloneTypeWith(newArgTypes, newResultTypes);
669 funcOp.setFunctionTypeAttr(mlir::TypeAttr::get(newFnTy));
674 bool needsArgAttrUpdate =
675 hasSRet || llvm::any_of(fc.argInfos, [](
const ArgClassification &ac) {
676 return ac.kind == ArgKind::Ignore || ac.kind == ArgKind::Extend;
678 if (needsArgAttrUpdate) {
679 auto existing = funcOp->getAttrOfType<mlir::ArrayAttr>(
"arg_attrs");
680 mlir::ArrayAttr updated = updateArgAttrs(ctx, existing, fc);
686 builder, origRetTy, fc.returnInfo.indirectAlign.value(),
687 funcOp.isDefinition());
689 withSret.push_back(mlir::DictionaryAttr::get(ctx, sretAttrs));
690 llvm::append_range(withSret, updated);
691 funcOp->setAttr(
"arg_attrs", mlir::ArrayAttr::get(ctx, withSret));
693 funcOp->setAttr(
"arg_attrs", updated);
699 if (fc.returnInfo.kind == ArgKind::Extend) {
700 auto existing = funcOp->getAttrOfType<mlir::ArrayAttr>(
"res_attrs");
701 funcOp->setAttr(
"res_attrs", updateResAttrs(ctx, existing, fc.returnInfo));
704 return mlir::success();
709 const FunctionClassification &fc,
710 mlir::OpBuilder &builder) {
711 if (!needsRewrite(fc))
712 return mlir::success();
714 if (mlir::isa<cir::TryCallOp>(callOp))
715 return callOp->emitOpError()
716 <<
"TryCallOp not yet implemented in CallConvLowering";
718 auto call = mlir::cast<cir::CallOp>(callOp);
719 if (call.isIndirect())
720 return call.emitOpError()
721 <<
"indirect call not yet implemented in CallConvLowering";
723 mlir::MLIRContext *ctx = callOp->getContext();
724 auto enclosingFunc = call->getParentOfType<mlir::FunctionOpInterface>();
726 for (
auto [idx, ac] : llvm::enumerate(fc.argInfos)) {
728 case ArgKind::Direct:
729 case ArgKind::Ignore:
731 case ArgKind::Expand:
732 return call.emitOpError() <<
"Expand at call-site arg " << idx
733 <<
" not yet implemented in CallConvLowering";
734 case ArgKind::Extend:
738 case ArgKind::Indirect:
739 return call.emitOpError() <<
"Indirect at call-site arg " << idx
740 <<
" not yet implemented in CallConvLowering";
744 builder.setInsertionPoint(call);
747 mlir::ValueRange argOperands = call.getArgOperands();
748 newArgs.reserve(argOperands.size());
749 if (argOperands.size() > fc.argInfos.size())
750 return call.emitOpError()
751 <<
"variadic arguments not yet implemented in CallConvLowering";
752 assert(fc.argInfos.size() == argOperands.size() &&
753 "call operand count must match classified arg count");
754 for (
auto [idx, ac] : llvm::enumerate(fc.argInfos)) {
755 if (ac.kind == ArgKind::Ignore)
757 mlir::Value arg = argOperands[idx];
758 if (ac.kind == ArgKind::Direct && ac.coercedType &&
759 arg.getType() != ac.coercedType)
760 arg = emitCoercion(builder, call.getLoc(), ac.coercedType, arg,
762 newArgs.push_back(arg);
765 bool hasResult = call.getNumResults() > 0;
766 mlir::Type origRetTy =
767 hasResult ? call.getResult().getType() : cir::VoidType::get(ctx);
773 if (fc.returnInfo.kind == ArgKind::Indirect && hasResult) {
774 rewriteIndirectReturnCall(call, fc, newArgs, origRetTy, builder);
775 return mlir::success();
778 mlir::Type callRetTy = origRetTy;
779 if (fc.returnInfo.kind == ArgKind::Ignore && hasResult)
780 callRetTy = cir::VoidType::get(ctx);
781 bool returnNeedsCoercion =
782 hasResult && fc.returnInfo.kind == ArgKind::Direct &&
783 fc.returnInfo.coercedType && fc.returnInfo.coercedType != origRetTy;
784 if (returnNeedsCoercion)
785 callRetTy = fc.returnInfo.coercedType;
787 builder.setInsertionPoint(call);
788 auto newCall = cir::CallOp::create(builder, call.getLoc(),
789 call.getCalleeAttr(), callRetTy, newArgs);
790 for (mlir::NamedAttribute attr : call->getAttrs())
791 if (!newCall->hasAttr(attr.getName()))
792 newCall->setAttr(attr.getName(), attr.getValue());
796 if (returnNeedsCoercion) {
797 builder.setInsertionPointAfter(newCall);
798 mlir::Value coercedBack =
799 emitCoercion(builder, call.getLoc(), origRetTy, newCall.getResult(),
801 call.getResult().replaceAllUsesWith(coercedBack);
807 bool needsArgAttrUpdate =
808 llvm::any_of(fc.argInfos, [](
const ArgClassification &ac) {
809 return ac.kind == ArgKind::Ignore || ac.kind == ArgKind::Extend;
811 if (needsArgAttrUpdate) {
812 auto existing = call->getAttrOfType<mlir::ArrayAttr>(
"arg_attrs");
813 newCall->setAttr(
"arg_attrs", updateArgAttrs(ctx, existing, fc));
815 if (fc.returnInfo.kind == ArgKind::Extend) {
816 auto existing = call->getAttrOfType<mlir::ArrayAttr>(
"res_attrs");
817 newCall->setAttr(
"res_attrs", updateResAttrs(ctx, existing, fc.returnInfo));
820 if (hasResult && fc.returnInfo.kind == ArgKind::Ignore) {
825 if (!call.getResult().use_empty()) {
826 builder.setInsertionPointAfter(newCall);
828 createIgnoredValue(builder, call.getLoc(), origRetTy);
829 call.getResult().replaceAllUsesWith(poison);
831 }
else if (hasResult && !returnNeedsCoercion) {
833 call.getResult().replaceAllUsesWith(newCall.getResult());
837 return mlir::success();
__CUDA_BUILTIN_VAR __cuda_builtin_blockIdx_t blockIdx
mlir::LogicalResult rewriteFunctionDefinition(mlir::FunctionOpInterface funcOp, const mlir::abi::FunctionClassification &fc, mlir::OpBuilder &builder) override
mlir::LogicalResult rewriteCallSite(mlir::Operation *callOp, const mlir::abi::FunctionClassification &fc, mlir::OpBuilder &builder) override
const internal::VariadicAllOfMatcher< Attr > attr