clang 23.0.0git
LoweringPrepare.cpp
Go to the documentation of this file.
1//===- LoweringPrepare.cpp - pareparation work for LLVM lowering ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PassDetail.h"
10#include "mlir/IR/Attributes.h"
12#include "clang/Basic/Module.h"
20#include "llvm/Support/Path.h"
21
22#include <memory>
23
24using namespace mlir;
25using namespace cir;
26
27namespace mlir {
28#define GEN_PASS_DEF_LOWERINGPREPARE
29#include "clang/CIR/Dialect/Passes.h.inc"
30} // namespace mlir
31
32static SmallString<128> getTransformedFileName(mlir::ModuleOp mlirModule) {
33 SmallString<128> fileName;
34
35 if (mlirModule.getSymName())
36 fileName = llvm::sys::path::filename(mlirModule.getSymName()->str());
37
38 if (fileName.empty())
39 fileName = "<null>";
40
41 for (size_t i = 0; i < fileName.size(); ++i) {
42 // Replace everything that's not [a-zA-Z0-9._] with a _. This set happens
43 // to be the set of C preprocessing numbers.
44 if (!clang::isPreprocessingNumberBody(fileName[i]))
45 fileName[i] = '_';
46 }
47
48 return fileName;
49}
50
51/// Return the FuncOp called by `callOp`.
52static cir::FuncOp getCalledFunction(cir::CallOp callOp) {
53 mlir::SymbolRefAttr sym = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
54 callOp.getCallableForCallee());
55 if (!sym)
56 return nullptr;
57 return dyn_cast_or_null<cir::FuncOp>(
58 mlir::SymbolTable::lookupNearestSymbolFrom(callOp, sym));
59}
60
61namespace {
62struct LoweringPreparePass
63 : public impl::LoweringPrepareBase<LoweringPreparePass> {
64 LoweringPreparePass() = default;
65 void runOnOperation() override;
66
67 void runOnOp(mlir::Operation *op);
68 void lowerCastOp(cir::CastOp op);
69 void lowerComplexDivOp(cir::ComplexDivOp op);
70 void lowerComplexMulOp(cir::ComplexMulOp op);
71 void lowerUnaryOp(cir::UnaryOp op);
72 void lowerGlobalOp(cir::GlobalOp op);
73 void lowerArrayDtor(cir::ArrayDtor op);
74 void lowerArrayCtor(cir::ArrayCtor op);
75 void lowerTrivialCopyCall(cir::CallOp op);
76
77 /// Build the function that initializes the specified global
78 cir::FuncOp buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op);
79
80 /// Handle the dtor region by registering destructor with __cxa_atexit
81 cir::FuncOp getOrCreateDtorFunc(CIRBaseBuilderTy &builder, cir::GlobalOp op,
82 mlir::Region &dtorRegion,
83 cir::CallOp &dtorCall);
84
85 /// Build a module init function that calls all the dynamic initializers.
86 void buildCXXGlobalInitFunc();
87
88 /// Materialize global ctor/dtor list
89 void buildGlobalCtorDtorList();
90
91 cir::FuncOp buildRuntimeFunction(
92 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
93 cir::FuncType type,
94 cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage);
95
96 cir::GlobalOp buildRuntimeVariable(
97 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
98 mlir::Type type,
99 cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage,
100 cir::VisibilityKind visibility = cir::VisibilityKind::Default);
101
102 ///
103 /// AST related
104 /// -----------
105
106 clang::ASTContext *astCtx;
107
108 /// Tracks current module.
109 mlir::ModuleOp mlirModule;
110
111 /// Tracks existing dynamic initializers.
112 llvm::StringMap<uint32_t> dynamicInitializerNames;
113 llvm::SmallVector<cir::FuncOp> dynamicInitializers;
114
115 /// List of ctors and their priorities to be called before main()
116 llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalCtorList;
117 /// List of dtors and their priorities to be called when unloading module.
118 llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalDtorList;
119
120 void setASTContext(clang::ASTContext *c) { astCtx = c; }
121};
122
123} // namespace
124
125cir::GlobalOp LoweringPreparePass::buildRuntimeVariable(
126 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
127 mlir::Type type, cir::GlobalLinkageKind linkage,
128 cir::VisibilityKind visibility) {
129 cir::GlobalOp g = dyn_cast_or_null<cir::GlobalOp>(
130 mlir::SymbolTable::lookupNearestSymbolFrom(
131 mlirModule, mlir::StringAttr::get(mlirModule->getContext(), name)));
132 if (!g) {
133 g = cir::GlobalOp::create(builder, loc, name, type);
134 g.setLinkageAttr(
135 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
136 mlir::SymbolTable::setSymbolVisibility(
137 g, mlir::SymbolTable::Visibility::Private);
138 g.setGlobalVisibilityAttr(
139 cir::VisibilityAttr::get(builder.getContext(), visibility));
140 }
141 return g;
142}
143
144cir::FuncOp LoweringPreparePass::buildRuntimeFunction(
145 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
146 cir::FuncType type, cir::GlobalLinkageKind linkage) {
147 cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom(
148 mlirModule, StringAttr::get(mlirModule->getContext(), name)));
149 if (!f) {
150 f = cir::FuncOp::create(builder, loc, name, type);
151 f.setLinkageAttr(
152 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
153 mlir::SymbolTable::setSymbolVisibility(
154 f, mlir::SymbolTable::Visibility::Private);
155
157 }
158 return f;
159}
160
161static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx,
162 cir::CastOp op) {
163 cir::CIRBaseBuilderTy builder(ctx);
164 builder.setInsertionPoint(op);
165
166 mlir::Value src = op.getSrc();
167 mlir::Value imag = builder.getNullValue(src.getType(), op.getLoc());
168 return builder.createComplexCreate(op.getLoc(), src, imag);
169}
170
171static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx,
172 cir::CastOp op,
173 cir::CastKind elemToBoolKind) {
174 cir::CIRBaseBuilderTy builder(ctx);
175 builder.setInsertionPoint(op);
176
177 mlir::Value src = op.getSrc();
178 if (!mlir::isa<cir::BoolType>(op.getType()))
179 return builder.createComplexReal(op.getLoc(), src);
180
181 // Complex cast to bool: (bool)(a+bi) => (bool)a || (bool)b
182 mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
183 mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
184
185 cir::BoolType boolTy = builder.getBoolTy();
186 mlir::Value srcRealToBool =
187 builder.createCast(op.getLoc(), elemToBoolKind, srcReal, boolTy);
188 mlir::Value srcImagToBool =
189 builder.createCast(op.getLoc(), elemToBoolKind, srcImag, boolTy);
190 return builder.createLogicalOr(op.getLoc(), srcRealToBool, srcImagToBool);
191}
192
193static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx,
194 cir::CastOp op,
195 cir::CastKind scalarCastKind) {
196 CIRBaseBuilderTy builder(ctx);
197 builder.setInsertionPoint(op);
198
199 mlir::Value src = op.getSrc();
200 auto dstComplexElemTy =
201 mlir::cast<cir::ComplexType>(op.getType()).getElementType();
202
203 mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
204 mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
205
206 mlir::Value dstReal = builder.createCast(op.getLoc(), scalarCastKind, srcReal,
207 dstComplexElemTy);
208 mlir::Value dstImag = builder.createCast(op.getLoc(), scalarCastKind, srcImag,
209 dstComplexElemTy);
210 return builder.createComplexCreate(op.getLoc(), dstReal, dstImag);
211}
212
213void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
214 mlir::MLIRContext &ctx = getContext();
215 mlir::Value loweredValue = [&]() -> mlir::Value {
216 switch (op.getKind()) {
217 case cir::CastKind::float_to_complex:
218 case cir::CastKind::int_to_complex:
219 return lowerScalarToComplexCast(ctx, op);
220 case cir::CastKind::float_complex_to_real:
221 case cir::CastKind::int_complex_to_real:
222 return lowerComplexToScalarCast(ctx, op, op.getKind());
223 case cir::CastKind::float_complex_to_bool:
224 return lowerComplexToScalarCast(ctx, op, cir::CastKind::float_to_bool);
225 case cir::CastKind::int_complex_to_bool:
226 return lowerComplexToScalarCast(ctx, op, cir::CastKind::int_to_bool);
227 case cir::CastKind::float_complex:
228 return lowerComplexToComplexCast(ctx, op, cir::CastKind::floating);
229 case cir::CastKind::float_complex_to_int_complex:
230 return lowerComplexToComplexCast(ctx, op, cir::CastKind::float_to_int);
231 case cir::CastKind::int_complex:
232 return lowerComplexToComplexCast(ctx, op, cir::CastKind::integral);
233 case cir::CastKind::int_complex_to_float_complex:
234 return lowerComplexToComplexCast(ctx, op, cir::CastKind::int_to_float);
235 default:
236 return nullptr;
237 }
238 }();
239
240 if (loweredValue) {
241 op.replaceAllUsesWith(loweredValue);
242 op.erase();
243 }
244}
245
246static mlir::Value buildComplexBinOpLibCall(
247 LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
248 llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics),
249 mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal,
250 mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) {
251 cir::FPTypeInterface elementTy =
252 mlir::cast<cir::FPTypeInterface>(ty.getElementType());
253
254 llvm::StringRef libFuncName = libFuncNameGetter(
255 llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics()));
256 llvm::SmallVector<mlir::Type, 4> libFuncInputTypes(4, elementTy);
257
258 cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty);
259
260 // Insert a declaration for the runtime function to be used in Complex
261 // multiplication and division when needed
262 cir::FuncOp libFunc;
263 {
264 mlir::OpBuilder::InsertionGuard ipGuard{builder};
265 builder.setInsertionPointToStart(pass.mlirModule.getBody());
266 libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy);
267 }
268
269 cir::CallOp call =
270 builder.createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag});
271 return call.getResult();
272}
273
274static llvm::StringRef
275getComplexDivLibCallName(llvm::APFloat::Semantics semantics) {
276 switch (semantics) {
277 case llvm::APFloat::S_IEEEhalf:
278 return "__divhc3";
279 case llvm::APFloat::S_IEEEsingle:
280 return "__divsc3";
281 case llvm::APFloat::S_IEEEdouble:
282 return "__divdc3";
283 case llvm::APFloat::S_PPCDoubleDouble:
284 return "__divtc3";
285 case llvm::APFloat::S_x87DoubleExtended:
286 return "__divxc3";
287 case llvm::APFloat::S_IEEEquad:
288 return "__divtc3";
289 default:
290 llvm_unreachable("unsupported floating point type");
291 }
292}
293
294static mlir::Value
295buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc,
296 mlir::Value lhsReal, mlir::Value lhsImag,
297 mlir::Value rhsReal, mlir::Value rhsImag) {
298 // (a+bi) / (c+di) = ((ac+bd)/(cc+dd)) + ((bc-ad)/(cc+dd))i
299 mlir::Value &a = lhsReal;
300 mlir::Value &b = lhsImag;
301 mlir::Value &c = rhsReal;
302 mlir::Value &d = rhsImag;
303
304 mlir::Value ac = builder.createBinop(loc, a, cir::BinOpKind::Mul, c); // a*c
305 mlir::Value bd = builder.createBinop(loc, b, cir::BinOpKind::Mul, d); // b*d
306 mlir::Value cc = builder.createBinop(loc, c, cir::BinOpKind::Mul, c); // c*c
307 mlir::Value dd = builder.createBinop(loc, d, cir::BinOpKind::Mul, d); // d*d
308 mlir::Value acbd =
309 builder.createBinop(loc, ac, cir::BinOpKind::Add, bd); // ac+bd
310 mlir::Value ccdd =
311 builder.createBinop(loc, cc, cir::BinOpKind::Add, dd); // cc+dd
312 mlir::Value resultReal =
313 builder.createBinop(loc, acbd, cir::BinOpKind::Div, ccdd);
314
315 mlir::Value bc = builder.createBinop(loc, b, cir::BinOpKind::Mul, c); // b*c
316 mlir::Value ad = builder.createBinop(loc, a, cir::BinOpKind::Mul, d); // a*d
317 mlir::Value bcad =
318 builder.createBinop(loc, bc, cir::BinOpKind::Sub, ad); // bc-ad
319 mlir::Value resultImag =
320 builder.createBinop(loc, bcad, cir::BinOpKind::Div, ccdd);
321 return builder.createComplexCreate(loc, resultReal, resultImag);
322}
323
324static mlir::Value
326 mlir::Value lhsReal, mlir::Value lhsImag,
327 mlir::Value rhsReal, mlir::Value rhsImag) {
328 // Implements Smith's algorithm for complex division.
329 // SMITH, R. L. Algorithm 116: Complex division. Commun. ACM 5, 8 (1962).
330
331 // Let:
332 // - lhs := a+bi
333 // - rhs := c+di
334 // - result := lhs / rhs = e+fi
335 //
336 // The algorithm pseudocode looks like follows:
337 // if fabs(c) >= fabs(d):
338 // r := d / c
339 // tmp := c + r*d
340 // e = (a + b*r) / tmp
341 // f = (b - a*r) / tmp
342 // else:
343 // r := c / d
344 // tmp := d + r*c
345 // e = (a*r + b) / tmp
346 // f = (b*r - a) / tmp
347
348 mlir::Value &a = lhsReal;
349 mlir::Value &b = lhsImag;
350 mlir::Value &c = rhsReal;
351 mlir::Value &d = rhsImag;
352
353 auto trueBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
354 mlir::Value r = builder.createBinop(loc, d, cir::BinOpKind::Div,
355 c); // r := d / c
356 mlir::Value rd = builder.createBinop(loc, r, cir::BinOpKind::Mul, d); // r*d
357 mlir::Value tmp = builder.createBinop(loc, c, cir::BinOpKind::Add,
358 rd); // tmp := c + r*d
359
360 mlir::Value br = builder.createBinop(loc, b, cir::BinOpKind::Mul, r); // b*r
361 mlir::Value abr =
362 builder.createBinop(loc, a, cir::BinOpKind::Add, br); // a + b*r
363 mlir::Value e = builder.createBinop(loc, abr, cir::BinOpKind::Div, tmp);
364
365 mlir::Value ar = builder.createBinop(loc, a, cir::BinOpKind::Mul, r); // a*r
366 mlir::Value bar =
367 builder.createBinop(loc, b, cir::BinOpKind::Sub, ar); // b - a*r
368 mlir::Value f = builder.createBinop(loc, bar, cir::BinOpKind::Div, tmp);
369
370 mlir::Value result = builder.createComplexCreate(loc, e, f);
371 builder.createYield(loc, result);
372 };
373
374 auto falseBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
375 mlir::Value r = builder.createBinop(loc, c, cir::BinOpKind::Div,
376 d); // r := c / d
377 mlir::Value rc = builder.createBinop(loc, r, cir::BinOpKind::Mul, c); // r*c
378 mlir::Value tmp = builder.createBinop(loc, d, cir::BinOpKind::Add,
379 rc); // tmp := d + r*c
380
381 mlir::Value ar = builder.createBinop(loc, a, cir::BinOpKind::Mul, r); // a*r
382 mlir::Value arb =
383 builder.createBinop(loc, ar, cir::BinOpKind::Add, b); // a*r + b
384 mlir::Value e = builder.createBinop(loc, arb, cir::BinOpKind::Div, tmp);
385
386 mlir::Value br = builder.createBinop(loc, b, cir::BinOpKind::Mul, r); // b*r
387 mlir::Value bra =
388 builder.createBinop(loc, br, cir::BinOpKind::Sub, a); // b*r - a
389 mlir::Value f = builder.createBinop(loc, bra, cir::BinOpKind::Div, tmp);
390
391 mlir::Value result = builder.createComplexCreate(loc, e, f);
392 builder.createYield(loc, result);
393 };
394
395 auto cFabs = cir::FAbsOp::create(builder, loc, c);
396 auto dFabs = cir::FAbsOp::create(builder, loc, d);
397 cir::CmpOp cmpResult =
398 builder.createCompare(loc, cir::CmpOpKind::ge, cFabs, dFabs);
399 auto ternary = cir::TernaryOp::create(builder, loc, cmpResult,
400 trueBranchBuilder, falseBranchBuilder);
401
402 return ternary.getResult();
403}
404
406 mlir::MLIRContext &context, clang::ASTContext &cc,
407 CIRBaseBuilderTy &builder, mlir::Type elementType) {
408
409 auto getHigherPrecisionFPType = [&context](mlir::Type type) -> mlir::Type {
410 if (mlir::isa<cir::FP16Type>(type))
411 return cir::SingleType::get(&context);
412
413 if (mlir::isa<cir::SingleType>(type) || mlir::isa<cir::BF16Type>(type))
414 return cir::DoubleType::get(&context);
415
416 if (mlir::isa<cir::DoubleType>(type))
417 return cir::LongDoubleType::get(&context, type);
418
419 return type;
420 };
421
422 auto getFloatTypeSemantics =
423 [&cc](mlir::Type type) -> const llvm::fltSemantics & {
424 const clang::TargetInfo &info = cc.getTargetInfo();
425 if (mlir::isa<cir::FP16Type>(type))
426 return info.getHalfFormat();
427
428 if (mlir::isa<cir::BF16Type>(type))
429 return info.getBFloat16Format();
430
431 if (mlir::isa<cir::SingleType>(type))
432 return info.getFloatFormat();
433
434 if (mlir::isa<cir::DoubleType>(type))
435 return info.getDoubleFormat();
436
437 if (mlir::isa<cir::LongDoubleType>(type)) {
438 if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
439 llvm_unreachable("NYI Float type semantics with OpenMP");
440 return info.getLongDoubleFormat();
441 }
442
443 if (mlir::isa<cir::FP128Type>(type)) {
444 if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
445 llvm_unreachable("NYI Float type semantics with OpenMP");
446 return info.getFloat128Format();
447 }
448
449 llvm_unreachable("Unsupported float type semantics");
450 };
451
452 const mlir::Type higherElementType = getHigherPrecisionFPType(elementType);
453 const llvm::fltSemantics &elementTypeSemantics =
454 getFloatTypeSemantics(elementType);
455 const llvm::fltSemantics &higherElementTypeSemantics =
456 getFloatTypeSemantics(higherElementType);
457
458 // Check that the promoted type can handle the intermediate values without
459 // overflowing. This can be interpreted as:
460 // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal) * 2 <=
461 // LargerType.LargestFiniteVal.
462 // In terms of exponent it gives this formula:
463 // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal
464 // doubles the exponent of SmallerType.LargestFiniteVal)
465 if (llvm::APFloat::semanticsMaxExponent(elementTypeSemantics) * 2 + 1 <=
466 llvm::APFloat::semanticsMaxExponent(higherElementTypeSemantics)) {
467 return higherElementType;
468 }
469
470 // The intermediate values can't be represented in the promoted type
471 // without overflowing.
472 return {};
473}
474
475static mlir::Value
476lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
477 mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal,
478 mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag,
479 mlir::MLIRContext &mlirCx, clang::ASTContext &cc) {
480 cir::ComplexType complexTy = op.getType();
481 if (mlir::isa<cir::FPTypeInterface>(complexTy.getElementType())) {
482 cir::ComplexRangeKind range = op.getRange();
483 if (range == cir::ComplexRangeKind::Improved)
484 return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
485 rhsReal, rhsImag);
486
487 if (range == cir::ComplexRangeKind::Full)
489 loc, complexTy, lhsReal, lhsImag, rhsReal,
490 rhsImag);
491
492 if (range == cir::ComplexRangeKind::Promoted) {
493 mlir::Type originalElementType = complexTy.getElementType();
494 mlir::Type higherPrecisionElementType =
496 originalElementType);
497
498 if (!higherPrecisionElementType)
499 return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
500 rhsReal, rhsImag);
501
502 cir::CastKind floatingCastKind = cir::CastKind::floating;
503 lhsReal = builder.createCast(floatingCastKind, lhsReal,
504 higherPrecisionElementType);
505 lhsImag = builder.createCast(floatingCastKind, lhsImag,
506 higherPrecisionElementType);
507 rhsReal = builder.createCast(floatingCastKind, rhsReal,
508 higherPrecisionElementType);
509 rhsImag = builder.createCast(floatingCastKind, rhsImag,
510 higherPrecisionElementType);
511
512 mlir::Value algebraicResult = buildAlgebraicComplexDiv(
513 builder, loc, lhsReal, lhsImag, rhsReal, rhsImag);
514
515 mlir::Value resultReal = builder.createComplexReal(loc, algebraicResult);
516 mlir::Value resultImag = builder.createComplexImag(loc, algebraicResult);
517
518 mlir::Value finalReal =
519 builder.createCast(floatingCastKind, resultReal, originalElementType);
520 mlir::Value finalImag =
521 builder.createCast(floatingCastKind, resultImag, originalElementType);
522 return builder.createComplexCreate(loc, finalReal, finalImag);
523 }
524 }
525
526 return buildAlgebraicComplexDiv(builder, loc, lhsReal, lhsImag, rhsReal,
527 rhsImag);
528}
529
530void LoweringPreparePass::lowerComplexDivOp(cir::ComplexDivOp op) {
531 cir::CIRBaseBuilderTy builder(getContext());
532 builder.setInsertionPointAfter(op);
533 mlir::Location loc = op.getLoc();
534 mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
535 mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
536 mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
537 mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
538 mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
539 mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
540
541 mlir::Value loweredResult =
542 lowerComplexDiv(*this, builder, loc, op, lhsReal, lhsImag, rhsReal,
543 rhsImag, getContext(), *astCtx);
544 op.replaceAllUsesWith(loweredResult);
545 op.erase();
546}
547
548static llvm::StringRef
549getComplexMulLibCallName(llvm::APFloat::Semantics semantics) {
550 switch (semantics) {
551 case llvm::APFloat::S_IEEEhalf:
552 return "__mulhc3";
553 case llvm::APFloat::S_IEEEsingle:
554 return "__mulsc3";
555 case llvm::APFloat::S_IEEEdouble:
556 return "__muldc3";
557 case llvm::APFloat::S_PPCDoubleDouble:
558 return "__multc3";
559 case llvm::APFloat::S_x87DoubleExtended:
560 return "__mulxc3";
561 case llvm::APFloat::S_IEEEquad:
562 return "__multc3";
563 default:
564 llvm_unreachable("unsupported floating point type");
565 }
566}
567
568static mlir::Value lowerComplexMul(LoweringPreparePass &pass,
569 CIRBaseBuilderTy &builder,
570 mlir::Location loc, cir::ComplexMulOp op,
571 mlir::Value lhsReal, mlir::Value lhsImag,
572 mlir::Value rhsReal, mlir::Value rhsImag) {
573 // (a+bi) * (c+di) = (ac-bd) + (ad+bc)i
574 mlir::Value resultRealLhs =
575 builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsReal);
576 mlir::Value resultRealRhs =
577 builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsImag);
578 mlir::Value resultImagLhs =
579 builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsImag);
580 mlir::Value resultImagRhs =
581 builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsReal);
582 mlir::Value resultReal = builder.createBinop(
583 loc, resultRealLhs, cir::BinOpKind::Sub, resultRealRhs);
584 mlir::Value resultImag = builder.createBinop(
585 loc, resultImagLhs, cir::BinOpKind::Add, resultImagRhs);
586 mlir::Value algebraicResult =
587 builder.createComplexCreate(loc, resultReal, resultImag);
588
589 cir::ComplexType complexTy = op.getType();
590 cir::ComplexRangeKind rangeKind = op.getRange();
591 if (mlir::isa<cir::IntType>(complexTy.getElementType()) ||
592 rangeKind == cir::ComplexRangeKind::Basic ||
593 rangeKind == cir::ComplexRangeKind::Improved ||
594 rangeKind == cir::ComplexRangeKind::Promoted)
595 return algebraicResult;
596
598
599 // Check whether the real part and the imaginary part of the result are both
600 // NaN. If so, emit a library call to compute the multiplication instead.
601 // We check a value against NaN by comparing the value against itself.
602 mlir::Value resultRealIsNaN = builder.createIsNaN(loc, resultReal);
603 mlir::Value resultImagIsNaN = builder.createIsNaN(loc, resultImag);
604 mlir::Value resultRealAndImagAreNaN =
605 builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN);
606
607 return cir::TernaryOp::create(
608 builder, loc, resultRealAndImagAreNaN,
609 [&](mlir::OpBuilder &, mlir::Location) {
610 mlir::Value libCallResult = buildComplexBinOpLibCall(
611 pass, builder, &getComplexMulLibCallName, loc, complexTy,
612 lhsReal, lhsImag, rhsReal, rhsImag);
613 builder.createYield(loc, libCallResult);
614 },
615 [&](mlir::OpBuilder &, mlir::Location) {
616 builder.createYield(loc, algebraicResult);
617 })
618 .getResult();
619}
620
621void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) {
622 cir::CIRBaseBuilderTy builder(getContext());
623 builder.setInsertionPointAfter(op);
624 mlir::Location loc = op.getLoc();
625 mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
626 mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
627 mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
628 mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
629 mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
630 mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
631 mlir::Value loweredResult = lowerComplexMul(*this, builder, loc, op, lhsReal,
632 lhsImag, rhsReal, rhsImag);
633 op.replaceAllUsesWith(loweredResult);
634 op.erase();
635}
636
637void LoweringPreparePass::lowerUnaryOp(cir::UnaryOp op) {
638 mlir::Type ty = op.getType();
639 if (!mlir::isa<cir::ComplexType>(ty))
640 return;
641
642 mlir::Location loc = op.getLoc();
643 cir::UnaryOpKind opKind = op.getKind();
644
645 CIRBaseBuilderTy builder(getContext());
646 builder.setInsertionPointAfter(op);
647
648 mlir::Value operand = op.getInput();
649 mlir::Value operandReal = builder.createComplexReal(loc, operand);
650 mlir::Value operandImag = builder.createComplexImag(loc, operand);
651
652 mlir::Value resultReal;
653 mlir::Value resultImag;
654
655 switch (opKind) {
656 case cir::UnaryOpKind::Inc:
657 case cir::UnaryOpKind::Dec:
658 resultReal = builder.createUnaryOp(loc, opKind, operandReal);
659 resultImag = operandImag;
660 break;
661
662 case cir::UnaryOpKind::Plus:
663 case cir::UnaryOpKind::Minus:
664 resultReal = builder.createUnaryOp(loc, opKind, operandReal);
665 resultImag = builder.createUnaryOp(loc, opKind, operandImag);
666 break;
667
668 case cir::UnaryOpKind::Not:
669 resultReal = operandReal;
670 resultImag =
671 builder.createUnaryOp(loc, cir::UnaryOpKind::Minus, operandImag);
672 break;
673 }
674
675 mlir::Value result = builder.createComplexCreate(loc, resultReal, resultImag);
676 op.replaceAllUsesWith(result);
677 op.erase();
678}
679
680cir::FuncOp LoweringPreparePass::getOrCreateDtorFunc(CIRBaseBuilderTy &builder,
681 cir::GlobalOp op,
682 mlir::Region &dtorRegion,
683 cir::CallOp &dtorCall) {
684 mlir::OpBuilder::InsertionGuard guard(builder);
687
688 cir::VoidType voidTy = builder.getVoidTy();
689 auto voidPtrTy = cir::PointerType::get(voidTy);
690
691 // Look for operations in dtorBlock
692 mlir::Block &dtorBlock = dtorRegion.front();
693
694 // The first operation should be a get_global to retrieve the address
695 // of the global variable we're destroying.
696 auto opIt = dtorBlock.getOperations().begin();
697 cir::GetGlobalOp ggop = mlir::cast<cir::GetGlobalOp>(*opIt);
698
699 // The simple case is just a call to a destructor, like this:
700 //
701 // %0 = cir.get_global %globalS : !cir.ptr<!rec_S>
702 // cir.call %_ZN1SD1Ev(%0) : (!cir.ptr<!rec_S>) -> ()
703 // (implicit cir.yield)
704 //
705 // That is, if the second operation is a call that takes the get_global result
706 // as its only operand, and the only other operation is a yield, then we can
707 // just return the called function.
708 if (dtorBlock.getOperations().size() == 3) {
709 auto callOp = mlir::dyn_cast<cir::CallOp>(&*(++opIt));
710 auto yieldOp = mlir::dyn_cast<cir::YieldOp>(&*(++opIt));
711 if (yieldOp && callOp && callOp.getNumOperands() == 1 &&
712 callOp.getArgOperand(0) == ggop) {
713 dtorCall = callOp;
714 return getCalledFunction(callOp);
715 }
716 }
717
718 // Otherwise, we need to create a helper function to replace the dtor region.
719 // This name is kind of arbitrary, but it matches the name that classic
720 // codegen uses, based on the expected case that gets us here.
721 builder.setInsertionPointAfter(op);
722 SmallString<256> fnName("__cxx_global_array_dtor");
723 uint32_t cnt = dynamicInitializerNames[fnName]++;
724 if (cnt)
725 fnName += "." + std::to_string(cnt);
726
727 // Create the helper function.
728 auto fnType = cir::FuncType::get({voidPtrTy}, voidTy);
729 cir::FuncOp dtorFunc =
730 buildRuntimeFunction(builder, fnName, op.getLoc(), fnType,
731 cir::GlobalLinkageKind::InternalLinkage);
732 mlir::Block *entryBB = dtorFunc.addEntryBlock();
733
734 // Move everything from the dtor region into the helper function.
735 entryBB->getOperations().splice(entryBB->begin(), dtorBlock.getOperations(),
736 dtorBlock.begin(), dtorBlock.end());
737
738 // Before erasing this, clone it back into the dtor region
739 cir::GetGlobalOp dtorGGop =
740 mlir::cast<cir::GetGlobalOp>(entryBB->getOperations().front());
741 builder.setInsertionPointToStart(&dtorBlock);
742 builder.clone(*dtorGGop.getOperation());
743
744 // Replace all uses of the help function's get_global with the function
745 // argument.
746 mlir::Value dtorArg = entryBB->getArgument(0);
747 dtorGGop.replaceAllUsesWith(dtorArg);
748 dtorGGop.erase();
749
750 // Replace the yield in the final block with a return
751 mlir::Block &finalBlock = dtorFunc.getBody().back();
752 auto yieldOp = cast<cir::YieldOp>(finalBlock.getTerminator());
753 builder.setInsertionPoint(yieldOp);
754 cir::ReturnOp::create(builder, yieldOp->getLoc());
755 yieldOp->erase();
756
757 // Create a call to the helper function, passing the original get_global op
758 // as the argument.
759 cir::GetGlobalOp origGGop =
760 mlir::cast<cir::GetGlobalOp>(dtorBlock.getOperations().front());
761 builder.setInsertionPointAfter(origGGop);
762 mlir::Value ggopResult = origGGop.getResult();
763 dtorCall = builder.createCallOp(op.getLoc(), dtorFunc, ggopResult);
764
765 // Add a yield after the call.
766 auto finalYield = cir::YieldOp::create(builder, op.getLoc());
767
768 // Erase everything after the yield.
769 dtorBlock.getOperations().erase(std::next(mlir::Block::iterator(finalYield)),
770 dtorBlock.end());
771 dtorRegion.getBlocks().erase(std::next(dtorRegion.begin()), dtorRegion.end());
772
773 return dtorFunc;
774}
775
776cir::FuncOp
777LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op) {
778 // TODO(cir): Store this in the GlobalOp.
779 // This should come from the MangleContext, but for now I'm hardcoding it.
780 SmallString<256> fnName("__cxx_global_var_init");
781 // Get a unique name
782 uint32_t cnt = dynamicInitializerNames[fnName]++;
783 if (cnt)
784 fnName += "." + std::to_string(cnt);
785
786 // Create a variable initialization function.
787 CIRBaseBuilderTy builder(getContext());
788 builder.setInsertionPointAfter(op);
789 cir::VoidType voidTy = builder.getVoidTy();
790 auto fnType = cir::FuncType::get({}, voidTy);
791 FuncOp f = buildRuntimeFunction(builder, fnName, op.getLoc(), fnType,
792 cir::GlobalLinkageKind::InternalLinkage);
793
794 // Move over the initialzation code of the ctor region.
795 mlir::Block *entryBB = f.addEntryBlock();
796 if (!op.getCtorRegion().empty()) {
797 mlir::Block &block = op.getCtorRegion().front();
798 entryBB->getOperations().splice(entryBB->begin(), block.getOperations(),
799 block.begin(), std::prev(block.end()));
800 }
801
802 // Register the destructor call with __cxa_atexit
803 mlir::Region &dtorRegion = op.getDtorRegion();
804 if (!dtorRegion.empty()) {
807
808 // Create a variable that binds the atexit to this shared object.
809 builder.setInsertionPointToStart(&mlirModule.getBodyRegion().front());
810 cir::GlobalOp handle = buildRuntimeVariable(
811 builder, "__dso_handle", op.getLoc(), builder.getI8Type(),
812 cir::GlobalLinkageKind::ExternalLinkage, cir::VisibilityKind::Hidden);
813
814 // If this is a simple call to a destructor, get the called function.
815 // Otherwise, create a helper function for the entire dtor region,
816 // replacing the current dtor region body with a call to the helper
817 // function.
818 cir::CallOp dtorCall;
819 cir::FuncOp dtorFunc =
820 getOrCreateDtorFunc(builder, op, dtorRegion, dtorCall);
821
822 // Create a runtime helper function:
823 // extern "C" int __cxa_atexit(void (*f)(void *), void *p, void *d);
824 auto voidPtrTy = cir::PointerType::get(voidTy);
825 auto voidFnTy = cir::FuncType::get({voidPtrTy}, voidTy);
826 auto voidFnPtrTy = cir::PointerType::get(voidFnTy);
827 auto handlePtrTy = cir::PointerType::get(handle.getSymType());
828 auto fnAtExitType =
829 cir::FuncType::get({voidFnPtrTy, voidPtrTy, handlePtrTy}, voidTy);
830 const char *nameAtExit = "__cxa_atexit";
831 cir::FuncOp fnAtExit =
832 buildRuntimeFunction(builder, nameAtExit, op.getLoc(), fnAtExitType);
833
834 // Replace the dtor (or helper) call with a call to
835 // __cxa_atexit(&dtor, &var, &__dso_handle)
836 builder.setInsertionPointAfter(dtorCall);
837 mlir::Value args[3];
838 auto dtorPtrTy = cir::PointerType::get(dtorFunc.getFunctionType());
839 // dtorPtrTy
840 args[0] = cir::GetGlobalOp::create(builder, dtorCall.getLoc(), dtorPtrTy,
841 dtorFunc.getSymName());
842 args[0] = cir::CastOp::create(builder, dtorCall.getLoc(), voidFnPtrTy,
843 cir::CastKind::bitcast, args[0]);
844 args[1] =
845 cir::CastOp::create(builder, dtorCall.getLoc(), voidPtrTy,
846 cir::CastKind::bitcast, dtorCall.getArgOperand(0));
847 args[2] = cir::GetGlobalOp::create(builder, handle.getLoc(), handlePtrTy,
848 handle.getSymName());
849 builder.createCallOp(dtorCall.getLoc(), fnAtExit, args);
850 dtorCall->erase();
851 mlir::Block &dtorBlock = dtorRegion.front();
852 entryBB->getOperations().splice(entryBB->end(), dtorBlock.getOperations(),
853 dtorBlock.begin(),
854 std::prev(dtorBlock.end()));
855 }
856
857 // Replace cir.yield with cir.return
858 builder.setInsertionPointToEnd(entryBB);
859 mlir::Operation *yieldOp = nullptr;
860 if (!op.getCtorRegion().empty()) {
861 mlir::Block &block = op.getCtorRegion().front();
862 yieldOp = &block.getOperations().back();
863 } else {
864 assert(!dtorRegion.empty());
865 mlir::Block &block = dtorRegion.front();
866 yieldOp = &block.getOperations().back();
867 }
868
869 assert(isa<cir::YieldOp>(*yieldOp));
870 cir::ReturnOp::create(builder, yieldOp->getLoc());
871 return f;
872}
873
874void LoweringPreparePass::lowerGlobalOp(GlobalOp op) {
875 mlir::Region &ctorRegion = op.getCtorRegion();
876 mlir::Region &dtorRegion = op.getDtorRegion();
877
878 if (!ctorRegion.empty() || !dtorRegion.empty()) {
879 // Build a variable initialization function and move the initialzation code
880 // in the ctor region over.
881 cir::FuncOp f = buildCXXGlobalVarDeclInitFunc(op);
882
883 // Clear the ctor and dtor region
884 ctorRegion.getBlocks().clear();
885 dtorRegion.getBlocks().clear();
886
888 dynamicInitializers.push_back(f);
889 }
890
892}
893
894template <typename AttributeTy>
895static llvm::SmallVector<mlir::Attribute>
896prepareCtorDtorAttrList(mlir::MLIRContext *context,
897 llvm::ArrayRef<std::pair<std::string, uint32_t>> list) {
899 for (const auto &[name, priority] : list)
900 attrs.push_back(AttributeTy::get(context, name, priority));
901 return attrs;
902}
903
904void LoweringPreparePass::buildGlobalCtorDtorList() {
905 if (!globalCtorList.empty()) {
906 llvm::SmallVector<mlir::Attribute> globalCtors =
908 globalCtorList);
909
910 mlirModule->setAttr(cir::CIRDialect::getGlobalCtorsAttrName(),
911 mlir::ArrayAttr::get(&getContext(), globalCtors));
912 }
913
914 if (!globalDtorList.empty()) {
915 llvm::SmallVector<mlir::Attribute> globalDtors =
917 globalDtorList);
918 mlirModule->setAttr(cir::CIRDialect::getGlobalDtorsAttrName(),
919 mlir::ArrayAttr::get(&getContext(), globalDtors));
920 }
921}
922
923void LoweringPreparePass::buildCXXGlobalInitFunc() {
924 if (dynamicInitializers.empty())
925 return;
926
927 // TODO: handle globals with a user-specified initialzation priority.
928 // TODO: handle default priority more nicely.
930
931 SmallString<256> fnName;
932 // Include the filename in the symbol name. Including "sub_" matches gcc
933 // and makes sure these symbols appear lexicographically behind the symbols
934 // with priority (TBD). Module implementation units behave the same
935 // way as a non-modular TU with imports.
936 // TODO: check CXX20ModuleInits
937 if (astCtx->getCurrentNamedModule() &&
939 llvm::raw_svector_ostream out(fnName);
940 std::unique_ptr<clang::MangleContext> mangleCtx(
941 astCtx->createMangleContext());
942 cast<clang::ItaniumMangleContext>(*mangleCtx)
943 .mangleModuleInitializer(astCtx->getCurrentNamedModule(), out);
944 } else {
945 fnName += "_GLOBAL__sub_I_";
946 fnName += getTransformedFileName(mlirModule);
947 }
948
949 CIRBaseBuilderTy builder(getContext());
950 builder.setInsertionPointToEnd(&mlirModule.getBodyRegion().back());
951 auto fnType = cir::FuncType::get({}, builder.getVoidTy());
952 cir::FuncOp f =
953 buildRuntimeFunction(builder, fnName, mlirModule.getLoc(), fnType,
954 cir::GlobalLinkageKind::ExternalLinkage);
955 builder.setInsertionPointToStart(f.addEntryBlock());
956 for (cir::FuncOp &f : dynamicInitializers)
957 builder.createCallOp(f.getLoc(), f, {});
958 // Add the global init function (not the individual ctor functions) to the
959 // global ctor list.
960 globalCtorList.emplace_back(fnName,
961 cir::GlobalCtorAttr::getDefaultPriority());
962
963 cir::ReturnOp::create(builder, f.getLoc());
964}
965
967 clang::ASTContext *astCtx,
968 mlir::Operation *op, mlir::Type eltTy,
969 mlir::Value arrayAddr, uint64_t arrayLen,
970 bool isCtor) {
971 // Generate loop to call into ctor/dtor for every element.
972 mlir::Location loc = op->getLoc();
973
974 // TODO: instead of getting the size from the AST context, create alias for
975 // PtrDiffTy and unify with CIRGen stuff.
976 const unsigned sizeTypeSize =
977 astCtx->getTypeSize(astCtx->getSignedSizeType());
978 uint64_t endOffset = isCtor ? arrayLen : arrayLen - 1;
979 mlir::Value endOffsetVal =
980 builder.getUnsignedInt(loc, endOffset, sizeTypeSize);
981
982 auto begin = cir::CastOp::create(builder, loc, eltTy,
983 cir::CastKind::array_to_ptrdecay, arrayAddr);
984 mlir::Value end =
985 cir::PtrStrideOp::create(builder, loc, eltTy, begin, endOffsetVal);
986 mlir::Value start = isCtor ? begin : end;
987 mlir::Value stop = isCtor ? end : begin;
988
989 mlir::Value tmpAddr = builder.createAlloca(
990 loc, /*addr type*/ builder.getPointerTo(eltTy),
991 /*var type*/ eltTy, "__array_idx", builder.getAlignmentAttr(1));
992 builder.createStore(loc, start, tmpAddr);
993
994 cir::DoWhileOp loop = builder.createDoWhile(
995 loc,
996 /*condBuilder=*/
997 [&](mlir::OpBuilder &b, mlir::Location loc) {
998 auto currentElement = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
999 auto cmp = cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne,
1000 currentElement, stop);
1001 builder.createCondition(cmp);
1002 },
1003 /*bodyBuilder=*/
1004 [&](mlir::OpBuilder &b, mlir::Location loc) {
1005 auto currentElement = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
1006
1007 cir::CallOp ctorCall;
1008 op->walk([&](cir::CallOp c) { ctorCall = c; });
1009 assert(ctorCall && "expected ctor call");
1010
1011 // Array elements get constructed in order but destructed in reverse.
1012 mlir::Value stride;
1013 if (isCtor)
1014 stride = builder.getUnsignedInt(loc, 1, sizeTypeSize);
1015 else
1016 stride = builder.getSignedInt(loc, -1, sizeTypeSize);
1017
1018 ctorCall->moveBefore(stride.getDefiningOp());
1019 ctorCall->setOperand(0, currentElement);
1020 auto nextElement = cir::PtrStrideOp::create(builder, loc, eltTy,
1021 currentElement, stride);
1022
1023 // Store the element pointer to the temporary variable
1024 builder.createStore(loc, nextElement, tmpAddr);
1025 builder.createYield(loc);
1026 });
1027
1028 op->replaceAllUsesWith(loop);
1029 op->erase();
1030}
1031
1032void LoweringPreparePass::lowerArrayDtor(cir::ArrayDtor op) {
1033 CIRBaseBuilderTy builder(getContext());
1034 builder.setInsertionPointAfter(op.getOperation());
1035
1036 mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
1037 assert(!cir::MissingFeatures::vlas());
1038 auto arrayLen =
1039 mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
1040 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen,
1041 false);
1042}
1043
1044void LoweringPreparePass::lowerArrayCtor(cir::ArrayCtor op) {
1045 cir::CIRBaseBuilderTy builder(getContext());
1046 builder.setInsertionPointAfter(op.getOperation());
1047
1048 mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
1049 assert(!cir::MissingFeatures::vlas());
1050 auto arrayLen =
1051 mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
1052 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen,
1053 true);
1054}
1055
1056void LoweringPreparePass::lowerTrivialCopyCall(cir::CallOp op) {
1057 cir::FuncOp funcOp = getCalledFunction(op);
1058 if (!funcOp)
1059 return;
1060
1061 std::optional<cir::CtorKind> ctorKind = funcOp.getCxxConstructorKind();
1062 if (ctorKind && *ctorKind == cir::CtorKind::Copy &&
1063 funcOp.isCxxTrivialMemberFunction()) {
1064 // Replace the trivial copy constructor call with a `CopyOp`
1065 CIRBaseBuilderTy builder(getContext());
1066 mlir::ValueRange operands = op.getOperands();
1067 mlir::Value dest = operands[0];
1068 mlir::Value src = operands[1];
1069 builder.setInsertionPoint(op);
1070 builder.createCopy(dest, src);
1071 op.erase();
1072 }
1073}
1074
1075void LoweringPreparePass::runOnOp(mlir::Operation *op) {
1076 if (auto arrayCtor = dyn_cast<cir::ArrayCtor>(op)) {
1077 lowerArrayCtor(arrayCtor);
1078 } else if (auto arrayDtor = dyn_cast<cir::ArrayDtor>(op)) {
1079 lowerArrayDtor(arrayDtor);
1080 } else if (auto cast = mlir::dyn_cast<cir::CastOp>(op)) {
1081 lowerCastOp(cast);
1082 } else if (auto complexDiv = mlir::dyn_cast<cir::ComplexDivOp>(op)) {
1083 lowerComplexDivOp(complexDiv);
1084 } else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op)) {
1085 lowerComplexMulOp(complexMul);
1086 } else if (auto glob = mlir::dyn_cast<cir::GlobalOp>(op)) {
1087 lowerGlobalOp(glob);
1088 } else if (auto unary = mlir::dyn_cast<cir::UnaryOp>(op)) {
1089 lowerUnaryOp(unary);
1090 } else if (auto callOp = dyn_cast<cir::CallOp>(op)) {
1091 lowerTrivialCopyCall(callOp);
1092 } else if (auto fnOp = dyn_cast<cir::FuncOp>(op)) {
1093 if (auto globalCtor = fnOp.getGlobalCtorPriority())
1094 globalCtorList.emplace_back(fnOp.getName(), globalCtor.value());
1095 else if (auto globalDtor = fnOp.getGlobalDtorPriority())
1096 globalDtorList.emplace_back(fnOp.getName(), globalDtor.value());
1097 }
1098}
1099
1100void LoweringPreparePass::runOnOperation() {
1101 mlir::Operation *op = getOperation();
1102 if (isa<::mlir::ModuleOp>(op))
1103 mlirModule = cast<::mlir::ModuleOp>(op);
1104
1105 llvm::SmallVector<mlir::Operation *> opsToTransform;
1106
1107 op->walk([&](mlir::Operation *op) {
1108 if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
1109 cir::ComplexMulOp, cir::ComplexDivOp, cir::DynamicCastOp,
1110 cir::FuncOp, cir::CallOp, cir::GlobalOp, cir::UnaryOp>(op))
1111 opsToTransform.push_back(op);
1112 });
1113
1114 for (mlir::Operation *o : opsToTransform)
1115 runOnOp(o);
1116
1117 buildCXXGlobalInitFunc();
1118 buildGlobalCtorDtorList();
1119}
1120
1121std::unique_ptr<Pass> mlir::createLoweringPreparePass() {
1122 return std::make_unique<LoweringPreparePass>();
1123}
1124
1125std::unique_ptr<Pass>
1127 auto pass = std::make_unique<LoweringPreparePass>();
1128 pass->setASTContext(astCtx);
1129 return std::move(pass);
1130}
Defines the clang::ASTContext interface.
static mlir::Value buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder, clang::ASTContext *astCtx, mlir::Operation *op, mlir::Type eltTy, mlir::Value arrayAddr, uint64_t arrayLen, bool isCtor)
static llvm::StringRef getComplexDivLibCallName(llvm::APFloat::Semantics semantics)
static llvm::SmallVector< mlir::Attribute > prepareCtorDtorAttrList(mlir::MLIRContext *context, llvm::ArrayRef< std::pair< std::string, uint32_t > > list)
static llvm::StringRef getComplexMulLibCallName(llvm::APFloat::Semantics semantics)
static mlir::Value buildComplexBinOpLibCall(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, llvm::StringRef(*libFuncNameGetter)(llvm::APFloat::Semantics), mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static mlir::Value lowerComplexMul(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexMulOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static SmallString< 128 > getTransformedFileName(mlir::ModuleOp mlirModule)
static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind scalarCastKind)
static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind elemToBoolKind)
static mlir::Value buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static cir::FuncOp getCalledFunction(cir::CallOp callOp)
Return the FuncOp called by callOp.
static mlir::Type higherPrecisionElementTypeForComplexArithmetic(mlir::MLIRContext &context, clang::ASTContext &cc, CIRBaseBuilderTy &builder, mlir::Type elementType)
static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op)
static mlir::Value lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag, mlir::MLIRContext &mlirCx, clang::ASTContext &cc)
Defines the clang::Module class, which describes a module in the source code.
__device__ __2f16 b
__device__ __2f16 float c
mlir::Value createLogicalOr(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
cir::ConditionOp createCondition(mlir::Value condition)
Create a loop condition.
cir::VoidType getVoidTy()
cir::ConstantOp getNullValue(mlir::Type ty, mlir::Location loc)
mlir::Value createCast(mlir::Location loc, cir::CastKind kind, mlir::Value src, mlir::Type newTy)
cir::PointerType getPointerTo(mlir::Type ty)
mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand)
cir::DoWhileOp createDoWhile(mlir::Location loc, llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> condBuilder, llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> bodyBuilder)
Create a do-while operation.
cir::CopyOp createCopy(mlir::Value dst, mlir::Value src, bool isVolatile=false)
Create a copy with inferred length.
cir::CallOp createCallOp(mlir::Location loc, mlir::SymbolRefAttr callee, mlir::Type returnType, mlir::ValueRange operands, llvm::ArrayRef< mlir::NamedAttribute > attrs={})
mlir::Value getSignedInt(mlir::Location loc, int64_t val, unsigned numBits)
cir::CmpOp createCompare(mlir::Location loc, cir::CmpOpKind kind, mlir::Value lhs, mlir::Value rhs)
mlir::IntegerAttr getAlignmentAttr(clang::CharUnits alignment)
mlir::Value createBinop(mlir::Location loc, mlir::Value lhs, cir::BinOpKind kind, mlir::Value rhs)
mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real, mlir::Value imag)
mlir::Value createIsNaN(mlir::Location loc, mlir::Value operand)
cir::StoreOp createStore(mlir::Location loc, mlir::Value val, mlir::Value dst, bool isVolatile=false, mlir::IntegerAttr align={}, cir::SyncScopeKindAttr scope={}, cir::MemOrderAttr order={})
cir::YieldOp createYield(mlir::Location loc, mlir::ValueRange value={})
Create a yield operation.
mlir::Value createLogicalAnd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createUnaryOp(mlir::Location loc, cir::UnaryOpKind kind, mlir::Value operand)
mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType, mlir::Type type, llvm::StringRef name, mlir::IntegerAttr alignment, mlir::Value dynAllocSize)
cir::BoolType getBoolTy()
mlir::Value getUnsignedInt(mlir::Location loc, uint64_t val, unsigned numBits)
mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand)
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
Definition ASTContext.h:220
MangleContext * createMangleContext(const TargetInfo *T=nullptr)
If T is null pointer, assume the target in ASTContext.
const LangOptions & getLangOpts() const
Definition ASTContext.h:944
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
const TargetInfo & getTargetInfo() const
Definition ASTContext.h:909
QualType getSignedSizeType() const
Return the unique signed counterpart of the integer type corresponding to size_t.
Module * getCurrentNamedModule() const
Get module under construction, nullptr if this is not a C++20 module.
bool isModuleImplementation() const
Is this a module implementation.
Definition Module.h:664
Exposes information about the current target.
Definition TargetInfo.h:226
const llvm::fltSemantics & getDoubleFormat() const
Definition TargetInfo.h:803
const llvm::fltSemantics & getHalfFormat() const
Definition TargetInfo.h:788
const llvm::fltSemantics & getBFloat16Format() const
Definition TargetInfo.h:798
const llvm::fltSemantics & getLongDoubleFormat() const
Definition TargetInfo.h:809
const llvm::fltSemantics & getFloatFormat() const
Definition TargetInfo.h:793
const llvm::fltSemantics & getFloat128Format() const
Definition TargetInfo.h:817
Defines the clang::TargetInfo interface.
LLVM_READONLY bool isPreprocessingNumberBody(unsigned char c)
Return true if this is the body character of a C preprocessing number, which is [a-zA-Z0-9_.
Definition CharInfo.h:168
unsigned int uint32_t
std::unique_ptr< Pass > createLoweringPreparePass()
static bool opGlobalThreadLocal()
static bool opGlobalAnnotations()
static bool opGlobalCtorPriority()
static bool opFuncExtraAttrs()
static bool fastMathFlags()
static bool astVarDeclInterface()