clang 23.0.0git
FlattenCFG.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements pass that inlines CIR operations regions into the parent
10// function region.
11//
12//===----------------------------------------------------------------------===//
13
14#include "PassDetail.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/IR/Block.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Support/LogicalResult.h"
20#include "mlir/Transforms/DialectConversion.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26#include "llvm/ADT/TypeSwitch.h"
27
28using namespace mlir;
29using namespace cir;
30
31namespace mlir {
32#define GEN_PASS_DEF_CIRFLATTENCFG
33#include "clang/CIR/Dialect/Passes.h.inc"
34} // namespace mlir
35
36namespace {
37
38/// Lowers operations with the terminator trait that have a single successor.
39void lowerTerminator(mlir::Operation *op, mlir::Block *dest,
40 mlir::PatternRewriter &rewriter) {
41 assert(op->hasTrait<mlir::OpTrait::IsTerminator>() && "not a terminator");
42 mlir::OpBuilder::InsertionGuard guard(rewriter);
43 rewriter.setInsertionPoint(op);
44 rewriter.replaceOpWithNewOp<cir::BrOp>(op, dest);
45}
46
47/// Walks a region while skipping operations of type `Ops`. This ensures the
48/// callback is not applied to said operations and its children.
49template <typename... Ops>
50void walkRegionSkipping(
51 mlir::Region &region,
52 mlir::function_ref<mlir::WalkResult(mlir::Operation *)> callback) {
53 region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
54 if (isa<Ops...>(op))
55 return mlir::WalkResult::skip();
56 return callback(op);
57 });
58}
59
60struct CIRFlattenCFGPass : public impl::CIRFlattenCFGBase<CIRFlattenCFGPass> {
61
62 CIRFlattenCFGPass() = default;
63 void runOnOperation() override;
64};
65
66struct CIRIfFlattening : public mlir::OpRewritePattern<cir::IfOp> {
67 using OpRewritePattern<IfOp>::OpRewritePattern;
68
69 mlir::LogicalResult
70 matchAndRewrite(cir::IfOp ifOp,
71 mlir::PatternRewriter &rewriter) const override {
72 mlir::OpBuilder::InsertionGuard guard(rewriter);
73 mlir::Location loc = ifOp.getLoc();
74 bool emptyElse = ifOp.getElseRegion().empty();
75 mlir::Block *currentBlock = rewriter.getInsertionBlock();
76 mlir::Block *remainingOpsBlock =
77 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
78 mlir::Block *continueBlock;
79 if (ifOp->getResults().empty())
80 continueBlock = remainingOpsBlock;
81 else
82 llvm_unreachable("NYI");
83
84 // Inline the region
85 mlir::Block *thenBeforeBody = &ifOp.getThenRegion().front();
86 mlir::Block *thenAfterBody = &ifOp.getThenRegion().back();
87 rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock);
88
89 rewriter.setInsertionPointToEnd(thenAfterBody);
90 if (auto thenYieldOp =
91 dyn_cast<cir::YieldOp>(thenAfterBody->getTerminator())) {
92 rewriter.replaceOpWithNewOp<cir::BrOp>(thenYieldOp, thenYieldOp.getArgs(),
93 continueBlock);
94 }
95
96 rewriter.setInsertionPointToEnd(continueBlock);
97
98 // Has else region: inline it.
99 mlir::Block *elseBeforeBody = nullptr;
100 mlir::Block *elseAfterBody = nullptr;
101 if (!emptyElse) {
102 elseBeforeBody = &ifOp.getElseRegion().front();
103 elseAfterBody = &ifOp.getElseRegion().back();
104 rewriter.inlineRegionBefore(ifOp.getElseRegion(), continueBlock);
105 } else {
106 elseBeforeBody = elseAfterBody = continueBlock;
107 }
108
109 rewriter.setInsertionPointToEnd(currentBlock);
110 cir::BrCondOp::create(rewriter, loc, ifOp.getCondition(), thenBeforeBody,
111 elseBeforeBody);
112
113 if (!emptyElse) {
114 rewriter.setInsertionPointToEnd(elseAfterBody);
115 if (auto elseYieldOP =
116 dyn_cast<cir::YieldOp>(elseAfterBody->getTerminator())) {
117 rewriter.replaceOpWithNewOp<cir::BrOp>(
118 elseYieldOP, elseYieldOP.getArgs(), continueBlock);
119 }
120 }
121
122 rewriter.replaceOp(ifOp, continueBlock->getArguments());
123 return mlir::success();
124 }
125};
126
127class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> {
128public:
129 using OpRewritePattern<cir::ScopeOp>::OpRewritePattern;
130
131 mlir::LogicalResult
132 matchAndRewrite(cir::ScopeOp scopeOp,
133 mlir::PatternRewriter &rewriter) const override {
134 mlir::OpBuilder::InsertionGuard guard(rewriter);
135 mlir::Location loc = scopeOp.getLoc();
136
137 // Empty scope: just remove it.
138 // TODO: Remove this logic once CIR uses MLIR infrastructure to remove
139 // trivially dead operations. MLIR canonicalizer is too aggressive and we
140 // need to either (a) make sure all our ops model all side-effects and/or
141 // (b) have more options in the canonicalizer in MLIR to temper
142 // aggressiveness level.
143 if (scopeOp.isEmpty()) {
144 rewriter.eraseOp(scopeOp);
145 return mlir::success();
146 }
147
148 // Split the current block before the ScopeOp to create the inlining
149 // point.
150 mlir::Block *currentBlock = rewriter.getInsertionBlock();
151 mlir::Block *continueBlock =
152 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
153 if (scopeOp.getNumResults() > 0)
154 continueBlock->addArguments(scopeOp.getResultTypes(), loc);
155
156 // Inline body region.
157 mlir::Block *beforeBody = &scopeOp.getScopeRegion().front();
158 mlir::Block *afterBody = &scopeOp.getScopeRegion().back();
159 rewriter.inlineRegionBefore(scopeOp.getScopeRegion(), continueBlock);
160
161 // Save stack and then branch into the body of the region.
162 rewriter.setInsertionPointToEnd(currentBlock);
164 cir::BrOp::create(rewriter, loc, mlir::ValueRange(), beforeBody);
165
166 // Replace the scopeop return with a branch that jumps out of the body.
167 // Stack restore before leaving the body region.
168 rewriter.setInsertionPointToEnd(afterBody);
169 if (auto yieldOp = dyn_cast<cir::YieldOp>(afterBody->getTerminator())) {
170 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(),
171 continueBlock);
172 }
173
174 // Replace the op with values return from the body region.
175 rewriter.replaceOp(scopeOp, continueBlock->getArguments());
176
177 return mlir::success();
178 }
179};
180
181class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
182public:
183 using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
184
185 inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
186 cir::YieldOp yieldOp,
187 mlir::Block *destination) const {
188 rewriter.setInsertionPoint(yieldOp);
189 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
190 destination);
191 }
192
193 // Return the new defaultDestination block.
194 Block *condBrToRangeDestination(cir::SwitchOp op,
195 mlir::PatternRewriter &rewriter,
196 mlir::Block *rangeDestination,
197 mlir::Block *defaultDestination,
198 const APInt &lowerBound,
199 const APInt &upperBound) const {
200 assert(lowerBound.sle(upperBound) && "Invalid range");
201 mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
202 cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true);
203 cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false);
204
205 cir::ConstantOp rangeLength = cir::ConstantOp::create(
206 rewriter, op.getLoc(),
207 cir::IntAttr::get(sIntType, upperBound - lowerBound));
208
209 cir::ConstantOp lowerBoundValue = cir::ConstantOp::create(
210 rewriter, op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
211 mlir::Value diffValue = cir::SubOp::create(
212 rewriter, op.getLoc(), op.getCondition(), lowerBoundValue);
213
214 // Use unsigned comparison to check if the condition is in the range.
215 cir::CastOp uDiffValue = cir::CastOp::create(
216 rewriter, op.getLoc(), uIntType, CastKind::integral, diffValue);
217 cir::CastOp uRangeLength = cir::CastOp::create(
218 rewriter, op.getLoc(), uIntType, CastKind::integral, rangeLength);
219
220 cir::CmpOp cmpResult = cir::CmpOp::create(
221 rewriter, op.getLoc(), cir::CmpOpKind::le, uDiffValue, uRangeLength);
222 cir::BrCondOp::create(rewriter, op.getLoc(), cmpResult, rangeDestination,
223 defaultDestination);
224 return resBlock;
225 }
226
227 mlir::LogicalResult
228 matchAndRewrite(cir::SwitchOp op,
229 mlir::PatternRewriter &rewriter) const override {
230 // Cleanup scopes must be lowered before the enclosing switch so that
231 // break inside them is properly routed through cleanup.
232 // Fail the match so the pattern rewriter will process cleanup scopes first.
233 bool hasNestedCleanup = op->walk([&](cir::CleanupScopeOp) {
234 return mlir::WalkResult::interrupt();
235 }).wasInterrupted();
236 if (hasNestedCleanup)
237 return mlir::failure();
238
239 llvm::SmallVector<CaseOp> cases;
240 op.collectCases(cases);
241
242 // Empty switch statement: just erase it.
243 if (cases.empty()) {
244 rewriter.eraseOp(op);
245 return mlir::success();
246 }
247
248 // Create exit block from the next node of cir.switch op.
249 mlir::Block *exitBlock = rewriter.splitBlock(
250 rewriter.getBlock(), op->getNextNode()->getIterator());
251
252 // We lower cir.switch op in the following process:
253 // 1. Inline the region from the switch op after switch op.
254 // 2. Traverse each cir.case op:
255 // a. Record the entry block, block arguments and condition for every
256 // case. b. Inline the case region after the case op.
257 // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the
258 // recorded block and conditions.
259
260 // inline everything from switch body between the switch op and the exit
261 // block.
262 {
263 cir::YieldOp switchYield = nullptr;
264 // Clear switch operation.
265 for (mlir::Block &block :
266 llvm::make_early_inc_range(op.getBody().getBlocks()))
267 if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
268 switchYield = yieldOp;
269
270 assert(!op.getBody().empty());
271 mlir::Block *originalBlock = op->getBlock();
272 mlir::Block *swopBlock =
273 rewriter.splitBlock(originalBlock, op->getIterator());
274 rewriter.inlineRegionBefore(op.getBody(), exitBlock);
275
276 if (switchYield)
277 rewriteYieldOp(rewriter, switchYield, exitBlock);
278
279 rewriter.setInsertionPointToEnd(originalBlock);
280 cir::BrOp::create(rewriter, op.getLoc(), swopBlock);
281 }
282
283 // Allocate required data structures (disconsider default case in
284 // vectors).
285 llvm::SmallVector<mlir::APInt, 8> caseValues;
286 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
287 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
288
289 llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
290 llvm::SmallVector<mlir::Block *> rangeDestinations;
291 llvm::SmallVector<mlir::ValueRange> rangeOperands;
292
293 // Initialize default case as optional.
294 mlir::Block *defaultDestination = exitBlock;
295 mlir::ValueRange defaultOperands = exitBlock->getArguments();
296
297 // Digest the case statements values and bodies.
298 for (cir::CaseOp caseOp : cases) {
299 mlir::Region &region = caseOp.getCaseRegion();
300
301 // Found default case: save destination and operands.
302 switch (caseOp.getKind()) {
303 case cir::CaseOpKind::Default:
304 defaultDestination = &region.front();
305 defaultOperands = defaultDestination->getArguments();
306 break;
307 case cir::CaseOpKind::Range:
308 assert(caseOp.getValue().size() == 2 &&
309 "Case range should have 2 case value");
310 rangeValues.push_back(
311 {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),
312 cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});
313 rangeDestinations.push_back(&region.front());
314 rangeOperands.push_back(rangeDestinations.back()->getArguments());
315 break;
316 case cir::CaseOpKind::Anyof:
317 case cir::CaseOpKind::Equal:
318 // AnyOf cases kind can have multiple values, hence the loop below.
319 for (const mlir::Attribute &value : caseOp.getValue()) {
320 caseValues.push_back(cast<cir::IntAttr>(value).getValue());
321 caseDestinations.push_back(&region.front());
322 caseOperands.push_back(caseDestinations.back()->getArguments());
323 }
324 break;
325 }
326
327 // Handle break statements.
328 walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
329 region, [&](mlir::Operation *op) {
330 if (!isa<cir::BreakOp>(op))
331 return mlir::WalkResult::advance();
332
333 lowerTerminator(op, exitBlock, rewriter);
334 return mlir::WalkResult::skip();
335 });
336
337 // Track fallthrough in cases.
338 for (mlir::Block &blk : region.getBlocks()) {
339 if (blk.getNumSuccessors())
340 continue;
341
342 if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {
343 mlir::Operation *nextOp = caseOp->getNextNode();
344 assert(nextOp && "caseOp is not expected to be the last op");
345 mlir::Block *oldBlock = nextOp->getBlock();
346 mlir::Block *newBlock =
347 rewriter.splitBlock(oldBlock, nextOp->getIterator());
348 rewriter.setInsertionPointToEnd(oldBlock);
349 cir::BrOp::create(rewriter, nextOp->getLoc(), mlir::ValueRange(),
350 newBlock);
351 rewriteYieldOp(rewriter, yieldOp, newBlock);
352 }
353 }
354
355 mlir::Block *oldBlock = caseOp->getBlock();
356 mlir::Block *newBlock =
357 rewriter.splitBlock(oldBlock, caseOp->getIterator());
358
359 mlir::Block &entryBlock = caseOp.getCaseRegion().front();
360 rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);
361
362 // Create a branch to the entry of the inlined region.
363 rewriter.setInsertionPointToEnd(oldBlock);
364 cir::BrOp::create(rewriter, caseOp.getLoc(), &entryBlock);
365 }
366
367 // Remove all cases since we've inlined the regions.
368 for (cir::CaseOp caseOp : cases) {
369 mlir::Block *caseBlock = caseOp->getBlock();
370 // Erase the block with no predecessors here to make the generated code
371 // simpler a little bit.
372 if (caseBlock->hasNoPredecessors())
373 rewriter.eraseBlock(caseBlock);
374 else
375 rewriter.eraseOp(caseOp);
376 }
377
378 for (auto [rangeVal, operand, destination] :
379 llvm::zip(rangeValues, rangeOperands, rangeDestinations)) {
380 APInt lowerBound = rangeVal.first;
381 APInt upperBound = rangeVal.second;
382
383 // The case range is unreachable, skip it.
384 if (lowerBound.sgt(upperBound))
385 continue;
386
387 // If range is small, add multiple switch instruction cases.
388 // This magical number is from the original CGStmt code.
389 constexpr int kSmallRangeThreshold = 64;
390 if ((upperBound - lowerBound)
391 .ult(llvm::APInt(32, kSmallRangeThreshold))) {
392 for (APInt iValue = lowerBound; iValue.sle(upperBound); ++iValue) {
393 caseValues.push_back(iValue);
394 caseOperands.push_back(operand);
395 caseDestinations.push_back(destination);
396 }
397 continue;
398 }
399
400 defaultDestination =
401 condBrToRangeDestination(op, rewriter, destination,
402 defaultDestination, lowerBound, upperBound);
403 defaultOperands = operand;
404 }
405
406 // Set switch op to branch to the newly created blocks.
407 rewriter.setInsertionPoint(op);
408 rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>(
409 op, op.getCondition(), defaultDestination, defaultOperands, caseValues,
410 caseDestinations, caseOperands);
411
412 return mlir::success();
413 }
414};
415
416class CIRLoopOpInterfaceFlattening
417 : public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> {
418public:
419 using mlir::OpInterfaceRewritePattern<
420 cir::LoopOpInterface>::OpInterfaceRewritePattern;
421
422 inline void lowerConditionOp(cir::ConditionOp op, mlir::Block *body,
423 mlir::Block *exit,
424 mlir::PatternRewriter &rewriter) const {
425 mlir::OpBuilder::InsertionGuard guard(rewriter);
426 rewriter.setInsertionPoint(op);
427 rewriter.replaceOpWithNewOp<cir::BrCondOp>(op, op.getCondition(), body,
428 exit);
429 }
430
431 mlir::LogicalResult
432 matchAndRewrite(cir::LoopOpInterface op,
433 mlir::PatternRewriter &rewriter) const final {
434 // Cleanup scopes must be lowered before the enclosing loop so that
435 // break/continue inside them are properly routed through cleanup.
436 // Fail the match so the pattern rewriter will process cleanup scopes first.
437 bool hasNestedCleanup = op->walk([&](cir::CleanupScopeOp) {
438 return mlir::WalkResult::interrupt();
439 }).wasInterrupted();
440 if (hasNestedCleanup)
441 return mlir::failure();
442
443 // Setup CFG blocks.
444 mlir::Block *entry = rewriter.getInsertionBlock();
445 mlir::Block *exit =
446 rewriter.splitBlock(entry, rewriter.getInsertionPoint());
447 mlir::Block *cond = &op.getCond().front();
448 mlir::Block *body = &op.getBody().front();
449 mlir::Block *step =
450 (op.maybeGetStep() ? &op.maybeGetStep()->front() : nullptr);
451
452 // Setup loop entry branch.
453 rewriter.setInsertionPointToEnd(entry);
454 cir::BrOp::create(rewriter, op.getLoc(), &op.getEntry().front());
455
456 // Branch from condition region to body or exit. The ConditionOp may not
457 // be in the first block of the condition region if a cleanup scope was
458 // already flattened within it, introducing multiple blocks. The
459 // ConditionOp is always the terminator of the last block.
460 auto conditionOp =
461 cast<cir::ConditionOp>(op.getCond().back().getTerminator());
462 lowerConditionOp(conditionOp, body, exit, rewriter);
463
464 // TODO(cir): Remove the walks below. It visits operations unnecessarily.
465 // However, to solve this we would likely need a custom DialectConversion
466 // driver to customize the order that operations are visited.
467
468 // Lower continue statements.
469 mlir::Block *dest = (step ? step : cond);
470 op.walkBodySkippingNestedLoops([&](mlir::Operation *op) {
471 if (!isa<cir::ContinueOp>(op))
472 return mlir::WalkResult::advance();
473
474 lowerTerminator(op, dest, rewriter);
475 return mlir::WalkResult::skip();
476 });
477
478 // Lower break statements.
479 walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
480 op.getBody(), [&](mlir::Operation *op) {
481 if (!isa<cir::BreakOp>(op))
482 return mlir::WalkResult::advance();
483
484 lowerTerminator(op, exit, rewriter);
485 return mlir::WalkResult::skip();
486 });
487
488 // Lower optional body region yield.
489 for (mlir::Block &blk : op.getBody().getBlocks()) {
490 auto bodyYield = dyn_cast<cir::YieldOp>(blk.getTerminator());
491 if (bodyYield)
492 lowerTerminator(bodyYield, (step ? step : cond), rewriter);
493 }
494
495 // Lower mandatory step region yield. Like the condition region, the
496 // YieldOp may be in the last block rather than the first if a cleanup
497 // scope was already flattened within the step region.
498 if (step)
499 lowerTerminator(
500 cast<cir::YieldOp>(op.maybeGetStep()->back().getTerminator()), cond,
501 rewriter);
502
503 // Move region contents out of the loop op.
504 rewriter.inlineRegionBefore(op.getCond(), exit);
505 rewriter.inlineRegionBefore(op.getBody(), exit);
506 if (step)
507 rewriter.inlineRegionBefore(*op.maybeGetStep(), exit);
508
509 rewriter.eraseOp(op);
510 return mlir::success();
511 }
512};
513
514class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
515public:
516 using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
517
518 mlir::LogicalResult
519 matchAndRewrite(cir::TernaryOp op,
520 mlir::PatternRewriter &rewriter) const override {
521 Location loc = op->getLoc();
522 Block *condBlock = rewriter.getInsertionBlock();
523 Block::iterator opPosition = rewriter.getInsertionPoint();
524 Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
525 llvm::SmallVector<mlir::Location, 2> locs;
526 // Ternary result is optional, make sure to populate the location only
527 // when relevant.
528 if (op->getResultTypes().size())
529 locs.push_back(loc);
530 Block *continueBlock =
531 rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
532 cir::BrOp::create(rewriter, loc, remainingOpsBlock);
533
534 Region &trueRegion = op.getTrueRegion();
535 Block *trueBlock = &trueRegion.front();
536 mlir::Operation *trueTerminator = trueRegion.back().getTerminator();
537 rewriter.setInsertionPointToEnd(&trueRegion.back());
538
539 // Handle both yield and unreachable terminators (throw expressions)
540 if (auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator)) {
541 rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
542 continueBlock);
543 } else if (isa<cir::UnreachableOp>(trueTerminator)) {
544 // Terminator is unreachable (e.g., from throw), just keep it
545 } else {
546 trueTerminator->emitError("unexpected terminator in ternary true region, "
547 "expected yield or unreachable, got: ")
548 << trueTerminator->getName();
549 return mlir::failure();
550 }
551 rewriter.inlineRegionBefore(trueRegion, continueBlock);
552
553 Block *falseBlock = continueBlock;
554 Region &falseRegion = op.getFalseRegion();
555
556 falseBlock = &falseRegion.front();
557 mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
558 rewriter.setInsertionPointToEnd(&falseRegion.back());
559
560 // Handle both yield and unreachable terminators (throw expressions)
561 if (auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator)) {
562 rewriter.replaceOpWithNewOp<cir::BrOp>(
563 falseYieldOp, falseYieldOp.getArgs(), continueBlock);
564 } else if (isa<cir::UnreachableOp>(falseTerminator)) {
565 // Terminator is unreachable (e.g., from throw), just keep it
566 } else {
567 falseTerminator->emitError("unexpected terminator in ternary false "
568 "region, expected yield or unreachable, got: ")
569 << falseTerminator->getName();
570 return mlir::failure();
571 }
572 rewriter.inlineRegionBefore(falseRegion, continueBlock);
573
574 rewriter.setInsertionPointToEnd(condBlock);
575 cir::BrCondOp::create(rewriter, loc, op.getCond(), trueBlock, falseBlock);
576
577 rewriter.replaceOp(op, continueBlock->getArguments());
578
579 // Ok, we're done!
580 return mlir::success();
581 }
582};
583
584// Get or create the cleanup destination slot for a function. This slot is
585// shared across all cleanup scopes in the function to track which exit path
586// to take after running cleanup code when there are multiple exits.
587static cir::AllocaOp getOrCreateCleanupDestSlot(cir::FuncOp funcOp,
588 mlir::PatternRewriter &rewriter,
589 mlir::Location loc) {
590 mlir::Block &entryBlock = funcOp.getBody().front();
591
592 // Look for an existing cleanup dest slot in the entry block.
593 auto it = llvm::find_if(entryBlock, [](auto &op) {
594 return mlir::isa<AllocaOp>(&op) &&
595 mlir::cast<AllocaOp>(&op).getCleanupDestSlot();
596 });
597 if (it != entryBlock.end())
598 return mlir::cast<cir::AllocaOp>(*it);
599
600 // Create a new cleanup dest slot at the start of the entry block.
601 mlir::OpBuilder::InsertionGuard guard(rewriter);
602 rewriter.setInsertionPointToStart(&entryBlock);
603 cir::IntType s32Type =
604 cir::IntType::get(rewriter.getContext(), 32, /*isSigned=*/true);
605 cir::PointerType ptrToS32Type = cir::PointerType::get(s32Type);
606 cir::CIRDataLayout dataLayout(funcOp->getParentOfType<mlir::ModuleOp>());
607 uint64_t alignment = dataLayout.getAlignment(s32Type, true).value();
608 auto allocaOp = cir::AllocaOp::create(
609 rewriter, loc, ptrToS32Type, s32Type, "__cleanup_dest_slot",
610 /*alignment=*/rewriter.getI64IntegerAttr(alignment));
611 allocaOp.setCleanupDestSlot(true);
612 return allocaOp;
613}
614
615/// Shared EH flattening utilities used by both CIRCleanupScopeOpFlattening
616/// and CIRTryOpFlattening.
617
618// Collect all function calls in a region that may throw exceptions and need
619// to be replaced with try_call operations. Skips calls marked nothrow.
620// Nested cleanup scopes and try ops are always flattened before their
621// enclosing parents, so there are no nested regions to skip here.
622static void
623collectThrowingCalls(mlir::Region &region,
624 llvm::SmallVectorImpl<cir::CallOp> &callsToRewrite) {
625 region.walk([&](cir::CallOp callOp) {
626 if (!callOp.getNothrow())
627 callsToRewrite.push_back(callOp);
628 });
629}
630
631// Collect all cir.resume operations in a region that come from
632// already-flattened try or cleanup scope operations. These resume ops need
633// to be chained through this scope's EH handler instead of unwinding
634// directly to the caller. Nested cleanup scopes and try ops are always
635// flattened before their enclosing parents, so there are no nested regions
636// to skip here.
637static void collectResumeOps(mlir::Region &region,
639 region.walk([&](cir::ResumeOp resumeOp) { resumeOps.push_back(resumeOp); });
640}
641
642// Replace a cir.call with a cir.try_call that unwinds to the `unwindDest`
643// block if an exception is thrown.
644static void replaceCallWithTryCall(cir::CallOp callOp, mlir::Block *unwindDest,
645 mlir::Location loc,
646 mlir::PatternRewriter &rewriter) {
647 mlir::Block *callBlock = callOp->getBlock();
648
649 assert(!callOp.getNothrow() && "call is not expected to throw");
650
651 // Split the block after the call - remaining ops become the normal
652 // destination.
653 mlir::Block *normalDest =
654 rewriter.splitBlock(callBlock, std::next(callOp->getIterator()));
655
656 // Build the try_call to replace the original call.
657 rewriter.setInsertionPoint(callOp);
658 cir::TryCallOp tryCallOp;
659 if (callOp.isIndirect()) {
660 mlir::Value indTarget = callOp.getIndirectCall();
661 auto ptrTy = mlir::cast<cir::PointerType>(indTarget.getType());
662 auto resTy = mlir::cast<cir::FuncType>(ptrTy.getPointee());
663 tryCallOp =
664 cir::TryCallOp::create(rewriter, loc, indTarget, resTy, normalDest,
665 unwindDest, callOp.getArgOperands());
666 } else {
667 mlir::Type resType = callOp->getNumResults() > 0
668 ? callOp->getResult(0).getType()
669 : mlir::Type();
670 tryCallOp =
671 cir::TryCallOp::create(rewriter, loc, callOp.getCalleeAttr(), resType,
672 normalDest, unwindDest, callOp.getArgOperands());
673 }
674
675 // Copy all attributes from the original call except those already set by
676 // TryCallOp::create or that are operation-specific and should not be copied.
677 llvm::StringRef excludedAttrs[] = {
678 CIRDialect::getCalleeAttrName(), // Set by create()
679 CIRDialect::getOperandSegmentSizesAttrName(),
680 };
681#ifndef NDEBUG
682 // We don't expect to ever see any of these attributes on a call that we
683 // converted to a try_call.
684 llvm::StringRef unexpectedAttrs[] = {
685 CIRDialect::getNoThrowAttrName(),
686 CIRDialect::getNoUnwindAttrName(),
687 };
688#endif
689 for (mlir::NamedAttribute attr : callOp->getAttrs()) {
690 if (llvm::is_contained(excludedAttrs, attr.getName()))
691 continue;
692 assert(!llvm::is_contained(unexpectedAttrs, attr.getName()) &&
693 "unexpected attribute on converted call");
694 tryCallOp->setAttr(attr.getName(), attr.getValue());
695 }
696
697 // Replace uses of the call result with the try_call result.
698 if (callOp->getNumResults() > 0)
699 callOp->getResult(0).replaceAllUsesWith(tryCallOp.getResult());
700
701 rewriter.eraseOp(callOp);
702}
703
704// Create a shared unwind destination block. The block contains a
705// cir.eh.initiate operation (optionally with the cleanup attribute) and a
706// branch to the given destination block, passing the eh_token.
707static mlir::Block *buildUnwindBlock(mlir::Block *dest, bool hasCleanup,
708 mlir::Location loc,
709 mlir::Block *insertBefore,
710 mlir::PatternRewriter &rewriter) {
711 mlir::Block *unwindBlock = rewriter.createBlock(insertBefore);
712 rewriter.setInsertionPointToEnd(unwindBlock);
713 auto ehInitiate =
714 cir::EhInitiateOp::create(rewriter, loc, /*cleanup=*/hasCleanup);
715 cir::BrOp::create(rewriter, loc, mlir::ValueRange{ehInitiate.getEhToken()},
716 dest);
717 return unwindBlock;
718}
719
720class CIRCleanupScopeOpFlattening
721 : public mlir::OpRewritePattern<cir::CleanupScopeOp> {
722public:
723 using OpRewritePattern<cir::CleanupScopeOp>::OpRewritePattern;
724
725 struct CleanupExit {
726 // An operation that exits the cleanup scope (yield, break, continue,
727 // return, etc.)
728 mlir::Operation *exitOp;
729
730 // A unique identifier for this exit's destination (used for switch dispatch
731 // when there are multiple exits).
732 int destinationId;
733
734 CleanupExit(mlir::Operation *op, int id) : exitOp(op), destinationId(id) {}
735 };
736
737 // Collect all operations that exit a cleanup scope body. Return, goto, break,
738 // and continue can all require branches through the cleanup region. When a
739 // loop is encountered, only return and goto are collected because break and
740 // continue are handled by the loop and stay within the cleanup scope. When a
741 // switch is encountered, return, goto and continue are collected because they
742 // may all branch through the cleanup, but break is local to the switch. When
743 // a nested cleanup scope is encountered, we recursively collect exits since
744 // any return, goto, break, or continue from the nested cleanup will also
745 // branch through the outer cleanup.
746 //
747 // Note that goto statements may not necessarily exit the cleanup scope, but
748 // for now we conservatively assume that they do. We'll need more nuanced
749 // handling of that when multi-exit flattening is implemented.
750 //
751 // This function assigns unique destination IDs to each exit, which are
752 // used when multi-exit cleanup scopes are flattened.
753 void collectExits(mlir::Region &cleanupBodyRegion,
754 llvm::SmallVectorImpl<CleanupExit> &exits,
755 int &nextId) const {
756 // Collect yield terminators from the body region. We do this separately
757 // because yields in nested operations, including those in nested cleanup
758 // scopes, won't branch through the outer cleanup region.
759 for (mlir::Block &block : cleanupBodyRegion) {
760 auto *terminator = block.getTerminator();
761 if (isa<cir::YieldOp>(terminator))
762 exits.emplace_back(terminator, nextId++);
763 }
764
765 // Lambda to walk a loop and collect only returns and gotos.
766 // Break and continue inside loops are handled by the loop itself.
767 // Loops don't require special handling for nested switch or cleanup scopes
768 // because break and continue never branch out of the loop.
769 auto collectExitsInLoop = [&](mlir::Operation *loopOp) {
770 loopOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *nestedOp) {
771 if (isa<cir::ReturnOp, cir::GotoOp>(nestedOp))
772 exits.emplace_back(nestedOp, nextId++);
773 return mlir::WalkResult::advance();
774 });
775 };
776
777 // Forward declaration for mutual recursion.
778 std::function<void(mlir::Region &, bool)> collectExitsInCleanup;
779 std::function<void(mlir::Operation *)> collectExitsInSwitch;
780
781 // Lambda to collect exits from a switch. Collects return/goto/continue but
782 // not break (handled by switch). For nested loops/cleanups, recurses.
783 collectExitsInSwitch = [&](mlir::Operation *switchOp) {
784 switchOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *nestedOp) {
785 if (isa<cir::CleanupScopeOp>(nestedOp)) {
786 // Walk the nested cleanup, but ignore break statements because they
787 // will be handled by the switch we are currently walking.
788 collectExitsInCleanup(
789 cast<cir::CleanupScopeOp>(nestedOp).getBodyRegion(),
790 /*ignoreBreak=*/true);
791 return mlir::WalkResult::skip();
792 } else if (isa<cir::LoopOpInterface>(nestedOp)) {
793 collectExitsInLoop(nestedOp);
794 return mlir::WalkResult::skip();
795 } else if (isa<cir::ReturnOp, cir::GotoOp, cir::ContinueOp>(nestedOp)) {
796 exits.emplace_back(nestedOp, nextId++);
797 }
798 return mlir::WalkResult::advance();
799 });
800 };
801
802 // Lambda to collect exits from a cleanup scope body region. This collects
803 // break (optionally), continue, return, and goto, handling nested loops,
804 // switches, and cleanups appropriately.
805 collectExitsInCleanup = [&](mlir::Region &region, bool ignoreBreak) {
806 region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
807 // We need special handling for break statements because if this cleanup
808 // scope was nested within a switch op, break will be handled by the
809 // switch operation and therefore won't exit the cleanup scope enclosing
810 // the switch. We're only collecting exits from the cleanup that started
811 // this walk. Exits from nested cleanups will be handled when we flatten
812 // the nested cleanup.
813 if (!ignoreBreak && isa<cir::BreakOp>(op)) {
814 exits.emplace_back(op, nextId++);
815 } else if (isa<cir::ContinueOp, cir::ReturnOp, cir::GotoOp>(op)) {
816 exits.emplace_back(op, nextId++);
817 } else if (isa<cir::CleanupScopeOp>(op)) {
818 // Recurse into nested cleanup's body region.
819 collectExitsInCleanup(cast<cir::CleanupScopeOp>(op).getBodyRegion(),
820 /*ignoreBreak=*/ignoreBreak);
821 return mlir::WalkResult::skip();
822 } else if (isa<cir::LoopOpInterface>(op)) {
823 // This kicks off a separate walk rather than continuing to dig deeper
824 // in the current walk because we need to handle break and continue
825 // differently inside loops.
826 collectExitsInLoop(op);
827 return mlir::WalkResult::skip();
828 } else if (isa<cir::SwitchOp>(op)) {
829 // This kicks off a separate walk rather than continuing to dig deeper
830 // in the current walk because we need to handle break differently
831 // inside switches.
832 collectExitsInSwitch(op);
833 return mlir::WalkResult::skip();
834 }
835 return mlir::WalkResult::advance();
836 });
837 };
838
839 // Collect exits from the body region.
840 collectExitsInCleanup(cleanupBodyRegion, /*ignoreBreak=*/false);
841 }
842
843 // Check if an operand's defining op should be moved to the destination block.
844 // We only sink constants and simple loads. Anything else should be saved
845 // to a temporary alloca and reloaded at the destination block.
846 static bool shouldSinkReturnOperand(mlir::Value operand,
847 cir::ReturnOp returnOp) {
848 // Block arguments can't be moved
849 mlir::Operation *defOp = operand.getDefiningOp();
850 if (!defOp)
851 return false;
852
853 // Only move constants and loads to the dispatch block. For anything else,
854 // we'll store to a temporary and reload in the dispatch block.
855 if (!mlir::isa<cir::ConstantOp, cir::LoadOp>(defOp))
856 return false;
857
858 // Check if the return is the only user
859 if (!operand.hasOneUse())
860 return false;
861
862 // Only move ops that are in the same block as the return.
863 if (defOp->getBlock() != returnOp->getBlock())
864 return false;
865
866 if (auto loadOp = mlir::dyn_cast<cir::LoadOp>(defOp)) {
867 // Only attempt to move loads of allocas in the entry block.
868 mlir::Value ptr = loadOp.getAddr();
869 auto funcOp = returnOp->getParentOfType<cir::FuncOp>();
870 assert(funcOp && "Return op has no function parent?");
871 mlir::Block &funcEntryBlock = funcOp.getBody().front();
872
873 // Check if it's an alloca in the function entry block
874 if (auto allocaOp =
875 mlir::dyn_cast_if_present<cir::AllocaOp>(ptr.getDefiningOp()))
876 return allocaOp->getBlock() == &funcEntryBlock;
877
878 return false;
879 }
880
881 // Make sure we only fall through to here with constants.
882 assert(mlir::isa<cir::ConstantOp>(defOp) && "Expected constant op");
883 return true;
884 }
885
886 // For returns with operands in cleanup dispatch blocks, the operands may not
887 // dominate the dispatch block. This function handles that by either sinking
888 // the operand's defining op to the dispatch block (for constants and simple
889 // loads) or by storing to a temporary alloca and reloading it.
890 void
891 getReturnOpOperands(cir::ReturnOp returnOp, mlir::Operation *exitOp,
892 mlir::Location loc, mlir::PatternRewriter &rewriter,
893 llvm::SmallVectorImpl<mlir::Value> &returnValues) const {
894 mlir::Block *destBlock = rewriter.getInsertionBlock();
895 auto funcOp = exitOp->getParentOfType<cir::FuncOp>();
896 assert(funcOp && "Return op has no function parent?");
897 mlir::Block &funcEntryBlock = funcOp.getBody().front();
898
899 for (mlir::Value operand : returnOp.getOperands()) {
900 if (shouldSinkReturnOperand(operand, returnOp)) {
901 // Sink the defining op to the dispatch block.
902 mlir::Operation *defOp = operand.getDefiningOp();
903 defOp->moveBefore(destBlock, destBlock->end());
904 returnValues.push_back(operand);
905 } else {
906 // Create an alloca in the function entry block.
907 cir::AllocaOp alloca;
908 {
909 mlir::OpBuilder::InsertionGuard guard(rewriter);
910 rewriter.setInsertionPointToStart(&funcEntryBlock);
911 cir::CIRDataLayout dataLayout(
912 funcOp->getParentOfType<mlir::ModuleOp>());
913 uint64_t alignment =
914 dataLayout.getAlignment(operand.getType(), true).value();
915 cir::PointerType ptrType = cir::PointerType::get(operand.getType());
916 alloca = cir::AllocaOp::create(rewriter, loc, ptrType,
917 operand.getType(), "__ret_operand_tmp",
918 rewriter.getI64IntegerAttr(alignment));
919 }
920
921 // Store the operand value at the original return location.
922 {
923 mlir::OpBuilder::InsertionGuard guard(rewriter);
924 rewriter.setInsertionPoint(exitOp);
925 cir::StoreOp::create(rewriter, loc, operand, alloca,
926 /*isVolatile=*/false,
927 /*alignment=*/mlir::IntegerAttr(),
928 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
929 }
930
931 // Reload the value from the temporary alloca in the destination block.
932 rewriter.setInsertionPointToEnd(destBlock);
933 auto loaded = cir::LoadOp::create(
934 rewriter, loc, alloca, /*isDeref=*/false,
935 /*isVolatile=*/false, /*alignment=*/mlir::IntegerAttr(),
936 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
937 returnValues.push_back(loaded);
938 }
939 }
940 }
941
942 // Create the appropriate terminator for an exit operation in the dispatch
943 // block. For return ops with operands, this handles the dominance issue by
944 // either moving the operand's defining op to the dispatch block (if it's a
945 // trivial use) or by storing to a temporary alloca and loading it.
946 mlir::LogicalResult
947 createExitTerminator(mlir::Operation *exitOp, mlir::Location loc,
948 mlir::Block *continueBlock,
949 mlir::PatternRewriter &rewriter) const {
950 return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(exitOp)
951 .Case<cir::YieldOp>([&](auto) {
952 // Yield becomes a branch to continue block.
953 cir::BrOp::create(rewriter, loc, continueBlock);
954 return mlir::success();
955 })
956 .Case<cir::BreakOp>([&](auto) {
957 // Break is preserved for later lowering by enclosing switch/loop.
958 cir::BreakOp::create(rewriter, loc);
959 return mlir::success();
960 })
961 .Case<cir::ContinueOp>([&](auto) {
962 // Continue is preserved for later lowering by enclosing loop.
963 cir::ContinueOp::create(rewriter, loc);
964 return mlir::success();
965 })
966 .Case<cir::ReturnOp>([&](auto returnOp) {
967 // Return from the cleanup exit. Note, if this is a return inside a
968 // nested cleanup scope, the flattening of the outer scope will handle
969 // branching through the outer cleanup.
970 if (returnOp.hasOperand()) {
971 llvm::SmallVector<mlir::Value, 2> returnValues;
972 getReturnOpOperands(returnOp, exitOp, loc, rewriter, returnValues);
973 cir::ReturnOp::create(rewriter, loc, returnValues);
974 } else {
975 cir::ReturnOp::create(rewriter, loc);
976 }
977 return mlir::success();
978 })
979 .Case<cir::GotoOp>([&](auto gotoOp) {
980 // Correct goto handling requires determining whether the goto
981 // branches out of the cleanup scope or stays within it.
982 // Although the goto necessarily exits the cleanup scope in the
983 // case where it is the only exit from the scope, it is left
984 // as unimplemented for now so that it can be generalized when
985 // multi-exit flattening is implemented.
986 cir::UnreachableOp::create(rewriter, loc);
987 return gotoOp.emitError(
988 "goto in cleanup scope is not yet implemented");
989 })
990 .Default([&](mlir::Operation *op) {
991 cir::UnreachableOp::create(rewriter, loc);
992 return op->emitError(
993 "unexpected exit operation in cleanup scope body");
994 });
995 }
996
997#ifndef NDEBUG
998 // Check that no block other than the last one in a region exits the region.
999 static bool regionExitsOnlyFromLastBlock(mlir::Region &region) {
1000 for (mlir::Block &block : region) {
1001 if (&block == &region.back())
1002 continue;
1003 bool expectedTerminator =
1004 llvm::TypeSwitch<mlir::Operation *, bool>(block.getTerminator())
1005 // It is theoretically possible to have a cleanup block with
1006 // any of the following exits in non-final blocks, but we won't
1007 // currently generate any CIR that does that, and being able to
1008 // assume that it doesn't happen simplifies the implementation.
1009 // If we ever need to handle this case, the code will need to
1010 // be updated to handle it.
1011 .Case<cir::YieldOp, cir::ReturnOp, cir::ResumeFlatOp,
1012 cir::ContinueOp, cir::BreakOp, cir::GotoOp>(
1013 [](auto) { return false; })
1014 // We expect that call operations have not yet been rewritten
1015 // as try_call operations. A call can unwind out of the cleanup
1016 // scope, but we will be handling that during flattening. The
1017 // only case where a try_call could be present inside an
1018 // unflattened cleanup region is if the cleanup contained a
1019 // nested try-catch region, and that isn't expected as of the
1020 // time of this implementation. If it does, this could be
1021 // updated to tolerate it.
1022 .Case<cir::TryCallOp>([](auto) { return false; })
1023 // Likewise, we don't expect to find an EH dispatch operation
1024 // because we weren't expecting try-catch regions nested in the
1025 // cleanup region.
1026 .Case<cir::EhDispatchOp>([](auto) { return false; })
1027 // In theory, it would be possible to have a flattened switch
1028 // operation that does not exit the cleanup region. For now,
1029 // that's not happening.
1030 .Case<cir::SwitchFlatOp>([](auto) { return false; })
1031 // These aren't expected either, but if they occur, they don't
1032 // exit the region, so that's OK.
1033 .Case<cir::UnreachableOp, cir::TrapOp>([](auto) { return true; })
1034 // Indirect branches are not expected.
1035 .Case<cir::IndirectBrOp>([](auto) { return false; })
1036 // We do expect branches, but we don't expect them to leave
1037 // the region.
1038 .Case<cir::BrOp>([&](cir::BrOp brOp) {
1039 assert(brOp.getDest()->getParent() == &region &&
1040 "branch destination is not in the region");
1041 return true;
1042 })
1043 .Case<cir::BrCondOp>([&](cir::BrCondOp brCondOp) {
1044 assert(brCondOp.getDestTrue()->getParent() == &region &&
1045 "branch destination is not in the region");
1046 assert(brCondOp.getDestFalse()->getParent() == &region &&
1047 "branch destination is not in the region");
1048 return true;
1049 })
1050 // What else could there be?
1051 .Default([](mlir::Operation *) -> bool {
1052 llvm_unreachable("unexpected terminator in cleanup region");
1053 });
1054 if (!expectedTerminator)
1055 return false;
1056 }
1057 return true;
1058 }
1059#endif
1060
1061 // Build the EH cleanup block structure by cloning the cleanup region. The
1062 // cloned entry block gets an !cir.eh_token argument and a cir.begin_cleanup
1063 // inserted at the top. All cir.yield terminators that might exit the cleanup
1064 // region are replaced with cir.end_cleanup + cir.resume.
1065 //
1066 // For a single-block cleanup region, this produces:
1067 //
1068 // ^eh_cleanup(%eh_token : !cir.eh_token):
1069 // %ct = cir.begin_cleanup %eh_token : !cir.eh_token -> !cir.cleanup_token
1070 // <cloned cleanup operations>
1071 // cir.end_cleanup %ct : !cir.cleanup_token
1072 // cir.resume %eh_token : !cir.eh_token
1073 //
1074 // For a multi-block cleanup region (e.g. containing a flattened cir.if),
1075 // the same wrapping is applied around the cloned block structure: the entry
1076 // block gets begin_cleanup and all exit blocks (those terminated by yield)
1077 // get end_cleanup + resume.
1078 //
1079 // If this cleanup scope is nested within a TryOp, the resume will be updated
1080 // to branch to the catch dispatch block of the enclosing try operation when
1081 // the TryOp is flattened.
1082 mlir::Block *buildEHCleanupBlocks(cir::CleanupScopeOp cleanupOp,
1083 mlir::Location loc,
1084 mlir::Block *insertBefore,
1085 mlir::PatternRewriter &rewriter) const {
1086 assert(regionExitsOnlyFromLastBlock(cleanupOp.getCleanupRegion()) &&
1087 "cleanup region has exits in non-final blocks");
1088
1089 // Track the block before the insertion point so we can find the cloned
1090 // blocks after cloning.
1091 mlir::Block *blockBeforeClone = insertBefore->getPrevNode();
1092
1093 // Clone the entire cleanup region before insertBefore.
1094 rewriter.cloneRegionBefore(cleanupOp.getCleanupRegion(), insertBefore);
1095
1096 // Find the first cloned block.
1097 mlir::Block *clonedEntry = blockBeforeClone
1098 ? blockBeforeClone->getNextNode()
1099 : &insertBefore->getParent()->front();
1100
1101 // Add the eh_token argument to the cloned entry block and insert
1102 // begin_cleanup at the top.
1103 auto ehTokenType = cir::EhTokenType::get(rewriter.getContext());
1104 mlir::Value ehToken = clonedEntry->addArgument(ehTokenType, loc);
1105
1106 rewriter.setInsertionPointToStart(clonedEntry);
1107 auto beginCleanup = cir::BeginCleanupOp::create(rewriter, loc, ehToken);
1108
1109 // Replace the yield terminator in the last cloned block with
1110 // end_cleanup + resume.
1111 mlir::Block *lastClonedBlock = insertBefore->getPrevNode();
1112 auto yieldOp =
1113 mlir::dyn_cast<cir::YieldOp>(lastClonedBlock->getTerminator());
1114 if (yieldOp) {
1115 rewriter.setInsertionPoint(yieldOp);
1116 cir::EndCleanupOp::create(rewriter, loc, beginCleanup.getCleanupToken());
1117 rewriter.replaceOpWithNewOp<cir::ResumeOp>(yieldOp, ehToken);
1118 } else {
1119 cleanupOp->emitError("Not yet implemented: cleanup region terminated "
1120 "with non-yield operation");
1121 }
1122
1123 return clonedEntry;
1124 }
1125
1126 // Flatten a cleanup scope. The body region's exits branch to the cleanup
1127 // block, and the cleanup block branches to destination blocks whose contents
1128 // depend on the type of operation that exited the body region. Yield becomes
1129 // a branch to the block after the cleanup scope, break and continue are
1130 // preserved for later lowering by enclosing switch or loop, and return
1131 // is preserved as is.
1132 //
1133 // If there are multiple exits from the cleanup body, a destination slot and
1134 // switch dispatch are used to continue to the correct destination after the
1135 // cleanup is complete. A destination slot alloca is created at the function
1136 // entry block. Each exit operation is replaced by a store of its unique ID to
1137 // the destination slot and a branch to cleanup. An operation is appended to
1138 // the to branch to a dispatch block that loads the destination slot and uses
1139 // switch.flat to branch to the correct destination.
1140 //
1141 // If the cleanup scope requires EH cleanup, any call operations in the body
1142 // that may throw are replaced with cir.try_call operations that unwind to an
1143 // EH cleanup block. The cleanup block(s) will be terminated with a cir.resume
1144 // operation. If this cleanup scope is enclosed by a try operation, the
1145 // flattening of the try operation flattening will replace the cir.resume with
1146 // a branch to a catch dispatch block. Otherwise, the cir.resume operation
1147 // remains in place and will unwind to the caller.
1148 mlir::LogicalResult
1149 flattenCleanup(cir::CleanupScopeOp cleanupOp,
1150 llvm::SmallVectorImpl<CleanupExit> &exits,
1151 llvm::SmallVectorImpl<cir::CallOp> &callsToRewrite,
1152 llvm::SmallVectorImpl<cir::ResumeOp> &resumeOpsToChain,
1153 mlir::PatternRewriter &rewriter) const {
1154 mlir::Location loc = cleanupOp.getLoc();
1155 cir::CleanupKind cleanupKind = cleanupOp.getCleanupKind();
1156 bool hasNormalCleanup = cleanupKind == cir::CleanupKind::Normal ||
1157 cleanupKind == cir::CleanupKind::All;
1158 bool hasEHCleanup = cleanupKind == cir::CleanupKind::EH ||
1159 cleanupKind == cir::CleanupKind::All;
1160 bool isMultiExit = exits.size() > 1;
1161
1162 // Get references to region blocks before inlining.
1163 mlir::Block *bodyEntry = &cleanupOp.getBodyRegion().front();
1164 mlir::Block *cleanupEntry = &cleanupOp.getCleanupRegion().front();
1165 mlir::Block *cleanupExit = &cleanupOp.getCleanupRegion().back();
1166 assert(regionExitsOnlyFromLastBlock(cleanupOp.getCleanupRegion()) &&
1167 "cleanup region has exits in non-final blocks");
1168 auto cleanupYield = dyn_cast<cir::YieldOp>(cleanupExit->getTerminator());
1169 if (!cleanupYield) {
1170 return rewriter.notifyMatchFailure(cleanupOp,
1171 "Not yet implemented: cleanup region "
1172 "terminated with non-yield operation");
1173 }
1174
1175 // For multiple exits from the body region, get or create a destination slot
1176 // at function entry. The slot is shared across all cleanup scopes in the
1177 // function. This is only needed if the cleanup scope requires normal
1178 // cleanup.
1179 cir::AllocaOp destSlot;
1180 if (isMultiExit && hasNormalCleanup) {
1181 auto funcOp = cleanupOp->getParentOfType<cir::FuncOp>();
1182 if (!funcOp)
1183 return cleanupOp->emitError("cleanup scope not inside a function");
1184 destSlot = getOrCreateCleanupDestSlot(funcOp, rewriter, loc);
1185 }
1186
1187 // Split the current block to create the insertion point.
1188 mlir::Block *currentBlock = rewriter.getInsertionBlock();
1189 mlir::Block *continueBlock =
1190 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
1191
1192 // Build EH cleanup blocks if needed. This must be done before inlining
1193 // the cleanup region since buildEHCleanupBlocks clones from it. The unwind
1194 // block is inserted before the EH cleanup entry so that the final layout
1195 // is: body -> normal cleanup -> exit -> unwind -> EH cleanup -> continue.
1196 // EH cleanup blocks are needed when there are throwing calls that need to
1197 // be rewritten to try_call, or when there are resume ops from
1198 // already-flattened inner cleanup scopes that need to chain through this
1199 // cleanup's EH handler.
1200 mlir::Block *unwindBlock = nullptr;
1201 mlir::Block *ehCleanupEntry = nullptr;
1202 if (hasEHCleanup &&
1203 (!callsToRewrite.empty() || !resumeOpsToChain.empty())) {
1204 ehCleanupEntry =
1205 buildEHCleanupBlocks(cleanupOp, loc, continueBlock, rewriter);
1206 // The unwind block is only needed when there are throwing calls that
1207 // need a shared unwind destination. Resume ops from inner cleanups
1208 // branch directly to the EH cleanup entry.
1209 if (!callsToRewrite.empty())
1210 unwindBlock = buildUnwindBlock(ehCleanupEntry, /*hasCleanup=*/true, loc,
1211 ehCleanupEntry, rewriter);
1212 }
1213
1214 // All normal flow blocks are inserted before this point — either before
1215 // the unwind block (if it exists), or before the EH cleanup entry (if EH
1216 // cleanup exists but no unwind block is needed), or before the continue
1217 // block.
1218 mlir::Block *normalInsertPt =
1219 unwindBlock ? unwindBlock
1220 : (ehCleanupEntry ? ehCleanupEntry : continueBlock);
1221
1222 // Inline the body region.
1223 rewriter.inlineRegionBefore(cleanupOp.getBodyRegion(), normalInsertPt);
1224
1225 // Inline the cleanup region for the normal cleanup path.
1226 if (hasNormalCleanup)
1227 rewriter.inlineRegionBefore(cleanupOp.getCleanupRegion(), normalInsertPt);
1228
1229 // Branch from current block to body entry.
1230 rewriter.setInsertionPointToEnd(currentBlock);
1231 cir::BrOp::create(rewriter, loc, bodyEntry);
1232
1233 // Handle normal exits.
1234 mlir::LogicalResult result = mlir::success();
1235 if (hasNormalCleanup) {
1236 // Create the exit/dispatch block (after cleanup, before continue).
1237 mlir::Block *exitBlock = rewriter.createBlock(normalInsertPt);
1238
1239 // Rewrite the cleanup region's yield to branch to exit block.
1240 rewriter.setInsertionPoint(cleanupYield);
1241 rewriter.replaceOpWithNewOp<cir::BrOp>(cleanupYield, exitBlock);
1242
1243 if (isMultiExit) {
1244 // Build the dispatch switch in the exit block.
1245 rewriter.setInsertionPointToEnd(exitBlock);
1246
1247 // Load the destination slot value.
1248 auto slotValue = cir::LoadOp::create(
1249 rewriter, loc, destSlot, /*isDeref=*/false,
1250 /*isVolatile=*/false, /*alignment=*/mlir::IntegerAttr(),
1251 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
1252
1253 // Create destination blocks for each exit and collect switch case info.
1254 llvm::SmallVector<mlir::APInt, 8> caseValues;
1255 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
1256 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
1257 cir::IntType s32Type =
1258 cir::IntType::get(rewriter.getContext(), 32, /*isSigned=*/true);
1259
1260 for (const CleanupExit &exit : exits) {
1261 // Create a block for this destination.
1262 mlir::Block *destBlock = rewriter.createBlock(normalInsertPt);
1263 rewriter.setInsertionPointToEnd(destBlock);
1264 result =
1265 createExitTerminator(exit.exitOp, loc, continueBlock, rewriter);
1266
1267 // Add to switch cases.
1268 caseValues.push_back(
1269 llvm::APInt(32, static_cast<uint64_t>(exit.destinationId), true));
1270 caseDestinations.push_back(destBlock);
1271 caseOperands.push_back(mlir::ValueRange());
1272
1273 // Replace the original exit op with: store dest ID, branch to
1274 // cleanup.
1275 rewriter.setInsertionPoint(exit.exitOp);
1276 auto destIdConst = cir::ConstantOp::create(
1277 rewriter, loc, cir::IntAttr::get(s32Type, exit.destinationId));
1278 cir::StoreOp::create(rewriter, loc, destIdConst, destSlot,
1279 /*isVolatile=*/false,
1280 /*alignment=*/mlir::IntegerAttr(),
1281 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
1282 rewriter.replaceOpWithNewOp<cir::BrOp>(exit.exitOp, cleanupEntry);
1283
1284 // If the exit terminator creation failed, we're going to end up with
1285 // partially flattened code, but we'll also have reported an error so
1286 // that's OK. We need to finish out this function to keep the IR in a
1287 // valid state to help diagnose the error. This is a temporary
1288 // possibility during development. It shouldn't ever happen after the
1289 // implementation is complete.
1290 if (result.failed())
1291 break;
1292 }
1293
1294 // Create the default destination (unreachable).
1295 mlir::Block *defaultBlock = rewriter.createBlock(normalInsertPt);
1296 rewriter.setInsertionPointToEnd(defaultBlock);
1297 cir::UnreachableOp::create(rewriter, loc);
1298
1299 // Build the switch.flat operation in the exit block.
1300 rewriter.setInsertionPointToEnd(exitBlock);
1301 cir::SwitchFlatOp::create(rewriter, loc, slotValue, defaultBlock,
1302 mlir::ValueRange(), caseValues,
1303 caseDestinations, caseOperands);
1304 } else {
1305 // Single exit: put the appropriate terminator directly in the exit
1306 // block.
1307 rewriter.setInsertionPointToEnd(exitBlock);
1308 mlir::Operation *exitOp = exits[0].exitOp;
1309 result = createExitTerminator(exitOp, loc, continueBlock, rewriter);
1310
1311 // Replace body exit with branch to cleanup entry.
1312 rewriter.setInsertionPoint(exitOp);
1313 rewriter.replaceOpWithNewOp<cir::BrOp>(exitOp, cleanupEntry);
1314 }
1315 } else {
1316 // EH-only cleanup: normal exits skip the cleanup entirely.
1317 // Replace yield exits with branches to the continue block.
1318 for (CleanupExit &exit : exits) {
1319 if (isa<cir::YieldOp>(exit.exitOp)) {
1320 rewriter.setInsertionPoint(exit.exitOp);
1321 rewriter.replaceOpWithNewOp<cir::BrOp>(exit.exitOp, continueBlock);
1322 }
1323 // Non-yield exits (break, continue, return) stay as-is since no normal
1324 // cleanup is needed.
1325 }
1326 }
1327
1328 // Replace non-nothrow calls with try_call operations. All calls within
1329 // this cleanup scope share the same unwind destination.
1330 if (hasEHCleanup) {
1331 for (cir::CallOp callOp : callsToRewrite)
1332 replaceCallWithTryCall(callOp, unwindBlock, loc, rewriter);
1333 }
1334
1335 // Chain inner EH cleanup resume ops to this cleanup's EH handler.
1336 // Each cir.resume from an already-flattened inner cleanup is replaced
1337 // with a branch to the outer EH cleanup entry, passing the eh_token
1338 // from the inner's begin_cleanup so that the same in-flight exception
1339 // flows through the outer cleanup before unwinding to the caller.
1340 if (ehCleanupEntry) {
1341 for (cir::ResumeOp resumeOp : resumeOpsToChain) {
1342 mlir::Value ehToken = resumeOp.getEhToken();
1343 rewriter.setInsertionPoint(resumeOp);
1344 rewriter.replaceOpWithNewOp<cir::BrOp>(
1345 resumeOp, mlir::ValueRange{ehToken}, ehCleanupEntry);
1346 }
1347 }
1348
1349 // Erase the original cleanup scope op.
1350 rewriter.eraseOp(cleanupOp);
1351
1352 return result;
1353 }
1354
1355 mlir::LogicalResult
1356 matchAndRewrite(cir::CleanupScopeOp cleanupOp,
1357 mlir::PatternRewriter &rewriter) const override {
1358 mlir::OpBuilder::InsertionGuard guard(rewriter);
1359
1360 // Nested cleanup scopes and try operations must be flattened before the
1361 // enclosing cleanup scope so that EH cleanup inside them is properly
1362 // handled. Fail the match so the pattern rewriter processes them first.
1363 bool hasNestedOps = cleanupOp.getBodyRegion()
1364 .walk([&](mlir::Operation *op) {
1365 if (isa<cir::CleanupScopeOp, cir::TryOp>(op))
1366 return mlir::WalkResult::interrupt();
1367 return mlir::WalkResult::advance();
1368 })
1369 .wasInterrupted();
1370 if (hasNestedOps)
1371 return mlir::failure();
1372
1373 cir::CleanupKind cleanupKind = cleanupOp.getCleanupKind();
1374
1375 // Throwing calls in the cleanup region of an EH-enabled cleanup scope
1376 // are not yet supported. Such calls would need their own EH handling
1377 // (e.g., terminate or nested cleanup) during the unwind path.
1378 if (cleanupKind != cir::CleanupKind::Normal) {
1379 llvm::SmallVector<cir::CallOp> cleanupThrowingCalls;
1380 collectThrowingCalls(cleanupOp.getCleanupRegion(), cleanupThrowingCalls);
1381 if (!cleanupThrowingCalls.empty())
1382 return cleanupOp->emitError(
1383 "throwing calls in cleanup region are not yet implemented");
1384 }
1385
1386 // Collect all exits from the body region.
1387 llvm::SmallVector<CleanupExit> exits;
1388 int nextId = 0;
1389 collectExits(cleanupOp.getBodyRegion(), exits, nextId);
1390
1391 assert(!exits.empty() && "cleanup scope body has no exit");
1392
1393 // Collect non-nothrow calls that need to be converted to try_call.
1394 // This is only needed for EH and All cleanup kinds, but the vector
1395 // will simply be empty for Normal cleanup.
1396 llvm::SmallVector<cir::CallOp> callsToRewrite;
1397 if (cleanupKind != cir::CleanupKind::Normal)
1398 collectThrowingCalls(cleanupOp.getBodyRegion(), callsToRewrite);
1399
1400 // Collect resume ops from already-flattened inner cleanup scopes that
1401 // need to chain through this cleanup's EH handler.
1402 llvm::SmallVector<cir::ResumeOp> resumeOpsToChain;
1403 if (cleanupKind != cir::CleanupKind::Normal)
1404 collectResumeOps(cleanupOp.getBodyRegion(), resumeOpsToChain);
1405
1406 return flattenCleanup(cleanupOp, exits, callsToRewrite, resumeOpsToChain,
1407 rewriter);
1408 }
1409};
1410
1411class CIRTryOpFlattening : public mlir::OpRewritePattern<cir::TryOp> {
1412public:
1413 using OpRewritePattern<cir::TryOp>::OpRewritePattern;
1414
1415 // Build the catch dispatch block with a cir.eh.dispatch operation.
1416 // The dispatch block receives an !cir.eh_token argument and dispatches
1417 // to the appropriate catch handler blocks based on exception types.
1418 mlir::Block *buildCatchDispatchBlock(
1419 cir::TryOp tryOp, mlir::ArrayAttr handlerTypes,
1420 llvm::SmallVectorImpl<mlir::Block *> &catchHandlerBlocks,
1421 mlir::Location loc, mlir::Block *insertBefore,
1422 mlir::PatternRewriter &rewriter) const {
1423 mlir::Block *dispatchBlock = rewriter.createBlock(insertBefore);
1424 auto ehTokenType = cir::EhTokenType::get(rewriter.getContext());
1425 mlir::Value ehToken = dispatchBlock->addArgument(ehTokenType, loc);
1426
1427 rewriter.setInsertionPointToEnd(dispatchBlock);
1428
1429 // Build the catch types and destinations for the dispatch.
1430 llvm::SmallVector<mlir::Attribute> catchTypeAttrs;
1431 llvm::SmallVector<mlir::Block *> catchDests;
1432 mlir::Block *defaultDest = nullptr;
1433 bool defaultIsCatchAll = false;
1434
1435 for (auto [typeAttr, handlerBlock] :
1436 llvm::zip(handlerTypes, catchHandlerBlocks)) {
1437 if (mlir::isa<cir::CatchAllAttr>(typeAttr)) {
1438 assert(!defaultDest && "multiple catch_all or unwind handlers");
1439 defaultDest = handlerBlock;
1440 defaultIsCatchAll = true;
1441 } else if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
1442 assert(!defaultDest && "multiple catch_all or unwind handlers");
1443 defaultDest = handlerBlock;
1444 defaultIsCatchAll = false;
1445 } else {
1446 // This is a typed catch handler (GlobalViewAttr with type info).
1447 catchTypeAttrs.push_back(typeAttr);
1448 catchDests.push_back(handlerBlock);
1449 }
1450 }
1451
1452 assert(defaultDest && "dispatch must have a catch_all or unwind handler");
1453
1454 mlir::ArrayAttr catchTypesArrayAttr;
1455 if (!catchTypeAttrs.empty())
1456 catchTypesArrayAttr = rewriter.getArrayAttr(catchTypeAttrs);
1457
1458 cir::EhDispatchOp::create(rewriter, loc, ehToken, catchTypesArrayAttr,
1459 defaultIsCatchAll, defaultDest, catchDests);
1460
1461 return dispatchBlock;
1462 }
1463
1464 // Flatten a single catch handler region. Each handler region has an
1465 // !cir.eh_token argument and starts with cir.begin_catch, followed by
1466 // a cir.cleanup.scope containing the handler body (with cir.end_catch in
1467 // its cleanup region), and ending with cir.yield.
1468 //
1469 // After flattening, the handler region becomes a block that receives the
1470 // eh_token, calls begin_catch, runs the handler body inline, calls
1471 // end_catch, and branches to the continue block.
1472 //
1473 // The cleanup scope inside the catch handler is expected to have been
1474 // flattened before we get here, so what we see in the handler region is
1475 // already flat code with begin_catch at the top and end_catch in any place
1476 // that we would exit the catch handler. We just need to inline the region
1477 // and fix up terminators.
1478 mlir::Block *flattenCatchHandler(mlir::Region &handlerRegion,
1479 mlir::Block *continueBlock,
1480 mlir::Location loc,
1481 mlir::Block *insertBefore,
1482 mlir::PatternRewriter &rewriter) const {
1483 // The handler region entry block has the !cir.eh_token argument.
1484 mlir::Block *handlerEntry = &handlerRegion.front();
1485
1486 // Inline the handler region before insertBefore.
1487 rewriter.inlineRegionBefore(handlerRegion, insertBefore);
1488
1489 // Replace yield terminators in the handler with branches to continue.
1490 for (mlir::Block &block : llvm::make_range(handlerEntry->getIterator(),
1491 insertBefore->getIterator())) {
1492 if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator())) {
1493 // Verify that end_catch is the last non-branch operation before
1494 // this yield. After cleanup scope flattening, end_catch may be in
1495 // a predecessor block rather than immediately before the yield.
1496 // Walk back through the single-predecessor chain, verifying that
1497 // each intermediate block contains only a branch terminator, until
1498 // we find end_catch as the last non-terminator in some block.
1499 assert([&]() {
1500 // Check if end_catch immediately precedes the yield.
1501 if (mlir::Operation *prev = yieldOp->getPrevNode())
1502 return isa<cir::EndCatchOp>(prev);
1503 // The yield is alone in its block. Walk backward through
1504 // single-predecessor blocks that contain only a branch.
1505 mlir::Block *b = block.getSinglePredecessor();
1506 while (b) {
1507 mlir::Operation *term = b->getTerminator();
1508 if (mlir::Operation *prev = term->getPrevNode())
1509 return isa<cir::EndCatchOp>(prev);
1510 if (!isa<cir::BrOp>(term))
1511 return false;
1512 b = b->getSinglePredecessor();
1513 }
1514 return false;
1515 }() && "expected end_catch as last operation before yield "
1516 "in catch handler, with only branches in between");
1517 rewriter.setInsertionPoint(yieldOp);
1518 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, continueBlock);
1519 }
1520 }
1521
1522 return handlerEntry;
1523 }
1524
1525 // Flatten an unwind handler region. The unwind region just contains a
1526 // cir.resume that continues unwinding. We inline it and leave the resume
1527 // in place. If this try op is nested inside an EH cleanup or another try op,
1528 // the enclosing op will rewrite the resume as a branch to its cleanup or
1529 // dispatch block when it is flattened. Otherwise, the resume will unwind to
1530 // the caller.
1531 mlir::Block *flattenUnwindHandler(mlir::Region &unwindRegion,
1532 mlir::Location loc,
1533 mlir::Block *insertBefore,
1534 mlir::PatternRewriter &rewriter) const {
1535 mlir::Block *unwindEntry = &unwindRegion.front();
1536 rewriter.inlineRegionBefore(unwindRegion, insertBefore);
1537 return unwindEntry;
1538 }
1539
1540 mlir::LogicalResult
1541 matchAndRewrite(cir::TryOp tryOp,
1542 mlir::PatternRewriter &rewriter) const override {
1543 // Nested try ops and cleanup scopes must be flattened before the enclosing
1544 // try so that EH cleanup inside them is properly handled. Fail the match so
1545 // the pattern rewriter will process nested ops first.
1546 bool hasNestedOps =
1547 tryOp
1548 ->walk([&](mlir::Operation *op) {
1549 if (isa<cir::CleanupScopeOp, cir::TryOp>(op) && op != tryOp)
1550 return mlir::WalkResult::interrupt();
1551 return mlir::WalkResult::advance();
1552 })
1553 .wasInterrupted();
1554 if (hasNestedOps)
1555 return mlir::failure();
1556
1557 mlir::OpBuilder::InsertionGuard guard(rewriter);
1558 mlir::Location loc = tryOp.getLoc();
1559
1560 mlir::ArrayAttr handlerTypes = tryOp.getHandlerTypesAttr();
1561 mlir::MutableArrayRef<mlir::Region> handlerRegions =
1562 tryOp.getHandlerRegions();
1563
1564 // Collect throwing calls in the try body.
1565 llvm::SmallVector<cir::CallOp> callsToRewrite;
1566 collectThrowingCalls(tryOp.getTryRegion(), callsToRewrite);
1567
1568 // Collect resume ops from already-flattened cleanup scopes in the try body.
1569 llvm::SmallVector<cir::ResumeOp> resumeOpsToChain;
1570 collectResumeOps(tryOp.getTryRegion(), resumeOpsToChain);
1571
1572 // Split the current block and inline the try body.
1573 mlir::Block *currentBlock = rewriter.getInsertionBlock();
1574 mlir::Block *continueBlock =
1575 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
1576
1577 // Get references to try body blocks before inlining.
1578 mlir::Block *bodyEntry = &tryOp.getTryRegion().front();
1579 mlir::Block *bodyExit = &tryOp.getTryRegion().back();
1580
1581 // Inline the try body region before the continue block.
1582 rewriter.inlineRegionBefore(tryOp.getTryRegion(), continueBlock);
1583
1584 // Branch from the current block to the body entry.
1585 rewriter.setInsertionPointToEnd(currentBlock);
1586 cir::BrOp::create(rewriter, loc, bodyEntry);
1587
1588 // Replace the try body's yield terminator with a branch to continue.
1589 if (auto bodyYield = dyn_cast<cir::YieldOp>(bodyExit->getTerminator())) {
1590 rewriter.setInsertionPoint(bodyYield);
1591 rewriter.replaceOpWithNewOp<cir::BrOp>(bodyYield, continueBlock);
1592 }
1593
1594 // If there are no handlers, we're done.
1595 if (!handlerTypes || handlerTypes.empty()) {
1596 rewriter.eraseOp(tryOp);
1597 return mlir::success();
1598 }
1599
1600 // If there are no throwing calls and no resume ops from inner cleanup
1601 // scopes, exceptions cannot reach the catch handlers. Skip handler and
1602 // dispatch block creation — the handler regions will be dropped when
1603 // the try op is erased.
1604 if (callsToRewrite.empty() && resumeOpsToChain.empty()) {
1605 rewriter.eraseOp(tryOp);
1606 return mlir::success();
1607 }
1608
1609 // Build the catch handler blocks.
1610
1611 // First, flatten all handler regions and collect the entry blocks.
1612 llvm::SmallVector<mlir::Block *> catchHandlerBlocks;
1613
1614 for (const auto &[idx, typeAttr] : llvm::enumerate(handlerTypes)) {
1615 mlir::Region &handlerRegion = handlerRegions[idx];
1616
1617 if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
1618 mlir::Block *unwindEntry =
1619 flattenUnwindHandler(handlerRegion, loc, continueBlock, rewriter);
1620 catchHandlerBlocks.push_back(unwindEntry);
1621 } else {
1622 mlir::Block *handlerEntry = flattenCatchHandler(
1623 handlerRegion, continueBlock, loc, continueBlock, rewriter);
1624 catchHandlerBlocks.push_back(handlerEntry);
1625 }
1626 }
1627
1628 // Build the catch dispatch block.
1629 mlir::Block *dispatchBlock =
1630 buildCatchDispatchBlock(tryOp, handlerTypes, catchHandlerBlocks, loc,
1631 catchHandlerBlocks.front(), rewriter);
1632
1633 // Build a block to be the unwind desination for throwing calls and replace
1634 // the calls with try_call ops. Note that the unwind block created here is
1635 // something different than the unwind handler that we may have created
1636 // above. The unwind handler continues unwinding after uncaught exceptions.
1637 // This is the block that will eventually become the landing pad for invoke
1638 // instructions.
1639 bool hasCleanup = tryOp.getCleanup();
1640 if (!callsToRewrite.empty()) {
1641 // Create a shared unwind block for all throwing calls.
1642 mlir::Block *unwindBlock = buildUnwindBlock(dispatchBlock, hasCleanup,
1643 loc, dispatchBlock, rewriter);
1644
1645 for (cir::CallOp callOp : callsToRewrite)
1646 replaceCallWithTryCall(callOp, unwindBlock, loc, rewriter);
1647 }
1648
1649 // Chain resume ops from inner cleanup scopes.
1650 // Resume ops from already-flattened cleanup scopes within the try body
1651 // should branch to the catch dispatch block instead of unwinding directly.
1652 for (cir::ResumeOp resumeOp : resumeOpsToChain) {
1653 mlir::Value ehToken = resumeOp.getEhToken();
1654 rewriter.setInsertionPoint(resumeOp);
1655 rewriter.replaceOpWithNewOp<cir::BrOp>(
1656 resumeOp, mlir::ValueRange{ehToken}, dispatchBlock);
1657 }
1658
1659 // Finally, erase the original try op ----
1660 rewriter.eraseOp(tryOp);
1661
1662 return mlir::success();
1663 }
1664};
1665
1666void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
1667 patterns
1668 .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
1669 CIRSwitchOpFlattening, CIRTernaryOpFlattening,
1670 CIRCleanupScopeOpFlattening, CIRTryOpFlattening>(
1671 patterns.getContext());
1672}
1673
1674void CIRFlattenCFGPass::runOnOperation() {
1675 RewritePatternSet patterns(&getContext());
1676 populateFlattenCFGPatterns(patterns);
1677
1678 // Collect operations to apply patterns.
1679 llvm::SmallVector<Operation *, 16> ops;
1680 getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
1681 if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp, CleanupScopeOp,
1682 TryOp>(op))
1683 ops.push_back(op);
1684 });
1685
1686 // Apply patterns.
1687 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
1688 signalPassFailure();
1689}
1690
1691} // namespace
1692
1693namespace mlir {
1694
1695std::unique_ptr<Pass> createCIRFlattenCFGPass() {
1696 return std::make_unique<CIRFlattenCFGPass>();
1697}
1698
1699} // namespace mlir
__device__ __2f16 b
const internal::VariadicAllOfMatcher< Attr > attr
llvm::APInt APInt
Definition FixedPoint.h:19
ASTEdit insertBefore(RangeSelector S, TextGenerator Replacement)
Inserts Replacement before S, leaving the source selected by \S unchanged.
unsigned long uint64_t
std::unique_ptr< Pass > createCIRFlattenCFGPass()
int const char * function
Definition c++config.h:31
float __ovld __cnfn step(float, float)
Returns 0.0 if x < edge, otherwise it returns 1.0.
static bool stackSaveOp()