10#include "mlir/IR/Attributes.h"
11#include "mlir/IR/BuiltinAttributeInterfaces.h"
12#include "mlir/IR/IRMapping.h"
13#include "mlir/IR/Location.h"
14#include "mlir/IR/Value.h"
32#include "llvm/ADT/StringRef.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/IR/Instructions.h"
35#include "llvm/Support/ErrorHandling.h"
36#include "llvm/Support/MemoryBuffer.h"
37#include "llvm/Support/Path.h"
38#include "llvm/Support/VirtualFileSystem.h"
47#define GEN_PASS_DEF_LOWERINGPREPARE
48#include "clang/CIR/Dialect/Passes.h.inc"
52 SmallString<128> fileName;
54 if (mlirModule.getSymName())
55 fileName = llvm::sys::path::filename(mlirModule.getSymName()->str());
60 for (
size_t i = 0; i < fileName.size(); ++i) {
71struct LoweringPreparePass
72 :
public impl::LoweringPrepareBase<LoweringPreparePass> {
73 LoweringPreparePass() =
default;
83 LoweringPreparePass(
const LoweringPreparePass &other)
84 : impl::LoweringPrepareBase<LoweringPreparePass>(other) {}
86 void runOnOperation()
override;
88 void runOnOp(mlir::Operation *op);
89 void lowerCastOp(cir::CastOp op);
90 void lowerComplexDivOp(cir::ComplexDivOp op);
91 void lowerComplexMulOp(cir::ComplexMulOp op);
92 void lowerUnaryOp(cir::UnaryOpInterface op);
93 void lowerGetGlobalOp(cir::GetGlobalOp op);
94 void lowerGlobalOp(cir::GlobalOp op);
95 void lowerThreeWayCmpOp(cir::CmpThreeWayOp op);
96 void lowerArrayDtor(cir::ArrayDtor op);
97 void lowerArrayCtor(cir::ArrayCtor op);
98 void lowerTrivialCopyCall(cir::CallOp op);
99 void lowerStoreOfConstAggregate(cir::StoreOp op);
100 void lowerLocalInitOp(cir::LocalInitOp op);
105 cir::FuncOp getCalledFunction(cir::CallOp callOp);
114 cir::GlobalOp getOrCreateConstAggregateGlobal(CIRBaseBuilderTy &builder,
116 llvm::StringRef baseName,
118 mlir::TypedAttr constant);
121 cir::FuncOp buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op);
124 void defineGlobalThreadLocalWrapper(cir::GlobalOp op, cir::FuncOp initAlias,
125 bool isVarDefinition);
127 cir::FuncOp defineGlobalThreadLocalInitAlias(cir::GlobalOp op,
128 cir::FuncOp aliasee);
130 cir::FuncOp getOrCreateThreadLocalWrapper(CIRBaseBuilderTy &builder,
136 cir::IfOp buildGlobalTlsGuardCheck(CIRBaseBuilderTy &builder,
137 mlir::Location loc, cir::GlobalOp guard);
139 cir::FuncOp getOrCreateDtorFunc(CIRBaseBuilderTy &builder, cir::GlobalOp op,
140 mlir::Region &dtorRegion,
141 cir::CallOp &dtorCall);
144 void buildCXXGlobalInitFunc();
147 void buildCXXGlobalTlsFunc();
150 void buildGlobalCtorDtorList();
152 cir::FuncOp buildRuntimeFunction(
153 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
155 cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage);
157 cir::GlobalOp getOrCreateRuntimeVariable(
158 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
160 cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage,
161 cir::VisibilityKind visibility = cir::VisibilityKind::Default);
167 llvm::StringMap<FuncOp> cudaKernelMap;
171 void buildCUDAModuleCtor();
172 std::optional<FuncOp> buildCUDAModuleDtor();
173 std::optional<FuncOp> buildHIPModuleDtor();
174 std::optional<FuncOp> buildCUDARegisterGlobals();
175 void buildCUDARegisterGlobalFunctions(cir::CIRBaseBuilderTy &builder,
176 FuncOp regGlobalFunc);
179 void handleStaticLocal(cir::GlobalOp globalOp, cir::LocalInitOp localInitOp);
188 cir::FuncOp getTlsInitFn();
191 cir::GlobalOp createGlobalThreadLocalGuard(CIRBaseBuilderTy &builder,
195 cir::GlobalOp createGuardGlobalOp(CIRBaseBuilderTy &builder,
196 mlir::Location loc, llvm::StringRef name,
197 cir::IntType guardTy,
198 cir::GlobalLinkageKind linkage);
201 cir::GlobalOp getStaticLocalDeclGuardAddress(llvm::StringRef globalSymName) {
202 auto it = staticLocalDeclGuardMap.find(globalSymName);
203 if (it != staticLocalDeclGuardMap.end())
209 void setStaticLocalDeclGuardAddress(llvm::StringRef globalSymName,
210 cir::GlobalOp guard) {
211 staticLocalDeclGuardMap[globalSymName] = guard;
215 cir::GlobalOp getOrCreateStaticLocalDeclGuardAddress(
216 CIRBaseBuilderTy &builder, cir::GlobalOp globalOp, StringRef guardName,
217 bool isLocalVarDecl,
bool useInt8GuardVariable) {
219 cir::CIRDataLayout dataLayout(mlirModule);
220 cir::IntType guardTy;
221 clang::CharUnits guardAlignment;
224 if (useInt8GuardVariable) {
225 guardTy = cir::IntType::get(&getContext(), 8,
true);
227 }
else if (useARMGuardVarABI()) {
229 const unsigned sizeTypeSize =
230 astCtx->getTypeSize(astCtx->getSignedSizeType());
232 cir::IntType::get(&getContext(), sizeTypeSize,
true);
236 guardTy = cir::IntType::get(&getContext(), 64,
true);
240 assert(guardTy && guardAlignment.
getQuantity() != 0);
242 llvm::StringRef globalSymName = globalOp.getSymName();
243 cir::GlobalOp guard = getStaticLocalDeclGuardAddress(globalSymName);
246 guard = createGuardGlobalOp(builder, globalOp->getLoc(), guardName,
247 guardTy, globalOp.getLinkage());
248 guard.setInitialValueAttr(cir::IntAttr::get(guardTy, 0));
249 guard.setDSOLocal(globalOp.getDsoLocal());
250 guard.setAlignment(guardAlignment.
getAsAlign().value());
251 guard.setTlsModel(globalOp.getTlsModel());
257 bool hasComdat = globalOp.getComdat();
258 const llvm::Triple &triple = astCtx->getTargetInfo().getTriple();
261 if (!isLocalVarDecl && hasComdat &&
262 (triple.isOSBinFormatELF() || triple.isOSBinFormatWasm())) {
264 guard.setComdat(
true);
265 }
else if (hasComdat && globalOp.isWeakForLinker()) {
266 guard.setComdat(
true);
269 setStaticLocalDeclGuardAddress(globalSymName, guard);
278 clang::ASTContext *astCtx;
281 mlir::ModuleOp mlirModule;
301 mlir::SymbolTableCollection symbolTables;
304 llvm::StringMap<uint32_t> dynamicInitializerNames;
305 llvm::SmallVector<cir::FuncOp> dynamicInitializers;
306 llvm::SmallVector<cir::FuncOp> globalThreadLocalInitializers;
307 llvm::StringMap<cir::FuncOp> threadLocalWrappers;
308 llvm::StringMap<cir::FuncOp> threadLocalInitAliases;
311 llvm::StringMap<cir::GlobalOp> staticLocalDeclGuardMap;
313 llvm::StringMap<llvm::SmallVector<cir::GlobalOp, 1>> constAggregateGlobals;
316 llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalCtorList;
318 llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalDtorList;
322 bool useARMGuardVarABI()
const {
323 switch (astCtx->getCXXABIKind()) {
324 case clang::TargetCXXABI::GenericARM:
325 case clang::TargetCXXABI::iOS:
326 case clang::TargetCXXABI::WatchOS:
327 case clang::TargetCXXABI::GenericAArch64:
328 case clang::TargetCXXABI::WebAssembly:
335 void emitGlobalGuardedDtorRegion(CIRBaseBuilderTy &builder,
336 cir::GlobalOp global,
337 mlir::Region &dtorRegion,
bool tls,
338 mlir::Block &entryBB) {
340 builder.setInsertionPointToStart(&mlirModule.getBodyRegion().front());
341 cir::GlobalOp handle = getOrCreateRuntimeVariable(
342 builder,
"__dso_handle", global.getLoc(), builder.getI8Type(),
343 cir::GlobalLinkageKind::ExternalLinkage, cir::VisibilityKind::Hidden);
349 cir::CallOp dtorCall;
350 cir::FuncOp dtorFunc =
351 getOrCreateDtorFunc(builder, global, dtorRegion, dtorCall);
356 cir::PointerType voidFnPtrTy = builder.
getVoidFnPtrTy({voidPtrTy});
357 cir::PointerType handlePtrTy = builder.
getPointerTo(handle.getSymType());
359 builder.
getVoidFnTy({voidFnPtrTy, voidPtrTy, handlePtrTy});
361 llvm::StringLiteral nameAtExit =
"__cxa_atexit";
363 nameAtExit = astCtx->getTargetInfo().getTriple().isOSDarwin()
364 ? llvm::StringLiteral(
"_tlv_atexit")
365 : llvm::StringLiteral(
"__cxa_thread_atexit");
367 cir::FuncOp fnAtExit = buildRuntimeFunction(builder, nameAtExit,
368 global.getLoc(), fnAtExitType);
372 builder.setInsertionPointAfter(dtorCall);
374 auto dtorPtrTy = cir::PointerType::get(dtorFunc.getFunctionType());
375 args[0] = cir::GetGlobalOp::create(builder, dtorCall.getLoc(), dtorPtrTy,
376 dtorFunc.getSymName());
377 args[0] = cir::CastOp::create(builder, dtorCall.getLoc(), voidFnPtrTy,
378 cir::CastKind::bitcast, args[0]);
380 cir::CastOp::create(builder, dtorCall.getLoc(), voidPtrTy,
381 cir::CastKind::bitcast, dtorCall.getArgOperand(0));
382 args[2] = cir::GetGlobalOp::create(builder, handle.getLoc(), handlePtrTy,
383 handle.getSymName());
384 builder.
createCallOp(dtorCall.getLoc(), fnAtExit, args);
386 mlir::Block &dtorBlock = dtorRegion.front();
387 entryBB.getOperations().splice(entryBB.end(), dtorBlock.getOperations(),
389 std::prev(dtorBlock.end()));
392 builder.setInsertionPointToEnd(&entryBB);
398 void emitCXXGuardedInitIf(CIRBaseBuilderTy &builder, cir::GlobalOp globalOp,
399 mlir::Region &ctorRegion, mlir::Region &dtorRegion,
400 cir::ASTVarDeclInterface varDecl,
401 mlir::Value guardPtr, cir::PointerType guardPtrTy,
403 auto loc = globalOp->getLoc();
423 mlir::Block *insertBlock = builder.getInsertionBlock();
424 if (!ctorRegion.empty()) {
425 assert(ctorRegion.hasOneBlock() &&
"Enforced by MaxSizedRegion<1>");
427 mlir::Block &block = ctorRegion.front();
428 insertBlock->getOperations().splice(
429 insertBlock->end(), block.getOperations(), block.begin(),
430 std::prev(block.end()));
433 if (!dtorRegion.empty()) {
434 assert(dtorRegion.hasOneBlock() &&
"Enforced by MaxSizedRegion<1>");
436 emitGlobalGuardedDtorRegion(builder, globalOp, dtorRegion, !threadsafe,
439 builder.setInsertionPointToEnd(insertBlock);
440 ctorRegion.getBlocks().clear();
448 mlir::Value acquireResult = acquireCall.getResult();
451 loc, mlir::cast<cir::IntType>(acquireResult.getType()), 0);
452 auto shouldInit = builder.
createCompare(loc, cir::CmpOpKind::ne,
453 acquireResult, acquireZero);
458 cir::IfOp::create(builder, loc, shouldInit,
false,
459 [](mlir::OpBuilder &, mlir::Location) {});
460 mlir::OpBuilder::InsertionGuard insertGuard(builder);
461 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
475 mlir::ValueRange{guardPtr});
478 }
else if (!
varDecl.isLocalVarDecl()) {
484 globalOp->emitError(
"NYI: non-threadsafe init for non-local variables");
499 void setASTContext(clang::ASTContext *
c) { astCtx =
c; }
504cir::GlobalOp LoweringPreparePass::getOrCreateRuntimeVariable(
505 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
506 mlir::Type type, cir::GlobalLinkageKind linkage,
507 cir::VisibilityKind visibility) {
508 cir::GlobalOp g = dyn_cast_or_null<cir::GlobalOp>(
509 mlir::SymbolTable::lookupNearestSymbolFrom(
510 mlirModule, mlir::StringAttr::get(mlirModule->getContext(), name)));
512 g = cir::GlobalOp::create(builder, loc, name, type);
514 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
515 mlir::SymbolTable::setSymbolVisibility(
516 g, mlir::SymbolTable::Visibility::Private);
517 g.setGlobalVisibility(visibility);
522cir::FuncOp LoweringPreparePass::buildRuntimeFunction(
523 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
524 cir::FuncType type, cir::GlobalLinkageKind linkage) {
525 cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom(
526 mlirModule, StringAttr::get(mlirModule->getContext(), name)));
528 f = cir::FuncOp::create(builder, loc, name, type);
530 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
531 mlir::SymbolTable::setSymbolVisibility(
532 f, mlir::SymbolTable::Visibility::Private);
542 builder.setInsertionPoint(op);
544 mlir::Value src = op.getSrc();
545 mlir::Value imag = builder.
getNullValue(src.getType(), op.getLoc());
551 cir::CastKind elemToBoolKind) {
553 builder.setInsertionPoint(op);
555 mlir::Value src = op.getSrc();
556 if (!mlir::isa<cir::BoolType>(op.getType()))
563 cir::BoolType boolTy = builder.
getBoolTy();
564 mlir::Value srcRealToBool =
565 builder.
createCast(op.getLoc(), elemToBoolKind, srcReal, boolTy);
566 mlir::Value srcImagToBool =
567 builder.
createCast(op.getLoc(), elemToBoolKind, srcImag, boolTy);
568 return builder.
createLogicalOr(op.getLoc(), srcRealToBool, srcImagToBool);
573 cir::CastKind scalarCastKind) {
575 builder.setInsertionPoint(op);
577 mlir::Value src = op.getSrc();
578 auto dstComplexElemTy =
579 mlir::cast<cir::ComplexType>(op.getType()).getElementType();
584 mlir::Value dstReal = builder.
createCast(op.getLoc(), scalarCastKind, srcReal,
586 mlir::Value dstImag = builder.
createCast(op.getLoc(), scalarCastKind, srcImag,
591void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
592 mlir::MLIRContext &ctx = getContext();
593 mlir::Value loweredValue = [&]() -> mlir::Value {
594 switch (op.getKind()) {
595 case cir::CastKind::float_to_complex:
596 case cir::CastKind::int_to_complex:
598 case cir::CastKind::float_complex_to_real:
599 case cir::CastKind::int_complex_to_real:
601 case cir::CastKind::float_complex_to_bool:
603 case cir::CastKind::int_complex_to_bool:
605 case cir::CastKind::float_complex:
607 case cir::CastKind::float_complex_to_int_complex:
609 case cir::CastKind::int_complex:
611 case cir::CastKind::int_complex_to_float_complex:
619 op.replaceAllUsesWith(loweredValue);
626 llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics),
627 mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal,
628 mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) {
629 cir::FPTypeInterface elementTy =
630 mlir::cast<cir::FPTypeInterface>(ty.getElementType());
632 llvm::StringRef libFuncName = libFuncNameGetter(
633 llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics()));
636 cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty);
642 mlir::OpBuilder::InsertionGuard ipGuard{builder};
643 builder.setInsertionPointToStart(pass.mlirModule.getBody());
644 libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy);
648 builder.
createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag});
649 return call.getResult();
652static llvm::StringRef
655 case llvm::APFloat::S_IEEEhalf:
657 case llvm::APFloat::S_IEEEsingle:
659 case llvm::APFloat::S_IEEEdouble:
661 case llvm::APFloat::S_PPCDoubleDouble:
663 case llvm::APFloat::S_x87DoubleExtended:
665 case llvm::APFloat::S_IEEEquad:
668 llvm_unreachable(
"unsupported floating point type");
674 mlir::Value lhsReal, mlir::Value lhsImag,
675 mlir::Value rhsReal, mlir::Value rhsImag) {
677 mlir::Value &a = lhsReal;
678 mlir::Value &
b = lhsImag;
679 mlir::Value &
c = rhsReal;
680 mlir::Value &d = rhsImag;
682 mlir::Value ac = builder.
createMul(loc, a,
c);
683 mlir::Value bd = builder.
createMul(loc,
b, d);
685 mlir::Value dd = builder.
createMul(loc, d, d);
686 mlir::Value acbd = builder.
createAdd(loc, ac, bd);
687 mlir::Value ccdd = builder.
createAdd(loc, cc, dd);
688 mlir::Value resultReal = builder.
createDiv(loc, acbd, ccdd);
691 mlir::Value ad = builder.
createMul(loc, a, d);
692 mlir::Value bcad = builder.
createSub(loc, bc, ad);
693 mlir::Value resultImag = builder.
createDiv(loc, bcad, ccdd);
699 mlir::Value lhsReal, mlir::Value lhsImag,
700 mlir::Value rhsReal, mlir::Value rhsImag) {
721 mlir::Value &a = lhsReal;
722 mlir::Value &
b = lhsImag;
723 mlir::Value &
c = rhsReal;
724 mlir::Value &d = rhsImag;
726 auto trueBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
728 mlir::Value rd = builder.
createMul(loc, r, d);
729 mlir::Value tmp = builder.
createAdd(loc,
c, rd);
731 mlir::Value br = builder.
createMul(loc,
b, r);
732 mlir::Value abr = builder.
createAdd(loc, a, br);
733 mlir::Value e = builder.
createDiv(loc, abr, tmp);
735 mlir::Value ar = builder.
createMul(loc, a, r);
736 mlir::Value bar = builder.
createSub(loc,
b, ar);
737 mlir::Value f = builder.
createDiv(loc, bar, tmp);
743 auto falseBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
745 mlir::Value rc = builder.
createMul(loc, r,
c);
746 mlir::Value tmp = builder.
createAdd(loc, d, rc);
748 mlir::Value ar = builder.
createMul(loc, a, r);
749 mlir::Value arb = builder.
createAdd(loc, ar,
b);
750 mlir::Value e = builder.
createDiv(loc, arb, tmp);
752 mlir::Value br = builder.
createMul(loc,
b, r);
753 mlir::Value bra = builder.
createSub(loc, br, a);
754 mlir::Value f = builder.
createDiv(loc, bra, tmp);
760 auto cFabs = cir::FAbsOp::create(builder, loc,
c);
761 auto dFabs = cir::FAbsOp::create(builder, loc, d);
762 cir::CmpOp cmpResult =
763 builder.
createCompare(loc, cir::CmpOpKind::ge, cFabs, dFabs);
764 auto ternary = cir::TernaryOp::create(builder, loc, cmpResult,
765 trueBranchBuilder, falseBranchBuilder);
767 return ternary.getResult();
774 auto getHigherPrecisionFPType = [&context](mlir::Type type) -> mlir::Type {
775 if (mlir::isa<cir::FP16Type>(type))
776 return cir::SingleType::get(&context);
778 if (mlir::isa<cir::SingleType>(type) || mlir::isa<cir::BF16Type>(type))
779 return cir::DoubleType::get(&context);
781 if (mlir::isa<cir::DoubleType>(type))
782 return cir::LongDoubleType::get(&context, type);
787 auto getFloatTypeSemantics =
788 [&cc](mlir::Type type) ->
const llvm::fltSemantics & {
790 if (mlir::isa<cir::FP16Type>(type))
793 if (mlir::isa<cir::BF16Type>(type))
796 if (mlir::isa<cir::SingleType>(type))
799 if (mlir::isa<cir::DoubleType>(type))
802 if (mlir::isa<cir::LongDoubleType>(type)) {
804 llvm_unreachable(
"NYI Float type semantics with OpenMP");
808 if (mlir::isa<cir::FP128Type>(type)) {
810 llvm_unreachable(
"NYI Float type semantics with OpenMP");
814 llvm_unreachable(
"Unsupported float type semantics");
817 const mlir::Type higherElementType = getHigherPrecisionFPType(elementType);
818 const llvm::fltSemantics &elementTypeSemantics =
819 getFloatTypeSemantics(elementType);
820 const llvm::fltSemantics &higherElementTypeSemantics =
821 getFloatTypeSemantics(higherElementType);
830 if (llvm::APFloat::semanticsMaxExponent(elementTypeSemantics) * 2 + 1 <=
831 llvm::APFloat::semanticsMaxExponent(higherElementTypeSemantics)) {
832 return higherElementType;
842 mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal,
843 mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag,
845 cir::ComplexType complexTy = op.getType();
846 if (mlir::isa<cir::FPTypeInterface>(complexTy.getElementType())) {
847 cir::ComplexRangeKind range = op.getRange();
848 if (range == cir::ComplexRangeKind::Improved)
852 if (range == cir::ComplexRangeKind::Full)
854 loc, complexTy, lhsReal, lhsImag, rhsReal,
857 if (range == cir::ComplexRangeKind::Promoted) {
858 mlir::Type originalElementType = complexTy.getElementType();
859 mlir::Type higherPrecisionElementType =
861 originalElementType);
863 if (!higherPrecisionElementType)
867 cir::CastKind floatingCastKind = cir::CastKind::floating;
868 lhsReal = builder.
createCast(floatingCastKind, lhsReal,
869 higherPrecisionElementType);
870 lhsImag = builder.
createCast(floatingCastKind, lhsImag,
871 higherPrecisionElementType);
872 rhsReal = builder.
createCast(floatingCastKind, rhsReal,
873 higherPrecisionElementType);
874 rhsImag = builder.
createCast(floatingCastKind, rhsImag,
875 higherPrecisionElementType);
878 builder, loc, lhsReal, lhsImag, rhsReal, rhsImag);
883 mlir::Value finalReal =
884 builder.
createCast(floatingCastKind, resultReal, originalElementType);
885 mlir::Value finalImag =
886 builder.
createCast(floatingCastKind, resultImag, originalElementType);
895void LoweringPreparePass::lowerComplexDivOp(cir::ComplexDivOp op) {
896 cir::CIRBaseBuilderTy builder(getContext());
897 builder.setInsertionPointAfter(op);
898 mlir::Location loc = op.getLoc();
899 mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
900 mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
906 mlir::Value loweredResult =
908 rhsImag, getContext(), *astCtx);
909 op.replaceAllUsesWith(loweredResult);
913static llvm::StringRef
916 case llvm::APFloat::S_IEEEhalf:
918 case llvm::APFloat::S_IEEEsingle:
920 case llvm::APFloat::S_IEEEdouble:
922 case llvm::APFloat::S_PPCDoubleDouble:
924 case llvm::APFloat::S_x87DoubleExtended:
926 case llvm::APFloat::S_IEEEquad:
929 llvm_unreachable(
"unsupported floating point type");
935 mlir::Location loc, cir::ComplexMulOp op,
936 mlir::Value lhsReal, mlir::Value lhsImag,
937 mlir::Value rhsReal, mlir::Value rhsImag) {
939 mlir::Value resultRealLhs = builder.
createMul(loc, lhsReal, rhsReal);
940 mlir::Value resultRealRhs = builder.
createMul(loc, lhsImag, rhsImag);
941 mlir::Value resultImagLhs = builder.
createMul(loc, lhsReal, rhsImag);
942 mlir::Value resultImagRhs = builder.
createMul(loc, lhsImag, rhsReal);
943 mlir::Value resultReal = builder.
createSub(loc, resultRealLhs, resultRealRhs);
944 mlir::Value resultImag = builder.
createAdd(loc, resultImagLhs, resultImagRhs);
945 mlir::Value algebraicResult =
948 cir::ComplexType complexTy = op.getType();
949 cir::ComplexRangeKind rangeKind = op.getRange();
950 if (mlir::isa<cir::IntType>(complexTy.getElementType()) ||
951 rangeKind == cir::ComplexRangeKind::Basic ||
952 rangeKind == cir::ComplexRangeKind::Improved ||
953 rangeKind == cir::ComplexRangeKind::Promoted)
954 return algebraicResult;
961 mlir::Value resultRealIsNaN = builder.
createIsNaN(loc, resultReal);
962 mlir::Value resultImagIsNaN = builder.
createIsNaN(loc, resultImag);
963 mlir::Value resultRealAndImagAreNaN =
966 return cir::TernaryOp::create(
967 builder, loc, resultRealAndImagAreNaN,
968 [&](mlir::OpBuilder &, mlir::Location) {
971 lhsReal, lhsImag, rhsReal, rhsImag);
974 [&](mlir::OpBuilder &, mlir::Location) {
980void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) {
981 cir::CIRBaseBuilderTy builder(getContext());
982 builder.setInsertionPointAfter(op);
983 mlir::Location loc = op.getLoc();
984 mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
985 mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
990 mlir::Value loweredResult =
lowerComplexMul(*
this, builder, loc, op, lhsReal,
991 lhsImag, rhsReal, rhsImag);
992 op.replaceAllUsesWith(loweredResult);
996void LoweringPreparePass::lowerUnaryOp(cir::UnaryOpInterface op) {
997 if (!mlir::isa<cir::ComplexType>(op.getResult().getType()))
1000 mlir::Location loc = op->getLoc();
1001 CIRBaseBuilderTy builder(getContext());
1002 builder.setInsertionPointAfter(op);
1004 mlir::Value operand = op.getInput();
1008 mlir::Value resultReal = operandReal;
1009 mlir::Value resultImag = operandImag;
1011 llvm::TypeSwitch<mlir::Operation *>(op)
1013 [&](
auto) { resultReal = builder.
createInc(loc, operandReal); })
1015 [&](
auto) { resultReal = builder.
createDec(loc, operandReal); })
1016 .Case<cir::MinusOp>([&](
auto) {
1017 resultReal = builder.
createMinus(loc, operandReal);
1018 resultImag = builder.
createMinus(loc, operandImag);
1021 [&](
auto) { resultImag = builder.
createMinus(loc, operandImag); })
1022 .
Default([](
auto) { llvm_unreachable(
"unhandled unary complex op"); });
1025 op->replaceAllUsesWith(mlir::ValueRange{result});
1029cir::FuncOp LoweringPreparePass::getOrCreateDtorFunc(CIRBaseBuilderTy &builder,
1031 mlir::Region &dtorRegion,
1032 cir::CallOp &dtorCall) {
1033 mlir::OpBuilder::InsertionGuard guard(builder);
1036 cir::VoidType voidTy = builder.
getVoidTy();
1037 auto voidPtrTy = cir::PointerType::get(voidTy);
1040 mlir::Block &dtorBlock = dtorRegion.front();
1044 auto opIt = dtorBlock.getOperations().begin();
1045 cir::GetGlobalOp ggop = mlir::cast<cir::GetGlobalOp>(*opIt);
1056 if (dtorBlock.getOperations().size() == 3) {
1057 auto callOp = mlir::dyn_cast<cir::CallOp>(&*(++opIt));
1058 auto yieldOp = mlir::dyn_cast<cir::YieldOp>(&*(++opIt));
1059 if (yieldOp && callOp && callOp.getNumOperands() == 1 &&
1060 callOp.getArgOperand(0) == ggop) {
1062 return getCalledFunction(callOp);
1069 builder.setInsertionPointAfter(op);
1070 SmallString<256> fnName(
"__cxx_global_array_dtor");
1071 uint32_t cnt = dynamicInitializerNames[fnName]++;
1073 fnName +=
"." + std::to_string(cnt);
1076 auto fnType = cir::FuncType::get({voidPtrTy}, voidTy);
1077 cir::FuncOp dtorFunc =
1078 buildRuntimeFunction(builder, fnName, op.getLoc(), fnType,
1079 cir::GlobalLinkageKind::InternalLinkage);
1081 SmallVector<mlir::NamedAttribute> paramAttrs;
1082 paramAttrs.push_back(
1083 builder.getNamedAttr(
"llvm.noundef", builder.getUnitAttr()));
1084 SmallVector<mlir::Attribute> argAttrDicts;
1085 argAttrDicts.push_back(
1086 mlir::DictionaryAttr::get(builder.getContext(), paramAttrs));
1087 dtorFunc.setArgAttrsAttr(
1088 mlir::ArrayAttr::get(builder.getContext(), argAttrDicts));
1090 mlir::Block *entryBB = dtorFunc.addEntryBlock();
1093 entryBB->getOperations().splice(entryBB->begin(), dtorBlock.getOperations(),
1094 dtorBlock.begin(), dtorBlock.end());
1097 cir::GetGlobalOp dtorGGop =
1098 mlir::cast<cir::GetGlobalOp>(entryBB->getOperations().front());
1099 builder.setInsertionPointToStart(&dtorBlock);
1100 builder.clone(*dtorGGop.getOperation());
1104 mlir::Value dtorArg = entryBB->getArgument(0);
1105 dtorGGop.replaceAllUsesWith(dtorArg);
1109 mlir::Block &finalBlock = dtorFunc.getBody().back();
1110 auto yieldOp = cast<cir::YieldOp>(finalBlock.getTerminator());
1111 builder.setInsertionPoint(yieldOp);
1112 cir::ReturnOp::create(builder, yieldOp->getLoc());
1117 cir::GetGlobalOp origGGop =
1118 mlir::cast<cir::GetGlobalOp>(dtorBlock.getOperations().front());
1119 builder.setInsertionPointAfter(origGGop);
1120 mlir::Value ggopResult = origGGop.getResult();
1121 dtorCall = builder.
createCallOp(op.getLoc(), dtorFunc, ggopResult);
1124 auto finalYield = cir::YieldOp::create(builder, op.getLoc());
1127 dtorBlock.getOperations().erase(std::next(mlir::Block::iterator(finalYield)),
1129 dtorRegion.getBlocks().erase(std::next(dtorRegion.begin()), dtorRegion.end());
1135LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op) {
1138 SmallString<256> fnName(
"__cxx_global_var_init");
1140 uint32_t cnt = dynamicInitializerNames[fnName]++;
1142 fnName +=
"." + std::to_string(cnt);
1145 CIRBaseBuilderTy builder(getContext());
1146 builder.setInsertionPointAfter(op);
1147 cir::VoidType voidTy = builder.
getVoidTy();
1148 auto fnType = cir::FuncType::get({}, voidTy);
1149 FuncOp f = buildRuntimeFunction(builder, fnName, op.getLoc(), fnType,
1150 cir::GlobalLinkageKind::InternalLinkage);
1158 mlir::Block *entryBB = f.addEntryBlock();
1159 builder.setInsertionPointToStart(entryBB);
1163 bool needsTlsGuard = op.getDynTlsRefs() && op.getDynTlsRefs()->getGuardName();
1165 if (needsTlsGuard) {
1166 guardIf = buildGlobalTlsGuardCheck(
1167 builder, op.getLoc(),
1168 getOrCreateStaticLocalDeclGuardAddress(
1169 builder, op, op.getDynTlsRefs()->getGuardName().getValue(),
1171 op.hasInternalLinkage()));
1172 builder.setInsertionPointToEnd(&guardIf.getThenRegion().front());
1175 if (!op.getCtorRegion().empty()) {
1176 mlir::Block &block = op.getCtorRegion().front();
1177 mlir::Block *insertBlock = builder.getBlock();
1178 insertBlock->getOperations().splice(insertBlock->end(),
1179 block.getOperations(), block.begin(),
1180 std::prev(block.end()));
1184 mlir::Region &dtorRegion = op.getDtorRegion();
1185 if (!dtorRegion.empty()) {
1188 emitGlobalGuardedDtorRegion(builder, op, dtorRegion,
1189 op.getTlsModel().has_value(),
1190 *builder.getBlock());
1194 if (needsTlsGuard) {
1195 builder.setInsertionPointToEnd(&guardIf.getThenRegion().back());
1196 cir::YieldOp::create(builder, op.getLoc());
1200 builder.setInsertionPointToEnd(entryBB);
1201 mlir::Operation *yieldOp =
nullptr;
1202 if (!op.getCtorRegion().empty()) {
1203 mlir::Block &block = op.getCtorRegion().front();
1204 yieldOp = &block.getOperations().back();
1206 assert(!dtorRegion.empty());
1207 mlir::Block &block = dtorRegion.front();
1208 yieldOp = &block.getOperations().back();
1211 assert(isa<cir::YieldOp>(*yieldOp));
1212 cir::ReturnOp::create(builder, yieldOp->getLoc());
1217LoweringPreparePass::getGuardAcquireFn(cir::PointerType guardPtrTy) {
1219 CIRBaseBuilderTy builder(getContext());
1220 mlir::OpBuilder::InsertionGuard ipGuard{builder};
1221 builder.setInsertionPointToStart(mlirModule.getBody());
1222 mlir::Location loc = mlirModule.getLoc();
1223 cir::IntType intTy = cir::IntType::get(&getContext(), 32,
true);
1224 auto fnType = cir::FuncType::get({guardPtrTy}, intTy);
1225 return buildRuntimeFunction(builder,
"__cxa_guard_acquire", loc, fnType);
1229LoweringPreparePass::getGuardReleaseFn(cir::PointerType guardPtrTy) {
1231 CIRBaseBuilderTy builder(getContext());
1232 mlir::OpBuilder::InsertionGuard ipGuard{builder};
1233 builder.setInsertionPointToStart(mlirModule.getBody());
1234 mlir::Location loc = mlirModule.getLoc();
1235 cir::VoidType voidTy = cir::VoidType::get(&getContext());
1236 auto fnType = cir::FuncType::get({guardPtrTy}, voidTy);
1237 return buildRuntimeFunction(builder,
"__cxa_guard_release", loc, fnType);
1240cir::FuncOp LoweringPreparePass::getTlsInitFn() {
1242 CIRBaseBuilderTy builder(getContext());
1243 mlir::OpBuilder::InsertionGuard _{builder};
1244 builder.setInsertionPointToStart(mlirModule.getBody());
1245 mlir::Location loc = mlirModule.getLoc();
1247 return buildRuntimeFunction(builder,
"__tls_init", loc, fnType,
1248 cir::GlobalLinkageKind::InternalLinkage);
1251cir::GlobalOp LoweringPreparePass::createGuardGlobalOp(
1252 CIRBaseBuilderTy &builder, mlir::Location loc, llvm::StringRef name,
1253 cir::IntType guardTy, cir::GlobalLinkageKind linkage) {
1254 mlir::OpBuilder::InsertionGuard guard(builder);
1255 builder.setInsertionPointToStart(mlirModule.getBody());
1256 cir::GlobalOp g = cir::GlobalOp::create(builder, loc, name, guardTy);
1258 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
1259 mlir::SymbolTable::setSymbolVisibility(
1260 g, mlir::SymbolTable::Visibility::Private);
1264void LoweringPreparePass::handleStaticLocal(cir::GlobalOp globalOp,
1265 cir::LocalInitOp localInitOp) {
1266 CIRBaseBuilderTy builder(getContext());
1268 std::optional<cir::ASTVarDeclInterface> astOption = globalOp.getAst();
1269 assert(astOption.has_value());
1270 cir::ASTVarDeclInterface
varDecl = astOption.value();
1272 builder.setInsertionPointAfter(localInitOp);
1273 mlir::Block *localInitBlock = builder.getInsertionBlock();
1276 mlir::Operation *ret = localInitBlock->getTerminator();
1280 builder.setInsertionPointAfter(localInitOp);
1284 bool nonTemplateInline =
1290 if (nonTemplateInline) {
1291 globalOp->emitError(
1292 "NYI: guarded initialization for inline namespace-scope variables");
1299 bool threadsafe = astCtx->
getLangOpts().ThreadsafeStatics &&
1300 (
varDecl.isLocalVarDecl() || nonTemplateInline) &&
1305 bool useInt8GuardVariable = !threadsafe && globalOp.hasInternalLinkage();
1308 cir::GlobalOp guard = getOrCreateStaticLocalDeclGuardAddress(
1309 builder, globalOp, globalOp.getStaticLocalGuard()->getName().getValue(),
1310 varDecl.isLocalVarDecl(), useInt8GuardVariable);
1313 localInitBlock->push_back(ret);
1317 mlir::Value guardPtr = builder.
createGetGlobal(guard, localInitOp.getTls());
1339 unsigned maxInlineWidthInBits =
1342 if (!threadsafe || maxInlineWidthInBits) {
1344 auto bytePtrTy = cir::PointerType::get(builder.
getSIntNTy(8));
1345 mlir::Value bytePtr = builder.
createBitcast(guardPtr, bytePtrTy);
1347 localInitOp.getLoc(), bytePtr, *guard.getAlignment());
1356 auto loadOp = mlir::cast<cir::LoadOp>(guardLoad.getDefiningOp());
1357 loadOp.setMemOrder(cir::MemOrder::Acquire);
1358 loadOp.setSyncScope(cir::SyncScopeKind::System);
1381 if (useARMGuardVarABI() && !useInt8GuardVariable) {
1383 localInitOp.getLoc(), mlir::cast<cir::IntType>(guardLoad.getType()),
1385 guardLoad = builder.
createAnd(localInitOp.getLoc(), guardLoad, one);
1390 localInitOp.getLoc(), mlir::cast<cir::IntType>(guardLoad.getType()), 0);
1391 auto needsInit = builder.
createCompare(localInitOp.getLoc(),
1392 cir::CmpOpKind::eq, guardLoad, zero);
1396 builder, globalOp.getLoc(), needsInit,
1397 false, [&](mlir::OpBuilder &, mlir::Location) {
1398 emitCXXGuardedInitIf(builder, globalOp, localInitOp.getCtorRegion(),
1399 localInitOp.getDtorRegion(), varDecl, guardPtr,
1400 builder.getPointerTo(guard.getSymType()),
1406 globalOp->emitError(
"NYI: guarded init without inline atomics support");
1411 builder.getInsertionBlock()->push_back(ret);
1414void LoweringPreparePass::lowerLocalInitOp(cir::LocalInitOp initOp) {
1417 if (initOp.getCtorRegion().empty() && initOp.getDtorRegion().empty()) {
1422 cir::GlobalOp globalOp = initOp.getReferencedGlobal(symbolTables);
1423 assert(globalOp &&
"No global-op found");
1425 handleStaticLocal(globalOp, initOp);
1432 return tls == cir::TLS_Model::GeneralDynamic &&
1436static cir::GlobalLinkageKind
1439 return op.getLinkage();
1444 return op.getLinkage();
1448 if (op.isDeclaration())
1449 return cir::GlobalLinkageKind::LinkOnceODRLinkage;
1450 return cir::GlobalLinkageKind::WeakODRLinkage;
1454LoweringPreparePass::getOrCreateThreadLocalWrapper(CIRBaseBuilderTy &builder,
1456 mlir::OpBuilder::InsertionGuard insertGuard(builder);
1457 builder.setInsertionPointToStart(&mlirModule.getBodyRegion().front());
1459 mlir::StringAttr wrapperName = op.getDynTlsRefs()->getWrapperName();
1461 auto existingWrapperIter = threadLocalWrappers.find(wrapperName.getValue());
1462 if (existingWrapperIter != threadLocalWrappers.end())
1463 return existingWrapperIter->second;
1466 auto funcType = cir::FuncType::get({}, builder.
getPointerTo(op.getSymType()));
1468 cir::FuncOp::create(builder, op.getLoc(), wrapperName, funcType);
1470 cir::GlobalLinkageKind linkageKind =
1472 func.setLinkageAttr(
1473 cir::GlobalLinkageKindAttr::get(&getContext(), linkageKind));
1478 func.isWeakForLinker())
1479 func.setComdat(
true);
1481 mlir::SymbolTable::setSymbolVisibility(
1482 func, mlir::SymbolTable::Visibility::Private);
1487 op.getGlobalVisibility() == cir::VisibilityKind::Hidden)
1488 func.setGlobalVisibility(cir::VisibilityKind::Hidden);
1491 op->emitError(
"Unhandled thread wrapper attributes for CC and Nounwind");
1493 threadLocalWrappers.insert({wrapperName.getValue(), func});
1497void LoweringPreparePass::defineGlobalThreadLocalWrapper(cir::GlobalOp op,
1498 cir::FuncOp initAlias,
1499 bool isVarDefinition) {
1500 CIRBaseBuilderTy builder(getContext());
1501 cir::FuncOp wrapper = getOrCreateThreadLocalWrapper(builder, op);
1502 mlir::Block *entryBB = wrapper.addEntryBlock();
1503 builder.setInsertionPointToStart(entryBB);
1507 mlir::Location aliasLoc = initAlias.getLoc();
1508 if (!isVarDefinition) {
1510 mlir::Value funcLoad = cir::GetGlobalOp::create(
1511 builder, aliasLoc, cir::PointerType::get(initAlias.getFunctionType()),
1512 initAlias.getSymName());
1513 mlir::Value nullCheck =
1515 mlir::Value cmp = cir::CmpOp::create(
1516 builder, aliasLoc, cir::CmpOpKind::ne, funcLoad, nullCheck);
1517 cir::IfOp::create(builder, aliasLoc, cmp,
false,
1518 [&](mlir::OpBuilder &, mlir::Location loc) {
1520 cir::YieldOp::create(builder, aliasLoc);
1529 cir::ReturnOp::create(builder, op.getLoc(), {get});
1533LoweringPreparePass::defineGlobalThreadLocalInitAlias(cir::GlobalOp op,
1534 cir::FuncOp aliasee) {
1535 CIRBaseBuilderTy builder(getContext());
1536 mlir::OpBuilder::InsertionGuard insertGuard(builder);
1537 builder.setInsertionPointToStart(&mlirModule.getBodyRegion().front());
1538 mlir::StringAttr aliasName = op.getDynTlsRefs()->getInitName();
1539 auto existingAliasIter = threadLocalInitAliases.find(aliasName.getValue());
1541 if (existingAliasIter != threadLocalInitAliases.end())
1542 return existingAliasIter->second;
1546 cir::FuncOp::create(builder, op.getLoc(), aliasName, funcType);
1547 alias.setLinkage(op.getLinkage());
1550 alias.setAliasee(aliasee.getSymName());
1555 alias.setLinkage(cir::GlobalLinkageKind::ExternalWeakLinkage);
1556 mlir::SymbolTable::setSymbolVisibility(
1557 alias, mlir::SymbolTable::Visibility::Private);
1560 threadLocalInitAliases.insert({aliasName.getValue(), alias});
1564void LoweringPreparePass::lowerGlobalOp(GlobalOp op) {
1566 if (op.getStaticLocalGuard())
1569 mlir::Region &ctorRegion = op.getCtorRegion();
1570 mlir::Region &dtorRegion = op.getDtorRegion();
1571 cir::FuncOp initAlias;
1573 if (!ctorRegion.empty() || !dtorRegion.empty()) {
1576 cir::FuncOp f = buildCXXGlobalVarDeclInitFunc(op);
1579 ctorRegion.getBlocks().clear();
1580 dtorRegion.getBlocks().clear();
1583 if (op.getTlsModel() == TLS_Model::GeneralDynamic &&
1584 !op.getStaticLocalGuard().has_value()) {
1596 if (op.getDynTlsRefs()->getGuardName()) {
1598 initAlias = defineGlobalThreadLocalInitAlias(op, f);
1601 initAlias = defineGlobalThreadLocalInitAlias(op, getTlsInitFn());
1605 globalThreadLocalInitializers.push_back(f);
1608 dynamicInitializers.push_back(f);
1610 }
else if (op.getTlsModel() == TLS_Model::GeneralDynamic &&
1611 op.getDynTlsRefs() && op.isDeclaration()) {
1614 initAlias = defineGlobalThreadLocalInitAlias(op, {});
1620 if (op.getTlsModel() == TLS_Model::GeneralDynamic && op.getDynTlsRefs())
1621 defineGlobalThreadLocalWrapper(op, initAlias, !op.isDeclaration());
1626void LoweringPreparePass::lowerGetGlobalOp(GetGlobalOp op) {
1629 auto globalOp = mlir::cast<cir::GlobalOp>(
1630 symbolTables.lookupNearestSymbolFrom(op, op.getNameAttr()));
1636 if (globalOp.getTlsModel() != TLS_Model::GeneralDynamic ||
1637 !globalOp.getDynTlsRefs())
1655 mlir::Operation *parentOp = op->getParentOp();
1656 if (parentOp == globalOp) {
1657 mlir::Region *ctorRegion = &globalOp.getCtorRegion();
1658 mlir::Region *dtorRegion = &globalOp.getDtorRegion();
1660 if (!ctorRegion->empty() && &*ctorRegion->op_begin() == op.getOperation())
1662 if (!dtorRegion->empty() && &*dtorRegion->op_begin() == op.getOperation())
1666 CIRBaseBuilderTy builder(getContext());
1667 cir::FuncOp wrapperFunc = getOrCreateThreadLocalWrapper(builder, globalOp);
1669 builder.setInsertionPoint(op);
1671 wrapperFunc.getLoc(),
1672 mlir::FlatSymbolRefAttr::get(wrapperFunc.getSymNameAttr()),
1673 wrapperFunc.getFunctionType().getReturnType(), {});
1674 op->replaceAllUsesWith(call);
1678void LoweringPreparePass::lowerThreeWayCmpOp(CmpThreeWayOp op) {
1679 CIRBaseBuilderTy builder(getContext());
1680 builder.setInsertionPointAfter(op);
1682 mlir::Location loc = op->getLoc();
1683 cir::CmpThreeWayInfoAttr cmpInfo = op.getInfo();
1692 mlir::Value transformedResult;
1693 if (cmpInfo.getOrdering() != CmpOrdering::Partial) {
1696 builder.
createCompare(loc, CmpOpKind::lt, op.getLhs(), op.getRhs());
1697 mlir::Value selectOnLt = builder.
createSelect(loc, lt, ltRes, gtRes);
1699 builder.
createCompare(loc, CmpOpKind::eq, op.getLhs(), op.getRhs());
1700 transformedResult = builder.
createSelect(loc, eq, eqRes, selectOnLt);
1704 loc, op.getType(), cmpInfo.getUnordered().value());
1707 builder.
createCompare(loc, CmpOpKind::eq, op.getLhs(), op.getRhs());
1708 mlir::Value selectOnEq = builder.
createSelect(loc, eq, eqRes, unorderedRes);
1710 builder.
createCompare(loc, CmpOpKind::gt, op.getLhs(), op.getRhs());
1711 mlir::Value selectOnGt = builder.
createSelect(loc, gt, gtRes, selectOnEq);
1713 builder.
createCompare(loc, CmpOpKind::lt, op.getLhs(), op.getRhs());
1714 transformedResult = builder.
createSelect(loc, lt, ltRes, selectOnGt);
1717 op.replaceAllUsesWith(transformedResult);
1721template <
typename AttributeTy>
1722static llvm::SmallVector<mlir::Attribute>
1726 for (
const auto &[name, priority] : list)
1727 attrs.push_back(AttributeTy::get(context, name, priority));
1731void LoweringPreparePass::buildGlobalCtorDtorList() {
1732 if (!globalCtorList.empty()) {
1733 llvm::SmallVector<mlir::Attribute> globalCtors =
1737 mlirModule->setAttr(cir::CIRDialect::getGlobalCtorsAttrName(),
1738 mlir::ArrayAttr::get(&getContext(), globalCtors));
1741 if (!globalDtorList.empty()) {
1742 llvm::SmallVector<mlir::Attribute> globalDtors =
1745 mlirModule->setAttr(cir::CIRDialect::getGlobalDtorsAttrName(),
1746 mlir::ArrayAttr::get(&getContext(), globalDtors));
1751LoweringPreparePass::createGlobalThreadLocalGuard(CIRBaseBuilderTy &builder,
1752 mlir::Location loc) {
1753 mlir::OpBuilder::InsertionGuard guard(builder);
1754 builder.setInsertionPointToStart(mlirModule.getBody());
1757 cir::IntType guardTy = builder.
getSIntNTy(8);
1758 auto g = cir::GlobalOp::create(builder, loc,
"__tls_guard", guardTy);
1759 g.setLinkageAttr(cir::GlobalLinkageKindAttr::get(
1760 builder.getContext(), cir::GlobalLinkageKind::InternalLinkage));
1764 g.setTlsModel(TLS_Model::GeneralDynamic);
1765 g.setInitialValueAttr(cir::IntAttr::get(guardTy, 0));
1769cir::IfOp LoweringPreparePass::buildGlobalTlsGuardCheck(
1770 CIRBaseBuilderTy &builder, mlir::Location loc, cir::GlobalOp guard) {
1772 mlir::Value getGuardValue = getGuard;
1777 if (guard.getSymType() != builder.
getSIntNTy(8))
1779 getGuard, cir::PointerType::get(builder.
getSIntNTy(8)));
1781 mlir::Value guardLoad =
1785 builder.
createCompare(loc, cir::CmpOpKind::eq, guardLoad, zero);
1786 return cir::IfOp::create(
1788 false, [&](mlir::OpBuilder &, mlir::Location loc) {
1792 loc, builder.
getConstantInt(loc, guard.getSymType(), 1), getGuard);
1796void LoweringPreparePass::buildCXXGlobalTlsFunc() {
1797 if (globalThreadLocalInitializers.empty())
1803 cir::FuncOp tlsInit = getTlsInitFn();
1804 mlir::Location loc = tlsInit.getLoc();
1805 CIRBaseBuilderTy builder(getContext());
1806 mlir::Block *entryBB = tlsInit.addEntryBlock();
1807 builder.setInsertionPointToStart(entryBB);
1809 cir::IfOp ifOperation = buildGlobalTlsGuardCheck(
1810 builder, loc, createGlobalThreadLocalGuard(builder, loc));
1813 builder.setInsertionPointToEnd(&ifOperation.getThenRegion().front());
1814 for (cir::FuncOp initFunc : globalThreadLocalInitializers)
1816 cir::YieldOp::create(builder, loc);
1818 builder.setInsertionPointAfter(ifOperation);
1819 cir::ReturnOp::create(builder, loc);
1822void LoweringPreparePass::buildCXXGlobalInitFunc() {
1823 if (dynamicInitializers.empty())
1830 SmallString<256> fnName;
1838 llvm::raw_svector_ostream
out(fnName);
1839 std::unique_ptr<clang::MangleContext> mangleCtx(
1841 cast<clang::ItaniumMangleContext>(*mangleCtx)
1844 fnName +=
"_GLOBAL__sub_I_";
1848 CIRBaseBuilderTy builder(getContext());
1849 builder.setInsertionPointToEnd(&mlirModule.getBodyRegion().back());
1850 auto fnType = cir::FuncType::get({}, builder.
getVoidTy());
1852 buildRuntimeFunction(builder, fnName, mlirModule.getLoc(), fnType,
1853 cir::GlobalLinkageKind::ExternalLinkage);
1854 builder.setInsertionPointToStart(f.addEntryBlock());
1855 for (cir::FuncOp &f : dynamicInitializers)
1859 globalCtorList.emplace_back(fnName,
1860 cir::GlobalCtorAttr::getDefaultPriority());
1862 cir::ReturnOp::create(builder, f.getLoc());
1871 mlir::Operation *op, mlir::Type eltTy,
1873 mlir::Value numElements,
1874 uint64_t arrayLen,
bool isCtor) {
1875 mlir::Location loc = op->getLoc();
1876 bool isDynamic = numElements !=
nullptr;
1880 const unsigned sizeTypeSize =
1886 mlir::Value begin, end;
1889 end = cir::PtrStrideOp::create(builder, loc, eltTy, begin, numElements);
1891 mlir::Value endOffsetVal =
1893 begin = cir::CastOp::create(builder, loc, eltTy,
1894 cir::CastKind::array_to_ptrdecay, addr);
1895 end = cir::PtrStrideOp::create(builder, loc, eltTy, begin, endOffsetVal);
1898 mlir::Value start = isCtor ? begin : end;
1899 mlir::Value stop = isCtor ? end : begin;
1905 mlir::Value guardCond;
1908 guardCond = cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne,
1914 cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne, start, stop);
1916 ifOp = cir::IfOp::create(builder, loc, guardCond,
1918 [&](mlir::OpBuilder &, mlir::Location) {});
1919 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1927 mlir::Block *bodyBlock = &op->getRegion(0).front();
1932 auto cloneRegionBodyInto = [&](mlir::Block *srcBlock,
1933 mlir::Value replacement) {
1934 mlir::IRMapping map;
1935 map.map(srcBlock->getArgument(0), replacement);
1936 for (mlir::Operation ®ionOp : *srcBlock) {
1937 if (!mlir::isa<cir::YieldOp>(®ionOp))
1938 builder.clone(regionOp, map);
1942 mlir::Block *partialDtorBlock =
nullptr;
1943 if (
auto arrayCtor = mlir::dyn_cast<cir::ArrayCtor>(op)) {
1944 mlir::Region &partialDtor = arrayCtor.getPartialDtor();
1945 if (!partialDtor.empty())
1946 partialDtorBlock = &partialDtor.front();
1947 }
else if (
auto arrayDtor = mlir::dyn_cast<cir::ArrayDtor>(op)) {
1956 if (arrayDtor.getDtorMayThrow())
1957 partialDtorBlock = bodyBlock;
1960 auto emitCtorDtorLoop = [&]() {
1964 [&](mlir::OpBuilder &
b, mlir::Location loc) {
1965 auto currentElement = cir::LoadOp::create(
b, loc, eltTy, tmpAddr);
1966 auto cmp = cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne,
1967 currentElement, stop);
1971 [&](mlir::OpBuilder &
b, mlir::Location loc) {
1972 auto currentElement = cir::LoadOp::create(
b, loc, eltTy, tmpAddr);
1974 cloneRegionBodyInto(bodyBlock, currentElement);
1975 mlir::Value stride = builder.
getUnsignedInt(loc, 1, sizeTypeSize);
1976 auto nextElement = cir::PtrStrideOp::create(builder, loc, eltTy,
1977 currentElement, stride);
1980 mlir::Value stride = builder.
getSignedInt(loc, -1, sizeTypeSize);
1981 auto prevElement = cir::PtrStrideOp::create(builder, loc, eltTy,
1982 currentElement, stride);
1984 cloneRegionBodyInto(bodyBlock, prevElement);
1987 cir::YieldOp::create(
b, loc);
1991 if (partialDtorBlock) {
1992 cir::CleanupScopeOp::create(
1993 builder, loc, cir::CleanupKind::EH,
1995 [&](mlir::OpBuilder &
b, mlir::Location loc) {
1997 cir::YieldOp::create(
b, loc);
2000 [&](mlir::OpBuilder &
b, mlir::Location loc) {
2001 auto cur = cir::LoadOp::create(
b, loc, eltTy, tmpAddr);
2003 cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne, cur, begin);
2005 builder, loc, cmp,
false,
2006 [&](mlir::OpBuilder &
b, mlir::Location loc) {
2010 [&](mlir::OpBuilder &
b, mlir::Location loc) {
2011 auto el = cir::LoadOp::create(
b, loc, eltTy, tmpAddr);
2012 auto neq = cir::CmpOp::create(
2013 builder, loc, cir::CmpOpKind::ne, el, begin);
2017 [&](mlir::OpBuilder &
b, mlir::Location loc) {
2018 auto el = cir::LoadOp::create(
b, loc, eltTy, tmpAddr);
2019 mlir::Value negOne =
2021 auto prev = cir::PtrStrideOp::create(builder, loc, eltTy,
2024 cloneRegionBodyInto(partialDtorBlock, prev);
2027 cir::YieldOp::create(builder, loc);
2029 cir::YieldOp::create(
b, loc);
2036 cir::YieldOp::create(builder, loc);
2041void LoweringPreparePass::lowerArrayDtor(cir::ArrayDtor op) {
2042 CIRBaseBuilderTy builder(getContext());
2043 builder.setInsertionPointAfter(op.getOperation());
2045 mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
2047 if (op.getNumElements()) {
2049 op.getNumElements(), 0,
2055 mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
2061void LoweringPreparePass::lowerArrayCtor(cir::ArrayCtor op) {
2062 cir::CIRBaseBuilderTy builder(getContext());
2063 builder.setInsertionPointAfter(op.getOperation());
2065 mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
2067 if (op.getNumElements()) {
2069 op.getNumElements(), 0,
2075 mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
2081cir::FuncOp LoweringPreparePass::getCalledFunction(cir::CallOp callOp) {
2082 mlir::SymbolRefAttr sym = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
2083 callOp.getCallableForCallee());
2086 return symbolTables.lookupNearestSymbolFrom<cir::FuncOp>(callOp, sym);
2089void LoweringPreparePass::lowerTrivialCopyCall(cir::CallOp op) {
2090 cir::FuncOp funcOp = getCalledFunction(op);
2094 std::optional<cir::CtorKind> ctorKind = funcOp.getCxxConstructorKind();
2095 if (ctorKind && *ctorKind == cir::CtorKind::Copy &&
2096 funcOp.isCxxTrivialMemberFunction()) {
2098 CIRBaseBuilderTy builder(getContext());
2099 mlir::ValueRange operands = op.getOperands();
2100 mlir::Value dest = operands[0];
2101 mlir::Value src = operands[1];
2102 builder.setInsertionPoint(op);
2108cir::GlobalOp LoweringPreparePass::getOrCreateConstAggregateGlobal(
2109 CIRBaseBuilderTy &builder, mlir::Location loc, llvm::StringRef baseName,
2110 mlir::Type ty, mlir::TypedAttr constant) {
2112 llvm::SmallVector<cir::GlobalOp, 1> &versions =
2113 constAggregateGlobals[baseName];
2116 for (cir::GlobalOp gv : versions) {
2117 if (gv.getSymType() == ty && gv.getInitialValue() == constant)
2125 llvm::SmallString<128>
name(baseName);
2126 size_t baseLen =
name.size();
2127 unsigned version = versions.size();
2129 name.resize(baseLen);
2131 name.push_back(
'.');
2132 llvm::Twine(version).toVector(name);
2134 auto existingGv = symbolTables.lookupSymbolIn<cir::GlobalOp>(
2135 mlirModule, mlir::StringAttr::get(&getContext(), name));
2138 versions.push_back(existingGv);
2139 if (existingGv.getSymType() == ty &&
2140 existingGv.getInitialValue() == constant)
2146 mlir::OpBuilder::InsertionGuard guard(builder);
2147 builder.setInsertionPointToStart(mlirModule.getBody());
2149 cir::GlobalOp::create(builder, loc, name, ty,
2151 cir::LangAddressSpaceAttr::get(
2152 &getContext(), cir::LangAddressSpace::Default),
2153 cir::GlobalLinkageKind::PrivateLinkage);
2154 mlir::SymbolTable::setSymbolVisibility(
2155 gv, mlir::SymbolTable::Visibility::Private);
2156 gv.setInitialValueAttr(constant);
2160 symbolTables.getSymbolTable(mlirModule).insert(gv);
2162 versions.push_back(gv);
2166void LoweringPreparePass::lowerStoreOfConstAggregate(cir::StoreOp op) {
2168 auto constOp = op.getValue().getDefiningOp<cir::ConstantOp>();
2172 mlir::Type ty = constOp.getType();
2173 if (!mlir::isa<cir::ArrayType, cir::RecordType>(ty))
2179 auto alloca = op.getAddr().getDefiningOp<cir::AllocaOp>();
2183 mlir::TypedAttr constant = constOp.getValue();
2194 auto func = op->getParentOfType<cir::FuncOp>();
2197 llvm::StringRef funcName = func.getSymName();
2200 llvm::StringRef varName = alloca.getName();
2203 std::string baseName = (
"__const." + funcName +
"." + varName).str();
2204 CIRBaseBuilderTy builder(getContext());
2208 cir::GlobalOp gv = getOrCreateConstAggregateGlobal(builder, op.getLoc(),
2209 baseName, ty, constant);
2212 builder.setInsertionPoint(op);
2214 auto ptrTy = cir::PointerType::get(ty);
2215 mlir::Value globalPtr =
2216 cir::GetGlobalOp::create(builder, op.getLoc(), ptrTy, gv.getSymName());
2225 if (constOp.use_empty())
2229void LoweringPreparePass::runOnOp(mlir::Operation *op) {
2230 if (
auto arrayCtor = dyn_cast<cir::ArrayCtor>(op)) {
2231 lowerArrayCtor(arrayCtor);
2232 }
else if (
auto arrayDtor = dyn_cast<cir::ArrayDtor>(op)) {
2233 lowerArrayDtor(arrayDtor);
2234 }
else if (
auto cast = mlir::dyn_cast<cir::CastOp>(op)) {
2236 }
else if (
auto complexDiv = mlir::dyn_cast<cir::ComplexDivOp>(op)) {
2237 lowerComplexDivOp(complexDiv);
2238 }
else if (
auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op)) {
2239 lowerComplexMulOp(complexMul);
2240 }
else if (
auto glob = mlir::dyn_cast<cir::GlobalOp>(op)) {
2241 lowerGlobalOp(glob);
2242 }
else if (
auto getGlob = mlir::dyn_cast<cir::GetGlobalOp>(op)) {
2243 lowerGetGlobalOp(getGlob);
2244 }
else if (
auto unaryOp = mlir::dyn_cast<cir::UnaryOpInterface>(op)) {
2245 lowerUnaryOp(unaryOp);
2246 }
else if (
auto callOp = dyn_cast<cir::CallOp>(op)) {
2247 lowerTrivialCopyCall(callOp);
2248 }
else if (
auto storeOp = dyn_cast<cir::StoreOp>(op)) {
2249 lowerStoreOfConstAggregate(storeOp);
2250 }
else if (
auto fnOp = dyn_cast<cir::FuncOp>(op)) {
2251 if (
auto globalCtor = fnOp.getGlobalCtorPriority())
2252 globalCtorList.emplace_back(fnOp.getName(), globalCtor.value());
2253 else if (
auto globalDtor = fnOp.getGlobalDtorPriority())
2254 globalDtorList.emplace_back(fnOp.getName(), globalDtor.value());
2256 if (mlir::Attribute attr =
2257 fnOp->getAttr(cir::CUDAKernelNameAttr::getMnemonic())) {
2258 auto kernelNameAttr = dyn_cast<CUDAKernelNameAttr>(attr);
2259 llvm::StringRef kernelName = kernelNameAttr.getKernelName();
2260 cudaKernelMap[kernelName] = fnOp;
2262 }
else if (
auto threeWayCmp = dyn_cast<cir::CmpThreeWayOp>(op)) {
2263 lowerThreeWayCmpOp(threeWayCmp);
2264 }
else if (
auto initOp = dyn_cast<cir::LocalInitOp>(op)) {
2265 lowerLocalInitOp(initOp);
2276 llvm::StringRef name) {
2277 return (
"__" + prefix + name).str();
2299void LoweringPreparePass::buildCUDAModuleCtor() {
2302 if (astCtx->
getLangOpts().GPURelocatableDeviceCode)
2303 llvm_unreachable(
"GPU RDC NYI");
2307 if (cudaKernelMap.empty())
2312 mlir::Attribute cudaBinaryHandleAttr =
2313 mlirModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName());
2314 if (!cudaBinaryHandleAttr) {
2320 llvm::StringRef cudaGPUBinaryName =
2321 mlir::cast<CUDABinaryHandleAttr>(cudaBinaryHandleAttr)
2325 llvm::vfs::FileSystem &vfs =
2327 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> gpuBinaryOrErr =
2328 vfs.getBufferForFile(cudaGPUBinaryName);
2329 if (std::error_code ec = gpuBinaryOrErr.getError()) {
2330 mlirModule->emitError(
"cannot open GPU binary file: " + cudaGPUBinaryName +
2331 ": " + ec.message());
2334 std::unique_ptr<llvm::MemoryBuffer> gpuBinary =
2335 std::move(gpuBinaryOrErr.get());
2339 mlir::Location loc = mlirModule->getLoc();
2340 CIRBaseBuilderTy builder(getContext());
2341 builder.setInsertionPointToStart(mlirModule.getBody());
2345 PointerType voidPtrPtrTy = builder.
getPointerTo(voidPtrTy);
2347 IntType charTy = cir::IntType::get(&getContext(), astCtx->
getCharWidth(),
2353 llvm::StringRef fatbinConstName =
2354 astCtx->
getLangOpts().HIP ?
".hip_fatbin" :
".nv_fatbin";
2356 llvm::StringRef fatbinSectionName =
2357 astCtx->
getLangOpts().HIP ?
".hipFatBinSegment" :
".nvFatBinSegment";
2361 ArrayType::get(&getContext(), charTy, gpuBinary->getBuffer().size());
2363 GlobalOp fatbinStr = GlobalOp::create(builder, loc, fatbinStrName, fatbinType,
2365 GlobalLinkageKind::PrivateLinkage);
2366 fatbinStr.setAlignment(8);
2367 fatbinStr.setInitialValueAttr(cir::ConstArrayAttr::get(
2368 fatbinType, StringAttr::get(gpuBinary->getBuffer(), fatbinType)));
2369 fatbinStr.setSection(fatbinConstName);
2370 fatbinStr.setPrivate();
2374 auto fatbinWrapperType = RecordType::get(
2375 &getContext(), {intTy, intTy, voidPtrTy, voidPtrTy},
2376 false,
false, RecordType::RecordKind::Struct);
2377 std::string fatbinWrapperName =
2379 GlobalOp fatbinWrapper = GlobalOp::create(
2380 builder, loc, fatbinWrapperName, fatbinWrapperType,
2381 true, {}, GlobalLinkageKind::PrivateLinkage);
2382 fatbinWrapper.setSection(fatbinSectionName);
2384 constexpr unsigned cudaFatMagic = 0x466243b1;
2385 constexpr unsigned hipFatMagic = 0x48495046;
2386 unsigned fatMagic =
isHIP ? hipFatMagic : cudaFatMagic;
2388 auto magicInit = IntAttr::get(intTy, fatMagic);
2389 auto versionInit = IntAttr::get(intTy, 1);
2390 auto fatbinStrSymbol =
2391 mlir::FlatSymbolRefAttr::get(fatbinStr.getSymNameAttr());
2392 auto fatbinInit = GlobalViewAttr::get(voidPtrTy, fatbinStrSymbol);
2394 fatbinWrapper.setInitialValueAttr(cir::ConstRecordAttr::get(
2396 mlir::ArrayAttr::get(&getContext(),
2397 {magicInit, versionInit, fatbinInit, unusedInit})));
2400 std::string gpubinHandleName =
2403 GlobalOp gpuBinHandle = GlobalOp::create(
2404 builder, loc, gpubinHandleName, voidPtrPtrTy,
2405 false, {}, cir::GlobalLinkageKind::InternalLinkage);
2407 gpuBinHandle.setPrivate();
2412 std::string regFuncName =
2414 FuncType regFuncType = FuncType::get({voidPtrTy}, voidPtrPtrTy);
2415 cir::FuncOp regFunc =
2416 buildRuntimeFunction(builder, regFuncName, loc, regFuncType);
2419 cir::FuncOp moduleCtor = buildRuntimeFunction(
2420 builder, moduleCtorName, loc, FuncType::get({}, voidTy),
2421 GlobalLinkageKind::InternalLinkage);
2423 globalCtorList.emplace_back(moduleCtorName,
2424 cir::GlobalCtorAttr::getDefaultPriority());
2425 builder.setInsertionPointToStart(moduleCtor.addEntryBlock());
2433 mlir::Block *entryBlock = builder.getInsertionBlock();
2434 mlir::Region *parent = entryBlock->getParent();
2435 mlir::Block *ifBlock = builder.createBlock(parent);
2436 mlir::Block *exitBlock = builder.createBlock(parent);
2438 mlir::OpBuilder::InsertionGuard guard(builder);
2439 builder.setInsertionPointToEnd(entryBlock);
2440 mlir::Value handle =
2442 auto handlePtrTy = mlir::cast<cir::PointerType>(handle.getType());
2443 mlir::Value nullPtr = builder.
getNullPtr(handlePtrTy, loc);
2444 mlir::Value isNull =
2445 builder.
createCompare(loc, cir::CmpOpKind::eq, handle, nullPtr);
2446 cir::BrCondOp::create(builder, loc, isNull, ifBlock, exitBlock);
2450 mlir::OpBuilder::InsertionGuard guard(builder);
2451 builder.setInsertionPointToStart(ifBlock);
2453 mlir::Value fatbinVoidPtr = builder.
createBitcast(wrapper, voidPtrTy);
2454 cir::CallOp gpuBinaryHandleCall =
2456 mlir::Value gpuBinaryHandle = gpuBinaryHandleCall.getResult();
2458 mlir::Value gpuBinaryHandleGlobal = builder.
createGetGlobal(gpuBinHandle);
2459 builder.
createStore(loc, gpuBinaryHandle, gpuBinaryHandleGlobal);
2460 cir::BrOp::create(builder, loc, exitBlock);
2465 mlir::OpBuilder::InsertionGuard guard(builder);
2466 builder.setInsertionPointToStart(exitBlock);
2467 mlir::Value gHandle =
2470 if (std::optional<FuncOp> regGlobal = buildCUDARegisterGlobals())
2473 if (std::optional<FuncOp> dtor = buildHIPModuleDtor()) {
2474 cir::CIRBaseBuilderTy globalBuilder(getContext());
2475 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
2476 FuncOp atexit = buildRuntimeFunction(
2477 globalBuilder,
"atexit", loc,
2478 FuncType::get(PointerType::get(dtor->getFunctionType()), intTy));
2479 mlir::Value dtorFunc = GetGlobalOp::create(
2480 builder, loc, PointerType::get(dtor->getFunctionType()),
2481 mlir::FlatSymbolRefAttr::get(dtor->getSymNameAttr()));
2484 cir::ReturnOp::create(builder, loc);
2488 if (!astCtx->
getLangOpts().GPURelocatableDeviceCode) {
2496 mlir::Value fatbinVoidPtr = builder.
createBitcast(wrapper, voidPtrTy);
2497 cir::CallOp gpuBinaryHandleCall =
2499 mlir::Value gpuBinaryHandle = gpuBinaryHandleCall.getResult();
2501 mlir::Value gpuBinaryHandleGlobal = builder.
createGetGlobal(gpuBinHandle);
2502 builder.
createStore(loc, gpuBinaryHandle, gpuBinaryHandleGlobal);
2505 if (std::optional<FuncOp> regGlobal = buildCUDARegisterGlobals()) {
2506 builder.
createCallOp(loc, *regGlobal, gpuBinaryHandle);
2515 cir::CIRBaseBuilderTy globalBuilder(getContext());
2516 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
2518 buildRuntimeFunction(globalBuilder,
"__cudaRegisterFatBinaryEnd", loc,
2519 FuncType::get({voidPtrPtrTy}, voidTy));
2523 llvm_unreachable(
"GPU RDC NYI");
2528 if (std::optional<FuncOp> dtor = buildCUDAModuleDtor()) {
2531 cir::CIRBaseBuilderTy globalBuilder(getContext());
2532 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
2533 FuncOp atexit = buildRuntimeFunction(
2534 globalBuilder,
"atexit", loc,
2535 FuncType::get(PointerType::get(dtor->getFunctionType()), intTy));
2536 mlir::Value dtorFunc = GetGlobalOp::create(
2537 builder, loc, PointerType::get(dtor->getFunctionType()),
2538 mlir::FlatSymbolRefAttr::get(dtor->getSymNameAttr()));
2541 cir::ReturnOp::create(builder, loc);
2544std::optional<FuncOp> LoweringPreparePass::buildCUDAModuleDtor() {
2545 if (!mlirModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName()))
2550 VoidType voidTy = VoidType::get(&getContext());
2551 PointerType voidPtrPtrTy = PointerType::get(PointerType::get(voidTy));
2553 mlir::Location loc = mlirModule.getLoc();
2555 cir::CIRBaseBuilderTy builder(getContext());
2556 builder.setInsertionPointToStart(mlirModule.getBody());
2559 std::string unregisterFuncName =
2561 FuncOp unregisterFunc = buildRuntimeFunction(
2562 builder, unregisterFuncName, loc, FuncType::get({voidPtrPtrTy}, voidTy));
2571 buildRuntimeFunction(builder, dtorName, loc, FuncType::get({}, voidTy),
2572 GlobalLinkageKind::InternalLinkage);
2574 builder.setInsertionPointToStart(dtor.addEntryBlock());
2580 GlobalOp gpubinGlobal = cast<GlobalOp>(mlirModule.lookupSymbol(gpubinName));
2582 mlir::Value gpubin = builder.
createLoad(loc, gpubinAddress);
2584 ReturnOp::create(builder, loc);
2601std::optional<FuncOp> LoweringPreparePass::buildHIPModuleDtor() {
2602 if (!mlirModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName()))
2607 VoidType voidTy = VoidType::get(&getContext());
2608 PointerType voidPtrPtrTy = PointerType::get(PointerType::get(voidTy));
2610 mlir::Location loc = mlirModule.getLoc();
2612 cir::CIRBaseBuilderTy builder(getContext());
2613 builder.setInsertionPointToStart(mlirModule.getBody());
2616 std::string unregisterFuncName =
2618 FuncOp unregisterFunc = buildRuntimeFunction(
2619 builder, unregisterFuncName, loc, FuncType::get({voidPtrPtrTy}, voidTy));
2623 buildRuntimeFunction(builder, dtorName, loc, FuncType::get({}, voidTy),
2624 GlobalLinkageKind::InternalLinkage);
2627 GlobalOp gpuBinGlobal = cast<GlobalOp>(mlirModule.lookupSymbol(gpubinName));
2629 mlir::Block *entryBlock = dtor.addEntryBlock();
2630 mlir::Block *ifBlock = builder.createBlock(&dtor.getBody());
2631 mlir::Block *exitBlock = builder.createBlock(&dtor.getBody());
2633 mlir::OpBuilder::InsertionGuard guard(builder);
2634 builder.setInsertionPointToEnd(entryBlock);
2635 mlir::Value handle =
2637 auto handlePtrTy = mlir::cast<cir::PointerType>(handle.getType());
2638 mlir::Value nullPtr = builder.
getNullPtr(handlePtrTy, loc);
2639 mlir::Value isNotNull =
2640 builder.
createCompare(loc, cir::CmpOpKind::ne, handle, nullPtr);
2641 cir::BrCondOp::create(builder, loc, isNotNull, ifBlock, exitBlock);
2645 mlir::OpBuilder::InsertionGuard ifGuard(builder);
2646 builder.setInsertionPointToStart(ifBlock);
2649 cir::BrOp::create(builder, loc, exitBlock);
2652 mlir::OpBuilder::InsertionGuard exitGuard(builder);
2653 builder.setInsertionPointToStart(exitBlock);
2654 cir::ReturnOp::create(builder, loc);
2660std::optional<FuncOp> LoweringPreparePass::buildCUDARegisterGlobals() {
2662 if (cudaKernelMap.empty())
2665 cir::CIRBaseBuilderTy builder(getContext());
2666 builder.setInsertionPointToStart(mlirModule.getBody());
2668 mlir::Location loc = mlirModule.getLoc();
2671 auto voidTy = VoidType::get(&getContext());
2672 auto voidPtrTy = PointerType::get(voidTy);
2673 auto voidPtrPtrTy = PointerType::get(voidPtrTy);
2677 std::string regGlobalFuncName =
2679 auto regGlobalFuncTy = FuncType::get({voidPtrPtrTy}, voidTy);
2680 FuncOp regGlobalFunc =
2681 buildRuntimeFunction(builder, regGlobalFuncName, loc, regGlobalFuncTy,
2682 GlobalLinkageKind::InternalLinkage);
2683 builder.setInsertionPointToStart(regGlobalFunc.addEntryBlock());
2685 buildCUDARegisterGlobalFunctions(builder, regGlobalFunc);
2689 ReturnOp::create(builder, loc);
2690 return regGlobalFunc;
2693void LoweringPreparePass::buildCUDARegisterGlobalFunctions(
2694 cir::CIRBaseBuilderTy &builder, FuncOp regGlobalFunc) {
2695 mlir::Location loc = mlirModule.getLoc();
2697 cir::CIRDataLayout dataLayout(mlirModule);
2699 auto voidTy = VoidType::get(&getContext());
2700 auto voidPtrTy = PointerType::get(voidTy);
2701 auto voidPtrPtrTy = PointerType::get(voidPtrTy);
2703 IntType charTy = cir::IntType::get(&getContext(), astCtx->
getCharWidth(),
2707 mlir::Value fatbinHandle = *regGlobalFunc.args_begin();
2709 cir::CIRBaseBuilderTy globalBuilder(getContext());
2710 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
2724 FuncOp cudaRegisterFunction = buildRuntimeFunction(
2726 FuncType::get({voidPtrPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, intTy,
2727 voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy},
2730 auto makeConstantString = [&](llvm::StringRef str) -> GlobalOp {
2731 auto strType = ArrayType::get(&getContext(), charTy, 1 + str.size());
2732 auto tmpString = cir::GlobalOp::create(
2733 globalBuilder, loc, (
".str" + str).str(), strType,
2735 cir::GlobalLinkageKind::PrivateLinkage);
2738 tmpString.setInitialValueAttr(
2739 ConstArrayAttr::get(strType, StringAttr::get(str +
"\0", strType)));
2740 tmpString.setPrivate();
2744 cir::ConstantOp cirNullPtr = builder.
getNullPtr(voidPtrTy, loc);
2746 for (
auto kernelName : cudaKernelMap.keys()) {
2747 FuncOp deviceStub = cudaKernelMap[kernelName];
2748 GlobalOp deviceFuncStr = makeConstantString(kernelName);
2752 mlir::Value hostFunc;
2759 auto funcHandle = cast<GlobalOp>(mlirModule.lookupSymbol(kernelName));
2764 GetGlobalOp::create(
2765 builder, loc, PointerType::get(deviceStub.getFunctionType()),
2766 mlir::FlatSymbolRefAttr::get(deviceStub.getSymNameAttr())),
2770 loc, cudaRegisterFunction,
2771 {fatbinHandle, hostFunc, deviceFunc, deviceFunc,
2772 ConstantOp::create(builder, loc, IntAttr::get(intTy, -1)), cirNullPtr,
2773 cirNullPtr, cirNullPtr, cirNullPtr, cirNullPtr});
2777void LoweringPreparePass::runOnOperation() {
2778 mlir::Operation *op = getOperation();
2779 if (isa<::mlir::ModuleOp>(op))
2780 mlirModule = cast<::mlir::ModuleOp>(op);
2782 llvm::SmallVector<mlir::Operation *> opsToTransform;
2784 op->walk([&](mlir::Operation *op) {
2785 if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
2786 cir::ComplexMulOp, cir::ComplexDivOp, cir::DynamicCastOp,
2787 cir::FuncOp, cir::CallOp, cir::GetGlobalOp, cir::GlobalOp,
2788 cir::StoreOp, cir::CmpThreeWayOp, cir::IncOp, cir::DecOp,
2789 cir::MinusOp, cir::NotOp, cir::LocalInitOp>(op))
2790 opsToTransform.push_back(op);
2793 for (mlir::Operation *o : opsToTransform)
2796 buildCXXGlobalInitFunc();
2797 buildCXXGlobalTlsFunc();
2799 buildCUDAModuleCtor();
2801 buildGlobalCtorDtorList();
2805 return std::make_unique<LoweringPreparePass>();
2808std::unique_ptr<Pass>
2810 auto pass = std::make_unique<LoweringPreparePass>();
2811 pass->setASTContext(astCtx);
2812 return std::move(pass);
Defines the clang::ASTContext interface.
static void emitBody(CodeGenFunction &CGF, const Stmt *S, const Stmt *NextLoop, int MaxLevel, int Level=0)
static llvm::FunctionCallee getGuardReleaseFn(CodeGenModule &CGM, llvm::PointerType *GuardPtrTy)
static llvm::FunctionCallee getGuardAcquireFn(CodeGenModule &CGM, llvm::PointerType *GuardPtrTy)
static mlir::Value buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static llvm::StringRef getComplexDivLibCallName(llvm::APFloat::Semantics semantics)
static llvm::SmallVector< mlir::Attribute > prepareCtorDtorAttrList(mlir::MLIRContext *context, llvm::ArrayRef< std::pair< std::string, uint32_t > > list)
static llvm::StringRef getComplexMulLibCallName(llvm::APFloat::Semantics semantics)
static cir::GlobalLinkageKind getThreadLocalWrapperLinkage(GlobalOp op, clang::ASTContext &astCtx)
static mlir::Value buildComplexBinOpLibCall(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, llvm::StringRef(*libFuncNameGetter)(llvm::APFloat::Semantics), mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static mlir::Value lowerComplexMul(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexMulOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static std::string addUnderscoredPrefix(llvm::StringRef prefix, llvm::StringRef name)
static SmallString< 128 > getTransformedFileName(mlir::ModuleOp mlirModule)
static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind scalarCastKind)
static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder, clang::ASTContext *astCtx, mlir::Operation *op, mlir::Type eltTy, mlir::Value addr, mlir::Value numElements, uint64_t arrayLen, bool isCtor)
Lower a cir.array.ctor or cir.array.dtor into a do-while loop that iterates over every element.
static bool isThreadWrapperReplaceable(cir::TLS_Model tls, clang::ASTContext &astCtx)
static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind elemToBoolKind)
static mlir::Value buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static llvm::StringRef getCUDAPrefix(clang::ASTContext *astCtx)
static mlir::Type higherPrecisionElementTypeForComplexArithmetic(mlir::MLIRContext &context, clang::ASTContext &cc, CIRBaseBuilderTy &builder, mlir::Type elementType)
static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op)
static mlir::Value lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag, mlir::MLIRContext &mlirCx, clang::ASTContext &cc)
Defines the clang::Module class, which describes a module in the source code.
static bool compare(const PathDiagnostic &X, const PathDiagnostic &Y)
Defines the SourceManager interface.
Defines various enumerations that describe declaration and type specifiers.
Defines the TargetCXXABI class, which abstracts details of the C++ ABI that we're targeting.
__device__ __2f16 float c
mlir::Value createDiv(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::TypedAttr getConstNullPtrAttr(mlir::Type t)
mlir::Value createDec(mlir::Location loc, mlir::Value input, bool nsw=false)
mlir::Value createLogicalOr(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createSub(mlir::Location loc, mlir::Value lhs, mlir::Value rhs, OverflowBehavior ob=OverflowBehavior::None)
cir::ConditionOp createCondition(mlir::Value condition)
Create a loop condition.
mlir::Value createInc(mlir::Location loc, mlir::Value input, bool nsw=false)
cir::CopyOp createCopy(mlir::Value dst, mlir::Value src, bool isVolatile=false, bool skipTailPadding=false)
Create a copy with inferred length.
cir::VoidType getVoidTy()
cir::ConstantOp getNullValue(mlir::Type ty, mlir::Location loc)
mlir::Value createCast(mlir::Location loc, cir::CastKind kind, mlir::Value src, mlir::Type newTy)
cir::PointerType getVoidFnPtrTy(mlir::TypeRange argTypes={})
Returns void (*)(T...) as a cir::PointerType.
mlir::Value createAdd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs, OverflowBehavior ob=OverflowBehavior::None)
cir::PointerType getPointerTo(mlir::Type ty)
mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand)
cir::ConstantOp getNullPtr(mlir::Type ty, mlir::Location loc)
cir::DoWhileOp createDoWhile(mlir::Location loc, llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> condBuilder, llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> bodyBuilder)
Create a do-while operation.
cir::GetGlobalOp createGetGlobal(mlir::Location loc, cir::GlobalOp global, bool threadLocal=false)
mlir::Value getSignedInt(mlir::Location loc, int64_t val, unsigned numBits)
mlir::Value createAnd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createBitcast(mlir::Value src, mlir::Type newTy)
cir::FuncType getVoidFnTy(mlir::TypeRange argTypes={})
Returns void (T...) as a cir::FuncType.
cir::CmpOp createCompare(mlir::Location loc, cir::CmpOpKind kind, mlir::Value lhs, mlir::Value rhs)
mlir::IntegerAttr getAlignmentAttr(clang::CharUnits alignment)
mlir::Value createSelect(mlir::Location loc, mlir::Value condition, mlir::Value trueValue, mlir::Value falseValue)
mlir::Value createMul(mlir::Location loc, mlir::Value lhs, mlir::Value rhs, OverflowBehavior ob=OverflowBehavior::None)
cir::LoadOp createLoad(mlir::Location loc, mlir::Value ptr, bool isVolatile=false, uint64_t alignment=0)
mlir::Value createMinus(mlir::Location loc, mlir::Value input, bool nsw=false)
cir::ConstantOp getConstantInt(mlir::Location loc, mlir::Type ty, int64_t value)
mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real, mlir::Value imag)
cir::PointerType getVoidPtrTy(clang::LangAS langAS=clang::LangAS::Default)
mlir::Value createIsNaN(mlir::Location loc, mlir::Value operand)
cir::IntType getSIntNTy(int n)
mlir::Value createAlignedLoad(mlir::Location loc, mlir::Value ptr, uint64_t alignment)
cir::CallOp createCallOp(mlir::Location loc, mlir::SymbolRefAttr callee, mlir::Type returnType, mlir::ValueRange operands, llvm::ArrayRef< mlir::NamedAttribute > attrs={}, llvm::ArrayRef< mlir::NamedAttrList > argAttrs={}, llvm::ArrayRef< mlir::NamedAttribute > resAttrs={})
cir::StoreOp createStore(mlir::Location loc, mlir::Value val, mlir::Value dst, bool isVolatile=false, mlir::IntegerAttr align={}, cir::SyncScopeKindAttr scope={}, cir::MemOrderAttr order={})
cir::YieldOp createYield(mlir::Location loc, mlir::ValueRange value={})
Create a yield operation.
mlir::Value createLogicalAnd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType, mlir::Type type, llvm::StringRef name, mlir::IntegerAttr alignment, mlir::Value dynAllocSize)
cir::BoolType getBoolTy()
mlir::Value getUnsignedInt(mlir::Location loc, uint64_t val, unsigned numBits)
mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand)
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
SourceManager & getSourceManager()
MangleContext * createMangleContext(const TargetInfo *T=nullptr)
If T is null pointer, assume the target in ASTContext.
const LangOptions & getLangOpts() const
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
const TargetInfo & getTargetInfo() const
QualType getSignedSizeType() const
Return the unique signed counterpart of the integer type corresponding to size_t.
Module * getCurrentNamedModule() const
Get module under construction, nullptr if this is not a C++20 module.
uint64_t getCharWidth() const
Return the size of the character type, in bits.
llvm::Align getAsAlign() const
getAsAlign - Returns Quantity as a valid llvm::Align, Beware llvm::Align assumes power of two 8-bit b...
QuantityType getQuantity() const
getQuantity - Get the raw integer representation of this quantity.
static CharUnits One()
One - Construct a CharUnits quantity of one.
static CharUnits fromQuantity(QuantityType Quantity)
fromQuantity - Construct a CharUnits quantity from a raw integer type.
llvm::vfs::FileSystem & getVirtualFileSystem() const
bool isModuleImplementation() const
Is this a module implementation.
FileManager & getFileManager() const
Exposes information about the current target.
const llvm::Triple & getTriple() const
Returns the target triple of the primary target.
unsigned getMaxAtomicInlineWidth() const
Return the maximum width lock-free atomic operation which can be inlined given the supported features...
const llvm::fltSemantics & getDoubleFormat() const
const llvm::fltSemantics & getHalfFormat() const
const llvm::fltSemantics & getBFloat16Format() const
const llvm::fltSemantics & getLongDoubleFormat() const
const llvm::fltSemantics & getFloatFormat() const
const llvm::fltSemantics & getFloat128Format() const
const llvm::VersionTuple & getSDKVersion() const
Defines the clang::TargetInfo interface.
static bool isLocalLinkage(GlobalLinkageKind linkage)
static bool isWeakODRLinkage(GlobalLinkageKind linkage)
static bool isLinkOnceLinkage(GlobalLinkageKind linkage)
const internal::VariadicDynCastAllOfMatcher< Decl, VarDecl > varDecl
Matches variable declarations.
bool isHIP(ID Id)
isHIP - Is this a HIP input.
bool isTemplateInstantiation(TemplateSpecializationKind Kind)
Determine whether this template specialization kind refers to an instantiation of an entity (as oppos...
bool CudaFeatureEnabled(llvm::VersionTuple, CudaFeature)
LLVM_READONLY bool isPreprocessingNumberBody(unsigned char c)
Return true if this is the body character of a C preprocessing number, which is [a-zA-Z0-9_.
@ CUDA_USES_FATBIN_REGISTER_END
std::unique_ptr< Pass > createLoweringPreparePass()
static bool hipModuleCtor()
static bool guardAbortOnException()
static bool opGlobalAnnotations()
static bool opGlobalCtorPriority()
static bool shouldSplitConstantStore()
static bool shouldUseMemSetToInitialize()
static bool opFuncExtraAttrs()
static bool shouldUseBZeroPlusStoresToInitialize()
static bool globalRegistration()
static bool fastMathFlags()
static bool astVarDeclInterface()