clang 23.0.0git
LoweringPrepare.cpp
Go to the documentation of this file.
1//===- LoweringPrepare.cpp - pareparation work for LLVM lowering ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PassDetail.h"
10#include "mlir/IR/Attributes.h"
11#include "mlir/IR/BuiltinAttributeInterfaces.h"
12#include "mlir/IR/IRMapping.h"
13#include "mlir/IR/Location.h"
14#include "mlir/IR/Value.h"
16#include "clang/AST/Mangle.h"
17#include "clang/Basic/Cuda.h"
18#include "clang/Basic/Module.h"
32#include "llvm/ADT/StringRef.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/IR/Instructions.h"
35#include "llvm/Support/ErrorHandling.h"
36#include "llvm/Support/MemoryBuffer.h"
37#include "llvm/Support/Path.h"
38#include "llvm/Support/VirtualFileSystem.h"
39
40#include <memory>
41#include <optional>
42
43using namespace mlir;
44using namespace cir;
45
46namespace mlir {
47#define GEN_PASS_DEF_LOWERINGPREPARE
48#include "clang/CIR/Dialect/Passes.h.inc"
49} // namespace mlir
50
51static SmallString<128> getTransformedFileName(mlir::ModuleOp mlirModule) {
52 SmallString<128> fileName;
53
54 if (mlirModule.getSymName())
55 fileName = llvm::sys::path::filename(mlirModule.getSymName()->str());
56
57 if (fileName.empty())
58 fileName = "<null>";
59
60 for (size_t i = 0; i < fileName.size(); ++i) {
61 // Replace everything that's not [a-zA-Z0-9._] with a _. This set happens
62 // to be the set of C preprocessing numbers.
63 if (!clang::isPreprocessingNumberBody(fileName[i]))
64 fileName[i] = '_';
65 }
66
67 return fileName;
68}
69
70namespace {
71struct LoweringPreparePass
72 : public impl::LoweringPrepareBase<LoweringPreparePass> {
73 LoweringPreparePass() = default;
74
75 // `mlir::SymbolTableCollection` is move-only (it owns lazily-created
76 // `unique_ptr<SymbolTable>` entries), which makes the implicit copy
77 // constructor ill-formed. MLIR's `clonePass()` requires copy
78 // construction, so define one explicitly. Per-run state members
79 // (dynamic initializers, guard maps, symbol-table cache, etc.) all
80 // start fresh in the cloned pass, which matches MLIR convention for
81 // pass clones and is more correct than the previous default-generated
82 // behavior that silently copied them.
83 LoweringPreparePass(const LoweringPreparePass &other)
84 : impl::LoweringPrepareBase<LoweringPreparePass>(other) {}
85
86 void runOnOperation() override;
87
88 void runOnOp(mlir::Operation *op);
89 void lowerCastOp(cir::CastOp op);
90 void lowerComplexConjOp(cir::ComplexConjOp op);
91 void lowerComplexDivOp(cir::ComplexDivOp op);
92 void lowerComplexMulOp(cir::ComplexMulOp op);
93 void lowerGetGlobalOp(cir::GetGlobalOp op);
94 void lowerGlobalOp(cir::GlobalOp op);
95 void lowerThreeWayCmpOp(cir::CmpThreeWayOp op);
96 void lowerArrayDtor(cir::ArrayDtor op);
97 void lowerArrayCtor(cir::ArrayCtor op);
98 void lowerTrivialCopyCall(cir::CallOp op);
99 void lowerStoreOfConstAggregate(cir::StoreOp op);
100 void lowerLocalInitOp(cir::LocalInitOp op);
101
102 /// Return the FuncOp called by `callOp`. Uses the cached `symbolTables`
103 /// member to avoid the O(M) module-wide scan that the static
104 /// `mlir::SymbolTable::lookupNearestSymbolFrom` would do per call.
105 cir::FuncOp getCalledFunction(cir::CallOp callOp);
106
107 /// Return a private constant cir::GlobalOp with the given type and initial
108 /// value, suitable for backing a memcpy-initialized local aggregate.
109 ///
110 /// If a global with `baseName` (or one of its `.<n>` versioned siblings)
111 /// already has a matching type and initial value, that global is reused.
112 /// Otherwise a new global is created with the next available `.<n>` suffix
113 /// (matching CIRGenBuilder::createVersionedGlobal and OGCG behavior).
114 cir::GlobalOp getOrCreateConstAggregateGlobal(CIRBaseBuilderTy &builder,
115 mlir::Location loc,
116 llvm::StringRef baseName,
117 mlir::Type ty,
118 mlir::TypedAttr constant);
119
120 /// Build the function that initializes the specified global
121 cir::FuncOp buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op);
122
123 /// When looking at the 'global' op, create the wrapper function.
124 void defineGlobalThreadLocalWrapper(cir::GlobalOp op, cir::FuncOp initAlias,
125 bool isVarDefinition);
126 /// Create an initialization alias for a thread-local variable.
127 cir::FuncOp defineGlobalThreadLocalInitAlias(cir::GlobalOp op,
128 cir::FuncOp aliasee);
129 /// Get the declaration for the 'wrapper' function for a global-TLS variable.
130 cir::FuncOp getOrCreateThreadLocalWrapper(CIRBaseBuilderTy &builder,
131 cir::GlobalOp op);
132 // Function that generates the guard global variable, get-global, and 'if'
133 // condition for global TLS init function generation. This inserts an 'if'
134 // with the store at the beginning of the 'then' region, so inserts into the
135 // body should happen after that.
136 cir::IfOp buildGlobalTlsGuardCheck(CIRBaseBuilderTy &builder,
137 mlir::Location loc, cir::GlobalOp guard);
138 /// Handle the dtor region by registering destructor with __cxa_atexit
139 cir::FuncOp getOrCreateDtorFunc(CIRBaseBuilderTy &builder, cir::GlobalOp op,
140 mlir::Region &dtorRegion,
141 cir::CallOp &dtorCall);
142
143 /// Build a module init function that calls all the dynamic initializers.
144 void buildCXXGlobalInitFunc();
145 // Build an init function for all of the ordered global thread local storage
146 // variables.
147 void buildCXXGlobalTlsFunc();
148
149 /// Materialize global ctor/dtor list
150 void buildGlobalCtorDtorList();
151
152 cir::FuncOp buildRuntimeFunction(
153 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
154 cir::FuncType type,
155 cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage);
156
157 cir::GlobalOp getOrCreateRuntimeVariable(
158 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
159 mlir::Type type,
160 cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage,
161 cir::VisibilityKind visibility = cir::VisibilityKind::Default);
162
163 /// ------------
164 /// CUDA registration related
165 /// ------------
166
167 llvm::StringMap<FuncOp> cudaKernelMap;
168 llvm::SmallVector<std::pair<cir::GlobalOp, cir::CUDAVarRegistrationInfoAttr>>
169 cudaDeviceVars;
170
171 /// Build the CUDA module constructor that registers the fat binary
172 /// with the CUDA runtime.
173 void buildCUDAModuleCtor();
174 std::optional<FuncOp> buildCUDAModuleDtor();
175 std::optional<FuncOp> buildHIPModuleDtor();
176 std::optional<FuncOp> buildCUDARegisterGlobals();
177 void buildCUDARegisterVars(cir::CIRBaseBuilderTy &builder,
178 FuncOp regGlobalFunc);
179 void buildCUDARegisterGlobalFunctions(cir::CIRBaseBuilderTy &builder,
180 FuncOp regGlobalFunc);
181
182 /// Handle static local variable initialization with guard variables.
183 void handleStaticLocal(cir::GlobalOp globalOp, cir::LocalInitOp localInitOp);
184
185 /// Get or create __cxa_guard_acquire function.
186 cir::FuncOp getGuardAcquireFn(cir::PointerType guardPtrTy);
187
188 /// Get or create __cxa_guard_release function.
189 cir::FuncOp getGuardReleaseFn(cir::PointerType guardPtrTy);
190
191 /// Get or create the __init_tls function.
192 cir::FuncOp getTlsInitFn();
193
194 // Create the __tls_guard variable.
195 cir::GlobalOp createGlobalThreadLocalGuard(CIRBaseBuilderTy &builder,
196 mlir::Location loc);
197
198 /// Create a guard global variable for a static local.
199 cir::GlobalOp createGuardGlobalOp(CIRBaseBuilderTy &builder,
200 mlir::Location loc, llvm::StringRef name,
201 cir::IntType guardTy,
202 cir::GlobalLinkageKind linkage);
203
204 /// Get the guard variable for a static local declaration.
205 cir::GlobalOp getStaticLocalDeclGuardAddress(llvm::StringRef globalSymName) {
206 auto it = staticLocalDeclGuardMap.find(globalSymName);
207 if (it != staticLocalDeclGuardMap.end())
208 return it->second;
209 return nullptr;
210 }
211
212 /// Set the guard variable for a static local declaration.
213 void setStaticLocalDeclGuardAddress(llvm::StringRef globalSymName,
214 cir::GlobalOp guard) {
215 staticLocalDeclGuardMap[globalSymName] = guard;
216 }
217
218 /// Get or create the guard variable for a static local declaration.
219 cir::GlobalOp getOrCreateStaticLocalDeclGuardAddress(
220 CIRBaseBuilderTy &builder, cir::GlobalOp globalOp, StringRef guardName,
221 bool isLocalVarDecl, bool useInt8GuardVariable) {
222
223 cir::CIRDataLayout dataLayout(mlirModule);
224 cir::IntType guardTy;
225 clang::CharUnits guardAlignment;
226 // Guard variables are 64 bits in the generic ABI and size width on ARM
227 // (i.e. 32-bit on AArch32, 64-bit on AArch64).
228 if (useInt8GuardVariable) {
229 guardTy = cir::IntType::get(&getContext(), 8, /*isSigned=*/true);
230 guardAlignment = clang::CharUnits::One();
231 } else if (useARMGuardVarABI()) {
232 // Guard variables are size width on ARM (32-bit AArch32, 64-bit AArch64).
233 const unsigned sizeTypeSize =
234 astCtx->getTypeSize(astCtx->getSignedSizeType());
235 guardTy =
236 cir::IntType::get(&getContext(), sizeTypeSize, /*isSigned=*/true);
237 guardAlignment =
238 clang::CharUnits::fromQuantity(dataLayout.getABITypeAlign(guardTy));
239 } else {
240 guardTy = cir::IntType::get(&getContext(), 64, /*isSigned=*/true);
241 guardAlignment =
242 clang::CharUnits::fromQuantity(dataLayout.getABITypeAlign(guardTy));
243 }
244 assert(guardTy && guardAlignment.getQuantity() != 0);
245
246 llvm::StringRef globalSymName = globalOp.getSymName();
247 cir::GlobalOp guard = getStaticLocalDeclGuardAddress(globalSymName);
248 if (!guard) {
249 // Create the guard variable with a zero-initializer.
250 guard = createGuardGlobalOp(builder, globalOp->getLoc(), guardName,
251 guardTy, globalOp.getLinkage());
252 guard.setInitialValueAttr(cir::IntAttr::get(guardTy, 0));
253 guard.setDSOLocal(globalOp.getDsoLocal());
254 guard.setAlignment(guardAlignment.getAsAlign().value());
255 guard.setTlsModel(globalOp.getTlsModel());
256
257 // The ABI says: "It is suggested that it be emitted in the same COMDAT
258 // group as the associated data object." In practice, this doesn't work
259 // for non-ELF and non-Wasm object formats, so only do it for ELF and
260 // Wasm.
261 bool hasComdat = globalOp.getComdat();
262 const llvm::Triple &triple = astCtx->getTargetInfo().getTriple();
263 // TODO(cir): for now, we're just setting comdat to true, but it should
264 // contain a comdat reference name here instead.
265 if (!isLocalVarDecl && hasComdat &&
266 (triple.isOSBinFormatELF() || triple.isOSBinFormatWasm())) {
267 // This should be a comdat for the variable.
268 guard.setComdat(true);
269 } else if (hasComdat && globalOp.isWeakForLinker()) {
270 guard.setComdat(true);
271 }
272
273 setStaticLocalDeclGuardAddress(globalSymName, guard);
274 }
275 return guard;
276 }
277
278 ///
279 /// AST related
280 /// -----------
281
282 clang::ASTContext *astCtx;
283
284 /// Tracks current module.
285 mlir::ModuleOp mlirModule;
286
287 /// Cached symbol tables used to avoid repeated O(M) module-wide scans
288 /// during per-call/per-global symbol lookups. Lazily populated on first
289 /// use. Pass methods access this directly rather than threading it
290 /// through helper signatures (see PR feedback on #195919).
291 ///
292 /// Invariant: every site that mutates the module's symbol table either
293 /// (a) keeps `symbolTables` in sync via
294 /// `symbolTables.getSymbolTable(mlirModule).insert(...)` (as
295 /// `getOrCreateConstAggregateGlobal` does), or (b) creates a symbol
296 /// that is never resolved through the cache later. Today
297 /// `buildRuntimeFunction` and `getOrCreateRuntimeVariable` fall in the
298 /// (b) bucket: their callers either use a separate map
299 /// (`cudaKernelMap`, `staticLocalDeclGuardMap`, `dynamicInitializers`)
300 /// or the static `mlir::SymbolTable::lookupNearestSymbolFrom`, never
301 /// the cached path. If a future change adds a cached lookup of a
302 /// freshly created symbol, the corresponding create site MUST move
303 /// to bucket (a) (insert into the cache or call
304 /// `invalidateSymbolTable`).
305 mlir::SymbolTableCollection symbolTables;
306
307 /// Tracks existing dynamic initializers.
308 llvm::StringMap<uint32_t> dynamicInitializerNames;
309 llvm::SmallVector<cir::FuncOp> dynamicInitializers;
310 llvm::SmallVector<cir::FuncOp> globalThreadLocalInitializers;
311 llvm::StringMap<cir::FuncOp> threadLocalWrappers;
312 llvm::StringMap<cir::FuncOp> threadLocalInitAliases;
313
314 /// Tracks guard variables for static locals (keyed by global symbol name).
315 llvm::StringMap<cir::GlobalOp> staticLocalDeclGuardMap;
316
317 llvm::StringMap<llvm::SmallVector<cir::GlobalOp, 1>> constAggregateGlobals;
318
319 /// List of ctors and their priorities to be called before main()
320 llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalCtorList;
321 /// List of dtors and their priorities to be called when unloading module.
322 llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalDtorList;
323
324 /// Returns true if the target uses ARM-style guard variables for static
325 /// local initialization (32-bit guard, check bit 0 only).
326 bool useARMGuardVarABI() const {
327 switch (astCtx->getCXXABIKind()) {
328 case clang::TargetCXXABI::GenericARM:
329 case clang::TargetCXXABI::iOS:
330 case clang::TargetCXXABI::WatchOS:
331 case clang::TargetCXXABI::GenericAArch64:
332 case clang::TargetCXXABI::WebAssembly:
333 return true;
334 default:
335 return false;
336 }
337 }
338
339 void emitGlobalGuardedDtorRegion(CIRBaseBuilderTy &builder,
340 cir::GlobalOp global,
341 mlir::Region &dtorRegion, bool tls,
342 mlir::Block &entryBB) {
343 // Create a variable that binds the atexit to this shared object.
344 builder.setInsertionPointToStart(&mlirModule.getBodyRegion().front());
345 cir::GlobalOp handle = getOrCreateRuntimeVariable(
346 builder, "__dso_handle", global.getLoc(), builder.getI8Type(),
347 cir::GlobalLinkageKind::ExternalLinkage, cir::VisibilityKind::Hidden);
348
349 // If this is a simple call to a destructor, get the called function.
350 // Otherwise, create a helper function for the entire dtor region,
351 // replacing the current dtor region body with a call to the helper
352 // function.
353 cir::CallOp dtorCall;
354 cir::FuncOp dtorFunc =
355 getOrCreateDtorFunc(builder, global, dtorRegion, dtorCall);
356
357 // Create a runtime helper function:
358 // extern "C" int __cxa_atexit(void (*f)(void *), void *p, void *d);
359 cir::PointerType voidPtrTy = builder.getVoidPtrTy();
360 cir::PointerType voidFnPtrTy = builder.getVoidFnPtrTy({voidPtrTy});
361 cir::PointerType handlePtrTy = builder.getPointerTo(handle.getSymType());
362 auto fnAtExitType =
363 builder.getVoidFnTy({voidFnPtrTy, voidPtrTy, handlePtrTy});
364
365 llvm::StringLiteral nameAtExit = "__cxa_atexit";
366 if (tls)
367 nameAtExit = astCtx->getTargetInfo().getTriple().isOSDarwin()
368 ? llvm::StringLiteral("_tlv_atexit")
369 : llvm::StringLiteral("__cxa_thread_atexit");
370
371 cir::FuncOp fnAtExit = buildRuntimeFunction(builder, nameAtExit,
372 global.getLoc(), fnAtExitType);
373
374 // Replace the dtor (or helper) call with a call to
375 // __cxa_atexit(&dtor, &var, &__dso_handle)
376 builder.setInsertionPointAfter(dtorCall);
377 mlir::Value args[3];
378 auto dtorPtrTy = cir::PointerType::get(dtorFunc.getFunctionType());
379 args[0] = cir::GetGlobalOp::create(builder, dtorCall.getLoc(), dtorPtrTy,
380 dtorFunc.getSymName());
381 args[0] = cir::CastOp::create(builder, dtorCall.getLoc(), voidFnPtrTy,
382 cir::CastKind::bitcast, args[0]);
383 args[1] =
384 cir::CastOp::create(builder, dtorCall.getLoc(), voidPtrTy,
385 cir::CastKind::bitcast, dtorCall.getArgOperand(0));
386 args[2] = cir::GetGlobalOp::create(builder, handle.getLoc(), handlePtrTy,
387 handle.getSymName());
388 builder.createCallOp(dtorCall.getLoc(), fnAtExit, args);
389 dtorCall->erase();
390 mlir::Block &dtorBlock = dtorRegion.front();
391 entryBB.getOperations().splice(entryBB.end(), dtorBlock.getOperations(),
392 dtorBlock.begin(),
393 std::prev(dtorBlock.end()));
394 // make sure we leave the insert location after the operations we just
395 // inserted.
396 builder.setInsertionPointToEnd(&entryBB);
397 }
398
399 /// Emit the guarded initialization for a static local variable.
400 /// This handles the if/else structure after the guard byte check,
401 /// following OG's ItaniumCXXABI::EmitGuardedInit skeleton.
402 void emitCXXGuardedInitIf(CIRBaseBuilderTy &builder, cir::GlobalOp globalOp,
403 mlir::Region &ctorRegion, mlir::Region &dtorRegion,
404 cir::ASTVarDeclInterface varDecl,
405 mlir::Value guardPtr, cir::PointerType guardPtrTy,
406 bool threadsafe) {
407 auto loc = globalOp->getLoc();
408
409 // The semantics of dynamic initialization of variables with static or
410 // thread storage duration depends on whether they are declared at
411 // block-scope. The initialization of such variables at block-scope can be
412 // aborted with an exception and later retried (per C++20 [stmt.dcl]p4),
413 // and recursive entry to their initialization has undefined behavior (also
414 // per C++20 [stmt.dcl]p4). For such variables declared at non-block scope,
415 // exceptions lead to termination (per C++20 [except.terminate]p1), and
416 // recursive references to the variables are governed only by the lifetime
417 // rules (per C++20 [class.cdtor]p2), which means such references are
418 // perfectly fine as long as they avoid touching memory. As a result,
419 // block-scope variables must not be marked as initialized until after
420 // initialization completes (unless the mark is reverted following an
421 // exception), but non-block-scope variables must be marked prior to
422 // initialization so that recursive accesses during initialization do not
423 // restart initialization.
424
425 auto emitBody = [&]() {
426 // Emit the initializer and add a global destructor if appropriate.
427 mlir::Block *insertBlock = builder.getInsertionBlock();
428 if (!ctorRegion.empty()) {
429 assert(ctorRegion.hasOneBlock() && "Enforced by MaxSizedRegion<1>");
430
431 mlir::Block &block = ctorRegion.front();
432 insertBlock->getOperations().splice(
433 insertBlock->end(), block.getOperations(), block.begin(),
434 std::prev(block.end()));
435 }
436
437 if (!dtorRegion.empty()) {
438 assert(dtorRegion.hasOneBlock() && "Enforced by MaxSizedRegion<1>");
439
440 emitGlobalGuardedDtorRegion(builder, globalOp, dtorRegion, !threadsafe,
441 *insertBlock);
442 }
443 builder.setInsertionPointToEnd(insertBlock);
444 ctorRegion.getBlocks().clear();
445 };
446
447 // Variables used when coping with thread-safe statics and exceptions.
448 if (threadsafe) {
449 // Call __cxa_guard_acquire.
450 cir::CallOp acquireCall = builder.createCallOp(
451 loc, getGuardAcquireFn(guardPtrTy), mlir::ValueRange{guardPtr});
452 mlir::Value acquireResult = acquireCall.getResult();
453
454 auto acquireZero = builder.getConstantInt(
455 loc, mlir::cast<cir::IntType>(acquireResult.getType()), 0);
456 auto shouldInit = builder.createCompare(loc, cir::CmpOpKind::ne,
457 acquireResult, acquireZero);
458
459 // Create the IfOp for the shouldInit check.
460 // Pass an empty callback to avoid auto-creating a yield terminator.
461 auto ifOp =
462 cir::IfOp::create(builder, loc, shouldInit, /*withElseRegion=*/false,
463 [](mlir::OpBuilder &, mlir::Location) {});
464 mlir::OpBuilder::InsertionGuard insertGuard(builder);
465 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
466
467 // Call __cxa_guard_abort along the exceptional edge.
468 // OG: CGF.EHStack.pushCleanup<CallGuardAbort>(EHCleanup, guard);
470
471 emitBody();
472
473 // Pop the guard-abort cleanup if we pushed one.
474 // OG: CGF.PopCleanupBlock();
476
477 // Call __cxa_guard_release. This cannot throw.
478 builder.createCallOp(loc, getGuardReleaseFn(guardPtrTy),
479 mlir::ValueRange{guardPtr});
480
481 builder.createYield(loc);
482 } else if (!varDecl.isLocalVarDecl()) {
483 // For non-local variables, store 1 into the first byte of the guard
484 // variable before the object initialization begins so that references
485 // to the variable during initialization don't restart initialization.
486 // OG: Builder.CreateStore(llvm::ConstantInt::get(CGM.Int8Ty, 1), ...);
487 // Then: CGF.EmitCXXGlobalVarDeclInit(D, var, shouldPerformInit);
488 globalOp->emitError("NYI: non-threadsafe init for non-local variables");
489 return;
490 } else {
491 emitBody();
492 // For local variables, store 1 into the first byte of the guard variable
493 // after the object initialization completes so that initialization is
494 // retried if initialization is interrupted by an exception.
495 builder.createStore(
496 loc, builder.getConstantInt(loc, guardPtrTy.getPointee(), 1),
497 guardPtr);
498 }
499
500 builder.createYield(loc); // Outermost IfOp
501 }
502
503 void setASTContext(clang::ASTContext *c) { astCtx = c; }
504};
505
506} // namespace
507
508cir::GlobalOp LoweringPreparePass::getOrCreateRuntimeVariable(
509 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
510 mlir::Type type, cir::GlobalLinkageKind linkage,
511 cir::VisibilityKind visibility) {
512 cir::GlobalOp g = dyn_cast_or_null<cir::GlobalOp>(
513 mlir::SymbolTable::lookupNearestSymbolFrom(
514 mlirModule, mlir::StringAttr::get(mlirModule->getContext(), name)));
515 if (!g) {
516 g = cir::GlobalOp::create(builder, loc, name, type);
517 g.setLinkageAttr(
518 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
519 mlir::SymbolTable::setSymbolVisibility(
520 g, mlir::SymbolTable::Visibility::Private);
521 g.setGlobalVisibility(visibility);
522 }
523 return g;
524}
525
526cir::FuncOp LoweringPreparePass::buildRuntimeFunction(
527 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
528 cir::FuncType type, cir::GlobalLinkageKind linkage) {
529 cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom(
530 mlirModule, StringAttr::get(mlirModule->getContext(), name)));
531 if (!f) {
532 f = cir::FuncOp::create(builder, loc, name, type);
533 f.setLinkageAttr(
534 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
535 mlir::SymbolTable::setSymbolVisibility(
536 f, mlir::SymbolTable::Visibility::Private);
537
539 }
540 return f;
541}
542
543static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx,
544 cir::CastOp op) {
545 cir::CIRBaseBuilderTy builder(ctx);
546 builder.setInsertionPoint(op);
547
548 mlir::Value src = op.getSrc();
549 mlir::Value imag = builder.getNullValue(src.getType(), op.getLoc());
550 return builder.createComplexCreate(op.getLoc(), src, imag);
551}
552
553static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx,
554 cir::CastOp op,
555 cir::CastKind elemToBoolKind) {
556 cir::CIRBaseBuilderTy builder(ctx);
557 builder.setInsertionPoint(op);
558
559 mlir::Value src = op.getSrc();
560 if (!mlir::isa<cir::BoolType>(op.getType()))
561 return builder.createComplexReal(op.getLoc(), src);
562
563 // Complex cast to bool: (bool)(a+bi) => (bool)a || (bool)b
564 mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
565 mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
566
567 cir::BoolType boolTy = builder.getBoolTy();
568 mlir::Value srcRealToBool =
569 builder.createCast(op.getLoc(), elemToBoolKind, srcReal, boolTy);
570 mlir::Value srcImagToBool =
571 builder.createCast(op.getLoc(), elemToBoolKind, srcImag, boolTy);
572 return builder.createLogicalOr(op.getLoc(), srcRealToBool, srcImagToBool);
573}
574
575static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx,
576 cir::CastOp op,
577 cir::CastKind scalarCastKind) {
578 CIRBaseBuilderTy builder(ctx);
579 builder.setInsertionPoint(op);
580
581 mlir::Value src = op.getSrc();
582 auto dstComplexElemTy =
583 mlir::cast<cir::ComplexType>(op.getType()).getElementType();
584
585 mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
586 mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
587
588 mlir::Value dstReal = builder.createCast(op.getLoc(), scalarCastKind, srcReal,
589 dstComplexElemTy);
590 mlir::Value dstImag = builder.createCast(op.getLoc(), scalarCastKind, srcImag,
591 dstComplexElemTy);
592 return builder.createComplexCreate(op.getLoc(), dstReal, dstImag);
593}
594
595void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
596 mlir::MLIRContext &ctx = getContext();
597 mlir::Value loweredValue = [&]() -> mlir::Value {
598 switch (op.getKind()) {
599 case cir::CastKind::float_to_complex:
600 case cir::CastKind::int_to_complex:
601 return lowerScalarToComplexCast(ctx, op);
602 case cir::CastKind::float_complex_to_real:
603 case cir::CastKind::int_complex_to_real:
604 return lowerComplexToScalarCast(ctx, op, op.getKind());
605 case cir::CastKind::float_complex_to_bool:
606 return lowerComplexToScalarCast(ctx, op, cir::CastKind::float_to_bool);
607 case cir::CastKind::int_complex_to_bool:
608 return lowerComplexToScalarCast(ctx, op, cir::CastKind::int_to_bool);
609 case cir::CastKind::float_complex:
610 return lowerComplexToComplexCast(ctx, op, cir::CastKind::floating);
611 case cir::CastKind::float_complex_to_int_complex:
612 return lowerComplexToComplexCast(ctx, op, cir::CastKind::float_to_int);
613 case cir::CastKind::int_complex:
614 return lowerComplexToComplexCast(ctx, op, cir::CastKind::integral);
615 case cir::CastKind::int_complex_to_float_complex:
616 return lowerComplexToComplexCast(ctx, op, cir::CastKind::int_to_float);
617 default:
618 return nullptr;
619 }
620 }();
621
622 if (loweredValue) {
623 op.replaceAllUsesWith(loweredValue);
624 op.erase();
625 }
626}
627
628static mlir::Value buildComplexBinOpLibCall(
629 LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
630 llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics),
631 mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal,
632 mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) {
633 cir::FPTypeInterface elementTy =
634 mlir::cast<cir::FPTypeInterface>(ty.getElementType());
635
636 llvm::StringRef libFuncName = libFuncNameGetter(
637 llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics()));
638 llvm::SmallVector<mlir::Type, 4> libFuncInputTypes(4, elementTy);
639
640 cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty);
641
642 // Insert a declaration for the runtime function to be used in Complex
643 // multiplication and division when needed
644 cir::FuncOp libFunc;
645 {
646 mlir::OpBuilder::InsertionGuard ipGuard{builder};
647 builder.setInsertionPointToStart(pass.mlirModule.getBody());
648 libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy);
649 }
650
651 cir::CallOp call =
652 builder.createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag});
653 return call.getResult();
654}
655
656static llvm::StringRef
657getComplexDivLibCallName(llvm::APFloat::Semantics semantics) {
658 switch (semantics) {
659 case llvm::APFloat::S_IEEEhalf:
660 return "__divhc3";
661 case llvm::APFloat::S_IEEEsingle:
662 return "__divsc3";
663 case llvm::APFloat::S_IEEEdouble:
664 return "__divdc3";
665 case llvm::APFloat::S_PPCDoubleDouble:
666 return "__divtc3";
667 case llvm::APFloat::S_x87DoubleExtended:
668 return "__divxc3";
669 case llvm::APFloat::S_IEEEquad:
670 return "__divtc3";
671 default:
672 llvm_unreachable("unsupported floating point type");
673 }
674}
675
676static mlir::Value
677buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc,
678 mlir::Value lhsReal, mlir::Value lhsImag,
679 mlir::Value rhsReal, mlir::Value rhsImag) {
680 // (a+bi) / (c+di) = ((ac+bd)/(cc+dd)) + ((bc-ad)/(cc+dd))i
681 mlir::Value &a = lhsReal;
682 mlir::Value &b = lhsImag;
683 mlir::Value &c = rhsReal;
684 mlir::Value &d = rhsImag;
685
686 // The element type of the complex (lhs/rhs) determines whether floating
687 // point or integer ops are needed.
688 bool isFP = cir::isFPOrVectorOfFPType(a.getType());
689 auto mul = [&](mlir::Location l, mlir::Value x, mlir::Value y) {
690 return isFP ? builder.createFMul(l, x, y) : builder.createMul(l, x, y);
691 };
692 auto add = [&](mlir::Location l, mlir::Value x, mlir::Value y) {
693 return isFP ? builder.createFAdd(l, x, y) : builder.createAdd(l, x, y);
694 };
695 auto sub = [&](mlir::Location l, mlir::Value x, mlir::Value y) {
696 return isFP ? builder.createFSub(l, x, y) : builder.createSub(l, x, y);
697 };
698 auto div = [&](mlir::Location l, mlir::Value x, mlir::Value y) {
699 return isFP ? builder.createFDiv(l, x, y) : builder.createDiv(l, x, y);
700 };
701
702 mlir::Value ac = mul(loc, a, c); // a*c
703 mlir::Value bd = mul(loc, b, d); // b*d
704 mlir::Value cc = mul(loc, c, c); // c*c
705 mlir::Value dd = mul(loc, d, d); // d*d
706 mlir::Value acbd = add(loc, ac, bd); // ac+bd
707 mlir::Value ccdd = add(loc, cc, dd); // cc+dd
708 mlir::Value resultReal = div(loc, acbd, ccdd);
709
710 mlir::Value bc = mul(loc, b, c); // b*c
711 mlir::Value ad = mul(loc, a, d); // a*d
712 mlir::Value bcad = sub(loc, bc, ad); // bc-ad
713 mlir::Value resultImag = div(loc, bcad, ccdd);
714 return builder.createComplexCreate(loc, resultReal, resultImag);
715}
716
717static mlir::Value
719 mlir::Value lhsReal, mlir::Value lhsImag,
720 mlir::Value rhsReal, mlir::Value rhsImag) {
721 // Implements Smith's algorithm for complex division.
722 // SMITH, R. L. Algorithm 116: Complex division. Commun. ACM 5, 8 (1962).
723
724 // Let:
725 // - lhs := a+bi
726 // - rhs := c+di
727 // - result := lhs / rhs = e+fi
728 //
729 // The algorithm pseudocode looks like follows:
730 // if fabs(c) >= fabs(d):
731 // r := d / c
732 // tmp := c + r*d
733 // e = (a + b*r) / tmp
734 // f = (b - a*r) / tmp
735 // else:
736 // r := c / d
737 // tmp := d + r*c
738 // e = (a*r + b) / tmp
739 // f = (b*r - a) / tmp
740
741 mlir::Value &a = lhsReal;
742 mlir::Value &b = lhsImag;
743 mlir::Value &c = rhsReal;
744 mlir::Value &d = rhsImag;
745
746 // Smith's algorithm is only used for floating-point complex division.
747 assert(cir::isFPOrVectorOfFPType(a.getType()) &&
748 "range-reduction complex divide expects floating-point operands");
749
750 auto trueBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
751 mlir::Value r = builder.createFDiv(loc, d, c); // r := d / c
752 mlir::Value rd = builder.createFMul(loc, r, d); // r*d
753 mlir::Value tmp = builder.createFAdd(loc, c, rd); // tmp := c + r*d
754
755 mlir::Value br = builder.createFMul(loc, b, r); // b*r
756 mlir::Value abr = builder.createFAdd(loc, a, br); // a + b*r
757 mlir::Value e = builder.createFDiv(loc, abr, tmp);
758
759 mlir::Value ar = builder.createFMul(loc, a, r); // a*r
760 mlir::Value bar = builder.createFSub(loc, b, ar); // b - a*r
761 mlir::Value f = builder.createFDiv(loc, bar, tmp);
762
763 mlir::Value result = builder.createComplexCreate(loc, e, f);
764 builder.createYield(loc, result);
765 };
766
767 auto falseBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
768 mlir::Value r = builder.createFDiv(loc, c, d); // r := c / d
769 mlir::Value rc = builder.createFMul(loc, r, c); // r*c
770 mlir::Value tmp = builder.createFAdd(loc, d, rc); // tmp := d + r*c
771
772 mlir::Value ar = builder.createFMul(loc, a, r); // a*r
773 mlir::Value arb = builder.createFAdd(loc, ar, b); // a*r + b
774 mlir::Value e = builder.createFDiv(loc, arb, tmp);
775
776 mlir::Value br = builder.createFMul(loc, b, r); // b*r
777 mlir::Value bra = builder.createFSub(loc, br, a); // b*r - a
778 mlir::Value f = builder.createFDiv(loc, bra, tmp);
779
780 mlir::Value result = builder.createComplexCreate(loc, e, f);
781 builder.createYield(loc, result);
782 };
783
784 auto cFabs = cir::FAbsOp::create(builder, loc, c);
785 auto dFabs = cir::FAbsOp::create(builder, loc, d);
786 cir::CmpOp cmpResult =
787 builder.createCompare(loc, cir::CmpOpKind::ge, cFabs, dFabs);
788 auto ternary = cir::TernaryOp::create(builder, loc, cmpResult,
789 trueBranchBuilder, falseBranchBuilder);
790
791 return ternary.getResult();
792}
793
795 mlir::MLIRContext &context, clang::ASTContext &cc,
796 CIRBaseBuilderTy &builder, mlir::Type elementType) {
797
798 auto getHigherPrecisionFPType = [&context](mlir::Type type) -> mlir::Type {
799 if (mlir::isa<cir::FP16Type>(type))
800 return cir::SingleType::get(&context);
801
802 if (mlir::isa<cir::SingleType>(type) || mlir::isa<cir::BF16Type>(type))
803 return cir::DoubleType::get(&context);
804
805 if (mlir::isa<cir::DoubleType>(type))
806 return cir::LongDoubleType::get(&context, type);
807
808 return type;
809 };
810
811 auto getFloatTypeSemantics =
812 [&cc](mlir::Type type) -> const llvm::fltSemantics & {
813 const clang::TargetInfo &info = cc.getTargetInfo();
814 if (mlir::isa<cir::FP16Type>(type))
815 return info.getHalfFormat();
816
817 if (mlir::isa<cir::BF16Type>(type))
818 return info.getBFloat16Format();
819
820 if (mlir::isa<cir::SingleType>(type))
821 return info.getFloatFormat();
822
823 if (mlir::isa<cir::DoubleType>(type))
824 return info.getDoubleFormat();
825
826 if (mlir::isa<cir::LongDoubleType>(type)) {
827 if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
828 llvm_unreachable("NYI Float type semantics with OpenMP");
829 return info.getLongDoubleFormat();
830 }
831
832 if (mlir::isa<cir::FP128Type>(type)) {
833 if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
834 llvm_unreachable("NYI Float type semantics with OpenMP");
835 return info.getFloat128Format();
836 }
837
838 llvm_unreachable("Unsupported float type semantics");
839 };
840
841 const mlir::Type higherElementType = getHigherPrecisionFPType(elementType);
842 const llvm::fltSemantics &elementTypeSemantics =
843 getFloatTypeSemantics(elementType);
844 const llvm::fltSemantics &higherElementTypeSemantics =
845 getFloatTypeSemantics(higherElementType);
846
847 // Check that the promoted type can handle the intermediate values without
848 // overflowing. This can be interpreted as:
849 // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal) * 2 <=
850 // LargerType.LargestFiniteVal.
851 // In terms of exponent it gives this formula:
852 // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal
853 // doubles the exponent of SmallerType.LargestFiniteVal)
854 if (llvm::APFloat::semanticsMaxExponent(elementTypeSemantics) * 2 + 1 <=
855 llvm::APFloat::semanticsMaxExponent(higherElementTypeSemantics)) {
856 return higherElementType;
857 }
858
859 // The intermediate values can't be represented in the promoted type
860 // without overflowing.
861 return {};
862}
863
864static mlir::Value
865lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
866 mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal,
867 mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag,
868 mlir::MLIRContext &mlirCx, clang::ASTContext &cc) {
869 cir::ComplexType complexTy = op.getType();
870 if (mlir::isa<cir::FPTypeInterface>(complexTy.getElementType())) {
871 cir::ComplexRangeKind range = op.getRange();
872 if (range == cir::ComplexRangeKind::Improved)
873 return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
874 rhsReal, rhsImag);
875
876 if (range == cir::ComplexRangeKind::Full)
878 loc, complexTy, lhsReal, lhsImag, rhsReal,
879 rhsImag);
880
881 if (range == cir::ComplexRangeKind::Promoted) {
882 mlir::Type originalElementType = complexTy.getElementType();
883 mlir::Type higherPrecisionElementType =
885 originalElementType);
886
887 if (!higherPrecisionElementType)
888 return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
889 rhsReal, rhsImag);
890
891 cir::CastKind floatingCastKind = cir::CastKind::floating;
892 lhsReal = builder.createCast(floatingCastKind, lhsReal,
893 higherPrecisionElementType);
894 lhsImag = builder.createCast(floatingCastKind, lhsImag,
895 higherPrecisionElementType);
896 rhsReal = builder.createCast(floatingCastKind, rhsReal,
897 higherPrecisionElementType);
898 rhsImag = builder.createCast(floatingCastKind, rhsImag,
899 higherPrecisionElementType);
900
901 mlir::Value algebraicResult = buildAlgebraicComplexDiv(
902 builder, loc, lhsReal, lhsImag, rhsReal, rhsImag);
903
904 mlir::Value resultReal = builder.createComplexReal(loc, algebraicResult);
905 mlir::Value resultImag = builder.createComplexImag(loc, algebraicResult);
906
907 mlir::Value finalReal =
908 builder.createCast(floatingCastKind, resultReal, originalElementType);
909 mlir::Value finalImag =
910 builder.createCast(floatingCastKind, resultImag, originalElementType);
911 return builder.createComplexCreate(loc, finalReal, finalImag);
912 }
913 }
914
915 return buildAlgebraicComplexDiv(builder, loc, lhsReal, lhsImag, rhsReal,
916 rhsImag);
917}
918
919void LoweringPreparePass::lowerComplexDivOp(cir::ComplexDivOp op) {
920 cir::CIRBaseBuilderTy builder(getContext());
921 builder.setInsertionPointAfter(op);
922 mlir::Location loc = op.getLoc();
923 mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
924 mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
925 mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
926 mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
927 mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
928 mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
929
930 mlir::Value loweredResult =
931 lowerComplexDiv(*this, builder, loc, op, lhsReal, lhsImag, rhsReal,
932 rhsImag, getContext(), *astCtx);
933 op.replaceAllUsesWith(loweredResult);
934 op.erase();
935}
936
937static llvm::StringRef
938getComplexMulLibCallName(llvm::APFloat::Semantics semantics) {
939 switch (semantics) {
940 case llvm::APFloat::S_IEEEhalf:
941 return "__mulhc3";
942 case llvm::APFloat::S_IEEEsingle:
943 return "__mulsc3";
944 case llvm::APFloat::S_IEEEdouble:
945 return "__muldc3";
946 case llvm::APFloat::S_PPCDoubleDouble:
947 return "__multc3";
948 case llvm::APFloat::S_x87DoubleExtended:
949 return "__mulxc3";
950 case llvm::APFloat::S_IEEEquad:
951 return "__multc3";
952 default:
953 llvm_unreachable("unsupported floating point type");
954 }
955}
956
957static mlir::Value lowerComplexMul(LoweringPreparePass &pass,
958 CIRBaseBuilderTy &builder,
959 mlir::Location loc, cir::ComplexMulOp op,
960 mlir::Value lhsReal, mlir::Value lhsImag,
961 mlir::Value rhsReal, mlir::Value rhsImag) {
962 // (a+bi) * (c+di) = (ac-bd) + (ad+bc)i
963 bool isFP = cir::isFPOrVectorOfFPType(lhsReal.getType());
964 auto mul = [&](mlir::Location l, mlir::Value x, mlir::Value y) {
965 return isFP ? builder.createFMul(l, x, y) : builder.createMul(l, x, y);
966 };
967 auto add = [&](mlir::Location l, mlir::Value x, mlir::Value y) {
968 return isFP ? builder.createFAdd(l, x, y) : builder.createAdd(l, x, y);
969 };
970 auto sub = [&](mlir::Location l, mlir::Value x, mlir::Value y) {
971 return isFP ? builder.createFSub(l, x, y) : builder.createSub(l, x, y);
972 };
973
974 mlir::Value resultRealLhs = mul(loc, lhsReal, rhsReal); // ac
975 mlir::Value resultRealRhs = mul(loc, lhsImag, rhsImag); // bd
976 mlir::Value resultImagLhs = mul(loc, lhsReal, rhsImag); // ad
977 mlir::Value resultImagRhs = mul(loc, lhsImag, rhsReal); // bc
978 mlir::Value resultReal = sub(loc, resultRealLhs, resultRealRhs);
979 mlir::Value resultImag = add(loc, resultImagLhs, resultImagRhs);
980 mlir::Value algebraicResult =
981 builder.createComplexCreate(loc, resultReal, resultImag);
982
983 cir::ComplexType complexTy = op.getType();
984 cir::ComplexRangeKind rangeKind = op.getRange();
985 if (mlir::isa<cir::IntType>(complexTy.getElementType()) ||
986 rangeKind == cir::ComplexRangeKind::Basic ||
987 rangeKind == cir::ComplexRangeKind::Improved ||
988 rangeKind == cir::ComplexRangeKind::Promoted)
989 return algebraicResult;
990
992
993 // Check whether the real part and the imaginary part of the result are both
994 // NaN. If so, emit a library call to compute the multiplication instead.
995 // We check a value against NaN by comparing the value against itself.
996 mlir::Value resultRealIsNaN = builder.createIsNaN(loc, resultReal);
997 mlir::Value resultImagIsNaN = builder.createIsNaN(loc, resultImag);
998 mlir::Value resultRealAndImagAreNaN =
999 builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN);
1000
1001 return cir::TernaryOp::create(
1002 builder, loc, resultRealAndImagAreNaN,
1003 [&](mlir::OpBuilder &, mlir::Location) {
1004 mlir::Value libCallResult = buildComplexBinOpLibCall(
1005 pass, builder, &getComplexMulLibCallName, loc, complexTy,
1006 lhsReal, lhsImag, rhsReal, rhsImag);
1007 builder.createYield(loc, libCallResult);
1008 },
1009 [&](mlir::OpBuilder &, mlir::Location) {
1010 builder.createYield(loc, algebraicResult);
1011 })
1012 .getResult();
1013}
1014
1015void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) {
1016 cir::CIRBaseBuilderTy builder(getContext());
1017 builder.setInsertionPointAfter(op);
1018 mlir::Location loc = op.getLoc();
1019 mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
1020 mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
1021 mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
1022 mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
1023 mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
1024 mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
1025 mlir::Value loweredResult = lowerComplexMul(*this, builder, loc, op, lhsReal,
1026 lhsImag, rhsReal, rhsImag);
1027 op.replaceAllUsesWith(loweredResult);
1028 op.erase();
1029}
1030
1031void LoweringPreparePass::lowerComplexConjOp(cir::ComplexConjOp op) {
1032 mlir::Location loc = op.getLoc();
1033 CIRBaseBuilderTy builder(getContext());
1034 builder.setInsertionPointAfter(op);
1035
1036 mlir::Value operand = op.getOperand();
1037 mlir::Value operandReal = builder.createComplexReal(loc, operand);
1038 mlir::Value operandImag = builder.createComplexImag(loc, operand);
1039
1040 // The complex conjugate is formed by negating the imaginary component.
1041 const bool isFP = cir::isFPOrVectorOfFPType(operandReal.getType());
1042 mlir::Value resultImag = isFP ? builder.createFNeg(loc, operandImag)
1043 : builder.createMinus(loc, operandImag);
1044
1045 mlir::Value result =
1046 builder.createComplexCreate(loc, operandReal, resultImag);
1047 op->replaceAllUsesWith(mlir::ValueRange{result});
1048 op->erase();
1049}
1050
1051cir::FuncOp LoweringPreparePass::getOrCreateDtorFunc(CIRBaseBuilderTy &builder,
1052 cir::GlobalOp op,
1053 mlir::Region &dtorRegion,
1054 cir::CallOp &dtorCall) {
1055 mlir::OpBuilder::InsertionGuard guard(builder);
1057
1058 cir::VoidType voidTy = builder.getVoidTy();
1059 auto voidPtrTy = cir::PointerType::get(voidTy);
1060
1061 // Look for operations in dtorBlock
1062 mlir::Block &dtorBlock = dtorRegion.front();
1063
1064 // The first operation should be a get_global to retrieve the address
1065 // of the global variable we're destroying.
1066 auto opIt = dtorBlock.getOperations().begin();
1067 cir::GetGlobalOp ggop = mlir::cast<cir::GetGlobalOp>(*opIt);
1068
1069 // The simple case is just a call to a destructor, like this:
1070 //
1071 // %0 = cir.get_global %globalS : !cir.ptr<!rec_S>
1072 // cir.call %_ZN1SD1Ev(%0) : (!cir.ptr<!rec_S>) -> ()
1073 // (implicit cir.yield)
1074 //
1075 // That is, if the second operation is a call that takes the get_global result
1076 // as its only operand, and the only other operation is a yield, then we can
1077 // just return the called function.
1078 if (dtorBlock.getOperations().size() == 3) {
1079 auto callOp = mlir::dyn_cast<cir::CallOp>(&*(++opIt));
1080 auto yieldOp = mlir::dyn_cast<cir::YieldOp>(&*(++opIt));
1081 if (yieldOp && callOp && callOp.getNumOperands() == 1 &&
1082 callOp.getArgOperand(0) == ggop) {
1083 dtorCall = callOp;
1084 return getCalledFunction(callOp);
1085 }
1086 }
1087
1088 // Otherwise, we need to create a helper function to replace the dtor region.
1089 // This name is kind of arbitrary, but it matches the name that classic
1090 // codegen uses, based on the expected case that gets us here.
1091 builder.setInsertionPointAfter(op);
1092 SmallString<256> fnName("__cxx_global_array_dtor");
1093 uint32_t cnt = dynamicInitializerNames[fnName]++;
1094 if (cnt)
1095 fnName += "." + std::to_string(cnt);
1096
1097 // Create the helper function.
1098 auto fnType = cir::FuncType::get({voidPtrTy}, voidTy);
1099 cir::FuncOp dtorFunc =
1100 buildRuntimeFunction(builder, fnName, op.getLoc(), fnType,
1101 cir::GlobalLinkageKind::InternalLinkage);
1102
1103 SmallVector<mlir::NamedAttribute> paramAttrs;
1104 paramAttrs.push_back(
1105 builder.getNamedAttr("llvm.noundef", builder.getUnitAttr()));
1106 SmallVector<mlir::Attribute> argAttrDicts;
1107 argAttrDicts.push_back(
1108 mlir::DictionaryAttr::get(builder.getContext(), paramAttrs));
1109 dtorFunc.setArgAttrsAttr(
1110 mlir::ArrayAttr::get(builder.getContext(), argAttrDicts));
1111
1112 mlir::Block *entryBB = dtorFunc.addEntryBlock();
1113
1114 // Move everything from the dtor region into the helper function.
1115 entryBB->getOperations().splice(entryBB->begin(), dtorBlock.getOperations(),
1116 dtorBlock.begin(), dtorBlock.end());
1117
1118 // Before erasing this, clone it back into the dtor region
1119 cir::GetGlobalOp dtorGGop =
1120 mlir::cast<cir::GetGlobalOp>(entryBB->getOperations().front());
1121 builder.setInsertionPointToStart(&dtorBlock);
1122 builder.clone(*dtorGGop.getOperation());
1123
1124 // Replace all uses of the help function's get_global with the function
1125 // argument.
1126 mlir::Value dtorArg = entryBB->getArgument(0);
1127 dtorGGop.replaceAllUsesWith(dtorArg);
1128 dtorGGop.erase();
1129
1130 // Replace the yield in the final block with a return
1131 mlir::Block &finalBlock = dtorFunc.getBody().back();
1132 auto yieldOp = cast<cir::YieldOp>(finalBlock.getTerminator());
1133 builder.setInsertionPoint(yieldOp);
1134 cir::ReturnOp::create(builder, yieldOp->getLoc());
1135 yieldOp->erase();
1136
1137 // Create a call to the helper function, passing the original get_global op
1138 // as the argument.
1139 cir::GetGlobalOp origGGop =
1140 mlir::cast<cir::GetGlobalOp>(dtorBlock.getOperations().front());
1141 builder.setInsertionPointAfter(origGGop);
1142 mlir::Value ggopResult = origGGop.getResult();
1143 dtorCall = builder.createCallOp(op.getLoc(), dtorFunc, ggopResult);
1144
1145 // Add a yield after the call.
1146 auto finalYield = cir::YieldOp::create(builder, op.getLoc());
1147
1148 // Erase everything after the yield.
1149 dtorBlock.getOperations().erase(std::next(mlir::Block::iterator(finalYield)),
1150 dtorBlock.end());
1151 dtorRegion.getBlocks().erase(std::next(dtorRegion.begin()), dtorRegion.end());
1152
1153 return dtorFunc;
1154}
1155
1156cir::FuncOp
1157LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op) {
1158 // TODO(cir): Store this in the GlobalOp.
1159 // This should come from the MangleContext, but for now I'm hardcoding it.
1160 SmallString<256> fnName("__cxx_global_var_init");
1161 // Get a unique name
1162 uint32_t cnt = dynamicInitializerNames[fnName]++;
1163 if (cnt)
1164 fnName += "." + std::to_string(cnt);
1165
1166 // Create a variable initialization function.
1167 CIRBaseBuilderTy builder(getContext());
1168 builder.setInsertionPointAfter(op);
1169 cir::VoidType voidTy = builder.getVoidTy();
1170 auto fnType = cir::FuncType::get({}, voidTy);
1171 FuncOp f = buildRuntimeFunction(builder, fnName, op.getLoc(), fnType,
1172 cir::GlobalLinkageKind::InternalLinkage);
1173
1174 // Move over the initialization code of the ctor region.
1175 // The ctor region may have multiple blocks when exception handling
1176 // scaffolding creates extra blocks (e.g., unreachable/trap blocks).
1177 // We move all operations from the first block (minus the yield) into
1178 // the function entry, and discard extra blocks (which contain only
1179 // unreachable terminators from EH cleanup paths).
1180 mlir::Block *entryBB = f.addEntryBlock();
1181 builder.setInsertionPointToStart(entryBB);
1182
1183 // If this is a global TLS variable (that is, declared at namespace scope), we
1184 // have to emit the guard variable here.
1185 bool needsTlsGuard = op.getDynTlsRefs() && op.getDynTlsRefs()->getGuardName();
1186 cir::IfOp guardIf;
1187 if (needsTlsGuard) {
1188 guardIf = buildGlobalTlsGuardCheck(
1189 builder, op.getLoc(),
1190 getOrCreateStaticLocalDeclGuardAddress(
1191 builder, op, op.getDynTlsRefs()->getGuardName().getValue(),
1192 /*isLocalVarDecl=*/false,
1193 /*useInt8GuardVariable=*/op.hasInternalLinkage()));
1194 builder.setInsertionPointToEnd(&guardIf.getThenRegion().front());
1195 }
1196
1197 if (!op.getCtorRegion().empty()) {
1198 mlir::Block &block = op.getCtorRegion().front();
1199 mlir::Block *insertBlock = builder.getBlock();
1200 insertBlock->getOperations().splice(insertBlock->end(),
1201 block.getOperations(), block.begin(),
1202 std::prev(block.end()));
1203 }
1204
1205 // Register the destructor call with __cxa_atexit
1206 mlir::Region &dtorRegion = op.getDtorRegion();
1207 if (!dtorRegion.empty()) {
1209
1210 emitGlobalGuardedDtorRegion(builder, op, dtorRegion,
1211 op.getTlsModel().has_value(),
1212 *builder.getBlock());
1213 }
1214
1215 // If we're actually in the 'if' above, create a yield.
1216 if (needsTlsGuard) {
1217 builder.setInsertionPointToEnd(&guardIf.getThenRegion().back());
1218 cir::YieldOp::create(builder, op.getLoc());
1219 }
1220
1221 // Replace cir.yield with cir.return
1222 builder.setInsertionPointToEnd(entryBB);
1223 mlir::Operation *yieldOp = nullptr;
1224 if (!op.getCtorRegion().empty()) {
1225 mlir::Block &block = op.getCtorRegion().front();
1226 yieldOp = &block.getOperations().back();
1227 } else {
1228 assert(!dtorRegion.empty());
1229 mlir::Block &block = dtorRegion.front();
1230 yieldOp = &block.getOperations().back();
1231 }
1232
1233 assert(isa<cir::YieldOp>(*yieldOp));
1234 cir::ReturnOp::create(builder, yieldOp->getLoc());
1235 return f;
1236}
1237
1238cir::FuncOp
1239LoweringPreparePass::getGuardAcquireFn(cir::PointerType guardPtrTy) {
1240 // int __cxa_guard_acquire(__guard *guard_object);
1241 CIRBaseBuilderTy builder(getContext());
1242 mlir::OpBuilder::InsertionGuard ipGuard{builder};
1243 builder.setInsertionPointToStart(mlirModule.getBody());
1244 mlir::Location loc = mlirModule.getLoc();
1245 cir::IntType intTy = cir::IntType::get(&getContext(), 32, /*isSigned=*/true);
1246 auto fnType = cir::FuncType::get({guardPtrTy}, intTy);
1247 return buildRuntimeFunction(builder, "__cxa_guard_acquire", loc, fnType);
1248}
1249
1250cir::FuncOp
1251LoweringPreparePass::getGuardReleaseFn(cir::PointerType guardPtrTy) {
1252 // void __cxa_guard_release(__guard *guard_object);
1253 CIRBaseBuilderTy builder(getContext());
1254 mlir::OpBuilder::InsertionGuard ipGuard{builder};
1255 builder.setInsertionPointToStart(mlirModule.getBody());
1256 mlir::Location loc = mlirModule.getLoc();
1257 cir::VoidType voidTy = cir::VoidType::get(&getContext());
1258 auto fnType = cir::FuncType::get({guardPtrTy}, voidTy);
1259 return buildRuntimeFunction(builder, "__cxa_guard_release", loc, fnType);
1260}
1261
1262cir::FuncOp LoweringPreparePass::getTlsInitFn() {
1263 // void __tls_init(void);
1264 CIRBaseBuilderTy builder(getContext());
1265 mlir::OpBuilder::InsertionGuard _{builder};
1266 builder.setInsertionPointToStart(mlirModule.getBody());
1267 mlir::Location loc = mlirModule.getLoc();
1268 auto fnType = builder.getVoidFnTy();
1269 return buildRuntimeFunction(builder, "__tls_init", loc, fnType,
1270 cir::GlobalLinkageKind::InternalLinkage);
1271}
1272
1273cir::GlobalOp LoweringPreparePass::createGuardGlobalOp(
1274 CIRBaseBuilderTy &builder, mlir::Location loc, llvm::StringRef name,
1275 cir::IntType guardTy, cir::GlobalLinkageKind linkage) {
1276 mlir::OpBuilder::InsertionGuard guard(builder);
1277 builder.setInsertionPointToStart(mlirModule.getBody());
1278 cir::GlobalOp g = cir::GlobalOp::create(builder, loc, name, guardTy);
1279 g.setLinkageAttr(
1280 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
1281 mlir::SymbolTable::setSymbolVisibility(
1282 g, mlir::SymbolTable::Visibility::Private);
1283 return g;
1284}
1285
1286void LoweringPreparePass::handleStaticLocal(cir::GlobalOp globalOp,
1287 cir::LocalInitOp localInitOp) {
1288 CIRBaseBuilderTy builder(getContext());
1289
1290 std::optional<cir::ASTVarDeclInterface> astOption = globalOp.getAst();
1291 assert(astOption.has_value());
1292 cir::ASTVarDeclInterface varDecl = astOption.value();
1293
1294 builder.setInsertionPointAfter(localInitOp);
1295 mlir::Block *localInitBlock = builder.getInsertionBlock();
1296
1297 // Remove the terminator temporarily - we'll add it back at the end.
1298 mlir::Operation *ret = localInitBlock->getTerminator();
1299 ret->remove();
1300 // Note: These two insert-point-after sets are necessary, as the 'trailing'
1301 // operation has changed thanks to the terminator removal.
1302 builder.setInsertionPointAfter(localInitOp);
1303
1304 // Inline variables that weren't instantiated from variable templates have
1305 // partially-ordered initialization within their translation unit.
1306 bool nonTemplateInline =
1307 varDecl.isInline() &&
1308 !clang::isTemplateInstantiation(varDecl.getTemplateSpecializationKind());
1309
1310 // Inline namespace-scope variables require guarded initialization in a
1311 // __cxx_global_var_init function. This is not yet implemented.
1312 if (nonTemplateInline) {
1313 globalOp->emitError(
1314 "NYI: guarded initialization for inline namespace-scope variables");
1315 return;
1316 }
1317
1318 // We only need to use thread-safe statics for local non-TLS variables and
1319 // inline variables; other global initialization is always single-threaded
1320 // or (through lazy dynamic loading in multiple threads) unsequenced.
1321 bool threadsafe = astCtx->getLangOpts().ThreadsafeStatics &&
1322 (varDecl.isLocalVarDecl() || nonTemplateInline) &&
1323 !varDecl.getTLSKind();
1324
1325 // If we have a global variable with internal linkage and thread-safe statics
1326 // are disabled, we can just let the guard variable be of type i8.
1327 bool useInt8GuardVariable = !threadsafe && globalOp.hasInternalLinkage();
1328
1329 // Create the guard variable if we don't already have it.
1330 cir::GlobalOp guard = getOrCreateStaticLocalDeclGuardAddress(
1331 builder, globalOp, globalOp.getStaticLocalGuard()->getName().getValue(),
1332 varDecl.isLocalVarDecl(), useInt8GuardVariable);
1333 if (!guard) {
1334 // Error was already emitted, just restore the terminator and return.
1335 localInitBlock->push_back(ret);
1336 return;
1337 }
1338
1339 mlir::Value guardPtr = builder.createGetGlobal(guard, localInitOp.getTls());
1340
1341 // Test whether the variable has completed initialization.
1342 //
1343 // Itanium C++ ABI 3.3.2:
1344 // The following is pseudo-code showing how these functions can be used:
1345 // if (obj_guard.first_byte == 0) {
1346 // if ( __cxa_guard_acquire (&obj_guard) ) {
1347 // try {
1348 // ... initialize the object ...;
1349 // } catch (...) {
1350 // __cxa_guard_abort (&obj_guard);
1351 // throw;
1352 // }
1353 // ... queue object destructor with __cxa_atexit() ...;
1354 // __cxa_guard_release (&obj_guard);
1355 // }
1356 // }
1357 //
1358 // If threadsafe statics are enabled, but we don't have inline atomics, just
1359 // call __cxa_guard_acquire unconditionally. The "inline" check isn't
1360 // actually inline, and the user might not expect calls to __atomic libcalls.
1361 unsigned maxInlineWidthInBits =
1363
1364 if (!threadsafe || maxInlineWidthInBits) {
1365 // Load the first byte of the guard variable.
1366 auto bytePtrTy = cir::PointerType::get(builder.getSIntNTy(8));
1367 mlir::Value bytePtr = builder.createBitcast(guardPtr, bytePtrTy);
1368 mlir::Value guardLoad = builder.createAlignedLoad(
1369 localInitOp.getLoc(), bytePtr, *guard.getAlignment());
1370
1371 // Itanium ABI:
1372 // An implementation supporting thread-safety on multiprocessor
1373 // systems must also guarantee that references to the initialized
1374 // object do not occur before the load of the initialization flag.
1375 //
1376 // In LLVM, we do this by marking the load Acquire.
1377 if (threadsafe) {
1378 auto loadOp = mlir::cast<cir::LoadOp>(guardLoad.getDefiningOp());
1379 loadOp.setMemOrder(cir::MemOrder::Acquire);
1380 loadOp.setSyncScope(cir::SyncScopeKind::System);
1381 }
1382
1383 // For ARM, we should only check the first bit, rather than the entire byte:
1384 //
1385 // ARM C++ ABI 3.2.3.1:
1386 // To support the potential use of initialization guard variables
1387 // as semaphores that are the target of ARM SWP and LDREX/STREX
1388 // synchronizing instructions we define a static initialization
1389 // guard variable to be a 4-byte aligned, 4-byte word with the
1390 // following inline access protocol.
1391 // #define INITIALIZED 1
1392 // if ((obj_guard & INITIALIZED) != INITIALIZED) {
1393 // if (__cxa_guard_acquire(&obj_guard))
1394 // ...
1395 // }
1396 //
1397 // and similarly for ARM64:
1398 //
1399 // ARM64 C++ ABI 3.2.2:
1400 // This ABI instead only specifies the value bit 0 of the static guard
1401 // variable; all other bits are platform defined. Bit 0 shall be 0 when
1402 // the variable is not initialized and 1 when it is.
1403 if (useARMGuardVarABI() && !useInt8GuardVariable) {
1404 auto one = builder.getConstantInt(
1405 localInitOp.getLoc(), mlir::cast<cir::IntType>(guardLoad.getType()),
1406 1);
1407 guardLoad = builder.createAnd(localInitOp.getLoc(), guardLoad, one);
1408 }
1409
1410 // Check if the first byte of the guard variable is zero.
1411 auto zero = builder.getConstantInt(
1412 localInitOp.getLoc(), mlir::cast<cir::IntType>(guardLoad.getType()), 0);
1413 auto needsInit = builder.createCompare(localInitOp.getLoc(),
1414 cir::CmpOpKind::eq, guardLoad, zero);
1415
1416 // Build the guarded initialization inside an if block.
1417 cir::IfOp::create(
1418 builder, globalOp.getLoc(), needsInit,
1419 /*withElseRegion=*/false, [&](mlir::OpBuilder &, mlir::Location) {
1420 emitCXXGuardedInitIf(builder, globalOp, localInitOp.getCtorRegion(),
1421 localInitOp.getDtorRegion(), varDecl, guardPtr,
1422 builder.getPointerTo(guard.getSymType()),
1423 threadsafe);
1424 });
1425 } else {
1426 // Threadsafe statics without inline atomics - call __cxa_guard_acquire
1427 // unconditionally without the initial guard byte check.
1428 globalOp->emitError("NYI: guarded init without inline atomics support");
1429 return;
1430 }
1431
1432 // Insert the removed terminator back.
1433 builder.getInsertionBlock()->push_back(ret);
1434}
1435
1436void LoweringPreparePass::lowerLocalInitOp(cir::LocalInitOp initOp) {
1437
1438 // If we don't actually need to initialize anything anymore, we're done here.
1439 if (initOp.getCtorRegion().empty() && initOp.getDtorRegion().empty()) {
1440 initOp.erase();
1441 return;
1442 }
1443
1444 cir::GlobalOp globalOp = initOp.getReferencedGlobal(symbolTables);
1445 assert(globalOp && "No global-op found");
1446
1447 handleStaticLocal(globalOp, initOp);
1448
1449 // Remove the init local op, now that we've done everything we need with it.
1450 initOp.erase();
1451}
1452static bool isThreadWrapperReplaceable(cir::TLS_Model tls,
1453 clang::ASTContext &astCtx) {
1454 return tls == cir::TLS_Model::GeneralDynamic &&
1455 astCtx.getTargetInfo().getTriple().isOSDarwin();
1456}
1457
1458static cir::GlobalLinkageKind
1460 if (isLocalLinkage(op.getLinkage()))
1461 return op.getLinkage();
1462
1463 if (isThreadWrapperReplaceable(*op.getTlsModel(), astCtx))
1464 if (!isLinkOnceLinkage(op.getLinkage()) &&
1465 !isWeakODRLinkage(op.getLinkage()))
1466 return op.getLinkage();
1467
1468 // If this isn't a TU in which this variable is defined, the thread wrapper is
1469 // discardable.
1470 if (op.isDeclaration())
1471 return cir::GlobalLinkageKind::LinkOnceODRLinkage;
1472 return cir::GlobalLinkageKind::WeakODRLinkage;
1473}
1474
1475cir::FuncOp
1476LoweringPreparePass::getOrCreateThreadLocalWrapper(CIRBaseBuilderTy &builder,
1477 GlobalOp op) {
1478 mlir::OpBuilder::InsertionGuard insertGuard(builder);
1479 builder.setInsertionPointToStart(&mlirModule.getBodyRegion().front());
1480
1481 mlir::StringAttr wrapperName = op.getDynTlsRefs()->getWrapperName();
1482
1483 auto existingWrapperIter = threadLocalWrappers.find(wrapperName.getValue());
1484 if (existingWrapperIter != threadLocalWrappers.end())
1485 return existingWrapperIter->second;
1486
1487 // type is ptr-to-global-type(void);
1488 auto funcType = cir::FuncType::get({}, builder.getPointerTo(op.getSymType()));
1489 cir::FuncOp func =
1490 cir::FuncOp::create(builder, op.getLoc(), wrapperName, funcType);
1491
1492 cir::GlobalLinkageKind linkageKind =
1493 getThreadLocalWrapperLinkage(op, *astCtx);
1494 func.setLinkageAttr(
1495 cir::GlobalLinkageKindAttr::get(&getContext(), linkageKind));
1496
1497 // TODO(cir): This is supposed to refer to the comdat of the global symbol,
1498 // but that isn't in CIR yet.
1499 if (astCtx->getTargetInfo().getTriple().supportsCOMDAT() &&
1500 func.isWeakForLinker())
1501 func.setComdat(true);
1502
1503 mlir::SymbolTable::setSymbolVisibility(
1504 func, mlir::SymbolTable::Visibility::Private);
1505
1506 if (!isLocalLinkage(linkageKind)) {
1507 if (!isThreadWrapperReplaceable(*op.getTlsModel(), *astCtx) ||
1508 isLinkOnceLinkage(linkageKind) || isWeakODRLinkage(linkageKind) ||
1509 op.getGlobalVisibility() == cir::VisibilityKind::Hidden)
1510 func.setGlobalVisibility(cir::VisibilityKind::Hidden);
1511 }
1512 if (isThreadWrapperReplaceable(*op.getTlsModel(), *astCtx))
1513 op->emitError("Unhandled thread wrapper attributes for CC and Nounwind");
1514
1515 threadLocalWrappers.insert({wrapperName.getValue(), func});
1516 return func;
1517}
1518
1519void LoweringPreparePass::defineGlobalThreadLocalWrapper(cir::GlobalOp op,
1520 cir::FuncOp initAlias,
1521 bool isVarDefinition) {
1522 CIRBaseBuilderTy builder(getContext());
1523 cir::FuncOp wrapper = getOrCreateThreadLocalWrapper(builder, op);
1524 mlir::Block *entryBB = wrapper.addEntryBlock();
1525 builder.setInsertionPointToStart(entryBB);
1526 // If we are a situation where we have/need one, emit a call to the init
1527 // function.
1528 if (initAlias) {
1529 mlir::Location aliasLoc = initAlias.getLoc();
1530 if (!isVarDefinition) {
1531 // If this isn't a definition, we have to check that the alias exists.
1532 mlir::Value funcLoad = cir::GetGlobalOp::create(
1533 builder, aliasLoc, cir::PointerType::get(initAlias.getFunctionType()),
1534 initAlias.getSymName());
1535 mlir::Value nullCheck =
1536 builder.getNullValue(funcLoad.getType(), aliasLoc);
1537 mlir::Value cmp = cir::CmpOp::create(
1538 builder, aliasLoc, cir::CmpOpKind::ne, funcLoad, nullCheck);
1539 cir::IfOp::create(builder, aliasLoc, cmp, /*withElseRegion=*/false,
1540 [&](mlir::OpBuilder &, mlir::Location loc) {
1541 builder.createCallOp(aliasLoc, initAlias, {});
1542 cir::YieldOp::create(builder, aliasLoc);
1543 });
1544 } else {
1545 // If this IS a definition, we know the alias exists, so we can just emit
1546 // a call to it.
1547 builder.createCallOp(aliasLoc, initAlias, {});
1548 }
1549 }
1550 cir::GetGlobalOp get = builder.createGetGlobal(op, /*tls=*/true);
1551 cir::ReturnOp::create(builder, op.getLoc(), {get});
1552}
1553
1554cir::FuncOp
1555LoweringPreparePass::defineGlobalThreadLocalInitAlias(cir::GlobalOp op,
1556 cir::FuncOp aliasee) {
1557 CIRBaseBuilderTy builder(getContext());
1558 mlir::OpBuilder::InsertionGuard insertGuard(builder);
1559 builder.setInsertionPointToStart(&mlirModule.getBodyRegion().front());
1560 mlir::StringAttr aliasName = op.getDynTlsRefs()->getInitName();
1561 auto existingAliasIter = threadLocalInitAliases.find(aliasName.getValue());
1562
1563 if (existingAliasIter != threadLocalInitAliases.end())
1564 return existingAliasIter->second;
1565
1566 auto funcType = builder.getVoidFnTy();
1567 cir::FuncOp alias =
1568 cir::FuncOp::create(builder, op.getLoc(), aliasName, funcType);
1569 alias.setLinkage(op.getLinkage());
1570
1571 if (aliasee) {
1572 alias.setAliasee(aliasee.getSymName());
1573 } else {
1574 // If we don't have anything to alias (because this isn't a variable
1575 // definition!), we set this as just a function definition with no alias,
1576 // and extern-weak.
1577 alias.setLinkage(cir::GlobalLinkageKind::ExternalWeakLinkage);
1578 mlir::SymbolTable::setSymbolVisibility(
1579 alias, mlir::SymbolTable::Visibility::Private);
1580 }
1581
1582 threadLocalInitAliases.insert({aliasName.getValue(), alias});
1583 return alias;
1584}
1585
1586void LoweringPreparePass::lowerGlobalOp(GlobalOp op) {
1587 // Static locals are handled separately via guard variables.
1588 if (op.getStaticLocalGuard())
1589 return;
1590
1591 mlir::Region &ctorRegion = op.getCtorRegion();
1592 mlir::Region &dtorRegion = op.getDtorRegion();
1593 cir::FuncOp initAlias;
1594
1595 if (!ctorRegion.empty() || !dtorRegion.empty()) {
1596 // Build a variable initialization function and move the initialzation code
1597 // in the ctor region over.
1598 cir::FuncOp f = buildCXXGlobalVarDeclInitFunc(op);
1599
1600 // Clear the ctor and dtor region
1601 ctorRegion.getBlocks().clear();
1602 dtorRegion.getBlocks().clear();
1603
1605 if (op.getTlsModel() == TLS_Model::GeneralDynamic &&
1606 !op.getStaticLocalGuard().has_value()) {
1607 // There are two types of global TLS variables: 'ordered' and 'unordered'.
1608 // 'ordered' are the common case. A call to any of them causes all of the
1609 // initializers for all other 'ordered' ones to be called, via a
1610 // `__tls_init` function. So the 'init alias' that gets called in the
1611 // wrapper for these goes directly to `__tls_init`.
1612
1613 // 'Unordered' values are the case for variable templates. In this case,
1614 // their init alias goes directly to their init function. The FE generates
1615 // a guard variable for them (since they cannot use the global guard), so
1616 // we differentiate them that way.
1617
1618 if (op.getDynTlsRefs()->getGuardName()) {
1619 // Unordered: the alias is the function we just generated.
1620 initAlias = defineGlobalThreadLocalInitAlias(op, f);
1621 } else {
1622 // Ordered: Get the __tls_init, and make the alias to that.
1623 initAlias = defineGlobalThreadLocalInitAlias(op, getTlsInitFn());
1624 // Ordered inits also need to get called from the __tls_init function,
1625 // so we add the init function to the list, so that we can add them to
1626 // it later.
1627 globalThreadLocalInitializers.push_back(f);
1628 }
1629 } else {
1630 dynamicInitializers.push_back(f);
1631 }
1632 } else if (op.getTlsModel() == TLS_Model::GeneralDynamic &&
1633 op.getDynTlsRefs() && op.isDeclaration()) {
1634 // If this is a declaration and has no init function, we probably DO have to
1635 // create an alias that needs checking, so create it as extern-weak.
1636 initAlias = defineGlobalThreadLocalInitAlias(op, {});
1637 }
1638
1639 // We need a wrapper for TLS globals that MIGHT have a non-constant
1640 // initialization. The FE will have generated the DynTlsRefs for any with
1641 // known dynamic init, or unknown (extern) init.
1642 if (op.getTlsModel() == TLS_Model::GeneralDynamic && op.getDynTlsRefs())
1643 defineGlobalThreadLocalWrapper(op, initAlias, !op.isDeclaration());
1644
1646}
1647
1648void LoweringPreparePass::lowerGetGlobalOp(GetGlobalOp op) {
1649 if (!op.getTls())
1650 return;
1651 auto globalOp = mlir::cast<cir::GlobalOp>(
1652 symbolTables.lookupNearestSymbolFrom(op, op.getNameAttr()));
1653
1654 // Only global/namespace scope thread local variables need to have their
1655 // get-global operations rewritten to be calls to a wrapper function. If
1656 // we're not in a dynamic TLS (or one without the TLS markers), we can leave
1657 // this one as a get-global and return early.
1658 if (globalOp.getTlsModel() != TLS_Model::GeneralDynamic ||
1659 !globalOp.getDynTlsRefs())
1660 return;
1661
1662 // If this is a global TLS, we need to replace the call to 'get_global' with a
1663 // call to the wrapper function. Classic codegen figures out some cases where
1664 // we can omit this, but for now we're going to always put it in, as it is
1665 // effectively a no-op.
1666
1667 // The first 'GetGlobalOp' at the beginning of a ctor/dtor region on one of
1668 // these is for the purpose of creating/destroying. We want to skip replacing
1669 // THAT one, but leave all other get-global-ops in place, else
1670 // self-referential ops won't work right.
1671
1672 // Note that ctors/dtors are removed during this pass. We get away with these
1673 // checks because the only time that these situations can actually be true
1674 // (that is, the ctor/dtor region exist) is if we're in the process of
1675 // converting the ctor/dtor for this. If we're NOT doing that, the ctor/dtor
1676 // will have already disappeared.
1677 mlir::Operation *parentOp = op->getParentOp();
1678 if (parentOp == globalOp) {
1679 mlir::Region *ctorRegion = &globalOp.getCtorRegion();
1680 mlir::Region *dtorRegion = &globalOp.getDtorRegion();
1681
1682 if (!ctorRegion->empty() && &*ctorRegion->op_begin() == op.getOperation())
1683 return;
1684 if (!dtorRegion->empty() && &*dtorRegion->op_begin() == op.getOperation())
1685 return;
1686 }
1687
1688 CIRBaseBuilderTy builder(getContext());
1689 cir::FuncOp wrapperFunc = getOrCreateThreadLocalWrapper(builder, globalOp);
1690
1691 builder.setInsertionPoint(op);
1692 cir::CallOp call = builder.createCallOp(
1693 wrapperFunc.getLoc(),
1694 mlir::FlatSymbolRefAttr::get(wrapperFunc.getSymNameAttr()),
1695 wrapperFunc.getFunctionType().getReturnType(), {});
1696 op->replaceAllUsesWith(call);
1697 op.erase();
1698}
1699
1700void LoweringPreparePass::lowerThreeWayCmpOp(CmpThreeWayOp op) {
1701 CIRBaseBuilderTy builder(getContext());
1702 builder.setInsertionPointAfter(op);
1703
1704 mlir::Location loc = op->getLoc();
1705 cir::CmpThreeWayInfoAttr cmpInfo = op.getInfo();
1706
1707 mlir::Value ltRes =
1708 builder.getConstantInt(loc, op.getType(), cmpInfo.getLt());
1709 mlir::Value eqRes =
1710 builder.getConstantInt(loc, op.getType(), cmpInfo.getEq());
1711 mlir::Value gtRes =
1712 builder.getConstantInt(loc, op.getType(), cmpInfo.getGt());
1713
1714 mlir::Value transformedResult;
1715 if (cmpInfo.getOrdering() != CmpOrdering::Partial) {
1716 // Total ordering
1717 mlir::Value lt =
1718 builder.createCompare(loc, CmpOpKind::lt, op.getLhs(), op.getRhs());
1719 mlir::Value selectOnLt = builder.createSelect(loc, lt, ltRes, gtRes);
1720 mlir::Value eq =
1721 builder.createCompare(loc, CmpOpKind::eq, op.getLhs(), op.getRhs());
1722 transformedResult = builder.createSelect(loc, eq, eqRes, selectOnLt);
1723 } else {
1724 // Partial ordering
1725 cir::ConstantOp unorderedRes = builder.getConstantInt(
1726 loc, op.getType(), cmpInfo.getUnordered().value());
1727
1728 mlir::Value eq =
1729 builder.createCompare(loc, CmpOpKind::eq, op.getLhs(), op.getRhs());
1730 mlir::Value selectOnEq = builder.createSelect(loc, eq, eqRes, unorderedRes);
1731 mlir::Value gt =
1732 builder.createCompare(loc, CmpOpKind::gt, op.getLhs(), op.getRhs());
1733 mlir::Value selectOnGt = builder.createSelect(loc, gt, gtRes, selectOnEq);
1734 mlir::Value lt =
1735 builder.createCompare(loc, CmpOpKind::lt, op.getLhs(), op.getRhs());
1736 transformedResult = builder.createSelect(loc, lt, ltRes, selectOnGt);
1737 }
1738
1739 op.replaceAllUsesWith(transformedResult);
1740 op.erase();
1741}
1742
1743template <typename AttributeTy>
1744static llvm::SmallVector<mlir::Attribute>
1745prepareCtorDtorAttrList(mlir::MLIRContext *context,
1746 llvm::ArrayRef<std::pair<std::string, uint32_t>> list) {
1748 for (const auto &[name, priority] : list)
1749 attrs.push_back(AttributeTy::get(context, name, priority));
1750 return attrs;
1751}
1752
1753void LoweringPreparePass::buildGlobalCtorDtorList() {
1754 if (!globalCtorList.empty()) {
1755 llvm::SmallVector<mlir::Attribute> globalCtors =
1757 globalCtorList);
1758
1759 mlirModule->setAttr(cir::CIRDialect::getGlobalCtorsAttrName(),
1760 mlir::ArrayAttr::get(&getContext(), globalCtors));
1761 }
1762
1763 if (!globalDtorList.empty()) {
1764 llvm::SmallVector<mlir::Attribute> globalDtors =
1766 globalDtorList);
1767 mlirModule->setAttr(cir::CIRDialect::getGlobalDtorsAttrName(),
1768 mlir::ArrayAttr::get(&getContext(), globalDtors));
1769 }
1770}
1771
1772cir::GlobalOp
1773LoweringPreparePass::createGlobalThreadLocalGuard(CIRBaseBuilderTy &builder,
1774 mlir::Location loc) {
1775 mlir::OpBuilder::InsertionGuard guard(builder);
1776 builder.setInsertionPointToStart(mlirModule.getBody());
1777
1778 // The TLS Guard is always an Int8Ty.
1779 cir::IntType guardTy = builder.getSIntNTy(8);
1780 auto g = cir::GlobalOp::create(builder, loc, "__tls_guard", guardTy);
1781 g.setLinkageAttr(cir::GlobalLinkageKindAttr::get(
1782 builder.getContext(), cir::GlobalLinkageKind::InternalLinkage));
1783 g.setAlignment(clang::CharUnits::One().getAsAlign().value());
1784 // At the moment, we only have implementation for this mode, as it is the
1785 // default. At one point we might need to load this mode from the module.
1786 g.setTlsModel(TLS_Model::GeneralDynamic);
1787 g.setInitialValueAttr(cir::IntAttr::get(guardTy, 0));
1788 return g;
1789}
1790
1791cir::IfOp LoweringPreparePass::buildGlobalTlsGuardCheck(
1792 CIRBaseBuilderTy &builder, mlir::Location loc, cir::GlobalOp guard) {
1793 cir::GetGlobalOp getGuard = builder.createGetGlobal(guard, /*tls=*/true);
1794 mlir::Value getGuardValue = getGuard;
1795
1796 // Classic codegen always just loads the first byte of the guard instead of
1797 // the whole thing. __tls_guard is already only 8 bits, but for the case of
1798 // unordered TLS, it gets created as 64 bits.
1799 if (guard.getSymType() != builder.getSIntNTy(8))
1800 getGuardValue = builder.createBitcast(
1801 getGuard, cir::PointerType::get(builder.getSIntNTy(8)));
1802
1803 mlir::Value guardLoad =
1804 builder.createAlignedLoad(loc, getGuardValue, *guard.getAlignment());
1805 auto zero = builder.getConstantInt(loc, builder.getSIntNTy(8), 0);
1806 cir::CmpOp compare =
1807 builder.createCompare(loc, cir::CmpOpKind::eq, guardLoad, zero);
1808 return cir::IfOp::create(
1809 builder, loc, compare,
1810 /*withElseRegion=*/false, [&](mlir::OpBuilder &, mlir::Location loc) {
1811 // Classic codegen still does this store as a i8, but it doesn't seem
1812 // reasonable to do an i8 store into a 64 bit value?
1813 builder.createStore(
1814 loc, builder.getConstantInt(loc, guard.getSymType(), 1), getGuard);
1815 });
1816}
1817
1818void LoweringPreparePass::buildCXXGlobalTlsFunc() {
1819 if (globalThreadLocalInitializers.empty())
1820 return;
1821
1822 // The global-ordered-init function for TLS variables just calls each of the
1823 // init-functions in order after doing a guard.
1824
1825 cir::FuncOp tlsInit = getTlsInitFn();
1826 mlir::Location loc = tlsInit.getLoc();
1827 CIRBaseBuilderTy builder(getContext());
1828 mlir::Block *entryBB = tlsInit.addEntryBlock();
1829 builder.setInsertionPointToStart(entryBB);
1830
1831 cir::IfOp ifOperation = buildGlobalTlsGuardCheck(
1832 builder, loc, createGlobalThreadLocalGuard(builder, loc));
1833
1834 // Emit the body of the guarded spot.
1835 builder.setInsertionPointToEnd(&ifOperation.getThenRegion().front());
1836 for (cir::FuncOp initFunc : globalThreadLocalInitializers)
1837 builder.createCallOp(loc, initFunc, {});
1838 cir::YieldOp::create(builder, loc);
1839
1840 builder.setInsertionPointAfter(ifOperation);
1841 cir::ReturnOp::create(builder, loc);
1842}
1843
1844void LoweringPreparePass::buildCXXGlobalInitFunc() {
1845 if (dynamicInitializers.empty())
1846 return;
1847
1848 // TODO: handle globals with a user-specified initialzation priority.
1849 // TODO: handle default priority more nicely.
1851
1852 SmallString<256> fnName;
1853 cir::GlobalLinkageKind linkage;
1854 // Include the filename in the symbol name. Including "sub_" matches gcc
1855 // and makes sure these symbols appear lexicographically behind the symbols
1856 // with priority (TBD). Module implementation units behave the same
1857 // way as a non-modular TU with imports.
1858 // TODO: check CXX20ModuleInits
1859 if (astCtx->getCurrentNamedModule() &&
1861 llvm::raw_svector_ostream out(fnName);
1862 std::unique_ptr<clang::MangleContext> mangleCtx(
1863 astCtx->createMangleContext());
1864 cast<clang::ItaniumMangleContext>(*mangleCtx)
1865 .mangleModuleInitializer(astCtx->getCurrentNamedModule(), out);
1866 linkage = cir::GlobalLinkageKind::ExternalLinkage;
1867 } else {
1868 fnName += "_GLOBAL__sub_I_";
1869 fnName += getTransformedFileName(mlirModule);
1870 linkage = cir::GlobalLinkageKind::InternalLinkage;
1871 }
1872
1873 CIRBaseBuilderTy builder(getContext());
1874 builder.setInsertionPointToEnd(&mlirModule.getBodyRegion().back());
1875 auto fnType = cir::FuncType::get({}, builder.getVoidTy());
1876 cir::FuncOp f = buildRuntimeFunction(builder, fnName, mlirModule.getLoc(),
1877 fnType, linkage);
1878 builder.setInsertionPointToStart(f.addEntryBlock());
1879 for (cir::FuncOp &f : dynamicInitializers)
1880 builder.createCallOp(f.getLoc(), f, {});
1881 // Add the global init function (not the individual ctor functions) to the
1882 // global ctor list.
1883 globalCtorList.emplace_back(fnName,
1884 cir::GlobalCtorAttr::getDefaultPriority());
1885
1886 cir::ReturnOp::create(builder, f.getLoc());
1887}
1888
1889/// Lower a cir.array.ctor or cir.array.dtor into a do-while loop that
1890/// iterates over every element. For cir.array.ctor ops whose partial_dtor
1891/// region is non-empty, the ctor loop is wrapped in a cir.cleanup.scope whose
1892/// EH cleanup performs a reverse destruction loop using the partial dtor body.
1894 clang::ASTContext *astCtx,
1895 mlir::Operation *op, mlir::Type eltTy,
1896 mlir::Value addr,
1897 mlir::Value numElements,
1898 uint64_t arrayLen, bool isCtor) {
1899 mlir::Location loc = op->getLoc();
1900 bool isDynamic = numElements != nullptr;
1901
1902 // TODO: instead of getting the size from the AST context, create alias for
1903 // PtrDiffTy and unify with CIRGen stuff.
1904 const unsigned sizeTypeSize =
1905 astCtx->getTypeSize(astCtx->getSignedSizeType());
1906
1907 // Both constructors and destructors use end = begin + numElements.
1908 // Constructors iterate forward [begin, end). Destructors iterate backward
1909 // from end, decrementing before calling the destructor on each element.
1910 mlir::Value begin, end;
1911 if (isDynamic) {
1912 begin = addr;
1913 end = cir::PtrStrideOp::create(builder, loc, eltTy, begin, numElements);
1914 } else {
1915 mlir::Value endOffsetVal =
1916 builder.getUnsignedInt(loc, arrayLen, sizeTypeSize);
1917 begin = cir::CastOp::create(builder, loc, eltTy,
1918 cir::CastKind::array_to_ptrdecay, addr);
1919 end = cir::PtrStrideOp::create(builder, loc, eltTy, begin, endOffsetVal);
1920 }
1921
1922 mlir::Value start = isCtor ? begin : end;
1923 mlir::Value stop = isCtor ? end : begin;
1924
1925 // For dynamic destructors, guard against zero elements.
1926 // This places the destructor loop emitted below inside the if block.
1927 cir::IfOp ifOp;
1928 if (isDynamic) {
1929 mlir::Value guardCond;
1930 if (isCtor) {
1931 mlir::Value zero = builder.getUnsignedInt(loc, 0, sizeTypeSize);
1932 guardCond = cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne,
1933 numElements, zero);
1934 } else {
1935 // We could check for numElements != 0 in this case too, but this matches
1936 // what classic codegen does.
1937 guardCond =
1938 cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne, start, stop);
1939 }
1940 ifOp = cir::IfOp::create(builder, loc, guardCond,
1941 /*withElseRegion=*/false,
1942 [&](mlir::OpBuilder &, mlir::Location) {});
1943 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1944 }
1945
1946 mlir::Value tmpAddr =
1947 builder.createAlloca(loc, /*addr type*/ builder.getPointerTo(eltTy),
1948 "__array_idx", builder.getAlignmentAttr(1));
1949 builder.createStore(loc, start, tmpAddr);
1950
1951 mlir::Block *bodyBlock = &op->getRegion(0).front();
1952
1953 // Clone the region body (ctor/dtor call and any setup ops like per-element
1954 // zero-init) into the loop, remapping the block argument to the current
1955 // element pointer.
1956 auto cloneRegionBodyInto = [&](mlir::Block *srcBlock,
1957 mlir::Value replacement) {
1958 mlir::IRMapping map;
1959 map.map(srcBlock->getArgument(0), replacement);
1960 for (mlir::Operation &regionOp : *srcBlock) {
1961 if (!mlir::isa<cir::YieldOp>(&regionOp))
1962 builder.clone(regionOp, map);
1963 }
1964 };
1965
1966 mlir::Block *partialDtorBlock = nullptr;
1967 if (auto arrayCtor = mlir::dyn_cast<cir::ArrayCtor>(op)) {
1968 mlir::Region &partialDtor = arrayCtor.getPartialDtor();
1969 if (!partialDtor.empty())
1970 partialDtorBlock = &partialDtor.front();
1971 } else if (auto arrayDtor = mlir::dyn_cast<cir::ArrayDtor>(op)) {
1972 // When the element destructor may throw, reuse the body block as the
1973 // partial-dtor block so that an exception thrown by an element's dtor
1974 // continues the reverse-destruction loop in the EH cleanup region. The
1975 // body block already stores the next element pointer to `tmpAddr`
1976 // before invoking the dtor, so when an exception unwinds from the
1977 // dtor call `tmpAddr` already points at the element that threw, and
1978 // the cleanup loop picks up from `tmpAddr - 1` and walks back to
1979 // `begin`.
1980 if (arrayDtor.getDtorMayThrow())
1981 partialDtorBlock = bodyBlock;
1982 }
1983
1984 auto emitCtorDtorLoop = [&]() {
1985 builder.createDoWhile(
1986 loc,
1987 /*condBuilder=*/
1988 [&](mlir::OpBuilder &b, mlir::Location loc) {
1989 auto currentElement = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
1990 auto cmp = cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne,
1991 currentElement, stop);
1992 builder.createCondition(cmp);
1993 },
1994 /*bodyBuilder=*/
1995 [&](mlir::OpBuilder &b, mlir::Location loc) {
1996 auto currentElement = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
1997 if (isCtor) {
1998 cloneRegionBodyInto(bodyBlock, currentElement);
1999 mlir::Value stride = builder.getUnsignedInt(loc, 1, sizeTypeSize);
2000 auto nextElement = cir::PtrStrideOp::create(builder, loc, eltTy,
2001 currentElement, stride);
2002 builder.createStore(loc, nextElement, tmpAddr);
2003 } else {
2004 mlir::Value stride = builder.getSignedInt(loc, -1, sizeTypeSize);
2005 auto prevElement = cir::PtrStrideOp::create(builder, loc, eltTy,
2006 currentElement, stride);
2007 builder.createStore(loc, prevElement, tmpAddr);
2008 cloneRegionBodyInto(bodyBlock, prevElement);
2009 }
2010
2011 cir::YieldOp::create(b, loc);
2012 });
2013 };
2014
2015 if (partialDtorBlock) {
2016 cir::CleanupScopeOp::create(
2017 builder, loc, cir::CleanupKind::EH,
2018 /*bodyBuilder=*/
2019 [&](mlir::OpBuilder &b, mlir::Location loc) {
2020 emitCtorDtorLoop();
2021 cir::YieldOp::create(b, loc);
2022 },
2023 /*cleanupBuilder=*/
2024 [&](mlir::OpBuilder &b, mlir::Location loc) {
2025 auto cur = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
2026 auto cmp =
2027 cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne, cur, begin);
2028 cir::IfOp::create(
2029 builder, loc, cmp, /*withElseRegion=*/false,
2030 [&](mlir::OpBuilder &b, mlir::Location loc) {
2031 builder.createDoWhile(
2032 loc,
2033 /*condBuilder=*/
2034 [&](mlir::OpBuilder &b, mlir::Location loc) {
2035 auto el = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
2036 auto neq = cir::CmpOp::create(
2037 builder, loc, cir::CmpOpKind::ne, el, begin);
2038 builder.createCondition(neq);
2039 },
2040 /*bodyBuilder=*/
2041 [&](mlir::OpBuilder &b, mlir::Location loc) {
2042 auto el = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
2043 mlir::Value negOne =
2044 builder.getSignedInt(loc, -1, sizeTypeSize);
2045 auto prev = cir::PtrStrideOp::create(builder, loc, eltTy,
2046 el, negOne);
2047 builder.createStore(loc, prev, tmpAddr);
2048 cloneRegionBodyInto(partialDtorBlock, prev);
2049 builder.createYield(loc);
2050 });
2051 cir::YieldOp::create(builder, loc);
2052 });
2053 cir::YieldOp::create(b, loc);
2054 });
2055 } else {
2056 emitCtorDtorLoop();
2057 }
2058
2059 if (ifOp)
2060 cir::YieldOp::create(builder, loc);
2061
2062 op->erase();
2063}
2064
2065void LoweringPreparePass::lowerArrayDtor(cir::ArrayDtor op) {
2066 CIRBaseBuilderTy builder(getContext());
2067 builder.setInsertionPointAfter(op.getOperation());
2068
2069 mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
2070
2071 if (op.getNumElements()) {
2072 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(),
2073 op.getNumElements(), /*arrayLen=*/0,
2074 /*isCtor=*/false);
2075 return;
2076 }
2077
2078 auto arrayLen =
2079 mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
2080 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(),
2081 /*numElements=*/nullptr, arrayLen,
2082 /*isCtor=*/false);
2083}
2084
2085void LoweringPreparePass::lowerArrayCtor(cir::ArrayCtor op) {
2086 cir::CIRBaseBuilderTy builder(getContext());
2087 builder.setInsertionPointAfter(op.getOperation());
2088
2089 mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
2090
2091 if (op.getNumElements()) {
2092 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(),
2093 op.getNumElements(), /*arrayLen=*/0,
2094 /*isCtor=*/true);
2095 return;
2096 }
2097
2098 auto arrayLen =
2099 mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
2100 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(),
2101 /*numElements=*/nullptr, arrayLen,
2102 /*isCtor=*/true);
2103}
2104
2105cir::FuncOp LoweringPreparePass::getCalledFunction(cir::CallOp callOp) {
2106 mlir::SymbolRefAttr sym = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
2107 callOp.getCallableForCallee());
2108 if (!sym)
2109 return nullptr;
2110 return symbolTables.lookupNearestSymbolFrom<cir::FuncOp>(callOp, sym);
2111}
2112
2113void LoweringPreparePass::lowerTrivialCopyCall(cir::CallOp op) {
2114 cir::FuncOp funcOp = getCalledFunction(op);
2115 if (!funcOp)
2116 return;
2117
2118 std::optional<cir::CtorKind> ctorKind = funcOp.getCxxConstructorKind();
2119 if (ctorKind && *ctorKind == cir::CtorKind::Copy &&
2120 funcOp.isCxxTrivialMemberFunction()) {
2121 // Replace the trivial copy constructor call with a `CopyOp`
2122 CIRBaseBuilderTy builder(getContext());
2123 mlir::ValueRange operands = op.getOperands();
2124 mlir::Value dest = operands[0];
2125 mlir::Value src = operands[1];
2126 builder.setInsertionPoint(op);
2127 builder.createCopy(dest, src);
2128 op.erase();
2129 }
2130}
2131
2132cir::GlobalOp LoweringPreparePass::getOrCreateConstAggregateGlobal(
2133 CIRBaseBuilderTy &builder, mlir::Location loc, llvm::StringRef baseName,
2134 mlir::Type ty, mlir::TypedAttr constant) {
2135 // Look up (and lazily populate) the per-base-name cache.
2136 llvm::SmallVector<cir::GlobalOp, 1> &versions =
2137 constAggregateGlobals[baseName];
2138
2139 // First, check globals we've already discovered for this base name.
2140 for (cir::GlobalOp gv : versions) {
2141 if (gv.getSymType() == ty && gv.getInitialValue() == constant)
2142 return gv;
2143 }
2144
2145 // No cached match. Scan the module's symbol table starting from the next
2146 // unscanned version. In practice this should usually exit on the first
2147 // iteration, but it's possible that some other pass or a previous
2148 // invocation of this pass created globals using this same logic.
2149 llvm::SmallString<128> name(baseName);
2150 size_t baseLen = name.size();
2151 unsigned version = versions.size();
2152 while (true) {
2153 name.resize(baseLen);
2154 if (version != 0) {
2155 name.push_back('.');
2156 llvm::Twine(version).toVector(name);
2157 }
2158 auto existingGv = symbolTables.lookupSymbolIn<cir::GlobalOp>(
2159 mlirModule, mlir::StringAttr::get(&getContext(), name));
2160 if (!existingGv)
2161 break;
2162 versions.push_back(existingGv);
2163 if (existingGv.getSymType() == ty &&
2164 existingGv.getInitialValue() == constant)
2165 return existingGv;
2166 ++version;
2167 }
2168
2169 // No match found, create a new global. The loop above found an unused name.
2170 mlir::OpBuilder::InsertionGuard guard(builder);
2171 builder.setInsertionPointToStart(mlirModule.getBody());
2172 auto gv =
2173 cir::GlobalOp::create(builder, loc, name, ty,
2174 /*isConstant=*/true,
2175 cir::LangAddressSpaceAttr::get(
2176 &getContext(), cir::LangAddressSpace::Default),
2177 cir::GlobalLinkageKind::PrivateLinkage);
2178 mlir::SymbolTable::setSymbolVisibility(
2179 gv, mlir::SymbolTable::Visibility::Private);
2180 gv.setInitialValueAttr(constant);
2181
2182 // Keep the cached symbol table in sync with the new global so subsequent
2183 // lookups for other base names find it.
2184 symbolTables.getSymbolTable(mlirModule).insert(gv);
2185
2186 versions.push_back(gv);
2187 return gv;
2188}
2189
2190void LoweringPreparePass::lowerStoreOfConstAggregate(cir::StoreOp op) {
2191 // Check if the value operand is a cir.const with aggregate type.
2192 auto constOp = op.getValue().getDefiningOp<cir::ConstantOp>();
2193 if (!constOp)
2194 return;
2195
2196 mlir::Type ty = constOp.getType();
2197 if (!mlir::isa<cir::ArrayType, cir::RecordType>(ty))
2198 return;
2199
2200 // Only transform stores to local variables (backed by cir.alloca).
2201 // Stores to other addresses (e.g. base_class_addr) should not be
2202 // transformed as they may be partial initializations.
2203 auto alloca = op.getAddr().getDefiningOp<cir::AllocaOp>();
2204 if (!alloca)
2205 return;
2206
2207 mlir::TypedAttr constant = constOp.getValue();
2208
2209 // OG implements several optimization tiers for constant aggregate
2210 // initialization. For now we always create a global constant + memcpy
2211 // (shouldCreateMemCpyFromGlobal). Future work can add the intermediate
2212 // tiers.
2216
2217 // Get function name from parent cir.func.
2218 auto func = op->getParentOfType<cir::FuncOp>();
2219 if (!func)
2220 return;
2221 llvm::StringRef funcName = func.getSymName();
2222
2223 // Get variable name from the alloca.
2224 llvm::StringRef varName = alloca.getName();
2225
2226 // Build base name: __const.<func>.<var>
2227 std::string baseName = ("__const." + funcName + "." + varName).str();
2228 CIRBaseBuilderTy builder(getContext());
2229
2230 // Check for existing globals and create a new global with a unique name
2231 // if no match is found.
2232 cir::GlobalOp gv = getOrCreateConstAggregateGlobal(builder, op.getLoc(),
2233 baseName, ty, constant);
2234
2235 // Now replace the store with get_global + copy.
2236 builder.setInsertionPoint(op);
2237
2238 auto ptrTy = cir::PointerType::get(ty);
2239 mlir::Value globalPtr =
2240 cir::GetGlobalOp::create(builder, op.getLoc(), ptrTy, gv.getSymName());
2241
2242 // Replace store with copy.
2243 builder.createCopy(op.getAddr(), globalPtr);
2244
2245 // Erase the original store.
2246 op.erase();
2247
2248 // Erase the cir.const if it has no remaining users.
2249 if (constOp.use_empty())
2250 constOp.erase();
2251}
2252
2253void LoweringPreparePass::runOnOp(mlir::Operation *op) {
2254 if (auto arrayCtor = dyn_cast<cir::ArrayCtor>(op)) {
2255 lowerArrayCtor(arrayCtor);
2256 } else if (auto arrayDtor = dyn_cast<cir::ArrayDtor>(op)) {
2257 lowerArrayDtor(arrayDtor);
2258 } else if (auto cast = mlir::dyn_cast<cir::CastOp>(op)) {
2259 lowerCastOp(cast);
2260 } else if (auto complexConj = mlir::dyn_cast<cir::ComplexConjOp>(op)) {
2261 lowerComplexConjOp(complexConj);
2262 } else if (auto complexDiv = mlir::dyn_cast<cir::ComplexDivOp>(op)) {
2263 lowerComplexDivOp(complexDiv);
2264 } else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op)) {
2265 lowerComplexMulOp(complexMul);
2266 } else if (auto glob = mlir::dyn_cast<cir::GlobalOp>(op)) {
2267 lowerGlobalOp(glob);
2268 if (auto regAttr = glob->getAttrOfType<CUDAVarRegistrationInfoAttr>(
2269 CUDAVarRegistrationInfoAttr::getMnemonic()))
2270 cudaDeviceVars.emplace_back(glob, regAttr);
2271 } else if (auto getGlob = mlir::dyn_cast<cir::GetGlobalOp>(op)) {
2272 lowerGetGlobalOp(getGlob);
2273 } else if (auto callOp = dyn_cast<cir::CallOp>(op)) {
2274 lowerTrivialCopyCall(callOp);
2275 } else if (auto storeOp = dyn_cast<cir::StoreOp>(op)) {
2276 lowerStoreOfConstAggregate(storeOp);
2277 } else if (auto fnOp = dyn_cast<cir::FuncOp>(op)) {
2278 if (auto globalCtor = fnOp.getGlobalCtorPriority())
2279 globalCtorList.emplace_back(fnOp.getName(), globalCtor.value());
2280 else if (auto globalDtor = fnOp.getGlobalDtorPriority())
2281 globalDtorList.emplace_back(fnOp.getName(), globalDtor.value());
2282
2283 if (mlir::Attribute attr =
2284 fnOp->getAttr(cir::CUDAKernelNameAttr::getMnemonic())) {
2285 auto kernelNameAttr = dyn_cast<CUDAKernelNameAttr>(attr);
2286 llvm::StringRef kernelName = kernelNameAttr.getKernelName();
2287 cudaKernelMap[kernelName] = fnOp;
2288 }
2289 } else if (auto threeWayCmp = dyn_cast<cir::CmpThreeWayOp>(op)) {
2290 lowerThreeWayCmpOp(threeWayCmp);
2291 } else if (auto initOp = dyn_cast<cir::LocalInitOp>(op)) {
2292 lowerLocalInitOp(initOp);
2293 }
2294}
2295
2296static llvm::StringRef getCUDAPrefix(clang::ASTContext *astCtx) {
2297 if (astCtx->getLangOpts().HIP)
2298 return "hip";
2299 return "cuda";
2300}
2301
2302static std::string addUnderscoredPrefix(llvm::StringRef prefix,
2303 llvm::StringRef name) {
2304 return ("__" + prefix + name).str();
2305}
2306
2307/// Creates a global constructor function for the module:
2308///
2309/// For CUDA:
2310/// \code
2311/// void __cuda_module_ctor() {
2312/// Handle = __cudaRegisterFatBinary(GpuBinaryBlob);
2313/// __cuda_register_globals(Handle);
2314/// }
2315/// \endcode
2316///
2317/// For HIP:
2318/// \code
2319/// void __hip_module_ctor() {
2320/// if (__hip_gpubin_handle == 0) {
2321/// __hip_gpubin_handle = __hipRegisterFatBinary(GpuBinaryBlob);
2322/// __hip_register_globals(__hip_gpubin_handle);
2323/// }
2324/// }
2325/// \endcode
2326void LoweringPreparePass::buildCUDAModuleCtor() {
2327 bool isHIP = astCtx->getLangOpts().HIP;
2328
2329 if (astCtx->getLangOpts().GPURelocatableDeviceCode)
2330 llvm_unreachable("GPU RDC NYI");
2331
2332 // For CUDA without -fgpu-rdc, it's safe to stop generating ctor
2333 // if there's nothing to register.
2334 if (cudaKernelMap.empty() && cudaDeviceVars.empty())
2335 return;
2336
2337 // There's no device-side binary, so no need to proceed for CUDA.
2338 // HIP has to create an external symbol in this case, which is NYI.
2339 mlir::Attribute cudaBinaryHandleAttr =
2340 mlirModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName());
2341 if (!cudaBinaryHandleAttr) {
2342 if (isHIP)
2344 return;
2345 }
2346
2347 llvm::StringRef cudaGPUBinaryName =
2348 mlir::cast<CUDABinaryHandleAttr>(cudaBinaryHandleAttr)
2349 .getName()
2350 .getValue();
2351
2352 llvm::vfs::FileSystem &vfs =
2354 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> gpuBinaryOrErr =
2355 vfs.getBufferForFile(cudaGPUBinaryName);
2356 if (std::error_code ec = gpuBinaryOrErr.getError()) {
2357 mlirModule->emitError("cannot open GPU binary file: " + cudaGPUBinaryName +
2358 ": " + ec.message());
2359 return;
2360 }
2361 std::unique_ptr<llvm::MemoryBuffer> gpuBinary =
2362 std::move(gpuBinaryOrErr.get());
2363
2364 // Set up common types and builder.
2365 llvm::StringRef cudaPrefix = getCUDAPrefix(astCtx);
2366 mlir::Location loc = mlirModule->getLoc();
2367 CIRBaseBuilderTy builder(getContext());
2368 builder.setInsertionPointToStart(mlirModule.getBody());
2369
2370 Type voidTy = builder.getVoidTy();
2371 PointerType voidPtrTy = builder.getVoidPtrTy();
2372 PointerType voidPtrPtrTy = builder.getPointerTo(voidPtrTy);
2373 IntType intTy = builder.getSIntNTy(32);
2374 IntType charTy = cir::IntType::get(&getContext(), astCtx->getCharWidth(),
2375 /*isSigned=*/false);
2376
2377 // --- Create fatbin globals ---
2378
2379 // The section names are different for MAC OS X.
2380 llvm::StringRef fatbinConstName =
2381 astCtx->getLangOpts().HIP ? ".hip_fatbin" : ".nv_fatbin";
2382
2383 llvm::StringRef fatbinSectionName =
2384 astCtx->getLangOpts().HIP ? ".hipFatBinSegment" : ".nvFatBinSegment";
2385
2386 // Create the fatbin string constant with GPU binary contents.
2387 auto fatbinType =
2388 ArrayType::get(&getContext(), charTy, gpuBinary->getBuffer().size());
2389 std::string fatbinStrName = addUnderscoredPrefix(cudaPrefix, "_fatbin_str");
2390 GlobalOp fatbinStr = GlobalOp::create(builder, loc, fatbinStrName, fatbinType,
2391 /*isConstant=*/true, {},
2392 GlobalLinkageKind::PrivateLinkage);
2393 fatbinStr.setAlignment(8);
2394 fatbinStr.setInitialValueAttr(cir::ConstArrayAttr::get(
2395 fatbinType, StringAttr::get(gpuBinary->getBuffer(), fatbinType)));
2396 fatbinStr.setSection(fatbinConstName);
2397 fatbinStr.setPrivate();
2398
2399 // Create the fatbin wrapper struct:
2400 // struct { int magic; int version; void *fatbin; void *unused; };
2401 auto fatbinWrapperType = cir::StructType::get(
2402 &getContext(), {intTy, intTy, voidPtrTy, voidPtrTy},
2403 /*packed=*/false, /*padded=*/false, /*is_class=*/false);
2404 std::string fatbinWrapperName =
2405 addUnderscoredPrefix(cudaPrefix, "_fatbin_wrapper");
2406 GlobalOp fatbinWrapper = GlobalOp::create(
2407 builder, loc, fatbinWrapperName, fatbinWrapperType,
2408 /*isConstant=*/true, {}, GlobalLinkageKind::PrivateLinkage);
2409 fatbinWrapper.setSection(fatbinSectionName);
2410
2411 constexpr unsigned cudaFatMagic = 0x466243b1;
2412 constexpr unsigned hipFatMagic = 0x48495046;
2413 unsigned fatMagic = isHIP ? hipFatMagic : cudaFatMagic;
2414
2415 auto magicInit = IntAttr::get(intTy, fatMagic);
2416 auto versionInit = IntAttr::get(intTy, 1);
2417 auto fatbinStrSymbol =
2418 mlir::FlatSymbolRefAttr::get(fatbinStr.getSymNameAttr());
2419 auto fatbinInit = GlobalViewAttr::get(voidPtrTy, fatbinStrSymbol);
2420 mlir::TypedAttr unusedInit = builder.getConstNullPtrAttr(voidPtrTy);
2421 fatbinWrapper.setInitialValueAttr(cir::ConstRecordAttr::get(
2422 fatbinWrapperType,
2423 mlir::ArrayAttr::get(&getContext(),
2424 {magicInit, versionInit, fatbinInit, unusedInit})));
2425
2426 // Create the GPU binary handle global variable.
2427 std::string gpubinHandleName =
2428 addUnderscoredPrefix(cudaPrefix, "_gpubin_handle");
2429
2430 GlobalOp gpuBinHandle = GlobalOp::create(
2431 builder, loc, gpubinHandleName, voidPtrPtrTy,
2432 /*isConstant=*/false, {}, cir::GlobalLinkageKind::InternalLinkage);
2433 gpuBinHandle.setInitialValueAttr(builder.getConstNullPtrAttr(voidPtrPtrTy));
2434 gpuBinHandle.setPrivate();
2435
2436 // Declare this function:
2437 // void **__{cuda|hip}RegisterFatBinary(void *);
2438
2439 std::string regFuncName =
2440 addUnderscoredPrefix(cudaPrefix, "RegisterFatBinary");
2441 FuncType regFuncType = FuncType::get({voidPtrTy}, voidPtrPtrTy);
2442 cir::FuncOp regFunc =
2443 buildRuntimeFunction(builder, regFuncName, loc, regFuncType);
2444
2445 std::string moduleCtorName = addUnderscoredPrefix(cudaPrefix, "_module_ctor");
2446 cir::FuncOp moduleCtor = buildRuntimeFunction(
2447 builder, moduleCtorName, loc, FuncType::get({}, voidTy),
2448 GlobalLinkageKind::InternalLinkage);
2449
2450 globalCtorList.emplace_back(moduleCtorName,
2451 cir::GlobalCtorAttr::getDefaultPriority());
2452 builder.setInsertionPointToStart(moduleCtor.addEntryBlock());
2454 if (isHIP) {
2455 // --- Create HIP CTOR ---
2456 // if (__hip_gpubin_handle == nullptr)
2457 // __hip_gpubin_handle = __hipRegisterFatBinary(&fatbinWrapper);
2458 // __hip_register_globals(__hip_gpubin_handle);
2459 // atexit(__hip_module_dtor);
2460 mlir::Block *entryBlock = builder.getInsertionBlock();
2461 mlir::Region *parent = entryBlock->getParent();
2462 mlir::Block *ifBlock = builder.createBlock(parent);
2463 mlir::Block *exitBlock = builder.createBlock(parent);
2464 {
2465 mlir::OpBuilder::InsertionGuard guard(builder);
2466 builder.setInsertionPointToEnd(entryBlock);
2467 mlir::Value handle =
2468 builder.createLoad(loc, builder.createGetGlobal(gpuBinHandle));
2469 auto handlePtrTy = mlir::cast<cir::PointerType>(handle.getType());
2470 mlir::Value nullPtr = builder.getNullPtr(handlePtrTy, loc);
2471 mlir::Value isNull =
2472 builder.createCompare(loc, cir::CmpOpKind::eq, handle, nullPtr);
2473 cir::BrCondOp::create(builder, loc, isNull, ifBlock, exitBlock);
2474 }
2475 {
2476 // Handle is null: load the fatbin and register it.
2477 mlir::OpBuilder::InsertionGuard guard(builder);
2478 builder.setInsertionPointToStart(ifBlock);
2479 mlir::Value wrapper = builder.createGetGlobal(fatbinWrapper);
2480 mlir::Value fatbinVoidPtr = builder.createBitcast(wrapper, voidPtrTy);
2481 cir::CallOp gpuBinaryHandleCall =
2482 builder.createCallOp(loc, regFunc, fatbinVoidPtr);
2483 mlir::Value gpuBinaryHandle = gpuBinaryHandleCall.getResult();
2484 // Store the value back to the global `__hip_gpubin_handle`.
2485 mlir::Value gpuBinaryHandleGlobal = builder.createGetGlobal(gpuBinHandle);
2486 builder.createStore(loc, gpuBinaryHandle, gpuBinaryHandleGlobal);
2487 cir::BrOp::create(builder, loc, exitBlock);
2488 }
2489 {
2490 // Exit block: load the (possibly newly-registered) handle, call
2491 // __hip_register_globals, and register the module dtor with atexit().
2492 mlir::OpBuilder::InsertionGuard guard(builder);
2493 builder.setInsertionPointToStart(exitBlock);
2494 mlir::Value gHandle =
2495 builder.createLoad(loc, builder.createGetGlobal(gpuBinHandle));
2496
2497 if (std::optional<FuncOp> regGlobal = buildCUDARegisterGlobals())
2498 builder.createCallOp(loc, *regGlobal, gHandle);
2499
2500 if (std::optional<FuncOp> dtor = buildHIPModuleDtor()) {
2501 cir::CIRBaseBuilderTy globalBuilder(getContext());
2502 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
2503 FuncOp atexit = buildRuntimeFunction(
2504 globalBuilder, "atexit", loc,
2505 FuncType::get(PointerType::get(dtor->getFunctionType()), intTy));
2506 mlir::Value dtorFunc = GetGlobalOp::create(
2507 builder, loc, PointerType::get(dtor->getFunctionType()),
2508 mlir::FlatSymbolRefAttr::get(dtor->getSymNameAttr()));
2509 builder.createCallOp(loc, atexit, dtorFunc);
2510 }
2511 cir::ReturnOp::create(builder, loc);
2512 }
2513 return;
2514 }
2515 if (!astCtx->getLangOpts().GPURelocatableDeviceCode) {
2516
2517 // --- Create CUDA CTOR-DTOR ---
2518 // Register binary with CUDA runtime. This is substantially different in
2519 // default mode vs. separate compilation.
2520 // Corresponding code:
2521 // gpuBinaryHandle = __cudaRegisterFatBinary(&fatbinWrapper);
2522 mlir::Value wrapper = builder.createGetGlobal(fatbinWrapper);
2523 mlir::Value fatbinVoidPtr = builder.createBitcast(wrapper, voidPtrTy);
2524 cir::CallOp gpuBinaryHandleCall =
2525 builder.createCallOp(loc, regFunc, fatbinVoidPtr);
2526 mlir::Value gpuBinaryHandle = gpuBinaryHandleCall.getResult();
2527 // Store the value back to the global `__cuda_gpubin_handle`.
2528 mlir::Value gpuBinaryHandleGlobal = builder.createGetGlobal(gpuBinHandle);
2529 builder.createStore(loc, gpuBinaryHandle, gpuBinaryHandleGlobal);
2530
2531 // --- Generate __cuda_register_globals and call it ---
2532 if (std::optional<FuncOp> regGlobal = buildCUDARegisterGlobals()) {
2533 builder.createCallOp(loc, *regGlobal, gpuBinaryHandle);
2534 }
2535
2536 // From CUDA 10.1 onwards, we must call this function to end registration:
2537 // void __cudaRegisterFatBinaryEnd(void **fatbinHandle);
2538 // This is CUDA-specific, so no need to use `addUnderscoredPrefix`.
2540 astCtx->getTargetInfo().getSDKVersion(),
2542 cir::CIRBaseBuilderTy globalBuilder(getContext());
2543 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
2544 FuncOp endFunc =
2545 buildRuntimeFunction(globalBuilder, "__cudaRegisterFatBinaryEnd", loc,
2546 FuncType::get({voidPtrPtrTy}, voidTy));
2547 builder.createCallOp(loc, endFunc, gpuBinaryHandle);
2548 }
2549 } else
2550 llvm_unreachable("GPU RDC NYI");
2551
2552 // Create destructor and register it with atexit() the way NVCC does it. Doing
2553 // it during regular destructor phase worked in CUDA before 9.2 but results in
2554 // double-free in 9.2.
2555 if (std::optional<FuncOp> dtor = buildCUDAModuleDtor()) {
2556
2557 // extern "C" int atexit(void (*f)(void));
2558 cir::CIRBaseBuilderTy globalBuilder(getContext());
2559 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
2560 FuncOp atexit = buildRuntimeFunction(
2561 globalBuilder, "atexit", loc,
2562 FuncType::get(PointerType::get(dtor->getFunctionType()), intTy));
2563 mlir::Value dtorFunc = GetGlobalOp::create(
2564 builder, loc, PointerType::get(dtor->getFunctionType()),
2565 mlir::FlatSymbolRefAttr::get(dtor->getSymNameAttr()));
2566 builder.createCallOp(loc, atexit, dtorFunc);
2567 }
2568 cir::ReturnOp::create(builder, loc);
2569}
2570
2571std::optional<FuncOp> LoweringPreparePass::buildCUDAModuleDtor() {
2572 if (!mlirModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName()))
2573 return {};
2574
2575 llvm::StringRef prefix = getCUDAPrefix(astCtx);
2576
2577 VoidType voidTy = VoidType::get(&getContext());
2578 PointerType voidPtrPtrTy = PointerType::get(PointerType::get(voidTy));
2579
2580 mlir::Location loc = mlirModule.getLoc();
2581
2582 cir::CIRBaseBuilderTy builder(getContext());
2583 builder.setInsertionPointToStart(mlirModule.getBody());
2584
2585 // define: void __cudaUnregisterFatBinary(void ** handle);
2586 std::string unregisterFuncName =
2587 addUnderscoredPrefix(prefix, "UnregisterFatBinary");
2588 FuncOp unregisterFunc = buildRuntimeFunction(
2589 builder, unregisterFuncName, loc, FuncType::get({voidPtrPtrTy}, voidTy));
2590
2591 // void __cuda_module_dtor();
2592 // Despite the name, OG doesn't treat it as a destructor, so it shouldn't be
2593 // put into globalDtorList. If it were a real dtor, then it would cause
2594 // double free above CUDA 9.2. The way to use it is to manually call
2595 // atexit() at end of module ctor.
2596 std::string dtorName = addUnderscoredPrefix(prefix, "_module_dtor");
2597 FuncOp dtor =
2598 buildRuntimeFunction(builder, dtorName, loc, FuncType::get({}, voidTy),
2599 GlobalLinkageKind::InternalLinkage);
2600
2601 builder.setInsertionPointToStart(dtor.addEntryBlock());
2602
2603 // For dtor, we only need to call:
2604 // __cudaUnregisterFatBinary(__cuda_gpubin_handle);
2605
2606 std::string gpubinName = addUnderscoredPrefix(prefix, "_gpubin_handle");
2607 GlobalOp gpubinGlobal = cast<GlobalOp>(mlirModule.lookupSymbol(gpubinName));
2608 mlir::Value gpubinAddress = builder.createGetGlobal(gpubinGlobal);
2609 mlir::Value gpubin = builder.createLoad(loc, gpubinAddress);
2610 builder.createCallOp(loc, unregisterFunc, gpubin);
2611 ReturnOp::create(builder, loc);
2612
2613 return dtor;
2614}
2615
2616/// Build the HIP module dtor:
2617///
2618/// void __hip_module_dtor() {
2619/// if (__hip_gpubin_handle != nullptr) {
2620/// __hipUnregisterFatBinary(__hip_gpubin_handle);
2621/// __hip_gpubin_handle = nullptr;
2622/// }
2623/// }
2624///
2625/// Despite the name, OG doesn't treat this as a real destructor: putting it on
2626/// the dtor list would cause a double-free. It is meant to be registered via
2627/// atexit() at the end of the module ctor.
2628std::optional<FuncOp> LoweringPreparePass::buildHIPModuleDtor() {
2629 if (!mlirModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName()))
2630 return {};
2631
2632 llvm::StringRef prefix = getCUDAPrefix(astCtx);
2633
2634 VoidType voidTy = VoidType::get(&getContext());
2635 PointerType voidPtrPtrTy = PointerType::get(PointerType::get(voidTy));
2636
2637 mlir::Location loc = mlirModule.getLoc();
2638
2639 cir::CIRBaseBuilderTy builder(getContext());
2640 builder.setInsertionPointToStart(mlirModule.getBody());
2641
2642 // void __hipUnregisterFatBinary(void ** handle);
2643 std::string unregisterFuncName =
2644 addUnderscoredPrefix(prefix, "UnregisterFatBinary");
2645 FuncOp unregisterFunc = buildRuntimeFunction(
2646 builder, unregisterFuncName, loc, FuncType::get({voidPtrPtrTy}, voidTy));
2647
2648 std::string dtorName = addUnderscoredPrefix(prefix, "_module_dtor");
2649 FuncOp dtor =
2650 buildRuntimeFunction(builder, dtorName, loc, FuncType::get({}, voidTy),
2651 GlobalLinkageKind::InternalLinkage);
2652
2653 std::string gpubinName = addUnderscoredPrefix(prefix, "_gpubin_handle");
2654 GlobalOp gpuBinGlobal = cast<GlobalOp>(mlirModule.lookupSymbol(gpubinName));
2655
2656 mlir::Block *entryBlock = dtor.addEntryBlock();
2657 mlir::Block *ifBlock = builder.createBlock(&dtor.getBody());
2658 mlir::Block *exitBlock = builder.createBlock(&dtor.getBody());
2659
2660 mlir::OpBuilder::InsertionGuard guard(builder);
2661 builder.setInsertionPointToEnd(entryBlock);
2662 mlir::Value handle =
2663 builder.createLoad(loc, builder.createGetGlobal(gpuBinGlobal));
2664 auto handlePtrTy = mlir::cast<cir::PointerType>(handle.getType());
2665 mlir::Value nullPtr = builder.getNullPtr(handlePtrTy, loc);
2666 mlir::Value isNotNull =
2667 builder.createCompare(loc, cir::CmpOpKind::ne, handle, nullPtr);
2668 cir::BrCondOp::create(builder, loc, isNotNull, ifBlock, exitBlock);
2669
2670 {
2671 // Handle is non-null: unregister and clear it.
2672 mlir::OpBuilder::InsertionGuard ifGuard(builder);
2673 builder.setInsertionPointToStart(ifBlock);
2674 builder.createCallOp(loc, unregisterFunc, handle);
2675 builder.createStore(loc, nullPtr, builder.createGetGlobal(gpuBinGlobal));
2676 cir::BrOp::create(builder, loc, exitBlock);
2677 }
2678 {
2679 mlir::OpBuilder::InsertionGuard exitGuard(builder);
2680 builder.setInsertionPointToStart(exitBlock);
2681 cir::ReturnOp::create(builder, loc);
2682 }
2683
2684 return dtor;
2685}
2686
2687std::optional<FuncOp> LoweringPreparePass::buildCUDARegisterGlobals() {
2688 if (cudaKernelMap.empty() && cudaDeviceVars.empty())
2689 return {};
2690
2691 cir::CIRBaseBuilderTy builder(getContext());
2692 builder.setInsertionPointToStart(mlirModule.getBody());
2693
2694 mlir::Location loc = mlirModule.getLoc();
2695 llvm::StringRef cudaPrefix = getCUDAPrefix(astCtx);
2696
2697 auto voidTy = VoidType::get(&getContext());
2698 auto voidPtrTy = PointerType::get(voidTy);
2699 auto voidPtrPtrTy = PointerType::get(voidPtrTy);
2700
2701 // Create the function:
2702 // void __cuda_register_globals(void **fatbinHandle)
2703 std::string regGlobalFuncName =
2704 addUnderscoredPrefix(cudaPrefix, "_register_globals");
2705 auto regGlobalFuncTy = FuncType::get({voidPtrPtrTy}, voidTy);
2706 FuncOp regGlobalFunc =
2707 buildRuntimeFunction(builder, regGlobalFuncName, loc, regGlobalFuncTy,
2708 /*linkage=*/GlobalLinkageKind::InternalLinkage);
2709 builder.setInsertionPointToStart(regGlobalFunc.addEntryBlock());
2710
2711 buildCUDARegisterGlobalFunctions(builder, regGlobalFunc);
2712 buildCUDARegisterVars(builder, regGlobalFunc);
2713
2714 ReturnOp::create(builder, loc);
2715 return regGlobalFunc;
2716}
2717
2718void LoweringPreparePass::buildCUDARegisterGlobalFunctions(
2719 cir::CIRBaseBuilderTy &builder, FuncOp regGlobalFunc) {
2720 mlir::Location loc = mlirModule.getLoc();
2721 llvm::StringRef cudaPrefix = getCUDAPrefix(astCtx);
2722 cir::CIRDataLayout dataLayout(mlirModule);
2723
2724 auto voidTy = VoidType::get(&getContext());
2725 auto voidPtrTy = PointerType::get(voidTy);
2726 auto voidPtrPtrTy = PointerType::get(voidPtrTy);
2727 IntType intTy = builder.getSIntNTy(32);
2728 IntType charTy = cir::IntType::get(&getContext(), astCtx->getCharWidth(),
2729 /*isSigned=*/false);
2730
2731 // Extract the GPU binary handle argument.
2732 mlir::Value fatbinHandle = *regGlobalFunc.args_begin();
2733
2734 cir::CIRBaseBuilderTy globalBuilder(getContext());
2735 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
2736
2737 // Declare CUDA internal functions:
2738 // int __cudaRegisterFunction(
2739 // void **fatbinHandle,
2740 // const char *hostFunc,
2741 // char *deviceFunc,
2742 // const char *deviceName,
2743 // int threadLimit,
2744 // uint3 *tid, uint3 *bid, dim3 *bDim, dim3 *gDim,
2745 // int *wsize
2746 // )
2747 // OG doesn't care about the types at all. They're treated as void*.
2748
2749 FuncOp cudaRegisterFunction = buildRuntimeFunction(
2750 globalBuilder, addUnderscoredPrefix(cudaPrefix, "RegisterFunction"), loc,
2751 FuncType::get({voidPtrPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, intTy,
2752 voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy},
2753 intTy));
2754
2755 auto makeConstantString = [&](llvm::StringRef str) -> GlobalOp {
2756 auto strType = ArrayType::get(&getContext(), charTy, 1 + str.size());
2757 auto tmpString = cir::GlobalOp::create(
2758 globalBuilder, loc, (".str" + str).str(), strType,
2759 /*isConstant=*/true, {},
2760 /*linkage=*/cir::GlobalLinkageKind::PrivateLinkage);
2761
2762 // We must make the string zero-terminated.
2763 tmpString.setInitialValueAttr(
2764 ConstArrayAttr::get(strType, StringAttr::get(str + "\0", strType)));
2765 tmpString.setPrivate();
2766 return tmpString;
2767 };
2768
2769 cir::ConstantOp cirNullPtr = builder.getNullPtr(voidPtrTy, loc);
2770 bool isHIP = astCtx->getLangOpts().HIP;
2771 for (auto kernelName : cudaKernelMap.keys()) {
2772 FuncOp deviceStub = cudaKernelMap[kernelName];
2773 GlobalOp deviceFuncStr = makeConstantString(kernelName);
2774 mlir::Value deviceFunc = builder.createBitcast(
2775 builder.createGetGlobal(deviceFuncStr), voidPtrTy);
2776
2777 mlir::Value hostFunc;
2778 if (isHIP) {
2779 // Under HIP, the kernel-handle is a GlobalOp shadow created by CIR
2780 // codegen and named with the kernel-reference mangled name (e.g.
2781 // `@_Z2fnv` pointing at the device-stub function
2782 // `_Z17__device_stub__fnv`). The CUDAKernelNameAttr on the device-stub
2783 // uses the same name, so we can resolve the shadow by symbol lookup.
2784 auto funcHandle = cast<GlobalOp>(mlirModule.lookupSymbol(kernelName));
2785 hostFunc =
2786 builder.createBitcast(builder.createGetGlobal(funcHandle), voidPtrTy);
2787 } else {
2788 hostFunc = builder.createBitcast(
2789 GetGlobalOp::create(
2790 builder, loc, PointerType::get(deviceStub.getFunctionType()),
2791 mlir::FlatSymbolRefAttr::get(deviceStub.getSymNameAttr())),
2792 voidPtrTy);
2793 }
2794 builder.createCallOp(
2795 loc, cudaRegisterFunction,
2796 {fatbinHandle, hostFunc, deviceFunc, deviceFunc,
2797 ConstantOp::create(builder, loc, IntAttr::get(intTy, -1)), cirNullPtr,
2798 cirNullPtr, cirNullPtr, cirNullPtr, cirNullPtr});
2799 }
2800}
2801
2802// Emit `__{cuda|hip}RegisterVar` calls inside `__{cuda|hip}_register_globals`
2803// for every device-side shadow that carries a `cu.var_registration` attribute
2804// (attached by `CIRGenNVCUDARuntime::handleVarRegistration`).
2805void LoweringPreparePass::buildCUDARegisterVars(cir::CIRBaseBuilderTy &builder,
2806 FuncOp regGlobalFunc) {
2807 mlir::Location loc = mlirModule.getLoc();
2808 llvm::StringRef cudaPrefix = getCUDAPrefix(astCtx);
2809 cir::CIRDataLayout dataLayout(mlirModule);
2810
2811 PointerType voidPtrTy = builder.getVoidPtrTy();
2812 PointerType voidPtrPtrTy = builder.getPointerTo(voidPtrTy);
2813 IntType intTy = builder.getSIntNTy(32);
2814 IntType sizeTy =
2815 builder.getUIntNTy(astCtx->getTargetInfo().getMaxPointerWidth());
2816 IntType charTy = cir::IntType::get(&getContext(), astCtx->getCharWidth(),
2817 /*isSigned=*/false);
2818
2819 if (cudaDeviceVars.empty())
2820 return;
2821
2822 cir::CIRBaseBuilderTy globalBuilder(getContext());
2823 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
2824
2825 // void __{cuda|hip}RegisterVar(void **fatbinHandle,
2826 // char *hostVar, char *deviceAddress,
2827 // const char *deviceName, int ext,
2828 // size_t size, int constant, int normalized);
2829 // OG ignores parameter types, treating pointers as void*.
2830 cir::VoidType voidTy = builder.getVoidTy();
2831 FuncOp cudaRegisterVar = buildRuntimeFunction(
2832 globalBuilder, addUnderscoredPrefix(cudaPrefix, "RegisterVar"), loc,
2833 FuncType::get({voidPtrPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, intTy,
2834 sizeTy, intTy, intTy},
2835 voidTy));
2836
2837 auto makeConstantString = [&](llvm::StringRef str) -> GlobalOp {
2838 auto strType = ArrayType::get(&getContext(), charTy, 1 + str.size());
2839 auto tmpString = cir::GlobalOp::create(
2840 globalBuilder, loc, (".str" + str).str(), strType,
2841 /*isConstant=*/true, {},
2842 /*linkage=*/cir::GlobalLinkageKind::PrivateLinkage);
2843 tmpString.setInitialValueAttr(
2844 ConstArrayAttr::get(strType, StringAttr::get(str + "\0", strType)));
2845 tmpString.setPrivate();
2846 return tmpString;
2847 };
2848
2849 mlir::Value fatbinHandle = *regGlobalFunc.args_begin();
2850
2851 for (auto &[global, regAttr] : cudaDeviceVars) {
2852 switch (regAttr.getKind()) {
2853 case cir::CUDADeviceVarKind::Variable:
2854 break;
2855 case cir::CUDADeviceVarKind::Surface:
2856 llvm_unreachable("Surface registration NYI");
2857 case cir::CUDADeviceVarKind::Texture:
2858 llvm_unreachable("Texture registration NYI");
2859 }
2860
2861 if (regAttr.getIsManaged())
2862 llvm_unreachable("Managed variable registration NYI");
2863
2864 GlobalOp deviceNameStr = makeConstantString(regAttr.getDeviceSideName());
2865 mlir::Value deviceName = builder.createBitcast(
2866 builder.createGetGlobal(deviceNameStr), voidPtrTy);
2867 mlir::Value hostVar =
2868 builder.createBitcast(builder.createGetGlobal(global), voidPtrTy);
2869
2870 auto isExtern = ConstantOp::create(
2871 builder, loc, IntAttr::get(intTy, regAttr.getIsExtern() ? 1 : 0));
2872 llvm::TypeSize size = dataLayout.getTypeAllocSize(global.getSymType());
2873 auto varSize = ConstantOp::create(
2874 builder, loc, IntAttr::get(sizeTy, size.getFixedValue()));
2875 auto isConstant = ConstantOp::create(
2876 builder, loc, IntAttr::get(intTy, regAttr.getIsConstant() ? 1 : 0));
2877 auto normalized = ConstantOp::create(builder, loc, IntAttr::get(intTy, 0));
2878 builder.createCallOp(loc, cudaRegisterVar,
2879 {fatbinHandle, hostVar, deviceName, deviceName,
2880 isExtern, varSize, isConstant, normalized});
2881 }
2882}
2883
2884void LoweringPreparePass::runOnOperation() {
2885 mlir::Operation *op = getOperation();
2886 if (isa<::mlir::ModuleOp>(op))
2887 mlirModule = cast<::mlir::ModuleOp>(op);
2888
2889 llvm::SmallVector<mlir::Operation *> opsToTransform;
2890
2891 op->walk([&](mlir::Operation *op) {
2892 if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
2893 cir::ComplexConjOp, cir::ComplexMulOp, cir::ComplexDivOp,
2894 cir::DynamicCastOp, cir::FuncOp, cir::CallOp,
2895 cir::GetGlobalOp, cir::GlobalOp, cir::StoreOp,
2896 cir::CmpThreeWayOp, cir::LocalInitOp>(op))
2897 opsToTransform.push_back(op);
2898 });
2899
2900 for (mlir::Operation *o : opsToTransform)
2901 runOnOp(o);
2902
2903 buildCXXGlobalInitFunc();
2904 buildCXXGlobalTlsFunc();
2905 if (astCtx->getLangOpts().CUDA && !astCtx->getLangOpts().CUDAIsDevice)
2906 buildCUDAModuleCtor();
2907
2908 buildGlobalCtorDtorList();
2909}
2910
2911std::unique_ptr<Pass> mlir::createLoweringPreparePass() {
2912 return std::make_unique<LoweringPreparePass>();
2913}
2914
2915std::unique_ptr<Pass>
2917 auto pass = std::make_unique<LoweringPreparePass>();
2918 pass->setASTContext(astCtx);
2919 return std::move(pass);
2920}
Defines the clang::ASTContext interface.
static void emitBody(CodeGenFunction &CGF, const Stmt *S, const Stmt *NextLoop, int MaxLevel, int Level=0)
static llvm::FunctionCallee getGuardReleaseFn(CodeGenModule &CGM, llvm::PointerType *GuardPtrTy)
static llvm::FunctionCallee getGuardAcquireFn(CodeGenModule &CGM, llvm::PointerType *GuardPtrTy)
static mlir::Value buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static llvm::StringRef getComplexDivLibCallName(llvm::APFloat::Semantics semantics)
static llvm::SmallVector< mlir::Attribute > prepareCtorDtorAttrList(mlir::MLIRContext *context, llvm::ArrayRef< std::pair< std::string, uint32_t > > list)
static llvm::StringRef getComplexMulLibCallName(llvm::APFloat::Semantics semantics)
static cir::GlobalLinkageKind getThreadLocalWrapperLinkage(GlobalOp op, clang::ASTContext &astCtx)
static mlir::Value buildComplexBinOpLibCall(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, llvm::StringRef(*libFuncNameGetter)(llvm::APFloat::Semantics), mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static mlir::Value lowerComplexMul(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexMulOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static std::string addUnderscoredPrefix(llvm::StringRef prefix, llvm::StringRef name)
static SmallString< 128 > getTransformedFileName(mlir::ModuleOp mlirModule)
static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind scalarCastKind)
static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder, clang::ASTContext *astCtx, mlir::Operation *op, mlir::Type eltTy, mlir::Value addr, mlir::Value numElements, uint64_t arrayLen, bool isCtor)
Lower a cir.array.ctor or cir.array.dtor into a do-while loop that iterates over every element.
static bool isThreadWrapperReplaceable(cir::TLS_Model tls, clang::ASTContext &astCtx)
static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind elemToBoolKind)
static mlir::Value buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static llvm::StringRef getCUDAPrefix(clang::ASTContext *astCtx)
static mlir::Type higherPrecisionElementTypeForComplexArithmetic(mlir::MLIRContext &context, clang::ASTContext &cc, CIRBaseBuilderTy &builder, mlir::Type elementType)
static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op)
static mlir::Value lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag, mlir::MLIRContext &mlirCx, clang::ASTContext &cc)
Defines the clang::Module class, which describes a module in the source code.
static bool compare(const PathDiagnostic &X, const PathDiagnostic &Y)
Defines the SourceManager interface.
Defines various enumerations that describe declaration and type specifiers.
Defines the TargetCXXABI class, which abstracts details of the C++ ABI that we're targeting.
mlir::Value createDiv(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::TypedAttr getConstNullPtrAttr(mlir::Type t)
mlir::Value createLogicalOr(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createSub(mlir::Location loc, mlir::Value lhs, mlir::Value rhs, OverflowBehavior ob=OverflowBehavior::None)
cir::ConditionOp createCondition(mlir::Value condition)
Create a loop condition.
cir::CopyOp createCopy(mlir::Value dst, mlir::Value src, bool isVolatile=false, bool skipTailPadding=false)
Create a copy with inferred length.
cir::VoidType getVoidTy()
cir::ConstantOp getNullValue(mlir::Type ty, mlir::Location loc)
mlir::Value createCast(mlir::Location loc, cir::CastKind kind, mlir::Value src, mlir::Type newTy)
cir::PointerType getVoidFnPtrTy(mlir::TypeRange argTypes={})
Returns void (*)(T...) as a cir::PointerType.
mlir::Value createFDiv(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createAdd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs, OverflowBehavior ob=OverflowBehavior::None)
cir::PointerType getPointerTo(mlir::Type ty)
mlir::Value createFNeg(mlir::Location loc, mlir::Value operand)
mlir::Value createFAdd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand)
cir::ConstantOp getNullPtr(mlir::Type ty, mlir::Location loc)
cir::IntType getUIntNTy(int n)
cir::DoWhileOp createDoWhile(mlir::Location loc, llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> condBuilder, llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> bodyBuilder)
Create a do-while operation.
cir::GetGlobalOp createGetGlobal(mlir::Location loc, cir::GlobalOp global, bool threadLocal=false)
mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType, llvm::StringRef name, mlir::IntegerAttr alignment, mlir::Value dynAllocSize)
mlir::Value getSignedInt(mlir::Location loc, int64_t val, unsigned numBits)
mlir::Value createAnd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createBitcast(mlir::Value src, mlir::Type newTy)
mlir::Value createFMul(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
cir::FuncType getVoidFnTy(mlir::TypeRange argTypes={})
Returns void (T...) as a cir::FuncType.
cir::CmpOp createCompare(mlir::Location loc, cir::CmpOpKind kind, mlir::Value lhs, mlir::Value rhs)
mlir::IntegerAttr getAlignmentAttr(clang::CharUnits alignment)
mlir::Value createSelect(mlir::Location loc, mlir::Value condition, mlir::Value trueValue, mlir::Value falseValue)
mlir::Value createMul(mlir::Location loc, mlir::Value lhs, mlir::Value rhs, OverflowBehavior ob=OverflowBehavior::None)
cir::LoadOp createLoad(mlir::Location loc, mlir::Value ptr, bool isVolatile=false, uint64_t alignment=0)
mlir::Value createMinus(mlir::Location loc, mlir::Value input, bool nsw=false)
cir::ConstantOp getConstantInt(mlir::Location loc, mlir::Type ty, int64_t value)
mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real, mlir::Value imag)
cir::PointerType getVoidPtrTy(clang::LangAS langAS=clang::LangAS::Default)
mlir::Value createIsNaN(mlir::Location loc, mlir::Value operand)
cir::IntType getSIntNTy(int n)
mlir::Value createAlignedLoad(mlir::Location loc, mlir::Value ptr, uint64_t alignment)
cir::CallOp createCallOp(mlir::Location loc, mlir::SymbolRefAttr callee, mlir::Type returnType, mlir::ValueRange operands, llvm::ArrayRef< mlir::NamedAttribute > attrs={}, llvm::ArrayRef< mlir::NamedAttrList > argAttrs={}, llvm::ArrayRef< mlir::NamedAttribute > resAttrs={})
cir::StoreOp createStore(mlir::Location loc, mlir::Value val, mlir::Value dst, bool isVolatile=false, mlir::IntegerAttr align={}, cir::SyncScopeKindAttr scope={}, cir::MemOrderAttr order={})
cir::YieldOp createYield(mlir::Location loc, mlir::ValueRange value={})
Create a yield operation.
mlir::Value createLogicalAnd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createFSub(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
cir::BoolType getBoolTy()
mlir::Value getUnsignedInt(mlir::Location loc, uint64_t val, unsigned numBits)
mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand)
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
Definition ASTContext.h:223
SourceManager & getSourceManager()
Definition ASTContext.h:863
MangleContext * createMangleContext(const TargetInfo *T=nullptr)
If T is null pointer, assume the target in ASTContext.
const LangOptions & getLangOpts() const
Definition ASTContext.h:959
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
const TargetInfo & getTargetInfo() const
Definition ASTContext.h:921
QualType getSignedSizeType() const
Return the unique signed counterpart of the integer type corresponding to size_t.
Module * getCurrentNamedModule() const
Get module under construction, nullptr if this is not a C++20 module.
uint64_t getCharWidth() const
Return the size of the character type, in bits.
llvm::Align getAsAlign() const
getAsAlign - Returns Quantity as a valid llvm::Align, Beware llvm::Align assumes power of two 8-bit b...
Definition CharUnits.h:189
QuantityType getQuantity() const
getQuantity - Get the raw integer representation of this quantity.
Definition CharUnits.h:185
static CharUnits One()
One - Construct a CharUnits quantity of one.
Definition CharUnits.h:58
static CharUnits fromQuantity(QuantityType Quantity)
fromQuantity - Construct a CharUnits quantity from a raw integer type.
Definition CharUnits.h:63
llvm::vfs::FileSystem & getVirtualFileSystem() const
bool isModuleImplementation() const
Is this a module implementation.
Definition Module.h:882
FileManager & getFileManager() const
Exposes information about the current target.
Definition TargetInfo.h:227
const llvm::Triple & getTriple() const
Returns the target triple of the primary target.
unsigned getMaxAtomicInlineWidth() const
Return the maximum width lock-free atomic operation which can be inlined given the supported features...
Definition TargetInfo.h:859
const llvm::fltSemantics & getDoubleFormat() const
Definition TargetInfo.h:804
const llvm::fltSemantics & getHalfFormat() const
Definition TargetInfo.h:789
const llvm::fltSemantics & getBFloat16Format() const
Definition TargetInfo.h:799
const llvm::fltSemantics & getLongDoubleFormat() const
Definition TargetInfo.h:810
const llvm::fltSemantics & getFloatFormat() const
Definition TargetInfo.h:794
virtual uint64_t getMaxPointerWidth() const
Return the maximum width of pointers on this target.
Definition TargetInfo.h:500
const llvm::fltSemantics & getFloat128Format() const
Definition TargetInfo.h:818
const llvm::VersionTuple & getSDKVersion() const
Defines the clang::TargetInfo interface.
static bool isLocalLinkage(GlobalLinkageKind linkage)
Definition CIROpsEnums.h:51
static bool isWeakODRLinkage(GlobalLinkageKind linkage)
Definition CIROpsEnums.h:39
static bool isLinkOnceLinkage(GlobalLinkageKind linkage)
Definition CIROpsEnums.h:33
const internal::VariadicDynCastAllOfMatcher< Decl, VarDecl > varDecl
Matches variable declarations.
bool isHIP(ID Id)
isHIP - Is this a HIP input.
Definition Types.cpp:314
RangeSelector name(std::string ID)
Given a node with a "name", (like NamedDecl, DeclRefExpr, CxxCtorInitializer, and TypeLoc) selects th...
bool isTemplateInstantiation(TemplateSpecializationKind Kind)
Determine whether this template specialization kind refers to an instantiation of an entity (as oppos...
Definition Specifiers.h:213
bool CudaFeatureEnabled(llvm::VersionTuple, CudaFeature)
Definition Cuda.cpp:172
LLVM_READONLY bool isPreprocessingNumberBody(unsigned char c)
Return true if this is the body character of a C preprocessing number, which is [a-zA-Z0-9_.
Definition CharInfo.h:168
@ CUDA_USES_FATBIN_REGISTER_END
Definition Cuda.h:82
std::unique_ptr< Pass > createLoweringPreparePass()
__packed_splat4 __packed_splat2 __packed_splat8 __packed_splat4 __packed_splat2 __packed_splat4 __packed_splat2 __packed_splat8 __packed_splat4 uint32_t
static bool hipModuleCtor()
static bool guardAbortOnException()
static bool opGlobalAnnotations()
static bool opGlobalCtorPriority()
static bool shouldSplitConstantStore()
static bool shouldUseMemSetToInitialize()
static bool opFuncExtraAttrs()
static bool shouldUseBZeroPlusStoresToInitialize()
static bool fastMathFlags()
static bool astVarDeclInterface()