15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/IR/Block.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Interfaces/SideEffectInterfaces.h"
20#include "mlir/Support/LogicalResult.h"
21#include "mlir/Transforms/DialectConversion.h"
22#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28#include "llvm/ADT/TypeSwitch.h"
34#define GEN_PASS_DEF_CIRFLATTENCFG
35#include "clang/CIR/Dialect/Passes.h.inc"
41void lowerTerminator(mlir::Operation *op, mlir::Block *dest,
42 mlir::PatternRewriter &rewriter) {
43 assert(op->hasTrait<mlir::OpTrait::IsTerminator>() &&
"not a terminator");
44 mlir::OpBuilder::InsertionGuard guard(rewriter);
45 rewriter.setInsertionPoint(op);
46 rewriter.replaceOpWithNewOp<cir::BrOp>(op, dest);
51template <
typename... Ops>
52void walkRegionSkipping(
54 mlir::function_ref<mlir::WalkResult(mlir::Operation *)> callback) {
55 region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
57 return mlir::WalkResult::skip();
71static bool hasNestedOpsToFlatten(mlir::Region ®ion) {
73 .walk([](mlir::Operation *op) {
74 if (op->getNumRegions() > 0 && !isa<cir::CaseOp>(op))
75 return mlir::WalkResult::interrupt();
76 return mlir::WalkResult::advance();
87static bool isNonReturningTerminator(mlir::Operation *op) {
88 return mlir::isa_and_nonnull<cir::UnreachableOp, cir::TrapOp>(op);
106static mlir::LogicalResult
107rewriteRegionExitToContinue(mlir::PatternRewriter &rewriter,
108 mlir::Region ®ion, mlir::Block *continueBlock,
109 llvm::StringRef regionDescription) {
110 mlir::Operation *terminator = region.back().getTerminator();
111 rewriter.setInsertionPointToEnd(®ion.back());
112 if (
auto yieldOp = mlir::dyn_cast<cir::YieldOp>(terminator)) {
113 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(),
115 return mlir::success();
117 if (isNonReturningTerminator(terminator))
118 return mlir::success();
119 terminator->emitError(
"unexpected terminator in ")
121 <<
" region, expected yield, unreachable, or trap, got: "
122 << terminator->getName();
123 return mlir::failure();
126struct CIRFlattenCFGPass :
public impl::CIRFlattenCFGBase<CIRFlattenCFGPass> {
128 CIRFlattenCFGPass() =
default;
129 void runOnOperation()
override;
132struct CIRIfFlattening :
public mlir::OpRewritePattern<cir::IfOp> {
133 using OpRewritePattern<IfOp>::OpRewritePattern;
136 matchAndRewrite(cir::IfOp ifOp,
137 mlir::PatternRewriter &rewriter)
const override {
138 mlir::OpBuilder::InsertionGuard guard(rewriter);
139 mlir::Location loc = ifOp.getLoc();
140 bool emptyElse = ifOp.getElseRegion().empty();
141 mlir::Block *currentBlock = rewriter.getInsertionBlock();
142 mlir::Block *remainingOpsBlock =
143 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
144 mlir::Block *continueBlock;
145 if (ifOp->getResults().empty())
146 continueBlock = remainingOpsBlock;
148 llvm_unreachable(
"NYI");
151 mlir::Block *thenBeforeBody = &ifOp.getThenRegion().front();
152 mlir::Block *thenAfterBody = &ifOp.getThenRegion().back();
153 rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock);
155 rewriter.setInsertionPointToEnd(thenAfterBody);
156 if (
auto thenYieldOp =
157 dyn_cast<cir::YieldOp>(thenAfterBody->getTerminator())) {
158 rewriter.replaceOpWithNewOp<cir::BrOp>(thenYieldOp, thenYieldOp.getArgs(),
162 rewriter.setInsertionPointToEnd(continueBlock);
165 mlir::Block *elseBeforeBody =
nullptr;
166 mlir::Block *elseAfterBody =
nullptr;
168 elseBeforeBody = &ifOp.getElseRegion().front();
169 elseAfterBody = &ifOp.getElseRegion().back();
170 rewriter.inlineRegionBefore(ifOp.getElseRegion(), continueBlock);
172 elseBeforeBody = elseAfterBody = continueBlock;
175 rewriter.setInsertionPointToEnd(currentBlock);
176 cir::BrCondOp::create(rewriter, loc, ifOp.getCondition(), thenBeforeBody,
180 rewriter.setInsertionPointToEnd(elseAfterBody);
181 if (
auto elseYieldOP =
182 dyn_cast<cir::YieldOp>(elseAfterBody->getTerminator())) {
183 rewriter.replaceOpWithNewOp<cir::BrOp>(
184 elseYieldOP, elseYieldOP.getArgs(), continueBlock);
188 rewriter.replaceOp(ifOp, continueBlock->getArguments());
189 return mlir::success();
193class CIRScopeOpFlattening :
public mlir::OpRewritePattern<cir::ScopeOp> {
195 using OpRewritePattern<cir::ScopeOp>::OpRewritePattern;
198 matchAndRewrite(cir::ScopeOp scopeOp,
199 mlir::PatternRewriter &rewriter)
const override {
200 mlir::OpBuilder::InsertionGuard guard(rewriter);
201 mlir::Location loc = scopeOp.getLoc();
209 if (scopeOp.isEmpty()) {
210 rewriter.eraseOp(scopeOp);
211 return mlir::success();
216 mlir::Block *currentBlock = rewriter.getInsertionBlock();
217 mlir::Block *continueBlock =
218 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
219 if (scopeOp.getNumResults() > 0)
220 continueBlock->addArguments(scopeOp.getResultTypes(), loc);
223 mlir::Block *beforeBody = &scopeOp.getScopeRegion().front();
224 mlir::Block *afterBody = &scopeOp.getScopeRegion().back();
225 rewriter.inlineRegionBefore(scopeOp.getScopeRegion(), continueBlock);
228 rewriter.setInsertionPointToEnd(currentBlock);
230 cir::BrOp::create(rewriter, loc, mlir::ValueRange(), beforeBody);
234 rewriter.setInsertionPointToEnd(afterBody);
235 if (
auto yieldOp = dyn_cast<cir::YieldOp>(afterBody->getTerminator())) {
236 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(),
241 rewriter.replaceOp(scopeOp, continueBlock->getArguments());
243 return mlir::success();
247class CIRSwitchOpFlattening :
public mlir::OpRewritePattern<cir::SwitchOp> {
249 using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
251 inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
252 cir::YieldOp yieldOp,
253 mlir::Block *destination)
const {
254 rewriter.setInsertionPoint(yieldOp);
255 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
260 Block *condBrToRangeDestination(cir::SwitchOp op,
261 mlir::PatternRewriter &rewriter,
262 mlir::Block *rangeDestination,
263 mlir::Block *defaultDestination,
264 const APInt &lowerBound,
265 const APInt &upperBound)
const {
266 assert(lowerBound.sle(upperBound) &&
"Invalid range");
267 mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
268 cir::IntType sIntType = cir::IntType::get(op.getContext(), 32,
true);
269 cir::IntType uIntType = cir::IntType::get(op.getContext(), 32,
false);
271 cir::ConstantOp rangeLength = cir::ConstantOp::create(
272 rewriter, op.getLoc(),
273 cir::IntAttr::get(sIntType, upperBound - lowerBound));
275 cir::ConstantOp lowerBoundValue = cir::ConstantOp::create(
276 rewriter, op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
277 mlir::Value diffValue = cir::SubOp::create(
278 rewriter, op.getLoc(), op.getCondition(), lowerBoundValue);
281 cir::CastOp uDiffValue = cir::CastOp::create(
282 rewriter, op.getLoc(), uIntType, CastKind::integral, diffValue);
283 cir::CastOp uRangeLength = cir::CastOp::create(
284 rewriter, op.getLoc(), uIntType, CastKind::integral, rangeLength);
286 cir::CmpOp cmpResult = cir::CmpOp::create(
287 rewriter, op.getLoc(), cir::CmpOpKind::le, uDiffValue, uRangeLength);
288 cir::BrCondOp::create(rewriter, op.getLoc(), cmpResult, rangeDestination,
294 matchAndRewrite(cir::SwitchOp op,
295 mlir::PatternRewriter &rewriter)
const override {
300 for (mlir::Region ®ion : op->getRegions())
301 if (hasNestedOpsToFlatten(region))
302 return mlir::failure();
305 if (op.getBody().hasOneBlock() &&
306 op.getBody().front().without_terminator().empty()) {
307 rewriter.eraseOp(op);
308 return mlir::success();
311 llvm::SmallVector<CaseOp> cases;
312 op.collectCases(cases);
315 mlir::Block *exitBlock = rewriter.splitBlock(
316 rewriter.getBlock(), op->getNextNode()->getIterator());
329 walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
330 op.getBody(), [&](mlir::Operation *op) {
331 if (!isa<cir::BreakOp>(op))
332 return mlir::WalkResult::advance();
334 lowerTerminator(op, exitBlock, rewriter);
335 return mlir::WalkResult::skip();
341 cir::YieldOp switchYield =
nullptr;
343 for (mlir::Block &block :
344 llvm::make_early_inc_range(op.getBody().getBlocks()))
345 if (
auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
346 switchYield = yieldOp;
348 assert(!op.getBody().empty());
349 mlir::Block *originalBlock = op->getBlock();
350 mlir::Block *swopBlock =
351 rewriter.splitBlock(originalBlock, op->getIterator());
352 rewriter.inlineRegionBefore(op.getBody(), exitBlock);
355 rewriteYieldOp(rewriter, switchYield, exitBlock);
357 rewriter.setInsertionPointToEnd(originalBlock);
358 cir::BrOp::create(rewriter, op.getLoc(), swopBlock);
363 llvm::SmallVector<mlir::APInt, 8> caseValues;
364 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
365 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
367 llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
368 llvm::SmallVector<mlir::Block *> rangeDestinations;
369 llvm::SmallVector<mlir::ValueRange> rangeOperands;
372 mlir::Block *defaultDestination = exitBlock;
373 mlir::ValueRange defaultOperands = exitBlock->getArguments();
376 for (cir::CaseOp caseOp : cases) {
377 mlir::Region ®ion = caseOp.getCaseRegion();
380 switch (caseOp.getKind()) {
381 case cir::CaseOpKind::Default:
382 defaultDestination = ®ion.front();
383 defaultOperands = defaultDestination->getArguments();
385 case cir::CaseOpKind::Range:
386 assert(caseOp.getValue().size() == 2 &&
387 "Case range should have 2 case value");
388 rangeValues.push_back(
389 {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),
390 cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});
391 rangeDestinations.push_back(®ion.front());
392 rangeOperands.push_back(rangeDestinations.back()->getArguments());
394 case cir::CaseOpKind::Anyof:
395 case cir::CaseOpKind::Equal:
397 for (
const mlir::Attribute &value : caseOp.getValue()) {
398 caseValues.push_back(cast<cir::IntAttr>(value).getValue());
399 caseDestinations.push_back(®ion.front());
400 caseOperands.push_back(caseDestinations.back()->getArguments());
406 for (mlir::Block &blk : region.getBlocks()) {
407 if (blk.getNumSuccessors())
410 if (
auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {
411 mlir::Operation *nextOp = caseOp->getNextNode();
412 assert(nextOp &&
"caseOp is not expected to be the last op");
413 mlir::Block *oldBlock = nextOp->getBlock();
414 mlir::Block *newBlock =
415 rewriter.splitBlock(oldBlock, nextOp->getIterator());
416 rewriter.setInsertionPointToEnd(oldBlock);
417 cir::BrOp::create(rewriter, nextOp->getLoc(), mlir::ValueRange(),
419 rewriteYieldOp(rewriter, yieldOp, newBlock);
423 mlir::Block *oldBlock = caseOp->getBlock();
424 mlir::Block *newBlock =
425 rewriter.splitBlock(oldBlock, caseOp->getIterator());
427 mlir::Block &entryBlock = caseOp.getCaseRegion().front();
428 rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);
431 rewriter.setInsertionPointToEnd(oldBlock);
432 cir::BrOp::create(rewriter, caseOp.getLoc(), &entryBlock);
436 for (cir::CaseOp caseOp : cases) {
437 mlir::Block *caseBlock = caseOp->getBlock();
440 if (caseBlock->hasNoPredecessors())
441 rewriter.eraseBlock(caseBlock);
443 rewriter.eraseOp(caseOp);
446 for (
auto [rangeVal, operand, destination] :
447 llvm::zip(rangeValues, rangeOperands, rangeDestinations)) {
448 APInt lowerBound = rangeVal.first;
449 APInt upperBound = rangeVal.second;
452 if (lowerBound.sgt(upperBound))
457 constexpr int kSmallRangeThreshold = 64;
458 if ((upperBound - lowerBound)
459 .ult(llvm::APInt(32, kSmallRangeThreshold))) {
460 for (APInt iValue = lowerBound; iValue.sle(upperBound); ++iValue) {
461 caseValues.push_back(iValue);
462 caseOperands.push_back(operand);
463 caseDestinations.push_back(destination);
469 condBrToRangeDestination(op, rewriter, destination,
470 defaultDestination, lowerBound, upperBound);
471 defaultOperands = operand;
475 rewriter.setInsertionPoint(op);
476 rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>(
477 op, op.getCondition(), defaultDestination, defaultOperands, caseValues,
478 caseDestinations, caseOperands);
480 return mlir::success();
484class CIRLoopOpInterfaceFlattening
485 :
public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> {
487 using mlir::OpInterfaceRewritePattern<
488 cir::LoopOpInterface>::OpInterfaceRewritePattern;
490 inline void lowerConditionOp(cir::ConditionOp op, mlir::Block *body,
492 mlir::PatternRewriter &rewriter)
const {
493 mlir::OpBuilder::InsertionGuard guard(rewriter);
494 rewriter.setInsertionPoint(op);
495 rewriter.replaceOpWithNewOp<cir::BrCondOp>(op, op.getCondition(), body,
500 matchAndRewrite(cir::LoopOpInterface op,
501 mlir::PatternRewriter &rewriter)
const final {
506 for (mlir::Region ®ion : op->getRegions())
507 if (hasNestedOpsToFlatten(region))
508 return mlir::failure();
511 mlir::Block *entry = rewriter.getInsertionBlock();
513 rewriter.splitBlock(entry, rewriter.getInsertionPoint());
514 mlir::Block *cond = &op.getCond().front();
515 mlir::Block *body = &op.getBody().front();
517 (op.maybeGetStep() ? &op.maybeGetStep()->front() :
nullptr);
520 rewriter.setInsertionPointToEnd(entry);
521 cir::BrOp::create(rewriter, op.getLoc(), &op.getEntry().front());
528 cast<cir::ConditionOp>(op.getCond().back().getTerminator());
529 lowerConditionOp(conditionOp, body, exit, rewriter);
536 mlir::Block *dest = (
step ?
step : cond);
537 op.walkBodySkippingNestedLoops([&](mlir::Operation *op) {
538 if (!isa<cir::ContinueOp>(op))
539 return mlir::WalkResult::advance();
541 lowerTerminator(op, dest, rewriter);
542 return mlir::WalkResult::skip();
546 walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
547 op.getBody(), [&](mlir::Operation *op) {
548 if (!isa<cir::BreakOp>(op))
549 return mlir::WalkResult::advance();
551 lowerTerminator(op, exit, rewriter);
552 return mlir::WalkResult::skip();
556 for (mlir::Block &blk : op.getBody().getBlocks()) {
557 auto bodyYield = dyn_cast<cir::YieldOp>(blk.getTerminator());
559 lowerTerminator(bodyYield, (
step ?
step : cond), rewriter);
567 cast<cir::YieldOp>(op.maybeGetStep()->back().getTerminator()), cond,
571 rewriter.inlineRegionBefore(op.getCond(), exit);
572 rewriter.inlineRegionBefore(op.getBody(), exit);
574 rewriter.inlineRegionBefore(*op.maybeGetStep(), exit);
576 rewriter.eraseOp(op);
577 return mlir::success();
581class CIRTernaryOpFlattening :
public mlir::OpRewritePattern<cir::TernaryOp> {
583 using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
586 matchAndRewrite(cir::TernaryOp op,
587 mlir::PatternRewriter &rewriter)
const override {
588 Location loc = op->getLoc();
589 Block *condBlock = rewriter.getInsertionBlock();
590 Block::iterator opPosition = rewriter.getInsertionPoint();
591 Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
592 llvm::SmallVector<mlir::Location, 2> locs;
595 if (op->getResultTypes().size())
597 Block *continueBlock =
598 rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
599 cir::BrOp::create(rewriter, loc, remainingOpsBlock);
601 Region &trueRegion = op.getTrueRegion();
602 Block *trueBlock = &trueRegion.front();
607 if (failed(rewriteRegionExitToContinue(rewriter, trueRegion, continueBlock,
609 return mlir::success();
610 rewriter.inlineRegionBefore(trueRegion, continueBlock);
612 Block *falseBlock = continueBlock;
613 Region &falseRegion = op.getFalseRegion();
615 falseBlock = &falseRegion.front();
616 if (failed(rewriteRegionExitToContinue(rewriter, falseRegion, continueBlock,
618 return mlir::success();
619 rewriter.inlineRegionBefore(falseRegion, continueBlock);
621 rewriter.setInsertionPointToEnd(condBlock);
622 cir::BrCondOp::create(rewriter, loc, op.getCond(), trueBlock, falseBlock);
624 rewriter.replaceOp(op, continueBlock->getArguments());
627 return mlir::success();
634static cir::AllocaOp getOrCreateCleanupDestSlot(cir::FuncOp funcOp,
635 mlir::PatternRewriter &rewriter,
636 mlir::Location loc) {
637 mlir::Block &entryBlock = funcOp.getBody().front();
640 auto it = llvm::find_if(entryBlock, [](
auto &op) {
641 return mlir::isa<AllocaOp>(&op) &&
642 mlir::cast<AllocaOp>(&op).getCleanupDestSlot();
644 if (it != entryBlock.end())
645 return mlir::cast<cir::AllocaOp>(*it);
648 mlir::OpBuilder::InsertionGuard guard(rewriter);
649 rewriter.setInsertionPointToStart(&entryBlock);
650 cir::IntType s32Type =
651 cir::IntType::get(rewriter.getContext(), 32,
true);
652 cir::PointerType ptrToS32Type = cir::PointerType::get(s32Type);
654 uint64_t alignment = dataLayout.getAlignment(s32Type,
true).value();
655 auto allocaOp = cir::AllocaOp::create(
656 rewriter, loc, ptrToS32Type, s32Type,
"__cleanup_dest_slot",
657 rewriter.getI64IntegerAttr(alignment));
658 allocaOp.setCleanupDestSlot(
true);
670collectThrowingCalls(mlir::Region ®ion,
672 region.walk([&](cir::CallOp callOp) {
673 if (!callOp.getNothrow())
674 callsToRewrite.push_back(callOp);
684collectThrows(mlir::Region ®ion,
687 [&](cir::ThrowOp throwOp) { throwsToRewrite.push_back(throwOp); });
696static void collectResumeOps(mlir::Region ®ion,
698 region.walk([&](cir::ResumeOp resumeOp) { resumeOps.push_back(resumeOp); });
704static mlir::Block *buildUnwindBlock(mlir::Block *dest,
bool isCleanupOnly,
706 mlir::Block *insertBefore,
707 mlir::PatternRewriter &rewriter) {
708 mlir::Block *unwindBlock = rewriter.createBlock(insertBefore);
709 rewriter.setInsertionPointToEnd(unwindBlock);
711 cir::EhInitiateOp::create(rewriter, loc, isCleanupOnly);
712 cir::BrOp::create(rewriter, loc, mlir::ValueRange{ehInitiate.getEhToken()},
720static mlir::Block *buildTerminateUnwindBlock(mlir::Location loc,
721 mlir::Block *insertBefore,
722 mlir::PatternRewriter &rewriter) {
723 mlir::Block *terminateBlock = rewriter.createBlock(insertBefore);
724 rewriter.setInsertionPointToEnd(terminateBlock);
725 auto ehInitiate = cir::EhInitiateOp::create(rewriter, loc,
false);
726 cir::EhTerminateOp::create(rewriter, loc, ehInitiate.getEhToken());
727 return terminateBlock;
730class CIRCleanupScopeOpFlattening
731 :
public mlir::OpRewritePattern<cir::CleanupScopeOp> {
733 using OpRewritePattern<cir::CleanupScopeOp>::OpRewritePattern;
738 mlir::Operation *exitOp;
744 CleanupExit(mlir::Operation *op,
int id) : exitOp(op), destinationId(id) {}
754 static bool gotoTargetsLabelInRegion(cir::GotoOp gotoOp,
755 mlir::Region ®ion) {
756 llvm::StringRef targetLabel = gotoOp.getLabel();
758 .walk([&](cir::LabelOp labelOp) {
759 if (labelOp.getLabel() == targetLabel)
760 return mlir::WalkResult::interrupt();
761 return mlir::WalkResult::advance();
784 void collectExits(mlir::Region &cleanupBodyRegion,
785 llvm::SmallVectorImpl<CleanupExit> &exits,
790 for (mlir::Block &block : cleanupBodyRegion) {
791 auto *terminator = block.getTerminator();
792 if (isa<cir::YieldOp>(terminator))
793 exits.emplace_back(terminator, nextId++);
800 auto isGotoThatExitsCleanup = [&](mlir::Operation *op) {
801 auto gotoOp = dyn_cast<cir::GotoOp>(op);
802 return gotoOp && !gotoTargetsLabelInRegion(gotoOp, cleanupBodyRegion);
809 auto collectExitsInLoop = [&](mlir::Operation *loopOp) {
810 loopOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *nestedOp) {
811 if (isa<cir::ReturnOp>(nestedOp)) {
812 exits.emplace_back(nestedOp, nextId++);
813 }
else if (isGotoThatExitsCleanup(nestedOp)) {
814 exits.emplace_back(nestedOp, nextId++);
816 return mlir::WalkResult::advance();
821 std::function<void(mlir::Region &,
bool)> collectExitsInCleanup;
826 collectExitsInSwitch = [&](mlir::Operation *switchOp) {
827 switchOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *nestedOp) {
828 if (isa<cir::CleanupScopeOp>(nestedOp)) {
831 collectExitsInCleanup(
832 cast<cir::CleanupScopeOp>(nestedOp).getBodyRegion(),
834 return mlir::WalkResult::skip();
835 }
else if (isa<cir::LoopOpInterface>(nestedOp)) {
836 collectExitsInLoop(nestedOp);
837 return mlir::WalkResult::skip();
838 }
else if (isa<cir::ReturnOp, cir::ContinueOp>(nestedOp)) {
839 exits.emplace_back(nestedOp, nextId++);
840 }
else if (isGotoThatExitsCleanup(nestedOp)) {
841 exits.emplace_back(nestedOp, nextId++);
843 return mlir::WalkResult::advance();
850 collectExitsInCleanup = [&](mlir::Region ®ion,
bool ignoreBreak) {
851 region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
858 if (!ignoreBreak && isa<cir::BreakOp>(op)) {
859 exits.emplace_back(op, nextId++);
860 }
else if (isa<cir::ContinueOp, cir::ReturnOp>(op)) {
861 exits.emplace_back(op, nextId++);
862 }
else if (isGotoThatExitsCleanup(op)) {
863 exits.emplace_back(op, nextId++);
864 }
else if (isa<cir::CleanupScopeOp>(op)) {
866 collectExitsInCleanup(cast<cir::CleanupScopeOp>(op).getBodyRegion(),
868 return mlir::WalkResult::skip();
869 }
else if (isa<cir::LoopOpInterface>(op)) {
873 collectExitsInLoop(op);
874 return mlir::WalkResult::skip();
875 }
else if (isa<cir::SwitchOp>(op)) {
879 collectExitsInSwitch(op);
880 return mlir::WalkResult::skip();
882 return mlir::WalkResult::advance();
887 collectExitsInCleanup(cleanupBodyRegion,
false);
893 static bool shouldSinkReturnOperand(mlir::Value operand,
894 cir::ReturnOp returnOp) {
896 mlir::Operation *defOp = operand.getDefiningOp();
902 if (!mlir::isa<cir::ConstantOp, cir::LoadOp>(defOp))
906 if (!operand.hasOneUse())
910 if (defOp->getBlock() != returnOp->getBlock())
913 if (
auto loadOp = mlir::dyn_cast<cir::LoadOp>(defOp)) {
915 mlir::Value ptr = loadOp.getAddr();
916 auto funcOp = returnOp->getParentOfType<cir::FuncOp>();
917 assert(funcOp &&
"Return op has no function parent?");
918 mlir::Block &funcEntryBlock = funcOp.getBody().front();
922 mlir::dyn_cast_if_present<cir::AllocaOp>(ptr.getDefiningOp()))
923 return allocaOp->getBlock() == &funcEntryBlock;
929 assert(mlir::isa<cir::ConstantOp>(defOp) &&
"Expected constant op");
938 getReturnOpOperands(cir::ReturnOp returnOp, mlir::Operation *exitOp,
939 mlir::Location loc, mlir::PatternRewriter &rewriter,
940 llvm::SmallVectorImpl<mlir::Value> &returnValues)
const {
941 mlir::Block *destBlock = rewriter.getInsertionBlock();
942 auto funcOp = exitOp->getParentOfType<cir::FuncOp>();
943 assert(funcOp &&
"Return op has no function parent?");
944 mlir::Block &funcEntryBlock = funcOp.getBody().front();
946 for (mlir::Value operand : returnOp.getOperands()) {
947 if (shouldSinkReturnOperand(operand, returnOp)) {
949 mlir::Operation *defOp = operand.getDefiningOp();
950 rewriter.moveOpBefore(defOp, destBlock, destBlock->end());
951 returnValues.push_back(operand);
954 cir::AllocaOp alloca;
956 mlir::OpBuilder::InsertionGuard guard(rewriter);
957 rewriter.setInsertionPointToStart(&funcEntryBlock);
958 cir::CIRDataLayout dataLayout(
959 funcOp->getParentOfType<mlir::ModuleOp>());
961 dataLayout.getAlignment(operand.getType(),
true).value();
962 cir::PointerType ptrType = cir::PointerType::get(operand.getType());
963 alloca = cir::AllocaOp::create(rewriter, loc, ptrType,
964 operand.getType(),
"__ret_operand_tmp",
965 rewriter.getI64IntegerAttr(alignment));
970 mlir::OpBuilder::InsertionGuard guard(rewriter);
971 rewriter.setInsertionPoint(exitOp);
972 cir::StoreOp::create(rewriter, loc, operand, alloca,
975 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
979 rewriter.setInsertionPointToEnd(destBlock);
980 auto loaded = cir::LoadOp::create(
981 rewriter, loc, alloca,
false,
982 false, mlir::IntegerAttr(),
983 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
984 returnValues.push_back(loaded);
994 createExitTerminator(mlir::Operation *exitOp, mlir::Location loc,
995 mlir::Block *continueBlock,
996 mlir::PatternRewriter &rewriter)
const {
997 return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(exitOp)
998 .Case<cir::YieldOp>([&](
auto) {
1000 cir::BrOp::create(rewriter, loc, continueBlock);
1001 return mlir::success();
1003 .Case<cir::BreakOp>([&](
auto) {
1005 cir::BreakOp::create(rewriter, loc);
1006 return mlir::success();
1008 .Case<cir::ContinueOp>([&](
auto) {
1010 cir::ContinueOp::create(rewriter, loc);
1011 return mlir::success();
1013 .Case<cir::ReturnOp>([&](
auto returnOp) {
1017 if (returnOp.hasOperand()) {
1018 llvm::SmallVector<mlir::Value, 2> returnValues;
1019 getReturnOpOperands(returnOp, exitOp, loc, rewriter, returnValues);
1020 cir::ReturnOp::create(rewriter, loc, returnValues);
1022 cir::ReturnOp::create(rewriter, loc);
1024 return mlir::success();
1026 .Case<cir::GotoOp>([&](
auto gotoOp) {
1031 cir::GotoOp::create(rewriter, loc, gotoOp.getLabel());
1032 return mlir::success();
1034 .
Default([&](mlir::Operation *op) {
1035 cir::UnreachableOp::create(rewriter, loc);
1036 return op->emitError(
1037 "unexpected exit operation in cleanup scope body");
1043 static bool regionExitsOnlyFromLastBlock(mlir::Region ®ion) {
1044 for (mlir::Block &block : region) {
1045 if (&block == ®ion.back())
1047 bool expectedTerminator =
1048 llvm::TypeSwitch<mlir::Operation *, bool>(block.getTerminator())
1055 .Case<cir::YieldOp, cir::ReturnOp, cir::ResumeFlatOp,
1056 cir::ContinueOp, cir::BreakOp, cir::GotoOp>(
1057 [](
auto) {
return false; })
1066 .Case<cir::TryCallOp>([](
auto) {
return false; })
1070 .Case<cir::EhDispatchOp>([](
auto) {
return false; })
1074 .Case<cir::SwitchFlatOp>([](
auto) {
return false; })
1077 .Case<cir::UnreachableOp, cir::TrapOp>([](
auto) {
return true; })
1079 .Case<cir::IndirectBrOp>([](
auto) {
return false; })
1082 .Case<cir::BrOp>([&](cir::BrOp brOp) {
1083 assert(brOp.getDest()->getParent() == ®ion &&
1084 "branch destination is not in the region");
1087 .Case<cir::BrCondOp>([&](cir::BrCondOp brCondOp) {
1088 assert(brCondOp.getDestTrue()->getParent() == ®ion &&
1089 "branch destination is not in the region");
1090 assert(brCondOp.getDestFalse()->getParent() == ®ion &&
1091 "branch destination is not in the region");
1095 .
Default([](mlir::Operation *) ->
bool {
1096 llvm_unreachable(
"unexpected terminator in cleanup region");
1098 if (!expectedTerminator)
1126 mlir::Block *buildEHCleanupBlocks(cir::CleanupScopeOp cleanupOp,
1128 mlir::Block *insertBefore,
1129 mlir::PatternRewriter &rewriter)
const {
1130 assert(regionExitsOnlyFromLastBlock(cleanupOp.getCleanupRegion()) &&
1131 "cleanup region has exits in non-final blocks");
1135 mlir::Block *blockBeforeClone =
insertBefore->getPrevNode();
1138 rewriter.cloneRegionBefore(cleanupOp.getCleanupRegion(), insertBefore);
1141 mlir::Block *clonedEntry = blockBeforeClone
1142 ? blockBeforeClone->getNextNode()
1147 auto ehTokenType = cir::EhTokenType::get(rewriter.getContext());
1148 mlir::Value ehToken = clonedEntry->addArgument(ehTokenType, loc);
1150 rewriter.setInsertionPointToStart(clonedEntry);
1151 auto beginCleanup = cir::BeginCleanupOp::create(rewriter, loc, ehToken);
1155 mlir::Block *lastClonedBlock =
insertBefore->getPrevNode();
1157 mlir::dyn_cast<cir::YieldOp>(lastClonedBlock->getTerminator());
1159 rewriter.setInsertionPoint(yieldOp);
1160 cir::EndCleanupOp::create(rewriter, loc, beginCleanup.getCleanupToken());
1161 rewriter.replaceOpWithNewOp<cir::ResumeOp>(yieldOp, ehToken);
1163 cleanupOp->emitError(
"Not yet implemented: cleanup region terminated "
1164 "with non-yield operation");
1193 flattenCleanup(cir::CleanupScopeOp cleanupOp,
1194 llvm::SmallVectorImpl<CleanupExit> &exits,
1195 llvm::SmallVectorImpl<cir::CallOp> &callsToRewrite,
1196 llvm::SmallVectorImpl<cir::ThrowOp> &throwsToRewrite,
1197 llvm::SmallVectorImpl<cir::ResumeOp> &resumeOpsToChain,
1198 mlir::PatternRewriter &rewriter)
const {
1199 mlir::Location loc = cleanupOp.getLoc();
1200 cir::CleanupKind cleanupKind = cleanupOp.getCleanupKind();
1201 bool hasNormalCleanup = cleanupKind == cir::CleanupKind::Normal ||
1202 cleanupKind == cir::CleanupKind::All;
1203 bool hasEHCleanup = cleanupKind == cir::CleanupKind::EH ||
1204 cleanupKind == cir::CleanupKind::All;
1205 bool isMultiExit = exits.size() > 1;
1208 mlir::Block *bodyEntry = &cleanupOp.getBodyRegion().front();
1209 mlir::Block *cleanupEntry = &cleanupOp.getCleanupRegion().front();
1210 mlir::Block *cleanupExit = &cleanupOp.getCleanupRegion().back();
1211 assert(regionExitsOnlyFromLastBlock(cleanupOp.getCleanupRegion()) &&
1212 "cleanup region has exits in non-final blocks");
1213 auto cleanupYield = dyn_cast<cir::YieldOp>(cleanupExit->getTerminator());
1214 if (!cleanupYield) {
1215 return rewriter.notifyMatchFailure(cleanupOp,
1216 "Not yet implemented: cleanup region "
1217 "terminated with non-yield operation");
1224 cir::AllocaOp destSlot;
1225 if (isMultiExit && hasNormalCleanup) {
1226 auto funcOp = cleanupOp->getParentOfType<cir::FuncOp>();
1228 return cleanupOp->emitError(
"cleanup scope not inside a function");
1229 destSlot = getOrCreateCleanupDestSlot(funcOp, rewriter, loc);
1233 mlir::Block *currentBlock = rewriter.getInsertionBlock();
1234 mlir::Block *continueBlock =
1235 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
1245 mlir::Block *unwindBlock =
nullptr;
1246 mlir::Block *ehCleanupEntry =
nullptr;
1247 if (hasEHCleanup && (!callsToRewrite.empty() || !throwsToRewrite.empty() ||
1248 !resumeOpsToChain.empty())) {
1250 buildEHCleanupBlocks(cleanupOp, loc, continueBlock, rewriter);
1254 if (!callsToRewrite.empty() || !throwsToRewrite.empty())
1255 unwindBlock = buildUnwindBlock(ehCleanupEntry,
true,
1256 loc, ehCleanupEntry, rewriter);
1263 mlir::Block *normalInsertPt =
1264 unwindBlock ? unwindBlock
1265 : (ehCleanupEntry ? ehCleanupEntry : continueBlock);
1268 rewriter.inlineRegionBefore(cleanupOp.getBodyRegion(), normalInsertPt);
1271 if (hasNormalCleanup)
1272 rewriter.inlineRegionBefore(cleanupOp.getCleanupRegion(), normalInsertPt);
1275 rewriter.setInsertionPointToEnd(currentBlock);
1276 cir::BrOp::create(rewriter, loc, bodyEntry);
1279 mlir::LogicalResult result = mlir::success();
1280 if (hasNormalCleanup) {
1282 mlir::Block *exitBlock = rewriter.createBlock(normalInsertPt);
1285 rewriter.setInsertionPoint(cleanupYield);
1286 rewriter.replaceOpWithNewOp<cir::BrOp>(cleanupYield, exitBlock);
1290 rewriter.setInsertionPointToEnd(exitBlock);
1293 auto slotValue = cir::LoadOp::create(
1294 rewriter, loc, destSlot,
false,
1295 false, mlir::IntegerAttr(),
1296 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
1299 llvm::SmallVector<mlir::APInt, 8> caseValues;
1300 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
1301 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
1302 cir::IntType s32Type =
1303 cir::IntType::get(rewriter.getContext(), 32,
true);
1305 for (
const CleanupExit &exit : exits) {
1307 mlir::Block *destBlock = rewriter.createBlock(normalInsertPt);
1308 rewriter.setInsertionPointToEnd(destBlock);
1310 createExitTerminator(exit.exitOp, loc, continueBlock, rewriter);
1313 caseValues.push_back(
1314 llvm::APInt(32,
static_cast<uint64_t>(exit.destinationId),
true));
1315 caseDestinations.push_back(destBlock);
1316 caseOperands.push_back(mlir::ValueRange());
1320 rewriter.setInsertionPoint(exit.exitOp);
1321 auto destIdConst = cir::ConstantOp::create(
1322 rewriter, loc, cir::IntAttr::get(s32Type, exit.destinationId));
1323 cir::StoreOp::create(rewriter, loc, destIdConst, destSlot,
1325 mlir::IntegerAttr(),
1326 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
1327 rewriter.replaceOpWithNewOp<cir::BrOp>(exit.exitOp, cleanupEntry);
1335 if (result.failed())
1340 mlir::Block *defaultBlock = rewriter.createBlock(normalInsertPt);
1341 rewriter.setInsertionPointToEnd(defaultBlock);
1342 cir::UnreachableOp::create(rewriter, loc);
1345 rewriter.setInsertionPointToEnd(exitBlock);
1346 cir::SwitchFlatOp::create(rewriter, loc, slotValue, defaultBlock,
1347 mlir::ValueRange(), caseValues,
1348 caseDestinations, caseOperands);
1352 rewriter.setInsertionPointToEnd(exitBlock);
1353 mlir::Operation *exitOp = exits[0].exitOp;
1354 result = createExitTerminator(exitOp, loc, continueBlock, rewriter);
1357 rewriter.setInsertionPoint(exitOp);
1358 rewriter.replaceOpWithNewOp<cir::BrOp>(exitOp, cleanupEntry);
1363 for (CleanupExit &exit : exits) {
1364 if (isa<cir::YieldOp>(exit.exitOp)) {
1365 rewriter.setInsertionPoint(exit.exitOp);
1366 rewriter.replaceOpWithNewOp<cir::BrOp>(exit.exitOp, continueBlock);
1377 for (cir::CallOp callOp : callsToRewrite)
1379 for (cir::ThrowOp throwOp : throwsToRewrite)
1389 if (ehCleanupEntry) {
1390 llvm::SmallVector<cir::CallOp> ehCleanupThrowingCalls;
1391 llvm::SmallVector<cir::ThrowOp> ehCleanupThrows;
1392 for (mlir::Block *block = ehCleanupEntry; block != continueBlock;
1393 block = block->getNextNode()) {
1394 block->walk([&](mlir::Operation *op) {
1395 if (
auto callOp = mlir::dyn_cast<cir::CallOp>(op)) {
1396 if (!callOp.getNothrow())
1397 ehCleanupThrowingCalls.push_back(callOp);
1398 }
else if (
auto throwOp = mlir::dyn_cast<cir::ThrowOp>(op)) {
1399 ehCleanupThrows.push_back(throwOp);
1403 if (!ehCleanupThrowingCalls.empty() || !ehCleanupThrows.empty()) {
1404 mlir::Block *terminateBlock =
1405 buildTerminateUnwindBlock(loc, continueBlock, rewriter);
1406 for (cir::CallOp callOp : ehCleanupThrowingCalls)
1408 for (cir::ThrowOp throwOp : ehCleanupThrows)
1418 if (ehCleanupEntry) {
1419 for (cir::ResumeOp resumeOp : resumeOpsToChain) {
1420 mlir::Value ehToken = resumeOp.getEhToken();
1421 rewriter.setInsertionPoint(resumeOp);
1422 rewriter.replaceOpWithNewOp<cir::BrOp>(
1423 resumeOp, mlir::ValueRange{ehToken}, ehCleanupEntry);
1428 rewriter.eraseOp(cleanupOp);
1433 return mlir::success();
1437 matchAndRewrite(cir::CleanupScopeOp cleanupOp,
1438 mlir::PatternRewriter &rewriter)
const override {
1439 mlir::OpBuilder::InsertionGuard guard(rewriter);
1453 llvm::SmallVector<cir::CleanupScopeOp> deadNestedOps;
1454 cleanupOp.getBodyRegion().walk([&](cir::CleanupScopeOp nested) {
1455 if (mlir::isOpTriviallyDead(nested))
1456 deadNestedOps.push_back(nested);
1458 for (
auto op : deadNestedOps)
1459 rewriter.eraseOp(op);
1461 if (hasNestedOpsToFlatten(cleanupOp.getBodyRegion()))
1462 return mlir::failure();
1464 cir::CleanupKind cleanupKind = cleanupOp.getCleanupKind();
1467 llvm::SmallVector<CleanupExit> exits;
1469 collectExits(cleanupOp.getBodyRegion(), exits, nextId);
1471 assert(!exits.empty() &&
"cleanup scope body has no exit");
1476 llvm::SmallVector<cir::CallOp> callsToRewrite;
1477 llvm::SmallVector<cir::ThrowOp> throwsToRewrite;
1478 if (cleanupKind != cir::CleanupKind::Normal) {
1479 collectThrowingCalls(cleanupOp.getBodyRegion(), callsToRewrite);
1480 collectThrows(cleanupOp.getBodyRegion(), throwsToRewrite);
1485 llvm::SmallVector<cir::ResumeOp> resumeOpsToChain;
1486 if (cleanupKind != cir::CleanupKind::Normal)
1487 collectResumeOps(cleanupOp.getBodyRegion(), resumeOpsToChain);
1489 return flattenCleanup(cleanupOp, exits, callsToRewrite, throwsToRewrite,
1490 resumeOpsToChain, rewriter);
1497static cir::EhInitiateOp traceToEhInitiate(mlir::Value ehToken) {
1499 if (
auto initiate = ehToken.getDefiningOp<cir::EhInitiateOp>())
1501 auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(ehToken);
1504 mlir::Block *pred = blockArg.getOwner()->getSinglePredecessor();
1507 auto brOp = mlir::dyn_cast<cir::BrOp>(pred->getTerminator());
1510 ehToken = brOp.getDestOperands()[blockArg.getArgNumber()];
1515class CIRTryOpFlattening :
public mlir::OpRewritePattern<cir::TryOp> {
1517 using OpRewritePattern<cir::TryOp>::OpRewritePattern;
1522 mlir::Block *buildCatchDispatchBlock(
1523 cir::TryOp tryOp, mlir::ArrayAttr handlerTypes,
1524 llvm::SmallVectorImpl<mlir::Block *> &catchHandlerBlocks,
1525 mlir::Location loc, mlir::Block *insertBefore,
1526 mlir::PatternRewriter &rewriter)
const {
1527 mlir::Block *dispatchBlock = rewriter.createBlock(insertBefore);
1528 auto ehTokenType = cir::EhTokenType::get(rewriter.getContext());
1529 mlir::Value ehToken = dispatchBlock->addArgument(ehTokenType, loc);
1531 rewriter.setInsertionPointToEnd(dispatchBlock);
1534 llvm::SmallVector<mlir::Attribute> catchTypeAttrs;
1535 llvm::SmallVector<mlir::Block *> catchDests;
1536 mlir::Block *defaultDest =
nullptr;
1537 bool defaultIsCatchAll =
false;
1539 for (
auto [typeAttr, handlerBlock] :
1540 llvm::zip(handlerTypes, catchHandlerBlocks)) {
1541 if (mlir::isa<cir::CatchAllAttr>(typeAttr)) {
1542 assert(!defaultDest &&
"multiple catch_all or unwind handlers");
1543 defaultDest = handlerBlock;
1544 defaultIsCatchAll =
true;
1545 }
else if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
1546 assert(!defaultDest &&
"multiple catch_all or unwind handlers");
1547 defaultDest = handlerBlock;
1548 defaultIsCatchAll =
false;
1551 catchTypeAttrs.push_back(typeAttr);
1552 catchDests.push_back(handlerBlock);
1556 assert(defaultDest &&
"dispatch must have a catch_all or unwind handler");
1558 mlir::ArrayAttr catchTypesArrayAttr;
1559 if (!catchTypeAttrs.empty())
1560 catchTypesArrayAttr = rewriter.getArrayAttr(catchTypeAttrs);
1562 cir::EhDispatchOp::create(rewriter, loc, ehToken, catchTypesArrayAttr,
1563 defaultIsCatchAll, defaultDest, catchDests);
1565 return dispatchBlock;
1582 mlir::Block *flattenCatchHandler(mlir::Region &handlerRegion,
1583 mlir::Block *continueBlock,
1585 mlir::Block *insertBefore,
1586 mlir::PatternRewriter &rewriter)
const {
1588 mlir::Block *handlerEntry = &handlerRegion.front();
1591 rewriter.inlineRegionBefore(handlerRegion, insertBefore);
1594 for (mlir::Block &block : llvm::make_range(handlerEntry->getIterator(),
1596 if (
auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator())) {
1609 if (mlir::Operation *prev = yieldOp->getPrevNode())
1610 return isa<cir::EndCatchOp>(prev);
1611 llvm::SmallPtrSet<mlir::Block *, 8> visited;
1612 llvm::SmallVector<mlir::Block *, 4> worklist;
1613 for (mlir::Block *pred : block.getPredecessors())
1614 worklist.push_back(pred);
1615 while (!worklist.empty()) {
1616 mlir::Block *
b = worklist.pop_back_val();
1617 if (!visited.insert(
b).second)
1619 mlir::Operation *term =
b->getTerminator();
1620 if (mlir::Operation *prev = term->getPrevNode()) {
1621 if (isa<cir::EndCatchOp>(prev))
1624 for (mlir::Block *pred :
b->getPredecessors())
1625 worklist.push_back(pred);
1629 "expected end_catch reachable before yield "
1630 "in catch handler");
1631 rewriter.setInsertionPoint(yieldOp);
1632 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, continueBlock);
1636 return handlerEntry;
1645 mlir::Block *flattenUnwindHandler(mlir::Region &unwindRegion,
1647 mlir::Block *insertBefore,
1648 mlir::PatternRewriter &rewriter)
const {
1649 mlir::Block *unwindEntry = &unwindRegion.front();
1650 rewriter.inlineRegionBefore(unwindRegion, insertBefore);
1655 matchAndRewrite(cir::TryOp tryOp,
1656 mlir::PatternRewriter &rewriter)
const override {
1663 for (mlir::Region ®ion : tryOp->getRegions())
1664 if (hasNestedOpsToFlatten(region))
1665 return mlir::failure();
1667 mlir::OpBuilder::InsertionGuard guard(rewriter);
1668 mlir::Location loc = tryOp.getLoc();
1670 mlir::ArrayAttr handlerTypes = tryOp.getHandlerTypesAttr();
1671 mlir::MutableArrayRef<mlir::Region> handlerRegions =
1672 tryOp.getHandlerRegions();
1675 llvm::SmallVector<cir::CallOp> callsToRewrite;
1676 collectThrowingCalls(tryOp.getTryRegion(), callsToRewrite);
1677 llvm::SmallVector<cir::ThrowOp> throwsToRewrite;
1678 collectThrows(tryOp.getTryRegion(), throwsToRewrite);
1681 llvm::SmallVector<cir::ResumeOp> resumeOpsToChain;
1682 collectResumeOps(tryOp.getTryRegion(), resumeOpsToChain);
1685 mlir::Block *currentBlock = rewriter.getInsertionBlock();
1686 mlir::Block *continueBlock =
1687 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
1690 mlir::Block *bodyEntry = &tryOp.getTryRegion().front();
1691 mlir::Block *bodyExit = &tryOp.getTryRegion().back();
1694 rewriter.inlineRegionBefore(tryOp.getTryRegion(), continueBlock);
1697 rewriter.setInsertionPointToEnd(currentBlock);
1698 cir::BrOp::create(rewriter, loc, bodyEntry);
1701 if (
auto bodyYield = dyn_cast<cir::YieldOp>(bodyExit->getTerminator())) {
1702 rewriter.setInsertionPoint(bodyYield);
1703 rewriter.replaceOpWithNewOp<cir::BrOp>(bodyYield, continueBlock);
1707 if (!handlerTypes || handlerTypes.empty()) {
1708 rewriter.eraseOp(tryOp);
1709 return mlir::success();
1717 if (callsToRewrite.empty() && throwsToRewrite.empty() &&
1718 resumeOpsToChain.empty()) {
1719 for (mlir::Region &handlerRegion : handlerRegions)
1720 for (mlir::Block &block : handlerRegion)
1721 block.dropAllDefinedValueUses();
1722 rewriter.eraseOp(tryOp);
1723 return mlir::success();
1729 llvm::SmallVector<mlir::Block *> catchHandlerBlocks;
1731 for (
const auto &[idx, typeAttr] : llvm::enumerate(handlerTypes)) {
1732 mlir::Region &handlerRegion = handlerRegions[idx];
1734 if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
1735 mlir::Block *unwindEntry =
1736 flattenUnwindHandler(handlerRegion, loc, continueBlock, rewriter);
1737 catchHandlerBlocks.push_back(unwindEntry);
1739 mlir::Block *handlerEntry = flattenCatchHandler(
1740 handlerRegion, continueBlock, loc, continueBlock, rewriter);
1741 catchHandlerBlocks.push_back(handlerEntry);
1746 mlir::Block *dispatchBlock =
1747 buildCatchDispatchBlock(tryOp, handlerTypes, catchHandlerBlocks, loc,
1748 catchHandlerBlocks.front(), rewriter);
1759 handlerTypes && llvm::any_of(handlerTypes, [](mlir::Attribute attr) {
1760 return mlir::isa<cir::CatchAllAttr>(attr);
1769 bool isCleanupOnly = tryOp.getCleanup() && !hasCatchAll;
1770 if (!callsToRewrite.empty() || !throwsToRewrite.empty()) {
1772 mlir::Block *unwindBlock = buildUnwindBlock(dispatchBlock, isCleanupOnly,
1773 loc, dispatchBlock, rewriter);
1775 for (cir::CallOp callOp : callsToRewrite)
1777 for (cir::ThrowOp throwOp : throwsToRewrite)
1784 for (cir::ResumeOp resumeOp : resumeOpsToChain) {
1789 if (
auto ehInitiate = traceToEhInitiate(resumeOp.getEhToken())) {
1790 rewriter.modifyOpInPlace(ehInitiate,
1791 [&] { ehInitiate.removeCleanupAttr(); });
1795 mlir::Value ehToken = resumeOp.getEhToken();
1796 rewriter.setInsertionPoint(resumeOp);
1797 rewriter.replaceOpWithNewOp<cir::BrOp>(
1798 resumeOp, mlir::ValueRange{ehToken}, dispatchBlock);
1802 rewriter.eraseOp(tryOp);
1804 return mlir::success();
1808void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
1810 .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
1811 CIRSwitchOpFlattening, CIRTernaryOpFlattening,
1812 CIRCleanupScopeOpFlattening, CIRTryOpFlattening>(
1813 patterns.getContext());
1816void CIRFlattenCFGPass::runOnOperation() {
1817 RewritePatternSet patterns(&getContext());
1818 populateFlattenCFGPatterns(patterns);
1821 llvm::SmallVector<Operation *, 16> ops;
1822 getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
1823 if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp, CleanupScopeOp,
1829 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
1830 signalPassFailure();
1838 return std::make_unique<CIRFlattenCFGPass>();
mlir::Block * replaceThrowWithTryThrow(cir::ThrowOp throwOp, mlir::Block *unwindDest, mlir::Location loc, mlir::RewriterBase &rewriter)
Replace a cir::ThrowOp with a cir::TryThrowOp whose unwind destination is unwindDest.
mlir::Block * replaceCallWithTryCall(cir::CallOp callOp, mlir::Block *unwindDest, mlir::Location loc, mlir::RewriterBase &rewriter)
Replace a cir::CallOp with a cir::TryCallOp whose unwind destination is unwindDest.
std::unique_ptr< Pass > createCIRFlattenCFGPass()
int const char * function
float __ovld __cnfn step(float, float)
Returns 0.0 if x < edge, otherwise it returns 1.0.
static bool stackSaveOp()