clang 22.0.0git
FlattenCFG.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements pass that inlines CIR operations regions into the parent
10// function region.
11//
12//===----------------------------------------------------------------------===//
13
14#include "PassDetail.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/IR/Block.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Support/LogicalResult.h"
20#include "mlir/Transforms/DialectConversion.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25
26using namespace mlir;
27using namespace cir;
28
29namespace {
30
31/// Lowers operations with the terminator trait that have a single successor.
32void lowerTerminator(mlir::Operation *op, mlir::Block *dest,
33 mlir::PatternRewriter &rewriter) {
34 assert(op->hasTrait<mlir::OpTrait::IsTerminator>() && "not a terminator");
35 mlir::OpBuilder::InsertionGuard guard(rewriter);
36 rewriter.setInsertionPoint(op);
37 rewriter.replaceOpWithNewOp<cir::BrOp>(op, dest);
38}
39
40/// Walks a region while skipping operations of type `Ops`. This ensures the
41/// callback is not applied to said operations and its children.
42template <typename... Ops>
43void walkRegionSkipping(
44 mlir::Region &region,
45 mlir::function_ref<mlir::WalkResult(mlir::Operation *)> callback) {
46 region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
47 if (isa<Ops...>(op))
48 return mlir::WalkResult::skip();
49 return callback(op);
50 });
51}
52
53struct CIRFlattenCFGPass : public CIRFlattenCFGBase<CIRFlattenCFGPass> {
54
55 CIRFlattenCFGPass() = default;
56 void runOnOperation() override;
57};
58
59struct CIRIfFlattening : public mlir::OpRewritePattern<cir::IfOp> {
60 using OpRewritePattern<IfOp>::OpRewritePattern;
61
62 mlir::LogicalResult
63 matchAndRewrite(cir::IfOp ifOp,
64 mlir::PatternRewriter &rewriter) const override {
65 mlir::OpBuilder::InsertionGuard guard(rewriter);
66 mlir::Location loc = ifOp.getLoc();
67 bool emptyElse = ifOp.getElseRegion().empty();
68 mlir::Block *currentBlock = rewriter.getInsertionBlock();
69 mlir::Block *remainingOpsBlock =
70 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
71 mlir::Block *continueBlock;
72 if (ifOp->getResults().empty())
73 continueBlock = remainingOpsBlock;
74 else
75 llvm_unreachable("NYI");
76
77 // Inline the region
78 mlir::Block *thenBeforeBody = &ifOp.getThenRegion().front();
79 mlir::Block *thenAfterBody = &ifOp.getThenRegion().back();
80 rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock);
81
82 rewriter.setInsertionPointToEnd(thenAfterBody);
83 if (auto thenYieldOp =
84 dyn_cast<cir::YieldOp>(thenAfterBody->getTerminator())) {
85 rewriter.replaceOpWithNewOp<cir::BrOp>(thenYieldOp, thenYieldOp.getArgs(),
86 continueBlock);
87 }
88
89 rewriter.setInsertionPointToEnd(continueBlock);
90
91 // Has else region: inline it.
92 mlir::Block *elseBeforeBody = nullptr;
93 mlir::Block *elseAfterBody = nullptr;
94 if (!emptyElse) {
95 elseBeforeBody = &ifOp.getElseRegion().front();
96 elseAfterBody = &ifOp.getElseRegion().back();
97 rewriter.inlineRegionBefore(ifOp.getElseRegion(), continueBlock);
98 } else {
99 elseBeforeBody = elseAfterBody = continueBlock;
100 }
101
102 rewriter.setInsertionPointToEnd(currentBlock);
103 cir::BrCondOp::create(rewriter, loc, ifOp.getCondition(), thenBeforeBody,
104 elseBeforeBody);
105
106 if (!emptyElse) {
107 rewriter.setInsertionPointToEnd(elseAfterBody);
108 if (auto elseYieldOP =
109 dyn_cast<cir::YieldOp>(elseAfterBody->getTerminator())) {
110 rewriter.replaceOpWithNewOp<cir::BrOp>(
111 elseYieldOP, elseYieldOP.getArgs(), continueBlock);
112 }
113 }
114
115 rewriter.replaceOp(ifOp, continueBlock->getArguments());
116 return mlir::success();
117 }
118};
119
120class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> {
121public:
122 using OpRewritePattern<cir::ScopeOp>::OpRewritePattern;
123
124 mlir::LogicalResult
125 matchAndRewrite(cir::ScopeOp scopeOp,
126 mlir::PatternRewriter &rewriter) const override {
127 mlir::OpBuilder::InsertionGuard guard(rewriter);
128 mlir::Location loc = scopeOp.getLoc();
129
130 // Empty scope: just remove it.
131 // TODO: Remove this logic once CIR uses MLIR infrastructure to remove
132 // trivially dead operations. MLIR canonicalizer is too aggressive and we
133 // need to either (a) make sure all our ops model all side-effects and/or
134 // (b) have more options in the canonicalizer in MLIR to temper
135 // aggressiveness level.
136 if (scopeOp.isEmpty()) {
137 rewriter.eraseOp(scopeOp);
138 return mlir::success();
139 }
140
141 // Split the current block before the ScopeOp to create the inlining
142 // point.
143 mlir::Block *currentBlock = rewriter.getInsertionBlock();
144 mlir::Block *continueBlock =
145 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
146 if (scopeOp.getNumResults() > 0)
147 continueBlock->addArguments(scopeOp.getResultTypes(), loc);
148
149 // Inline body region.
150 mlir::Block *beforeBody = &scopeOp.getScopeRegion().front();
151 mlir::Block *afterBody = &scopeOp.getScopeRegion().back();
152 rewriter.inlineRegionBefore(scopeOp.getScopeRegion(), continueBlock);
153
154 // Save stack and then branch into the body of the region.
155 rewriter.setInsertionPointToEnd(currentBlock);
157 cir::BrOp::create(rewriter, loc, mlir::ValueRange(), beforeBody);
158
159 // Replace the scopeop return with a branch that jumps out of the body.
160 // Stack restore before leaving the body region.
161 rewriter.setInsertionPointToEnd(afterBody);
162 if (auto yieldOp = dyn_cast<cir::YieldOp>(afterBody->getTerminator())) {
163 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(),
164 continueBlock);
165 }
166
167 // Replace the op with values return from the body region.
168 rewriter.replaceOp(scopeOp, continueBlock->getArguments());
169
170 return mlir::success();
171 }
172};
173
174class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
175public:
176 using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
177
178 inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
179 cir::YieldOp yieldOp,
180 mlir::Block *destination) const {
181 rewriter.setInsertionPoint(yieldOp);
182 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
183 destination);
184 }
185
186 // Return the new defaultDestination block.
187 Block *condBrToRangeDestination(cir::SwitchOp op,
188 mlir::PatternRewriter &rewriter,
189 mlir::Block *rangeDestination,
190 mlir::Block *defaultDestination,
191 const APInt &lowerBound,
192 const APInt &upperBound) const {
193 assert(lowerBound.sle(upperBound) && "Invalid range");
194 mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
195 cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true);
196 cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false);
197
198 cir::ConstantOp rangeLength = cir::ConstantOp::create(
199 rewriter, op.getLoc(),
200 cir::IntAttr::get(sIntType, upperBound - lowerBound));
201
202 cir::ConstantOp lowerBoundValue = cir::ConstantOp::create(
203 rewriter, op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
204 cir::BinOp diffValue =
205 cir::BinOp::create(rewriter, op.getLoc(), sIntType, cir::BinOpKind::Sub,
206 op.getCondition(), lowerBoundValue);
207
208 // Use unsigned comparison to check if the condition is in the range.
209 cir::CastOp uDiffValue = cir::CastOp::create(
210 rewriter, op.getLoc(), uIntType, CastKind::integral, diffValue);
211 cir::CastOp uRangeLength = cir::CastOp::create(
212 rewriter, op.getLoc(), uIntType, CastKind::integral, rangeLength);
213
214 cir::CmpOp cmpResult = cir::CmpOp::create(
215 rewriter, op.getLoc(), cir::BoolType::get(op.getContext()),
216 cir::CmpOpKind::le, uDiffValue, uRangeLength);
217 cir::BrCondOp::create(rewriter, op.getLoc(), cmpResult, rangeDestination,
218 defaultDestination);
219 return resBlock;
220 }
221
222 mlir::LogicalResult
223 matchAndRewrite(cir::SwitchOp op,
224 mlir::PatternRewriter &rewriter) const override {
225 llvm::SmallVector<CaseOp> cases;
226 op.collectCases(cases);
227
228 // Empty switch statement: just erase it.
229 if (cases.empty()) {
230 rewriter.eraseOp(op);
231 return mlir::success();
232 }
233
234 // Create exit block from the next node of cir.switch op.
235 mlir::Block *exitBlock = rewriter.splitBlock(
236 rewriter.getBlock(), op->getNextNode()->getIterator());
237
238 // We lower cir.switch op in the following process:
239 // 1. Inline the region from the switch op after switch op.
240 // 2. Traverse each cir.case op:
241 // a. Record the entry block, block arguments and condition for every
242 // case. b. Inline the case region after the case op.
243 // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the
244 // recorded block and conditions.
245
246 // inline everything from switch body between the switch op and the exit
247 // block.
248 {
249 cir::YieldOp switchYield = nullptr;
250 // Clear switch operation.
251 for (mlir::Block &block :
252 llvm::make_early_inc_range(op.getBody().getBlocks()))
253 if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
254 switchYield = yieldOp;
255
256 assert(!op.getBody().empty());
257 mlir::Block *originalBlock = op->getBlock();
258 mlir::Block *swopBlock =
259 rewriter.splitBlock(originalBlock, op->getIterator());
260 rewriter.inlineRegionBefore(op.getBody(), exitBlock);
261
262 if (switchYield)
263 rewriteYieldOp(rewriter, switchYield, exitBlock);
264
265 rewriter.setInsertionPointToEnd(originalBlock);
266 cir::BrOp::create(rewriter, op.getLoc(), swopBlock);
267 }
268
269 // Allocate required data structures (disconsider default case in
270 // vectors).
271 llvm::SmallVector<mlir::APInt, 8> caseValues;
272 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
273 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
274
275 llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
276 llvm::SmallVector<mlir::Block *> rangeDestinations;
277 llvm::SmallVector<mlir::ValueRange> rangeOperands;
278
279 // Initialize default case as optional.
280 mlir::Block *defaultDestination = exitBlock;
281 mlir::ValueRange defaultOperands = exitBlock->getArguments();
282
283 // Digest the case statements values and bodies.
284 for (cir::CaseOp caseOp : cases) {
285 mlir::Region &region = caseOp.getCaseRegion();
286
287 // Found default case: save destination and operands.
288 switch (caseOp.getKind()) {
289 case cir::CaseOpKind::Default:
290 defaultDestination = &region.front();
291 defaultOperands = defaultDestination->getArguments();
292 break;
293 case cir::CaseOpKind::Range:
294 assert(caseOp.getValue().size() == 2 &&
295 "Case range should have 2 case value");
296 rangeValues.push_back(
297 {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),
298 cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});
299 rangeDestinations.push_back(&region.front());
300 rangeOperands.push_back(rangeDestinations.back()->getArguments());
301 break;
302 case cir::CaseOpKind::Anyof:
303 case cir::CaseOpKind::Equal:
304 // AnyOf cases kind can have multiple values, hence the loop below.
305 for (const mlir::Attribute &value : caseOp.getValue()) {
306 caseValues.push_back(cast<cir::IntAttr>(value).getValue());
307 caseDestinations.push_back(&region.front());
308 caseOperands.push_back(caseDestinations.back()->getArguments());
309 }
310 break;
311 }
312
313 // Handle break statements.
314 walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
315 region, [&](mlir::Operation *op) {
316 if (!isa<cir::BreakOp>(op))
317 return mlir::WalkResult::advance();
318
319 lowerTerminator(op, exitBlock, rewriter);
320 return mlir::WalkResult::skip();
321 });
322
323 // Track fallthrough in cases.
324 for (mlir::Block &blk : region.getBlocks()) {
325 if (blk.getNumSuccessors())
326 continue;
327
328 if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {
329 mlir::Operation *nextOp = caseOp->getNextNode();
330 assert(nextOp && "caseOp is not expected to be the last op");
331 mlir::Block *oldBlock = nextOp->getBlock();
332 mlir::Block *newBlock =
333 rewriter.splitBlock(oldBlock, nextOp->getIterator());
334 rewriter.setInsertionPointToEnd(oldBlock);
335 cir::BrOp::create(rewriter, nextOp->getLoc(), mlir::ValueRange(),
336 newBlock);
337 rewriteYieldOp(rewriter, yieldOp, newBlock);
338 }
339 }
340
341 mlir::Block *oldBlock = caseOp->getBlock();
342 mlir::Block *newBlock =
343 rewriter.splitBlock(oldBlock, caseOp->getIterator());
344
345 mlir::Block &entryBlock = caseOp.getCaseRegion().front();
346 rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);
347
348 // Create a branch to the entry of the inlined region.
349 rewriter.setInsertionPointToEnd(oldBlock);
350 cir::BrOp::create(rewriter, caseOp.getLoc(), &entryBlock);
351 }
352
353 // Remove all cases since we've inlined the regions.
354 for (cir::CaseOp caseOp : cases) {
355 mlir::Block *caseBlock = caseOp->getBlock();
356 // Erase the block with no predecessors here to make the generated code
357 // simpler a little bit.
358 if (caseBlock->hasNoPredecessors())
359 rewriter.eraseBlock(caseBlock);
360 else
361 rewriter.eraseOp(caseOp);
362 }
363
364 for (auto [rangeVal, operand, destination] :
365 llvm::zip(rangeValues, rangeOperands, rangeDestinations)) {
366 APInt lowerBound = rangeVal.first;
367 APInt upperBound = rangeVal.second;
368
369 // The case range is unreachable, skip it.
370 if (lowerBound.sgt(upperBound))
371 continue;
372
373 // If range is small, add multiple switch instruction cases.
374 // This magical number is from the original CGStmt code.
375 constexpr int kSmallRangeThreshold = 64;
376 if ((upperBound - lowerBound)
377 .ult(llvm::APInt(32, kSmallRangeThreshold))) {
378 for (APInt iValue = lowerBound; iValue.sle(upperBound); ++iValue) {
379 caseValues.push_back(iValue);
380 caseOperands.push_back(operand);
381 caseDestinations.push_back(destination);
382 }
383 continue;
384 }
385
386 defaultDestination =
387 condBrToRangeDestination(op, rewriter, destination,
388 defaultDestination, lowerBound, upperBound);
389 defaultOperands = operand;
390 }
391
392 // Set switch op to branch to the newly created blocks.
393 rewriter.setInsertionPoint(op);
394 rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>(
395 op, op.getCondition(), defaultDestination, defaultOperands, caseValues,
396 caseDestinations, caseOperands);
397
398 return mlir::success();
399 }
400};
401
402class CIRLoopOpInterfaceFlattening
403 : public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> {
404public:
405 using mlir::OpInterfaceRewritePattern<
406 cir::LoopOpInterface>::OpInterfaceRewritePattern;
407
408 inline void lowerConditionOp(cir::ConditionOp op, mlir::Block *body,
409 mlir::Block *exit,
410 mlir::PatternRewriter &rewriter) const {
411 mlir::OpBuilder::InsertionGuard guard(rewriter);
412 rewriter.setInsertionPoint(op);
413 rewriter.replaceOpWithNewOp<cir::BrCondOp>(op, op.getCondition(), body,
414 exit);
415 }
416
417 mlir::LogicalResult
418 matchAndRewrite(cir::LoopOpInterface op,
419 mlir::PatternRewriter &rewriter) const final {
420 // Setup CFG blocks.
421 mlir::Block *entry = rewriter.getInsertionBlock();
422 mlir::Block *exit =
423 rewriter.splitBlock(entry, rewriter.getInsertionPoint());
424 mlir::Block *cond = &op.getCond().front();
425 mlir::Block *body = &op.getBody().front();
426 mlir::Block *step =
427 (op.maybeGetStep() ? &op.maybeGetStep()->front() : nullptr);
428
429 // Setup loop entry branch.
430 rewriter.setInsertionPointToEnd(entry);
431 cir::BrOp::create(rewriter, op.getLoc(), &op.getEntry().front());
432
433 // Branch from condition region to body or exit.
434 auto conditionOp = cast<cir::ConditionOp>(cond->getTerminator());
435 lowerConditionOp(conditionOp, body, exit, rewriter);
436
437 // TODO(cir): Remove the walks below. It visits operations unnecessarily.
438 // However, to solve this we would likely need a custom DialectConversion
439 // driver to customize the order that operations are visited.
440
441 // Lower continue statements.
442 mlir::Block *dest = (step ? step : cond);
443 op.walkBodySkippingNestedLoops([&](mlir::Operation *op) {
444 if (!isa<cir::ContinueOp>(op))
445 return mlir::WalkResult::advance();
446
447 lowerTerminator(op, dest, rewriter);
448 return mlir::WalkResult::skip();
449 });
450
451 // Lower break statements.
453 walkRegionSkipping<cir::LoopOpInterface>(
454 op.getBody(), [&](mlir::Operation *op) {
455 if (!isa<cir::BreakOp>(op))
456 return mlir::WalkResult::advance();
457
458 lowerTerminator(op, exit, rewriter);
459 return mlir::WalkResult::skip();
460 });
461
462 // Lower optional body region yield.
463 for (mlir::Block &blk : op.getBody().getBlocks()) {
464 auto bodyYield = dyn_cast<cir::YieldOp>(blk.getTerminator());
465 if (bodyYield)
466 lowerTerminator(bodyYield, (step ? step : cond), rewriter);
467 }
468
469 // Lower mandatory step region yield.
470 if (step)
471 lowerTerminator(cast<cir::YieldOp>(step->getTerminator()), cond,
472 rewriter);
473
474 // Move region contents out of the loop op.
475 rewriter.inlineRegionBefore(op.getCond(), exit);
476 rewriter.inlineRegionBefore(op.getBody(), exit);
477 if (step)
478 rewriter.inlineRegionBefore(*op.maybeGetStep(), exit);
479
480 rewriter.eraseOp(op);
481 return mlir::success();
482 }
483};
484
485class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
486public:
487 using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
488
489 mlir::LogicalResult
490 matchAndRewrite(cir::TernaryOp op,
491 mlir::PatternRewriter &rewriter) const override {
492 Location loc = op->getLoc();
493 Block *condBlock = rewriter.getInsertionBlock();
494 Block::iterator opPosition = rewriter.getInsertionPoint();
495 Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
496 llvm::SmallVector<mlir::Location, 2> locs;
497 // Ternary result is optional, make sure to populate the location only
498 // when relevant.
499 if (op->getResultTypes().size())
500 locs.push_back(loc);
501 Block *continueBlock =
502 rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
503 cir::BrOp::create(rewriter, loc, remainingOpsBlock);
504
505 Region &trueRegion = op.getTrueRegion();
506 Block *trueBlock = &trueRegion.front();
507 mlir::Operation *trueTerminator = trueRegion.back().getTerminator();
508 rewriter.setInsertionPointToEnd(&trueRegion.back());
509
510 // Handle both yield and unreachable terminators (throw expressions)
511 if (auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator)) {
512 rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
513 continueBlock);
514 } else if (isa<cir::UnreachableOp>(trueTerminator)) {
515 // Terminator is unreachable (e.g., from throw), just keep it
516 } else {
517 trueTerminator->emitError("unexpected terminator in ternary true region, "
518 "expected yield or unreachable, got: ")
519 << trueTerminator->getName();
520 return mlir::failure();
521 }
522 rewriter.inlineRegionBefore(trueRegion, continueBlock);
523
524 Block *falseBlock = continueBlock;
525 Region &falseRegion = op.getFalseRegion();
526
527 falseBlock = &falseRegion.front();
528 mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
529 rewriter.setInsertionPointToEnd(&falseRegion.back());
530
531 // Handle both yield and unreachable terminators (throw expressions)
532 if (auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator)) {
533 rewriter.replaceOpWithNewOp<cir::BrOp>(
534 falseYieldOp, falseYieldOp.getArgs(), continueBlock);
535 } else if (isa<cir::UnreachableOp>(falseTerminator)) {
536 // Terminator is unreachable (e.g., from throw), just keep it
537 } else {
538 falseTerminator->emitError("unexpected terminator in ternary false "
539 "region, expected yield or unreachable, got: ")
540 << falseTerminator->getName();
541 return mlir::failure();
542 }
543 rewriter.inlineRegionBefore(falseRegion, continueBlock);
544
545 rewriter.setInsertionPointToEnd(condBlock);
546 cir::BrCondOp::create(rewriter, loc, op.getCond(), trueBlock, falseBlock);
547
548 rewriter.replaceOp(op, continueBlock->getArguments());
549
550 // Ok, we're done!
551 return mlir::success();
552 }
553};
554
555class CIRTryOpFlattening : public mlir::OpRewritePattern<cir::TryOp> {
556public:
557 using OpRewritePattern<cir::TryOp>::OpRewritePattern;
558
559 mlir::Block *buildTryBody(cir::TryOp tryOp,
560 mlir::PatternRewriter &rewriter) const {
561 // Split the current block before the TryOp to create the inlining
562 // point.
563 mlir::Block *beforeTryScopeBlock = rewriter.getInsertionBlock();
564 mlir::Block *afterTry =
565 rewriter.splitBlock(beforeTryScopeBlock, rewriter.getInsertionPoint());
566
567 // Inline body region.
568 mlir::Block *beforeBody = &tryOp.getTryRegion().front();
569 rewriter.inlineRegionBefore(tryOp.getTryRegion(), afterTry);
570
571 // Branch into the body of the region.
572 rewriter.setInsertionPointToEnd(beforeTryScopeBlock);
573 cir::BrOp::create(rewriter, tryOp.getLoc(), mlir::ValueRange(), beforeBody);
574 return afterTry;
575 }
576
577 void buildHandlers(cir::TryOp tryOp, mlir::PatternRewriter &rewriter,
578 mlir::Block *afterBody, mlir::Block *afterTry,
579 SmallVectorImpl<cir::CallOp> &callsToRewrite,
580 SmallVectorImpl<mlir::Block *> &landingPads) const {
581 // Replace the tryOp return with a branch that jumps out of the body.
582 rewriter.setInsertionPointToEnd(afterBody);
583
584 mlir::Block *beforeCatch = rewriter.getInsertionBlock();
585 rewriter.setInsertionPointToEnd(beforeCatch);
586
587 // Check if the terminator is a YieldOp because there could be another
588 // terminator, e.g. unreachable
589 if (auto tryBodyYield = dyn_cast<cir::YieldOp>(afterBody->getTerminator()))
590 rewriter.replaceOpWithNewOp<cir::BrOp>(tryBodyYield, afterTry);
591
592 mlir::ArrayAttr handlers = tryOp.getHandlerTypesAttr();
593 if (!handlers || handlers.empty())
594 return;
595
596 llvm_unreachable("TryOpFlattening buildHandlers with CallsOp is NYI");
597 }
598
599 mlir::LogicalResult
600 matchAndRewrite(cir::TryOp tryOp,
601 mlir::PatternRewriter &rewriter) const override {
602 mlir::OpBuilder::InsertionGuard guard(rewriter);
603 mlir::Block *afterBody = &tryOp.getTryRegion().back();
604
605 // Grab the collection of `cir.call exception`s to rewrite to
606 // `cir.try_call`.
607 llvm::SmallVector<cir::CallOp, 4> callsToRewrite;
608 tryOp.getTryRegion().walk([&](CallOp op) {
609 if (op.getNothrow())
610 return;
611
612 // Only grab calls within immediate closest TryOp scope.
613 if (op->getParentOfType<cir::TryOp>() != tryOp)
614 return;
615 callsToRewrite.push_back(op);
616 });
617
618 if (!callsToRewrite.empty())
619 llvm_unreachable(
620 "TryOpFlattening with try block that contains CallOps is NYI");
621
622 // Build try body.
623 mlir::Block *afterTry = buildTryBody(tryOp, rewriter);
624
625 // Build handlers.
626 llvm::SmallVector<mlir::Block *, 4> landingPads;
627 buildHandlers(tryOp, rewriter, afterBody, afterTry, callsToRewrite,
628 landingPads);
629
630 rewriter.eraseOp(tryOp);
631
632 assert((landingPads.size() == callsToRewrite.size()) &&
633 "expected matching number of entries");
634
635 // Quick block cleanup: no indirection to the post try block.
636 auto brOp = dyn_cast<cir::BrOp>(afterTry->getTerminator());
637 if (brOp && brOp.getDest()->hasNoPredecessors()) {
638 mlir::Block *srcBlock = brOp.getDest();
639 rewriter.eraseOp(brOp);
640 rewriter.mergeBlocks(srcBlock, afterTry);
641 }
642
643 return mlir::success();
644 }
645};
646
647void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
648 patterns
649 .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
650 CIRSwitchOpFlattening, CIRTernaryOpFlattening, CIRTryOpFlattening>(
651 patterns.getContext());
652}
653
654void CIRFlattenCFGPass::runOnOperation() {
655 RewritePatternSet patterns(&getContext());
656 populateFlattenCFGPatterns(patterns);
657
658 // Collect operations to apply patterns.
659 llvm::SmallVector<Operation *, 16> ops;
660 getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
664 if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp, TryOp>(op))
665 ops.push_back(op);
666 });
667
668 // Apply patterns.
669 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
670 signalPassFailure();
671}
672
673} // namespace
674
675namespace mlir {
676
677std::unique_ptr<Pass> createCIRFlattenCFGPass() {
678 return std::make_unique<CIRFlattenCFGPass>();
679}
680
681} // namespace mlir
llvm::APInt APInt
Definition FixedPoint.h:19
std::unique_ptr< Pass > createCIRFlattenCFGPass()
float __ovld __cnfn step(float, float)
Returns 0.0 if x < edge, otherwise it returns 1.0.
static bool stackSaveOp()