clang 23.0.0git
TargetLowering.cpp
Go to the documentation of this file.
1//===- TargetLowering.cpp -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the cir-target-lowering pass.
10//
11//===----------------------------------------------------------------------===//
12
15
16#include "mlir/IR/PatternMatch.h"
17#include "mlir/Support/LLVM.h"
18#include "mlir/Transforms/DialectConversion.h"
23
24using namespace mlir;
25using namespace cir;
26
27namespace mlir {
28#define GEN_PASS_DEF_TARGETLOWERING
29#include "clang/CIR/Dialect/Passes.h.inc"
30} // namespace mlir
31
32namespace {
33
34struct TargetLoweringPass
35 : public impl::TargetLoweringBase<TargetLoweringPass> {
36 TargetLoweringPass() = default;
37 void runOnOperation() override;
38};
39
40/// A generic target lowering pattern that matches any CIR op whose operand or
41/// result types need address space conversion. Clones the op with converted
42/// types.
43class CIRGenericTargetLoweringPattern : public mlir::ConversionPattern {
44public:
45 CIRGenericTargetLoweringPattern(mlir::MLIRContext *context,
46 const mlir::TypeConverter &typeConverter)
47 : mlir::ConversionPattern(typeConverter, MatchAnyOpTypeTag(),
48 /*benefit=*/1, context) {}
49
50 mlir::LogicalResult
51 matchAndRewrite(mlir::Operation *op, llvm::ArrayRef<mlir::Value> operands,
52 mlir::ConversionPatternRewriter &rewriter) const override {
53 // Do not match on operations that have dedicated lowering patterns.
54 if (llvm::isa<cir::FuncOp, cir::GlobalOp>(op))
55 return mlir::failure();
56
57 const mlir::TypeConverter *typeConverter = getTypeConverter();
58 assert(typeConverter &&
59 "CIRGenericTargetLoweringPattern requires a type converter");
60 bool operandsAndResultsLegal = typeConverter->isLegal(op);
61 bool regionsLegal =
62 std::all_of(op->getRegions().begin(), op->getRegions().end(),
63 [typeConverter](mlir::Region &region) {
64 return typeConverter->isLegal(&region);
65 });
66 if (operandsAndResultsLegal && regionsLegal)
67 return mlir::failure();
68
69 assert(op->getNumRegions() == 0 && "CIRGenericTargetLoweringPattern cannot "
70 "deal with operations with regions");
71
72 mlir::OperationState loweredOpState(op->getLoc(), op->getName());
73 loweredOpState.addOperands(operands);
74
75 // Copy attributes, converting any TypeAttr through the type converter so
76 // that address-space-bearing types (e.g. AllocaOp's allocaType) stay in
77 // sync with the converted result types.
78 for (mlir::NamedAttribute attr : op->getAttrs()) {
79 if (auto typeAttr = mlir::dyn_cast<mlir::TypeAttr>(attr.getValue())) {
80 mlir::Type converted = typeConverter->convertType(typeAttr.getValue());
81 loweredOpState.addAttribute(attr.getName(),
82 mlir::TypeAttr::get(converted));
83 } else {
84 loweredOpState.addAttribute(attr.getName(), attr.getValue());
85 }
86 }
87
88 loweredOpState.addSuccessors(op->getSuccessors());
89
90 llvm::SmallVector<mlir::Type> loweredResultTypes;
91 loweredResultTypes.reserve(op->getNumResults());
92 for (mlir::Type result : op->getResultTypes())
93 loweredResultTypes.push_back(typeConverter->convertType(result));
94 loweredOpState.addTypes(loweredResultTypes);
95
96 for (mlir::Region &region : op->getRegions()) {
97 mlir::Region *loweredRegion = loweredOpState.addRegion();
98 rewriter.inlineRegionBefore(region, *loweredRegion, loweredRegion->end());
99 if (mlir::failed(
100 rewriter.convertRegionTypes(loweredRegion, *getTypeConverter())))
101 return mlir::failure();
102 }
103
104 mlir::Operation *loweredOp = rewriter.create(loweredOpState);
105 rewriter.replaceOp(op, loweredOp);
106 return mlir::success();
107 }
108};
109
110/// Pattern to lower GlobalOp address space attributes. GlobalOp carries
111/// addr_space as a standalone attribute (not inside a type), so the
112/// TypeConverter won't reach it automatically.
113class CIRGlobalOpTargetLowering
114 : public mlir::OpConversionPattern<cir::GlobalOp> {
115 const cir::TargetLoweringInfo &targetInfo;
116
117public:
118 CIRGlobalOpTargetLowering(mlir::MLIRContext *context,
119 const mlir::TypeConverter &typeConverter,
120 const cir::TargetLoweringInfo &targetInfo)
121 : mlir::OpConversionPattern<cir::GlobalOp>(typeConverter, context,
122 /*benefit=*/1),
123 targetInfo(targetInfo) {}
124
125 mlir::LogicalResult
126 matchAndRewrite(cir::GlobalOp op, OpAdaptor adaptor,
127 mlir::ConversionPatternRewriter &rewriter) const override {
128 mlir::Type loweredSymTy = getTypeConverter()->convertType(op.getSymType());
129 if (!loweredSymTy)
130 return mlir::failure();
131
132 // Convert the addr_space attribute.
133 mlir::ptr::MemorySpaceAttrInterface addrSpace = op.getAddrSpaceAttr();
134 if (auto langAS =
135 mlir::dyn_cast_if_present<cir::LangAddressSpaceAttr>(addrSpace)) {
136 unsigned targetAS =
137 targetInfo.getTargetAddrSpaceFromCIRAddrSpace(langAS.getValue());
138 addrSpace =
139 targetAS == 0
140 ? nullptr
141 : cir::TargetAddressSpaceAttr::get(op.getContext(), targetAS);
142 }
143
144 // Only rewrite if something actually changed.
145 if (loweredSymTy == op.getSymType() && addrSpace == op.getAddrSpaceAttr())
146 return mlir::failure();
147
148 auto newOp = mlir::cast<cir::GlobalOp>(rewriter.clone(*op.getOperation()));
149 newOp.setSymType(loweredSymTy);
150 newOp.setAddrSpaceAttr(addrSpace);
151 rewriter.replaceOp(op, newOp);
152 return mlir::success();
153 }
154};
155
156/// Pattern to lower FuncOp types that contain address spaces.
157class CIRFuncOpTargetLowering : public mlir::OpConversionPattern<cir::FuncOp> {
158public:
159 using mlir::OpConversionPattern<cir::FuncOp>::OpConversionPattern;
160
161 mlir::LogicalResult
162 matchAndRewrite(cir::FuncOp op, OpAdaptor adaptor,
163 mlir::ConversionPatternRewriter &rewriter) const override {
164 cir::FuncType opFuncType = op.getFunctionType();
165 mlir::TypeConverter::SignatureConversion signatureConversion(
166 opFuncType.getNumInputs());
167
168 for (const auto &[i, argType] : llvm::enumerate(opFuncType.getInputs())) {
169 mlir::Type loweredArgType = getTypeConverter()->convertType(argType);
170 if (!loweredArgType)
171 return mlir::failure();
172 signatureConversion.addInputs(i, loweredArgType);
173 }
174
175 mlir::Type loweredReturnType =
176 getTypeConverter()->convertType(opFuncType.getReturnType());
177 if (!loweredReturnType)
178 return mlir::failure();
179
180 auto loweredFuncType = cir::FuncType::get(
181 signatureConversion.getConvertedTypes(), loweredReturnType,
182 /*isVarArg=*/opFuncType.getVarArg());
183
184 // Nothing changed, skip.
185 if (loweredFuncType == opFuncType)
186 return mlir::failure();
187
188 cir::FuncOp loweredFuncOp = rewriter.cloneWithoutRegions(op);
189 loweredFuncOp.setFunctionType(loweredFuncType);
190 rewriter.inlineRegionBefore(op.getBody(), loweredFuncOp.getBody(),
191 loweredFuncOp.end());
192 if (mlir::failed(rewriter.convertRegionTypes(&loweredFuncOp.getBody(),
193 *getTypeConverter(),
194 &signatureConversion)))
195 return mlir::failure();
196
197 rewriter.eraseOp(op);
198 return mlir::success();
199 }
200};
201
202} // namespace
203
204static void convertSyncScopeIfPresent(mlir::Operation *op,
205 cir::LowerModule &lowerModule) {
206 auto syncScopeAttr =
207 mlir::cast_if_present<cir::SyncScopeKindAttr>(op->getAttr("sync_scope"));
208 if (syncScopeAttr) {
209 cir::SyncScopeKind convertedSyncScope =
211 syncScopeAttr.getValue());
212 op->setAttr("sync_scope", cir::SyncScopeKindAttr::get(op->getContext(),
213 convertedSyncScope));
214 }
215}
216
217/// Prepare the type converter for the target lowering pass.
218/// Converts LangAddressSpaceAttr → TargetAddressSpaceAttr inside pointer types.
219static void
220prepareTargetLoweringTypeConverter(mlir::TypeConverter &converter,
221 const cir::TargetLoweringInfo &targetInfo) {
222 converter.addConversion([](mlir::Type type) { return type; });
223
224 converter.addConversion([&converter,
225 &targetInfo](cir::PointerType type) -> mlir::Type {
226 mlir::Type pointee = converter.convertType(type.getPointee());
227 if (!pointee)
228 return {};
229 auto addrSpace = type.getAddrSpace();
230 if (auto langAS =
231 mlir::dyn_cast_if_present<cir::LangAddressSpaceAttr>(addrSpace)) {
232 unsigned targetAS =
233 targetInfo.getTargetAddrSpaceFromCIRAddrSpace(langAS.getValue());
234 addrSpace =
235 targetAS == 0
236 ? nullptr
237 : cir::TargetAddressSpaceAttr::get(type.getContext(), targetAS);
238 }
239 return cir::PointerType::get(type.getContext(), pointee, addrSpace);
240 });
241
242 converter.addConversion([&converter](cir::ArrayType type) -> mlir::Type {
243 mlir::Type loweredElementType =
244 converter.convertType(type.getElementType());
245 if (!loweredElementType)
246 return {};
247 return cir::ArrayType::get(loweredElementType, type.getSize());
248 });
249
250 converter.addConversion([&converter](cir::FuncType type) -> mlir::Type {
251 llvm::SmallVector<mlir::Type> loweredInputTypes;
252 loweredInputTypes.reserve(type.getNumInputs());
253 if (mlir::failed(
254 converter.convertTypes(type.getInputs(), loweredInputTypes)))
255 return {};
256
257 mlir::Type loweredReturnType = converter.convertType(type.getReturnType());
258 if (!loweredReturnType)
259 return {};
260
261 return cir::FuncType::get(loweredInputTypes, loweredReturnType,
262 /*isVarArg=*/type.getVarArg());
263 });
264}
265
266static void
267populateTargetLoweringConversionTarget(mlir::ConversionTarget &target,
268 const mlir::TypeConverter &tc) {
269 target.addLegalOp<mlir::ModuleOp>();
270
271 target.addDynamicallyLegalDialect<cir::CIRDialect>(
272 [&tc](mlir::Operation *op) {
273 if (!tc.isLegal(op))
274 return false;
275 return std::all_of(
276 op->getRegions().begin(), op->getRegions().end(),
277 [&tc](mlir::Region &region) { return tc.isLegal(&region); });
278 });
279
280 target.addDynamicallyLegalOp<cir::FuncOp>(
281 [&tc](cir::FuncOp op) { return tc.isLegal(op.getFunctionType()); });
282
283 target.addDynamicallyLegalOp<cir::GlobalOp>([&tc](cir::GlobalOp op) {
284 if (!tc.isLegal(op.getSymType()))
285 return false;
286 return !mlir::isa_and_present<cir::LangAddressSpaceAttr>(
287 op.getAddrSpaceAttr());
288 });
289}
290
291void TargetLoweringPass::runOnOperation() {
292 auto mod = mlir::cast<mlir::ModuleOp>(getOperation());
293 std::unique_ptr<cir::LowerModule> lowerModule = cir::createLowerModule(mod);
294 // If lower module is not available, skip the target lowering pass.
295 if (!lowerModule) {
296 mod.emitWarning("Cannot create a CIR lower module, skipping the ")
297 << getName() << " pass";
298 return;
299 }
300
301 const auto &targetInfo = lowerModule->getTargetLoweringInfo();
302
303 mod->walk([&](mlir::Operation *op) {
304 if (mlir::isa<cir::LoadOp, cir::StoreOp, cir::AtomicXchgOp,
305 cir::AtomicCmpXchgOp, cir::AtomicFetchOp>(op))
306 convertSyncScopeIfPresent(op, *lowerModule);
307 });
308
309 // Address space conversion: LangAddressSpaceAttr → TargetAddressSpaceAttr.
310 mlir::TypeConverter typeConverter;
311 prepareTargetLoweringTypeConverter(typeConverter, targetInfo);
312
313 mlir::RewritePatternSet patterns(mod.getContext());
314 patterns.add<CIRGlobalOpTargetLowering>(mod.getContext(), typeConverter,
315 targetInfo);
316 patterns.add<CIRFuncOpTargetLowering>(typeConverter, mod.getContext());
317 patterns.add<CIRGenericTargetLoweringPattern>(mod.getContext(),
318 typeConverter);
319
320 mlir::ConversionTarget target(*mod.getContext());
321 populateTargetLoweringConversionTarget(target, typeConverter);
322
323 if (failed(mlir::applyPartialConversion(mod, target, std::move(patterns))))
324 signalPassFailure();
325}
326
327std::unique_ptr<Pass> mlir::createTargetLoweringPass() {
328 return std::make_unique<TargetLoweringPass>();
329}
static void populateTargetLoweringConversionTarget(mlir::ConversionTarget &target, const mlir::TypeConverter &tc)
static void prepareTargetLoweringTypeConverter(mlir::TypeConverter &converter, const cir::TargetLoweringInfo &targetInfo)
Prepare the type converter for the target lowering pass.
static void convertSyncScopeIfPresent(mlir::Operation *op, cir::LowerModule &lowerModule)
const TargetLoweringInfo & getTargetLoweringInfo()
virtual unsigned getTargetAddrSpaceFromCIRAddrSpace(cir::LangAddressSpace addrSpace) const
virtual cir::SyncScopeKind convertSyncScope(cir::SyncScopeKind syncScope) const
std::unique_ptr< LowerModule > createLowerModule(mlir::ModuleOp module)
const internal::VariadicAllOfMatcher< Attr > attr
StringRef getName(const HeaderType T)
Definition HeaderFile.h:38
nullptr
This class represents a compute construct, representing a 'Kind' of ‘parallel’, 'serial',...
std::unique_ptr< Pass > createTargetLoweringPass()