clang 23.0.0git
FlattenCFG.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements pass that inlines CIR operations regions into the parent
10// function region.
11//
12//===----------------------------------------------------------------------===//
13
14#include "CIRTransformUtils.h"
15#include "PassDetail.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/IR/Block.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Interfaces/SideEffectInterfaces.h"
21#include "mlir/Support/LogicalResult.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28#include "llvm/ADT/TypeSwitch.h"
29
30using namespace mlir;
31using namespace cir;
32
33namespace mlir {
34#define GEN_PASS_DEF_CIRFLATTENCFG
35#include "clang/CIR/Dialect/Passes.h.inc"
36} // namespace mlir
37
38namespace {
39
40/// Lowers operations with the terminator trait that have a single successor.
41void lowerTerminator(mlir::Operation *op, mlir::Block *dest,
42 mlir::PatternRewriter &rewriter) {
43 assert(op->hasTrait<mlir::OpTrait::IsTerminator>() && "not a terminator");
44 mlir::OpBuilder::InsertionGuard guard(rewriter);
45 rewriter.setInsertionPoint(op);
46 rewriter.replaceOpWithNewOp<cir::BrOp>(op, dest);
47}
48
49/// Walks a region while skipping operations of type `Ops`. This ensures the
50/// callback is not applied to said operations and its children.
51template <typename... Ops>
52void walkRegionSkipping(
53 mlir::Region &region,
54 mlir::function_ref<mlir::WalkResult(mlir::Operation *)> callback) {
55 region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
56 if (isa<Ops...>(op))
57 return mlir::WalkResult::skip();
58 return callback(op);
59 });
60}
61
62/// Check whether a region contains any nested op with regions (i.e. structured
63/// CIR ops that must be flattened before their parent). The greedy pattern
64/// rewriter doesn't guarantee inside-out processing order — when a pattern
65/// fires and modifies IR, newly created ops go onto the worklist and can be
66/// visited in any order. So each flattening pattern must explicitly defer
67/// until its nested structured ops are flat.
68///
69/// CaseOps are excluded because they are structural children of SwitchOp and
70/// are handled by the SwitchOp flattening pattern.
71static bool hasNestedOpsToFlatten(mlir::Region &region) {
72 return region
73 .walk([](mlir::Operation *op) {
74 if (op->getNumRegions() > 0 && !isa<cir::CaseOp>(op))
75 return mlir::WalkResult::interrupt();
76 return mlir::WalkResult::advance();
77 })
78 .wasInterrupted();
79}
80
81/// True if `op` is a non-returning terminator — currently `cir.unreachable`
82/// or `cir.trap`. Such terminators don't fall through and don't yield a
83/// value, so when flattening a region they can be left in place rather than
84/// being replaced with a branch to the continuation block. Add new ops here
85/// (e.g. a hypothetical `cir.abort`) so every flattening pattern picks them
86/// up at once.
87static bool isNonReturningTerminator(mlir::Operation *op) {
88 return mlir::isa_and_nonnull<cir::UnreachableOp, cir::TrapOp>(op);
89}
90
91/// Rewrite the terminator of `region`'s exit block so that, after
92/// flattening, control falls through to `continueBlock`. The exit
93/// terminator is expected to be either:
94/// - `cir.yield`: replaced with `cir.br` to `continueBlock` (yielded
95/// args become the destination block's arguments).
96/// - non-returning (`cir.unreachable`, `cir.trap`): left in place — no
97/// branch is needed.
98///
99/// On success returns `success()`. If the terminator is anything else, an
100/// error is emitted and `failure()` is returned. NOTE: callers in this
101/// file have typically already mutated IR (splitBlock / createBlock) by
102/// the time this is invoked, so the MLIR pattern rewriter contract
103/// requires them to still return `success()` from the surrounding
104/// pattern; the `failure()` here just signals "stop trying to wire up
105/// this region".
106static mlir::LogicalResult
107rewriteRegionExitToContinue(mlir::PatternRewriter &rewriter,
108 mlir::Region &region, mlir::Block *continueBlock,
109 llvm::StringRef regionDescription) {
110 mlir::Operation *terminator = region.back().getTerminator();
111 rewriter.setInsertionPointToEnd(&region.back());
112 if (auto yieldOp = mlir::dyn_cast<cir::YieldOp>(terminator)) {
113 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(),
114 continueBlock);
115 return mlir::success();
116 }
117 if (isNonReturningTerminator(terminator))
118 return mlir::success();
119 terminator->emitError("unexpected terminator in ")
120 << regionDescription
121 << " region, expected yield, unreachable, or trap, got: "
122 << terminator->getName();
123 return mlir::failure();
124}
125
126struct CIRFlattenCFGPass : public impl::CIRFlattenCFGBase<CIRFlattenCFGPass> {
127
128 CIRFlattenCFGPass() = default;
129 void runOnOperation() override;
130};
131
132struct CIRIfFlattening : public mlir::OpRewritePattern<cir::IfOp> {
133 using OpRewritePattern<IfOp>::OpRewritePattern;
134
135 mlir::LogicalResult
136 matchAndRewrite(cir::IfOp ifOp,
137 mlir::PatternRewriter &rewriter) const override {
138 mlir::OpBuilder::InsertionGuard guard(rewriter);
139 mlir::Location loc = ifOp.getLoc();
140 bool emptyElse = ifOp.getElseRegion().empty();
141 mlir::Block *currentBlock = rewriter.getInsertionBlock();
142 mlir::Block *remainingOpsBlock =
143 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
144 mlir::Block *continueBlock;
145 if (ifOp->getResults().empty())
146 continueBlock = remainingOpsBlock;
147 else
148 llvm_unreachable("NYI");
149
150 // Inline the region
151 mlir::Block *thenBeforeBody = &ifOp.getThenRegion().front();
152 mlir::Block *thenAfterBody = &ifOp.getThenRegion().back();
153 rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock);
154
155 rewriter.setInsertionPointToEnd(thenAfterBody);
156 if (auto thenYieldOp =
157 dyn_cast<cir::YieldOp>(thenAfterBody->getTerminator())) {
158 rewriter.replaceOpWithNewOp<cir::BrOp>(thenYieldOp, thenYieldOp.getArgs(),
159 continueBlock);
160 }
161
162 rewriter.setInsertionPointToEnd(continueBlock);
163
164 // Has else region: inline it.
165 mlir::Block *elseBeforeBody = nullptr;
166 mlir::Block *elseAfterBody = nullptr;
167 if (!emptyElse) {
168 elseBeforeBody = &ifOp.getElseRegion().front();
169 elseAfterBody = &ifOp.getElseRegion().back();
170 rewriter.inlineRegionBefore(ifOp.getElseRegion(), continueBlock);
171 } else {
172 elseBeforeBody = elseAfterBody = continueBlock;
173 }
174
175 rewriter.setInsertionPointToEnd(currentBlock);
176 cir::BrCondOp::create(rewriter, loc, ifOp.getCondition(), thenBeforeBody,
177 elseBeforeBody);
178
179 if (!emptyElse) {
180 rewriter.setInsertionPointToEnd(elseAfterBody);
181 if (auto elseYieldOP =
182 dyn_cast<cir::YieldOp>(elseAfterBody->getTerminator())) {
183 rewriter.replaceOpWithNewOp<cir::BrOp>(
184 elseYieldOP, elseYieldOP.getArgs(), continueBlock);
185 }
186 }
187
188 rewriter.replaceOp(ifOp, continueBlock->getArguments());
189 return mlir::success();
190 }
191};
192
193class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> {
194public:
195 using OpRewritePattern<cir::ScopeOp>::OpRewritePattern;
196
197 mlir::LogicalResult
198 matchAndRewrite(cir::ScopeOp scopeOp,
199 mlir::PatternRewriter &rewriter) const override {
200 mlir::OpBuilder::InsertionGuard guard(rewriter);
201 mlir::Location loc = scopeOp.getLoc();
202
203 // Empty scope: just remove it.
204 // TODO: Remove this logic once CIR uses MLIR infrastructure to remove
205 // trivially dead operations. MLIR canonicalizer is too aggressive and we
206 // need to either (a) make sure all our ops model all side-effects and/or
207 // (b) have more options in the canonicalizer in MLIR to temper
208 // aggressiveness level.
209 if (scopeOp.isEmpty()) {
210 rewriter.eraseOp(scopeOp);
211 return mlir::success();
212 }
213
214 // Split the current block before the ScopeOp to create the inlining
215 // point.
216 mlir::Block *currentBlock = rewriter.getInsertionBlock();
217 mlir::Block *continueBlock =
218 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
219 if (scopeOp.getNumResults() > 0)
220 continueBlock->addArguments(scopeOp.getResultTypes(), loc);
221
222 // Inline body region.
223 mlir::Block *beforeBody = &scopeOp.getScopeRegion().front();
224 mlir::Block *afterBody = &scopeOp.getScopeRegion().back();
225 rewriter.inlineRegionBefore(scopeOp.getScopeRegion(), continueBlock);
226
227 // Save stack and then branch into the body of the region.
228 rewriter.setInsertionPointToEnd(currentBlock);
230 cir::BrOp::create(rewriter, loc, mlir::ValueRange(), beforeBody);
231
232 // Replace the scopeop return with a branch that jumps out of the body.
233 // Stack restore before leaving the body region.
234 rewriter.setInsertionPointToEnd(afterBody);
235 if (auto yieldOp = dyn_cast<cir::YieldOp>(afterBody->getTerminator())) {
236 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(),
237 continueBlock);
238 }
239
240 // Replace the op with values return from the body region.
241 rewriter.replaceOp(scopeOp, continueBlock->getArguments());
242
243 return mlir::success();
244 }
245};
246
247class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
248public:
249 using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
250
251 inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
252 cir::YieldOp yieldOp,
253 mlir::Block *destination) const {
254 rewriter.setInsertionPoint(yieldOp);
255 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
256 destination);
257 }
258
259 // Return the new defaultDestination block.
260 Block *condBrToRangeDestination(cir::SwitchOp op,
261 mlir::PatternRewriter &rewriter,
262 mlir::Block *rangeDestination,
263 mlir::Block *defaultDestination,
264 const APInt &lowerBound,
265 const APInt &upperBound) const {
266 assert(lowerBound.sle(upperBound) && "Invalid range");
267 mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
268 cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true);
269 cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false);
270
271 cir::ConstantOp rangeLength = cir::ConstantOp::create(
272 rewriter, op.getLoc(),
273 cir::IntAttr::get(sIntType, upperBound - lowerBound));
274
275 cir::ConstantOp lowerBoundValue = cir::ConstantOp::create(
276 rewriter, op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
277 mlir::Value diffValue = cir::SubOp::create(
278 rewriter, op.getLoc(), op.getCondition(), lowerBoundValue);
279
280 // Use unsigned comparison to check if the condition is in the range.
281 cir::CastOp uDiffValue = cir::CastOp::create(
282 rewriter, op.getLoc(), uIntType, CastKind::integral, diffValue);
283 cir::CastOp uRangeLength = cir::CastOp::create(
284 rewriter, op.getLoc(), uIntType, CastKind::integral, rangeLength);
285
286 cir::CmpOp cmpResult = cir::CmpOp::create(
287 rewriter, op.getLoc(), cir::CmpOpKind::le, uDiffValue, uRangeLength);
288 cir::BrCondOp::create(rewriter, op.getLoc(), cmpResult, rangeDestination,
289 defaultDestination);
290 return resBlock;
291 }
292
293 mlir::LogicalResult
294 matchAndRewrite(cir::SwitchOp op,
295 mlir::PatternRewriter &rewriter) const override {
296 // All nested structured CIR ops must be flattened before the switch.
297 // Break statements inside nested structured ops would create branches to
298 // blocks outside those ops' regions, which is invalid. Fail the match so
299 // the pattern rewriter will process them first.
300 for (mlir::Region &region : op->getRegions())
301 if (hasNestedOpsToFlatten(region))
302 return mlir::failure();
303
304 llvm::SmallVector<CaseOp> cases;
305 op.collectCases(cases);
306
307 // Empty switch statement: just erase it.
308 if (cases.empty()) {
309 rewriter.eraseOp(op);
310 return mlir::success();
311 }
312
313 // Create exit block from the next node of cir.switch op.
314 mlir::Block *exitBlock = rewriter.splitBlock(
315 rewriter.getBlock(), op->getNextNode()->getIterator());
316
317 // We lower cir.switch op in the following process:
318 // 1. Inline the region from the switch op after switch op.
319 // 2. Traverse each cir.case op:
320 // a. Record the entry block, block arguments and condition for every
321 // case. b. Inline the case region after the case op.
322 // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the
323 // recorded block and conditions.
324
325 // inline everything from switch body between the switch op and the exit
326 // block.
327 {
328 cir::YieldOp switchYield = nullptr;
329 // Clear switch operation.
330 for (mlir::Block &block :
331 llvm::make_early_inc_range(op.getBody().getBlocks()))
332 if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
333 switchYield = yieldOp;
334
335 assert(!op.getBody().empty());
336 mlir::Block *originalBlock = op->getBlock();
337 mlir::Block *swopBlock =
338 rewriter.splitBlock(originalBlock, op->getIterator());
339 rewriter.inlineRegionBefore(op.getBody(), exitBlock);
340
341 if (switchYield)
342 rewriteYieldOp(rewriter, switchYield, exitBlock);
343
344 rewriter.setInsertionPointToEnd(originalBlock);
345 cir::BrOp::create(rewriter, op.getLoc(), swopBlock);
346 }
347
348 // Allocate required data structures (disconsider default case in
349 // vectors).
350 llvm::SmallVector<mlir::APInt, 8> caseValues;
351 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
352 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
353
354 llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
355 llvm::SmallVector<mlir::Block *> rangeDestinations;
356 llvm::SmallVector<mlir::ValueRange> rangeOperands;
357
358 // Initialize default case as optional.
359 mlir::Block *defaultDestination = exitBlock;
360 mlir::ValueRange defaultOperands = exitBlock->getArguments();
361
362 // Digest the case statements values and bodies.
363 for (cir::CaseOp caseOp : cases) {
364 mlir::Region &region = caseOp.getCaseRegion();
365
366 // Found default case: save destination and operands.
367 switch (caseOp.getKind()) {
368 case cir::CaseOpKind::Default:
369 defaultDestination = &region.front();
370 defaultOperands = defaultDestination->getArguments();
371 break;
372 case cir::CaseOpKind::Range:
373 assert(caseOp.getValue().size() == 2 &&
374 "Case range should have 2 case value");
375 rangeValues.push_back(
376 {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),
377 cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});
378 rangeDestinations.push_back(&region.front());
379 rangeOperands.push_back(rangeDestinations.back()->getArguments());
380 break;
381 case cir::CaseOpKind::Anyof:
382 case cir::CaseOpKind::Equal:
383 // AnyOf cases kind can have multiple values, hence the loop below.
384 for (const mlir::Attribute &value : caseOp.getValue()) {
385 caseValues.push_back(cast<cir::IntAttr>(value).getValue());
386 caseDestinations.push_back(&region.front());
387 caseOperands.push_back(caseDestinations.back()->getArguments());
388 }
389 break;
390 }
391
392 // Handle break statements.
393 walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
394 region, [&](mlir::Operation *op) {
395 if (!isa<cir::BreakOp>(op))
396 return mlir::WalkResult::advance();
397
398 lowerTerminator(op, exitBlock, rewriter);
399 return mlir::WalkResult::skip();
400 });
401
402 // Track fallthrough in cases.
403 for (mlir::Block &blk : region.getBlocks()) {
404 if (blk.getNumSuccessors())
405 continue;
406
407 if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {
408 mlir::Operation *nextOp = caseOp->getNextNode();
409 assert(nextOp && "caseOp is not expected to be the last op");
410 mlir::Block *oldBlock = nextOp->getBlock();
411 mlir::Block *newBlock =
412 rewriter.splitBlock(oldBlock, nextOp->getIterator());
413 rewriter.setInsertionPointToEnd(oldBlock);
414 cir::BrOp::create(rewriter, nextOp->getLoc(), mlir::ValueRange(),
415 newBlock);
416 rewriteYieldOp(rewriter, yieldOp, newBlock);
417 }
418 }
419
420 mlir::Block *oldBlock = caseOp->getBlock();
421 mlir::Block *newBlock =
422 rewriter.splitBlock(oldBlock, caseOp->getIterator());
423
424 mlir::Block &entryBlock = caseOp.getCaseRegion().front();
425 rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);
426
427 // Create a branch to the entry of the inlined region.
428 rewriter.setInsertionPointToEnd(oldBlock);
429 cir::BrOp::create(rewriter, caseOp.getLoc(), &entryBlock);
430 }
431
432 // Remove all cases since we've inlined the regions.
433 for (cir::CaseOp caseOp : cases) {
434 mlir::Block *caseBlock = caseOp->getBlock();
435 // Erase the block with no predecessors here to make the generated code
436 // simpler a little bit.
437 if (caseBlock->hasNoPredecessors())
438 rewriter.eraseBlock(caseBlock);
439 else
440 rewriter.eraseOp(caseOp);
441 }
442
443 for (auto [rangeVal, operand, destination] :
444 llvm::zip(rangeValues, rangeOperands, rangeDestinations)) {
445 APInt lowerBound = rangeVal.first;
446 APInt upperBound = rangeVal.second;
447
448 // The case range is unreachable, skip it.
449 if (lowerBound.sgt(upperBound))
450 continue;
451
452 // If range is small, add multiple switch instruction cases.
453 // This magical number is from the original CGStmt code.
454 constexpr int kSmallRangeThreshold = 64;
455 if ((upperBound - lowerBound)
456 .ult(llvm::APInt(32, kSmallRangeThreshold))) {
457 for (APInt iValue = lowerBound; iValue.sle(upperBound); ++iValue) {
458 caseValues.push_back(iValue);
459 caseOperands.push_back(operand);
460 caseDestinations.push_back(destination);
461 }
462 continue;
463 }
464
465 defaultDestination =
466 condBrToRangeDestination(op, rewriter, destination,
467 defaultDestination, lowerBound, upperBound);
468 defaultOperands = operand;
469 }
470
471 // Set switch op to branch to the newly created blocks.
472 rewriter.setInsertionPoint(op);
473 rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>(
474 op, op.getCondition(), defaultDestination, defaultOperands, caseValues,
475 caseDestinations, caseOperands);
476
477 return mlir::success();
478 }
479};
480
481class CIRLoopOpInterfaceFlattening
482 : public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> {
483public:
484 using mlir::OpInterfaceRewritePattern<
485 cir::LoopOpInterface>::OpInterfaceRewritePattern;
486
487 inline void lowerConditionOp(cir::ConditionOp op, mlir::Block *body,
488 mlir::Block *exit,
489 mlir::PatternRewriter &rewriter) const {
490 mlir::OpBuilder::InsertionGuard guard(rewriter);
491 rewriter.setInsertionPoint(op);
492 rewriter.replaceOpWithNewOp<cir::BrCondOp>(op, op.getCondition(), body,
493 exit);
494 }
495
496 mlir::LogicalResult
497 matchAndRewrite(cir::LoopOpInterface op,
498 mlir::PatternRewriter &rewriter) const final {
499 // All nested structured CIR ops must be flattened before the loop.
500 // Break/continue statements inside nested structured ops would create
501 // branches to blocks outside those ops' regions, which is invalid. Fail
502 // the match so the pattern rewriter will process them first.
503 for (mlir::Region &region : op->getRegions())
504 if (hasNestedOpsToFlatten(region))
505 return mlir::failure();
506
507 // Setup CFG blocks.
508 mlir::Block *entry = rewriter.getInsertionBlock();
509 mlir::Block *exit =
510 rewriter.splitBlock(entry, rewriter.getInsertionPoint());
511 mlir::Block *cond = &op.getCond().front();
512 mlir::Block *body = &op.getBody().front();
513 mlir::Block *step =
514 (op.maybeGetStep() ? &op.maybeGetStep()->front() : nullptr);
515
516 // Setup loop entry branch.
517 rewriter.setInsertionPointToEnd(entry);
518 cir::BrOp::create(rewriter, op.getLoc(), &op.getEntry().front());
519
520 // Branch from condition region to body or exit. The ConditionOp may not
521 // be in the first block of the condition region if a cleanup scope was
522 // already flattened within it, introducing multiple blocks. The
523 // ConditionOp is always the terminator of the last block.
524 auto conditionOp =
525 cast<cir::ConditionOp>(op.getCond().back().getTerminator());
526 lowerConditionOp(conditionOp, body, exit, rewriter);
527
528 // TODO(cir): Remove the walks below. It visits operations unnecessarily.
529 // However, to solve this we would likely need a custom DialectConversion
530 // driver to customize the order that operations are visited.
531
532 // Lower continue statements.
533 mlir::Block *dest = (step ? step : cond);
534 op.walkBodySkippingNestedLoops([&](mlir::Operation *op) {
535 if (!isa<cir::ContinueOp>(op))
536 return mlir::WalkResult::advance();
537
538 lowerTerminator(op, dest, rewriter);
539 return mlir::WalkResult::skip();
540 });
541
542 // Lower break statements.
543 walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
544 op.getBody(), [&](mlir::Operation *op) {
545 if (!isa<cir::BreakOp>(op))
546 return mlir::WalkResult::advance();
547
548 lowerTerminator(op, exit, rewriter);
549 return mlir::WalkResult::skip();
550 });
551
552 // Lower optional body region yield.
553 for (mlir::Block &blk : op.getBody().getBlocks()) {
554 auto bodyYield = dyn_cast<cir::YieldOp>(blk.getTerminator());
555 if (bodyYield)
556 lowerTerminator(bodyYield, (step ? step : cond), rewriter);
557 }
558
559 // Lower mandatory step region yield. Like the condition region, the
560 // YieldOp may be in the last block rather than the first if a cleanup
561 // scope was already flattened within the step region.
562 if (step)
563 lowerTerminator(
564 cast<cir::YieldOp>(op.maybeGetStep()->back().getTerminator()), cond,
565 rewriter);
566
567 // Move region contents out of the loop op.
568 rewriter.inlineRegionBefore(op.getCond(), exit);
569 rewriter.inlineRegionBefore(op.getBody(), exit);
570 if (step)
571 rewriter.inlineRegionBefore(*op.maybeGetStep(), exit);
572
573 rewriter.eraseOp(op);
574 return mlir::success();
575 }
576};
577
578class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
579public:
580 using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
581
582 mlir::LogicalResult
583 matchAndRewrite(cir::TernaryOp op,
584 mlir::PatternRewriter &rewriter) const override {
585 Location loc = op->getLoc();
586 Block *condBlock = rewriter.getInsertionBlock();
587 Block::iterator opPosition = rewriter.getInsertionPoint();
588 Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
589 llvm::SmallVector<mlir::Location, 2> locs;
590 // Ternary result is optional, make sure to populate the location only
591 // when relevant.
592 if (op->getResultTypes().size())
593 locs.push_back(loc);
594 Block *continueBlock =
595 rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
596 cir::BrOp::create(rewriter, loc, remainingOpsBlock);
597
598 Region &trueRegion = op.getTrueRegion();
599 Block *trueBlock = &trueRegion.front();
600 // Wire up the true region's exit (cir.yield -> br, cir.unreachable /
601 // cir.trap kept as-is). IR has already been modified by splitBlock /
602 // createBlock above, so per the MLIR pattern rewriter contract we must
603 // still return success() if the terminator turns out to be unexpected.
604 if (failed(rewriteRegionExitToContinue(rewriter, trueRegion, continueBlock,
605 "ternary true")))
606 return mlir::success();
607 rewriter.inlineRegionBefore(trueRegion, continueBlock);
608
609 Block *falseBlock = continueBlock;
610 Region &falseRegion = op.getFalseRegion();
611
612 falseBlock = &falseRegion.front();
613 if (failed(rewriteRegionExitToContinue(rewriter, falseRegion, continueBlock,
614 "ternary false")))
615 return mlir::success();
616 rewriter.inlineRegionBefore(falseRegion, continueBlock);
617
618 rewriter.setInsertionPointToEnd(condBlock);
619 cir::BrCondOp::create(rewriter, loc, op.getCond(), trueBlock, falseBlock);
620
621 rewriter.replaceOp(op, continueBlock->getArguments());
622
623 // Ok, we're done!
624 return mlir::success();
625 }
626};
627
628// Get or create the cleanup destination slot for a function. This slot is
629// shared across all cleanup scopes in the function to track which exit path
630// to take after running cleanup code when there are multiple exits.
631static cir::AllocaOp getOrCreateCleanupDestSlot(cir::FuncOp funcOp,
632 mlir::PatternRewriter &rewriter,
633 mlir::Location loc) {
634 mlir::Block &entryBlock = funcOp.getBody().front();
635
636 // Look for an existing cleanup dest slot in the entry block.
637 auto it = llvm::find_if(entryBlock, [](auto &op) {
638 return mlir::isa<AllocaOp>(&op) &&
639 mlir::cast<AllocaOp>(&op).getCleanupDestSlot();
640 });
641 if (it != entryBlock.end())
642 return mlir::cast<cir::AllocaOp>(*it);
643
644 // Create a new cleanup dest slot at the start of the entry block.
645 mlir::OpBuilder::InsertionGuard guard(rewriter);
646 rewriter.setInsertionPointToStart(&entryBlock);
647 cir::IntType s32Type =
648 cir::IntType::get(rewriter.getContext(), 32, /*isSigned=*/true);
649 cir::PointerType ptrToS32Type = cir::PointerType::get(s32Type);
650 cir::CIRDataLayout dataLayout(funcOp->getParentOfType<mlir::ModuleOp>());
651 uint64_t alignment = dataLayout.getAlignment(s32Type, true).value();
652 auto allocaOp = cir::AllocaOp::create(
653 rewriter, loc, ptrToS32Type, s32Type, "__cleanup_dest_slot",
654 /*alignment=*/rewriter.getI64IntegerAttr(alignment));
655 allocaOp.setCleanupDestSlot(true);
656 return allocaOp;
657}
658
659/// Shared EH flattening utilities used by both CIRCleanupScopeOpFlattening
660/// and CIRTryOpFlattening.
661
662// Collect all function calls in a region that may throw exceptions and need
663// to be replaced with try_call operations. Skips calls marked nothrow.
664// Nested cleanup scopes and try ops are always flattened before their
665// enclosing parents, so there are no nested regions to skip here.
666static void
667collectThrowingCalls(mlir::Region &region,
668 llvm::SmallVectorImpl<cir::CallOp> &callsToRewrite) {
669 region.walk([&](cir::CallOp callOp) {
670 if (!callOp.getNothrow())
671 callsToRewrite.push_back(callOp);
672 });
673}
674
675// Collect all cir.resume operations in a region that come from
676// already-flattened try or cleanup scope operations. These resume ops need
677// to be chained through this scope's EH handler instead of unwinding
678// directly to the caller. Nested cleanup scopes and try ops are always
679// flattened before their enclosing parents, so there are no nested regions
680// to skip here.
681static void collectResumeOps(mlir::Region &region,
683 region.walk([&](cir::ResumeOp resumeOp) { resumeOps.push_back(resumeOp); });
684}
685
686// Create a shared unwind destination block. The block contains a
687// cir.eh.initiate operation (optionally with the cleanup attribute) and a
688// branch to the given destination block, passing the eh_token.
689static mlir::Block *buildUnwindBlock(mlir::Block *dest, bool isCleanupOnly,
690 mlir::Location loc,
691 mlir::Block *insertBefore,
692 mlir::PatternRewriter &rewriter) {
693 mlir::Block *unwindBlock = rewriter.createBlock(insertBefore);
694 rewriter.setInsertionPointToEnd(unwindBlock);
695 auto ehInitiate =
696 cir::EhInitiateOp::create(rewriter, loc, /*cleanup=*/isCleanupOnly);
697 cir::BrOp::create(rewriter, loc, mlir::ValueRange{ehInitiate.getEhToken()},
698 dest);
699 return unwindBlock;
700}
701
702// Create a shared terminate unwind block for throwing calls in EH cleanup
703// regions. When an exception is thrown during cleanup (unwinding), the C++
704// standard requires that std::terminate() be called.
705static mlir::Block *buildTerminateUnwindBlock(mlir::Location loc,
706 mlir::Block *insertBefore,
707 mlir::PatternRewriter &rewriter) {
708 mlir::Block *terminateBlock = rewriter.createBlock(insertBefore);
709 rewriter.setInsertionPointToEnd(terminateBlock);
710 auto ehInitiate = cir::EhInitiateOp::create(rewriter, loc, /*cleanup=*/false);
711 cir::EhTerminateOp::create(rewriter, loc, ehInitiate.getEhToken());
712 return terminateBlock;
713}
714
715class CIRCleanupScopeOpFlattening
716 : public mlir::OpRewritePattern<cir::CleanupScopeOp> {
717public:
718 using OpRewritePattern<cir::CleanupScopeOp>::OpRewritePattern;
719
720 struct CleanupExit {
721 // An operation that exits the cleanup scope (yield, break, continue,
722 // return, etc.)
723 mlir::Operation *exitOp;
724
725 // A unique identifier for this exit's destination (used for switch dispatch
726 // when there are multiple exits).
727 int destinationId;
728
729 CleanupExit(mlir::Operation *op, int id) : exitOp(op), destinationId(id) {}
730 };
731
732 // Collect all operations that exit a cleanup scope body. Return, goto, break,
733 // and continue can all require branches through the cleanup region. When a
734 // loop is encountered, only return and goto are collected because break and
735 // continue are handled by the loop and stay within the cleanup scope. When a
736 // switch is encountered, return, goto and continue are collected because they
737 // may all branch through the cleanup, but break is local to the switch. When
738 // a nested cleanup scope is encountered, we recursively collect exits since
739 // any return, goto, break, or continue from the nested cleanup will also
740 // branch through the outer cleanup.
741 //
742 // Note that goto statements may not necessarily exit the cleanup scope, but
743 // for now we conservatively assume that they do. We'll need more nuanced
744 // handling of that when multi-exit flattening is implemented.
745 //
746 // This function assigns unique destination IDs to each exit, which are
747 // used when multi-exit cleanup scopes are flattened.
748 void collectExits(mlir::Region &cleanupBodyRegion,
749 llvm::SmallVectorImpl<CleanupExit> &exits,
750 int &nextId) const {
751 // Collect yield terminators from the body region. We do this separately
752 // because yields in nested operations, including those in nested cleanup
753 // scopes, won't branch through the outer cleanup region.
754 for (mlir::Block &block : cleanupBodyRegion) {
755 auto *terminator = block.getTerminator();
756 if (isa<cir::YieldOp>(terminator))
757 exits.emplace_back(terminator, nextId++);
758 }
759
760 // Lambda to walk a loop and collect only returns and gotos.
761 // Break and continue inside loops are handled by the loop itself.
762 // Loops don't require special handling for nested switch or cleanup scopes
763 // because break and continue never branch out of the loop.
764 auto collectExitsInLoop = [&](mlir::Operation *loopOp) {
765 loopOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *nestedOp) {
766 if (isa<cir::ReturnOp, cir::GotoOp>(nestedOp))
767 exits.emplace_back(nestedOp, nextId++);
768 return mlir::WalkResult::advance();
769 });
770 };
771
772 // Forward declaration for mutual recursion.
773 std::function<void(mlir::Region &, bool)> collectExitsInCleanup;
774 std::function<void(mlir::Operation *)> collectExitsInSwitch;
775
776 // Lambda to collect exits from a switch. Collects return/goto/continue but
777 // not break (handled by switch). For nested loops/cleanups, recurses.
778 collectExitsInSwitch = [&](mlir::Operation *switchOp) {
779 switchOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *nestedOp) {
780 if (isa<cir::CleanupScopeOp>(nestedOp)) {
781 // Walk the nested cleanup, but ignore break statements because they
782 // will be handled by the switch we are currently walking.
783 collectExitsInCleanup(
784 cast<cir::CleanupScopeOp>(nestedOp).getBodyRegion(),
785 /*ignoreBreak=*/true);
786 return mlir::WalkResult::skip();
787 } else if (isa<cir::LoopOpInterface>(nestedOp)) {
788 collectExitsInLoop(nestedOp);
789 return mlir::WalkResult::skip();
790 } else if (isa<cir::ReturnOp, cir::GotoOp, cir::ContinueOp>(nestedOp)) {
791 exits.emplace_back(nestedOp, nextId++);
792 }
793 return mlir::WalkResult::advance();
794 });
795 };
796
797 // Lambda to collect exits from a cleanup scope body region. This collects
798 // break (optionally), continue, return, and goto, handling nested loops,
799 // switches, and cleanups appropriately.
800 collectExitsInCleanup = [&](mlir::Region &region, bool ignoreBreak) {
801 region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
802 // We need special handling for break statements because if this cleanup
803 // scope was nested within a switch op, break will be handled by the
804 // switch operation and therefore won't exit the cleanup scope enclosing
805 // the switch. We're only collecting exits from the cleanup that started
806 // this walk. Exits from nested cleanups will be handled when we flatten
807 // the nested cleanup.
808 if (!ignoreBreak && isa<cir::BreakOp>(op)) {
809 exits.emplace_back(op, nextId++);
810 } else if (isa<cir::ContinueOp, cir::ReturnOp, cir::GotoOp>(op)) {
811 exits.emplace_back(op, nextId++);
812 } else if (isa<cir::CleanupScopeOp>(op)) {
813 // Recurse into nested cleanup's body region.
814 collectExitsInCleanup(cast<cir::CleanupScopeOp>(op).getBodyRegion(),
815 /*ignoreBreak=*/ignoreBreak);
816 return mlir::WalkResult::skip();
817 } else if (isa<cir::LoopOpInterface>(op)) {
818 // This kicks off a separate walk rather than continuing to dig deeper
819 // in the current walk because we need to handle break and continue
820 // differently inside loops.
821 collectExitsInLoop(op);
822 return mlir::WalkResult::skip();
823 } else if (isa<cir::SwitchOp>(op)) {
824 // This kicks off a separate walk rather than continuing to dig deeper
825 // in the current walk because we need to handle break differently
826 // inside switches.
827 collectExitsInSwitch(op);
828 return mlir::WalkResult::skip();
829 }
830 return mlir::WalkResult::advance();
831 });
832 };
833
834 // Collect exits from the body region.
835 collectExitsInCleanup(cleanupBodyRegion, /*ignoreBreak=*/false);
836 }
837
838 // Check if an operand's defining op should be moved to the destination block.
839 // We only sink constants and simple loads. Anything else should be saved
840 // to a temporary alloca and reloaded at the destination block.
841 static bool shouldSinkReturnOperand(mlir::Value operand,
842 cir::ReturnOp returnOp) {
843 // Block arguments can't be moved
844 mlir::Operation *defOp = operand.getDefiningOp();
845 if (!defOp)
846 return false;
847
848 // Only move constants and loads to the dispatch block. For anything else,
849 // we'll store to a temporary and reload in the dispatch block.
850 if (!mlir::isa<cir::ConstantOp, cir::LoadOp>(defOp))
851 return false;
852
853 // Check if the return is the only user
854 if (!operand.hasOneUse())
855 return false;
856
857 // Only move ops that are in the same block as the return.
858 if (defOp->getBlock() != returnOp->getBlock())
859 return false;
860
861 if (auto loadOp = mlir::dyn_cast<cir::LoadOp>(defOp)) {
862 // Only attempt to move loads of allocas in the entry block.
863 mlir::Value ptr = loadOp.getAddr();
864 auto funcOp = returnOp->getParentOfType<cir::FuncOp>();
865 assert(funcOp && "Return op has no function parent?");
866 mlir::Block &funcEntryBlock = funcOp.getBody().front();
867
868 // Check if it's an alloca in the function entry block
869 if (auto allocaOp =
870 mlir::dyn_cast_if_present<cir::AllocaOp>(ptr.getDefiningOp()))
871 return allocaOp->getBlock() == &funcEntryBlock;
872
873 return false;
874 }
875
876 // Make sure we only fall through to here with constants.
877 assert(mlir::isa<cir::ConstantOp>(defOp) && "Expected constant op");
878 return true;
879 }
880
881 // For returns with operands in cleanup dispatch blocks, the operands may not
882 // dominate the dispatch block. This function handles that by either sinking
883 // the operand's defining op to the dispatch block (for constants and simple
884 // loads) or by storing to a temporary alloca and reloading it.
885 void
886 getReturnOpOperands(cir::ReturnOp returnOp, mlir::Operation *exitOp,
887 mlir::Location loc, mlir::PatternRewriter &rewriter,
888 llvm::SmallVectorImpl<mlir::Value> &returnValues) const {
889 mlir::Block *destBlock = rewriter.getInsertionBlock();
890 auto funcOp = exitOp->getParentOfType<cir::FuncOp>();
891 assert(funcOp && "Return op has no function parent?");
892 mlir::Block &funcEntryBlock = funcOp.getBody().front();
893
894 for (mlir::Value operand : returnOp.getOperands()) {
895 if (shouldSinkReturnOperand(operand, returnOp)) {
896 // Sink the defining op to the dispatch block.
897 mlir::Operation *defOp = operand.getDefiningOp();
898 rewriter.moveOpBefore(defOp, destBlock, destBlock->end());
899 returnValues.push_back(operand);
900 } else {
901 // Create an alloca in the function entry block.
902 cir::AllocaOp alloca;
903 {
904 mlir::OpBuilder::InsertionGuard guard(rewriter);
905 rewriter.setInsertionPointToStart(&funcEntryBlock);
906 cir::CIRDataLayout dataLayout(
907 funcOp->getParentOfType<mlir::ModuleOp>());
908 uint64_t alignment =
909 dataLayout.getAlignment(operand.getType(), true).value();
910 cir::PointerType ptrType = cir::PointerType::get(operand.getType());
911 alloca = cir::AllocaOp::create(rewriter, loc, ptrType,
912 operand.getType(), "__ret_operand_tmp",
913 rewriter.getI64IntegerAttr(alignment));
914 }
915
916 // Store the operand value at the original return location.
917 {
918 mlir::OpBuilder::InsertionGuard guard(rewriter);
919 rewriter.setInsertionPoint(exitOp);
920 cir::StoreOp::create(rewriter, loc, operand, alloca,
921 /*isVolatile=*/false,
922 /*alignment=*/mlir::IntegerAttr(),
923 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
924 }
925
926 // Reload the value from the temporary alloca in the destination block.
927 rewriter.setInsertionPointToEnd(destBlock);
928 auto loaded = cir::LoadOp::create(
929 rewriter, loc, alloca, /*isDeref=*/false,
930 /*isVolatile=*/false, /*alignment=*/mlir::IntegerAttr(),
931 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
932 returnValues.push_back(loaded);
933 }
934 }
935 }
936
937 // Create the appropriate terminator for an exit operation in the dispatch
938 // block. For return ops with operands, this handles the dominance issue by
939 // either moving the operand's defining op to the dispatch block (if it's a
940 // trivial use) or by storing to a temporary alloca and loading it.
941 mlir::LogicalResult
942 createExitTerminator(mlir::Operation *exitOp, mlir::Location loc,
943 mlir::Block *continueBlock,
944 mlir::PatternRewriter &rewriter) const {
945 return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(exitOp)
946 .Case<cir::YieldOp>([&](auto) {
947 // Yield becomes a branch to continue block.
948 cir::BrOp::create(rewriter, loc, continueBlock);
949 return mlir::success();
950 })
951 .Case<cir::BreakOp>([&](auto) {
952 // Break is preserved for later lowering by enclosing switch/loop.
953 cir::BreakOp::create(rewriter, loc);
954 return mlir::success();
955 })
956 .Case<cir::ContinueOp>([&](auto) {
957 // Continue is preserved for later lowering by enclosing loop.
958 cir::ContinueOp::create(rewriter, loc);
959 return mlir::success();
960 })
961 .Case<cir::ReturnOp>([&](auto returnOp) {
962 // Return from the cleanup exit. Note, if this is a return inside a
963 // nested cleanup scope, the flattening of the outer scope will handle
964 // branching through the outer cleanup.
965 if (returnOp.hasOperand()) {
966 llvm::SmallVector<mlir::Value, 2> returnValues;
967 getReturnOpOperands(returnOp, exitOp, loc, rewriter, returnValues);
968 cir::ReturnOp::create(rewriter, loc, returnValues);
969 } else {
970 cir::ReturnOp::create(rewriter, loc);
971 }
972 return mlir::success();
973 })
974 .Case<cir::GotoOp>([&](auto gotoOp) {
975 // Correct goto handling requires determining whether the goto
976 // branches out of the cleanup scope or stays within it.
977 // Although the goto necessarily exits the cleanup scope in the
978 // case where it is the only exit from the scope, it is left
979 // as unimplemented for now so that it can be generalized when
980 // multi-exit flattening is implemented.
981 cir::UnreachableOp::create(rewriter, loc);
982 return gotoOp.emitError(
983 "goto in cleanup scope is not yet implemented");
984 })
985 .Default([&](mlir::Operation *op) {
986 cir::UnreachableOp::create(rewriter, loc);
987 return op->emitError(
988 "unexpected exit operation in cleanup scope body");
989 });
990 }
991
992#ifndef NDEBUG
993 // Check that no block other than the last one in a region exits the region.
994 static bool regionExitsOnlyFromLastBlock(mlir::Region &region) {
995 for (mlir::Block &block : region) {
996 if (&block == &region.back())
997 continue;
998 bool expectedTerminator =
999 llvm::TypeSwitch<mlir::Operation *, bool>(block.getTerminator())
1000 // It is theoretically possible to have a cleanup block with
1001 // any of the following exits in non-final blocks, but we won't
1002 // currently generate any CIR that does that, and being able to
1003 // assume that it doesn't happen simplifies the implementation.
1004 // If we ever need to handle this case, the code will need to
1005 // be updated to handle it.
1006 .Case<cir::YieldOp, cir::ReturnOp, cir::ResumeFlatOp,
1007 cir::ContinueOp, cir::BreakOp, cir::GotoOp>(
1008 [](auto) { return false; })
1009 // We expect that call operations have not yet been rewritten
1010 // as try_call operations. A call can unwind out of the cleanup
1011 // scope, but we will be handling that during flattening. The
1012 // only case where a try_call could be present inside an
1013 // unflattened cleanup region is if the cleanup contained a
1014 // nested try-catch region, and that isn't expected as of the
1015 // time of this implementation. If it does, this could be
1016 // updated to tolerate it.
1017 .Case<cir::TryCallOp>([](auto) { return false; })
1018 // Likewise, we don't expect to find an EH dispatch operation
1019 // because we weren't expecting try-catch regions nested in the
1020 // cleanup region.
1021 .Case<cir::EhDispatchOp>([](auto) { return false; })
1022 // In theory, it would be possible to have a flattened switch
1023 // operation that does not exit the cleanup region. For now,
1024 // that's not happening.
1025 .Case<cir::SwitchFlatOp>([](auto) { return false; })
1026 // These aren't expected either, but if they occur, they don't
1027 // exit the region, so that's OK.
1028 .Case<cir::UnreachableOp, cir::TrapOp>([](auto) { return true; })
1029 // Indirect branches are not expected.
1030 .Case<cir::IndirectBrOp>([](auto) { return false; })
1031 // We do expect branches, but we don't expect them to leave
1032 // the region.
1033 .Case<cir::BrOp>([&](cir::BrOp brOp) {
1034 assert(brOp.getDest()->getParent() == &region &&
1035 "branch destination is not in the region");
1036 return true;
1037 })
1038 .Case<cir::BrCondOp>([&](cir::BrCondOp brCondOp) {
1039 assert(brCondOp.getDestTrue()->getParent() == &region &&
1040 "branch destination is not in the region");
1041 assert(brCondOp.getDestFalse()->getParent() == &region &&
1042 "branch destination is not in the region");
1043 return true;
1044 })
1045 // What else could there be?
1046 .Default([](mlir::Operation *) -> bool {
1047 llvm_unreachable("unexpected terminator in cleanup region");
1048 });
1049 if (!expectedTerminator)
1050 return false;
1051 }
1052 return true;
1053 }
1054#endif
1055
1056 // Build the EH cleanup block structure by cloning the cleanup region. The
1057 // cloned entry block gets an !cir.eh_token argument and a cir.begin_cleanup
1058 // inserted at the top. All cir.yield terminators that might exit the cleanup
1059 // region are replaced with cir.end_cleanup + cir.resume.
1060 //
1061 // For a single-block cleanup region, this produces:
1062 //
1063 // ^eh_cleanup(%eh_token : !cir.eh_token):
1064 // %ct = cir.begin_cleanup %eh_token : !cir.eh_token -> !cir.cleanup_token
1065 // <cloned cleanup operations>
1066 // cir.end_cleanup %ct : !cir.cleanup_token
1067 // cir.resume %eh_token : !cir.eh_token
1068 //
1069 // For a multi-block cleanup region (e.g. containing a flattened cir.if),
1070 // the same wrapping is applied around the cloned block structure: the entry
1071 // block gets begin_cleanup and all exit blocks (those terminated by yield)
1072 // get end_cleanup + resume.
1073 //
1074 // If this cleanup scope is nested within a TryOp, the resume will be updated
1075 // to branch to the catch dispatch block of the enclosing try operation when
1076 // the TryOp is flattened.
1077 mlir::Block *buildEHCleanupBlocks(cir::CleanupScopeOp cleanupOp,
1078 mlir::Location loc,
1079 mlir::Block *insertBefore,
1080 mlir::PatternRewriter &rewriter) const {
1081 assert(regionExitsOnlyFromLastBlock(cleanupOp.getCleanupRegion()) &&
1082 "cleanup region has exits in non-final blocks");
1083
1084 // Track the block before the insertion point so we can find the cloned
1085 // blocks after cloning.
1086 mlir::Block *blockBeforeClone = insertBefore->getPrevNode();
1087
1088 // Clone the entire cleanup region before insertBefore.
1089 rewriter.cloneRegionBefore(cleanupOp.getCleanupRegion(), insertBefore);
1090
1091 // Find the first cloned block.
1092 mlir::Block *clonedEntry = blockBeforeClone
1093 ? blockBeforeClone->getNextNode()
1094 : &insertBefore->getParent()->front();
1095
1096 // Add the eh_token argument to the cloned entry block and insert
1097 // begin_cleanup at the top.
1098 auto ehTokenType = cir::EhTokenType::get(rewriter.getContext());
1099 mlir::Value ehToken = clonedEntry->addArgument(ehTokenType, loc);
1100
1101 rewriter.setInsertionPointToStart(clonedEntry);
1102 auto beginCleanup = cir::BeginCleanupOp::create(rewriter, loc, ehToken);
1103
1104 // Replace the yield terminator in the last cloned block with
1105 // end_cleanup + resume.
1106 mlir::Block *lastClonedBlock = insertBefore->getPrevNode();
1107 auto yieldOp =
1108 mlir::dyn_cast<cir::YieldOp>(lastClonedBlock->getTerminator());
1109 if (yieldOp) {
1110 rewriter.setInsertionPoint(yieldOp);
1111 cir::EndCleanupOp::create(rewriter, loc, beginCleanup.getCleanupToken());
1112 rewriter.replaceOpWithNewOp<cir::ResumeOp>(yieldOp, ehToken);
1113 } else {
1114 cleanupOp->emitError("Not yet implemented: cleanup region terminated "
1115 "with non-yield operation");
1116 }
1117
1118 return clonedEntry;
1119 }
1120
1121 // Flatten a cleanup scope. The body region's exits branch to the cleanup
1122 // block, and the cleanup block branches to destination blocks whose contents
1123 // depend on the type of operation that exited the body region. Yield becomes
1124 // a branch to the block after the cleanup scope, break and continue are
1125 // preserved for later lowering by enclosing switch or loop, and return
1126 // is preserved as is.
1127 //
1128 // If there are multiple exits from the cleanup body, a destination slot and
1129 // switch dispatch are used to continue to the correct destination after the
1130 // cleanup is complete. A destination slot alloca is created at the function
1131 // entry block. Each exit operation is replaced by a store of its unique ID to
1132 // the destination slot and a branch to cleanup. An operation is appended to
1133 // the to branch to a dispatch block that loads the destination slot and uses
1134 // switch.flat to branch to the correct destination.
1135 //
1136 // If the cleanup scope requires EH cleanup, any call operations in the body
1137 // that may throw are replaced with cir.try_call operations that unwind to an
1138 // EH cleanup block. The cleanup block(s) will be terminated with a cir.resume
1139 // operation. If this cleanup scope is enclosed by a try operation, the
1140 // flattening of the try operation flattening will replace the cir.resume with
1141 // a branch to a catch dispatch block. Otherwise, the cir.resume operation
1142 // remains in place and will unwind to the caller.
1143 mlir::LogicalResult
1144 flattenCleanup(cir::CleanupScopeOp cleanupOp,
1145 llvm::SmallVectorImpl<CleanupExit> &exits,
1146 llvm::SmallVectorImpl<cir::CallOp> &callsToRewrite,
1147 llvm::SmallVectorImpl<cir::ResumeOp> &resumeOpsToChain,
1148 mlir::PatternRewriter &rewriter) const {
1149 mlir::Location loc = cleanupOp.getLoc();
1150 cir::CleanupKind cleanupKind = cleanupOp.getCleanupKind();
1151 bool hasNormalCleanup = cleanupKind == cir::CleanupKind::Normal ||
1152 cleanupKind == cir::CleanupKind::All;
1153 bool hasEHCleanup = cleanupKind == cir::CleanupKind::EH ||
1154 cleanupKind == cir::CleanupKind::All;
1155 bool isMultiExit = exits.size() > 1;
1156
1157 // Get references to region blocks before inlining.
1158 mlir::Block *bodyEntry = &cleanupOp.getBodyRegion().front();
1159 mlir::Block *cleanupEntry = &cleanupOp.getCleanupRegion().front();
1160 mlir::Block *cleanupExit = &cleanupOp.getCleanupRegion().back();
1161 assert(regionExitsOnlyFromLastBlock(cleanupOp.getCleanupRegion()) &&
1162 "cleanup region has exits in non-final blocks");
1163 auto cleanupYield = dyn_cast<cir::YieldOp>(cleanupExit->getTerminator());
1164 if (!cleanupYield) {
1165 return rewriter.notifyMatchFailure(cleanupOp,
1166 "Not yet implemented: cleanup region "
1167 "terminated with non-yield operation");
1168 }
1169
1170 // For multiple exits from the body region, get or create a destination slot
1171 // at function entry. The slot is shared across all cleanup scopes in the
1172 // function. This is only needed if the cleanup scope requires normal
1173 // cleanup.
1174 cir::AllocaOp destSlot;
1175 if (isMultiExit && hasNormalCleanup) {
1176 auto funcOp = cleanupOp->getParentOfType<cir::FuncOp>();
1177 if (!funcOp)
1178 return cleanupOp->emitError("cleanup scope not inside a function");
1179 destSlot = getOrCreateCleanupDestSlot(funcOp, rewriter, loc);
1180 }
1181
1182 // Split the current block to create the insertion point.
1183 mlir::Block *currentBlock = rewriter.getInsertionBlock();
1184 mlir::Block *continueBlock =
1185 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
1186
1187 // Build EH cleanup blocks if needed. This must be done before inlining
1188 // the cleanup region since buildEHCleanupBlocks clones from it. The unwind
1189 // block is inserted before the EH cleanup entry so that the final layout
1190 // is: body -> normal cleanup -> exit -> unwind -> EH cleanup -> continue.
1191 // EH cleanup blocks are needed when there are throwing calls that need to
1192 // be rewritten to try_call, or when there are resume ops from
1193 // already-flattened inner cleanup scopes that need to chain through this
1194 // cleanup's EH handler.
1195 mlir::Block *unwindBlock = nullptr;
1196 mlir::Block *ehCleanupEntry = nullptr;
1197 if (hasEHCleanup &&
1198 (!callsToRewrite.empty() || !resumeOpsToChain.empty())) {
1199 ehCleanupEntry =
1200 buildEHCleanupBlocks(cleanupOp, loc, continueBlock, rewriter);
1201 // The unwind block is only needed when there are throwing calls that
1202 // need a shared unwind destination. Resume ops from inner cleanups
1203 // branch directly to the EH cleanup entry.
1204 if (!callsToRewrite.empty())
1205 unwindBlock = buildUnwindBlock(ehCleanupEntry, /*isCleanupOnly=*/true,
1206 loc, ehCleanupEntry, rewriter);
1207 }
1208
1209 // All normal flow blocks are inserted before this point — either before
1210 // the unwind block (if it exists), or before the EH cleanup entry (if EH
1211 // cleanup exists but no unwind block is needed), or before the continue
1212 // block.
1213 mlir::Block *normalInsertPt =
1214 unwindBlock ? unwindBlock
1215 : (ehCleanupEntry ? ehCleanupEntry : continueBlock);
1216
1217 // Inline the body region.
1218 rewriter.inlineRegionBefore(cleanupOp.getBodyRegion(), normalInsertPt);
1219
1220 // Inline the cleanup region for the normal cleanup path.
1221 if (hasNormalCleanup)
1222 rewriter.inlineRegionBefore(cleanupOp.getCleanupRegion(), normalInsertPt);
1223
1224 // Branch from current block to body entry.
1225 rewriter.setInsertionPointToEnd(currentBlock);
1226 cir::BrOp::create(rewriter, loc, bodyEntry);
1227
1228 // Handle normal exits.
1229 mlir::LogicalResult result = mlir::success();
1230 if (hasNormalCleanup) {
1231 // Create the exit/dispatch block (after cleanup, before continue).
1232 mlir::Block *exitBlock = rewriter.createBlock(normalInsertPt);
1233
1234 // Rewrite the cleanup region's yield to branch to exit block.
1235 rewriter.setInsertionPoint(cleanupYield);
1236 rewriter.replaceOpWithNewOp<cir::BrOp>(cleanupYield, exitBlock);
1237
1238 if (isMultiExit) {
1239 // Build the dispatch switch in the exit block.
1240 rewriter.setInsertionPointToEnd(exitBlock);
1241
1242 // Load the destination slot value.
1243 auto slotValue = cir::LoadOp::create(
1244 rewriter, loc, destSlot, /*isDeref=*/false,
1245 /*isVolatile=*/false, /*alignment=*/mlir::IntegerAttr(),
1246 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
1247
1248 // Create destination blocks for each exit and collect switch case info.
1249 llvm::SmallVector<mlir::APInt, 8> caseValues;
1250 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
1251 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
1252 cir::IntType s32Type =
1253 cir::IntType::get(rewriter.getContext(), 32, /*isSigned=*/true);
1254
1255 for (const CleanupExit &exit : exits) {
1256 // Create a block for this destination.
1257 mlir::Block *destBlock = rewriter.createBlock(normalInsertPt);
1258 rewriter.setInsertionPointToEnd(destBlock);
1259 result =
1260 createExitTerminator(exit.exitOp, loc, continueBlock, rewriter);
1261
1262 // Add to switch cases.
1263 caseValues.push_back(
1264 llvm::APInt(32, static_cast<uint64_t>(exit.destinationId), true));
1265 caseDestinations.push_back(destBlock);
1266 caseOperands.push_back(mlir::ValueRange());
1267
1268 // Replace the original exit op with: store dest ID, branch to
1269 // cleanup.
1270 rewriter.setInsertionPoint(exit.exitOp);
1271 auto destIdConst = cir::ConstantOp::create(
1272 rewriter, loc, cir::IntAttr::get(s32Type, exit.destinationId));
1273 cir::StoreOp::create(rewriter, loc, destIdConst, destSlot,
1274 /*isVolatile=*/false,
1275 /*alignment=*/mlir::IntegerAttr(),
1276 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
1277 rewriter.replaceOpWithNewOp<cir::BrOp>(exit.exitOp, cleanupEntry);
1278
1279 // If the exit terminator creation failed, we're going to end up with
1280 // partially flattened code, but we'll also have reported an error so
1281 // that's OK. We need to finish out this function to keep the IR in a
1282 // valid state to help diagnose the error. This is a temporary
1283 // possibility during development. It shouldn't ever happen after the
1284 // implementation is complete.
1285 if (result.failed())
1286 break;
1287 }
1288
1289 // Create the default destination (unreachable).
1290 mlir::Block *defaultBlock = rewriter.createBlock(normalInsertPt);
1291 rewriter.setInsertionPointToEnd(defaultBlock);
1292 cir::UnreachableOp::create(rewriter, loc);
1293
1294 // Build the switch.flat operation in the exit block.
1295 rewriter.setInsertionPointToEnd(exitBlock);
1296 cir::SwitchFlatOp::create(rewriter, loc, slotValue, defaultBlock,
1297 mlir::ValueRange(), caseValues,
1298 caseDestinations, caseOperands);
1299 } else {
1300 // Single exit: put the appropriate terminator directly in the exit
1301 // block.
1302 rewriter.setInsertionPointToEnd(exitBlock);
1303 mlir::Operation *exitOp = exits[0].exitOp;
1304 result = createExitTerminator(exitOp, loc, continueBlock, rewriter);
1305
1306 // Replace body exit with branch to cleanup entry.
1307 rewriter.setInsertionPoint(exitOp);
1308 rewriter.replaceOpWithNewOp<cir::BrOp>(exitOp, cleanupEntry);
1309 }
1310 } else {
1311 // EH-only cleanup: normal exits skip the cleanup entirely.
1312 // Replace yield exits with branches to the continue block.
1313 for (CleanupExit &exit : exits) {
1314 if (isa<cir::YieldOp>(exit.exitOp)) {
1315 rewriter.setInsertionPoint(exit.exitOp);
1316 rewriter.replaceOpWithNewOp<cir::BrOp>(exit.exitOp, continueBlock);
1317 }
1318 // Non-yield exits (break, continue, return) stay as-is since no normal
1319 // cleanup is needed.
1320 }
1321 }
1322
1323 // Replace non-nothrow calls with try_call operations. All calls within
1324 // this cleanup scope share the same unwind destination.
1325 if (hasEHCleanup) {
1326 for (cir::CallOp callOp : callsToRewrite)
1327 replaceCallWithTryCall(callOp, unwindBlock, loc, rewriter);
1328 }
1329
1330 // Handle throwing calls in EH cleanup blocks. When an exception is thrown
1331 // during cleanup code that runs on the exception unwind path, the C++
1332 // standard requires that std::terminate() be called. Replace such calls
1333 // with try_call operations that unwind to a terminate block containing
1334 // cir.eh.initiate + cir.eh.terminate.
1335 if (ehCleanupEntry) {
1336 llvm::SmallVector<cir::CallOp> ehCleanupThrowingCalls;
1337 for (mlir::Block *block = ehCleanupEntry; block != continueBlock;
1338 block = block->getNextNode()) {
1339 block->walk([&](cir::CallOp callOp) {
1340 if (!callOp.getNothrow())
1341 ehCleanupThrowingCalls.push_back(callOp);
1342 });
1343 }
1344 if (!ehCleanupThrowingCalls.empty()) {
1345 mlir::Block *terminateBlock =
1346 buildTerminateUnwindBlock(loc, continueBlock, rewriter);
1347 for (cir::CallOp callOp : ehCleanupThrowingCalls)
1348 replaceCallWithTryCall(callOp, terminateBlock, loc, rewriter);
1349 }
1350 }
1351
1352 // Chain inner EH cleanup resume ops to this cleanup's EH handler.
1353 // Each cir.resume from an already-flattened inner cleanup is replaced
1354 // with a branch to the outer EH cleanup entry, passing the eh_token
1355 // from the inner's begin_cleanup so that the same in-flight exception
1356 // flows through the outer cleanup before unwinding to the caller.
1357 if (ehCleanupEntry) {
1358 for (cir::ResumeOp resumeOp : resumeOpsToChain) {
1359 mlir::Value ehToken = resumeOp.getEhToken();
1360 rewriter.setInsertionPoint(resumeOp);
1361 rewriter.replaceOpWithNewOp<cir::BrOp>(
1362 resumeOp, mlir::ValueRange{ehToken}, ehCleanupEntry);
1363 }
1364 }
1365
1366 // Erase the original cleanup scope op.
1367 rewriter.eraseOp(cleanupOp);
1368
1369 // Always return success because the IR has been modified (blocks split,
1370 // regions inlined, ops erased, etc.). The MLIR pattern rewriter contract
1371 // requires that if a pattern modifies IR, it must return success(). Any
1372 // errors from unsupported exit operations (e.g. goto) have already been
1373 // reported via emitError and an unreachable terminator was placed as a
1374 // placeholder.
1375 return mlir::success();
1376 }
1377
1378 mlir::LogicalResult
1379 matchAndRewrite(cir::CleanupScopeOp cleanupOp,
1380 mlir::PatternRewriter &rewriter) const override {
1381 mlir::OpBuilder::InsertionGuard guard(rewriter);
1382
1383 // All nested structured CIR ops must be flattened before the cleanup scope.
1384 // Operations like loops, switches, scopes, and ifs may contain exits
1385 // (return, break, continue) that the cleanup scope will replace with
1386 // branches to the cleanup entry. If those exits are inside a structured
1387 // op's region, the branch would reference a block outside that region,
1388 // which is invalid. Fail the match so they are processed first.
1389 //
1390 // Before checking, erase any trivially dead nested cleanup scopes. These
1391 // arise from deactivated cleanups (e.g. partial-construction guards for
1392 // lambda captures). The greedy rewriter may have already DCE'd them, but
1393 // when a trivially dead nested op is erased first, the parent isn't always
1394 // re-added to the worklist, so we handle it here.
1395 llvm::SmallVector<cir::CleanupScopeOp> deadNestedOps;
1396 cleanupOp.getBodyRegion().walk([&](cir::CleanupScopeOp nested) {
1397 if (mlir::isOpTriviallyDead(nested))
1398 deadNestedOps.push_back(nested);
1399 });
1400 for (auto op : deadNestedOps)
1401 rewriter.eraseOp(op);
1402
1403 if (hasNestedOpsToFlatten(cleanupOp.getBodyRegion()))
1404 return mlir::failure();
1405
1406 cir::CleanupKind cleanupKind = cleanupOp.getCleanupKind();
1407
1408 // Collect all exits from the body region.
1409 llvm::SmallVector<CleanupExit> exits;
1410 int nextId = 0;
1411 collectExits(cleanupOp.getBodyRegion(), exits, nextId);
1412
1413 assert(!exits.empty() && "cleanup scope body has no exit");
1414
1415 // Collect non-nothrow calls that need to be converted to try_call.
1416 // This is only needed for EH and All cleanup kinds, but the vector
1417 // will simply be empty for Normal cleanup.
1418 llvm::SmallVector<cir::CallOp> callsToRewrite;
1419 if (cleanupKind != cir::CleanupKind::Normal)
1420 collectThrowingCalls(cleanupOp.getBodyRegion(), callsToRewrite);
1421
1422 // Collect resume ops from already-flattened inner cleanup scopes that
1423 // need to chain through this cleanup's EH handler.
1424 llvm::SmallVector<cir::ResumeOp> resumeOpsToChain;
1425 if (cleanupKind != cir::CleanupKind::Normal)
1426 collectResumeOps(cleanupOp.getBodyRegion(), resumeOpsToChain);
1427
1428 return flattenCleanup(cleanupOp, exits, callsToRewrite, resumeOpsToChain,
1429 rewriter);
1430 }
1431};
1432
1433// Trace an !cir.eh_token value back through block arguments to find the
1434// cir.eh.initiate operation that defines it. Returns {} if the defining op
1435// cannot be found (e.g. multiple predecessors).
1436static cir::EhInitiateOp traceToEhInitiate(mlir::Value ehToken) {
1437 while (ehToken) {
1438 if (auto initiate = ehToken.getDefiningOp<cir::EhInitiateOp>())
1439 return initiate;
1440 auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(ehToken);
1441 if (!blockArg)
1442 return {};
1443 mlir::Block *pred = blockArg.getOwner()->getSinglePredecessor();
1444 if (!pred)
1445 return {};
1446 auto brOp = mlir::dyn_cast<cir::BrOp>(pred->getTerminator());
1447 if (!brOp)
1448 return {};
1449 ehToken = brOp.getDestOperands()[blockArg.getArgNumber()];
1450 }
1451 return {};
1452}
1453
1454class CIRTryOpFlattening : public mlir::OpRewritePattern<cir::TryOp> {
1455public:
1456 using OpRewritePattern<cir::TryOp>::OpRewritePattern;
1457
1458 // Build the catch dispatch block with a cir.eh.dispatch operation.
1459 // The dispatch block receives an !cir.eh_token argument and dispatches
1460 // to the appropriate catch handler blocks based on exception types.
1461 mlir::Block *buildCatchDispatchBlock(
1462 cir::TryOp tryOp, mlir::ArrayAttr handlerTypes,
1463 llvm::SmallVectorImpl<mlir::Block *> &catchHandlerBlocks,
1464 mlir::Location loc, mlir::Block *insertBefore,
1465 mlir::PatternRewriter &rewriter) const {
1466 mlir::Block *dispatchBlock = rewriter.createBlock(insertBefore);
1467 auto ehTokenType = cir::EhTokenType::get(rewriter.getContext());
1468 mlir::Value ehToken = dispatchBlock->addArgument(ehTokenType, loc);
1469
1470 rewriter.setInsertionPointToEnd(dispatchBlock);
1471
1472 // Build the catch types and destinations for the dispatch.
1473 llvm::SmallVector<mlir::Attribute> catchTypeAttrs;
1474 llvm::SmallVector<mlir::Block *> catchDests;
1475 mlir::Block *defaultDest = nullptr;
1476 bool defaultIsCatchAll = false;
1477
1478 for (auto [typeAttr, handlerBlock] :
1479 llvm::zip(handlerTypes, catchHandlerBlocks)) {
1480 if (mlir::isa<cir::CatchAllAttr>(typeAttr)) {
1481 assert(!defaultDest && "multiple catch_all or unwind handlers");
1482 defaultDest = handlerBlock;
1483 defaultIsCatchAll = true;
1484 } else if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
1485 assert(!defaultDest && "multiple catch_all or unwind handlers");
1486 defaultDest = handlerBlock;
1487 defaultIsCatchAll = false;
1488 } else {
1489 // This is a typed catch handler (GlobalViewAttr with type info).
1490 catchTypeAttrs.push_back(typeAttr);
1491 catchDests.push_back(handlerBlock);
1492 }
1493 }
1494
1495 assert(defaultDest && "dispatch must have a catch_all or unwind handler");
1496
1497 mlir::ArrayAttr catchTypesArrayAttr;
1498 if (!catchTypeAttrs.empty())
1499 catchTypesArrayAttr = rewriter.getArrayAttr(catchTypeAttrs);
1500
1501 cir::EhDispatchOp::create(rewriter, loc, ehToken, catchTypesArrayAttr,
1502 defaultIsCatchAll, defaultDest, catchDests);
1503
1504 return dispatchBlock;
1505 }
1506
1507 // Flatten a single catch handler region. Each handler region has an
1508 // !cir.eh_token argument and starts with cir.begin_catch, followed by
1509 // a cir.cleanup.scope containing the handler body (with cir.end_catch in
1510 // its cleanup region), and ending with cir.yield.
1511 //
1512 // After flattening, the handler region becomes a block that receives the
1513 // eh_token, calls begin_catch, runs the handler body inline, calls
1514 // end_catch, and branches to the continue block.
1515 //
1516 // The cleanup scope inside the catch handler is expected to have been
1517 // flattened before we get here, so what we see in the handler region is
1518 // already flat code with begin_catch at the top and end_catch in any place
1519 // that we would exit the catch handler. We just need to inline the region
1520 // and fix up terminators.
1521 mlir::Block *flattenCatchHandler(mlir::Region &handlerRegion,
1522 mlir::Block *continueBlock,
1523 mlir::Location loc,
1524 mlir::Block *insertBefore,
1525 mlir::PatternRewriter &rewriter) const {
1526 // The handler region entry block has the !cir.eh_token argument.
1527 mlir::Block *handlerEntry = &handlerRegion.front();
1528
1529 // Inline the handler region before insertBefore.
1530 rewriter.inlineRegionBefore(handlerRegion, insertBefore);
1531
1532 // Replace yield terminators in the handler with branches to continue.
1533 for (mlir::Block &block : llvm::make_range(handlerEntry->getIterator(),
1534 insertBefore->getIterator())) {
1535 if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator())) {
1536 // Verify that end_catch is the last non-branch operation before
1537 // this yield. After cleanup scope flattening, end_catch may be
1538 // in a predecessor block rather than immediately before the yield.
1539 // Walk back through predecessors (including multi-predecessor
1540 // blocks), verifying that each intermediate block contains only a
1541 // branch terminator, until we find end_catch as the last
1542 // non-terminator in some block.
1543 // Verify that end_catch is reachable on some predecessor path
1544 // before this yield. After cleanup scope flattening, end_catch
1545 // may be separated from yield by conditional branches (e.g.,
1546 // from flattened cir.if inside the catch body).
1547 assert(([&]() {
1548 if (mlir::Operation *prev = yieldOp->getPrevNode())
1549 return isa<cir::EndCatchOp>(prev);
1550 llvm::SmallPtrSet<mlir::Block *, 8> visited;
1551 llvm::SmallVector<mlir::Block *, 4> worklist;
1552 for (mlir::Block *pred : block.getPredecessors())
1553 worklist.push_back(pred);
1554 while (!worklist.empty()) {
1555 mlir::Block *b = worklist.pop_back_val();
1556 if (!visited.insert(b).second)
1557 continue;
1558 mlir::Operation *term = b->getTerminator();
1559 if (mlir::Operation *prev = term->getPrevNode()) {
1560 if (isa<cir::EndCatchOp>(prev))
1561 return true;
1562 }
1563 for (mlir::Block *pred : b->getPredecessors())
1564 worklist.push_back(pred);
1565 }
1566 return false;
1567 }()) &&
1568 "expected end_catch reachable before yield "
1569 "in catch handler");
1570 rewriter.setInsertionPoint(yieldOp);
1571 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, continueBlock);
1572 }
1573 }
1574
1575 return handlerEntry;
1576 }
1577
1578 // Flatten an unwind handler region. The unwind region just contains a
1579 // cir.resume that continues unwinding. We inline it and leave the resume
1580 // in place. If this try op is nested inside an EH cleanup or another try op,
1581 // the enclosing op will rewrite the resume as a branch to its cleanup or
1582 // dispatch block when it is flattened. Otherwise, the resume will unwind to
1583 // the caller.
1584 mlir::Block *flattenUnwindHandler(mlir::Region &unwindRegion,
1585 mlir::Location loc,
1586 mlir::Block *insertBefore,
1587 mlir::PatternRewriter &rewriter) const {
1588 mlir::Block *unwindEntry = &unwindRegion.front();
1589 rewriter.inlineRegionBefore(unwindRegion, insertBefore);
1590 return unwindEntry;
1591 }
1592
1593 mlir::LogicalResult
1594 matchAndRewrite(cir::TryOp tryOp,
1595 mlir::PatternRewriter &rewriter) const override {
1596 // All nested structured CIR ops must be flattened before the try op.
1597 // Cleanup scopes and nested try ops need to be flat so EH cleanup is
1598 // properly handled. Other structured ops (scopes, ifs, loops, switches,
1599 // ternaries) must be flat because replaceCallWithTryCall creates try_call
1600 // ops whose unwind destination is outside the structured op's region,
1601 // which would be an invalid cross-region reference.
1602 for (mlir::Region &region : tryOp->getRegions())
1603 if (hasNestedOpsToFlatten(region))
1604 return mlir::failure();
1605
1606 mlir::OpBuilder::InsertionGuard guard(rewriter);
1607 mlir::Location loc = tryOp.getLoc();
1608
1609 mlir::ArrayAttr handlerTypes = tryOp.getHandlerTypesAttr();
1610 mlir::MutableArrayRef<mlir::Region> handlerRegions =
1611 tryOp.getHandlerRegions();
1612
1613 // Collect throwing calls in the try body.
1614 llvm::SmallVector<cir::CallOp> callsToRewrite;
1615 collectThrowingCalls(tryOp.getTryRegion(), callsToRewrite);
1616
1617 // Collect resume ops from already-flattened cleanup scopes in the try body.
1618 llvm::SmallVector<cir::ResumeOp> resumeOpsToChain;
1619 collectResumeOps(tryOp.getTryRegion(), resumeOpsToChain);
1620
1621 // Split the current block and inline the try body.
1622 mlir::Block *currentBlock = rewriter.getInsertionBlock();
1623 mlir::Block *continueBlock =
1624 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
1625
1626 // Get references to try body blocks before inlining.
1627 mlir::Block *bodyEntry = &tryOp.getTryRegion().front();
1628 mlir::Block *bodyExit = &tryOp.getTryRegion().back();
1629
1630 // Inline the try body region before the continue block.
1631 rewriter.inlineRegionBefore(tryOp.getTryRegion(), continueBlock);
1632
1633 // Branch from the current block to the body entry.
1634 rewriter.setInsertionPointToEnd(currentBlock);
1635 cir::BrOp::create(rewriter, loc, bodyEntry);
1636
1637 // Replace the try body's yield terminator with a branch to continue.
1638 if (auto bodyYield = dyn_cast<cir::YieldOp>(bodyExit->getTerminator())) {
1639 rewriter.setInsertionPoint(bodyYield);
1640 rewriter.replaceOpWithNewOp<cir::BrOp>(bodyYield, continueBlock);
1641 }
1642
1643 // If there are no handlers, we're done.
1644 if (!handlerTypes || handlerTypes.empty()) {
1645 rewriter.eraseOp(tryOp);
1646 return mlir::success();
1647 }
1648
1649 // If there are no throwing calls and no resume ops from inner cleanup
1650 // scopes, exceptions cannot reach the catch handlers. Drop all uses
1651 // from the (unreachable) handler regions before erasing the try op,
1652 // since handler ops may reference values that were inlined from the
1653 // try body into the parent block.
1654 if (callsToRewrite.empty() && resumeOpsToChain.empty()) {
1655 for (mlir::Region &handlerRegion : handlerRegions)
1656 for (mlir::Block &block : handlerRegion)
1657 block.dropAllDefinedValueUses();
1658 rewriter.eraseOp(tryOp);
1659 return mlir::success();
1660 }
1661
1662 // Build the catch handler blocks.
1663
1664 // First, flatten all handler regions and collect the entry blocks.
1665 llvm::SmallVector<mlir::Block *> catchHandlerBlocks;
1666
1667 for (const auto &[idx, typeAttr] : llvm::enumerate(handlerTypes)) {
1668 mlir::Region &handlerRegion = handlerRegions[idx];
1669
1670 if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
1671 mlir::Block *unwindEntry =
1672 flattenUnwindHandler(handlerRegion, loc, continueBlock, rewriter);
1673 catchHandlerBlocks.push_back(unwindEntry);
1674 } else {
1675 mlir::Block *handlerEntry = flattenCatchHandler(
1676 handlerRegion, continueBlock, loc, continueBlock, rewriter);
1677 catchHandlerBlocks.push_back(handlerEntry);
1678 }
1679 }
1680
1681 // Build the catch dispatch block.
1682 mlir::Block *dispatchBlock =
1683 buildCatchDispatchBlock(tryOp, handlerTypes, catchHandlerBlocks, loc,
1684 catchHandlerBlocks.front(), rewriter);
1685
1686 // Check whether the try has a catch-all handler. When catch-all is
1687 // present, the personality function will always stop unwinding at this
1688 // frame (because catch-all matches every exception type). The LLVM
1689 // landingpad therefore needs "catch ptr null" rather than "cleanup".
1690 // The downstream pipeline (EHABILowering + LowerToLLVM) emits
1691 // "catch ptr null" when the EhInitiateOp has neither cleanup nor typed
1692 // catch types, so we clear the cleanup flag on every EhInitiateOp that
1693 // feeds into a dispatch with a catch-all handler.
1694 bool hasCatchAll =
1695 handlerTypes && llvm::any_of(handlerTypes, [](mlir::Attribute attr) {
1696 return mlir::isa<cir::CatchAllAttr>(attr);
1697 });
1698
1699 // Build a block to be the unwind desination for throwing calls and replace
1700 // the calls with try_call ops. Note that the unwind block created here is
1701 // something different than the unwind handler that we may have created
1702 // above. The unwind handler continues unwinding after uncaught exceptions.
1703 // This is the block that will eventually become the landing pad for invoke
1704 // instructions.
1705 bool isCleanupOnly = tryOp.getCleanup() && !hasCatchAll;
1706 if (!callsToRewrite.empty()) {
1707 // Create a shared unwind block for all throwing calls.
1708 mlir::Block *unwindBlock = buildUnwindBlock(dispatchBlock, isCleanupOnly,
1709 loc, dispatchBlock, rewriter);
1710
1711 for (cir::CallOp callOp : callsToRewrite)
1712 replaceCallWithTryCall(callOp, unwindBlock, loc, rewriter);
1713 }
1714
1715 // Chain resume ops from inner cleanup scopes.
1716 // Resume ops from already-flattened cleanup scopes within the try body
1717 // should branch to the catch dispatch block instead of unwinding directly.
1718 for (cir::ResumeOp resumeOp : resumeOpsToChain) {
1719 // When there is a catch-all handler, clear the cleanup flag on the
1720 // cir.eh.initiate that produced this token. With catch-all, the LLVM
1721 // landingpad needs "catch ptr null" instead of "cleanup".
1722 if (hasCatchAll) {
1723 if (auto ehInitiate = traceToEhInitiate(resumeOp.getEhToken())) {
1724 rewriter.modifyOpInPlace(ehInitiate,
1725 [&] { ehInitiate.removeCleanupAttr(); });
1726 }
1727 }
1728
1729 mlir::Value ehToken = resumeOp.getEhToken();
1730 rewriter.setInsertionPoint(resumeOp);
1731 rewriter.replaceOpWithNewOp<cir::BrOp>(
1732 resumeOp, mlir::ValueRange{ehToken}, dispatchBlock);
1733 }
1734
1735 // Finally, erase the original try op ----
1736 rewriter.eraseOp(tryOp);
1737
1738 return mlir::success();
1739 }
1740};
1741
1742void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
1743 patterns
1744 .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
1745 CIRSwitchOpFlattening, CIRTernaryOpFlattening,
1746 CIRCleanupScopeOpFlattening, CIRTryOpFlattening>(
1747 patterns.getContext());
1748}
1749
1750void CIRFlattenCFGPass::runOnOperation() {
1751 RewritePatternSet patterns(&getContext());
1752 populateFlattenCFGPatterns(patterns);
1753
1754 // Collect operations to apply patterns.
1755 llvm::SmallVector<Operation *, 16> ops;
1756 getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
1757 if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp, CleanupScopeOp,
1758 TryOp>(op))
1759 ops.push_back(op);
1760 });
1761
1762 // Apply patterns.
1763 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
1764 signalPassFailure();
1765}
1766
1767} // namespace
1768
1769namespace mlir {
1770
1771std::unique_ptr<Pass> createCIRFlattenCFGPass() {
1772 return std::make_unique<CIRFlattenCFGPass>();
1773}
1774
1775} // namespace mlir
__device__ __2f16 b
mlir::Block * replaceCallWithTryCall(cir::CallOp callOp, mlir::Block *unwindDest, mlir::Location loc, mlir::RewriterBase &rewriter)
Replace a cir::CallOp with a cir::TryCallOp whose unwind destination is unwindDest.
llvm::APInt APInt
Definition FixedPoint.h:19
ASTEdit insertBefore(RangeSelector S, TextGenerator Replacement)
Inserts Replacement before S, leaving the source selected by \S unchanged.
unsigned long uint64_t
std::unique_ptr< Pass > createCIRFlattenCFGPass()
int const char * function
Definition c++config.h:31
float __ovld __cnfn step(float, float)
Returns 0.0 if x < edge, otherwise it returns 1.0.
static bool stackSaveOp()