28#include "mlir/IR/Builders.h"
33#include "llvm/ADT/DenseMap.h"
34#include "llvm/ADT/SmallVector.h"
35#include "llvm/TargetParser/Triple.h"
41#define GEN_PASS_DEF_CIREHABILOWERING
42#include "clang/CIR/Dialect/Passes.h.inc"
53static cir::FuncOp getOrCreateRuntimeFuncDecl(mlir::ModuleOp mod,
56 cir::FuncType funcTy) {
57 if (
auto existing = mod.lookupSymbol<cir::FuncOp>(
name))
60 mlir::OpBuilder builder(mod.getContext());
61 builder.setInsertionPointToEnd(mod.getBody());
62 auto funcOp = cir::FuncOp::create(builder, loc,
name, funcTy);
63 funcOp.setLinkage(cir::GlobalLinkageKind::ExternalLinkage);
76 explicit EHABILowering(mlir::ModuleOp mod)
77 : mod(mod), ctx(mod.getContext()), builder(ctx) {}
78 virtual ~EHABILowering() =
default;
81 virtual mlir::LogicalResult
run() = 0;
85 mlir::MLIRContext *ctx;
86 mlir::OpBuilder builder;
100class ItaniumEHLowering :
public EHABILowering {
102 using EHABILowering::EHABILowering;
103 mlir::LogicalResult
run()
override;
108 using EhTokenMap = DenseMap<mlir::Value, std::pair<mlir::Value, mlir::Value>>;
110 cir::VoidType voidType;
111 cir::PointerType voidPtrType;
112 cir::PointerType u8PtrType;
113 cir::IntType u32Type;
117 cir::FuncOp personalityFunc;
118 cir::FuncOp beginCatchFunc;
119 cir::FuncOp endCatchFunc;
120 cir::FuncOp clangCallTerminateFunc;
122 constexpr const static ::llvm::StringLiteral kGxxPersonality =
123 "__gxx_personality_v0";
125 void ensureRuntimeDecls(mlir::Location loc);
126 void ensureClangCallTerminate(mlir::Location loc);
127 mlir::LogicalResult lowerFunc(cir::FuncOp funcOp);
128 void lowerEhInitiate(cir::EhInitiateOp initiateOp, EhTokenMap &ehTokenMap,
129 SmallVectorImpl<mlir::Operation *> &deadOps);
130 void lowerDispatch(cir::EhDispatchOp dispatch, mlir::Value exnPtr,
132 SmallVectorImpl<mlir::Operation *> &deadOps);
136mlir::LogicalResult ItaniumEHLowering::run() {
139 voidType = cir::VoidType::get(ctx);
140 voidPtrType = cir::PointerType::get(voidType);
141 auto u8Type = cir::IntType::get(ctx, 8,
false);
142 u8PtrType = cir::PointerType::get(u8Type);
143 u32Type = cir::IntType::get(ctx, 32,
false);
145 for (cir::FuncOp funcOp : mod.getOps<cir::FuncOp>()) {
146 if (mlir::failed(lowerFunc(funcOp)))
147 return mlir::failure();
149 return mlir::success();
154void ItaniumEHLowering::ensureRuntimeDecls(mlir::Location loc) {
157 if (!personalityFunc) {
158 auto s32Type = cir::IntType::get(ctx, 32,
true);
159 auto personalityFuncTy = cir::FuncType::get({}, s32Type,
true);
160 personalityFunc = getOrCreateRuntimeFuncDecl(mod, loc, kGxxPersonality,
164 if (!beginCatchFunc) {
165 auto beginCatchFuncTy =
166 cir::FuncType::get({voidPtrType}, u8PtrType,
false);
167 beginCatchFunc = getOrCreateRuntimeFuncDecl(mod, loc,
"__cxa_begin_catch",
172 auto endCatchFuncTy = cir::FuncType::get({}, voidType,
false);
174 getOrCreateRuntimeFuncDecl(mod, loc,
"__cxa_end_catch", endCatchFuncTy);
187void ItaniumEHLowering::ensureClangCallTerminate(mlir::Location loc) {
188 if (clangCallTerminateFunc)
191 ensureRuntimeDecls(loc);
193 if (
auto existing = mod.lookupSymbol<cir::FuncOp>(
"__clang_call_terminate")) {
194 clangCallTerminateFunc = existing;
198 auto funcTy = cir::FuncType::get({voidPtrType}, voidType,
false);
199 builder.setInsertionPointToEnd(mod.getBody());
201 cir::FuncOp::create(builder, loc,
"__clang_call_terminate", funcTy);
202 funcOp.setLinkage(cir::GlobalLinkageKind::LinkOnceODRLinkage);
203 funcOp.setGlobalVisibility(cir::VisibilityKind::Hidden);
205 mlir::Block *entryBlock = funcOp.addEntryBlock();
206 builder.setInsertionPointToStart(entryBlock);
207 mlir::Value exnArg = entryBlock->getArgument(0);
209 auto catchCall = cir::CallOp::create(
210 builder, loc, mlir::FlatSymbolRefAttr::get(beginCatchFunc), u8PtrType,
211 mlir::ValueRange{exnArg});
212 catchCall.setNothrowAttr(builder.getUnitAttr());
214 auto terminateFuncDecl = getOrCreateRuntimeFuncDecl(
215 mod, loc,
"_ZSt9terminatev",
216 cir::FuncType::get({}, voidType,
false));
217 terminateFuncDecl->setAttr(cir::CIRDialect::getNoReturnAttrName(),
218 builder.getUnitAttr());
219 auto terminateCall = cir::CallOp::create(
220 builder, loc, mlir::FlatSymbolRefAttr::get(terminateFuncDecl), voidType,
222 terminateCall.setNothrowAttr(builder.getUnitAttr());
223 terminateCall->setAttr(cir::CIRDialect::getNoReturnAttrName(),
224 builder.getUnitAttr());
226 cir::UnreachableOp::create(builder, loc);
228 funcOp->setAttr(cir::CIRDialect::getNoReturnAttrName(),
229 builder.getUnitAttr());
230 clangCallTerminateFunc = funcOp;
234mlir::LogicalResult ItaniumEHLowering::lowerFunc(cir::FuncOp funcOp) {
235 if (funcOp.isDeclaration())
236 return mlir::success();
242 SmallVector<cir::EhInitiateOp> initiateOps;
243 funcOp.walk([&](cir::EhInitiateOp op) { initiateOps.push_back(op); });
244 if (initiateOps.empty())
245 return mlir::success();
247 ensureRuntimeDecls(funcOp.getLoc());
254 if (!funcOp.getPersonality())
255 funcOp.setPersonality(kGxxPersonality);
262 EhTokenMap ehTokenMap;
263 SmallVector<mlir::Operation *> deadOps;
264 for (cir::EhInitiateOp initiateOp : initiateOps)
265 lowerEhInitiate(initiateOp, ehTokenMap, deadOps);
269 for (mlir::Operation *op : deadOps)
274 for (mlir::Block &block : funcOp.getBody()) {
275 for (
int i = block.getNumArguments() - 1; i >= 0; --i) {
276 if (mlir::isa<cir::EhTokenType>(block.getArgument(i).getType()))
277 block.eraseArgument(i);
281 return mlir::success();
310void ItaniumEHLowering::lowerEhInitiate(
311 cir::EhInitiateOp initiateOp, EhTokenMap &ehTokenMap,
312 SmallVectorImpl<mlir::Operation *> &deadOps) {
313 mlir::Value rootToken = initiateOp.getEhToken();
317 builder.setInsertionPoint(initiateOp);
318 auto inflightOp = cir::EhInflightOp::create(
319 builder, initiateOp.getLoc(), initiateOp.getCleanup(),
323 ehTokenMap[rootToken] = {inflightOp.getExceptionPtr(),
324 inflightOp.getTypeId()};
330 SmallVector<mlir::Value> worklist;
331 SmallPtrSet<mlir::Value, 8> visited;
332 worklist.push_back(rootToken);
334 while (!worklist.empty()) {
335 mlir::Value current = worklist.pop_back_val();
336 if (!visited.insert(current).second)
341 SmallVector<mlir::Operation *> users;
342 for (mlir::OpOperand &use : current.getUses())
343 users.push_back(use.getOwner());
347 for (mlir::Operation *user : users) {
353 for (
unsigned s = 0;
s < user->getNumSuccessors(); ++
s) {
354 mlir::Block *succ = user->getSuccessor(
s);
355 for (mlir::BlockArgument arg : succ->getArguments()) {
356 if (!mlir::isa<cir::EhTokenType>(
arg.getType()))
358 if (!ehTokenMap.count(arg)) {
359 mlir::Value ptrArg = succ->addArgument(voidPtrType,
arg.getLoc());
360 mlir::Value u32Arg = succ->addArgument(u32Type,
arg.getLoc());
361 ehTokenMap[
arg] = {ptrArg, u32Arg};
363 worklist.push_back(arg);
367 if (
auto op = mlir::dyn_cast<cir::BeginCleanupOp>(user)) {
370 for (
auto &tokenUsers :
371 llvm::make_early_inc_range(op.getCleanupToken().getUses())) {
373 mlir::dyn_cast<cir::EndCleanupOp>(tokenUsers.getOwner()))
377 }
else if (
auto op = mlir::dyn_cast<cir::BeginCatchOp>(user)) {
380 for (
auto &tokenUsers :
381 llvm::make_early_inc_range(op.getCatchToken().getUses())) {
383 mlir::dyn_cast<cir::EndCatchOp>(tokenUsers.getOwner())) {
384 builder.setInsertionPoint(endOp);
385 cir::CallOp::create(builder, endOp.getLoc(),
386 mlir::FlatSymbolRefAttr::get(endCatchFunc),
387 voidType, mlir::ValueRange{});
392 auto [exnPtr, typeId] = ehTokenMap.lookup(op.getEhToken());
393 builder.setInsertionPoint(op);
394 auto callOp = cir::CallOp::create(
395 builder, op.getLoc(), mlir::FlatSymbolRefAttr::get(beginCatchFunc),
396 u8PtrType, mlir::ValueRange{exnPtr});
397 mlir::Value castResult = callOp.getResult();
398 mlir::Type expectedPtrType = op.getExnPtr().getType();
399 if (castResult.getType() != expectedPtrType)
401 cir::CastOp::create(builder, op.getLoc(), expectedPtrType,
402 cir::CastKind::bitcast, callOp.getResult());
403 op.getExnPtr().replaceAllUsesWith(castResult);
405 }
else if (
auto op = mlir::dyn_cast<cir::EhDispatchOp>(user)) {
407 mlir::ArrayAttr catchTypes = op.getCatchTypesAttr();
408 if (catchTypes && catchTypes.size() > 0) {
409 SmallVector<mlir::Attribute> typeSymbols;
410 for (mlir::Attribute attr : catchTypes)
411 typeSymbols.push_back(
412 mlir::cast<cir::GlobalViewAttr>(attr).getSymbol());
413 inflightOp.setCatchTypeListAttr(builder.getArrayAttr(typeSymbols));
415 if (op.getDefaultIsCatchAll())
416 inflightOp.setCatchAllAttr(builder.getUnitAttr());
420 if (!llvm::is_contained(deadOps, op.getOperation())) {
421 auto [exnPtr, typeId] = ehTokenMap.lookup(op.getEhToken());
422 lowerDispatch(op, exnPtr, typeId, deadOps);
424 }
else if (
auto op = mlir::dyn_cast<cir::EhTerminateOp>(user)) {
425 auto [exnPtr, typeId] = ehTokenMap.lookup(op.getEhToken());
426 ensureClangCallTerminate(op.getLoc());
427 builder.setInsertionPoint(op);
428 auto call = cir::CallOp::create(
429 builder, op.getLoc(),
430 mlir::FlatSymbolRefAttr::get(clangCallTerminateFunc), voidType,
431 mlir::ValueRange{exnPtr});
432 call.setNothrowAttr(builder.getUnitAttr());
433 call->setAttr(cir::CIRDialect::getNoReturnAttrName(),
434 builder.getUnitAttr());
435 cir::UnreachableOp::create(builder, op.getLoc());
437 }
else if (
auto op = mlir::dyn_cast<cir::ResumeOp>(user)) {
438 auto [exnPtr, typeId] = ehTokenMap.lookup(op.getEhToken());
439 builder.setInsertionPoint(op);
440 cir::ResumeFlatOp::create(builder, op.getLoc(), exnPtr, typeId);
442 }
else if (
auto op = mlir::dyn_cast<cir::BrOp>(user)) {
444 SmallVector<mlir::Value> newOperands;
445 bool changed =
false;
446 for (mlir::Value operand : op.getDestOperands()) {
447 auto it = ehTokenMap.find(operand);
448 if (it != ehTokenMap.end()) {
449 newOperands.push_back(it->second.first);
450 newOperands.push_back(it->second.second);
453 newOperands.push_back(operand);
457 builder.setInsertionPoint(op);
458 cir::BrOp::create(builder, op.getLoc(), op.getDest(), newOperands);
471void ItaniumEHLowering::lowerDispatch(
472 cir::EhDispatchOp dispatch, mlir::Value exnPtr, mlir::Value typeId,
473 SmallVectorImpl<mlir::Operation *> &deadOps) {
474 mlir::Location dispLoc = dispatch.getLoc();
475 mlir::Block *defaultDest = dispatch.getDefaultDestination();
476 mlir::ArrayAttr catchTypes = dispatch.getCatchTypesAttr();
477 mlir::SuccessorRange catchDests = dispatch.getCatchDestinations();
478 mlir::Block *dispatchBlock = dispatch->getBlock();
483 if (!catchTypes || catchTypes.empty()) {
485 builder.setInsertionPoint(dispatch);
486 cir::BrOp::create(builder, dispLoc, defaultDest,
487 mlir::ValueRange{exnPtr, typeId});
489 unsigned numCatches = catchTypes.size();
495 mlir::Block *
insertBefore = dispatchBlock->getNextNode();
496 mlir::Block *falseDest = defaultDest;
497 mlir::Block *firstCmpBlock =
nullptr;
498 for (
int i = numCatches - 1; i >= 0; --i) {
499 auto *cmpBlock = builder.createBlock(insertBefore, {voidPtrType, u32Type},
502 mlir::Value cmpExnPtr = cmpBlock->getArgument(0);
503 mlir::Value cmpTypeId = cmpBlock->getArgument(1);
505 auto globalView = mlir::cast<cir::GlobalViewAttr>(catchTypes[i]);
507 cir::EhTypeIdOp::create(builder, dispLoc, globalView.getSymbol());
508 auto cmpOp = cir::CmpOp::create(builder, dispLoc, cir::CmpOpKind::eq,
509 cmpTypeId, ehTypeIdOp.getTypeId());
511 cir::BrCondOp::create(builder, dispLoc, cmpOp, catchDests[i], falseDest,
512 mlir::ValueRange{cmpExnPtr, cmpTypeId},
513 mlir::ValueRange{cmpExnPtr, cmpTypeId});
516 falseDest = cmpBlock;
517 firstCmpBlock = cmpBlock;
521 builder.setInsertionPoint(dispatch);
522 cir::BrOp::create(builder, dispLoc, firstCmpBlock,
523 mlir::ValueRange{exnPtr, typeId});
529 deadOps.push_back(dispatch);
536struct CIREHABILoweringPass
537 :
public impl::CIREHABILoweringBase<CIREHABILoweringPass> {
538 CIREHABILoweringPass() =
default;
539 void runOnOperation()
override;
542void CIREHABILoweringPass::runOnOperation() {
543 auto mod = mlir::cast<mlir::ModuleOp>(getOperation());
548 auto tripleAttr = mlir::dyn_cast_if_present<mlir::StringAttr>(
549 mod->getAttr(cir::CIRDialect::getTripleAttrName()));
551 mod.emitError(
"Module has no target triple");
558 llvm::Triple triple(tripleAttr.getValue());
559 std::unique_ptr<EHABILowering> lowering;
560 if (triple.isWindowsMSVCEnvironment()) {
562 "EH ABI lowering is not yet implemented for the Microsoft ABI");
563 return signalPassFailure();
565 lowering = std::make_unique<ItaniumEHLowering>(mod);
568 if (mlir::failed(lowering->run()))
569 return signalPassFailure();
575 return std::make_unique<CIREHABILoweringPass>();
__device__ __2f16 float __ockl_bool s
std::unique_ptr< Pass > createCIREHABILoweringPass()
__DEVICE__ _Tp arg(const std::complex< _Tp > &__c)