10#include "mlir/Dialect/Func/IR/FuncOps.h"
11#include "mlir/IR/Block.h"
12#include "mlir/IR/Operation.h"
13#include "mlir/IR/PatternMatch.h"
14#include "mlir/IR/Region.h"
15#include "mlir/Support/LogicalResult.h"
16#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19#include "llvm/ADT/SmallVector.h"
25#define GEN_PASS_DEF_CIRSIMPLIFY
26#include "clang/CIR/Dialect/Passes.h.inc"
56struct SimplifyTernary final :
public OpRewritePattern<TernaryOp> {
57 using OpRewritePattern<TernaryOp>::OpRewritePattern;
59 LogicalResult matchAndRewrite(TernaryOp op,
60 PatternRewriter &rewriter)
const override {
61 if (op->getNumResults() != 1)
62 return mlir::failure();
64 if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
65 !isSimpleTernaryBranch(op.getFalseRegion()))
66 return mlir::failure();
68 cir::YieldOp trueBranchYieldOp =
69 mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
70 cir::YieldOp falseBranchYieldOp =
71 mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
72 mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
73 mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
75 rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
76 rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
77 rewriter.eraseOp(trueBranchYieldOp);
78 rewriter.eraseOp(falseBranchYieldOp);
79 rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
82 return mlir::success();
86 bool isSimpleTernaryBranch(mlir::Region ®ion)
const {
87 if (!region.hasOneBlock())
90 mlir::Block &onlyBlock = region.front();
91 mlir::Block::OpListType &ops = onlyBlock.getOperations();
97 if (ops.size() == 1) {
104 auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
105 auto yieldValueDefOp =
106 yieldOp.getArgs()[0].getDefiningOp<cir::ConstantOp>();
107 return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
129struct SimplifySelect :
public OpRewritePattern<SelectOp> {
130 using OpRewritePattern<SelectOp>::OpRewritePattern;
132 LogicalResult matchAndRewrite(SelectOp op,
133 PatternRewriter &rewriter)
const final {
134 auto trueValueOp = op.getTrueValue().getDefiningOp<cir::ConstantOp>();
135 auto falseValueOp = op.getFalseValue().getDefiningOp<cir::ConstantOp>();
136 if (!trueValueOp || !falseValueOp)
137 return mlir::failure();
139 auto trueValue = trueValueOp.getValueAttr<cir::BoolAttr>();
140 auto falseValue = falseValueOp.getValueAttr<cir::BoolAttr>();
141 if (!trueValue || !falseValue)
142 return mlir::failure();
145 if (trueValue.getValue() && !falseValue.getValue()) {
146 rewriter.replaceAllUsesWith(op, op.getCondition());
147 rewriter.eraseOp(op);
148 return mlir::success();
152 if (!trueValue.getValue() && falseValue.getValue()) {
153 rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
155 return mlir::success();
158 return mlir::failure();
194struct SimplifySwitch :
public OpRewritePattern<SwitchOp> {
195 using OpRewritePattern<SwitchOp>::OpRewritePattern;
196 LogicalResult matchAndRewrite(SwitchOp op,
197 PatternRewriter &rewriter)
const override {
199 LogicalResult changed = mlir::failure();
200 SmallVector<CaseOp, 8> cases;
201 SmallVector<CaseOp, 4> cascadingCases;
202 SmallVector<mlir::Attribute, 4> cascadingCaseValues;
204 op.collectCases(cases);
206 return mlir::failure();
208 auto flushMergedOps = [&]() {
209 for (CaseOp &
c : cascadingCases)
211 cascadingCases.clear();
212 cascadingCaseValues.clear();
215 auto mergeCascadingInto = [&](CaseOp &target) {
216 rewriter.modifyOpInPlace(target, [&]() {
217 target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
218 target.setKind(CaseOpKind::Anyof);
220 changed = mlir::success();
223 for (CaseOp
c : cases) {
224 cir::CaseOpKind
kind =
c.getKind();
225 if (kind == cir::CaseOpKind::Equal &&
226 isa<YieldOp>(
c.getCaseRegion().front().front())) {
228 cascadingCases.push_back(
c);
229 cascadingCaseValues.push_back(
c.getValue()[0]);
230 }
else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
232 cascadingCaseValues.push_back(
c.getValue()[0]);
233 mergeCascadingInto(
c);
235 }
else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
241 CaseOp lastCascadingCase = cascadingCases.back();
242 mergeCascadingInto(lastCascadingCase);
243 cascadingCases.pop_back();
246 cascadingCases.clear();
247 cascadingCaseValues.clear();
252 if (cascadingCases.size() == cases.size()) {
253 CaseOp lastCascadingCase = cascadingCases.back();
254 mergeCascadingInto(lastCascadingCase);
255 cascadingCases.pop_back();
263struct SimplifyVecSplat :
public OpRewritePattern<VecSplatOp> {
264 using OpRewritePattern<VecSplatOp>::OpRewritePattern;
265 LogicalResult matchAndRewrite(VecSplatOp op,
266 PatternRewriter &rewriter)
const override {
267 mlir::Value splatValue = op.getValue();
268 auto constant = splatValue.getDefiningOp<cir::ConstantOp>();
270 return mlir::failure();
272 auto value = constant.getValue();
273 if (!mlir::isa_and_nonnull<cir::IntAttr>(value) &&
274 !mlir::isa_and_nonnull<cir::FPAttr>(value))
275 return mlir::failure();
277 cir::VectorType resultType = op.getResult().getType();
278 SmallVector<mlir::Attribute, 16> elements(resultType.getSize(), value);
279 auto constVecAttr = cir::ConstVectorAttr::get(
280 resultType, mlir::ArrayAttr::get(getContext(), elements));
282 rewriter.replaceOpWithNewOp<cir::ConstantOp>(op, constVecAttr);
283 return mlir::success();
291struct CIRSimplifyPass :
public impl::CIRSimplifyBase<CIRSimplifyPass> {
292 using CIRSimplifyBase::CIRSimplifyBase;
294 void runOnOperation()
override;
297void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
304 >(patterns.getContext());
308void CIRSimplifyPass::runOnOperation() {
310 RewritePatternSet patterns(&getContext());
311 populateMergeCleanupPatterns(patterns);
314 llvm::SmallVector<Operation *, 16> ops;
315 getOperation()->walk([&](Operation *op) {
316 if (isa<TernaryOp, SelectOp, SwitchOp, VecSplatOp>(op))
321 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
328 return std::make_unique<CIRSimplifyPass>();
__device__ __2f16 float c
unsigned kind
All of the diagnostics that can be emitted by the frontend.
std::unique_ptr< Pass > createCIRSimplifyPass()
static bool foldRangeCase()