clang 23.0.0git
FlattenCFG.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements pass that inlines CIR operations regions into the parent
10// function region.
11//
12//===----------------------------------------------------------------------===//
13
14#include "PassDetail.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/IR/Block.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Interfaces/SideEffectInterfaces.h"
20#include "mlir/Support/LogicalResult.h"
21#include "mlir/Transforms/DialectConversion.h"
22#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27#include "llvm/ADT/TypeSwitch.h"
28
29using namespace mlir;
30using namespace cir;
31
32namespace mlir {
33#define GEN_PASS_DEF_CIRFLATTENCFG
34#include "clang/CIR/Dialect/Passes.h.inc"
35} // namespace mlir
36
37namespace {
38
39/// Lowers operations with the terminator trait that have a single successor.
40void lowerTerminator(mlir::Operation *op, mlir::Block *dest,
41 mlir::PatternRewriter &rewriter) {
42 assert(op->hasTrait<mlir::OpTrait::IsTerminator>() && "not a terminator");
43 mlir::OpBuilder::InsertionGuard guard(rewriter);
44 rewriter.setInsertionPoint(op);
45 rewriter.replaceOpWithNewOp<cir::BrOp>(op, dest);
46}
47
48/// Walks a region while skipping operations of type `Ops`. This ensures the
49/// callback is not applied to said operations and its children.
50template <typename... Ops>
51void walkRegionSkipping(
52 mlir::Region &region,
53 mlir::function_ref<mlir::WalkResult(mlir::Operation *)> callback) {
54 region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
55 if (isa<Ops...>(op))
56 return mlir::WalkResult::skip();
57 return callback(op);
58 });
59}
60
61struct CIRFlattenCFGPass : public impl::CIRFlattenCFGBase<CIRFlattenCFGPass> {
62
63 CIRFlattenCFGPass() = default;
64 void runOnOperation() override;
65};
66
67struct CIRIfFlattening : public mlir::OpRewritePattern<cir::IfOp> {
68 using OpRewritePattern<IfOp>::OpRewritePattern;
69
70 mlir::LogicalResult
71 matchAndRewrite(cir::IfOp ifOp,
72 mlir::PatternRewriter &rewriter) const override {
73 mlir::OpBuilder::InsertionGuard guard(rewriter);
74 mlir::Location loc = ifOp.getLoc();
75 bool emptyElse = ifOp.getElseRegion().empty();
76 mlir::Block *currentBlock = rewriter.getInsertionBlock();
77 mlir::Block *remainingOpsBlock =
78 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
79 mlir::Block *continueBlock;
80 if (ifOp->getResults().empty())
81 continueBlock = remainingOpsBlock;
82 else
83 llvm_unreachable("NYI");
84
85 // Inline the region
86 mlir::Block *thenBeforeBody = &ifOp.getThenRegion().front();
87 mlir::Block *thenAfterBody = &ifOp.getThenRegion().back();
88 rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock);
89
90 rewriter.setInsertionPointToEnd(thenAfterBody);
91 if (auto thenYieldOp =
92 dyn_cast<cir::YieldOp>(thenAfterBody->getTerminator())) {
93 rewriter.replaceOpWithNewOp<cir::BrOp>(thenYieldOp, thenYieldOp.getArgs(),
94 continueBlock);
95 }
96
97 rewriter.setInsertionPointToEnd(continueBlock);
98
99 // Has else region: inline it.
100 mlir::Block *elseBeforeBody = nullptr;
101 mlir::Block *elseAfterBody = nullptr;
102 if (!emptyElse) {
103 elseBeforeBody = &ifOp.getElseRegion().front();
104 elseAfterBody = &ifOp.getElseRegion().back();
105 rewriter.inlineRegionBefore(ifOp.getElseRegion(), continueBlock);
106 } else {
107 elseBeforeBody = elseAfterBody = continueBlock;
108 }
109
110 rewriter.setInsertionPointToEnd(currentBlock);
111 cir::BrCondOp::create(rewriter, loc, ifOp.getCondition(), thenBeforeBody,
112 elseBeforeBody);
113
114 if (!emptyElse) {
115 rewriter.setInsertionPointToEnd(elseAfterBody);
116 if (auto elseYieldOP =
117 dyn_cast<cir::YieldOp>(elseAfterBody->getTerminator())) {
118 rewriter.replaceOpWithNewOp<cir::BrOp>(
119 elseYieldOP, elseYieldOP.getArgs(), continueBlock);
120 }
121 }
122
123 rewriter.replaceOp(ifOp, continueBlock->getArguments());
124 return mlir::success();
125 }
126};
127
128class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> {
129public:
130 using OpRewritePattern<cir::ScopeOp>::OpRewritePattern;
131
132 mlir::LogicalResult
133 matchAndRewrite(cir::ScopeOp scopeOp,
134 mlir::PatternRewriter &rewriter) const override {
135 mlir::OpBuilder::InsertionGuard guard(rewriter);
136 mlir::Location loc = scopeOp.getLoc();
137
138 // Empty scope: just remove it.
139 // TODO: Remove this logic once CIR uses MLIR infrastructure to remove
140 // trivially dead operations. MLIR canonicalizer is too aggressive and we
141 // need to either (a) make sure all our ops model all side-effects and/or
142 // (b) have more options in the canonicalizer in MLIR to temper
143 // aggressiveness level.
144 if (scopeOp.isEmpty()) {
145 rewriter.eraseOp(scopeOp);
146 return mlir::success();
147 }
148
149 // Split the current block before the ScopeOp to create the inlining
150 // point.
151 mlir::Block *currentBlock = rewriter.getInsertionBlock();
152 mlir::Block *continueBlock =
153 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
154 if (scopeOp.getNumResults() > 0)
155 continueBlock->addArguments(scopeOp.getResultTypes(), loc);
156
157 // Inline body region.
158 mlir::Block *beforeBody = &scopeOp.getScopeRegion().front();
159 mlir::Block *afterBody = &scopeOp.getScopeRegion().back();
160 rewriter.inlineRegionBefore(scopeOp.getScopeRegion(), continueBlock);
161
162 // Save stack and then branch into the body of the region.
163 rewriter.setInsertionPointToEnd(currentBlock);
165 cir::BrOp::create(rewriter, loc, mlir::ValueRange(), beforeBody);
166
167 // Replace the scopeop return with a branch that jumps out of the body.
168 // Stack restore before leaving the body region.
169 rewriter.setInsertionPointToEnd(afterBody);
170 if (auto yieldOp = dyn_cast<cir::YieldOp>(afterBody->getTerminator())) {
171 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(),
172 continueBlock);
173 }
174
175 // Replace the op with values return from the body region.
176 rewriter.replaceOp(scopeOp, continueBlock->getArguments());
177
178 return mlir::success();
179 }
180};
181
182class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
183public:
184 using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
185
186 inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
187 cir::YieldOp yieldOp,
188 mlir::Block *destination) const {
189 rewriter.setInsertionPoint(yieldOp);
190 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
191 destination);
192 }
193
194 // Return the new defaultDestination block.
195 Block *condBrToRangeDestination(cir::SwitchOp op,
196 mlir::PatternRewriter &rewriter,
197 mlir::Block *rangeDestination,
198 mlir::Block *defaultDestination,
199 const APInt &lowerBound,
200 const APInt &upperBound) const {
201 assert(lowerBound.sle(upperBound) && "Invalid range");
202 mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
203 cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true);
204 cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false);
205
206 cir::ConstantOp rangeLength = cir::ConstantOp::create(
207 rewriter, op.getLoc(),
208 cir::IntAttr::get(sIntType, upperBound - lowerBound));
209
210 cir::ConstantOp lowerBoundValue = cir::ConstantOp::create(
211 rewriter, op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
212 mlir::Value diffValue = cir::SubOp::create(
213 rewriter, op.getLoc(), op.getCondition(), lowerBoundValue);
214
215 // Use unsigned comparison to check if the condition is in the range.
216 cir::CastOp uDiffValue = cir::CastOp::create(
217 rewriter, op.getLoc(), uIntType, CastKind::integral, diffValue);
218 cir::CastOp uRangeLength = cir::CastOp::create(
219 rewriter, op.getLoc(), uIntType, CastKind::integral, rangeLength);
220
221 cir::CmpOp cmpResult = cir::CmpOp::create(
222 rewriter, op.getLoc(), cir::CmpOpKind::le, uDiffValue, uRangeLength);
223 cir::BrCondOp::create(rewriter, op.getLoc(), cmpResult, rangeDestination,
224 defaultDestination);
225 return resBlock;
226 }
227
228 mlir::LogicalResult
229 matchAndRewrite(cir::SwitchOp op,
230 mlir::PatternRewriter &rewriter) const override {
231 // Cleanup scopes must be lowered before the enclosing switch so that
232 // break inside them is properly routed through cleanup.
233 // Fail the match so the pattern rewriter will process cleanup scopes first.
234 bool hasNestedCleanup = op->walk([&](cir::CleanupScopeOp) {
235 return mlir::WalkResult::interrupt();
236 }).wasInterrupted();
237 if (hasNestedCleanup)
238 return mlir::failure();
239
240 llvm::SmallVector<CaseOp> cases;
241 op.collectCases(cases);
242
243 // Empty switch statement: just erase it.
244 if (cases.empty()) {
245 rewriter.eraseOp(op);
246 return mlir::success();
247 }
248
249 // Create exit block from the next node of cir.switch op.
250 mlir::Block *exitBlock = rewriter.splitBlock(
251 rewriter.getBlock(), op->getNextNode()->getIterator());
252
253 // We lower cir.switch op in the following process:
254 // 1. Inline the region from the switch op after switch op.
255 // 2. Traverse each cir.case op:
256 // a. Record the entry block, block arguments and condition for every
257 // case. b. Inline the case region after the case op.
258 // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the
259 // recorded block and conditions.
260
261 // inline everything from switch body between the switch op and the exit
262 // block.
263 {
264 cir::YieldOp switchYield = nullptr;
265 // Clear switch operation.
266 for (mlir::Block &block :
267 llvm::make_early_inc_range(op.getBody().getBlocks()))
268 if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
269 switchYield = yieldOp;
270
271 assert(!op.getBody().empty());
272 mlir::Block *originalBlock = op->getBlock();
273 mlir::Block *swopBlock =
274 rewriter.splitBlock(originalBlock, op->getIterator());
275 rewriter.inlineRegionBefore(op.getBody(), exitBlock);
276
277 if (switchYield)
278 rewriteYieldOp(rewriter, switchYield, exitBlock);
279
280 rewriter.setInsertionPointToEnd(originalBlock);
281 cir::BrOp::create(rewriter, op.getLoc(), swopBlock);
282 }
283
284 // Allocate required data structures (disconsider default case in
285 // vectors).
286 llvm::SmallVector<mlir::APInt, 8> caseValues;
287 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
288 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
289
290 llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
291 llvm::SmallVector<mlir::Block *> rangeDestinations;
292 llvm::SmallVector<mlir::ValueRange> rangeOperands;
293
294 // Initialize default case as optional.
295 mlir::Block *defaultDestination = exitBlock;
296 mlir::ValueRange defaultOperands = exitBlock->getArguments();
297
298 // Digest the case statements values and bodies.
299 for (cir::CaseOp caseOp : cases) {
300 mlir::Region &region = caseOp.getCaseRegion();
301
302 // Found default case: save destination and operands.
303 switch (caseOp.getKind()) {
304 case cir::CaseOpKind::Default:
305 defaultDestination = &region.front();
306 defaultOperands = defaultDestination->getArguments();
307 break;
308 case cir::CaseOpKind::Range:
309 assert(caseOp.getValue().size() == 2 &&
310 "Case range should have 2 case value");
311 rangeValues.push_back(
312 {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),
313 cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});
314 rangeDestinations.push_back(&region.front());
315 rangeOperands.push_back(rangeDestinations.back()->getArguments());
316 break;
317 case cir::CaseOpKind::Anyof:
318 case cir::CaseOpKind::Equal:
319 // AnyOf cases kind can have multiple values, hence the loop below.
320 for (const mlir::Attribute &value : caseOp.getValue()) {
321 caseValues.push_back(cast<cir::IntAttr>(value).getValue());
322 caseDestinations.push_back(&region.front());
323 caseOperands.push_back(caseDestinations.back()->getArguments());
324 }
325 break;
326 }
327
328 // Handle break statements.
329 walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
330 region, [&](mlir::Operation *op) {
331 if (!isa<cir::BreakOp>(op))
332 return mlir::WalkResult::advance();
333
334 lowerTerminator(op, exitBlock, rewriter);
335 return mlir::WalkResult::skip();
336 });
337
338 // Track fallthrough in cases.
339 for (mlir::Block &blk : region.getBlocks()) {
340 if (blk.getNumSuccessors())
341 continue;
342
343 if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {
344 mlir::Operation *nextOp = caseOp->getNextNode();
345 assert(nextOp && "caseOp is not expected to be the last op");
346 mlir::Block *oldBlock = nextOp->getBlock();
347 mlir::Block *newBlock =
348 rewriter.splitBlock(oldBlock, nextOp->getIterator());
349 rewriter.setInsertionPointToEnd(oldBlock);
350 cir::BrOp::create(rewriter, nextOp->getLoc(), mlir::ValueRange(),
351 newBlock);
352 rewriteYieldOp(rewriter, yieldOp, newBlock);
353 }
354 }
355
356 mlir::Block *oldBlock = caseOp->getBlock();
357 mlir::Block *newBlock =
358 rewriter.splitBlock(oldBlock, caseOp->getIterator());
359
360 mlir::Block &entryBlock = caseOp.getCaseRegion().front();
361 rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);
362
363 // Create a branch to the entry of the inlined region.
364 rewriter.setInsertionPointToEnd(oldBlock);
365 cir::BrOp::create(rewriter, caseOp.getLoc(), &entryBlock);
366 }
367
368 // Remove all cases since we've inlined the regions.
369 for (cir::CaseOp caseOp : cases) {
370 mlir::Block *caseBlock = caseOp->getBlock();
371 // Erase the block with no predecessors here to make the generated code
372 // simpler a little bit.
373 if (caseBlock->hasNoPredecessors())
374 rewriter.eraseBlock(caseBlock);
375 else
376 rewriter.eraseOp(caseOp);
377 }
378
379 for (auto [rangeVal, operand, destination] :
380 llvm::zip(rangeValues, rangeOperands, rangeDestinations)) {
381 APInt lowerBound = rangeVal.first;
382 APInt upperBound = rangeVal.second;
383
384 // The case range is unreachable, skip it.
385 if (lowerBound.sgt(upperBound))
386 continue;
387
388 // If range is small, add multiple switch instruction cases.
389 // This magical number is from the original CGStmt code.
390 constexpr int kSmallRangeThreshold = 64;
391 if ((upperBound - lowerBound)
392 .ult(llvm::APInt(32, kSmallRangeThreshold))) {
393 for (APInt iValue = lowerBound; iValue.sle(upperBound); ++iValue) {
394 caseValues.push_back(iValue);
395 caseOperands.push_back(operand);
396 caseDestinations.push_back(destination);
397 }
398 continue;
399 }
400
401 defaultDestination =
402 condBrToRangeDestination(op, rewriter, destination,
403 defaultDestination, lowerBound, upperBound);
404 defaultOperands = operand;
405 }
406
407 // Set switch op to branch to the newly created blocks.
408 rewriter.setInsertionPoint(op);
409 rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>(
410 op, op.getCondition(), defaultDestination, defaultOperands, caseValues,
411 caseDestinations, caseOperands);
412
413 return mlir::success();
414 }
415};
416
417class CIRLoopOpInterfaceFlattening
418 : public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> {
419public:
420 using mlir::OpInterfaceRewritePattern<
421 cir::LoopOpInterface>::OpInterfaceRewritePattern;
422
423 inline void lowerConditionOp(cir::ConditionOp op, mlir::Block *body,
424 mlir::Block *exit,
425 mlir::PatternRewriter &rewriter) const {
426 mlir::OpBuilder::InsertionGuard guard(rewriter);
427 rewriter.setInsertionPoint(op);
428 rewriter.replaceOpWithNewOp<cir::BrCondOp>(op, op.getCondition(), body,
429 exit);
430 }
431
432 mlir::LogicalResult
433 matchAndRewrite(cir::LoopOpInterface op,
434 mlir::PatternRewriter &rewriter) const final {
435 // Cleanup scopes must be lowered before the enclosing loop so that
436 // break/continue inside them are properly routed through cleanup.
437 // Fail the match so the pattern rewriter will process cleanup scopes first.
438 bool hasNestedCleanup = op->walk([&](cir::CleanupScopeOp) {
439 return mlir::WalkResult::interrupt();
440 }).wasInterrupted();
441 if (hasNestedCleanup)
442 return mlir::failure();
443
444 // Setup CFG blocks.
445 mlir::Block *entry = rewriter.getInsertionBlock();
446 mlir::Block *exit =
447 rewriter.splitBlock(entry, rewriter.getInsertionPoint());
448 mlir::Block *cond = &op.getCond().front();
449 mlir::Block *body = &op.getBody().front();
450 mlir::Block *step =
451 (op.maybeGetStep() ? &op.maybeGetStep()->front() : nullptr);
452
453 // Setup loop entry branch.
454 rewriter.setInsertionPointToEnd(entry);
455 cir::BrOp::create(rewriter, op.getLoc(), &op.getEntry().front());
456
457 // Branch from condition region to body or exit. The ConditionOp may not
458 // be in the first block of the condition region if a cleanup scope was
459 // already flattened within it, introducing multiple blocks. The
460 // ConditionOp is always the terminator of the last block.
461 auto conditionOp =
462 cast<cir::ConditionOp>(op.getCond().back().getTerminator());
463 lowerConditionOp(conditionOp, body, exit, rewriter);
464
465 // TODO(cir): Remove the walks below. It visits operations unnecessarily.
466 // However, to solve this we would likely need a custom DialectConversion
467 // driver to customize the order that operations are visited.
468
469 // Lower continue statements.
470 mlir::Block *dest = (step ? step : cond);
471 op.walkBodySkippingNestedLoops([&](mlir::Operation *op) {
472 if (!isa<cir::ContinueOp>(op))
473 return mlir::WalkResult::advance();
474
475 lowerTerminator(op, dest, rewriter);
476 return mlir::WalkResult::skip();
477 });
478
479 // Lower break statements.
480 walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
481 op.getBody(), [&](mlir::Operation *op) {
482 if (!isa<cir::BreakOp>(op))
483 return mlir::WalkResult::advance();
484
485 lowerTerminator(op, exit, rewriter);
486 return mlir::WalkResult::skip();
487 });
488
489 // Lower optional body region yield.
490 for (mlir::Block &blk : op.getBody().getBlocks()) {
491 auto bodyYield = dyn_cast<cir::YieldOp>(blk.getTerminator());
492 if (bodyYield)
493 lowerTerminator(bodyYield, (step ? step : cond), rewriter);
494 }
495
496 // Lower mandatory step region yield. Like the condition region, the
497 // YieldOp may be in the last block rather than the first if a cleanup
498 // scope was already flattened within the step region.
499 if (step)
500 lowerTerminator(
501 cast<cir::YieldOp>(op.maybeGetStep()->back().getTerminator()), cond,
502 rewriter);
503
504 // Move region contents out of the loop op.
505 rewriter.inlineRegionBefore(op.getCond(), exit);
506 rewriter.inlineRegionBefore(op.getBody(), exit);
507 if (step)
508 rewriter.inlineRegionBefore(*op.maybeGetStep(), exit);
509
510 rewriter.eraseOp(op);
511 return mlir::success();
512 }
513};
514
515class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
516public:
517 using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
518
519 mlir::LogicalResult
520 matchAndRewrite(cir::TernaryOp op,
521 mlir::PatternRewriter &rewriter) const override {
522 Location loc = op->getLoc();
523 Block *condBlock = rewriter.getInsertionBlock();
524 Block::iterator opPosition = rewriter.getInsertionPoint();
525 Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
526 llvm::SmallVector<mlir::Location, 2> locs;
527 // Ternary result is optional, make sure to populate the location only
528 // when relevant.
529 if (op->getResultTypes().size())
530 locs.push_back(loc);
531 Block *continueBlock =
532 rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
533 cir::BrOp::create(rewriter, loc, remainingOpsBlock);
534
535 Region &trueRegion = op.getTrueRegion();
536 Block *trueBlock = &trueRegion.front();
537 mlir::Operation *trueTerminator = trueRegion.back().getTerminator();
538 rewriter.setInsertionPointToEnd(&trueRegion.back());
539
540 // Handle both yield and unreachable terminators (throw expressions).
541 // Note: IR has already been modified (splitBlock, createBlock above), so
542 // we must not return failure() from this point onward per the MLIR pattern
543 // rewriter contract.
544 if (auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator)) {
545 rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
546 continueBlock);
547 } else if (isa<cir::UnreachableOp>(trueTerminator)) {
548 // Terminator is unreachable (e.g., from throw), just keep it
549 } else {
550 trueTerminator->emitError("unexpected terminator in ternary true region, "
551 "expected yield or unreachable, got: ")
552 << trueTerminator->getName();
553 // Return success because IR was already modified
554 // (splitBlock/createBlock).
555 return mlir::success();
556 }
557 rewriter.inlineRegionBefore(trueRegion, continueBlock);
558
559 Block *falseBlock = continueBlock;
560 Region &falseRegion = op.getFalseRegion();
561
562 falseBlock = &falseRegion.front();
563 mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
564 rewriter.setInsertionPointToEnd(&falseRegion.back());
565
566 // Handle both yield and unreachable terminators (throw expressions)
567 if (auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator)) {
568 rewriter.replaceOpWithNewOp<cir::BrOp>(
569 falseYieldOp, falseYieldOp.getArgs(), continueBlock);
570 } else if (isa<cir::UnreachableOp>(falseTerminator)) {
571 // Terminator is unreachable (e.g., from throw), just keep it
572 } else {
573 falseTerminator->emitError("unexpected terminator in ternary false "
574 "region, expected yield or unreachable, got: ")
575 << falseTerminator->getName();
576 // Return success because IR was already modified
577 // (splitBlock/createBlock).
578 return mlir::success();
579 }
580 rewriter.inlineRegionBefore(falseRegion, continueBlock);
581
582 rewriter.setInsertionPointToEnd(condBlock);
583 cir::BrCondOp::create(rewriter, loc, op.getCond(), trueBlock, falseBlock);
584
585 rewriter.replaceOp(op, continueBlock->getArguments());
586
587 // Ok, we're done!
588 return mlir::success();
589 }
590};
591
592// Get or create the cleanup destination slot for a function. This slot is
593// shared across all cleanup scopes in the function to track which exit path
594// to take after running cleanup code when there are multiple exits.
595static cir::AllocaOp getOrCreateCleanupDestSlot(cir::FuncOp funcOp,
596 mlir::PatternRewriter &rewriter,
597 mlir::Location loc) {
598 mlir::Block &entryBlock = funcOp.getBody().front();
599
600 // Look for an existing cleanup dest slot in the entry block.
601 auto it = llvm::find_if(entryBlock, [](auto &op) {
602 return mlir::isa<AllocaOp>(&op) &&
603 mlir::cast<AllocaOp>(&op).getCleanupDestSlot();
604 });
605 if (it != entryBlock.end())
606 return mlir::cast<cir::AllocaOp>(*it);
607
608 // Create a new cleanup dest slot at the start of the entry block.
609 mlir::OpBuilder::InsertionGuard guard(rewriter);
610 rewriter.setInsertionPointToStart(&entryBlock);
611 cir::IntType s32Type =
612 cir::IntType::get(rewriter.getContext(), 32, /*isSigned=*/true);
613 cir::PointerType ptrToS32Type = cir::PointerType::get(s32Type);
614 cir::CIRDataLayout dataLayout(funcOp->getParentOfType<mlir::ModuleOp>());
615 uint64_t alignment = dataLayout.getAlignment(s32Type, true).value();
616 auto allocaOp = cir::AllocaOp::create(
617 rewriter, loc, ptrToS32Type, s32Type, "__cleanup_dest_slot",
618 /*alignment=*/rewriter.getI64IntegerAttr(alignment));
619 allocaOp.setCleanupDestSlot(true);
620 return allocaOp;
621}
622
623/// Shared EH flattening utilities used by both CIRCleanupScopeOpFlattening
624/// and CIRTryOpFlattening.
625
626// Collect all function calls in a region that may throw exceptions and need
627// to be replaced with try_call operations. Skips calls marked nothrow.
628// Nested cleanup scopes and try ops are always flattened before their
629// enclosing parents, so there are no nested regions to skip here.
630static void
631collectThrowingCalls(mlir::Region &region,
632 llvm::SmallVectorImpl<cir::CallOp> &callsToRewrite) {
633 region.walk([&](cir::CallOp callOp) {
634 if (!callOp.getNothrow())
635 callsToRewrite.push_back(callOp);
636 });
637}
638
639// Collect all cir.resume operations in a region that come from
640// already-flattened try or cleanup scope operations. These resume ops need
641// to be chained through this scope's EH handler instead of unwinding
642// directly to the caller. Nested cleanup scopes and try ops are always
643// flattened before their enclosing parents, so there are no nested regions
644// to skip here.
645static void collectResumeOps(mlir::Region &region,
647 region.walk([&](cir::ResumeOp resumeOp) { resumeOps.push_back(resumeOp); });
648}
649
650// Replace a cir.call with a cir.try_call that unwinds to the `unwindDest`
651// block if an exception is thrown.
652static void replaceCallWithTryCall(cir::CallOp callOp, mlir::Block *unwindDest,
653 mlir::Location loc,
654 mlir::PatternRewriter &rewriter) {
655 mlir::Block *callBlock = callOp->getBlock();
656
657 assert(!callOp.getNothrow() && "call is not expected to throw");
658
659 // Split the block after the call - remaining ops become the normal
660 // destination.
661 mlir::Block *normalDest =
662 rewriter.splitBlock(callBlock, std::next(callOp->getIterator()));
663
664 // Build the try_call to replace the original call.
665 rewriter.setInsertionPoint(callOp);
666 cir::TryCallOp tryCallOp;
667 if (callOp.isIndirect()) {
668 mlir::Value indTarget = callOp.getIndirectCall();
669 auto ptrTy = mlir::cast<cir::PointerType>(indTarget.getType());
670 auto resTy = mlir::cast<cir::FuncType>(ptrTy.getPointee());
671 tryCallOp =
672 cir::TryCallOp::create(rewriter, loc, indTarget, resTy, normalDest,
673 unwindDest, callOp.getArgOperands());
674 } else {
675 mlir::Type resType = callOp->getNumResults() > 0
676 ? callOp->getResult(0).getType()
677 : mlir::Type();
678 tryCallOp =
679 cir::TryCallOp::create(rewriter, loc, callOp.getCalleeAttr(), resType,
680 normalDest, unwindDest, callOp.getArgOperands());
681 }
682
683 // Copy all attributes from the original call except those already set by
684 // TryCallOp::create or that are operation-specific and should not be copied.
685 llvm::StringRef excludedAttrs[] = {
686 CIRDialect::getCalleeAttrName(), // Set by create()
687 CIRDialect::getOperandSegmentSizesAttrName(),
688 };
689#ifndef NDEBUG
690 // We don't expect to ever see any of these attributes on a call that we
691 // converted to a try_call.
692 llvm::StringRef unexpectedAttrs[] = {
693 CIRDialect::getNoThrowAttrName(),
694 CIRDialect::getNoUnwindAttrName(),
695 };
696#endif
697 for (mlir::NamedAttribute attr : callOp->getAttrs()) {
698 if (llvm::is_contained(excludedAttrs, attr.getName()))
699 continue;
700 assert(!llvm::is_contained(unexpectedAttrs, attr.getName()) &&
701 "unexpected attribute on converted call");
702 tryCallOp->setAttr(attr.getName(), attr.getValue());
703 }
704
705 // Replace uses of the call result with the try_call result.
706 if (callOp->getNumResults() > 0)
707 callOp->getResult(0).replaceAllUsesWith(tryCallOp.getResult());
708
709 rewriter.eraseOp(callOp);
710}
711
712// Create a shared unwind destination block. The block contains a
713// cir.eh.initiate operation (optionally with the cleanup attribute) and a
714// branch to the given destination block, passing the eh_token.
715static mlir::Block *buildUnwindBlock(mlir::Block *dest, bool isCleanupOnly,
716 mlir::Location loc,
717 mlir::Block *insertBefore,
718 mlir::PatternRewriter &rewriter) {
719 mlir::Block *unwindBlock = rewriter.createBlock(insertBefore);
720 rewriter.setInsertionPointToEnd(unwindBlock);
721 auto ehInitiate =
722 cir::EhInitiateOp::create(rewriter, loc, /*cleanup=*/isCleanupOnly);
723 cir::BrOp::create(rewriter, loc, mlir::ValueRange{ehInitiate.getEhToken()},
724 dest);
725 return unwindBlock;
726}
727
728// Create a shared terminate unwind block for throwing calls in EH cleanup
729// regions. When an exception is thrown during cleanup (unwinding), the C++
730// standard requires that std::terminate() be called.
731static mlir::Block *buildTerminateUnwindBlock(mlir::Location loc,
732 mlir::Block *insertBefore,
733 mlir::PatternRewriter &rewriter) {
734 mlir::Block *terminateBlock = rewriter.createBlock(insertBefore);
735 rewriter.setInsertionPointToEnd(terminateBlock);
736 auto ehInitiate = cir::EhInitiateOp::create(rewriter, loc, /*cleanup=*/false);
737 cir::EhTerminateOp::create(rewriter, loc, ehInitiate.getEhToken());
738 return terminateBlock;
739}
740
741class CIRCleanupScopeOpFlattening
742 : public mlir::OpRewritePattern<cir::CleanupScopeOp> {
743public:
744 using OpRewritePattern<cir::CleanupScopeOp>::OpRewritePattern;
745
746 struct CleanupExit {
747 // An operation that exits the cleanup scope (yield, break, continue,
748 // return, etc.)
749 mlir::Operation *exitOp;
750
751 // A unique identifier for this exit's destination (used for switch dispatch
752 // when there are multiple exits).
753 int destinationId;
754
755 CleanupExit(mlir::Operation *op, int id) : exitOp(op), destinationId(id) {}
756 };
757
758 // Collect all operations that exit a cleanup scope body. Return, goto, break,
759 // and continue can all require branches through the cleanup region. When a
760 // loop is encountered, only return and goto are collected because break and
761 // continue are handled by the loop and stay within the cleanup scope. When a
762 // switch is encountered, return, goto and continue are collected because they
763 // may all branch through the cleanup, but break is local to the switch. When
764 // a nested cleanup scope is encountered, we recursively collect exits since
765 // any return, goto, break, or continue from the nested cleanup will also
766 // branch through the outer cleanup.
767 //
768 // Note that goto statements may not necessarily exit the cleanup scope, but
769 // for now we conservatively assume that they do. We'll need more nuanced
770 // handling of that when multi-exit flattening is implemented.
771 //
772 // This function assigns unique destination IDs to each exit, which are
773 // used when multi-exit cleanup scopes are flattened.
774 void collectExits(mlir::Region &cleanupBodyRegion,
775 llvm::SmallVectorImpl<CleanupExit> &exits,
776 int &nextId) const {
777 // Collect yield terminators from the body region. We do this separately
778 // because yields in nested operations, including those in nested cleanup
779 // scopes, won't branch through the outer cleanup region.
780 for (mlir::Block &block : cleanupBodyRegion) {
781 auto *terminator = block.getTerminator();
782 if (isa<cir::YieldOp>(terminator))
783 exits.emplace_back(terminator, nextId++);
784 }
785
786 // Lambda to walk a loop and collect only returns and gotos.
787 // Break and continue inside loops are handled by the loop itself.
788 // Loops don't require special handling for nested switch or cleanup scopes
789 // because break and continue never branch out of the loop.
790 auto collectExitsInLoop = [&](mlir::Operation *loopOp) {
791 loopOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *nestedOp) {
792 if (isa<cir::ReturnOp, cir::GotoOp>(nestedOp))
793 exits.emplace_back(nestedOp, nextId++);
794 return mlir::WalkResult::advance();
795 });
796 };
797
798 // Forward declaration for mutual recursion.
799 std::function<void(mlir::Region &, bool)> collectExitsInCleanup;
800 std::function<void(mlir::Operation *)> collectExitsInSwitch;
801
802 // Lambda to collect exits from a switch. Collects return/goto/continue but
803 // not break (handled by switch). For nested loops/cleanups, recurses.
804 collectExitsInSwitch = [&](mlir::Operation *switchOp) {
805 switchOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *nestedOp) {
806 if (isa<cir::CleanupScopeOp>(nestedOp)) {
807 // Walk the nested cleanup, but ignore break statements because they
808 // will be handled by the switch we are currently walking.
809 collectExitsInCleanup(
810 cast<cir::CleanupScopeOp>(nestedOp).getBodyRegion(),
811 /*ignoreBreak=*/true);
812 return mlir::WalkResult::skip();
813 } else if (isa<cir::LoopOpInterface>(nestedOp)) {
814 collectExitsInLoop(nestedOp);
815 return mlir::WalkResult::skip();
816 } else if (isa<cir::ReturnOp, cir::GotoOp, cir::ContinueOp>(nestedOp)) {
817 exits.emplace_back(nestedOp, nextId++);
818 }
819 return mlir::WalkResult::advance();
820 });
821 };
822
823 // Lambda to collect exits from a cleanup scope body region. This collects
824 // break (optionally), continue, return, and goto, handling nested loops,
825 // switches, and cleanups appropriately.
826 collectExitsInCleanup = [&](mlir::Region &region, bool ignoreBreak) {
827 region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
828 // We need special handling for break statements because if this cleanup
829 // scope was nested within a switch op, break will be handled by the
830 // switch operation and therefore won't exit the cleanup scope enclosing
831 // the switch. We're only collecting exits from the cleanup that started
832 // this walk. Exits from nested cleanups will be handled when we flatten
833 // the nested cleanup.
834 if (!ignoreBreak && isa<cir::BreakOp>(op)) {
835 exits.emplace_back(op, nextId++);
836 } else if (isa<cir::ContinueOp, cir::ReturnOp, cir::GotoOp>(op)) {
837 exits.emplace_back(op, nextId++);
838 } else if (isa<cir::CleanupScopeOp>(op)) {
839 // Recurse into nested cleanup's body region.
840 collectExitsInCleanup(cast<cir::CleanupScopeOp>(op).getBodyRegion(),
841 /*ignoreBreak=*/ignoreBreak);
842 return mlir::WalkResult::skip();
843 } else if (isa<cir::LoopOpInterface>(op)) {
844 // This kicks off a separate walk rather than continuing to dig deeper
845 // in the current walk because we need to handle break and continue
846 // differently inside loops.
847 collectExitsInLoop(op);
848 return mlir::WalkResult::skip();
849 } else if (isa<cir::SwitchOp>(op)) {
850 // This kicks off a separate walk rather than continuing to dig deeper
851 // in the current walk because we need to handle break differently
852 // inside switches.
853 collectExitsInSwitch(op);
854 return mlir::WalkResult::skip();
855 }
856 return mlir::WalkResult::advance();
857 });
858 };
859
860 // Collect exits from the body region.
861 collectExitsInCleanup(cleanupBodyRegion, /*ignoreBreak=*/false);
862 }
863
864 // Check if an operand's defining op should be moved to the destination block.
865 // We only sink constants and simple loads. Anything else should be saved
866 // to a temporary alloca and reloaded at the destination block.
867 static bool shouldSinkReturnOperand(mlir::Value operand,
868 cir::ReturnOp returnOp) {
869 // Block arguments can't be moved
870 mlir::Operation *defOp = operand.getDefiningOp();
871 if (!defOp)
872 return false;
873
874 // Only move constants and loads to the dispatch block. For anything else,
875 // we'll store to a temporary and reload in the dispatch block.
876 if (!mlir::isa<cir::ConstantOp, cir::LoadOp>(defOp))
877 return false;
878
879 // Check if the return is the only user
880 if (!operand.hasOneUse())
881 return false;
882
883 // Only move ops that are in the same block as the return.
884 if (defOp->getBlock() != returnOp->getBlock())
885 return false;
886
887 if (auto loadOp = mlir::dyn_cast<cir::LoadOp>(defOp)) {
888 // Only attempt to move loads of allocas in the entry block.
889 mlir::Value ptr = loadOp.getAddr();
890 auto funcOp = returnOp->getParentOfType<cir::FuncOp>();
891 assert(funcOp && "Return op has no function parent?");
892 mlir::Block &funcEntryBlock = funcOp.getBody().front();
893
894 // Check if it's an alloca in the function entry block
895 if (auto allocaOp =
896 mlir::dyn_cast_if_present<cir::AllocaOp>(ptr.getDefiningOp()))
897 return allocaOp->getBlock() == &funcEntryBlock;
898
899 return false;
900 }
901
902 // Make sure we only fall through to here with constants.
903 assert(mlir::isa<cir::ConstantOp>(defOp) && "Expected constant op");
904 return true;
905 }
906
907 // For returns with operands in cleanup dispatch blocks, the operands may not
908 // dominate the dispatch block. This function handles that by either sinking
909 // the operand's defining op to the dispatch block (for constants and simple
910 // loads) or by storing to a temporary alloca and reloading it.
911 void
912 getReturnOpOperands(cir::ReturnOp returnOp, mlir::Operation *exitOp,
913 mlir::Location loc, mlir::PatternRewriter &rewriter,
914 llvm::SmallVectorImpl<mlir::Value> &returnValues) const {
915 mlir::Block *destBlock = rewriter.getInsertionBlock();
916 auto funcOp = exitOp->getParentOfType<cir::FuncOp>();
917 assert(funcOp && "Return op has no function parent?");
918 mlir::Block &funcEntryBlock = funcOp.getBody().front();
919
920 for (mlir::Value operand : returnOp.getOperands()) {
921 if (shouldSinkReturnOperand(operand, returnOp)) {
922 // Sink the defining op to the dispatch block.
923 mlir::Operation *defOp = operand.getDefiningOp();
924 rewriter.moveOpBefore(defOp, destBlock, destBlock->end());
925 returnValues.push_back(operand);
926 } else {
927 // Create an alloca in the function entry block.
928 cir::AllocaOp alloca;
929 {
930 mlir::OpBuilder::InsertionGuard guard(rewriter);
931 rewriter.setInsertionPointToStart(&funcEntryBlock);
932 cir::CIRDataLayout dataLayout(
933 funcOp->getParentOfType<mlir::ModuleOp>());
934 uint64_t alignment =
935 dataLayout.getAlignment(operand.getType(), true).value();
936 cir::PointerType ptrType = cir::PointerType::get(operand.getType());
937 alloca = cir::AllocaOp::create(rewriter, loc, ptrType,
938 operand.getType(), "__ret_operand_tmp",
939 rewriter.getI64IntegerAttr(alignment));
940 }
941
942 // Store the operand value at the original return location.
943 {
944 mlir::OpBuilder::InsertionGuard guard(rewriter);
945 rewriter.setInsertionPoint(exitOp);
946 cir::StoreOp::create(rewriter, loc, operand, alloca,
947 /*isVolatile=*/false,
948 /*alignment=*/mlir::IntegerAttr(),
949 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
950 }
951
952 // Reload the value from the temporary alloca in the destination block.
953 rewriter.setInsertionPointToEnd(destBlock);
954 auto loaded = cir::LoadOp::create(
955 rewriter, loc, alloca, /*isDeref=*/false,
956 /*isVolatile=*/false, /*alignment=*/mlir::IntegerAttr(),
957 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
958 returnValues.push_back(loaded);
959 }
960 }
961 }
962
963 // Create the appropriate terminator for an exit operation in the dispatch
964 // block. For return ops with operands, this handles the dominance issue by
965 // either moving the operand's defining op to the dispatch block (if it's a
966 // trivial use) or by storing to a temporary alloca and loading it.
967 mlir::LogicalResult
968 createExitTerminator(mlir::Operation *exitOp, mlir::Location loc,
969 mlir::Block *continueBlock,
970 mlir::PatternRewriter &rewriter) const {
971 return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(exitOp)
972 .Case<cir::YieldOp>([&](auto) {
973 // Yield becomes a branch to continue block.
974 cir::BrOp::create(rewriter, loc, continueBlock);
975 return mlir::success();
976 })
977 .Case<cir::BreakOp>([&](auto) {
978 // Break is preserved for later lowering by enclosing switch/loop.
979 cir::BreakOp::create(rewriter, loc);
980 return mlir::success();
981 })
982 .Case<cir::ContinueOp>([&](auto) {
983 // Continue is preserved for later lowering by enclosing loop.
984 cir::ContinueOp::create(rewriter, loc);
985 return mlir::success();
986 })
987 .Case<cir::ReturnOp>([&](auto returnOp) {
988 // Return from the cleanup exit. Note, if this is a return inside a
989 // nested cleanup scope, the flattening of the outer scope will handle
990 // branching through the outer cleanup.
991 if (returnOp.hasOperand()) {
992 llvm::SmallVector<mlir::Value, 2> returnValues;
993 getReturnOpOperands(returnOp, exitOp, loc, rewriter, returnValues);
994 cir::ReturnOp::create(rewriter, loc, returnValues);
995 } else {
996 cir::ReturnOp::create(rewriter, loc);
997 }
998 return mlir::success();
999 })
1000 .Case<cir::GotoOp>([&](auto gotoOp) {
1001 // Correct goto handling requires determining whether the goto
1002 // branches out of the cleanup scope or stays within it.
1003 // Although the goto necessarily exits the cleanup scope in the
1004 // case where it is the only exit from the scope, it is left
1005 // as unimplemented for now so that it can be generalized when
1006 // multi-exit flattening is implemented.
1007 cir::UnreachableOp::create(rewriter, loc);
1008 return gotoOp.emitError(
1009 "goto in cleanup scope is not yet implemented");
1010 })
1011 .Default([&](mlir::Operation *op) {
1012 cir::UnreachableOp::create(rewriter, loc);
1013 return op->emitError(
1014 "unexpected exit operation in cleanup scope body");
1015 });
1016 }
1017
1018#ifndef NDEBUG
1019 // Check that no block other than the last one in a region exits the region.
1020 static bool regionExitsOnlyFromLastBlock(mlir::Region &region) {
1021 for (mlir::Block &block : region) {
1022 if (&block == &region.back())
1023 continue;
1024 bool expectedTerminator =
1025 llvm::TypeSwitch<mlir::Operation *, bool>(block.getTerminator())
1026 // It is theoretically possible to have a cleanup block with
1027 // any of the following exits in non-final blocks, but we won't
1028 // currently generate any CIR that does that, and being able to
1029 // assume that it doesn't happen simplifies the implementation.
1030 // If we ever need to handle this case, the code will need to
1031 // be updated to handle it.
1032 .Case<cir::YieldOp, cir::ReturnOp, cir::ResumeFlatOp,
1033 cir::ContinueOp, cir::BreakOp, cir::GotoOp>(
1034 [](auto) { return false; })
1035 // We expect that call operations have not yet been rewritten
1036 // as try_call operations. A call can unwind out of the cleanup
1037 // scope, but we will be handling that during flattening. The
1038 // only case where a try_call could be present inside an
1039 // unflattened cleanup region is if the cleanup contained a
1040 // nested try-catch region, and that isn't expected as of the
1041 // time of this implementation. If it does, this could be
1042 // updated to tolerate it.
1043 .Case<cir::TryCallOp>([](auto) { return false; })
1044 // Likewise, we don't expect to find an EH dispatch operation
1045 // because we weren't expecting try-catch regions nested in the
1046 // cleanup region.
1047 .Case<cir::EhDispatchOp>([](auto) { return false; })
1048 // In theory, it would be possible to have a flattened switch
1049 // operation that does not exit the cleanup region. For now,
1050 // that's not happening.
1051 .Case<cir::SwitchFlatOp>([](auto) { return false; })
1052 // These aren't expected either, but if they occur, they don't
1053 // exit the region, so that's OK.
1054 .Case<cir::UnreachableOp, cir::TrapOp>([](auto) { return true; })
1055 // Indirect branches are not expected.
1056 .Case<cir::IndirectBrOp>([](auto) { return false; })
1057 // We do expect branches, but we don't expect them to leave
1058 // the region.
1059 .Case<cir::BrOp>([&](cir::BrOp brOp) {
1060 assert(brOp.getDest()->getParent() == &region &&
1061 "branch destination is not in the region");
1062 return true;
1063 })
1064 .Case<cir::BrCondOp>([&](cir::BrCondOp brCondOp) {
1065 assert(brCondOp.getDestTrue()->getParent() == &region &&
1066 "branch destination is not in the region");
1067 assert(brCondOp.getDestFalse()->getParent() == &region &&
1068 "branch destination is not in the region");
1069 return true;
1070 })
1071 // What else could there be?
1072 .Default([](mlir::Operation *) -> bool {
1073 llvm_unreachable("unexpected terminator in cleanup region");
1074 });
1075 if (!expectedTerminator)
1076 return false;
1077 }
1078 return true;
1079 }
1080#endif
1081
1082 // Build the EH cleanup block structure by cloning the cleanup region. The
1083 // cloned entry block gets an !cir.eh_token argument and a cir.begin_cleanup
1084 // inserted at the top. All cir.yield terminators that might exit the cleanup
1085 // region are replaced with cir.end_cleanup + cir.resume.
1086 //
1087 // For a single-block cleanup region, this produces:
1088 //
1089 // ^eh_cleanup(%eh_token : !cir.eh_token):
1090 // %ct = cir.begin_cleanup %eh_token : !cir.eh_token -> !cir.cleanup_token
1091 // <cloned cleanup operations>
1092 // cir.end_cleanup %ct : !cir.cleanup_token
1093 // cir.resume %eh_token : !cir.eh_token
1094 //
1095 // For a multi-block cleanup region (e.g. containing a flattened cir.if),
1096 // the same wrapping is applied around the cloned block structure: the entry
1097 // block gets begin_cleanup and all exit blocks (those terminated by yield)
1098 // get end_cleanup + resume.
1099 //
1100 // If this cleanup scope is nested within a TryOp, the resume will be updated
1101 // to branch to the catch dispatch block of the enclosing try operation when
1102 // the TryOp is flattened.
1103 mlir::Block *buildEHCleanupBlocks(cir::CleanupScopeOp cleanupOp,
1104 mlir::Location loc,
1105 mlir::Block *insertBefore,
1106 mlir::PatternRewriter &rewriter) const {
1107 assert(regionExitsOnlyFromLastBlock(cleanupOp.getCleanupRegion()) &&
1108 "cleanup region has exits in non-final blocks");
1109
1110 // Track the block before the insertion point so we can find the cloned
1111 // blocks after cloning.
1112 mlir::Block *blockBeforeClone = insertBefore->getPrevNode();
1113
1114 // Clone the entire cleanup region before insertBefore.
1115 rewriter.cloneRegionBefore(cleanupOp.getCleanupRegion(), insertBefore);
1116
1117 // Find the first cloned block.
1118 mlir::Block *clonedEntry = blockBeforeClone
1119 ? blockBeforeClone->getNextNode()
1120 : &insertBefore->getParent()->front();
1121
1122 // Add the eh_token argument to the cloned entry block and insert
1123 // begin_cleanup at the top.
1124 auto ehTokenType = cir::EhTokenType::get(rewriter.getContext());
1125 mlir::Value ehToken = clonedEntry->addArgument(ehTokenType, loc);
1126
1127 rewriter.setInsertionPointToStart(clonedEntry);
1128 auto beginCleanup = cir::BeginCleanupOp::create(rewriter, loc, ehToken);
1129
1130 // Replace the yield terminator in the last cloned block with
1131 // end_cleanup + resume.
1132 mlir::Block *lastClonedBlock = insertBefore->getPrevNode();
1133 auto yieldOp =
1134 mlir::dyn_cast<cir::YieldOp>(lastClonedBlock->getTerminator());
1135 if (yieldOp) {
1136 rewriter.setInsertionPoint(yieldOp);
1137 cir::EndCleanupOp::create(rewriter, loc, beginCleanup.getCleanupToken());
1138 rewriter.replaceOpWithNewOp<cir::ResumeOp>(yieldOp, ehToken);
1139 } else {
1140 cleanupOp->emitError("Not yet implemented: cleanup region terminated "
1141 "with non-yield operation");
1142 }
1143
1144 return clonedEntry;
1145 }
1146
1147 // Flatten a cleanup scope. The body region's exits branch to the cleanup
1148 // block, and the cleanup block branches to destination blocks whose contents
1149 // depend on the type of operation that exited the body region. Yield becomes
1150 // a branch to the block after the cleanup scope, break and continue are
1151 // preserved for later lowering by enclosing switch or loop, and return
1152 // is preserved as is.
1153 //
1154 // If there are multiple exits from the cleanup body, a destination slot and
1155 // switch dispatch are used to continue to the correct destination after the
1156 // cleanup is complete. A destination slot alloca is created at the function
1157 // entry block. Each exit operation is replaced by a store of its unique ID to
1158 // the destination slot and a branch to cleanup. An operation is appended to
1159 // the to branch to a dispatch block that loads the destination slot and uses
1160 // switch.flat to branch to the correct destination.
1161 //
1162 // If the cleanup scope requires EH cleanup, any call operations in the body
1163 // that may throw are replaced with cir.try_call operations that unwind to an
1164 // EH cleanup block. The cleanup block(s) will be terminated with a cir.resume
1165 // operation. If this cleanup scope is enclosed by a try operation, the
1166 // flattening of the try operation flattening will replace the cir.resume with
1167 // a branch to a catch dispatch block. Otherwise, the cir.resume operation
1168 // remains in place and will unwind to the caller.
1169 mlir::LogicalResult
1170 flattenCleanup(cir::CleanupScopeOp cleanupOp,
1171 llvm::SmallVectorImpl<CleanupExit> &exits,
1172 llvm::SmallVectorImpl<cir::CallOp> &callsToRewrite,
1173 llvm::SmallVectorImpl<cir::ResumeOp> &resumeOpsToChain,
1174 mlir::PatternRewriter &rewriter) const {
1175 mlir::Location loc = cleanupOp.getLoc();
1176 cir::CleanupKind cleanupKind = cleanupOp.getCleanupKind();
1177 bool hasNormalCleanup = cleanupKind == cir::CleanupKind::Normal ||
1178 cleanupKind == cir::CleanupKind::All;
1179 bool hasEHCleanup = cleanupKind == cir::CleanupKind::EH ||
1180 cleanupKind == cir::CleanupKind::All;
1181 bool isMultiExit = exits.size() > 1;
1182
1183 // Get references to region blocks before inlining.
1184 mlir::Block *bodyEntry = &cleanupOp.getBodyRegion().front();
1185 mlir::Block *cleanupEntry = &cleanupOp.getCleanupRegion().front();
1186 mlir::Block *cleanupExit = &cleanupOp.getCleanupRegion().back();
1187 assert(regionExitsOnlyFromLastBlock(cleanupOp.getCleanupRegion()) &&
1188 "cleanup region has exits in non-final blocks");
1189 auto cleanupYield = dyn_cast<cir::YieldOp>(cleanupExit->getTerminator());
1190 if (!cleanupYield) {
1191 return rewriter.notifyMatchFailure(cleanupOp,
1192 "Not yet implemented: cleanup region "
1193 "terminated with non-yield operation");
1194 }
1195
1196 // For multiple exits from the body region, get or create a destination slot
1197 // at function entry. The slot is shared across all cleanup scopes in the
1198 // function. This is only needed if the cleanup scope requires normal
1199 // cleanup.
1200 cir::AllocaOp destSlot;
1201 if (isMultiExit && hasNormalCleanup) {
1202 auto funcOp = cleanupOp->getParentOfType<cir::FuncOp>();
1203 if (!funcOp)
1204 return cleanupOp->emitError("cleanup scope not inside a function");
1205 destSlot = getOrCreateCleanupDestSlot(funcOp, rewriter, loc);
1206 }
1207
1208 // Split the current block to create the insertion point.
1209 mlir::Block *currentBlock = rewriter.getInsertionBlock();
1210 mlir::Block *continueBlock =
1211 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
1212
1213 // Build EH cleanup blocks if needed. This must be done before inlining
1214 // the cleanup region since buildEHCleanupBlocks clones from it. The unwind
1215 // block is inserted before the EH cleanup entry so that the final layout
1216 // is: body -> normal cleanup -> exit -> unwind -> EH cleanup -> continue.
1217 // EH cleanup blocks are needed when there are throwing calls that need to
1218 // be rewritten to try_call, or when there are resume ops from
1219 // already-flattened inner cleanup scopes that need to chain through this
1220 // cleanup's EH handler.
1221 mlir::Block *unwindBlock = nullptr;
1222 mlir::Block *ehCleanupEntry = nullptr;
1223 if (hasEHCleanup &&
1224 (!callsToRewrite.empty() || !resumeOpsToChain.empty())) {
1225 ehCleanupEntry =
1226 buildEHCleanupBlocks(cleanupOp, loc, continueBlock, rewriter);
1227 // The unwind block is only needed when there are throwing calls that
1228 // need a shared unwind destination. Resume ops from inner cleanups
1229 // branch directly to the EH cleanup entry.
1230 if (!callsToRewrite.empty())
1231 unwindBlock = buildUnwindBlock(ehCleanupEntry, /*isCleanupOnly=*/true,
1232 loc, ehCleanupEntry, rewriter);
1233 }
1234
1235 // All normal flow blocks are inserted before this point — either before
1236 // the unwind block (if it exists), or before the EH cleanup entry (if EH
1237 // cleanup exists but no unwind block is needed), or before the continue
1238 // block.
1239 mlir::Block *normalInsertPt =
1240 unwindBlock ? unwindBlock
1241 : (ehCleanupEntry ? ehCleanupEntry : continueBlock);
1242
1243 // Inline the body region.
1244 rewriter.inlineRegionBefore(cleanupOp.getBodyRegion(), normalInsertPt);
1245
1246 // Inline the cleanup region for the normal cleanup path.
1247 if (hasNormalCleanup)
1248 rewriter.inlineRegionBefore(cleanupOp.getCleanupRegion(), normalInsertPt);
1249
1250 // Branch from current block to body entry.
1251 rewriter.setInsertionPointToEnd(currentBlock);
1252 cir::BrOp::create(rewriter, loc, bodyEntry);
1253
1254 // Handle normal exits.
1255 mlir::LogicalResult result = mlir::success();
1256 if (hasNormalCleanup) {
1257 // Create the exit/dispatch block (after cleanup, before continue).
1258 mlir::Block *exitBlock = rewriter.createBlock(normalInsertPt);
1259
1260 // Rewrite the cleanup region's yield to branch to exit block.
1261 rewriter.setInsertionPoint(cleanupYield);
1262 rewriter.replaceOpWithNewOp<cir::BrOp>(cleanupYield, exitBlock);
1263
1264 if (isMultiExit) {
1265 // Build the dispatch switch in the exit block.
1266 rewriter.setInsertionPointToEnd(exitBlock);
1267
1268 // Load the destination slot value.
1269 auto slotValue = cir::LoadOp::create(
1270 rewriter, loc, destSlot, /*isDeref=*/false,
1271 /*isVolatile=*/false, /*alignment=*/mlir::IntegerAttr(),
1272 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
1273
1274 // Create destination blocks for each exit and collect switch case info.
1275 llvm::SmallVector<mlir::APInt, 8> caseValues;
1276 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
1277 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
1278 cir::IntType s32Type =
1279 cir::IntType::get(rewriter.getContext(), 32, /*isSigned=*/true);
1280
1281 for (const CleanupExit &exit : exits) {
1282 // Create a block for this destination.
1283 mlir::Block *destBlock = rewriter.createBlock(normalInsertPt);
1284 rewriter.setInsertionPointToEnd(destBlock);
1285 result =
1286 createExitTerminator(exit.exitOp, loc, continueBlock, rewriter);
1287
1288 // Add to switch cases.
1289 caseValues.push_back(
1290 llvm::APInt(32, static_cast<uint64_t>(exit.destinationId), true));
1291 caseDestinations.push_back(destBlock);
1292 caseOperands.push_back(mlir::ValueRange());
1293
1294 // Replace the original exit op with: store dest ID, branch to
1295 // cleanup.
1296 rewriter.setInsertionPoint(exit.exitOp);
1297 auto destIdConst = cir::ConstantOp::create(
1298 rewriter, loc, cir::IntAttr::get(s32Type, exit.destinationId));
1299 cir::StoreOp::create(rewriter, loc, destIdConst, destSlot,
1300 /*isVolatile=*/false,
1301 /*alignment=*/mlir::IntegerAttr(),
1302 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
1303 rewriter.replaceOpWithNewOp<cir::BrOp>(exit.exitOp, cleanupEntry);
1304
1305 // If the exit terminator creation failed, we're going to end up with
1306 // partially flattened code, but we'll also have reported an error so
1307 // that's OK. We need to finish out this function to keep the IR in a
1308 // valid state to help diagnose the error. This is a temporary
1309 // possibility during development. It shouldn't ever happen after the
1310 // implementation is complete.
1311 if (result.failed())
1312 break;
1313 }
1314
1315 // Create the default destination (unreachable).
1316 mlir::Block *defaultBlock = rewriter.createBlock(normalInsertPt);
1317 rewriter.setInsertionPointToEnd(defaultBlock);
1318 cir::UnreachableOp::create(rewriter, loc);
1319
1320 // Build the switch.flat operation in the exit block.
1321 rewriter.setInsertionPointToEnd(exitBlock);
1322 cir::SwitchFlatOp::create(rewriter, loc, slotValue, defaultBlock,
1323 mlir::ValueRange(), caseValues,
1324 caseDestinations, caseOperands);
1325 } else {
1326 // Single exit: put the appropriate terminator directly in the exit
1327 // block.
1328 rewriter.setInsertionPointToEnd(exitBlock);
1329 mlir::Operation *exitOp = exits[0].exitOp;
1330 result = createExitTerminator(exitOp, loc, continueBlock, rewriter);
1331
1332 // Replace body exit with branch to cleanup entry.
1333 rewriter.setInsertionPoint(exitOp);
1334 rewriter.replaceOpWithNewOp<cir::BrOp>(exitOp, cleanupEntry);
1335 }
1336 } else {
1337 // EH-only cleanup: normal exits skip the cleanup entirely.
1338 // Replace yield exits with branches to the continue block.
1339 for (CleanupExit &exit : exits) {
1340 if (isa<cir::YieldOp>(exit.exitOp)) {
1341 rewriter.setInsertionPoint(exit.exitOp);
1342 rewriter.replaceOpWithNewOp<cir::BrOp>(exit.exitOp, continueBlock);
1343 }
1344 // Non-yield exits (break, continue, return) stay as-is since no normal
1345 // cleanup is needed.
1346 }
1347 }
1348
1349 // Replace non-nothrow calls with try_call operations. All calls within
1350 // this cleanup scope share the same unwind destination.
1351 if (hasEHCleanup) {
1352 for (cir::CallOp callOp : callsToRewrite)
1353 replaceCallWithTryCall(callOp, unwindBlock, loc, rewriter);
1354 }
1355
1356 // Handle throwing calls in EH cleanup blocks. When an exception is thrown
1357 // during cleanup code that runs on the exception unwind path, the C++
1358 // standard requires that std::terminate() be called. Replace such calls
1359 // with try_call operations that unwind to a terminate block containing
1360 // cir.eh.initiate + cir.eh.terminate.
1361 if (ehCleanupEntry) {
1362 llvm::SmallVector<cir::CallOp> ehCleanupThrowingCalls;
1363 for (mlir::Block *block = ehCleanupEntry; block != continueBlock;
1364 block = block->getNextNode()) {
1365 block->walk([&](cir::CallOp callOp) {
1366 if (!callOp.getNothrow())
1367 ehCleanupThrowingCalls.push_back(callOp);
1368 });
1369 }
1370 if (!ehCleanupThrowingCalls.empty()) {
1371 mlir::Block *terminateBlock =
1372 buildTerminateUnwindBlock(loc, continueBlock, rewriter);
1373 for (cir::CallOp callOp : ehCleanupThrowingCalls)
1374 replaceCallWithTryCall(callOp, terminateBlock, loc, rewriter);
1375 }
1376 }
1377
1378 // Chain inner EH cleanup resume ops to this cleanup's EH handler.
1379 // Each cir.resume from an already-flattened inner cleanup is replaced
1380 // with a branch to the outer EH cleanup entry, passing the eh_token
1381 // from the inner's begin_cleanup so that the same in-flight exception
1382 // flows through the outer cleanup before unwinding to the caller.
1383 if (ehCleanupEntry) {
1384 for (cir::ResumeOp resumeOp : resumeOpsToChain) {
1385 mlir::Value ehToken = resumeOp.getEhToken();
1386 rewriter.setInsertionPoint(resumeOp);
1387 rewriter.replaceOpWithNewOp<cir::BrOp>(
1388 resumeOp, mlir::ValueRange{ehToken}, ehCleanupEntry);
1389 }
1390 }
1391
1392 // Erase the original cleanup scope op.
1393 rewriter.eraseOp(cleanupOp);
1394
1395 // Always return success because the IR has been modified (blocks split,
1396 // regions inlined, ops erased, etc.). The MLIR pattern rewriter contract
1397 // requires that if a pattern modifies IR, it must return success(). Any
1398 // errors from unsupported exit operations (e.g. goto) have already been
1399 // reported via emitError and an unreachable terminator was placed as a
1400 // placeholder.
1401 return mlir::success();
1402 }
1403
1404 mlir::LogicalResult
1405 matchAndRewrite(cir::CleanupScopeOp cleanupOp,
1406 mlir::PatternRewriter &rewriter) const override {
1407 mlir::OpBuilder::InsertionGuard guard(rewriter);
1408
1409 // Nested cleanup scopes and try operations must be flattened before the
1410 // enclosing cleanup scope so that EH cleanup inside them is properly
1411 // handled. Fail the match so the pattern rewriter processes them first.
1412 //
1413 // Before checking, erase any trivially dead nested cleanup scopes. These
1414 // arise from deactivated cleanups (e.g. partial-construction guards for
1415 // lambda captures). The greedy rewriter may have already DCE'd them, but
1416 // when a trivially dead nested op is erased first, the parent isn't always
1417 // re-added to the worklist, so we handle it here. These types of operations
1418 // will normally be removed by the canonicalizer, but we handle it here
1419 // also, because DCE can run between pattern matches in the current pass,
1420 // and if a trivially dead operation makes it this far, we will fail.
1421 llvm::SmallVector<cir::CleanupScopeOp> deadNestedOps;
1422 cleanupOp.getBodyRegion().walk([&](cir::CleanupScopeOp nested) {
1423 if (mlir::isOpTriviallyDead(nested))
1424 deadNestedOps.push_back(nested);
1425 });
1426 for (auto op : deadNestedOps)
1427 rewriter.eraseOp(op);
1428
1429 bool hasNestedOps = cleanupOp.getBodyRegion()
1430 .walk([&](mlir::Operation *op) {
1431 if (isa<cir::CleanupScopeOp, cir::TryOp>(op))
1432 return mlir::WalkResult::interrupt();
1433 return mlir::WalkResult::advance();
1434 })
1435 .wasInterrupted();
1436 if (hasNestedOps)
1437 return mlir::failure();
1438
1439 cir::CleanupKind cleanupKind = cleanupOp.getCleanupKind();
1440
1441 // Collect all exits from the body region.
1442 llvm::SmallVector<CleanupExit> exits;
1443 int nextId = 0;
1444 collectExits(cleanupOp.getBodyRegion(), exits, nextId);
1445
1446 assert(!exits.empty() && "cleanup scope body has no exit");
1447
1448 // Collect non-nothrow calls that need to be converted to try_call.
1449 // This is only needed for EH and All cleanup kinds, but the vector
1450 // will simply be empty for Normal cleanup.
1451 llvm::SmallVector<cir::CallOp> callsToRewrite;
1452 if (cleanupKind != cir::CleanupKind::Normal)
1453 collectThrowingCalls(cleanupOp.getBodyRegion(), callsToRewrite);
1454
1455 // Collect resume ops from already-flattened inner cleanup scopes that
1456 // need to chain through this cleanup's EH handler.
1457 llvm::SmallVector<cir::ResumeOp> resumeOpsToChain;
1458 if (cleanupKind != cir::CleanupKind::Normal)
1459 collectResumeOps(cleanupOp.getBodyRegion(), resumeOpsToChain);
1460
1461 return flattenCleanup(cleanupOp, exits, callsToRewrite, resumeOpsToChain,
1462 rewriter);
1463 }
1464};
1465
1466// Trace an !cir.eh_token value back through block arguments to find the
1467// cir.eh.initiate operation that defines it. Returns {} if the defining op
1468// cannot be found (e.g. multiple predecessors).
1469static cir::EhInitiateOp traceToEhInitiate(mlir::Value ehToken) {
1470 while (ehToken) {
1471 if (auto initiate = ehToken.getDefiningOp<cir::EhInitiateOp>())
1472 return initiate;
1473 auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(ehToken);
1474 if (!blockArg)
1475 return {};
1476 mlir::Block *pred = blockArg.getOwner()->getSinglePredecessor();
1477 if (!pred)
1478 return {};
1479 auto brOp = mlir::dyn_cast<cir::BrOp>(pred->getTerminator());
1480 if (!brOp)
1481 return {};
1482 ehToken = brOp.getDestOperands()[blockArg.getArgNumber()];
1483 }
1484 return {};
1485}
1486
1487class CIRTryOpFlattening : public mlir::OpRewritePattern<cir::TryOp> {
1488public:
1489 using OpRewritePattern<cir::TryOp>::OpRewritePattern;
1490
1491 // Build the catch dispatch block with a cir.eh.dispatch operation.
1492 // The dispatch block receives an !cir.eh_token argument and dispatches
1493 // to the appropriate catch handler blocks based on exception types.
1494 mlir::Block *buildCatchDispatchBlock(
1495 cir::TryOp tryOp, mlir::ArrayAttr handlerTypes,
1496 llvm::SmallVectorImpl<mlir::Block *> &catchHandlerBlocks,
1497 mlir::Location loc, mlir::Block *insertBefore,
1498 mlir::PatternRewriter &rewriter) const {
1499 mlir::Block *dispatchBlock = rewriter.createBlock(insertBefore);
1500 auto ehTokenType = cir::EhTokenType::get(rewriter.getContext());
1501 mlir::Value ehToken = dispatchBlock->addArgument(ehTokenType, loc);
1502
1503 rewriter.setInsertionPointToEnd(dispatchBlock);
1504
1505 // Build the catch types and destinations for the dispatch.
1506 llvm::SmallVector<mlir::Attribute> catchTypeAttrs;
1507 llvm::SmallVector<mlir::Block *> catchDests;
1508 mlir::Block *defaultDest = nullptr;
1509 bool defaultIsCatchAll = false;
1510
1511 for (auto [typeAttr, handlerBlock] :
1512 llvm::zip(handlerTypes, catchHandlerBlocks)) {
1513 if (mlir::isa<cir::CatchAllAttr>(typeAttr)) {
1514 assert(!defaultDest && "multiple catch_all or unwind handlers");
1515 defaultDest = handlerBlock;
1516 defaultIsCatchAll = true;
1517 } else if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
1518 assert(!defaultDest && "multiple catch_all or unwind handlers");
1519 defaultDest = handlerBlock;
1520 defaultIsCatchAll = false;
1521 } else {
1522 // This is a typed catch handler (GlobalViewAttr with type info).
1523 catchTypeAttrs.push_back(typeAttr);
1524 catchDests.push_back(handlerBlock);
1525 }
1526 }
1527
1528 assert(defaultDest && "dispatch must have a catch_all or unwind handler");
1529
1530 mlir::ArrayAttr catchTypesArrayAttr;
1531 if (!catchTypeAttrs.empty())
1532 catchTypesArrayAttr = rewriter.getArrayAttr(catchTypeAttrs);
1533
1534 cir::EhDispatchOp::create(rewriter, loc, ehToken, catchTypesArrayAttr,
1535 defaultIsCatchAll, defaultDest, catchDests);
1536
1537 return dispatchBlock;
1538 }
1539
1540 // Flatten a single catch handler region. Each handler region has an
1541 // !cir.eh_token argument and starts with cir.begin_catch, followed by
1542 // a cir.cleanup.scope containing the handler body (with cir.end_catch in
1543 // its cleanup region), and ending with cir.yield.
1544 //
1545 // After flattening, the handler region becomes a block that receives the
1546 // eh_token, calls begin_catch, runs the handler body inline, calls
1547 // end_catch, and branches to the continue block.
1548 //
1549 // The cleanup scope inside the catch handler is expected to have been
1550 // flattened before we get here, so what we see in the handler region is
1551 // already flat code with begin_catch at the top and end_catch in any place
1552 // that we would exit the catch handler. We just need to inline the region
1553 // and fix up terminators.
1554 mlir::Block *flattenCatchHandler(mlir::Region &handlerRegion,
1555 mlir::Block *continueBlock,
1556 mlir::Location loc,
1557 mlir::Block *insertBefore,
1558 mlir::PatternRewriter &rewriter) const {
1559 // The handler region entry block has the !cir.eh_token argument.
1560 mlir::Block *handlerEntry = &handlerRegion.front();
1561
1562 // Inline the handler region before insertBefore.
1563 rewriter.inlineRegionBefore(handlerRegion, insertBefore);
1564
1565 // Replace yield terminators in the handler with branches to continue.
1566 for (mlir::Block &block : llvm::make_range(handlerEntry->getIterator(),
1567 insertBefore->getIterator())) {
1568 if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator())) {
1569 // Verify that end_catch is the last non-branch operation before
1570 // this yield. After cleanup scope flattening, end_catch may be in
1571 // a predecessor block rather than immediately before the yield.
1572 // Walk back through the single-predecessor chain, verifying that
1573 // each intermediate block contains only a branch terminator, until
1574 // we find end_catch as the last non-terminator in some block.
1575 assert([&]() {
1576 // Check if end_catch immediately precedes the yield.
1577 if (mlir::Operation *prev = yieldOp->getPrevNode())
1578 return isa<cir::EndCatchOp>(prev);
1579 // The yield is alone in its block. Walk backward through
1580 // single-predecessor blocks that contain only a branch.
1581 mlir::Block *b = block.getSinglePredecessor();
1582 while (b) {
1583 mlir::Operation *term = b->getTerminator();
1584 if (mlir::Operation *prev = term->getPrevNode())
1585 return isa<cir::EndCatchOp>(prev);
1586 if (!isa<cir::BrOp>(term))
1587 return false;
1588 b = b->getSinglePredecessor();
1589 }
1590 return false;
1591 }() && "expected end_catch as last operation before yield "
1592 "in catch handler, with only branches in between");
1593 rewriter.setInsertionPoint(yieldOp);
1594 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, continueBlock);
1595 }
1596 }
1597
1598 return handlerEntry;
1599 }
1600
1601 // Flatten an unwind handler region. The unwind region just contains a
1602 // cir.resume that continues unwinding. We inline it and leave the resume
1603 // in place. If this try op is nested inside an EH cleanup or another try op,
1604 // the enclosing op will rewrite the resume as a branch to its cleanup or
1605 // dispatch block when it is flattened. Otherwise, the resume will unwind to
1606 // the caller.
1607 mlir::Block *flattenUnwindHandler(mlir::Region &unwindRegion,
1608 mlir::Location loc,
1609 mlir::Block *insertBefore,
1610 mlir::PatternRewriter &rewriter) const {
1611 mlir::Block *unwindEntry = &unwindRegion.front();
1612 rewriter.inlineRegionBefore(unwindRegion, insertBefore);
1613 return unwindEntry;
1614 }
1615
1616 mlir::LogicalResult
1617 matchAndRewrite(cir::TryOp tryOp,
1618 mlir::PatternRewriter &rewriter) const override {
1619 // Nested try ops and cleanup scopes must be flattened before the enclosing
1620 // try so that EH cleanup inside them is properly handled. Fail the match so
1621 // the pattern rewriter will process nested ops first.
1622 bool hasNestedOps =
1623 tryOp
1624 ->walk([&](mlir::Operation *op) {
1625 if (isa<cir::CleanupScopeOp, cir::TryOp>(op) && op != tryOp)
1626 return mlir::WalkResult::interrupt();
1627 return mlir::WalkResult::advance();
1628 })
1629 .wasInterrupted();
1630 if (hasNestedOps)
1631 return mlir::failure();
1632
1633 mlir::OpBuilder::InsertionGuard guard(rewriter);
1634 mlir::Location loc = tryOp.getLoc();
1635
1636 mlir::ArrayAttr handlerTypes = tryOp.getHandlerTypesAttr();
1637 mlir::MutableArrayRef<mlir::Region> handlerRegions =
1638 tryOp.getHandlerRegions();
1639
1640 // Collect throwing calls in the try body.
1641 llvm::SmallVector<cir::CallOp> callsToRewrite;
1642 collectThrowingCalls(tryOp.getTryRegion(), callsToRewrite);
1643
1644 // Collect resume ops from already-flattened cleanup scopes in the try body.
1645 llvm::SmallVector<cir::ResumeOp> resumeOpsToChain;
1646 collectResumeOps(tryOp.getTryRegion(), resumeOpsToChain);
1647
1648 // Split the current block and inline the try body.
1649 mlir::Block *currentBlock = rewriter.getInsertionBlock();
1650 mlir::Block *continueBlock =
1651 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
1652
1653 // Get references to try body blocks before inlining.
1654 mlir::Block *bodyEntry = &tryOp.getTryRegion().front();
1655 mlir::Block *bodyExit = &tryOp.getTryRegion().back();
1656
1657 // Inline the try body region before the continue block.
1658 rewriter.inlineRegionBefore(tryOp.getTryRegion(), continueBlock);
1659
1660 // Branch from the current block to the body entry.
1661 rewriter.setInsertionPointToEnd(currentBlock);
1662 cir::BrOp::create(rewriter, loc, bodyEntry);
1663
1664 // Replace the try body's yield terminator with a branch to continue.
1665 if (auto bodyYield = dyn_cast<cir::YieldOp>(bodyExit->getTerminator())) {
1666 rewriter.setInsertionPoint(bodyYield);
1667 rewriter.replaceOpWithNewOp<cir::BrOp>(bodyYield, continueBlock);
1668 }
1669
1670 // If there are no handlers, we're done.
1671 if (!handlerTypes || handlerTypes.empty()) {
1672 rewriter.eraseOp(tryOp);
1673 return mlir::success();
1674 }
1675
1676 // If there are no throwing calls and no resume ops from inner cleanup
1677 // scopes, exceptions cannot reach the catch handlers. Skip handler and
1678 // dispatch block creation — the handler regions will be dropped when
1679 // the try op is erased.
1680 if (callsToRewrite.empty() && resumeOpsToChain.empty()) {
1681 rewriter.eraseOp(tryOp);
1682 return mlir::success();
1683 }
1684
1685 // Build the catch handler blocks.
1686
1687 // First, flatten all handler regions and collect the entry blocks.
1688 llvm::SmallVector<mlir::Block *> catchHandlerBlocks;
1689
1690 for (const auto &[idx, typeAttr] : llvm::enumerate(handlerTypes)) {
1691 mlir::Region &handlerRegion = handlerRegions[idx];
1692
1693 if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
1694 mlir::Block *unwindEntry =
1695 flattenUnwindHandler(handlerRegion, loc, continueBlock, rewriter);
1696 catchHandlerBlocks.push_back(unwindEntry);
1697 } else {
1698 mlir::Block *handlerEntry = flattenCatchHandler(
1699 handlerRegion, continueBlock, loc, continueBlock, rewriter);
1700 catchHandlerBlocks.push_back(handlerEntry);
1701 }
1702 }
1703
1704 // Build the catch dispatch block.
1705 mlir::Block *dispatchBlock =
1706 buildCatchDispatchBlock(tryOp, handlerTypes, catchHandlerBlocks, loc,
1707 catchHandlerBlocks.front(), rewriter);
1708
1709 // Check whether the try has a catch-all handler. When catch-all is
1710 // present, the personality function will always stop unwinding at this
1711 // frame (because catch-all matches every exception type). The LLVM
1712 // landingpad therefore needs "catch ptr null" rather than "cleanup".
1713 // The downstream pipeline (EHABILowering + LowerToLLVM) emits
1714 // "catch ptr null" when the EhInitiateOp has neither cleanup nor typed
1715 // catch types, so we clear the cleanup flag on every EhInitiateOp that
1716 // feeds into a dispatch with a catch-all handler.
1717 bool hasCatchAll =
1718 handlerTypes && llvm::any_of(handlerTypes, [](mlir::Attribute attr) {
1719 return mlir::isa<cir::CatchAllAttr>(attr);
1720 });
1721
1722 // Build a block to be the unwind desination for throwing calls and replace
1723 // the calls with try_call ops. Note that the unwind block created here is
1724 // something different than the unwind handler that we may have created
1725 // above. The unwind handler continues unwinding after uncaught exceptions.
1726 // This is the block that will eventually become the landing pad for invoke
1727 // instructions.
1728 bool isCleanupOnly = tryOp.getCleanup() && !hasCatchAll;
1729 if (!callsToRewrite.empty()) {
1730 // Create a shared unwind block for all throwing calls.
1731 mlir::Block *unwindBlock = buildUnwindBlock(dispatchBlock, isCleanupOnly,
1732 loc, dispatchBlock, rewriter);
1733
1734 for (cir::CallOp callOp : callsToRewrite)
1735 replaceCallWithTryCall(callOp, unwindBlock, loc, rewriter);
1736 }
1737
1738 // Chain resume ops from inner cleanup scopes.
1739 // Resume ops from already-flattened cleanup scopes within the try body
1740 // should branch to the catch dispatch block instead of unwinding directly.
1741 for (cir::ResumeOp resumeOp : resumeOpsToChain) {
1742 // When there is a catch-all handler, clear the cleanup flag on the
1743 // cir.eh.initiate that produced this token. With catch-all, the LLVM
1744 // landingpad needs "catch ptr null" instead of "cleanup".
1745 if (hasCatchAll) {
1746 if (auto ehInitiate = traceToEhInitiate(resumeOp.getEhToken()))
1747 ehInitiate.removeCleanupAttr();
1748 }
1749
1750 mlir::Value ehToken = resumeOp.getEhToken();
1751 rewriter.setInsertionPoint(resumeOp);
1752 rewriter.replaceOpWithNewOp<cir::BrOp>(
1753 resumeOp, mlir::ValueRange{ehToken}, dispatchBlock);
1754 }
1755
1756 // Finally, erase the original try op ----
1757 rewriter.eraseOp(tryOp);
1758
1759 return mlir::success();
1760 }
1761};
1762
1763void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
1764 patterns
1765 .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
1766 CIRSwitchOpFlattening, CIRTernaryOpFlattening,
1767 CIRCleanupScopeOpFlattening, CIRTryOpFlattening>(
1768 patterns.getContext());
1769}
1770
1771void CIRFlattenCFGPass::runOnOperation() {
1772 RewritePatternSet patterns(&getContext());
1773 populateFlattenCFGPatterns(patterns);
1774
1775 // Collect operations to apply patterns.
1776 llvm::SmallVector<Operation *, 16> ops;
1777 getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
1778 if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp, CleanupScopeOp,
1779 TryOp>(op))
1780 ops.push_back(op);
1781 });
1782
1783 // Apply patterns.
1784 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
1785 signalPassFailure();
1786}
1787
1788} // namespace
1789
1790namespace mlir {
1791
1792std::unique_ptr<Pass> createCIRFlattenCFGPass() {
1793 return std::make_unique<CIRFlattenCFGPass>();
1794}
1795
1796} // namespace mlir
__device__ __2f16 b
const internal::VariadicAllOfMatcher< Attr > attr
llvm::APInt APInt
Definition FixedPoint.h:19
ASTEdit insertBefore(RangeSelector S, TextGenerator Replacement)
Inserts Replacement before S, leaving the source selected by \S unchanged.
unsigned long uint64_t
std::unique_ptr< Pass > createCIRFlattenCFGPass()
int const char * function
Definition c++config.h:31
float __ovld __cnfn step(float, float)
Returns 0.0 if x < edge, otherwise it returns 1.0.
static bool stackSaveOp()