clang 23.0.0git
CIRABIRewriteContext.cpp
Go to the documentation of this file.
1//===- CIRABIRewriteContext.cpp - CIR ABI rewrite context ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/IR/Builders.h"
13
14using namespace cir;
15using namespace mlir;
16using namespace mlir::abi;
17
18// This rewrite context currently supports only the Direct (no coercion) and
19// Ignore classifications. All other ArgKinds emit an errorNYI here rather
20// than silently passing through, because the IR they would produce is wrong
21// (e.g. Expand should flatten an aggregate into multiple primitives, not
22// pass it through as a single value). Subsequent PRs in the
23// CallConvLowering split series add the remaining kinds and the
24// signature-shaping behavior that goes with them (sret / byval insert
25// extra arguments, struct coercion replaces one argument with several).
26
27namespace {
28
29bool needsRewrite(const FunctionClassification &fc) {
30 if ((fc.returnInfo.kind != ArgKind::Direct) || fc.returnInfo.coercedType)
31 return true;
32 for (const ArgClassification &ac : fc.argInfos)
33 if ((ac.kind != ArgKind::Direct) || ac.coercedType)
34 return true;
35 return false;
36}
37
38SmallVector<unsigned> ignoredArgIndices(const FunctionClassification &fc) {
40 for (auto [idx, ac] : llvm::enumerate(fc.argInfos))
41 if (ac.kind == ArgKind::Ignore)
42 v.push_back(idx);
43 return v;
44}
45
46/// Build the new argument-type list for a function whose ABI classification
47/// is \p fc. This currently handles only Direct (no coercion) and Ignore;
48/// other kinds emit an error. Classifications that add arguments (e.g.
49/// Indirect-sret would prepend a return-pointer arg) are not yet
50/// implemented and will arrive in a subsequent PR.
51LogicalResult buildNewArgTypes(ArrayRef<Type> oldArgTypes,
52 const FunctionClassification &fc,
53 SmallVectorImpl<Type> &newArgTypes,
54 function_ref<InFlightDiagnostic()> emitError) {
55 assert(newArgTypes.empty() && "expected an empty output vector");
56 newArgTypes.reserve(oldArgTypes.size());
57 for (auto [idx, ac] : llvm::enumerate(fc.argInfos)) {
58 Type origTy = oldArgTypes[idx];
59 switch (ac.kind) {
60 case ArgKind::Direct:
61 if (ac.coercedType) {
62 emitError() << "Direct with coerced type at arg " << idx
63 << " not yet implemented in CallConvLowering";
64 return failure();
65 }
66 newArgTypes.push_back(origTy);
67 break;
68 case ArgKind::Ignore:
69 break;
70 case ArgKind::Expand:
71 emitError() << "Expand at arg " << idx
72 << " not yet implemented in CallConvLowering";
73 return failure();
74 case ArgKind::Extend:
75 emitError() << "Extend at arg " << idx
76 << " not yet implemented in CallConvLowering";
77 return failure();
78 case ArgKind::Indirect:
79 emitError() << "Indirect at arg " << idx
80 << " not yet implemented in CallConvLowering";
81 return failure();
82 }
83 }
84 return success();
85}
86
87/// Compute the new return type for a function whose return classification
88/// is \p retInfo. As with `buildNewArgTypes`, only Direct (no coercion)
89/// and Ignore are implemented here; the remaining kinds emit an error.
90Type computeNewReturnType(Type origRetTy, const ArgClassification &retInfo,
91 MLIRContext *ctx,
92 function_ref<InFlightDiagnostic()> emitError) {
93 switch (retInfo.kind) {
94 case ArgKind::Direct:
95 if (retInfo.coercedType) {
96 emitError() << "Direct return with coerced type not yet implemented "
97 << "in CallConvLowering";
98 return nullptr;
99 }
100 return origRetTy;
101 case ArgKind::Ignore:
102 return cir::VoidType::get(ctx);
103 case ArgKind::Expand:
104 emitError() << "Expand return is not allowed (classic codegen rejects "
105 << "it in EmitFunctionEpilog)";
106 return nullptr;
107 case ArgKind::Extend:
108 emitError() << "Extend return not yet implemented in CallConvLowering";
109 return nullptr;
110 case ArgKind::Indirect:
111 emitError() << "Indirect return (sret) not yet implemented in "
112 << "CallConvLowering";
113 return nullptr;
114 }
115 llvm_unreachable("all ArgKind cases handled");
116}
117
118/// Create a typed poison constant to stand in for a value the body of a
119/// function (or the result of a call) still references but whose ABI
120/// classification is Ignore. Using poison is honest -- the value is
121/// genuinely unused at the ABI boundary -- and avoids a fake alloca+load
122/// pattern that would suggest we have a value when we don't.
123Value createIgnoredValue(OpBuilder &builder, Location loc, Type ty) {
124 return cir::ConstantOp::create(builder, loc, ty, cir::PoisonAttr::get(ty));
125}
126
127} // namespace
128
130 FunctionOpInterface funcOpInterface, const FunctionClassification &fc,
131 OpBuilder &builder) {
132 // The pass driver (CallConvLoweringPass) only ever hands us cir.func ops,
133 // and the body of this routine is end-to-end CIR (it creates cir.constant,
134 // cir.return, etc.). Cast once at the top so the rest of the function
135 // reads in CIR's own vocabulary, and so we can dispatch to the
136 // CIRGlobalValueInterface for isDefinition() (FunctionOpInterface alone
137 // does not inherit from CIRGlobalValueInterface).
138 cir::FuncOp funcOp = cast<cir::FuncOp>(funcOpInterface);
139
140 if (!needsRewrite(fc))
141 return success();
142
143 ArrayRef<Type> oldArgTypes = funcOp.getArgumentTypes();
144 ArrayRef<Type> oldResultTypes = funcOp.getResultTypes();
145 MLIRContext *ctx = funcOp->getContext();
146
147 // CIR follows LLVM IR's single-result rule: a function returns either
148 // zero or one value. Document the invariant so a future multi-result
149 // change forces us to revisit the return-handling below.
150 assert(oldResultTypes.size() <= 1 &&
151 "CIR functions return zero or one value");
152
153 SmallVector<Type> newArgTypes;
154 if (failed(buildNewArgTypes(oldArgTypes, fc, newArgTypes,
155 [&]() { return funcOp.emitOpError(); })))
156 return failure();
157
158 Type voidTy = cir::VoidType::get(ctx);
159 Type origRetTy = oldResultTypes.empty() ? voidTy : oldResultTypes[0];
160 Type newRetTy = computeNewReturnType(origRetTy, fc.returnInfo, ctx,
161 [&]() { return funcOp.emitOpError(); });
162 if (!newRetTy)
163 return failure();
164 SmallVector<Type> newResultTypes = {newRetTy};
165
166 if (funcOp.isDefinition()) {
167 Region &body = funcOp->getRegion(0);
168 if (!body.empty()) {
169 Block &entry = body.front();
170
171 // For each Ignored argument: drop the block argument and, if the
172 // body still references it, replace those uses with a poison
173 // constant. Ignore classifications mean the value is empty / not
174 // passed at the ABI level, so any remaining uses are vacuous;
175 // poison says exactly that.
176 SmallVector<unsigned> ignored = ignoredArgIndices(fc);
177 for (unsigned blockIdx : llvm::reverse(ignored)) {
178 if (blockIdx >= entry.getNumArguments())
179 continue;
180 BlockArgument arg = entry.getArgument(blockIdx);
181 if (!arg.use_empty()) {
182 builder.setInsertionPointToStart(&entry);
183 Value poison =
184 createIgnoredValue(builder, funcOp.getLoc(), arg.getType());
185 arg.replaceAllUsesWith(poison);
186 }
187 entry.eraseArgument(blockIdx);
188 }
189 }
190
191 // When the return is classified Ignore but the original function had
192 // a non-void return type, every cir.return becomes a naked return.
193 // This relies on the invariant that computeNewReturnType has set
194 // newRetTy = void for Ignore above, and that the function type is
195 // updated below to match. Asserting this keeps the dependency
196 // explicit.
197 if (fc.returnInfo.kind == ArgKind::Ignore && !oldResultTypes.empty()) {
198 assert(isa<cir::VoidType>(newRetTy) &&
199 "Ignore-return path requires the new return type to be void");
201 funcOp.walk([&](cir::ReturnOp r) { returns.push_back(r); });
202 for (cir::ReturnOp r : returns) {
203 if (r.getNumOperands() == 0)
204 continue;
205 builder.setInsertionPoint(r);
206 cir::ReturnOp::create(builder, r.getLoc());
207 r.erase();
208 }
209 }
210 }
211
212 Type newFnTy = funcOp.cloneTypeWith(newArgTypes, newResultTypes);
213 funcOp.setFunctionTypeAttr(TypeAttr::get(newFnTy));
214
215 // Keep the arg_attrs array in sync with the new argument count by
216 // dropping entries for every Ignored argument. Without this the
217 // attribute array would have stale entries that no longer match any
218 // block argument.
219 SmallVector<unsigned> ignored = ignoredArgIndices(fc);
220 if (!ignored.empty()) {
221 if (auto existing = funcOp->getAttrOfType<ArrayAttr>("arg_attrs")) {
223 kept.reserve(newArgTypes.size());
224 for (auto [oldIdx, attr] : llvm::enumerate(existing.getValue())) {
225 if (oldIdx >= fc.argInfos.size() ||
226 fc.argInfos[oldIdx].kind != ArgKind::Ignore)
227 kept.push_back(attr);
228 }
229 funcOp->setAttr("arg_attrs", ArrayAttr::get(ctx, kept));
230 }
231 }
232
233 return success();
234}
235
237 Operation *callOp, const FunctionClassification &fc, OpBuilder &builder) {
238 if (!needsRewrite(fc))
239 return success();
240
241 if (isa<cir::TryCallOp>(callOp))
242 return callOp->emitOpError()
243 << "TryCallOp not yet implemented in CallConvLowering";
244
245 auto call = cast<cir::CallOp>(callOp);
246 if (call.isIndirect())
247 return call.emitOpError()
248 << "indirect call not yet implemented in CallConvLowering";
249
250 for (auto [idx, ac] : llvm::enumerate(fc.argInfos)) {
251 switch (ac.kind) {
252 case ArgKind::Direct:
253 if (ac.coercedType)
254 return call.emitOpError()
255 << "Direct with coerced type at call-site arg " << idx
256 << " not yet implemented in CallConvLowering";
257 break;
258 case ArgKind::Ignore:
259 break;
260 case ArgKind::Expand:
261 return call.emitOpError() << "Expand at call-site arg " << idx
262 << " not yet implemented in CallConvLowering";
263 case ArgKind::Extend:
264 return call.emitOpError() << "Extend at call-site arg " << idx
265 << " not yet implemented in CallConvLowering";
266 case ArgKind::Indirect:
267 return call.emitOpError() << "Indirect at call-site arg " << idx
268 << " not yet implemented in CallConvLowering";
269 }
270 }
271
272 SmallVector<Value> newArgs;
273 ValueRange argOperands = call.getArgOperands();
274 newArgs.reserve(argOperands.size());
275 if (argOperands.size() > fc.argInfos.size())
276 return call.emitOpError()
277 << "variadic arguments not yet implemented in CallConvLowering";
278 assert(fc.argInfos.size() == argOperands.size() &&
279 "call operand count must match classified arg count");
280 for (auto [idx, ac] : llvm::enumerate(fc.argInfos)) {
281 if (ac.kind == ArgKind::Ignore)
282 continue;
283 newArgs.push_back(argOperands[idx]);
284 }
285
286 bool hasResult = call.getNumResults() > 0;
287 Type origRetTy = hasResult ? call.getResult().getType()
288 : cir::VoidType::get(callOp->getContext());
289 Type callRetTy = origRetTy;
290 if (fc.returnInfo.kind == ArgKind::Ignore && hasResult)
291 callRetTy = cir::VoidType::get(callOp->getContext());
292 if ((fc.returnInfo.kind == ArgKind::Direct ||
293 fc.returnInfo.kind == ArgKind::Extend) &&
294 fc.returnInfo.coercedType)
295 return call.emitOpError() << "Direct/Extend return with coerced type at "
296 << "call-site not yet implemented in "
297 << "CallConvLowering";
298
299 builder.setInsertionPoint(call);
300 auto newCall = cir::CallOp::create(builder, call.getLoc(),
301 call.getCalleeAttr(), callRetTy, newArgs);
302 for (NamedAttribute attr : call->getAttrs())
303 if (!newCall->hasAttr(attr.getName()))
304 newCall->setAttr(attr.getName(), attr.getValue());
305
306 if (hasResult && fc.returnInfo.kind == ArgKind::Ignore) {
307 // The new call returns void, but the original call's result may still
308 // have uses. Substitute a poison constant of the original type so
309 // those uses remain well-formed without pretending we have a real
310 // value at the ABI boundary.
311 if (!call.getResult().use_empty()) {
312 builder.setInsertionPointAfter(newCall);
313 Value poison = createIgnoredValue(builder, call.getLoc(), origRetTy);
314 call.getResult().replaceAllUsesWith(poison);
315 }
316 } else if (hasResult) {
317 call.getResult().replaceAllUsesWith(newCall.getResult());
318 }
319
320 call->erase();
321 return success();
322}
__CUDA_BUILTIN_VAR __cuda_builtin_blockIdx_t blockIdx
mlir::LogicalResult rewriteFunctionDefinition(mlir::FunctionOpInterface funcOp, const mlir::abi::FunctionClassification &fc, mlir::OpBuilder &builder) override
mlir::LogicalResult rewriteCallSite(mlir::Operation *callOp, const mlir::abi::FunctionClassification &fc, mlir::OpBuilder &builder) override