10#include "mlir/Dialect/Func/IR/FuncOps.h"
11#include "mlir/IR/Block.h"
12#include "mlir/IR/Operation.h"
13#include "mlir/IR/PatternMatch.h"
14#include "mlir/IR/Region.h"
15#include "mlir/Support/LogicalResult.h"
16#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19#include "llvm/ADT/SmallVector.h"
25#define GEN_PASS_DEF_CIRSIMPLIFY
26#include "clang/CIR/Dialect/Passes.h.inc"
56struct SimplifyTernary final :
public OpRewritePattern<TernaryOp> {
57 using OpRewritePattern<TernaryOp>::OpRewritePattern;
59 LogicalResult matchAndRewrite(TernaryOp op,
60 PatternRewriter &rewriter)
const override {
61 if (op->getNumResults() != 1)
62 return mlir::failure();
64 if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
65 !isSimpleTernaryBranch(op.getFalseRegion()))
66 return mlir::failure();
68 cir::YieldOp trueBranchYieldOp =
69 mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
70 cir::YieldOp falseBranchYieldOp =
71 mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
72 mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
73 mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
75 rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
76 rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
77 rewriter.eraseOp(trueBranchYieldOp);
78 rewriter.eraseOp(falseBranchYieldOp);
79 rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
82 return mlir::success();
86 bool isSimpleTernaryBranch(mlir::Region ®ion)
const {
87 if (!region.hasOneBlock())
90 mlir::Block &onlyBlock = region.front();
91 mlir::Block::OpListType &ops = onlyBlock.getOperations();
97 if (ops.size() == 1) {
104 auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
105 auto yieldValueDefOp =
106 yieldOp.getArgs()[0].getDefiningOp<cir::ConstantOp>();
107 return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
129struct SimplifySelect :
public OpRewritePattern<SelectOp> {
130 using OpRewritePattern<SelectOp>::OpRewritePattern;
132 LogicalResult matchAndRewrite(SelectOp op,
133 PatternRewriter &rewriter)
const final {
134 auto trueValueOp = op.getTrueValue().getDefiningOp<cir::ConstantOp>();
135 auto falseValueOp = op.getFalseValue().getDefiningOp<cir::ConstantOp>();
136 if (!trueValueOp || !falseValueOp)
137 return mlir::failure();
139 auto trueValue = trueValueOp.getValueAttr<cir::BoolAttr>();
140 auto falseValue = falseValueOp.getValueAttr<cir::BoolAttr>();
141 if (!trueValue || !falseValue)
142 return mlir::failure();
145 if (trueValue.getValue() && !falseValue.getValue()) {
146 rewriter.replaceAllUsesWith(op, op.getCondition());
147 rewriter.eraseOp(op);
148 return mlir::success();
152 if (!trueValue.getValue() && falseValue.getValue()) {
153 rewriter.replaceOpWithNewOp<cir::NotOp>(op, op.getCondition());
154 return mlir::success();
157 return mlir::failure();
193struct SimplifySwitch :
public OpRewritePattern<SwitchOp> {
194 using OpRewritePattern<SwitchOp>::OpRewritePattern;
195 LogicalResult matchAndRewrite(SwitchOp op,
196 PatternRewriter &rewriter)
const override {
198 LogicalResult changed = mlir::failure();
199 SmallVector<CaseOp, 8> cases;
200 SmallVector<CaseOp, 4> cascadingCases;
201 SmallVector<mlir::Attribute, 4> cascadingCaseValues;
203 op.collectCases(cases);
205 return mlir::failure();
207 auto flushMergedOps = [&]() {
208 for (CaseOp &
c : cascadingCases)
210 cascadingCases.clear();
211 cascadingCaseValues.clear();
214 auto mergeCascadingInto = [&](CaseOp &target) {
215 rewriter.modifyOpInPlace(target, [&]() {
216 target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
217 target.setKind(CaseOpKind::Anyof);
219 changed = mlir::success();
222 for (CaseOp
c : cases) {
223 cir::CaseOpKind
kind =
c.getKind();
224 if (
kind == cir::CaseOpKind::Equal &&
225 isa<YieldOp>(
c.getCaseRegion().front().front())) {
227 cascadingCases.push_back(
c);
228 cascadingCaseValues.push_back(
c.getValue()[0]);
229 }
else if (
kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
231 cascadingCaseValues.push_back(
c.getValue()[0]);
232 mergeCascadingInto(
c);
234 }
else if (
kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
240 CaseOp lastCascadingCase = cascadingCases.back();
241 mergeCascadingInto(lastCascadingCase);
242 cascadingCases.pop_back();
245 cascadingCases.clear();
246 cascadingCaseValues.clear();
251 if (cascadingCases.size() == cases.size()) {
252 CaseOp lastCascadingCase = cascadingCases.back();
253 mergeCascadingInto(lastCascadingCase);
254 cascadingCases.pop_back();
262struct SimplifyVecSplat :
public OpRewritePattern<VecSplatOp> {
263 using OpRewritePattern<VecSplatOp>::OpRewritePattern;
264 LogicalResult matchAndRewrite(VecSplatOp op,
265 PatternRewriter &rewriter)
const override {
266 mlir::Value splatValue = op.getValue();
267 auto constant = splatValue.getDefiningOp<cir::ConstantOp>();
269 return mlir::failure();
271 auto value = constant.getValue();
272 if (!mlir::isa_and_nonnull<cir::IntAttr>(value) &&
273 !mlir::isa_and_nonnull<cir::FPAttr>(value))
274 return mlir::failure();
276 cir::VectorType resultType = op.getResult().getType();
277 SmallVector<mlir::Attribute, 16> elements(resultType.getSize(), value);
278 auto constVecAttr = cir::ConstVectorAttr::get(
279 resultType, mlir::ArrayAttr::get(getContext(), elements));
281 rewriter.replaceOpWithNewOp<cir::ConstantOp>(op, constVecAttr);
282 return mlir::success();
290struct CIRSimplifyPass :
public impl::CIRSimplifyBase<CIRSimplifyPass> {
291 using CIRSimplifyBase::CIRSimplifyBase;
293 void runOnOperation()
override;
296void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
303 >(patterns.getContext());
307void CIRSimplifyPass::runOnOperation() {
309 RewritePatternSet patterns(&getContext());
310 populateMergeCleanupPatterns(patterns);
313 llvm::SmallVector<Operation *, 16> ops;
314 getOperation()->walk([&](Operation *op) {
315 if (isa<TernaryOp, SelectOp, SwitchOp, VecSplatOp>(op))
320 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
327 return std::make_unique<CIRSimplifyPass>();
*collection of selector each with an associated kind and an ordered *collection of selectors A selector has a kind
__device__ __2f16 float c
std::unique_ptr< Pass > createCIRSimplifyPass()
static bool foldRangeCase()