10#include "mlir/Dialect/Func/IR/FuncOps.h"
11#include "mlir/IR/Block.h"
12#include "mlir/IR/Operation.h"
13#include "mlir/IR/PatternMatch.h"
14#include "mlir/IR/Region.h"
15#include "mlir/Support/LogicalResult.h"
16#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19#include "llvm/ADT/SmallVector.h"
51struct SimplifyTernary final :
public OpRewritePattern<TernaryOp> {
52 using OpRewritePattern<TernaryOp>::OpRewritePattern;
54 LogicalResult matchAndRewrite(TernaryOp op,
55 PatternRewriter &rewriter)
const override {
56 if (op->getNumResults() != 1)
57 return mlir::failure();
59 if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
60 !isSimpleTernaryBranch(op.getFalseRegion()))
61 return mlir::failure();
63 cir::YieldOp trueBranchYieldOp =
64 mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
65 cir::YieldOp falseBranchYieldOp =
66 mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
67 mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
68 mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
70 rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
71 rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
72 rewriter.eraseOp(trueBranchYieldOp);
73 rewriter.eraseOp(falseBranchYieldOp);
74 rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
77 return mlir::success();
81 bool isSimpleTernaryBranch(mlir::Region ®ion)
const {
82 if (!region.hasOneBlock())
85 mlir::Block &onlyBlock = region.front();
86 mlir::Block::OpListType &ops = onlyBlock.getOperations();
92 if (ops.size() == 1) {
99 auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
100 auto yieldValueDefOp =
101 yieldOp.getArgs()[0].getDefiningOp<cir::ConstantOp>();
102 return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
124struct SimplifySelect :
public OpRewritePattern<SelectOp> {
125 using OpRewritePattern<SelectOp>::OpRewritePattern;
127 LogicalResult matchAndRewrite(SelectOp op,
128 PatternRewriter &rewriter)
const final {
129 auto trueValueOp = op.getTrueValue().getDefiningOp<cir::ConstantOp>();
130 auto falseValueOp = op.getFalseValue().getDefiningOp<cir::ConstantOp>();
131 if (!trueValueOp || !falseValueOp)
132 return mlir::failure();
134 auto trueValue = trueValueOp.getValueAttr<cir::BoolAttr>();
135 auto falseValue = falseValueOp.getValueAttr<cir::BoolAttr>();
136 if (!trueValue || !falseValue)
137 return mlir::failure();
140 if (trueValue.getValue() && !falseValue.getValue()) {
141 rewriter.replaceAllUsesWith(op, op.getCondition());
142 rewriter.eraseOp(op);
143 return mlir::success();
147 if (!trueValue.getValue() && falseValue.getValue()) {
148 rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
150 return mlir::success();
153 return mlir::failure();
189struct SimplifySwitch :
public OpRewritePattern<SwitchOp> {
190 using OpRewritePattern<SwitchOp>::OpRewritePattern;
191 LogicalResult matchAndRewrite(SwitchOp op,
192 PatternRewriter &rewriter)
const override {
194 LogicalResult changed = mlir::failure();
195 SmallVector<CaseOp, 8> cases;
196 SmallVector<CaseOp, 4> cascadingCases;
197 SmallVector<mlir::Attribute, 4> cascadingCaseValues;
199 op.collectCases(cases);
201 return mlir::failure();
203 auto flushMergedOps = [&]() {
204 for (CaseOp &
c : cascadingCases)
206 cascadingCases.clear();
207 cascadingCaseValues.clear();
210 auto mergeCascadingInto = [&](CaseOp &target) {
211 rewriter.modifyOpInPlace(target, [&]() {
212 target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
213 target.setKind(CaseOpKind::Anyof);
215 changed = mlir::success();
218 for (CaseOp
c : cases) {
219 cir::CaseOpKind
kind =
c.getKind();
220 if (kind == cir::CaseOpKind::Equal &&
221 isa<YieldOp>(
c.getCaseRegion().front().front())) {
223 cascadingCases.push_back(
c);
224 cascadingCaseValues.push_back(
c.getValue()[0]);
225 }
else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
227 cascadingCaseValues.push_back(
c.getValue()[0]);
228 mergeCascadingInto(
c);
230 }
else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
236 CaseOp lastCascadingCase = cascadingCases.back();
237 mergeCascadingInto(lastCascadingCase);
238 cascadingCases.pop_back();
241 cascadingCases.clear();
242 cascadingCaseValues.clear();
247 if (cascadingCases.size() == cases.size()) {
248 CaseOp lastCascadingCase = cascadingCases.back();
249 mergeCascadingInto(lastCascadingCase);
250 cascadingCases.pop_back();
258struct SimplifyVecSplat :
public OpRewritePattern<VecSplatOp> {
259 using OpRewritePattern<VecSplatOp>::OpRewritePattern;
260 LogicalResult matchAndRewrite(VecSplatOp op,
261 PatternRewriter &rewriter)
const override {
262 mlir::Value splatValue = op.getValue();
263 auto constant = splatValue.getDefiningOp<cir::ConstantOp>();
265 return mlir::failure();
267 auto value = constant.getValue();
268 if (!mlir::isa_and_nonnull<cir::IntAttr>(value) &&
269 !mlir::isa_and_nonnull<cir::FPAttr>(value))
270 return mlir::failure();
272 cir::VectorType resultType = op.getResult().getType();
273 SmallVector<mlir::Attribute, 16> elements(resultType.getSize(), value);
274 auto constVecAttr = cir::ConstVectorAttr::get(
275 resultType, mlir::ArrayAttr::get(getContext(), elements));
277 rewriter.replaceOpWithNewOp<cir::ConstantOp>(op, constVecAttr);
278 return mlir::success();
286struct CIRSimplifyPass :
public CIRSimplifyBase<CIRSimplifyPass> {
287 using CIRSimplifyBase::CIRSimplifyBase;
289 void runOnOperation()
override;
292void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
299 >(patterns.getContext());
303void CIRSimplifyPass::runOnOperation() {
305 RewritePatternSet patterns(&getContext());
306 populateMergeCleanupPatterns(patterns);
310 getOperation()->walk([&](Operation *op) {
311 if (isa<TernaryOp, SelectOp, SwitchOp, VecSplatOp>(op))
316 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
323 return std::make_unique<CIRSimplifyPass>();
__device__ __2f16 float c
unsigned kind
All of the diagnostics that can be emitted by the frontend.
std::unique_ptr< Pass > createCIRSimplifyPass()
static bool foldRangeCase()