clang 23.0.0git
LoweringPrepare.cpp
Go to the documentation of this file.
1//===- LoweringPrepare.cpp - pareparation work for LLVM lowering ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PassDetail.h"
10#include "mlir/IR/Attributes.h"
11#include "mlir/IR/BuiltinAttributeInterfaces.h"
12#include "mlir/IR/IRMapping.h"
13#include "mlir/IR/Location.h"
14#include "mlir/IR/Value.h"
16#include "clang/AST/Mangle.h"
17#include "clang/Basic/Cuda.h"
18#include "clang/Basic/Module.h"
32#include "llvm/ADT/StringRef.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/IR/Instructions.h"
35#include "llvm/Support/ErrorHandling.h"
36#include "llvm/Support/MemoryBuffer.h"
37#include "llvm/Support/Path.h"
38#include "llvm/Support/VirtualFileSystem.h"
39
40#include <memory>
41#include <optional>
42
43using namespace mlir;
44using namespace cir;
45
46namespace mlir {
47#define GEN_PASS_DEF_LOWERINGPREPARE
48#include "clang/CIR/Dialect/Passes.h.inc"
49} // namespace mlir
50
51static SmallString<128> getTransformedFileName(mlir::ModuleOp mlirModule) {
52 SmallString<128> fileName;
53
54 if (mlirModule.getSymName())
55 fileName = llvm::sys::path::filename(mlirModule.getSymName()->str());
56
57 if (fileName.empty())
58 fileName = "<null>";
59
60 for (size_t i = 0; i < fileName.size(); ++i) {
61 // Replace everything that's not [a-zA-Z0-9._] with a _. This set happens
62 // to be the set of C preprocessing numbers.
63 if (!clang::isPreprocessingNumberBody(fileName[i]))
64 fileName[i] = '_';
65 }
66
67 return fileName;
68}
69
70/// Return the FuncOp called by `callOp`.
71static cir::FuncOp getCalledFunction(cir::CallOp callOp) {
72 mlir::SymbolRefAttr sym = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
73 callOp.getCallableForCallee());
74 if (!sym)
75 return nullptr;
76 return dyn_cast_or_null<cir::FuncOp>(
77 mlir::SymbolTable::lookupNearestSymbolFrom(callOp, sym));
78}
79
80namespace {
81struct LoweringPreparePass
82 : public impl::LoweringPrepareBase<LoweringPreparePass> {
83 LoweringPreparePass() = default;
84 void runOnOperation() override;
85
86 void runOnOp(mlir::Operation *op);
87 void lowerCastOp(cir::CastOp op);
88 void lowerComplexDivOp(cir::ComplexDivOp op);
89 void lowerComplexMulOp(cir::ComplexMulOp op);
90 void lowerUnaryOp(cir::UnaryOpInterface op);
91 void lowerGlobalOp(cir::GlobalOp op);
92 void lowerThreeWayCmpOp(cir::CmpThreeWayOp op);
93 void lowerArrayDtor(cir::ArrayDtor op);
94 void lowerArrayCtor(cir::ArrayCtor op);
95 void lowerTrivialCopyCall(cir::CallOp op);
96 void lowerStoreOfConstAggregate(cir::StoreOp op);
97
98 /// Build the function that initializes the specified global
99 cir::FuncOp buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op);
100
101 /// Handle the dtor region by registering destructor with __cxa_atexit
102 cir::FuncOp getOrCreateDtorFunc(CIRBaseBuilderTy &builder, cir::GlobalOp op,
103 mlir::Region &dtorRegion,
104 cir::CallOp &dtorCall);
105
106 /// Build a module init function that calls all the dynamic initializers.
107 void buildCXXGlobalInitFunc();
108
109 /// Materialize global ctor/dtor list
110 void buildGlobalCtorDtorList();
111
112 cir::FuncOp buildRuntimeFunction(
113 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
114 cir::FuncType type,
115 cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage);
116
117 cir::GlobalOp buildRuntimeVariable(
118 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
119 mlir::Type type,
120 cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage,
121 cir::VisibilityKind visibility = cir::VisibilityKind::Default);
122
123 /// ------------
124 /// CUDA registration related
125 /// ------------
126
127 llvm::StringMap<FuncOp> cudaKernelMap;
128
129 /// Build the CUDA module constructor that registers the fat binary
130 /// with the CUDA runtime.
131 void buildCUDAModuleCtor();
132 std::optional<FuncOp> buildCUDAModuleDtor();
133 std::optional<FuncOp> buildCUDARegisterGlobals();
134 void buildCUDARegisterGlobalFunctions(cir::CIRBaseBuilderTy &builder,
135 FuncOp regGlobalFunc);
136
137 /// Handle static local variable initialization with guard variables.
138 void handleStaticLocal(cir::GlobalOp globalOp, cir::GetGlobalOp getGlobalOp);
139
140 /// Get or create __cxa_guard_acquire function.
141 cir::FuncOp getGuardAcquireFn(cir::PointerType guardPtrTy);
142
143 /// Get or create __cxa_guard_release function.
144 cir::FuncOp getGuardReleaseFn(cir::PointerType guardPtrTy);
145
146 /// Create a guard global variable for a static local.
147 cir::GlobalOp createGuardGlobalOp(CIRBaseBuilderTy &builder,
148 mlir::Location loc, llvm::StringRef name,
149 cir::IntType guardTy,
150 cir::GlobalLinkageKind linkage);
151
152 /// Get the guard variable for a static local declaration.
153 cir::GlobalOp getStaticLocalDeclGuardAddress(llvm::StringRef globalSymName) {
154 auto it = staticLocalDeclGuardMap.find(globalSymName);
155 if (it != staticLocalDeclGuardMap.end())
156 return it->second;
157 return nullptr;
158 }
159
160 /// Set the guard variable for a static local declaration.
161 void setStaticLocalDeclGuardAddress(llvm::StringRef globalSymName,
162 cir::GlobalOp guard) {
163 staticLocalDeclGuardMap[globalSymName] = guard;
164 }
165
166 /// Get or create the guard variable for a static local declaration.
167 cir::GlobalOp getOrCreateStaticLocalDeclGuardAddress(
168 CIRBaseBuilderTy &builder, cir::GlobalOp globalOp,
169 cir::ASTVarDeclInterface varDecl, cir::IntType guardTy,
170 clang::CharUnits guardAlignment) {
171 llvm::StringRef globalSymName = globalOp.getSymName();
172 cir::GlobalOp guard = getStaticLocalDeclGuardAddress(globalSymName);
173 if (!guard) {
174 // Get the guard name from the static_local attribute.
175 llvm::StringRef guardName =
176 globalOp.getStaticLocalGuard()->getName().getValue();
177
178 // Create the guard variable with a zero-initializer.
179 guard = createGuardGlobalOp(builder, globalOp->getLoc(), guardName,
180 guardTy, globalOp.getLinkage());
181 guard.setInitialValueAttr(cir::IntAttr::get(guardTy, 0));
182 guard.setDSOLocal(globalOp.getDsoLocal());
183 guard.setAlignment(guardAlignment.getAsAlign().value());
184
185 // The ABI says: "It is suggested that it be emitted in the same COMDAT
186 // group as the associated data object." In practice, this doesn't work
187 // for non-ELF and non-Wasm object formats, so only do it for ELF and
188 // Wasm.
189 bool hasComdat = globalOp.getComdat();
190 const llvm::Triple &triple = astCtx->getTargetInfo().getTriple();
191 if (!varDecl.isLocalVarDecl() && hasComdat &&
192 (triple.isOSBinFormatELF() || triple.isOSBinFormatWasm())) {
193 globalOp->emitError("NYI: guard COMDAT for non-local variables");
194 return {};
195 } else if (hasComdat && globalOp.isWeakForLinker()) {
196 globalOp->emitError("NYI: guard COMDAT for weak linkage");
197 return {};
198 }
199
200 setStaticLocalDeclGuardAddress(globalSymName, guard);
201 }
202 return guard;
203 }
204
205 ///
206 /// AST related
207 /// -----------
208
209 clang::ASTContext *astCtx;
210
211 /// Tracks current module.
212 mlir::ModuleOp mlirModule;
213
214 /// Tracks existing dynamic initializers.
215 llvm::StringMap<uint32_t> dynamicInitializerNames;
216 llvm::SmallVector<cir::FuncOp> dynamicInitializers;
217
218 /// Tracks guard variables for static locals (keyed by global symbol name).
219 llvm::StringMap<cir::GlobalOp> staticLocalDeclGuardMap;
220
221 /// List of ctors and their priorities to be called before main()
222 llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalCtorList;
223 /// List of dtors and their priorities to be called when unloading module.
224 llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalDtorList;
225
226 /// Returns true if the target uses ARM-style guard variables for static
227 /// local initialization (32-bit guard, check bit 0 only).
228 bool useARMGuardVarABI() const {
229 switch (astCtx->getCXXABIKind()) {
230 case clang::TargetCXXABI::GenericARM:
231 case clang::TargetCXXABI::iOS:
232 case clang::TargetCXXABI::WatchOS:
233 case clang::TargetCXXABI::GenericAArch64:
234 case clang::TargetCXXABI::WebAssembly:
235 return true;
236 default:
237 return false;
238 }
239 }
240
241 /// Emit the guarded initialization for a static local variable.
242 /// This handles the if/else structure after the guard byte check,
243 /// following OG's ItaniumCXXABI::EmitGuardedInit skeleton.
244 void emitCXXGuardedInitIf(CIRBaseBuilderTy &builder, cir::GlobalOp globalOp,
245 cir::ASTVarDeclInterface varDecl,
246 mlir::Value guardPtr, cir::PointerType guardPtrTy,
247 bool threadsafe) {
248 auto loc = globalOp->getLoc();
249
250 // The semantics of dynamic initialization of variables with static or
251 // thread storage duration depends on whether they are declared at
252 // block-scope. The initialization of such variables at block-scope can be
253 // aborted with an exception and later retried (per C++20 [stmt.dcl]p4),
254 // and recursive entry to their initialization has undefined behavior (also
255 // per C++20 [stmt.dcl]p4). For such variables declared at non-block scope,
256 // exceptions lead to termination (per C++20 [except.terminate]p1), and
257 // recursive references to the variables are governed only by the lifetime
258 // rules (per C++20 [class.cdtor]p2), which means such references are
259 // perfectly fine as long as they avoid touching memory. As a result,
260 // block-scope variables must not be marked as initialized until after
261 // initialization completes (unless the mark is reverted following an
262 // exception), but non-block-scope variables must be marked prior to
263 // initialization so that recursive accesses during initialization do not
264 // restart initialization.
265
266 // Variables used when coping with thread-safe statics and exceptions.
267 if (threadsafe) {
268 // Call __cxa_guard_acquire.
269 cir::CallOp acquireCall = builder.createCallOp(
270 loc, getGuardAcquireFn(guardPtrTy), mlir::ValueRange{guardPtr});
271 mlir::Value acquireResult = acquireCall.getResult();
272
273 auto acquireZero = builder.getConstantInt(
274 loc, mlir::cast<cir::IntType>(acquireResult.getType()), 0);
275 auto shouldInit = builder.createCompare(loc, cir::CmpOpKind::ne,
276 acquireResult, acquireZero);
277
278 // Create the IfOp for the shouldInit check.
279 // Pass an empty callback to avoid auto-creating a yield terminator.
280 auto ifOp =
281 cir::IfOp::create(builder, loc, shouldInit, /*withElseRegion=*/false,
282 [](mlir::OpBuilder &, mlir::Location) {});
283 mlir::OpBuilder::InsertionGuard insertGuard(builder);
284 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
285
286 // Call __cxa_guard_abort along the exceptional edge.
287 // OG: CGF.EHStack.pushCleanup<CallGuardAbort>(EHCleanup, guard);
289
290 // Emit the initializer and add a global destructor if appropriate.
291 auto &ctorRegion = globalOp.getCtorRegion();
292 assert(!ctorRegion.empty() && "This should never be empty here.");
293 if (!ctorRegion.hasOneBlock())
294 llvm_unreachable("Multiple blocks NYI");
295 mlir::Block &block = ctorRegion.front();
296 mlir::Block *insertBlock = builder.getInsertionBlock();
297 insertBlock->getOperations().splice(insertBlock->end(),
298 block.getOperations(), block.begin(),
299 std::prev(block.end()));
300 builder.setInsertionPointToEnd(insertBlock);
301 ctorRegion.getBlocks().clear();
302
303 // Pop the guard-abort cleanup if we pushed one.
304 // OG: CGF.PopCleanupBlock();
306
307 // Call __cxa_guard_release. This cannot throw.
308 builder.createCallOp(loc, getGuardReleaseFn(guardPtrTy),
309 mlir::ValueRange{guardPtr});
310
311 builder.createYield(loc);
312 } else if (!varDecl.isLocalVarDecl()) {
313 // For non-local variables, store 1 into the first byte of the guard
314 // variable before the object initialization begins so that references
315 // to the variable during initialization don't restart initialization.
316 // OG: Builder.CreateStore(llvm::ConstantInt::get(CGM.Int8Ty, 1), ...);
317 // Then: CGF.EmitCXXGlobalVarDeclInit(D, var, shouldPerformInit);
318 globalOp->emitError("NYI: non-threadsafe init for non-local variables");
319 return;
320 } else {
321 // For local variables, store 1 into the first byte of the guard variable
322 // after the object initialization completes so that initialization is
323 // retried if initialization is interrupted by an exception.
324 globalOp->emitError("NYI: non-threadsafe init for local variables");
325 return;
326 }
327
328 builder.createYield(loc); // Outermost IfOp
329 }
330
331 void setASTContext(clang::ASTContext *c) { astCtx = c; }
332};
333
334} // namespace
335
336cir::GlobalOp LoweringPreparePass::buildRuntimeVariable(
337 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
338 mlir::Type type, cir::GlobalLinkageKind linkage,
339 cir::VisibilityKind visibility) {
340 cir::GlobalOp g = dyn_cast_or_null<cir::GlobalOp>(
341 mlir::SymbolTable::lookupNearestSymbolFrom(
342 mlirModule, mlir::StringAttr::get(mlirModule->getContext(), name)));
343 if (!g) {
344 g = cir::GlobalOp::create(builder, loc, name, type);
345 g.setLinkageAttr(
346 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
347 mlir::SymbolTable::setSymbolVisibility(
348 g, mlir::SymbolTable::Visibility::Private);
349 g.setGlobalVisibility(visibility);
350 }
351 return g;
352}
353
354cir::FuncOp LoweringPreparePass::buildRuntimeFunction(
355 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
356 cir::FuncType type, cir::GlobalLinkageKind linkage) {
357 cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom(
358 mlirModule, StringAttr::get(mlirModule->getContext(), name)));
359 if (!f) {
360 f = cir::FuncOp::create(builder, loc, name, type);
361 f.setLinkageAttr(
362 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
363 mlir::SymbolTable::setSymbolVisibility(
364 f, mlir::SymbolTable::Visibility::Private);
365
367 }
368 return f;
369}
370
371static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx,
372 cir::CastOp op) {
373 cir::CIRBaseBuilderTy builder(ctx);
374 builder.setInsertionPoint(op);
375
376 mlir::Value src = op.getSrc();
377 mlir::Value imag = builder.getNullValue(src.getType(), op.getLoc());
378 return builder.createComplexCreate(op.getLoc(), src, imag);
379}
380
381static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx,
382 cir::CastOp op,
383 cir::CastKind elemToBoolKind) {
384 cir::CIRBaseBuilderTy builder(ctx);
385 builder.setInsertionPoint(op);
386
387 mlir::Value src = op.getSrc();
388 if (!mlir::isa<cir::BoolType>(op.getType()))
389 return builder.createComplexReal(op.getLoc(), src);
390
391 // Complex cast to bool: (bool)(a+bi) => (bool)a || (bool)b
392 mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
393 mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
394
395 cir::BoolType boolTy = builder.getBoolTy();
396 mlir::Value srcRealToBool =
397 builder.createCast(op.getLoc(), elemToBoolKind, srcReal, boolTy);
398 mlir::Value srcImagToBool =
399 builder.createCast(op.getLoc(), elemToBoolKind, srcImag, boolTy);
400 return builder.createLogicalOr(op.getLoc(), srcRealToBool, srcImagToBool);
401}
402
403static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx,
404 cir::CastOp op,
405 cir::CastKind scalarCastKind) {
406 CIRBaseBuilderTy builder(ctx);
407 builder.setInsertionPoint(op);
408
409 mlir::Value src = op.getSrc();
410 auto dstComplexElemTy =
411 mlir::cast<cir::ComplexType>(op.getType()).getElementType();
412
413 mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
414 mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
415
416 mlir::Value dstReal = builder.createCast(op.getLoc(), scalarCastKind, srcReal,
417 dstComplexElemTy);
418 mlir::Value dstImag = builder.createCast(op.getLoc(), scalarCastKind, srcImag,
419 dstComplexElemTy);
420 return builder.createComplexCreate(op.getLoc(), dstReal, dstImag);
421}
422
423void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
424 mlir::MLIRContext &ctx = getContext();
425 mlir::Value loweredValue = [&]() -> mlir::Value {
426 switch (op.getKind()) {
427 case cir::CastKind::float_to_complex:
428 case cir::CastKind::int_to_complex:
429 return lowerScalarToComplexCast(ctx, op);
430 case cir::CastKind::float_complex_to_real:
431 case cir::CastKind::int_complex_to_real:
432 return lowerComplexToScalarCast(ctx, op, op.getKind());
433 case cir::CastKind::float_complex_to_bool:
434 return lowerComplexToScalarCast(ctx, op, cir::CastKind::float_to_bool);
435 case cir::CastKind::int_complex_to_bool:
436 return lowerComplexToScalarCast(ctx, op, cir::CastKind::int_to_bool);
437 case cir::CastKind::float_complex:
438 return lowerComplexToComplexCast(ctx, op, cir::CastKind::floating);
439 case cir::CastKind::float_complex_to_int_complex:
440 return lowerComplexToComplexCast(ctx, op, cir::CastKind::float_to_int);
441 case cir::CastKind::int_complex:
442 return lowerComplexToComplexCast(ctx, op, cir::CastKind::integral);
443 case cir::CastKind::int_complex_to_float_complex:
444 return lowerComplexToComplexCast(ctx, op, cir::CastKind::int_to_float);
445 default:
446 return nullptr;
447 }
448 }();
449
450 if (loweredValue) {
451 op.replaceAllUsesWith(loweredValue);
452 op.erase();
453 }
454}
455
456static mlir::Value buildComplexBinOpLibCall(
457 LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
458 llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics),
459 mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal,
460 mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) {
461 cir::FPTypeInterface elementTy =
462 mlir::cast<cir::FPTypeInterface>(ty.getElementType());
463
464 llvm::StringRef libFuncName = libFuncNameGetter(
465 llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics()));
466 llvm::SmallVector<mlir::Type, 4> libFuncInputTypes(4, elementTy);
467
468 cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty);
469
470 // Insert a declaration for the runtime function to be used in Complex
471 // multiplication and division when needed
472 cir::FuncOp libFunc;
473 {
474 mlir::OpBuilder::InsertionGuard ipGuard{builder};
475 builder.setInsertionPointToStart(pass.mlirModule.getBody());
476 libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy);
477 }
478
479 cir::CallOp call =
480 builder.createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag});
481 return call.getResult();
482}
483
484static llvm::StringRef
485getComplexDivLibCallName(llvm::APFloat::Semantics semantics) {
486 switch (semantics) {
487 case llvm::APFloat::S_IEEEhalf:
488 return "__divhc3";
489 case llvm::APFloat::S_IEEEsingle:
490 return "__divsc3";
491 case llvm::APFloat::S_IEEEdouble:
492 return "__divdc3";
493 case llvm::APFloat::S_PPCDoubleDouble:
494 return "__divtc3";
495 case llvm::APFloat::S_x87DoubleExtended:
496 return "__divxc3";
497 case llvm::APFloat::S_IEEEquad:
498 return "__divtc3";
499 default:
500 llvm_unreachable("unsupported floating point type");
501 }
502}
503
504static mlir::Value
505buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc,
506 mlir::Value lhsReal, mlir::Value lhsImag,
507 mlir::Value rhsReal, mlir::Value rhsImag) {
508 // (a+bi) / (c+di) = ((ac+bd)/(cc+dd)) + ((bc-ad)/(cc+dd))i
509 mlir::Value &a = lhsReal;
510 mlir::Value &b = lhsImag;
511 mlir::Value &c = rhsReal;
512 mlir::Value &d = rhsImag;
513
514 mlir::Value ac = builder.createMul(loc, a, c); // a*c
515 mlir::Value bd = builder.createMul(loc, b, d); // b*d
516 mlir::Value cc = builder.createMul(loc, c, c); // c*c
517 mlir::Value dd = builder.createMul(loc, d, d); // d*d
518 mlir::Value acbd = builder.createAdd(loc, ac, bd); // ac+bd
519 mlir::Value ccdd = builder.createAdd(loc, cc, dd); // cc+dd
520 mlir::Value resultReal = builder.createDiv(loc, acbd, ccdd);
521
522 mlir::Value bc = builder.createMul(loc, b, c); // b*c
523 mlir::Value ad = builder.createMul(loc, a, d); // a*d
524 mlir::Value bcad = builder.createSub(loc, bc, ad); // bc-ad
525 mlir::Value resultImag = builder.createDiv(loc, bcad, ccdd);
526 return builder.createComplexCreate(loc, resultReal, resultImag);
527}
528
529static mlir::Value
531 mlir::Value lhsReal, mlir::Value lhsImag,
532 mlir::Value rhsReal, mlir::Value rhsImag) {
533 // Implements Smith's algorithm for complex division.
534 // SMITH, R. L. Algorithm 116: Complex division. Commun. ACM 5, 8 (1962).
535
536 // Let:
537 // - lhs := a+bi
538 // - rhs := c+di
539 // - result := lhs / rhs = e+fi
540 //
541 // The algorithm pseudocode looks like follows:
542 // if fabs(c) >= fabs(d):
543 // r := d / c
544 // tmp := c + r*d
545 // e = (a + b*r) / tmp
546 // f = (b - a*r) / tmp
547 // else:
548 // r := c / d
549 // tmp := d + r*c
550 // e = (a*r + b) / tmp
551 // f = (b*r - a) / tmp
552
553 mlir::Value &a = lhsReal;
554 mlir::Value &b = lhsImag;
555 mlir::Value &c = rhsReal;
556 mlir::Value &d = rhsImag;
557
558 auto trueBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
559 mlir::Value r = builder.createDiv(loc, d, c); // r := d / c
560 mlir::Value rd = builder.createMul(loc, r, d); // r*d
561 mlir::Value tmp = builder.createAdd(loc, c, rd); // tmp := c + r*d
562
563 mlir::Value br = builder.createMul(loc, b, r); // b*r
564 mlir::Value abr = builder.createAdd(loc, a, br); // a + b*r
565 mlir::Value e = builder.createDiv(loc, abr, tmp);
566
567 mlir::Value ar = builder.createMul(loc, a, r); // a*r
568 mlir::Value bar = builder.createSub(loc, b, ar); // b - a*r
569 mlir::Value f = builder.createDiv(loc, bar, tmp);
570
571 mlir::Value result = builder.createComplexCreate(loc, e, f);
572 builder.createYield(loc, result);
573 };
574
575 auto falseBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
576 mlir::Value r = builder.createDiv(loc, c, d); // r := c / d
577 mlir::Value rc = builder.createMul(loc, r, c); // r*c
578 mlir::Value tmp = builder.createAdd(loc, d, rc); // tmp := d + r*c
579
580 mlir::Value ar = builder.createMul(loc, a, r); // a*r
581 mlir::Value arb = builder.createAdd(loc, ar, b); // a*r + b
582 mlir::Value e = builder.createDiv(loc, arb, tmp);
583
584 mlir::Value br = builder.createMul(loc, b, r); // b*r
585 mlir::Value bra = builder.createSub(loc, br, a); // b*r - a
586 mlir::Value f = builder.createDiv(loc, bra, tmp);
587
588 mlir::Value result = builder.createComplexCreate(loc, e, f);
589 builder.createYield(loc, result);
590 };
591
592 auto cFabs = cir::FAbsOp::create(builder, loc, c);
593 auto dFabs = cir::FAbsOp::create(builder, loc, d);
594 cir::CmpOp cmpResult =
595 builder.createCompare(loc, cir::CmpOpKind::ge, cFabs, dFabs);
596 auto ternary = cir::TernaryOp::create(builder, loc, cmpResult,
597 trueBranchBuilder, falseBranchBuilder);
598
599 return ternary.getResult();
600}
601
603 mlir::MLIRContext &context, clang::ASTContext &cc,
604 CIRBaseBuilderTy &builder, mlir::Type elementType) {
605
606 auto getHigherPrecisionFPType = [&context](mlir::Type type) -> mlir::Type {
607 if (mlir::isa<cir::FP16Type>(type))
608 return cir::SingleType::get(&context);
609
610 if (mlir::isa<cir::SingleType>(type) || mlir::isa<cir::BF16Type>(type))
611 return cir::DoubleType::get(&context);
612
613 if (mlir::isa<cir::DoubleType>(type))
614 return cir::LongDoubleType::get(&context, type);
615
616 return type;
617 };
618
619 auto getFloatTypeSemantics =
620 [&cc](mlir::Type type) -> const llvm::fltSemantics & {
621 const clang::TargetInfo &info = cc.getTargetInfo();
622 if (mlir::isa<cir::FP16Type>(type))
623 return info.getHalfFormat();
624
625 if (mlir::isa<cir::BF16Type>(type))
626 return info.getBFloat16Format();
627
628 if (mlir::isa<cir::SingleType>(type))
629 return info.getFloatFormat();
630
631 if (mlir::isa<cir::DoubleType>(type))
632 return info.getDoubleFormat();
633
634 if (mlir::isa<cir::LongDoubleType>(type)) {
635 if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
636 llvm_unreachable("NYI Float type semantics with OpenMP");
637 return info.getLongDoubleFormat();
638 }
639
640 if (mlir::isa<cir::FP128Type>(type)) {
641 if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
642 llvm_unreachable("NYI Float type semantics with OpenMP");
643 return info.getFloat128Format();
644 }
645
646 llvm_unreachable("Unsupported float type semantics");
647 };
648
649 const mlir::Type higherElementType = getHigherPrecisionFPType(elementType);
650 const llvm::fltSemantics &elementTypeSemantics =
651 getFloatTypeSemantics(elementType);
652 const llvm::fltSemantics &higherElementTypeSemantics =
653 getFloatTypeSemantics(higherElementType);
654
655 // Check that the promoted type can handle the intermediate values without
656 // overflowing. This can be interpreted as:
657 // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal) * 2 <=
658 // LargerType.LargestFiniteVal.
659 // In terms of exponent it gives this formula:
660 // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal
661 // doubles the exponent of SmallerType.LargestFiniteVal)
662 if (llvm::APFloat::semanticsMaxExponent(elementTypeSemantics) * 2 + 1 <=
663 llvm::APFloat::semanticsMaxExponent(higherElementTypeSemantics)) {
664 return higherElementType;
665 }
666
667 // The intermediate values can't be represented in the promoted type
668 // without overflowing.
669 return {};
670}
671
672static mlir::Value
673lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
674 mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal,
675 mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag,
676 mlir::MLIRContext &mlirCx, clang::ASTContext &cc) {
677 cir::ComplexType complexTy = op.getType();
678 if (mlir::isa<cir::FPTypeInterface>(complexTy.getElementType())) {
679 cir::ComplexRangeKind range = op.getRange();
680 if (range == cir::ComplexRangeKind::Improved)
681 return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
682 rhsReal, rhsImag);
683
684 if (range == cir::ComplexRangeKind::Full)
686 loc, complexTy, lhsReal, lhsImag, rhsReal,
687 rhsImag);
688
689 if (range == cir::ComplexRangeKind::Promoted) {
690 mlir::Type originalElementType = complexTy.getElementType();
691 mlir::Type higherPrecisionElementType =
693 originalElementType);
694
695 if (!higherPrecisionElementType)
696 return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
697 rhsReal, rhsImag);
698
699 cir::CastKind floatingCastKind = cir::CastKind::floating;
700 lhsReal = builder.createCast(floatingCastKind, lhsReal,
701 higherPrecisionElementType);
702 lhsImag = builder.createCast(floatingCastKind, lhsImag,
703 higherPrecisionElementType);
704 rhsReal = builder.createCast(floatingCastKind, rhsReal,
705 higherPrecisionElementType);
706 rhsImag = builder.createCast(floatingCastKind, rhsImag,
707 higherPrecisionElementType);
708
709 mlir::Value algebraicResult = buildAlgebraicComplexDiv(
710 builder, loc, lhsReal, lhsImag, rhsReal, rhsImag);
711
712 mlir::Value resultReal = builder.createComplexReal(loc, algebraicResult);
713 mlir::Value resultImag = builder.createComplexImag(loc, algebraicResult);
714
715 mlir::Value finalReal =
716 builder.createCast(floatingCastKind, resultReal, originalElementType);
717 mlir::Value finalImag =
718 builder.createCast(floatingCastKind, resultImag, originalElementType);
719 return builder.createComplexCreate(loc, finalReal, finalImag);
720 }
721 }
722
723 return buildAlgebraicComplexDiv(builder, loc, lhsReal, lhsImag, rhsReal,
724 rhsImag);
725}
726
727void LoweringPreparePass::lowerComplexDivOp(cir::ComplexDivOp op) {
728 cir::CIRBaseBuilderTy builder(getContext());
729 builder.setInsertionPointAfter(op);
730 mlir::Location loc = op.getLoc();
731 mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
732 mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
733 mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
734 mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
735 mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
736 mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
737
738 mlir::Value loweredResult =
739 lowerComplexDiv(*this, builder, loc, op, lhsReal, lhsImag, rhsReal,
740 rhsImag, getContext(), *astCtx);
741 op.replaceAllUsesWith(loweredResult);
742 op.erase();
743}
744
745static llvm::StringRef
746getComplexMulLibCallName(llvm::APFloat::Semantics semantics) {
747 switch (semantics) {
748 case llvm::APFloat::S_IEEEhalf:
749 return "__mulhc3";
750 case llvm::APFloat::S_IEEEsingle:
751 return "__mulsc3";
752 case llvm::APFloat::S_IEEEdouble:
753 return "__muldc3";
754 case llvm::APFloat::S_PPCDoubleDouble:
755 return "__multc3";
756 case llvm::APFloat::S_x87DoubleExtended:
757 return "__mulxc3";
758 case llvm::APFloat::S_IEEEquad:
759 return "__multc3";
760 default:
761 llvm_unreachable("unsupported floating point type");
762 }
763}
764
765static mlir::Value lowerComplexMul(LoweringPreparePass &pass,
766 CIRBaseBuilderTy &builder,
767 mlir::Location loc, cir::ComplexMulOp op,
768 mlir::Value lhsReal, mlir::Value lhsImag,
769 mlir::Value rhsReal, mlir::Value rhsImag) {
770 // (a+bi) * (c+di) = (ac-bd) + (ad+bc)i
771 mlir::Value resultRealLhs = builder.createMul(loc, lhsReal, rhsReal); // ac
772 mlir::Value resultRealRhs = builder.createMul(loc, lhsImag, rhsImag); // bd
773 mlir::Value resultImagLhs = builder.createMul(loc, lhsReal, rhsImag); // ad
774 mlir::Value resultImagRhs = builder.createMul(loc, lhsImag, rhsReal); // bc
775 mlir::Value resultReal = builder.createSub(loc, resultRealLhs, resultRealRhs);
776 mlir::Value resultImag = builder.createAdd(loc, resultImagLhs, resultImagRhs);
777 mlir::Value algebraicResult =
778 builder.createComplexCreate(loc, resultReal, resultImag);
779
780 cir::ComplexType complexTy = op.getType();
781 cir::ComplexRangeKind rangeKind = op.getRange();
782 if (mlir::isa<cir::IntType>(complexTy.getElementType()) ||
783 rangeKind == cir::ComplexRangeKind::Basic ||
784 rangeKind == cir::ComplexRangeKind::Improved ||
785 rangeKind == cir::ComplexRangeKind::Promoted)
786 return algebraicResult;
787
789
790 // Check whether the real part and the imaginary part of the result are both
791 // NaN. If so, emit a library call to compute the multiplication instead.
792 // We check a value against NaN by comparing the value against itself.
793 mlir::Value resultRealIsNaN = builder.createIsNaN(loc, resultReal);
794 mlir::Value resultImagIsNaN = builder.createIsNaN(loc, resultImag);
795 mlir::Value resultRealAndImagAreNaN =
796 builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN);
797
798 return cir::TernaryOp::create(
799 builder, loc, resultRealAndImagAreNaN,
800 [&](mlir::OpBuilder &, mlir::Location) {
801 mlir::Value libCallResult = buildComplexBinOpLibCall(
802 pass, builder, &getComplexMulLibCallName, loc, complexTy,
803 lhsReal, lhsImag, rhsReal, rhsImag);
804 builder.createYield(loc, libCallResult);
805 },
806 [&](mlir::OpBuilder &, mlir::Location) {
807 builder.createYield(loc, algebraicResult);
808 })
809 .getResult();
810}
811
812void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) {
813 cir::CIRBaseBuilderTy builder(getContext());
814 builder.setInsertionPointAfter(op);
815 mlir::Location loc = op.getLoc();
816 mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
817 mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
818 mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
819 mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
820 mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
821 mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
822 mlir::Value loweredResult = lowerComplexMul(*this, builder, loc, op, lhsReal,
823 lhsImag, rhsReal, rhsImag);
824 op.replaceAllUsesWith(loweredResult);
825 op.erase();
826}
827
828void LoweringPreparePass::lowerUnaryOp(cir::UnaryOpInterface op) {
829 if (!mlir::isa<cir::ComplexType>(op.getResult().getType()))
830 return;
831
832 mlir::Location loc = op->getLoc();
833 CIRBaseBuilderTy builder(getContext());
834 builder.setInsertionPointAfter(op);
835
836 mlir::Value operand = op.getInput();
837 mlir::Value operandReal = builder.createComplexReal(loc, operand);
838 mlir::Value operandImag = builder.createComplexImag(loc, operand);
839
840 mlir::Value resultReal = operandReal;
841 mlir::Value resultImag = operandImag;
842
843 llvm::TypeSwitch<mlir::Operation *>(op)
844 .Case<cir::IncOp>(
845 [&](auto) { resultReal = builder.createInc(loc, operandReal); })
846 .Case<cir::DecOp>(
847 [&](auto) { resultReal = builder.createDec(loc, operandReal); })
848 .Case<cir::MinusOp>([&](auto) {
849 resultReal = builder.createMinus(loc, operandReal);
850 resultImag = builder.createMinus(loc, operandImag);
851 })
852 .Case<cir::NotOp>(
853 [&](auto) { resultImag = builder.createMinus(loc, operandImag); })
854 .Default([](auto) { llvm_unreachable("unhandled unary complex op"); });
855
856 mlir::Value result = builder.createComplexCreate(loc, resultReal, resultImag);
857 op->replaceAllUsesWith(mlir::ValueRange{result});
858 op->erase();
859}
860
861cir::FuncOp LoweringPreparePass::getOrCreateDtorFunc(CIRBaseBuilderTy &builder,
862 cir::GlobalOp op,
863 mlir::Region &dtorRegion,
864 cir::CallOp &dtorCall) {
865 mlir::OpBuilder::InsertionGuard guard(builder);
868
869 cir::VoidType voidTy = builder.getVoidTy();
870 auto voidPtrTy = cir::PointerType::get(voidTy);
871
872 // Look for operations in dtorBlock
873 mlir::Block &dtorBlock = dtorRegion.front();
874
875 // The first operation should be a get_global to retrieve the address
876 // of the global variable we're destroying.
877 auto opIt = dtorBlock.getOperations().begin();
878 cir::GetGlobalOp ggop = mlir::cast<cir::GetGlobalOp>(*opIt);
879
880 // The simple case is just a call to a destructor, like this:
881 //
882 // %0 = cir.get_global %globalS : !cir.ptr<!rec_S>
883 // cir.call %_ZN1SD1Ev(%0) : (!cir.ptr<!rec_S>) -> ()
884 // (implicit cir.yield)
885 //
886 // That is, if the second operation is a call that takes the get_global result
887 // as its only operand, and the only other operation is a yield, then we can
888 // just return the called function.
889 if (dtorBlock.getOperations().size() == 3) {
890 auto callOp = mlir::dyn_cast<cir::CallOp>(&*(++opIt));
891 auto yieldOp = mlir::dyn_cast<cir::YieldOp>(&*(++opIt));
892 if (yieldOp && callOp && callOp.getNumOperands() == 1 &&
893 callOp.getArgOperand(0) == ggop) {
894 dtorCall = callOp;
895 return getCalledFunction(callOp);
896 }
897 }
898
899 // Otherwise, we need to create a helper function to replace the dtor region.
900 // This name is kind of arbitrary, but it matches the name that classic
901 // codegen uses, based on the expected case that gets us here.
902 builder.setInsertionPointAfter(op);
903 SmallString<256> fnName("__cxx_global_array_dtor");
904 uint32_t cnt = dynamicInitializerNames[fnName]++;
905 if (cnt)
906 fnName += "." + std::to_string(cnt);
907
908 // Create the helper function.
909 auto fnType = cir::FuncType::get({voidPtrTy}, voidTy);
910 cir::FuncOp dtorFunc =
911 buildRuntimeFunction(builder, fnName, op.getLoc(), fnType,
912 cir::GlobalLinkageKind::InternalLinkage);
913
914 SmallVector<mlir::NamedAttribute> paramAttrs;
915 paramAttrs.push_back(
916 builder.getNamedAttr("llvm.noundef", builder.getUnitAttr()));
917 SmallVector<mlir::Attribute> argAttrDicts;
918 argAttrDicts.push_back(
919 mlir::DictionaryAttr::get(builder.getContext(), paramAttrs));
920 dtorFunc.setArgAttrsAttr(
921 mlir::ArrayAttr::get(builder.getContext(), argAttrDicts));
922
923 mlir::Block *entryBB = dtorFunc.addEntryBlock();
924
925 // Move everything from the dtor region into the helper function.
926 entryBB->getOperations().splice(entryBB->begin(), dtorBlock.getOperations(),
927 dtorBlock.begin(), dtorBlock.end());
928
929 // Before erasing this, clone it back into the dtor region
930 cir::GetGlobalOp dtorGGop =
931 mlir::cast<cir::GetGlobalOp>(entryBB->getOperations().front());
932 builder.setInsertionPointToStart(&dtorBlock);
933 builder.clone(*dtorGGop.getOperation());
934
935 // Replace all uses of the help function's get_global with the function
936 // argument.
937 mlir::Value dtorArg = entryBB->getArgument(0);
938 dtorGGop.replaceAllUsesWith(dtorArg);
939 dtorGGop.erase();
940
941 // Replace the yield in the final block with a return
942 mlir::Block &finalBlock = dtorFunc.getBody().back();
943 auto yieldOp = cast<cir::YieldOp>(finalBlock.getTerminator());
944 builder.setInsertionPoint(yieldOp);
945 cir::ReturnOp::create(builder, yieldOp->getLoc());
946 yieldOp->erase();
947
948 // Create a call to the helper function, passing the original get_global op
949 // as the argument.
950 cir::GetGlobalOp origGGop =
951 mlir::cast<cir::GetGlobalOp>(dtorBlock.getOperations().front());
952 builder.setInsertionPointAfter(origGGop);
953 mlir::Value ggopResult = origGGop.getResult();
954 dtorCall = builder.createCallOp(op.getLoc(), dtorFunc, ggopResult);
955
956 // Add a yield after the call.
957 auto finalYield = cir::YieldOp::create(builder, op.getLoc());
958
959 // Erase everything after the yield.
960 dtorBlock.getOperations().erase(std::next(mlir::Block::iterator(finalYield)),
961 dtorBlock.end());
962 dtorRegion.getBlocks().erase(std::next(dtorRegion.begin()), dtorRegion.end());
963
964 return dtorFunc;
965}
966
967cir::FuncOp
968LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op) {
969 // TODO(cir): Store this in the GlobalOp.
970 // This should come from the MangleContext, but for now I'm hardcoding it.
971 SmallString<256> fnName("__cxx_global_var_init");
972 // Get a unique name
973 uint32_t cnt = dynamicInitializerNames[fnName]++;
974 if (cnt)
975 fnName += "." + std::to_string(cnt);
976
977 // Create a variable initialization function.
978 CIRBaseBuilderTy builder(getContext());
979 builder.setInsertionPointAfter(op);
980 cir::VoidType voidTy = builder.getVoidTy();
981 auto fnType = cir::FuncType::get({}, voidTy);
982 FuncOp f = buildRuntimeFunction(builder, fnName, op.getLoc(), fnType,
983 cir::GlobalLinkageKind::InternalLinkage);
984
985 // Move over the initialzation code of the ctor region.
986 mlir::Block *entryBB = f.addEntryBlock();
987 if (!op.getCtorRegion().empty()) {
988 mlir::Block &block = op.getCtorRegion().front();
989 entryBB->getOperations().splice(entryBB->begin(), block.getOperations(),
990 block.begin(), std::prev(block.end()));
991 }
992
993 // Register the destructor call with __cxa_atexit
994 mlir::Region &dtorRegion = op.getDtorRegion();
995 if (!dtorRegion.empty()) {
998
999 // Create a variable that binds the atexit to this shared object.
1000 builder.setInsertionPointToStart(&mlirModule.getBodyRegion().front());
1001 cir::GlobalOp handle = buildRuntimeVariable(
1002 builder, "__dso_handle", op.getLoc(), builder.getI8Type(),
1003 cir::GlobalLinkageKind::ExternalLinkage, cir::VisibilityKind::Hidden);
1004
1005 // If this is a simple call to a destructor, get the called function.
1006 // Otherwise, create a helper function for the entire dtor region,
1007 // replacing the current dtor region body with a call to the helper
1008 // function.
1009 cir::CallOp dtorCall;
1010 cir::FuncOp dtorFunc =
1011 getOrCreateDtorFunc(builder, op, dtorRegion, dtorCall);
1012
1013 // Create a runtime helper function:
1014 // extern "C" int __cxa_atexit(void (*f)(void *), void *p, void *d);
1015 auto voidPtrTy = cir::PointerType::get(voidTy);
1016 auto voidFnTy = cir::FuncType::get({voidPtrTy}, voidTy);
1017 auto voidFnPtrTy = cir::PointerType::get(voidFnTy);
1018 auto handlePtrTy = cir::PointerType::get(handle.getSymType());
1019 auto fnAtExitType =
1020 cir::FuncType::get({voidFnPtrTy, voidPtrTy, handlePtrTy}, voidTy);
1021 const char *nameAtExit = "__cxa_atexit";
1022 cir::FuncOp fnAtExit =
1023 buildRuntimeFunction(builder, nameAtExit, op.getLoc(), fnAtExitType);
1024
1025 // Replace the dtor (or helper) call with a call to
1026 // __cxa_atexit(&dtor, &var, &__dso_handle)
1027 builder.setInsertionPointAfter(dtorCall);
1028 mlir::Value args[3];
1029 auto dtorPtrTy = cir::PointerType::get(dtorFunc.getFunctionType());
1030 // dtorPtrTy
1031 args[0] = cir::GetGlobalOp::create(builder, dtorCall.getLoc(), dtorPtrTy,
1032 dtorFunc.getSymName());
1033 args[0] = cir::CastOp::create(builder, dtorCall.getLoc(), voidFnPtrTy,
1034 cir::CastKind::bitcast, args[0]);
1035 args[1] =
1036 cir::CastOp::create(builder, dtorCall.getLoc(), voidPtrTy,
1037 cir::CastKind::bitcast, dtorCall.getArgOperand(0));
1038 args[2] = cir::GetGlobalOp::create(builder, handle.getLoc(), handlePtrTy,
1039 handle.getSymName());
1040 builder.createCallOp(dtorCall.getLoc(), fnAtExit, args);
1041 dtorCall->erase();
1042 mlir::Block &dtorBlock = dtorRegion.front();
1043 entryBB->getOperations().splice(entryBB->end(), dtorBlock.getOperations(),
1044 dtorBlock.begin(),
1045 std::prev(dtorBlock.end()));
1046 }
1047
1048 // Replace cir.yield with cir.return
1049 builder.setInsertionPointToEnd(entryBB);
1050 mlir::Operation *yieldOp = nullptr;
1051 if (!op.getCtorRegion().empty()) {
1052 mlir::Block &block = op.getCtorRegion().front();
1053 yieldOp = &block.getOperations().back();
1054 } else {
1055 assert(!dtorRegion.empty());
1056 mlir::Block &block = dtorRegion.front();
1057 yieldOp = &block.getOperations().back();
1058 }
1059
1060 assert(isa<cir::YieldOp>(*yieldOp));
1061 cir::ReturnOp::create(builder, yieldOp->getLoc());
1062 return f;
1063}
1064
1065cir::FuncOp
1066LoweringPreparePass::getGuardAcquireFn(cir::PointerType guardPtrTy) {
1067 // int __cxa_guard_acquire(__guard *guard_object);
1068 CIRBaseBuilderTy builder(getContext());
1069 mlir::OpBuilder::InsertionGuard ipGuard{builder};
1070 builder.setInsertionPointToStart(mlirModule.getBody());
1071 mlir::Location loc = mlirModule.getLoc();
1072 cir::IntType intTy = cir::IntType::get(&getContext(), 32, /*isSigned=*/true);
1073 auto fnType = cir::FuncType::get({guardPtrTy}, intTy);
1074 return buildRuntimeFunction(builder, "__cxa_guard_acquire", loc, fnType);
1075}
1076
1077cir::FuncOp
1078LoweringPreparePass::getGuardReleaseFn(cir::PointerType guardPtrTy) {
1079 // void __cxa_guard_release(__guard *guard_object);
1080 CIRBaseBuilderTy builder(getContext());
1081 mlir::OpBuilder::InsertionGuard ipGuard{builder};
1082 builder.setInsertionPointToStart(mlirModule.getBody());
1083 mlir::Location loc = mlirModule.getLoc();
1084 cir::VoidType voidTy = cir::VoidType::get(&getContext());
1085 auto fnType = cir::FuncType::get({guardPtrTy}, voidTy);
1086 return buildRuntimeFunction(builder, "__cxa_guard_release", loc, fnType);
1087}
1088
1089cir::GlobalOp LoweringPreparePass::createGuardGlobalOp(
1090 CIRBaseBuilderTy &builder, mlir::Location loc, llvm::StringRef name,
1091 cir::IntType guardTy, cir::GlobalLinkageKind linkage) {
1092 mlir::OpBuilder::InsertionGuard guard(builder);
1093 builder.setInsertionPointToStart(mlirModule.getBody());
1094 cir::GlobalOp g = cir::GlobalOp::create(builder, loc, name, guardTy);
1095 g.setLinkageAttr(
1096 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
1097 mlir::SymbolTable::setSymbolVisibility(
1098 g, mlir::SymbolTable::Visibility::Private);
1099 return g;
1100}
1101
1102void LoweringPreparePass::handleStaticLocal(cir::GlobalOp globalOp,
1103 cir::GetGlobalOp getGlobalOp) {
1104 CIRBaseBuilderTy builder(getContext());
1105
1106 std::optional<cir::ASTVarDeclInterface> astOption = globalOp.getAst();
1107 assert(astOption.has_value());
1108 cir::ASTVarDeclInterface varDecl = astOption.value();
1109
1110 builder.setInsertionPointAfter(getGlobalOp);
1111 mlir::Block *getGlobalOpBlock = builder.getInsertionBlock();
1112
1113 // Remove the terminator temporarily - we'll add it back at the end.
1114 mlir::Operation *ret = getGlobalOpBlock->getTerminator();
1115 ret->remove();
1116 builder.setInsertionPointAfter(getGlobalOp);
1117
1118 // Inline variables that weren't instantiated from variable templates have
1119 // partially-ordered initialization within their translation unit.
1120 bool nonTemplateInline =
1121 varDecl.isInline() &&
1122 !clang::isTemplateInstantiation(varDecl.getTemplateSpecializationKind());
1123
1124 // Inline namespace-scope variables require guarded initialization in a
1125 // __cxx_global_var_init function. This is not yet implemented.
1126 if (nonTemplateInline) {
1127 globalOp->emitError(
1128 "NYI: guarded initialization for inline namespace-scope variables");
1129 return;
1130 }
1131
1132 // We only need to use thread-safe statics for local non-TLS variables and
1133 // inline variables; other global initialization is always single-threaded
1134 // or (through lazy dynamic loading in multiple threads) unsequenced.
1135 bool threadsafe = astCtx->getLangOpts().ThreadsafeStatics &&
1136 (varDecl.isLocalVarDecl() || nonTemplateInline) &&
1137 !varDecl.getTLSKind();
1138
1139 // TLS variables need special handling - the guard must also be thread-local.
1140 if (varDecl.getTLSKind()) {
1141 globalOp->emitError("NYI: guarded initialization for thread-local statics");
1142 return;
1143 }
1144
1145 // If we have a global variable with internal linkage and thread-safe statics
1146 // are disabled, we can just let the guard variable be of type i8.
1147 bool useInt8GuardVariable = !threadsafe && globalOp.hasInternalLinkage();
1148 if (useInt8GuardVariable) {
1149 globalOp->emitError("NYI: int8 guard variables for non-threadsafe statics");
1150 return;
1151 }
1152
1153 // Guard variables are 64 bits in the generic ABI and size width on ARM
1154 // (i.e. 32-bit on AArch32, 64-bit on AArch64).
1155 if (useARMGuardVarABI()) {
1156 globalOp->emitError("NYI: ARM-style guard variables for static locals");
1157 return;
1158 }
1159 cir::IntType guardTy =
1160 cir::IntType::get(&getContext(), 64, /*isSigned=*/true);
1161 cir::CIRDataLayout dataLayout(mlirModule);
1162 clang::CharUnits guardAlignment =
1163 clang::CharUnits::fromQuantity(dataLayout.getABITypeAlign(guardTy));
1164 auto guardPtrTy = cir::PointerType::get(guardTy);
1165
1166 // Create the guard variable if we don't already have it.
1167 cir::GlobalOp guard = getOrCreateStaticLocalDeclGuardAddress(
1168 builder, globalOp, varDecl, guardTy, guardAlignment);
1169 if (!guard) {
1170 // Error was already emitted, just restore the terminator and return.
1171 getGlobalOpBlock->push_back(ret);
1172 return;
1173 }
1174
1175 mlir::Value guardPtr = builder.createGetGlobal(guard, /*threadLocal*/ false);
1176
1177 // Test whether the variable has completed initialization.
1178 //
1179 // Itanium C++ ABI 3.3.2:
1180 // The following is pseudo-code showing how these functions can be used:
1181 // if (obj_guard.first_byte == 0) {
1182 // if ( __cxa_guard_acquire (&obj_guard) ) {
1183 // try {
1184 // ... initialize the object ...;
1185 // } catch (...) {
1186 // __cxa_guard_abort (&obj_guard);
1187 // throw;
1188 // }
1189 // ... queue object destructor with __cxa_atexit() ...;
1190 // __cxa_guard_release (&obj_guard);
1191 // }
1192 // }
1193 //
1194 // If threadsafe statics are enabled, but we don't have inline atomics, just
1195 // call __cxa_guard_acquire unconditionally. The "inline" check isn't
1196 // actually inline, and the user might not expect calls to __atomic libcalls.
1197 unsigned maxInlineWidthInBits =
1199
1200 if (!threadsafe || maxInlineWidthInBits) {
1201 // Load the first byte of the guard variable.
1202 auto bytePtrTy = cir::PointerType::get(builder.getSIntNTy(8));
1203 mlir::Value bytePtr = builder.createBitcast(guardPtr, bytePtrTy);
1204 mlir::Value guardLoad = builder.createAlignedLoad(
1205 getGlobalOp.getLoc(), bytePtr, guardAlignment.getAsAlign().value());
1206
1207 // Itanium ABI:
1208 // An implementation supporting thread-safety on multiprocessor
1209 // systems must also guarantee that references to the initialized
1210 // object do not occur before the load of the initialization flag.
1211 //
1212 // In LLVM, we do this by marking the load Acquire.
1213 if (threadsafe) {
1214 auto loadOp = mlir::cast<cir::LoadOp>(guardLoad.getDefiningOp());
1215 loadOp.setMemOrder(cir::MemOrder::Acquire);
1216 loadOp.setSyncScope(cir::SyncScopeKind::System);
1217 }
1218
1219 // For ARM, we should only check the first bit, rather than the entire byte:
1220 //
1221 // ARM C++ ABI 3.2.3.1:
1222 // To support the potential use of initialization guard variables
1223 // as semaphores that are the target of ARM SWP and LDREX/STREX
1224 // synchronizing instructions we define a static initialization
1225 // guard variable to be a 4-byte aligned, 4-byte word with the
1226 // following inline access protocol.
1227 // #define INITIALIZED 1
1228 // if ((obj_guard & INITIALIZED) != INITIALIZED) {
1229 // if (__cxa_guard_acquire(&obj_guard))
1230 // ...
1231 // }
1232 //
1233 // and similarly for ARM64:
1234 //
1235 // ARM64 C++ ABI 3.2.2:
1236 // This ABI instead only specifies the value bit 0 of the static guard
1237 // variable; all other bits are platform defined. Bit 0 shall be 0 when
1238 // the variable is not initialized and 1 when it is.
1239 if (useARMGuardVarABI()) {
1240 globalOp->emitError(
1241 "NYI: ARM-style guard variable check (bit 0 only) for static locals");
1242 return;
1243 }
1244
1245 // Check if the first byte of the guard variable is zero.
1246 auto zero = builder.getConstantInt(
1247 getGlobalOp.getLoc(), mlir::cast<cir::IntType>(guardLoad.getType()), 0);
1248 auto needsInit = builder.createCompare(getGlobalOp.getLoc(),
1249 cir::CmpOpKind::eq, guardLoad, zero);
1250
1251 // Build the guarded initialization inside an if block.
1252 cir::IfOp::create(builder, globalOp.getLoc(), needsInit,
1253 /*withElseRegion=*/false,
1254 [&](mlir::OpBuilder &, mlir::Location) {
1255 emitCXXGuardedInitIf(builder, globalOp, varDecl,
1256 guardPtr, guardPtrTy, threadsafe);
1257 });
1258 } else {
1259 // Threadsafe statics without inline atomics - call __cxa_guard_acquire
1260 // unconditionally without the initial guard byte check.
1261 globalOp->emitError("NYI: guarded init without inline atomics support");
1262 return;
1263 }
1264
1265 // Insert the removed terminator back.
1266 builder.getInsertionBlock()->push_back(ret);
1267}
1268
1269void LoweringPreparePass::lowerGlobalOp(GlobalOp op) {
1270 // Static locals are handled separately via guard variables.
1271 if (op.getStaticLocalGuard())
1272 return;
1273
1274 mlir::Region &ctorRegion = op.getCtorRegion();
1275 mlir::Region &dtorRegion = op.getDtorRegion();
1276
1277 if (!ctorRegion.empty() || !dtorRegion.empty()) {
1278 // Build a variable initialization function and move the initialzation code
1279 // in the ctor region over.
1280 cir::FuncOp f = buildCXXGlobalVarDeclInitFunc(op);
1281
1282 // Clear the ctor and dtor region
1283 ctorRegion.getBlocks().clear();
1284 dtorRegion.getBlocks().clear();
1285
1287 dynamicInitializers.push_back(f);
1288 }
1289
1291}
1292
1293void LoweringPreparePass::lowerThreeWayCmpOp(CmpThreeWayOp op) {
1294 CIRBaseBuilderTy builder(getContext());
1295 builder.setInsertionPointAfter(op);
1296
1297 mlir::Location loc = op->getLoc();
1298 cir::CmpThreeWayInfoAttr cmpInfo = op.getInfo();
1299
1300 mlir::Value ltRes =
1301 builder.getConstantInt(loc, op.getType(), cmpInfo.getLt());
1302 mlir::Value eqRes =
1303 builder.getConstantInt(loc, op.getType(), cmpInfo.getEq());
1304 mlir::Value gtRes =
1305 builder.getConstantInt(loc, op.getType(), cmpInfo.getGt());
1306
1307 mlir::Value transformedResult;
1308 if (cmpInfo.getOrdering() != CmpOrdering::Partial) {
1309 // Total ordering
1310 mlir::Value lt =
1311 builder.createCompare(loc, CmpOpKind::lt, op.getLhs(), op.getRhs());
1312 mlir::Value selectOnLt = builder.createSelect(loc, lt, ltRes, gtRes);
1313 mlir::Value eq =
1314 builder.createCompare(loc, CmpOpKind::eq, op.getLhs(), op.getRhs());
1315 transformedResult = builder.createSelect(loc, eq, eqRes, selectOnLt);
1316 } else {
1317 // Partial ordering
1318 cir::ConstantOp unorderedRes = builder.getConstantInt(
1319 loc, op.getType(), cmpInfo.getUnordered().value());
1320
1321 mlir::Value eq =
1322 builder.createCompare(loc, CmpOpKind::eq, op.getLhs(), op.getRhs());
1323 mlir::Value selectOnEq = builder.createSelect(loc, eq, eqRes, unorderedRes);
1324 mlir::Value gt =
1325 builder.createCompare(loc, CmpOpKind::gt, op.getLhs(), op.getRhs());
1326 mlir::Value selectOnGt = builder.createSelect(loc, gt, gtRes, selectOnEq);
1327 mlir::Value lt =
1328 builder.createCompare(loc, CmpOpKind::lt, op.getLhs(), op.getRhs());
1329 transformedResult = builder.createSelect(loc, lt, ltRes, selectOnGt);
1330 }
1331
1332 op.replaceAllUsesWith(transformedResult);
1333 op.erase();
1334}
1335
1336template <typename AttributeTy>
1337static llvm::SmallVector<mlir::Attribute>
1338prepareCtorDtorAttrList(mlir::MLIRContext *context,
1339 llvm::ArrayRef<std::pair<std::string, uint32_t>> list) {
1341 for (const auto &[name, priority] : list)
1342 attrs.push_back(AttributeTy::get(context, name, priority));
1343 return attrs;
1344}
1345
1346void LoweringPreparePass::buildGlobalCtorDtorList() {
1347 if (!globalCtorList.empty()) {
1348 llvm::SmallVector<mlir::Attribute> globalCtors =
1350 globalCtorList);
1351
1352 mlirModule->setAttr(cir::CIRDialect::getGlobalCtorsAttrName(),
1353 mlir::ArrayAttr::get(&getContext(), globalCtors));
1354 }
1355
1356 if (!globalDtorList.empty()) {
1357 llvm::SmallVector<mlir::Attribute> globalDtors =
1359 globalDtorList);
1360 mlirModule->setAttr(cir::CIRDialect::getGlobalDtorsAttrName(),
1361 mlir::ArrayAttr::get(&getContext(), globalDtors));
1362 }
1363}
1364
1365void LoweringPreparePass::buildCXXGlobalInitFunc() {
1366 if (dynamicInitializers.empty())
1367 return;
1368
1369 // TODO: handle globals with a user-specified initialzation priority.
1370 // TODO: handle default priority more nicely.
1372
1373 SmallString<256> fnName;
1374 // Include the filename in the symbol name. Including "sub_" matches gcc
1375 // and makes sure these symbols appear lexicographically behind the symbols
1376 // with priority (TBD). Module implementation units behave the same
1377 // way as a non-modular TU with imports.
1378 // TODO: check CXX20ModuleInits
1379 if (astCtx->getCurrentNamedModule() &&
1381 llvm::raw_svector_ostream out(fnName);
1382 std::unique_ptr<clang::MangleContext> mangleCtx(
1383 astCtx->createMangleContext());
1384 cast<clang::ItaniumMangleContext>(*mangleCtx)
1385 .mangleModuleInitializer(astCtx->getCurrentNamedModule(), out);
1386 } else {
1387 fnName += "_GLOBAL__sub_I_";
1388 fnName += getTransformedFileName(mlirModule);
1389 }
1390
1391 CIRBaseBuilderTy builder(getContext());
1392 builder.setInsertionPointToEnd(&mlirModule.getBodyRegion().back());
1393 auto fnType = cir::FuncType::get({}, builder.getVoidTy());
1394 cir::FuncOp f =
1395 buildRuntimeFunction(builder, fnName, mlirModule.getLoc(), fnType,
1396 cir::GlobalLinkageKind::ExternalLinkage);
1397 builder.setInsertionPointToStart(f.addEntryBlock());
1398 for (cir::FuncOp &f : dynamicInitializers)
1399 builder.createCallOp(f.getLoc(), f, {});
1400 // Add the global init function (not the individual ctor functions) to the
1401 // global ctor list.
1402 globalCtorList.emplace_back(fnName,
1403 cir::GlobalCtorAttr::getDefaultPriority());
1404
1405 cir::ReturnOp::create(builder, f.getLoc());
1406}
1407
1408/// Lower a cir.array.ctor or cir.array.dtor into a do-while loop that
1409/// iterates over every element. For cir.array.ctor ops whose partial_dtor
1410/// region is non-empty, the ctor loop is wrapped in a cir.cleanup.scope whose
1411/// EH cleanup performs a reverse destruction loop using the partial dtor body.
1413 clang::ASTContext *astCtx,
1414 mlir::Operation *op, mlir::Type eltTy,
1415 mlir::Value addr,
1416 mlir::Value numElements,
1417 uint64_t arrayLen, bool isCtor) {
1418 mlir::Location loc = op->getLoc();
1419 bool isDynamic = numElements != nullptr;
1420
1421 // TODO: instead of getting the size from the AST context, create alias for
1422 // PtrDiffTy and unify with CIRGen stuff.
1423 const unsigned sizeTypeSize =
1424 astCtx->getTypeSize(astCtx->getSignedSizeType());
1425
1426 // Both constructors and destructors use end = begin + numElements.
1427 // Constructors iterate forward [begin, end). Destructors iterate backward
1428 // from end, decrementing before calling the destructor on each element.
1429 mlir::Value begin, end;
1430 if (isDynamic) {
1431 begin = addr;
1432 end = cir::PtrStrideOp::create(builder, loc, eltTy, begin, numElements);
1433 } else {
1434 mlir::Value endOffsetVal =
1435 builder.getUnsignedInt(loc, arrayLen, sizeTypeSize);
1436 begin = cir::CastOp::create(builder, loc, eltTy,
1437 cir::CastKind::array_to_ptrdecay, addr);
1438 end = cir::PtrStrideOp::create(builder, loc, eltTy, begin, endOffsetVal);
1439 }
1440
1441 mlir::Value start = isCtor ? begin : end;
1442 mlir::Value stop = isCtor ? end : begin;
1443
1444 // For dynamic destructors, guard against zero elements.
1445 // This places the destructor loop emitted below inside the if block.
1446 cir::IfOp ifOp;
1447 if (isDynamic) {
1448 mlir::Value guardCond;
1449 if (isCtor) {
1450 mlir::Value zero = builder.getUnsignedInt(loc, 0, sizeTypeSize);
1451 guardCond = cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne,
1452 numElements, zero);
1453 } else {
1454 // We could check for numElements != 0 in this case too, but this matches
1455 // what classic codegen does.
1456 guardCond =
1457 cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne, start, stop);
1458 }
1459 ifOp = cir::IfOp::create(builder, loc, guardCond,
1460 /*withElseRegion=*/false,
1461 [&](mlir::OpBuilder &, mlir::Location) {});
1462 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1463 }
1464
1465 mlir::Value tmpAddr = builder.createAlloca(
1466 loc, /*addr type*/ builder.getPointerTo(eltTy),
1467 /*var type*/ eltTy, "__array_idx", builder.getAlignmentAttr(1));
1468 builder.createStore(loc, start, tmpAddr);
1469
1470 mlir::Block *bodyBlock = &op->getRegion(0).front();
1471
1472 // Clone the region body (ctor/dtor call and any setup ops like per-element
1473 // zero-init) into the loop, remapping the block argument to the current
1474 // element pointer.
1475 auto cloneRegionBodyInto = [&](mlir::Block *srcBlock,
1476 mlir::Value replacement) {
1477 mlir::IRMapping map;
1478 map.map(srcBlock->getArgument(0), replacement);
1479 for (mlir::Operation &regionOp : *srcBlock) {
1480 if (!mlir::isa<cir::YieldOp>(&regionOp))
1481 builder.clone(regionOp, map);
1482 }
1483 };
1484
1485 mlir::Block *partialDtorBlock = nullptr;
1486 if (auto arrayCtor = mlir::dyn_cast<cir::ArrayCtor>(op)) {
1487 mlir::Region &partialDtor = arrayCtor.getPartialDtor();
1488 if (!partialDtor.empty())
1489 partialDtorBlock = &partialDtor.front();
1490 }
1491
1492 auto emitCtorDtorLoop = [&]() {
1493 builder.createDoWhile(
1494 loc,
1495 /*condBuilder=*/
1496 [&](mlir::OpBuilder &b, mlir::Location loc) {
1497 auto currentElement = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
1498 auto cmp = cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne,
1499 currentElement, stop);
1500 builder.createCondition(cmp);
1501 },
1502 /*bodyBuilder=*/
1503 [&](mlir::OpBuilder &b, mlir::Location loc) {
1504 auto currentElement = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
1505 if (isCtor) {
1506 cloneRegionBodyInto(bodyBlock, currentElement);
1507 mlir::Value stride = builder.getUnsignedInt(loc, 1, sizeTypeSize);
1508 auto nextElement = cir::PtrStrideOp::create(builder, loc, eltTy,
1509 currentElement, stride);
1510 builder.createStore(loc, nextElement, tmpAddr);
1511 } else {
1512 mlir::Value stride = builder.getSignedInt(loc, -1, sizeTypeSize);
1513 auto prevElement = cir::PtrStrideOp::create(builder, loc, eltTy,
1514 currentElement, stride);
1515 builder.createStore(loc, prevElement, tmpAddr);
1516 cloneRegionBodyInto(bodyBlock, prevElement);
1517 }
1518
1519 cir::YieldOp::create(b, loc);
1520 });
1521 };
1522
1523 if (partialDtorBlock) {
1524 cir::CleanupScopeOp::create(
1525 builder, loc, cir::CleanupKind::EH,
1526 /*bodyBuilder=*/
1527 [&](mlir::OpBuilder &b, mlir::Location loc) {
1528 emitCtorDtorLoop();
1529 cir::YieldOp::create(b, loc);
1530 },
1531 /*cleanupBuilder=*/
1532 [&](mlir::OpBuilder &b, mlir::Location loc) {
1533 auto cur = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
1534 auto cmp =
1535 cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne, cur, begin);
1536 cir::IfOp::create(
1537 builder, loc, cmp, /*withElseRegion=*/false,
1538 [&](mlir::OpBuilder &b, mlir::Location loc) {
1539 builder.createDoWhile(
1540 loc,
1541 /*condBuilder=*/
1542 [&](mlir::OpBuilder &b, mlir::Location loc) {
1543 auto el = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
1544 auto neq = cir::CmpOp::create(
1545 builder, loc, cir::CmpOpKind::ne, el, begin);
1546 builder.createCondition(neq);
1547 },
1548 /*bodyBuilder=*/
1549 [&](mlir::OpBuilder &b, mlir::Location loc) {
1550 auto el = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
1551 mlir::Value negOne =
1552 builder.getSignedInt(loc, -1, sizeTypeSize);
1553 auto prev = cir::PtrStrideOp::create(builder, loc, eltTy,
1554 el, negOne);
1555 builder.createStore(loc, prev, tmpAddr);
1556 cloneRegionBodyInto(partialDtorBlock, prev);
1557 builder.createYield(loc);
1558 });
1559 cir::YieldOp::create(builder, loc);
1560 });
1561 cir::YieldOp::create(b, loc);
1562 });
1563 } else {
1564 emitCtorDtorLoop();
1565 }
1566
1567 if (ifOp)
1568 cir::YieldOp::create(builder, loc);
1569
1570 op->erase();
1571}
1572
1573void LoweringPreparePass::lowerArrayDtor(cir::ArrayDtor op) {
1574 CIRBaseBuilderTy builder(getContext());
1575 builder.setInsertionPointAfter(op.getOperation());
1576
1577 mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
1578
1579 if (op.getNumElements()) {
1580 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(),
1581 op.getNumElements(), /*arrayLen=*/0,
1582 /*isCtor=*/false);
1583 return;
1584 }
1585
1586 auto arrayLen =
1587 mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
1588 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(),
1589 /*numElements=*/nullptr, arrayLen,
1590 /*isCtor=*/false);
1591}
1592
1593void LoweringPreparePass::lowerArrayCtor(cir::ArrayCtor op) {
1594 cir::CIRBaseBuilderTy builder(getContext());
1595 builder.setInsertionPointAfter(op.getOperation());
1596
1597 mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
1598
1599 if (op.getNumElements()) {
1600 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(),
1601 op.getNumElements(), /*arrayLen=*/0,
1602 /*isCtor=*/true);
1603 return;
1604 }
1605
1606 auto arrayLen =
1607 mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
1608 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(),
1609 /*numElements=*/nullptr, arrayLen,
1610 /*isCtor=*/true);
1611}
1612
1613void LoweringPreparePass::lowerTrivialCopyCall(cir::CallOp op) {
1614 cir::FuncOp funcOp = getCalledFunction(op);
1615 if (!funcOp)
1616 return;
1617
1618 std::optional<cir::CtorKind> ctorKind = funcOp.getCxxConstructorKind();
1619 if (ctorKind && *ctorKind == cir::CtorKind::Copy &&
1620 funcOp.isCxxTrivialMemberFunction()) {
1621 // Replace the trivial copy constructor call with a `CopyOp`
1622 CIRBaseBuilderTy builder(getContext());
1623 mlir::ValueRange operands = op.getOperands();
1624 mlir::Value dest = operands[0];
1625 mlir::Value src = operands[1];
1626 builder.setInsertionPoint(op);
1627 builder.createCopy(dest, src);
1628 op.erase();
1629 }
1630}
1631
1632void LoweringPreparePass::lowerStoreOfConstAggregate(cir::StoreOp op) {
1633 // Check if the value operand is a cir.const with aggregate type.
1634 auto constOp = op.getValue().getDefiningOp<cir::ConstantOp>();
1635 if (!constOp)
1636 return;
1637
1638 mlir::Type ty = constOp.getType();
1639 if (!mlir::isa<cir::ArrayType, cir::RecordType>(ty))
1640 return;
1641
1642 // Only transform stores to local variables (backed by cir.alloca).
1643 // Stores to other addresses (e.g. base_class_addr) should not be
1644 // transformed as they may be partial initializations.
1645 auto alloca = op.getAddr().getDefiningOp<cir::AllocaOp>();
1646 if (!alloca)
1647 return;
1648
1649 mlir::TypedAttr constant = constOp.getValue();
1650
1651 // OG implements several optimization tiers for constant aggregate
1652 // initialization. For now we always create a global constant + memcpy
1653 // (shouldCreateMemCpyFromGlobal). Future work can add the intermediate
1654 // tiers.
1658
1659 // Get function name from parent cir.func.
1660 auto func = op->getParentOfType<cir::FuncOp>();
1661 if (!func)
1662 return;
1663 llvm::StringRef funcName = func.getSymName();
1664
1665 // Get variable name from the alloca.
1666 llvm::StringRef varName = alloca.getName();
1667
1668 // Build name: __const.<func>.<var>
1669 std::string name = ("__const." + funcName + "." + varName).str();
1670
1671 // Create the global constant.
1672 CIRBaseBuilderTy builder(getContext());
1673
1674 // Use InsertionGuard to create the global at module level.
1675 builder.setInsertionPointToStart(mlirModule.getBody());
1676
1677 // If a global with this name already exists (e.g. CIRGen materializes
1678 // constexpr locals as globals when their address is taken), reuse it.
1679 if (!mlir::SymbolTable::lookupSymbolIn(
1680 mlirModule, mlir::StringAttr::get(&getContext(), name))) {
1681 auto gv = cir::GlobalOp::create(
1682 builder, op.getLoc(), name, ty,
1683 /*isConstant=*/true,
1684 cir::LangAddressSpaceAttr::get(&getContext(),
1685 cir::LangAddressSpace::Default),
1686 cir::GlobalLinkageKind::PrivateLinkage);
1687 mlir::SymbolTable::setSymbolVisibility(
1688 gv, mlir::SymbolTable::Visibility::Private);
1689 gv.setInitialValueAttr(constant);
1690 }
1691
1692 // Now replace the store with get_global + copy.
1693 builder.setInsertionPoint(op);
1694
1695 auto ptrTy = cir::PointerType::get(ty);
1696 mlir::Value globalPtr =
1697 cir::GetGlobalOp::create(builder, op.getLoc(), ptrTy, name);
1698
1699 // Replace store with copy.
1700 builder.createCopy(op.getAddr(), globalPtr);
1701
1702 // Erase the original store.
1703 op.erase();
1704
1705 // Erase the cir.const if it has no remaining users.
1706 if (constOp.use_empty())
1707 constOp.erase();
1708}
1709
1710void LoweringPreparePass::runOnOp(mlir::Operation *op) {
1711 if (auto arrayCtor = dyn_cast<cir::ArrayCtor>(op)) {
1712 lowerArrayCtor(arrayCtor);
1713 } else if (auto arrayDtor = dyn_cast<cir::ArrayDtor>(op)) {
1714 lowerArrayDtor(arrayDtor);
1715 } else if (auto cast = mlir::dyn_cast<cir::CastOp>(op)) {
1716 lowerCastOp(cast);
1717 } else if (auto complexDiv = mlir::dyn_cast<cir::ComplexDivOp>(op)) {
1718 lowerComplexDivOp(complexDiv);
1719 } else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op)) {
1720 lowerComplexMulOp(complexMul);
1721 } else if (auto glob = mlir::dyn_cast<cir::GlobalOp>(op)) {
1722 lowerGlobalOp(glob);
1723 } else if (auto getGlobal = mlir::dyn_cast<cir::GetGlobalOp>(op)) {
1724 // Handle static local variables with guard variables.
1725 // Only process GetGlobalOps inside function bodies, not in GlobalOp
1726 // regions.
1727 if (getGlobal.getStaticLocal() &&
1728 getGlobal->getParentOfType<cir::FuncOp>()) {
1729 auto globalOp = mlir::dyn_cast_or_null<cir::GlobalOp>(
1730 mlir::SymbolTable::lookupNearestSymbolFrom(getGlobal,
1731 getGlobal.getNameAttr()));
1732 // Only process if the GlobalOp has static_local and the ctor region is
1733 // not empty. After handleStaticLocal processes a static local, the ctor
1734 // region is cleared. GetGlobalOps that were spliced from the ctor region
1735 // into the function will be skipped on subsequent iterations.
1736 if (globalOp && globalOp.getStaticLocalGuard() &&
1737 !globalOp.getCtorRegion().empty())
1738 handleStaticLocal(globalOp, getGlobal);
1739 }
1740 } else if (auto unaryOp = mlir::dyn_cast<cir::UnaryOpInterface>(op)) {
1741 lowerUnaryOp(unaryOp);
1742 } else if (auto callOp = dyn_cast<cir::CallOp>(op)) {
1743 lowerTrivialCopyCall(callOp);
1744 } else if (auto storeOp = dyn_cast<cir::StoreOp>(op)) {
1745 lowerStoreOfConstAggregate(storeOp);
1746 } else if (auto fnOp = dyn_cast<cir::FuncOp>(op)) {
1747 if (auto globalCtor = fnOp.getGlobalCtorPriority())
1748 globalCtorList.emplace_back(fnOp.getName(), globalCtor.value());
1749 else if (auto globalDtor = fnOp.getGlobalDtorPriority())
1750 globalDtorList.emplace_back(fnOp.getName(), globalDtor.value());
1751
1752 if (mlir::Attribute attr =
1753 fnOp->getAttr(cir::CUDAKernelNameAttr::getMnemonic())) {
1754 auto kernelNameAttr = dyn_cast<CUDAKernelNameAttr>(attr);
1755 llvm::StringRef kernelName = kernelNameAttr.getKernelName();
1756 cudaKernelMap[kernelName] = fnOp;
1757 }
1758 } else if (auto threeWayCmp = dyn_cast<cir::CmpThreeWayOp>(op)) {
1759 lowerThreeWayCmpOp(threeWayCmp);
1760 }
1761}
1762
1763static llvm::StringRef getCUDAPrefix(clang::ASTContext *astCtx) {
1764 if (astCtx->getLangOpts().HIP)
1765 return "hip";
1766 return "cuda";
1767}
1768
1769static std::string addUnderscoredPrefix(llvm::StringRef prefix,
1770 llvm::StringRef name) {
1771 return ("__" + prefix + name).str();
1772}
1773
1774/// Creates a global constructor function for the module:
1775///
1776/// For CUDA:
1777/// \code
1778/// void __cuda_module_ctor() {
1779/// Handle = __cudaRegisterFatBinary(GpuBinaryBlob);
1780/// __cuda_register_globals(Handle);
1781/// }
1782/// \endcode
1783///
1784/// For HIP:
1785/// \code
1786/// void __hip_module_ctor() {
1787/// if (__hip_gpubin_handle == 0) {
1788/// __hip_gpubin_handle = __hipRegisterFatBinary(GpuBinaryBlob);
1789/// __hip_register_globals(__hip_gpubin_handle);
1790/// }
1791/// }
1792/// \endcode
1793void LoweringPreparePass::buildCUDAModuleCtor() {
1794 bool isHIP = astCtx->getLangOpts().HIP;
1795
1796 if (isHIP)
1798 if (astCtx->getLangOpts().GPURelocatableDeviceCode)
1799 llvm_unreachable("GPU RDC NYI");
1800
1801 // For CUDA without -fgpu-rdc, it's safe to stop generating ctor
1802 // if there's nothing to register.
1803 if (cudaKernelMap.empty())
1804 return;
1805
1806 // There's no device-side binary, so no need to proceed for CUDA.
1807 // HIP has to create an external symbol in this case, which is NYI.
1808 mlir::Attribute cudaBinaryHandleAttr =
1809 mlirModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName());
1810 if (!cudaBinaryHandleAttr) {
1811 if (isHIP)
1813 return;
1814 }
1815
1816 llvm::StringRef cudaGPUBinaryName =
1817 mlir::cast<CUDABinaryHandleAttr>(cudaBinaryHandleAttr)
1818 .getName()
1819 .getValue();
1820
1821 llvm::vfs::FileSystem &vfs =
1823 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> gpuBinaryOrErr =
1824 vfs.getBufferForFile(cudaGPUBinaryName);
1825 if (std::error_code ec = gpuBinaryOrErr.getError()) {
1826 mlirModule->emitError("cannot open GPU binary file: " + cudaGPUBinaryName +
1827 ": " + ec.message());
1828 return;
1829 }
1830 std::unique_ptr<llvm::MemoryBuffer> gpuBinary =
1831 std::move(gpuBinaryOrErr.get());
1832
1833 // Set up common types and builder.
1834 llvm::StringRef cudaPrefix = getCUDAPrefix(astCtx);
1835 mlir::Location loc = mlirModule->getLoc();
1836 CIRBaseBuilderTy builder(getContext());
1837 builder.setInsertionPointToStart(mlirModule.getBody());
1838
1839 Type voidTy = builder.getVoidTy();
1840 PointerType voidPtrTy = builder.getVoidPtrTy();
1841 PointerType voidPtrPtrTy = builder.getPointerTo(voidPtrTy);
1842 IntType intTy = builder.getSIntNTy(32);
1843 IntType charTy = cir::IntType::get(&getContext(), astCtx->getCharWidth(),
1844 /*isSigned=*/false);
1845
1846 // --- Create fatbin globals ---
1847
1848 // The section names are different for MAC OS X.
1849 llvm::StringRef fatbinConstName =
1850 astCtx->getLangOpts().HIP ? ".hip_fatbin" : ".nv_fatbin";
1851
1852 llvm::StringRef fatbinSectionName =
1853 astCtx->getLangOpts().HIP ? ".hipFatBinSegment" : ".nvFatBinSegment";
1854
1855 // Create the fatbin string constant with GPU binary contents.
1856 auto fatbinType =
1857 ArrayType::get(&getContext(), charTy, gpuBinary->getBuffer().size());
1858 std::string fatbinStrName = addUnderscoredPrefix(cudaPrefix, "_fatbin_str");
1859 GlobalOp fatbinStr = GlobalOp::create(builder, loc, fatbinStrName, fatbinType,
1860 /*isConstant=*/true, {},
1861 GlobalLinkageKind::PrivateLinkage);
1862 fatbinStr.setAlignment(8);
1863 fatbinStr.setInitialValueAttr(cir::ConstArrayAttr::get(
1864 fatbinType, builder.getStringAttr(gpuBinary->getBuffer())));
1865 fatbinStr.setSection(fatbinConstName);
1866 fatbinStr.setPrivate();
1867
1868 // Create the fatbin wrapper struct:
1869 // struct { int magic; int version; void *fatbin; void *unused; };
1870 auto fatbinWrapperType = RecordType::get(
1871 &getContext(), {intTy, intTy, voidPtrTy, voidPtrTy},
1872 /*packed=*/false, /*padded=*/false, RecordType::RecordKind::Struct);
1873 std::string fatbinWrapperName =
1874 addUnderscoredPrefix(cudaPrefix, "_fatbin_wrapper");
1875 GlobalOp fatbinWrapper = GlobalOp::create(
1876 builder, loc, fatbinWrapperName, fatbinWrapperType,
1877 /*isConstant=*/true, {}, GlobalLinkageKind::PrivateLinkage);
1878 fatbinWrapper.setSection(fatbinSectionName);
1879
1880 constexpr unsigned cudaFatMagic = 0x466243b1;
1881 constexpr unsigned hipFatMagic = 0x48495046;
1882 unsigned fatMagic = isHIP ? hipFatMagic : cudaFatMagic;
1883
1884 auto magicInit = IntAttr::get(intTy, fatMagic);
1885 auto versionInit = IntAttr::get(intTy, 1);
1886 auto fatbinStrSymbol =
1887 mlir::FlatSymbolRefAttr::get(fatbinStr.getSymNameAttr());
1888 auto fatbinInit = GlobalViewAttr::get(voidPtrTy, fatbinStrSymbol);
1889 mlir::TypedAttr unusedInit = builder.getConstNullPtrAttr(voidPtrTy);
1890 fatbinWrapper.setInitialValueAttr(cir::ConstRecordAttr::get(
1891 fatbinWrapperType,
1892 mlir::ArrayAttr::get(&getContext(),
1893 {magicInit, versionInit, fatbinInit, unusedInit})));
1894
1895 // Create the GPU binary handle global variable.
1896 std::string gpubinHandleName =
1897 addUnderscoredPrefix(cudaPrefix, "_gpubin_handle");
1898
1899 GlobalOp gpuBinHandle = GlobalOp::create(
1900 builder, loc, gpubinHandleName, voidPtrPtrTy,
1901 /*isConstant=*/false, {}, cir::GlobalLinkageKind::InternalLinkage);
1902 gpuBinHandle.setInitialValueAttr(builder.getConstNullPtrAttr(voidPtrPtrTy));
1903 gpuBinHandle.setPrivate();
1904
1905 // Declare this function:
1906 // void **__{cuda|hip}RegisterFatBinary(void *);
1907
1908 std::string regFuncName =
1909 addUnderscoredPrefix(cudaPrefix, "RegisterFatBinary");
1910 FuncType regFuncType = FuncType::get({voidPtrTy}, voidPtrPtrTy);
1911 cir::FuncOp regFunc =
1912 buildRuntimeFunction(builder, regFuncName, loc, regFuncType);
1913
1914 std::string moduleCtorName = addUnderscoredPrefix(cudaPrefix, "_module_ctor");
1915 cir::FuncOp moduleCtor = buildRuntimeFunction(
1916 builder, moduleCtorName, loc, FuncType::get({}, voidTy),
1917 GlobalLinkageKind::InternalLinkage);
1918
1919 globalCtorList.emplace_back(moduleCtorName,
1920 cir::GlobalCtorAttr::getDefaultPriority());
1921 builder.setInsertionPointToStart(moduleCtor.addEntryBlock());
1923 if (isHIP) {
1924 llvm_unreachable("HIP Module Constructor Support");
1925 } else if (!astCtx->getLangOpts().GPURelocatableDeviceCode) {
1926
1927 // --- Create CUDA CTOR-DTOR ---
1928 // Register binary with CUDA runtime. This is substantially different in
1929 // default mode vs. separate compilation.
1930 // Corresponding code:
1931 // gpuBinaryHandle = __cudaRegisterFatBinary(&fatbinWrapper);
1932 mlir::Value wrapper = builder.createGetGlobal(fatbinWrapper);
1933 mlir::Value fatbinVoidPtr = builder.createBitcast(wrapper, voidPtrTy);
1934 cir::CallOp gpuBinaryHandleCall =
1935 builder.createCallOp(loc, regFunc, fatbinVoidPtr);
1936 mlir::Value gpuBinaryHandle = gpuBinaryHandleCall.getResult();
1937 // Store the value back to the global `__cuda_gpubin_handle`.
1938 mlir::Value gpuBinaryHandleGlobal = builder.createGetGlobal(gpuBinHandle);
1939 builder.createStore(loc, gpuBinaryHandle, gpuBinaryHandleGlobal);
1940
1941 // --- Generate __cuda_register_globals and call it ---
1942 if (std::optional<FuncOp> regGlobal = buildCUDARegisterGlobals()) {
1943 builder.createCallOp(loc, *regGlobal, gpuBinaryHandle);
1944 }
1945
1946 // From CUDA 10.1 onwards, we must call this function to end registration:
1947 // void __cudaRegisterFatBinaryEnd(void **fatbinHandle);
1948 // This is CUDA-specific, so no need to use `addUnderscoredPrefix`.
1950 astCtx->getTargetInfo().getSDKVersion(),
1952 cir::CIRBaseBuilderTy globalBuilder(getContext());
1953 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
1954 FuncOp endFunc =
1955 buildRuntimeFunction(globalBuilder, "__cudaRegisterFatBinaryEnd", loc,
1956 FuncType::get({voidPtrPtrTy}, voidTy));
1957 builder.createCallOp(loc, endFunc, gpuBinaryHandle);
1958 }
1959 } else
1960 llvm_unreachable("GPU RDC NYI");
1961
1962 // Create destructor and register it with atexit() the way NVCC does it. Doing
1963 // it during regular destructor phase worked in CUDA before 9.2 but results in
1964 // double-free in 9.2.
1965 if (std::optional<FuncOp> dtor = buildCUDAModuleDtor()) {
1966
1967 // extern "C" int atexit(void (*f)(void));
1968 cir::CIRBaseBuilderTy globalBuilder(getContext());
1969 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
1970 FuncOp atexit = buildRuntimeFunction(
1971 globalBuilder, "atexit", loc,
1972 FuncType::get(PointerType::get(dtor->getFunctionType()), intTy));
1973 mlir::Value dtorFunc = GetGlobalOp::create(
1974 builder, loc, PointerType::get(dtor->getFunctionType()),
1975 mlir::FlatSymbolRefAttr::get(dtor->getSymNameAttr()));
1976 builder.createCallOp(loc, atexit, dtorFunc);
1977 }
1978 cir::ReturnOp::create(builder, loc);
1979}
1980
1981std::optional<FuncOp> LoweringPreparePass::buildCUDAModuleDtor() {
1982 if (!mlirModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName()))
1983 return {};
1984
1985 llvm::StringRef prefix = getCUDAPrefix(astCtx);
1986
1987 VoidType voidTy = VoidType::get(&getContext());
1988 PointerType voidPtrPtrTy = PointerType::get(PointerType::get(voidTy));
1989
1990 mlir::Location loc = mlirModule.getLoc();
1991
1992 cir::CIRBaseBuilderTy builder(getContext());
1993 builder.setInsertionPointToStart(mlirModule.getBody());
1994
1995 // define: void __cudaUnregisterFatBinary(void ** handle);
1996 std::string unregisterFuncName =
1997 addUnderscoredPrefix(prefix, "UnregisterFatBinary");
1998 FuncOp unregisterFunc = buildRuntimeFunction(
1999 builder, unregisterFuncName, loc, FuncType::get({voidPtrPtrTy}, voidTy));
2000
2001 // void __cuda_module_dtor();
2002 // Despite the name, OG doesn't treat it as a destructor, so it shouldn't be
2003 // put into globalDtorList. If it were a real dtor, then it would cause
2004 // double free above CUDA 9.2. The way to use it is to manually call
2005 // atexit() at end of module ctor.
2006 std::string dtorName = addUnderscoredPrefix(prefix, "_module_dtor");
2007 FuncOp dtor =
2008 buildRuntimeFunction(builder, dtorName, loc, FuncType::get({}, voidTy),
2009 GlobalLinkageKind::InternalLinkage);
2010
2011 builder.setInsertionPointToStart(dtor.addEntryBlock());
2012
2013 // For dtor, we only need to call:
2014 // __cudaUnregisterFatBinary(__cuda_gpubin_handle);
2015
2016 std::string gpubinName = addUnderscoredPrefix(prefix, "_gpubin_handle");
2017 GlobalOp gpubinGlobal = cast<GlobalOp>(mlirModule.lookupSymbol(gpubinName));
2018 mlir::Value gpubinAddress = builder.createGetGlobal(gpubinGlobal);
2019 mlir::Value gpubin = builder.createLoad(loc, gpubinAddress);
2020 builder.createCallOp(loc, unregisterFunc, gpubin);
2021 ReturnOp::create(builder, loc);
2022
2023 return dtor;
2024}
2025
2026std::optional<FuncOp> LoweringPreparePass::buildCUDARegisterGlobals() {
2027 // There is nothing to register.
2028 if (cudaKernelMap.empty())
2029 return {};
2030
2031 cir::CIRBaseBuilderTy builder(getContext());
2032 builder.setInsertionPointToStart(mlirModule.getBody());
2033
2034 mlir::Location loc = mlirModule.getLoc();
2035 llvm::StringRef cudaPrefix = getCUDAPrefix(astCtx);
2036
2037 auto voidTy = VoidType::get(&getContext());
2038 auto voidPtrTy = PointerType::get(voidTy);
2039 auto voidPtrPtrTy = PointerType::get(voidPtrTy);
2040
2041 // Create the function:
2042 // void __cuda_register_globals(void **fatbinHandle)
2043 std::string regGlobalFuncName =
2044 addUnderscoredPrefix(cudaPrefix, "_register_globals");
2045 auto regGlobalFuncTy = FuncType::get({voidPtrPtrTy}, voidTy);
2046 FuncOp regGlobalFunc =
2047 buildRuntimeFunction(builder, regGlobalFuncName, loc, regGlobalFuncTy,
2048 /*linkage=*/GlobalLinkageKind::InternalLinkage);
2049 builder.setInsertionPointToStart(regGlobalFunc.addEntryBlock());
2050
2051 buildCUDARegisterGlobalFunctions(builder, regGlobalFunc);
2052 // TODO: Handle shadow registration
2054
2055 ReturnOp::create(builder, loc);
2056 return regGlobalFunc;
2057}
2058
2059void LoweringPreparePass::buildCUDARegisterGlobalFunctions(
2060 cir::CIRBaseBuilderTy &builder, FuncOp regGlobalFunc) {
2061 mlir::Location loc = mlirModule.getLoc();
2062 llvm::StringRef cudaPrefix = getCUDAPrefix(astCtx);
2063 cir::CIRDataLayout dataLayout(mlirModule);
2064
2065 auto voidTy = VoidType::get(&getContext());
2066 auto voidPtrTy = PointerType::get(voidTy);
2067 auto voidPtrPtrTy = PointerType::get(voidPtrTy);
2068 IntType intTy = builder.getSIntNTy(32);
2069 IntType charTy = cir::IntType::get(&getContext(), astCtx->getCharWidth(),
2070 /*isSigned=*/false);
2071
2072 // Extract the GPU binary handle argument.
2073 mlir::Value fatbinHandle = *regGlobalFunc.args_begin();
2074
2075 cir::CIRBaseBuilderTy globalBuilder(getContext());
2076 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
2077
2078 // Declare CUDA internal functions:
2079 // int __cudaRegisterFunction(
2080 // void **fatbinHandle,
2081 // const char *hostFunc,
2082 // char *deviceFunc,
2083 // const char *deviceName,
2084 // int threadLimit,
2085 // uint3 *tid, uint3 *bid, dim3 *bDim, dim3 *gDim,
2086 // int *wsize
2087 // )
2088 // OG doesn't care about the types at all. They're treated as void*.
2089
2090 FuncOp cudaRegisterFunction = buildRuntimeFunction(
2091 globalBuilder, addUnderscoredPrefix(cudaPrefix, "RegisterFunction"), loc,
2092 FuncType::get({voidPtrPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, intTy,
2093 voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy},
2094 intTy));
2095
2096 auto makeConstantString = [&](llvm::StringRef str) -> GlobalOp {
2097 auto strType = ArrayType::get(&getContext(), charTy, 1 + str.size());
2098 auto tmpString = cir::GlobalOp::create(
2099 globalBuilder, loc, (".str" + str).str(), strType,
2100 /*isConstant=*/true, {},
2101 /*linkage=*/cir::GlobalLinkageKind::PrivateLinkage);
2102
2103 // We must make the string zero-terminated.
2104 tmpString.setInitialValueAttr(ConstArrayAttr::get(
2105 strType, StringAttr::get(&getContext(), str + "\0")));
2106 tmpString.setPrivate();
2107 return tmpString;
2108 };
2109
2110 cir::ConstantOp cirNullPtr = builder.getNullPtr(voidPtrTy, loc);
2111 bool isHIP = astCtx->getLangOpts().HIP;
2112 for (auto kernelName : cudaKernelMap.keys()) {
2113 FuncOp deviceStub = cudaKernelMap[kernelName];
2114 GlobalOp deviceFuncStr = makeConstantString(kernelName);
2115 mlir::Value deviceFunc = builder.createBitcast(
2116 builder.createGetGlobal(deviceFuncStr), voidPtrTy);
2117
2118 if (isHIP) {
2119 llvm_unreachable("HIP kernel registration NYI");
2120 } else {
2121 mlir::Value hostFunc = builder.createBitcast(
2122 GetGlobalOp::create(
2123 builder, loc, PointerType::get(deviceStub.getFunctionType()),
2124 mlir::FlatSymbolRefAttr::get(deviceStub.getSymNameAttr())),
2125 voidPtrTy);
2126 builder.createCallOp(
2127 loc, cudaRegisterFunction,
2128 {fatbinHandle, hostFunc, deviceFunc, deviceFunc,
2129 ConstantOp::create(builder, loc, IntAttr::get(intTy, -1)),
2130 cirNullPtr, cirNullPtr, cirNullPtr, cirNullPtr, cirNullPtr});
2131 }
2132 }
2133}
2134
2135void LoweringPreparePass::runOnOperation() {
2136 mlir::Operation *op = getOperation();
2137 if (isa<::mlir::ModuleOp>(op))
2138 mlirModule = cast<::mlir::ModuleOp>(op);
2139
2140 llvm::SmallVector<mlir::Operation *> opsToTransform;
2141
2142 op->walk([&](mlir::Operation *op) {
2143 if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
2144 cir::ComplexMulOp, cir::ComplexDivOp, cir::DynamicCastOp,
2145 cir::FuncOp, cir::CallOp, cir::GetGlobalOp, cir::GlobalOp,
2146 cir::StoreOp, cir::CmpThreeWayOp, cir::IncOp, cir::DecOp,
2147 cir::MinusOp, cir::NotOp>(op))
2148 opsToTransform.push_back(op);
2149 });
2150
2151 for (mlir::Operation *o : opsToTransform)
2152 runOnOp(o);
2153
2154 buildCXXGlobalInitFunc();
2155 if (astCtx->getLangOpts().CUDA && !astCtx->getLangOpts().CUDAIsDevice)
2156 buildCUDAModuleCtor();
2157
2158 buildGlobalCtorDtorList();
2159}
2160
2161std::unique_ptr<Pass> mlir::createLoweringPreparePass() {
2162 return std::make_unique<LoweringPreparePass>();
2163}
2164
2165std::unique_ptr<Pass>
2167 auto pass = std::make_unique<LoweringPreparePass>();
2168 pass->setASTContext(astCtx);
2169 return std::move(pass);
2170}
Defines the clang::ASTContext interface.
static llvm::FunctionCallee getGuardReleaseFn(CodeGenModule &CGM, llvm::PointerType *GuardPtrTy)
static llvm::FunctionCallee getGuardAcquireFn(CodeGenModule &CGM, llvm::PointerType *GuardPtrTy)
static mlir::Value buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static llvm::StringRef getComplexDivLibCallName(llvm::APFloat::Semantics semantics)
static llvm::SmallVector< mlir::Attribute > prepareCtorDtorAttrList(mlir::MLIRContext *context, llvm::ArrayRef< std::pair< std::string, uint32_t > > list)
static llvm::StringRef getComplexMulLibCallName(llvm::APFloat::Semantics semantics)
static mlir::Value buildComplexBinOpLibCall(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, llvm::StringRef(*libFuncNameGetter)(llvm::APFloat::Semantics), mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static mlir::Value lowerComplexMul(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexMulOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static std::string addUnderscoredPrefix(llvm::StringRef prefix, llvm::StringRef name)
static SmallString< 128 > getTransformedFileName(mlir::ModuleOp mlirModule)
static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind scalarCastKind)
static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder, clang::ASTContext *astCtx, mlir::Operation *op, mlir::Type eltTy, mlir::Value addr, mlir::Value numElements, uint64_t arrayLen, bool isCtor)
Lower a cir.array.ctor or cir.array.dtor into a do-while loop that iterates over every element.
static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind elemToBoolKind)
static mlir::Value buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static llvm::StringRef getCUDAPrefix(clang::ASTContext *astCtx)
static cir::FuncOp getCalledFunction(cir::CallOp callOp)
Return the FuncOp called by callOp.
static mlir::Type higherPrecisionElementTypeForComplexArithmetic(mlir::MLIRContext &context, clang::ASTContext &cc, CIRBaseBuilderTy &builder, mlir::Type elementType)
static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op)
static mlir::Value lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag, mlir::MLIRContext &mlirCx, clang::ASTContext &cc)
Defines the clang::Module class, which describes a module in the source code.
Defines the SourceManager interface.
Defines various enumerations that describe declaration and type specifiers.
Defines the TargetCXXABI class, which abstracts details of the C++ ABI that we're targeting.
__device__ __2f16 b
__device__ __2f16 float c
mlir::Value createDiv(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::TypedAttr getConstNullPtrAttr(mlir::Type t)
mlir::Value createDec(mlir::Location loc, mlir::Value input, bool nsw=false)
mlir::Value createLogicalOr(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createSub(mlir::Location loc, mlir::Value lhs, mlir::Value rhs, OverflowBehavior ob=OverflowBehavior::None)
cir::ConditionOp createCondition(mlir::Value condition)
Create a loop condition.
mlir::Value createInc(mlir::Location loc, mlir::Value input, bool nsw=false)
cir::CopyOp createCopy(mlir::Value dst, mlir::Value src, bool isVolatile=false, bool skipTailPadding=false)
Create a copy with inferred length.
cir::VoidType getVoidTy()
cir::ConstantOp getNullValue(mlir::Type ty, mlir::Location loc)
mlir::Value createCast(mlir::Location loc, cir::CastKind kind, mlir::Value src, mlir::Type newTy)
mlir::Value createAdd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs, OverflowBehavior ob=OverflowBehavior::None)
cir::PointerType getPointerTo(mlir::Type ty)
mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand)
cir::ConstantOp getNullPtr(mlir::Type ty, mlir::Location loc)
cir::DoWhileOp createDoWhile(mlir::Location loc, llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> condBuilder, llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> bodyBuilder)
Create a do-while operation.
mlir::Value getSignedInt(mlir::Location loc, int64_t val, unsigned numBits)
mlir::Value createBitcast(mlir::Value src, mlir::Type newTy)
mlir::Value createGetGlobal(mlir::Location loc, cir::GlobalOp global, bool threadLocal=false)
cir::CmpOp createCompare(mlir::Location loc, cir::CmpOpKind kind, mlir::Value lhs, mlir::Value rhs)
mlir::IntegerAttr getAlignmentAttr(clang::CharUnits alignment)
mlir::Value createSelect(mlir::Location loc, mlir::Value condition, mlir::Value trueValue, mlir::Value falseValue)
mlir::Value createMul(mlir::Location loc, mlir::Value lhs, mlir::Value rhs, OverflowBehavior ob=OverflowBehavior::None)
cir::LoadOp createLoad(mlir::Location loc, mlir::Value ptr, bool isVolatile=false, uint64_t alignment=0)
mlir::Value createMinus(mlir::Location loc, mlir::Value input, bool nsw=false)
cir::ConstantOp getConstantInt(mlir::Location loc, mlir::Type ty, int64_t value)
mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real, mlir::Value imag)
cir::PointerType getVoidPtrTy(clang::LangAS langAS=clang::LangAS::Default)
mlir::Value createIsNaN(mlir::Location loc, mlir::Value operand)
cir::IntType getSIntNTy(int n)
mlir::Value createAlignedLoad(mlir::Location loc, mlir::Value ptr, uint64_t alignment)
cir::CallOp createCallOp(mlir::Location loc, mlir::SymbolRefAttr callee, mlir::Type returnType, mlir::ValueRange operands, llvm::ArrayRef< mlir::NamedAttribute > attrs={}, llvm::ArrayRef< mlir::NamedAttrList > argAttrs={}, llvm::ArrayRef< mlir::NamedAttribute > resAttrs={})
cir::StoreOp createStore(mlir::Location loc, mlir::Value val, mlir::Value dst, bool isVolatile=false, mlir::IntegerAttr align={}, cir::SyncScopeKindAttr scope={}, cir::MemOrderAttr order={})
cir::YieldOp createYield(mlir::Location loc, mlir::ValueRange value={})
Create a yield operation.
mlir::Value createLogicalAnd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType, mlir::Type type, llvm::StringRef name, mlir::IntegerAttr alignment, mlir::Value dynAllocSize)
cir::BoolType getBoolTy()
mlir::Value getUnsignedInt(mlir::Location loc, uint64_t val, unsigned numBits)
mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand)
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
Definition ASTContext.h:226
SourceManager & getSourceManager()
Definition ASTContext.h:859
MangleContext * createMangleContext(const TargetInfo *T=nullptr)
If T is null pointer, assume the target in ASTContext.
const LangOptions & getLangOpts() const
Definition ASTContext.h:952
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
const TargetInfo & getTargetInfo() const
Definition ASTContext.h:917
QualType getSignedSizeType() const
Return the unique signed counterpart of the integer type corresponding to size_t.
Module * getCurrentNamedModule() const
Get module under construction, nullptr if this is not a C++20 module.
uint64_t getCharWidth() const
Return the size of the character type, in bits.
llvm::Align getAsAlign() const
getAsAlign - Returns Quantity as a valid llvm::Align, Beware llvm::Align assumes power of two 8-bit b...
Definition CharUnits.h:189
static CharUnits fromQuantity(QuantityType Quantity)
fromQuantity - Construct a CharUnits quantity from a raw integer type.
Definition CharUnits.h:63
llvm::vfs::FileSystem & getVirtualFileSystem() const
bool isModuleImplementation() const
Is this a module implementation.
Definition Module.h:770
FileManager & getFileManager() const
Exposes information about the current target.
Definition TargetInfo.h:227
unsigned getMaxAtomicInlineWidth() const
Return the maximum width lock-free atomic operation which can be inlined given the supported features...
Definition TargetInfo.h:859
const llvm::fltSemantics & getDoubleFormat() const
Definition TargetInfo.h:804
const llvm::fltSemantics & getHalfFormat() const
Definition TargetInfo.h:789
const llvm::fltSemantics & getBFloat16Format() const
Definition TargetInfo.h:799
const llvm::fltSemantics & getLongDoubleFormat() const
Definition TargetInfo.h:810
const llvm::fltSemantics & getFloatFormat() const
Definition TargetInfo.h:794
const llvm::fltSemantics & getFloat128Format() const
Definition TargetInfo.h:818
const llvm::VersionTuple & getSDKVersion() const
Defines the clang::TargetInfo interface.
const internal::VariadicDynCastAllOfMatcher< Decl, VarDecl > varDecl
Matches variable declarations.
bool isHIP(ID Id)
isHIP - Is this a HIP input.
Definition Types.cpp:291
RangeSelector name(std::string ID)
Given a node with a "name", (like NamedDecl, DeclRefExpr, CxxCtorInitializer, and TypeLoc) selects th...
bool isTemplateInstantiation(TemplateSpecializationKind Kind)
Determine whether this template specialization kind refers to an instantiation of an entity (as oppos...
Definition Specifiers.h:213
bool CudaFeatureEnabled(llvm::VersionTuple, CudaFeature)
Definition Cuda.cpp:163
LLVM_READONLY bool isPreprocessingNumberBody(unsigned char c)
Return true if this is the body character of a C preprocessing number, which is [a-zA-Z0-9_.
Definition CharInfo.h:168
@ CUDA_USES_FATBIN_REGISTER_END
Definition Cuda.h:80
unsigned int uint32_t
std::unique_ptr< Pass > createLoweringPreparePass()
static bool opGlobalThreadLocal()
static bool hipModuleCtor()
static bool guardAbortOnException()
static bool opGlobalAnnotations()
static bool opGlobalCtorPriority()
static bool shouldSplitConstantStore()
static bool shouldUseMemSetToInitialize()
static bool opFuncExtraAttrs()
static bool shouldUseBZeroPlusStoresToInitialize()
static bool globalRegistration()
static bool fastMathFlags()
static bool astVarDeclInterface()