clang 22.0.0git
FlattenCFG.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements pass that inlines CIR operations regions into the parent
10// function region.
11//
12//===----------------------------------------------------------------------===//
13
14#include "PassDetail.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/IR/Block.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Support/LogicalResult.h"
20#include "mlir/Transforms/DialectConversion.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25
26using namespace mlir;
27using namespace cir;
28
29namespace mlir {
30#define GEN_PASS_DEF_CIRFLATTENCFG
31#include "clang/CIR/Dialect/Passes.h.inc"
32} // namespace mlir
33
34namespace {
35
36/// Lowers operations with the terminator trait that have a single successor.
37void lowerTerminator(mlir::Operation *op, mlir::Block *dest,
38 mlir::PatternRewriter &rewriter) {
39 assert(op->hasTrait<mlir::OpTrait::IsTerminator>() && "not a terminator");
40 mlir::OpBuilder::InsertionGuard guard(rewriter);
41 rewriter.setInsertionPoint(op);
42 rewriter.replaceOpWithNewOp<cir::BrOp>(op, dest);
43}
44
45/// Walks a region while skipping operations of type `Ops`. This ensures the
46/// callback is not applied to said operations and its children.
47template <typename... Ops>
48void walkRegionSkipping(
49 mlir::Region &region,
50 mlir::function_ref<mlir::WalkResult(mlir::Operation *)> callback) {
51 region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
52 if (isa<Ops...>(op))
53 return mlir::WalkResult::skip();
54 return callback(op);
55 });
56}
57
58struct CIRFlattenCFGPass : public impl::CIRFlattenCFGBase<CIRFlattenCFGPass> {
59
60 CIRFlattenCFGPass() = default;
61 void runOnOperation() override;
62};
63
64struct CIRIfFlattening : public mlir::OpRewritePattern<cir::IfOp> {
65 using OpRewritePattern<IfOp>::OpRewritePattern;
66
67 mlir::LogicalResult
68 matchAndRewrite(cir::IfOp ifOp,
69 mlir::PatternRewriter &rewriter) const override {
70 mlir::OpBuilder::InsertionGuard guard(rewriter);
71 mlir::Location loc = ifOp.getLoc();
72 bool emptyElse = ifOp.getElseRegion().empty();
73 mlir::Block *currentBlock = rewriter.getInsertionBlock();
74 mlir::Block *remainingOpsBlock =
75 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
76 mlir::Block *continueBlock;
77 if (ifOp->getResults().empty())
78 continueBlock = remainingOpsBlock;
79 else
80 llvm_unreachable("NYI");
81
82 // Inline the region
83 mlir::Block *thenBeforeBody = &ifOp.getThenRegion().front();
84 mlir::Block *thenAfterBody = &ifOp.getThenRegion().back();
85 rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock);
86
87 rewriter.setInsertionPointToEnd(thenAfterBody);
88 if (auto thenYieldOp =
89 dyn_cast<cir::YieldOp>(thenAfterBody->getTerminator())) {
90 rewriter.replaceOpWithNewOp<cir::BrOp>(thenYieldOp, thenYieldOp.getArgs(),
91 continueBlock);
92 }
93
94 rewriter.setInsertionPointToEnd(continueBlock);
95
96 // Has else region: inline it.
97 mlir::Block *elseBeforeBody = nullptr;
98 mlir::Block *elseAfterBody = nullptr;
99 if (!emptyElse) {
100 elseBeforeBody = &ifOp.getElseRegion().front();
101 elseAfterBody = &ifOp.getElseRegion().back();
102 rewriter.inlineRegionBefore(ifOp.getElseRegion(), continueBlock);
103 } else {
104 elseBeforeBody = elseAfterBody = continueBlock;
105 }
106
107 rewriter.setInsertionPointToEnd(currentBlock);
108 cir::BrCondOp::create(rewriter, loc, ifOp.getCondition(), thenBeforeBody,
109 elseBeforeBody);
110
111 if (!emptyElse) {
112 rewriter.setInsertionPointToEnd(elseAfterBody);
113 if (auto elseYieldOP =
114 dyn_cast<cir::YieldOp>(elseAfterBody->getTerminator())) {
115 rewriter.replaceOpWithNewOp<cir::BrOp>(
116 elseYieldOP, elseYieldOP.getArgs(), continueBlock);
117 }
118 }
119
120 rewriter.replaceOp(ifOp, continueBlock->getArguments());
121 return mlir::success();
122 }
123};
124
125class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> {
126public:
127 using OpRewritePattern<cir::ScopeOp>::OpRewritePattern;
128
129 mlir::LogicalResult
130 matchAndRewrite(cir::ScopeOp scopeOp,
131 mlir::PatternRewriter &rewriter) const override {
132 mlir::OpBuilder::InsertionGuard guard(rewriter);
133 mlir::Location loc = scopeOp.getLoc();
134
135 // Empty scope: just remove it.
136 // TODO: Remove this logic once CIR uses MLIR infrastructure to remove
137 // trivially dead operations. MLIR canonicalizer is too aggressive and we
138 // need to either (a) make sure all our ops model all side-effects and/or
139 // (b) have more options in the canonicalizer in MLIR to temper
140 // aggressiveness level.
141 if (scopeOp.isEmpty()) {
142 rewriter.eraseOp(scopeOp);
143 return mlir::success();
144 }
145
146 // Split the current block before the ScopeOp to create the inlining
147 // point.
148 mlir::Block *currentBlock = rewriter.getInsertionBlock();
149 mlir::Block *continueBlock =
150 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
151 if (scopeOp.getNumResults() > 0)
152 continueBlock->addArguments(scopeOp.getResultTypes(), loc);
153
154 // Inline body region.
155 mlir::Block *beforeBody = &scopeOp.getScopeRegion().front();
156 mlir::Block *afterBody = &scopeOp.getScopeRegion().back();
157 rewriter.inlineRegionBefore(scopeOp.getScopeRegion(), continueBlock);
158
159 // Save stack and then branch into the body of the region.
160 rewriter.setInsertionPointToEnd(currentBlock);
162 cir::BrOp::create(rewriter, loc, mlir::ValueRange(), beforeBody);
163
164 // Replace the scopeop return with a branch that jumps out of the body.
165 // Stack restore before leaving the body region.
166 rewriter.setInsertionPointToEnd(afterBody);
167 if (auto yieldOp = dyn_cast<cir::YieldOp>(afterBody->getTerminator())) {
168 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(),
169 continueBlock);
170 }
171
172 // Replace the op with values return from the body region.
173 rewriter.replaceOp(scopeOp, continueBlock->getArguments());
174
175 return mlir::success();
176 }
177};
178
179class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
180public:
181 using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
182
183 inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
184 cir::YieldOp yieldOp,
185 mlir::Block *destination) const {
186 rewriter.setInsertionPoint(yieldOp);
187 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
188 destination);
189 }
190
191 // Return the new defaultDestination block.
192 Block *condBrToRangeDestination(cir::SwitchOp op,
193 mlir::PatternRewriter &rewriter,
194 mlir::Block *rangeDestination,
195 mlir::Block *defaultDestination,
196 const APInt &lowerBound,
197 const APInt &upperBound) const {
198 assert(lowerBound.sle(upperBound) && "Invalid range");
199 mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
200 cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true);
201 cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false);
202
203 cir::ConstantOp rangeLength = cir::ConstantOp::create(
204 rewriter, op.getLoc(),
205 cir::IntAttr::get(sIntType, upperBound - lowerBound));
206
207 cir::ConstantOp lowerBoundValue = cir::ConstantOp::create(
208 rewriter, op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
209 cir::BinOp diffValue =
210 cir::BinOp::create(rewriter, op.getLoc(), sIntType, cir::BinOpKind::Sub,
211 op.getCondition(), lowerBoundValue);
212
213 // Use unsigned comparison to check if the condition is in the range.
214 cir::CastOp uDiffValue = cir::CastOp::create(
215 rewriter, op.getLoc(), uIntType, CastKind::integral, diffValue);
216 cir::CastOp uRangeLength = cir::CastOp::create(
217 rewriter, op.getLoc(), uIntType, CastKind::integral, rangeLength);
218
219 cir::CmpOp cmpResult = cir::CmpOp::create(
220 rewriter, op.getLoc(), cir::BoolType::get(op.getContext()),
221 cir::CmpOpKind::le, uDiffValue, uRangeLength);
222 cir::BrCondOp::create(rewriter, op.getLoc(), cmpResult, rangeDestination,
223 defaultDestination);
224 return resBlock;
225 }
226
227 mlir::LogicalResult
228 matchAndRewrite(cir::SwitchOp op,
229 mlir::PatternRewriter &rewriter) const override {
230 llvm::SmallVector<CaseOp> cases;
231 op.collectCases(cases);
232
233 // Empty switch statement: just erase it.
234 if (cases.empty()) {
235 rewriter.eraseOp(op);
236 return mlir::success();
237 }
238
239 // Create exit block from the next node of cir.switch op.
240 mlir::Block *exitBlock = rewriter.splitBlock(
241 rewriter.getBlock(), op->getNextNode()->getIterator());
242
243 // We lower cir.switch op in the following process:
244 // 1. Inline the region from the switch op after switch op.
245 // 2. Traverse each cir.case op:
246 // a. Record the entry block, block arguments and condition for every
247 // case. b. Inline the case region after the case op.
248 // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the
249 // recorded block and conditions.
250
251 // inline everything from switch body between the switch op and the exit
252 // block.
253 {
254 cir::YieldOp switchYield = nullptr;
255 // Clear switch operation.
256 for (mlir::Block &block :
257 llvm::make_early_inc_range(op.getBody().getBlocks()))
258 if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
259 switchYield = yieldOp;
260
261 assert(!op.getBody().empty());
262 mlir::Block *originalBlock = op->getBlock();
263 mlir::Block *swopBlock =
264 rewriter.splitBlock(originalBlock, op->getIterator());
265 rewriter.inlineRegionBefore(op.getBody(), exitBlock);
266
267 if (switchYield)
268 rewriteYieldOp(rewriter, switchYield, exitBlock);
269
270 rewriter.setInsertionPointToEnd(originalBlock);
271 cir::BrOp::create(rewriter, op.getLoc(), swopBlock);
272 }
273
274 // Allocate required data structures (disconsider default case in
275 // vectors).
276 llvm::SmallVector<mlir::APInt, 8> caseValues;
277 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
278 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
279
280 llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
281 llvm::SmallVector<mlir::Block *> rangeDestinations;
282 llvm::SmallVector<mlir::ValueRange> rangeOperands;
283
284 // Initialize default case as optional.
285 mlir::Block *defaultDestination = exitBlock;
286 mlir::ValueRange defaultOperands = exitBlock->getArguments();
287
288 // Digest the case statements values and bodies.
289 for (cir::CaseOp caseOp : cases) {
290 mlir::Region &region = caseOp.getCaseRegion();
291
292 // Found default case: save destination and operands.
293 switch (caseOp.getKind()) {
294 case cir::CaseOpKind::Default:
295 defaultDestination = &region.front();
296 defaultOperands = defaultDestination->getArguments();
297 break;
298 case cir::CaseOpKind::Range:
299 assert(caseOp.getValue().size() == 2 &&
300 "Case range should have 2 case value");
301 rangeValues.push_back(
302 {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),
303 cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});
304 rangeDestinations.push_back(&region.front());
305 rangeOperands.push_back(rangeDestinations.back()->getArguments());
306 break;
307 case cir::CaseOpKind::Anyof:
308 case cir::CaseOpKind::Equal:
309 // AnyOf cases kind can have multiple values, hence the loop below.
310 for (const mlir::Attribute &value : caseOp.getValue()) {
311 caseValues.push_back(cast<cir::IntAttr>(value).getValue());
312 caseDestinations.push_back(&region.front());
313 caseOperands.push_back(caseDestinations.back()->getArguments());
314 }
315 break;
316 }
317
318 // Handle break statements.
319 walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
320 region, [&](mlir::Operation *op) {
321 if (!isa<cir::BreakOp>(op))
322 return mlir::WalkResult::advance();
323
324 lowerTerminator(op, exitBlock, rewriter);
325 return mlir::WalkResult::skip();
326 });
327
328 // Track fallthrough in cases.
329 for (mlir::Block &blk : region.getBlocks()) {
330 if (blk.getNumSuccessors())
331 continue;
332
333 if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {
334 mlir::Operation *nextOp = caseOp->getNextNode();
335 assert(nextOp && "caseOp is not expected to be the last op");
336 mlir::Block *oldBlock = nextOp->getBlock();
337 mlir::Block *newBlock =
338 rewriter.splitBlock(oldBlock, nextOp->getIterator());
339 rewriter.setInsertionPointToEnd(oldBlock);
340 cir::BrOp::create(rewriter, nextOp->getLoc(), mlir::ValueRange(),
341 newBlock);
342 rewriteYieldOp(rewriter, yieldOp, newBlock);
343 }
344 }
345
346 mlir::Block *oldBlock = caseOp->getBlock();
347 mlir::Block *newBlock =
348 rewriter.splitBlock(oldBlock, caseOp->getIterator());
349
350 mlir::Block &entryBlock = caseOp.getCaseRegion().front();
351 rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);
352
353 // Create a branch to the entry of the inlined region.
354 rewriter.setInsertionPointToEnd(oldBlock);
355 cir::BrOp::create(rewriter, caseOp.getLoc(), &entryBlock);
356 }
357
358 // Remove all cases since we've inlined the regions.
359 for (cir::CaseOp caseOp : cases) {
360 mlir::Block *caseBlock = caseOp->getBlock();
361 // Erase the block with no predecessors here to make the generated code
362 // simpler a little bit.
363 if (caseBlock->hasNoPredecessors())
364 rewriter.eraseBlock(caseBlock);
365 else
366 rewriter.eraseOp(caseOp);
367 }
368
369 for (auto [rangeVal, operand, destination] :
370 llvm::zip(rangeValues, rangeOperands, rangeDestinations)) {
371 APInt lowerBound = rangeVal.first;
372 APInt upperBound = rangeVal.second;
373
374 // The case range is unreachable, skip it.
375 if (lowerBound.sgt(upperBound))
376 continue;
377
378 // If range is small, add multiple switch instruction cases.
379 // This magical number is from the original CGStmt code.
380 constexpr int kSmallRangeThreshold = 64;
381 if ((upperBound - lowerBound)
382 .ult(llvm::APInt(32, kSmallRangeThreshold))) {
383 for (APInt iValue = lowerBound; iValue.sle(upperBound); ++iValue) {
384 caseValues.push_back(iValue);
385 caseOperands.push_back(operand);
386 caseDestinations.push_back(destination);
387 }
388 continue;
389 }
390
391 defaultDestination =
392 condBrToRangeDestination(op, rewriter, destination,
393 defaultDestination, lowerBound, upperBound);
394 defaultOperands = operand;
395 }
396
397 // Set switch op to branch to the newly created blocks.
398 rewriter.setInsertionPoint(op);
399 rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>(
400 op, op.getCondition(), defaultDestination, defaultOperands, caseValues,
401 caseDestinations, caseOperands);
402
403 return mlir::success();
404 }
405};
406
407class CIRLoopOpInterfaceFlattening
408 : public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> {
409public:
410 using mlir::OpInterfaceRewritePattern<
411 cir::LoopOpInterface>::OpInterfaceRewritePattern;
412
413 inline void lowerConditionOp(cir::ConditionOp op, mlir::Block *body,
414 mlir::Block *exit,
415 mlir::PatternRewriter &rewriter) const {
416 mlir::OpBuilder::InsertionGuard guard(rewriter);
417 rewriter.setInsertionPoint(op);
418 rewriter.replaceOpWithNewOp<cir::BrCondOp>(op, op.getCondition(), body,
419 exit);
420 }
421
422 mlir::LogicalResult
423 matchAndRewrite(cir::LoopOpInterface op,
424 mlir::PatternRewriter &rewriter) const final {
425 // Setup CFG blocks.
426 mlir::Block *entry = rewriter.getInsertionBlock();
427 mlir::Block *exit =
428 rewriter.splitBlock(entry, rewriter.getInsertionPoint());
429 mlir::Block *cond = &op.getCond().front();
430 mlir::Block *body = &op.getBody().front();
431 mlir::Block *step =
432 (op.maybeGetStep() ? &op.maybeGetStep()->front() : nullptr);
433
434 // Setup loop entry branch.
435 rewriter.setInsertionPointToEnd(entry);
436 cir::BrOp::create(rewriter, op.getLoc(), &op.getEntry().front());
437
438 // Branch from condition region to body or exit.
439 auto conditionOp = cast<cir::ConditionOp>(cond->getTerminator());
440 lowerConditionOp(conditionOp, body, exit, rewriter);
441
442 // TODO(cir): Remove the walks below. It visits operations unnecessarily.
443 // However, to solve this we would likely need a custom DialectConversion
444 // driver to customize the order that operations are visited.
445
446 // Lower continue statements.
447 mlir::Block *dest = (step ? step : cond);
448 op.walkBodySkippingNestedLoops([&](mlir::Operation *op) {
449 if (!isa<cir::ContinueOp>(op))
450 return mlir::WalkResult::advance();
451
452 lowerTerminator(op, dest, rewriter);
453 return mlir::WalkResult::skip();
454 });
455
456 // Lower break statements.
458 walkRegionSkipping<cir::LoopOpInterface>(
459 op.getBody(), [&](mlir::Operation *op) {
460 if (!isa<cir::BreakOp>(op))
461 return mlir::WalkResult::advance();
462
463 lowerTerminator(op, exit, rewriter);
464 return mlir::WalkResult::skip();
465 });
466
467 // Lower optional body region yield.
468 for (mlir::Block &blk : op.getBody().getBlocks()) {
469 auto bodyYield = dyn_cast<cir::YieldOp>(blk.getTerminator());
470 if (bodyYield)
471 lowerTerminator(bodyYield, (step ? step : cond), rewriter);
472 }
473
474 // Lower mandatory step region yield.
475 if (step)
476 lowerTerminator(cast<cir::YieldOp>(step->getTerminator()), cond,
477 rewriter);
478
479 // Move region contents out of the loop op.
480 rewriter.inlineRegionBefore(op.getCond(), exit);
481 rewriter.inlineRegionBefore(op.getBody(), exit);
482 if (step)
483 rewriter.inlineRegionBefore(*op.maybeGetStep(), exit);
484
485 rewriter.eraseOp(op);
486 return mlir::success();
487 }
488};
489
490class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
491public:
492 using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
493
494 mlir::LogicalResult
495 matchAndRewrite(cir::TernaryOp op,
496 mlir::PatternRewriter &rewriter) const override {
497 Location loc = op->getLoc();
498 Block *condBlock = rewriter.getInsertionBlock();
499 Block::iterator opPosition = rewriter.getInsertionPoint();
500 Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
501 llvm::SmallVector<mlir::Location, 2> locs;
502 // Ternary result is optional, make sure to populate the location only
503 // when relevant.
504 if (op->getResultTypes().size())
505 locs.push_back(loc);
506 Block *continueBlock =
507 rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
508 cir::BrOp::create(rewriter, loc, remainingOpsBlock);
509
510 Region &trueRegion = op.getTrueRegion();
511 Block *trueBlock = &trueRegion.front();
512 mlir::Operation *trueTerminator = trueRegion.back().getTerminator();
513 rewriter.setInsertionPointToEnd(&trueRegion.back());
514
515 // Handle both yield and unreachable terminators (throw expressions)
516 if (auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator)) {
517 rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
518 continueBlock);
519 } else if (isa<cir::UnreachableOp>(trueTerminator)) {
520 // Terminator is unreachable (e.g., from throw), just keep it
521 } else {
522 trueTerminator->emitError("unexpected terminator in ternary true region, "
523 "expected yield or unreachable, got: ")
524 << trueTerminator->getName();
525 return mlir::failure();
526 }
527 rewriter.inlineRegionBefore(trueRegion, continueBlock);
528
529 Block *falseBlock = continueBlock;
530 Region &falseRegion = op.getFalseRegion();
531
532 falseBlock = &falseRegion.front();
533 mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
534 rewriter.setInsertionPointToEnd(&falseRegion.back());
535
536 // Handle both yield and unreachable terminators (throw expressions)
537 if (auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator)) {
538 rewriter.replaceOpWithNewOp<cir::BrOp>(
539 falseYieldOp, falseYieldOp.getArgs(), continueBlock);
540 } else if (isa<cir::UnreachableOp>(falseTerminator)) {
541 // Terminator is unreachable (e.g., from throw), just keep it
542 } else {
543 falseTerminator->emitError("unexpected terminator in ternary false "
544 "region, expected yield or unreachable, got: ")
545 << falseTerminator->getName();
546 return mlir::failure();
547 }
548 rewriter.inlineRegionBefore(falseRegion, continueBlock);
549
550 rewriter.setInsertionPointToEnd(condBlock);
551 cir::BrCondOp::create(rewriter, loc, op.getCond(), trueBlock, falseBlock);
552
553 rewriter.replaceOp(op, continueBlock->getArguments());
554
555 // Ok, we're done!
556 return mlir::success();
557 }
558};
559
560class CIRTryOpFlattening : public mlir::OpRewritePattern<cir::TryOp> {
561public:
562 using OpRewritePattern<cir::TryOp>::OpRewritePattern;
563
564 mlir::Block *buildTryBody(cir::TryOp tryOp,
565 mlir::PatternRewriter &rewriter) const {
566 // Split the current block before the TryOp to create the inlining
567 // point.
568 mlir::Block *beforeTryScopeBlock = rewriter.getInsertionBlock();
569 mlir::Block *afterTry =
570 rewriter.splitBlock(beforeTryScopeBlock, rewriter.getInsertionPoint());
571
572 // Inline body region.
573 mlir::Block *beforeBody = &tryOp.getTryRegion().front();
574 rewriter.inlineRegionBefore(tryOp.getTryRegion(), afterTry);
575
576 // Branch into the body of the region.
577 rewriter.setInsertionPointToEnd(beforeTryScopeBlock);
578 cir::BrOp::create(rewriter, tryOp.getLoc(), mlir::ValueRange(), beforeBody);
579 return afterTry;
580 }
581
582 void buildHandlers(cir::TryOp tryOp, mlir::PatternRewriter &rewriter,
583 mlir::Block *afterBody, mlir::Block *afterTry,
584 SmallVectorImpl<cir::CallOp> &callsToRewrite,
585 SmallVectorImpl<mlir::Block *> &landingPads) const {
586 // Replace the tryOp return with a branch that jumps out of the body.
587 rewriter.setInsertionPointToEnd(afterBody);
588
589 mlir::Block *beforeCatch = rewriter.getInsertionBlock();
590 rewriter.setInsertionPointToEnd(beforeCatch);
591
592 // Check if the terminator is a YieldOp because there could be another
593 // terminator, e.g. unreachable
594 if (auto tryBodyYield = dyn_cast<cir::YieldOp>(afterBody->getTerminator()))
595 rewriter.replaceOpWithNewOp<cir::BrOp>(tryBodyYield, afterTry);
596
597 mlir::ArrayAttr handlers = tryOp.getHandlerTypesAttr();
598 if (!handlers || handlers.empty())
599 return;
600
601 llvm_unreachable("TryOpFlattening buildHandlers with CallsOp is NYI");
602 }
603
604 mlir::LogicalResult
605 matchAndRewrite(cir::TryOp tryOp,
606 mlir::PatternRewriter &rewriter) const override {
607 mlir::OpBuilder::InsertionGuard guard(rewriter);
608 mlir::Block *afterBody = &tryOp.getTryRegion().back();
609
610 // Grab the collection of `cir.call exception`s to rewrite to
611 // `cir.try_call`.
612 llvm::SmallVector<cir::CallOp, 4> callsToRewrite;
613 tryOp.getTryRegion().walk([&](CallOp op) {
614 if (op.getNothrow())
615 return;
616
617 // Only grab calls within immediate closest TryOp scope.
618 if (op->getParentOfType<cir::TryOp>() != tryOp)
619 return;
620 callsToRewrite.push_back(op);
621 });
622
623 if (!callsToRewrite.empty())
624 llvm_unreachable(
625 "TryOpFlattening with try block that contains CallOps is NYI");
626
627 // Build try body.
628 mlir::Block *afterTry = buildTryBody(tryOp, rewriter);
629
630 // Build handlers.
631 llvm::SmallVector<mlir::Block *, 4> landingPads;
632 buildHandlers(tryOp, rewriter, afterBody, afterTry, callsToRewrite,
633 landingPads);
634
635 rewriter.eraseOp(tryOp);
636
637 assert((landingPads.size() == callsToRewrite.size()) &&
638 "expected matching number of entries");
639
640 // Quick block cleanup: no indirection to the post try block.
641 auto brOp = dyn_cast<cir::BrOp>(afterTry->getTerminator());
642 if (brOp && brOp.getDest()->hasNoPredecessors()) {
643 mlir::Block *srcBlock = brOp.getDest();
644 rewriter.eraseOp(brOp);
645 rewriter.mergeBlocks(srcBlock, afterTry);
646 }
647
648 return mlir::success();
649 }
650};
651
652void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
653 patterns
654 .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
655 CIRSwitchOpFlattening, CIRTernaryOpFlattening, CIRTryOpFlattening>(
656 patterns.getContext());
657}
658
659void CIRFlattenCFGPass::runOnOperation() {
660 RewritePatternSet patterns(&getContext());
661 populateFlattenCFGPatterns(patterns);
662
663 // Collect operations to apply patterns.
664 llvm::SmallVector<Operation *, 16> ops;
665 getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
669 if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp, TryOp>(op))
670 ops.push_back(op);
671 });
672
673 // Apply patterns.
674 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
675 signalPassFailure();
676}
677
678} // namespace
679
680namespace mlir {
681
682std::unique_ptr<Pass> createCIRFlattenCFGPass() {
683 return std::make_unique<CIRFlattenCFGPass>();
684}
685
686} // namespace mlir
llvm::APInt APInt
Definition FixedPoint.h:19
std::unique_ptr< Pass > createCIRFlattenCFGPass()
float __ovld __cnfn step(float, float)
Returns 0.0 if x < edge, otherwise it returns 1.0.
static bool stackSaveOp()