36#include "mlir/ABI/ABIRewriteContext.h"
37#include "mlir/ABI/Targets/Test/TestTarget.h"
38#include "mlir/Dialect/DLTI/DLTI.h"
39#include "mlir/IR/Builders.h"
40#include "mlir/IR/BuiltinOps.h"
41#include "mlir/IR/SymbolTable.h"
42#include "mlir/Interfaces/DataLayoutInterfaces.h"
43#include "mlir/Pass/Pass.h"
48using namespace mlir::abi;
52#define GEN_PASS_DEF_CALLCONVLOWERING
53#include "clang/CIR/Dialect/Passes.h.inc"
58bool needsRewrite(
const FunctionClassification &fc) {
59 if ((fc.returnInfo.kind != ArgKind::Direct) || fc.returnInfo.coercedType)
61 for (
const ArgClassification &ac : fc.argInfos)
62 if ((ac.kind != ArgKind::Direct) || ac.coercedType)
67struct CallConvLoweringPass
68 :
public impl::CallConvLoweringBase<CallConvLoweringPass> {
69 using CallConvLoweringBase::CallConvLoweringBase;
70 void runOnOperation()
override;
77std::optional<FunctionClassification>
78classifyFunction(cir::FuncOp func,
const DataLayout &dl, StringRef target,
79 StringRef classificationAttrName) {
81 Type returnType = func.getFunctionType().getReturnType();
83 if (!classificationAttrName.empty()) {
84 auto attr = func->getAttrOfType<DictionaryAttr>(classificationAttrName);
87 <<
"missing classification attribute '" << classificationAttrName
88 <<
"' (CallConvLowering driver mode 'classification-attr')";
91 return mlir::abi::test::parseClassificationAttr(
92 attr, [&]() {
return func.emitOpError(); });
96 return mlir::abi::test::classify(argTypes, returnType, dl);
98 func.emitOpError() <<
"unknown target '" << target <<
"' (supported: test)";
107cir::FuncOp lookupCallee(Operation *callOp, SymbolTable &symbolTable) {
108 FlatSymbolRefAttr callee;
109 if (
auto call = dyn_cast<cir::CallOp>(callOp))
110 callee = call.getCalleeAttr();
111 else if (
auto tryCall = dyn_cast<cir::TryCallOp>(callOp))
112 callee = tryCall.getCalleeAttr();
117 return symbolTable.lookup<cir::FuncOp>(callee.getValue());
120void CallConvLoweringPass::runOnOperation() {
121 ModuleOp moduleOp = getOperation();
122 MLIRContext *ctx = &getContext();
124 if (target.empty() == classificationAttr.empty()) {
125 moduleOp.emitOpError() <<
"CallConvLowering requires exactly one of "
126 "'target' or 'classification-attr' pass options";
131 if (!moduleOp->hasAttr(DLTIDialect::kDataLayoutAttrName)) {
132 moduleOp.emitOpError()
133 <<
"CallConvLowering requires a DataLayout (dlti.dl_spec attribute "
139 DataLayout dl(moduleOp);
140 CIRABIRewriteContext rewriteCtx(moduleOp);
141 SymbolTable symbolTable(moduleOp);
146 llvm::MapVector<cir::FuncOp, FunctionClassification> classifications;
147 bool anyFailed =
false;
148 moduleOp.walk([&](cir::FuncOp f) {
149 auto fc = classifyFunction(f, dl, target, classificationAttr);
154 classifications.insert({f, std::move(*fc)});
165 llvm::DenseMap<cir::FuncOp, SmallVector<Operation *>> callers;
166 moduleOp.walk([&](Operation *op) {
167 if (!isa<cir::CallOp, cir::TryCallOp>(op))
169 if (cir::FuncOp callee = lookupCallee(op, symbolTable))
170 callers[callee].push_back(op);
184 OpBuilder builder(ctx);
185 for (
auto &kv : classifications) {
186 cir::FuncOp func = kv.first;
187 const FunctionClassification &fc = kv.second;
188 if (failed(rewriteCtx.rewriteFunctionDefinition(func, fc, builder))) {
192 for (Operation *callOp : callers.lookup(func)) {
193 if (failed(rewriteCtx.rewriteCallSite(callOp, fc, builder))) {
203 const FunctionClassification *rewriteFc =
nullptr;
204 for (
auto &kv : classifications) {
205 if (needsRewrite(kv.second)) {
206 rewriteFc = &kv.second;
211 moduleOp.walk([&](cir::CallOp
c) {
214 if (failed(rewriteCtx.rewriteCallSite(
c, *rewriteFc, builder)))
227 return std::make_unique<CallConvLoweringPass>();
__device__ __2f16 float c
const internal::VariadicAllOfMatcher< Attr > attr
std::unique_ptr< Pass > createCallConvLoweringPass()