19#include "llvm/Support/Path.h"
27 SmallString<128> fileName;
29 if (mlirModule.getSymName())
30 fileName = llvm::sys::path::filename(mlirModule.getSymName()->str());
35 for (
size_t i = 0; i < fileName.size(); ++i) {
47 mlir::SymbolRefAttr sym = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
48 callOp.getCallableForCallee());
51 return dyn_cast_or_null<cir::FuncOp>(
52 mlir::SymbolTable::lookupNearestSymbolFrom(callOp, sym));
56struct LoweringPreparePass :
public LoweringPrepareBase<LoweringPreparePass> {
57 LoweringPreparePass() =
default;
58 void runOnOperation()
override;
60 void runOnOp(mlir::Operation *op);
61 void lowerCastOp(cir::CastOp op);
62 void lowerComplexDivOp(cir::ComplexDivOp op);
63 void lowerComplexMulOp(cir::ComplexMulOp op);
64 void lowerUnaryOp(cir::UnaryOp op);
65 void lowerGlobalOp(cir::GlobalOp op);
66 void lowerDynamicCastOp(cir::DynamicCastOp op);
67 void lowerArrayDtor(cir::ArrayDtor op);
68 void lowerArrayCtor(cir::ArrayCtor op);
71 cir::FuncOp buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op);
74 void buildCXXGlobalInitFunc();
77 void buildGlobalCtorDtorList();
79 cir::FuncOp buildRuntimeFunction(
80 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
82 cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage);
84 cir::GlobalOp buildRuntimeVariable(
85 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
87 cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage,
88 cir::VisibilityKind visibility = cir::VisibilityKind::Default);
94 clang::ASTContext *astCtx;
97 std::shared_ptr<cir::LoweringPrepareCXXABI> cxxABI;
100 mlir::ModuleOp mlirModule;
103 llvm::StringMap<uint32_t> dynamicInitializerNames;
104 llvm::SmallVector<cir::FuncOp> dynamicInitializers;
107 llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalCtorList;
109 llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalDtorList;
111 void setASTContext(clang::ASTContext *
c) {
113 switch (
c->getCXXABIKind()) {
114 case clang::TargetCXXABI::GenericItanium:
120 case clang::TargetCXXABI::GenericAArch64:
121 case clang::TargetCXXABI::AppleARM64:
126 llvm_unreachable(
"NYI");
133cir::GlobalOp LoweringPreparePass::buildRuntimeVariable(
134 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
135 mlir::Type type, cir::GlobalLinkageKind linkage,
136 cir::VisibilityKind visibility) {
137 cir::GlobalOp g = dyn_cast_or_null<cir::GlobalOp>(
138 mlir::SymbolTable::lookupNearestSymbolFrom(
139 mlirModule, mlir::StringAttr::get(mlirModule->getContext(), name)));
141 g = cir::GlobalOp::create(builder, loc, name, type);
143 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
144 mlir::SymbolTable::setSymbolVisibility(
145 g, mlir::SymbolTable::Visibility::Private);
146 g.setGlobalVisibilityAttr(
147 cir::VisibilityAttr::get(builder.getContext(), visibility));
152cir::FuncOp LoweringPreparePass::buildRuntimeFunction(
153 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
154 cir::FuncType type, cir::GlobalLinkageKind linkage) {
155 cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom(
156 mlirModule, StringAttr::get(mlirModule->getContext(), name)));
158 f = builder.create<cir::FuncOp>(loc,
name,
type);
160 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
161 mlir::SymbolTable::setSymbolVisibility(
162 f, mlir::SymbolTable::Visibility::Private);
172 builder.setInsertionPoint(op);
174 mlir::Value src = op.getSrc();
175 mlir::Value imag = builder.
getNullValue(src.getType(), op.getLoc());
181 cir::CastKind elemToBoolKind) {
183 builder.setInsertionPoint(op);
185 mlir::Value src = op.getSrc();
186 if (!mlir::isa<cir::BoolType>(op.getType()))
193 cir::BoolType boolTy = builder.
getBoolTy();
194 mlir::Value srcRealToBool =
195 builder.
createCast(op.getLoc(), elemToBoolKind, srcReal, boolTy);
196 mlir::Value srcImagToBool =
197 builder.
createCast(op.getLoc(), elemToBoolKind, srcImag, boolTy);
198 return builder.
createLogicalOr(op.getLoc(), srcRealToBool, srcImagToBool);
203 cir::CastKind scalarCastKind) {
205 builder.setInsertionPoint(op);
207 mlir::Value src = op.getSrc();
208 auto dstComplexElemTy =
209 mlir::cast<cir::ComplexType>(op.getType()).getElementType();
214 mlir::Value dstReal = builder.
createCast(op.getLoc(), scalarCastKind, srcReal,
216 mlir::Value dstImag = builder.
createCast(op.getLoc(), scalarCastKind, srcImag,
221void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
222 mlir::MLIRContext &ctx = getContext();
223 mlir::Value loweredValue = [&]() -> mlir::Value {
224 switch (op.getKind()) {
225 case cir::CastKind::float_to_complex:
226 case cir::CastKind::int_to_complex:
228 case cir::CastKind::float_complex_to_real:
229 case cir::CastKind::int_complex_to_real:
231 case cir::CastKind::float_complex_to_bool:
233 case cir::CastKind::int_complex_to_bool:
235 case cir::CastKind::float_complex:
237 case cir::CastKind::float_complex_to_int_complex:
239 case cir::CastKind::int_complex:
241 case cir::CastKind::int_complex_to_float_complex:
249 op.replaceAllUsesWith(loweredValue);
256 llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics),
257 mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal,
258 mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) {
259 cir::FPTypeInterface elementTy =
260 mlir::cast<cir::FPTypeInterface>(ty.getElementType());
262 llvm::StringRef libFuncName = libFuncNameGetter(
263 llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics()));
266 cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty);
272 mlir::OpBuilder::InsertionGuard ipGuard{builder};
273 builder.setInsertionPointToStart(pass.mlirModule.getBody());
274 libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy);
278 builder.
createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag});
279 return call.getResult();
282static llvm::StringRef
285 case llvm::APFloat::S_IEEEhalf:
287 case llvm::APFloat::S_IEEEsingle:
289 case llvm::APFloat::S_IEEEdouble:
291 case llvm::APFloat::S_PPCDoubleDouble:
293 case llvm::APFloat::S_x87DoubleExtended:
295 case llvm::APFloat::S_IEEEquad:
298 llvm_unreachable(
"unsupported floating point type");
304 mlir::Value lhsReal, mlir::Value lhsImag,
305 mlir::Value rhsReal, mlir::Value rhsImag) {
307 mlir::Value &a = lhsReal;
308 mlir::Value &
b = lhsImag;
309 mlir::Value &
c = rhsReal;
310 mlir::Value &d = rhsImag;
312 mlir::Value ac = builder.
createBinop(loc, a, cir::BinOpKind::Mul,
c);
313 mlir::Value bd = builder.
createBinop(loc,
b, cir::BinOpKind::Mul, d);
314 mlir::Value cc = builder.
createBinop(loc,
c, cir::BinOpKind::Mul,
c);
315 mlir::Value dd = builder.
createBinop(loc, d, cir::BinOpKind::Mul, d);
317 builder.
createBinop(loc, ac, cir::BinOpKind::Add, bd);
319 builder.
createBinop(loc, cc, cir::BinOpKind::Add, dd);
320 mlir::Value resultReal =
321 builder.
createBinop(loc, acbd, cir::BinOpKind::Div, ccdd);
323 mlir::Value bc = builder.
createBinop(loc,
b, cir::BinOpKind::Mul,
c);
324 mlir::Value ad = builder.
createBinop(loc, a, cir::BinOpKind::Mul, d);
326 builder.
createBinop(loc, bc, cir::BinOpKind::Sub, ad);
327 mlir::Value resultImag =
328 builder.
createBinop(loc, bcad, cir::BinOpKind::Div, ccdd);
334 mlir::Value lhsReal, mlir::Value lhsImag,
335 mlir::Value rhsReal, mlir::Value rhsImag) {
356 mlir::Value &a = lhsReal;
357 mlir::Value &
b = lhsImag;
358 mlir::Value &
c = rhsReal;
359 mlir::Value &d = rhsImag;
361 auto trueBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
362 mlir::Value r = builder.
createBinop(loc, d, cir::BinOpKind::Div,
364 mlir::Value rd = builder.
createBinop(loc, r, cir::BinOpKind::Mul, d);
365 mlir::Value tmp = builder.
createBinop(loc,
c, cir::BinOpKind::Add,
368 mlir::Value br = builder.
createBinop(loc,
b, cir::BinOpKind::Mul, r);
370 builder.
createBinop(loc, a, cir::BinOpKind::Add, br);
371 mlir::Value e = builder.
createBinop(loc, abr, cir::BinOpKind::Div, tmp);
373 mlir::Value ar = builder.
createBinop(loc, a, cir::BinOpKind::Mul, r);
376 mlir::Value f = builder.
createBinop(loc, bar, cir::BinOpKind::Div, tmp);
382 auto falseBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
383 mlir::Value r = builder.
createBinop(loc,
c, cir::BinOpKind::Div,
385 mlir::Value rc = builder.
createBinop(loc, r, cir::BinOpKind::Mul,
c);
386 mlir::Value tmp = builder.
createBinop(loc, d, cir::BinOpKind::Add,
389 mlir::Value ar = builder.
createBinop(loc, a, cir::BinOpKind::Mul, r);
392 mlir::Value e = builder.
createBinop(loc, arb, cir::BinOpKind::Div, tmp);
394 mlir::Value br = builder.
createBinop(loc,
b, cir::BinOpKind::Mul, r);
396 builder.
createBinop(loc, br, cir::BinOpKind::Sub, a);
397 mlir::Value f = builder.
createBinop(loc, bra, cir::BinOpKind::Div, tmp);
403 auto cFabs = builder.create<cir::FAbsOp>(loc,
c);
404 auto dFabs = builder.create<cir::FAbsOp>(loc, d);
405 cir::CmpOp cmpResult =
406 builder.
createCompare(loc, cir::CmpOpKind::ge, cFabs, dFabs);
407 auto ternary = builder.create<cir::TernaryOp>(
408 loc, cmpResult, trueBranchBuilder, falseBranchBuilder);
410 return ternary.getResult();
417 auto getHigherPrecisionFPType = [&context](mlir::Type type) -> mlir::Type {
418 if (mlir::isa<cir::FP16Type>(type))
419 return cir::SingleType::get(&context);
421 if (mlir::isa<cir::SingleType>(type) || mlir::isa<cir::BF16Type>(type))
422 return cir::DoubleType::get(&context);
424 if (mlir::isa<cir::DoubleType>(type))
425 return cir::LongDoubleType::get(&context, type);
430 auto getFloatTypeSemantics =
431 [&cc](mlir::Type type) ->
const llvm::fltSemantics & {
433 if (mlir::isa<cir::FP16Type>(type))
436 if (mlir::isa<cir::BF16Type>(type))
439 if (mlir::isa<cir::SingleType>(type))
442 if (mlir::isa<cir::DoubleType>(type))
445 if (mlir::isa<cir::LongDoubleType>(type)) {
447 llvm_unreachable(
"NYI Float type semantics with OpenMP");
451 if (mlir::isa<cir::FP128Type>(type)) {
453 llvm_unreachable(
"NYI Float type semantics with OpenMP");
457 assert(
false &&
"Unsupported float type semantics");
460 const mlir::Type higherElementType = getHigherPrecisionFPType(elementType);
461 const llvm::fltSemantics &elementTypeSemantics =
462 getFloatTypeSemantics(elementType);
463 const llvm::fltSemantics &higherElementTypeSemantics =
464 getFloatTypeSemantics(higherElementType);
473 if (llvm::APFloat::semanticsMaxExponent(elementTypeSemantics) * 2 + 1 <=
474 llvm::APFloat::semanticsMaxExponent(higherElementTypeSemantics)) {
475 return higherElementType;
485 mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal,
486 mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag,
488 cir::ComplexType complexTy = op.getType();
489 if (mlir::isa<cir::FPTypeInterface>(complexTy.getElementType())) {
490 cir::ComplexRangeKind range = op.getRange();
491 if (range == cir::ComplexRangeKind::Improved)
495 if (range == cir::ComplexRangeKind::Full)
497 loc, complexTy, lhsReal, lhsImag, rhsReal,
500 if (range == cir::ComplexRangeKind::Promoted) {
501 mlir::Type originalElementType = complexTy.getElementType();
502 mlir::Type higherPrecisionElementType =
504 originalElementType);
506 if (!higherPrecisionElementType)
510 cir::CastKind floatingCastKind = cir::CastKind::floating;
511 lhsReal = builder.
createCast(floatingCastKind, lhsReal,
512 higherPrecisionElementType);
513 lhsImag = builder.
createCast(floatingCastKind, lhsImag,
514 higherPrecisionElementType);
515 rhsReal = builder.
createCast(floatingCastKind, rhsReal,
516 higherPrecisionElementType);
517 rhsImag = builder.
createCast(floatingCastKind, rhsImag,
518 higherPrecisionElementType);
521 builder, loc, lhsReal, lhsImag, rhsReal, rhsImag);
526 mlir::Value finalReal =
527 builder.
createCast(floatingCastKind, resultReal, originalElementType);
528 mlir::Value finalImag =
529 builder.
createCast(floatingCastKind, resultImag, originalElementType);
538void LoweringPreparePass::lowerComplexDivOp(cir::ComplexDivOp op) {
539 cir::CIRBaseBuilderTy builder(getContext());
540 builder.setInsertionPointAfter(op);
541 mlir::Location loc = op.getLoc();
542 mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
543 mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
549 mlir::Value loweredResult =
551 rhsImag, getContext(), *astCtx);
552 op.replaceAllUsesWith(loweredResult);
556static llvm::StringRef
559 case llvm::APFloat::S_IEEEhalf:
561 case llvm::APFloat::S_IEEEsingle:
563 case llvm::APFloat::S_IEEEdouble:
565 case llvm::APFloat::S_PPCDoubleDouble:
567 case llvm::APFloat::S_x87DoubleExtended:
569 case llvm::APFloat::S_IEEEquad:
572 llvm_unreachable(
"unsupported floating point type");
578 mlir::Location loc, cir::ComplexMulOp op,
579 mlir::Value lhsReal, mlir::Value lhsImag,
580 mlir::Value rhsReal, mlir::Value rhsImag) {
582 mlir::Value resultRealLhs =
583 builder.
createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsReal);
584 mlir::Value resultRealRhs =
585 builder.
createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsImag);
586 mlir::Value resultImagLhs =
587 builder.
createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsImag);
588 mlir::Value resultImagRhs =
589 builder.
createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsReal);
591 loc, resultRealLhs, cir::BinOpKind::Sub, resultRealRhs);
593 loc, resultImagLhs, cir::BinOpKind::Add, resultImagRhs);
594 mlir::Value algebraicResult =
597 cir::ComplexType complexTy = op.getType();
598 cir::ComplexRangeKind rangeKind = op.getRange();
599 if (mlir::isa<cir::IntType>(complexTy.getElementType()) ||
600 rangeKind == cir::ComplexRangeKind::Basic ||
601 rangeKind == cir::ComplexRangeKind::Improved ||
602 rangeKind == cir::ComplexRangeKind::Promoted)
603 return algebraicResult;
610 mlir::Value resultRealIsNaN = builder.
createIsNaN(loc, resultReal);
611 mlir::Value resultImagIsNaN = builder.
createIsNaN(loc, resultImag);
612 mlir::Value resultRealAndImagAreNaN =
616 .create<cir::TernaryOp>(
617 loc, resultRealAndImagAreNaN,
618 [&](mlir::OpBuilder &, mlir::Location) {
621 lhsReal, lhsImag, rhsReal, rhsImag);
624 [&](mlir::OpBuilder &, mlir::Location) {
630void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) {
631 cir::CIRBaseBuilderTy builder(getContext());
632 builder.setInsertionPointAfter(op);
633 mlir::Location loc = op.getLoc();
634 mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
635 mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
640 mlir::Value loweredResult =
lowerComplexMul(*
this, builder, loc, op, lhsReal,
641 lhsImag, rhsReal, rhsImag);
642 op.replaceAllUsesWith(loweredResult);
646void LoweringPreparePass::lowerUnaryOp(cir::UnaryOp op) {
647 mlir::Type ty = op.getType();
648 if (!mlir::isa<cir::ComplexType>(ty))
651 mlir::Location loc = op.getLoc();
652 cir::UnaryOpKind opKind = op.getKind();
654 CIRBaseBuilderTy builder(getContext());
655 builder.setInsertionPointAfter(op);
657 mlir::Value operand = op.getInput();
661 mlir::Value resultReal;
662 mlir::Value resultImag;
665 case cir::UnaryOpKind::Inc:
666 case cir::UnaryOpKind::Dec:
667 resultReal = builder.
createUnaryOp(loc, opKind, operandReal);
668 resultImag = operandImag;
671 case cir::UnaryOpKind::Plus:
672 case cir::UnaryOpKind::Minus:
673 resultReal = builder.
createUnaryOp(loc, opKind, operandReal);
674 resultImag = builder.
createUnaryOp(loc, opKind, operandImag);
677 case cir::UnaryOpKind::Not:
678 resultReal = operandReal;
680 builder.
createUnaryOp(loc, cir::UnaryOpKind::Minus, operandImag);
685 op.replaceAllUsesWith(result);
690LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op) {
693 SmallString<256> fnName(
"__cxx_global_var_init");
695 uint32_t cnt = dynamicInitializerNames[fnName]++;
697 fnName +=
"." + llvm::Twine(cnt).str();
700 CIRBaseBuilderTy builder(getContext());
701 builder.setInsertionPointAfter(op);
702 cir::VoidType voidTy = builder.
getVoidTy();
703 auto fnType = cir::FuncType::get({}, voidTy);
704 FuncOp f = buildRuntimeFunction(builder, fnName, op.getLoc(), fnType,
705 cir::GlobalLinkageKind::InternalLinkage);
708 mlir::Block *entryBB = f.addEntryBlock();
709 if (!op.getCtorRegion().empty()) {
710 mlir::Block &block = op.getCtorRegion().front();
711 entryBB->getOperations().splice(entryBB->begin(), block.getOperations(),
712 block.begin(), std::prev(block.end()));
716 mlir::Region &dtorRegion = op.getDtorRegion();
717 if (!dtorRegion.empty()) {
721 builder.setInsertionPointToStart(&mlirModule.getBodyRegion().front());
722 cir::GlobalOp handle = buildRuntimeVariable(
723 builder,
"__dso_handle", op.getLoc(), builder.getI8Type(),
724 cir::GlobalLinkageKind::ExternalLinkage, cir::VisibilityKind::Hidden);
727 mlir::Block &dtorBlock = dtorRegion.front();
728 cir::CallOp dtorCall;
729 for (
auto op : reverse(dtorBlock.getOps<cir::CallOp>())) {
733 assert(dtorCall &&
"Expected a dtor call");
735 assert(dtorFunc &&
"Expected a dtor call");
739 auto voidPtrTy = cir::PointerType::get(voidTy);
740 auto voidFnTy = cir::FuncType::get({voidPtrTy}, voidTy);
741 auto voidFnPtrTy = cir::PointerType::get(voidFnTy);
742 auto handlePtrTy = cir::PointerType::get(handle.getSymType());
744 cir::FuncType::get({voidFnPtrTy, voidPtrTy, handlePtrTy}, voidTy);
745 const char *nameAtExit =
"__cxa_atexit";
746 cir::FuncOp fnAtExit =
747 buildRuntimeFunction(builder, nameAtExit, op.getLoc(), fnAtExitType);
751 builder.setInsertionPointAfter(dtorCall);
753 auto dtorPtrTy = cir::PointerType::get(dtorFunc.getFunctionType());
755 args[0] = cir::GetGlobalOp::create(builder, dtorCall.getLoc(), dtorPtrTy,
756 dtorFunc.getSymName());
757 args[0] = cir::CastOp::create(builder, dtorCall.getLoc(), voidFnPtrTy,
758 cir::CastKind::bitcast, args[0]);
760 cir::CastOp::create(builder, dtorCall.getLoc(), voidPtrTy,
761 cir::CastKind::bitcast, dtorCall.getArgOperand(0));
762 args[2] = cir::GetGlobalOp::create(builder, handle.getLoc(), handlePtrTy,
763 handle.getSymName());
764 builder.
createCallOp(dtorCall.getLoc(), fnAtExit, args);
766 entryBB->getOperations().splice(entryBB->end(), dtorBlock.getOperations(),
768 std::prev(dtorBlock.end()));
772 builder.setInsertionPointToEnd(entryBB);
773 mlir::Operation *yieldOp =
nullptr;
774 if (!op.getCtorRegion().empty()) {
775 mlir::Block &block = op.getCtorRegion().front();
776 yieldOp = &block.getOperations().back();
778 assert(!dtorRegion.empty());
779 mlir::Block &block = dtorRegion.front();
780 yieldOp = &block.getOperations().back();
783 assert(isa<cir::YieldOp>(*yieldOp));
784 cir::ReturnOp::create(builder, yieldOp->getLoc());
788void LoweringPreparePass::lowerGlobalOp(GlobalOp op) {
789 mlir::Region &ctorRegion = op.getCtorRegion();
790 mlir::Region &dtorRegion = op.getDtorRegion();
792 if (!ctorRegion.empty() || !dtorRegion.empty()) {
795 cir::FuncOp f = buildCXXGlobalVarDeclInitFunc(op);
798 ctorRegion.getBlocks().clear();
799 dtorRegion.getBlocks().clear();
802 dynamicInitializers.push_back(f);
808template <
typename AttributeTy>
809static llvm::SmallVector<mlir::Attribute>
813 for (
const auto &[name, priority] : list)
814 attrs.push_back(AttributeTy::get(context, name, priority));
818void LoweringPreparePass::buildGlobalCtorDtorList() {
819 if (!globalCtorList.empty()) {
820 llvm::SmallVector<mlir::Attribute> globalCtors =
824 mlirModule->setAttr(cir::CIRDialect::getGlobalCtorsAttrName(),
825 mlir::ArrayAttr::get(&getContext(), globalCtors));
828 if (!globalDtorList.empty()) {
829 llvm::SmallVector<mlir::Attribute> globalDtors =
832 mlirModule->setAttr(cir::CIRDialect::getGlobalDtorsAttrName(),
833 mlir::ArrayAttr::get(&getContext(), globalDtors));
837void LoweringPreparePass::buildCXXGlobalInitFunc() {
838 if (dynamicInitializers.empty())
845 SmallString<256> fnName;
853 llvm::raw_svector_ostream
out(fnName);
854 std::unique_ptr<clang::MangleContext> mangleCtx(
856 cast<clang::ItaniumMangleContext>(*mangleCtx)
859 fnName +=
"_GLOBAL__sub_I_";
863 CIRBaseBuilderTy builder(getContext());
864 builder.setInsertionPointToEnd(&mlirModule.getBodyRegion().back());
865 auto fnType = cir::FuncType::get({}, builder.
getVoidTy());
867 buildRuntimeFunction(builder, fnName, mlirModule.getLoc(), fnType,
868 cir::GlobalLinkageKind::ExternalLinkage);
869 builder.setInsertionPointToStart(f.addEntryBlock());
870 for (cir::FuncOp &f : dynamicInitializers)
874 globalCtorList.emplace_back(fnName,
875 cir::GlobalCtorAttr::getDefaultPriority());
877 cir::ReturnOp::create(builder, f.getLoc());
880void LoweringPreparePass::lowerDynamicCastOp(DynamicCastOp op) {
881 CIRBaseBuilderTy builder(getContext());
882 builder.setInsertionPointAfter(op);
884 assert(astCtx &&
"AST context is not available during lowering prepare");
885 auto loweredValue = cxxABI->lowerDynamicCast(builder, *astCtx, op);
887 op.replaceAllUsesWith(loweredValue);
893 mlir::Operation *op, mlir::Type eltTy,
894 mlir::Value arrayAddr, uint64_t arrayLen,
897 mlir::Location loc = op->getLoc();
901 const unsigned sizeTypeSize =
903 uint64_t endOffset = isCtor ? arrayLen : arrayLen - 1;
904 mlir::Value endOffsetVal =
907 auto begin = cir::CastOp::create(builder, loc, eltTy,
908 cir::CastKind::array_to_ptrdecay, arrayAddr);
910 cir::PtrStrideOp::create(builder, loc, eltTy, begin, endOffsetVal);
911 mlir::Value start = isCtor ? begin : end;
912 mlir::Value stop = isCtor ? end : begin;
922 [&](mlir::OpBuilder &
b, mlir::Location loc) {
923 auto currentElement =
b.create<cir::LoadOp>(loc, eltTy, tmpAddr);
924 mlir::Type boolTy = cir::BoolType::get(
b.getContext());
925 auto cmp = builder.create<cir::CmpOp>(loc, boolTy, cir::CmpOpKind::ne,
926 currentElement, stop);
930 [&](mlir::OpBuilder &
b, mlir::Location loc) {
931 auto currentElement =
b.create<cir::LoadOp>(loc, eltTy, tmpAddr);
933 cir::CallOp ctorCall;
934 op->walk([&](cir::CallOp
c) { ctorCall =
c; });
935 assert(ctorCall &&
"expected ctor call");
944 ctorCall->moveBefore(stride.getDefiningOp());
945 ctorCall->setOperand(0, currentElement);
946 auto nextElement = cir::PtrStrideOp::create(builder, loc, eltTy,
947 currentElement, stride);
954 op->replaceAllUsesWith(loop);
958void LoweringPreparePass::lowerArrayDtor(cir::ArrayDtor op) {
959 CIRBaseBuilderTy builder(getContext());
960 builder.setInsertionPointAfter(op.getOperation());
962 mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
965 mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
970void LoweringPreparePass::lowerArrayCtor(cir::ArrayCtor op) {
971 cir::CIRBaseBuilderTy builder(getContext());
972 builder.setInsertionPointAfter(op.getOperation());
974 mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
977 mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
982void LoweringPreparePass::runOnOp(mlir::Operation *op) {
983 if (
auto arrayCtor = dyn_cast<cir::ArrayCtor>(op)) {
984 lowerArrayCtor(arrayCtor);
985 }
else if (
auto arrayDtor = dyn_cast<cir::ArrayDtor>(op)) {
986 lowerArrayDtor(arrayDtor);
987 }
else if (
auto cast = mlir::dyn_cast<cir::CastOp>(op)) {
989 }
else if (
auto complexDiv = mlir::dyn_cast<cir::ComplexDivOp>(op)) {
990 lowerComplexDivOp(complexDiv);
991 }
else if (
auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op)) {
992 lowerComplexMulOp(complexMul);
993 }
else if (
auto glob = mlir::dyn_cast<cir::GlobalOp>(op)) {
995 }
else if (
auto dynamicCast = mlir::dyn_cast<cir::DynamicCastOp>(op)) {
996 lowerDynamicCastOp(dynamicCast);
997 }
else if (
auto unary = mlir::dyn_cast<cir::UnaryOp>(op)) {
999 }
else if (
auto fnOp = dyn_cast<cir::FuncOp>(op)) {
1000 if (
auto globalCtor = fnOp.getGlobalCtorPriority())
1001 globalCtorList.emplace_back(fnOp.getName(), globalCtor.value());
1002 else if (
auto globalDtor = fnOp.getGlobalDtorPriority())
1003 globalDtorList.emplace_back(fnOp.getName(), globalDtor.value());
1007void LoweringPreparePass::runOnOperation() {
1008 mlir::Operation *op = getOperation();
1009 if (isa<::mlir::ModuleOp>(op))
1010 mlirModule = cast<::mlir::ModuleOp>(op);
1012 llvm::SmallVector<mlir::Operation *> opsToTransform;
1014 op->walk([&](mlir::Operation *op) {
1015 if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
1016 cir::ComplexMulOp, cir::ComplexDivOp, cir::DynamicCastOp,
1017 cir::FuncOp, cir::GlobalOp, cir::UnaryOp>(op))
1018 opsToTransform.push_back(op);
1021 for (mlir::Operation *o : opsToTransform)
1024 buildCXXGlobalInitFunc();
1025 buildGlobalCtorDtorList();
1029 return std::make_unique<LoweringPreparePass>();
1032std::unique_ptr<Pass>
1034 auto pass = std::make_unique<LoweringPreparePass>();
1035 pass->setASTContext(astCtx);
1036 return std::move(pass);
Defines the clang::ASTContext interface.
static mlir::Value buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder, clang::ASTContext *astCtx, mlir::Operation *op, mlir::Type eltTy, mlir::Value arrayAddr, uint64_t arrayLen, bool isCtor)
static llvm::StringRef getComplexDivLibCallName(llvm::APFloat::Semantics semantics)
static llvm::SmallVector< mlir::Attribute > prepareCtorDtorAttrList(mlir::MLIRContext *context, llvm::ArrayRef< std::pair< std::string, uint32_t > > list)
static llvm::StringRef getComplexMulLibCallName(llvm::APFloat::Semantics semantics)
static mlir::Value buildComplexBinOpLibCall(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, llvm::StringRef(*libFuncNameGetter)(llvm::APFloat::Semantics), mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static mlir::Value lowerComplexMul(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexMulOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static SmallString< 128 > getTransformedFileName(mlir::ModuleOp mlirModule)
static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind scalarCastKind)
static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind elemToBoolKind)
static mlir::Value buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static cir::FuncOp getCalledFunction(cir::CallOp callOp)
Return the FuncOp called by callOp.
static mlir::Type higherPrecisionElementTypeForComplexArithmetic(mlir::MLIRContext &context, clang::ASTContext &cc, CIRBaseBuilderTy &builder, mlir::Type elementType)
static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op)
static mlir::Value lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag, mlir::MLIRContext &mlirCx, clang::ASTContext &cc)
Defines the clang::Module class, which describes a module in the source code.
__device__ __2f16 float c
mlir::Value createLogicalOr(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
cir::ConditionOp createCondition(mlir::Value condition)
Create a loop condition.
cir::VoidType getVoidTy()
cir::ConstantOp getNullValue(mlir::Type ty, mlir::Location loc)
mlir::Value createCast(mlir::Location loc, cir::CastKind kind, mlir::Value src, mlir::Type newTy)
cir::PointerType getPointerTo(mlir::Type ty)
mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand)
cir::DoWhileOp createDoWhile(mlir::Location loc, llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> condBuilder, llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> bodyBuilder)
Create a do-while operation.
cir::CallOp createCallOp(mlir::Location loc, mlir::SymbolRefAttr callee, mlir::Type returnType, mlir::ValueRange operands, llvm::ArrayRef< mlir::NamedAttribute > attrs={})
mlir::Value getSignedInt(mlir::Location loc, int64_t val, unsigned numBits)
cir::StoreOp createStore(mlir::Location loc, mlir::Value val, mlir::Value dst, bool isVolatile=false, mlir::IntegerAttr align={}, cir::MemOrderAttr order={})
cir::CmpOp createCompare(mlir::Location loc, cir::CmpOpKind kind, mlir::Value lhs, mlir::Value rhs)
mlir::IntegerAttr getAlignmentAttr(clang::CharUnits alignment)
mlir::Value createBinop(mlir::Location loc, mlir::Value lhs, cir::BinOpKind kind, mlir::Value rhs)
mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real, mlir::Value imag)
mlir::Value createIsNaN(mlir::Location loc, mlir::Value operand)
cir::YieldOp createYield(mlir::Location loc, mlir::ValueRange value={})
Create a yield operation.
mlir::Value createLogicalAnd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createUnaryOp(mlir::Location loc, cir::UnaryOpKind kind, mlir::Value operand)
mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType, mlir::Type type, llvm::StringRef name, mlir::IntegerAttr alignment, mlir::Value dynAllocSize)
cir::BoolType getBoolTy()
mlir::Value getUnsignedInt(mlir::Location loc, uint64_t val, unsigned numBits)
mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand)
static LoweringPrepareCXXABI * createItaniumABI()
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
MangleContext * createMangleContext(const TargetInfo *T=nullptr)
If T is null pointer, assume the target in ASTContext.
const LangOptions & getLangOpts() const
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
const TargetInfo & getTargetInfo() const
QualType getSignedSizeType() const
Return the unique signed counterpart of the integer type corresponding to size_t.
Module * getCurrentNamedModule() const
Get module under construction, nullptr if this is not a C++20 module.
bool isModuleImplementation() const
Is this a module implementation.
Exposes information about the current target.
const llvm::fltSemantics & getDoubleFormat() const
const llvm::fltSemantics & getHalfFormat() const
const llvm::fltSemantics & getBFloat16Format() const
const llvm::fltSemantics & getLongDoubleFormat() const
const llvm::fltSemantics & getFloatFormat() const
const llvm::fltSemantics & getFloat128Format() const
Defines the clang::TargetInfo interface.
const internal::VariadicAllOfMatcher< Type > type
Matches Types in the clang AST.
LLVM_READONLY bool isPreprocessingNumberBody(unsigned char c)
Return true if this is the body character of a C preprocessing number, which is [a-zA-Z0-9_.
std::unique_ptr< Pass > createLoweringPreparePass()
static bool opGlobalThreadLocal()
static bool opGlobalAnnotations()
static bool opGlobalCtorPriority()
static bool loweringPrepareX86CXXABI()
static bool opFuncExtraAttrs()
static bool fastMathFlags()
static bool loweringPrepareAArch64XXABI()
static bool astVarDeclInterface()