clang 23.0.0git
CallConvLoweringPass.cpp
Go to the documentation of this file.
1//===- CallConvLoweringPass.cpp - Lower CIR to ABI calling convention ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass walks every cir.func and cir.call in the module, computes a
10// FunctionClassification for it (via either an ABI target or a pre-built
11// classification injected as a function attribute), and dispatches to
12// CIRABIRewriteContext to perform the actual IR rewriting.
13//
14// Two driver modes (mutually exclusive):
15//
16// target=test
17// Use the MLIR test ABI target (mlir/lib/ABI/Targets/Test/) to classify
18// each function. Predictable rules that approximate x86_64 SysV. Real
19// targets (x86_64, AArch64) will be added once the LLVM ABI library
20// ships them.
21//
22// classification-attr=<name>
23// Read a DictionaryAttr named <name> from each cir.func and parse it via
24// mlir::abi::test::parseClassificationAttr. Used by tests to inject any
25// classification (including shapes the test target itself does not
26// produce) without depending on a real ABI target.
27//
28// The pass requires a `dlti.dl_spec` attribute on the module so the
29// classifier can query type sizes and alignments.
30//
31//===----------------------------------------------------------------------===//
32
33#include "PassDetail.h"
35
36#include "mlir/ABI/ABIRewriteContext.h"
37#include "mlir/ABI/Targets/Test/TestTarget.h"
38#include "mlir/Dialect/DLTI/DLTI.h"
39#include "mlir/IR/Builders.h"
40#include "mlir/IR/BuiltinOps.h"
41#include "mlir/IR/SymbolTable.h"
42#include "mlir/Interfaces/DataLayoutInterfaces.h"
43#include "mlir/Pass/Pass.h"
46
47using namespace mlir;
48using namespace mlir::abi;
49using namespace cir;
50
51namespace mlir {
52#define GEN_PASS_DEF_CALLCONVLOWERING
53#include "clang/CIR/Dialect/Passes.h.inc"
54} // namespace mlir
55
56namespace {
57
58bool needsRewrite(const FunctionClassification &fc) {
59 if ((fc.returnInfo.kind != ArgKind::Direct) || fc.returnInfo.coercedType)
60 return true;
61 for (const ArgClassification &ac : fc.argInfos)
62 if ((ac.kind != ArgKind::Direct) || ac.coercedType)
63 return true;
64 return false;
65}
66
67struct CallConvLoweringPass
68 : public impl::CallConvLoweringBase<CallConvLoweringPass> {
69 using CallConvLoweringBase::CallConvLoweringBase;
70 void runOnOperation() override;
71};
72
73/// Classify \p func using whichever driver mode is configured. Returns
74/// std::nullopt and emits an error on the function if classification fails
75/// (e.g. injection-driver mode but the function is missing the attribute,
76/// or the attribute is malformed).
77std::optional<FunctionClassification>
78classifyFunction(cir::FuncOp func, const DataLayout &dl, StringRef target,
79 StringRef classificationAttrName) {
80 ArrayRef<Type> argTypes = func.getFunctionType().getInputs();
81 Type returnType = func.getFunctionType().getReturnType();
82
83 if (!classificationAttrName.empty()) {
84 auto attr = func->getAttrOfType<DictionaryAttr>(classificationAttrName);
85 if (!attr) {
86 func.emitOpError()
87 << "missing classification attribute '" << classificationAttrName
88 << "' (CallConvLowering driver mode 'classification-attr')";
89 return std::nullopt;
90 }
91 return mlir::abi::test::parseClassificationAttr(
92 attr, [&]() { return func.emitOpError(); });
93 }
94
95 if (target == "test")
96 return mlir::abi::test::classify(argTypes, returnType, dl);
97
98 func.emitOpError() << "unknown target '" << target << "' (supported: test)";
99 return std::nullopt;
100}
101
102/// Find the cir.func declaration matching a direct cir.call / cir.try_call
103/// callee, if any. Returns nullptr if the callee is indirect or the symbol
104/// cannot be resolved. Takes a SymbolTable instead of a ModuleOp so the
105/// symbol lookup is amortized across all the call sites the driver walks
106/// (ModuleOp::lookupSymbol is linear per call).
107cir::FuncOp lookupCallee(Operation *callOp, SymbolTable &symbolTable) {
108 FlatSymbolRefAttr callee;
109 if (auto call = dyn_cast<cir::CallOp>(callOp))
110 callee = call.getCalleeAttr();
111 else if (auto tryCall = dyn_cast<cir::TryCallOp>(callOp))
112 callee = tryCall.getCalleeAttr();
113 else
114 return nullptr;
115 if (!callee)
116 return nullptr;
117 return symbolTable.lookup<cir::FuncOp>(callee.getValue());
118}
119
120void CallConvLoweringPass::runOnOperation() {
121 ModuleOp moduleOp = getOperation();
122 MLIRContext *ctx = &getContext();
123
124 if (target.empty() == classificationAttr.empty()) {
125 moduleOp.emitOpError() << "CallConvLowering requires exactly one of "
126 "'target' or 'classification-attr' pass options";
127 signalPassFailure();
128 return;
129 }
130
131 if (!moduleOp->hasAttr(DLTIDialect::kDataLayoutAttrName)) {
132 moduleOp.emitOpError()
133 << "CallConvLowering requires a DataLayout (dlti.dl_spec attribute "
134 "on the module)";
135 signalPassFailure();
136 return;
137 }
138
139 DataLayout dl(moduleOp);
140 CIRABIRewriteContext rewriteCtx(moduleOp);
141 SymbolTable symbolTable(moduleOp);
142
143 // Classify every cir.func up front. No IR mutation happens here, so
144 // later walks can consult any function's classification regardless of
145 // visitation order.
146 llvm::MapVector<cir::FuncOp, FunctionClassification> classifications;
147 bool anyFailed = false;
148 moduleOp.walk([&](cir::FuncOp f) {
149 auto fc = classifyFunction(f, dl, target, classificationAttr);
150 if (!fc) {
151 anyFailed = true;
152 return;
153 }
154 classifications.insert({f, std::move(*fc)});
155 });
156 if (anyFailed) {
157 signalPassFailure();
158 return;
159 }
160
161 // Build a callee-to-callers index. One module walk collects every direct
162 // cir.call / cir.try_call to each cir.func; the loop below rewrites a
163 // function and all of its call sites together. Indirect or unresolved
164 // callees are skipped here; rewriteCallSite errors on those at the end.
165 llvm::DenseMap<cir::FuncOp, SmallVector<Operation *>> callers;
166 moduleOp.walk([&](Operation *op) {
167 if (!isa<cir::CallOp, cir::TryCallOp>(op))
168 return;
169 if (cir::FuncOp callee = lookupCallee(op, symbolTable))
170 callers[callee].push_back(op);
171 });
172
173 // Rewrite each function together with every direct call to it. By the
174 // time we move on to function F+1, F's signature and every direct call to
175 // F have already been brought into alignment, and F+1..FN are still in
176 // their original (mutually consistent) form, so the IR is verifier-clean
177 // at every outer-iteration boundary.
178 //
179 // There is still a brief inner window where F's signature has been
180 // rewritten but its callers have not yet caught up -- we have no way to
181 // mutate both sides of a call atomically. No verifier runs inside the
182 // pass, and at pass exit the module is verifier-clean. Fusing the inner
183 // loop here keeps the invalid window per-function rather than module-wide.
184 OpBuilder builder(ctx);
185 for (auto &kv : classifications) {
186 cir::FuncOp func = kv.first;
187 const FunctionClassification &fc = kv.second;
188 if (failed(rewriteCtx.rewriteFunctionDefinition(func, fc, builder))) {
189 signalPassFailure();
190 return;
191 }
192 for (Operation *callOp : callers.lookup(func)) {
193 if (failed(rewriteCtx.rewriteCallSite(callOp, fc, builder))) {
194 signalPassFailure();
195 return;
196 }
197 }
198 }
199
200 // Reject indirect calls when the module contains any ABI rewrite that
201 // would need call-site lowering. We cannot strip or coerce operands
202 // without a resolved callee symbol.
203 const FunctionClassification *rewriteFc = nullptr;
204 for (auto &kv : classifications) {
205 if (needsRewrite(kv.second)) {
206 rewriteFc = &kv.second;
207 break;
208 }
209 }
210 if (rewriteFc) {
211 moduleOp.walk([&](cir::CallOp c) {
212 if (!c.isIndirect())
213 return;
214 if (failed(rewriteCtx.rewriteCallSite(c, *rewriteFc, builder)))
215 anyFailed = true;
216 });
217 if (anyFailed) {
218 signalPassFailure();
219 return;
220 }
221 }
222}
223
224} // namespace
225
226std::unique_ptr<Pass> mlir::createCallConvLoweringPass() {
227 return std::make_unique<CallConvLoweringPass>();
228}
__device__ __2f16 float c
const internal::VariadicAllOfMatcher< Attr > attr
std::unique_ptr< Pass > createCallConvLoweringPass()