clang 22.0.0git
LoweringPrepare.cpp
Go to the documentation of this file.
1//===- LoweringPrepare.cpp - pareparation work for LLVM lowering ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PassDetail.h"
17
18#include <memory>
19
20using namespace mlir;
21using namespace cir;
22
23namespace {
24struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
25 LoweringPreparePass() = default;
26 void runOnOperation() override;
27
28 void runOnOp(mlir::Operation *op);
29 void lowerCastOp(cir::CastOp op);
30 void lowerComplexDivOp(cir::ComplexDivOp op);
31 void lowerComplexMulOp(cir::ComplexMulOp op);
32 void lowerUnaryOp(cir::UnaryOp op);
33 void lowerArrayDtor(cir::ArrayDtor op);
34 void lowerArrayCtor(cir::ArrayCtor op);
35
36 cir::FuncOp buildRuntimeFunction(
37 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
38 cir::FuncType type,
39 cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage);
40
41 ///
42 /// AST related
43 /// -----------
44
45 clang::ASTContext *astCtx;
46
47 /// Tracks current module.
48 mlir::ModuleOp mlirModule;
49
50 void setASTContext(clang::ASTContext *c) { astCtx = c; }
51};
52
53} // namespace
54
55cir::FuncOp LoweringPreparePass::buildRuntimeFunction(
56 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
57 cir::FuncType type, cir::GlobalLinkageKind linkage) {
58 cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom(
59 mlirModule, StringAttr::get(mlirModule->getContext(), name)));
60 if (!f) {
61 f = builder.create<cir::FuncOp>(loc, name, type);
62 f.setLinkageAttr(
63 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
64 mlir::SymbolTable::setSymbolVisibility(
65 f, mlir::SymbolTable::Visibility::Private);
66
68 }
69 return f;
70}
71
72static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx,
73 cir::CastOp op) {
74 cir::CIRBaseBuilderTy builder(ctx);
75 builder.setInsertionPoint(op);
76
77 mlir::Value src = op.getSrc();
78 mlir::Value imag = builder.getNullValue(src.getType(), op.getLoc());
79 return builder.createComplexCreate(op.getLoc(), src, imag);
80}
81
82static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx,
83 cir::CastOp op,
84 cir::CastKind elemToBoolKind) {
85 cir::CIRBaseBuilderTy builder(ctx);
86 builder.setInsertionPoint(op);
87
88 mlir::Value src = op.getSrc();
89 if (!mlir::isa<cir::BoolType>(op.getType()))
90 return builder.createComplexReal(op.getLoc(), src);
91
92 // Complex cast to bool: (bool)(a+bi) => (bool)a || (bool)b
93 mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
94 mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
95
96 cir::BoolType boolTy = builder.getBoolTy();
97 mlir::Value srcRealToBool =
98 builder.createCast(op.getLoc(), elemToBoolKind, srcReal, boolTy);
99 mlir::Value srcImagToBool =
100 builder.createCast(op.getLoc(), elemToBoolKind, srcImag, boolTy);
101 return builder.createLogicalOr(op.getLoc(), srcRealToBool, srcImagToBool);
102}
103
104static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx,
105 cir::CastOp op,
106 cir::CastKind scalarCastKind) {
107 CIRBaseBuilderTy builder(ctx);
108 builder.setInsertionPoint(op);
109
110 mlir::Value src = op.getSrc();
111 auto dstComplexElemTy =
112 mlir::cast<cir::ComplexType>(op.getType()).getElementType();
113
114 mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
115 mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
116
117 mlir::Value dstReal = builder.createCast(op.getLoc(), scalarCastKind, srcReal,
118 dstComplexElemTy);
119 mlir::Value dstImag = builder.createCast(op.getLoc(), scalarCastKind, srcImag,
120 dstComplexElemTy);
121 return builder.createComplexCreate(op.getLoc(), dstReal, dstImag);
122}
123
124void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
125 mlir::MLIRContext &ctx = getContext();
126 mlir::Value loweredValue = [&]() -> mlir::Value {
127 switch (op.getKind()) {
128 case cir::CastKind::float_to_complex:
129 case cir::CastKind::int_to_complex:
130 return lowerScalarToComplexCast(ctx, op);
131 case cir::CastKind::float_complex_to_real:
132 case cir::CastKind::int_complex_to_real:
133 return lowerComplexToScalarCast(ctx, op, op.getKind());
134 case cir::CastKind::float_complex_to_bool:
135 return lowerComplexToScalarCast(ctx, op, cir::CastKind::float_to_bool);
136 case cir::CastKind::int_complex_to_bool:
137 return lowerComplexToScalarCast(ctx, op, cir::CastKind::int_to_bool);
138 case cir::CastKind::float_complex:
139 return lowerComplexToComplexCast(ctx, op, cir::CastKind::floating);
140 case cir::CastKind::float_complex_to_int_complex:
141 return lowerComplexToComplexCast(ctx, op, cir::CastKind::float_to_int);
142 case cir::CastKind::int_complex:
143 return lowerComplexToComplexCast(ctx, op, cir::CastKind::integral);
144 case cir::CastKind::int_complex_to_float_complex:
145 return lowerComplexToComplexCast(ctx, op, cir::CastKind::int_to_float);
146 default:
147 return nullptr;
148 }
149 }();
150
151 if (loweredValue) {
152 op.replaceAllUsesWith(loweredValue);
153 op.erase();
154 }
155}
156
157static mlir::Value buildComplexBinOpLibCall(
158 LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
159 llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics),
160 mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal,
161 mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) {
162 cir::FPTypeInterface elementTy =
163 mlir::cast<cir::FPTypeInterface>(ty.getElementType());
164
165 llvm::StringRef libFuncName = libFuncNameGetter(
166 llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics()));
167 llvm::SmallVector<mlir::Type, 4> libFuncInputTypes(4, elementTy);
168
169 cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty);
170
171 // Insert a declaration for the runtime function to be used in Complex
172 // multiplication and division when needed
173 cir::FuncOp libFunc;
174 {
175 mlir::OpBuilder::InsertionGuard ipGuard{builder};
176 builder.setInsertionPointToStart(pass.mlirModule.getBody());
177 libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy);
178 }
179
180 cir::CallOp call =
181 builder.createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag});
182 return call.getResult();
183}
184
185static llvm::StringRef
186getComplexDivLibCallName(llvm::APFloat::Semantics semantics) {
187 switch (semantics) {
188 case llvm::APFloat::S_IEEEhalf:
189 return "__divhc3";
190 case llvm::APFloat::S_IEEEsingle:
191 return "__divsc3";
192 case llvm::APFloat::S_IEEEdouble:
193 return "__divdc3";
194 case llvm::APFloat::S_PPCDoubleDouble:
195 return "__divtc3";
196 case llvm::APFloat::S_x87DoubleExtended:
197 return "__divxc3";
198 case llvm::APFloat::S_IEEEquad:
199 return "__divtc3";
200 default:
201 llvm_unreachable("unsupported floating point type");
202 }
203}
204
205static mlir::Value
206buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc,
207 mlir::Value lhsReal, mlir::Value lhsImag,
208 mlir::Value rhsReal, mlir::Value rhsImag) {
209 // (a+bi) / (c+di) = ((ac+bd)/(cc+dd)) + ((bc-ad)/(cc+dd))i
210 mlir::Value &a = lhsReal;
211 mlir::Value &b = lhsImag;
212 mlir::Value &c = rhsReal;
213 mlir::Value &d = rhsImag;
214
215 mlir::Value ac = builder.createBinop(loc, a, cir::BinOpKind::Mul, c); // a*c
216 mlir::Value bd = builder.createBinop(loc, b, cir::BinOpKind::Mul, d); // b*d
217 mlir::Value cc = builder.createBinop(loc, c, cir::BinOpKind::Mul, c); // c*c
218 mlir::Value dd = builder.createBinop(loc, d, cir::BinOpKind::Mul, d); // d*d
219 mlir::Value acbd =
220 builder.createBinop(loc, ac, cir::BinOpKind::Add, bd); // ac+bd
221 mlir::Value ccdd =
222 builder.createBinop(loc, cc, cir::BinOpKind::Add, dd); // cc+dd
223 mlir::Value resultReal =
224 builder.createBinop(loc, acbd, cir::BinOpKind::Div, ccdd);
225
226 mlir::Value bc = builder.createBinop(loc, b, cir::BinOpKind::Mul, c); // b*c
227 mlir::Value ad = builder.createBinop(loc, a, cir::BinOpKind::Mul, d); // a*d
228 mlir::Value bcad =
229 builder.createBinop(loc, bc, cir::BinOpKind::Sub, ad); // bc-ad
230 mlir::Value resultImag =
231 builder.createBinop(loc, bcad, cir::BinOpKind::Div, ccdd);
232 return builder.createComplexCreate(loc, resultReal, resultImag);
233}
234
235static mlir::Value
237 mlir::Value lhsReal, mlir::Value lhsImag,
238 mlir::Value rhsReal, mlir::Value rhsImag) {
239 // Implements Smith's algorithm for complex division.
240 // SMITH, R. L. Algorithm 116: Complex division. Commun. ACM 5, 8 (1962).
241
242 // Let:
243 // - lhs := a+bi
244 // - rhs := c+di
245 // - result := lhs / rhs = e+fi
246 //
247 // The algorithm pseudocode looks like follows:
248 // if fabs(c) >= fabs(d):
249 // r := d / c
250 // tmp := c + r*d
251 // e = (a + b*r) / tmp
252 // f = (b - a*r) / tmp
253 // else:
254 // r := c / d
255 // tmp := d + r*c
256 // e = (a*r + b) / tmp
257 // f = (b*r - a) / tmp
258
259 mlir::Value &a = lhsReal;
260 mlir::Value &b = lhsImag;
261 mlir::Value &c = rhsReal;
262 mlir::Value &d = rhsImag;
263
264 auto trueBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
265 mlir::Value r = builder.createBinop(loc, d, cir::BinOpKind::Div,
266 c); // r := d / c
267 mlir::Value rd = builder.createBinop(loc, r, cir::BinOpKind::Mul, d); // r*d
268 mlir::Value tmp = builder.createBinop(loc, c, cir::BinOpKind::Add,
269 rd); // tmp := c + r*d
270
271 mlir::Value br = builder.createBinop(loc, b, cir::BinOpKind::Mul, r); // b*r
272 mlir::Value abr =
273 builder.createBinop(loc, a, cir::BinOpKind::Add, br); // a + b*r
274 mlir::Value e = builder.createBinop(loc, abr, cir::BinOpKind::Div, tmp);
275
276 mlir::Value ar = builder.createBinop(loc, a, cir::BinOpKind::Mul, r); // a*r
277 mlir::Value bar =
278 builder.createBinop(loc, b, cir::BinOpKind::Sub, ar); // b - a*r
279 mlir::Value f = builder.createBinop(loc, bar, cir::BinOpKind::Div, tmp);
280
281 mlir::Value result = builder.createComplexCreate(loc, e, f);
282 builder.createYield(loc, result);
283 };
284
285 auto falseBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
286 mlir::Value r = builder.createBinop(loc, c, cir::BinOpKind::Div,
287 d); // r := c / d
288 mlir::Value rc = builder.createBinop(loc, r, cir::BinOpKind::Mul, c); // r*c
289 mlir::Value tmp = builder.createBinop(loc, d, cir::BinOpKind::Add,
290 rc); // tmp := d + r*c
291
292 mlir::Value ar = builder.createBinop(loc, a, cir::BinOpKind::Mul, r); // a*r
293 mlir::Value arb =
294 builder.createBinop(loc, ar, cir::BinOpKind::Add, b); // a*r + b
295 mlir::Value e = builder.createBinop(loc, arb, cir::BinOpKind::Div, tmp);
296
297 mlir::Value br = builder.createBinop(loc, b, cir::BinOpKind::Mul, r); // b*r
298 mlir::Value bra =
299 builder.createBinop(loc, br, cir::BinOpKind::Sub, a); // b*r - a
300 mlir::Value f = builder.createBinop(loc, bra, cir::BinOpKind::Div, tmp);
301
302 mlir::Value result = builder.createComplexCreate(loc, e, f);
303 builder.createYield(loc, result);
304 };
305
306 auto cFabs = builder.create<cir::FAbsOp>(loc, c);
307 auto dFabs = builder.create<cir::FAbsOp>(loc, d);
308 cir::CmpOp cmpResult =
309 builder.createCompare(loc, cir::CmpOpKind::ge, cFabs, dFabs);
310 auto ternary = builder.create<cir::TernaryOp>(
311 loc, cmpResult, trueBranchBuilder, falseBranchBuilder);
312
313 return ternary.getResult();
314}
315
317 mlir::MLIRContext &context, clang::ASTContext &cc,
318 CIRBaseBuilderTy &builder, mlir::Type elementType) {
319
320 auto getHigherPrecisionFPType = [&context](mlir::Type type) -> mlir::Type {
321 if (mlir::isa<cir::FP16Type>(type))
322 return cir::SingleType::get(&context);
323
324 if (mlir::isa<cir::SingleType>(type) || mlir::isa<cir::BF16Type>(type))
325 return cir::DoubleType::get(&context);
326
327 if (mlir::isa<cir::DoubleType>(type))
328 return cir::LongDoubleType::get(&context, type);
329
330 return type;
331 };
332
333 auto getFloatTypeSemantics =
334 [&cc](mlir::Type type) -> const llvm::fltSemantics & {
335 const clang::TargetInfo &info = cc.getTargetInfo();
336 if (mlir::isa<cir::FP16Type>(type))
337 return info.getHalfFormat();
338
339 if (mlir::isa<cir::BF16Type>(type))
340 return info.getBFloat16Format();
341
342 if (mlir::isa<cir::SingleType>(type))
343 return info.getFloatFormat();
344
345 if (mlir::isa<cir::DoubleType>(type))
346 return info.getDoubleFormat();
347
348 if (mlir::isa<cir::LongDoubleType>(type)) {
349 if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
350 llvm_unreachable("NYI Float type semantics with OpenMP");
351 return info.getLongDoubleFormat();
352 }
353
354 if (mlir::isa<cir::FP128Type>(type)) {
355 if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
356 llvm_unreachable("NYI Float type semantics with OpenMP");
357 return info.getFloat128Format();
358 }
359
360 assert(false && "Unsupported float type semantics");
361 };
362
363 const mlir::Type higherElementType = getHigherPrecisionFPType(elementType);
364 const llvm::fltSemantics &elementTypeSemantics =
365 getFloatTypeSemantics(elementType);
366 const llvm::fltSemantics &higherElementTypeSemantics =
367 getFloatTypeSemantics(higherElementType);
368
369 // Check that the promoted type can handle the intermediate values without
370 // overflowing. This can be interpreted as:
371 // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal) * 2 <=
372 // LargerType.LargestFiniteVal.
373 // In terms of exponent it gives this formula:
374 // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal
375 // doubles the exponent of SmallerType.LargestFiniteVal)
376 if (llvm::APFloat::semanticsMaxExponent(elementTypeSemantics) * 2 + 1 <=
377 llvm::APFloat::semanticsMaxExponent(higherElementTypeSemantics)) {
378 return higherElementType;
379 }
380
381 // The intermediate values can't be represented in the promoted type
382 // without overflowing.
383 return {};
384}
385
386static mlir::Value
387lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
388 mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal,
389 mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag,
390 mlir::MLIRContext &mlirCx, clang::ASTContext &cc) {
391 cir::ComplexType complexTy = op.getType();
392 if (mlir::isa<cir::FPTypeInterface>(complexTy.getElementType())) {
393 cir::ComplexRangeKind range = op.getRange();
394 if (range == cir::ComplexRangeKind::Improved)
395 return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
396 rhsReal, rhsImag);
397
398 if (range == cir::ComplexRangeKind::Full)
400 loc, complexTy, lhsReal, lhsImag, rhsReal,
401 rhsImag);
402
403 if (range == cir::ComplexRangeKind::Promoted) {
404 mlir::Type originalElementType = complexTy.getElementType();
405 mlir::Type higherPrecisionElementType =
407 originalElementType);
408
409 if (!higherPrecisionElementType)
410 return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
411 rhsReal, rhsImag);
412
413 cir::CastKind floatingCastKind = cir::CastKind::floating;
414 lhsReal = builder.createCast(floatingCastKind, lhsReal,
415 higherPrecisionElementType);
416 lhsImag = builder.createCast(floatingCastKind, lhsImag,
417 higherPrecisionElementType);
418 rhsReal = builder.createCast(floatingCastKind, rhsReal,
419 higherPrecisionElementType);
420 rhsImag = builder.createCast(floatingCastKind, rhsImag,
421 higherPrecisionElementType);
422
423 mlir::Value algebraicResult = buildAlgebraicComplexDiv(
424 builder, loc, lhsReal, lhsImag, rhsReal, rhsImag);
425
426 mlir::Value resultReal = builder.createComplexReal(loc, algebraicResult);
427 mlir::Value resultImag = builder.createComplexImag(loc, algebraicResult);
428
429 mlir::Value finalReal =
430 builder.createCast(floatingCastKind, resultReal, originalElementType);
431 mlir::Value finalImag =
432 builder.createCast(floatingCastKind, resultImag, originalElementType);
433 return builder.createComplexCreate(loc, finalReal, finalImag);
434 }
435 }
436
437 return buildAlgebraicComplexDiv(builder, loc, lhsReal, lhsImag, rhsReal,
438 rhsImag);
439}
440
441void LoweringPreparePass::lowerComplexDivOp(cir::ComplexDivOp op) {
442 cir::CIRBaseBuilderTy builder(getContext());
443 builder.setInsertionPointAfter(op);
444 mlir::Location loc = op.getLoc();
445 mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
446 mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
447 mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
448 mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
449 mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
450 mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
451
452 mlir::Value loweredResult =
453 lowerComplexDiv(*this, builder, loc, op, lhsReal, lhsImag, rhsReal,
454 rhsImag, getContext(), *astCtx);
455 op.replaceAllUsesWith(loweredResult);
456 op.erase();
457}
458
459static llvm::StringRef
460getComplexMulLibCallName(llvm::APFloat::Semantics semantics) {
461 switch (semantics) {
462 case llvm::APFloat::S_IEEEhalf:
463 return "__mulhc3";
464 case llvm::APFloat::S_IEEEsingle:
465 return "__mulsc3";
466 case llvm::APFloat::S_IEEEdouble:
467 return "__muldc3";
468 case llvm::APFloat::S_PPCDoubleDouble:
469 return "__multc3";
470 case llvm::APFloat::S_x87DoubleExtended:
471 return "__mulxc3";
472 case llvm::APFloat::S_IEEEquad:
473 return "__multc3";
474 default:
475 llvm_unreachable("unsupported floating point type");
476 }
477}
478
479static mlir::Value lowerComplexMul(LoweringPreparePass &pass,
480 CIRBaseBuilderTy &builder,
481 mlir::Location loc, cir::ComplexMulOp op,
482 mlir::Value lhsReal, mlir::Value lhsImag,
483 mlir::Value rhsReal, mlir::Value rhsImag) {
484 // (a+bi) * (c+di) = (ac-bd) + (ad+bc)i
485 mlir::Value resultRealLhs =
486 builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsReal);
487 mlir::Value resultRealRhs =
488 builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsImag);
489 mlir::Value resultImagLhs =
490 builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsImag);
491 mlir::Value resultImagRhs =
492 builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsReal);
493 mlir::Value resultReal = builder.createBinop(
494 loc, resultRealLhs, cir::BinOpKind::Sub, resultRealRhs);
495 mlir::Value resultImag = builder.createBinop(
496 loc, resultImagLhs, cir::BinOpKind::Add, resultImagRhs);
497 mlir::Value algebraicResult =
498 builder.createComplexCreate(loc, resultReal, resultImag);
499
500 cir::ComplexType complexTy = op.getType();
501 cir::ComplexRangeKind rangeKind = op.getRange();
502 if (mlir::isa<cir::IntType>(complexTy.getElementType()) ||
503 rangeKind == cir::ComplexRangeKind::Basic ||
504 rangeKind == cir::ComplexRangeKind::Improved ||
505 rangeKind == cir::ComplexRangeKind::Promoted)
506 return algebraicResult;
507
509
510 // Check whether the real part and the imaginary part of the result are both
511 // NaN. If so, emit a library call to compute the multiplication instead.
512 // We check a value against NaN by comparing the value against itself.
513 mlir::Value resultRealIsNaN = builder.createIsNaN(loc, resultReal);
514 mlir::Value resultImagIsNaN = builder.createIsNaN(loc, resultImag);
515 mlir::Value resultRealAndImagAreNaN =
516 builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN);
517
518 return builder
519 .create<cir::TernaryOp>(
520 loc, resultRealAndImagAreNaN,
521 [&](mlir::OpBuilder &, mlir::Location) {
522 mlir::Value libCallResult = buildComplexBinOpLibCall(
523 pass, builder, &getComplexMulLibCallName, loc, complexTy,
524 lhsReal, lhsImag, rhsReal, rhsImag);
525 builder.createYield(loc, libCallResult);
526 },
527 [&](mlir::OpBuilder &, mlir::Location) {
528 builder.createYield(loc, algebraicResult);
529 })
530 .getResult();
531}
532
533void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) {
534 cir::CIRBaseBuilderTy builder(getContext());
535 builder.setInsertionPointAfter(op);
536 mlir::Location loc = op.getLoc();
537 mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
538 mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
539 mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
540 mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
541 mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
542 mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
543 mlir::Value loweredResult = lowerComplexMul(*this, builder, loc, op, lhsReal,
544 lhsImag, rhsReal, rhsImag);
545 op.replaceAllUsesWith(loweredResult);
546 op.erase();
547}
548
549void LoweringPreparePass::lowerUnaryOp(cir::UnaryOp op) {
550 mlir::Type ty = op.getType();
551 if (!mlir::isa<cir::ComplexType>(ty))
552 return;
553
554 mlir::Location loc = op.getLoc();
555 cir::UnaryOpKind opKind = op.getKind();
556
557 CIRBaseBuilderTy builder(getContext());
558 builder.setInsertionPointAfter(op);
559
560 mlir::Value operand = op.getInput();
561 mlir::Value operandReal = builder.createComplexReal(loc, operand);
562 mlir::Value operandImag = builder.createComplexImag(loc, operand);
563
564 mlir::Value resultReal;
565 mlir::Value resultImag;
566
567 switch (opKind) {
568 case cir::UnaryOpKind::Inc:
569 case cir::UnaryOpKind::Dec:
570 resultReal = builder.createUnaryOp(loc, opKind, operandReal);
571 resultImag = operandImag;
572 break;
573
574 case cir::UnaryOpKind::Plus:
575 case cir::UnaryOpKind::Minus:
576 resultReal = builder.createUnaryOp(loc, opKind, operandReal);
577 resultImag = builder.createUnaryOp(loc, opKind, operandImag);
578 break;
579
580 case cir::UnaryOpKind::Not:
581 resultReal = operandReal;
582 resultImag =
583 builder.createUnaryOp(loc, cir::UnaryOpKind::Minus, operandImag);
584 break;
585 }
586
587 mlir::Value result = builder.createComplexCreate(loc, resultReal, resultImag);
588 op.replaceAllUsesWith(result);
589 op.erase();
590}
591
593 clang::ASTContext *astCtx,
594 mlir::Operation *op, mlir::Type eltTy,
595 mlir::Value arrayAddr, uint64_t arrayLen,
596 bool isCtor) {
597 // Generate loop to call into ctor/dtor for every element.
598 mlir::Location loc = op->getLoc();
599
600 // TODO: instead of getting the size from the AST context, create alias for
601 // PtrDiffTy and unify with CIRGen stuff.
602 const unsigned sizeTypeSize =
603 astCtx->getTypeSize(astCtx->getSignedSizeType());
604 uint64_t endOffset = isCtor ? arrayLen : arrayLen - 1;
605 mlir::Value endOffsetVal =
606 builder.getUnsignedInt(loc, endOffset, sizeTypeSize);
607
608 auto begin = cir::CastOp::create(builder, loc, eltTy,
609 cir::CastKind::array_to_ptrdecay, arrayAddr);
610 mlir::Value end =
611 cir::PtrStrideOp::create(builder, loc, eltTy, begin, endOffsetVal);
612 mlir::Value start = isCtor ? begin : end;
613 mlir::Value stop = isCtor ? end : begin;
614
615 mlir::Value tmpAddr = builder.createAlloca(
616 loc, /*addr type*/ builder.getPointerTo(eltTy),
617 /*var type*/ eltTy, "__array_idx", builder.getAlignmentAttr(1));
618 builder.createStore(loc, start, tmpAddr);
619
620 cir::DoWhileOp loop = builder.createDoWhile(
621 loc,
622 /*condBuilder=*/
623 [&](mlir::OpBuilder &b, mlir::Location loc) {
624 auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr);
625 mlir::Type boolTy = cir::BoolType::get(b.getContext());
626 auto cmp = builder.create<cir::CmpOp>(loc, boolTy, cir::CmpOpKind::ne,
627 currentElement, stop);
628 builder.createCondition(cmp);
629 },
630 /*bodyBuilder=*/
631 [&](mlir::OpBuilder &b, mlir::Location loc) {
632 auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr);
633
634 cir::CallOp ctorCall;
635 op->walk([&](cir::CallOp c) { ctorCall = c; });
636 assert(ctorCall && "expected ctor call");
637
638 // Array elements get constructed in order but destructed in reverse.
639 mlir::Value stride;
640 if (isCtor)
641 stride = builder.getUnsignedInt(loc, 1, sizeTypeSize);
642 else
643 stride = builder.getSignedInt(loc, -1, sizeTypeSize);
644
645 ctorCall->moveBefore(stride.getDefiningOp());
646 ctorCall->setOperand(0, currentElement);
647 auto nextElement = cir::PtrStrideOp::create(builder, loc, eltTy,
648 currentElement, stride);
649
650 // Store the element pointer to the temporary variable
651 builder.createStore(loc, nextElement, tmpAddr);
652 builder.createYield(loc);
653 });
654
655 op->replaceAllUsesWith(loop);
656 op->erase();
657}
658
659void LoweringPreparePass::lowerArrayDtor(cir::ArrayDtor op) {
660 CIRBaseBuilderTy builder(getContext());
661 builder.setInsertionPointAfter(op.getOperation());
662
663 mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
665 auto arrayLen =
666 mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
667 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen,
668 false);
669}
670
671void LoweringPreparePass::lowerArrayCtor(cir::ArrayCtor op) {
672 cir::CIRBaseBuilderTy builder(getContext());
673 builder.setInsertionPointAfter(op.getOperation());
674
675 mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
677 auto arrayLen =
678 mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
679 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen,
680 true);
681}
682
683void LoweringPreparePass::runOnOp(mlir::Operation *op) {
684 if (auto arrayCtor = dyn_cast<ArrayCtor>(op))
685 lowerArrayCtor(arrayCtor);
686 else if (auto arrayDtor = dyn_cast<cir::ArrayDtor>(op))
687 lowerArrayDtor(arrayDtor);
688 else if (auto cast = mlir::dyn_cast<cir::CastOp>(op))
689 lowerCastOp(cast);
690 else if (auto complexDiv = mlir::dyn_cast<cir::ComplexDivOp>(op))
691 lowerComplexDivOp(complexDiv);
692 else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op))
693 lowerComplexMulOp(complexMul);
694 else if (auto unary = mlir::dyn_cast<cir::UnaryOp>(op))
695 lowerUnaryOp(unary);
696}
697
698void LoweringPreparePass::runOnOperation() {
699 mlir::Operation *op = getOperation();
700 if (isa<::mlir::ModuleOp>(op))
701 mlirModule = cast<::mlir::ModuleOp>(op);
702
704
705 op->walk([&](mlir::Operation *op) {
706 if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
707 cir::ComplexMulOp, cir::ComplexDivOp, cir::UnaryOp>(op))
708 opsToTransform.push_back(op);
709 });
710
711 for (mlir::Operation *o : opsToTransform)
712 runOnOp(o);
713}
714
715std::unique_ptr<Pass> mlir::createLoweringPreparePass() {
716 return std::make_unique<LoweringPreparePass>();
717}
718
719std::unique_ptr<Pass>
721 auto pass = std::make_unique<LoweringPreparePass>();
722 pass->setASTContext(astCtx);
723 return std::move(pass);
724}
Defines the clang::ASTContext interface.
static mlir::Value buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder, clang::ASTContext *astCtx, mlir::Operation *op, mlir::Type eltTy, mlir::Value arrayAddr, uint64_t arrayLen, bool isCtor)
static llvm::StringRef getComplexDivLibCallName(llvm::APFloat::Semantics semantics)
static llvm::StringRef getComplexMulLibCallName(llvm::APFloat::Semantics semantics)
static mlir::Value buildComplexBinOpLibCall(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, llvm::StringRef(*libFuncNameGetter)(llvm::APFloat::Semantics), mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static mlir::Value lowerComplexMul(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexMulOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind scalarCastKind)
static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind elemToBoolKind)
static mlir::Value buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static mlir::Type higherPrecisionElementTypeForComplexArithmetic(mlir::MLIRContext &context, clang::ASTContext &cc, CIRBaseBuilderTy &builder, mlir::Type elementType)
static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op)
static mlir::Value lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag, mlir::MLIRContext &mlirCx, clang::ASTContext &cc)
__device__ __2f16 b
__device__ __2f16 float c
mlir::Value createLogicalOr(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
cir::ConditionOp createCondition(mlir::Value condition)
Create a loop condition.
cir::ConstantOp getNullValue(mlir::Type ty, mlir::Location loc)
mlir::Value createCast(mlir::Location loc, cir::CastKind kind, mlir::Value src, mlir::Type newTy)
cir::PointerType getPointerTo(mlir::Type ty)
mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand)
cir::DoWhileOp createDoWhile(mlir::Location loc, llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> condBuilder, llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> bodyBuilder)
Create a do-while operation.
cir::CallOp createCallOp(mlir::Location loc, mlir::SymbolRefAttr callee, mlir::Type returnType, mlir::ValueRange operands, llvm::ArrayRef< mlir::NamedAttribute > attrs={})
mlir::Value getSignedInt(mlir::Location loc, int64_t val, unsigned numBits)
cir::StoreOp createStore(mlir::Location loc, mlir::Value val, mlir::Value dst, bool isVolatile=false, mlir::IntegerAttr align={}, cir::MemOrderAttr order={})
cir::CmpOp createCompare(mlir::Location loc, cir::CmpOpKind kind, mlir::Value lhs, mlir::Value rhs)
mlir::IntegerAttr getAlignmentAttr(clang::CharUnits alignment)
mlir::Value createBinop(mlir::Location loc, mlir::Value lhs, cir::BinOpKind kind, mlir::Value rhs)
mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType, mlir::Type type, llvm::StringRef name, mlir::IntegerAttr alignment)
mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real, mlir::Value imag)
mlir::Value createIsNaN(mlir::Location loc, mlir::Value operand)
cir::YieldOp createYield(mlir::Location loc, mlir::ValueRange value={})
Create a yield operation.
mlir::Value createLogicalAnd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createUnaryOp(mlir::Location loc, cir::UnaryOpKind kind, mlir::Value operand)
cir::BoolType getBoolTy()
mlir::Value getUnsignedInt(mlir::Location loc, uint64_t val, unsigned numBits)
mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand)
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
Definition: ASTContext.h:188
const LangOptions & getLangOpts() const
Definition: ASTContext.h:894
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
Definition: ASTContext.h:2625
const TargetInfo & getTargetInfo() const
Definition: ASTContext.h:859
QualType getSignedSizeType() const
Return the unique signed counterpart of the integer type corresponding to size_t.
Exposes information about the current target.
Definition: TargetInfo.h:226
const llvm::fltSemantics & getDoubleFormat() const
Definition: TargetInfo.h:798
const llvm::fltSemantics & getHalfFormat() const
Definition: TargetInfo.h:783
const llvm::fltSemantics & getBFloat16Format() const
Definition: TargetInfo.h:793
const llvm::fltSemantics & getLongDoubleFormat() const
Definition: TargetInfo.h:804
const llvm::fltSemantics & getFloatFormat() const
Definition: TargetInfo.h:788
const llvm::fltSemantics & getFloat128Format() const
Definition: TargetInfo.h:812
Defines the clang::TargetInfo interface.
Definition: ABIArgInfo.h:22
const internal::VariadicAllOfMatcher< Type > type
Matches Types in the clang AST.
RangeSelector name(std::string ID)
Given a node with a "name", (like NamedDecl, DeclRefExpr, CxxCtorInitializer, and TypeLoc) selects th...
std::unique_ptr< Pass > createLoweringPreparePass()
static bool opFuncExtraAttrs()
static bool fastMathFlags()