15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/IR/Block.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Interfaces/SideEffectInterfaces.h"
20#include "mlir/Support/LogicalResult.h"
21#include "mlir/Transforms/DialectConversion.h"
22#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27#include "llvm/ADT/TypeSwitch.h"
33#define GEN_PASS_DEF_CIRFLATTENCFG
34#include "clang/CIR/Dialect/Passes.h.inc"
40void lowerTerminator(mlir::Operation *op, mlir::Block *dest,
41 mlir::PatternRewriter &rewriter) {
42 assert(op->hasTrait<mlir::OpTrait::IsTerminator>() &&
"not a terminator");
43 mlir::OpBuilder::InsertionGuard guard(rewriter);
44 rewriter.setInsertionPoint(op);
45 rewriter.replaceOpWithNewOp<cir::BrOp>(op, dest);
50template <
typename... Ops>
51void walkRegionSkipping(
53 mlir::function_ref<mlir::WalkResult(mlir::Operation *)> callback) {
54 region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
56 return mlir::WalkResult::skip();
61struct CIRFlattenCFGPass :
public impl::CIRFlattenCFGBase<CIRFlattenCFGPass> {
63 CIRFlattenCFGPass() =
default;
64 void runOnOperation()
override;
67struct CIRIfFlattening :
public mlir::OpRewritePattern<cir::IfOp> {
68 using OpRewritePattern<IfOp>::OpRewritePattern;
71 matchAndRewrite(cir::IfOp ifOp,
72 mlir::PatternRewriter &rewriter)
const override {
73 mlir::OpBuilder::InsertionGuard guard(rewriter);
74 mlir::Location loc = ifOp.getLoc();
75 bool emptyElse = ifOp.getElseRegion().empty();
76 mlir::Block *currentBlock = rewriter.getInsertionBlock();
77 mlir::Block *remainingOpsBlock =
78 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
79 mlir::Block *continueBlock;
80 if (ifOp->getResults().empty())
81 continueBlock = remainingOpsBlock;
83 llvm_unreachable(
"NYI");
86 mlir::Block *thenBeforeBody = &ifOp.getThenRegion().front();
87 mlir::Block *thenAfterBody = &ifOp.getThenRegion().back();
88 rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock);
90 rewriter.setInsertionPointToEnd(thenAfterBody);
91 if (
auto thenYieldOp =
92 dyn_cast<cir::YieldOp>(thenAfterBody->getTerminator())) {
93 rewriter.replaceOpWithNewOp<cir::BrOp>(thenYieldOp, thenYieldOp.getArgs(),
97 rewriter.setInsertionPointToEnd(continueBlock);
100 mlir::Block *elseBeforeBody =
nullptr;
101 mlir::Block *elseAfterBody =
nullptr;
103 elseBeforeBody = &ifOp.getElseRegion().front();
104 elseAfterBody = &ifOp.getElseRegion().back();
105 rewriter.inlineRegionBefore(ifOp.getElseRegion(), continueBlock);
107 elseBeforeBody = elseAfterBody = continueBlock;
110 rewriter.setInsertionPointToEnd(currentBlock);
111 cir::BrCondOp::create(rewriter, loc, ifOp.getCondition(), thenBeforeBody,
115 rewriter.setInsertionPointToEnd(elseAfterBody);
116 if (
auto elseYieldOP =
117 dyn_cast<cir::YieldOp>(elseAfterBody->getTerminator())) {
118 rewriter.replaceOpWithNewOp<cir::BrOp>(
119 elseYieldOP, elseYieldOP.getArgs(), continueBlock);
123 rewriter.replaceOp(ifOp, continueBlock->getArguments());
124 return mlir::success();
128class CIRScopeOpFlattening :
public mlir::OpRewritePattern<cir::ScopeOp> {
130 using OpRewritePattern<cir::ScopeOp>::OpRewritePattern;
133 matchAndRewrite(cir::ScopeOp scopeOp,
134 mlir::PatternRewriter &rewriter)
const override {
135 mlir::OpBuilder::InsertionGuard guard(rewriter);
136 mlir::Location loc = scopeOp.getLoc();
144 if (scopeOp.isEmpty()) {
145 rewriter.eraseOp(scopeOp);
146 return mlir::success();
151 mlir::Block *currentBlock = rewriter.getInsertionBlock();
152 mlir::Block *continueBlock =
153 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
154 if (scopeOp.getNumResults() > 0)
155 continueBlock->addArguments(scopeOp.getResultTypes(), loc);
158 mlir::Block *beforeBody = &scopeOp.getScopeRegion().front();
159 mlir::Block *afterBody = &scopeOp.getScopeRegion().back();
160 rewriter.inlineRegionBefore(scopeOp.getScopeRegion(), continueBlock);
163 rewriter.setInsertionPointToEnd(currentBlock);
165 cir::BrOp::create(rewriter, loc, mlir::ValueRange(), beforeBody);
169 rewriter.setInsertionPointToEnd(afterBody);
170 if (
auto yieldOp = dyn_cast<cir::YieldOp>(afterBody->getTerminator())) {
171 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(),
176 rewriter.replaceOp(scopeOp, continueBlock->getArguments());
178 return mlir::success();
182class CIRSwitchOpFlattening :
public mlir::OpRewritePattern<cir::SwitchOp> {
184 using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
186 inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
187 cir::YieldOp yieldOp,
188 mlir::Block *destination)
const {
189 rewriter.setInsertionPoint(yieldOp);
190 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
195 Block *condBrToRangeDestination(cir::SwitchOp op,
196 mlir::PatternRewriter &rewriter,
197 mlir::Block *rangeDestination,
198 mlir::Block *defaultDestination,
199 const APInt &lowerBound,
200 const APInt &upperBound)
const {
201 assert(lowerBound.sle(upperBound) &&
"Invalid range");
202 mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
203 cir::IntType sIntType = cir::IntType::get(op.getContext(), 32,
true);
204 cir::IntType uIntType = cir::IntType::get(op.getContext(), 32,
false);
206 cir::ConstantOp rangeLength = cir::ConstantOp::create(
207 rewriter, op.getLoc(),
208 cir::IntAttr::get(sIntType, upperBound - lowerBound));
210 cir::ConstantOp lowerBoundValue = cir::ConstantOp::create(
211 rewriter, op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
212 mlir::Value diffValue = cir::SubOp::create(
213 rewriter, op.getLoc(), op.getCondition(), lowerBoundValue);
216 cir::CastOp uDiffValue = cir::CastOp::create(
217 rewriter, op.getLoc(), uIntType, CastKind::integral, diffValue);
218 cir::CastOp uRangeLength = cir::CastOp::create(
219 rewriter, op.getLoc(), uIntType, CastKind::integral, rangeLength);
221 cir::CmpOp cmpResult = cir::CmpOp::create(
222 rewriter, op.getLoc(), cir::CmpOpKind::le, uDiffValue, uRangeLength);
223 cir::BrCondOp::create(rewriter, op.getLoc(), cmpResult, rangeDestination,
229 matchAndRewrite(cir::SwitchOp op,
230 mlir::PatternRewriter &rewriter)
const override {
234 bool hasNestedCleanup = op->walk([&](cir::CleanupScopeOp) {
235 return mlir::WalkResult::interrupt();
237 if (hasNestedCleanup)
238 return mlir::failure();
240 llvm::SmallVector<CaseOp> cases;
241 op.collectCases(cases);
245 rewriter.eraseOp(op);
246 return mlir::success();
250 mlir::Block *exitBlock = rewriter.splitBlock(
251 rewriter.getBlock(), op->getNextNode()->getIterator());
264 cir::YieldOp switchYield =
nullptr;
266 for (mlir::Block &block :
267 llvm::make_early_inc_range(op.getBody().getBlocks()))
268 if (
auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
269 switchYield = yieldOp;
271 assert(!op.getBody().empty());
272 mlir::Block *originalBlock = op->getBlock();
273 mlir::Block *swopBlock =
274 rewriter.splitBlock(originalBlock, op->getIterator());
275 rewriter.inlineRegionBefore(op.getBody(), exitBlock);
278 rewriteYieldOp(rewriter, switchYield, exitBlock);
280 rewriter.setInsertionPointToEnd(originalBlock);
281 cir::BrOp::create(rewriter, op.getLoc(), swopBlock);
286 llvm::SmallVector<mlir::APInt, 8> caseValues;
287 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
288 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
290 llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
291 llvm::SmallVector<mlir::Block *> rangeDestinations;
292 llvm::SmallVector<mlir::ValueRange> rangeOperands;
295 mlir::Block *defaultDestination = exitBlock;
296 mlir::ValueRange defaultOperands = exitBlock->getArguments();
299 for (cir::CaseOp caseOp : cases) {
300 mlir::Region ®ion = caseOp.getCaseRegion();
303 switch (caseOp.getKind()) {
304 case cir::CaseOpKind::Default:
305 defaultDestination = ®ion.front();
306 defaultOperands = defaultDestination->getArguments();
308 case cir::CaseOpKind::Range:
309 assert(caseOp.getValue().size() == 2 &&
310 "Case range should have 2 case value");
311 rangeValues.push_back(
312 {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),
313 cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});
314 rangeDestinations.push_back(®ion.front());
315 rangeOperands.push_back(rangeDestinations.back()->getArguments());
317 case cir::CaseOpKind::Anyof:
318 case cir::CaseOpKind::Equal:
320 for (
const mlir::Attribute &value : caseOp.getValue()) {
321 caseValues.push_back(cast<cir::IntAttr>(value).getValue());
322 caseDestinations.push_back(®ion.front());
323 caseOperands.push_back(caseDestinations.back()->getArguments());
329 walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
330 region, [&](mlir::Operation *op) {
331 if (!isa<cir::BreakOp>(op))
332 return mlir::WalkResult::advance();
334 lowerTerminator(op, exitBlock, rewriter);
335 return mlir::WalkResult::skip();
339 for (mlir::Block &blk : region.getBlocks()) {
340 if (blk.getNumSuccessors())
343 if (
auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {
344 mlir::Operation *nextOp = caseOp->getNextNode();
345 assert(nextOp &&
"caseOp is not expected to be the last op");
346 mlir::Block *oldBlock = nextOp->getBlock();
347 mlir::Block *newBlock =
348 rewriter.splitBlock(oldBlock, nextOp->getIterator());
349 rewriter.setInsertionPointToEnd(oldBlock);
350 cir::BrOp::create(rewriter, nextOp->getLoc(), mlir::ValueRange(),
352 rewriteYieldOp(rewriter, yieldOp, newBlock);
356 mlir::Block *oldBlock = caseOp->getBlock();
357 mlir::Block *newBlock =
358 rewriter.splitBlock(oldBlock, caseOp->getIterator());
360 mlir::Block &entryBlock = caseOp.getCaseRegion().front();
361 rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);
364 rewriter.setInsertionPointToEnd(oldBlock);
365 cir::BrOp::create(rewriter, caseOp.getLoc(), &entryBlock);
369 for (cir::CaseOp caseOp : cases) {
370 mlir::Block *caseBlock = caseOp->getBlock();
373 if (caseBlock->hasNoPredecessors())
374 rewriter.eraseBlock(caseBlock);
376 rewriter.eraseOp(caseOp);
379 for (
auto [rangeVal, operand, destination] :
380 llvm::zip(rangeValues, rangeOperands, rangeDestinations)) {
381 APInt lowerBound = rangeVal.first;
382 APInt upperBound = rangeVal.second;
385 if (lowerBound.sgt(upperBound))
390 constexpr int kSmallRangeThreshold = 64;
391 if ((upperBound - lowerBound)
392 .ult(llvm::APInt(32, kSmallRangeThreshold))) {
393 for (APInt iValue = lowerBound; iValue.sle(upperBound); ++iValue) {
394 caseValues.push_back(iValue);
395 caseOperands.push_back(operand);
396 caseDestinations.push_back(destination);
402 condBrToRangeDestination(op, rewriter, destination,
403 defaultDestination, lowerBound, upperBound);
404 defaultOperands = operand;
408 rewriter.setInsertionPoint(op);
409 rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>(
410 op, op.getCondition(), defaultDestination, defaultOperands, caseValues,
411 caseDestinations, caseOperands);
413 return mlir::success();
417class CIRLoopOpInterfaceFlattening
418 :
public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> {
420 using mlir::OpInterfaceRewritePattern<
421 cir::LoopOpInterface>::OpInterfaceRewritePattern;
423 inline void lowerConditionOp(cir::ConditionOp op, mlir::Block *body,
425 mlir::PatternRewriter &rewriter)
const {
426 mlir::OpBuilder::InsertionGuard guard(rewriter);
427 rewriter.setInsertionPoint(op);
428 rewriter.replaceOpWithNewOp<cir::BrCondOp>(op, op.getCondition(), body,
433 matchAndRewrite(cir::LoopOpInterface op,
434 mlir::PatternRewriter &rewriter)
const final {
438 bool hasNestedCleanup = op->walk([&](cir::CleanupScopeOp) {
439 return mlir::WalkResult::interrupt();
441 if (hasNestedCleanup)
442 return mlir::failure();
445 mlir::Block *entry = rewriter.getInsertionBlock();
447 rewriter.splitBlock(entry, rewriter.getInsertionPoint());
448 mlir::Block *cond = &op.getCond().front();
449 mlir::Block *body = &op.getBody().front();
451 (op.maybeGetStep() ? &op.maybeGetStep()->front() :
nullptr);
454 rewriter.setInsertionPointToEnd(entry);
455 cir::BrOp::create(rewriter, op.getLoc(), &op.getEntry().front());
462 cast<cir::ConditionOp>(op.getCond().back().getTerminator());
463 lowerConditionOp(conditionOp, body, exit, rewriter);
470 mlir::Block *dest = (
step ?
step : cond);
471 op.walkBodySkippingNestedLoops([&](mlir::Operation *op) {
472 if (!isa<cir::ContinueOp>(op))
473 return mlir::WalkResult::advance();
475 lowerTerminator(op, dest, rewriter);
476 return mlir::WalkResult::skip();
480 walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
481 op.getBody(), [&](mlir::Operation *op) {
482 if (!isa<cir::BreakOp>(op))
483 return mlir::WalkResult::advance();
485 lowerTerminator(op, exit, rewriter);
486 return mlir::WalkResult::skip();
490 for (mlir::Block &blk : op.getBody().getBlocks()) {
491 auto bodyYield = dyn_cast<cir::YieldOp>(blk.getTerminator());
493 lowerTerminator(bodyYield, (
step ?
step : cond), rewriter);
501 cast<cir::YieldOp>(op.maybeGetStep()->back().getTerminator()), cond,
505 rewriter.inlineRegionBefore(op.getCond(), exit);
506 rewriter.inlineRegionBefore(op.getBody(), exit);
508 rewriter.inlineRegionBefore(*op.maybeGetStep(), exit);
510 rewriter.eraseOp(op);
511 return mlir::success();
515class CIRTernaryOpFlattening :
public mlir::OpRewritePattern<cir::TernaryOp> {
517 using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
520 matchAndRewrite(cir::TernaryOp op,
521 mlir::PatternRewriter &rewriter)
const override {
522 Location loc = op->getLoc();
523 Block *condBlock = rewriter.getInsertionBlock();
524 Block::iterator opPosition = rewriter.getInsertionPoint();
525 Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
526 llvm::SmallVector<mlir::Location, 2> locs;
529 if (op->getResultTypes().size())
531 Block *continueBlock =
532 rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
533 cir::BrOp::create(rewriter, loc, remainingOpsBlock);
535 Region &trueRegion = op.getTrueRegion();
536 Block *trueBlock = &trueRegion.front();
537 mlir::Operation *trueTerminator = trueRegion.back().getTerminator();
538 rewriter.setInsertionPointToEnd(&trueRegion.back());
544 if (
auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator)) {
545 rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
547 }
else if (isa<cir::UnreachableOp>(trueTerminator)) {
550 trueTerminator->emitError(
"unexpected terminator in ternary true region, "
551 "expected yield or unreachable, got: ")
552 << trueTerminator->getName();
555 return mlir::success();
557 rewriter.inlineRegionBefore(trueRegion, continueBlock);
559 Block *falseBlock = continueBlock;
560 Region &falseRegion = op.getFalseRegion();
562 falseBlock = &falseRegion.front();
563 mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
564 rewriter.setInsertionPointToEnd(&falseRegion.back());
567 if (
auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator)) {
568 rewriter.replaceOpWithNewOp<cir::BrOp>(
569 falseYieldOp, falseYieldOp.getArgs(), continueBlock);
570 }
else if (isa<cir::UnreachableOp>(falseTerminator)) {
573 falseTerminator->emitError(
"unexpected terminator in ternary false "
574 "region, expected yield or unreachable, got: ")
575 << falseTerminator->getName();
578 return mlir::success();
580 rewriter.inlineRegionBefore(falseRegion, continueBlock);
582 rewriter.setInsertionPointToEnd(condBlock);
583 cir::BrCondOp::create(rewriter, loc, op.getCond(), trueBlock, falseBlock);
585 rewriter.replaceOp(op, continueBlock->getArguments());
588 return mlir::success();
595static cir::AllocaOp getOrCreateCleanupDestSlot(cir::FuncOp funcOp,
596 mlir::PatternRewriter &rewriter,
597 mlir::Location loc) {
598 mlir::Block &entryBlock = funcOp.getBody().front();
601 auto it = llvm::find_if(entryBlock, [](
auto &op) {
602 return mlir::isa<AllocaOp>(&op) &&
603 mlir::cast<AllocaOp>(&op).getCleanupDestSlot();
605 if (it != entryBlock.end())
606 return mlir::cast<cir::AllocaOp>(*it);
609 mlir::OpBuilder::InsertionGuard guard(rewriter);
610 rewriter.setInsertionPointToStart(&entryBlock);
611 cir::IntType s32Type =
612 cir::IntType::get(rewriter.getContext(), 32,
true);
613 cir::PointerType ptrToS32Type = cir::PointerType::get(s32Type);
615 uint64_t alignment = dataLayout.getAlignment(s32Type,
true).value();
616 auto allocaOp = cir::AllocaOp::create(
617 rewriter, loc, ptrToS32Type, s32Type,
"__cleanup_dest_slot",
618 rewriter.getI64IntegerAttr(alignment));
619 allocaOp.setCleanupDestSlot(
true);
631collectThrowingCalls(mlir::Region ®ion,
633 region.walk([&](cir::CallOp callOp) {
634 if (!callOp.getNothrow())
635 callsToRewrite.push_back(callOp);
645static void collectResumeOps(mlir::Region ®ion,
647 region.walk([&](cir::ResumeOp resumeOp) { resumeOps.push_back(resumeOp); });
652static void replaceCallWithTryCall(cir::CallOp callOp, mlir::Block *unwindDest,
654 mlir::PatternRewriter &rewriter) {
655 mlir::Block *callBlock = callOp->getBlock();
657 assert(!callOp.getNothrow() &&
"call is not expected to throw");
661 mlir::Block *normalDest =
662 rewriter.splitBlock(callBlock, std::next(callOp->getIterator()));
665 rewriter.setInsertionPoint(callOp);
666 cir::TryCallOp tryCallOp;
667 if (callOp.isIndirect()) {
668 mlir::Value indTarget = callOp.getIndirectCall();
669 auto ptrTy = mlir::cast<cir::PointerType>(indTarget.getType());
670 auto resTy = mlir::cast<cir::FuncType>(ptrTy.getPointee());
672 cir::TryCallOp::create(rewriter, loc, indTarget, resTy, normalDest,
673 unwindDest, callOp.getArgOperands());
675 mlir::Type resType = callOp->getNumResults() > 0
676 ? callOp->getResult(0).getType()
679 cir::TryCallOp::create(rewriter, loc, callOp.getCalleeAttr(), resType,
680 normalDest, unwindDest, callOp.getArgOperands());
685 llvm::StringRef excludedAttrs[] = {
686 CIRDialect::getCalleeAttrName(),
687 CIRDialect::getOperandSegmentSizesAttrName(),
692 llvm::StringRef unexpectedAttrs[] = {
693 CIRDialect::getNoThrowAttrName(),
694 CIRDialect::getNoUnwindAttrName(),
697 for (mlir::NamedAttribute attr : callOp->getAttrs()) {
698 if (llvm::is_contained(excludedAttrs,
attr.getName()))
700 assert(!llvm::is_contained(unexpectedAttrs,
attr.getName()) &&
701 "unexpected attribute on converted call");
702 tryCallOp->setAttr(
attr.getName(),
attr.getValue());
706 if (callOp->getNumResults() > 0)
707 callOp->getResult(0).replaceAllUsesWith(tryCallOp.getResult());
709 rewriter.eraseOp(callOp);
715static mlir::Block *buildUnwindBlock(mlir::Block *dest,
bool isCleanupOnly,
717 mlir::Block *insertBefore,
718 mlir::PatternRewriter &rewriter) {
719 mlir::Block *unwindBlock = rewriter.createBlock(insertBefore);
720 rewriter.setInsertionPointToEnd(unwindBlock);
722 cir::EhInitiateOp::create(rewriter, loc, isCleanupOnly);
723 cir::BrOp::create(rewriter, loc, mlir::ValueRange{ehInitiate.getEhToken()},
731static mlir::Block *buildTerminateUnwindBlock(mlir::Location loc,
732 mlir::Block *insertBefore,
733 mlir::PatternRewriter &rewriter) {
734 mlir::Block *terminateBlock = rewriter.createBlock(insertBefore);
735 rewriter.setInsertionPointToEnd(terminateBlock);
736 auto ehInitiate = cir::EhInitiateOp::create(rewriter, loc,
false);
737 cir::EhTerminateOp::create(rewriter, loc, ehInitiate.getEhToken());
738 return terminateBlock;
741class CIRCleanupScopeOpFlattening
742 :
public mlir::OpRewritePattern<cir::CleanupScopeOp> {
744 using OpRewritePattern<cir::CleanupScopeOp>::OpRewritePattern;
749 mlir::Operation *exitOp;
755 CleanupExit(mlir::Operation *op,
int id) : exitOp(op), destinationId(id) {}
774 void collectExits(mlir::Region &cleanupBodyRegion,
775 llvm::SmallVectorImpl<CleanupExit> &exits,
780 for (mlir::Block &block : cleanupBodyRegion) {
781 auto *terminator = block.getTerminator();
782 if (isa<cir::YieldOp>(terminator))
783 exits.emplace_back(terminator, nextId++);
790 auto collectExitsInLoop = [&](mlir::Operation *loopOp) {
791 loopOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *nestedOp) {
792 if (isa<cir::ReturnOp, cir::GotoOp>(nestedOp))
793 exits.emplace_back(nestedOp, nextId++);
794 return mlir::WalkResult::advance();
799 std::function<void(mlir::Region &,
bool)> collectExitsInCleanup;
804 collectExitsInSwitch = [&](mlir::Operation *switchOp) {
805 switchOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *nestedOp) {
806 if (isa<cir::CleanupScopeOp>(nestedOp)) {
809 collectExitsInCleanup(
810 cast<cir::CleanupScopeOp>(nestedOp).getBodyRegion(),
812 return mlir::WalkResult::skip();
813 }
else if (isa<cir::LoopOpInterface>(nestedOp)) {
814 collectExitsInLoop(nestedOp);
815 return mlir::WalkResult::skip();
816 }
else if (isa<cir::ReturnOp, cir::GotoOp, cir::ContinueOp>(nestedOp)) {
817 exits.emplace_back(nestedOp, nextId++);
819 return mlir::WalkResult::advance();
826 collectExitsInCleanup = [&](mlir::Region ®ion,
bool ignoreBreak) {
827 region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
834 if (!ignoreBreak && isa<cir::BreakOp>(op)) {
835 exits.emplace_back(op, nextId++);
836 }
else if (isa<cir::ContinueOp, cir::ReturnOp, cir::GotoOp>(op)) {
837 exits.emplace_back(op, nextId++);
838 }
else if (isa<cir::CleanupScopeOp>(op)) {
840 collectExitsInCleanup(cast<cir::CleanupScopeOp>(op).getBodyRegion(),
842 return mlir::WalkResult::skip();
843 }
else if (isa<cir::LoopOpInterface>(op)) {
847 collectExitsInLoop(op);
848 return mlir::WalkResult::skip();
849 }
else if (isa<cir::SwitchOp>(op)) {
853 collectExitsInSwitch(op);
854 return mlir::WalkResult::skip();
856 return mlir::WalkResult::advance();
861 collectExitsInCleanup(cleanupBodyRegion,
false);
867 static bool shouldSinkReturnOperand(mlir::Value operand,
868 cir::ReturnOp returnOp) {
870 mlir::Operation *defOp = operand.getDefiningOp();
876 if (!mlir::isa<cir::ConstantOp, cir::LoadOp>(defOp))
880 if (!operand.hasOneUse())
884 if (defOp->getBlock() != returnOp->getBlock())
887 if (
auto loadOp = mlir::dyn_cast<cir::LoadOp>(defOp)) {
889 mlir::Value ptr = loadOp.getAddr();
890 auto funcOp = returnOp->getParentOfType<cir::FuncOp>();
891 assert(funcOp &&
"Return op has no function parent?");
892 mlir::Block &funcEntryBlock = funcOp.getBody().front();
896 mlir::dyn_cast_if_present<cir::AllocaOp>(ptr.getDefiningOp()))
897 return allocaOp->getBlock() == &funcEntryBlock;
903 assert(mlir::isa<cir::ConstantOp>(defOp) &&
"Expected constant op");
912 getReturnOpOperands(cir::ReturnOp returnOp, mlir::Operation *exitOp,
913 mlir::Location loc, mlir::PatternRewriter &rewriter,
914 llvm::SmallVectorImpl<mlir::Value> &returnValues)
const {
915 mlir::Block *destBlock = rewriter.getInsertionBlock();
916 auto funcOp = exitOp->getParentOfType<cir::FuncOp>();
917 assert(funcOp &&
"Return op has no function parent?");
918 mlir::Block &funcEntryBlock = funcOp.getBody().front();
920 for (mlir::Value operand : returnOp.getOperands()) {
921 if (shouldSinkReturnOperand(operand, returnOp)) {
923 mlir::Operation *defOp = operand.getDefiningOp();
924 rewriter.moveOpBefore(defOp, destBlock, destBlock->end());
925 returnValues.push_back(operand);
928 cir::AllocaOp alloca;
930 mlir::OpBuilder::InsertionGuard guard(rewriter);
931 rewriter.setInsertionPointToStart(&funcEntryBlock);
932 cir::CIRDataLayout dataLayout(
933 funcOp->getParentOfType<mlir::ModuleOp>());
935 dataLayout.getAlignment(operand.getType(),
true).value();
936 cir::PointerType ptrType = cir::PointerType::get(operand.getType());
937 alloca = cir::AllocaOp::create(rewriter, loc, ptrType,
938 operand.getType(),
"__ret_operand_tmp",
939 rewriter.getI64IntegerAttr(alignment));
944 mlir::OpBuilder::InsertionGuard guard(rewriter);
945 rewriter.setInsertionPoint(exitOp);
946 cir::StoreOp::create(rewriter, loc, operand, alloca,
949 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
953 rewriter.setInsertionPointToEnd(destBlock);
954 auto loaded = cir::LoadOp::create(
955 rewriter, loc, alloca,
false,
956 false, mlir::IntegerAttr(),
957 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
958 returnValues.push_back(loaded);
968 createExitTerminator(mlir::Operation *exitOp, mlir::Location loc,
969 mlir::Block *continueBlock,
970 mlir::PatternRewriter &rewriter)
const {
971 return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(exitOp)
972 .Case<cir::YieldOp>([&](
auto) {
974 cir::BrOp::create(rewriter, loc, continueBlock);
975 return mlir::success();
977 .Case<cir::BreakOp>([&](
auto) {
979 cir::BreakOp::create(rewriter, loc);
980 return mlir::success();
982 .Case<cir::ContinueOp>([&](
auto) {
984 cir::ContinueOp::create(rewriter, loc);
985 return mlir::success();
987 .Case<cir::ReturnOp>([&](
auto returnOp) {
991 if (returnOp.hasOperand()) {
992 llvm::SmallVector<mlir::Value, 2> returnValues;
993 getReturnOpOperands(returnOp, exitOp, loc, rewriter, returnValues);
994 cir::ReturnOp::create(rewriter, loc, returnValues);
996 cir::ReturnOp::create(rewriter, loc);
998 return mlir::success();
1000 .Case<cir::GotoOp>([&](
auto gotoOp) {
1007 cir::UnreachableOp::create(rewriter, loc);
1008 return gotoOp.emitError(
1009 "goto in cleanup scope is not yet implemented");
1011 .
Default([&](mlir::Operation *op) {
1012 cir::UnreachableOp::create(rewriter, loc);
1013 return op->emitError(
1014 "unexpected exit operation in cleanup scope body");
1020 static bool regionExitsOnlyFromLastBlock(mlir::Region ®ion) {
1021 for (mlir::Block &block : region) {
1022 if (&block == ®ion.back())
1024 bool expectedTerminator =
1025 llvm::TypeSwitch<mlir::Operation *, bool>(block.getTerminator())
1032 .Case<cir::YieldOp, cir::ReturnOp, cir::ResumeFlatOp,
1033 cir::ContinueOp, cir::BreakOp, cir::GotoOp>(
1034 [](
auto) {
return false; })
1043 .Case<cir::TryCallOp>([](
auto) {
return false; })
1047 .Case<cir::EhDispatchOp>([](
auto) {
return false; })
1051 .Case<cir::SwitchFlatOp>([](
auto) {
return false; })
1054 .Case<cir::UnreachableOp, cir::TrapOp>([](
auto) {
return true; })
1056 .Case<cir::IndirectBrOp>([](
auto) {
return false; })
1059 .Case<cir::BrOp>([&](cir::BrOp brOp) {
1060 assert(brOp.getDest()->getParent() == ®ion &&
1061 "branch destination is not in the region");
1064 .Case<cir::BrCondOp>([&](cir::BrCondOp brCondOp) {
1065 assert(brCondOp.getDestTrue()->getParent() == ®ion &&
1066 "branch destination is not in the region");
1067 assert(brCondOp.getDestFalse()->getParent() == ®ion &&
1068 "branch destination is not in the region");
1072 .
Default([](mlir::Operation *) ->
bool {
1073 llvm_unreachable(
"unexpected terminator in cleanup region");
1075 if (!expectedTerminator)
1103 mlir::Block *buildEHCleanupBlocks(cir::CleanupScopeOp cleanupOp,
1105 mlir::Block *insertBefore,
1106 mlir::PatternRewriter &rewriter)
const {
1107 assert(regionExitsOnlyFromLastBlock(cleanupOp.getCleanupRegion()) &&
1108 "cleanup region has exits in non-final blocks");
1112 mlir::Block *blockBeforeClone =
insertBefore->getPrevNode();
1115 rewriter.cloneRegionBefore(cleanupOp.getCleanupRegion(), insertBefore);
1118 mlir::Block *clonedEntry = blockBeforeClone
1119 ? blockBeforeClone->getNextNode()
1124 auto ehTokenType = cir::EhTokenType::get(rewriter.getContext());
1125 mlir::Value ehToken = clonedEntry->addArgument(ehTokenType, loc);
1127 rewriter.setInsertionPointToStart(clonedEntry);
1128 auto beginCleanup = cir::BeginCleanupOp::create(rewriter, loc, ehToken);
1132 mlir::Block *lastClonedBlock =
insertBefore->getPrevNode();
1134 mlir::dyn_cast<cir::YieldOp>(lastClonedBlock->getTerminator());
1136 rewriter.setInsertionPoint(yieldOp);
1137 cir::EndCleanupOp::create(rewriter, loc, beginCleanup.getCleanupToken());
1138 rewriter.replaceOpWithNewOp<cir::ResumeOp>(yieldOp, ehToken);
1140 cleanupOp->emitError(
"Not yet implemented: cleanup region terminated "
1141 "with non-yield operation");
1170 flattenCleanup(cir::CleanupScopeOp cleanupOp,
1171 llvm::SmallVectorImpl<CleanupExit> &exits,
1172 llvm::SmallVectorImpl<cir::CallOp> &callsToRewrite,
1173 llvm::SmallVectorImpl<cir::ResumeOp> &resumeOpsToChain,
1174 mlir::PatternRewriter &rewriter)
const {
1175 mlir::Location loc = cleanupOp.getLoc();
1176 cir::CleanupKind cleanupKind = cleanupOp.getCleanupKind();
1177 bool hasNormalCleanup = cleanupKind == cir::CleanupKind::Normal ||
1178 cleanupKind == cir::CleanupKind::All;
1179 bool hasEHCleanup = cleanupKind == cir::CleanupKind::EH ||
1180 cleanupKind == cir::CleanupKind::All;
1181 bool isMultiExit = exits.size() > 1;
1184 mlir::Block *bodyEntry = &cleanupOp.getBodyRegion().front();
1185 mlir::Block *cleanupEntry = &cleanupOp.getCleanupRegion().front();
1186 mlir::Block *cleanupExit = &cleanupOp.getCleanupRegion().back();
1187 assert(regionExitsOnlyFromLastBlock(cleanupOp.getCleanupRegion()) &&
1188 "cleanup region has exits in non-final blocks");
1189 auto cleanupYield = dyn_cast<cir::YieldOp>(cleanupExit->getTerminator());
1190 if (!cleanupYield) {
1191 return rewriter.notifyMatchFailure(cleanupOp,
1192 "Not yet implemented: cleanup region "
1193 "terminated with non-yield operation");
1200 cir::AllocaOp destSlot;
1201 if (isMultiExit && hasNormalCleanup) {
1202 auto funcOp = cleanupOp->getParentOfType<cir::FuncOp>();
1204 return cleanupOp->emitError(
"cleanup scope not inside a function");
1205 destSlot = getOrCreateCleanupDestSlot(funcOp, rewriter, loc);
1209 mlir::Block *currentBlock = rewriter.getInsertionBlock();
1210 mlir::Block *continueBlock =
1211 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
1221 mlir::Block *unwindBlock =
nullptr;
1222 mlir::Block *ehCleanupEntry =
nullptr;
1224 (!callsToRewrite.empty() || !resumeOpsToChain.empty())) {
1226 buildEHCleanupBlocks(cleanupOp, loc, continueBlock, rewriter);
1230 if (!callsToRewrite.empty())
1231 unwindBlock = buildUnwindBlock(ehCleanupEntry,
true,
1232 loc, ehCleanupEntry, rewriter);
1239 mlir::Block *normalInsertPt =
1240 unwindBlock ? unwindBlock
1241 : (ehCleanupEntry ? ehCleanupEntry : continueBlock);
1244 rewriter.inlineRegionBefore(cleanupOp.getBodyRegion(), normalInsertPt);
1247 if (hasNormalCleanup)
1248 rewriter.inlineRegionBefore(cleanupOp.getCleanupRegion(), normalInsertPt);
1251 rewriter.setInsertionPointToEnd(currentBlock);
1252 cir::BrOp::create(rewriter, loc, bodyEntry);
1255 mlir::LogicalResult result = mlir::success();
1256 if (hasNormalCleanup) {
1258 mlir::Block *exitBlock = rewriter.createBlock(normalInsertPt);
1261 rewriter.setInsertionPoint(cleanupYield);
1262 rewriter.replaceOpWithNewOp<cir::BrOp>(cleanupYield, exitBlock);
1266 rewriter.setInsertionPointToEnd(exitBlock);
1269 auto slotValue = cir::LoadOp::create(
1270 rewriter, loc, destSlot,
false,
1271 false, mlir::IntegerAttr(),
1272 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
1275 llvm::SmallVector<mlir::APInt, 8> caseValues;
1276 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
1277 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
1278 cir::IntType s32Type =
1279 cir::IntType::get(rewriter.getContext(), 32,
true);
1281 for (
const CleanupExit &exit : exits) {
1283 mlir::Block *destBlock = rewriter.createBlock(normalInsertPt);
1284 rewriter.setInsertionPointToEnd(destBlock);
1286 createExitTerminator(exit.exitOp, loc, continueBlock, rewriter);
1289 caseValues.push_back(
1290 llvm::APInt(32,
static_cast<uint64_t>(exit.destinationId),
true));
1291 caseDestinations.push_back(destBlock);
1292 caseOperands.push_back(mlir::ValueRange());
1296 rewriter.setInsertionPoint(exit.exitOp);
1297 auto destIdConst = cir::ConstantOp::create(
1298 rewriter, loc, cir::IntAttr::get(s32Type, exit.destinationId));
1299 cir::StoreOp::create(rewriter, loc, destIdConst, destSlot,
1301 mlir::IntegerAttr(),
1302 cir::SyncScopeKindAttr(), cir::MemOrderAttr());
1303 rewriter.replaceOpWithNewOp<cir::BrOp>(exit.exitOp, cleanupEntry);
1311 if (result.failed())
1316 mlir::Block *defaultBlock = rewriter.createBlock(normalInsertPt);
1317 rewriter.setInsertionPointToEnd(defaultBlock);
1318 cir::UnreachableOp::create(rewriter, loc);
1321 rewriter.setInsertionPointToEnd(exitBlock);
1322 cir::SwitchFlatOp::create(rewriter, loc, slotValue, defaultBlock,
1323 mlir::ValueRange(), caseValues,
1324 caseDestinations, caseOperands);
1328 rewriter.setInsertionPointToEnd(exitBlock);
1329 mlir::Operation *exitOp = exits[0].exitOp;
1330 result = createExitTerminator(exitOp, loc, continueBlock, rewriter);
1333 rewriter.setInsertionPoint(exitOp);
1334 rewriter.replaceOpWithNewOp<cir::BrOp>(exitOp, cleanupEntry);
1339 for (CleanupExit &exit : exits) {
1340 if (isa<cir::YieldOp>(exit.exitOp)) {
1341 rewriter.setInsertionPoint(exit.exitOp);
1342 rewriter.replaceOpWithNewOp<cir::BrOp>(exit.exitOp, continueBlock);
1352 for (cir::CallOp callOp : callsToRewrite)
1353 replaceCallWithTryCall(callOp, unwindBlock, loc, rewriter);
1361 if (ehCleanupEntry) {
1362 llvm::SmallVector<cir::CallOp> ehCleanupThrowingCalls;
1363 for (mlir::Block *block = ehCleanupEntry; block != continueBlock;
1364 block = block->getNextNode()) {
1365 block->walk([&](cir::CallOp callOp) {
1366 if (!callOp.getNothrow())
1367 ehCleanupThrowingCalls.push_back(callOp);
1370 if (!ehCleanupThrowingCalls.empty()) {
1371 mlir::Block *terminateBlock =
1372 buildTerminateUnwindBlock(loc, continueBlock, rewriter);
1373 for (cir::CallOp callOp : ehCleanupThrowingCalls)
1374 replaceCallWithTryCall(callOp, terminateBlock, loc, rewriter);
1383 if (ehCleanupEntry) {
1384 for (cir::ResumeOp resumeOp : resumeOpsToChain) {
1385 mlir::Value ehToken = resumeOp.getEhToken();
1386 rewriter.setInsertionPoint(resumeOp);
1387 rewriter.replaceOpWithNewOp<cir::BrOp>(
1388 resumeOp, mlir::ValueRange{ehToken}, ehCleanupEntry);
1393 rewriter.eraseOp(cleanupOp);
1401 return mlir::success();
1405 matchAndRewrite(cir::CleanupScopeOp cleanupOp,
1406 mlir::PatternRewriter &rewriter)
const override {
1407 mlir::OpBuilder::InsertionGuard guard(rewriter);
1421 llvm::SmallVector<cir::CleanupScopeOp> deadNestedOps;
1422 cleanupOp.getBodyRegion().walk([&](cir::CleanupScopeOp nested) {
1423 if (mlir::isOpTriviallyDead(nested))
1424 deadNestedOps.push_back(nested);
1426 for (
auto op : deadNestedOps)
1427 rewriter.eraseOp(op);
1429 bool hasNestedOps = cleanupOp.getBodyRegion()
1430 .walk([&](mlir::Operation *op) {
1431 if (isa<cir::CleanupScopeOp, cir::TryOp>(op))
1432 return mlir::WalkResult::interrupt();
1433 return mlir::WalkResult::advance();
1437 return mlir::failure();
1439 cir::CleanupKind cleanupKind = cleanupOp.getCleanupKind();
1442 llvm::SmallVector<CleanupExit> exits;
1444 collectExits(cleanupOp.getBodyRegion(), exits, nextId);
1446 assert(!exits.empty() &&
"cleanup scope body has no exit");
1451 llvm::SmallVector<cir::CallOp> callsToRewrite;
1452 if (cleanupKind != cir::CleanupKind::Normal)
1453 collectThrowingCalls(cleanupOp.getBodyRegion(), callsToRewrite);
1457 llvm::SmallVector<cir::ResumeOp> resumeOpsToChain;
1458 if (cleanupKind != cir::CleanupKind::Normal)
1459 collectResumeOps(cleanupOp.getBodyRegion(), resumeOpsToChain);
1461 return flattenCleanup(cleanupOp, exits, callsToRewrite, resumeOpsToChain,
1469static cir::EhInitiateOp traceToEhInitiate(mlir::Value ehToken) {
1471 if (
auto initiate = ehToken.getDefiningOp<cir::EhInitiateOp>())
1473 auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(ehToken);
1476 mlir::Block *pred = blockArg.getOwner()->getSinglePredecessor();
1479 auto brOp = mlir::dyn_cast<cir::BrOp>(pred->getTerminator());
1482 ehToken = brOp.getDestOperands()[blockArg.getArgNumber()];
1487class CIRTryOpFlattening :
public mlir::OpRewritePattern<cir::TryOp> {
1489 using OpRewritePattern<cir::TryOp>::OpRewritePattern;
1494 mlir::Block *buildCatchDispatchBlock(
1495 cir::TryOp tryOp, mlir::ArrayAttr handlerTypes,
1496 llvm::SmallVectorImpl<mlir::Block *> &catchHandlerBlocks,
1497 mlir::Location loc, mlir::Block *insertBefore,
1498 mlir::PatternRewriter &rewriter)
const {
1499 mlir::Block *dispatchBlock = rewriter.createBlock(insertBefore);
1500 auto ehTokenType = cir::EhTokenType::get(rewriter.getContext());
1501 mlir::Value ehToken = dispatchBlock->addArgument(ehTokenType, loc);
1503 rewriter.setInsertionPointToEnd(dispatchBlock);
1506 llvm::SmallVector<mlir::Attribute> catchTypeAttrs;
1507 llvm::SmallVector<mlir::Block *> catchDests;
1508 mlir::Block *defaultDest =
nullptr;
1509 bool defaultIsCatchAll =
false;
1511 for (
auto [typeAttr, handlerBlock] :
1512 llvm::zip(handlerTypes, catchHandlerBlocks)) {
1513 if (mlir::isa<cir::CatchAllAttr>(typeAttr)) {
1514 assert(!defaultDest &&
"multiple catch_all or unwind handlers");
1515 defaultDest = handlerBlock;
1516 defaultIsCatchAll =
true;
1517 }
else if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
1518 assert(!defaultDest &&
"multiple catch_all or unwind handlers");
1519 defaultDest = handlerBlock;
1520 defaultIsCatchAll =
false;
1523 catchTypeAttrs.push_back(typeAttr);
1524 catchDests.push_back(handlerBlock);
1528 assert(defaultDest &&
"dispatch must have a catch_all or unwind handler");
1530 mlir::ArrayAttr catchTypesArrayAttr;
1531 if (!catchTypeAttrs.empty())
1532 catchTypesArrayAttr = rewriter.getArrayAttr(catchTypeAttrs);
1534 cir::EhDispatchOp::create(rewriter, loc, ehToken, catchTypesArrayAttr,
1535 defaultIsCatchAll, defaultDest, catchDests);
1537 return dispatchBlock;
1554 mlir::Block *flattenCatchHandler(mlir::Region &handlerRegion,
1555 mlir::Block *continueBlock,
1557 mlir::Block *insertBefore,
1558 mlir::PatternRewriter &rewriter)
const {
1560 mlir::Block *handlerEntry = &handlerRegion.front();
1563 rewriter.inlineRegionBefore(handlerRegion, insertBefore);
1566 for (mlir::Block &block : llvm::make_range(handlerEntry->getIterator(),
1568 if (
auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator())) {
1577 if (mlir::Operation *prev = yieldOp->getPrevNode())
1578 return isa<cir::EndCatchOp>(prev);
1581 mlir::Block *
b = block.getSinglePredecessor();
1583 mlir::Operation *term =
b->getTerminator();
1584 if (mlir::Operation *prev = term->getPrevNode())
1585 return isa<cir::EndCatchOp>(prev);
1586 if (!isa<cir::BrOp>(term))
1588 b =
b->getSinglePredecessor();
1591 }() &&
"expected end_catch as last operation before yield "
1592 "in catch handler, with only branches in between");
1593 rewriter.setInsertionPoint(yieldOp);
1594 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, continueBlock);
1598 return handlerEntry;
1607 mlir::Block *flattenUnwindHandler(mlir::Region &unwindRegion,
1609 mlir::Block *insertBefore,
1610 mlir::PatternRewriter &rewriter)
const {
1611 mlir::Block *unwindEntry = &unwindRegion.front();
1612 rewriter.inlineRegionBefore(unwindRegion, insertBefore);
1617 matchAndRewrite(cir::TryOp tryOp,
1618 mlir::PatternRewriter &rewriter)
const override {
1624 ->walk([&](mlir::Operation *op) {
1625 if (isa<cir::CleanupScopeOp, cir::TryOp>(op) && op != tryOp)
1626 return mlir::WalkResult::interrupt();
1627 return mlir::WalkResult::advance();
1631 return mlir::failure();
1633 mlir::OpBuilder::InsertionGuard guard(rewriter);
1634 mlir::Location loc = tryOp.getLoc();
1636 mlir::ArrayAttr handlerTypes = tryOp.getHandlerTypesAttr();
1637 mlir::MutableArrayRef<mlir::Region> handlerRegions =
1638 tryOp.getHandlerRegions();
1641 llvm::SmallVector<cir::CallOp> callsToRewrite;
1642 collectThrowingCalls(tryOp.getTryRegion(), callsToRewrite);
1645 llvm::SmallVector<cir::ResumeOp> resumeOpsToChain;
1646 collectResumeOps(tryOp.getTryRegion(), resumeOpsToChain);
1649 mlir::Block *currentBlock = rewriter.getInsertionBlock();
1650 mlir::Block *continueBlock =
1651 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
1654 mlir::Block *bodyEntry = &tryOp.getTryRegion().front();
1655 mlir::Block *bodyExit = &tryOp.getTryRegion().back();
1658 rewriter.inlineRegionBefore(tryOp.getTryRegion(), continueBlock);
1661 rewriter.setInsertionPointToEnd(currentBlock);
1662 cir::BrOp::create(rewriter, loc, bodyEntry);
1665 if (
auto bodyYield = dyn_cast<cir::YieldOp>(bodyExit->getTerminator())) {
1666 rewriter.setInsertionPoint(bodyYield);
1667 rewriter.replaceOpWithNewOp<cir::BrOp>(bodyYield, continueBlock);
1671 if (!handlerTypes || handlerTypes.empty()) {
1672 rewriter.eraseOp(tryOp);
1673 return mlir::success();
1680 if (callsToRewrite.empty() && resumeOpsToChain.empty()) {
1681 rewriter.eraseOp(tryOp);
1682 return mlir::success();
1688 llvm::SmallVector<mlir::Block *> catchHandlerBlocks;
1690 for (
const auto &[idx, typeAttr] : llvm::enumerate(handlerTypes)) {
1691 mlir::Region &handlerRegion = handlerRegions[idx];
1693 if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
1694 mlir::Block *unwindEntry =
1695 flattenUnwindHandler(handlerRegion, loc, continueBlock, rewriter);
1696 catchHandlerBlocks.push_back(unwindEntry);
1698 mlir::Block *handlerEntry = flattenCatchHandler(
1699 handlerRegion, continueBlock, loc, continueBlock, rewriter);
1700 catchHandlerBlocks.push_back(handlerEntry);
1705 mlir::Block *dispatchBlock =
1706 buildCatchDispatchBlock(tryOp, handlerTypes, catchHandlerBlocks, loc,
1707 catchHandlerBlocks.front(), rewriter);
1718 handlerTypes && llvm::any_of(handlerTypes, [](mlir::Attribute attr) {
1719 return mlir::isa<cir::CatchAllAttr>(attr);
1728 bool isCleanupOnly = tryOp.getCleanup() && !hasCatchAll;
1729 if (!callsToRewrite.empty()) {
1731 mlir::Block *unwindBlock = buildUnwindBlock(dispatchBlock, isCleanupOnly,
1732 loc, dispatchBlock, rewriter);
1734 for (cir::CallOp callOp : callsToRewrite)
1735 replaceCallWithTryCall(callOp, unwindBlock, loc, rewriter);
1741 for (cir::ResumeOp resumeOp : resumeOpsToChain) {
1746 if (
auto ehInitiate = traceToEhInitiate(resumeOp.getEhToken()))
1747 ehInitiate.removeCleanupAttr();
1750 mlir::Value ehToken = resumeOp.getEhToken();
1751 rewriter.setInsertionPoint(resumeOp);
1752 rewriter.replaceOpWithNewOp<cir::BrOp>(
1753 resumeOp, mlir::ValueRange{ehToken}, dispatchBlock);
1757 rewriter.eraseOp(tryOp);
1759 return mlir::success();
1763void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
1765 .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
1766 CIRSwitchOpFlattening, CIRTernaryOpFlattening,
1767 CIRCleanupScopeOpFlattening, CIRTryOpFlattening>(
1768 patterns.getContext());
1771void CIRFlattenCFGPass::runOnOperation() {
1772 RewritePatternSet patterns(&getContext());
1773 populateFlattenCFGPatterns(patterns);
1776 llvm::SmallVector<Operation *, 16> ops;
1777 getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
1778 if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp, CleanupScopeOp,
1784 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
1785 signalPassFailure();
1793 return std::make_unique<CIRFlattenCFGPass>();
const internal::VariadicAllOfMatcher< Attr > attr
std::unique_ptr< Pass > createCIRFlattenCFGPass()
int const char * function
float __ovld __cnfn step(float, float)
Returns 0.0 if x < edge, otherwise it returns 1.0.
static bool stackSaveOp()