16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/IR/Block.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Interfaces/SideEffectInterfaces.h"
21#include "mlir/Support/LogicalResult.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28#include "llvm/ADT/TypeSwitch.h"
34#define GEN_PASS_DEF_CIRFLATTENCFG
35#include "clang/CIR/Dialect/Passes.h.inc"
41void lowerTerminator(mlir::Operation *op, mlir::Block *dest,
42 mlir::PatternRewriter &rewriter) {
43 assert(op->hasTrait<mlir::OpTrait::IsTerminator>() &&
"not a terminator");
44 mlir::OpBuilder::InsertionGuard guard(rewriter);
45 rewriter.setInsertionPoint(op);
46 rewriter.replaceOpWithNewOp<cir::BrOp>(op, dest);
51template <
typename... Ops>
52void walkRegionSkipping(
54 mlir::function_ref<mlir::WalkResult(mlir::Operation *)> callback) {
55 region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
57 return mlir::WalkResult::skip();
71static bool hasNestedOpsToFlatten(mlir::Region ®ion) {
73 .walk([](mlir::Operation *op) {
74 if (op->getNumRegions() > 0 && !isa<cir::CaseOp>(op))
75 return mlir::WalkResult::interrupt();
76 return mlir::WalkResult::advance();
87static bool isNonReturningTerminator(mlir::Operation *op) {
88 return mlir::isa_and_nonnull<cir::UnreachableOp, cir::TrapOp>(op);
106static mlir::LogicalResult
107rewriteRegionExitToContinue(mlir::PatternRewriter &rewriter,
108 mlir::Region ®ion, mlir::Block *continueBlock,
109 llvm::StringRef regionDescription) {
110 mlir::Operation *terminator = region.back().getTerminator();
111 rewriter.setInsertionPointToEnd(®ion.back());
112 if (
auto yieldOp = mlir::dyn_cast<cir::YieldOp>(terminator)) {
113 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(),
115 return mlir::success();
117 if (isNonReturningTerminator(terminator))
118 return mlir::success();
119 terminator->emitError(
"unexpected terminator in ")
121 <<
" region, expected yield, unreachable, or trap, got: "
122 << terminator->getName();
123 return mlir::failure();
126struct CIRFlattenCFGPass :
public impl::CIRFlattenCFGBase<CIRFlattenCFGPass> {
128 CIRFlattenCFGPass() =
default;
129 void runOnOperation()
override;
132struct CIRIfFlattening :
public mlir::OpRewritePattern<cir::IfOp> {
133 using OpRewritePattern<IfOp>::OpRewritePattern;
136 matchAndRewrite(cir::IfOp ifOp,
137 mlir::PatternRewriter &rewriter)
const override {
138 mlir::OpBuilder::InsertionGuard guard(rewriter);
139 mlir::Location loc = ifOp.getLoc();
140 bool emptyElse = ifOp.getElseRegion().empty();
141 mlir::Block *currentBlock = rewriter.getInsertionBlock();
142 mlir::Block *remainingOpsBlock =
143 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
144 mlir::Block *continueBlock;
145 if (ifOp->getResults().empty())
146 continueBlock = remainingOpsBlock;
148 llvm_unreachable(
"NYI");
151 mlir::Block *thenBeforeBody = &ifOp.getThenRegion().front();
152 mlir::Block *thenAfterBody = &ifOp.getThenRegion().back();
153 rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock);
155 rewriter.setInsertionPointToEnd(thenAfterBody);
156 if (
auto thenYieldOp =
157 dyn_cast<cir::YieldOp>(thenAfterBody->getTerminator())) {
158 rewriter.replaceOpWithNewOp<cir::BrOp>(thenYieldOp, thenYieldOp.getArgs(),
162 rewriter.setInsertionPointToEnd(continueBlock);
165 mlir::Block *elseBeforeBody =
nullptr;
166 mlir::Block *elseAfterBody =
nullptr;
168 elseBeforeBody = &ifOp.getElseRegion().front();
169 elseAfterBody = &ifOp.getElseRegion().back();
170 rewriter.inlineRegionBefore(ifOp.getElseRegion(), continueBlock);
172 elseBeforeBody = elseAfterBody = continueBlock;
175 rewriter.setInsertionPointToEnd(currentBlock);
176 cir::BrCondOp::create(rewriter, loc, ifOp.getCondition(), thenBeforeBody,
180 rewriter.setInsertionPointToEnd(elseAfterBody);
181 if (
auto elseYieldOP =
182 dyn_cast<cir::YieldOp>(elseAfterBody->getTerminator())) {
183 rewriter.replaceOpWithNewOp<cir::BrOp>(
184 elseYieldOP, elseYieldOP.getArgs(), continueBlock);
188 rewriter.replaceOp(ifOp, continueBlock->getArguments());
189 return mlir::success();
193class CIRScopeOpFlattening :
public mlir::OpRewritePattern<cir::ScopeOp> {
195 using OpRewritePattern<cir::ScopeOp>::OpRewritePattern;
198 matchAndRewrite(cir::ScopeOp scopeOp,
199 mlir::PatternRewriter &rewriter)
const override {
200 mlir::OpBuilder::InsertionGuard guard(rewriter);
201 mlir::Location loc = scopeOp.getLoc();
209 if (scopeOp.isEmpty()) {
210 rewriter.eraseOp(scopeOp);
211 return mlir::success();
216 mlir::Block *currentBlock = rewriter.getInsertionBlock();
217 mlir::Block *continueBlock =
218 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
219 if (scopeOp.getNumResults() > 0)
220 continueBlock->addArguments(scopeOp.getResultTypes(), loc);
223 mlir::Block *beforeBody = &scopeOp.getScopeRegion().front();
224 mlir::Block *afterBody = &scopeOp.getScopeRegion().back();
225 rewriter.inlineRegionBefore(scopeOp.getScopeRegion(), continueBlock);
228 rewriter.setInsertionPointToEnd(currentBlock);
230 cir::BrOp::create(rewriter, loc, mlir::ValueRange(), beforeBody);
234 rewriter.setInsertionPointToEnd(afterBody);
235 if (
auto yieldOp = dyn_cast<cir::YieldOp>(afterBody->getTerminator())) {
236 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(),
241 rewriter.replaceOp(scopeOp, continueBlock->getArguments());
243 return mlir::success();
247class CIRSwitchOpFlattening :
public mlir::OpRewritePattern<cir::SwitchOp> {
249 using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
251 inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
252 cir::YieldOp yieldOp,
253 mlir::Block *destination)
const {
254 rewriter.setInsertionPoint(yieldOp);
255 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
260 Block *condBrToRangeDestination(cir::SwitchOp op,
261 mlir::PatternRewriter &rewriter,
262 mlir::Block *rangeDestination,
263 mlir::Block *defaultDestination,
264 const APInt &lowerBound,
265 const APInt &upperBound)
const {
266 assert(lowerBound.sle(upperBound) &&
"Invalid range");
267 mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
268 cir::IntType sIntType = cir::IntType::get(op.getContext(), 32,
true);
269 cir::IntType uIntType = cir::IntType::get(op.getContext(), 32,
false);
271 cir::ConstantOp rangeLength = cir::ConstantOp::create(
272 rewriter, op.getLoc(),
273 cir::IntAttr::get(sIntType, upperBound - lowerBound));
275 cir::ConstantOp lowerBoundValue = cir::ConstantOp::create(
276 rewriter, op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
277 mlir::Value diffValue = cir::SubOp::create(
278 rewriter, op.getLoc(), op.getCondition(), lowerBoundValue);
281 cir::CastOp uDiffValue = cir::CastOp::create(
282 rewriter, op.getLoc(), uIntType, CastKind::integral, diffValue);
283 cir::CastOp uRangeLength = cir::CastOp::create(
284 rewriter, op.getLoc(), uIntType, CastKind::integral, rangeLength);
286 cir::CmpOp cmpResult = cir::CmpOp::create(
287 rewriter, op.getLoc(), cir::CmpOpKind::le, uDiffValue, uRangeLength);
288 cir::BrCondOp::create(rewriter, op.getLoc(), cmpResult, rangeDestination,
294 matchAndRewrite(cir::SwitchOp op,
295 mlir::PatternRewriter &rewriter)
const override {
300 for (mlir::Region ®ion : op->getRegions())
301 if (hasNestedOpsToFlatten(region))
302 return mlir::failure();
304 llvm::SmallVector<CaseOp> cases;
305 op.collectCases(cases);
309 rewriter.eraseOp(op);
310 return mlir::success();
314 mlir::Block *exitBlock = rewriter.splitBlock(
315 rewriter.getBlock(), op->getNextNode()->getIterator());
328 cir::YieldOp switchYield =
nullptr;
330 for (mlir::Block &block :
331 llvm::make_early_inc_range(op.getBody().getBlocks()))
332 if (
auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
333 switchYield = yieldOp;
335 assert(!op.getBody().empty());
336 mlir::Block *originalBlock = op->getBlock();
337 mlir::Block *swopBlock =
338 rewriter.splitBlock(originalBlock, op->getIterator());
339 rewriter.inlineRegionBefore(op.getBody(), exitBlock);
342 rewriteYieldOp(rewriter, switchYield, exitBlock);
344 rewriter.setInsertionPointToEnd(originalBlock);
345 cir::BrOp::create(rewriter, op.getLoc(), swopBlock);
350 llvm::SmallVector<mlir::APInt, 8> caseValues;
351 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
352 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
354 llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
355 llvm::SmallVector<mlir::Block *> rangeDestinations;
356 llvm::SmallVector<mlir::ValueRange> rangeOperands;
359 mlir::Block *defaultDestination = exitBlock;
360 mlir::ValueRange defaultOperands = exitBlock->getArguments();
363 for (cir::CaseOp caseOp : cases) {
364 mlir::Region ®ion = caseOp.getCaseRegion();
367 switch (caseOp.getKind()) {
368 case cir::CaseOpKind::Default:
369 defaultDestination = ®ion.front();
370 defaultOperands = defaultDestination->getArguments();
372 case cir::CaseOpKind::Range:
373 assert(caseOp.getValue().size() == 2 &&
374 "Case range should have 2 case value");
375 rangeValues.push_back(
376 {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),
377 cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});
378 rangeDestinations.push_back(®ion.front());
379 rangeOperands.push_back(rangeDestinations.back()->getArguments());
381 case cir::CaseOpKind::Anyof:
382 case cir::CaseOpKind::Equal:
384 for (
const mlir::Attribute &value : caseOp.getValue()) {
385 caseValues.push_back(cast<cir::IntAttr>(value).getValue());
386 caseDestinations.push_back(®ion.front());
387 caseOperands.push_back(caseDestinations.back()->getArguments());
393 walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
394 region, [&](mlir::Operation *op) {
395 if (!isa<cir::BreakOp>(op))
396 return mlir::WalkResult::advance();
398 lowerTerminator(op, exitBlock, rewriter);
399 return mlir::WalkResult::skip();
403 for (mlir::Block &blk : region.getBlocks()) {
404 if (blk.getNumSuccessors())
407 if (
auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {
408 mlir::Operation *nextOp = caseOp->getNextNode();
409 assert(nextOp &&
"caseOp is not expected to be the last op");
410 mlir::Block *oldBlock = nextOp->getBlock();
411 mlir::Block *newBlock =
412 rewriter.splitBlock(oldBlock, nextOp->getIterator());
413 rewriter.setInsertionPointToEnd(oldBlock);
414 cir::BrOp::create(rewriter, nextOp->getLoc(), mlir::ValueRange(),
416 rewriteYieldOp(rewriter, yieldOp, newBlock);
420 mlir::Block *oldBlock = caseOp->getBlock();
421 mlir::Block *newBlock =
422 rewriter.splitBlock(oldBlock, caseOp->getIterator());
424 mlir::Block &entryBlock = caseOp.getCaseRegion().front();
425 rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);
428 rewriter.setInsertionPointToEnd(oldBlock);
429 cir::BrOp::create(rewriter, caseOp.getLoc(), &entryBlock);
433 for (cir::CaseOp caseOp : cases) {
434 mlir::Block *caseBlock = caseOp->getBlock();
437 if (caseBlock->hasNoPredecessors())
438 rewriter.eraseBlock(caseBlock);
440 rewriter.eraseOp(caseOp);
443 for (
auto [rangeVal, operand, destination] :
444 llvm::zip(rangeValues, rangeOperands, rangeDestinations)) {
445 APInt lowerBound = rangeVal.first;
446 APInt upperBound = rangeVal.second;
449 if (lowerBound.sgt(upperBound))
454 constexpr int kSmallRangeThreshold = 64;
455 if ((upperBound - lowerBound)
456 .ult(llvm::APInt(32, kSmallRangeThreshold))) {
457 for (APInt iValue = lowerBound; iValue.sle(upperBound); ++iValue) {
458 caseValues.push_back(iValue);
459 caseOperands.push_back(operand);
460 caseDestinations.push_back(destination);
466 condBrToRangeDestination(op, rewriter, destination,
467 defaultDestination, lowerBound, upperBound);
468 defaultOperands = operand;
472 rewriter.setInsertionPoint(op);
473 rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>(
474 op, op.getCondition(), defaultDestination, defaultOperands, caseValues,
475 caseDestinations, caseOperands);
477 return mlir::success();
481class CIRLoopOpInterfaceFlattening
482 :
public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> {
484 using mlir::OpInterfaceRewritePattern<
485 cir::LoopOpInterface>::OpInterfaceRewritePattern;
487 inline void lowerConditionOp(cir::ConditionOp op, mlir::Block *body,
489 mlir::PatternRewriter &rewriter)
const {
490 mlir::OpBuilder::InsertionGuard guard(rewriter);
491 rewriter.setInsertionPoint(op);
492 rewriter.replaceOpWithNewOp<cir::BrCondOp>(op, op.getCondition(), body,
497 matchAndRewrite(cir::LoopOpInterface op,
498 mlir::PatternRewriter &rewriter)
const final {
503 for (mlir::Region ®ion : op->getRegions())
504 if (hasNestedOpsToFlatten(region))
505 return mlir::failure();
508 mlir::Block *entry = rewriter.getInsertionBlock();
510 rewriter.splitBlock(entry, rewriter.getInsertionPoint());
511 mlir::Block *cond = &op.getCond().front();
512 mlir::Block *body = &op.getBody().front();
514 (op.maybeGetStep() ? &op.maybeGetStep()->front() :
nullptr);
517 rewriter.setInsertionPointToEnd(entry);
518 cir::BrOp::create(rewriter, op.getLoc(), &op.getEntry().front());
525 cast<cir::ConditionOp>(op.getCond().back().getTerminator());
526 lowerConditionOp(conditionOp, body, exit, rewriter);
533 mlir::Block *dest = (
step ?
step : cond);
534 op.walkBodySkippingNestedLoops([&](mlir::Operation *op) {
535 if (!isa<cir::ContinueOp>(op))
536 return mlir::WalkResult::advance();
538 lowerTerminator(op, dest, rewriter);
539 return mlir::WalkResult::skip();
543 walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
544 op.getBody(), [&](mlir::Operation *op) {
545 if (!isa<cir::BreakOp>(op))
546 return mlir::WalkResult::advance();
548 lowerTerminator(op, exit, rewriter);
549 return mlir::WalkResult::skip();
553 for (mlir::Block &blk : op.getBody().getBlocks()) {
554 auto bodyYield = dyn_cast<cir::YieldOp>(blk.getTerminator());
556 lowerTerminator(bodyYield, (
step ?
step : cond), rewriter);
564 cast<cir::YieldOp>(op.maybeGetStep()->back().getTerminator()), cond,
568 rewriter.inlineRegionBefore(op.getCond(), exit);
569 rewriter.inlineRegionBefore(op.getBody(), exit);
571 rewriter.inlineRegionBefore(*op.maybeGetStep(), exit);
573 rewriter.eraseOp(op);
574 return mlir::success();
578class CIRTernaryOpFlattening :
public mlir::OpRewritePattern<cir::TernaryOp> {
580 using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
583 matchAndRewrite(cir::TernaryOp op,
584 mlir::PatternRewriter &rewriter)
const override {
585 Location loc = op->getLoc();
586 Block *condBlock = rewriter.getInsertionBlock();
587 Block::iterator opPosition = rewriter.getInsertionPoint();
588 Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
589 llvm::SmallVector<mlir::Location, 2> locs;
592 if (op->getResultTypes().size())
594 Block *continueBlock =
595 rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
596 cir::BrOp::create(rewriter, loc, remainingOpsBlock);
598 Region &trueRegion = op.getTrueRegion();
599 Block *trueBlock = &trueRegion.front();
604 if (failed(rewriteRegionExitToContinue(rewriter, trueRegion, continueBlock,
606 return mlir::success();
607 rewriter.inlineRegionBefore(trueRegion, continueBlock);
609 Block *falseBlock = continueBlock;
610 Region &falseRegion = op.getFalseRegion();
612 falseBlock = &falseRegion.front();
613 if (failed(rewriteRegionExitToContinue(rewriter, falseRegion, continueBlock,
615 return mlir::success();
616 rewriter.inlineRegionBefore(falseRegion, continueBlock);
618 rewriter.setInsertionPointToEnd(condBlock);
619 cir::BrCondOp::create(rewriter, loc, op.getCond(), trueBlock, falseBlock);
621 rewriter.replaceOp(op, continueBlock->getArguments());
624 return mlir::success();
631static cir::AllocaOp getOrCreateCleanupDestSlot(cir::FuncOp funcOp,
632 mlir::PatternRewriter &rewriter,
633 mlir::Location loc) {
634 mlir::Block &entryBlock = funcOp.getBody().front();
637 auto it = llvm::find_if(entryBlock, [](
auto &op) {
638 return mlir::isa<AllocaOp>(&op) &&
639 mlir::cast<AllocaOp>(&op).getCleanupDestSlot();
641 if (it != entryBlock.end())
642 return mlir::cast<cir::AllocaOp>(*it);
645 mlir::OpBuilder::InsertionGuard guard(rewriter);
646 rewriter.setInsertionPointToStart(&entryBlock);
647 cir::IntType s32Type =
648 cir::IntType::get(rewriter.getContext(), 32,
true);
649 cir::PointerType ptrToS32Type = cir::PointerType::get(s32Type);
651 uint64_t alignment = dataLayout.getAlignment(s32Type,
true).value();
652 auto allocaOp = cir::AllocaOp::create(
653 rewriter, loc, ptrToS32Type, s32Type,
"__cleanup_dest_slot",
654 rewriter.getI64IntegerAttr(alignment));
655 allocaOp.setCleanupDestSlot(
true);
667collectThrowingCalls(mlir::Region ®ion,
669 region.walk([&](cir::CallOp callOp) {
670 if (!callOp.getNothrow())
671 callsToRewrite.push_back(callOp);
681static void collectResumeOps(mlir::Region ®ion,
683 region.walk([&](cir::ResumeOp resumeOp) { resumeOps.push_back(resumeOp); });
689static mlir::Block *buildUnwindBlock(mlir::Block *dest,
bool isCleanupOnly,
691 mlir::Block *insertBefore,
692 mlir::PatternRewriter &rewriter) {
693 mlir::Block *unwindBlock = rewriter.createBlock(insertBefore);
694 rewriter.setInsertionPointToEnd(unwindBlock);
696 cir::EhInitiateOp::create(rewriter, loc, isCleanupOnly);
697 cir::BrOp::create(rewriter, loc, mlir::ValueRange{ehInitiate.getEhToken()},
705static mlir::Block *buildTerminateUnwindBlock(mlir::Location loc,
706 mlir::Block *insertBefore,
707 mlir::PatternRewriter &rewriter) {
708 mlir::Block *terminateBlock = rewriter.createBlock(insertBefore);
709 rewriter.setInsertionPointToEnd(terminateBlock);
710 auto ehInitiate = cir::EhInitiateOp::create(rewriter, loc,
false);
711 cir::EhTerminateOp::create(rewriter, loc, ehInitiate.getEhToken());
712 return terminateBlock;
715class CIRCleanupScopeOpFlattening
716 :
public mlir::OpRewritePattern<cir::CleanupScopeOp> {
718 using OpRewritePattern<cir::CleanupScopeOp>::OpRewritePattern;
723 mlir::Operation *exitOp;
729 CleanupExit(mlir::Operation *op,
int id) : exitOp(op), destinationId(id) {}
748 void collectExits(mlir::Region &cleanupBodyRegion,
749 llvm::SmallVectorImpl<CleanupExit> &exits,
754 for (mlir::Block &block : cleanupBodyRegion) {
755 auto *terminator = block.getTerminator();
756 if (isa<cir::YieldOp>(terminator))
757 exits.emplace_back(terminator, nextId++);
764 auto collectExitsInLoop = [&](mlir::Operation *loopOp) {
765 loopOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *nestedOp) {
766 if (isa<cir::ReturnOp, cir::GotoOp>(nestedOp))
767 exits.emplace_back(nestedOp, nextId++);
768 return mlir::WalkResult::advance();
773 std::function<void(mlir::Region &,
bool)> collectExitsInCleanup;
778 collectExitsInSwitch = [&](mlir::Operation *switchOp) {
779 switchOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *nestedOp) {
780 if (isa<cir::CleanupScopeOp>(nestedOp)) {
783 collectExitsInCleanup(
784 cast<cir::CleanupScopeOp>(nestedOp).getBodyRegion(),
786 return mlir::WalkResult::skip();
787 }
else if (isa<cir::LoopOpInterface>(nestedOp)) {
788 collectExitsInLoop(nestedOp);
789 return mlir::WalkResult::skip();
790 }
else if (isa<cir::ReturnOp, cir::GotoOp, cir::ContinueOp>(nestedOp)) {
791 exits.emplace_back(nestedOp, nextId++);
793 return mlir::WalkResult::advance();
800 collectExitsInCleanup = [&](mlir::Region ®ion,
bool ignoreBreak) {
801 region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
808 if (!ignoreBreak && isa<cir::BreakOp>(op)) {
809 exits.emplace_back(op, nextId++);
810 }
else if (isa<cir::ContinueOp, cir::ReturnOp, cir::GotoOp>(op)) {
811 exits.emplace_back(op, nextId++);
812 }
else if (isa<cir::CleanupScopeOp>(op)) {
814 collectExitsInCleanup(cast<cir::CleanupScopeOp>(op).getBodyRegion(),
816 return mlir::WalkResult::skip();
817 }
else if (isa<cir::LoopOpInterface>(op)) {
821 collectExitsInLoop(op);
822 return mlir::WalkResult::skip();
823 }
else if (isa<cir::SwitchOp>(op)) {
827 collectExitsInSwitch(op);
828 return mlir::WalkResult::skip();
830 return mlir::WalkResult::advance();
835 collectExitsInCleanup(cleanupBodyRegion,
false);
841 static bool shouldSinkReturnOperand(mlir::Value operand,
842 cir::ReturnOp returnOp) {
844 mlir::Operation *defOp = operand.getDefiningOp();
850 if (!mlir::isa<cir::ConstantOp, cir::LoadOp>(defOp))
854 if (!operand.hasOneUse())
858 if (defOp->getBlock() != returnOp->getBlock())
861 if (
auto loadOp = mlir::dyn_cast<cir::LoadOp>(defOp)) {
863 mlir::Value ptr = loadOp.getAddr();
864 auto funcOp = returnOp->getParentOfType<cir::FuncOp>();
865 assert(funcOp &&
"Return op has no function parent?");
866 mlir::Block &funcEntryBlock = funcOp.getBody().front();
870 mlir::dyn_cast_if_present<cir::AllocaOp>(ptr.getDefiningOp()))
871 return allocaOp->getBlock() == &funcEntryBlock;
877 assert(mlir::isa<cir::ConstantOp>(defOp) &&
"Expected constant op");
886 getReturnOpOperands(cir::ReturnOp returnOp, mlir::Operation *exitOp,
887 mlir::Location loc, mlir::PatternRewriter &rewriter,
888 llvm::SmallVectorImpl<mlir::Value> &returnValues)
const {
889 mlir::Block *destBlock = rewriter.getInsertionBlock();
890 auto funcOp = exitOp->getParentOfType<cir::FuncOp>();
891 assert(funcOp &&
"Return op has no function parent?");
892 mlir::Block &funcEntryBlock = funcOp.getBody().front();
894 for (mlir::Value operand : returnOp.getOperands()) {
895 if (shouldSinkReturnOperand(operand, returnOp)) {
897 mlir::Operation *defOp = operand.getDefiningOp();
898 rewriter.moveOpBefore(defOp, destBlock, destBlock->end());
899 returnValues.push_back(operand);
902 cir::AllocaOp alloca;
904 mlir::OpBuilder::InsertionGuard guard(rewriter);
905 rewriter.setInsertionPointToStart(&funcEntryBlock);
906 cir::CIRDataLayout dataLayout(
907 funcOp->getParentOfType<mlir::ModuleOp>());
909 dataLayout.getAlignment(operand.getType(),
true).value();
910 cir::PointerType ptrType = cir::PointerType::get(operand.getType());
911 alloca = cir::AllocaOp::create(rewriter, loc, ptrType,
912 operand.getType(),
"__ret_operand_tmp",
913 rewriter.getI64IntegerAttr(alignment));
918 mlir::OpBuilder::InsertionGuard guard(rewriter);
919 rewriter.setInsertionPoint(exitOp);
920 cir::StoreOp::create(rewriter, loc, operand, alloca,
923 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
927 rewriter.setInsertionPointToEnd(destBlock);
928 auto loaded = cir::LoadOp::create(
929 rewriter, loc, alloca,
false,
930 false, mlir::IntegerAttr(),
931 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
932 returnValues.push_back(loaded);
942 createExitTerminator(mlir::Operation *exitOp, mlir::Location loc,
943 mlir::Block *continueBlock,
944 mlir::PatternRewriter &rewriter)
const {
945 return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(exitOp)
946 .Case<cir::YieldOp>([&](
auto) {
948 cir::BrOp::create(rewriter, loc, continueBlock);
949 return mlir::success();
951 .Case<cir::BreakOp>([&](
auto) {
953 cir::BreakOp::create(rewriter, loc);
954 return mlir::success();
956 .Case<cir::ContinueOp>([&](
auto) {
958 cir::ContinueOp::create(rewriter, loc);
959 return mlir::success();
961 .Case<cir::ReturnOp>([&](
auto returnOp) {
965 if (returnOp.hasOperand()) {
966 llvm::SmallVector<mlir::Value, 2> returnValues;
967 getReturnOpOperands(returnOp, exitOp, loc, rewriter, returnValues);
968 cir::ReturnOp::create(rewriter, loc, returnValues);
970 cir::ReturnOp::create(rewriter, loc);
972 return mlir::success();
974 .Case<cir::GotoOp>([&](
auto gotoOp) {
981 cir::UnreachableOp::create(rewriter, loc);
982 return gotoOp.emitError(
983 "goto in cleanup scope is not yet implemented");
985 .
Default([&](mlir::Operation *op) {
986 cir::UnreachableOp::create(rewriter, loc);
987 return op->emitError(
988 "unexpected exit operation in cleanup scope body");
994 static bool regionExitsOnlyFromLastBlock(mlir::Region ®ion) {
995 for (mlir::Block &block : region) {
996 if (&block == ®ion.back())
998 bool expectedTerminator =
999 llvm::TypeSwitch<mlir::Operation *, bool>(block.getTerminator())
1006 .Case<cir::YieldOp, cir::ReturnOp, cir::ResumeFlatOp,
1007 cir::ContinueOp, cir::BreakOp, cir::GotoOp>(
1008 [](
auto) {
return false; })
1017 .Case<cir::TryCallOp>([](
auto) {
return false; })
1021 .Case<cir::EhDispatchOp>([](
auto) {
return false; })
1025 .Case<cir::SwitchFlatOp>([](
auto) {
return false; })
1028 .Case<cir::UnreachableOp, cir::TrapOp>([](
auto) {
return true; })
1030 .Case<cir::IndirectBrOp>([](
auto) {
return false; })
1033 .Case<cir::BrOp>([&](cir::BrOp brOp) {
1034 assert(brOp.getDest()->getParent() == ®ion &&
1035 "branch destination is not in the region");
1038 .Case<cir::BrCondOp>([&](cir::BrCondOp brCondOp) {
1039 assert(brCondOp.getDestTrue()->getParent() == ®ion &&
1040 "branch destination is not in the region");
1041 assert(brCondOp.getDestFalse()->getParent() == ®ion &&
1042 "branch destination is not in the region");
1046 .
Default([](mlir::Operation *) ->
bool {
1047 llvm_unreachable(
"unexpected terminator in cleanup region");
1049 if (!expectedTerminator)
1077 mlir::Block *buildEHCleanupBlocks(cir::CleanupScopeOp cleanupOp,
1079 mlir::Block *insertBefore,
1080 mlir::PatternRewriter &rewriter)
const {
1081 assert(regionExitsOnlyFromLastBlock(cleanupOp.getCleanupRegion()) &&
1082 "cleanup region has exits in non-final blocks");
1086 mlir::Block *blockBeforeClone =
insertBefore->getPrevNode();
1089 rewriter.cloneRegionBefore(cleanupOp.getCleanupRegion(), insertBefore);
1092 mlir::Block *clonedEntry = blockBeforeClone
1093 ? blockBeforeClone->getNextNode()
1098 auto ehTokenType = cir::EhTokenType::get(rewriter.getContext());
1099 mlir::Value ehToken = clonedEntry->addArgument(ehTokenType, loc);
1101 rewriter.setInsertionPointToStart(clonedEntry);
1102 auto beginCleanup = cir::BeginCleanupOp::create(rewriter, loc, ehToken);
1106 mlir::Block *lastClonedBlock =
insertBefore->getPrevNode();
1108 mlir::dyn_cast<cir::YieldOp>(lastClonedBlock->getTerminator());
1110 rewriter.setInsertionPoint(yieldOp);
1111 cir::EndCleanupOp::create(rewriter, loc, beginCleanup.getCleanupToken());
1112 rewriter.replaceOpWithNewOp<cir::ResumeOp>(yieldOp, ehToken);
1114 cleanupOp->emitError(
"Not yet implemented: cleanup region terminated "
1115 "with non-yield operation");
1144 flattenCleanup(cir::CleanupScopeOp cleanupOp,
1145 llvm::SmallVectorImpl<CleanupExit> &exits,
1146 llvm::SmallVectorImpl<cir::CallOp> &callsToRewrite,
1147 llvm::SmallVectorImpl<cir::ResumeOp> &resumeOpsToChain,
1148 mlir::PatternRewriter &rewriter)
const {
1149 mlir::Location loc = cleanupOp.getLoc();
1150 cir::CleanupKind cleanupKind = cleanupOp.getCleanupKind();
1151 bool hasNormalCleanup = cleanupKind == cir::CleanupKind::Normal ||
1152 cleanupKind == cir::CleanupKind::All;
1153 bool hasEHCleanup = cleanupKind == cir::CleanupKind::EH ||
1154 cleanupKind == cir::CleanupKind::All;
1155 bool isMultiExit = exits.size() > 1;
1158 mlir::Block *bodyEntry = &cleanupOp.getBodyRegion().front();
1159 mlir::Block *cleanupEntry = &cleanupOp.getCleanupRegion().front();
1160 mlir::Block *cleanupExit = &cleanupOp.getCleanupRegion().back();
1161 assert(regionExitsOnlyFromLastBlock(cleanupOp.getCleanupRegion()) &&
1162 "cleanup region has exits in non-final blocks");
1163 auto cleanupYield = dyn_cast<cir::YieldOp>(cleanupExit->getTerminator());
1164 if (!cleanupYield) {
1165 return rewriter.notifyMatchFailure(cleanupOp,
1166 "Not yet implemented: cleanup region "
1167 "terminated with non-yield operation");
1174 cir::AllocaOp destSlot;
1175 if (isMultiExit && hasNormalCleanup) {
1176 auto funcOp = cleanupOp->getParentOfType<cir::FuncOp>();
1178 return cleanupOp->emitError(
"cleanup scope not inside a function");
1179 destSlot = getOrCreateCleanupDestSlot(funcOp, rewriter, loc);
1183 mlir::Block *currentBlock = rewriter.getInsertionBlock();
1184 mlir::Block *continueBlock =
1185 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
1195 mlir::Block *unwindBlock =
nullptr;
1196 mlir::Block *ehCleanupEntry =
nullptr;
1198 (!callsToRewrite.empty() || !resumeOpsToChain.empty())) {
1200 buildEHCleanupBlocks(cleanupOp, loc, continueBlock, rewriter);
1204 if (!callsToRewrite.empty())
1205 unwindBlock = buildUnwindBlock(ehCleanupEntry,
true,
1206 loc, ehCleanupEntry, rewriter);
1213 mlir::Block *normalInsertPt =
1214 unwindBlock ? unwindBlock
1215 : (ehCleanupEntry ? ehCleanupEntry : continueBlock);
1218 rewriter.inlineRegionBefore(cleanupOp.getBodyRegion(), normalInsertPt);
1221 if (hasNormalCleanup)
1222 rewriter.inlineRegionBefore(cleanupOp.getCleanupRegion(), normalInsertPt);
1225 rewriter.setInsertionPointToEnd(currentBlock);
1226 cir::BrOp::create(rewriter, loc, bodyEntry);
1229 mlir::LogicalResult result = mlir::success();
1230 if (hasNormalCleanup) {
1232 mlir::Block *exitBlock = rewriter.createBlock(normalInsertPt);
1235 rewriter.setInsertionPoint(cleanupYield);
1236 rewriter.replaceOpWithNewOp<cir::BrOp>(cleanupYield, exitBlock);
1240 rewriter.setInsertionPointToEnd(exitBlock);
1243 auto slotValue = cir::LoadOp::create(
1244 rewriter, loc, destSlot,
false,
1245 false, mlir::IntegerAttr(),
1246 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
1249 llvm::SmallVector<mlir::APInt, 8> caseValues;
1250 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
1251 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
1252 cir::IntType s32Type =
1253 cir::IntType::get(rewriter.getContext(), 32,
true);
1255 for (
const CleanupExit &exit : exits) {
1257 mlir::Block *destBlock = rewriter.createBlock(normalInsertPt);
1258 rewriter.setInsertionPointToEnd(destBlock);
1260 createExitTerminator(exit.exitOp, loc, continueBlock, rewriter);
1263 caseValues.push_back(
1264 llvm::APInt(32,
static_cast<uint64_t>(exit.destinationId),
true));
1265 caseDestinations.push_back(destBlock);
1266 caseOperands.push_back(mlir::ValueRange());
1270 rewriter.setInsertionPoint(exit.exitOp);
1271 auto destIdConst = cir::ConstantOp::create(
1272 rewriter, loc, cir::IntAttr::get(s32Type, exit.destinationId));
1273 cir::StoreOp::create(rewriter, loc, destIdConst, destSlot,
1275 mlir::IntegerAttr(),
1276 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
1277 rewriter.replaceOpWithNewOp<cir::BrOp>(exit.exitOp, cleanupEntry);
1285 if (result.failed())
1290 mlir::Block *defaultBlock = rewriter.createBlock(normalInsertPt);
1291 rewriter.setInsertionPointToEnd(defaultBlock);
1292 cir::UnreachableOp::create(rewriter, loc);
1295 rewriter.setInsertionPointToEnd(exitBlock);
1296 cir::SwitchFlatOp::create(rewriter, loc, slotValue, defaultBlock,
1297 mlir::ValueRange(), caseValues,
1298 caseDestinations, caseOperands);
1302 rewriter.setInsertionPointToEnd(exitBlock);
1303 mlir::Operation *exitOp = exits[0].exitOp;
1304 result = createExitTerminator(exitOp, loc, continueBlock, rewriter);
1307 rewriter.setInsertionPoint(exitOp);
1308 rewriter.replaceOpWithNewOp<cir::BrOp>(exitOp, cleanupEntry);
1313 for (CleanupExit &exit : exits) {
1314 if (isa<cir::YieldOp>(exit.exitOp)) {
1315 rewriter.setInsertionPoint(exit.exitOp);
1316 rewriter.replaceOpWithNewOp<cir::BrOp>(exit.exitOp, continueBlock);
1326 for (cir::CallOp callOp : callsToRewrite)
1335 if (ehCleanupEntry) {
1336 llvm::SmallVector<cir::CallOp> ehCleanupThrowingCalls;
1337 for (mlir::Block *block = ehCleanupEntry; block != continueBlock;
1338 block = block->getNextNode()) {
1339 block->walk([&](cir::CallOp callOp) {
1340 if (!callOp.getNothrow())
1341 ehCleanupThrowingCalls.push_back(callOp);
1344 if (!ehCleanupThrowingCalls.empty()) {
1345 mlir::Block *terminateBlock =
1346 buildTerminateUnwindBlock(loc, continueBlock, rewriter);
1347 for (cir::CallOp callOp : ehCleanupThrowingCalls)
1357 if (ehCleanupEntry) {
1358 for (cir::ResumeOp resumeOp : resumeOpsToChain) {
1359 mlir::Value ehToken = resumeOp.getEhToken();
1360 rewriter.setInsertionPoint(resumeOp);
1361 rewriter.replaceOpWithNewOp<cir::BrOp>(
1362 resumeOp, mlir::ValueRange{ehToken}, ehCleanupEntry);
1367 rewriter.eraseOp(cleanupOp);
1375 return mlir::success();
1379 matchAndRewrite(cir::CleanupScopeOp cleanupOp,
1380 mlir::PatternRewriter &rewriter)
const override {
1381 mlir::OpBuilder::InsertionGuard guard(rewriter);
1395 llvm::SmallVector<cir::CleanupScopeOp> deadNestedOps;
1396 cleanupOp.getBodyRegion().walk([&](cir::CleanupScopeOp nested) {
1397 if (mlir::isOpTriviallyDead(nested))
1398 deadNestedOps.push_back(nested);
1400 for (
auto op : deadNestedOps)
1401 rewriter.eraseOp(op);
1403 if (hasNestedOpsToFlatten(cleanupOp.getBodyRegion()))
1404 return mlir::failure();
1406 cir::CleanupKind cleanupKind = cleanupOp.getCleanupKind();
1409 llvm::SmallVector<CleanupExit> exits;
1411 collectExits(cleanupOp.getBodyRegion(), exits, nextId);
1413 assert(!exits.empty() &&
"cleanup scope body has no exit");
1418 llvm::SmallVector<cir::CallOp> callsToRewrite;
1419 if (cleanupKind != cir::CleanupKind::Normal)
1420 collectThrowingCalls(cleanupOp.getBodyRegion(), callsToRewrite);
1424 llvm::SmallVector<cir::ResumeOp> resumeOpsToChain;
1425 if (cleanupKind != cir::CleanupKind::Normal)
1426 collectResumeOps(cleanupOp.getBodyRegion(), resumeOpsToChain);
1428 return flattenCleanup(cleanupOp, exits, callsToRewrite, resumeOpsToChain,
1436static cir::EhInitiateOp traceToEhInitiate(mlir::Value ehToken) {
1438 if (
auto initiate = ehToken.getDefiningOp<cir::EhInitiateOp>())
1440 auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(ehToken);
1443 mlir::Block *pred = blockArg.getOwner()->getSinglePredecessor();
1446 auto brOp = mlir::dyn_cast<cir::BrOp>(pred->getTerminator());
1449 ehToken = brOp.getDestOperands()[blockArg.getArgNumber()];
1454class CIRTryOpFlattening :
public mlir::OpRewritePattern<cir::TryOp> {
1456 using OpRewritePattern<cir::TryOp>::OpRewritePattern;
1461 mlir::Block *buildCatchDispatchBlock(
1462 cir::TryOp tryOp, mlir::ArrayAttr handlerTypes,
1463 llvm::SmallVectorImpl<mlir::Block *> &catchHandlerBlocks,
1464 mlir::Location loc, mlir::Block *insertBefore,
1465 mlir::PatternRewriter &rewriter)
const {
1466 mlir::Block *dispatchBlock = rewriter.createBlock(insertBefore);
1467 auto ehTokenType = cir::EhTokenType::get(rewriter.getContext());
1468 mlir::Value ehToken = dispatchBlock->addArgument(ehTokenType, loc);
1470 rewriter.setInsertionPointToEnd(dispatchBlock);
1473 llvm::SmallVector<mlir::Attribute> catchTypeAttrs;
1474 llvm::SmallVector<mlir::Block *> catchDests;
1475 mlir::Block *defaultDest =
nullptr;
1476 bool defaultIsCatchAll =
false;
1478 for (
auto [typeAttr, handlerBlock] :
1479 llvm::zip(handlerTypes, catchHandlerBlocks)) {
1480 if (mlir::isa<cir::CatchAllAttr>(typeAttr)) {
1481 assert(!defaultDest &&
"multiple catch_all or unwind handlers");
1482 defaultDest = handlerBlock;
1483 defaultIsCatchAll =
true;
1484 }
else if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
1485 assert(!defaultDest &&
"multiple catch_all or unwind handlers");
1486 defaultDest = handlerBlock;
1487 defaultIsCatchAll =
false;
1490 catchTypeAttrs.push_back(typeAttr);
1491 catchDests.push_back(handlerBlock);
1495 assert(defaultDest &&
"dispatch must have a catch_all or unwind handler");
1497 mlir::ArrayAttr catchTypesArrayAttr;
1498 if (!catchTypeAttrs.empty())
1499 catchTypesArrayAttr = rewriter.getArrayAttr(catchTypeAttrs);
1501 cir::EhDispatchOp::create(rewriter, loc, ehToken, catchTypesArrayAttr,
1502 defaultIsCatchAll, defaultDest, catchDests);
1504 return dispatchBlock;
1521 mlir::Block *flattenCatchHandler(mlir::Region &handlerRegion,
1522 mlir::Block *continueBlock,
1524 mlir::Block *insertBefore,
1525 mlir::PatternRewriter &rewriter)
const {
1527 mlir::Block *handlerEntry = &handlerRegion.front();
1530 rewriter.inlineRegionBefore(handlerRegion, insertBefore);
1533 for (mlir::Block &block : llvm::make_range(handlerEntry->getIterator(),
1535 if (
auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator())) {
1548 if (mlir::Operation *prev = yieldOp->getPrevNode())
1549 return isa<cir::EndCatchOp>(prev);
1550 llvm::SmallPtrSet<mlir::Block *, 8> visited;
1551 llvm::SmallVector<mlir::Block *, 4> worklist;
1552 for (mlir::Block *pred : block.getPredecessors())
1553 worklist.push_back(pred);
1554 while (!worklist.empty()) {
1555 mlir::Block *
b = worklist.pop_back_val();
1556 if (!visited.insert(
b).second)
1558 mlir::Operation *term =
b->getTerminator();
1559 if (mlir::Operation *prev = term->getPrevNode()) {
1560 if (isa<cir::EndCatchOp>(prev))
1563 for (mlir::Block *pred :
b->getPredecessors())
1564 worklist.push_back(pred);
1568 "expected end_catch reachable before yield "
1569 "in catch handler");
1570 rewriter.setInsertionPoint(yieldOp);
1571 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, continueBlock);
1575 return handlerEntry;
1584 mlir::Block *flattenUnwindHandler(mlir::Region &unwindRegion,
1586 mlir::Block *insertBefore,
1587 mlir::PatternRewriter &rewriter)
const {
1588 mlir::Block *unwindEntry = &unwindRegion.front();
1589 rewriter.inlineRegionBefore(unwindRegion, insertBefore);
1594 matchAndRewrite(cir::TryOp tryOp,
1595 mlir::PatternRewriter &rewriter)
const override {
1602 for (mlir::Region ®ion : tryOp->getRegions())
1603 if (hasNestedOpsToFlatten(region))
1604 return mlir::failure();
1606 mlir::OpBuilder::InsertionGuard guard(rewriter);
1607 mlir::Location loc = tryOp.getLoc();
1609 mlir::ArrayAttr handlerTypes = tryOp.getHandlerTypesAttr();
1610 mlir::MutableArrayRef<mlir::Region> handlerRegions =
1611 tryOp.getHandlerRegions();
1614 llvm::SmallVector<cir::CallOp> callsToRewrite;
1615 collectThrowingCalls(tryOp.getTryRegion(), callsToRewrite);
1618 llvm::SmallVector<cir::ResumeOp> resumeOpsToChain;
1619 collectResumeOps(tryOp.getTryRegion(), resumeOpsToChain);
1622 mlir::Block *currentBlock = rewriter.getInsertionBlock();
1623 mlir::Block *continueBlock =
1624 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
1627 mlir::Block *bodyEntry = &tryOp.getTryRegion().front();
1628 mlir::Block *bodyExit = &tryOp.getTryRegion().back();
1631 rewriter.inlineRegionBefore(tryOp.getTryRegion(), continueBlock);
1634 rewriter.setInsertionPointToEnd(currentBlock);
1635 cir::BrOp::create(rewriter, loc, bodyEntry);
1638 if (
auto bodyYield = dyn_cast<cir::YieldOp>(bodyExit->getTerminator())) {
1639 rewriter.setInsertionPoint(bodyYield);
1640 rewriter.replaceOpWithNewOp<cir::BrOp>(bodyYield, continueBlock);
1644 if (!handlerTypes || handlerTypes.empty()) {
1645 rewriter.eraseOp(tryOp);
1646 return mlir::success();
1654 if (callsToRewrite.empty() && resumeOpsToChain.empty()) {
1655 for (mlir::Region &handlerRegion : handlerRegions)
1656 for (mlir::Block &block : handlerRegion)
1657 block.dropAllDefinedValueUses();
1658 rewriter.eraseOp(tryOp);
1659 return mlir::success();
1665 llvm::SmallVector<mlir::Block *> catchHandlerBlocks;
1667 for (
const auto &[idx, typeAttr] : llvm::enumerate(handlerTypes)) {
1668 mlir::Region &handlerRegion = handlerRegions[idx];
1670 if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
1671 mlir::Block *unwindEntry =
1672 flattenUnwindHandler(handlerRegion, loc, continueBlock, rewriter);
1673 catchHandlerBlocks.push_back(unwindEntry);
1675 mlir::Block *handlerEntry = flattenCatchHandler(
1676 handlerRegion, continueBlock, loc, continueBlock, rewriter);
1677 catchHandlerBlocks.push_back(handlerEntry);
1682 mlir::Block *dispatchBlock =
1683 buildCatchDispatchBlock(tryOp, handlerTypes, catchHandlerBlocks, loc,
1684 catchHandlerBlocks.front(), rewriter);
1695 handlerTypes && llvm::any_of(handlerTypes, [](mlir::Attribute attr) {
1696 return mlir::isa<cir::CatchAllAttr>(attr);
1705 bool isCleanupOnly = tryOp.getCleanup() && !hasCatchAll;
1706 if (!callsToRewrite.empty()) {
1708 mlir::Block *unwindBlock = buildUnwindBlock(dispatchBlock, isCleanupOnly,
1709 loc, dispatchBlock, rewriter);
1711 for (cir::CallOp callOp : callsToRewrite)
1718 for (cir::ResumeOp resumeOp : resumeOpsToChain) {
1723 if (
auto ehInitiate = traceToEhInitiate(resumeOp.getEhToken())) {
1724 rewriter.modifyOpInPlace(ehInitiate,
1725 [&] { ehInitiate.removeCleanupAttr(); });
1729 mlir::Value ehToken = resumeOp.getEhToken();
1730 rewriter.setInsertionPoint(resumeOp);
1731 rewriter.replaceOpWithNewOp<cir::BrOp>(
1732 resumeOp, mlir::ValueRange{ehToken}, dispatchBlock);
1736 rewriter.eraseOp(tryOp);
1738 return mlir::success();
1742void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
1744 .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
1745 CIRSwitchOpFlattening, CIRTernaryOpFlattening,
1746 CIRCleanupScopeOpFlattening, CIRTryOpFlattening>(
1747 patterns.getContext());
1750void CIRFlattenCFGPass::runOnOperation() {
1751 RewritePatternSet patterns(&getContext());
1752 populateFlattenCFGPatterns(patterns);
1755 llvm::SmallVector<Operation *, 16> ops;
1756 getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
1757 if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp, CleanupScopeOp,
1763 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
1764 signalPassFailure();
1772 return std::make_unique<CIRFlattenCFGPass>();
mlir::Block * replaceCallWithTryCall(cir::CallOp callOp, mlir::Block *unwindDest, mlir::Location loc, mlir::RewriterBase &rewriter)
Replace a cir::CallOp with a cir::TryCallOp whose unwind destination is unwindDest.
std::unique_ptr< Pass > createCIRFlattenCFGPass()
int const char * function
float __ovld __cnfn step(float, float)
Returns 0.0 if x < edge, otherwise it returns 1.0.
static bool stackSaveOp()