16#include "mlir/IR/PatternMatch.h"
17#include "mlir/Support/LLVM.h"
18#include "mlir/Transforms/DialectConversion.h"
28#define GEN_PASS_DEF_TARGETLOWERING
29#include "clang/CIR/Dialect/Passes.h.inc"
34struct TargetLoweringPass
35 :
public impl::TargetLoweringBase<TargetLoweringPass> {
36 TargetLoweringPass() =
default;
37 void runOnOperation()
override;
43class CIRGenericTargetLoweringPattern :
public mlir::ConversionPattern {
45 CIRGenericTargetLoweringPattern(mlir::MLIRContext *context,
46 const mlir::TypeConverter &typeConverter)
47 : mlir::ConversionPattern(typeConverter, MatchAnyOpTypeTag(),
51 matchAndRewrite(mlir::Operation *op, llvm::ArrayRef<mlir::Value> operands,
52 mlir::ConversionPatternRewriter &rewriter)
const override {
54 if (llvm::isa<cir::FuncOp, cir::GlobalOp>(op))
55 return mlir::failure();
57 const mlir::TypeConverter *typeConverter = getTypeConverter();
58 assert(typeConverter &&
59 "CIRGenericTargetLoweringPattern requires a type converter");
60 bool operandsAndResultsLegal = typeConverter->isLegal(op);
62 std::all_of(op->getRegions().begin(), op->getRegions().end(),
63 [typeConverter](mlir::Region ®ion) {
64 return typeConverter->isLegal(®ion);
66 if (operandsAndResultsLegal && regionsLegal)
67 return mlir::failure();
69 assert(op->getNumRegions() == 0 &&
"CIRGenericTargetLoweringPattern cannot "
70 "deal with operations with regions");
72 mlir::OperationState loweredOpState(op->getLoc(), op->getName());
73 loweredOpState.addOperands(operands);
78 for (mlir::NamedAttribute attr : op->getAttrs()) {
79 if (
auto typeAttr = mlir::dyn_cast<mlir::TypeAttr>(
attr.getValue())) {
80 mlir::Type converted = typeConverter->convertType(typeAttr.getValue());
81 loweredOpState.addAttribute(
attr.getName(),
82 mlir::TypeAttr::get(converted));
84 loweredOpState.addAttribute(
attr.getName(),
attr.getValue());
88 loweredOpState.addSuccessors(op->getSuccessors());
90 llvm::SmallVector<mlir::Type> loweredResultTypes;
91 loweredResultTypes.reserve(op->getNumResults());
92 for (mlir::Type result : op->getResultTypes())
93 loweredResultTypes.push_back(typeConverter->convertType(result));
94 loweredOpState.addTypes(loweredResultTypes);
96 for (mlir::Region ®ion : op->getRegions()) {
97 mlir::Region *loweredRegion = loweredOpState.addRegion();
98 rewriter.inlineRegionBefore(region, *loweredRegion, loweredRegion->end());
100 rewriter.convertRegionTypes(loweredRegion, *getTypeConverter())))
101 return mlir::failure();
104 mlir::Operation *loweredOp = rewriter.create(loweredOpState);
105 rewriter.replaceOp(op, loweredOp);
106 return mlir::success();
113class CIRGlobalOpTargetLowering
114 :
public mlir::OpConversionPattern<cir::GlobalOp> {
115 const cir::TargetLoweringInfo &targetInfo;
118 CIRGlobalOpTargetLowering(mlir::MLIRContext *context,
119 const mlir::TypeConverter &typeConverter,
120 const cir::TargetLoweringInfo &targetInfo)
121 : mlir::OpConversionPattern<cir::GlobalOp>(typeConverter, context,
123 targetInfo(targetInfo) {}
126 matchAndRewrite(cir::GlobalOp op, OpAdaptor adaptor,
127 mlir::ConversionPatternRewriter &rewriter)
const override {
128 mlir::Type loweredSymTy = getTypeConverter()->convertType(op.getSymType());
130 return mlir::failure();
133 mlir::ptr::MemorySpaceAttrInterface addrSpace = op.getAddrSpaceAttr();
135 mlir::dyn_cast_if_present<cir::LangAddressSpaceAttr>(addrSpace)) {
137 targetInfo.getTargetAddrSpaceFromCIRAddrSpace(langAS.getValue());
141 : cir::TargetAddressSpaceAttr::get(op.getContext(), targetAS);
145 if (loweredSymTy == op.getSymType() && addrSpace == op.getAddrSpaceAttr())
146 return mlir::failure();
148 auto newOp = mlir::cast<cir::GlobalOp>(rewriter.clone(*op.getOperation()));
149 newOp.setSymType(loweredSymTy);
150 newOp.setAddrSpaceAttr(addrSpace);
151 rewriter.replaceOp(op, newOp);
152 return mlir::success();
157class CIRFuncOpTargetLowering :
public mlir::OpConversionPattern<cir::FuncOp> {
159 using mlir::OpConversionPattern<cir::FuncOp>::OpConversionPattern;
162 matchAndRewrite(cir::FuncOp op, OpAdaptor adaptor,
163 mlir::ConversionPatternRewriter &rewriter)
const override {
164 cir::FuncType opFuncType = op.getFunctionType();
165 mlir::TypeConverter::SignatureConversion signatureConversion(
166 opFuncType.getNumInputs());
168 for (
const auto &[i, argType] : llvm::enumerate(opFuncType.getInputs())) {
169 mlir::Type loweredArgType = getTypeConverter()->convertType(argType);
171 return mlir::failure();
172 signatureConversion.addInputs(i, loweredArgType);
175 mlir::Type loweredReturnType =
176 getTypeConverter()->convertType(opFuncType.getReturnType());
177 if (!loweredReturnType)
178 return mlir::failure();
180 auto loweredFuncType = cir::FuncType::get(
181 signatureConversion.getConvertedTypes(), loweredReturnType,
182 opFuncType.getVarArg());
185 if (loweredFuncType == opFuncType)
186 return mlir::failure();
188 cir::FuncOp loweredFuncOp = rewriter.cloneWithoutRegions(op);
189 loweredFuncOp.setFunctionType(loweredFuncType);
190 rewriter.inlineRegionBefore(op.getBody(), loweredFuncOp.getBody(),
191 loweredFuncOp.end());
192 if (mlir::failed(rewriter.convertRegionTypes(&loweredFuncOp.getBody(),
194 &signatureConversion)))
195 return mlir::failure();
197 rewriter.eraseOp(op);
198 return mlir::success();
207 mlir::cast_if_present<cir::SyncScopeKindAttr>(op->getAttr(
"sync_scope"));
209 cir::SyncScopeKind convertedSyncScope =
211 syncScopeAttr.getValue());
212 op->setAttr(
"sync_scope", cir::SyncScopeKindAttr::get(op->getContext(),
213 convertedSyncScope));
222 converter.addConversion([](mlir::Type type) {
return type; });
224 converter.addConversion([&converter,
225 &targetInfo](cir::PointerType type) -> mlir::Type {
226 mlir::Type pointee = converter.convertType(type.getPointee());
229 auto addrSpace = type.getAddrSpace();
231 mlir::dyn_cast_if_present<cir::LangAddressSpaceAttr>(addrSpace)) {
237 : cir::TargetAddressSpaceAttr::get(type.getContext(), targetAS);
239 return cir::PointerType::get(type.getContext(), pointee, addrSpace);
242 converter.addConversion([&converter](cir::ArrayType type) -> mlir::Type {
243 mlir::Type loweredElementType =
244 converter.convertType(type.getElementType());
245 if (!loweredElementType)
247 return cir::ArrayType::get(loweredElementType, type.getSize());
250 converter.addConversion([&converter](cir::FuncType type) -> mlir::Type {
252 loweredInputTypes.reserve(type.getNumInputs());
254 converter.convertTypes(type.getInputs(), loweredInputTypes)))
257 mlir::Type loweredReturnType = converter.convertType(type.getReturnType());
258 if (!loweredReturnType)
261 return cir::FuncType::get(loweredInputTypes, loweredReturnType,
268 const mlir::TypeConverter &tc) {
269 target.addLegalOp<mlir::ModuleOp>();
271 target.addDynamicallyLegalDialect<cir::CIRDialect>(
272 [&tc](mlir::Operation *op) {
276 op->getRegions().begin(), op->getRegions().end(),
277 [&tc](mlir::Region ®ion) { return tc.isLegal(®ion); });
280 target.addDynamicallyLegalOp<cir::FuncOp>(
281 [&tc](cir::FuncOp op) {
return tc.isLegal(op.getFunctionType()); });
283 target.addDynamicallyLegalOp<cir::GlobalOp>([&tc](cir::GlobalOp op) {
284 if (!tc.isLegal(op.getSymType()))
286 return !mlir::isa_and_present<cir::LangAddressSpaceAttr>(
287 op.getAddrSpaceAttr());
291void TargetLoweringPass::runOnOperation() {
292 auto mod = mlir::cast<mlir::ModuleOp>(getOperation());
296 mod.emitWarning(
"Cannot create a CIR lower module, skipping the ")
301 const auto &targetInfo = lowerModule->getTargetLoweringInfo();
303 mod->walk([&](mlir::Operation *op) {
304 if (mlir::isa<cir::LoadOp, cir::StoreOp, cir::AtomicXchgOp,
305 cir::AtomicCmpXchgOp, cir::AtomicFetchOp>(op))
310 mlir::TypeConverter typeConverter;
313 mlir::RewritePatternSet patterns(mod.getContext());
314 patterns.add<CIRGlobalOpTargetLowering>(mod.getContext(), typeConverter,
316 patterns.add<CIRFuncOpTargetLowering>(typeConverter, mod.getContext());
317 patterns.add<CIRGenericTargetLoweringPattern>(mod.getContext(),
320 mlir::ConversionTarget target(*mod.getContext());
323 if (failed(mlir::applyPartialConversion(mod, target, std::move(patterns))))
328 return std::make_unique<TargetLoweringPass>();
static void populateTargetLoweringConversionTarget(mlir::ConversionTarget &target, const mlir::TypeConverter &tc)
static void prepareTargetLoweringTypeConverter(mlir::TypeConverter &converter, const cir::TargetLoweringInfo &targetInfo)
Prepare the type converter for the target lowering pass.
static void convertSyncScopeIfPresent(mlir::Operation *op, cir::LowerModule &lowerModule)
const TargetLoweringInfo & getTargetLoweringInfo()
virtual unsigned getTargetAddrSpaceFromCIRAddrSpace(cir::LangAddressSpace addrSpace) const
virtual cir::SyncScopeKind convertSyncScope(cir::SyncScopeKind syncScope) const
std::unique_ptr< LowerModule > createLowerModule(mlir::ModuleOp module)
const internal::VariadicAllOfMatcher< Attr > attr
StringRef getName(const HeaderType T)
nullptr
This class represents a compute construct, representing a 'Kind' of ‘parallel’, 'serial',...
std::unique_ptr< Pass > createTargetLoweringPass()