clang 22.0.0git
CIRSimplify.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PassDetail.h"
10#include "mlir/Dialect/Func/IR/FuncOps.h"
11#include "mlir/IR/Block.h"
12#include "mlir/IR/Operation.h"
13#include "mlir/IR/PatternMatch.h"
14#include "mlir/IR/Region.h"
15#include "mlir/Support/LogicalResult.h"
16#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19#include "llvm/ADT/SmallVector.h"
20
21using namespace mlir;
22using namespace cir;
23
24//===----------------------------------------------------------------------===//
25// Rewrite patterns
26//===----------------------------------------------------------------------===//
27
28namespace {
29
30/// Simplify suitable ternary operations into select operations.
31///
32/// For now we only simplify those ternary operations whose true and false
33/// branches directly yield a value or a constant. That is, both of the true and
34/// the false branch must either contain a cir.yield operation as the only
35/// operation in the branch, or contain a cir.const operation followed by a
36/// cir.yield operation that yields the constant value.
37///
38/// For example, we will simplify the following ternary operation:
39///
40/// %0 = ...
41/// %1 = cir.ternary (%condition, true {
42/// %2 = cir.const ...
43/// cir.yield %2
44/// } false {
45/// cir.yield %0
46///
47/// into the following sequence of operations:
48///
49/// %1 = cir.const ...
50/// %0 = cir.select if %condition then %1 else %2
51struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
52 using OpRewritePattern<TernaryOp>::OpRewritePattern;
53
54 LogicalResult matchAndRewrite(TernaryOp op,
55 PatternRewriter &rewriter) const override {
56 if (op->getNumResults() != 1)
57 return mlir::failure();
58
59 if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
60 !isSimpleTernaryBranch(op.getFalseRegion()))
61 return mlir::failure();
62
63 cir::YieldOp trueBranchYieldOp =
64 mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
65 cir::YieldOp falseBranchYieldOp =
66 mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
67 mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
68 mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
69
70 rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
71 rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
72 rewriter.eraseOp(trueBranchYieldOp);
73 rewriter.eraseOp(falseBranchYieldOp);
74 rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
75 falseValue);
76
77 return mlir::success();
78 }
79
80private:
81 bool isSimpleTernaryBranch(mlir::Region &region) const {
82 if (!region.hasOneBlock())
83 return false;
84
85 mlir::Block &onlyBlock = region.front();
86 mlir::Block::OpListType &ops = onlyBlock.getOperations();
87
88 // The region/block could only contain at most 2 operations.
89 if (ops.size() > 2)
90 return false;
91
92 if (ops.size() == 1) {
93 // The region/block only contain a cir.yield operation.
94 return true;
95 }
96
97 // Check whether the region/block contains a cir.const followed by a
98 // cir.yield that yields the value.
99 auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
100 auto yieldValueDefOp =
101 yieldOp.getArgs()[0].getDefiningOp<cir::ConstantOp>();
102 return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
103 }
104};
105
106/// Simplify select operations with boolean constants into simpler forms.
107///
108/// This pattern simplifies select operations where both true and false values
109/// are boolean constants. Two specific cases are handled:
110///
111/// 1. When selecting between true and false based on a condition,
112/// the operation simplifies to just the condition itself:
113///
114/// %0 = cir.select if %condition then true else false
115/// ->
116/// (replaced with %condition directly)
117///
118/// 2. When selecting between false and true based on a condition,
119/// the operation simplifies to the logical negation of the condition:
120///
121/// %0 = cir.select if %condition then false else true
122/// ->
123/// %0 = cir.unary not %condition
124struct SimplifySelect : public OpRewritePattern<SelectOp> {
125 using OpRewritePattern<SelectOp>::OpRewritePattern;
126
127 LogicalResult matchAndRewrite(SelectOp op,
128 PatternRewriter &rewriter) const final {
129 auto trueValueOp = op.getTrueValue().getDefiningOp<cir::ConstantOp>();
130 auto falseValueOp = op.getFalseValue().getDefiningOp<cir::ConstantOp>();
131 if (!trueValueOp || !falseValueOp)
132 return mlir::failure();
133
134 auto trueValue = trueValueOp.getValueAttr<cir::BoolAttr>();
135 auto falseValue = falseValueOp.getValueAttr<cir::BoolAttr>();
136 if (!trueValue || !falseValue)
137 return mlir::failure();
138
139 // cir.select if %0 then #true else #false -> %0
140 if (trueValue.getValue() && !falseValue.getValue()) {
141 rewriter.replaceAllUsesWith(op, op.getCondition());
142 rewriter.eraseOp(op);
143 return mlir::success();
144 }
145
146 // cir.select if %0 then #false else #true -> cir.unary not %0
147 if (!trueValue.getValue() && falseValue.getValue()) {
148 rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
149 op.getCondition());
150 return mlir::success();
151 }
152
153 return mlir::failure();
154 }
155};
156
157/// Simplify `cir.switch` operations by folding cascading cases
158/// into a single `cir.case` with the `anyof` kind.
159///
160/// This pattern identifies cascading cases within a `cir.switch` operation.
161/// Cascading cases are defined as consecutive `cir.case` operations of kind
162/// `equal`, each containing a single `cir.yield` operation in their body.
163///
164/// The pattern merges these cascading cases into a single `cir.case` operation
165/// with kind `anyof`, aggregating all the case values.
166///
167/// The merging process continues until a `cir.case` with a different body
168/// (e.g., containing `cir.break` or compound stmt) is encountered, which
169/// breaks the chain.
170///
171/// Example:
172///
173/// Before:
174/// cir.case equal, [#cir.int<0> : !s32i] {
175/// cir.yield
176/// }
177/// cir.case equal, [#cir.int<1> : !s32i] {
178/// cir.yield
179/// }
180/// cir.case equal, [#cir.int<2> : !s32i] {
181/// cir.break
182/// }
183///
184/// After applying SimplifySwitch:
185/// cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :
186/// !s32i] {
187/// cir.break
188/// }
189struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
190 using OpRewritePattern<SwitchOp>::OpRewritePattern;
191 LogicalResult matchAndRewrite(SwitchOp op,
192 PatternRewriter &rewriter) const override {
193
194 LogicalResult changed = mlir::failure();
195 SmallVector<CaseOp, 8> cases;
196 SmallVector<CaseOp, 4> cascadingCases;
197 SmallVector<mlir::Attribute, 4> cascadingCaseValues;
198
199 op.collectCases(cases);
200 if (cases.empty())
201 return mlir::failure();
202
203 auto flushMergedOps = [&]() {
204 for (CaseOp &c : cascadingCases)
205 rewriter.eraseOp(c);
206 cascadingCases.clear();
207 cascadingCaseValues.clear();
208 };
209
210 auto mergeCascadingInto = [&](CaseOp &target) {
211 rewriter.modifyOpInPlace(target, [&]() {
212 target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
213 target.setKind(CaseOpKind::Anyof);
214 });
215 changed = mlir::success();
216 };
217
218 for (CaseOp c : cases) {
219 cir::CaseOpKind kind = c.getKind();
220 if (kind == cir::CaseOpKind::Equal &&
221 isa<YieldOp>(c.getCaseRegion().front().front())) {
222 // If the case contains only a YieldOp, collect it for cascading merge
223 cascadingCases.push_back(c);
224 cascadingCaseValues.push_back(c.getValue()[0]);
225 } else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
226 // merge previously collected cascading cases
227 cascadingCaseValues.push_back(c.getValue()[0]);
228 mergeCascadingInto(c);
229 flushMergedOps();
230 } else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
231 // If a Default, Anyof or Range case is found and there are previous
232 // cascading cases, merge all of them into the last cascading case.
233 // We don't currently fold case range statements with other case
234 // statements.
236 CaseOp lastCascadingCase = cascadingCases.back();
237 mergeCascadingInto(lastCascadingCase);
238 cascadingCases.pop_back();
239 flushMergedOps();
240 } else {
241 cascadingCases.clear();
242 cascadingCaseValues.clear();
243 }
244 }
245
246 // Edge case: all cases are simple cascading cases
247 if (cascadingCases.size() == cases.size()) {
248 CaseOp lastCascadingCase = cascadingCases.back();
249 mergeCascadingInto(lastCascadingCase);
250 cascadingCases.pop_back();
251 flushMergedOps();
252 }
253
254 return changed;
255 }
256};
257
258struct SimplifyVecSplat : public OpRewritePattern<VecSplatOp> {
259 using OpRewritePattern<VecSplatOp>::OpRewritePattern;
260 LogicalResult matchAndRewrite(VecSplatOp op,
261 PatternRewriter &rewriter) const override {
262 mlir::Value splatValue = op.getValue();
263 auto constant = splatValue.getDefiningOp<cir::ConstantOp>();
264 if (!constant)
265 return mlir::failure();
266
267 auto value = constant.getValue();
268 if (!mlir::isa_and_nonnull<cir::IntAttr>(value) &&
269 !mlir::isa_and_nonnull<cir::FPAttr>(value))
270 return mlir::failure();
271
272 cir::VectorType resultType = op.getResult().getType();
273 SmallVector<mlir::Attribute, 16> elements(resultType.getSize(), value);
274 auto constVecAttr = cir::ConstVectorAttr::get(
275 resultType, mlir::ArrayAttr::get(getContext(), elements));
276
277 rewriter.replaceOpWithNewOp<cir::ConstantOp>(op, constVecAttr);
278 return mlir::success();
279 }
280};
281
282//===----------------------------------------------------------------------===//
283// CIRSimplifyPass
284//===----------------------------------------------------------------------===//
285
286struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> {
287 using CIRSimplifyBase::CIRSimplifyBase;
288
289 void runOnOperation() override;
290};
291
292void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
293 // clang-format off
294 patterns.add<
295 SimplifyTernary,
296 SimplifySelect,
297 SimplifySwitch,
298 SimplifyVecSplat
299 >(patterns.getContext());
300 // clang-format on
301}
302
303void CIRSimplifyPass::runOnOperation() {
304 // Collect rewrite patterns.
305 RewritePatternSet patterns(&getContext());
306 populateMergeCleanupPatterns(patterns);
307
308 // Collect operations to apply patterns.
310 getOperation()->walk([&](Operation *op) {
311 if (isa<TernaryOp, SelectOp, SwitchOp, VecSplatOp>(op))
312 ops.push_back(op);
313 });
314
315 // Apply patterns.
316 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
317 signalPassFailure();
318}
319
320} // namespace
321
322std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {
323 return std::make_unique<CIRSimplifyPass>();
324}
__device__ __2f16 float c
Definition: ABIArgInfo.h:22
unsigned kind
All of the diagnostics that can be emitted by the frontend.
Definition: DiagnosticIDs.h:76
std::unique_ptr< Pass > createCIRSimplifyPass()
static bool foldRangeCase()