clang 23.0.0git
LoweringPrepare.cpp
Go to the documentation of this file.
1//===- LoweringPrepare.cpp - pareparation work for LLVM lowering ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PassDetail.h"
10#include "mlir/IR/Attributes.h"
11#include "mlir/IR/BuiltinAttributeInterfaces.h"
12#include "mlir/IR/IRMapping.h"
13#include "mlir/IR/Location.h"
14#include "mlir/IR/Value.h"
16#include "clang/AST/Mangle.h"
17#include "clang/Basic/Cuda.h"
18#include "clang/Basic/Module.h"
32#include "llvm/ADT/StringRef.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/IR/Instructions.h"
35#include "llvm/Support/ErrorHandling.h"
36#include "llvm/Support/MemoryBuffer.h"
37#include "llvm/Support/Path.h"
38#include "llvm/Support/VirtualFileSystem.h"
39
40#include <memory>
41#include <optional>
42
43using namespace mlir;
44using namespace cir;
45
46namespace mlir {
47#define GEN_PASS_DEF_LOWERINGPREPARE
48#include "clang/CIR/Dialect/Passes.h.inc"
49} // namespace mlir
50
51static SmallString<128> getTransformedFileName(mlir::ModuleOp mlirModule) {
52 SmallString<128> fileName;
53
54 if (mlirModule.getSymName())
55 fileName = llvm::sys::path::filename(mlirModule.getSymName()->str());
56
57 if (fileName.empty())
58 fileName = "<null>";
59
60 for (size_t i = 0; i < fileName.size(); ++i) {
61 // Replace everything that's not [a-zA-Z0-9._] with a _. This set happens
62 // to be the set of C preprocessing numbers.
63 if (!clang::isPreprocessingNumberBody(fileName[i]))
64 fileName[i] = '_';
65 }
66
67 return fileName;
68}
69
70/// Return the FuncOp called by `callOp`.
71static cir::FuncOp getCalledFunction(cir::CallOp callOp) {
72 mlir::SymbolRefAttr sym = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
73 callOp.getCallableForCallee());
74 if (!sym)
75 return nullptr;
76 return dyn_cast_or_null<cir::FuncOp>(
77 mlir::SymbolTable::lookupNearestSymbolFrom(callOp, sym));
78}
79
80namespace {
81struct LoweringPreparePass
82 : public impl::LoweringPrepareBase<LoweringPreparePass> {
83 LoweringPreparePass() = default;
84 void runOnOperation() override;
85
86 void runOnOp(mlir::Operation *op, mlir::SymbolTableCollection &symbolTables);
87 void lowerCastOp(cir::CastOp op);
88 void lowerComplexDivOp(cir::ComplexDivOp op);
89 void lowerComplexMulOp(cir::ComplexMulOp op);
90 void lowerUnaryOp(cir::UnaryOpInterface op);
91 void lowerGlobalOp(cir::GlobalOp op);
92 void lowerThreeWayCmpOp(cir::CmpThreeWayOp op);
93 void lowerArrayDtor(cir::ArrayDtor op);
94 void lowerArrayCtor(cir::ArrayCtor op);
95 void lowerTrivialCopyCall(cir::CallOp op);
96 void lowerStoreOfConstAggregate(cir::StoreOp op,
97 mlir::SymbolTableCollection &symbolTables);
98 void lowerLocalInitOp(cir::LocalInitOp op,
99 mlir::SymbolTableCollection &symbolTables);
100
101 /// Return a private constant cir::GlobalOp with the given type and initial
102 /// value, suitable for backing a memcpy-initialized local aggregate.
103 ///
104 /// If a global with `baseName` (or one of its `.<n>` versioned siblings)
105 /// already has a matching type and initial value, that global is reused.
106 /// Otherwise a new global is created with the next available `.<n>` suffix
107 /// (matching CIRGenBuilder::createVersionedGlobal and OGCG behavior).
108 cir::GlobalOp
109 getOrCreateConstAggregateGlobal(CIRBaseBuilderTy &builder,
110 mlir::SymbolTableCollection &symbolTables,
111 mlir::Location loc, llvm::StringRef baseName,
112 mlir::Type ty, mlir::TypedAttr constant);
113
114 /// Build the function that initializes the specified global
115 cir::FuncOp buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op);
116
117 /// Handle the dtor region by registering destructor with __cxa_atexit
118 cir::FuncOp getOrCreateDtorFunc(CIRBaseBuilderTy &builder, cir::GlobalOp op,
119 mlir::Region &dtorRegion,
120 cir::CallOp &dtorCall);
121
122 /// Build a module init function that calls all the dynamic initializers.
123 void buildCXXGlobalInitFunc();
124
125 /// Materialize global ctor/dtor list
126 void buildGlobalCtorDtorList();
127
128 cir::FuncOp buildRuntimeFunction(
129 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
130 cir::FuncType type,
131 cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage);
132
133 cir::GlobalOp getOrCreateRuntimeVariable(
134 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
135 mlir::Type type,
136 cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage,
137 cir::VisibilityKind visibility = cir::VisibilityKind::Default);
138
139 /// ------------
140 /// CUDA registration related
141 /// ------------
142
143 llvm::StringMap<FuncOp> cudaKernelMap;
144
145 /// Build the CUDA module constructor that registers the fat binary
146 /// with the CUDA runtime.
147 void buildCUDAModuleCtor();
148 std::optional<FuncOp> buildCUDAModuleDtor();
149 std::optional<FuncOp> buildCUDARegisterGlobals();
150 void buildCUDARegisterGlobalFunctions(cir::CIRBaseBuilderTy &builder,
151 FuncOp regGlobalFunc);
152
153 /// Handle static local variable initialization with guard variables.
154 void handleStaticLocal(cir::GlobalOp globalOp, cir::LocalInitOp localInitOp);
155
156 /// Get or create __cxa_guard_acquire function.
157 cir::FuncOp getGuardAcquireFn(cir::PointerType guardPtrTy);
158
159 /// Get or create __cxa_guard_release function.
160 cir::FuncOp getGuardReleaseFn(cir::PointerType guardPtrTy);
161
162 /// Create a guard global variable for a static local.
163 cir::GlobalOp createGuardGlobalOp(CIRBaseBuilderTy &builder,
164 mlir::Location loc, llvm::StringRef name,
165 cir::IntType guardTy,
166 cir::GlobalLinkageKind linkage);
167
168 /// Get the guard variable for a static local declaration.
169 cir::GlobalOp getStaticLocalDeclGuardAddress(llvm::StringRef globalSymName) {
170 auto it = staticLocalDeclGuardMap.find(globalSymName);
171 if (it != staticLocalDeclGuardMap.end())
172 return it->second;
173 return nullptr;
174 }
175
176 /// Set the guard variable for a static local declaration.
177 void setStaticLocalDeclGuardAddress(llvm::StringRef globalSymName,
178 cir::GlobalOp guard) {
179 staticLocalDeclGuardMap[globalSymName] = guard;
180 }
181
182 /// Get or create the guard variable for a static local declaration.
183 cir::GlobalOp getOrCreateStaticLocalDeclGuardAddress(
184 CIRBaseBuilderTy &builder, cir::GlobalOp globalOp,
185 cir::ASTVarDeclInterface varDecl, cir::IntType guardTy,
186 clang::CharUnits guardAlignment) {
187 llvm::StringRef globalSymName = globalOp.getSymName();
188 cir::GlobalOp guard = getStaticLocalDeclGuardAddress(globalSymName);
189 if (!guard) {
190 // Get the guard name from the static_local attribute.
191 llvm::StringRef guardName =
192 globalOp.getStaticLocalGuard()->getName().getValue();
193
194 // Create the guard variable with a zero-initializer.
195 guard = createGuardGlobalOp(builder, globalOp->getLoc(), guardName,
196 guardTy, globalOp.getLinkage());
197 guard.setInitialValueAttr(cir::IntAttr::get(guardTy, 0));
198 guard.setDSOLocal(globalOp.getDsoLocal());
199 guard.setAlignment(guardAlignment.getAsAlign().value());
200 guard.setTlsModel(globalOp.getTlsModel());
201
202 // The ABI says: "It is suggested that it be emitted in the same COMDAT
203 // group as the associated data object." In practice, this doesn't work
204 // for non-ELF and non-Wasm object formats, so only do it for ELF and
205 // Wasm.
206 bool hasComdat = globalOp.getComdat();
207 const llvm::Triple &triple = astCtx->getTargetInfo().getTriple();
208 if (!varDecl.isLocalVarDecl() && hasComdat &&
209 (triple.isOSBinFormatELF() || triple.isOSBinFormatWasm())) {
210 globalOp->emitError("NYI: guard COMDAT for non-local variables");
211 return {};
212 } else if (hasComdat && globalOp.isWeakForLinker()) {
213 guard.setComdat(true);
214 }
215
216 setStaticLocalDeclGuardAddress(globalSymName, guard);
217 }
218 return guard;
219 }
220
221 ///
222 /// AST related
223 /// -----------
224
225 clang::ASTContext *astCtx;
226
227 /// Tracks current module.
228 mlir::ModuleOp mlirModule;
229
230 /// Tracks existing dynamic initializers.
231 llvm::StringMap<uint32_t> dynamicInitializerNames;
232 llvm::SmallVector<cir::FuncOp> dynamicInitializers;
233
234 /// Tracks guard variables for static locals (keyed by global symbol name).
235 llvm::StringMap<cir::GlobalOp> staticLocalDeclGuardMap;
236
237 llvm::StringMap<llvm::SmallVector<cir::GlobalOp, 1>> constAggregateGlobals;
238
239 /// List of ctors and their priorities to be called before main()
240 llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalCtorList;
241 /// List of dtors and their priorities to be called when unloading module.
242 llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalDtorList;
243
244 /// Returns true if the target uses ARM-style guard variables for static
245 /// local initialization (32-bit guard, check bit 0 only).
246 bool useARMGuardVarABI() const {
247 switch (astCtx->getCXXABIKind()) {
248 case clang::TargetCXXABI::GenericARM:
249 case clang::TargetCXXABI::iOS:
250 case clang::TargetCXXABI::WatchOS:
251 case clang::TargetCXXABI::GenericAArch64:
252 case clang::TargetCXXABI::WebAssembly:
253 return true;
254 default:
255 return false;
256 }
257 }
258
259 void emitGlobalGuardedDtorRegion(CIRBaseBuilderTy &builder,
260 cir::GlobalOp global,
261 mlir::Region &dtorRegion, bool tls,
262 mlir::Block &entryBB) {
263 // Create a variable that binds the atexit to this shared object.
264 builder.setInsertionPointToStart(&mlirModule.getBodyRegion().front());
265 cir::GlobalOp handle = getOrCreateRuntimeVariable(
266 builder, "__dso_handle", global.getLoc(), builder.getI8Type(),
267 cir::GlobalLinkageKind::ExternalLinkage, cir::VisibilityKind::Hidden);
268
269 // If this is a simple call to a destructor, get the called function.
270 // Otherwise, create a helper function for the entire dtor region,
271 // replacing the current dtor region body with a call to the helper
272 // function.
273 cir::CallOp dtorCall;
274 cir::FuncOp dtorFunc =
275 getOrCreateDtorFunc(builder, global, dtorRegion, dtorCall);
276
277 // Create a runtime helper function:
278 // extern "C" int __cxa_atexit(void (*f)(void *), void *p, void *d);
279 cir::PointerType voidPtrTy = builder.getVoidPtrTy();
280 cir::PointerType voidFnPtrTy = builder.getVoidFnPtrTy({voidPtrTy});
281 cir::PointerType handlePtrTy = builder.getPointerTo(handle.getSymType());
282 auto fnAtExitType =
283 builder.getVoidFnTy({voidFnPtrTy, voidPtrTy, handlePtrTy});
284
285 llvm::StringLiteral nameAtExit = "__cxa_atexit";
286 if (tls)
287 nameAtExit = astCtx->getTargetInfo().getTriple().isOSDarwin()
288 ? llvm::StringLiteral("_tlv_atexit")
289 : llvm::StringLiteral("__cxa_thread_atexit");
290
291 cir::FuncOp fnAtExit = buildRuntimeFunction(builder, nameAtExit,
292 global.getLoc(), fnAtExitType);
293
294 // Replace the dtor (or helper) call with a call to
295 // __cxa_atexit(&dtor, &var, &__dso_handle)
296 builder.setInsertionPointAfter(dtorCall);
297 mlir::Value args[3];
298 auto dtorPtrTy = cir::PointerType::get(dtorFunc.getFunctionType());
299 args[0] = cir::GetGlobalOp::create(builder, dtorCall.getLoc(), dtorPtrTy,
300 dtorFunc.getSymName());
301 args[0] = cir::CastOp::create(builder, dtorCall.getLoc(), voidFnPtrTy,
302 cir::CastKind::bitcast, args[0]);
303 args[1] =
304 cir::CastOp::create(builder, dtorCall.getLoc(), voidPtrTy,
305 cir::CastKind::bitcast, dtorCall.getArgOperand(0));
306 args[2] = cir::GetGlobalOp::create(builder, handle.getLoc(), handlePtrTy,
307 handle.getSymName());
308 builder.createCallOp(dtorCall.getLoc(), fnAtExit, args);
309 dtorCall->erase();
310 mlir::Block &dtorBlock = dtorRegion.front();
311 entryBB.getOperations().splice(entryBB.end(), dtorBlock.getOperations(),
312 dtorBlock.begin(),
313 std::prev(dtorBlock.end()));
314 }
315
316 /// Emit the guarded initialization for a static local variable.
317 /// This handles the if/else structure after the guard byte check,
318 /// following OG's ItaniumCXXABI::EmitGuardedInit skeleton.
319 void emitCXXGuardedInitIf(CIRBaseBuilderTy &builder, cir::GlobalOp globalOp,
320 mlir::Region &ctorRegion, mlir::Region &dtorRegion,
321 cir::ASTVarDeclInterface varDecl,
322 mlir::Value guardPtr, cir::PointerType guardPtrTy,
323 bool threadsafe) {
324 auto loc = globalOp->getLoc();
325
326 // The semantics of dynamic initialization of variables with static or
327 // thread storage duration depends on whether they are declared at
328 // block-scope. The initialization of such variables at block-scope can be
329 // aborted with an exception and later retried (per C++20 [stmt.dcl]p4),
330 // and recursive entry to their initialization has undefined behavior (also
331 // per C++20 [stmt.dcl]p4). For such variables declared at non-block scope,
332 // exceptions lead to termination (per C++20 [except.terminate]p1), and
333 // recursive references to the variables are governed only by the lifetime
334 // rules (per C++20 [class.cdtor]p2), which means such references are
335 // perfectly fine as long as they avoid touching memory. As a result,
336 // block-scope variables must not be marked as initialized until after
337 // initialization completes (unless the mark is reverted following an
338 // exception), but non-block-scope variables must be marked prior to
339 // initialization so that recursive accesses during initialization do not
340 // restart initialization.
341
342 auto emitBody = [&]() {
343 // Emit the initializer and add a global destructor if appropriate.
344 mlir::Block *insertBlock = builder.getInsertionBlock();
345 if (!ctorRegion.empty()) {
346 assert(ctorRegion.hasOneBlock() && "Enforced by MaxSizedRegion<1>");
347
348 mlir::Block &block = ctorRegion.front();
349 insertBlock->getOperations().splice(
350 insertBlock->end(), block.getOperations(), block.begin(),
351 std::prev(block.end()));
352 }
353
354 if (!dtorRegion.empty()) {
355 assert(dtorRegion.hasOneBlock() && "Enforced by MaxSizedRegion<1>");
356
357 emitGlobalGuardedDtorRegion(builder, globalOp, dtorRegion, !threadsafe,
358 *insertBlock);
359 }
360 builder.setInsertionPointToEnd(insertBlock);
361 ctorRegion.getBlocks().clear();
362 };
363
364 // Variables used when coping with thread-safe statics and exceptions.
365 if (threadsafe) {
366 // Call __cxa_guard_acquire.
367 cir::CallOp acquireCall = builder.createCallOp(
368 loc, getGuardAcquireFn(guardPtrTy), mlir::ValueRange{guardPtr});
369 mlir::Value acquireResult = acquireCall.getResult();
370
371 auto acquireZero = builder.getConstantInt(
372 loc, mlir::cast<cir::IntType>(acquireResult.getType()), 0);
373 auto shouldInit = builder.createCompare(loc, cir::CmpOpKind::ne,
374 acquireResult, acquireZero);
375
376 // Create the IfOp for the shouldInit check.
377 // Pass an empty callback to avoid auto-creating a yield terminator.
378 auto ifOp =
379 cir::IfOp::create(builder, loc, shouldInit, /*withElseRegion=*/false,
380 [](mlir::OpBuilder &, mlir::Location) {});
381 mlir::OpBuilder::InsertionGuard insertGuard(builder);
382 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
383
384 // Call __cxa_guard_abort along the exceptional edge.
385 // OG: CGF.EHStack.pushCleanup<CallGuardAbort>(EHCleanup, guard);
387
388 emitBody();
389
390 // Pop the guard-abort cleanup if we pushed one.
391 // OG: CGF.PopCleanupBlock();
393
394 // Call __cxa_guard_release. This cannot throw.
395 builder.createCallOp(loc, getGuardReleaseFn(guardPtrTy),
396 mlir::ValueRange{guardPtr});
397
398 builder.createYield(loc);
399 } else if (!varDecl.isLocalVarDecl()) {
400 // For non-local variables, store 1 into the first byte of the guard
401 // variable before the object initialization begins so that references
402 // to the variable during initialization don't restart initialization.
403 // OG: Builder.CreateStore(llvm::ConstantInt::get(CGM.Int8Ty, 1), ...);
404 // Then: CGF.EmitCXXGlobalVarDeclInit(D, var, shouldPerformInit);
405 globalOp->emitError("NYI: non-threadsafe init for non-local variables");
406 return;
407 } else {
408 emitBody();
409 // For local variables, store 1 into the first byte of the guard variable
410 // after the object initialization completes so that initialization is
411 // retried if initialization is interrupted by an exception.
412 builder.createStore(
413 loc, builder.getConstantInt(loc, guardPtrTy.getPointee(), 1),
414 guardPtr);
415 }
416
417 builder.createYield(loc); // Outermost IfOp
418 }
419
420 void setASTContext(clang::ASTContext *c) { astCtx = c; }
421};
422
423} // namespace
424
425cir::GlobalOp LoweringPreparePass::getOrCreateRuntimeVariable(
426 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
427 mlir::Type type, cir::GlobalLinkageKind linkage,
428 cir::VisibilityKind visibility) {
429 cir::GlobalOp g = dyn_cast_or_null<cir::GlobalOp>(
430 mlir::SymbolTable::lookupNearestSymbolFrom(
431 mlirModule, mlir::StringAttr::get(mlirModule->getContext(), name)));
432 if (!g) {
433 g = cir::GlobalOp::create(builder, loc, name, type);
434 g.setLinkageAttr(
435 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
436 mlir::SymbolTable::setSymbolVisibility(
437 g, mlir::SymbolTable::Visibility::Private);
438 g.setGlobalVisibility(visibility);
439 }
440 return g;
441}
442
443cir::FuncOp LoweringPreparePass::buildRuntimeFunction(
444 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
445 cir::FuncType type, cir::GlobalLinkageKind linkage) {
446 cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom(
447 mlirModule, StringAttr::get(mlirModule->getContext(), name)));
448 if (!f) {
449 f = cir::FuncOp::create(builder, loc, name, type);
450 f.setLinkageAttr(
451 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
452 mlir::SymbolTable::setSymbolVisibility(
453 f, mlir::SymbolTable::Visibility::Private);
454
456 }
457 return f;
458}
459
460static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx,
461 cir::CastOp op) {
462 cir::CIRBaseBuilderTy builder(ctx);
463 builder.setInsertionPoint(op);
464
465 mlir::Value src = op.getSrc();
466 mlir::Value imag = builder.getNullValue(src.getType(), op.getLoc());
467 return builder.createComplexCreate(op.getLoc(), src, imag);
468}
469
470static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx,
471 cir::CastOp op,
472 cir::CastKind elemToBoolKind) {
473 cir::CIRBaseBuilderTy builder(ctx);
474 builder.setInsertionPoint(op);
475
476 mlir::Value src = op.getSrc();
477 if (!mlir::isa<cir::BoolType>(op.getType()))
478 return builder.createComplexReal(op.getLoc(), src);
479
480 // Complex cast to bool: (bool)(a+bi) => (bool)a || (bool)b
481 mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
482 mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
483
484 cir::BoolType boolTy = builder.getBoolTy();
485 mlir::Value srcRealToBool =
486 builder.createCast(op.getLoc(), elemToBoolKind, srcReal, boolTy);
487 mlir::Value srcImagToBool =
488 builder.createCast(op.getLoc(), elemToBoolKind, srcImag, boolTy);
489 return builder.createLogicalOr(op.getLoc(), srcRealToBool, srcImagToBool);
490}
491
492static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx,
493 cir::CastOp op,
494 cir::CastKind scalarCastKind) {
495 CIRBaseBuilderTy builder(ctx);
496 builder.setInsertionPoint(op);
497
498 mlir::Value src = op.getSrc();
499 auto dstComplexElemTy =
500 mlir::cast<cir::ComplexType>(op.getType()).getElementType();
501
502 mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
503 mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
504
505 mlir::Value dstReal = builder.createCast(op.getLoc(), scalarCastKind, srcReal,
506 dstComplexElemTy);
507 mlir::Value dstImag = builder.createCast(op.getLoc(), scalarCastKind, srcImag,
508 dstComplexElemTy);
509 return builder.createComplexCreate(op.getLoc(), dstReal, dstImag);
510}
511
512void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
513 mlir::MLIRContext &ctx = getContext();
514 mlir::Value loweredValue = [&]() -> mlir::Value {
515 switch (op.getKind()) {
516 case cir::CastKind::float_to_complex:
517 case cir::CastKind::int_to_complex:
518 return lowerScalarToComplexCast(ctx, op);
519 case cir::CastKind::float_complex_to_real:
520 case cir::CastKind::int_complex_to_real:
521 return lowerComplexToScalarCast(ctx, op, op.getKind());
522 case cir::CastKind::float_complex_to_bool:
523 return lowerComplexToScalarCast(ctx, op, cir::CastKind::float_to_bool);
524 case cir::CastKind::int_complex_to_bool:
525 return lowerComplexToScalarCast(ctx, op, cir::CastKind::int_to_bool);
526 case cir::CastKind::float_complex:
527 return lowerComplexToComplexCast(ctx, op, cir::CastKind::floating);
528 case cir::CastKind::float_complex_to_int_complex:
529 return lowerComplexToComplexCast(ctx, op, cir::CastKind::float_to_int);
530 case cir::CastKind::int_complex:
531 return lowerComplexToComplexCast(ctx, op, cir::CastKind::integral);
532 case cir::CastKind::int_complex_to_float_complex:
533 return lowerComplexToComplexCast(ctx, op, cir::CastKind::int_to_float);
534 default:
535 return nullptr;
536 }
537 }();
538
539 if (loweredValue) {
540 op.replaceAllUsesWith(loweredValue);
541 op.erase();
542 }
543}
544
545static mlir::Value buildComplexBinOpLibCall(
546 LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
547 llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics),
548 mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal,
549 mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) {
550 cir::FPTypeInterface elementTy =
551 mlir::cast<cir::FPTypeInterface>(ty.getElementType());
552
553 llvm::StringRef libFuncName = libFuncNameGetter(
554 llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics()));
555 llvm::SmallVector<mlir::Type, 4> libFuncInputTypes(4, elementTy);
556
557 cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty);
558
559 // Insert a declaration for the runtime function to be used in Complex
560 // multiplication and division when needed
561 cir::FuncOp libFunc;
562 {
563 mlir::OpBuilder::InsertionGuard ipGuard{builder};
564 builder.setInsertionPointToStart(pass.mlirModule.getBody());
565 libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy);
566 }
567
568 cir::CallOp call =
569 builder.createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag});
570 return call.getResult();
571}
572
573static llvm::StringRef
574getComplexDivLibCallName(llvm::APFloat::Semantics semantics) {
575 switch (semantics) {
576 case llvm::APFloat::S_IEEEhalf:
577 return "__divhc3";
578 case llvm::APFloat::S_IEEEsingle:
579 return "__divsc3";
580 case llvm::APFloat::S_IEEEdouble:
581 return "__divdc3";
582 case llvm::APFloat::S_PPCDoubleDouble:
583 return "__divtc3";
584 case llvm::APFloat::S_x87DoubleExtended:
585 return "__divxc3";
586 case llvm::APFloat::S_IEEEquad:
587 return "__divtc3";
588 default:
589 llvm_unreachable("unsupported floating point type");
590 }
591}
592
593static mlir::Value
594buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc,
595 mlir::Value lhsReal, mlir::Value lhsImag,
596 mlir::Value rhsReal, mlir::Value rhsImag) {
597 // (a+bi) / (c+di) = ((ac+bd)/(cc+dd)) + ((bc-ad)/(cc+dd))i
598 mlir::Value &a = lhsReal;
599 mlir::Value &b = lhsImag;
600 mlir::Value &c = rhsReal;
601 mlir::Value &d = rhsImag;
602
603 mlir::Value ac = builder.createMul(loc, a, c); // a*c
604 mlir::Value bd = builder.createMul(loc, b, d); // b*d
605 mlir::Value cc = builder.createMul(loc, c, c); // c*c
606 mlir::Value dd = builder.createMul(loc, d, d); // d*d
607 mlir::Value acbd = builder.createAdd(loc, ac, bd); // ac+bd
608 mlir::Value ccdd = builder.createAdd(loc, cc, dd); // cc+dd
609 mlir::Value resultReal = builder.createDiv(loc, acbd, ccdd);
610
611 mlir::Value bc = builder.createMul(loc, b, c); // b*c
612 mlir::Value ad = builder.createMul(loc, a, d); // a*d
613 mlir::Value bcad = builder.createSub(loc, bc, ad); // bc-ad
614 mlir::Value resultImag = builder.createDiv(loc, bcad, ccdd);
615 return builder.createComplexCreate(loc, resultReal, resultImag);
616}
617
618static mlir::Value
620 mlir::Value lhsReal, mlir::Value lhsImag,
621 mlir::Value rhsReal, mlir::Value rhsImag) {
622 // Implements Smith's algorithm for complex division.
623 // SMITH, R. L. Algorithm 116: Complex division. Commun. ACM 5, 8 (1962).
624
625 // Let:
626 // - lhs := a+bi
627 // - rhs := c+di
628 // - result := lhs / rhs = e+fi
629 //
630 // The algorithm pseudocode looks like follows:
631 // if fabs(c) >= fabs(d):
632 // r := d / c
633 // tmp := c + r*d
634 // e = (a + b*r) / tmp
635 // f = (b - a*r) / tmp
636 // else:
637 // r := c / d
638 // tmp := d + r*c
639 // e = (a*r + b) / tmp
640 // f = (b*r - a) / tmp
641
642 mlir::Value &a = lhsReal;
643 mlir::Value &b = lhsImag;
644 mlir::Value &c = rhsReal;
645 mlir::Value &d = rhsImag;
646
647 auto trueBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
648 mlir::Value r = builder.createDiv(loc, d, c); // r := d / c
649 mlir::Value rd = builder.createMul(loc, r, d); // r*d
650 mlir::Value tmp = builder.createAdd(loc, c, rd); // tmp := c + r*d
651
652 mlir::Value br = builder.createMul(loc, b, r); // b*r
653 mlir::Value abr = builder.createAdd(loc, a, br); // a + b*r
654 mlir::Value e = builder.createDiv(loc, abr, tmp);
655
656 mlir::Value ar = builder.createMul(loc, a, r); // a*r
657 mlir::Value bar = builder.createSub(loc, b, ar); // b - a*r
658 mlir::Value f = builder.createDiv(loc, bar, tmp);
659
660 mlir::Value result = builder.createComplexCreate(loc, e, f);
661 builder.createYield(loc, result);
662 };
663
664 auto falseBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
665 mlir::Value r = builder.createDiv(loc, c, d); // r := c / d
666 mlir::Value rc = builder.createMul(loc, r, c); // r*c
667 mlir::Value tmp = builder.createAdd(loc, d, rc); // tmp := d + r*c
668
669 mlir::Value ar = builder.createMul(loc, a, r); // a*r
670 mlir::Value arb = builder.createAdd(loc, ar, b); // a*r + b
671 mlir::Value e = builder.createDiv(loc, arb, tmp);
672
673 mlir::Value br = builder.createMul(loc, b, r); // b*r
674 mlir::Value bra = builder.createSub(loc, br, a); // b*r - a
675 mlir::Value f = builder.createDiv(loc, bra, tmp);
676
677 mlir::Value result = builder.createComplexCreate(loc, e, f);
678 builder.createYield(loc, result);
679 };
680
681 auto cFabs = cir::FAbsOp::create(builder, loc, c);
682 auto dFabs = cir::FAbsOp::create(builder, loc, d);
683 cir::CmpOp cmpResult =
684 builder.createCompare(loc, cir::CmpOpKind::ge, cFabs, dFabs);
685 auto ternary = cir::TernaryOp::create(builder, loc, cmpResult,
686 trueBranchBuilder, falseBranchBuilder);
687
688 return ternary.getResult();
689}
690
692 mlir::MLIRContext &context, clang::ASTContext &cc,
693 CIRBaseBuilderTy &builder, mlir::Type elementType) {
694
695 auto getHigherPrecisionFPType = [&context](mlir::Type type) -> mlir::Type {
696 if (mlir::isa<cir::FP16Type>(type))
697 return cir::SingleType::get(&context);
698
699 if (mlir::isa<cir::SingleType>(type) || mlir::isa<cir::BF16Type>(type))
700 return cir::DoubleType::get(&context);
701
702 if (mlir::isa<cir::DoubleType>(type))
703 return cir::LongDoubleType::get(&context, type);
704
705 return type;
706 };
707
708 auto getFloatTypeSemantics =
709 [&cc](mlir::Type type) -> const llvm::fltSemantics & {
710 const clang::TargetInfo &info = cc.getTargetInfo();
711 if (mlir::isa<cir::FP16Type>(type))
712 return info.getHalfFormat();
713
714 if (mlir::isa<cir::BF16Type>(type))
715 return info.getBFloat16Format();
716
717 if (mlir::isa<cir::SingleType>(type))
718 return info.getFloatFormat();
719
720 if (mlir::isa<cir::DoubleType>(type))
721 return info.getDoubleFormat();
722
723 if (mlir::isa<cir::LongDoubleType>(type)) {
724 if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
725 llvm_unreachable("NYI Float type semantics with OpenMP");
726 return info.getLongDoubleFormat();
727 }
728
729 if (mlir::isa<cir::FP128Type>(type)) {
730 if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
731 llvm_unreachable("NYI Float type semantics with OpenMP");
732 return info.getFloat128Format();
733 }
734
735 llvm_unreachable("Unsupported float type semantics");
736 };
737
738 const mlir::Type higherElementType = getHigherPrecisionFPType(elementType);
739 const llvm::fltSemantics &elementTypeSemantics =
740 getFloatTypeSemantics(elementType);
741 const llvm::fltSemantics &higherElementTypeSemantics =
742 getFloatTypeSemantics(higherElementType);
743
744 // Check that the promoted type can handle the intermediate values without
745 // overflowing. This can be interpreted as:
746 // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal) * 2 <=
747 // LargerType.LargestFiniteVal.
748 // In terms of exponent it gives this formula:
749 // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal
750 // doubles the exponent of SmallerType.LargestFiniteVal)
751 if (llvm::APFloat::semanticsMaxExponent(elementTypeSemantics) * 2 + 1 <=
752 llvm::APFloat::semanticsMaxExponent(higherElementTypeSemantics)) {
753 return higherElementType;
754 }
755
756 // The intermediate values can't be represented in the promoted type
757 // without overflowing.
758 return {};
759}
760
761static mlir::Value
762lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
763 mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal,
764 mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag,
765 mlir::MLIRContext &mlirCx, clang::ASTContext &cc) {
766 cir::ComplexType complexTy = op.getType();
767 if (mlir::isa<cir::FPTypeInterface>(complexTy.getElementType())) {
768 cir::ComplexRangeKind range = op.getRange();
769 if (range == cir::ComplexRangeKind::Improved)
770 return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
771 rhsReal, rhsImag);
772
773 if (range == cir::ComplexRangeKind::Full)
775 loc, complexTy, lhsReal, lhsImag, rhsReal,
776 rhsImag);
777
778 if (range == cir::ComplexRangeKind::Promoted) {
779 mlir::Type originalElementType = complexTy.getElementType();
780 mlir::Type higherPrecisionElementType =
782 originalElementType);
783
784 if (!higherPrecisionElementType)
785 return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
786 rhsReal, rhsImag);
787
788 cir::CastKind floatingCastKind = cir::CastKind::floating;
789 lhsReal = builder.createCast(floatingCastKind, lhsReal,
790 higherPrecisionElementType);
791 lhsImag = builder.createCast(floatingCastKind, lhsImag,
792 higherPrecisionElementType);
793 rhsReal = builder.createCast(floatingCastKind, rhsReal,
794 higherPrecisionElementType);
795 rhsImag = builder.createCast(floatingCastKind, rhsImag,
796 higherPrecisionElementType);
797
798 mlir::Value algebraicResult = buildAlgebraicComplexDiv(
799 builder, loc, lhsReal, lhsImag, rhsReal, rhsImag);
800
801 mlir::Value resultReal = builder.createComplexReal(loc, algebraicResult);
802 mlir::Value resultImag = builder.createComplexImag(loc, algebraicResult);
803
804 mlir::Value finalReal =
805 builder.createCast(floatingCastKind, resultReal, originalElementType);
806 mlir::Value finalImag =
807 builder.createCast(floatingCastKind, resultImag, originalElementType);
808 return builder.createComplexCreate(loc, finalReal, finalImag);
809 }
810 }
811
812 return buildAlgebraicComplexDiv(builder, loc, lhsReal, lhsImag, rhsReal,
813 rhsImag);
814}
815
816void LoweringPreparePass::lowerComplexDivOp(cir::ComplexDivOp op) {
817 cir::CIRBaseBuilderTy builder(getContext());
818 builder.setInsertionPointAfter(op);
819 mlir::Location loc = op.getLoc();
820 mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
821 mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
822 mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
823 mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
824 mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
825 mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
826
827 mlir::Value loweredResult =
828 lowerComplexDiv(*this, builder, loc, op, lhsReal, lhsImag, rhsReal,
829 rhsImag, getContext(), *astCtx);
830 op.replaceAllUsesWith(loweredResult);
831 op.erase();
832}
833
834static llvm::StringRef
835getComplexMulLibCallName(llvm::APFloat::Semantics semantics) {
836 switch (semantics) {
837 case llvm::APFloat::S_IEEEhalf:
838 return "__mulhc3";
839 case llvm::APFloat::S_IEEEsingle:
840 return "__mulsc3";
841 case llvm::APFloat::S_IEEEdouble:
842 return "__muldc3";
843 case llvm::APFloat::S_PPCDoubleDouble:
844 return "__multc3";
845 case llvm::APFloat::S_x87DoubleExtended:
846 return "__mulxc3";
847 case llvm::APFloat::S_IEEEquad:
848 return "__multc3";
849 default:
850 llvm_unreachable("unsupported floating point type");
851 }
852}
853
854static mlir::Value lowerComplexMul(LoweringPreparePass &pass,
855 CIRBaseBuilderTy &builder,
856 mlir::Location loc, cir::ComplexMulOp op,
857 mlir::Value lhsReal, mlir::Value lhsImag,
858 mlir::Value rhsReal, mlir::Value rhsImag) {
859 // (a+bi) * (c+di) = (ac-bd) + (ad+bc)i
860 mlir::Value resultRealLhs = builder.createMul(loc, lhsReal, rhsReal); // ac
861 mlir::Value resultRealRhs = builder.createMul(loc, lhsImag, rhsImag); // bd
862 mlir::Value resultImagLhs = builder.createMul(loc, lhsReal, rhsImag); // ad
863 mlir::Value resultImagRhs = builder.createMul(loc, lhsImag, rhsReal); // bc
864 mlir::Value resultReal = builder.createSub(loc, resultRealLhs, resultRealRhs);
865 mlir::Value resultImag = builder.createAdd(loc, resultImagLhs, resultImagRhs);
866 mlir::Value algebraicResult =
867 builder.createComplexCreate(loc, resultReal, resultImag);
868
869 cir::ComplexType complexTy = op.getType();
870 cir::ComplexRangeKind rangeKind = op.getRange();
871 if (mlir::isa<cir::IntType>(complexTy.getElementType()) ||
872 rangeKind == cir::ComplexRangeKind::Basic ||
873 rangeKind == cir::ComplexRangeKind::Improved ||
874 rangeKind == cir::ComplexRangeKind::Promoted)
875 return algebraicResult;
876
878
879 // Check whether the real part and the imaginary part of the result are both
880 // NaN. If so, emit a library call to compute the multiplication instead.
881 // We check a value against NaN by comparing the value against itself.
882 mlir::Value resultRealIsNaN = builder.createIsNaN(loc, resultReal);
883 mlir::Value resultImagIsNaN = builder.createIsNaN(loc, resultImag);
884 mlir::Value resultRealAndImagAreNaN =
885 builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN);
886
887 return cir::TernaryOp::create(
888 builder, loc, resultRealAndImagAreNaN,
889 [&](mlir::OpBuilder &, mlir::Location) {
890 mlir::Value libCallResult = buildComplexBinOpLibCall(
891 pass, builder, &getComplexMulLibCallName, loc, complexTy,
892 lhsReal, lhsImag, rhsReal, rhsImag);
893 builder.createYield(loc, libCallResult);
894 },
895 [&](mlir::OpBuilder &, mlir::Location) {
896 builder.createYield(loc, algebraicResult);
897 })
898 .getResult();
899}
900
901void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) {
902 cir::CIRBaseBuilderTy builder(getContext());
903 builder.setInsertionPointAfter(op);
904 mlir::Location loc = op.getLoc();
905 mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
906 mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
907 mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
908 mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
909 mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
910 mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
911 mlir::Value loweredResult = lowerComplexMul(*this, builder, loc, op, lhsReal,
912 lhsImag, rhsReal, rhsImag);
913 op.replaceAllUsesWith(loweredResult);
914 op.erase();
915}
916
917void LoweringPreparePass::lowerUnaryOp(cir::UnaryOpInterface op) {
918 if (!mlir::isa<cir::ComplexType>(op.getResult().getType()))
919 return;
920
921 mlir::Location loc = op->getLoc();
922 CIRBaseBuilderTy builder(getContext());
923 builder.setInsertionPointAfter(op);
924
925 mlir::Value operand = op.getInput();
926 mlir::Value operandReal = builder.createComplexReal(loc, operand);
927 mlir::Value operandImag = builder.createComplexImag(loc, operand);
928
929 mlir::Value resultReal = operandReal;
930 mlir::Value resultImag = operandImag;
931
932 llvm::TypeSwitch<mlir::Operation *>(op)
933 .Case<cir::IncOp>(
934 [&](auto) { resultReal = builder.createInc(loc, operandReal); })
935 .Case<cir::DecOp>(
936 [&](auto) { resultReal = builder.createDec(loc, operandReal); })
937 .Case<cir::MinusOp>([&](auto) {
938 resultReal = builder.createMinus(loc, operandReal);
939 resultImag = builder.createMinus(loc, operandImag);
940 })
941 .Case<cir::NotOp>(
942 [&](auto) { resultImag = builder.createMinus(loc, operandImag); })
943 .Default([](auto) { llvm_unreachable("unhandled unary complex op"); });
944
945 mlir::Value result = builder.createComplexCreate(loc, resultReal, resultImag);
946 op->replaceAllUsesWith(mlir::ValueRange{result});
947 op->erase();
948}
949
950cir::FuncOp LoweringPreparePass::getOrCreateDtorFunc(CIRBaseBuilderTy &builder,
951 cir::GlobalOp op,
952 mlir::Region &dtorRegion,
953 cir::CallOp &dtorCall) {
954 mlir::OpBuilder::InsertionGuard guard(builder);
957
958 cir::VoidType voidTy = builder.getVoidTy();
959 auto voidPtrTy = cir::PointerType::get(voidTy);
960
961 // Look for operations in dtorBlock
962 mlir::Block &dtorBlock = dtorRegion.front();
963
964 // The first operation should be a get_global to retrieve the address
965 // of the global variable we're destroying.
966 auto opIt = dtorBlock.getOperations().begin();
967 cir::GetGlobalOp ggop = mlir::cast<cir::GetGlobalOp>(*opIt);
968
969 // The simple case is just a call to a destructor, like this:
970 //
971 // %0 = cir.get_global %globalS : !cir.ptr<!rec_S>
972 // cir.call %_ZN1SD1Ev(%0) : (!cir.ptr<!rec_S>) -> ()
973 // (implicit cir.yield)
974 //
975 // That is, if the second operation is a call that takes the get_global result
976 // as its only operand, and the only other operation is a yield, then we can
977 // just return the called function.
978 if (dtorBlock.getOperations().size() == 3) {
979 auto callOp = mlir::dyn_cast<cir::CallOp>(&*(++opIt));
980 auto yieldOp = mlir::dyn_cast<cir::YieldOp>(&*(++opIt));
981 if (yieldOp && callOp && callOp.getNumOperands() == 1 &&
982 callOp.getArgOperand(0) == ggop) {
983 dtorCall = callOp;
984 return getCalledFunction(callOp);
985 }
986 }
987
988 // Otherwise, we need to create a helper function to replace the dtor region.
989 // This name is kind of arbitrary, but it matches the name that classic
990 // codegen uses, based on the expected case that gets us here.
991 builder.setInsertionPointAfter(op);
992 SmallString<256> fnName("__cxx_global_array_dtor");
993 uint32_t cnt = dynamicInitializerNames[fnName]++;
994 if (cnt)
995 fnName += "." + std::to_string(cnt);
996
997 // Create the helper function.
998 auto fnType = cir::FuncType::get({voidPtrTy}, voidTy);
999 cir::FuncOp dtorFunc =
1000 buildRuntimeFunction(builder, fnName, op.getLoc(), fnType,
1001 cir::GlobalLinkageKind::InternalLinkage);
1002
1003 SmallVector<mlir::NamedAttribute> paramAttrs;
1004 paramAttrs.push_back(
1005 builder.getNamedAttr("llvm.noundef", builder.getUnitAttr()));
1006 SmallVector<mlir::Attribute> argAttrDicts;
1007 argAttrDicts.push_back(
1008 mlir::DictionaryAttr::get(builder.getContext(), paramAttrs));
1009 dtorFunc.setArgAttrsAttr(
1010 mlir::ArrayAttr::get(builder.getContext(), argAttrDicts));
1011
1012 mlir::Block *entryBB = dtorFunc.addEntryBlock();
1013
1014 // Move everything from the dtor region into the helper function.
1015 entryBB->getOperations().splice(entryBB->begin(), dtorBlock.getOperations(),
1016 dtorBlock.begin(), dtorBlock.end());
1017
1018 // Before erasing this, clone it back into the dtor region
1019 cir::GetGlobalOp dtorGGop =
1020 mlir::cast<cir::GetGlobalOp>(entryBB->getOperations().front());
1021 builder.setInsertionPointToStart(&dtorBlock);
1022 builder.clone(*dtorGGop.getOperation());
1023
1024 // Replace all uses of the help function's get_global with the function
1025 // argument.
1026 mlir::Value dtorArg = entryBB->getArgument(0);
1027 dtorGGop.replaceAllUsesWith(dtorArg);
1028 dtorGGop.erase();
1029
1030 // Replace the yield in the final block with a return
1031 mlir::Block &finalBlock = dtorFunc.getBody().back();
1032 auto yieldOp = cast<cir::YieldOp>(finalBlock.getTerminator());
1033 builder.setInsertionPoint(yieldOp);
1034 cir::ReturnOp::create(builder, yieldOp->getLoc());
1035 yieldOp->erase();
1036
1037 // Create a call to the helper function, passing the original get_global op
1038 // as the argument.
1039 cir::GetGlobalOp origGGop =
1040 mlir::cast<cir::GetGlobalOp>(dtorBlock.getOperations().front());
1041 builder.setInsertionPointAfter(origGGop);
1042 mlir::Value ggopResult = origGGop.getResult();
1043 dtorCall = builder.createCallOp(op.getLoc(), dtorFunc, ggopResult);
1044
1045 // Add a yield after the call.
1046 auto finalYield = cir::YieldOp::create(builder, op.getLoc());
1047
1048 // Erase everything after the yield.
1049 dtorBlock.getOperations().erase(std::next(mlir::Block::iterator(finalYield)),
1050 dtorBlock.end());
1051 dtorRegion.getBlocks().erase(std::next(dtorRegion.begin()), dtorRegion.end());
1052
1053 return dtorFunc;
1054}
1055
1056cir::FuncOp
1057LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op) {
1058 // TODO(cir): Store this in the GlobalOp.
1059 // This should come from the MangleContext, but for now I'm hardcoding it.
1060 SmallString<256> fnName("__cxx_global_var_init");
1061 // Get a unique name
1062 uint32_t cnt = dynamicInitializerNames[fnName]++;
1063 if (cnt)
1064 fnName += "." + std::to_string(cnt);
1065
1066 // Create a variable initialization function.
1067 CIRBaseBuilderTy builder(getContext());
1068 builder.setInsertionPointAfter(op);
1069 cir::VoidType voidTy = builder.getVoidTy();
1070 auto fnType = cir::FuncType::get({}, voidTy);
1071 FuncOp f = buildRuntimeFunction(builder, fnName, op.getLoc(), fnType,
1072 cir::GlobalLinkageKind::InternalLinkage);
1073
1074 // Move over the initialization code of the ctor region.
1075 // The ctor region may have multiple blocks when exception handling
1076 // scaffolding creates extra blocks (e.g., unreachable/trap blocks).
1077 // We move all operations from the first block (minus the yield) into
1078 // the function entry, and discard extra blocks (which contain only
1079 // unreachable terminators from EH cleanup paths).
1080 mlir::Block *entryBB = f.addEntryBlock();
1081 if (!op.getCtorRegion().empty()) {
1082 mlir::Block &block = op.getCtorRegion().front();
1083 entryBB->getOperations().splice(entryBB->begin(), block.getOperations(),
1084 block.begin(), std::prev(block.end()));
1085 }
1086
1087 // Register the destructor call with __cxa_atexit
1088 mlir::Region &dtorRegion = op.getDtorRegion();
1089 if (!dtorRegion.empty()) {
1092
1093 emitGlobalGuardedDtorRegion(builder, op, dtorRegion,
1094 op.getTlsModel().has_value(), *entryBB);
1095 }
1096
1097 // Replace cir.yield with cir.return
1098 builder.setInsertionPointToEnd(entryBB);
1099 mlir::Operation *yieldOp = nullptr;
1100 if (!op.getCtorRegion().empty()) {
1101 mlir::Block &block = op.getCtorRegion().front();
1102 yieldOp = &block.getOperations().back();
1103 } else {
1104 assert(!dtorRegion.empty());
1105 mlir::Block &block = dtorRegion.front();
1106 yieldOp = &block.getOperations().back();
1107 }
1108
1109 assert(isa<cir::YieldOp>(*yieldOp));
1110 cir::ReturnOp::create(builder, yieldOp->getLoc());
1111 return f;
1112}
1113
1114cir::FuncOp
1115LoweringPreparePass::getGuardAcquireFn(cir::PointerType guardPtrTy) {
1116 // int __cxa_guard_acquire(__guard *guard_object);
1117 CIRBaseBuilderTy builder(getContext());
1118 mlir::OpBuilder::InsertionGuard ipGuard{builder};
1119 builder.setInsertionPointToStart(mlirModule.getBody());
1120 mlir::Location loc = mlirModule.getLoc();
1121 cir::IntType intTy = cir::IntType::get(&getContext(), 32, /*isSigned=*/true);
1122 auto fnType = cir::FuncType::get({guardPtrTy}, intTy);
1123 return buildRuntimeFunction(builder, "__cxa_guard_acquire", loc, fnType);
1124}
1125
1126cir::FuncOp
1127LoweringPreparePass::getGuardReleaseFn(cir::PointerType guardPtrTy) {
1128 // void __cxa_guard_release(__guard *guard_object);
1129 CIRBaseBuilderTy builder(getContext());
1130 mlir::OpBuilder::InsertionGuard ipGuard{builder};
1131 builder.setInsertionPointToStart(mlirModule.getBody());
1132 mlir::Location loc = mlirModule.getLoc();
1133 cir::VoidType voidTy = cir::VoidType::get(&getContext());
1134 auto fnType = cir::FuncType::get({guardPtrTy}, voidTy);
1135 return buildRuntimeFunction(builder, "__cxa_guard_release", loc, fnType);
1136}
1137
1138cir::GlobalOp LoweringPreparePass::createGuardGlobalOp(
1139 CIRBaseBuilderTy &builder, mlir::Location loc, llvm::StringRef name,
1140 cir::IntType guardTy, cir::GlobalLinkageKind linkage) {
1141 mlir::OpBuilder::InsertionGuard guard(builder);
1142 builder.setInsertionPointToStart(mlirModule.getBody());
1143 cir::GlobalOp g = cir::GlobalOp::create(builder, loc, name, guardTy);
1144 g.setLinkageAttr(
1145 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
1146 mlir::SymbolTable::setSymbolVisibility(
1147 g, mlir::SymbolTable::Visibility::Private);
1148 return g;
1149}
1150
1151void LoweringPreparePass::handleStaticLocal(cir::GlobalOp globalOp,
1152 cir::LocalInitOp localInitOp) {
1153 CIRBaseBuilderTy builder(getContext());
1154
1155 std::optional<cir::ASTVarDeclInterface> astOption = globalOp.getAst();
1156 assert(astOption.has_value());
1157 cir::ASTVarDeclInterface varDecl = astOption.value();
1158
1159 builder.setInsertionPointAfter(localInitOp);
1160 mlir::Block *localInitBlock = builder.getInsertionBlock();
1161
1162 // Remove the terminator temporarily - we'll add it back at the end.
1163 mlir::Operation *ret = localInitBlock->getTerminator();
1164 ret->remove();
1165 // Note: These two insert-point-after sets are necessary, as the 'trailing'
1166 // operation has changed thanks to the terminator removal.
1167 builder.setInsertionPointAfter(localInitOp);
1168
1169 // Inline variables that weren't instantiated from variable templates have
1170 // partially-ordered initialization within their translation unit.
1171 bool nonTemplateInline =
1172 varDecl.isInline() &&
1173 !clang::isTemplateInstantiation(varDecl.getTemplateSpecializationKind());
1174
1175 // Inline namespace-scope variables require guarded initialization in a
1176 // __cxx_global_var_init function. This is not yet implemented.
1177 if (nonTemplateInline) {
1178 globalOp->emitError(
1179 "NYI: guarded initialization for inline namespace-scope variables");
1180 return;
1181 }
1182
1183 // We only need to use thread-safe statics for local non-TLS variables and
1184 // inline variables; other global initialization is always single-threaded
1185 // or (through lazy dynamic loading in multiple threads) unsequenced.
1186 bool threadsafe = astCtx->getLangOpts().ThreadsafeStatics &&
1187 (varDecl.isLocalVarDecl() || nonTemplateInline) &&
1188 !varDecl.getTLSKind();
1189
1190 // If we have a global variable with internal linkage and thread-safe statics
1191 // are disabled, we can just let the guard variable be of type i8.
1192 bool useInt8GuardVariable = !threadsafe && globalOp.hasInternalLinkage();
1193 cir::CIRDataLayout dataLayout(mlirModule);
1194 cir::IntType guardTy;
1195 clang::CharUnits guardAlignment;
1196 // Guard variables are 64 bits in the generic ABI and size width on ARM
1197 // (i.e. 32-bit on AArch32, 64-bit on AArch64).
1198 if (useInt8GuardVariable) {
1199 guardTy = cir::IntType::get(&getContext(), 8, /*isSigned=*/true);
1200 guardAlignment = clang::CharUnits::One();
1201 } else if (useARMGuardVarABI()) {
1202 // Guard variables are size width on ARM (32-bit AArch32, 64-bit AArch64).
1203 const unsigned sizeTypeSize =
1204 astCtx->getTypeSize(astCtx->getSignedSizeType());
1205 guardTy = cir::IntType::get(&getContext(), sizeTypeSize, /*isSigned=*/true);
1206 guardAlignment =
1207 clang::CharUnits::fromQuantity(dataLayout.getABITypeAlign(guardTy));
1208 } else {
1209 guardTy = cir::IntType::get(&getContext(), 64, /*isSigned=*/true);
1210 guardAlignment =
1211 clang::CharUnits::fromQuantity(dataLayout.getABITypeAlign(guardTy));
1212 }
1213 assert(guardTy && guardAlignment.getQuantity() != 0);
1214
1215 auto guardPtrTy = cir::PointerType::get(guardTy);
1216
1217 // Create the guard variable if we don't already have it.
1218 cir::GlobalOp guard = getOrCreateStaticLocalDeclGuardAddress(
1219 builder, globalOp, varDecl, guardTy, guardAlignment);
1220 if (!guard) {
1221 // Error was already emitted, just restore the terminator and return.
1222 localInitBlock->push_back(ret);
1223 return;
1224 }
1225
1226 mlir::Value guardPtr = builder.createGetGlobal(guard, localInitOp.getTls());
1227
1228 // Test whether the variable has completed initialization.
1229 //
1230 // Itanium C++ ABI 3.3.2:
1231 // The following is pseudo-code showing how these functions can be used:
1232 // if (obj_guard.first_byte == 0) {
1233 // if ( __cxa_guard_acquire (&obj_guard) ) {
1234 // try {
1235 // ... initialize the object ...;
1236 // } catch (...) {
1237 // __cxa_guard_abort (&obj_guard);
1238 // throw;
1239 // }
1240 // ... queue object destructor with __cxa_atexit() ...;
1241 // __cxa_guard_release (&obj_guard);
1242 // }
1243 // }
1244 //
1245 // If threadsafe statics are enabled, but we don't have inline atomics, just
1246 // call __cxa_guard_acquire unconditionally. The "inline" check isn't
1247 // actually inline, and the user might not expect calls to __atomic libcalls.
1248 unsigned maxInlineWidthInBits =
1250
1251 if (!threadsafe || maxInlineWidthInBits) {
1252 // Load the first byte of the guard variable.
1253 auto bytePtrTy = cir::PointerType::get(builder.getSIntNTy(8));
1254 mlir::Value bytePtr = builder.createBitcast(guardPtr, bytePtrTy);
1255 mlir::Value guardLoad = builder.createAlignedLoad(
1256 localInitOp.getLoc(), bytePtr, guardAlignment.getAsAlign().value());
1257
1258 // Itanium ABI:
1259 // An implementation supporting thread-safety on multiprocessor
1260 // systems must also guarantee that references to the initialized
1261 // object do not occur before the load of the initialization flag.
1262 //
1263 // In LLVM, we do this by marking the load Acquire.
1264 if (threadsafe) {
1265 auto loadOp = mlir::cast<cir::LoadOp>(guardLoad.getDefiningOp());
1266 loadOp.setMemOrder(cir::MemOrder::Acquire);
1267 loadOp.setSyncScope(cir::SyncScopeKind::System);
1268 }
1269
1270 // For ARM, we should only check the first bit, rather than the entire byte:
1271 //
1272 // ARM C++ ABI 3.2.3.1:
1273 // To support the potential use of initialization guard variables
1274 // as semaphores that are the target of ARM SWP and LDREX/STREX
1275 // synchronizing instructions we define a static initialization
1276 // guard variable to be a 4-byte aligned, 4-byte word with the
1277 // following inline access protocol.
1278 // #define INITIALIZED 1
1279 // if ((obj_guard & INITIALIZED) != INITIALIZED) {
1280 // if (__cxa_guard_acquire(&obj_guard))
1281 // ...
1282 // }
1283 //
1284 // and similarly for ARM64:
1285 //
1286 // ARM64 C++ ABI 3.2.2:
1287 // This ABI instead only specifies the value bit 0 of the static guard
1288 // variable; all other bits are platform defined. Bit 0 shall be 0 when
1289 // the variable is not initialized and 1 when it is.
1290 if (useARMGuardVarABI() && !useInt8GuardVariable) {
1291 auto one = builder.getConstantInt(
1292 localInitOp.getLoc(), mlir::cast<cir::IntType>(guardLoad.getType()),
1293 1);
1294 guardLoad = builder.createAnd(localInitOp.getLoc(), guardLoad, one);
1295 }
1296
1297 // Check if the first byte of the guard variable is zero.
1298 auto zero = builder.getConstantInt(
1299 localInitOp.getLoc(), mlir::cast<cir::IntType>(guardLoad.getType()), 0);
1300 auto needsInit = builder.createCompare(localInitOp.getLoc(),
1301 cir::CmpOpKind::eq, guardLoad, zero);
1302
1303 // Build the guarded initialization inside an if block.
1304 cir::IfOp::create(
1305 builder, globalOp.getLoc(), needsInit,
1306 /*withElseRegion=*/false, [&](mlir::OpBuilder &, mlir::Location) {
1307 emitCXXGuardedInitIf(builder, globalOp, localInitOp.getCtorRegion(),
1308 localInitOp.getDtorRegion(), varDecl, guardPtr,
1309 guardPtrTy, threadsafe);
1310 });
1311 } else {
1312 // Threadsafe statics without inline atomics - call __cxa_guard_acquire
1313 // unconditionally without the initial guard byte check.
1314 globalOp->emitError("NYI: guarded init without inline atomics support");
1315 return;
1316 }
1317
1318 // Insert the removed terminator back.
1319 builder.getInsertionBlock()->push_back(ret);
1320}
1321
1322void LoweringPreparePass::lowerLocalInitOp(
1323 cir::LocalInitOp initOp, mlir::SymbolTableCollection &symbolTables) {
1324
1325 // If we don't actually need to initialize anything anymore, we're done here.
1326 if (initOp.getCtorRegion().empty() && initOp.getDtorRegion().empty()) {
1327 initOp.erase();
1328 return;
1329 }
1330
1331 cir::GlobalOp globalOp = initOp.getReferencedGlobal(symbolTables);
1332 assert(globalOp && "No global-op found");
1333
1334 handleStaticLocal(globalOp, initOp);
1335
1336 // Remove the init local op, now that we've done everything we need with it.
1337 initOp.erase();
1338}
1339
1340void LoweringPreparePass::lowerGlobalOp(GlobalOp op) {
1341 // Static locals are handled separately via guard variables.
1342 if (op.getStaticLocalGuard())
1343 return;
1344
1345 mlir::Region &ctorRegion = op.getCtorRegion();
1346 mlir::Region &dtorRegion = op.getDtorRegion();
1347
1348 if (!ctorRegion.empty() || !dtorRegion.empty()) {
1349 // Build a variable initialization function and move the initialzation code
1350 // in the ctor region over.
1351 cir::FuncOp f = buildCXXGlobalVarDeclInitFunc(op);
1352
1353 // Clear the ctor and dtor region
1354 ctorRegion.getBlocks().clear();
1355 dtorRegion.getBlocks().clear();
1356
1358 dynamicInitializers.push_back(f);
1359 }
1360
1362}
1363
1364void LoweringPreparePass::lowerThreeWayCmpOp(CmpThreeWayOp op) {
1365 CIRBaseBuilderTy builder(getContext());
1366 builder.setInsertionPointAfter(op);
1367
1368 mlir::Location loc = op->getLoc();
1369 cir::CmpThreeWayInfoAttr cmpInfo = op.getInfo();
1370
1371 mlir::Value ltRes =
1372 builder.getConstantInt(loc, op.getType(), cmpInfo.getLt());
1373 mlir::Value eqRes =
1374 builder.getConstantInt(loc, op.getType(), cmpInfo.getEq());
1375 mlir::Value gtRes =
1376 builder.getConstantInt(loc, op.getType(), cmpInfo.getGt());
1377
1378 mlir::Value transformedResult;
1379 if (cmpInfo.getOrdering() != CmpOrdering::Partial) {
1380 // Total ordering
1381 mlir::Value lt =
1382 builder.createCompare(loc, CmpOpKind::lt, op.getLhs(), op.getRhs());
1383 mlir::Value selectOnLt = builder.createSelect(loc, lt, ltRes, gtRes);
1384 mlir::Value eq =
1385 builder.createCompare(loc, CmpOpKind::eq, op.getLhs(), op.getRhs());
1386 transformedResult = builder.createSelect(loc, eq, eqRes, selectOnLt);
1387 } else {
1388 // Partial ordering
1389 cir::ConstantOp unorderedRes = builder.getConstantInt(
1390 loc, op.getType(), cmpInfo.getUnordered().value());
1391
1392 mlir::Value eq =
1393 builder.createCompare(loc, CmpOpKind::eq, op.getLhs(), op.getRhs());
1394 mlir::Value selectOnEq = builder.createSelect(loc, eq, eqRes, unorderedRes);
1395 mlir::Value gt =
1396 builder.createCompare(loc, CmpOpKind::gt, op.getLhs(), op.getRhs());
1397 mlir::Value selectOnGt = builder.createSelect(loc, gt, gtRes, selectOnEq);
1398 mlir::Value lt =
1399 builder.createCompare(loc, CmpOpKind::lt, op.getLhs(), op.getRhs());
1400 transformedResult = builder.createSelect(loc, lt, ltRes, selectOnGt);
1401 }
1402
1403 op.replaceAllUsesWith(transformedResult);
1404 op.erase();
1405}
1406
1407template <typename AttributeTy>
1408static llvm::SmallVector<mlir::Attribute>
1409prepareCtorDtorAttrList(mlir::MLIRContext *context,
1410 llvm::ArrayRef<std::pair<std::string, uint32_t>> list) {
1412 for (const auto &[name, priority] : list)
1413 attrs.push_back(AttributeTy::get(context, name, priority));
1414 return attrs;
1415}
1416
1417void LoweringPreparePass::buildGlobalCtorDtorList() {
1418 if (!globalCtorList.empty()) {
1419 llvm::SmallVector<mlir::Attribute> globalCtors =
1421 globalCtorList);
1422
1423 mlirModule->setAttr(cir::CIRDialect::getGlobalCtorsAttrName(),
1424 mlir::ArrayAttr::get(&getContext(), globalCtors));
1425 }
1426
1427 if (!globalDtorList.empty()) {
1428 llvm::SmallVector<mlir::Attribute> globalDtors =
1430 globalDtorList);
1431 mlirModule->setAttr(cir::CIRDialect::getGlobalDtorsAttrName(),
1432 mlir::ArrayAttr::get(&getContext(), globalDtors));
1433 }
1434}
1435
1436void LoweringPreparePass::buildCXXGlobalInitFunc() {
1437 if (dynamicInitializers.empty())
1438 return;
1439
1440 // TODO: handle globals with a user-specified initialzation priority.
1441 // TODO: handle default priority more nicely.
1443
1444 SmallString<256> fnName;
1445 // Include the filename in the symbol name. Including "sub_" matches gcc
1446 // and makes sure these symbols appear lexicographically behind the symbols
1447 // with priority (TBD). Module implementation units behave the same
1448 // way as a non-modular TU with imports.
1449 // TODO: check CXX20ModuleInits
1450 if (astCtx->getCurrentNamedModule() &&
1452 llvm::raw_svector_ostream out(fnName);
1453 std::unique_ptr<clang::MangleContext> mangleCtx(
1454 astCtx->createMangleContext());
1455 cast<clang::ItaniumMangleContext>(*mangleCtx)
1456 .mangleModuleInitializer(astCtx->getCurrentNamedModule(), out);
1457 } else {
1458 fnName += "_GLOBAL__sub_I_";
1459 fnName += getTransformedFileName(mlirModule);
1460 }
1461
1462 CIRBaseBuilderTy builder(getContext());
1463 builder.setInsertionPointToEnd(&mlirModule.getBodyRegion().back());
1464 auto fnType = cir::FuncType::get({}, builder.getVoidTy());
1465 cir::FuncOp f =
1466 buildRuntimeFunction(builder, fnName, mlirModule.getLoc(), fnType,
1467 cir::GlobalLinkageKind::ExternalLinkage);
1468 builder.setInsertionPointToStart(f.addEntryBlock());
1469 for (cir::FuncOp &f : dynamicInitializers)
1470 builder.createCallOp(f.getLoc(), f, {});
1471 // Add the global init function (not the individual ctor functions) to the
1472 // global ctor list.
1473 globalCtorList.emplace_back(fnName,
1474 cir::GlobalCtorAttr::getDefaultPriority());
1475
1476 cir::ReturnOp::create(builder, f.getLoc());
1477}
1478
1479/// Lower a cir.array.ctor or cir.array.dtor into a do-while loop that
1480/// iterates over every element. For cir.array.ctor ops whose partial_dtor
1481/// region is non-empty, the ctor loop is wrapped in a cir.cleanup.scope whose
1482/// EH cleanup performs a reverse destruction loop using the partial dtor body.
1484 clang::ASTContext *astCtx,
1485 mlir::Operation *op, mlir::Type eltTy,
1486 mlir::Value addr,
1487 mlir::Value numElements,
1488 uint64_t arrayLen, bool isCtor) {
1489 mlir::Location loc = op->getLoc();
1490 bool isDynamic = numElements != nullptr;
1491
1492 // TODO: instead of getting the size from the AST context, create alias for
1493 // PtrDiffTy and unify with CIRGen stuff.
1494 const unsigned sizeTypeSize =
1495 astCtx->getTypeSize(astCtx->getSignedSizeType());
1496
1497 // Both constructors and destructors use end = begin + numElements.
1498 // Constructors iterate forward [begin, end). Destructors iterate backward
1499 // from end, decrementing before calling the destructor on each element.
1500 mlir::Value begin, end;
1501 if (isDynamic) {
1502 begin = addr;
1503 end = cir::PtrStrideOp::create(builder, loc, eltTy, begin, numElements);
1504 } else {
1505 mlir::Value endOffsetVal =
1506 builder.getUnsignedInt(loc, arrayLen, sizeTypeSize);
1507 begin = cir::CastOp::create(builder, loc, eltTy,
1508 cir::CastKind::array_to_ptrdecay, addr);
1509 end = cir::PtrStrideOp::create(builder, loc, eltTy, begin, endOffsetVal);
1510 }
1511
1512 mlir::Value start = isCtor ? begin : end;
1513 mlir::Value stop = isCtor ? end : begin;
1514
1515 // For dynamic destructors, guard against zero elements.
1516 // This places the destructor loop emitted below inside the if block.
1517 cir::IfOp ifOp;
1518 if (isDynamic) {
1519 mlir::Value guardCond;
1520 if (isCtor) {
1521 mlir::Value zero = builder.getUnsignedInt(loc, 0, sizeTypeSize);
1522 guardCond = cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne,
1523 numElements, zero);
1524 } else {
1525 // We could check for numElements != 0 in this case too, but this matches
1526 // what classic codegen does.
1527 guardCond =
1528 cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne, start, stop);
1529 }
1530 ifOp = cir::IfOp::create(builder, loc, guardCond,
1531 /*withElseRegion=*/false,
1532 [&](mlir::OpBuilder &, mlir::Location) {});
1533 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1534 }
1535
1536 mlir::Value tmpAddr = builder.createAlloca(
1537 loc, /*addr type*/ builder.getPointerTo(eltTy),
1538 /*var type*/ eltTy, "__array_idx", builder.getAlignmentAttr(1));
1539 builder.createStore(loc, start, tmpAddr);
1540
1541 mlir::Block *bodyBlock = &op->getRegion(0).front();
1542
1543 // Clone the region body (ctor/dtor call and any setup ops like per-element
1544 // zero-init) into the loop, remapping the block argument to the current
1545 // element pointer.
1546 auto cloneRegionBodyInto = [&](mlir::Block *srcBlock,
1547 mlir::Value replacement) {
1548 mlir::IRMapping map;
1549 map.map(srcBlock->getArgument(0), replacement);
1550 for (mlir::Operation &regionOp : *srcBlock) {
1551 if (!mlir::isa<cir::YieldOp>(&regionOp))
1552 builder.clone(regionOp, map);
1553 }
1554 };
1555
1556 mlir::Block *partialDtorBlock = nullptr;
1557 if (auto arrayCtor = mlir::dyn_cast<cir::ArrayCtor>(op)) {
1558 mlir::Region &partialDtor = arrayCtor.getPartialDtor();
1559 if (!partialDtor.empty())
1560 partialDtorBlock = &partialDtor.front();
1561 } else if (auto arrayDtor = mlir::dyn_cast<cir::ArrayDtor>(op)) {
1562 // When the element destructor may throw, reuse the body block as the
1563 // partial-dtor block so that an exception thrown by an element's dtor
1564 // continues the reverse-destruction loop in the EH cleanup region. The
1565 // body block already stores the next element pointer to `tmpAddr`
1566 // before invoking the dtor, so when an exception unwinds from the
1567 // dtor call `tmpAddr` already points at the element that threw, and
1568 // the cleanup loop picks up from `tmpAddr - 1` and walks back to
1569 // `begin`.
1570 if (arrayDtor.getDtorMayThrow())
1571 partialDtorBlock = bodyBlock;
1572 }
1573
1574 auto emitCtorDtorLoop = [&]() {
1575 builder.createDoWhile(
1576 loc,
1577 /*condBuilder=*/
1578 [&](mlir::OpBuilder &b, mlir::Location loc) {
1579 auto currentElement = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
1580 auto cmp = cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne,
1581 currentElement, stop);
1582 builder.createCondition(cmp);
1583 },
1584 /*bodyBuilder=*/
1585 [&](mlir::OpBuilder &b, mlir::Location loc) {
1586 auto currentElement = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
1587 if (isCtor) {
1588 cloneRegionBodyInto(bodyBlock, currentElement);
1589 mlir::Value stride = builder.getUnsignedInt(loc, 1, sizeTypeSize);
1590 auto nextElement = cir::PtrStrideOp::create(builder, loc, eltTy,
1591 currentElement, stride);
1592 builder.createStore(loc, nextElement, tmpAddr);
1593 } else {
1594 mlir::Value stride = builder.getSignedInt(loc, -1, sizeTypeSize);
1595 auto prevElement = cir::PtrStrideOp::create(builder, loc, eltTy,
1596 currentElement, stride);
1597 builder.createStore(loc, prevElement, tmpAddr);
1598 cloneRegionBodyInto(bodyBlock, prevElement);
1599 }
1600
1601 cir::YieldOp::create(b, loc);
1602 });
1603 };
1604
1605 if (partialDtorBlock) {
1606 cir::CleanupScopeOp::create(
1607 builder, loc, cir::CleanupKind::EH,
1608 /*bodyBuilder=*/
1609 [&](mlir::OpBuilder &b, mlir::Location loc) {
1610 emitCtorDtorLoop();
1611 cir::YieldOp::create(b, loc);
1612 },
1613 /*cleanupBuilder=*/
1614 [&](mlir::OpBuilder &b, mlir::Location loc) {
1615 auto cur = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
1616 auto cmp =
1617 cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne, cur, begin);
1618 cir::IfOp::create(
1619 builder, loc, cmp, /*withElseRegion=*/false,
1620 [&](mlir::OpBuilder &b, mlir::Location loc) {
1621 builder.createDoWhile(
1622 loc,
1623 /*condBuilder=*/
1624 [&](mlir::OpBuilder &b, mlir::Location loc) {
1625 auto el = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
1626 auto neq = cir::CmpOp::create(
1627 builder, loc, cir::CmpOpKind::ne, el, begin);
1628 builder.createCondition(neq);
1629 },
1630 /*bodyBuilder=*/
1631 [&](mlir::OpBuilder &b, mlir::Location loc) {
1632 auto el = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
1633 mlir::Value negOne =
1634 builder.getSignedInt(loc, -1, sizeTypeSize);
1635 auto prev = cir::PtrStrideOp::create(builder, loc, eltTy,
1636 el, negOne);
1637 builder.createStore(loc, prev, tmpAddr);
1638 cloneRegionBodyInto(partialDtorBlock, prev);
1639 builder.createYield(loc);
1640 });
1641 cir::YieldOp::create(builder, loc);
1642 });
1643 cir::YieldOp::create(b, loc);
1644 });
1645 } else {
1646 emitCtorDtorLoop();
1647 }
1648
1649 if (ifOp)
1650 cir::YieldOp::create(builder, loc);
1651
1652 op->erase();
1653}
1654
1655void LoweringPreparePass::lowerArrayDtor(cir::ArrayDtor op) {
1656 CIRBaseBuilderTy builder(getContext());
1657 builder.setInsertionPointAfter(op.getOperation());
1658
1659 mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
1660
1661 if (op.getNumElements()) {
1662 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(),
1663 op.getNumElements(), /*arrayLen=*/0,
1664 /*isCtor=*/false);
1665 return;
1666 }
1667
1668 auto arrayLen =
1669 mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
1670 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(),
1671 /*numElements=*/nullptr, arrayLen,
1672 /*isCtor=*/false);
1673}
1674
1675void LoweringPreparePass::lowerArrayCtor(cir::ArrayCtor op) {
1676 cir::CIRBaseBuilderTy builder(getContext());
1677 builder.setInsertionPointAfter(op.getOperation());
1678
1679 mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
1680
1681 if (op.getNumElements()) {
1682 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(),
1683 op.getNumElements(), /*arrayLen=*/0,
1684 /*isCtor=*/true);
1685 return;
1686 }
1687
1688 auto arrayLen =
1689 mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
1690 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(),
1691 /*numElements=*/nullptr, arrayLen,
1692 /*isCtor=*/true);
1693}
1694
1695void LoweringPreparePass::lowerTrivialCopyCall(cir::CallOp op) {
1696 cir::FuncOp funcOp = getCalledFunction(op);
1697 if (!funcOp)
1698 return;
1699
1700 std::optional<cir::CtorKind> ctorKind = funcOp.getCxxConstructorKind();
1701 if (ctorKind && *ctorKind == cir::CtorKind::Copy &&
1702 funcOp.isCxxTrivialMemberFunction()) {
1703 // Replace the trivial copy constructor call with a `CopyOp`
1704 CIRBaseBuilderTy builder(getContext());
1705 mlir::ValueRange operands = op.getOperands();
1706 mlir::Value dest = operands[0];
1707 mlir::Value src = operands[1];
1708 builder.setInsertionPoint(op);
1709 builder.createCopy(dest, src);
1710 op.erase();
1711 }
1712}
1713
1714cir::GlobalOp LoweringPreparePass::getOrCreateConstAggregateGlobal(
1715 CIRBaseBuilderTy &builder, mlir::SymbolTableCollection &symbolTables,
1716 mlir::Location loc, llvm::StringRef baseName, mlir::Type ty,
1717 mlir::TypedAttr constant) {
1718 // Look up (and lazily populate) the per-base-name cache.
1719 llvm::SmallVector<cir::GlobalOp, 1> &versions =
1720 constAggregateGlobals[baseName];
1721
1722 // First, check globals we've already discovered for this base name.
1723 for (cir::GlobalOp gv : versions) {
1724 if (gv.getSymType() == ty && gv.getInitialValue() == constant)
1725 return gv;
1726 }
1727
1728 // No cached match. Scan the module's symbol table starting from the next
1729 // unscanned version. In practice this should usually exit on the first
1730 // iteration, but it's possible that some other pass or a previous
1731 // invocation of this pass created globals using this same logic.
1732 llvm::SmallString<128> name(baseName);
1733 size_t baseLen = name.size();
1734 unsigned version = versions.size();
1735 while (true) {
1736 name.resize(baseLen);
1737 if (version != 0) {
1738 name.push_back('.');
1739 llvm::Twine(version).toVector(name);
1740 }
1741 auto existingGv = symbolTables.lookupSymbolIn<cir::GlobalOp>(
1742 mlirModule, mlir::StringAttr::get(&getContext(), name));
1743 if (!existingGv)
1744 break;
1745 versions.push_back(existingGv);
1746 if (existingGv.getSymType() == ty &&
1747 existingGv.getInitialValue() == constant)
1748 return existingGv;
1749 ++version;
1750 }
1751
1752 // No match found, create a new global. The loop above found an unused name.
1753 mlir::OpBuilder::InsertionGuard guard(builder);
1754 builder.setInsertionPointToStart(mlirModule.getBody());
1755 auto gv =
1756 cir::GlobalOp::create(builder, loc, name, ty,
1757 /*isConstant=*/true,
1758 cir::LangAddressSpaceAttr::get(
1759 &getContext(), cir::LangAddressSpace::Default),
1760 cir::GlobalLinkageKind::PrivateLinkage);
1761 mlir::SymbolTable::setSymbolVisibility(
1762 gv, mlir::SymbolTable::Visibility::Private);
1763 gv.setInitialValueAttr(constant);
1764
1765 // Keep the cached symbol table in sync with the new global so subsequent
1766 // lookups for other base names find it.
1767 symbolTables.getSymbolTable(mlirModule).insert(gv);
1768
1769 versions.push_back(gv);
1770 return gv;
1771}
1772
1773void LoweringPreparePass::lowerStoreOfConstAggregate(
1774 cir::StoreOp op, mlir::SymbolTableCollection &symbolTables) {
1775 // Check if the value operand is a cir.const with aggregate type.
1776 auto constOp = op.getValue().getDefiningOp<cir::ConstantOp>();
1777 if (!constOp)
1778 return;
1779
1780 mlir::Type ty = constOp.getType();
1781 if (!mlir::isa<cir::ArrayType, cir::RecordType>(ty))
1782 return;
1783
1784 // Only transform stores to local variables (backed by cir.alloca).
1785 // Stores to other addresses (e.g. base_class_addr) should not be
1786 // transformed as they may be partial initializations.
1787 auto alloca = op.getAddr().getDefiningOp<cir::AllocaOp>();
1788 if (!alloca)
1789 return;
1790
1791 mlir::TypedAttr constant = constOp.getValue();
1792
1793 // OG implements several optimization tiers for constant aggregate
1794 // initialization. For now we always create a global constant + memcpy
1795 // (shouldCreateMemCpyFromGlobal). Future work can add the intermediate
1796 // tiers.
1800
1801 // Get function name from parent cir.func.
1802 auto func = op->getParentOfType<cir::FuncOp>();
1803 if (!func)
1804 return;
1805 llvm::StringRef funcName = func.getSymName();
1806
1807 // Get variable name from the alloca.
1808 llvm::StringRef varName = alloca.getName();
1809
1810 // Build base name: __const.<func>.<var>
1811 std::string baseName = ("__const." + funcName + "." + varName).str();
1812 CIRBaseBuilderTy builder(getContext());
1813
1814 // Check for existing globals and create a new global with a unique name
1815 // if no match is found.
1816 cir::GlobalOp gv = getOrCreateConstAggregateGlobal(
1817 builder, symbolTables, op.getLoc(), baseName, ty, constant);
1818
1819 // Now replace the store with get_global + copy.
1820 builder.setInsertionPoint(op);
1821
1822 auto ptrTy = cir::PointerType::get(ty);
1823 mlir::Value globalPtr =
1824 cir::GetGlobalOp::create(builder, op.getLoc(), ptrTy, gv.getSymName());
1825
1826 // Replace store with copy.
1827 builder.createCopy(op.getAddr(), globalPtr);
1828
1829 // Erase the original store.
1830 op.erase();
1831
1832 // Erase the cir.const if it has no remaining users.
1833 if (constOp.use_empty())
1834 constOp.erase();
1835}
1836
1837void LoweringPreparePass::runOnOp(mlir::Operation *op,
1838 mlir::SymbolTableCollection &symbolTables) {
1839 if (auto arrayCtor = dyn_cast<cir::ArrayCtor>(op)) {
1840 lowerArrayCtor(arrayCtor);
1841 } else if (auto arrayDtor = dyn_cast<cir::ArrayDtor>(op)) {
1842 lowerArrayDtor(arrayDtor);
1843 } else if (auto cast = mlir::dyn_cast<cir::CastOp>(op)) {
1844 lowerCastOp(cast);
1845 } else if (auto complexDiv = mlir::dyn_cast<cir::ComplexDivOp>(op)) {
1846 lowerComplexDivOp(complexDiv);
1847 } else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op)) {
1848 lowerComplexMulOp(complexMul);
1849 } else if (auto glob = mlir::dyn_cast<cir::GlobalOp>(op)) {
1850 lowerGlobalOp(glob);
1851 } else if (auto unaryOp = mlir::dyn_cast<cir::UnaryOpInterface>(op)) {
1852 lowerUnaryOp(unaryOp);
1853 } else if (auto callOp = dyn_cast<cir::CallOp>(op)) {
1854 lowerTrivialCopyCall(callOp);
1855 } else if (auto storeOp = dyn_cast<cir::StoreOp>(op)) {
1856 lowerStoreOfConstAggregate(storeOp, symbolTables);
1857 } else if (auto fnOp = dyn_cast<cir::FuncOp>(op)) {
1858 if (auto globalCtor = fnOp.getGlobalCtorPriority())
1859 globalCtorList.emplace_back(fnOp.getName(), globalCtor.value());
1860 else if (auto globalDtor = fnOp.getGlobalDtorPriority())
1861 globalDtorList.emplace_back(fnOp.getName(), globalDtor.value());
1862
1863 if (mlir::Attribute attr =
1864 fnOp->getAttr(cir::CUDAKernelNameAttr::getMnemonic())) {
1865 auto kernelNameAttr = dyn_cast<CUDAKernelNameAttr>(attr);
1866 llvm::StringRef kernelName = kernelNameAttr.getKernelName();
1867 cudaKernelMap[kernelName] = fnOp;
1868 }
1869 } else if (auto threeWayCmp = dyn_cast<cir::CmpThreeWayOp>(op)) {
1870 lowerThreeWayCmpOp(threeWayCmp);
1871 } else if (auto initOp = dyn_cast<cir::LocalInitOp>(op)) {
1872 lowerLocalInitOp(initOp, symbolTables);
1873 }
1874}
1875
1876static llvm::StringRef getCUDAPrefix(clang::ASTContext *astCtx) {
1877 if (astCtx->getLangOpts().HIP)
1878 return "hip";
1879 return "cuda";
1880}
1881
1882static std::string addUnderscoredPrefix(llvm::StringRef prefix,
1883 llvm::StringRef name) {
1884 return ("__" + prefix + name).str();
1885}
1886
1887/// Creates a global constructor function for the module:
1888///
1889/// For CUDA:
1890/// \code
1891/// void __cuda_module_ctor() {
1892/// Handle = __cudaRegisterFatBinary(GpuBinaryBlob);
1893/// __cuda_register_globals(Handle);
1894/// }
1895/// \endcode
1896///
1897/// For HIP:
1898/// \code
1899/// void __hip_module_ctor() {
1900/// if (__hip_gpubin_handle == 0) {
1901/// __hip_gpubin_handle = __hipRegisterFatBinary(GpuBinaryBlob);
1902/// __hip_register_globals(__hip_gpubin_handle);
1903/// }
1904/// }
1905/// \endcode
1906void LoweringPreparePass::buildCUDAModuleCtor() {
1907 bool isHIP = astCtx->getLangOpts().HIP;
1908
1909 if (isHIP)
1911 if (astCtx->getLangOpts().GPURelocatableDeviceCode)
1912 llvm_unreachable("GPU RDC NYI");
1913
1914 // For CUDA without -fgpu-rdc, it's safe to stop generating ctor
1915 // if there's nothing to register.
1916 if (cudaKernelMap.empty())
1917 return;
1918
1919 // There's no device-side binary, so no need to proceed for CUDA.
1920 // HIP has to create an external symbol in this case, which is NYI.
1921 mlir::Attribute cudaBinaryHandleAttr =
1922 mlirModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName());
1923 if (!cudaBinaryHandleAttr) {
1924 if (isHIP)
1926 return;
1927 }
1928
1929 llvm::StringRef cudaGPUBinaryName =
1930 mlir::cast<CUDABinaryHandleAttr>(cudaBinaryHandleAttr)
1931 .getName()
1932 .getValue();
1933
1934 llvm::vfs::FileSystem &vfs =
1936 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> gpuBinaryOrErr =
1937 vfs.getBufferForFile(cudaGPUBinaryName);
1938 if (std::error_code ec = gpuBinaryOrErr.getError()) {
1939 mlirModule->emitError("cannot open GPU binary file: " + cudaGPUBinaryName +
1940 ": " + ec.message());
1941 return;
1942 }
1943 std::unique_ptr<llvm::MemoryBuffer> gpuBinary =
1944 std::move(gpuBinaryOrErr.get());
1945
1946 // Set up common types and builder.
1947 llvm::StringRef cudaPrefix = getCUDAPrefix(astCtx);
1948 mlir::Location loc = mlirModule->getLoc();
1949 CIRBaseBuilderTy builder(getContext());
1950 builder.setInsertionPointToStart(mlirModule.getBody());
1951
1952 Type voidTy = builder.getVoidTy();
1953 PointerType voidPtrTy = builder.getVoidPtrTy();
1954 PointerType voidPtrPtrTy = builder.getPointerTo(voidPtrTy);
1955 IntType intTy = builder.getSIntNTy(32);
1956 IntType charTy = cir::IntType::get(&getContext(), astCtx->getCharWidth(),
1957 /*isSigned=*/false);
1958
1959 // --- Create fatbin globals ---
1960
1961 // The section names are different for MAC OS X.
1962 llvm::StringRef fatbinConstName =
1963 astCtx->getLangOpts().HIP ? ".hip_fatbin" : ".nv_fatbin";
1964
1965 llvm::StringRef fatbinSectionName =
1966 astCtx->getLangOpts().HIP ? ".hipFatBinSegment" : ".nvFatBinSegment";
1967
1968 // Create the fatbin string constant with GPU binary contents.
1969 auto fatbinType =
1970 ArrayType::get(&getContext(), charTy, gpuBinary->getBuffer().size());
1971 std::string fatbinStrName = addUnderscoredPrefix(cudaPrefix, "_fatbin_str");
1972 GlobalOp fatbinStr = GlobalOp::create(builder, loc, fatbinStrName, fatbinType,
1973 /*isConstant=*/true, {},
1974 GlobalLinkageKind::PrivateLinkage);
1975 fatbinStr.setAlignment(8);
1976 fatbinStr.setInitialValueAttr(cir::ConstArrayAttr::get(
1977 fatbinType, StringAttr::get(gpuBinary->getBuffer(), fatbinType)));
1978 fatbinStr.setSection(fatbinConstName);
1979 fatbinStr.setPrivate();
1980
1981 // Create the fatbin wrapper struct:
1982 // struct { int magic; int version; void *fatbin; void *unused; };
1983 auto fatbinWrapperType = RecordType::get(
1984 &getContext(), {intTy, intTy, voidPtrTy, voidPtrTy},
1985 /*packed=*/false, /*padded=*/false, RecordType::RecordKind::Struct);
1986 std::string fatbinWrapperName =
1987 addUnderscoredPrefix(cudaPrefix, "_fatbin_wrapper");
1988 GlobalOp fatbinWrapper = GlobalOp::create(
1989 builder, loc, fatbinWrapperName, fatbinWrapperType,
1990 /*isConstant=*/true, {}, GlobalLinkageKind::PrivateLinkage);
1991 fatbinWrapper.setSection(fatbinSectionName);
1992
1993 constexpr unsigned cudaFatMagic = 0x466243b1;
1994 constexpr unsigned hipFatMagic = 0x48495046;
1995 unsigned fatMagic = isHIP ? hipFatMagic : cudaFatMagic;
1996
1997 auto magicInit = IntAttr::get(intTy, fatMagic);
1998 auto versionInit = IntAttr::get(intTy, 1);
1999 auto fatbinStrSymbol =
2000 mlir::FlatSymbolRefAttr::get(fatbinStr.getSymNameAttr());
2001 auto fatbinInit = GlobalViewAttr::get(voidPtrTy, fatbinStrSymbol);
2002 mlir::TypedAttr unusedInit = builder.getConstNullPtrAttr(voidPtrTy);
2003 fatbinWrapper.setInitialValueAttr(cir::ConstRecordAttr::get(
2004 fatbinWrapperType,
2005 mlir::ArrayAttr::get(&getContext(),
2006 {magicInit, versionInit, fatbinInit, unusedInit})));
2007
2008 // Create the GPU binary handle global variable.
2009 std::string gpubinHandleName =
2010 addUnderscoredPrefix(cudaPrefix, "_gpubin_handle");
2011
2012 GlobalOp gpuBinHandle = GlobalOp::create(
2013 builder, loc, gpubinHandleName, voidPtrPtrTy,
2014 /*isConstant=*/false, {}, cir::GlobalLinkageKind::InternalLinkage);
2015 gpuBinHandle.setInitialValueAttr(builder.getConstNullPtrAttr(voidPtrPtrTy));
2016 gpuBinHandle.setPrivate();
2017
2018 // Declare this function:
2019 // void **__{cuda|hip}RegisterFatBinary(void *);
2020
2021 std::string regFuncName =
2022 addUnderscoredPrefix(cudaPrefix, "RegisterFatBinary");
2023 FuncType regFuncType = FuncType::get({voidPtrTy}, voidPtrPtrTy);
2024 cir::FuncOp regFunc =
2025 buildRuntimeFunction(builder, regFuncName, loc, regFuncType);
2026
2027 std::string moduleCtorName = addUnderscoredPrefix(cudaPrefix, "_module_ctor");
2028 cir::FuncOp moduleCtor = buildRuntimeFunction(
2029 builder, moduleCtorName, loc, FuncType::get({}, voidTy),
2030 GlobalLinkageKind::InternalLinkage);
2031
2032 globalCtorList.emplace_back(moduleCtorName,
2033 cir::GlobalCtorAttr::getDefaultPriority());
2034 builder.setInsertionPointToStart(moduleCtor.addEntryBlock());
2036 if (isHIP) {
2037 llvm_unreachable("HIP Module Constructor Support");
2038 } else if (!astCtx->getLangOpts().GPURelocatableDeviceCode) {
2039
2040 // --- Create CUDA CTOR-DTOR ---
2041 // Register binary with CUDA runtime. This is substantially different in
2042 // default mode vs. separate compilation.
2043 // Corresponding code:
2044 // gpuBinaryHandle = __cudaRegisterFatBinary(&fatbinWrapper);
2045 mlir::Value wrapper = builder.createGetGlobal(fatbinWrapper);
2046 mlir::Value fatbinVoidPtr = builder.createBitcast(wrapper, voidPtrTy);
2047 cir::CallOp gpuBinaryHandleCall =
2048 builder.createCallOp(loc, regFunc, fatbinVoidPtr);
2049 mlir::Value gpuBinaryHandle = gpuBinaryHandleCall.getResult();
2050 // Store the value back to the global `__cuda_gpubin_handle`.
2051 mlir::Value gpuBinaryHandleGlobal = builder.createGetGlobal(gpuBinHandle);
2052 builder.createStore(loc, gpuBinaryHandle, gpuBinaryHandleGlobal);
2053
2054 // --- Generate __cuda_register_globals and call it ---
2055 if (std::optional<FuncOp> regGlobal = buildCUDARegisterGlobals()) {
2056 builder.createCallOp(loc, *regGlobal, gpuBinaryHandle);
2057 }
2058
2059 // From CUDA 10.1 onwards, we must call this function to end registration:
2060 // void __cudaRegisterFatBinaryEnd(void **fatbinHandle);
2061 // This is CUDA-specific, so no need to use `addUnderscoredPrefix`.
2063 astCtx->getTargetInfo().getSDKVersion(),
2065 cir::CIRBaseBuilderTy globalBuilder(getContext());
2066 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
2067 FuncOp endFunc =
2068 buildRuntimeFunction(globalBuilder, "__cudaRegisterFatBinaryEnd", loc,
2069 FuncType::get({voidPtrPtrTy}, voidTy));
2070 builder.createCallOp(loc, endFunc, gpuBinaryHandle);
2071 }
2072 } else
2073 llvm_unreachable("GPU RDC NYI");
2074
2075 // Create destructor and register it with atexit() the way NVCC does it. Doing
2076 // it during regular destructor phase worked in CUDA before 9.2 but results in
2077 // double-free in 9.2.
2078 if (std::optional<FuncOp> dtor = buildCUDAModuleDtor()) {
2079
2080 // extern "C" int atexit(void (*f)(void));
2081 cir::CIRBaseBuilderTy globalBuilder(getContext());
2082 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
2083 FuncOp atexit = buildRuntimeFunction(
2084 globalBuilder, "atexit", loc,
2085 FuncType::get(PointerType::get(dtor->getFunctionType()), intTy));
2086 mlir::Value dtorFunc = GetGlobalOp::create(
2087 builder, loc, PointerType::get(dtor->getFunctionType()),
2088 mlir::FlatSymbolRefAttr::get(dtor->getSymNameAttr()));
2089 builder.createCallOp(loc, atexit, dtorFunc);
2090 }
2091 cir::ReturnOp::create(builder, loc);
2092}
2093
2094std::optional<FuncOp> LoweringPreparePass::buildCUDAModuleDtor() {
2095 if (!mlirModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName()))
2096 return {};
2097
2098 llvm::StringRef prefix = getCUDAPrefix(astCtx);
2099
2100 VoidType voidTy = VoidType::get(&getContext());
2101 PointerType voidPtrPtrTy = PointerType::get(PointerType::get(voidTy));
2102
2103 mlir::Location loc = mlirModule.getLoc();
2104
2105 cir::CIRBaseBuilderTy builder(getContext());
2106 builder.setInsertionPointToStart(mlirModule.getBody());
2107
2108 // define: void __cudaUnregisterFatBinary(void ** handle);
2109 std::string unregisterFuncName =
2110 addUnderscoredPrefix(prefix, "UnregisterFatBinary");
2111 FuncOp unregisterFunc = buildRuntimeFunction(
2112 builder, unregisterFuncName, loc, FuncType::get({voidPtrPtrTy}, voidTy));
2113
2114 // void __cuda_module_dtor();
2115 // Despite the name, OG doesn't treat it as a destructor, so it shouldn't be
2116 // put into globalDtorList. If it were a real dtor, then it would cause
2117 // double free above CUDA 9.2. The way to use it is to manually call
2118 // atexit() at end of module ctor.
2119 std::string dtorName = addUnderscoredPrefix(prefix, "_module_dtor");
2120 FuncOp dtor =
2121 buildRuntimeFunction(builder, dtorName, loc, FuncType::get({}, voidTy),
2122 GlobalLinkageKind::InternalLinkage);
2123
2124 builder.setInsertionPointToStart(dtor.addEntryBlock());
2125
2126 // For dtor, we only need to call:
2127 // __cudaUnregisterFatBinary(__cuda_gpubin_handle);
2128
2129 std::string gpubinName = addUnderscoredPrefix(prefix, "_gpubin_handle");
2130 GlobalOp gpubinGlobal = cast<GlobalOp>(mlirModule.lookupSymbol(gpubinName));
2131 mlir::Value gpubinAddress = builder.createGetGlobal(gpubinGlobal);
2132 mlir::Value gpubin = builder.createLoad(loc, gpubinAddress);
2133 builder.createCallOp(loc, unregisterFunc, gpubin);
2134 ReturnOp::create(builder, loc);
2135
2136 return dtor;
2137}
2138
2139std::optional<FuncOp> LoweringPreparePass::buildCUDARegisterGlobals() {
2140 // There is nothing to register.
2141 if (cudaKernelMap.empty())
2142 return {};
2143
2144 cir::CIRBaseBuilderTy builder(getContext());
2145 builder.setInsertionPointToStart(mlirModule.getBody());
2146
2147 mlir::Location loc = mlirModule.getLoc();
2148 llvm::StringRef cudaPrefix = getCUDAPrefix(astCtx);
2149
2150 auto voidTy = VoidType::get(&getContext());
2151 auto voidPtrTy = PointerType::get(voidTy);
2152 auto voidPtrPtrTy = PointerType::get(voidPtrTy);
2153
2154 // Create the function:
2155 // void __cuda_register_globals(void **fatbinHandle)
2156 std::string regGlobalFuncName =
2157 addUnderscoredPrefix(cudaPrefix, "_register_globals");
2158 auto regGlobalFuncTy = FuncType::get({voidPtrPtrTy}, voidTy);
2159 FuncOp regGlobalFunc =
2160 buildRuntimeFunction(builder, regGlobalFuncName, loc, regGlobalFuncTy,
2161 /*linkage=*/GlobalLinkageKind::InternalLinkage);
2162 builder.setInsertionPointToStart(regGlobalFunc.addEntryBlock());
2163
2164 buildCUDARegisterGlobalFunctions(builder, regGlobalFunc);
2165 // TODO: Handle shadow registration
2167
2168 ReturnOp::create(builder, loc);
2169 return regGlobalFunc;
2170}
2171
2172void LoweringPreparePass::buildCUDARegisterGlobalFunctions(
2173 cir::CIRBaseBuilderTy &builder, FuncOp regGlobalFunc) {
2174 mlir::Location loc = mlirModule.getLoc();
2175 llvm::StringRef cudaPrefix = getCUDAPrefix(astCtx);
2176 cir::CIRDataLayout dataLayout(mlirModule);
2177
2178 auto voidTy = VoidType::get(&getContext());
2179 auto voidPtrTy = PointerType::get(voidTy);
2180 auto voidPtrPtrTy = PointerType::get(voidPtrTy);
2181 IntType intTy = builder.getSIntNTy(32);
2182 IntType charTy = cir::IntType::get(&getContext(), astCtx->getCharWidth(),
2183 /*isSigned=*/false);
2184
2185 // Extract the GPU binary handle argument.
2186 mlir::Value fatbinHandle = *regGlobalFunc.args_begin();
2187
2188 cir::CIRBaseBuilderTy globalBuilder(getContext());
2189 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
2190
2191 // Declare CUDA internal functions:
2192 // int __cudaRegisterFunction(
2193 // void **fatbinHandle,
2194 // const char *hostFunc,
2195 // char *deviceFunc,
2196 // const char *deviceName,
2197 // int threadLimit,
2198 // uint3 *tid, uint3 *bid, dim3 *bDim, dim3 *gDim,
2199 // int *wsize
2200 // )
2201 // OG doesn't care about the types at all. They're treated as void*.
2202
2203 FuncOp cudaRegisterFunction = buildRuntimeFunction(
2204 globalBuilder, addUnderscoredPrefix(cudaPrefix, "RegisterFunction"), loc,
2205 FuncType::get({voidPtrPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, intTy,
2206 voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy},
2207 intTy));
2208
2209 auto makeConstantString = [&](llvm::StringRef str) -> GlobalOp {
2210 auto strType = ArrayType::get(&getContext(), charTy, 1 + str.size());
2211 auto tmpString = cir::GlobalOp::create(
2212 globalBuilder, loc, (".str" + str).str(), strType,
2213 /*isConstant=*/true, {},
2214 /*linkage=*/cir::GlobalLinkageKind::PrivateLinkage);
2215
2216 // We must make the string zero-terminated.
2217 tmpString.setInitialValueAttr(
2218 ConstArrayAttr::get(strType, StringAttr::get(str + "\0", strType)));
2219 tmpString.setPrivate();
2220 return tmpString;
2221 };
2222
2223 cir::ConstantOp cirNullPtr = builder.getNullPtr(voidPtrTy, loc);
2224 bool isHIP = astCtx->getLangOpts().HIP;
2225 for (auto kernelName : cudaKernelMap.keys()) {
2226 FuncOp deviceStub = cudaKernelMap[kernelName];
2227 GlobalOp deviceFuncStr = makeConstantString(kernelName);
2228 mlir::Value deviceFunc = builder.createBitcast(
2229 builder.createGetGlobal(deviceFuncStr), voidPtrTy);
2230
2231 if (isHIP) {
2232 llvm_unreachable("HIP kernel registration NYI");
2233 } else {
2234 mlir::Value hostFunc = builder.createBitcast(
2235 GetGlobalOp::create(
2236 builder, loc, PointerType::get(deviceStub.getFunctionType()),
2237 mlir::FlatSymbolRefAttr::get(deviceStub.getSymNameAttr())),
2238 voidPtrTy);
2239 builder.createCallOp(
2240 loc, cudaRegisterFunction,
2241 {fatbinHandle, hostFunc, deviceFunc, deviceFunc,
2242 ConstantOp::create(builder, loc, IntAttr::get(intTy, -1)),
2243 cirNullPtr, cirNullPtr, cirNullPtr, cirNullPtr, cirNullPtr});
2244 }
2245 }
2246}
2247
2248void LoweringPreparePass::runOnOperation() {
2249 mlir::Operation *op = getOperation();
2250 if (isa<::mlir::ModuleOp>(op))
2251 mlirModule = cast<::mlir::ModuleOp>(op);
2252
2253 llvm::SmallVector<mlir::Operation *> opsToTransform;
2254 mlir::SymbolTableCollection symbolTables;
2255
2256 op->walk([&](mlir::Operation *op) {
2257 if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
2258 cir::ComplexMulOp, cir::ComplexDivOp, cir::DynamicCastOp,
2259 cir::FuncOp, cir::CallOp, cir::GlobalOp, cir::StoreOp,
2260 cir::CmpThreeWayOp, cir::IncOp, cir::DecOp, cir::MinusOp,
2261 cir::NotOp, cir::LocalInitOp>(op))
2262 opsToTransform.push_back(op);
2263 });
2264
2265 for (mlir::Operation *o : opsToTransform)
2266 runOnOp(o, symbolTables);
2267
2268 buildCXXGlobalInitFunc();
2269 if (astCtx->getLangOpts().CUDA && !astCtx->getLangOpts().CUDAIsDevice)
2270 buildCUDAModuleCtor();
2271
2272 buildGlobalCtorDtorList();
2273}
2274
2275std::unique_ptr<Pass> mlir::createLoweringPreparePass() {
2276 return std::make_unique<LoweringPreparePass>();
2277}
2278
2279std::unique_ptr<Pass>
2281 auto pass = std::make_unique<LoweringPreparePass>();
2282 pass->setASTContext(astCtx);
2283 return std::move(pass);
2284}
Defines the clang::ASTContext interface.
static void emitBody(CodeGenFunction &CGF, const Stmt *S, const Stmt *NextLoop, int MaxLevel, int Level=0)
static llvm::FunctionCallee getGuardReleaseFn(CodeGenModule &CGM, llvm::PointerType *GuardPtrTy)
static llvm::FunctionCallee getGuardAcquireFn(CodeGenModule &CGM, llvm::PointerType *GuardPtrTy)
static mlir::Value buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static llvm::StringRef getComplexDivLibCallName(llvm::APFloat::Semantics semantics)
static llvm::SmallVector< mlir::Attribute > prepareCtorDtorAttrList(mlir::MLIRContext *context, llvm::ArrayRef< std::pair< std::string, uint32_t > > list)
static llvm::StringRef getComplexMulLibCallName(llvm::APFloat::Semantics semantics)
static mlir::Value buildComplexBinOpLibCall(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, llvm::StringRef(*libFuncNameGetter)(llvm::APFloat::Semantics), mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static mlir::Value lowerComplexMul(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexMulOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static std::string addUnderscoredPrefix(llvm::StringRef prefix, llvm::StringRef name)
static SmallString< 128 > getTransformedFileName(mlir::ModuleOp mlirModule)
static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind scalarCastKind)
static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder, clang::ASTContext *astCtx, mlir::Operation *op, mlir::Type eltTy, mlir::Value addr, mlir::Value numElements, uint64_t arrayLen, bool isCtor)
Lower a cir.array.ctor or cir.array.dtor into a do-while loop that iterates over every element.
static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind elemToBoolKind)
static mlir::Value buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static llvm::StringRef getCUDAPrefix(clang::ASTContext *astCtx)
static cir::FuncOp getCalledFunction(cir::CallOp callOp)
Return the FuncOp called by callOp.
static mlir::Type higherPrecisionElementTypeForComplexArithmetic(mlir::MLIRContext &context, clang::ASTContext &cc, CIRBaseBuilderTy &builder, mlir::Type elementType)
static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op)
static mlir::Value lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag, mlir::MLIRContext &mlirCx, clang::ASTContext &cc)
Defines the clang::Module class, which describes a module in the source code.
Defines the SourceManager interface.
Defines various enumerations that describe declaration and type specifiers.
Defines the TargetCXXABI class, which abstracts details of the C++ ABI that we're targeting.
__device__ __2f16 b
__device__ __2f16 float c
mlir::Value createDiv(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::TypedAttr getConstNullPtrAttr(mlir::Type t)
mlir::Value createDec(mlir::Location loc, mlir::Value input, bool nsw=false)
mlir::Value createLogicalOr(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createSub(mlir::Location loc, mlir::Value lhs, mlir::Value rhs, OverflowBehavior ob=OverflowBehavior::None)
cir::ConditionOp createCondition(mlir::Value condition)
Create a loop condition.
mlir::Value createInc(mlir::Location loc, mlir::Value input, bool nsw=false)
cir::CopyOp createCopy(mlir::Value dst, mlir::Value src, bool isVolatile=false, bool skipTailPadding=false)
Create a copy with inferred length.
cir::VoidType getVoidTy()
cir::ConstantOp getNullValue(mlir::Type ty, mlir::Location loc)
mlir::Value createCast(mlir::Location loc, cir::CastKind kind, mlir::Value src, mlir::Type newTy)
cir::PointerType getVoidFnPtrTy(mlir::TypeRange argTypes={})
Returns void (*)(T...) as a cir::PointerType.
mlir::Value createAdd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs, OverflowBehavior ob=OverflowBehavior::None)
cir::PointerType getPointerTo(mlir::Type ty)
mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand)
cir::ConstantOp getNullPtr(mlir::Type ty, mlir::Location loc)
cir::DoWhileOp createDoWhile(mlir::Location loc, llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> condBuilder, llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> bodyBuilder)
Create a do-while operation.
cir::GetGlobalOp createGetGlobal(mlir::Location loc, cir::GlobalOp global, bool threadLocal=false)
mlir::Value getSignedInt(mlir::Location loc, int64_t val, unsigned numBits)
mlir::Value createAnd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createBitcast(mlir::Value src, mlir::Type newTy)
cir::FuncType getVoidFnTy(mlir::TypeRange argTypes={})
Returns void (T...) as a cir::FuncType.
cir::CmpOp createCompare(mlir::Location loc, cir::CmpOpKind kind, mlir::Value lhs, mlir::Value rhs)
mlir::IntegerAttr getAlignmentAttr(clang::CharUnits alignment)
mlir::Value createSelect(mlir::Location loc, mlir::Value condition, mlir::Value trueValue, mlir::Value falseValue)
mlir::Value createMul(mlir::Location loc, mlir::Value lhs, mlir::Value rhs, OverflowBehavior ob=OverflowBehavior::None)
cir::LoadOp createLoad(mlir::Location loc, mlir::Value ptr, bool isVolatile=false, uint64_t alignment=0)
mlir::Value createMinus(mlir::Location loc, mlir::Value input, bool nsw=false)
cir::ConstantOp getConstantInt(mlir::Location loc, mlir::Type ty, int64_t value)
mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real, mlir::Value imag)
cir::PointerType getVoidPtrTy(clang::LangAS langAS=clang::LangAS::Default)
mlir::Value createIsNaN(mlir::Location loc, mlir::Value operand)
cir::IntType getSIntNTy(int n)
mlir::Value createAlignedLoad(mlir::Location loc, mlir::Value ptr, uint64_t alignment)
cir::CallOp createCallOp(mlir::Location loc, mlir::SymbolRefAttr callee, mlir::Type returnType, mlir::ValueRange operands, llvm::ArrayRef< mlir::NamedAttribute > attrs={}, llvm::ArrayRef< mlir::NamedAttrList > argAttrs={}, llvm::ArrayRef< mlir::NamedAttribute > resAttrs={})
cir::StoreOp createStore(mlir::Location loc, mlir::Value val, mlir::Value dst, bool isVolatile=false, mlir::IntegerAttr align={}, cir::SyncScopeKindAttr scope={}, cir::MemOrderAttr order={})
cir::YieldOp createYield(mlir::Location loc, mlir::ValueRange value={})
Create a yield operation.
mlir::Value createLogicalAnd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType, mlir::Type type, llvm::StringRef name, mlir::IntegerAttr alignment, mlir::Value dynAllocSize)
cir::BoolType getBoolTy()
mlir::Value getUnsignedInt(mlir::Location loc, uint64_t val, unsigned numBits)
mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand)
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
Definition ASTContext.h:227
SourceManager & getSourceManager()
Definition ASTContext.h:866
MangleContext * createMangleContext(const TargetInfo *T=nullptr)
If T is null pointer, assume the target in ASTContext.
const LangOptions & getLangOpts() const
Definition ASTContext.h:959
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
const TargetInfo & getTargetInfo() const
Definition ASTContext.h:924
QualType getSignedSizeType() const
Return the unique signed counterpart of the integer type corresponding to size_t.
Module * getCurrentNamedModule() const
Get module under construction, nullptr if this is not a C++20 module.
uint64_t getCharWidth() const
Return the size of the character type, in bits.
llvm::Align getAsAlign() const
getAsAlign - Returns Quantity as a valid llvm::Align, Beware llvm::Align assumes power of two 8-bit b...
Definition CharUnits.h:189
QuantityType getQuantity() const
getQuantity - Get the raw integer representation of this quantity.
Definition CharUnits.h:185
static CharUnits One()
One - Construct a CharUnits quantity of one.
Definition CharUnits.h:58
static CharUnits fromQuantity(QuantityType Quantity)
fromQuantity - Construct a CharUnits quantity from a raw integer type.
Definition CharUnits.h:63
llvm::vfs::FileSystem & getVirtualFileSystem() const
bool isModuleImplementation() const
Is this a module implementation.
Definition Module.h:843
FileManager & getFileManager() const
Exposes information about the current target.
Definition TargetInfo.h:227
unsigned getMaxAtomicInlineWidth() const
Return the maximum width lock-free atomic operation which can be inlined given the supported features...
Definition TargetInfo.h:859
const llvm::fltSemantics & getDoubleFormat() const
Definition TargetInfo.h:804
const llvm::fltSemantics & getHalfFormat() const
Definition TargetInfo.h:789
const llvm::fltSemantics & getBFloat16Format() const
Definition TargetInfo.h:799
const llvm::fltSemantics & getLongDoubleFormat() const
Definition TargetInfo.h:810
const llvm::fltSemantics & getFloatFormat() const
Definition TargetInfo.h:794
const llvm::fltSemantics & getFloat128Format() const
Definition TargetInfo.h:818
const llvm::VersionTuple & getSDKVersion() const
Defines the clang::TargetInfo interface.
const internal::VariadicDynCastAllOfMatcher< Decl, VarDecl > varDecl
Matches variable declarations.
bool isHIP(ID Id)
isHIP - Is this a HIP input.
Definition Types.cpp:291
RangeSelector name(std::string ID)
Given a node with a "name", (like NamedDecl, DeclRefExpr, CxxCtorInitializer, and TypeLoc) selects th...
bool isTemplateInstantiation(TemplateSpecializationKind Kind)
Determine whether this template specialization kind refers to an instantiation of an entity (as oppos...
Definition Specifiers.h:213
bool CudaFeatureEnabled(llvm::VersionTuple, CudaFeature)
Definition Cuda.cpp:163
LLVM_READONLY bool isPreprocessingNumberBody(unsigned char c)
Return true if this is the body character of a C preprocessing number, which is [a-zA-Z0-9_.
Definition CharInfo.h:168
@ CUDA_USES_FATBIN_REGISTER_END
Definition Cuda.h:80
unsigned int uint32_t
std::unique_ptr< Pass > createLoweringPreparePass()
static bool opGlobalThreadLocal()
static bool hipModuleCtor()
static bool guardAbortOnException()
static bool opGlobalAnnotations()
static bool opGlobalCtorPriority()
static bool shouldSplitConstantStore()
static bool shouldUseMemSetToInitialize()
static bool opFuncExtraAttrs()
static bool shouldUseBZeroPlusStoresToInitialize()
static bool globalRegistration()
static bool fastMathFlags()
static bool astVarDeclInterface()