clang 23.0.0git
FlattenCFG.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements pass that inlines CIR operations regions into the parent
10// function region.
11//
12//===----------------------------------------------------------------------===//
13
14#include "PassDetail.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/IR/Block.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Interfaces/SideEffectInterfaces.h"
20#include "mlir/Support/LogicalResult.h"
21#include "mlir/Transforms/DialectConversion.h"
22#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28#include "llvm/ADT/TypeSwitch.h"
29
30using namespace mlir;
31using namespace cir;
32
33namespace mlir {
34#define GEN_PASS_DEF_CIRFLATTENCFG
35#include "clang/CIR/Dialect/Passes.h.inc"
36} // namespace mlir
37
38namespace {
39
40/// Lowers operations with the terminator trait that have a single successor.
41void lowerTerminator(mlir::Operation *op, mlir::Block *dest,
42 mlir::PatternRewriter &rewriter) {
43 assert(op->hasTrait<mlir::OpTrait::IsTerminator>() && "not a terminator");
44 mlir::OpBuilder::InsertionGuard guard(rewriter);
45 rewriter.setInsertionPoint(op);
46 rewriter.replaceOpWithNewOp<cir::BrOp>(op, dest);
47}
48
49/// Walks a region while skipping operations of type `Ops`. This ensures the
50/// callback is not applied to said operations and its children.
51template <typename... Ops>
52void walkRegionSkipping(
53 mlir::Region &region,
54 mlir::function_ref<mlir::WalkResult(mlir::Operation *)> callback) {
55 region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
56 if (isa<Ops...>(op))
57 return mlir::WalkResult::skip();
58 return callback(op);
59 });
60}
61
62/// Check whether a region contains any nested op with regions (i.e. structured
63/// CIR ops that must be flattened before their parent). The greedy pattern
64/// rewriter doesn't guarantee inside-out processing order — when a pattern
65/// fires and modifies IR, newly created ops go onto the worklist and can be
66/// visited in any order. So each flattening pattern must explicitly defer
67/// until its nested structured ops are flat.
68///
69/// CaseOps are excluded because they are structural children of SwitchOp and
70/// are handled by the SwitchOp flattening pattern.
71static bool hasNestedOpsToFlatten(mlir::Region &region) {
72 return region
73 .walk([](mlir::Operation *op) {
74 if (op->getNumRegions() > 0 && !isa<cir::CaseOp>(op))
75 return mlir::WalkResult::interrupt();
76 return mlir::WalkResult::advance();
77 })
78 .wasInterrupted();
79}
80
81/// True if `op` is a non-returning terminator — currently `cir.unreachable`
82/// or `cir.trap`. Such terminators don't fall through and don't yield a
83/// value, so when flattening a region they can be left in place rather than
84/// being replaced with a branch to the continuation block. Add new ops here
85/// (e.g. a hypothetical `cir.abort`) so every flattening pattern picks them
86/// up at once.
87static bool isNonReturningTerminator(mlir::Operation *op) {
88 return mlir::isa_and_nonnull<cir::UnreachableOp, cir::TrapOp>(op);
89}
90
91/// Rewrite the terminator of `region`'s exit block so that, after
92/// flattening, control falls through to `continueBlock`. The exit
93/// terminator is expected to be either:
94/// - `cir.yield`: replaced with `cir.br` to `continueBlock` (yielded
95/// args become the destination block's arguments).
96/// - non-returning (`cir.unreachable`, `cir.trap`): left in place — no
97/// branch is needed.
98///
99/// On success returns `success()`. If the terminator is anything else, an
100/// error is emitted and `failure()` is returned. NOTE: callers in this
101/// file have typically already mutated IR (splitBlock / createBlock) by
102/// the time this is invoked, so the MLIR pattern rewriter contract
103/// requires them to still return `success()` from the surrounding
104/// pattern; the `failure()` here just signals "stop trying to wire up
105/// this region".
106static mlir::LogicalResult
107rewriteRegionExitToContinue(mlir::PatternRewriter &rewriter,
108 mlir::Region &region, mlir::Block *continueBlock,
109 llvm::StringRef regionDescription) {
110 mlir::Operation *terminator = region.back().getTerminator();
111 rewriter.setInsertionPointToEnd(&region.back());
112 if (auto yieldOp = mlir::dyn_cast<cir::YieldOp>(terminator)) {
113 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(),
114 continueBlock);
115 return mlir::success();
116 }
117 if (isNonReturningTerminator(terminator))
118 return mlir::success();
119 terminator->emitError("unexpected terminator in ")
120 << regionDescription
121 << " region, expected yield, unreachable, or trap, got: "
122 << terminator->getName();
123 return mlir::failure();
124}
125
126struct CIRFlattenCFGPass : public impl::CIRFlattenCFGBase<CIRFlattenCFGPass> {
127
128 CIRFlattenCFGPass() = default;
129 void runOnOperation() override;
130};
131
132struct CIRIfFlattening : public mlir::OpRewritePattern<cir::IfOp> {
133 using OpRewritePattern<IfOp>::OpRewritePattern;
134
135 mlir::LogicalResult
136 matchAndRewrite(cir::IfOp ifOp,
137 mlir::PatternRewriter &rewriter) const override {
138 mlir::OpBuilder::InsertionGuard guard(rewriter);
139 mlir::Location loc = ifOp.getLoc();
140 bool emptyElse = ifOp.getElseRegion().empty();
141 mlir::Block *currentBlock = rewriter.getInsertionBlock();
142 mlir::Block *remainingOpsBlock =
143 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
144 mlir::Block *continueBlock;
145 if (ifOp->getResults().empty())
146 continueBlock = remainingOpsBlock;
147 else
148 llvm_unreachable("NYI");
149
150 // Inline the region
151 mlir::Block *thenBeforeBody = &ifOp.getThenRegion().front();
152 mlir::Block *thenAfterBody = &ifOp.getThenRegion().back();
153 rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock);
154
155 rewriter.setInsertionPointToEnd(thenAfterBody);
156 if (auto thenYieldOp =
157 dyn_cast<cir::YieldOp>(thenAfterBody->getTerminator())) {
158 rewriter.replaceOpWithNewOp<cir::BrOp>(thenYieldOp, thenYieldOp.getArgs(),
159 continueBlock);
160 }
161
162 rewriter.setInsertionPointToEnd(continueBlock);
163
164 // Has else region: inline it.
165 mlir::Block *elseBeforeBody = nullptr;
166 mlir::Block *elseAfterBody = nullptr;
167 if (!emptyElse) {
168 elseBeforeBody = &ifOp.getElseRegion().front();
169 elseAfterBody = &ifOp.getElseRegion().back();
170 rewriter.inlineRegionBefore(ifOp.getElseRegion(), continueBlock);
171 } else {
172 elseBeforeBody = elseAfterBody = continueBlock;
173 }
174
175 rewriter.setInsertionPointToEnd(currentBlock);
176 cir::BrCondOp::create(rewriter, loc, ifOp.getCondition(), thenBeforeBody,
177 elseBeforeBody);
178
179 if (!emptyElse) {
180 rewriter.setInsertionPointToEnd(elseAfterBody);
181 if (auto elseYieldOP =
182 dyn_cast<cir::YieldOp>(elseAfterBody->getTerminator())) {
183 rewriter.replaceOpWithNewOp<cir::BrOp>(
184 elseYieldOP, elseYieldOP.getArgs(), continueBlock);
185 }
186 }
187
188 rewriter.replaceOp(ifOp, continueBlock->getArguments());
189 return mlir::success();
190 }
191};
192
193class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> {
194public:
195 using OpRewritePattern<cir::ScopeOp>::OpRewritePattern;
196
197 mlir::LogicalResult
198 matchAndRewrite(cir::ScopeOp scopeOp,
199 mlir::PatternRewriter &rewriter) const override {
200 mlir::OpBuilder::InsertionGuard guard(rewriter);
201 mlir::Location loc = scopeOp.getLoc();
202
203 // Empty scope: just remove it.
204 // TODO: Remove this logic once CIR uses MLIR infrastructure to remove
205 // trivially dead operations. MLIR canonicalizer is too aggressive and we
206 // need to either (a) make sure all our ops model all side-effects and/or
207 // (b) have more options in the canonicalizer in MLIR to temper
208 // aggressiveness level.
209 if (scopeOp.isEmpty()) {
210 rewriter.eraseOp(scopeOp);
211 return mlir::success();
212 }
213
214 // Split the current block before the ScopeOp to create the inlining
215 // point.
216 mlir::Block *currentBlock = rewriter.getInsertionBlock();
217 mlir::Block *continueBlock =
218 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
219 if (scopeOp.getNumResults() > 0)
220 continueBlock->addArguments(scopeOp.getResultTypes(), loc);
221
222 // Inline body region.
223 mlir::Block *beforeBody = &scopeOp.getScopeRegion().front();
224 mlir::Block *afterBody = &scopeOp.getScopeRegion().back();
225 rewriter.inlineRegionBefore(scopeOp.getScopeRegion(), continueBlock);
226
227 // Save stack and then branch into the body of the region.
228 rewriter.setInsertionPointToEnd(currentBlock);
230 cir::BrOp::create(rewriter, loc, mlir::ValueRange(), beforeBody);
231
232 // Replace the scopeop return with a branch that jumps out of the body.
233 // Stack restore before leaving the body region.
234 rewriter.setInsertionPointToEnd(afterBody);
235 if (auto yieldOp = dyn_cast<cir::YieldOp>(afterBody->getTerminator())) {
236 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(),
237 continueBlock);
238 }
239
240 // Replace the op with values return from the body region.
241 rewriter.replaceOp(scopeOp, continueBlock->getArguments());
242
243 return mlir::success();
244 }
245};
246
247class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
248public:
249 using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
250
251 inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
252 cir::YieldOp yieldOp,
253 mlir::Block *destination) const {
254 rewriter.setInsertionPoint(yieldOp);
255 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
256 destination);
257 }
258
259 // Return the new defaultDestination block.
260 Block *condBrToRangeDestination(cir::SwitchOp op,
261 mlir::PatternRewriter &rewriter,
262 mlir::Block *rangeDestination,
263 mlir::Block *defaultDestination,
264 const APInt &lowerBound,
265 const APInt &upperBound) const {
266 assert(lowerBound.sle(upperBound) && "Invalid range");
267 mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
268 cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true);
269 cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false);
270
271 cir::ConstantOp rangeLength = cir::ConstantOp::create(
272 rewriter, op.getLoc(),
273 cir::IntAttr::get(sIntType, upperBound - lowerBound));
274
275 cir::ConstantOp lowerBoundValue = cir::ConstantOp::create(
276 rewriter, op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
277 mlir::Value diffValue = cir::SubOp::create(
278 rewriter, op.getLoc(), op.getCondition(), lowerBoundValue);
279
280 // Use unsigned comparison to check if the condition is in the range.
281 cir::CastOp uDiffValue = cir::CastOp::create(
282 rewriter, op.getLoc(), uIntType, CastKind::integral, diffValue);
283 cir::CastOp uRangeLength = cir::CastOp::create(
284 rewriter, op.getLoc(), uIntType, CastKind::integral, rangeLength);
285
286 cir::CmpOp cmpResult = cir::CmpOp::create(
287 rewriter, op.getLoc(), cir::CmpOpKind::le, uDiffValue, uRangeLength);
288 cir::BrCondOp::create(rewriter, op.getLoc(), cmpResult, rangeDestination,
289 defaultDestination);
290 return resBlock;
291 }
292
293 mlir::LogicalResult
294 matchAndRewrite(cir::SwitchOp op,
295 mlir::PatternRewriter &rewriter) const override {
296 // All nested structured CIR ops must be flattened before the switch.
297 // Break statements inside nested structured ops would create branches to
298 // blocks outside those ops' regions, which is invalid. Fail the match so
299 // the pattern rewriter will process them first.
300 for (mlir::Region &region : op->getRegions())
301 if (hasNestedOpsToFlatten(region))
302 return mlir::failure();
303
304 // Empty switch statement: just erase it.
305 if (op.getBody().hasOneBlock() &&
306 op.getBody().front().without_terminator().empty()) {
307 rewriter.eraseOp(op);
308 return mlir::success();
309 }
310
311 llvm::SmallVector<CaseOp> cases;
312 op.collectCases(cases);
313
314 // Create exit block from the next node of cir.switch op.
315 mlir::Block *exitBlock = rewriter.splitBlock(
316 rewriter.getBlock(), op->getNextNode()->getIterator());
317
318 // We lower cir.switch op in the following process:
319 // 1. Inline the region from the switch op after switch op.
320 // 2. Traverse each cir.case op:
321 // a. Record the entry block, block arguments and condition for every
322 // case. b. Inline the case region after the case op.
323 // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the
324 // recorded block and conditions.
325
326 // First we have to handle the rewrite of all of the 'break' ops to make
327 // sure they now go to the right place, including the ones in the pre-case
328 // blcoks.
329 walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
330 op.getBody(), [&](mlir::Operation *op) {
331 if (!isa<cir::BreakOp>(op))
332 return mlir::WalkResult::advance();
333
334 lowerTerminator(op, exitBlock, rewriter);
335 return mlir::WalkResult::skip();
336 });
337
338 // inline everything from switch body between the switch op and the exit
339 // block.
340 {
341 cir::YieldOp switchYield = nullptr;
342 // Clear switch operation.
343 for (mlir::Block &block :
344 llvm::make_early_inc_range(op.getBody().getBlocks()))
345 if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
346 switchYield = yieldOp;
347
348 assert(!op.getBody().empty());
349 mlir::Block *originalBlock = op->getBlock();
350 mlir::Block *swopBlock =
351 rewriter.splitBlock(originalBlock, op->getIterator());
352 rewriter.inlineRegionBefore(op.getBody(), exitBlock);
353
354 if (switchYield)
355 rewriteYieldOp(rewriter, switchYield, exitBlock);
356
357 rewriter.setInsertionPointToEnd(originalBlock);
358 cir::BrOp::create(rewriter, op.getLoc(), swopBlock);
359 }
360
361 // Allocate required data structures (disconsider default case in
362 // vectors).
363 llvm::SmallVector<mlir::APInt, 8> caseValues;
364 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
365 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
366
367 llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
368 llvm::SmallVector<mlir::Block *> rangeDestinations;
369 llvm::SmallVector<mlir::ValueRange> rangeOperands;
370
371 // Initialize default case as optional.
372 mlir::Block *defaultDestination = exitBlock;
373 mlir::ValueRange defaultOperands = exitBlock->getArguments();
374
375 // Digest the case statements values and bodies.
376 for (cir::CaseOp caseOp : cases) {
377 mlir::Region &region = caseOp.getCaseRegion();
378
379 // Found default case: save destination and operands.
380 switch (caseOp.getKind()) {
381 case cir::CaseOpKind::Default:
382 defaultDestination = &region.front();
383 defaultOperands = defaultDestination->getArguments();
384 break;
385 case cir::CaseOpKind::Range:
386 assert(caseOp.getValue().size() == 2 &&
387 "Case range should have 2 case value");
388 rangeValues.push_back(
389 {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),
390 cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});
391 rangeDestinations.push_back(&region.front());
392 rangeOperands.push_back(rangeDestinations.back()->getArguments());
393 break;
394 case cir::CaseOpKind::Anyof:
395 case cir::CaseOpKind::Equal:
396 // AnyOf cases kind can have multiple values, hence the loop below.
397 for (const mlir::Attribute &value : caseOp.getValue()) {
398 caseValues.push_back(cast<cir::IntAttr>(value).getValue());
399 caseDestinations.push_back(&region.front());
400 caseOperands.push_back(caseDestinations.back()->getArguments());
401 }
402 break;
403 }
404
405 // Track fallthrough in cases.
406 for (mlir::Block &blk : region.getBlocks()) {
407 if (blk.getNumSuccessors())
408 continue;
409
410 if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {
411 mlir::Operation *nextOp = caseOp->getNextNode();
412 assert(nextOp && "caseOp is not expected to be the last op");
413 mlir::Block *oldBlock = nextOp->getBlock();
414 mlir::Block *newBlock =
415 rewriter.splitBlock(oldBlock, nextOp->getIterator());
416 rewriter.setInsertionPointToEnd(oldBlock);
417 cir::BrOp::create(rewriter, nextOp->getLoc(), mlir::ValueRange(),
418 newBlock);
419 rewriteYieldOp(rewriter, yieldOp, newBlock);
420 }
421 }
422
423 mlir::Block *oldBlock = caseOp->getBlock();
424 mlir::Block *newBlock =
425 rewriter.splitBlock(oldBlock, caseOp->getIterator());
426
427 mlir::Block &entryBlock = caseOp.getCaseRegion().front();
428 rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);
429
430 // Create a branch to the entry of the inlined region.
431 rewriter.setInsertionPointToEnd(oldBlock);
432 cir::BrOp::create(rewriter, caseOp.getLoc(), &entryBlock);
433 }
434
435 // Remove all cases since we've inlined the regions.
436 for (cir::CaseOp caseOp : cases) {
437 mlir::Block *caseBlock = caseOp->getBlock();
438 // Erase the block with no predecessors here to make the generated code
439 // simpler a little bit.
440 if (caseBlock->hasNoPredecessors())
441 rewriter.eraseBlock(caseBlock);
442 else
443 rewriter.eraseOp(caseOp);
444 }
445
446 for (auto [rangeVal, operand, destination] :
447 llvm::zip(rangeValues, rangeOperands, rangeDestinations)) {
448 APInt lowerBound = rangeVal.first;
449 APInt upperBound = rangeVal.second;
450
451 // The case range is unreachable, skip it.
452 if (lowerBound.sgt(upperBound))
453 continue;
454
455 // If range is small, add multiple switch instruction cases.
456 // This magical number is from the original CGStmt code.
457 constexpr int kSmallRangeThreshold = 64;
458 if ((upperBound - lowerBound)
459 .ult(llvm::APInt(32, kSmallRangeThreshold))) {
460 for (APInt iValue = lowerBound; iValue.sle(upperBound); ++iValue) {
461 caseValues.push_back(iValue);
462 caseOperands.push_back(operand);
463 caseDestinations.push_back(destination);
464 }
465 continue;
466 }
467
468 defaultDestination =
469 condBrToRangeDestination(op, rewriter, destination,
470 defaultDestination, lowerBound, upperBound);
471 defaultOperands = operand;
472 }
473
474 // Set switch op to branch to the newly created blocks.
475 rewriter.setInsertionPoint(op);
476 rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>(
477 op, op.getCondition(), defaultDestination, defaultOperands, caseValues,
478 caseDestinations, caseOperands);
479
480 return mlir::success();
481 }
482};
483
484class CIRLoopOpInterfaceFlattening
485 : public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> {
486public:
487 using mlir::OpInterfaceRewritePattern<
488 cir::LoopOpInterface>::OpInterfaceRewritePattern;
489
490 inline void lowerConditionOp(cir::ConditionOp op, mlir::Block *body,
491 mlir::Block *exit,
492 mlir::PatternRewriter &rewriter) const {
493 mlir::OpBuilder::InsertionGuard guard(rewriter);
494 rewriter.setInsertionPoint(op);
495 rewriter.replaceOpWithNewOp<cir::BrCondOp>(op, op.getCondition(), body,
496 exit);
497 }
498
499 mlir::LogicalResult
500 matchAndRewrite(cir::LoopOpInterface op,
501 mlir::PatternRewriter &rewriter) const final {
502 // All nested structured CIR ops must be flattened before the loop.
503 // Break/continue statements inside nested structured ops would create
504 // branches to blocks outside those ops' regions, which is invalid. Fail
505 // the match so the pattern rewriter will process them first.
506 for (mlir::Region &region : op->getRegions())
507 if (hasNestedOpsToFlatten(region))
508 return mlir::failure();
509
510 // Setup CFG blocks.
511 mlir::Block *entry = rewriter.getInsertionBlock();
512 mlir::Block *exit =
513 rewriter.splitBlock(entry, rewriter.getInsertionPoint());
514 mlir::Block *cond = &op.getCond().front();
515 mlir::Block *body = &op.getBody().front();
516 mlir::Block *step =
517 (op.maybeGetStep() ? &op.maybeGetStep()->front() : nullptr);
518
519 // Setup loop entry branch.
520 rewriter.setInsertionPointToEnd(entry);
521 cir::BrOp::create(rewriter, op.getLoc(), &op.getEntry().front());
522
523 // Branch from condition region to body or exit. The ConditionOp may not
524 // be in the first block of the condition region if a cleanup scope was
525 // already flattened within it, introducing multiple blocks. The
526 // ConditionOp is always the terminator of the last block.
527 auto conditionOp =
528 cast<cir::ConditionOp>(op.getCond().back().getTerminator());
529 lowerConditionOp(conditionOp, body, exit, rewriter);
530
531 // TODO(cir): Remove the walks below. It visits operations unnecessarily.
532 // However, to solve this we would likely need a custom DialectConversion
533 // driver to customize the order that operations are visited.
534
535 // Lower continue statements.
536 mlir::Block *dest = (step ? step : cond);
537 op.walkBodySkippingNestedLoops([&](mlir::Operation *op) {
538 if (!isa<cir::ContinueOp>(op))
539 return mlir::WalkResult::advance();
540
541 lowerTerminator(op, dest, rewriter);
542 return mlir::WalkResult::skip();
543 });
544
545 // Lower break statements.
546 walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
547 op.getBody(), [&](mlir::Operation *op) {
548 if (!isa<cir::BreakOp>(op))
549 return mlir::WalkResult::advance();
550
551 lowerTerminator(op, exit, rewriter);
552 return mlir::WalkResult::skip();
553 });
554
555 // Lower optional body region yield.
556 for (mlir::Block &blk : op.getBody().getBlocks()) {
557 auto bodyYield = dyn_cast<cir::YieldOp>(blk.getTerminator());
558 if (bodyYield)
559 lowerTerminator(bodyYield, (step ? step : cond), rewriter);
560 }
561
562 // Lower mandatory step region yield. Like the condition region, the
563 // YieldOp may be in the last block rather than the first if a cleanup
564 // scope was already flattened within the step region.
565 if (step)
566 lowerTerminator(
567 cast<cir::YieldOp>(op.maybeGetStep()->back().getTerminator()), cond,
568 rewriter);
569
570 // Move region contents out of the loop op.
571 rewriter.inlineRegionBefore(op.getCond(), exit);
572 rewriter.inlineRegionBefore(op.getBody(), exit);
573 if (step)
574 rewriter.inlineRegionBefore(*op.maybeGetStep(), exit);
575
576 rewriter.eraseOp(op);
577 return mlir::success();
578 }
579};
580
581class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
582public:
583 using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
584
585 mlir::LogicalResult
586 matchAndRewrite(cir::TernaryOp op,
587 mlir::PatternRewriter &rewriter) const override {
588 Location loc = op->getLoc();
589 Block *condBlock = rewriter.getInsertionBlock();
590 Block::iterator opPosition = rewriter.getInsertionPoint();
591 Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
592 llvm::SmallVector<mlir::Location, 2> locs;
593 // Ternary result is optional, make sure to populate the location only
594 // when relevant.
595 if (op->getResultTypes().size())
596 locs.push_back(loc);
597 Block *continueBlock =
598 rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
599 cir::BrOp::create(rewriter, loc, remainingOpsBlock);
600
601 Region &trueRegion = op.getTrueRegion();
602 Block *trueBlock = &trueRegion.front();
603 // Wire up the true region's exit (cir.yield -> br, cir.unreachable /
604 // cir.trap kept as-is). IR has already been modified by splitBlock /
605 // createBlock above, so per the MLIR pattern rewriter contract we must
606 // still return success() if the terminator turns out to be unexpected.
607 if (failed(rewriteRegionExitToContinue(rewriter, trueRegion, continueBlock,
608 "ternary true")))
609 return mlir::success();
610 rewriter.inlineRegionBefore(trueRegion, continueBlock);
611
612 Block *falseBlock = continueBlock;
613 Region &falseRegion = op.getFalseRegion();
614
615 falseBlock = &falseRegion.front();
616 if (failed(rewriteRegionExitToContinue(rewriter, falseRegion, continueBlock,
617 "ternary false")))
618 return mlir::success();
619 rewriter.inlineRegionBefore(falseRegion, continueBlock);
620
621 rewriter.setInsertionPointToEnd(condBlock);
622 cir::BrCondOp::create(rewriter, loc, op.getCond(), trueBlock, falseBlock);
623
624 rewriter.replaceOp(op, continueBlock->getArguments());
625
626 // Ok, we're done!
627 return mlir::success();
628 }
629};
630
631// Get or create the cleanup destination slot for a function. This slot is
632// shared across all cleanup scopes in the function to track which exit path
633// to take after running cleanup code when there are multiple exits.
634static cir::AllocaOp getOrCreateCleanupDestSlot(cir::FuncOp funcOp,
635 mlir::PatternRewriter &rewriter,
636 mlir::Location loc) {
637 mlir::Block &entryBlock = funcOp.getBody().front();
638
639 // Look for an existing cleanup dest slot in the entry block.
640 auto it = llvm::find_if(entryBlock, [](auto &op) {
641 return mlir::isa<AllocaOp>(&op) &&
642 mlir::cast<AllocaOp>(&op).getCleanupDestSlot();
643 });
644 if (it != entryBlock.end())
645 return mlir::cast<cir::AllocaOp>(*it);
646
647 // Create a new cleanup dest slot at the start of the entry block.
648 mlir::OpBuilder::InsertionGuard guard(rewriter);
649 rewriter.setInsertionPointToStart(&entryBlock);
650 cir::IntType s32Type =
651 cir::IntType::get(rewriter.getContext(), 32, /*isSigned=*/true);
652 cir::PointerType ptrToS32Type = cir::PointerType::get(s32Type);
653 cir::CIRDataLayout dataLayout(funcOp->getParentOfType<mlir::ModuleOp>());
654 uint64_t alignment = dataLayout.getAlignment(s32Type, true).value();
655 auto allocaOp = cir::AllocaOp::create(
656 rewriter, loc, ptrToS32Type, s32Type, "__cleanup_dest_slot",
657 /*alignment=*/rewriter.getI64IntegerAttr(alignment));
658 allocaOp.setCleanupDestSlot(true);
659 return allocaOp;
660}
661
662/// Shared EH flattening utilities used by both CIRCleanupScopeOpFlattening
663/// and CIRTryOpFlattening.
664
665// Collect all function calls in a region that may throw exceptions and need
666// to be replaced with try_call operations. Skips calls marked nothrow.
667// Nested cleanup scopes and try ops are always flattened before their
668// enclosing parents, so there are no nested regions to skip here.
669static void
670collectThrowingCalls(mlir::Region &region,
671 llvm::SmallVectorImpl<cir::CallOp> &callsToRewrite) {
672 region.walk([&](cir::CallOp callOp) {
673 if (!callOp.getNothrow())
674 callsToRewrite.push_back(callOp);
675 });
676}
677
678// Collect all cir.throw operations in a region that need to be replaced
679// with cir.try_throw operations so they can unwind through an enclosing
680// cleanup or catch handler. Nested cleanup scopes and try ops are always
681// flattened before their enclosing parents, so there are no nested
682// regions to skip here.
683static void
684collectThrows(mlir::Region &region,
685 llvm::SmallVectorImpl<cir::ThrowOp> &throwsToRewrite) {
686 region.walk(
687 [&](cir::ThrowOp throwOp) { throwsToRewrite.push_back(throwOp); });
688}
689
690// Collect all cir.resume operations in a region that come from
691// already-flattened try or cleanup scope operations. These resume ops need
692// to be chained through this scope's EH handler instead of unwinding
693// directly to the caller. Nested cleanup scopes and try ops are always
694// flattened before their enclosing parents, so there are no nested regions
695// to skip here.
696static void collectResumeOps(mlir::Region &region,
698 region.walk([&](cir::ResumeOp resumeOp) { resumeOps.push_back(resumeOp); });
699}
700
701// Create a shared unwind destination block. The block contains a
702// cir.eh.initiate operation (optionally with the cleanup attribute) and a
703// branch to the given destination block, passing the eh_token.
704static mlir::Block *buildUnwindBlock(mlir::Block *dest, bool isCleanupOnly,
705 mlir::Location loc,
706 mlir::Block *insertBefore,
707 mlir::PatternRewriter &rewriter) {
708 mlir::Block *unwindBlock = rewriter.createBlock(insertBefore);
709 rewriter.setInsertionPointToEnd(unwindBlock);
710 auto ehInitiate =
711 cir::EhInitiateOp::create(rewriter, loc, /*cleanup=*/isCleanupOnly);
712 cir::BrOp::create(rewriter, loc, mlir::ValueRange{ehInitiate.getEhToken()},
713 dest);
714 return unwindBlock;
715}
716
717// Create a shared terminate unwind block for throwing calls in EH cleanup
718// regions. When an exception is thrown during cleanup (unwinding), the C++
719// standard requires that std::terminate() be called.
720static mlir::Block *buildTerminateUnwindBlock(mlir::Location loc,
721 mlir::Block *insertBefore,
722 mlir::PatternRewriter &rewriter) {
723 mlir::Block *terminateBlock = rewriter.createBlock(insertBefore);
724 rewriter.setInsertionPointToEnd(terminateBlock);
725 auto ehInitiate = cir::EhInitiateOp::create(rewriter, loc, /*cleanup=*/false);
726 cir::EhTerminateOp::create(rewriter, loc, ehInitiate.getEhToken());
727 return terminateBlock;
728}
729
730class CIRCleanupScopeOpFlattening
731 : public mlir::OpRewritePattern<cir::CleanupScopeOp> {
732public:
733 using OpRewritePattern<cir::CleanupScopeOp>::OpRewritePattern;
734
735 struct CleanupExit {
736 // An operation that exits the cleanup scope (yield, break, continue,
737 // return, etc.)
738 mlir::Operation *exitOp;
739
740 // A unique identifier for this exit's destination (used for switch dispatch
741 // when there are multiple exits).
742 int destinationId;
743
744 CleanupExit(mlir::Operation *op, int id) : exitOp(op), destinationId(id) {}
745 };
746
747 // Determine whether a goto operation transfers control to a label that
748 // exists somewhere inside the given region (or any of its nested regions).
749 // Label names are unique within a function, so finding a matching cir.label
750 // inside the region implies that the goto definitely targets that label and
751 // therefore stays within the region. If no match is found, the goto either
752 // exits the region or its target is unknown; in either case the caller must
753 // treat it as exiting the region.
754 static bool gotoTargetsLabelInRegion(cir::GotoOp gotoOp,
755 mlir::Region &region) {
756 llvm::StringRef targetLabel = gotoOp.getLabel();
757 return region
758 .walk([&](cir::LabelOp labelOp) {
759 if (labelOp.getLabel() == targetLabel)
760 return mlir::WalkResult::interrupt();
761 return mlir::WalkResult::advance();
762 })
763 .wasInterrupted();
764 }
765
766 // Collect all operations that exit a cleanup scope body. Return, goto, break,
767 // and continue can all require branches through the cleanup region. When a
768 // loop is encountered, only return and goto are collected because break and
769 // continue are handled by the loop and stay within the cleanup scope. When a
770 // switch is encountered, return, goto and continue are collected because they
771 // may all branch through the cleanup, but break is local to the switch. When
772 // a nested cleanup scope is encountered, we recursively collect exits since
773 // any return, goto, break, or continue from the nested cleanup will also
774 // branch through the outer cleanup.
775 //
776 // A goto is only treated as an exit if its target label is not somewhere
777 // inside the cleanup body region. Gotos whose target label is within the
778 // cleanup body stay inside the cleanup scope and need no special handling
779 // during flattening; they are simply inlined along with the rest of the
780 // body region.
781 //
782 // This function assigns unique destination IDs to each exit, which are
783 // used when multi-exit cleanup scopes are flattened.
784 void collectExits(mlir::Region &cleanupBodyRegion,
785 llvm::SmallVectorImpl<CleanupExit> &exits,
786 int &nextId) const {
787 // Collect yield terminators from the body region. We do this separately
788 // because yields in nested operations, including those in nested cleanup
789 // scopes, won't branch through the outer cleanup region.
790 for (mlir::Block &block : cleanupBodyRegion) {
791 auto *terminator = block.getTerminator();
792 if (isa<cir::YieldOp>(terminator))
793 exits.emplace_back(terminator, nextId++);
794 }
795
796 // Helper to decide whether an op is a goto that needs to be treated as an
797 // exit from the cleanup scope being flattened. If op is a goto and targets
798 // a label inside the cleanup body region, control stays within the cleanup
799 // and we leave the goto in place.
800 auto isGotoThatExitsCleanup = [&](mlir::Operation *op) {
801 auto gotoOp = dyn_cast<cir::GotoOp>(op);
802 return gotoOp && !gotoTargetsLabelInRegion(gotoOp, cleanupBodyRegion);
803 };
804
805 // Lambda to walk a loop and collect only returns and gotos.
806 // Break and continue inside loops are handled by the loop itself.
807 // Loops don't require special handling for nested switch or cleanup scopes
808 // because break and continue never branch out of the loop.
809 auto collectExitsInLoop = [&](mlir::Operation *loopOp) {
810 loopOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *nestedOp) {
811 if (isa<cir::ReturnOp>(nestedOp)) {
812 exits.emplace_back(nestedOp, nextId++);
813 } else if (isGotoThatExitsCleanup(nestedOp)) {
814 exits.emplace_back(nestedOp, nextId++);
815 }
816 return mlir::WalkResult::advance();
817 });
818 };
819
820 // Forward declaration for mutual recursion.
821 std::function<void(mlir::Region &, bool)> collectExitsInCleanup;
822 std::function<void(mlir::Operation *)> collectExitsInSwitch;
823
824 // Lambda to collect exits from a switch. Collects return/goto/continue but
825 // not break (handled by switch). For nested loops/cleanups, recurses.
826 collectExitsInSwitch = [&](mlir::Operation *switchOp) {
827 switchOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *nestedOp) {
828 if (isa<cir::CleanupScopeOp>(nestedOp)) {
829 // Walk the nested cleanup, but ignore break statements because they
830 // will be handled by the switch we are currently walking.
831 collectExitsInCleanup(
832 cast<cir::CleanupScopeOp>(nestedOp).getBodyRegion(),
833 /*ignoreBreak=*/true);
834 return mlir::WalkResult::skip();
835 } else if (isa<cir::LoopOpInterface>(nestedOp)) {
836 collectExitsInLoop(nestedOp);
837 return mlir::WalkResult::skip();
838 } else if (isa<cir::ReturnOp, cir::ContinueOp>(nestedOp)) {
839 exits.emplace_back(nestedOp, nextId++);
840 } else if (isGotoThatExitsCleanup(nestedOp)) {
841 exits.emplace_back(nestedOp, nextId++);
842 }
843 return mlir::WalkResult::advance();
844 });
845 };
846
847 // Lambda to collect exits from a cleanup scope body region. This collects
848 // break (optionally), continue, return, and goto, handling nested loops,
849 // switches, and cleanups appropriately.
850 collectExitsInCleanup = [&](mlir::Region &region, bool ignoreBreak) {
851 region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
852 // We need special handling for break statements because if this cleanup
853 // scope was nested within a switch op, break will be handled by the
854 // switch operation and therefore won't exit the cleanup scope enclosing
855 // the switch. We're only collecting exits from the cleanup that started
856 // this walk. Exits from nested cleanups will be handled when we flatten
857 // the nested cleanup.
858 if (!ignoreBreak && isa<cir::BreakOp>(op)) {
859 exits.emplace_back(op, nextId++);
860 } else if (isa<cir::ContinueOp, cir::ReturnOp>(op)) {
861 exits.emplace_back(op, nextId++);
862 } else if (isGotoThatExitsCleanup(op)) {
863 exits.emplace_back(op, nextId++);
864 } else if (isa<cir::CleanupScopeOp>(op)) {
865 // Recurse into nested cleanup's body region.
866 collectExitsInCleanup(cast<cir::CleanupScopeOp>(op).getBodyRegion(),
867 /*ignoreBreak=*/ignoreBreak);
868 return mlir::WalkResult::skip();
869 } else if (isa<cir::LoopOpInterface>(op)) {
870 // This kicks off a separate walk rather than continuing to dig deeper
871 // in the current walk because we need to handle break and continue
872 // differently inside loops.
873 collectExitsInLoop(op);
874 return mlir::WalkResult::skip();
875 } else if (isa<cir::SwitchOp>(op)) {
876 // This kicks off a separate walk rather than continuing to dig deeper
877 // in the current walk because we need to handle break differently
878 // inside switches.
879 collectExitsInSwitch(op);
880 return mlir::WalkResult::skip();
881 }
882 return mlir::WalkResult::advance();
883 });
884 };
885
886 // Collect exits from the body region.
887 collectExitsInCleanup(cleanupBodyRegion, /*ignoreBreak=*/false);
888 }
889
890 // Check if an operand's defining op should be moved to the destination block.
891 // We only sink constants and simple loads. Anything else should be saved
892 // to a temporary alloca and reloaded at the destination block.
893 static bool shouldSinkReturnOperand(mlir::Value operand,
894 cir::ReturnOp returnOp) {
895 // Block arguments can't be moved
896 mlir::Operation *defOp = operand.getDefiningOp();
897 if (!defOp)
898 return false;
899
900 // Only move constants and loads to the dispatch block. For anything else,
901 // we'll store to a temporary and reload in the dispatch block.
902 if (!mlir::isa<cir::ConstantOp, cir::LoadOp>(defOp))
903 return false;
904
905 // Check if the return is the only user
906 if (!operand.hasOneUse())
907 return false;
908
909 // Only move ops that are in the same block as the return.
910 if (defOp->getBlock() != returnOp->getBlock())
911 return false;
912
913 if (auto loadOp = mlir::dyn_cast<cir::LoadOp>(defOp)) {
914 // Only attempt to move loads of allocas in the entry block.
915 mlir::Value ptr = loadOp.getAddr();
916 auto funcOp = returnOp->getParentOfType<cir::FuncOp>();
917 assert(funcOp && "Return op has no function parent?");
918 mlir::Block &funcEntryBlock = funcOp.getBody().front();
919
920 // Check if it's an alloca in the function entry block
921 if (auto allocaOp =
922 mlir::dyn_cast_if_present<cir::AllocaOp>(ptr.getDefiningOp()))
923 return allocaOp->getBlock() == &funcEntryBlock;
924
925 return false;
926 }
927
928 // Make sure we only fall through to here with constants.
929 assert(mlir::isa<cir::ConstantOp>(defOp) && "Expected constant op");
930 return true;
931 }
932
933 // For returns with operands in cleanup dispatch blocks, the operands may not
934 // dominate the dispatch block. This function handles that by either sinking
935 // the operand's defining op to the dispatch block (for constants and simple
936 // loads) or by storing to a temporary alloca and reloading it.
937 void
938 getReturnOpOperands(cir::ReturnOp returnOp, mlir::Operation *exitOp,
939 mlir::Location loc, mlir::PatternRewriter &rewriter,
940 llvm::SmallVectorImpl<mlir::Value> &returnValues) const {
941 mlir::Block *destBlock = rewriter.getInsertionBlock();
942 auto funcOp = exitOp->getParentOfType<cir::FuncOp>();
943 assert(funcOp && "Return op has no function parent?");
944 mlir::Block &funcEntryBlock = funcOp.getBody().front();
945
946 for (mlir::Value operand : returnOp.getOperands()) {
947 if (shouldSinkReturnOperand(operand, returnOp)) {
948 // Sink the defining op to the dispatch block.
949 mlir::Operation *defOp = operand.getDefiningOp();
950 rewriter.moveOpBefore(defOp, destBlock, destBlock->end());
951 returnValues.push_back(operand);
952 } else {
953 // Create an alloca in the function entry block.
954 cir::AllocaOp alloca;
955 {
956 mlir::OpBuilder::InsertionGuard guard(rewriter);
957 rewriter.setInsertionPointToStart(&funcEntryBlock);
958 cir::CIRDataLayout dataLayout(
959 funcOp->getParentOfType<mlir::ModuleOp>());
960 uint64_t alignment =
961 dataLayout.getAlignment(operand.getType(), true).value();
962 cir::PointerType ptrType = cir::PointerType::get(operand.getType());
963 alloca = cir::AllocaOp::create(rewriter, loc, ptrType,
964 operand.getType(), "__ret_operand_tmp",
965 rewriter.getI64IntegerAttr(alignment));
966 }
967
968 // Store the operand value at the original return location.
969 {
970 mlir::OpBuilder::InsertionGuard guard(rewriter);
971 rewriter.setInsertionPoint(exitOp);
972 cir::StoreOp::create(rewriter, loc, operand, alloca,
973 /*isVolatile=*/false,
974 /*alignment=*/mlir::IntegerAttr(),
975 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
976 }
977
978 // Reload the value from the temporary alloca in the destination block.
979 rewriter.setInsertionPointToEnd(destBlock);
980 auto loaded = cir::LoadOp::create(
981 rewriter, loc, alloca, /*isDeref=*/false,
982 /*isVolatile=*/false, /*alignment=*/mlir::IntegerAttr(),
983 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
984 returnValues.push_back(loaded);
985 }
986 }
987 }
988
989 // Create the appropriate terminator for an exit operation in the dispatch
990 // block. For return ops with operands, this handles the dominance issue by
991 // either moving the operand's defining op to the dispatch block (if it's a
992 // trivial use) or by storing to a temporary alloca and loading it.
993 mlir::LogicalResult
994 createExitTerminator(mlir::Operation *exitOp, mlir::Location loc,
995 mlir::Block *continueBlock,
996 mlir::PatternRewriter &rewriter) const {
997 return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(exitOp)
998 .Case<cir::YieldOp>([&](auto) {
999 // Yield becomes a branch to continue block.
1000 cir::BrOp::create(rewriter, loc, continueBlock);
1001 return mlir::success();
1002 })
1003 .Case<cir::BreakOp>([&](auto) {
1004 // Break is preserved for later lowering by enclosing switch/loop.
1005 cir::BreakOp::create(rewriter, loc);
1006 return mlir::success();
1007 })
1008 .Case<cir::ContinueOp>([&](auto) {
1009 // Continue is preserved for later lowering by enclosing loop.
1010 cir::ContinueOp::create(rewriter, loc);
1011 return mlir::success();
1012 })
1013 .Case<cir::ReturnOp>([&](auto returnOp) {
1014 // Return from the cleanup exit. Note, if this is a return inside a
1015 // nested cleanup scope, the flattening of the outer scope will handle
1016 // branching through the outer cleanup.
1017 if (returnOp.hasOperand()) {
1018 llvm::SmallVector<mlir::Value, 2> returnValues;
1019 getReturnOpOperands(returnOp, exitOp, loc, rewriter, returnValues);
1020 cir::ReturnOp::create(rewriter, loc, returnValues);
1021 } else {
1022 cir::ReturnOp::create(rewriter, loc);
1023 }
1024 return mlir::success();
1025 })
1026 .Case<cir::GotoOp>([&](auto gotoOp) {
1027 // Gotos that target a label within the cleanup body region are
1028 // filtered out by collectExits and never reach this code, so any
1029 // goto that does reach here transfers control out of the cleanup
1030 // scope. The goto is just moved to the exit block.
1031 cir::GotoOp::create(rewriter, loc, gotoOp.getLabel());
1032 return mlir::success();
1033 })
1034 .Default([&](mlir::Operation *op) {
1035 cir::UnreachableOp::create(rewriter, loc);
1036 return op->emitError(
1037 "unexpected exit operation in cleanup scope body");
1038 });
1039 }
1040
1041#ifndef NDEBUG
1042 // Check that no block other than the last one in a region exits the region.
1043 static bool regionExitsOnlyFromLastBlock(mlir::Region &region) {
1044 for (mlir::Block &block : region) {
1045 if (&block == &region.back())
1046 continue;
1047 bool expectedTerminator =
1048 llvm::TypeSwitch<mlir::Operation *, bool>(block.getTerminator())
1049 // It is theoretically possible to have a cleanup block with
1050 // any of the following exits in non-final blocks, but we won't
1051 // currently generate any CIR that does that, and being able to
1052 // assume that it doesn't happen simplifies the implementation.
1053 // If we ever need to handle this case, the code will need to
1054 // be updated to handle it.
1055 .Case<cir::YieldOp, cir::ReturnOp, cir::ResumeFlatOp,
1056 cir::ContinueOp, cir::BreakOp, cir::GotoOp>(
1057 [](auto) { return false; })
1058 // We expect that call operations have not yet been rewritten
1059 // as try_call operations. A call can unwind out of the cleanup
1060 // scope, but we will be handling that during flattening. The
1061 // only case where a try_call could be present inside an
1062 // unflattened cleanup region is if the cleanup contained a
1063 // nested try-catch region, and that isn't expected as of the
1064 // time of this implementation. If it does, this could be
1065 // updated to tolerate it.
1066 .Case<cir::TryCallOp>([](auto) { return false; })
1067 // Likewise, we don't expect to find an EH dispatch operation
1068 // because we weren't expecting try-catch regions nested in the
1069 // cleanup region.
1070 .Case<cir::EhDispatchOp>([](auto) { return false; })
1071 // In theory, it would be possible to have a flattened switch
1072 // operation that does not exit the cleanup region. For now,
1073 // that's not happening.
1074 .Case<cir::SwitchFlatOp>([](auto) { return false; })
1075 // These aren't expected either, but if they occur, they don't
1076 // exit the region, so that's OK.
1077 .Case<cir::UnreachableOp, cir::TrapOp>([](auto) { return true; })
1078 // Indirect branches are not expected.
1079 .Case<cir::IndirectBrOp>([](auto) { return false; })
1080 // We do expect branches, but we don't expect them to leave
1081 // the region.
1082 .Case<cir::BrOp>([&](cir::BrOp brOp) {
1083 assert(brOp.getDest()->getParent() == &region &&
1084 "branch destination is not in the region");
1085 return true;
1086 })
1087 .Case<cir::BrCondOp>([&](cir::BrCondOp brCondOp) {
1088 assert(brCondOp.getDestTrue()->getParent() == &region &&
1089 "branch destination is not in the region");
1090 assert(brCondOp.getDestFalse()->getParent() == &region &&
1091 "branch destination is not in the region");
1092 return true;
1093 })
1094 // What else could there be?
1095 .Default([](mlir::Operation *) -> bool {
1096 llvm_unreachable("unexpected terminator in cleanup region");
1097 });
1098 if (!expectedTerminator)
1099 return false;
1100 }
1101 return true;
1102 }
1103#endif
1104
1105 // Build the EH cleanup block structure by cloning the cleanup region. The
1106 // cloned entry block gets an !cir.eh_token argument and a cir.begin_cleanup
1107 // inserted at the top. All cir.yield terminators that might exit the cleanup
1108 // region are replaced with cir.end_cleanup + cir.resume.
1109 //
1110 // For a single-block cleanup region, this produces:
1111 //
1112 // ^eh_cleanup(%eh_token : !cir.eh_token):
1113 // %ct = cir.begin_cleanup %eh_token : !cir.eh_token -> !cir.cleanup_token
1114 // <cloned cleanup operations>
1115 // cir.end_cleanup %ct : !cir.cleanup_token
1116 // cir.resume %eh_token : !cir.eh_token
1117 //
1118 // For a multi-block cleanup region (e.g. containing a flattened cir.if),
1119 // the same wrapping is applied around the cloned block structure: the entry
1120 // block gets begin_cleanup and all exit blocks (those terminated by yield)
1121 // get end_cleanup + resume.
1122 //
1123 // If this cleanup scope is nested within a TryOp, the resume will be updated
1124 // to branch to the catch dispatch block of the enclosing try operation when
1125 // the TryOp is flattened.
1126 mlir::Block *buildEHCleanupBlocks(cir::CleanupScopeOp cleanupOp,
1127 mlir::Location loc,
1128 mlir::Block *insertBefore,
1129 mlir::PatternRewriter &rewriter) const {
1130 assert(regionExitsOnlyFromLastBlock(cleanupOp.getCleanupRegion()) &&
1131 "cleanup region has exits in non-final blocks");
1132
1133 // Track the block before the insertion point so we can find the cloned
1134 // blocks after cloning.
1135 mlir::Block *blockBeforeClone = insertBefore->getPrevNode();
1136
1137 // Clone the entire cleanup region before insertBefore.
1138 rewriter.cloneRegionBefore(cleanupOp.getCleanupRegion(), insertBefore);
1139
1140 // Find the first cloned block.
1141 mlir::Block *clonedEntry = blockBeforeClone
1142 ? blockBeforeClone->getNextNode()
1143 : &insertBefore->getParent()->front();
1144
1145 // Add the eh_token argument to the cloned entry block and insert
1146 // begin_cleanup at the top.
1147 auto ehTokenType = cir::EhTokenType::get(rewriter.getContext());
1148 mlir::Value ehToken = clonedEntry->addArgument(ehTokenType, loc);
1149
1150 rewriter.setInsertionPointToStart(clonedEntry);
1151 auto beginCleanup = cir::BeginCleanupOp::create(rewriter, loc, ehToken);
1152
1153 // Replace the yield terminator in the last cloned block with
1154 // end_cleanup + resume.
1155 mlir::Block *lastClonedBlock = insertBefore->getPrevNode();
1156 auto yieldOp =
1157 mlir::dyn_cast<cir::YieldOp>(lastClonedBlock->getTerminator());
1158 if (yieldOp) {
1159 rewriter.setInsertionPoint(yieldOp);
1160 cir::EndCleanupOp::create(rewriter, loc, beginCleanup.getCleanupToken());
1161 rewriter.replaceOpWithNewOp<cir::ResumeOp>(yieldOp, ehToken);
1162 } else {
1163 cleanupOp->emitError("Not yet implemented: cleanup region terminated "
1164 "with non-yield operation");
1165 }
1166
1167 return clonedEntry;
1168 }
1169
1170 // Flatten a cleanup scope. The body region's exits branch to the cleanup
1171 // block, and the cleanup block branches to destination blocks whose contents
1172 // depend on the type of operation that exited the body region. Yield becomes
1173 // a branch to the block after the cleanup scope, break and continue are
1174 // preserved for later lowering by enclosing switch or loop, and return
1175 // is preserved as is.
1176 //
1177 // If there are multiple exits from the cleanup body, a destination slot and
1178 // switch dispatch are used to continue to the correct destination after the
1179 // cleanup is complete. A destination slot alloca is created at the function
1180 // entry block. Each exit operation is replaced by a store of its unique ID to
1181 // the destination slot and a branch to cleanup. An operation is appended to
1182 // the to branch to a dispatch block that loads the destination slot and uses
1183 // switch.flat to branch to the correct destination.
1184 //
1185 // If the cleanup scope requires EH cleanup, any call operations in the body
1186 // that may throw are replaced with cir.try_call operations that unwind to an
1187 // EH cleanup block. The cleanup block(s) will be terminated with a cir.resume
1188 // operation. If this cleanup scope is enclosed by a try operation, the
1189 // flattening of the try operation flattening will replace the cir.resume with
1190 // a branch to a catch dispatch block. Otherwise, the cir.resume operation
1191 // remains in place and will unwind to the caller.
1192 mlir::LogicalResult
1193 flattenCleanup(cir::CleanupScopeOp cleanupOp,
1194 llvm::SmallVectorImpl<CleanupExit> &exits,
1195 llvm::SmallVectorImpl<cir::CallOp> &callsToRewrite,
1196 llvm::SmallVectorImpl<cir::ThrowOp> &throwsToRewrite,
1197 llvm::SmallVectorImpl<cir::ResumeOp> &resumeOpsToChain,
1198 mlir::PatternRewriter &rewriter) const {
1199 mlir::Location loc = cleanupOp.getLoc();
1200 cir::CleanupKind cleanupKind = cleanupOp.getCleanupKind();
1201 bool hasNormalCleanup = cleanupKind == cir::CleanupKind::Normal ||
1202 cleanupKind == cir::CleanupKind::All;
1203 bool hasEHCleanup = cleanupKind == cir::CleanupKind::EH ||
1204 cleanupKind == cir::CleanupKind::All;
1205 bool isMultiExit = exits.size() > 1;
1206
1207 // Get references to region blocks before inlining.
1208 mlir::Block *bodyEntry = &cleanupOp.getBodyRegion().front();
1209 mlir::Block *cleanupEntry = &cleanupOp.getCleanupRegion().front();
1210 mlir::Block *cleanupExit = &cleanupOp.getCleanupRegion().back();
1211 assert(regionExitsOnlyFromLastBlock(cleanupOp.getCleanupRegion()) &&
1212 "cleanup region has exits in non-final blocks");
1213 auto cleanupYield = dyn_cast<cir::YieldOp>(cleanupExit->getTerminator());
1214 if (!cleanupYield) {
1215 return rewriter.notifyMatchFailure(cleanupOp,
1216 "Not yet implemented: cleanup region "
1217 "terminated with non-yield operation");
1218 }
1219
1220 // For multiple exits from the body region, get or create a destination slot
1221 // at function entry. The slot is shared across all cleanup scopes in the
1222 // function. This is only needed if the cleanup scope requires normal
1223 // cleanup.
1224 cir::AllocaOp destSlot;
1225 if (isMultiExit && hasNormalCleanup) {
1226 auto funcOp = cleanupOp->getParentOfType<cir::FuncOp>();
1227 if (!funcOp)
1228 return cleanupOp->emitError("cleanup scope not inside a function");
1229 destSlot = getOrCreateCleanupDestSlot(funcOp, rewriter, loc);
1230 }
1231
1232 // Split the current block to create the insertion point.
1233 mlir::Block *currentBlock = rewriter.getInsertionBlock();
1234 mlir::Block *continueBlock =
1235 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
1236
1237 // Build EH cleanup blocks if needed. This must be done before inlining
1238 // the cleanup region since buildEHCleanupBlocks clones from it. The unwind
1239 // block is inserted before the EH cleanup entry so that the final layout
1240 // is: body -> normal cleanup -> exit -> unwind -> EH cleanup -> continue.
1241 // EH cleanup blocks are needed when there are throwing calls or throws
1242 // that need to be rewritten, or when there are resume ops from
1243 // already-flattened inner cleanup scopes that need to chain through this
1244 // cleanup's EH handler.
1245 mlir::Block *unwindBlock = nullptr;
1246 mlir::Block *ehCleanupEntry = nullptr;
1247 if (hasEHCleanup && (!callsToRewrite.empty() || !throwsToRewrite.empty() ||
1248 !resumeOpsToChain.empty())) {
1249 ehCleanupEntry =
1250 buildEHCleanupBlocks(cleanupOp, loc, continueBlock, rewriter);
1251 // The unwind block is only needed when there are throwing calls or
1252 // throws that need a shared unwind destination. Resume ops from inner
1253 // cleanups branch directly to the EH cleanup entry.
1254 if (!callsToRewrite.empty() || !throwsToRewrite.empty())
1255 unwindBlock = buildUnwindBlock(ehCleanupEntry, /*isCleanupOnly=*/true,
1256 loc, ehCleanupEntry, rewriter);
1257 }
1258
1259 // All normal flow blocks are inserted before this point — either before
1260 // the unwind block (if it exists), or before the EH cleanup entry (if EH
1261 // cleanup exists but no unwind block is needed), or before the continue
1262 // block.
1263 mlir::Block *normalInsertPt =
1264 unwindBlock ? unwindBlock
1265 : (ehCleanupEntry ? ehCleanupEntry : continueBlock);
1266
1267 // Inline the body region.
1268 rewriter.inlineRegionBefore(cleanupOp.getBodyRegion(), normalInsertPt);
1269
1270 // Inline the cleanup region for the normal cleanup path.
1271 if (hasNormalCleanup)
1272 rewriter.inlineRegionBefore(cleanupOp.getCleanupRegion(), normalInsertPt);
1273
1274 // Branch from current block to body entry.
1275 rewriter.setInsertionPointToEnd(currentBlock);
1276 cir::BrOp::create(rewriter, loc, bodyEntry);
1277
1278 // Handle normal exits.
1279 mlir::LogicalResult result = mlir::success();
1280 if (hasNormalCleanup) {
1281 // Create the exit/dispatch block (after cleanup, before continue).
1282 mlir::Block *exitBlock = rewriter.createBlock(normalInsertPt);
1283
1284 // Rewrite the cleanup region's yield to branch to exit block.
1285 rewriter.setInsertionPoint(cleanupYield);
1286 rewriter.replaceOpWithNewOp<cir::BrOp>(cleanupYield, exitBlock);
1287
1288 if (isMultiExit) {
1289 // Build the dispatch switch in the exit block.
1290 rewriter.setInsertionPointToEnd(exitBlock);
1291
1292 // Load the destination slot value.
1293 auto slotValue = cir::LoadOp::create(
1294 rewriter, loc, destSlot, /*isDeref=*/false,
1295 /*isVolatile=*/false, /*alignment=*/mlir::IntegerAttr(),
1296 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
1297
1298 // Create destination blocks for each exit and collect switch case info.
1299 llvm::SmallVector<mlir::APInt, 8> caseValues;
1300 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
1301 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
1302 cir::IntType s32Type =
1303 cir::IntType::get(rewriter.getContext(), 32, /*isSigned=*/true);
1304
1305 for (const CleanupExit &exit : exits) {
1306 // Create a block for this destination.
1307 mlir::Block *destBlock = rewriter.createBlock(normalInsertPt);
1308 rewriter.setInsertionPointToEnd(destBlock);
1309 result =
1310 createExitTerminator(exit.exitOp, loc, continueBlock, rewriter);
1311
1312 // Add to switch cases.
1313 caseValues.push_back(
1314 llvm::APInt(32, static_cast<uint64_t>(exit.destinationId), true));
1315 caseDestinations.push_back(destBlock);
1316 caseOperands.push_back(mlir::ValueRange());
1317
1318 // Replace the original exit op with: store dest ID, branch to
1319 // cleanup.
1320 rewriter.setInsertionPoint(exit.exitOp);
1321 auto destIdConst = cir::ConstantOp::create(
1322 rewriter, loc, cir::IntAttr::get(s32Type, exit.destinationId));
1323 cir::StoreOp::create(rewriter, loc, destIdConst, destSlot,
1324 /*isVolatile=*/false,
1325 /*alignment=*/mlir::IntegerAttr(),
1326 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
1327 rewriter.replaceOpWithNewOp<cir::BrOp>(exit.exitOp, cleanupEntry);
1328
1329 // If the exit terminator creation failed, we're going to end up with
1330 // partially flattened code, but we'll also have reported an error so
1331 // that's OK. We need to finish out this function to keep the IR in a
1332 // valid state to help diagnose the error. This is a temporary
1333 // possibility during development. It shouldn't ever happen after the
1334 // implementation is complete.
1335 if (result.failed())
1336 break;
1337 }
1338
1339 // Create the default destination (unreachable).
1340 mlir::Block *defaultBlock = rewriter.createBlock(normalInsertPt);
1341 rewriter.setInsertionPointToEnd(defaultBlock);
1342 cir::UnreachableOp::create(rewriter, loc);
1343
1344 // Build the switch.flat operation in the exit block.
1345 rewriter.setInsertionPointToEnd(exitBlock);
1346 cir::SwitchFlatOp::create(rewriter, loc, slotValue, defaultBlock,
1347 mlir::ValueRange(), caseValues,
1348 caseDestinations, caseOperands);
1349 } else {
1350 // Single exit: put the appropriate terminator directly in the exit
1351 // block.
1352 rewriter.setInsertionPointToEnd(exitBlock);
1353 mlir::Operation *exitOp = exits[0].exitOp;
1354 result = createExitTerminator(exitOp, loc, continueBlock, rewriter);
1355
1356 // Replace body exit with branch to cleanup entry.
1357 rewriter.setInsertionPoint(exitOp);
1358 rewriter.replaceOpWithNewOp<cir::BrOp>(exitOp, cleanupEntry);
1359 }
1360 } else {
1361 // EH-only cleanup: normal exits skip the cleanup entirely.
1362 // Replace yield exits with branches to the continue block.
1363 for (CleanupExit &exit : exits) {
1364 if (isa<cir::YieldOp>(exit.exitOp)) {
1365 rewriter.setInsertionPoint(exit.exitOp);
1366 rewriter.replaceOpWithNewOp<cir::BrOp>(exit.exitOp, continueBlock);
1367 }
1368 // Non-yield exits (break, continue, return) stay as-is since no normal
1369 // cleanup is needed.
1370 }
1371 }
1372
1373 // Replace non-nothrow calls and throws with try_call/try_throw
1374 // operations. All calls and throws within this cleanup scope share the
1375 // same unwind destination.
1376 if (hasEHCleanup) {
1377 for (cir::CallOp callOp : callsToRewrite)
1378 replaceCallWithTryCall(callOp, unwindBlock, loc, rewriter);
1379 for (cir::ThrowOp throwOp : throwsToRewrite)
1380 replaceThrowWithTryThrow(throwOp, unwindBlock, loc, rewriter);
1381 }
1382
1383 // Handle throwing calls and throws in EH cleanup blocks. When an
1384 // exception is thrown during cleanup code that runs on the exception
1385 // unwind path, the C++ standard requires that std::terminate() be
1386 // called. Replace such calls and throws with try_call/try_throw
1387 // operations that unwind to a terminate block containing
1388 // cir.eh.initiate + cir.eh.terminate.
1389 if (ehCleanupEntry) {
1390 llvm::SmallVector<cir::CallOp> ehCleanupThrowingCalls;
1391 llvm::SmallVector<cir::ThrowOp> ehCleanupThrows;
1392 for (mlir::Block *block = ehCleanupEntry; block != continueBlock;
1393 block = block->getNextNode()) {
1394 block->walk([&](mlir::Operation *op) {
1395 if (auto callOp = mlir::dyn_cast<cir::CallOp>(op)) {
1396 if (!callOp.getNothrow())
1397 ehCleanupThrowingCalls.push_back(callOp);
1398 } else if (auto throwOp = mlir::dyn_cast<cir::ThrowOp>(op)) {
1399 ehCleanupThrows.push_back(throwOp);
1400 }
1401 });
1402 }
1403 if (!ehCleanupThrowingCalls.empty() || !ehCleanupThrows.empty()) {
1404 mlir::Block *terminateBlock =
1405 buildTerminateUnwindBlock(loc, continueBlock, rewriter);
1406 for (cir::CallOp callOp : ehCleanupThrowingCalls)
1407 replaceCallWithTryCall(callOp, terminateBlock, loc, rewriter);
1408 for (cir::ThrowOp throwOp : ehCleanupThrows)
1409 replaceThrowWithTryThrow(throwOp, terminateBlock, loc, rewriter);
1410 }
1411 }
1412
1413 // Chain inner EH cleanup resume ops to this cleanup's EH handler.
1414 // Each cir.resume from an already-flattened inner cleanup is replaced
1415 // with a branch to the outer EH cleanup entry, passing the eh_token
1416 // from the inner's begin_cleanup so that the same in-flight exception
1417 // flows through the outer cleanup before unwinding to the caller.
1418 if (ehCleanupEntry) {
1419 for (cir::ResumeOp resumeOp : resumeOpsToChain) {
1420 mlir::Value ehToken = resumeOp.getEhToken();
1421 rewriter.setInsertionPoint(resumeOp);
1422 rewriter.replaceOpWithNewOp<cir::BrOp>(
1423 resumeOp, mlir::ValueRange{ehToken}, ehCleanupEntry);
1424 }
1425 }
1426
1427 // Erase the original cleanup scope op.
1428 rewriter.eraseOp(cleanupOp);
1429
1430 // Always return success because the IR has been modified (blocks split,
1431 // regions inlined, ops erased, etc.). The MLIR pattern rewriter contract
1432 // requires that if a pattern modifies IR, it must return success().
1433 return mlir::success();
1434 }
1435
1436 mlir::LogicalResult
1437 matchAndRewrite(cir::CleanupScopeOp cleanupOp,
1438 mlir::PatternRewriter &rewriter) const override {
1439 mlir::OpBuilder::InsertionGuard guard(rewriter);
1440
1441 // All nested structured CIR ops must be flattened before the cleanup scope.
1442 // Operations like loops, switches, scopes, and ifs may contain exits
1443 // (return, break, continue) that the cleanup scope will replace with
1444 // branches to the cleanup entry. If those exits are inside a structured
1445 // op's region, the branch would reference a block outside that region,
1446 // which is invalid. Fail the match so they are processed first.
1447 //
1448 // Before checking, erase any trivially dead nested cleanup scopes. These
1449 // arise from deactivated cleanups (e.g. partial-construction guards for
1450 // lambda captures). The greedy rewriter may have already DCE'd them, but
1451 // when a trivially dead nested op is erased first, the parent isn't always
1452 // re-added to the worklist, so we handle it here.
1453 llvm::SmallVector<cir::CleanupScopeOp> deadNestedOps;
1454 cleanupOp.getBodyRegion().walk([&](cir::CleanupScopeOp nested) {
1455 if (mlir::isOpTriviallyDead(nested))
1456 deadNestedOps.push_back(nested);
1457 });
1458 for (auto op : deadNestedOps)
1459 rewriter.eraseOp(op);
1460
1461 if (hasNestedOpsToFlatten(cleanupOp.getBodyRegion()))
1462 return mlir::failure();
1463
1464 cir::CleanupKind cleanupKind = cleanupOp.getCleanupKind();
1465
1466 // Collect all exits from the body region.
1467 llvm::SmallVector<CleanupExit> exits;
1468 int nextId = 0;
1469 collectExits(cleanupOp.getBodyRegion(), exits, nextId);
1470
1471 assert(!exits.empty() && "cleanup scope body has no exit");
1472
1473 // Collect non-nothrow calls and throws that need to be converted to
1474 // try_call/try_throw. This is only needed for EH and All cleanup kinds,
1475 // but the vectors will simply be empty for Normal cleanup.
1476 llvm::SmallVector<cir::CallOp> callsToRewrite;
1477 llvm::SmallVector<cir::ThrowOp> throwsToRewrite;
1478 if (cleanupKind != cir::CleanupKind::Normal) {
1479 collectThrowingCalls(cleanupOp.getBodyRegion(), callsToRewrite);
1480 collectThrows(cleanupOp.getBodyRegion(), throwsToRewrite);
1481 }
1482
1483 // Collect resume ops from already-flattened inner cleanup scopes that
1484 // need to chain through this cleanup's EH handler.
1485 llvm::SmallVector<cir::ResumeOp> resumeOpsToChain;
1486 if (cleanupKind != cir::CleanupKind::Normal)
1487 collectResumeOps(cleanupOp.getBodyRegion(), resumeOpsToChain);
1488
1489 return flattenCleanup(cleanupOp, exits, callsToRewrite, throwsToRewrite,
1490 resumeOpsToChain, rewriter);
1491 }
1492};
1493
1494// Trace an !cir.eh_token value back through block arguments to find the
1495// cir.eh.initiate operation that defines it. Returns {} if the defining op
1496// cannot be found (e.g. multiple predecessors).
1497static cir::EhInitiateOp traceToEhInitiate(mlir::Value ehToken) {
1498 while (ehToken) {
1499 if (auto initiate = ehToken.getDefiningOp<cir::EhInitiateOp>())
1500 return initiate;
1501 auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(ehToken);
1502 if (!blockArg)
1503 return {};
1504 mlir::Block *pred = blockArg.getOwner()->getSinglePredecessor();
1505 if (!pred)
1506 return {};
1507 auto brOp = mlir::dyn_cast<cir::BrOp>(pred->getTerminator());
1508 if (!brOp)
1509 return {};
1510 ehToken = brOp.getDestOperands()[blockArg.getArgNumber()];
1511 }
1512 return {};
1513}
1514
1515class CIRTryOpFlattening : public mlir::OpRewritePattern<cir::TryOp> {
1516public:
1517 using OpRewritePattern<cir::TryOp>::OpRewritePattern;
1518
1519 // Build the catch dispatch block with a cir.eh.dispatch operation.
1520 // The dispatch block receives an !cir.eh_token argument and dispatches
1521 // to the appropriate catch handler blocks based on exception types.
1522 mlir::Block *buildCatchDispatchBlock(
1523 cir::TryOp tryOp, mlir::ArrayAttr handlerTypes,
1524 llvm::SmallVectorImpl<mlir::Block *> &catchHandlerBlocks,
1525 mlir::Location loc, mlir::Block *insertBefore,
1526 mlir::PatternRewriter &rewriter) const {
1527 mlir::Block *dispatchBlock = rewriter.createBlock(insertBefore);
1528 auto ehTokenType = cir::EhTokenType::get(rewriter.getContext());
1529 mlir::Value ehToken = dispatchBlock->addArgument(ehTokenType, loc);
1530
1531 rewriter.setInsertionPointToEnd(dispatchBlock);
1532
1533 // Build the catch types and destinations for the dispatch.
1534 llvm::SmallVector<mlir::Attribute> catchTypeAttrs;
1535 llvm::SmallVector<mlir::Block *> catchDests;
1536 mlir::Block *defaultDest = nullptr;
1537 bool defaultIsCatchAll = false;
1538
1539 for (auto [typeAttr, handlerBlock] :
1540 llvm::zip(handlerTypes, catchHandlerBlocks)) {
1541 if (mlir::isa<cir::CatchAllAttr>(typeAttr)) {
1542 assert(!defaultDest && "multiple catch_all or unwind handlers");
1543 defaultDest = handlerBlock;
1544 defaultIsCatchAll = true;
1545 } else if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
1546 assert(!defaultDest && "multiple catch_all or unwind handlers");
1547 defaultDest = handlerBlock;
1548 defaultIsCatchAll = false;
1549 } else {
1550 // This is a typed catch handler (GlobalViewAttr with type info).
1551 catchTypeAttrs.push_back(typeAttr);
1552 catchDests.push_back(handlerBlock);
1553 }
1554 }
1555
1556 assert(defaultDest && "dispatch must have a catch_all or unwind handler");
1557
1558 mlir::ArrayAttr catchTypesArrayAttr;
1559 if (!catchTypeAttrs.empty())
1560 catchTypesArrayAttr = rewriter.getArrayAttr(catchTypeAttrs);
1561
1562 cir::EhDispatchOp::create(rewriter, loc, ehToken, catchTypesArrayAttr,
1563 defaultIsCatchAll, defaultDest, catchDests);
1564
1565 return dispatchBlock;
1566 }
1567
1568 // Flatten a single catch handler region. Each handler region has an
1569 // !cir.eh_token argument and starts with cir.begin_catch, followed by
1570 // a cir.cleanup.scope containing the handler body (with cir.end_catch in
1571 // its cleanup region), and ending with cir.yield.
1572 //
1573 // After flattening, the handler region becomes a block that receives the
1574 // eh_token, calls begin_catch, runs the handler body inline, calls
1575 // end_catch, and branches to the continue block.
1576 //
1577 // The cleanup scope inside the catch handler is expected to have been
1578 // flattened before we get here, so what we see in the handler region is
1579 // already flat code with begin_catch at the top and end_catch in any place
1580 // that we would exit the catch handler. We just need to inline the region
1581 // and fix up terminators.
1582 mlir::Block *flattenCatchHandler(mlir::Region &handlerRegion,
1583 mlir::Block *continueBlock,
1584 mlir::Location loc,
1585 mlir::Block *insertBefore,
1586 mlir::PatternRewriter &rewriter) const {
1587 // The handler region entry block has the !cir.eh_token argument.
1588 mlir::Block *handlerEntry = &handlerRegion.front();
1589
1590 // Inline the handler region before insertBefore.
1591 rewriter.inlineRegionBefore(handlerRegion, insertBefore);
1592
1593 // Replace yield terminators in the handler with branches to continue.
1594 for (mlir::Block &block : llvm::make_range(handlerEntry->getIterator(),
1595 insertBefore->getIterator())) {
1596 if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator())) {
1597 // Verify that end_catch is the last non-branch operation before
1598 // this yield. After cleanup scope flattening, end_catch may be
1599 // in a predecessor block rather than immediately before the yield.
1600 // Walk back through predecessors (including multi-predecessor
1601 // blocks), verifying that each intermediate block contains only a
1602 // branch terminator, until we find end_catch as the last
1603 // non-terminator in some block.
1604 // Verify that end_catch is reachable on some predecessor path
1605 // before this yield. After cleanup scope flattening, end_catch
1606 // may be separated from yield by conditional branches (e.g.,
1607 // from flattened cir.if inside the catch body).
1608 assert(([&]() {
1609 if (mlir::Operation *prev = yieldOp->getPrevNode())
1610 return isa<cir::EndCatchOp>(prev);
1611 llvm::SmallPtrSet<mlir::Block *, 8> visited;
1612 llvm::SmallVector<mlir::Block *, 4> worklist;
1613 for (mlir::Block *pred : block.getPredecessors())
1614 worklist.push_back(pred);
1615 while (!worklist.empty()) {
1616 mlir::Block *b = worklist.pop_back_val();
1617 if (!visited.insert(b).second)
1618 continue;
1619 mlir::Operation *term = b->getTerminator();
1620 if (mlir::Operation *prev = term->getPrevNode()) {
1621 if (isa<cir::EndCatchOp>(prev))
1622 return true;
1623 }
1624 for (mlir::Block *pred : b->getPredecessors())
1625 worklist.push_back(pred);
1626 }
1627 return false;
1628 }()) &&
1629 "expected end_catch reachable before yield "
1630 "in catch handler");
1631 rewriter.setInsertionPoint(yieldOp);
1632 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, continueBlock);
1633 }
1634 }
1635
1636 return handlerEntry;
1637 }
1638
1639 // Flatten an unwind handler region. The unwind region just contains a
1640 // cir.resume that continues unwinding. We inline it and leave the resume
1641 // in place. If this try op is nested inside an EH cleanup or another try op,
1642 // the enclosing op will rewrite the resume as a branch to its cleanup or
1643 // dispatch block when it is flattened. Otherwise, the resume will unwind to
1644 // the caller.
1645 mlir::Block *flattenUnwindHandler(mlir::Region &unwindRegion,
1646 mlir::Location loc,
1647 mlir::Block *insertBefore,
1648 mlir::PatternRewriter &rewriter) const {
1649 mlir::Block *unwindEntry = &unwindRegion.front();
1650 rewriter.inlineRegionBefore(unwindRegion, insertBefore);
1651 return unwindEntry;
1652 }
1653
1654 mlir::LogicalResult
1655 matchAndRewrite(cir::TryOp tryOp,
1656 mlir::PatternRewriter &rewriter) const override {
1657 // All nested structured CIR ops must be flattened before the try op.
1658 // Cleanup scopes and nested try ops need to be flat so EH cleanup is
1659 // properly handled. Other structured ops (scopes, ifs, loops, switches,
1660 // ternaries) must be flat because replaceCallWithTryCall creates try_call
1661 // ops whose unwind destination is outside the structured op's region,
1662 // which would be an invalid cross-region reference.
1663 for (mlir::Region &region : tryOp->getRegions())
1664 if (hasNestedOpsToFlatten(region))
1665 return mlir::failure();
1666
1667 mlir::OpBuilder::InsertionGuard guard(rewriter);
1668 mlir::Location loc = tryOp.getLoc();
1669
1670 mlir::ArrayAttr handlerTypes = tryOp.getHandlerTypesAttr();
1671 mlir::MutableArrayRef<mlir::Region> handlerRegions =
1672 tryOp.getHandlerRegions();
1673
1674 // Collect throwing calls and throws in the try body.
1675 llvm::SmallVector<cir::CallOp> callsToRewrite;
1676 collectThrowingCalls(tryOp.getTryRegion(), callsToRewrite);
1677 llvm::SmallVector<cir::ThrowOp> throwsToRewrite;
1678 collectThrows(tryOp.getTryRegion(), throwsToRewrite);
1679
1680 // Collect resume ops from already-flattened cleanup scopes in the try body.
1681 llvm::SmallVector<cir::ResumeOp> resumeOpsToChain;
1682 collectResumeOps(tryOp.getTryRegion(), resumeOpsToChain);
1683
1684 // Split the current block and inline the try body.
1685 mlir::Block *currentBlock = rewriter.getInsertionBlock();
1686 mlir::Block *continueBlock =
1687 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
1688
1689 // Get references to try body blocks before inlining.
1690 mlir::Block *bodyEntry = &tryOp.getTryRegion().front();
1691 mlir::Block *bodyExit = &tryOp.getTryRegion().back();
1692
1693 // Inline the try body region before the continue block.
1694 rewriter.inlineRegionBefore(tryOp.getTryRegion(), continueBlock);
1695
1696 // Branch from the current block to the body entry.
1697 rewriter.setInsertionPointToEnd(currentBlock);
1698 cir::BrOp::create(rewriter, loc, bodyEntry);
1699
1700 // Replace the try body's yield terminator with a branch to continue.
1701 if (auto bodyYield = dyn_cast<cir::YieldOp>(bodyExit->getTerminator())) {
1702 rewriter.setInsertionPoint(bodyYield);
1703 rewriter.replaceOpWithNewOp<cir::BrOp>(bodyYield, continueBlock);
1704 }
1705
1706 // If there are no handlers, we're done.
1707 if (!handlerTypes || handlerTypes.empty()) {
1708 rewriter.eraseOp(tryOp);
1709 return mlir::success();
1710 }
1711
1712 // If there are no throwing calls, no throws, and no resume ops from
1713 // inner cleanup scopes, exceptions cannot reach the catch handlers.
1714 // Drop all uses from the (unreachable) handler regions before erasing
1715 // the try op, since handler ops may reference values that were inlined
1716 // from the try body into the parent block.
1717 if (callsToRewrite.empty() && throwsToRewrite.empty() &&
1718 resumeOpsToChain.empty()) {
1719 for (mlir::Region &handlerRegion : handlerRegions)
1720 for (mlir::Block &block : handlerRegion)
1721 block.dropAllDefinedValueUses();
1722 rewriter.eraseOp(tryOp);
1723 return mlir::success();
1724 }
1725
1726 // Build the catch handler blocks.
1727
1728 // First, flatten all handler regions and collect the entry blocks.
1729 llvm::SmallVector<mlir::Block *> catchHandlerBlocks;
1730
1731 for (const auto &[idx, typeAttr] : llvm::enumerate(handlerTypes)) {
1732 mlir::Region &handlerRegion = handlerRegions[idx];
1733
1734 if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
1735 mlir::Block *unwindEntry =
1736 flattenUnwindHandler(handlerRegion, loc, continueBlock, rewriter);
1737 catchHandlerBlocks.push_back(unwindEntry);
1738 } else {
1739 mlir::Block *handlerEntry = flattenCatchHandler(
1740 handlerRegion, continueBlock, loc, continueBlock, rewriter);
1741 catchHandlerBlocks.push_back(handlerEntry);
1742 }
1743 }
1744
1745 // Build the catch dispatch block.
1746 mlir::Block *dispatchBlock =
1747 buildCatchDispatchBlock(tryOp, handlerTypes, catchHandlerBlocks, loc,
1748 catchHandlerBlocks.front(), rewriter);
1749
1750 // Check whether the try has a catch-all handler. When catch-all is
1751 // present, the personality function will always stop unwinding at this
1752 // frame (because catch-all matches every exception type). The LLVM
1753 // landingpad therefore needs "catch ptr null" rather than "cleanup".
1754 // The downstream pipeline (EHABILowering + LowerToLLVM) emits
1755 // "catch ptr null" when the EhInitiateOp has neither cleanup nor typed
1756 // catch types, so we clear the cleanup flag on every EhInitiateOp that
1757 // feeds into a dispatch with a catch-all handler.
1758 bool hasCatchAll =
1759 handlerTypes && llvm::any_of(handlerTypes, [](mlir::Attribute attr) {
1760 return mlir::isa<cir::CatchAllAttr>(attr);
1761 });
1762
1763 // Build a block to be the unwind desination for throwing calls/throws
1764 // and replace the calls/throws with try_call/try_throw ops. Note that
1765 // the unwind block created here is something different than the unwind
1766 // handler that we may have created above. The unwind handler continues
1767 // unwinding after uncaught exceptions. This is the block that will
1768 // eventually become the landing pad for invoke instructions.
1769 bool isCleanupOnly = tryOp.getCleanup() && !hasCatchAll;
1770 if (!callsToRewrite.empty() || !throwsToRewrite.empty()) {
1771 // Create a shared unwind block for all throwing calls/throws.
1772 mlir::Block *unwindBlock = buildUnwindBlock(dispatchBlock, isCleanupOnly,
1773 loc, dispatchBlock, rewriter);
1774
1775 for (cir::CallOp callOp : callsToRewrite)
1776 replaceCallWithTryCall(callOp, unwindBlock, loc, rewriter);
1777 for (cir::ThrowOp throwOp : throwsToRewrite)
1778 replaceThrowWithTryThrow(throwOp, unwindBlock, loc, rewriter);
1779 }
1780
1781 // Chain resume ops from inner cleanup scopes.
1782 // Resume ops from already-flattened cleanup scopes within the try body
1783 // should branch to the catch dispatch block instead of unwinding directly.
1784 for (cir::ResumeOp resumeOp : resumeOpsToChain) {
1785 // When there is a catch-all handler, clear the cleanup flag on the
1786 // cir.eh.initiate that produced this token. With catch-all, the LLVM
1787 // landingpad needs "catch ptr null" instead of "cleanup".
1788 if (hasCatchAll) {
1789 if (auto ehInitiate = traceToEhInitiate(resumeOp.getEhToken())) {
1790 rewriter.modifyOpInPlace(ehInitiate,
1791 [&] { ehInitiate.removeCleanupAttr(); });
1792 }
1793 }
1794
1795 mlir::Value ehToken = resumeOp.getEhToken();
1796 rewriter.setInsertionPoint(resumeOp);
1797 rewriter.replaceOpWithNewOp<cir::BrOp>(
1798 resumeOp, mlir::ValueRange{ehToken}, dispatchBlock);
1799 }
1800
1801 // Finally, erase the original try op ----
1802 rewriter.eraseOp(tryOp);
1803
1804 return mlir::success();
1805 }
1806};
1807
1808void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
1809 patterns
1810 .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
1811 CIRSwitchOpFlattening, CIRTernaryOpFlattening,
1812 CIRCleanupScopeOpFlattening, CIRTryOpFlattening>(
1813 patterns.getContext());
1814}
1815
1816void CIRFlattenCFGPass::runOnOperation() {
1817 RewritePatternSet patterns(&getContext());
1818 populateFlattenCFGPatterns(patterns);
1819
1820 // Collect operations to apply patterns.
1821 llvm::SmallVector<Operation *, 16> ops;
1822 getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
1823 if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp, CleanupScopeOp,
1824 TryOp>(op))
1825 ops.push_back(op);
1826 });
1827
1828 // Apply patterns.
1829 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
1830 signalPassFailure();
1831}
1832
1833} // namespace
1834
1835namespace mlir {
1836
1837std::unique_ptr<Pass> createCIRFlattenCFGPass() {
1838 return std::make_unique<CIRFlattenCFGPass>();
1839}
1840
1841} // namespace mlir
__device__ __2f16 b
mlir::Block * replaceThrowWithTryThrow(cir::ThrowOp throwOp, mlir::Block *unwindDest, mlir::Location loc, mlir::RewriterBase &rewriter)
Replace a cir::ThrowOp with a cir::TryThrowOp whose unwind destination is unwindDest.
mlir::Block * replaceCallWithTryCall(cir::CallOp callOp, mlir::Block *unwindDest, mlir::Location loc, mlir::RewriterBase &rewriter)
Replace a cir::CallOp with a cir::TryCallOp whose unwind destination is unwindDest.
llvm::APInt APInt
Definition FixedPoint.h:19
ASTEdit insertBefore(RangeSelector S, TextGenerator Replacement)
Inserts Replacement before S, leaving the source selected by \S unchanged.
unsigned long uint64_t
std::unique_ptr< Pass > createCIRFlattenCFGPass()
int const char * function
Definition c++config.h:31
float __ovld __cnfn step(float, float)
Returns 0.0 if x < edge, otherwise it returns 1.0.
static bool stackSaveOp()