clang 23.0.0git
CIRDialect.cpp
Go to the documentation of this file.
1//===- CIRDialect.cpp - MLIR CIR ops implementation -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the CIR dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
18
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/DialectImplementation.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Interfaces/ControlFlowInterfaces.h"
23#include "mlir/Interfaces/FunctionImplementation.h"
24#include "mlir/Support/LLVM.h"
25
26#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
27#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
29#include "llvm/ADT/SetOperations.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/LogicalResult.h"
33
34using namespace mlir;
35using namespace cir;
36
37//===----------------------------------------------------------------------===//
38// CIR Dialect
39//===----------------------------------------------------------------------===//
40namespace {
41struct CIROpAsmDialectInterface : public OpAsmDialectInterface {
42 using OpAsmDialectInterface::OpAsmDialectInterface;
43
44 AliasResult getAlias(Type type, raw_ostream &os) const final {
45 if (auto recordType = dyn_cast<cir::RecordType>(type)) {
46 StringAttr nameAttr = recordType.getName();
47 if (!nameAttr)
48 os << "rec_anon_" << recordType.getKindAsStr();
49 else
50 os << "rec_" << nameAttr.getValue();
51 return AliasResult::OverridableAlias;
52 }
53 if (auto intType = dyn_cast<cir::IntType>(type)) {
54 // We only provide alias for standard integer types (i.e. integer types
55 // whose width is a power of 2 and at least 8).
56 unsigned width = intType.getWidth();
57 if (width < 8 || !llvm::isPowerOf2_32(width))
58 return AliasResult::NoAlias;
59 os << intType.getAlias();
60 return AliasResult::OverridableAlias;
61 }
62 if (auto voidType = dyn_cast<cir::VoidType>(type)) {
63 os << voidType.getAlias();
64 return AliasResult::OverridableAlias;
65 }
66
67 return AliasResult::NoAlias;
68 }
69
70 AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
71 if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr)) {
72 os << (boolAttr.getValue() ? "true" : "false");
73 return AliasResult::FinalAlias;
74 }
75 if (auto bitfield = mlir::dyn_cast<cir::BitfieldInfoAttr>(attr)) {
76 os << "bfi_" << bitfield.getName().str();
77 return AliasResult::FinalAlias;
78 }
79 if (auto dynCastInfoAttr = mlir::dyn_cast<cir::DynamicCastInfoAttr>(attr)) {
80 os << dynCastInfoAttr.getAlias();
81 return AliasResult::FinalAlias;
82 }
83 return AliasResult::NoAlias;
84 }
85};
86} // namespace
87
88void cir::CIRDialect::initialize() {
89 registerTypes();
90 registerAttributes();
91 addOperations<
92#define GET_OP_LIST
93#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
94 >();
95 addInterfaces<CIROpAsmDialectInterface>();
96}
97
98Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
99 mlir::Attribute value,
100 mlir::Type type,
101 mlir::Location loc) {
102 return cir::ConstantOp::create(builder, loc, type,
103 mlir::cast<mlir::TypedAttr>(value));
104}
105
106//===----------------------------------------------------------------------===//
107// Helpers
108//===----------------------------------------------------------------------===//
109
110// Parses one of the keywords provided in the list `keywords` and returns the
111// position of the parsed keyword in the list. If none of the keywords from the
112// list is parsed, returns -1.
113static int parseOptionalKeywordAlternative(AsmParser &parser,
114 ArrayRef<llvm::StringRef> keywords) {
115 for (auto en : llvm::enumerate(keywords)) {
116 if (succeeded(parser.parseOptionalKeyword(en.value())))
117 return en.index();
118 }
119 return -1;
120}
121
122namespace {
123template <typename Ty> struct EnumTraits {};
124
125#define REGISTER_ENUM_TYPE(Ty) \
126 template <> struct EnumTraits<cir::Ty> { \
127 static llvm::StringRef stringify(cir::Ty value) { \
128 return stringify##Ty(value); \
129 } \
130 static unsigned getMaxEnumVal() { return cir::getMaxEnumValFor##Ty(); } \
131 }
132
133REGISTER_ENUM_TYPE(GlobalLinkageKind);
134REGISTER_ENUM_TYPE(VisibilityKind);
135REGISTER_ENUM_TYPE(SideEffect);
136} // namespace
137
138/// Parse an enum from the keyword, or default to the provided default value.
139/// The return type is the enum type by default, unless overriden with the
140/// second template argument.
141template <typename EnumTy, typename RetTy = EnumTy>
142static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue) {
144 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
145 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
146
147 int index = parseOptionalKeywordAlternative(parser, names);
148 if (index == -1)
149 return static_cast<RetTy>(defaultValue);
150 return static_cast<RetTy>(index);
151}
152
153/// Parse an enum from the keyword, return failure if the keyword is not found.
154template <typename EnumTy, typename RetTy = EnumTy>
155static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result) {
157 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
158 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
159
160 int index = parseOptionalKeywordAlternative(parser, names);
161 if (index == -1)
162 return failure();
163 result = static_cast<RetTy>(index);
164 return success();
165}
166
167// Check if a region's termination omission is valid and, if so, creates and
168// inserts the omitted terminator into the region.
169static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region,
170 SMLoc errLoc) {
171 Location eLoc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
172 OpBuilder builder(parser.getBuilder().getContext());
173
174 // Insert empty block in case the region is empty to ensure the terminator
175 // will be inserted
176 if (region.empty())
177 builder.createBlock(&region);
178
179 Block &block = region.back();
180 // Region is properly terminated: nothing to do.
181 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>())
182 return success();
183
184 // Check for invalid terminator omissions.
185 if (!region.hasOneBlock())
186 return parser.emitError(errLoc,
187 "multi-block region must not omit terminator");
188
189 // Terminator was omitted correctly: recreate it.
190 builder.setInsertionPointToEnd(&block);
191 cir::YieldOp::create(builder, eLoc);
192 return success();
193}
194
195// True if the region's terminator should be omitted.
196static bool omitRegionTerm(mlir::Region &r) {
197 const auto singleNonEmptyBlock = r.hasOneBlock() && !r.back().empty();
198 const auto yieldsNothing = [&r]() {
199 auto y = dyn_cast<cir::YieldOp>(r.back().getTerminator());
200 return y && y.getArgs().empty();
201 };
202 return singleNonEmptyBlock && yieldsNothing();
203}
204
205void printVisibilityAttr(OpAsmPrinter &printer,
206 cir::VisibilityAttr &visibility) {
207 switch (visibility.getValue()) {
208 case cir::VisibilityKind::Hidden:
209 printer << "hidden";
210 break;
211 case cir::VisibilityKind::Protected:
212 printer << "protected";
213 break;
214 case cir::VisibilityKind::Default:
215 break;
216 }
217}
218
219void parseVisibilityAttr(OpAsmParser &parser, cir::VisibilityAttr &visibility) {
220 cir::VisibilityKind visibilityKind =
221 parseOptionalCIRKeyword(parser, cir::VisibilityKind::Default);
222 visibility = cir::VisibilityAttr::get(parser.getContext(), visibilityKind);
223}
224
225//===----------------------------------------------------------------------===//
226// InlineKindAttr (FIXME: remove once FuncOp uses assembly format)
227//===----------------------------------------------------------------------===//
228
229ParseResult parseInlineKindAttr(OpAsmParser &parser,
230 cir::InlineKindAttr &inlineKindAttr) {
231 // Static list of possible inline kind keywords
232 static constexpr llvm::StringRef keywords[] = {"no_inline", "always_inline",
233 "inline_hint"};
234
235 // Parse the inline kind keyword (optional)
236 llvm::StringRef keyword;
237 if (parser.parseOptionalKeyword(&keyword, keywords).failed()) {
238 // Not an inline kind keyword, leave inlineKindAttr empty
239 return success();
240 }
241
242 // Parse the enum value from the keyword
243 auto inlineKindResult = ::cir::symbolizeEnum<::cir::InlineKind>(keyword);
244 if (!inlineKindResult) {
245 return parser.emitError(parser.getCurrentLocation(), "expected one of [")
246 << llvm::join(llvm::ArrayRef(keywords), ", ")
247 << "] for inlineKind, got: " << keyword;
248 }
249
250 inlineKindAttr =
251 ::cir::InlineKindAttr::get(parser.getContext(), *inlineKindResult);
252 return success();
253}
254
255void printInlineKindAttr(OpAsmPrinter &p, cir::InlineKindAttr inlineKindAttr) {
256 if (inlineKindAttr) {
257 p << " " << stringifyInlineKind(inlineKindAttr.getValue());
258 }
259}
260//===----------------------------------------------------------------------===//
261// CIR Custom Parsers/Printers
262//===----------------------------------------------------------------------===//
263
264static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser,
265 mlir::Region &region) {
266 auto regionLoc = parser.getCurrentLocation();
267 if (parser.parseRegion(region))
268 return failure();
269 if (ensureRegionTerm(parser, region, regionLoc).failed())
270 return failure();
271 return success();
272}
273
274static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer,
275 cir::ScopeOp &op,
276 mlir::Region &region) {
277 printer.printRegion(region,
278 /*printEntryBlockArgs=*/false,
279 /*printBlockTerminators=*/!omitRegionTerm(region));
280}
281
282//===----------------------------------------------------------------------===//
283// AllocaOp
284//===----------------------------------------------------------------------===//
285
286void cir::AllocaOp::build(mlir::OpBuilder &odsBuilder,
287 mlir::OperationState &odsState, mlir::Type addr,
288 mlir::Type allocaType, llvm::StringRef name,
289 mlir::IntegerAttr alignment) {
290 odsState.addAttribute(getAllocaTypeAttrName(odsState.name),
291 mlir::TypeAttr::get(allocaType));
292 odsState.addAttribute(getNameAttrName(odsState.name),
293 odsBuilder.getStringAttr(name));
294 if (alignment) {
295 odsState.addAttribute(getAlignmentAttrName(odsState.name), alignment);
296 }
297 odsState.addTypes(addr);
298}
299
300//===----------------------------------------------------------------------===//
301// BreakOp
302//===----------------------------------------------------------------------===//
303
304LogicalResult cir::BreakOp::verify() {
305 if (!getOperation()->getParentOfType<LoopOpInterface>() &&
306 !getOperation()->getParentOfType<SwitchOp>())
307 return emitOpError("must be within a loop");
308 return success();
309}
310
311//===----------------------------------------------------------------------===//
312// ConditionOp
313//===----------------------------------------------------------------------===//
314
315//===----------------------------------
316// BranchOpTerminatorInterface Methods
317//===----------------------------------
318
319void cir::ConditionOp::getSuccessorRegions(
320 ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
321 // TODO(cir): The condition value may be folded to a constant, narrowing
322 // down its list of possible successors.
323
324 // Parent is a loop: condition may branch to the body or to the parent op.
325 if (auto loopOp = dyn_cast<LoopOpInterface>(getOperation()->getParentOp())) {
326 regions.emplace_back(&loopOp.getBody());
327 regions.push_back(RegionSuccessor::parent());
328 }
329
330 // Parent is an await: condition may branch to resume or suspend regions.
331 auto await = cast<AwaitOp>(getOperation()->getParentOp());
332 regions.emplace_back(&await.getResume());
333 regions.emplace_back(&await.getSuspend());
334}
335
336MutableOperandRange
337cir::ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
338 // No values are yielded to the successor region.
339 return MutableOperandRange(getOperation(), 0, 0);
340}
341
342MutableOperandRange
343cir::ResumeOp::getMutableSuccessorOperands(RegionSuccessor point) {
344 // The eh_token operand is not forwarded to the parent region.
345 return MutableOperandRange(getOperation(), 0, 0);
346}
347
348LogicalResult cir::ConditionOp::verify() {
349 if (!isa<LoopOpInterface, AwaitOp>(getOperation()->getParentOp()))
350 return emitOpError("condition must be within a conditional region");
351 return success();
352}
353
354//===----------------------------------------------------------------------===//
355// ConstantOp
356//===----------------------------------------------------------------------===//
357
358static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
359 mlir::Attribute attrType) {
360 if (isa<cir::ConstPtrAttr>(attrType)) {
361 if (!mlir::isa<cir::PointerType>(opType))
362 return op->emitOpError(
363 "pointer constant initializing a non-pointer type");
364 return success();
365 }
366
367 if (isa<cir::DataMemberAttr, cir::MethodAttr>(attrType)) {
368 // More detailed type verifications are already done in
369 // DataMemberAttr::verify or MethodAttr::verify. Don't need to repeat here.
370 return success();
371 }
372
373 if (isa<cir::ZeroAttr>(attrType)) {
374 if (isa<cir::RecordType, cir::ArrayType, cir::VectorType, cir::ComplexType>(
375 opType))
376 return success();
377 return op->emitOpError(
378 "zero expects struct, array, vector, or complex type");
379 }
380
381 if (mlir::isa<cir::UndefAttr>(attrType)) {
382 if (!mlir::isa<cir::VoidType>(opType))
383 return success();
384 return op->emitOpError("undef expects non-void type");
385 }
386
387 if (mlir::isa<cir::BoolAttr>(attrType)) {
388 if (!mlir::isa<cir::BoolType>(opType))
389 return op->emitOpError("result type (")
390 << opType << ") must be '!cir.bool' for '" << attrType << "'";
391 return success();
392 }
393
394 if (mlir::isa<cir::IntAttr, cir::FPAttr>(attrType)) {
395 auto at = cast<TypedAttr>(attrType);
396 if (at.getType() != opType) {
397 return op->emitOpError("result type (")
398 << opType << ") does not match value type (" << at.getType()
399 << ")";
400 }
401 return success();
402 }
403
404 if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
405 cir::ConstComplexAttr, cir::ConstRecordAttr,
406 cir::GlobalViewAttr, cir::PoisonAttr, cir::TypeInfoAttr,
407 cir::VTableAttr>(attrType))
408 return success();
409
410 assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
411 return op->emitOpError("global with type ")
412 << cast<TypedAttr>(attrType).getType() << " not yet supported";
413}
414
415LogicalResult cir::ConstantOp::verify() {
416 // ODS already generates checks to make sure the result type is valid. We just
417 // need to additionally check that the value's attribute type is consistent
418 // with the result type.
419 return checkConstantTypes(getOperation(), getType(), getValue());
420}
421
422OpFoldResult cir::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
423 return getValue();
424}
425
426//===----------------------------------------------------------------------===//
427// ContinueOp
428//===----------------------------------------------------------------------===//
429
430LogicalResult cir::ContinueOp::verify() {
431 if (!getOperation()->getParentOfType<LoopOpInterface>())
432 return emitOpError("must be within a loop");
433 return success();
434}
435
436//===----------------------------------------------------------------------===//
437// CastOp
438//===----------------------------------------------------------------------===//
439
440LogicalResult cir::CastOp::verify() {
441 mlir::Type resType = getType();
442 mlir::Type srcType = getSrc().getType();
443
444 // Verify address space casts for pointer types. given that
445 // casts for within a different address space are illegal.
446 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
447 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
448 if (srcPtrTy && resPtrTy && (getKind() != cir::CastKind::address_space))
449 if (srcPtrTy.getAddrSpace() != resPtrTy.getAddrSpace()) {
450 return emitOpError() << "result type address space does not match the "
451 "address space of the operand";
452 }
453
454 if (mlir::isa<cir::VectorType>(srcType) &&
455 mlir::isa<cir::VectorType>(resType)) {
456 // Use the element type of the vector to verify the cast kind. (Except for
457 // bitcast, see below.)
458 srcType = mlir::dyn_cast<cir::VectorType>(srcType).getElementType();
459 resType = mlir::dyn_cast<cir::VectorType>(resType).getElementType();
460 }
461
462 switch (getKind()) {
463 case cir::CastKind::int_to_bool: {
464 if (!mlir::isa<cir::BoolType>(resType))
465 return emitOpError() << "requires !cir.bool type for result";
466 if (!mlir::isa<cir::IntType>(srcType))
467 return emitOpError() << "requires !cir.int type for source";
468 return success();
469 }
470 case cir::CastKind::ptr_to_bool: {
471 if (!mlir::isa<cir::BoolType>(resType))
472 return emitOpError() << "requires !cir.bool type for result";
473 if (!mlir::isa<cir::PointerType>(srcType))
474 return emitOpError() << "requires !cir.ptr type for source";
475 return success();
476 }
477 case cir::CastKind::integral: {
478 if (!mlir::isa<cir::IntType>(resType))
479 return emitOpError() << "requires !cir.int type for result";
480 if (!mlir::isa<cir::IntType>(srcType))
481 return emitOpError() << "requires !cir.int type for source";
482 return success();
483 }
484 case cir::CastKind::array_to_ptrdecay: {
485 const auto arrayPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
486 const auto flatPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
487 if (!arrayPtrTy || !flatPtrTy)
488 return emitOpError() << "requires !cir.ptr type for source and result";
489
490 // TODO(CIR): Make sure the AddrSpace of both types are equals
491 return success();
492 }
493 case cir::CastKind::bitcast: {
494 // Handle the pointer types first.
495 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
496 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
497
498 if (srcPtrTy && resPtrTy) {
499 return success();
500 }
501
502 return success();
503 }
504 case cir::CastKind::floating: {
505 if (!mlir::isa<cir::FPTypeInterface>(srcType) ||
506 !mlir::isa<cir::FPTypeInterface>(resType))
507 return emitOpError() << "requires !cir.float type for source and result";
508 return success();
509 }
510 case cir::CastKind::float_to_int: {
511 if (!mlir::isa<cir::FPTypeInterface>(srcType))
512 return emitOpError() << "requires !cir.float type for source";
513 if (!mlir::dyn_cast<cir::IntType>(resType))
514 return emitOpError() << "requires !cir.int type for result";
515 return success();
516 }
517 case cir::CastKind::int_to_ptr: {
518 if (!mlir::dyn_cast<cir::IntType>(srcType))
519 return emitOpError() << "requires !cir.int type for source";
520 if (!mlir::dyn_cast<cir::PointerType>(resType))
521 return emitOpError() << "requires !cir.ptr type for result";
522 return success();
523 }
524 case cir::CastKind::ptr_to_int: {
525 if (!mlir::dyn_cast<cir::PointerType>(srcType))
526 return emitOpError() << "requires !cir.ptr type for source";
527 if (!mlir::dyn_cast<cir::IntType>(resType))
528 return emitOpError() << "requires !cir.int type for result";
529 return success();
530 }
531 case cir::CastKind::float_to_bool: {
532 if (!mlir::isa<cir::FPTypeInterface>(srcType))
533 return emitOpError() << "requires !cir.float type for source";
534 if (!mlir::isa<cir::BoolType>(resType))
535 return emitOpError() << "requires !cir.bool type for result";
536 return success();
537 }
538 case cir::CastKind::bool_to_int: {
539 if (!mlir::isa<cir::BoolType>(srcType))
540 return emitOpError() << "requires !cir.bool type for source";
541 if (!mlir::isa<cir::IntType>(resType))
542 return emitOpError() << "requires !cir.int type for result";
543 return success();
544 }
545 case cir::CastKind::int_to_float: {
546 if (!mlir::isa<cir::IntType>(srcType))
547 return emitOpError() << "requires !cir.int type for source";
548 if (!mlir::isa<cir::FPTypeInterface>(resType))
549 return emitOpError() << "requires !cir.float type for result";
550 return success();
551 }
552 case cir::CastKind::bool_to_float: {
553 if (!mlir::isa<cir::BoolType>(srcType))
554 return emitOpError() << "requires !cir.bool type for source";
555 if (!mlir::isa<cir::FPTypeInterface>(resType))
556 return emitOpError() << "requires !cir.float type for result";
557 return success();
558 }
559 case cir::CastKind::address_space: {
560 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
561 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
562 if (!srcPtrTy || !resPtrTy)
563 return emitOpError() << "requires !cir.ptr type for source and result";
564 if (srcPtrTy.getPointee() != resPtrTy.getPointee())
565 return emitOpError() << "requires two types differ in addrspace only";
566 return success();
567 }
568 case cir::CastKind::float_to_complex: {
569 if (!mlir::isa<cir::FPTypeInterface>(srcType))
570 return emitOpError() << "requires !cir.float type for source";
571 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
572 if (!resComplexTy)
573 return emitOpError() << "requires !cir.complex type for result";
574 if (srcType != resComplexTy.getElementType())
575 return emitOpError() << "requires source type match result element type";
576 return success();
577 }
578 case cir::CastKind::int_to_complex: {
579 if (!mlir::isa<cir::IntType>(srcType))
580 return emitOpError() << "requires !cir.int type for source";
581 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
582 if (!resComplexTy)
583 return emitOpError() << "requires !cir.complex type for result";
584 if (srcType != resComplexTy.getElementType())
585 return emitOpError() << "requires source type match result element type";
586 return success();
587 }
588 case cir::CastKind::float_complex_to_real: {
589 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
590 if (!srcComplexTy)
591 return emitOpError() << "requires !cir.complex type for source";
592 if (!mlir::isa<cir::FPTypeInterface>(resType))
593 return emitOpError() << "requires !cir.float type for result";
594 if (srcComplexTy.getElementType() != resType)
595 return emitOpError() << "requires source element type match result type";
596 return success();
597 }
598 case cir::CastKind::int_complex_to_real: {
599 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
600 if (!srcComplexTy)
601 return emitOpError() << "requires !cir.complex type for source";
602 if (!mlir::isa<cir::IntType>(resType))
603 return emitOpError() << "requires !cir.int type for result";
604 if (srcComplexTy.getElementType() != resType)
605 return emitOpError() << "requires source element type match result type";
606 return success();
607 }
608 case cir::CastKind::float_complex_to_bool: {
609 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
610 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
611 return emitOpError()
612 << "requires floating point !cir.complex type for source";
613 if (!mlir::isa<cir::BoolType>(resType))
614 return emitOpError() << "requires !cir.bool type for result";
615 return success();
616 }
617 case cir::CastKind::int_complex_to_bool: {
618 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
619 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
620 return emitOpError()
621 << "requires floating point !cir.complex type for source";
622 if (!mlir::isa<cir::BoolType>(resType))
623 return emitOpError() << "requires !cir.bool type for result";
624 return success();
625 }
626 case cir::CastKind::float_complex: {
627 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
628 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
629 return emitOpError()
630 << "requires floating point !cir.complex type for source";
631 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
632 if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
633 return emitOpError()
634 << "requires floating point !cir.complex type for result";
635 return success();
636 }
637 case cir::CastKind::float_complex_to_int_complex: {
638 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
639 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
640 return emitOpError()
641 << "requires floating point !cir.complex type for source";
642 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
643 if (!resComplexTy || !resComplexTy.isIntegerComplex())
644 return emitOpError() << "requires integer !cir.complex type for result";
645 return success();
646 }
647 case cir::CastKind::int_complex: {
648 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
649 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
650 return emitOpError() << "requires integer !cir.complex type for source";
651 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
652 if (!resComplexTy || !resComplexTy.isIntegerComplex())
653 return emitOpError() << "requires integer !cir.complex type for result";
654 return success();
655 }
656 case cir::CastKind::int_complex_to_float_complex: {
657 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
658 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
659 return emitOpError() << "requires integer !cir.complex type for source";
660 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
661 if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
662 return emitOpError()
663 << "requires floating point !cir.complex type for result";
664 return success();
665 }
666 case cir::CastKind::member_ptr_to_bool: {
667 if (!mlir::isa<cir::DataMemberType, cir::MethodType>(srcType))
668 return emitOpError()
669 << "requires !cir.data_member or !cir.method type for source";
670 if (!mlir::isa<cir::BoolType>(resType))
671 return emitOpError() << "requires !cir.bool type for result";
672 return success();
673 }
674 }
675 llvm_unreachable("Unknown CastOp kind?");
676}
677
678static bool isIntOrBoolCast(cir::CastOp op) {
679 auto kind = op.getKind();
680 return kind == cir::CastKind::bool_to_int ||
681 kind == cir::CastKind::int_to_bool || kind == cir::CastKind::integral;
682}
683
684static Value tryFoldCastChain(cir::CastOp op) {
685 cir::CastOp head = op, tail = op;
686
687 while (op) {
688 if (!isIntOrBoolCast(op))
689 break;
690 head = op;
691 op = head.getSrc().getDefiningOp<cir::CastOp>();
692 }
693
694 if (head == tail)
695 return {};
696
697 // if bool_to_int -> ... -> int_to_bool: take the bool
698 // as we had it was before all casts
699 if (head.getKind() == cir::CastKind::bool_to_int &&
700 tail.getKind() == cir::CastKind::int_to_bool)
701 return head.getSrc();
702
703 // if int_to_bool -> ... -> int_to_bool: take the result
704 // of the first one, as no other casts (and ext casts as well)
705 // don't change the first result
706 if (head.getKind() == cir::CastKind::int_to_bool &&
707 tail.getKind() == cir::CastKind::int_to_bool)
708 return head.getResult();
709
710 return {};
711}
712
713OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
714 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getSrc())) {
715 // Propagate poison value
716 return cir::PoisonAttr::get(getContext(), getType());
717 }
718
719 if (getSrc().getType() == getType()) {
720 switch (getKind()) {
721 case cir::CastKind::integral: {
723 auto foldOrder = getSrc().getDefiningOp()->fold(foldResults);
724 if (foldOrder.succeeded() && mlir::isa<mlir::Attribute>(foldResults[0]))
725 return mlir::cast<mlir::Attribute>(foldResults[0]);
726 return {};
727 }
728 case cir::CastKind::bitcast:
729 case cir::CastKind::address_space:
730 case cir::CastKind::float_complex:
731 case cir::CastKind::int_complex: {
732 return getSrc();
733 }
734 default:
735 return {};
736 }
737 }
738
739 // Handle cases where a chain of casts cancel out.
740 Value result = tryFoldCastChain(*this);
741 if (result)
742 return result;
743
744 // Handle simple constant casts.
745 if (auto srcConst = getSrc().getDefiningOp<cir::ConstantOp>()) {
746 switch (getKind()) {
747 case cir::CastKind::integral: {
748 mlir::Type srcTy = getSrc().getType();
749 // Don't try to fold vector casts for now.
750 assert(mlir::isa<cir::VectorType>(srcTy) ==
751 mlir::isa<cir::VectorType>(getType()));
752 if (mlir::isa<cir::VectorType>(srcTy))
753 break;
754
755 auto srcIntTy = mlir::cast<cir::IntType>(srcTy);
756 auto dstIntTy = mlir::cast<cir::IntType>(getType());
757 APInt newVal =
758 srcIntTy.isSigned()
759 ? srcConst.getIntValue().sextOrTrunc(dstIntTy.getWidth())
760 : srcConst.getIntValue().zextOrTrunc(dstIntTy.getWidth());
761 return cir::IntAttr::get(dstIntTy, newVal);
762 }
763 default:
764 break;
765 }
766 }
767 return {};
768}
769
770//===----------------------------------------------------------------------===//
771// CallOp
772//===----------------------------------------------------------------------===//
773
774mlir::OperandRange cir::CallOp::getArgOperands() {
775 if (isIndirect())
776 return getArgs().drop_front(1);
777 return getArgs();
778}
779
780mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
781 mlir::MutableOperandRange args = getArgsMutable();
782 if (isIndirect())
783 return args.slice(1, args.size() - 1);
784 return args;
785}
786
787mlir::Value cir::CallOp::getIndirectCall() {
788 assert(isIndirect());
789 return getOperand(0);
790}
791
792/// Return the operand at index 'i'.
793Value cir::CallOp::getArgOperand(unsigned i) {
794 if (isIndirect())
795 ++i;
796 return getOperand(i);
797}
798
799/// Return the number of operands.
800unsigned cir::CallOp::getNumArgOperands() {
801 if (isIndirect())
802 return this->getOperation()->getNumOperands() - 1;
803 return this->getOperation()->getNumOperands();
804}
805
806static mlir::ParseResult
807parseTryCallDestinations(mlir::OpAsmParser &parser,
808 mlir::OperationState &result) {
809 mlir::Block *normalDestSuccessor;
810 if (parser.parseSuccessor(normalDestSuccessor))
811 return mlir::failure();
812
813 if (parser.parseComma())
814 return mlir::failure();
815
816 mlir::Block *unwindDestSuccessor;
817 if (parser.parseSuccessor(unwindDestSuccessor))
818 return mlir::failure();
819
820 result.addSuccessors(normalDestSuccessor);
821 result.addSuccessors(unwindDestSuccessor);
822 return mlir::success();
823}
824
825static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
826 mlir::OperationState &result,
827 bool hasDestinationBlocks = false) {
829 llvm::SMLoc opsLoc;
830 mlir::FlatSymbolRefAttr calleeAttr;
831
832 // If we cannot parse a string callee, it means this is an indirect call.
833 if (!parser
834 .parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
835 result.attributes)
836 .has_value()) {
837 OpAsmParser::UnresolvedOperand indirectVal;
838 // Do not resolve right now, since we need to figure out the type
839 if (parser.parseOperand(indirectVal).failed())
840 return failure();
841 ops.push_back(indirectVal);
842 }
843
844 if (parser.parseLParen())
845 return mlir::failure();
846
847 opsLoc = parser.getCurrentLocation();
848 if (parser.parseOperandList(ops))
849 return mlir::failure();
850 if (parser.parseRParen())
851 return mlir::failure();
852
853 if (hasDestinationBlocks &&
854 parseTryCallDestinations(parser, result).failed()) {
855 return ::mlir::failure();
856 }
857
858 if (parser.parseOptionalKeyword("nothrow").succeeded())
859 result.addAttribute(CIRDialect::getNoThrowAttrName(),
860 mlir::UnitAttr::get(parser.getContext()));
861
862 if (parser.parseOptionalKeyword("side_effect").succeeded()) {
863 if (parser.parseLParen().failed())
864 return failure();
865 cir::SideEffect sideEffect;
866 if (parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed())
867 return failure();
868 if (parser.parseRParen().failed())
869 return failure();
870 auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
871 result.addAttribute(CIRDialect::getSideEffectAttrName(), attr);
872 }
873
874 if (parser.parseOptionalAttrDict(result.attributes))
875 return ::mlir::failure();
876
877 if (parser.parseColon())
878 return ::mlir::failure();
879
880 SmallVector<Type> argTypes;
882 SmallVector<Type> resultTypes;
883 SmallVector<DictionaryAttr> resultAttrs;
884 if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
885 resultTypes, resultAttrs))
886 return mlir::failure();
887
888 if (resultTypes.size() > 1 || resultAttrs.size() > 1)
889 return parser.emitError(
890 parser.getCurrentLocation(),
891 "functions with multiple return types are not supported");
892
893 result.addTypes(resultTypes);
894
895 if (parser.resolveOperands(ops, argTypes, opsLoc, result.operands))
896 return mlir::failure();
897
898 if (!resultAttrs.empty() && resultAttrs[0])
899 result.addAttribute(
900 CIRDialect::getResAttrsAttrName(),
901 mlir::ArrayAttr::get(parser.getContext(), {resultAttrs[0]}));
902
903 // ArrayAttr requires a vector of 'Attribute', so we have to do the conversion
904 // here into a separate collection.
905 llvm::SmallVector<Attribute> convertedArgAttrs;
906 bool argAttrsEmpty = true;
907
908 llvm::transform(argAttrs, std::back_inserter(convertedArgAttrs),
909 [&](DictionaryAttr da) -> mlir::Attribute {
910 if (da)
911 argAttrsEmpty = false;
912 return da;
913 });
914
915 if (!argAttrsEmpty) {
916 llvm::ArrayRef argAttrsRef = convertedArgAttrs;
917 if (!calleeAttr) {
918 // Fixup for indirect calls, which get an extra entry in the 'args' for
919 // the indirect type, which doesn't get attributes.
920 argAttrsRef = argAttrsRef.drop_front();
921 }
922 result.addAttribute(CIRDialect::getArgAttrsAttrName(),
923 mlir::ArrayAttr::get(parser.getContext(), argAttrsRef));
924 }
925
926 return mlir::success();
927}
928
929static void
930printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym,
931 mlir::Value indirectCallee, mlir::OpAsmPrinter &printer,
932 bool isNothrow, cir::SideEffect sideEffect, ArrayAttr argAttrs,
933 ArrayAttr resAttrs, mlir::Block *normalDest = nullptr,
934 mlir::Block *unwindDest = nullptr) {
935 printer << ' ';
936
937 auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
938 auto ops = callLikeOp.getArgOperands();
939
940 if (calleeSym) {
941 // Direct calls
942 printer.printAttributeWithoutType(calleeSym);
943 } else {
944 // Indirect calls
945 assert(indirectCallee);
946 printer << indirectCallee;
947 }
948
949 printer << "(" << ops << ")";
950
951 if (normalDest) {
952 assert(unwindDest && "expected two successors");
953 auto tryCall = cast<cir::TryCallOp>(op);
954 printer << ' ' << tryCall.getNormalDest();
955 printer << ",";
956 printer << ' ';
957 printer << tryCall.getUnwindDest();
958 }
959
960 if (isNothrow)
961 printer << " nothrow";
962
963 if (sideEffect != cir::SideEffect::All) {
964 printer << " side_effect(";
965 printer << stringifySideEffect(sideEffect);
966 printer << ")";
967 }
968
970 CIRDialect::getCalleeAttrName(),
971 CIRDialect::getNoThrowAttrName(),
972 CIRDialect::getSideEffectAttrName(),
973 CIRDialect::getOperandSegmentSizesAttrName(),
974 llvm::StringRef("res_attrs"),
975 llvm::StringRef("arg_attrs")};
976 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
977 printer << " : ";
978 if (calleeSym || !argAttrs) {
979 call_interface_impl::printFunctionSignature(
980 printer, op->getOperands().getTypes(), argAttrs,
981 /*isVariadic=*/false, op->getResultTypes(), resAttrs);
982 } else {
983 // indirect function calls use an 'arg' type for the type of its indirect
984 // argument. However, we don't store a similar attribute collection. In
985 // order to make `printFunctionSignature` have the attributes line up, we
986 // have to make a 'shimmed' copy of the attributes that have a blank set of
987 // attributes for the indirect argument.
988 llvm::SmallVector<Attribute> shimmedArgAttrs;
989 shimmedArgAttrs.push_back(mlir::DictionaryAttr::get(op->getContext(), {}));
990 shimmedArgAttrs.append(argAttrs.begin(), argAttrs.end());
991 call_interface_impl::printFunctionSignature(
992 printer, op->getOperands().getTypes(),
993 mlir::ArrayAttr::get(op->getContext(), shimmedArgAttrs),
994 /*isVariadic=*/false, op->getResultTypes(), resAttrs);
995 }
996}
997
998mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
999 mlir::OperationState &result) {
1000 return parseCallCommon(parser, result);
1001}
1002
1003void cir::CallOp::print(mlir::OpAsmPrinter &p) {
1004 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
1005 cir::SideEffect sideEffect = getSideEffect();
1006 printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
1007 sideEffect, getArgAttrsAttr(), getResAttrsAttr());
1008}
1009
1010static LogicalResult
1011verifyCallCommInSymbolUses(mlir::Operation *op,
1012 SymbolTableCollection &symbolTable) {
1013 auto fnAttr =
1014 op->getAttrOfType<FlatSymbolRefAttr>(CIRDialect::getCalleeAttrName());
1015 if (!fnAttr) {
1016 // This is an indirect call, thus we don't have to check the symbol uses.
1017 return mlir::success();
1018 }
1019
1020 auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
1021 if (!fn)
1022 return op->emitOpError() << "'" << fnAttr.getValue()
1023 << "' does not reference a valid function";
1024
1025 auto callIf = dyn_cast<cir::CIRCallOpInterface>(op);
1026 assert(callIf && "expected CIR call interface to be always available");
1027
1028 // Verify that the operand and result types match the callee. Note that
1029 // argument-checking is disabled for functions without a prototype.
1030 auto fnType = fn.getFunctionType();
1031 if (!fn.getNoProto()) {
1032 unsigned numCallOperands = callIf.getNumArgOperands();
1033 unsigned numFnOpOperands = fnType.getNumInputs();
1034
1035 if (!fnType.isVarArg() && numCallOperands != numFnOpOperands)
1036 return op->emitOpError("incorrect number of operands for callee");
1037 if (fnType.isVarArg() && numCallOperands < numFnOpOperands)
1038 return op->emitOpError("too few operands for callee");
1039
1040 for (unsigned i = 0, e = numFnOpOperands; i != e; ++i)
1041 if (callIf.getArgOperand(i).getType() != fnType.getInput(i))
1042 return op->emitOpError("operand type mismatch: expected operand type ")
1043 << fnType.getInput(i) << ", but provided "
1044 << op->getOperand(i).getType() << " for operand number " << i;
1045 }
1046
1048
1049 // Void function must not return any results.
1050 if (fnType.hasVoidReturn() && op->getNumResults() != 0)
1051 return op->emitOpError("callee returns void but call has results");
1052
1053 // Non-void function calls must return exactly one result.
1054 if (!fnType.hasVoidReturn() && op->getNumResults() != 1)
1055 return op->emitOpError("incorrect number of results for callee");
1056
1057 // Parent function and return value types must match.
1058 if (!fnType.hasVoidReturn() &&
1059 op->getResultTypes().front() != fnType.getReturnType()) {
1060 return op->emitOpError("result type mismatch: expected ")
1061 << fnType.getReturnType() << ", but provided "
1062 << op->getResult(0).getType();
1063 }
1064
1065 return mlir::success();
1066}
1067
1068LogicalResult
1069cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1070 return verifyCallCommInSymbolUses(*this, symbolTable);
1071}
1072
1073//===----------------------------------------------------------------------===//
1074// TryCallOp
1075//===----------------------------------------------------------------------===//
1076
1077mlir::OperandRange cir::TryCallOp::getArgOperands() {
1078 if (isIndirect())
1079 return getArgs().drop_front(1);
1080 return getArgs();
1081}
1082
1083mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
1084 mlir::MutableOperandRange args = getArgsMutable();
1085 if (isIndirect())
1086 return args.slice(1, args.size() - 1);
1087 return args;
1088}
1089
1090mlir::Value cir::TryCallOp::getIndirectCall() {
1091 assert(isIndirect());
1092 return getOperand(0);
1093}
1094
1095/// Return the operand at index 'i'.
1096Value cir::TryCallOp::getArgOperand(unsigned i) {
1097 if (isIndirect())
1098 ++i;
1099 return getOperand(i);
1100}
1101
1102/// Return the number of operands.
1103unsigned cir::TryCallOp::getNumArgOperands() {
1104 if (isIndirect())
1105 return this->getOperation()->getNumOperands() - 1;
1106 return this->getOperation()->getNumOperands();
1107}
1108
1109LogicalResult
1110cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1111 return verifyCallCommInSymbolUses(*this, symbolTable);
1112}
1113
1114mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
1115 mlir::OperationState &result) {
1116 return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
1117}
1118
1119void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
1120 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
1121 cir::SideEffect sideEffect = getSideEffect();
1122 printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
1123 sideEffect, getArgAttrsAttr(), getResAttrsAttr(),
1124 getNormalDest(), getUnwindDest());
1125}
1126
1127//===----------------------------------------------------------------------===//
1128// ReturnOp
1129//===----------------------------------------------------------------------===//
1130
1131static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op,
1132 cir::FuncOp function) {
1133 // ReturnOps currently only have a single optional operand.
1134 if (op.getNumOperands() > 1)
1135 return op.emitOpError() << "expects at most 1 return operand";
1136
1137 // Ensure returned type matches the function signature.
1138 auto expectedTy = function.getFunctionType().getReturnType();
1139 auto actualTy =
1140 (op.getNumOperands() == 0 ? cir::VoidType::get(op.getContext())
1141 : op.getOperand(0).getType());
1142 if (actualTy != expectedTy)
1143 return op.emitOpError() << "returns " << actualTy
1144 << " but enclosing function returns " << expectedTy;
1145
1146 return mlir::success();
1147}
1148
1149mlir::LogicalResult cir::ReturnOp::verify() {
1150 // Returns can be present in multiple different scopes, get the
1151 // wrapping function and start from there.
1152 auto *fnOp = getOperation()->getParentOp();
1153 while (!isa<cir::FuncOp>(fnOp))
1154 fnOp = fnOp->getParentOp();
1155
1156 // Make sure return types match function return type.
1157 if (checkReturnAndFunction(*this, cast<cir::FuncOp>(fnOp)).failed())
1158 return failure();
1159
1160 return success();
1161}
1162
1163//===----------------------------------------------------------------------===//
1164// IfOp
1165//===----------------------------------------------------------------------===//
1166
1167ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
1168 // create the regions for 'then'.
1169 result.regions.reserve(2);
1170 Region *thenRegion = result.addRegion();
1171 Region *elseRegion = result.addRegion();
1172
1173 mlir::Builder &builder = parser.getBuilder();
1174 OpAsmParser::UnresolvedOperand cond;
1175 Type boolType = cir::BoolType::get(builder.getContext());
1176
1177 if (parser.parseOperand(cond) ||
1178 parser.resolveOperand(cond, boolType, result.operands))
1179 return failure();
1180
1181 // Parse 'then' region.
1182 mlir::SMLoc parseThenLoc = parser.getCurrentLocation();
1183 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1184 return failure();
1185
1186 if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed())
1187 return failure();
1188
1189 // If we find an 'else' keyword, parse the 'else' region.
1190 if (!parser.parseOptionalKeyword("else")) {
1191 mlir::SMLoc parseElseLoc = parser.getCurrentLocation();
1192 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1193 return failure();
1194 if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed())
1195 return failure();
1196 }
1197
1198 // Parse the optional attribute list.
1199 if (parser.parseOptionalAttrDict(result.attributes))
1200 return failure();
1201 return success();
1202}
1203
1204void cir::IfOp::print(OpAsmPrinter &p) {
1205 p << " " << getCondition() << " ";
1206 mlir::Region &thenRegion = this->getThenRegion();
1207 p.printRegion(thenRegion,
1208 /*printEntryBlockArgs=*/false,
1209 /*printBlockTerminators=*/!omitRegionTerm(thenRegion));
1210
1211 // Print the 'else' regions if it exists and has a block.
1212 mlir::Region &elseRegion = this->getElseRegion();
1213 if (!elseRegion.empty()) {
1214 p << " else ";
1215 p.printRegion(elseRegion,
1216 /*printEntryBlockArgs=*/false,
1217 /*printBlockTerminators=*/!omitRegionTerm(elseRegion));
1218 }
1219
1220 p.printOptionalAttrDict(getOperation()->getAttrs());
1221}
1222
1223/// Default callback for IfOp builders.
1224void cir::buildTerminatedBody(OpBuilder &builder, Location loc) {
1225 // add cir.yield to end of the block
1226 cir::YieldOp::create(builder, loc);
1227}
1228
1229/// Given the region at `index`, or the parent operation if `index` is None,
1230/// return the successor regions. These are the regions that may be selected
1231/// during the flow of control. `operands` is a set of optional attributes that
1232/// correspond to a constant value for each operand, or null if that operand is
1233/// not a constant.
1234void cir::IfOp::getSuccessorRegions(mlir::RegionBranchPoint point,
1235 SmallVectorImpl<RegionSuccessor> &regions) {
1236 // The `then` and the `else` region branch back to the parent operation.
1237 if (!point.isParent()) {
1238 regions.push_back(RegionSuccessor::parent());
1239 return;
1240 }
1241
1242 // Don't consider the else region if it is empty.
1243 Region *elseRegion = &this->getElseRegion();
1244 if (elseRegion->empty())
1245 elseRegion = nullptr;
1246
1247 // If the condition isn't constant, both regions may be executed.
1248 regions.push_back(RegionSuccessor(&getThenRegion()));
1249 // If the else region does not exist, it is not a viable successor.
1250 if (elseRegion)
1251 regions.push_back(RegionSuccessor(elseRegion));
1252
1253 return;
1254}
1255
1256mlir::ValueRange cir::IfOp::getSuccessorInputs(RegionSuccessor successor) {
1257 return successor.isParent() ? ValueRange(getOperation()->getResults())
1258 : ValueRange();
1259}
1260
1261void cir::IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1262 bool withElseRegion, BuilderCallbackRef thenBuilder,
1263 BuilderCallbackRef elseBuilder) {
1264 assert(thenBuilder && "the builder callback for 'then' must be present");
1265 result.addOperands(cond);
1266
1267 OpBuilder::InsertionGuard guard(builder);
1268 Region *thenRegion = result.addRegion();
1269 builder.createBlock(thenRegion);
1270 thenBuilder(builder, result.location);
1271
1272 Region *elseRegion = result.addRegion();
1273 if (!withElseRegion)
1274 return;
1275
1276 builder.createBlock(elseRegion);
1277 elseBuilder(builder, result.location);
1278}
1279
1280//===----------------------------------------------------------------------===//
1281// ScopeOp
1282//===----------------------------------------------------------------------===//
1283
1284/// Given the region at `index`, or the parent operation if `index` is None,
1285/// return the successor regions. These are the regions that may be selected
1286/// during the flow of control. `operands` is a set of optional attributes
1287/// that correspond to a constant value for each operand, or null if that
1288/// operand is not a constant.
1289void cir::ScopeOp::getSuccessorRegions(
1290 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1291 // The only region always branch back to the parent operation.
1292 if (!point.isParent()) {
1293 regions.push_back(RegionSuccessor::parent());
1294 return;
1295 }
1296
1297 // If the condition isn't constant, both regions may be executed.
1298 regions.push_back(RegionSuccessor(&getScopeRegion()));
1299}
1300
1301mlir::ValueRange cir::ScopeOp::getSuccessorInputs(RegionSuccessor successor) {
1302 return successor.isParent() ? ValueRange(getOperation()->getResults())
1303 : ValueRange();
1304}
1305
1306void cir::ScopeOp::build(
1307 OpBuilder &builder, OperationState &result,
1308 function_ref<void(OpBuilder &, Type &, Location)> scopeBuilder) {
1309 assert(scopeBuilder && "the builder callback for 'then' must be present");
1310
1311 OpBuilder::InsertionGuard guard(builder);
1312 Region *scopeRegion = result.addRegion();
1313 builder.createBlock(scopeRegion);
1315
1316 mlir::Type yieldTy;
1317 scopeBuilder(builder, yieldTy, result.location);
1318
1319 if (yieldTy)
1320 result.addTypes(TypeRange{yieldTy});
1321}
1322
1323void cir::ScopeOp::build(
1324 OpBuilder &builder, OperationState &result,
1325 function_ref<void(OpBuilder &, Location)> scopeBuilder) {
1326 assert(scopeBuilder && "the builder callback for 'then' must be present");
1327 OpBuilder::InsertionGuard guard(builder);
1328 Region *scopeRegion = result.addRegion();
1329 builder.createBlock(scopeRegion);
1331 scopeBuilder(builder, result.location);
1332}
1333
1334LogicalResult cir::ScopeOp::verify() {
1335 if (getRegion().empty()) {
1336 return emitOpError() << "cir.scope must not be empty since it should "
1337 "include at least an implicit cir.yield ";
1338 }
1339
1340 mlir::Block &lastBlock = getRegion().back();
1341 if (lastBlock.empty() || !lastBlock.mightHaveTerminator() ||
1342 !lastBlock.getTerminator()->hasTrait<OpTrait::IsTerminator>())
1343 return emitOpError() << "last block of cir.scope must be terminated";
1344 return success();
1345}
1346
1347LogicalResult cir::ScopeOp::fold(FoldAdaptor /*adaptor*/,
1348 SmallVectorImpl<OpFoldResult> &results) {
1349 // Only fold "trivial" scopes: a single block containing only a `cir.yield`.
1350 if (!getRegion().hasOneBlock())
1351 return failure();
1352 Block &block = getRegion().front();
1353 if (block.getOperations().size() != 1)
1354 return failure();
1355
1356 auto yield = dyn_cast<cir::YieldOp>(block.front());
1357 if (!yield)
1358 return failure();
1359
1360 // Only fold when the scope produces a value.
1361 if (getNumResults() != 1 || yield.getNumOperands() != 1)
1362 return failure();
1363
1364 results.push_back(yield.getOperand(0));
1365 return success();
1366}
1367
1368//===----------------------------------------------------------------------===//
1369// CleanupScopeOp
1370//===----------------------------------------------------------------------===//
1371
1372void cir::CleanupScopeOp::getSuccessorRegions(
1373 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1374 if (!point.isParent()) {
1375 regions.push_back(RegionSuccessor::parent());
1376 return;
1377 }
1378
1379 // Execution always proceeds from the body region to the cleanup region.
1380 regions.push_back(RegionSuccessor(&getBodyRegion()));
1381 regions.push_back(RegionSuccessor(&getCleanupRegion()));
1382}
1383
1384mlir::ValueRange
1385cir::CleanupScopeOp::getSuccessorInputs(RegionSuccessor successor) {
1386 return ValueRange();
1387}
1388
1389void cir::CleanupScopeOp::build(
1390 OpBuilder &builder, OperationState &result, CleanupKind cleanupKind,
1391 function_ref<void(OpBuilder &, Location)> bodyBuilder,
1392 function_ref<void(OpBuilder &, Location)> cleanupBuilder) {
1393 result.addAttribute(getCleanupKindAttrName(result.name),
1394 CleanupKindAttr::get(builder.getContext(), cleanupKind));
1395
1396 OpBuilder::InsertionGuard guard(builder);
1397
1398 // Build body region.
1399 Region *bodyRegion = result.addRegion();
1400 builder.createBlock(bodyRegion);
1401 if (bodyBuilder)
1402 bodyBuilder(builder, result.location);
1403
1404 // Build cleanup region.
1405 Region *cleanupRegion = result.addRegion();
1406 builder.createBlock(cleanupRegion);
1407 if (cleanupBuilder)
1408 cleanupBuilder(builder, result.location);
1409}
1410
1411//===----------------------------------------------------------------------===//
1412// BrOp
1413//===----------------------------------------------------------------------===//
1414
1415/// Merges blocks connected by a unique unconditional branch.
1416///
1417/// ^bb0: ^bb0:
1418/// ... ...
1419/// cir.br ^bb1 => ...
1420/// ^bb1: cir.return
1421/// ...
1422/// cir.return
1423LogicalResult cir::BrOp::canonicalize(BrOp op, PatternRewriter &rewriter) {
1424 Block *src = op->getBlock();
1425 Block *dst = op.getDest();
1426
1427 // Do not fold self-loops.
1428 if (src == dst)
1429 return failure();
1430
1431 // Only merge when this is the unique edge between the blocks.
1432 if (src->getNumSuccessors() != 1 || dst->getSinglePredecessor() != src)
1433 return failure();
1434
1435 // Don't merge blocks that start with LabelOp or IndirectBrOp.
1436 // This is to avoid merging blocks that have an indirect predecessor.
1437 if (isa<cir::LabelOp, cir::IndirectBrOp>(dst->front()))
1438 return failure();
1439
1440 auto operands = op.getDestOperands();
1441 rewriter.eraseOp(op);
1442 rewriter.mergeBlocks(dst, src, operands);
1443 return success();
1444}
1445
1446mlir::SuccessorOperands cir::BrOp::getSuccessorOperands(unsigned index) {
1447 assert(index == 0 && "invalid successor index");
1448 return mlir::SuccessorOperands(getDestOperandsMutable());
1449}
1450
1451Block *cir::BrOp::getSuccessorForOperands(ArrayRef<Attribute>) {
1452 return getDest();
1453}
1454
1455//===----------------------------------------------------------------------===//
1456// IndirectBrCondOp
1457//===----------------------------------------------------------------------===//
1458
1459mlir::SuccessorOperands
1460cir::IndirectBrOp::getSuccessorOperands(unsigned index) {
1461 assert(index < getNumSuccessors() && "invalid successor index");
1462 return mlir::SuccessorOperands(getSuccOperandsMutable()[index]);
1463}
1464
1466 OpAsmParser &parser, Type &flagType,
1467 SmallVectorImpl<Block *> &succOperandBlocks,
1468 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &succOperands,
1469 SmallVectorImpl<SmallVector<Type>> &succOperandsTypes) {
1470 if (failed(parser.parseCommaSeparatedList(
1471 OpAsmParser::Delimiter::Square,
1472 [&]() {
1473 Block *destination = nullptr;
1474 SmallVector<OpAsmParser::UnresolvedOperand> operands;
1475 SmallVector<Type> operandTypes;
1476
1477 if (parser.parseSuccessor(destination).failed())
1478 return failure();
1479
1480 if (succeeded(parser.parseOptionalLParen())) {
1481 if (failed(parser.parseOperandList(
1482 operands, OpAsmParser::Delimiter::None)) ||
1483 failed(parser.parseColonTypeList(operandTypes)) ||
1484 failed(parser.parseRParen()))
1485 return failure();
1486 }
1487 succOperandBlocks.push_back(destination);
1488 succOperands.emplace_back(operands);
1489 succOperandsTypes.emplace_back(operandTypes);
1490 return success();
1491 },
1492 "successor blocks")))
1493 return failure();
1494 return success();
1495}
1496
1497void printIndirectBrOpSucessors(OpAsmPrinter &p, cir::IndirectBrOp op,
1498 Type flagType, SuccessorRange succs,
1499 OperandRangeRange succOperands,
1500 const TypeRangeRange &succOperandsTypes) {
1501 p << "[";
1502 llvm::interleave(
1503 llvm::zip(succs, succOperands),
1504 [&](auto i) {
1505 p.printNewline();
1506 p.printSuccessorAndUseList(std::get<0>(i), std::get<1>(i));
1507 },
1508 [&] { p << ','; });
1509 if (!succOperands.empty())
1510 p.printNewline();
1511 p << "]";
1512}
1513
1514//===----------------------------------------------------------------------===//
1515// BrCondOp
1516//===----------------------------------------------------------------------===//
1517
1518mlir::SuccessorOperands cir::BrCondOp::getSuccessorOperands(unsigned index) {
1519 assert(index < getNumSuccessors() && "invalid successor index");
1520 return SuccessorOperands(index == 0 ? getDestOperandsTrueMutable()
1521 : getDestOperandsFalseMutable());
1522}
1523
1524Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1525 if (IntegerAttr condAttr = dyn_cast_if_present<IntegerAttr>(operands.front()))
1526 return condAttr.getValue().isOne() ? getDestTrue() : getDestFalse();
1527 return nullptr;
1528}
1529
1530//===----------------------------------------------------------------------===//
1531// CaseOp
1532//===----------------------------------------------------------------------===//
1533
1534void cir::CaseOp::getSuccessorRegions(
1535 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1536 if (!point.isParent()) {
1537 regions.push_back(RegionSuccessor::parent());
1538 return;
1539 }
1540 regions.push_back(RegionSuccessor(&getCaseRegion()));
1541}
1542
1543mlir::ValueRange cir::CaseOp::getSuccessorInputs(RegionSuccessor successor) {
1544 return successor.isParent() ? ValueRange(getOperation()->getResults())
1545 : ValueRange();
1546}
1547
1548void cir::CaseOp::build(OpBuilder &builder, OperationState &result,
1549 ArrayAttr value, CaseOpKind kind,
1550 OpBuilder::InsertPoint &insertPoint) {
1551 OpBuilder::InsertionGuard guardSwitch(builder);
1552 result.addAttribute("value", value);
1553 result.getOrAddProperties<Properties>().kind =
1554 cir::CaseOpKindAttr::get(builder.getContext(), kind);
1555 Region *caseRegion = result.addRegion();
1556 builder.createBlock(caseRegion);
1557
1558 insertPoint = builder.saveInsertionPoint();
1559}
1560
1561//===----------------------------------------------------------------------===//
1562// SwitchOp
1563//===----------------------------------------------------------------------===//
1564
1565void cir::SwitchOp::getSuccessorRegions(
1566 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &region) {
1567 if (!point.isParent()) {
1568 region.push_back(RegionSuccessor::parent());
1569 return;
1570 }
1571
1572 region.push_back(RegionSuccessor(&getBody()));
1573}
1574
1575mlir::ValueRange cir::SwitchOp::getSuccessorInputs(RegionSuccessor successor) {
1576 return successor.isParent() ? ValueRange(getOperation()->getResults())
1577 : ValueRange();
1578}
1579
1580void cir::SwitchOp::build(OpBuilder &builder, OperationState &result,
1581 Value cond, BuilderOpStateCallbackRef switchBuilder) {
1582 assert(switchBuilder && "the builder callback for regions must be present");
1583 OpBuilder::InsertionGuard guardSwitch(builder);
1584 Region *switchRegion = result.addRegion();
1585 builder.createBlock(switchRegion);
1586 result.addOperands({cond});
1587 switchBuilder(builder, result.location, result);
1588}
1589
1590void cir::SwitchOp::collectCases(llvm::SmallVectorImpl<CaseOp> &cases) {
1591 walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
1592 // Don't walk in nested switch op.
1593 if (isa<cir::SwitchOp>(op) && op != *this)
1594 return WalkResult::skip();
1595
1596 if (auto caseOp = dyn_cast<cir::CaseOp>(op))
1597 cases.push_back(caseOp);
1598
1599 return WalkResult::advance();
1600 });
1601}
1602
1603bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
1604 collectCases(cases);
1605
1606 if (getBody().empty())
1607 return false;
1608
1609 if (!isa<YieldOp>(getBody().front().back()))
1610 return false;
1611
1612 if (!llvm::all_of(getBody().front(),
1613 [](Operation &op) { return isa<CaseOp, YieldOp>(op); }))
1614 return false;
1615
1616 return llvm::all_of(cases, [this](CaseOp op) {
1617 return op->getParentOfType<SwitchOp>() == *this;
1618 });
1619}
1620
1621//===----------------------------------------------------------------------===//
1622// SwitchFlatOp
1623//===----------------------------------------------------------------------===//
1624
1625void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
1626 Value value, Block *defaultDestination,
1627 ValueRange defaultOperands,
1628 ArrayRef<APInt> caseValues,
1629 BlockRange caseDestinations,
1630 ArrayRef<ValueRange> caseOperands) {
1631
1632 std::vector<mlir::Attribute> caseValuesAttrs;
1633 for (const APInt &val : caseValues)
1634 caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
1635 mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
1636
1637 build(builder, result, value, defaultOperands, caseOperands, attrs,
1638 defaultDestination, caseDestinations);
1639}
1640
1641/// <cases> ::= `[` (case (`,` case )* )? `]`
1642/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
1643static ParseResult parseSwitchFlatOpCases(
1644 OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
1645 SmallVectorImpl<Block *> &caseDestinations,
1647 &caseOperands,
1648 SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
1649 if (failed(parser.parseLSquare()))
1650 return failure();
1651 if (succeeded(parser.parseOptionalRSquare()))
1652 return success();
1654
1655 auto parseCase = [&]() {
1656 int64_t value = 0;
1657 if (failed(parser.parseInteger(value)))
1658 return failure();
1659
1660 values.push_back(cir::IntAttr::get(flagType, value));
1661
1662 Block *destination;
1664 llvm::SmallVector<Type> operandTypes;
1665 if (parser.parseColon() || parser.parseSuccessor(destination))
1666 return failure();
1667 if (!parser.parseOptionalLParen()) {
1668 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
1669 /*allowResultNumber=*/false) ||
1670 parser.parseColonTypeList(operandTypes) || parser.parseRParen())
1671 return failure();
1672 }
1673 caseDestinations.push_back(destination);
1674 caseOperands.emplace_back(operands);
1675 caseOperandTypes.emplace_back(operandTypes);
1676 return success();
1677 };
1678 if (failed(parser.parseCommaSeparatedList(parseCase)))
1679 return failure();
1680
1681 caseValues = ArrayAttr::get(flagType.getContext(), values);
1682
1683 return parser.parseRSquare();
1684}
1685
1686static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
1687 Type flagType, mlir::ArrayAttr caseValues,
1688 SuccessorRange caseDestinations,
1689 OperandRangeRange caseOperands,
1690 const TypeRangeRange &caseOperandTypes) {
1691 p << '[';
1692 p.printNewline();
1693 if (!caseValues) {
1694 p << ']';
1695 return;
1696 }
1697
1698 size_t index = 0;
1699 llvm::interleave(
1700 llvm::zip(caseValues, caseDestinations),
1701 [&](auto i) {
1702 p << " ";
1703 mlir::Attribute a = std::get<0>(i);
1704 p << mlir::cast<cir::IntAttr>(a).getValue();
1705 p << ": ";
1706 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
1707 },
1708 [&] {
1709 p << ',';
1710 p.printNewline();
1711 });
1712 p.printNewline();
1713 p << ']';
1714}
1715
1716//===----------------------------------------------------------------------===//
1717// GlobalOp
1718//===----------------------------------------------------------------------===//
1719
1720static ParseResult parseConstantValue(OpAsmParser &parser,
1721 mlir::Attribute &valueAttr) {
1722 NamedAttrList attr;
1723 return parser.parseAttribute(valueAttr, "value", attr);
1724}
1725
1726static void printConstant(OpAsmPrinter &p, Attribute value) {
1727 p.printAttribute(value);
1728}
1729
1730mlir::LogicalResult cir::GlobalOp::verify() {
1731 // Verify that the initial value, if present, is either a unit attribute or
1732 // an attribute CIR supports.
1733 if (getInitialValue().has_value()) {
1734 if (checkConstantTypes(getOperation(), getSymType(), *getInitialValue())
1735 .failed())
1736 return failure();
1737 }
1738
1739 // TODO(CIR): Many other checks for properties that haven't been upstreamed
1740 // yet.
1741
1742 return success();
1743}
1744
1745void cir::GlobalOp::build(
1746 OpBuilder &odsBuilder, OperationState &odsState, llvm::StringRef sym_name,
1747 mlir::Type sym_type, bool isConstant, cir::GlobalLinkageKind linkage,
1748 function_ref<void(OpBuilder &, Location)> ctorBuilder,
1749 function_ref<void(OpBuilder &, Location)> dtorBuilder) {
1750 odsState.addAttribute(getSymNameAttrName(odsState.name),
1751 odsBuilder.getStringAttr(sym_name));
1752 odsState.addAttribute(getSymTypeAttrName(odsState.name),
1753 mlir::TypeAttr::get(sym_type));
1754 if (isConstant)
1755 odsState.addAttribute(getConstantAttrName(odsState.name),
1756 odsBuilder.getUnitAttr());
1757
1758 cir::GlobalLinkageKindAttr linkageAttr =
1759 cir::GlobalLinkageKindAttr::get(odsBuilder.getContext(), linkage);
1760 odsState.addAttribute(getLinkageAttrName(odsState.name), linkageAttr);
1761
1762 Region *ctorRegion = odsState.addRegion();
1763 if (ctorBuilder) {
1764 odsBuilder.createBlock(ctorRegion);
1765 ctorBuilder(odsBuilder, odsState.location);
1766 }
1767
1768 Region *dtorRegion = odsState.addRegion();
1769 if (dtorBuilder) {
1770 odsBuilder.createBlock(dtorRegion);
1771 dtorBuilder(odsBuilder, odsState.location);
1772 }
1773
1774 odsState.addAttribute(getGlobalVisibilityAttrName(odsState.name),
1775 cir::VisibilityAttr::get(odsBuilder.getContext()));
1776}
1777
1778/// Given the region at `index`, or the parent operation if `index` is None,
1779/// return the successor regions. These are the regions that may be selected
1780/// during the flow of control. `operands` is a set of optional attributes that
1781/// correspond to a constant value for each operand, or null if that operand is
1782/// not a constant.
1783void cir::GlobalOp::getSuccessorRegions(
1784 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1785 // The `ctor` and `dtor` regions always branch back to the parent operation.
1786 if (!point.isParent()) {
1787 regions.push_back(RegionSuccessor::parent());
1788 return;
1789 }
1790
1791 // Don't consider the ctor region if it is empty.
1792 Region *ctorRegion = &this->getCtorRegion();
1793 if (ctorRegion->empty())
1794 ctorRegion = nullptr;
1795
1796 // Don't consider the dtor region if it is empty.
1797 Region *dtorRegion = &this->getDtorRegion();
1798 if (dtorRegion->empty())
1799 dtorRegion = nullptr;
1800
1801 // If the condition isn't constant, both regions may be executed.
1802 if (ctorRegion)
1803 regions.push_back(RegionSuccessor(ctorRegion));
1804 if (dtorRegion)
1805 regions.push_back(RegionSuccessor(dtorRegion));
1806}
1807
1808mlir::ValueRange cir::GlobalOp::getSuccessorInputs(RegionSuccessor successor) {
1809 return successor.isParent() ? ValueRange(getOperation()->getResults())
1810 : ValueRange();
1811}
1812
1813static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op,
1814 TypeAttr type, Attribute initAttr,
1815 mlir::Region &ctorRegion,
1816 mlir::Region &dtorRegion) {
1817 auto printType = [&]() { p << ": " << type; };
1818 if (!op.isDeclaration()) {
1819 p << "= ";
1820 if (!ctorRegion.empty()) {
1821 p << "ctor ";
1822 printType();
1823 p << " ";
1824 p.printRegion(ctorRegion,
1825 /*printEntryBlockArgs=*/false,
1826 /*printBlockTerminators=*/false);
1827 } else {
1828 // This also prints the type...
1829 if (initAttr)
1830 printConstant(p, initAttr);
1831 }
1832
1833 if (!dtorRegion.empty()) {
1834 p << " dtor ";
1835 p.printRegion(dtorRegion,
1836 /*printEntryBlockArgs=*/false,
1837 /*printBlockTerminators=*/false);
1838 }
1839 } else {
1840 printType();
1841 }
1842}
1843
1844static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser,
1845 TypeAttr &typeAttr,
1846 Attribute &initialValueAttr,
1847 mlir::Region &ctorRegion,
1848 mlir::Region &dtorRegion) {
1849 mlir::Type opTy;
1850 if (parser.parseOptionalEqual().failed()) {
1851 // Absence of equal means a declaration, so we need to parse the type.
1852 // cir.global @a : !cir.int<s, 32>
1853 if (parser.parseColonType(opTy))
1854 return failure();
1855 } else {
1856 // Parse contructor, example:
1857 // cir.global @rgb = ctor : type { ... }
1858 if (!parser.parseOptionalKeyword("ctor")) {
1859 if (parser.parseColonType(opTy))
1860 return failure();
1861 auto parseLoc = parser.getCurrentLocation();
1862 if (parser.parseRegion(ctorRegion, /*arguments=*/{}, /*argTypes=*/{}))
1863 return failure();
1864 if (ensureRegionTerm(parser, ctorRegion, parseLoc).failed())
1865 return failure();
1866 } else {
1867 // Parse constant with initializer, examples:
1868 // cir.global @y = 3.400000e+00 : f32
1869 // cir.global @rgb = #cir.const_array<[...] : !cir.array<i8 x 3>>
1870 if (parseConstantValue(parser, initialValueAttr).failed())
1871 return failure();
1872
1873 assert(mlir::isa<mlir::TypedAttr>(initialValueAttr) &&
1874 "Non-typed attrs shouldn't appear here.");
1875 auto typedAttr = mlir::cast<mlir::TypedAttr>(initialValueAttr);
1876 opTy = typedAttr.getType();
1877 }
1878
1879 // Parse destructor, example:
1880 // dtor { ... }
1881 if (!parser.parseOptionalKeyword("dtor")) {
1882 auto parseLoc = parser.getCurrentLocation();
1883 if (parser.parseRegion(dtorRegion, /*arguments=*/{}, /*argTypes=*/{}))
1884 return failure();
1885 if (ensureRegionTerm(parser, dtorRegion, parseLoc).failed())
1886 return failure();
1887 }
1888 }
1889
1890 typeAttr = TypeAttr::get(opTy);
1891 return success();
1892}
1893
1894//===----------------------------------------------------------------------===//
1895// GetGlobalOp
1896//===----------------------------------------------------------------------===//
1897
1898LogicalResult
1899cir::GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1900 // Verify that the result type underlying pointer type matches the type of
1901 // the referenced cir.global or cir.func op.
1902 mlir::Operation *op =
1903 symbolTable.lookupNearestSymbolFrom(*this, getNameAttr());
1904 if (op == nullptr || !(isa<GlobalOp>(op) || isa<FuncOp>(op)))
1905 return emitOpError("'")
1906 << getName()
1907 << "' does not reference a valid cir.global or cir.func";
1908
1909 mlir::Type symTy;
1910 if (auto g = dyn_cast<GlobalOp>(op)) {
1911 symTy = g.getSymType();
1913 // Verify that for thread local global access, the global needs to
1914 // be marked with tls bits.
1915 if (getTls() && !g.getTlsModel())
1916 return emitOpError("access to global not marked thread local");
1917
1918 // Verify that the static_local attribute on GetGlobalOp matches the
1919 // static_local_guard attribute on GlobalOp. GetGlobalOp uses a UnitAttr,
1920 // GlobalOp uses StaticLocalGuardAttr. Both should be present, or neither.
1921 bool getGlobalIsStaticLocal = getStaticLocal();
1922 bool globalIsStaticLocal = g.getStaticLocalGuard().has_value();
1923 if (getGlobalIsStaticLocal != globalIsStaticLocal &&
1924 !getOperation()->getParentOfType<cir::GlobalOp>())
1925 return emitOpError("static_local attribute mismatch");
1926 } else if (auto f = dyn_cast<FuncOp>(op)) {
1927 symTy = f.getFunctionType();
1928 } else {
1929 llvm_unreachable("Unexpected operation for GetGlobalOp");
1930 }
1931
1932 auto resultType = dyn_cast<PointerType>(getAddr().getType());
1933 if (!resultType || symTy != resultType.getPointee())
1934 return emitOpError("result type pointee type '")
1935 << resultType.getPointee() << "' does not match type " << symTy
1936 << " of the global @" << getName();
1937
1938 return success();
1939}
1940
1941//===----------------------------------------------------------------------===//
1942// VTableAddrPointOp
1943//===----------------------------------------------------------------------===//
1944
1945LogicalResult
1946cir::VTableAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1947 StringRef name = getName();
1948
1949 // Verify that the result type underlying pointer type matches the type of
1950 // the referenced cir.global.
1951 auto op =
1952 symbolTable.lookupNearestSymbolFrom<cir::GlobalOp>(*this, getNameAttr());
1953 if (!op)
1954 return emitOpError("'")
1955 << name << "' does not reference a valid cir.global";
1956 std::optional<mlir::Attribute> init = op.getInitialValue();
1957 if (!init)
1958 return success();
1959 if (!isa<cir::VTableAttr>(*init))
1960 return emitOpError("Expected #cir.vtable in initializer for global '")
1961 << name << "'";
1962 return success();
1963}
1964
1965//===----------------------------------------------------------------------===//
1966// VTTAddrPointOp
1967//===----------------------------------------------------------------------===//
1968
1969LogicalResult
1970cir::VTTAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1971 // VTT ptr is not coming from a symbol.
1972 if (!getName())
1973 return success();
1974 StringRef name = *getName();
1975
1976 // Verify that the result type underlying pointer type matches the type of
1977 // the referenced cir.global op.
1978 auto op =
1979 symbolTable.lookupNearestSymbolFrom<cir::GlobalOp>(*this, getNameAttr());
1980 if (!op)
1981 return emitOpError("'")
1982 << name << "' does not reference a valid cir.global";
1983 std::optional<mlir::Attribute> init = op.getInitialValue();
1984 if (!init)
1985 return success();
1986 if (!isa<cir::ConstArrayAttr>(*init))
1987 return emitOpError(
1988 "Expected constant array in initializer for global VTT '")
1989 << name << "'";
1990 return success();
1991}
1992
1993LogicalResult cir::VTTAddrPointOp::verify() {
1994 // The operation uses either a symbol or a value to operate, but not both
1995 if (getName() && getSymAddr())
1996 return emitOpError("should use either a symbol or value, but not both");
1997
1998 // If not a symbol, stick with the concrete type used for getSymAddr.
1999 if (getSymAddr())
2000 return success();
2001
2002 mlir::Type resultType = getAddr().getType();
2003 mlir::Type resTy = cir::PointerType::get(
2004 cir::PointerType::get(cir::VoidType::get(getContext())));
2005
2006 if (resultType != resTy)
2007 return emitOpError("result type must be ")
2008 << resTy << ", but provided result type is " << resultType;
2009 return success();
2010}
2011
2012//===----------------------------------------------------------------------===//
2013// FuncOp
2014//===----------------------------------------------------------------------===//
2015
2016/// Returns the name used for the linkage attribute. This *must* correspond to
2017/// the name of the attribute in ODS.
2018static llvm::StringRef getLinkageAttrNameString() { return "linkage"; }
2019
2020void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
2021 StringRef name, FuncType type,
2022 GlobalLinkageKind linkage) {
2023 result.addRegion();
2024 result.addAttribute(SymbolTable::getSymbolAttrName(),
2025 builder.getStringAttr(name));
2026 result.addAttribute(getFunctionTypeAttrName(result.name),
2027 TypeAttr::get(type));
2028 result.addAttribute(
2030 GlobalLinkageKindAttr::get(builder.getContext(), linkage));
2031 result.addAttribute(getGlobalVisibilityAttrName(result.name),
2032 cir::VisibilityAttr::get(builder.getContext()));
2033}
2034
2035ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
2036 llvm::SMLoc loc = parser.getCurrentLocation();
2037 mlir::Builder &builder = parser.getBuilder();
2038
2039 mlir::StringAttr builtinNameAttr = getBuiltinAttrName(state.name);
2040 mlir::StringAttr coroutineNameAttr = getCoroutineAttrName(state.name);
2041 mlir::StringAttr inlineKindNameAttr = getInlineKindAttrName(state.name);
2042 mlir::StringAttr lambdaNameAttr = getLambdaAttrName(state.name);
2043 mlir::StringAttr noProtoNameAttr = getNoProtoAttrName(state.name);
2044 mlir::StringAttr comdatNameAttr = getComdatAttrName(state.name);
2045 mlir::StringAttr visNameAttr = getSymVisibilityAttrName(state.name);
2046 mlir::StringAttr visibilityNameAttr = getGlobalVisibilityAttrName(state.name);
2047 mlir::StringAttr dsoLocalNameAttr = getDsoLocalAttrName(state.name);
2048 mlir::StringAttr specialMemberAttr = getCxxSpecialMemberAttrName(state.name);
2049
2050 if (::mlir::succeeded(parser.parseOptionalKeyword(builtinNameAttr.strref())))
2051 state.addAttribute(builtinNameAttr, parser.getBuilder().getUnitAttr());
2052 if (::mlir::succeeded(
2053 parser.parseOptionalKeyword(coroutineNameAttr.strref())))
2054 state.addAttribute(coroutineNameAttr, parser.getBuilder().getUnitAttr());
2055
2056 // Parse optional inline kind attribute
2057 cir::InlineKindAttr inlineKindAttr;
2058 if (failed(parseInlineKindAttr(parser, inlineKindAttr)))
2059 return failure();
2060 if (inlineKindAttr)
2061 state.addAttribute(inlineKindNameAttr, inlineKindAttr);
2062
2063 if (::mlir::succeeded(parser.parseOptionalKeyword(lambdaNameAttr.strref())))
2064 state.addAttribute(lambdaNameAttr, parser.getBuilder().getUnitAttr());
2065 if (parser.parseOptionalKeyword(noProtoNameAttr).succeeded())
2066 state.addAttribute(noProtoNameAttr, parser.getBuilder().getUnitAttr());
2067
2068 if (parser.parseOptionalKeyword(comdatNameAttr).succeeded())
2069 state.addAttribute(comdatNameAttr, parser.getBuilder().getUnitAttr());
2070
2071 // Default to external linkage if no keyword is provided.
2072 state.addAttribute(getLinkageAttrNameString(),
2073 GlobalLinkageKindAttr::get(
2074 parser.getContext(),
2076 parser, GlobalLinkageKind::ExternalLinkage)));
2077
2078 ::llvm::StringRef visAttrStr;
2079 if (parser.parseOptionalKeyword(&visAttrStr, {"private", "public", "nested"})
2080 .succeeded()) {
2081 state.addAttribute(visNameAttr,
2082 parser.getBuilder().getStringAttr(visAttrStr));
2083 }
2084
2085 cir::VisibilityAttr cirVisibilityAttr;
2086 parseVisibilityAttr(parser, cirVisibilityAttr);
2087 state.addAttribute(visibilityNameAttr, cirVisibilityAttr);
2088
2089 if (parser.parseOptionalKeyword(dsoLocalNameAttr).succeeded())
2090 state.addAttribute(dsoLocalNameAttr, parser.getBuilder().getUnitAttr());
2091
2092 StringAttr nameAttr;
2093 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2094 state.attributes))
2095 return failure();
2099 bool isVariadic = false;
2100 if (function_interface_impl::parseFunctionSignatureWithArguments(
2101 parser, /*allowVariadic=*/true, arguments, isVariadic, resultTypes,
2102 resultAttrs))
2103 return failure();
2106 bool argAttrsEmpty = true;
2107 for (OpAsmParser::Argument &arg : arguments) {
2108 argTypes.push_back(arg.type);
2109 // Add the 'empty' attribute anyway to make sure the arity matches, but we
2110 // only want to 'set' the attribute at the top level if there is SOME data
2111 // along the way.
2112 argAttrs.push_back(arg.attrs);
2113 if (arg.attrs)
2114 argAttrsEmpty = false;
2115 }
2116
2117 // These should be in sync anyway, but test both of them anyway.
2118 if (resultTypes.size() > 1 || resultAttrs.size() > 1)
2119 return parser.emitError(
2120 loc, "functions with multiple return types are not supported");
2121
2122 mlir::Type returnType =
2123 (resultTypes.empty() ? cir::VoidType::get(builder.getContext())
2124 : resultTypes.front());
2125
2126 cir::FuncType fnType = cir::FuncType::get(argTypes, returnType, isVariadic);
2127 if (!fnType)
2128 return failure();
2129
2130 state.addAttribute(getFunctionTypeAttrName(state.name),
2131 TypeAttr::get(fnType));
2132
2133 if (!resultAttrs.empty() && resultAttrs[0])
2134 state.addAttribute(
2135 getResAttrsAttrName(state.name),
2136 mlir::ArrayAttr::get(parser.getContext(), {resultAttrs[0]}));
2137
2138 if (!argAttrsEmpty)
2139 state.addAttribute(getArgAttrsAttrName(state.name),
2140 mlir::ArrayAttr::get(parser.getContext(), argAttrs));
2141
2142 bool hasAlias = false;
2143 mlir::StringAttr aliaseeNameAttr = getAliaseeAttrName(state.name);
2144 if (parser.parseOptionalKeyword("alias").succeeded()) {
2145 if (parser.parseLParen().failed())
2146 return failure();
2147 mlir::StringAttr aliaseeAttr;
2148 if (parser.parseOptionalSymbolName(aliaseeAttr).failed())
2149 return failure();
2150 state.addAttribute(aliaseeNameAttr, FlatSymbolRefAttr::get(aliaseeAttr));
2151 if (parser.parseRParen().failed())
2152 return failure();
2153 hasAlias = true;
2154 }
2155
2156 mlir::StringAttr personalityNameAttr = getPersonalityAttrName(state.name);
2157 if (parser.parseOptionalKeyword("personality").succeeded()) {
2158 if (parser.parseLParen().failed())
2159 return failure();
2160 mlir::StringAttr personalityAttr;
2161 if (parser.parseOptionalSymbolName(personalityAttr).failed())
2162 return failure();
2163 state.addAttribute(personalityNameAttr,
2164 FlatSymbolRefAttr::get(personalityAttr));
2165 if (parser.parseRParen().failed())
2166 return failure();
2167 }
2168
2169 auto parseGlobalDtorCtor =
2170 [&](StringRef keyword,
2171 llvm::function_ref<void(std::optional<int> prio)> createAttr)
2172 -> mlir::LogicalResult {
2173 if (mlir::succeeded(parser.parseOptionalKeyword(keyword))) {
2174 std::optional<int> priority;
2175 if (mlir::succeeded(parser.parseOptionalLParen())) {
2176 auto parsedPriority = mlir::FieldParser<int>::parse(parser);
2177 if (mlir::failed(parsedPriority))
2178 return parser.emitError(parser.getCurrentLocation(),
2179 "failed to parse 'priority', of type 'int'");
2180 priority = parsedPriority.value_or(int());
2181 // Parse literal ')'
2182 if (parser.parseRParen())
2183 return failure();
2184 }
2185 createAttr(priority);
2186 }
2187 return success();
2188 };
2189
2190 // Parse CXXSpecialMember attribute
2191 if (parser.parseOptionalKeyword("special_member").succeeded()) {
2192 cir::CXXCtorAttr ctorAttr;
2193 cir::CXXDtorAttr dtorAttr;
2194 cir::CXXAssignAttr assignAttr;
2195 if (parser.parseLess().failed())
2196 return failure();
2197 if (parser.parseOptionalAttribute(ctorAttr).has_value())
2198 state.addAttribute(specialMemberAttr, ctorAttr);
2199 else if (parser.parseOptionalAttribute(dtorAttr).has_value())
2200 state.addAttribute(specialMemberAttr, dtorAttr);
2201 else if (parser.parseOptionalAttribute(assignAttr).has_value())
2202 state.addAttribute(specialMemberAttr, assignAttr);
2203 if (parser.parseGreater().failed())
2204 return failure();
2205 }
2206
2207 if (parseGlobalDtorCtor("global_ctor", [&](std::optional<int> priority) {
2208 mlir::IntegerAttr globalCtorPriorityAttr =
2209 builder.getI32IntegerAttr(priority.value_or(65535));
2210 state.addAttribute(getGlobalCtorPriorityAttrName(state.name),
2211 globalCtorPriorityAttr);
2212 }).failed())
2213 return failure();
2214
2215 if (parseGlobalDtorCtor("global_dtor", [&](std::optional<int> priority) {
2216 mlir::IntegerAttr globalDtorPriorityAttr =
2217 builder.getI32IntegerAttr(priority.value_or(65535));
2218 state.addAttribute(getGlobalDtorPriorityAttrName(state.name),
2219 globalDtorPriorityAttr);
2220 }).failed())
2221 return failure();
2222
2223 if (parser.parseOptionalKeyword("side_effect").succeeded()) {
2224 cir::SideEffect sideEffect;
2225
2226 if (parser.parseLParen().failed() ||
2227 parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed() ||
2228 parser.parseRParen().failed())
2229 return failure();
2230
2231 auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
2232 state.addAttribute(CIRDialect::getSideEffectAttrName(), attr);
2233 }
2234
2235 // Parse the rest of the attributes.
2236 NamedAttrList parsedAttrs;
2237 if (parser.parseOptionalAttrDictWithKeyword(parsedAttrs))
2238 return failure();
2239
2240 for (StringRef disallowed : cir::FuncOp::getAttributeNames()) {
2241 if (parsedAttrs.get(disallowed))
2242 return parser.emitError(loc, "attribute '")
2243 << disallowed
2244 << "' should not be specified in the explicit attribute list";
2245 }
2246
2247 state.attributes.append(parsedAttrs);
2248
2249 // Parse the optional function body.
2250 auto *body = state.addRegion();
2251 OptionalParseResult parseResult = parser.parseOptionalRegion(
2252 *body, arguments, /*enableNameShadowing=*/false);
2253 if (parseResult.has_value()) {
2254 if (hasAlias)
2255 return parser.emitError(loc, "function alias shall not have a body");
2256 if (failed(*parseResult))
2257 return failure();
2258 // Function body was parsed, make sure its not empty.
2259 if (body->empty())
2260 return parser.emitError(loc, "expected non-empty function body");
2261 }
2262
2263 return success();
2264}
2265
2266// This function corresponds to `llvm::GlobalValue::isDeclaration` and should
2267// have a similar implementation. We don't currently ifuncs or materializable
2268// functions, but those should be handled here as they are implemented.
2269bool cir::FuncOp::isDeclaration() {
2271
2272 std::optional<StringRef> aliasee = getAliasee();
2273 if (!aliasee)
2274 return getFunctionBody().empty();
2275
2276 // Aliases are always definitions.
2277 return false;
2278}
2279
2280bool cir::FuncOp::isCXXSpecialMemberFunction() {
2281 return getCxxSpecialMemberAttr() != nullptr;
2282}
2283
2284bool cir::FuncOp::isCxxConstructor() {
2285 auto attr = getCxxSpecialMemberAttr();
2286 return attr && dyn_cast<CXXCtorAttr>(attr);
2287}
2288
2289bool cir::FuncOp::isCxxDestructor() {
2290 auto attr = getCxxSpecialMemberAttr();
2291 return attr && dyn_cast<CXXDtorAttr>(attr);
2292}
2293
2294bool cir::FuncOp::isCxxSpecialAssignment() {
2295 auto attr = getCxxSpecialMemberAttr();
2296 return attr && dyn_cast<CXXAssignAttr>(attr);
2297}
2298
2299std::optional<CtorKind> cir::FuncOp::getCxxConstructorKind() {
2300 mlir::Attribute attr = getCxxSpecialMemberAttr();
2301 if (attr) {
2302 if (auto ctor = dyn_cast<CXXCtorAttr>(attr))
2303 return ctor.getCtorKind();
2304 }
2305 return std::nullopt;
2306}
2307
2308std::optional<AssignKind> cir::FuncOp::getCxxSpecialAssignKind() {
2309 mlir::Attribute attr = getCxxSpecialMemberAttr();
2310 if (attr) {
2311 if (auto assign = dyn_cast<CXXAssignAttr>(attr))
2312 return assign.getAssignKind();
2313 }
2314 return std::nullopt;
2315}
2316
2317bool cir::FuncOp::isCxxTrivialMemberFunction() {
2318 mlir::Attribute attr = getCxxSpecialMemberAttr();
2319 if (attr) {
2320 if (auto ctor = dyn_cast<CXXCtorAttr>(attr))
2321 return ctor.getIsTrivial();
2322 if (auto dtor = dyn_cast<CXXDtorAttr>(attr))
2323 return dtor.getIsTrivial();
2324 if (auto assign = dyn_cast<CXXAssignAttr>(attr))
2325 return assign.getIsTrivial();
2326 }
2327 return false;
2328}
2329
2330mlir::Region *cir::FuncOp::getCallableRegion() {
2331 // TODO(CIR): This function will have special handling for aliases and a
2332 // check for an external function, once those features have been upstreamed.
2333 return &getBody();
2334}
2335
2336void cir::FuncOp::print(OpAsmPrinter &p) {
2337 if (getBuiltin())
2338 p << " builtin";
2339
2340 if (getCoroutine())
2341 p << " coroutine";
2342
2343 printInlineKindAttr(p, getInlineKindAttr());
2344
2345 if (getLambda())
2346 p << " lambda";
2347
2348 if (getNoProto())
2349 p << " no_proto";
2350
2351 if (getComdat())
2352 p << " comdat";
2353
2354 if (getLinkage() != GlobalLinkageKind::ExternalLinkage)
2355 p << ' ' << stringifyGlobalLinkageKind(getLinkage());
2356
2357 mlir::SymbolTable::Visibility vis = getVisibility();
2358 if (vis != mlir::SymbolTable::Visibility::Public)
2359 p << ' ' << vis;
2360
2361 cir::VisibilityAttr cirVisibilityAttr = getGlobalVisibilityAttr();
2362 if (!cirVisibilityAttr.isDefault()) {
2363 p << ' ';
2364 printVisibilityAttr(p, cirVisibilityAttr);
2365 }
2366
2367 if (getDsoLocal())
2368 p << " dso_local";
2369
2370 p << ' ';
2371 p.printSymbolName(getSymName());
2372 cir::FuncType fnType = getFunctionType();
2373 function_interface_impl::printFunctionSignature(
2374 p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes());
2375
2376 if (std::optional<StringRef> aliaseeName = getAliasee()) {
2377 p << " alias(";
2378 p.printSymbolName(*aliaseeName);
2379 p << ")";
2380 }
2381
2382 if (std::optional<StringRef> personalityName = getPersonality()) {
2383 p << " personality(";
2384 p.printSymbolName(*personalityName);
2385 p << ")";
2386 }
2387
2388 if (auto specialMemberAttr = getCxxSpecialMember()) {
2389 p << " special_member<";
2390 p.printAttribute(*specialMemberAttr);
2391 p << '>';
2392 }
2393
2394 if (auto globalCtorPriority = getGlobalCtorPriority()) {
2395 p << " global_ctor";
2396 if (globalCtorPriority.value() != 65535)
2397 p << "(" << globalCtorPriority.value() << ")";
2398 }
2399
2400 if (auto globalDtorPriority = getGlobalDtorPriority()) {
2401 p << " global_dtor";
2402 if (globalDtorPriority.value() != 65535)
2403 p << "(" << globalDtorPriority.value() << ")";
2404 }
2405
2406 if (std::optional<cir::SideEffect> sideEffect = getSideEffect();
2407 sideEffect && *sideEffect != cir::SideEffect::All) {
2408 p << " side_effect(";
2409 p << stringifySideEffect(*sideEffect);
2410 p << ")";
2411 }
2412
2413 function_interface_impl::printFunctionAttributes(
2414 p, *this, cir::FuncOp::getAttributeNames());
2415
2416 // Print the body if this is not an external function.
2417 Region &body = getOperation()->getRegion(0);
2418 if (!body.empty()) {
2419 p << ' ';
2420 p.printRegion(body, /*printEntryBlockArgs=*/false,
2421 /*printBlockTerminators=*/true);
2422 }
2423}
2424
2425mlir::LogicalResult cir::FuncOp::verify() {
2426
2427 if (!isDeclaration() && getCoroutine()) {
2428 bool foundAwait = false;
2429 this->walk([&](Operation *op) {
2430 if (auto await = dyn_cast<AwaitOp>(op)) {
2431 foundAwait = true;
2432 return;
2433 }
2434 });
2435 if (!foundAwait)
2436 return emitOpError()
2437 << "coroutine body must use at least one cir.await op";
2438 }
2439
2440 llvm::SmallSet<llvm::StringRef, 16> labels;
2441 llvm::SmallSet<llvm::StringRef, 16> gotos;
2442 llvm::SmallSet<llvm::StringRef, 16> blockAddresses;
2443 bool invalidBlockAddress = false;
2444 getOperation()->walk([&](mlir::Operation *op) {
2445 if (auto lab = dyn_cast<cir::LabelOp>(op)) {
2446 labels.insert(lab.getLabel());
2447 } else if (auto goTo = dyn_cast<cir::GotoOp>(op)) {
2448 gotos.insert(goTo.getLabel());
2449 } else if (auto blkAdd = dyn_cast<cir::BlockAddressOp>(op)) {
2450 if (blkAdd.getBlockAddrInfoAttr().getFunc().getAttr() != getSymName()) {
2451 // Stop the walk early, no need to continue
2452 invalidBlockAddress = true;
2453 return mlir::WalkResult::interrupt();
2454 }
2455 blockAddresses.insert(blkAdd.getBlockAddrInfoAttr().getLabel());
2456 }
2457 return mlir::WalkResult::advance();
2458 });
2459
2460 if (invalidBlockAddress)
2461 return emitOpError() << "blockaddress references a different function";
2462
2463 llvm::SmallSet<llvm::StringRef, 16> mismatched;
2464 if (!labels.empty() || !gotos.empty()) {
2465 mismatched = llvm::set_difference(gotos, labels);
2466
2467 if (!mismatched.empty())
2468 return emitOpError() << "goto/label mismatch";
2469 }
2470
2471 mismatched.clear();
2472
2473 if (!labels.empty() || !blockAddresses.empty()) {
2474 mismatched = llvm::set_difference(blockAddresses, labels);
2475
2476 if (!mismatched.empty())
2477 return emitOpError()
2478 << "expects an existing label target in the referenced function";
2479 }
2480
2481 return success();
2482}
2483
2484//===----------------------------------------------------------------------===//
2485// AddOp / SubOp / MulOp
2486//===----------------------------------------------------------------------===//
2487
2488static LogicalResult verifyBinaryOverflowOp(mlir::Operation *op,
2489 bool noSignedWrap,
2490 bool noUnsignedWrap, bool saturated,
2491 bool hasSat) {
2492 bool noWrap = noSignedWrap || noUnsignedWrap;
2493 if (!isa<cir::IntType>(op->getResultTypes()[0]) && noWrap)
2494 return op->emitError()
2495 << "only operations on integer values may have nsw/nuw flags";
2496 if (hasSat && saturated && !isa<cir::IntType>(op->getResultTypes()[0]))
2497 return op->emitError()
2498 << "only operations on integer values may have sat flag";
2499 if (hasSat && noWrap && saturated)
2500 return op->emitError()
2501 << "the nsw/nuw flags and the saturated flag are mutually exclusive";
2502 return mlir::success();
2503}
2504
2505LogicalResult cir::AddOp::verify() {
2506 return verifyBinaryOverflowOp(getOperation(), getNoSignedWrap(),
2507 getNoUnsignedWrap(), getSaturated(),
2508 /*hasSat=*/true);
2509}
2510
2511LogicalResult cir::SubOp::verify() {
2512 return verifyBinaryOverflowOp(getOperation(), getNoSignedWrap(),
2513 getNoUnsignedWrap(), getSaturated(),
2514 /*hasSat=*/true);
2515}
2516
2517LogicalResult cir::MulOp::verify() {
2518 return verifyBinaryOverflowOp(getOperation(), getNoSignedWrap(),
2519 getNoUnsignedWrap(), /*saturated=*/false,
2520 /*hasSat=*/false);
2521}
2522
2523//===----------------------------------------------------------------------===//
2524// TernaryOp
2525//===----------------------------------------------------------------------===//
2526
2527/// Given the region at `point`, or the parent operation if `point` is None,
2528/// return the successor regions. These are the regions that may be selected
2529/// during the flow of control. `operands` is a set of optional attributes that
2530/// correspond to a constant value for each operand, or null if that operand is
2531/// not a constant.
2532void cir::TernaryOp::getSuccessorRegions(
2533 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2534 // The `true` and the `false` region branch back to the parent operation.
2535 if (!point.isParent()) {
2536 regions.push_back(RegionSuccessor::parent());
2537 return;
2538 }
2539
2540 // When branching from the parent operation, both the true and false
2541 // regions are considered possible successors
2542 regions.push_back(RegionSuccessor(&getTrueRegion()));
2543 regions.push_back(RegionSuccessor(&getFalseRegion()));
2544}
2545
2546mlir::ValueRange cir::TernaryOp::getSuccessorInputs(RegionSuccessor successor) {
2547 return successor.isParent() ? ValueRange(getOperation()->getResults())
2548 : ValueRange();
2549}
2550
2551void cir::TernaryOp::build(
2552 OpBuilder &builder, OperationState &result, Value cond,
2553 function_ref<void(OpBuilder &, Location)> trueBuilder,
2554 function_ref<void(OpBuilder &, Location)> falseBuilder) {
2555 result.addOperands(cond);
2556 OpBuilder::InsertionGuard guard(builder);
2557 Region *trueRegion = result.addRegion();
2558 builder.createBlock(trueRegion);
2559 trueBuilder(builder, result.location);
2560 Region *falseRegion = result.addRegion();
2561 builder.createBlock(falseRegion);
2562 falseBuilder(builder, result.location);
2563
2564 // Get result type from whichever branch has a yield (the other may have
2565 // unreachable from a throw expression)
2566 cir::YieldOp yield;
2567 if (trueRegion->back().mightHaveTerminator())
2568 yield = dyn_cast_or_null<cir::YieldOp>(trueRegion->back().getTerminator());
2569 if (!yield && falseRegion->back().mightHaveTerminator())
2570 yield = dyn_cast_or_null<cir::YieldOp>(falseRegion->back().getTerminator());
2571
2572 assert((!yield || yield.getNumOperands() <= 1) &&
2573 "expected zero or one result type");
2574 if (yield && yield.getNumOperands() == 1)
2575 result.addTypes(TypeRange{yield.getOperandTypes().front()});
2576}
2577
2578//===----------------------------------------------------------------------===//
2579// SelectOp
2580//===----------------------------------------------------------------------===//
2581
2582OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
2583 mlir::Attribute condition = adaptor.getCondition();
2584 if (condition) {
2585 bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
2586 return conditionValue ? getTrueValue() : getFalseValue();
2587 }
2588
2589 // cir.select if %0 then x else x -> x
2590 mlir::Attribute trueValue = adaptor.getTrueValue();
2591 mlir::Attribute falseValue = adaptor.getFalseValue();
2592 if (trueValue == falseValue)
2593 return trueValue;
2594 if (getTrueValue() == getFalseValue())
2595 return getTrueValue();
2596
2597 return {};
2598}
2599
2600LogicalResult cir::SelectOp::verify() {
2601 // AllTypesMatch already guarantees trueVal and falseVal have matching types.
2602 auto condTy = dyn_cast<cir::VectorType>(getCondition().getType());
2603
2604 // If condition is not a vector, no further checks are needed.
2605 if (!condTy)
2606 return success();
2607
2608 // When condition is a vector, both other operands must also be vectors.
2609 if (!isa<cir::VectorType>(getTrueValue().getType()) ||
2610 !isa<cir::VectorType>(getFalseValue().getType())) {
2611 return emitOpError()
2612 << "expected both true and false operands to be vector types "
2613 "when the condition is a vector boolean type";
2614 }
2615
2616 return success();
2617}
2618
2619//===----------------------------------------------------------------------===//
2620// ShiftOp
2621//===----------------------------------------------------------------------===//
2622LogicalResult cir::ShiftOp::verify() {
2623 mlir::Operation *op = getOperation();
2624 auto op0VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(0).getType());
2625 auto op1VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(1).getType());
2626 if (!op0VecTy ^ !op1VecTy)
2627 return emitOpError() << "input types cannot be one vector and one scalar";
2628
2629 if (op0VecTy) {
2630 if (op0VecTy.getSize() != op1VecTy.getSize())
2631 return emitOpError() << "input vector types must have the same size";
2632
2633 auto opResultTy = mlir::dyn_cast<cir::VectorType>(getType());
2634 if (!opResultTy)
2635 return emitOpError() << "the type of the result must be a vector "
2636 << "if it is vector shift";
2637
2638 auto op0VecEleTy = mlir::cast<cir::IntType>(op0VecTy.getElementType());
2639 auto op1VecEleTy = mlir::cast<cir::IntType>(op1VecTy.getElementType());
2640 if (op0VecEleTy.getWidth() != op1VecEleTy.getWidth())
2641 return emitOpError()
2642 << "vector operands do not have the same elements sizes";
2643
2644 auto resVecEleTy = mlir::cast<cir::IntType>(opResultTy.getElementType());
2645 if (op0VecEleTy.getWidth() != resVecEleTy.getWidth())
2646 return emitOpError() << "vector operands and result type do not have the "
2647 "same elements sizes";
2648 }
2649
2650 return mlir::success();
2651}
2652
2653//===----------------------------------------------------------------------===//
2654// LabelOp Definitions
2655//===----------------------------------------------------------------------===//
2656
2657LogicalResult cir::LabelOp::verify() {
2658 mlir::Operation *op = getOperation();
2659 mlir::Block *blk = op->getBlock();
2660 if (&blk->front() != op)
2661 return emitError() << "must be the first operation in a block";
2662
2663 return mlir::success();
2664}
2665
2666//===----------------------------------------------------------------------===//
2667// UnaryOp
2668//===----------------------------------------------------------------------===//
2669
2670LogicalResult cir::UnaryOp::verify() {
2671 switch (getKind()) {
2672 case cir::UnaryOpKind::Inc:
2673 case cir::UnaryOpKind::Dec:
2674 case cir::UnaryOpKind::Plus:
2675 case cir::UnaryOpKind::Minus:
2676 case cir::UnaryOpKind::Not:
2677 // Nothing to verify.
2678 return success();
2679 }
2680
2681 llvm_unreachable("Unknown UnaryOp kind?");
2682}
2683
2684static bool isBoolNot(cir::UnaryOp op) {
2685 return isa<cir::BoolType>(op.getInput().getType()) &&
2686 op.getKind() == cir::UnaryOpKind::Not;
2687}
2688
2689// This folder simplifies the sequential boolean not operations.
2690// For instance, the next two unary operations will be eliminated:
2691//
2692// ```mlir
2693// %1 = cir.unary(not, %0) : !cir.bool, !cir.bool
2694// %2 = cir.unary(not, %1) : !cir.bool, !cir.bool
2695// ```
2696//
2697// and the argument of the first one (%0) will be used instead.
2698OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) {
2699 if (auto poison =
2700 mlir::dyn_cast_if_present<cir::PoisonAttr>(adaptor.getInput())) {
2701 // Propagate poison values
2702 return poison;
2703 }
2704
2705 if (isBoolNot(*this))
2706 if (auto previous = getInput().getDefiningOp<cir::UnaryOp>())
2707 if (isBoolNot(previous))
2708 return previous.getInput();
2709
2710 // Avoid introducing unnecessary duplicate constants in cases where we are
2711 // just folding the operation to its input value. If we return the
2712 // input attribute from the adapter, a new constant is materialized, but
2713 // if we return the input value directly, it avoids that.
2714 if (auto srcConst = getInput().getDefiningOp<cir::ConstantOp>()) {
2715 if (getKind() == cir::UnaryOpKind::Plus ||
2716 (mlir::isa<cir::BoolType>(srcConst.getType()) &&
2717 getKind() == cir::UnaryOpKind::Minus))
2718 return srcConst.getResult();
2719 }
2720
2721 // Fold unary operations with constant inputs. If the input is a ConstantOp,
2722 // it "folds" to its value attribute. If it was some other operation that
2723 // was folded, it will be an mlir::Attribute that hasn't yet been
2724 // materialized. If it was a value that couldn't be folded, it will be null.
2725 if (mlir::Attribute attr = adaptor.getInput()) {
2726 // For now, we only attempt to fold simple scalar values.
2727 OpFoldResult result =
2728 llvm::TypeSwitch<mlir::Attribute, OpFoldResult>(attr)
2729 .Case<cir::IntAttr>([&](cir::IntAttr attrT) {
2730 switch (getKind()) {
2731 case cir::UnaryOpKind::Not: {
2732 APInt val = attrT.getValue();
2733 val.flipAllBits();
2734 return cir::IntAttr::get(getType(), val);
2735 }
2736 case cir::UnaryOpKind::Plus:
2737 return attrT;
2738 case cir::UnaryOpKind::Minus: {
2739 APInt val = attrT.getValue();
2740 val.negate();
2741 return cir::IntAttr::get(getType(), val);
2742 }
2743 default:
2744 return cir::IntAttr{};
2745 }
2746 })
2747 .Case<cir::FPAttr>([&](cir::FPAttr attrT) {
2748 switch (getKind()) {
2749 case cir::UnaryOpKind::Plus:
2750 return attrT;
2751 case cir::UnaryOpKind::Minus: {
2752 APFloat val = attrT.getValue();
2753 val.changeSign();
2754 return cir::FPAttr::get(getType(), val);
2755 }
2756 default:
2757 return cir::FPAttr{};
2758 }
2759 })
2760 .Case<cir::BoolAttr>([&](cir::BoolAttr attrT) {
2761 switch (getKind()) {
2762 case cir::UnaryOpKind::Not:
2763 return cir::BoolAttr::get(getContext(), !attrT.getValue());
2764 case cir::UnaryOpKind::Plus:
2765 case cir::UnaryOpKind::Minus:
2766 return attrT;
2767 default:
2768 return cir::BoolAttr{};
2769 }
2770 })
2771 .Default([&](auto attrT) { return mlir::Attribute{}; });
2772 if (result)
2773 return result;
2774 }
2775
2776 return {};
2777}
2778
2779//===----------------------------------------------------------------------===//
2780// BaseDataMemberOp & DerivedDataMemberOp
2781//===----------------------------------------------------------------------===//
2782
2783static LogicalResult verifyMemberPtrCast(Operation *op, mlir::Value src,
2784 mlir::Type resultTy) {
2785 // Let the operand type be T1 C1::*, let the result type be T2 C2::*.
2786 // Verify that T1 and T2 are the same type.
2787 mlir::Type inputMemberTy;
2788 mlir::Type resultMemberTy;
2789 if (mlir::isa<cir::DataMemberType>(src.getType())) {
2790 inputMemberTy =
2791 mlir::cast<cir::DataMemberType>(src.getType()).getMemberTy();
2792 resultMemberTy = mlir::cast<cir::DataMemberType>(resultTy).getMemberTy();
2793 }
2795 if (inputMemberTy != resultMemberTy)
2796 return op->emitOpError()
2797 << "member types of the operand and the result do not match";
2798
2799 return mlir::success();
2800}
2801
2802LogicalResult cir::BaseDataMemberOp::verify() {
2803 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
2804}
2805
2806LogicalResult cir::DerivedDataMemberOp::verify() {
2807 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
2808}
2809
2810//===----------------------------------------------------------------------===//
2811// BaseMethodOp & DerivedMethodOp
2812//===----------------------------------------------------------------------===//
2813
2814LogicalResult cir::BaseMethodOp::verify() {
2815 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
2816}
2817
2818LogicalResult cir::DerivedMethodOp::verify() {
2819 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
2820}
2821
2822//===----------------------------------------------------------------------===//
2823// AwaitOp
2824//===----------------------------------------------------------------------===//
2825
2826void cir::AwaitOp::build(OpBuilder &builder, OperationState &result,
2827 cir::AwaitKind kind, BuilderCallbackRef readyBuilder,
2828 BuilderCallbackRef suspendBuilder,
2829 BuilderCallbackRef resumeBuilder) {
2830 result.addAttribute(getKindAttrName(result.name),
2831 cir::AwaitKindAttr::get(builder.getContext(), kind));
2832 {
2833 OpBuilder::InsertionGuard guard(builder);
2834 Region *readyRegion = result.addRegion();
2835 builder.createBlock(readyRegion);
2836 readyBuilder(builder, result.location);
2837 }
2838
2839 {
2840 OpBuilder::InsertionGuard guard(builder);
2841 Region *suspendRegion = result.addRegion();
2842 builder.createBlock(suspendRegion);
2843 suspendBuilder(builder, result.location);
2844 }
2845
2846 {
2847 OpBuilder::InsertionGuard guard(builder);
2848 Region *resumeRegion = result.addRegion();
2849 builder.createBlock(resumeRegion);
2850 resumeBuilder(builder, result.location);
2851 }
2852}
2853
2854void cir::AwaitOp::getSuccessorRegions(
2855 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2856 // If any index all the underlying regions branch back to the parent
2857 // operation.
2858 if (!point.isParent()) {
2859 regions.push_back(RegionSuccessor::parent());
2860 return;
2861 }
2862
2863 // TODO: retrieve information from the promise and only push the
2864 // necessary ones. Example: `std::suspend_never` on initial or final
2865 // await's might allow suspend region to be skipped.
2866 regions.push_back(RegionSuccessor(&this->getReady()));
2867 regions.push_back(RegionSuccessor(&this->getSuspend()));
2868 regions.push_back(RegionSuccessor(&this->getResume()));
2869}
2870
2871mlir::ValueRange cir::AwaitOp::getSuccessorInputs(RegionSuccessor successor) {
2872 if (successor.isParent())
2873 return getOperation()->getResults();
2874 if (successor == &getReady())
2875 return getReady().getArguments();
2876 if (successor == &getSuspend())
2877 return getSuspend().getArguments();
2878 if (successor == &getResume())
2879 return getResume().getArguments();
2880 llvm_unreachable("invalid region successor");
2881}
2882
2883LogicalResult cir::AwaitOp::verify() {
2884 if (!isa<ConditionOp>(this->getReady().back().getTerminator()))
2885 return emitOpError("ready region must end with cir.condition");
2886 return success();
2887}
2888
2889//===----------------------------------------------------------------------===//
2890// CopyOp Definitions
2891//===----------------------------------------------------------------------===//
2892
2893LogicalResult cir::CopyOp::verify() {
2894 // A data layout is required for us to know the number of bytes to be copied.
2895 if (!getType().getPointee().hasTrait<DataLayoutTypeInterface::Trait>())
2896 return emitError() << "missing data layout for pointee type";
2897
2898 if (getSrc() == getDst())
2899 return emitError() << "source and destination are the same";
2900
2901 return mlir::success();
2902}
2903
2904//===----------------------------------------------------------------------===//
2905// GetRuntimeMemberOp Definitions
2906//===----------------------------------------------------------------------===//
2907
2908LogicalResult cir::GetRuntimeMemberOp::verify() {
2909 auto recordTy = mlir::cast<RecordType>(getAddr().getType().getPointee());
2910 cir::DataMemberType memberPtrTy = getMember().getType();
2911
2912 if (recordTy != memberPtrTy.getClassTy())
2913 return emitError() << "record type does not match the member pointer type";
2914 if (getType().getPointee() != memberPtrTy.getMemberTy())
2915 return emitError() << "result type does not match the member pointer type";
2916 return mlir::success();
2917}
2918
2919//===----------------------------------------------------------------------===//
2920// GetMethodOp Definitions
2921//===----------------------------------------------------------------------===//
2922
2923LogicalResult cir::GetMethodOp::verify() {
2924 cir::MethodType methodTy = getMethod().getType();
2925
2926 // Assume objectTy is !cir.ptr<!T>
2927 cir::PointerType objectPtrTy = getObject().getType();
2928 mlir::Type objectTy = objectPtrTy.getPointee();
2929
2930 if (methodTy.getClassTy() != objectTy)
2931 return emitError() << "method class type and object type do not match";
2932
2933 // Assume methodFuncTy is !cir.func<!Ret (!Args)>
2934 auto calleeTy = mlir::cast<cir::FuncType>(getCallee().getType().getPointee());
2935 cir::FuncType methodFuncTy = methodTy.getMemberFuncTy();
2936
2937 // We verify at here that calleeTy is !cir.func<!Ret (!cir.ptr<!void>, !Args)>
2938 // Note that the first parameter type of the callee is !cir.ptr<!void> instead
2939 // of !cir.ptr<!T> because the "this" pointer may be adjusted before calling
2940 // the callee.
2941
2942 if (methodFuncTy.getReturnType() != calleeTy.getReturnType())
2943 return emitError()
2944 << "method return type and callee return type do not match";
2945
2946 llvm::ArrayRef<mlir::Type> calleeArgsTy = calleeTy.getInputs();
2947 llvm::ArrayRef<mlir::Type> methodFuncArgsTy = methodFuncTy.getInputs();
2948
2949 if (calleeArgsTy.empty())
2950 return emitError() << "callee parameter list lacks receiver object ptr";
2951
2952 auto calleeThisArgPtrTy = mlir::dyn_cast<cir::PointerType>(calleeArgsTy[0]);
2953 if (!calleeThisArgPtrTy ||
2954 !mlir::isa<cir::VoidType>(calleeThisArgPtrTy.getPointee())) {
2955 return emitError()
2956 << "the first parameter of callee must be a void pointer";
2957 }
2958
2959 if (calleeArgsTy.slice(1) != methodFuncArgsTy)
2960 return emitError()
2961 << "callee parameters and method parameters do not match";
2962
2963 return mlir::success();
2964}
2965
2966//===----------------------------------------------------------------------===//
2967// GetMemberOp Definitions
2968//===----------------------------------------------------------------------===//
2969
2970LogicalResult cir::GetMemberOp::verify() {
2971 const auto recordTy = dyn_cast<RecordType>(getAddrTy().getPointee());
2972 if (!recordTy)
2973 return emitError() << "expected pointer to a record type";
2974
2975 if (recordTy.getMembers().size() <= getIndex())
2976 return emitError() << "member index out of bounds";
2977
2978 if (recordTy.getMembers()[getIndex()] != getType().getPointee())
2979 return emitError() << "member type mismatch";
2980
2981 return mlir::success();
2982}
2983
2984//===----------------------------------------------------------------------===//
2985// ExtractMemberOp Definitions
2986//===----------------------------------------------------------------------===//
2987
2988LogicalResult cir::ExtractMemberOp::verify() {
2989 auto recordTy = mlir::cast<cir::RecordType>(getRecord().getType());
2990 if (recordTy.getKind() == cir::RecordType::Union)
2991 return emitError()
2992 << "cir.extract_member currently does not support unions";
2993 if (recordTy.getMembers().size() <= getIndex())
2994 return emitError() << "member index out of bounds";
2995 if (recordTy.getMembers()[getIndex()] != getType())
2996 return emitError() << "member type mismatch";
2997 return mlir::success();
2998}
2999
3000//===----------------------------------------------------------------------===//
3001// InsertMemberOp Definitions
3002//===----------------------------------------------------------------------===//
3003
3004LogicalResult cir::InsertMemberOp::verify() {
3005 auto recordTy = mlir::cast<cir::RecordType>(getRecord().getType());
3006 if (recordTy.getKind() == cir::RecordType::Union)
3007 return emitError() << "cir.insert_member currently does not support unions";
3008 if (recordTy.getMembers().size() <= getIndex())
3009 return emitError() << "member index out of bounds";
3010 if (recordTy.getMembers()[getIndex()] != getValue().getType())
3011 return emitError() << "member type mismatch";
3012 // The op trait already checks that the types of $result and $record match.
3013 return mlir::success();
3014}
3015
3016//===----------------------------------------------------------------------===//
3017// VecCreateOp
3018//===----------------------------------------------------------------------===//
3019
3020OpFoldResult cir::VecCreateOp::fold(FoldAdaptor adaptor) {
3021 if (llvm::any_of(getElements(), [](mlir::Value value) {
3022 return !value.getDefiningOp<cir::ConstantOp>();
3023 }))
3024 return {};
3025
3026 return cir::ConstVectorAttr::get(
3027 getType(), mlir::ArrayAttr::get(getContext(), adaptor.getElements()));
3028}
3029
3030LogicalResult cir::VecCreateOp::verify() {
3031 // Verify that the number of arguments matches the number of elements in the
3032 // vector, and that the type of all the arguments matches the type of the
3033 // elements in the vector.
3034 const cir::VectorType vecTy = getType();
3035 if (getElements().size() != vecTy.getSize()) {
3036 return emitOpError() << "operand count of " << getElements().size()
3037 << " doesn't match vector type " << vecTy
3038 << " element count of " << vecTy.getSize();
3039 }
3040
3041 const mlir::Type elementType = vecTy.getElementType();
3042 for (const mlir::Value element : getElements()) {
3043 if (element.getType() != elementType) {
3044 return emitOpError() << "operand type " << element.getType()
3045 << " doesn't match vector element type "
3046 << elementType;
3047 }
3048 }
3049
3050 return success();
3051}
3052
3053//===----------------------------------------------------------------------===//
3054// VecExtractOp
3055//===----------------------------------------------------------------------===//
3056
3057OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
3058 const auto vectorAttr =
3059 llvm::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec());
3060 if (!vectorAttr)
3061 return {};
3062
3063 const auto indexAttr =
3064 llvm::dyn_cast_if_present<cir::IntAttr>(adaptor.getIndex());
3065 if (!indexAttr)
3066 return {};
3067
3068 const mlir::ArrayAttr elements = vectorAttr.getElts();
3069 const uint64_t index = indexAttr.getUInt();
3070 if (index >= elements.size())
3071 return {};
3072
3073 return elements[index];
3074}
3075
3076//===----------------------------------------------------------------------===//
3077// VecCmpOp
3078//===----------------------------------------------------------------------===//
3079
3080OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
3081 auto lhsVecAttr =
3082 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getLhs());
3083 auto rhsVecAttr =
3084 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getRhs());
3085 if (!lhsVecAttr || !rhsVecAttr)
3086 return {};
3087
3088 mlir::Type inputElemTy =
3089 mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType();
3090 if (!isAnyIntegerOrFloatingPointType(inputElemTy))
3091 return {};
3092
3093 cir::CmpOpKind opKind = adaptor.getKind();
3094 mlir::ArrayAttr lhsVecElhs = lhsVecAttr.getElts();
3095 mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts();
3096 uint64_t vecSize = lhsVecElhs.size();
3097
3098 SmallVector<mlir::Attribute, 16> elements(vecSize);
3099 bool isIntAttr = vecSize && mlir::isa<cir::IntAttr>(lhsVecElhs[0]);
3100 for (uint64_t i = 0; i < vecSize; i++) {
3101 mlir::Attribute lhsAttr = lhsVecElhs[i];
3102 mlir::Attribute rhsAttr = rhsVecElhs[i];
3103 int cmpResult = 0;
3104 switch (opKind) {
3105 case cir::CmpOpKind::lt: {
3106 if (isIntAttr) {
3107 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
3108 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3109 } else {
3110 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <
3111 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3112 }
3113 break;
3114 }
3115 case cir::CmpOpKind::le: {
3116 if (isIntAttr) {
3117 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
3118 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3119 } else {
3120 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <=
3121 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3122 }
3123 break;
3124 }
3125 case cir::CmpOpKind::gt: {
3126 if (isIntAttr) {
3127 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
3128 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3129 } else {
3130 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >
3131 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3132 }
3133 break;
3134 }
3135 case cir::CmpOpKind::ge: {
3136 if (isIntAttr) {
3137 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
3138 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3139 } else {
3140 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >=
3141 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3142 }
3143 break;
3144 }
3145 case cir::CmpOpKind::eq: {
3146 if (isIntAttr) {
3147 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
3148 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3149 } else {
3150 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() ==
3151 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3152 }
3153 break;
3154 }
3155 case cir::CmpOpKind::ne: {
3156 if (isIntAttr) {
3157 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
3158 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3159 } else {
3160 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() !=
3161 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3162 }
3163 break;
3164 }
3165 }
3166
3167 elements[i] = cir::IntAttr::get(getType().getElementType(), cmpResult);
3168 }
3169
3170 return cir::ConstVectorAttr::get(
3171 getType(), mlir::ArrayAttr::get(getContext(), elements));
3172}
3173
3174//===----------------------------------------------------------------------===//
3175// VecShuffleOp
3176//===----------------------------------------------------------------------===//
3177
3178OpFoldResult cir::VecShuffleOp::fold(FoldAdaptor adaptor) {
3179 auto vec1Attr =
3180 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec1());
3181 auto vec2Attr =
3182 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec2());
3183 if (!vec1Attr || !vec2Attr)
3184 return {};
3185
3186 mlir::Type vec1ElemTy =
3187 mlir::cast<cir::VectorType>(vec1Attr.getType()).getElementType();
3188
3189 mlir::ArrayAttr vec1Elts = vec1Attr.getElts();
3190 mlir::ArrayAttr vec2Elts = vec2Attr.getElts();
3191 mlir::ArrayAttr indicesElts = adaptor.getIndices();
3192
3194 elements.reserve(indicesElts.size());
3195
3196 uint64_t vec1Size = vec1Elts.size();
3197 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
3198 if (idxAttr.getSInt() == -1) {
3199 elements.push_back(cir::UndefAttr::get(vec1ElemTy));
3200 continue;
3201 }
3202
3203 uint64_t idxValue = idxAttr.getUInt();
3204 elements.push_back(idxValue < vec1Size ? vec1Elts[idxValue]
3205 : vec2Elts[idxValue - vec1Size]);
3206 }
3207
3208 return cir::ConstVectorAttr::get(
3209 getType(), mlir::ArrayAttr::get(getContext(), elements));
3210}
3211
3212LogicalResult cir::VecShuffleOp::verify() {
3213 // The number of elements in the indices array must match the number of
3214 // elements in the result type.
3215 if (getIndices().size() != getResult().getType().getSize()) {
3216 return emitOpError() << ": the number of elements in " << getIndices()
3217 << " and " << getResult().getType() << " don't match";
3218 }
3219
3220 // The element types of the two input vectors and of the result type must
3221 // match.
3222 if (getVec1().getType().getElementType() !=
3223 getResult().getType().getElementType()) {
3224 return emitOpError() << ": element types of " << getVec1().getType()
3225 << " and " << getResult().getType() << " don't match";
3226 }
3227
3228 const uint64_t maxValidIndex =
3229 getVec1().getType().getSize() + getVec2().getType().getSize() - 1;
3230 if (llvm::any_of(
3231 getIndices().getAsRange<cir::IntAttr>(), [&](cir::IntAttr idxAttr) {
3232 return idxAttr.getSInt() != -1 && idxAttr.getUInt() > maxValidIndex;
3233 })) {
3234 return emitOpError() << ": index for __builtin_shufflevector must be "
3235 "less than the total number of vector elements";
3236 }
3237 return success();
3238}
3239
3240//===----------------------------------------------------------------------===//
3241// VecShuffleDynamicOp
3242//===----------------------------------------------------------------------===//
3243
3244OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
3245 mlir::Attribute vec = adaptor.getVec();
3246 mlir::Attribute indices = adaptor.getIndices();
3247 if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec) &&
3248 mlir::isa_and_nonnull<cir::ConstVectorAttr>(indices)) {
3249 auto vecAttr = mlir::cast<cir::ConstVectorAttr>(vec);
3250 auto indicesAttr = mlir::cast<cir::ConstVectorAttr>(indices);
3251
3252 mlir::ArrayAttr vecElts = vecAttr.getElts();
3253 mlir::ArrayAttr indicesElts = indicesAttr.getElts();
3254
3255 const uint64_t numElements = vecElts.size();
3256
3258 elements.reserve(numElements);
3259
3260 const uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1;
3261 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
3262 uint64_t idxValue = idxAttr.getUInt();
3263 uint64_t newIdx = idxValue & maskBits;
3264 elements.push_back(vecElts[newIdx]);
3265 }
3266
3267 return cir::ConstVectorAttr::get(
3268 getType(), mlir::ArrayAttr::get(getContext(), elements));
3269 }
3270
3271 return {};
3272}
3273
3274LogicalResult cir::VecShuffleDynamicOp::verify() {
3275 // The number of elements in the two input vectors must match.
3276 if (getVec().getType().getSize() !=
3277 mlir::cast<cir::VectorType>(getIndices().getType()).getSize()) {
3278 return emitOpError() << ": the number of elements in " << getVec().getType()
3279 << " and " << getIndices().getType() << " don't match";
3280 }
3281 return success();
3282}
3283
3284//===----------------------------------------------------------------------===//
3285// VecTernaryOp
3286//===----------------------------------------------------------------------===//
3287
3288LogicalResult cir::VecTernaryOp::verify() {
3289 // Verify that the condition operand has the same number of elements as the
3290 // other operands. (The automatic verification already checked that all
3291 // operands are vector types and that the second and third operands are the
3292 // same type.)
3293 if (getCond().getType().getSize() != getLhs().getType().getSize()) {
3294 return emitOpError() << ": the number of elements in "
3295 << getCond().getType() << " and " << getLhs().getType()
3296 << " don't match";
3297 }
3298 return success();
3299}
3300
3301OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
3302 mlir::Attribute cond = adaptor.getCond();
3303 mlir::Attribute lhs = adaptor.getLhs();
3304 mlir::Attribute rhs = adaptor.getRhs();
3305
3306 if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
3307 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
3308 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
3309 return {};
3310 auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
3311 auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
3312 auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
3313
3314 mlir::ArrayAttr condElts = condVec.getElts();
3315
3317 elements.reserve(condElts.size());
3318
3319 for (const auto &[idx, condAttr] :
3320 llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
3321 if (condAttr.getSInt()) {
3322 elements.push_back(lhsVec.getElts()[idx]);
3323 } else {
3324 elements.push_back(rhsVec.getElts()[idx]);
3325 }
3326 }
3327
3328 cir::VectorType vecTy = getLhs().getType();
3329 return cir::ConstVectorAttr::get(
3330 vecTy, mlir::ArrayAttr::get(getContext(), elements));
3331}
3332
3333//===----------------------------------------------------------------------===//
3334// ComplexCreateOp
3335//===----------------------------------------------------------------------===//
3336
3337LogicalResult cir::ComplexCreateOp::verify() {
3338 if (getType().getElementType() != getReal().getType()) {
3339 emitOpError()
3340 << "operand type of cir.complex.create does not match its result type";
3341 return failure();
3342 }
3343
3344 return success();
3345}
3346
3347OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
3348 mlir::Attribute real = adaptor.getReal();
3349 mlir::Attribute imag = adaptor.getImag();
3350 if (!real || !imag)
3351 return {};
3352
3353 // When both of real and imag are constants, we can fold the operation into an
3354 // `#cir.const_complex` operation.
3355 auto realAttr = mlir::cast<mlir::TypedAttr>(real);
3356 auto imagAttr = mlir::cast<mlir::TypedAttr>(imag);
3357 return cir::ConstComplexAttr::get(realAttr, imagAttr);
3358}
3359
3360//===----------------------------------------------------------------------===//
3361// ComplexRealOp
3362//===----------------------------------------------------------------------===//
3363
3364LogicalResult cir::ComplexRealOp::verify() {
3365 mlir::Type operandTy = getOperand().getType();
3366 if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
3367 operandTy = complexOperandTy.getElementType();
3368
3369 if (getType() != operandTy) {
3370 emitOpError() << ": result type does not match operand type";
3371 return failure();
3372 }
3373
3374 return success();
3375}
3376
3377OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
3378 if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
3379 return nullptr;
3380
3381 if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
3382 return complexCreateOp.getOperand(0);
3383
3384 auto complex =
3385 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
3386 return complex ? complex.getReal() : nullptr;
3387}
3388
3389//===----------------------------------------------------------------------===//
3390// ComplexImagOp
3391//===----------------------------------------------------------------------===//
3392
3393LogicalResult cir::ComplexImagOp::verify() {
3394 mlir::Type operandTy = getOperand().getType();
3395 if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
3396 operandTy = complexOperandTy.getElementType();
3397
3398 if (getType() != operandTy) {
3399 emitOpError() << ": result type does not match operand type";
3400 return failure();
3401 }
3402
3403 return success();
3404}
3405
3406OpFoldResult cir::ComplexImagOp::fold(FoldAdaptor adaptor) {
3407 if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
3408 return nullptr;
3409
3410 if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
3411 return complexCreateOp.getOperand(1);
3412
3413 auto complex =
3414 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
3415 return complex ? complex.getImag() : nullptr;
3416}
3417
3418//===----------------------------------------------------------------------===//
3419// ComplexRealPtrOp
3420//===----------------------------------------------------------------------===//
3421
3422LogicalResult cir::ComplexRealPtrOp::verify() {
3423 mlir::Type resultPointeeTy = getType().getPointee();
3424 cir::PointerType operandPtrTy = getOperand().getType();
3425 auto operandPointeeTy =
3426 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
3427
3428 if (resultPointeeTy != operandPointeeTy.getElementType()) {
3429 return emitOpError() << ": result type does not match operand type";
3430 }
3431
3432 return success();
3433}
3434
3435//===----------------------------------------------------------------------===//
3436// ComplexImagPtrOp
3437//===----------------------------------------------------------------------===//
3438
3439LogicalResult cir::ComplexImagPtrOp::verify() {
3440 mlir::Type resultPointeeTy = getType().getPointee();
3441 cir::PointerType operandPtrTy = getOperand().getType();
3442 auto operandPointeeTy =
3443 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
3444
3445 if (resultPointeeTy != operandPointeeTy.getElementType()) {
3446 return emitOpError()
3447 << "cir.complex.imag_ptr result type does not match operand type";
3448 }
3449 return success();
3450}
3451
3452//===----------------------------------------------------------------------===//
3453// Bit manipulation operations
3454//===----------------------------------------------------------------------===//
3455
3456static OpFoldResult
3457foldUnaryBitOp(mlir::Attribute inputAttr,
3458 llvm::function_ref<llvm::APInt(const llvm::APInt &)> func,
3459 bool poisonZero = false) {
3460 if (mlir::isa_and_present<cir::PoisonAttr>(inputAttr)) {
3461 // Propagate poison value
3462 return inputAttr;
3463 }
3464
3465 auto input = mlir::dyn_cast_if_present<IntAttr>(inputAttr);
3466 if (!input)
3467 return nullptr;
3468
3469 llvm::APInt inputValue = input.getValue();
3470 if (poisonZero && inputValue.isZero())
3471 return cir::PoisonAttr::get(input.getType());
3472
3473 llvm::APInt resultValue = func(inputValue);
3474 return IntAttr::get(input.getType(), resultValue);
3475}
3476
3477OpFoldResult BitClrsbOp::fold(FoldAdaptor adaptor) {
3478 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3479 unsigned resultValue =
3480 inputValue.getBitWidth() - inputValue.getSignificantBits();
3481 return llvm::APInt(inputValue.getBitWidth(), resultValue);
3482 });
3483}
3484
3485OpFoldResult BitClzOp::fold(FoldAdaptor adaptor) {
3486 return foldUnaryBitOp(
3487 adaptor.getInput(),
3488 [](const llvm::APInt &inputValue) {
3489 unsigned resultValue = inputValue.countLeadingZeros();
3490 return llvm::APInt(inputValue.getBitWidth(), resultValue);
3491 },
3492 getPoisonZero());
3493}
3494
3495OpFoldResult BitCtzOp::fold(FoldAdaptor adaptor) {
3496 return foldUnaryBitOp(
3497 adaptor.getInput(),
3498 [](const llvm::APInt &inputValue) {
3499 return llvm::APInt(inputValue.getBitWidth(),
3500 inputValue.countTrailingZeros());
3501 },
3502 getPoisonZero());
3503}
3504
3505OpFoldResult BitFfsOp::fold(FoldAdaptor adaptor) {
3506 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3507 unsigned trailingZeros = inputValue.countTrailingZeros();
3508 unsigned result =
3509 trailingZeros == inputValue.getBitWidth() ? 0 : trailingZeros + 1;
3510 return llvm::APInt(inputValue.getBitWidth(), result);
3511 });
3512}
3513
3514OpFoldResult BitParityOp::fold(FoldAdaptor adaptor) {
3515 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3516 return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount() % 2);
3517 });
3518}
3519
3520OpFoldResult BitPopcountOp::fold(FoldAdaptor adaptor) {
3521 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3522 return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount());
3523 });
3524}
3525
3526OpFoldResult BitReverseOp::fold(FoldAdaptor adaptor) {
3527 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3528 return inputValue.reverseBits();
3529 });
3530}
3531
3532OpFoldResult ByteSwapOp::fold(FoldAdaptor adaptor) {
3533 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3534 return inputValue.byteSwap();
3535 });
3536}
3537
3538OpFoldResult RotateOp::fold(FoldAdaptor adaptor) {
3539 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()) ||
3540 mlir::isa_and_present<cir::PoisonAttr>(adaptor.getAmount())) {
3541 // Propagate poison values
3542 return cir::PoisonAttr::get(getType());
3543 }
3544
3545 auto input = mlir::dyn_cast_if_present<IntAttr>(adaptor.getInput());
3546 auto amount = mlir::dyn_cast_if_present<IntAttr>(adaptor.getAmount());
3547 if (!input && !amount)
3548 return nullptr;
3549
3550 // We could fold cir.rotate even if one of its two operands is not a constant:
3551 // - `cir.rotate left/right %0, 0` could be folded into just %0 even if %0
3552 // is not a constant.
3553 // - `cir.rotate left/right 0/0b111...111, %0` could be folded into 0 or
3554 // 0b111...111 even if %0 is not a constant.
3555
3556 llvm::APInt inputValue;
3557 if (input) {
3558 inputValue = input.getValue();
3559 if (inputValue.isZero() || inputValue.isAllOnes()) {
3560 // An input value of all 0s or all 1s will not change after rotation
3561 return input;
3562 }
3563 }
3564
3565 uint64_t amountValue;
3566 if (amount) {
3567 amountValue = amount.getValue().urem(getInput().getType().getWidth());
3568 if (amountValue == 0) {
3569 // A shift amount of 0 will not change the input value
3570 return getInput();
3571 }
3572 }
3573
3574 if (!input || !amount)
3575 return nullptr;
3576
3577 assert(inputValue.getBitWidth() == getInput().getType().getWidth() &&
3578 "input value must have the same bit width as the input type");
3579
3580 llvm::APInt resultValue;
3581 if (isRotateLeft())
3582 resultValue = inputValue.rotl(amountValue);
3583 else
3584 resultValue = inputValue.rotr(amountValue);
3585
3586 return IntAttr::get(input.getContext(), input.getType(), resultValue);
3587}
3588
3589//===----------------------------------------------------------------------===//
3590// InlineAsmOp
3591//===----------------------------------------------------------------------===//
3592
3593void cir::InlineAsmOp::print(OpAsmPrinter &p) {
3594 p << '(' << getAsmFlavor() << ", ";
3595 p.increaseIndent();
3596 p.printNewline();
3597
3598 llvm::SmallVector<std::string, 3> names{"out", "in", "in_out"};
3599 auto *nameIt = names.begin();
3600 auto *attrIt = getOperandAttrs().begin();
3601
3602 for (mlir::OperandRange ops : getAsmOperands()) {
3603 p << *nameIt << " = ";
3604
3605 p << '[';
3606 llvm::interleaveComma(llvm::make_range(ops.begin(), ops.end()), p,
3607 [&](Value value) {
3608 p.printOperand(value);
3609 p << " : " << value.getType();
3610 if (*attrIt)
3611 p << " (maybe_memory)";
3612 attrIt++;
3613 });
3614 p << "],";
3615 p.printNewline();
3616 ++nameIt;
3617 }
3618
3619 p << "{";
3620 p.printString(getAsmString());
3621 p << " ";
3622 p.printString(getConstraints());
3623 p << "}";
3624 p.decreaseIndent();
3625 p << ')';
3626 if (getSideEffects())
3627 p << " side_effects";
3628
3629 std::array elidedAttrs{
3630 llvm::StringRef("asm_flavor"), llvm::StringRef("asm_string"),
3631 llvm::StringRef("constraints"), llvm::StringRef("operand_attrs"),
3632 llvm::StringRef("operands_segments"), llvm::StringRef("side_effects")};
3633 p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs);
3634
3635 if (auto v = getRes())
3636 p << " -> " << v.getType();
3637}
3638
3639void cir::InlineAsmOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3640 ArrayRef<ValueRange> asmOperands,
3641 StringRef asmString, StringRef constraints,
3642 bool sideEffects, cir::AsmFlavor asmFlavor,
3643 ArrayRef<Attribute> operandAttrs) {
3644 // Set up the operands_segments for VariadicOfVariadic
3645 SmallVector<int32_t> segments;
3646 for (auto operandRange : asmOperands) {
3647 segments.push_back(operandRange.size());
3648 odsState.addOperands(operandRange);
3649 }
3650
3651 odsState.addAttribute(
3652 "operands_segments",
3653 DenseI32ArrayAttr::get(odsBuilder.getContext(), segments));
3654 odsState.addAttribute("asm_string", odsBuilder.getStringAttr(asmString));
3655 odsState.addAttribute("constraints", odsBuilder.getStringAttr(constraints));
3656 odsState.addAttribute("asm_flavor",
3657 AsmFlavorAttr::get(odsBuilder.getContext(), asmFlavor));
3658
3659 if (sideEffects)
3660 odsState.addAttribute("side_effects", odsBuilder.getUnitAttr());
3661
3662 odsState.addAttribute("operand_attrs", odsBuilder.getArrayAttr(operandAttrs));
3663}
3664
3665ParseResult cir::InlineAsmOp::parse(OpAsmParser &parser,
3666 OperationState &result) {
3668 llvm::SmallVector<int32_t> operandsGroupSizes;
3669 std::string asmString, constraints;
3670 Type resType;
3671 MLIRContext *ctxt = parser.getBuilder().getContext();
3672
3673 auto error = [&](const Twine &msg) -> LogicalResult {
3674 return parser.emitError(parser.getCurrentLocation(), msg);
3675 };
3676
3677 auto expected = [&](const std::string &c) {
3678 return error("expected '" + c + "'");
3679 };
3680
3681 if (parser.parseLParen().failed())
3682 return expected("(");
3683
3684 auto flavor = FieldParser<AsmFlavor, AsmFlavor>::parse(parser);
3685 if (failed(flavor))
3686 return error("Unknown AsmFlavor");
3687
3688 if (parser.parseComma().failed())
3689 return expected(",");
3690
3691 auto parseValue = [&](Value &v) {
3692 OpAsmParser::UnresolvedOperand op;
3693
3694 if (parser.parseOperand(op) || parser.parseColon())
3695 return error("can't parse operand");
3696
3697 Type typ;
3698 if (parser.parseType(typ).failed())
3699 return error("can't parse operand type");
3701 if (parser.resolveOperand(op, typ, tmp))
3702 return error("can't resolve operand");
3703 v = tmp[0];
3704 return mlir::success();
3705 };
3706
3707 auto parseOperands = [&](llvm::StringRef name) {
3708 if (parser.parseKeyword(name).failed())
3709 return error("expected " + name + " operands here");
3710 if (parser.parseEqual().failed())
3711 return expected("=");
3712 if (parser.parseLSquare().failed())
3713 return expected("[");
3714
3715 int size = 0;
3716 if (parser.parseOptionalRSquare().succeeded()) {
3717 operandsGroupSizes.push_back(size);
3718 if (parser.parseComma())
3719 return expected(",");
3720 return mlir::success();
3721 }
3722
3723 auto parseOperand = [&]() {
3724 Value val;
3725 if (parseValue(val).succeeded()) {
3726 result.operands.push_back(val);
3727 size++;
3728
3729 if (parser.parseOptionalLParen().failed()) {
3730 operandAttrs.push_back(mlir::Attribute());
3731 return mlir::success();
3732 }
3733
3734 if (parser.parseKeyword("maybe_memory").succeeded()) {
3735 operandAttrs.push_back(mlir::UnitAttr::get(ctxt));
3736 if (parser.parseRParen())
3737 return expected(")");
3738 return mlir::success();
3739 } else {
3740 return expected("maybe_memory");
3741 }
3742 }
3743 return mlir::failure();
3744 };
3745
3746 if (parser.parseCommaSeparatedList(parseOperand).failed())
3747 return mlir::failure();
3748
3749 if (parser.parseRSquare().failed() || parser.parseComma().failed())
3750 return expected("]");
3751 operandsGroupSizes.push_back(size);
3752 return mlir::success();
3753 };
3754
3755 if (parseOperands("out").failed() || parseOperands("in").failed() ||
3756 parseOperands("in_out").failed())
3757 return error("failed to parse operands");
3758
3759 if (parser.parseLBrace())
3760 return expected("{");
3761 if (parser.parseString(&asmString))
3762 return error("asm string parsing failed");
3763 if (parser.parseString(&constraints))
3764 return error("constraints string parsing failed");
3765 if (parser.parseRBrace())
3766 return expected("}");
3767 if (parser.parseRParen())
3768 return expected(")");
3769
3770 if (parser.parseOptionalKeyword("side_effects").succeeded())
3771 result.attributes.set("side_effects", UnitAttr::get(ctxt));
3772
3773 if (parser.parseOptionalArrow().succeeded() &&
3774 parser.parseType(resType).failed())
3775 return mlir::failure();
3776
3777 if (parser.parseOptionalAttrDict(result.attributes).failed())
3778 return mlir::failure();
3779
3780 result.attributes.set("asm_flavor", AsmFlavorAttr::get(ctxt, *flavor));
3781 result.attributes.set("asm_string", StringAttr::get(ctxt, asmString));
3782 result.attributes.set("constraints", StringAttr::get(ctxt, constraints));
3783 result.attributes.set("operand_attrs", ArrayAttr::get(ctxt, operandAttrs));
3784 result.getOrAddProperties<InlineAsmOp::Properties>().operands_segments =
3785 parser.getBuilder().getDenseI32ArrayAttr(operandsGroupSizes);
3786 if (resType)
3787 result.addTypes(TypeRange{resType});
3788
3789 return mlir::success();
3790}
3791
3792//===----------------------------------------------------------------------===//
3793// ThrowOp
3794//===----------------------------------------------------------------------===//
3795
3796mlir::LogicalResult cir::ThrowOp::verify() {
3797 // For the no-rethrow version, it must have at least the exception pointer.
3798 if (rethrows())
3799 return success();
3800
3801 if (getNumOperands() != 0) {
3802 if (getTypeInfo())
3803 return success();
3804 return emitOpError() << "'type_info' symbol attribute missing";
3805 }
3806
3807 return failure();
3808}
3809
3810//===----------------------------------------------------------------------===//
3811// AtomicFetchOp
3812//===----------------------------------------------------------------------===//
3813
3814LogicalResult cir::AtomicFetchOp::verify() {
3815 if (getBinop() != cir::AtomicFetchKind::Add &&
3816 getBinop() != cir::AtomicFetchKind::Sub &&
3817 getBinop() != cir::AtomicFetchKind::Max &&
3818 getBinop() != cir::AtomicFetchKind::Min &&
3819 !mlir::isa<cir::IntType>(getVal().getType()))
3820 return emitError("only atomic add, sub, max, and min operation could "
3821 "operate on floating-point values");
3822 return success();
3823}
3824
3825//===----------------------------------------------------------------------===//
3826// TypeInfoAttr
3827//===----------------------------------------------------------------------===//
3828
3829LogicalResult cir::TypeInfoAttr::verify(
3830 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
3831 ::mlir::Type type, ::mlir::ArrayAttr typeInfoData) {
3832
3833 if (cir::ConstRecordAttr::verify(emitError, type, typeInfoData).failed())
3834 return failure();
3835
3836 return success();
3837}
3838
3839//===----------------------------------------------------------------------===//
3840// TryOp
3841//===----------------------------------------------------------------------===//
3842
3843void cir::TryOp::getSuccessorRegions(
3844 mlir::RegionBranchPoint point,
3846 // The `try` and the `catchers` region branch back to the parent operation.
3847 if (!point.isParent()) {
3848 regions.push_back(RegionSuccessor::parent());
3849 return;
3850 }
3851
3852 regions.push_back(mlir::RegionSuccessor(&getTryRegion()));
3853
3854 // TODO(CIR): If we know a target function never throws a specific type, we
3855 // can remove the catch handler.
3856 for (mlir::Region &handlerRegion : this->getHandlerRegions())
3857 regions.push_back(mlir::RegionSuccessor(&handlerRegion));
3858}
3859
3860mlir::ValueRange cir::TryOp::getSuccessorInputs(RegionSuccessor successor) {
3861 return successor.isParent() ? ValueRange(getOperation()->getResults())
3862 : ValueRange();
3863}
3864
3865LogicalResult cir::TryOp::verify() {
3866 mlir::ArrayAttr handlerTypes = getHandlerTypes();
3867 if (!handlerTypes) {
3868 if (!getHandlerRegions().empty())
3869 return emitOpError(
3870 "handler regions must be empty when no handler types are present");
3871 return success();
3872 }
3873
3874 mlir::MutableArrayRef<mlir::Region> handlerRegions = getHandlerRegions();
3875
3876 // The parser and builder won't allow this to happen, but the loop below
3877 // relies on the sizes being the same, so we check it here.
3878 if (handlerRegions.size() != handlerTypes.size())
3879 return emitOpError(
3880 "number of handler regions and handler types must match");
3881
3882 for (const auto &[typeAttr, handlerRegion] :
3883 llvm::zip(handlerTypes, handlerRegions)) {
3884 // Verify that handler regions have a !cir.eh_token block argument.
3885 mlir::Block &entryBlock = handlerRegion.front();
3886 if (entryBlock.getNumArguments() != 1 ||
3887 !mlir::isa<cir::EhTokenType>(entryBlock.getArgument(0).getType()))
3888 return emitOpError(
3889 "handler region must have a single '!cir.eh_token' argument");
3890
3891 // The unwind region does not require a cir.begin_catch.
3892 if (mlir::isa<cir::UnwindAttr>(typeAttr))
3893 continue;
3894
3895 if (entryBlock.empty() || !mlir::isa<cir::BeginCatchOp>(entryBlock.front()))
3896 return emitOpError(
3897 "catch handler region must start with 'cir.begin_catch'");
3898 }
3899
3900 return success();
3901}
3902
3903static void
3904printTryHandlerRegions(mlir::OpAsmPrinter &printer, cir::TryOp op,
3905 mlir::MutableArrayRef<mlir::Region> handlerRegions,
3906 mlir::ArrayAttr handlerTypes) {
3907 if (!handlerTypes)
3908 return;
3909
3910 for (const auto [typeIdx, typeAttr] : llvm::enumerate(handlerTypes)) {
3911 if (typeIdx)
3912 printer << " ";
3913
3914 if (mlir::isa<cir::CatchAllAttr>(typeAttr)) {
3915 printer << "catch all ";
3916 } else if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
3917 printer << "unwind ";
3918 } else {
3919 printer << "catch [type ";
3920 printer.printAttribute(typeAttr);
3921 printer << "] ";
3922 }
3923
3924 // Print the handler region's !cir.eh_token block argument.
3925 mlir::Region &region = handlerRegions[typeIdx];
3926 if (!region.empty() && region.front().getNumArguments() > 0) {
3927 printer << "(";
3928 printer.printRegionArgument(region.front().getArgument(0));
3929 printer << ") ";
3930 }
3931
3932 printer.printRegion(region,
3933 /*printEntryBLockArgs=*/false,
3934 /*printBlockTerminators=*/true);
3935 }
3936}
3937
3938static mlir::ParseResult parseTryHandlerRegions(
3939 mlir::OpAsmParser &parser,
3940 llvm::SmallVectorImpl<std::unique_ptr<mlir::Region>> &handlerRegions,
3941 mlir::ArrayAttr &handlerTypes) {
3942
3943 auto parseCheckedCatcherRegion = [&]() -> mlir::ParseResult {
3944 handlerRegions.emplace_back(new mlir::Region);
3945
3946 mlir::Region &currRegion = *handlerRegions.back();
3947
3948 // Parse the required region argument: (%eh_token : !cir.eh_token)
3950 if (parser.parseLParen())
3951 return failure();
3952 mlir::OpAsmParser::Argument arg;
3953 if (parser.parseArgument(arg, /*allowType=*/true))
3954 return failure();
3955 regionArgs.push_back(arg);
3956 if (parser.parseRParen())
3957 return failure();
3958
3959 mlir::SMLoc regionLoc = parser.getCurrentLocation();
3960 if (parser.parseRegion(currRegion, regionArgs)) {
3961 handlerRegions.clear();
3962 return failure();
3963 }
3964
3965 if (currRegion.empty())
3966 return parser.emitError(regionLoc, "handler region shall not be empty");
3967
3968 if (!(currRegion.back().mightHaveTerminator() &&
3969 currRegion.back().getTerminator()))
3970 return parser.emitError(
3971 regionLoc, "blocks are expected to be explicitly terminated");
3972
3973 return success();
3974 };
3975
3976 bool hasCatchAll = false;
3978 while (parser.parseOptionalKeyword("catch").succeeded()) {
3979 bool hasLSquare = parser.parseOptionalLSquare().succeeded();
3980
3981 llvm::StringRef attrStr;
3982 if (parser.parseOptionalKeyword(&attrStr, {"all", "type"}).failed())
3983 return parser.emitError(parser.getCurrentLocation(),
3984 "expected 'all' or 'type' keyword");
3985
3986 bool isCatchAll = attrStr == "all";
3987 if (isCatchAll) {
3988 if (hasCatchAll)
3989 return parser.emitError(parser.getCurrentLocation(),
3990 "can't have more than one catch all");
3991 hasCatchAll = true;
3992 }
3993
3994 mlir::Attribute exceptionRTTIAttr;
3995 if (!isCatchAll && parser.parseAttribute(exceptionRTTIAttr).failed())
3996 return parser.emitError(parser.getCurrentLocation(),
3997 "expected valid RTTI info attribute");
3998
3999 catcherAttrs.push_back(isCatchAll
4000 ? cir::CatchAllAttr::get(parser.getContext())
4001 : exceptionRTTIAttr);
4002
4003 if (hasLSquare && isCatchAll)
4004 return parser.emitError(parser.getCurrentLocation(),
4005 "catch all dosen't need RTTI info attribute");
4006
4007 if (hasLSquare && parser.parseRSquare().failed())
4008 return parser.emitError(parser.getCurrentLocation(),
4009 "expected `]` after RTTI info attribute");
4010
4011 if (parseCheckedCatcherRegion().failed())
4012 return mlir::failure();
4013 }
4014
4015 if (parser.parseOptionalKeyword("unwind").succeeded()) {
4016 if (hasCatchAll)
4017 return parser.emitError(parser.getCurrentLocation(),
4018 "unwind can't be used with catch all");
4019
4020 catcherAttrs.push_back(cir::UnwindAttr::get(parser.getContext()));
4021 if (parseCheckedCatcherRegion().failed())
4022 return mlir::failure();
4023 }
4024
4025 handlerTypes = parser.getBuilder().getArrayAttr(catcherAttrs);
4026 return mlir::success();
4027}
4028
4029//===----------------------------------------------------------------------===//
4030// EhTypeIdOp
4031//===----------------------------------------------------------------------===//
4032
4033LogicalResult
4034cir::EhTypeIdOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
4035 Operation *op = symbolTable.lookupNearestSymbolFrom(*this, getTypeSymAttr());
4036 if (!isa_and_nonnull<GlobalOp>(op))
4037 return emitOpError("'")
4038 << getTypeSym() << "' does not reference a valid cir.global";
4039 return success();
4040}
4041
4042//===----------------------------------------------------------------------===//
4043// EhDispatchOp
4044//===----------------------------------------------------------------------===//
4045
4046static ParseResult
4047parseEhDispatchDestinations(OpAsmParser &parser, mlir::ArrayAttr &catchTypes,
4048 SmallVectorImpl<Block *> &catchDestinations,
4049 Block *&defaultDestination,
4050 mlir::UnitAttr &defaultIsCatchAll) {
4051 // Parse: [ ... ]
4052 if (parser.parseLSquare())
4053 return failure();
4054
4055 SmallVector<Attribute> handlerTypes;
4056 bool hasCatchAll = false;
4057 bool hasUnwind = false;
4058
4059 // Parse handler list.
4060 auto parseHandler = [&]() -> ParseResult {
4061 // Check for 'catch_all' or 'unwind' keywords.
4062 if (succeeded(parser.parseOptionalKeyword("catch_all"))) {
4063 if (hasCatchAll)
4064 return parser.emitError(parser.getCurrentLocation(),
4065 "duplicate 'catch_all' handler");
4066 if (hasUnwind)
4067 return parser.emitError(parser.getCurrentLocation(),
4068 "cannot have both 'catch_all' and 'unwind'");
4069 hasCatchAll = true;
4070
4071 if (parser.parseColon().failed())
4072 return failure();
4073
4074 if (parser.parseSuccessor(defaultDestination).failed())
4075 return failure();
4076
4077 return success();
4078 }
4079
4080 if (succeeded(parser.parseOptionalKeyword("unwind"))) {
4081 if (hasUnwind)
4082 return parser.emitError(parser.getCurrentLocation(),
4083 "duplicate 'unwind' handler");
4084 if (hasCatchAll)
4085 return parser.emitError(parser.getCurrentLocation(),
4086 "cannot have both 'catch_all' and 'unwind'");
4087 hasUnwind = true;
4088
4089 if (parser.parseColon().failed())
4090 return failure();
4091
4092 if (parser.parseSuccessor(defaultDestination).failed())
4093 return failure();
4094 return success();
4095 }
4096
4097 // Otherwise, expect 'catch(<attr> : <type>) : ^block'.
4098 // The 'catch(...)' wrapper allows the attribute to include its type
4099 // without conflicting with the ':' used for the block destination.
4100 if (parser.parseKeyword("catch").failed())
4101 return failure();
4102
4103 if (parser.parseLParen().failed())
4104 return failure();
4105
4106 mlir::Attribute catchTypeAttr;
4107 if (parser.parseAttribute(catchTypeAttr).failed())
4108 return failure();
4109 handlerTypes.push_back(catchTypeAttr);
4110
4111 if (parser.parseRParen().failed())
4112 return failure();
4113
4114 if (parser.parseColon().failed())
4115 return failure();
4116
4117 Block *dest;
4118 if (parser.parseSuccessor(dest).failed())
4119 return failure();
4120 catchDestinations.push_back(dest);
4121 return success();
4122 };
4123
4124 if (parser.parseCommaSeparatedList(parseHandler).failed())
4125 return failure();
4126
4127 if (parser.parseRSquare().failed())
4128 return failure();
4129
4130 // Verify we have catch_all or unwind.
4131 if (!hasCatchAll && !hasUnwind)
4132 return parser.emitError(parser.getCurrentLocation(),
4133 "must have either 'catch_all' or 'unwind' handler");
4134
4135 // Add attributes and successors.
4136 if (!handlerTypes.empty())
4137 catchTypes = parser.getBuilder().getArrayAttr(handlerTypes);
4138
4139 if (hasCatchAll)
4140 defaultIsCatchAll = parser.getBuilder().getUnitAttr();
4141
4142 return success();
4143}
4144
4145static void printEhDispatchDestinations(OpAsmPrinter &p, cir::EhDispatchOp op,
4146 mlir::ArrayAttr catchTypes,
4147 SuccessorRange catchDestinations,
4148 Block *defaultDestination,
4149 mlir::UnitAttr defaultIsCatchAll) {
4150 p << " [";
4151 p.printNewline();
4152
4153 // If we have at least one catch type, print them.
4154 if (catchTypes) {
4155 // Print type handlers using 'catch(<attr>) : ^block' syntax.
4156 llvm::interleave(
4157 llvm::zip(catchTypes, catchDestinations),
4158 [&](auto i) {
4159 p << " catch(";
4160 p.printAttribute(std::get<0>(i));
4161 p << ") : ";
4162 p.printSuccessor(std::get<1>(i));
4163 },
4164 [&] {
4165 p << ',';
4166 p.printNewline();
4167 });
4168
4169 p << ", ";
4170 p.printNewline();
4171 }
4172
4173 // Print catch_all or unwind handler.
4174 if (defaultIsCatchAll)
4175 p << " catch_all : ";
4176 else
4177 p << " unwind : ";
4178 p.printSuccessor(defaultDestination);
4179 p.printNewline();
4180
4181 p << "]";
4182}
4183
4184//===----------------------------------------------------------------------===//
4185// TableGen'd op method definitions
4186//===----------------------------------------------------------------------===//
4187
4188#define GET_OP_CLASSES
4189#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
static const MemRegion * getRegion(const CallEvent &Call, const MutexDescriptor &Descriptor, bool IsLock)
static void printEhDispatchDestinations(OpAsmPrinter &p, cir::EhDispatchOp op, mlir::ArrayAttr catchTypes, SuccessorRange catchDestinations, Block *defaultDestination, mlir::UnitAttr defaultIsCatchAll)
static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op, cir::FuncOp function)
static LogicalResult verifyMemberPtrCast(Operation *op, mlir::Value src, mlir::Type resultTy)
static bool isBoolNot(cir::UnaryOp op)
static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, mlir::OperationState &result, bool hasDestinationBlocks=false)
static bool isIntOrBoolCast(cir::CastOp op)
static ParseResult parseEhDispatchDestinations(OpAsmParser &parser, mlir::ArrayAttr &catchTypes, SmallVectorImpl< Block * > &catchDestinations, Block *&defaultDestination, mlir::UnitAttr &defaultIsCatchAll)
static void printConstant(OpAsmPrinter &p, Attribute value)
static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser, mlir::Region &region)
ParseResult parseInlineKindAttr(OpAsmParser &parser, cir::InlineKindAttr &inlineKindAttr)
void printInlineKindAttr(OpAsmPrinter &p, cir::InlineKindAttr inlineKindAttr)
void printVisibilityAttr(OpAsmPrinter &printer, cir::VisibilityAttr &visibility)
static ParseResult parseSwitchFlatOpCases(OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< llvm::SmallVector< OpAsmParser::UnresolvedOperand > > &caseOperands, SmallVectorImpl< llvm::SmallVector< Type > > &caseOperandTypes)
<cases> ::= [ (case (, case )* )?
static void printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym, mlir::Value indirectCallee, mlir::OpAsmPrinter &printer, bool isNothrow, cir::SideEffect sideEffect, ArrayAttr argAttrs, ArrayAttr resAttrs, mlir::Block *normalDest=nullptr, mlir::Block *unwindDest=nullptr)
static LogicalResult verifyCallCommInSymbolUses(mlir::Operation *op, SymbolTableCollection &symbolTable)
static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region, SMLoc errLoc)
static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValueAttr, mlir::Region &ctorRegion, mlir::Region &dtorRegion)
void printIndirectBrOpSucessors(OpAsmPrinter &p, cir::IndirectBrOp op, Type flagType, SuccessorRange succs, OperandRangeRange succOperands, const TypeRangeRange &succOperandsTypes)
static OpFoldResult foldUnaryBitOp(mlir::Attribute inputAttr, llvm::function_ref< llvm::APInt(const llvm::APInt &)> func, bool poisonZero=false)
static llvm::StringRef getLinkageAttrNameString()
Returns the name used for the linkage attribute.
static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue)
Parse an enum from the keyword, or default to the provided default value.
static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op, Type flagType, mlir::ArrayAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static mlir::ParseResult parseTryCallDestinations(mlir::OpAsmParser &parser, mlir::OperationState &result)
static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op, TypeAttr type, Attribute initAttr, mlir::Region &ctorRegion, mlir::Region &dtorRegion)
static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result)
Parse an enum from the keyword, return failure if the keyword is not found.
static Value tryFoldCastChain(cir::CastOp op)
void parseVisibilityAttr(OpAsmParser &parser, cir::VisibilityAttr &visibility)
static void printTryHandlerRegions(mlir::OpAsmPrinter &printer, cir::TryOp op, mlir::MutableArrayRef< mlir::Region > handlerRegions, mlir::ArrayAttr handlerTypes)
ParseResult parseIndirectBrOpSucessors(OpAsmParser &parser, Type &flagType, SmallVectorImpl< Block * > &succOperandBlocks, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &succOperands, SmallVectorImpl< SmallVector< Type > > &succOperandsTypes)
static bool omitRegionTerm(mlir::Region &r)
static LogicalResult verifyBinaryOverflowOp(mlir::Operation *op, bool noSignedWrap, bool noUnsignedWrap, bool saturated, bool hasSat)
static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer, cir::ScopeOp &op, mlir::Region &region)
static ParseResult parseConstantValue(OpAsmParser &parser, mlir::Attribute &valueAttr)
static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, mlir::Attribute attrType)
static mlir::ParseResult parseTryHandlerRegions(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< std::unique_ptr< mlir::Region > > &handlerRegions, mlir::ArrayAttr &handlerTypes)
#define REGISTER_ENUM_TYPE(Ty)
static int parseOptionalKeywordAlternative(AsmParser &parser, ArrayRef< llvm::StringRef > keywords)
llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> BuilderCallbackRef
Definition CIRDialect.h:37
llvm::function_ref< void( mlir::OpBuilder &, mlir::Location, mlir::OperationState &)> BuilderOpStateCallbackRef
Definition CIRDialect.h:39
static std::optional< NonLoc > getIndex(ProgramStateRef State, const ElementRegion *ER, CharKind CK)
static Decl::Kind getKind(const Decl *D)
TokenType getType() const
Returns the token's type, e.g.
*collection of selector each with an associated kind and an ordered *collection of selectors A selector has a an optional score condition
*collection of selector each with an associated kind and an ordered *collection of selectors A selector has a kind
__device__ __2f16 float c
void buildTerminatedBody(mlir::OpBuilder &builder, mlir::Location loc)
const internal::VariadicAllOfMatcher< Attr > attr
const AstTypeMatcher< RecordType > recordType
StringRef getName(const HeaderType T)
Definition HeaderFile.h:38
RangeSelector name(std::string ID)
Given a node with a "name", (like NamedDecl, DeclRefExpr, CxxCtorInitializer, and TypeLoc) selects th...
nullptr
This class represents a compute construct, representing a 'Kind' of ‘parallel’, 'serial',...
const half4 dst(half4 Src0, half4 Src1)
__DEVICE__ _Tp arg(const std::complex< _Tp > &__c)
static bool memberFuncPtrCast()
static bool addressSpace()
static bool opCallCallConv()
static bool opScopeCleanupRegion()
static bool supportIFuncAttr()