12#include "mlir/IR/PatternMatch.h"
13#include "mlir/Interfaces/DataLayoutInterfaces.h"
14#include "mlir/Pass/Pass.h"
15#include "mlir/Transforms/DialectConversion.h"
28#define GEN_PASS_DEF_CXXABILOWERING
29#include "clang/CIR/Dialect/Passes.h.inc"
35class CIROpCXXABILoweringPattern :
public mlir::OpConversionPattern<Op> {
37 mlir::DataLayout *dataLayout;
38 cir::LowerModule *lowerModule;
41 CIROpCXXABILoweringPattern(mlir::MLIRContext *context,
42 const mlir::TypeConverter &typeConverter,
43 mlir::DataLayout &dataLayout,
44 cir::LowerModule &lowerModule)
45 : mlir::OpConversionPattern<Op>(typeConverter, context),
46 dataLayout(&dataLayout), lowerModule(&lowerModule) {}
50#define CIR_CXXABI_LOWERING_PATTERN(name, operation) \
51 struct name : CIROpCXXABILoweringPattern<operation> { \
52 using CIROpCXXABILoweringPattern<operation>::CIROpCXXABILoweringPattern; \
55 matchAndRewrite(operation op, OpAdaptor adaptor, \
56 mlir::ConversionPatternRewriter &rewriter) const override; \
63 cir::GetRuntimeMemberOp);
65#undef CIR_CXXABI_LOWERING_PATTERN
67struct CXXABILoweringPass
68 :
public impl::CXXABILoweringBase<CXXABILoweringPass> {
69 CXXABILoweringPass() =
default;
70 void runOnOperation()
override;
78class CIRGenericCXXABILoweringPattern :
public mlir::ConversionPattern {
80 CIRGenericCXXABILoweringPattern(mlir::MLIRContext *context,
81 const mlir::TypeConverter &typeConverter)
82 : mlir::ConversionPattern(typeConverter, MatchAnyOpTypeTag(),
86 matchAndRewrite(mlir::Operation *op, llvm::ArrayRef<mlir::Value> operands,
87 mlir::ConversionPatternRewriter &rewriter)
const override {
89 if (llvm::isa<cir::AllocaOp, cir::ConstantOp, cir::FuncOp,
90 cir::GetRuntimeMemberOp, cir::GlobalOp>(op))
91 return mlir::failure();
93 const mlir::TypeConverter *typeConverter = getTypeConverter();
94 assert(typeConverter &&
95 "CIRGenericCXXABILoweringPattern requires a type converter");
96 bool operandsAndResultsLegal = typeConverter->isLegal(op);
98 std::all_of(op->getRegions().begin(), op->getRegions().end(),
99 [typeConverter](mlir::Region ®ion) {
100 return typeConverter->isLegal(®ion);
102 if (operandsAndResultsLegal && regionsLegal) {
105 return mlir::failure();
108 assert(op->getNumRegions() == 0 &&
"CIRGenericCXXABILoweringPattern cannot "
109 "deal with operations with regions");
111 mlir::OperationState loweredOpState(op->getLoc(), op->getName());
112 loweredOpState.addOperands(operands);
113 loweredOpState.addAttributes(op->getAttrs());
114 loweredOpState.addSuccessors(op->getSuccessors());
117 llvm::SmallVector<mlir::Type> loweredResultTypes;
118 loweredResultTypes.reserve(op->getNumResults());
119 for (mlir::Type result : op->getResultTypes())
120 loweredResultTypes.push_back(typeConverter->convertType(result));
121 loweredOpState.addTypes(loweredResultTypes);
124 for (mlir::Region ®ion : op->getRegions()) {
125 mlir::Region *loweredRegion = loweredOpState.addRegion();
126 rewriter.inlineRegionBefore(region, *loweredRegion, loweredRegion->end());
128 rewriter.convertRegionTypes(loweredRegion, *getTypeConverter())))
129 return mlir::failure();
133 mlir::Operation *loweredOp = rewriter.create(loweredOpState);
135 rewriter.replaceOp(op, loweredOp);
136 return mlir::success();
142mlir::LogicalResult CIRAllocaOpABILowering::matchAndRewrite(
143 cir::AllocaOp op, OpAdaptor adaptor,
144 mlir::ConversionPatternRewriter &rewriter)
const {
145 mlir::Type allocaPtrTy = op.getType();
146 mlir::Type allocaTy = op.getAllocaType();
147 mlir::Type loweredAllocaPtrTy = getTypeConverter()->convertType(allocaPtrTy);
148 mlir::Type loweredAllocaTy = getTypeConverter()->convertType(allocaTy);
150 cir::AllocaOp loweredOp = cir::AllocaOp::create(
151 rewriter, op.getLoc(), loweredAllocaPtrTy, loweredAllocaTy, op.getName(),
152 op.getAlignmentAttr(), adaptor.getDynAllocSize());
153 loweredOp.setInit(op.getInit());
154 loweredOp.setConstant(op.getConstant());
155 loweredOp.setAnnotationsAttr(op.getAnnotationsAttr());
157 rewriter.replaceOp(op, loweredOp);
158 return mlir::success();
161mlir::LogicalResult CIRConstantOpABILowering::matchAndRewrite(
162 cir::ConstantOp op, OpAdaptor adaptor,
163 mlir::ConversionPatternRewriter &rewriter)
const {
165 if (mlir::isa<cir::DataMemberType>(op.getType())) {
166 auto dataMember = mlir::cast<cir::DataMemberAttr>(op.getValue());
167 mlir::DataLayout layout(op->getParentOfType<mlir::ModuleOp>());
168 mlir::TypedAttr abiValue = lowerModule->getCXXABI().lowerDataMemberConstant(
169 dataMember, layout, *getTypeConverter());
170 rewriter.replaceOpWithNewOp<ConstantOp>(op, abiValue);
171 return mlir::success();
174 llvm_unreachable(
"constant operand is not an CXXABI-dependent type");
177mlir::LogicalResult CIRFuncOpABILowering::matchAndRewrite(
178 cir::FuncOp op, OpAdaptor adaptor,
179 mlir::ConversionPatternRewriter &rewriter)
const {
180 cir::FuncType opFuncType = op.getFunctionType();
181 mlir::TypeConverter::SignatureConversion signatureConversion(
182 opFuncType.getNumInputs());
184 for (
const auto &[i, argType] : llvm::enumerate(opFuncType.getInputs())) {
185 mlir::Type loweredArgType = getTypeConverter()->convertType(argType);
187 return mlir::failure();
188 signatureConversion.addInputs(i, loweredArgType);
191 mlir::Type loweredResultType =
192 getTypeConverter()->convertType(opFuncType.getReturnType());
193 if (!loweredResultType)
194 return mlir::failure();
196 auto loweredFuncType =
197 cir::FuncType::get(signatureConversion.getConvertedTypes(),
198 loweredResultType, opFuncType.isVarArg());
201 cir::FuncOp loweredFuncOp = rewriter.cloneWithoutRegions(op);
202 loweredFuncOp.setFunctionType(loweredFuncType);
203 rewriter.inlineRegionBefore(op.getBody(), loweredFuncOp.getBody(),
204 loweredFuncOp.end());
205 if (mlir::failed(rewriter.convertRegionTypes(
206 &loweredFuncOp.getBody(), *getTypeConverter(), &signatureConversion)))
207 return mlir::failure();
209 rewriter.eraseOp(op);
210 return mlir::success();
213mlir::LogicalResult CIRGlobalOpABILowering::matchAndRewrite(
214 cir::GlobalOp op, OpAdaptor adaptor,
215 mlir::ConversionPatternRewriter &rewriter)
const {
216 mlir::Type ty = op.getSymType();
217 mlir::Type loweredTy = getTypeConverter()->convertType(ty);
219 return mlir::failure();
221 mlir::DataLayout layout(op->getParentOfType<mlir::ModuleOp>());
223 mlir::Attribute loweredInit;
224 if (mlir::isa<cir::DataMemberType>(ty)) {
225 cir::DataMemberAttr init =
226 mlir::cast_if_present<cir::DataMemberAttr>(op.getInitialValueAttr());
227 loweredInit = lowerModule->getCXXABI().lowerDataMemberConstant(
228 init, layout, *getTypeConverter());
231 "inputs to cir.global in ABI lowering must be data member or method");
234 auto newOp = mlir::cast<cir::GlobalOp>(rewriter.clone(*op.getOperation()));
235 newOp.setInitialValueAttr(loweredInit);
236 newOp.setSymType(loweredTy);
237 rewriter.replaceOp(op, newOp);
238 return mlir::success();
241mlir::LogicalResult CIRGetRuntimeMemberOpABILowering::matchAndRewrite(
242 cir::GetRuntimeMemberOp op, OpAdaptor adaptor,
243 mlir::ConversionPatternRewriter &rewriter)
const {
244 mlir::Type resTy = getTypeConverter()->convertType(op.getType());
245 mlir::Operation *newOp = lowerModule->getCXXABI().lowerGetRuntimeMember(
246 op, resTy, adaptor.getAddr(), adaptor.getMember(), rewriter);
247 rewriter.replaceOp(op, newOp);
248 return mlir::success();
254 mlir::DataLayout &dataLayout,
256 converter.addConversion([&](mlir::Type type) -> mlir::Type {
return type; });
259 converter.addConversion([&](cir::PointerType type) -> mlir::Type {
260 mlir::Type loweredPointeeType = converter.convertType(type.getPointee());
261 if (!loweredPointeeType)
263 return cir::PointerType::get(type.getContext(), loweredPointeeType,
264 type.getAddrSpace());
266 converter.addConversion([&](cir::DataMemberType type) -> mlir::Type {
269 return converter.convertType(abiType);
273 converter.addConversion([&](cir::FuncType type) -> mlir::Type {
275 loweredInputTypes.reserve(type.getNumInputs());
277 converter.convertTypes(type.getInputs(), loweredInputTypes)))
280 mlir::Type loweredReturnType = converter.convertType(type.getReturnType());
281 if (!loweredReturnType)
284 return cir::FuncType::get(loweredInputTypes, loweredReturnType,
291 const mlir::TypeConverter &typeConverter) {
292 target.addLegalOp<mlir::ModuleOp>();
297 target.addDynamicallyLegalDialect<cir::CIRDialect>(
298 [&typeConverter](mlir::Operation *op) {
299 if (!typeConverter.isLegal(op))
301 return std::all_of(op->getRegions().begin(), op->getRegions().end(),
302 [&typeConverter](mlir::Region ®ion) {
303 return typeConverter.isLegal(®ion);
308 target.addDynamicallyLegalOp<cir::FuncOp>([&typeConverter](cir::FuncOp op) {
309 return typeConverter.isLegal(op.getFunctionType());
311 target.addDynamicallyLegalOp<cir::GlobalOp>(
312 [&typeConverter](cir::GlobalOp op) {
313 return typeConverter.isLegal(op.getSymType());
321void CXXABILoweringPass::runOnOperation() {
322 auto module = mlir::cast<mlir::ModuleOp>(getOperation());
323 mlir::MLIRContext *ctx =
module.getContext();
329 if (!module->hasAttr(cir::CIRDialect::getTripleAttrName()))
332 mlir::PatternRewriter rewriter(ctx);
333 std::unique_ptr<cir::LowerModule> lowerModule =
336 mlir::DataLayout dataLayout(module);
337 mlir::TypeConverter typeConverter;
340 mlir::RewritePatternSet patterns(ctx);
341 patterns.add<CIRGenericCXXABILoweringPattern>(patterns.getContext(),
345 CIRAllocaOpABILowering,
346 CIRConstantOpABILowering,
347 CIRFuncOpABILowering,
348 CIRGetRuntimeMemberOpABILowering,
349 CIRGlobalOpABILowering
351 >(patterns.getContext(), typeConverter, dataLayout, *lowerModule);
353 mlir::ConversionTarget target(*ctx);
356 if (failed(mlir::applyPartialConversion(module, target, std::move(patterns))))
361 return std::make_unique<CXXABILoweringPass>();
static void populateCXXABIConversionTarget(mlir::ConversionTarget &target, const mlir::TypeConverter &typeConverter)
static void prepareCXXABITypeConverter(mlir::TypeConverter &converter, mlir::DataLayout &dataLayout, cir::LowerModule &lowerModule)
#define CIR_CXXABI_LOWERING_PATTERN(name, operation)
virtual mlir::Type lowerDataMemberType(cir::DataMemberType type, const mlir::TypeConverter &typeConverter) const =0
Lower the given data member pointer type to its ABI type.
CIRCXXABI & getCXXABI() const
Defines the clang::TargetInfo interface.
std::unique_ptr< LowerModule > createLowerModule(mlir::ModuleOp module, mlir::PatternRewriter &rewriter)
std::unique_ptr< Pass > createCXXABILoweringPass()
static bool makeTripleAlwaysPresent()