clang 23.0.0git
CIRDialect.cpp
Go to the documentation of this file.
1//===- CIRDialect.cpp - MLIR CIR ops implementation -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the CIR dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
18
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/DialectImplementation.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Interfaces/ControlFlowInterfaces.h"
23#include "mlir/Interfaces/FunctionImplementation.h"
24#include "mlir/Support/LLVM.h"
25
26#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
27#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
29#include "llvm/ADT/SetOperations.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/LogicalResult.h"
33
34using namespace mlir;
35using namespace cir;
36
37//===----------------------------------------------------------------------===//
38// CIR Dialect
39//===----------------------------------------------------------------------===//
40namespace {
41struct CIROpAsmDialectInterface : public OpAsmDialectInterface {
42 using OpAsmDialectInterface::OpAsmDialectInterface;
43
44 AliasResult getAlias(Type type, raw_ostream &os) const final {
45 if (auto recordType = dyn_cast<cir::RecordType>(type)) {
46 StringAttr nameAttr = recordType.getName();
47 if (!nameAttr)
48 os << "rec_anon_" << recordType.getKindAsStr();
49 else
50 os << "rec_" << nameAttr.getValue();
51 return AliasResult::OverridableAlias;
52 }
53 if (auto intType = dyn_cast<cir::IntType>(type)) {
54 // We only provide alias for standard integer types (i.e. integer types
55 // whose width is a power of 2 and at least 8).
56 unsigned width = intType.getWidth();
57 if (width < 8 || !llvm::isPowerOf2_32(width))
58 return AliasResult::NoAlias;
59 os << intType.getAlias();
60 return AliasResult::OverridableAlias;
61 }
62 if (auto voidType = dyn_cast<cir::VoidType>(type)) {
63 os << voidType.getAlias();
64 return AliasResult::OverridableAlias;
65 }
66
67 return AliasResult::NoAlias;
68 }
69
70 AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
71 if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr)) {
72 os << (boolAttr.getValue() ? "true" : "false");
73 return AliasResult::FinalAlias;
74 }
75 if (auto bitfield = mlir::dyn_cast<cir::BitfieldInfoAttr>(attr)) {
76 os << "bfi_" << bitfield.getName().str();
77 return AliasResult::FinalAlias;
78 }
79 if (auto dynCastInfoAttr = mlir::dyn_cast<cir::DynamicCastInfoAttr>(attr)) {
80 os << dynCastInfoAttr.getAlias();
81 return AliasResult::FinalAlias;
82 }
83 if (auto cmpThreeWayInfoAttr =
84 mlir::dyn_cast<cir::CmpThreeWayInfoAttr>(attr)) {
85 os << cmpThreeWayInfoAttr.getAlias();
86 return AliasResult::FinalAlias;
87 }
88 return AliasResult::NoAlias;
89 }
90};
91} // namespace
92
93void cir::CIRDialect::initialize() {
94 registerTypes();
95 registerAttributes();
96 addOperations<
97#define GET_OP_LIST
98#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
99 >();
100 addInterfaces<CIROpAsmDialectInterface>();
101}
102
103Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
104 mlir::Attribute value,
105 mlir::Type type,
106 mlir::Location loc) {
107 return cir::ConstantOp::create(builder, loc, type,
108 mlir::cast<mlir::TypedAttr>(value));
109}
110
111//===----------------------------------------------------------------------===//
112// Helpers
113//===----------------------------------------------------------------------===//
114
115// Parses one of the keywords provided in the list `keywords` and returns the
116// position of the parsed keyword in the list. If none of the keywords from the
117// list is parsed, returns -1.
118static int parseOptionalKeywordAlternative(AsmParser &parser,
119 ArrayRef<llvm::StringRef> keywords) {
120 for (auto en : llvm::enumerate(keywords)) {
121 if (succeeded(parser.parseOptionalKeyword(en.value())))
122 return en.index();
123 }
124 return -1;
125}
126
127namespace {
128template <typename Ty> struct EnumTraits {};
129
130#define REGISTER_ENUM_TYPE(Ty) \
131 template <> struct EnumTraits<cir::Ty> { \
132 static llvm::StringRef stringify(cir::Ty value) { \
133 return stringify##Ty(value); \
134 } \
135 static unsigned getMaxEnumVal() { return cir::getMaxEnumValFor##Ty(); } \
136 }
137
138REGISTER_ENUM_TYPE(GlobalLinkageKind);
139REGISTER_ENUM_TYPE(VisibilityKind);
140REGISTER_ENUM_TYPE(SideEffect);
141REGISTER_ENUM_TYPE(CallingConv);
142} // namespace
143
144/// Parse an enum from the keyword, or default to the provided default value.
145/// The return type is the enum type by default, unless overriden with the
146/// second template argument.
147template <typename EnumTy, typename RetTy = EnumTy>
148static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue) {
150 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
151 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
152
153 int index = parseOptionalKeywordAlternative(parser, names);
154 if (index == -1)
155 return static_cast<RetTy>(defaultValue);
156 return static_cast<RetTy>(index);
157}
158
159/// Parse an enum from the keyword, return failure if the keyword is not found.
160template <typename EnumTy, typename RetTy = EnumTy>
161static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result) {
163 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
164 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
165
166 int index = parseOptionalKeywordAlternative(parser, names);
167 if (index == -1)
168 return failure();
169 result = static_cast<RetTy>(index);
170 return success();
171}
172
173// Check if a region's termination omission is valid and, if so, creates and
174// inserts the omitted terminator into the region.
175static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region,
176 SMLoc errLoc) {
177 Location eLoc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
178 OpBuilder builder(parser.getBuilder().getContext());
179
180 // Insert empty block in case the region is empty to ensure the terminator
181 // will be inserted
182 if (region.empty())
183 builder.createBlock(&region);
184
185 Block &block = region.back();
186 // Region is properly terminated: nothing to do.
187 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>())
188 return success();
189
190 // Check for invalid terminator omissions.
191 if (!region.hasOneBlock())
192 return parser.emitError(errLoc,
193 "multi-block region must not omit terminator");
194
195 // Terminator was omitted correctly: recreate it.
196 builder.setInsertionPointToEnd(&block);
197 cir::YieldOp::create(builder, eLoc);
198 return success();
199}
200
201// True if the region's terminator should be omitted.
202static bool omitRegionTerm(mlir::Region &r) {
203 const auto singleNonEmptyBlock = r.hasOneBlock() && !r.back().empty();
204 const auto yieldsNothing = [&r]() {
205 auto y = dyn_cast<cir::YieldOp>(r.back().getTerminator());
206 return y && y.getArgs().empty();
207 };
208 return singleNonEmptyBlock && yieldsNothing();
209}
210
211//===----------------------------------------------------------------------===//
212// InlineKindAttr (FIXME: remove once FuncOp uses assembly format)
213//===----------------------------------------------------------------------===//
214
215ParseResult parseInlineKindAttr(OpAsmParser &parser,
216 cir::InlineKindAttr &inlineKindAttr) {
217 // Static list of possible inline kind keywords
218 static constexpr llvm::StringRef keywords[] = {"no_inline", "always_inline",
219 "inline_hint"};
220
221 // Parse the inline kind keyword (optional)
222 llvm::StringRef keyword;
223 if (parser.parseOptionalKeyword(&keyword, keywords).failed()) {
224 // Not an inline kind keyword, leave inlineKindAttr empty
225 return success();
226 }
227
228 // Parse the enum value from the keyword
229 auto inlineKindResult = ::cir::symbolizeEnum<::cir::InlineKind>(keyword);
230 if (!inlineKindResult) {
231 return parser.emitError(parser.getCurrentLocation(), "expected one of [")
232 << llvm::join(llvm::ArrayRef(keywords), ", ")
233 << "] for inlineKind, got: " << keyword;
234 }
235
236 inlineKindAttr =
237 ::cir::InlineKindAttr::get(parser.getContext(), *inlineKindResult);
238 return success();
239}
240
241void printInlineKindAttr(OpAsmPrinter &p, cir::InlineKindAttr inlineKindAttr) {
242 if (inlineKindAttr) {
243 p << " " << stringifyInlineKind(inlineKindAttr.getValue());
244 }
245}
246//===----------------------------------------------------------------------===//
247// CIR Custom Parsers/Printers
248//===----------------------------------------------------------------------===//
249
250static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser,
251 mlir::Region &region) {
252 auto regionLoc = parser.getCurrentLocation();
253 if (parser.parseRegion(region))
254 return failure();
255 if (ensureRegionTerm(parser, region, regionLoc).failed())
256 return failure();
257 return success();
258}
259
260static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer,
261 cir::ScopeOp &op,
262 mlir::Region &region) {
263 printer.printRegion(region,
264 /*printEntryBlockArgs=*/false,
265 /*printBlockTerminators=*/!omitRegionTerm(region));
266}
267
268mlir::OptionalParseResult
269parseGlobalAddressSpaceValue(mlir::AsmParser &p,
270 mlir::ptr::MemorySpaceAttrInterface &attr);
271
272void printGlobalAddressSpaceValue(mlir::AsmPrinter &printer, cir::GlobalOp op,
273 mlir::ptr::MemorySpaceAttrInterface attr);
274
275//===----------------------------------------------------------------------===//
276// AllocaOp
277//===----------------------------------------------------------------------===//
278
279void cir::AllocaOp::build(mlir::OpBuilder &odsBuilder,
280 mlir::OperationState &odsState, mlir::Type addr,
281 mlir::Type allocaType, llvm::StringRef name,
282 mlir::IntegerAttr alignment) {
283 odsState.addAttribute(getAllocaTypeAttrName(odsState.name),
284 mlir::TypeAttr::get(allocaType));
285 odsState.addAttribute(getNameAttrName(odsState.name),
286 odsBuilder.getStringAttr(name));
287 if (alignment) {
288 odsState.addAttribute(getAlignmentAttrName(odsState.name), alignment);
289 }
290 odsState.addTypes(addr);
291}
292
293//===----------------------------------------------------------------------===//
294// ArrayCtor & ArrayDtor
295//===----------------------------------------------------------------------===//
296
297template <typename Op> static LogicalResult verifyArrayCtorDtor(Op op) {
298 auto ptrTy = mlir::cast<cir::PointerType>(op.getAddr().getType());
299 mlir::Type pointeeTy = ptrTy.getPointee();
300
301 mlir::Block &body = op.getBody().front();
302 if (body.getNumArguments() != 1)
303 return op.emitOpError("body must have exactly one block argument");
304
305 auto expectedEltPtrTy =
306 mlir::dyn_cast<cir::PointerType>(body.getArgument(0).getType());
307 if (!expectedEltPtrTy)
308 return op.emitOpError("block argument must be a !cir.ptr type");
309
310 if (op.getNumElements()) {
311 auto recTy = mlir::dyn_cast<cir::RecordType>(pointeeTy);
312 if (!recTy)
313 return op.emitOpError(
314 "when 'num_elements' is present, 'addr' must be a pointer to a "
315 "!cir.record type");
316
317 if (expectedEltPtrTy != ptrTy)
318 return op.emitOpError("when 'num_elements' is present, 'addr' type must "
319 "match the block argument type");
320 } else {
321 auto arrayTy = mlir::dyn_cast<cir::ArrayType>(pointeeTy);
322 if (!arrayTy)
323 return op.emitOpError(
324 "when 'num_elements' is absent, 'addr' must be a pointer to a "
325 "!cir.array type");
326
327 mlir::Type innerEltTy = arrayTy.getElementType();
328 while (auto nested = mlir::dyn_cast<cir::ArrayType>(innerEltTy))
329 innerEltTy = nested.getElementType();
330
331 auto recTy = mlir::dyn_cast<cir::RecordType>(innerEltTy);
332 if (!recTy)
333 return op.emitOpError(
334 "the block argument type must be a pointer to a !cir.record type");
335
336 if (expectedEltPtrTy.getPointee() != innerEltTy)
337 return op.emitOpError(
338 "block argument pointee type must match the innermost array "
339 "element type");
340 }
341
342 return success();
343}
344
345LogicalResult cir::ArrayCtor::verify() {
346 if (failed(verifyArrayCtorDtor(*this)))
347 return failure();
348
349 mlir::Region &partialDtor = getPartialDtor();
350 if (!partialDtor.empty()) {
351 mlir::Block &dtorBlock = partialDtor.front();
352 if (dtorBlock.getNumArguments() != 1)
353 return emitOpError("partial_dtor must have exactly one block argument");
354
355 auto bodyArgTy = getBody().front().getArgument(0).getType();
356 if (dtorBlock.getArgument(0).getType() != bodyArgTy)
357 return emitOpError("partial_dtor block argument type must match "
358 "the body block argument type");
359 }
360 return success();
361}
362LogicalResult cir::ArrayDtor::verify() { return verifyArrayCtorDtor(*this); }
363
364//===----------------------------------------------------------------------===//
365// DeleteArrayOp
366//===----------------------------------------------------------------------===//
367
368LogicalResult cir::DeleteArrayOp::verify() {
369 if (getDtorMayThrow() && !getElementDtorAttr())
370 return emitOpError(
371 "'dtor_may_throw' requires an 'element_dtor' to be present");
372 return success();
373}
374
375//===----------------------------------------------------------------------===//
376// BreakOp
377//===----------------------------------------------------------------------===//
378
379LogicalResult cir::BreakOp::verify() {
380 if (!getOperation()->getParentOfType<LoopOpInterface>() &&
381 !getOperation()->getParentOfType<SwitchOp>())
382 return emitOpError("must be within a loop");
383 return success();
384}
385
386//===----------------------------------------------------------------------===//
387// LocalInitOp
388//===----------------------------------------------------------------------===//
389
390LogicalResult cir::LocalInitOp::verify() {
391 if (!getOperation()->getParentOfType<FuncOp>())
392 return emitOpError("must be inside of a 'cir.func'");
393 return success();
394}
395
396LogicalResult
397cir::LocalInitOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
398 cir::GlobalOp global = getReferencedGlobal(symbolTable);
399 if (!global)
400 return emitOpError("'")
401 << getGlobalName() << "' does not reference a valid cir.global";
402
403 if (getTls() && !global.getTlsModel())
404 return emitOpError("access to global not marked thread local");
405
406 if (!global.getStaticLocalGuard().has_value())
407 return emitOpError("static_local attribute mismatch");
408
409 return success();
410}
411
412//===----------------------------------------------------------------------===//
413// ConditionOp
414//===----------------------------------------------------------------------===//
415
416//===----------------------------------
417// BranchOpTerminatorInterface Methods
418//===----------------------------------
419
420void cir::ConditionOp::getSuccessorRegions(
421 ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
422 // TODO(cir): The condition value may be folded to a constant, narrowing
423 // down its list of possible successors.
424
425 // Parent is a loop: condition may branch to the body or to the parent op.
426 if (auto loopOp = dyn_cast<LoopOpInterface>(getOperation()->getParentOp())) {
427 regions.emplace_back(&loopOp.getBody());
428 regions.push_back(RegionSuccessor::parent());
429 return;
430 }
431
432 // Parent is an await: condition may branch to resume or suspend regions.
433 auto await = cast<AwaitOp>(getOperation()->getParentOp());
434 regions.emplace_back(&await.getResume());
435 regions.emplace_back(&await.getSuspend());
436}
437
438MutableOperandRange
439cir::ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
440 // No values are yielded to the successor region.
441 return MutableOperandRange(getOperation(), 0, 0);
442}
443
444MutableOperandRange
445cir::ResumeOp::getMutableSuccessorOperands(RegionSuccessor point) {
446 // The eh_token operand is not forwarded to the parent region.
447 return MutableOperandRange(getOperation(), 0, 0);
448}
449
450LogicalResult cir::ConditionOp::verify() {
451 if (!isa<LoopOpInterface, AwaitOp>(getOperation()->getParentOp()))
452 return emitOpError("condition must be within a conditional region");
453 return success();
454}
455
456//===----------------------------------------------------------------------===//
457// ConstantOp
458//===----------------------------------------------------------------------===//
459
460static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
461 mlir::Attribute attrType) {
462 if (isa<cir::ConstPtrAttr>(attrType)) {
463 if (!mlir::isa<cir::PointerType>(opType))
464 return op->emitOpError(
465 "pointer constant initializing a non-pointer type");
466 return success();
467 }
468
469 if (isa<cir::DataMemberAttr, cir::MethodAttr>(attrType)) {
470 // More detailed type verifications are already done in
471 // DataMemberAttr::verify or MethodAttr::verify. Don't need to repeat here.
472 return success();
473 }
474
475 if (isa<cir::ZeroAttr>(attrType)) {
476 if (isa<cir::RecordType, cir::ArrayType, cir::VectorType, cir::ComplexType>(
477 opType))
478 return success();
479 return op->emitOpError(
480 "zero expects struct, array, vector, or complex type");
481 }
482
483 if (mlir::isa<cir::UndefAttr>(attrType)) {
484 if (!mlir::isa<cir::VoidType>(opType))
485 return success();
486 return op->emitOpError("undef expects non-void type");
487 }
488
489 if (mlir::isa<cir::BoolAttr>(attrType)) {
490 if (!mlir::isa<cir::BoolType>(opType))
491 return op->emitOpError("result type (")
492 << opType << ") must be '!cir.bool' for '" << attrType << "'";
493 return success();
494 }
495
496 if (mlir::isa<cir::IntAttr, cir::FPAttr>(attrType)) {
497 auto at = cast<TypedAttr>(attrType);
498 if (at.getType() != opType) {
499 return op->emitOpError("result type (")
500 << opType << ") does not match value type (" << at.getType()
501 << ")";
502 }
503 return success();
504 }
505
506 if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
507 cir::ConstComplexAttr, cir::ConstRecordAttr,
508 cir::GlobalViewAttr, cir::PoisonAttr, cir::TypeInfoAttr,
509 cir::VTableAttr>(attrType))
510 return success();
511
512 assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
513 return op->emitOpError("global with type ")
514 << cast<TypedAttr>(attrType).getType() << " not yet supported";
515}
516
517LogicalResult cir::ConstantOp::verify() {
518 // ODS already generates checks to make sure the result type is valid. We just
519 // need to additionally check that the value's attribute type is consistent
520 // with the result type.
521 return checkConstantTypes(getOperation(), getType(), getValue());
522}
523
524OpFoldResult cir::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
525 return getValue();
526}
527
528//===----------------------------------------------------------------------===//
529// ContinueOp
530//===----------------------------------------------------------------------===//
531
532LogicalResult cir::ContinueOp::verify() {
533 if (!getOperation()->getParentOfType<LoopOpInterface>())
534 return emitOpError("must be within a loop");
535 return success();
536}
537
538//===----------------------------------------------------------------------===//
539// CastOp
540//===----------------------------------------------------------------------===//
541
542LogicalResult cir::CastOp::verify() {
543 mlir::Type resType = getType();
544 mlir::Type srcType = getSrc().getType();
545
546 // Verify address space casts for pointer types. given that
547 // casts for within a different address space are illegal.
548 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
549 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
550 if (srcPtrTy && resPtrTy && (getKind() != cir::CastKind::address_space))
551 if (srcPtrTy.getAddrSpace() != resPtrTy.getAddrSpace()) {
552 return emitOpError() << "result type address space does not match the "
553 "address space of the operand";
554 }
555
556 if (mlir::isa<cir::VectorType>(srcType) &&
557 mlir::isa<cir::VectorType>(resType)) {
558 // Use the element type of the vector to verify the cast kind. (Except for
559 // bitcast, see below.)
560 srcType = mlir::dyn_cast<cir::VectorType>(srcType).getElementType();
561 resType = mlir::dyn_cast<cir::VectorType>(resType).getElementType();
562 }
563
564 switch (getKind()) {
565 case cir::CastKind::int_to_bool: {
566 if (!mlir::isa<cir::BoolType>(resType))
567 return emitOpError() << "requires !cir.bool type for result";
568 if (!mlir::isa<cir::IntType>(srcType))
569 return emitOpError() << "requires !cir.int type for source";
570 return success();
571 }
572 case cir::CastKind::ptr_to_bool: {
573 if (!mlir::isa<cir::BoolType>(resType))
574 return emitOpError() << "requires !cir.bool type for result";
575 if (!mlir::isa<cir::PointerType>(srcType))
576 return emitOpError() << "requires !cir.ptr type for source";
577 return success();
578 }
579 case cir::CastKind::integral: {
580 if (!mlir::isa<cir::IntType>(resType))
581 return emitOpError() << "requires !cir.int type for result";
582 if (!mlir::isa<cir::IntType>(srcType))
583 return emitOpError() << "requires !cir.int type for source";
584 return success();
585 }
586 case cir::CastKind::array_to_ptrdecay: {
587 const auto arrayPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
588 const auto flatPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
589 if (!arrayPtrTy || !flatPtrTy)
590 return emitOpError() << "requires !cir.ptr type for source and result";
591
592 // TODO(CIR): Make sure the AddrSpace of both types are equals
593 return success();
594 }
595 case cir::CastKind::bitcast: {
596 // Handle the pointer types first.
597 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
598 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
599
600 if (srcPtrTy && resPtrTy) {
601 return success();
602 }
603
604 return success();
605 }
606 case cir::CastKind::floating: {
607 if (!mlir::isa<cir::FPTypeInterface>(srcType) ||
608 !mlir::isa<cir::FPTypeInterface>(resType))
609 return emitOpError() << "requires !cir.float type for source and result";
610 return success();
611 }
612 case cir::CastKind::float_to_int: {
613 if (!mlir::isa<cir::FPTypeInterface>(srcType))
614 return emitOpError() << "requires !cir.float type for source";
615 if (!mlir::dyn_cast<cir::IntType>(resType))
616 return emitOpError() << "requires !cir.int type for result";
617 return success();
618 }
619 case cir::CastKind::int_to_ptr: {
620 if (!mlir::dyn_cast<cir::IntType>(srcType))
621 return emitOpError() << "requires !cir.int type for source";
622 if (!mlir::dyn_cast<cir::PointerType>(resType))
623 return emitOpError() << "requires !cir.ptr type for result";
624 return success();
625 }
626 case cir::CastKind::ptr_to_int: {
627 if (!mlir::dyn_cast<cir::PointerType>(srcType))
628 return emitOpError() << "requires !cir.ptr type for source";
629 if (!mlir::dyn_cast<cir::IntType>(resType))
630 return emitOpError() << "requires !cir.int type for result";
631 return success();
632 }
633 case cir::CastKind::float_to_bool: {
634 if (!mlir::isa<cir::FPTypeInterface>(srcType))
635 return emitOpError() << "requires !cir.float type for source";
636 if (!mlir::isa<cir::BoolType>(resType))
637 return emitOpError() << "requires !cir.bool type for result";
638 return success();
639 }
640 case cir::CastKind::bool_to_int: {
641 if (!mlir::isa<cir::BoolType>(srcType))
642 return emitOpError() << "requires !cir.bool type for source";
643 if (!mlir::isa<cir::IntType>(resType))
644 return emitOpError() << "requires !cir.int type for result";
645 return success();
646 }
647 case cir::CastKind::int_to_float: {
648 if (!mlir::isa<cir::IntType>(srcType))
649 return emitOpError() << "requires !cir.int type for source";
650 if (!mlir::isa<cir::FPTypeInterface>(resType))
651 return emitOpError() << "requires !cir.float type for result";
652 return success();
653 }
654 case cir::CastKind::bool_to_float: {
655 if (!mlir::isa<cir::BoolType>(srcType))
656 return emitOpError() << "requires !cir.bool type for source";
657 if (!mlir::isa<cir::FPTypeInterface>(resType))
658 return emitOpError() << "requires !cir.float type for result";
659 return success();
660 }
661 case cir::CastKind::address_space: {
662 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
663 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
664 if (!srcPtrTy || !resPtrTy)
665 return emitOpError() << "requires !cir.ptr type for source and result";
666 if (srcPtrTy.getPointee() != resPtrTy.getPointee())
667 return emitOpError() << "requires two types differ in addrspace only";
668 return success();
669 }
670 case cir::CastKind::float_to_complex: {
671 if (!mlir::isa<cir::FPTypeInterface>(srcType))
672 return emitOpError() << "requires !cir.float type for source";
673 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
674 if (!resComplexTy)
675 return emitOpError() << "requires !cir.complex type for result";
676 if (srcType != resComplexTy.getElementType())
677 return emitOpError() << "requires source type match result element type";
678 return success();
679 }
680 case cir::CastKind::int_to_complex: {
681 if (!mlir::isa<cir::IntType>(srcType))
682 return emitOpError() << "requires !cir.int type for source";
683 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
684 if (!resComplexTy)
685 return emitOpError() << "requires !cir.complex type for result";
686 if (srcType != resComplexTy.getElementType())
687 return emitOpError() << "requires source type match result element type";
688 return success();
689 }
690 case cir::CastKind::float_complex_to_real: {
691 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
692 if (!srcComplexTy)
693 return emitOpError() << "requires !cir.complex type for source";
694 if (!mlir::isa<cir::FPTypeInterface>(resType))
695 return emitOpError() << "requires !cir.float type for result";
696 if (srcComplexTy.getElementType() != resType)
697 return emitOpError() << "requires source element type match result type";
698 return success();
699 }
700 case cir::CastKind::int_complex_to_real: {
701 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
702 if (!srcComplexTy)
703 return emitOpError() << "requires !cir.complex type for source";
704 if (!mlir::isa<cir::IntType>(resType))
705 return emitOpError() << "requires !cir.int type for result";
706 if (srcComplexTy.getElementType() != resType)
707 return emitOpError() << "requires source element type match result type";
708 return success();
709 }
710 case cir::CastKind::float_complex_to_bool: {
711 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
712 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
713 return emitOpError()
714 << "requires floating point !cir.complex type for source";
715 if (!mlir::isa<cir::BoolType>(resType))
716 return emitOpError() << "requires !cir.bool type for result";
717 return success();
718 }
719 case cir::CastKind::int_complex_to_bool: {
720 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
721 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
722 return emitOpError()
723 << "requires floating point !cir.complex type for source";
724 if (!mlir::isa<cir::BoolType>(resType))
725 return emitOpError() << "requires !cir.bool type for result";
726 return success();
727 }
728 case cir::CastKind::float_complex: {
729 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
730 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
731 return emitOpError()
732 << "requires floating point !cir.complex type for source";
733 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
734 if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
735 return emitOpError()
736 << "requires floating point !cir.complex type for result";
737 return success();
738 }
739 case cir::CastKind::float_complex_to_int_complex: {
740 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
741 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
742 return emitOpError()
743 << "requires floating point !cir.complex type for source";
744 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
745 if (!resComplexTy || !resComplexTy.isIntegerComplex())
746 return emitOpError() << "requires integer !cir.complex type for result";
747 return success();
748 }
749 case cir::CastKind::int_complex: {
750 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
751 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
752 return emitOpError() << "requires integer !cir.complex type for source";
753 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
754 if (!resComplexTy || !resComplexTy.isIntegerComplex())
755 return emitOpError() << "requires integer !cir.complex type for result";
756 return success();
757 }
758 case cir::CastKind::int_complex_to_float_complex: {
759 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
760 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
761 return emitOpError() << "requires integer !cir.complex type for source";
762 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
763 if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
764 return emitOpError()
765 << "requires floating point !cir.complex type for result";
766 return success();
767 }
768 case cir::CastKind::member_ptr_to_bool: {
769 if (!mlir::isa<cir::DataMemberType, cir::MethodType>(srcType))
770 return emitOpError()
771 << "requires !cir.data_member or !cir.method type for source";
772 if (!mlir::isa<cir::BoolType>(resType))
773 return emitOpError() << "requires !cir.bool type for result";
774 return success();
775 }
776 }
777 llvm_unreachable("Unknown CastOp kind?");
778}
779
780static bool isIntOrBoolCast(cir::CastOp op) {
781 auto kind = op.getKind();
782 return kind == cir::CastKind::bool_to_int ||
783 kind == cir::CastKind::int_to_bool || kind == cir::CastKind::integral;
784}
785
786static bool isCirFunctionPointerType(mlir::Type ty) {
787 const auto ptrTy = mlir::dyn_cast<cir::PointerType>(ty);
788 return ptrTy && mlir::isa<cir::FuncType>(ptrTy.getPointee());
789}
790
791static Value tryFoldCastChain(cir::CastOp op) {
792 cir::CastOp head = op, tail = op;
793
794 while (op) {
795 if (!isIntOrBoolCast(op))
796 break;
797 head = op;
798 op = head.getSrc().getDefiningOp<cir::CastOp>();
799 }
800
801 if (head != tail) {
802 // if bool_to_int -> ... -> int_to_bool: take the bool
803 // as we had it was before all casts
804 if (head.getKind() == cir::CastKind::bool_to_int &&
805 tail.getKind() == cir::CastKind::int_to_bool)
806 return head.getSrc();
807
808 // if int_to_bool -> ... -> int_to_bool: take the result
809 // of the first one, as no other casts (and ext casts as well)
810 // don't change the first result
811 if (head.getKind() == cir::CastKind::int_to_bool &&
812 tail.getKind() == cir::CastKind::int_to_bool)
813 return head.getResult();
814
815 return {};
816 }
817
818 // Bitcast round-trip on function pointers: T0 -> T1 -> T0 (e.g. no-proto
819 // redeclaration vs. actual prototype). Restrict to function pointers so
820 // other pointer bitcast chains are unchanged.
821 if (tail.getKind() == cir::CastKind::bitcast) {
822 auto *inner = tail.getSrc().getDefiningOp();
823 if (inner && isCirFunctionPointerType(tail.getType())) {
824 auto innerCast = mlir::dyn_cast<cir::CastOp>(inner);
825 if (innerCast && innerCast.getKind() == cir::CastKind::bitcast &&
826 innerCast.getSrc().getType() == tail.getType() &&
827 innerCast.getType() == tail.getSrc().getType()) {
828 return innerCast.getSrc();
829 }
830 }
831 }
832
833 return {};
834}
835
836OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
837 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getSrc())) {
838 // Propagate poison value
839 return cir::PoisonAttr::get(getContext(), getType());
840 }
841
842 if (getSrc().getType() == getType()) {
843 switch (getKind()) {
844 case cir::CastKind::integral: {
846 auto foldOrder = getSrc().getDefiningOp()->fold(foldResults);
847 if (foldOrder.succeeded() && mlir::isa<mlir::Attribute>(foldResults[0]))
848 return mlir::cast<mlir::Attribute>(foldResults[0]);
849 return {};
850 }
851 case cir::CastKind::bitcast:
852 case cir::CastKind::address_space:
853 case cir::CastKind::float_complex:
854 case cir::CastKind::int_complex: {
855 return getSrc();
856 }
857 default:
858 return {};
859 }
860 }
861
862 // Handle cases where a chain of casts cancel out.
863 Value result = tryFoldCastChain(*this);
864 if (result)
865 return result;
866
867 // Handle simple constant casts.
868 if (auto srcConst = getSrc().getDefiningOp<cir::ConstantOp>()) {
869 switch (getKind()) {
870 case cir::CastKind::integral: {
871 mlir::Type srcTy = getSrc().getType();
872 // Don't try to fold vector casts for now.
873 assert(mlir::isa<cir::VectorType>(srcTy) ==
874 mlir::isa<cir::VectorType>(getType()));
875 if (mlir::isa<cir::VectorType>(srcTy))
876 break;
877
878 auto srcIntTy = mlir::cast<cir::IntType>(srcTy);
879 auto dstIntTy = mlir::cast<cir::IntType>(getType());
880 APInt newVal =
881 srcIntTy.isSigned()
882 ? srcConst.getIntValue().sextOrTrunc(dstIntTy.getWidth())
883 : srcConst.getIntValue().zextOrTrunc(dstIntTy.getWidth());
884 return cir::IntAttr::get(dstIntTy, newVal);
885 }
886 default:
887 break;
888 }
889 }
890 return {};
891}
892
893//===----------------------------------------------------------------------===//
894// CallOp
895//===----------------------------------------------------------------------===//
896
897mlir::OperandRange cir::CallOp::getArgOperands() {
898 if (isIndirect())
899 return getArgs().drop_front(1);
900 return getArgs();
901}
902
903mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
904 mlir::MutableOperandRange args = getArgsMutable();
905 if (isIndirect())
906 return args.slice(1, args.size() - 1);
907 return args;
908}
909
910mlir::Value cir::CallOp::getIndirectCall() {
911 assert(isIndirect());
912 return getOperand(0);
913}
914
915/// Return the operand at index 'i'.
916Value cir::CallOp::getArgOperand(unsigned i) {
917 if (isIndirect())
918 ++i;
919 return getOperand(i);
920}
921
922/// Return the number of operands.
923unsigned cir::CallOp::getNumArgOperands() {
924 if (isIndirect())
925 return this->getOperation()->getNumOperands() - 1;
926 return this->getOperation()->getNumOperands();
927}
928
929static mlir::ParseResult
930parseTryCallDestinations(mlir::OpAsmParser &parser,
931 mlir::OperationState &result) {
932 mlir::Block *normalDestSuccessor;
933 if (parser.parseSuccessor(normalDestSuccessor))
934 return mlir::failure();
935
936 if (parser.parseComma())
937 return mlir::failure();
938
939 mlir::Block *unwindDestSuccessor;
940 if (parser.parseSuccessor(unwindDestSuccessor))
941 return mlir::failure();
942
943 result.addSuccessors(normalDestSuccessor);
944 result.addSuccessors(unwindDestSuccessor);
945 return mlir::success();
946}
947
948static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
949 mlir::OperationState &result,
950 bool hasDestinationBlocks = false) {
952 llvm::SMLoc opsLoc;
953 mlir::FlatSymbolRefAttr calleeAttr;
954
955 // If we cannot parse a string callee, it means this is an indirect call.
956 if (!parser
957 .parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
958 result.attributes)
959 .has_value()) {
960 OpAsmParser::UnresolvedOperand indirectVal;
961 // Do not resolve right now, since we need to figure out the type
962 if (parser.parseOperand(indirectVal).failed())
963 return failure();
964 ops.push_back(indirectVal);
965 }
966
967 if (parser.parseLParen())
968 return mlir::failure();
969
970 opsLoc = parser.getCurrentLocation();
971 if (parser.parseOperandList(ops))
972 return mlir::failure();
973 if (parser.parseRParen())
974 return mlir::failure();
975
976 if (hasDestinationBlocks &&
977 parseTryCallDestinations(parser, result).failed()) {
978 return ::mlir::failure();
979 }
980
981 if (parser.parseOptionalKeyword("musttail").succeeded())
982 result.addAttribute(CIRDialect::getMustTailAttrName(),
983 mlir::UnitAttr::get(parser.getContext()));
984
985 if (parser.parseOptionalKeyword("nothrow").succeeded())
986 result.addAttribute(CIRDialect::getNoThrowAttrName(),
987 mlir::UnitAttr::get(parser.getContext()));
988
989 if (parser.parseOptionalKeyword("side_effect").succeeded()) {
990 if (parser.parseLParen().failed())
991 return failure();
992 cir::SideEffect sideEffect;
993 if (parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed())
994 return failure();
995 if (parser.parseRParen().failed())
996 return failure();
997 auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
998 result.addAttribute(CIRDialect::getSideEffectAttrName(), attr);
999 }
1000
1001 if (parser.parseOptionalAttrDict(result.attributes))
1002 return ::mlir::failure();
1003
1004 if (parser.parseColon())
1005 return ::mlir::failure();
1006
1007 SmallVector<Type> argTypes;
1009 SmallVector<Type> resultTypes;
1010 SmallVector<DictionaryAttr> resultAttrs;
1011 if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
1012 resultTypes, resultAttrs))
1013 return mlir::failure();
1014
1015 if (resultTypes.size() > 1 || resultAttrs.size() > 1)
1016 return parser.emitError(
1017 parser.getCurrentLocation(),
1018 "functions with multiple return types are not supported");
1019
1020 result.addTypes(resultTypes);
1021
1022 if (parser.resolveOperands(ops, argTypes, opsLoc, result.operands))
1023 return mlir::failure();
1024
1025 if (!resultAttrs.empty() && resultAttrs[0])
1026 result.addAttribute(
1027 CIRDialect::getResAttrsAttrName(),
1028 mlir::ArrayAttr::get(parser.getContext(), {resultAttrs[0]}));
1029
1030 // ArrayAttr requires a vector of 'Attribute', so we have to do the conversion
1031 // here into a separate collection.
1032 llvm::SmallVector<Attribute> convertedArgAttrs;
1033 bool argAttrsEmpty = true;
1034
1035 llvm::transform(argAttrs, std::back_inserter(convertedArgAttrs),
1036 [&](DictionaryAttr da) -> mlir::Attribute {
1037 if (da)
1038 argAttrsEmpty = false;
1039 return da;
1040 });
1041
1042 if (!argAttrsEmpty) {
1043 llvm::ArrayRef argAttrsRef = convertedArgAttrs;
1044 if (!calleeAttr) {
1045 // Fixup for indirect calls, which get an extra entry in the 'args' for
1046 // the indirect type, which doesn't get attributes.
1047 argAttrsRef = argAttrsRef.drop_front();
1048 }
1049 result.addAttribute(CIRDialect::getArgAttrsAttrName(),
1050 mlir::ArrayAttr::get(parser.getContext(), argAttrsRef));
1051 }
1052
1053 return mlir::success();
1054}
1055
1056static void
1057printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym,
1058 mlir::Value indirectCallee, mlir::OpAsmPrinter &printer,
1059 bool isNothrow, cir::SideEffect sideEffect, ArrayAttr argAttrs,
1060 ArrayAttr resAttrs, mlir::Block *normalDest = nullptr,
1061 mlir::Block *unwindDest = nullptr) {
1062 printer << ' ';
1063
1064 auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
1065 auto ops = callLikeOp.getArgOperands();
1066
1067 if (calleeSym) {
1068 // Direct calls
1069 printer.printAttributeWithoutType(calleeSym);
1070 } else {
1071 // Indirect calls
1072 assert(indirectCallee);
1073 printer << indirectCallee;
1074 }
1075
1076 printer << "(" << ops << ")";
1077
1078 if (normalDest) {
1079 assert(unwindDest && "expected two successors");
1080 auto tryCall = cast<cir::TryCallOp>(op);
1081 printer << ' ' << tryCall.getNormalDest();
1082 printer << ",";
1083 printer << ' ';
1084 printer << tryCall.getUnwindDest();
1085 }
1086
1087 if (op->hasAttr(CIRDialect::getMustTailAttrName()))
1088 printer << " musttail";
1089
1090 if (isNothrow)
1091 printer << " nothrow";
1092
1093 if (sideEffect != cir::SideEffect::All) {
1094 printer << " side_effect(";
1095 printer << stringifySideEffect(sideEffect);
1096 printer << ")";
1097 }
1098
1100 CIRDialect::getCalleeAttrName(),
1101 CIRDialect::getMustTailAttrName(),
1102 CIRDialect::getNoThrowAttrName(),
1103 CIRDialect::getSideEffectAttrName(),
1104 CIRDialect::getOperandSegmentSizesAttrName(),
1105 llvm::StringRef("res_attrs"),
1106 llvm::StringRef("arg_attrs")};
1107 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
1108 printer << " : ";
1109 if (calleeSym || !argAttrs) {
1110 call_interface_impl::printFunctionSignature(
1111 printer, op->getOperands().getTypes(), argAttrs,
1112 /*isVariadic=*/false, op->getResultTypes(), resAttrs);
1113 } else {
1114 // indirect function calls use an 'arg' type for the type of its indirect
1115 // argument. However, we don't store a similar attribute collection. In
1116 // order to make `printFunctionSignature` have the attributes line up, we
1117 // have to make a 'shimmed' copy of the attributes that have a blank set of
1118 // attributes for the indirect argument.
1119 llvm::SmallVector<Attribute> shimmedArgAttrs;
1120 shimmedArgAttrs.push_back(mlir::DictionaryAttr::get(op->getContext(), {}));
1121 shimmedArgAttrs.append(argAttrs.begin(), argAttrs.end());
1122 call_interface_impl::printFunctionSignature(
1123 printer, op->getOperands().getTypes(),
1124 mlir::ArrayAttr::get(op->getContext(), shimmedArgAttrs),
1125 /*isVariadic=*/false, op->getResultTypes(), resAttrs);
1126 }
1127}
1128
1129mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
1130 mlir::OperationState &result) {
1131 return parseCallCommon(parser, result);
1132}
1133
1134void cir::CallOp::print(mlir::OpAsmPrinter &p) {
1135 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
1136 cir::SideEffect sideEffect = getSideEffect();
1137 printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
1138 sideEffect, getArgAttrsAttr(), getResAttrsAttr());
1139}
1140
1141static LogicalResult
1142verifyCallCommInSymbolUses(mlir::Operation *op,
1143 SymbolTableCollection &symbolTable) {
1144 auto fnAttr =
1145 op->getAttrOfType<FlatSymbolRefAttr>(CIRDialect::getCalleeAttrName());
1146 if (!fnAttr) {
1147 // This is an indirect call, thus we don't have to check the symbol uses.
1148 return mlir::success();
1149 }
1150
1151 auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
1152 if (!fn)
1153 return op->emitOpError() << "'" << fnAttr.getValue()
1154 << "' does not reference a valid function";
1155
1156 auto callIf = dyn_cast<cir::CIRCallOpInterface>(op);
1157 assert(callIf && "expected CIR call interface to be always available");
1158
1159 // Verify that the operand and result types match the callee. Note that
1160 // argument-checking is disabled for functions without a prototype.
1161 auto fnType = fn.getFunctionType();
1162 if (!fn.getNoProto()) {
1163 unsigned numCallOperands = callIf.getNumArgOperands();
1164 unsigned numFnOpOperands = fnType.getNumInputs();
1165
1166 if (!fnType.isVarArg() && numCallOperands != numFnOpOperands)
1167 return op->emitOpError("incorrect number of operands for callee");
1168 if (fnType.isVarArg() && numCallOperands < numFnOpOperands)
1169 return op->emitOpError("too few operands for callee");
1170
1171 for (unsigned i = 0, e = numFnOpOperands; i != e; ++i)
1172 if (callIf.getArgOperand(i).getType() != fnType.getInput(i))
1173 return op->emitOpError("operand type mismatch: expected operand type ")
1174 << fnType.getInput(i) << ", but provided "
1175 << op->getOperand(i).getType() << " for operand number " << i;
1176 }
1177
1179
1180 // Void function must not return any results.
1181 if (fnType.hasVoidReturn() && op->getNumResults() != 0)
1182 return op->emitOpError("callee returns void but call has results");
1183
1184 // Non-void function calls must return exactly one result.
1185 if (!fnType.hasVoidReturn() && op->getNumResults() != 1)
1186 return op->emitOpError("incorrect number of results for callee");
1187
1188 // Parent function and return value types must match.
1189 if (!fnType.hasVoidReturn() &&
1190 op->getResultTypes().front() != fnType.getReturnType()) {
1191 return op->emitOpError("result type mismatch: expected ")
1192 << fnType.getReturnType() << ", but provided "
1193 << op->getResult(0).getType();
1194 }
1195
1196 return mlir::success();
1197}
1198
1199LogicalResult
1200cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1201 return verifyCallCommInSymbolUses(*this, symbolTable);
1202}
1203
1204//===----------------------------------------------------------------------===//
1205// TryCallOp
1206//===----------------------------------------------------------------------===//
1207
1208mlir::OperandRange cir::TryCallOp::getArgOperands() {
1209 if (isIndirect())
1210 return getArgs().drop_front(1);
1211 return getArgs();
1212}
1213
1214mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
1215 mlir::MutableOperandRange args = getArgsMutable();
1216 if (isIndirect())
1217 return args.slice(1, args.size() - 1);
1218 return args;
1219}
1220
1221mlir::Value cir::TryCallOp::getIndirectCall() {
1222 assert(isIndirect());
1223 return getOperand(0);
1224}
1225
1226/// Return the operand at index 'i'.
1227Value cir::TryCallOp::getArgOperand(unsigned i) {
1228 if (isIndirect())
1229 ++i;
1230 return getOperand(i);
1231}
1232
1233/// Return the number of operands.
1234unsigned cir::TryCallOp::getNumArgOperands() {
1235 if (isIndirect())
1236 return this->getOperation()->getNumOperands() - 1;
1237 return this->getOperation()->getNumOperands();
1238}
1239
1240LogicalResult
1241cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1242 return verifyCallCommInSymbolUses(*this, symbolTable);
1243}
1244
1245mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
1246 mlir::OperationState &result) {
1247 return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
1248}
1249
1250void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
1251 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
1252 cir::SideEffect sideEffect = getSideEffect();
1253 printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
1254 sideEffect, getArgAttrsAttr(), getResAttrsAttr(),
1255 getNormalDest(), getUnwindDest());
1256}
1257
1258//===----------------------------------------------------------------------===//
1259// ReturnOp
1260//===----------------------------------------------------------------------===//
1261
1262static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op,
1263 cir::FuncOp function) {
1264 // ReturnOps currently only have a single optional operand.
1265 if (op.getNumOperands() > 1)
1266 return op.emitOpError() << "expects at most 1 return operand";
1267
1268 // Ensure returned type matches the function signature.
1269 auto expectedTy = function.getFunctionType().getReturnType();
1270 auto actualTy =
1271 (op.getNumOperands() == 0 ? cir::VoidType::get(op.getContext())
1272 : op.getOperand(0).getType());
1273 if (actualTy != expectedTy)
1274 return op.emitOpError() << "returns " << actualTy
1275 << " but enclosing function returns " << expectedTy;
1276
1277 return mlir::success();
1278}
1279
1280mlir::LogicalResult cir::ReturnOp::verify() {
1281 // Returns can be present in multiple different scopes, get the
1282 // wrapping function and start from there.
1283 auto *fnOp = getOperation()->getParentOp();
1284 while (!isa<cir::FuncOp>(fnOp))
1285 fnOp = fnOp->getParentOp();
1286
1287 // Make sure return types match function return type.
1288 if (checkReturnAndFunction(*this, cast<cir::FuncOp>(fnOp)).failed())
1289 return failure();
1290
1291 return success();
1292}
1293
1294//===----------------------------------------------------------------------===//
1295// IfOp
1296//===----------------------------------------------------------------------===//
1297
1298ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
1299 // create the regions for 'then'.
1300 result.regions.reserve(2);
1301 Region *thenRegion = result.addRegion();
1302 Region *elseRegion = result.addRegion();
1303
1304 mlir::Builder &builder = parser.getBuilder();
1305 OpAsmParser::UnresolvedOperand cond;
1306 Type boolType = cir::BoolType::get(builder.getContext());
1307
1308 if (parser.parseOperand(cond) ||
1309 parser.resolveOperand(cond, boolType, result.operands))
1310 return failure();
1311
1312 // Parse 'then' region.
1313 mlir::SMLoc parseThenLoc = parser.getCurrentLocation();
1314 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1315 return failure();
1316
1317 if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed())
1318 return failure();
1319
1320 // If we find an 'else' keyword, parse the 'else' region.
1321 if (!parser.parseOptionalKeyword("else")) {
1322 mlir::SMLoc parseElseLoc = parser.getCurrentLocation();
1323 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1324 return failure();
1325 if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed())
1326 return failure();
1327 }
1328
1329 // Parse the optional attribute list.
1330 if (parser.parseOptionalAttrDict(result.attributes))
1331 return failure();
1332 return success();
1333}
1334
1335void cir::IfOp::print(OpAsmPrinter &p) {
1336 p << " " << getCondition() << " ";
1337 mlir::Region &thenRegion = this->getThenRegion();
1338 p.printRegion(thenRegion,
1339 /*printEntryBlockArgs=*/false,
1340 /*printBlockTerminators=*/!omitRegionTerm(thenRegion));
1341
1342 // Print the 'else' regions if it exists and has a block.
1343 mlir::Region &elseRegion = this->getElseRegion();
1344 if (!elseRegion.empty()) {
1345 p << " else ";
1346 p.printRegion(elseRegion,
1347 /*printEntryBlockArgs=*/false,
1348 /*printBlockTerminators=*/!omitRegionTerm(elseRegion));
1349 }
1350
1351 p.printOptionalAttrDict(getOperation()->getAttrs());
1352}
1353
1354/// Default callback for IfOp builders.
1355void cir::buildTerminatedBody(OpBuilder &builder, Location loc) {
1356 // add cir.yield to end of the block
1357 cir::YieldOp::create(builder, loc);
1358}
1359
1360/// Given the region at `index`, or the parent operation if `index` is None,
1361/// return the successor regions. These are the regions that may be selected
1362/// during the flow of control. `operands` is a set of optional attributes that
1363/// correspond to a constant value for each operand, or null if that operand is
1364/// not a constant.
1365void cir::IfOp::getSuccessorRegions(mlir::RegionBranchPoint point,
1366 SmallVectorImpl<RegionSuccessor> &regions) {
1367 // The `then` and the `else` region branch back to the parent operation.
1368 if (!point.isParent()) {
1369 regions.push_back(RegionSuccessor::parent());
1370 return;
1371 }
1372
1373 // Don't consider the else region if it is empty.
1374 Region *elseRegion = &this->getElseRegion();
1375 if (elseRegion->empty())
1376 elseRegion = nullptr;
1377
1378 // If the condition isn't constant, both regions may be executed.
1379 regions.push_back(RegionSuccessor(&getThenRegion()));
1380 if (elseRegion)
1381 regions.push_back(RegionSuccessor(elseRegion));
1382 else
1383 regions.push_back(RegionSuccessor::parent());
1384}
1385
1386mlir::ValueRange cir::IfOp::getSuccessorInputs(RegionSuccessor successor) {
1387 return successor.isParent() ? ValueRange(getOperation()->getResults())
1388 : ValueRange();
1389}
1390
1391void cir::IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1392 bool withElseRegion, BuilderCallbackRef thenBuilder,
1393 BuilderCallbackRef elseBuilder) {
1394 assert(thenBuilder && "the builder callback for 'then' must be present");
1395 result.addOperands(cond);
1396
1397 OpBuilder::InsertionGuard guard(builder);
1398 Region *thenRegion = result.addRegion();
1399 builder.createBlock(thenRegion);
1400 thenBuilder(builder, result.location);
1401
1402 Region *elseRegion = result.addRegion();
1403 if (!withElseRegion)
1404 return;
1405
1406 builder.createBlock(elseRegion);
1407 elseBuilder(builder, result.location);
1408}
1409
1410//===----------------------------------------------------------------------===//
1411// ScopeOp
1412//===----------------------------------------------------------------------===//
1413
1414/// Given the region at `index`, or the parent operation if `index` is None,
1415/// return the successor regions. These are the regions that may be selected
1416/// during the flow of control. `operands` is a set of optional attributes
1417/// that correspond to a constant value for each operand, or null if that
1418/// operand is not a constant.
1419void cir::ScopeOp::getSuccessorRegions(
1420 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1421 // The only region always branch back to the parent operation.
1422 if (!point.isParent()) {
1423 regions.push_back(RegionSuccessor::parent());
1424 return;
1425 }
1426
1427 // If the condition isn't constant, both regions may be executed.
1428 regions.push_back(RegionSuccessor(&getScopeRegion()));
1429}
1430
1431mlir::ValueRange cir::ScopeOp::getSuccessorInputs(RegionSuccessor successor) {
1432 return successor.isParent() ? ValueRange(getOperation()->getResults())
1433 : ValueRange();
1434}
1435
1436void cir::ScopeOp::build(
1437 OpBuilder &builder, OperationState &result,
1438 function_ref<void(OpBuilder &, Type &, Location)> scopeBuilder) {
1439 assert(scopeBuilder && "the builder callback for 'then' must be present");
1440
1441 OpBuilder::InsertionGuard guard(builder);
1442 Region *scopeRegion = result.addRegion();
1443 builder.createBlock(scopeRegion);
1445
1446 mlir::Type yieldTy;
1447 scopeBuilder(builder, yieldTy, result.location);
1448
1449 if (yieldTy)
1450 result.addTypes(TypeRange{yieldTy});
1451}
1452
1453void cir::ScopeOp::build(
1454 OpBuilder &builder, OperationState &result,
1455 function_ref<void(OpBuilder &, Location)> scopeBuilder) {
1456 assert(scopeBuilder && "the builder callback for 'then' must be present");
1457 OpBuilder::InsertionGuard guard(builder);
1458 Region *scopeRegion = result.addRegion();
1459 builder.createBlock(scopeRegion);
1461 scopeBuilder(builder, result.location);
1462}
1463
1464LogicalResult cir::ScopeOp::verify() {
1465 if (getRegion().empty()) {
1466 return emitOpError() << "cir.scope must not be empty since it should "
1467 "include at least an implicit cir.yield ";
1468 }
1469
1470 mlir::Block &lastBlock = getRegion().back();
1471 if (lastBlock.empty() || !lastBlock.mightHaveTerminator() ||
1472 !lastBlock.getTerminator()->hasTrait<OpTrait::IsTerminator>())
1473 return emitOpError() << "last block of cir.scope must be terminated";
1474 return success();
1475}
1476
1477LogicalResult cir::ScopeOp::fold(FoldAdaptor /*adaptor*/,
1478 SmallVectorImpl<OpFoldResult> &results) {
1479 // Only fold "trivial" scopes: a single block containing only a `cir.yield`.
1480 if (!getRegion().hasOneBlock())
1481 return failure();
1482 Block &block = getRegion().front();
1483 if (block.getOperations().size() != 1)
1484 return failure();
1485
1486 auto yield = dyn_cast<cir::YieldOp>(block.front());
1487 if (!yield)
1488 return failure();
1489
1490 // Only fold when the scope produces a value.
1491 if (getNumResults() != 1 || yield.getNumOperands() != 1)
1492 return failure();
1493
1494 results.push_back(yield.getOperand(0));
1495 return success();
1496}
1497
1498//===----------------------------------------------------------------------===//
1499// CleanupScopeOp
1500//===----------------------------------------------------------------------===//
1501
1502void cir::CleanupScopeOp::getSuccessorRegions(
1503 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1504 if (!point.isParent()) {
1505 regions.push_back(RegionSuccessor::parent());
1506 return;
1507 }
1508
1509 // Execution always proceeds from the body region to the cleanup region.
1510 regions.push_back(RegionSuccessor(&getBodyRegion()));
1511 regions.push_back(RegionSuccessor(&getCleanupRegion()));
1512}
1513
1514mlir::ValueRange
1515cir::CleanupScopeOp::getSuccessorInputs(RegionSuccessor successor) {
1516 return ValueRange();
1517}
1518
1519LogicalResult cir::CleanupScopeOp::canonicalize(CleanupScopeOp op,
1520 PatternRewriter &rewriter) {
1521 auto isRegionTrivial = [](Region &region) {
1522 assert(!region.empty() && "CleanupScopeOp regions must not be empty");
1523 if (!region.hasOneBlock())
1524 return false;
1525 Block &block = llvm::getSingleElement(region);
1526 return llvm::hasSingleElement(block) &&
1527 isa<cir::YieldOp>(llvm::getSingleElement(block));
1528 };
1529
1530 Region &body = op.getBodyRegion();
1531 Region &cleanup = op.getCleanupRegion();
1532
1533 // An EH-only cleanup scope with an empty body can never trigger its cleanup
1534 // region — there are no operations in the body that could throw. Erase it.
1535 if (op.getCleanupKind() == CleanupKind::EH && isRegionTrivial(body)) {
1536 rewriter.eraseOp(op);
1537 return success();
1538 }
1539
1540 // A cleanup scope with a trivial cleanup region has no cleanup to perform.
1541 // Inline the body into the parent block and erase the scope.
1542 if (!isRegionTrivial(cleanup) || !body.hasOneBlock())
1543 return failure();
1544
1545 Block &bodyBlock = body.front();
1546 if (!isa<cir::YieldOp>(bodyBlock.getTerminator()))
1547 return failure();
1548
1549 Operation *yield = bodyBlock.getTerminator();
1550 rewriter.inlineBlockBefore(&bodyBlock, op);
1551 rewriter.eraseOp(yield);
1552 rewriter.eraseOp(op);
1553 return success();
1554}
1555
1556void cir::CleanupScopeOp::build(
1557 OpBuilder &builder, OperationState &result, CleanupKind cleanupKind,
1558 function_ref<void(OpBuilder &, Location)> bodyBuilder,
1559 function_ref<void(OpBuilder &, Location)> cleanupBuilder) {
1560 result.addAttribute(getCleanupKindAttrName(result.name),
1561 CleanupKindAttr::get(builder.getContext(), cleanupKind));
1562
1563 OpBuilder::InsertionGuard guard(builder);
1564
1565 // Build body region.
1566 Region *bodyRegion = result.addRegion();
1567 builder.createBlock(bodyRegion);
1568 if (bodyBuilder)
1569 bodyBuilder(builder, result.location);
1570
1571 // Build cleanup region.
1572 Region *cleanupRegion = result.addRegion();
1573 builder.createBlock(cleanupRegion);
1574 if (cleanupBuilder)
1575 cleanupBuilder(builder, result.location);
1576}
1577
1578//===----------------------------------------------------------------------===//
1579// BrOp
1580//===----------------------------------------------------------------------===//
1581
1582/// Merges blocks connected by a unique unconditional branch.
1583///
1584/// ^bb0: ^bb0:
1585/// ... ...
1586/// cir.br ^bb1 => ...
1587/// ^bb1: cir.return
1588/// ...
1589/// cir.return
1590LogicalResult cir::BrOp::canonicalize(BrOp op, PatternRewriter &rewriter) {
1591 Block *src = op->getBlock();
1592 Block *dst = op.getDest();
1593
1594 // Do not fold self-loops.
1595 if (src == dst)
1596 return failure();
1597
1598 // Only merge when this is the unique edge between the blocks.
1599 if (src->getNumSuccessors() != 1 || dst->getSinglePredecessor() != src)
1600 return failure();
1601
1602 // Don't merge blocks that start with LabelOp or IndirectBrOp.
1603 // This is to avoid merging blocks that have an indirect predecessor.
1604 if (isa<cir::LabelOp, cir::IndirectBrOp>(dst->front()))
1605 return failure();
1606
1607 auto operands = op.getDestOperands();
1608 rewriter.eraseOp(op);
1609 rewriter.mergeBlocks(dst, src, operands);
1610 return success();
1611}
1612
1613mlir::SuccessorOperands cir::BrOp::getSuccessorOperands(unsigned index) {
1614 assert(index == 0 && "invalid successor index");
1615 return mlir::SuccessorOperands(getDestOperandsMutable());
1616}
1617
1618Block *cir::BrOp::getSuccessorForOperands(ArrayRef<Attribute>) {
1619 return getDest();
1620}
1621
1622//===----------------------------------------------------------------------===//
1623// IndirectBrCondOp
1624//===----------------------------------------------------------------------===//
1625
1626mlir::SuccessorOperands
1627cir::IndirectBrOp::getSuccessorOperands(unsigned index) {
1628 assert(index < getNumSuccessors() && "invalid successor index");
1629 return mlir::SuccessorOperands(getSuccOperandsMutable()[index]);
1630}
1631
1633 OpAsmParser &parser, Type &flagType,
1634 SmallVectorImpl<Block *> &succOperandBlocks,
1635 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &succOperands,
1636 SmallVectorImpl<SmallVector<Type>> &succOperandsTypes) {
1637 if (failed(parser.parseCommaSeparatedList(
1638 OpAsmParser::Delimiter::Square,
1639 [&]() {
1640 Block *destination = nullptr;
1641 SmallVector<OpAsmParser::UnresolvedOperand> operands;
1642 SmallVector<Type> operandTypes;
1643
1644 if (parser.parseSuccessor(destination).failed())
1645 return failure();
1646
1647 if (succeeded(parser.parseOptionalLParen())) {
1648 if (failed(parser.parseOperandList(
1649 operands, OpAsmParser::Delimiter::None)) ||
1650 failed(parser.parseColonTypeList(operandTypes)) ||
1651 failed(parser.parseRParen()))
1652 return failure();
1653 }
1654 succOperandBlocks.push_back(destination);
1655 succOperands.emplace_back(operands);
1656 succOperandsTypes.emplace_back(operandTypes);
1657 return success();
1658 },
1659 "successor blocks")))
1660 return failure();
1661 return success();
1662}
1663
1664void printIndirectBrOpSucessors(OpAsmPrinter &p, cir::IndirectBrOp op,
1665 Type flagType, SuccessorRange succs,
1666 OperandRangeRange succOperands,
1667 const TypeRangeRange &succOperandsTypes) {
1668 p << "[";
1669 llvm::interleave(
1670 llvm::zip(succs, succOperands),
1671 [&](auto i) {
1672 p.printNewline();
1673 p.printSuccessorAndUseList(std::get<0>(i), std::get<1>(i));
1674 },
1675 [&] { p << ','; });
1676 if (!succOperands.empty())
1677 p.printNewline();
1678 p << "]";
1679}
1680
1681//===----------------------------------------------------------------------===//
1682// BrCondOp
1683//===----------------------------------------------------------------------===//
1684
1685mlir::SuccessorOperands cir::BrCondOp::getSuccessorOperands(unsigned index) {
1686 assert(index < getNumSuccessors() && "invalid successor index");
1687 return SuccessorOperands(index == 0 ? getDestOperandsTrueMutable()
1688 : getDestOperandsFalseMutable());
1689}
1690
1691Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1692 if (IntegerAttr condAttr = dyn_cast_if_present<IntegerAttr>(operands.front()))
1693 return condAttr.getValue().isOne() ? getDestTrue() : getDestFalse();
1694 return nullptr;
1695}
1696
1697//===----------------------------------------------------------------------===//
1698// CaseOp
1699//===----------------------------------------------------------------------===//
1700
1701void cir::CaseOp::getSuccessorRegions(
1702 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1703 if (!point.isParent()) {
1704 regions.push_back(RegionSuccessor::parent());
1705 return;
1706 }
1707 regions.push_back(RegionSuccessor(&getCaseRegion()));
1708}
1709
1710mlir::ValueRange cir::CaseOp::getSuccessorInputs(RegionSuccessor successor) {
1711 return successor.isParent() ? ValueRange(getOperation()->getResults())
1712 : ValueRange();
1713}
1714
1715void cir::CaseOp::build(OpBuilder &builder, OperationState &result,
1716 ArrayAttr value, CaseOpKind kind,
1717 OpBuilder::InsertPoint &insertPoint) {
1718 OpBuilder::InsertionGuard guardSwitch(builder);
1719 result.addAttribute("value", value);
1720 result.getOrAddProperties<Properties>().kind =
1721 cir::CaseOpKindAttr::get(builder.getContext(), kind);
1722 Region *caseRegion = result.addRegion();
1723 builder.createBlock(caseRegion);
1724
1725 insertPoint = builder.saveInsertionPoint();
1726}
1727
1728//===----------------------------------------------------------------------===//
1729// SwitchOp
1730//===----------------------------------------------------------------------===//
1731
1732void cir::SwitchOp::getSuccessorRegions(
1733 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &region) {
1734 if (!point.isParent()) {
1735 region.push_back(RegionSuccessor::parent());
1736 return;
1737 }
1738
1739 region.push_back(RegionSuccessor(&getBody()));
1740}
1741
1742mlir::ValueRange cir::SwitchOp::getSuccessorInputs(RegionSuccessor successor) {
1743 return successor.isParent() ? ValueRange(getOperation()->getResults())
1744 : ValueRange();
1745}
1746
1747void cir::SwitchOp::build(OpBuilder &builder, OperationState &result,
1748 Value cond, BuilderOpStateCallbackRef switchBuilder) {
1749 assert(switchBuilder && "the builder callback for regions must be present");
1750 OpBuilder::InsertionGuard guardSwitch(builder);
1751 Region *switchRegion = result.addRegion();
1752 builder.createBlock(switchRegion);
1753 result.addOperands({cond});
1754 switchBuilder(builder, result.location, result);
1755}
1756
1757void cir::SwitchOp::collectCases(llvm::SmallVectorImpl<CaseOp> &cases) {
1758 walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
1759 // Don't walk in nested switch op.
1760 if (isa<cir::SwitchOp>(op) && op != *this)
1761 return WalkResult::skip();
1762
1763 if (auto caseOp = dyn_cast<cir::CaseOp>(op))
1764 cases.push_back(caseOp);
1765
1766 return WalkResult::advance();
1767 });
1768}
1769
1770bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
1771 collectCases(cases);
1772
1773 if (getBody().empty())
1774 return false;
1775
1776 if (!isa<YieldOp>(getBody().front().back()))
1777 return false;
1778
1779 if (!llvm::all_of(getBody().front(),
1780 [](Operation &op) { return isa<CaseOp, YieldOp>(op); }))
1781 return false;
1782
1783 return llvm::all_of(cases, [this](CaseOp op) {
1784 return op->getParentOfType<SwitchOp>() == *this;
1785 });
1786}
1787
1788//===----------------------------------------------------------------------===//
1789// SwitchFlatOp
1790//===----------------------------------------------------------------------===//
1791
1792void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
1793 Value value, Block *defaultDestination,
1794 ValueRange defaultOperands,
1795 ArrayRef<APInt> caseValues,
1796 BlockRange caseDestinations,
1797 ArrayRef<ValueRange> caseOperands) {
1798
1799 std::vector<mlir::Attribute> caseValuesAttrs;
1800 for (const APInt &val : caseValues)
1801 caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
1802 mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
1803
1804 build(builder, result, value, defaultOperands, caseOperands, attrs,
1805 defaultDestination, caseDestinations);
1806}
1807
1808/// <cases> ::= `[` (case (`,` case )* )? `]`
1809/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
1810static ParseResult parseSwitchFlatOpCases(
1811 OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
1812 SmallVectorImpl<Block *> &caseDestinations,
1814 &caseOperands,
1815 SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
1816 if (failed(parser.parseLSquare()))
1817 return failure();
1818 if (succeeded(parser.parseOptionalRSquare()))
1819 return success();
1821
1822 auto parseCase = [&]() {
1823 int64_t value = 0;
1824 if (failed(parser.parseInteger(value)))
1825 return failure();
1826
1827 values.push_back(cir::IntAttr::get(flagType, value));
1828
1829 Block *destination;
1831 llvm::SmallVector<Type> operandTypes;
1832 if (parser.parseColon() || parser.parseSuccessor(destination))
1833 return failure();
1834 if (!parser.parseOptionalLParen()) {
1835 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
1836 /*allowResultNumber=*/false) ||
1837 parser.parseColonTypeList(operandTypes) || parser.parseRParen())
1838 return failure();
1839 }
1840 caseDestinations.push_back(destination);
1841 caseOperands.emplace_back(operands);
1842 caseOperandTypes.emplace_back(operandTypes);
1843 return success();
1844 };
1845 if (failed(parser.parseCommaSeparatedList(parseCase)))
1846 return failure();
1847
1848 caseValues = ArrayAttr::get(flagType.getContext(), values);
1849
1850 return parser.parseRSquare();
1851}
1852
1853static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
1854 Type flagType, mlir::ArrayAttr caseValues,
1855 SuccessorRange caseDestinations,
1856 OperandRangeRange caseOperands,
1857 const TypeRangeRange &caseOperandTypes) {
1858 p << '[';
1859 p.printNewline();
1860 if (!caseValues) {
1861 p << ']';
1862 return;
1863 }
1864
1865 size_t index = 0;
1866 llvm::interleave(
1867 llvm::zip(caseValues, caseDestinations),
1868 [&](auto i) {
1869 p << " ";
1870 mlir::Attribute a = std::get<0>(i);
1871 p << mlir::cast<cir::IntAttr>(a).getValue();
1872 p << ": ";
1873 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
1874 },
1875 [&] {
1876 p << ',';
1877 p.printNewline();
1878 });
1879 p.printNewline();
1880 p << ']';
1881}
1882
1883//===----------------------------------------------------------------------===//
1884// GlobalOp
1885//===----------------------------------------------------------------------===//
1886
1887static ParseResult parseConstantValue(OpAsmParser &parser,
1888 mlir::Attribute &valueAttr) {
1889 NamedAttrList attr;
1890 return parser.parseAttribute(valueAttr, "value", attr);
1891}
1892
1893static void printConstant(OpAsmPrinter &p, Attribute value) {
1894 p.printAttribute(value);
1895}
1896
1897mlir::LogicalResult cir::GlobalOp::verify() {
1898 // Verify that the initial value, if present, is either a unit attribute or
1899 // an attribute CIR supports.
1900 if (getInitialValue().has_value()) {
1901 if (checkConstantTypes(getOperation(), getSymType(), *getInitialValue())
1902 .failed())
1903 return failure();
1904 }
1905
1906 if ((getStaticLocalGuard().has_value() || getTlsModel()) &&
1907 (!getCtorRegion().empty() || !getDtorRegion().empty()))
1908 return emitOpError(
1909 "Cannot have a thread-local or static-local global-op "
1910 "with a constructor or destructor, they require in-function "
1911 "initialization via LocalInitOp");
1912
1913 if (getAliasee().has_value()) {
1914 if (getInitialValue().has_value() || !getCtorRegion().empty() ||
1915 !getDtorRegion().empty())
1916 return emitOpError("global alias shall not have an initializer or "
1917 "constructor/destructor regions");
1918 }
1919
1920 // TODO(CIR): Many other checks for properties that haven't been upstreamed
1921 // yet.
1922
1923 return success();
1924}
1925
1926void cir::GlobalOp::build(
1927 OpBuilder &odsBuilder, OperationState &odsState, llvm::StringRef sym_name,
1928 mlir::Type sym_type, bool isConstant,
1929 mlir::ptr::MemorySpaceAttrInterface addrSpace,
1930 cir::GlobalLinkageKind linkage,
1931 function_ref<void(OpBuilder &, Location)> ctorBuilder,
1932 function_ref<void(OpBuilder &, Location)> dtorBuilder) {
1933 odsState.addAttribute(getSymNameAttrName(odsState.name),
1934 odsBuilder.getStringAttr(sym_name));
1935 odsState.addAttribute(getSymTypeAttrName(odsState.name),
1936 mlir::TypeAttr::get(sym_type));
1937 auto &properties = odsState.getOrAddProperties<cir::GlobalOp::Properties>();
1938 properties.setConstant(isConstant);
1939
1940 addrSpace = normalizeDefaultAddressSpace(addrSpace);
1941 if (addrSpace)
1942 odsState.addAttribute(getAddrSpaceAttrName(odsState.name), addrSpace);
1943
1944 cir::GlobalLinkageKindAttr linkageAttr =
1945 cir::GlobalLinkageKindAttr::get(odsBuilder.getContext(), linkage);
1946 odsState.addAttribute(getLinkageAttrName(odsState.name), linkageAttr);
1947
1948 Region *ctorRegion = odsState.addRegion();
1949 if (ctorBuilder) {
1950 odsBuilder.createBlock(ctorRegion);
1951 ctorBuilder(odsBuilder, odsState.location);
1952 }
1953
1954 Region *dtorRegion = odsState.addRegion();
1955 if (dtorBuilder) {
1956 odsBuilder.createBlock(dtorRegion);
1957 dtorBuilder(odsBuilder, odsState.location);
1958 }
1959}
1960
1961/// Given the region at `index`, or the parent operation if `index` is None,
1962/// return the successor regions. These are the regions that may be selected
1963/// during the flow of control. `operands` is a set of optional attributes that
1964/// correspond to a constant value for each operand, or null if that operand is
1965/// not a constant.
1966void cir::GlobalOp::getSuccessorRegions(
1967 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1968 // The `ctor` and `dtor` regions always branch back to the parent operation.
1969 if (!point.isParent()) {
1970 regions.push_back(RegionSuccessor::parent());
1971 return;
1972 }
1973
1974 // Don't consider the ctor region if it is empty.
1975 Region *ctorRegion = &this->getCtorRegion();
1976 if (ctorRegion->empty())
1977 ctorRegion = nullptr;
1978
1979 // Don't consider the dtor region if it is empty.
1980 Region *dtorRegion = &this->getDtorRegion();
1981 if (dtorRegion->empty())
1982 dtorRegion = nullptr;
1983
1984 // If the condition isn't constant, both regions may be executed.
1985 if (ctorRegion)
1986 regions.push_back(RegionSuccessor(ctorRegion));
1987 if (dtorRegion)
1988 regions.push_back(RegionSuccessor(dtorRegion));
1989}
1990
1991mlir::ValueRange cir::GlobalOp::getSuccessorInputs(RegionSuccessor successor) {
1992 return successor.isParent() ? ValueRange(getOperation()->getResults())
1993 : ValueRange();
1994}
1995
1996static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op,
1997 TypeAttr type, Attribute initAttr,
1998 mlir::Region &ctorRegion,
1999 mlir::Region &dtorRegion) {
2000 auto printType = [&]() { p << ": " << type; };
2001 // Aliases are definitions but they have no initial value or ctor/dtor; the
2002 // assembly prints them like declarations (`: type`).
2003 if (op.isDeclaration() || op.getAliasee()) {
2004 printType();
2005 return;
2006 }
2007
2008 p << "= ";
2009 if (!ctorRegion.empty()) {
2010 p << "ctor ";
2011 printType();
2012 p << " ";
2013 p.printRegion(ctorRegion,
2014 /*printEntryBlockArgs=*/false,
2015 /*printBlockTerminators=*/false);
2016 } else {
2017 // This also prints the type...
2018 if (initAttr)
2019 printConstant(p, initAttr);
2020 }
2021
2022 if (!dtorRegion.empty()) {
2023 p << " dtor ";
2024 p.printRegion(dtorRegion,
2025 /*printEntryBlockArgs=*/false,
2026 /*printBlockTerminators=*/false);
2027 }
2028}
2029
2030static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser,
2031 TypeAttr &typeAttr,
2032 Attribute &initialValueAttr,
2033 mlir::Region &ctorRegion,
2034 mlir::Region &dtorRegion) {
2035 mlir::Type opTy;
2036 if (parser.parseOptionalEqual().failed()) {
2037 // Absence of equal means a declaration, so we need to parse the type.
2038 // cir.global @a : !cir.int<s, 32>
2039 if (parser.parseColonType(opTy))
2040 return failure();
2041 } else {
2042 // Parse contructor, example:
2043 // cir.global @rgb = ctor : type { ... }
2044 if (!parser.parseOptionalKeyword("ctor")) {
2045 if (parser.parseColonType(opTy))
2046 return failure();
2047 auto parseLoc = parser.getCurrentLocation();
2048 if (parser.parseRegion(ctorRegion, /*arguments=*/{}, /*argTypes=*/{}))
2049 return failure();
2050 if (ensureRegionTerm(parser, ctorRegion, parseLoc).failed())
2051 return failure();
2052 } else {
2053 // Parse constant with initializer, examples:
2054 // cir.global @y = 3.400000e+00 : f32
2055 // cir.global @rgb = #cir.const_array<[...] : !cir.array<i8 x 3>>
2056 if (parseConstantValue(parser, initialValueAttr).failed())
2057 return failure();
2058
2059 assert(mlir::isa<mlir::TypedAttr>(initialValueAttr) &&
2060 "Non-typed attrs shouldn't appear here.");
2061 auto typedAttr = mlir::cast<mlir::TypedAttr>(initialValueAttr);
2062 opTy = typedAttr.getType();
2063 }
2064
2065 // Parse destructor, example:
2066 // dtor { ... }
2067 if (!parser.parseOptionalKeyword("dtor")) {
2068 auto parseLoc = parser.getCurrentLocation();
2069 if (parser.parseRegion(dtorRegion, /*arguments=*/{}, /*argTypes=*/{}))
2070 return failure();
2071 if (ensureRegionTerm(parser, dtorRegion, parseLoc).failed())
2072 return failure();
2073 }
2074 }
2075
2076 typeAttr = TypeAttr::get(opTy);
2077 return success();
2078}
2079
2080//===----------------------------------------------------------------------===//
2081// GetGlobalOp
2082//===----------------------------------------------------------------------===//
2083
2084LogicalResult
2085cir::GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2086 // Verify that the result type underlying pointer type matches the type of
2087 // the referenced cir.global or cir.func op.
2088 mlir::Operation *op =
2089 symbolTable.lookupNearestSymbolFrom(*this, getNameAttr());
2090 if (op == nullptr || !(isa<GlobalOp>(op) || isa<FuncOp>(op)))
2091 return emitOpError("'")
2092 << getName()
2093 << "' does not reference a valid cir.global or cir.func";
2094
2095 mlir::Type symTy;
2096 mlir::ptr::MemorySpaceAttrInterface symAddrSpaceAttr{};
2097 if (auto g = dyn_cast<GlobalOp>(op)) {
2098 symTy = g.getSymType();
2099 symAddrSpaceAttr = g.getAddrSpaceAttr();
2100 // Verify that for thread local global access, the global needs to
2101 // be marked with tls bits.
2102 if (getTls() && !g.getTlsModel())
2103 return emitOpError("access to global not marked thread local");
2104
2105 // Verify that the static_local attribute on GetGlobalOp matches the
2106 // static_local_guard attribute on GlobalOp. GetGlobalOp uses a UnitAttr,
2107 // GlobalOp uses StaticLocalGuardAttr. Both should be present, or neither.
2108 bool getGlobalIsStaticLocal = getStaticLocal();
2109 bool globalIsStaticLocal = g.getStaticLocalGuard().has_value();
2110 if (getGlobalIsStaticLocal != globalIsStaticLocal &&
2111 !getOperation()->getParentOfType<cir::GlobalOp>())
2112 return emitOpError("static_local attribute mismatch");
2113 } else if (auto f = dyn_cast<FuncOp>(op)) {
2114 symTy = f.getFunctionType();
2115 } else {
2116 llvm_unreachable("Unexpected operation for GetGlobalOp");
2117 }
2118
2119 auto resultType = dyn_cast<PointerType>(getAddr().getType());
2120 if (!resultType || symTy != resultType.getPointee())
2121 return emitOpError("result type pointee type '")
2122 << resultType.getPointee() << "' does not match type " << symTy
2123 << " of the global @" << getName();
2124
2125 if (symAddrSpaceAttr != resultType.getAddrSpace()) {
2126 return emitOpError()
2127 << "result type address space does not match the address "
2128 "space of the global @"
2129 << getName();
2130 }
2131
2132 return success();
2133}
2134
2135//===----------------------------------------------------------------------===//
2136// VTableAddrPointOp
2137//===----------------------------------------------------------------------===//
2138
2139LogicalResult
2140cir::VTableAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2141 StringRef name = getName();
2142
2143 // Verify that the result type underlying pointer type matches the type of
2144 // the referenced cir.global.
2145 auto op =
2146 symbolTable.lookupNearestSymbolFrom<cir::GlobalOp>(*this, getNameAttr());
2147 if (!op)
2148 return emitOpError("'")
2149 << name << "' does not reference a valid cir.global";
2150 std::optional<mlir::Attribute> init = op.getInitialValue();
2151 if (!init)
2152 return success();
2153 if (!isa<cir::VTableAttr>(*init))
2154 return emitOpError("Expected #cir.vtable in initializer for global '")
2155 << name << "'";
2156 return success();
2157}
2158
2159//===----------------------------------------------------------------------===//
2160// VTTAddrPointOp
2161//===----------------------------------------------------------------------===//
2162
2163LogicalResult
2164cir::VTTAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2165 // VTT ptr is not coming from a symbol.
2166 if (!getName())
2167 return success();
2168 StringRef name = *getName();
2169
2170 // Verify that the result type underlying pointer type matches the type of
2171 // the referenced cir.global op.
2172 auto op =
2173 symbolTable.lookupNearestSymbolFrom<cir::GlobalOp>(*this, getNameAttr());
2174 if (!op)
2175 return emitOpError("'")
2176 << name << "' does not reference a valid cir.global";
2177 std::optional<mlir::Attribute> init = op.getInitialValue();
2178 if (!init)
2179 return success();
2180 if (!isa<cir::ConstArrayAttr>(*init))
2181 return emitOpError(
2182 "Expected constant array in initializer for global VTT '")
2183 << name << "'";
2184 return success();
2185}
2186
2187LogicalResult cir::VTTAddrPointOp::verify() {
2188 // The operation uses either a symbol or a value to operate, but not both
2189 if (getName() && getSymAddr())
2190 return emitOpError("should use either a symbol or value, but not both");
2191
2192 // If not a symbol, stick with the concrete type used for getSymAddr.
2193 if (getSymAddr())
2194 return success();
2195
2196 mlir::Type resultType = getAddr().getType();
2197 mlir::Type resTy = cir::PointerType::get(
2198 cir::PointerType::get(cir::VoidType::get(getContext())));
2199
2200 if (resultType != resTy)
2201 return emitOpError("result type must be ")
2202 << resTy << ", but provided result type is " << resultType;
2203 return success();
2204}
2205
2206//===----------------------------------------------------------------------===//
2207// FuncOp
2208//===----------------------------------------------------------------------===//
2209
2210/// Returns the name used for the linkage attribute. This *must* correspond to
2211/// the name of the attribute in ODS.
2212static llvm::StringRef getLinkageAttrNameString() { return "linkage"; }
2213
2214void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
2215 StringRef name, FuncType type,
2216 GlobalLinkageKind linkage, CallingConv callingConv) {
2217 result.addRegion();
2218 result.addAttribute(SymbolTable::getSymbolAttrName(),
2219 builder.getStringAttr(name));
2220 result.addAttribute(getFunctionTypeAttrName(result.name),
2221 TypeAttr::get(type));
2222 result.addAttribute(
2224 GlobalLinkageKindAttr::get(builder.getContext(), linkage));
2225 result.addAttribute(getCallingConvAttrName(result.name),
2226 CallingConvAttr::get(builder.getContext(), callingConv));
2227}
2228
2229//===----------------------------------------------------------------------===//
2230// AnnotationAttr
2231//===----------------------------------------------------------------------===//
2232
2233LogicalResult
2234cir::AnnotationAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2235 mlir::StringAttr name, mlir::ArrayAttr args) {
2236 if (!args)
2237 return success();
2238 for (mlir::Attribute arg : args) {
2239 if (!isa<mlir::StringAttr, mlir::IntegerAttr>(arg))
2240 return emitError() << "annotation args must be StringAttr or IntegerAttr,"
2241 << " got " << arg;
2242 }
2243 return success();
2244}
2245
2246ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
2247 llvm::SMLoc loc = parser.getCurrentLocation();
2248 mlir::Builder &builder = parser.getBuilder();
2249
2250 mlir::StringAttr builtinNameAttr = getBuiltinAttrName(state.name);
2251 mlir::StringAttr coroutineNameAttr = getCoroutineAttrName(state.name);
2252 mlir::StringAttr inlineKindNameAttr = getInlineKindAttrName(state.name);
2253 mlir::StringAttr lambdaNameAttr = getLambdaAttrName(state.name);
2254 mlir::StringAttr noProtoNameAttr = getNoProtoAttrName(state.name);
2255 mlir::StringAttr comdatNameAttr = getComdatAttrName(state.name);
2256 mlir::StringAttr visNameAttr = getSymVisibilityAttrName(state.name);
2257 mlir::StringAttr dsoLocalNameAttr = getDsoLocalAttrName(state.name);
2258 mlir::StringAttr specialMemberAttr = getCxxSpecialMemberAttrName(state.name);
2259
2260 if (::mlir::succeeded(parser.parseOptionalKeyword(builtinNameAttr.strref())))
2261 state.addAttribute(builtinNameAttr, parser.getBuilder().getUnitAttr());
2262 if (::mlir::succeeded(
2263 parser.parseOptionalKeyword(coroutineNameAttr.strref())))
2264 state.addAttribute(coroutineNameAttr, parser.getBuilder().getUnitAttr());
2265
2266 // Parse optional inline kind attribute
2267 cir::InlineKindAttr inlineKindAttr;
2268 if (failed(parseInlineKindAttr(parser, inlineKindAttr)))
2269 return failure();
2270 if (inlineKindAttr)
2271 state.addAttribute(inlineKindNameAttr, inlineKindAttr);
2272
2273 if (::mlir::succeeded(parser.parseOptionalKeyword(lambdaNameAttr.strref())))
2274 state.addAttribute(lambdaNameAttr, parser.getBuilder().getUnitAttr());
2275 if (parser.parseOptionalKeyword(noProtoNameAttr).succeeded())
2276 state.addAttribute(noProtoNameAttr, parser.getBuilder().getUnitAttr());
2277
2278 if (parser.parseOptionalKeyword(comdatNameAttr).succeeded())
2279 state.addAttribute(comdatNameAttr, parser.getBuilder().getUnitAttr());
2280
2281 // Default to external linkage if no keyword is provided.
2282 state.addAttribute(getLinkageAttrNameString(),
2283 GlobalLinkageKindAttr::get(
2284 parser.getContext(),
2286 parser, GlobalLinkageKind::ExternalLinkage)));
2287
2288 ::llvm::StringRef visAttrStr;
2289 if (parser.parseOptionalKeyword(&visAttrStr, {"private", "public", "nested"})
2290 .succeeded()) {
2291 state.addAttribute(visNameAttr,
2292 parser.getBuilder().getStringAttr(visAttrStr));
2293 }
2294
2295 state.getOrAddProperties<cir::FuncOp::Properties>().global_visibility =
2296 parseOptionalCIRKeyword(parser, cir::VisibilityKind::Default);
2297
2298 if (parser.parseOptionalKeyword(dsoLocalNameAttr).succeeded())
2299 state.addAttribute(dsoLocalNameAttr, parser.getBuilder().getUnitAttr());
2300
2301 StringAttr nameAttr;
2302 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2303 state.attributes))
2304 return failure();
2308 bool isVariadic = false;
2309 if (function_interface_impl::parseFunctionSignatureWithArguments(
2310 parser, /*allowVariadic=*/true, arguments, isVariadic, resultTypes,
2311 resultAttrs))
2312 return failure();
2315 bool argAttrsEmpty = true;
2316 for (OpAsmParser::Argument &arg : arguments) {
2317 argTypes.push_back(arg.type);
2318 // Add the 'empty' attribute anyway to make sure the arity matches, but we
2319 // only want to 'set' the attribute at the top level if there is SOME data
2320 // along the way.
2321 argAttrs.push_back(arg.attrs);
2322 if (arg.attrs)
2323 argAttrsEmpty = false;
2324 }
2325
2326 // These should be in sync anyway, but test both of them anyway.
2327 if (resultTypes.size() > 1 || resultAttrs.size() > 1)
2328 return parser.emitError(
2329 loc, "functions with multiple return types are not supported");
2330
2331 mlir::Type returnType =
2332 (resultTypes.empty() ? cir::VoidType::get(builder.getContext())
2333 : resultTypes.front());
2334
2335 cir::FuncType fnType = cir::FuncType::get(argTypes, returnType, isVariadic);
2336 if (!fnType)
2337 return failure();
2338
2339 state.addAttribute(getFunctionTypeAttrName(state.name),
2340 TypeAttr::get(fnType));
2341
2342 if (!resultAttrs.empty() && resultAttrs[0])
2343 state.addAttribute(
2344 getResAttrsAttrName(state.name),
2345 mlir::ArrayAttr::get(parser.getContext(), {resultAttrs[0]}));
2346
2347 if (!argAttrsEmpty)
2348 state.addAttribute(getArgAttrsAttrName(state.name),
2349 mlir::ArrayAttr::get(parser.getContext(), argAttrs));
2350
2351 bool hasAlias = false;
2352 mlir::StringAttr aliaseeNameAttr = getAliaseeAttrName(state.name);
2353 if (parser.parseOptionalKeyword("alias").succeeded()) {
2354 if (parser.parseLParen().failed())
2355 return failure();
2356 mlir::StringAttr aliaseeAttr;
2357 if (parser.parseOptionalSymbolName(aliaseeAttr).failed())
2358 return failure();
2359 state.addAttribute(aliaseeNameAttr, FlatSymbolRefAttr::get(aliaseeAttr));
2360 if (parser.parseRParen().failed())
2361 return failure();
2362 hasAlias = true;
2363 }
2364
2365 mlir::StringAttr personalityNameAttr = getPersonalityAttrName(state.name);
2366 if (parser.parseOptionalKeyword("personality").succeeded()) {
2367 if (parser.parseLParen().failed())
2368 return failure();
2369 mlir::StringAttr personalityAttr;
2370 if (parser.parseOptionalSymbolName(personalityAttr).failed())
2371 return failure();
2372 state.addAttribute(personalityNameAttr,
2373 FlatSymbolRefAttr::get(personalityAttr));
2374 if (parser.parseRParen().failed())
2375 return failure();
2376 }
2377
2378 // Default to C calling convention if no keyword is provided.
2379 mlir::StringAttr callConvNameAttr = getCallingConvAttrName(state.name);
2380 cir::CallingConv callConv = cir::CallingConv::C;
2381 if (parser.parseOptionalKeyword("cc").succeeded()) {
2382 if (parser.parseLParen().failed())
2383 return failure();
2384 if (parseCIRKeyword<cir::CallingConv>(parser, callConv).failed())
2385 return parser.emitError(loc) << "unknown calling convention";
2386 if (parser.parseRParen().failed())
2387 return failure();
2388 }
2389 state.addAttribute(callConvNameAttr,
2390 cir::CallingConvAttr::get(parser.getContext(), callConv));
2391
2392 auto parseGlobalDtorCtor =
2393 [&](StringRef keyword,
2394 llvm::function_ref<void(std::optional<int> prio)> createAttr)
2395 -> mlir::LogicalResult {
2396 if (mlir::succeeded(parser.parseOptionalKeyword(keyword))) {
2397 std::optional<int> priority;
2398 if (mlir::succeeded(parser.parseOptionalLParen())) {
2399 auto parsedPriority = mlir::FieldParser<int>::parse(parser);
2400 if (mlir::failed(parsedPriority))
2401 return parser.emitError(parser.getCurrentLocation(),
2402 "failed to parse 'priority', of type 'int'");
2403 priority = parsedPriority.value_or(int());
2404 // Parse literal ')'
2405 if (parser.parseRParen())
2406 return failure();
2407 }
2408 createAttr(priority);
2409 }
2410 return success();
2411 };
2412
2413 // Parse CXXSpecialMember attribute
2414 if (parser.parseOptionalKeyword("special_member").succeeded()) {
2415 if (parser.parseLess().failed())
2416 return failure();
2417
2418 mlir::Attribute attr;
2419 if (parser.parseAttribute(attr).failed())
2420 return failure();
2421 if (!mlir::isa<cir::CXXCtorAttr, cir::CXXDtorAttr, cir::CXXAssignAttr>(
2422 attr))
2423 return parser.emitError(parser.getCurrentLocation(),
2424 "expected a C++ special member attribute");
2425 state.addAttribute(specialMemberAttr, attr);
2426
2427 if (parser.parseGreater().failed())
2428 return failure();
2429 }
2430
2431 if (parseGlobalDtorCtor("global_ctor", [&](std::optional<int> priority) {
2432 mlir::IntegerAttr globalCtorPriorityAttr =
2433 builder.getI32IntegerAttr(priority.value_or(65535));
2434 state.addAttribute(getGlobalCtorPriorityAttrName(state.name),
2435 globalCtorPriorityAttr);
2436 }).failed())
2437 return failure();
2438
2439 if (parseGlobalDtorCtor("global_dtor", [&](std::optional<int> priority) {
2440 mlir::IntegerAttr globalDtorPriorityAttr =
2441 builder.getI32IntegerAttr(priority.value_or(65535));
2442 state.addAttribute(getGlobalDtorPriorityAttrName(state.name),
2443 globalDtorPriorityAttr);
2444 }).failed())
2445 return failure();
2446
2447 if (parser.parseOptionalKeyword("side_effect").succeeded()) {
2448 cir::SideEffect sideEffect;
2449
2450 if (parser.parseLParen().failed() ||
2451 parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed() ||
2452 parser.parseRParen().failed())
2453 return failure();
2454
2455 auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
2456 state.addAttribute(CIRDialect::getSideEffectAttrName(), attr);
2457 }
2458
2459 // Parse optional annotations attribute (an ArrayAttr of AnnotationAttr).
2460 mlir::StringAttr annotationsNameAttr = getAnnotationsAttrName(state.name);
2461 mlir::ArrayAttr annotationsAttr;
2462 if (parser.parseOptionalAttribute(annotationsAttr).has_value() &&
2463 annotationsAttr)
2464 state.addAttribute(annotationsNameAttr, annotationsAttr);
2465
2466 // Parse the rest of the attributes.
2467 NamedAttrList parsedAttrs;
2468 if (parser.parseOptionalAttrDictWithKeyword(parsedAttrs))
2469 return failure();
2470
2471 for (StringRef disallowed : cir::FuncOp::getAttributeNames()) {
2472 if (parsedAttrs.get(disallowed))
2473 return parser.emitError(loc, "attribute '")
2474 << disallowed
2475 << "' should not be specified in the explicit attribute list";
2476 }
2477
2478 state.attributes.append(parsedAttrs);
2479
2480 // Parse the optional function body.
2481 auto *body = state.addRegion();
2482 OptionalParseResult parseResult = parser.parseOptionalRegion(
2483 *body, arguments, /*enableNameShadowing=*/false);
2484 if (parseResult.has_value()) {
2485 if (hasAlias)
2486 return parser.emitError(loc, "function alias shall not have a body");
2487 if (failed(*parseResult))
2488 return failure();
2489 // Function body was parsed, make sure its not empty.
2490 if (body->empty())
2491 return parser.emitError(loc, "expected non-empty function body");
2492 }
2493
2494 return success();
2495}
2496
2497// This function corresponds to `llvm::GlobalValue::isDeclaration` and should
2498// have a similar implementation. We don't currently ifuncs or materializable
2499// functions, but those should be handled here as they are implemented.
2500bool cir::FuncOp::isDeclaration() {
2502
2503 std::optional<StringRef> aliasee = getAliasee();
2504 if (!aliasee)
2505 return getFunctionBody().empty();
2506
2507 // Aliases are always definitions.
2508 return false;
2509}
2510
2511bool cir::FuncOp::isCXXSpecialMemberFunction() {
2512 return getCxxSpecialMemberAttr() != nullptr;
2513}
2514
2515bool cir::FuncOp::isCxxConstructor() {
2516 auto attr = getCxxSpecialMemberAttr();
2517 return attr && dyn_cast<CXXCtorAttr>(attr);
2518}
2519
2520bool cir::FuncOp::isCxxDestructor() {
2521 auto attr = getCxxSpecialMemberAttr();
2522 return attr && dyn_cast<CXXDtorAttr>(attr);
2523}
2524
2525bool cir::FuncOp::isCxxSpecialAssignment() {
2526 auto attr = getCxxSpecialMemberAttr();
2527 return attr && dyn_cast<CXXAssignAttr>(attr);
2528}
2529
2530std::optional<CtorKind> cir::FuncOp::getCxxConstructorKind() {
2531 mlir::Attribute attr = getCxxSpecialMemberAttr();
2532 if (attr) {
2533 if (auto ctor = dyn_cast<CXXCtorAttr>(attr))
2534 return ctor.getCtorKind();
2535 }
2536 return std::nullopt;
2537}
2538
2539std::optional<AssignKind> cir::FuncOp::getCxxSpecialAssignKind() {
2540 mlir::Attribute attr = getCxxSpecialMemberAttr();
2541 if (attr) {
2542 if (auto assign = dyn_cast<CXXAssignAttr>(attr))
2543 return assign.getAssignKind();
2544 }
2545 return std::nullopt;
2546}
2547
2548bool cir::FuncOp::isCxxTrivialMemberFunction() {
2549 mlir::Attribute attr = getCxxSpecialMemberAttr();
2550 if (attr) {
2551 if (auto ctor = dyn_cast<CXXCtorAttr>(attr))
2552 return ctor.getIsTrivial();
2553 if (auto dtor = dyn_cast<CXXDtorAttr>(attr))
2554 return dtor.getIsTrivial();
2555 if (auto assign = dyn_cast<CXXAssignAttr>(attr))
2556 return assign.getIsTrivial();
2557 }
2558 return false;
2559}
2560
2561mlir::Region *cir::FuncOp::getCallableRegion() {
2562 // TODO(CIR): This function will have special handling for aliases and a
2563 // check for an external function, once those features have been upstreamed.
2564 return &getBody();
2565}
2566
2567void cir::FuncOp::print(OpAsmPrinter &p) {
2568 if (getBuiltin())
2569 p << " builtin";
2570
2571 if (getCoroutine())
2572 p << " coroutine";
2573
2574 printInlineKindAttr(p, getInlineKindAttr());
2575
2576 if (getLambda())
2577 p << " lambda";
2578
2579 if (getNoProto())
2580 p << " no_proto";
2581
2582 if (getComdat())
2583 p << " comdat";
2584
2585 if (getLinkage() != GlobalLinkageKind::ExternalLinkage)
2586 p << ' ' << stringifyGlobalLinkageKind(getLinkage());
2587
2588 mlir::SymbolTable::Visibility vis = getVisibility();
2589 if (vis != mlir::SymbolTable::Visibility::Public)
2590 p << ' ' << vis;
2591
2592 if (getGlobalVisibility() != cir::VisibilityKind::Default)
2593 p << ' ' << stringifyVisibilityKind(getGlobalVisibility());
2594
2595 if (getDsoLocal())
2596 p << " dso_local";
2597
2598 p << ' ';
2599 p.printSymbolName(getSymName());
2600 cir::FuncType fnType = getFunctionType();
2601 function_interface_impl::printFunctionSignature(
2602 p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes());
2603
2604 if (std::optional<StringRef> aliaseeName = getAliasee()) {
2605 p << " alias(";
2606 p.printSymbolName(*aliaseeName);
2607 p << ")";
2608 }
2609
2610 if (getCallingConv() != cir::CallingConv::C) {
2611 p << " cc(";
2612 p << stringifyCallingConv(getCallingConv());
2613 p << ")";
2614 }
2615
2616 if (std::optional<StringRef> personalityName = getPersonality()) {
2617 p << " personality(";
2618 p.printSymbolName(*personalityName);
2619 p << ")";
2620 }
2621
2622 if (auto specialMemberAttr = getCxxSpecialMember()) {
2623 p << " special_member<";
2624 p.printAttribute(*specialMemberAttr);
2625 p << '>';
2626 }
2627
2628 if (auto globalCtorPriority = getGlobalCtorPriority()) {
2629 p << " global_ctor";
2630 if (globalCtorPriority.value() != 65535)
2631 p << "(" << globalCtorPriority.value() << ")";
2632 }
2633
2634 if (auto globalDtorPriority = getGlobalDtorPriority()) {
2635 p << " global_dtor";
2636 if (globalDtorPriority.value() != 65535)
2637 p << "(" << globalDtorPriority.value() << ")";
2638 }
2639
2640 if (std::optional<cir::SideEffect> sideEffect = getSideEffect();
2641 sideEffect && *sideEffect != cir::SideEffect::All) {
2642 p << " side_effect(";
2643 p << stringifySideEffect(*sideEffect);
2644 p << ")";
2645 }
2646
2647 if (mlir::ArrayAttr annotations = getAnnotationsAttr()) {
2648 p << ' ';
2649 p.printAttribute(annotations);
2650 }
2651
2652 function_interface_impl::printFunctionAttributes(
2653 p, *this, cir::FuncOp::getAttributeNames());
2654
2655 // Print the body if this is not an external function.
2656 Region &body = getOperation()->getRegion(0);
2657 if (!body.empty()) {
2658 p << ' ';
2659 p.printRegion(body, /*printEntryBlockArgs=*/false,
2660 /*printBlockTerminators=*/true);
2661 }
2662}
2663
2664mlir::LogicalResult cir::FuncOp::verify() {
2665
2666 if (!isDeclaration() && getCoroutine()) {
2667 bool foundAwait = false;
2668 int coroBodyCount = 0;
2669 this->walk([&](Operation *op) {
2670 if (auto await = dyn_cast<AwaitOp>(op)) {
2671 foundAwait = true;
2672 } else if (isa<CoroBodyOp>(op)) {
2673 coroBodyCount++;
2674 if (coroBodyCount > 1) {
2675 return mlir::WalkResult::interrupt();
2676 }
2677 }
2678 return mlir::WalkResult::advance();
2679 });
2680 if (!foundAwait)
2681 return emitOpError()
2682 << "coroutine body must use at least one cir.await op";
2683 if (coroBodyCount != 1)
2684 return emitOpError()
2685 << "coroutine function must have exactly one cir.body op";
2686 }
2687
2688 llvm::SmallSet<llvm::StringRef, 16> labels;
2689 llvm::SmallSet<llvm::StringRef, 16> gotos;
2690 llvm::SmallSet<llvm::StringRef, 16> blockAddresses;
2691 bool invalidBlockAddress = false;
2692 getOperation()->walk([&](mlir::Operation *op) {
2693 if (auto lab = dyn_cast<cir::LabelOp>(op)) {
2694 labels.insert(lab.getLabel());
2695 } else if (auto goTo = dyn_cast<cir::GotoOp>(op)) {
2696 gotos.insert(goTo.getLabel());
2697 } else if (auto blkAdd = dyn_cast<cir::BlockAddressOp>(op)) {
2698 if (blkAdd.getBlockAddrInfoAttr().getFunc().getAttr() != getSymName()) {
2699 // Stop the walk early, no need to continue
2700 invalidBlockAddress = true;
2701 return mlir::WalkResult::interrupt();
2702 }
2703 blockAddresses.insert(blkAdd.getBlockAddrInfoAttr().getLabel());
2704 }
2705 return mlir::WalkResult::advance();
2706 });
2707
2708 if (invalidBlockAddress)
2709 return emitOpError() << "blockaddress references a different function";
2710
2711 llvm::SmallSet<llvm::StringRef, 16> mismatched;
2712 if (!labels.empty() || !gotos.empty()) {
2713 mismatched = llvm::set_difference(gotos, labels);
2714
2715 if (!mismatched.empty())
2716 return emitOpError() << "goto/label mismatch";
2717 }
2718
2719 mismatched.clear();
2720
2721 if (!labels.empty() || !blockAddresses.empty()) {
2722 mismatched = llvm::set_difference(blockAddresses, labels);
2723
2724 if (!mismatched.empty())
2725 return emitOpError()
2726 << "expects an existing label target in the referenced function";
2727 }
2728
2729 return success();
2730}
2731
2732//===----------------------------------------------------------------------===//
2733// AddOp / SubOp / MulOp
2734//===----------------------------------------------------------------------===//
2735
2736static LogicalResult verifyBinaryOverflowOp(mlir::Operation *op,
2737 bool noSignedWrap,
2738 bool noUnsignedWrap, bool saturated,
2739 bool hasSat) {
2740 bool noWrap = noSignedWrap || noUnsignedWrap;
2741 if (!isa<cir::IntType>(op->getResultTypes()[0]) && noWrap)
2742 return op->emitError()
2743 << "only operations on integer values may have nsw/nuw flags";
2744 if (hasSat && saturated && !isa<cir::IntType>(op->getResultTypes()[0]))
2745 return op->emitError()
2746 << "only operations on integer values may have sat flag";
2747 if (hasSat && noWrap && saturated)
2748 return op->emitError()
2749 << "the nsw/nuw flags and the saturated flag are mutually exclusive";
2750 return mlir::success();
2751}
2752
2753LogicalResult cir::AddOp::verify() {
2754 return verifyBinaryOverflowOp(getOperation(), getNoSignedWrap(),
2755 getNoUnsignedWrap(), getSaturated(),
2756 /*hasSat=*/true);
2757}
2758
2759LogicalResult cir::SubOp::verify() {
2760 return verifyBinaryOverflowOp(getOperation(), getNoSignedWrap(),
2761 getNoUnsignedWrap(), getSaturated(),
2762 /*hasSat=*/true);
2763}
2764
2765LogicalResult cir::MulOp::verify() {
2766 return verifyBinaryOverflowOp(getOperation(), getNoSignedWrap(),
2767 getNoUnsignedWrap(), /*saturated=*/false,
2768 /*hasSat=*/false);
2769}
2770
2771//===----------------------------------------------------------------------===//
2772// TernaryOp
2773//===----------------------------------------------------------------------===//
2774
2775/// Given the region at `point`, or the parent operation if `point` is None,
2776/// return the successor regions. These are the regions that may be selected
2777/// during the flow of control. `operands` is a set of optional attributes that
2778/// correspond to a constant value for each operand, or null if that operand is
2779/// not a constant.
2780void cir::TernaryOp::getSuccessorRegions(
2781 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2782 // The `true` and the `false` region branch back to the parent operation.
2783 if (!point.isParent()) {
2784 regions.push_back(RegionSuccessor::parent());
2785 return;
2786 }
2787
2788 // When branching from the parent operation, both the true and false
2789 // regions are considered possible successors
2790 regions.push_back(RegionSuccessor(&getTrueRegion()));
2791 regions.push_back(RegionSuccessor(&getFalseRegion()));
2792}
2793
2794mlir::ValueRange cir::TernaryOp::getSuccessorInputs(RegionSuccessor successor) {
2795 return successor.isParent() ? ValueRange(getOperation()->getResults())
2796 : ValueRange();
2797}
2798
2799void cir::TernaryOp::build(
2800 OpBuilder &builder, OperationState &result, Value cond,
2801 function_ref<void(OpBuilder &, Location)> trueBuilder,
2802 function_ref<void(OpBuilder &, Location)> falseBuilder) {
2803 result.addOperands(cond);
2804 OpBuilder::InsertionGuard guard(builder);
2805 Region *trueRegion = result.addRegion();
2806 builder.createBlock(trueRegion);
2807 trueBuilder(builder, result.location);
2808 Region *falseRegion = result.addRegion();
2809 builder.createBlock(falseRegion);
2810 falseBuilder(builder, result.location);
2811
2812 // Get result type from whichever branch has a yield (the other may have
2813 // unreachable from a throw expression)
2814 cir::YieldOp yield;
2815 if (trueRegion->back().mightHaveTerminator())
2816 yield = dyn_cast_or_null<cir::YieldOp>(trueRegion->back().getTerminator());
2817 if (!yield && falseRegion->back().mightHaveTerminator())
2818 yield = dyn_cast_or_null<cir::YieldOp>(falseRegion->back().getTerminator());
2819
2820 assert((!yield || yield.getNumOperands() <= 1) &&
2821 "expected zero or one result type");
2822 if (yield && yield.getNumOperands() == 1)
2823 result.addTypes(TypeRange{yield.getOperandTypes().front()});
2824}
2825
2826//===----------------------------------------------------------------------===//
2827// SelectOp
2828//===----------------------------------------------------------------------===//
2829
2830OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
2831 mlir::Attribute condition = adaptor.getCondition();
2832 if (condition) {
2833 bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
2834 return conditionValue ? getTrueValue() : getFalseValue();
2835 }
2836
2837 // cir.select if %0 then x else x -> x
2838 mlir::Attribute trueValue = adaptor.getTrueValue();
2839 mlir::Attribute falseValue = adaptor.getFalseValue();
2840 if (trueValue == falseValue)
2841 return trueValue;
2842 if (getTrueValue() == getFalseValue())
2843 return getTrueValue();
2844
2845 return {};
2846}
2847
2848LogicalResult cir::SelectOp::verify() {
2849 // AllTypesMatch already guarantees trueVal and falseVal have matching types.
2850 auto condTy = dyn_cast<cir::VectorType>(getCondition().getType());
2851
2852 // If condition is not a vector, no further checks are needed.
2853 if (!condTy)
2854 return success();
2855
2856 // When condition is a vector, both other operands must also be vectors.
2857 if (!isa<cir::VectorType>(getTrueValue().getType()) ||
2858 !isa<cir::VectorType>(getFalseValue().getType())) {
2859 return emitOpError()
2860 << "expected both true and false operands to be vector types "
2861 "when the condition is a vector boolean type";
2862 }
2863
2864 return success();
2865}
2866
2867//===----------------------------------------------------------------------===//
2868// ShiftOp
2869//===----------------------------------------------------------------------===//
2870LogicalResult cir::ShiftOp::verify() {
2871 mlir::Operation *op = getOperation();
2872 auto op0VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(0).getType());
2873 auto op1VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(1).getType());
2874 if (!op0VecTy ^ !op1VecTy)
2875 return emitOpError() << "input types cannot be one vector and one scalar";
2876
2877 if (op0VecTy) {
2878 if (op0VecTy.getSize() != op1VecTy.getSize())
2879 return emitOpError() << "input vector types must have the same size";
2880
2881 auto opResultTy = mlir::dyn_cast<cir::VectorType>(getType());
2882 if (!opResultTy)
2883 return emitOpError() << "the type of the result must be a vector "
2884 << "if it is vector shift";
2885
2886 auto op0VecEleTy = mlir::cast<cir::IntType>(op0VecTy.getElementType());
2887 auto op1VecEleTy = mlir::cast<cir::IntType>(op1VecTy.getElementType());
2888 if (op0VecEleTy.getWidth() != op1VecEleTy.getWidth())
2889 return emitOpError()
2890 << "vector operands do not have the same elements sizes";
2891
2892 auto resVecEleTy = mlir::cast<cir::IntType>(opResultTy.getElementType());
2893 if (op0VecEleTy.getWidth() != resVecEleTy.getWidth())
2894 return emitOpError() << "vector operands and result type do not have the "
2895 "same elements sizes";
2896 }
2897
2898 return mlir::success();
2899}
2900
2901//===----------------------------------------------------------------------===//
2902// LabelOp Definitions
2903//===----------------------------------------------------------------------===//
2904
2905LogicalResult cir::LabelOp::verify() {
2906 mlir::Operation *op = getOperation();
2907 mlir::Block *blk = op->getBlock();
2908 if (&blk->front() != op)
2909 return emitError() << "must be the first operation in a block";
2910
2911 return mlir::success();
2912}
2913
2914//===----------------------------------------------------------------------===//
2915// IncOp
2916//===----------------------------------------------------------------------===//
2917
2918OpFoldResult cir::IncOp::fold(FoldAdaptor adaptor) {
2919 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()))
2920 return adaptor.getInput();
2921 return {};
2922}
2923
2924//===----------------------------------------------------------------------===//
2925// DecOp
2926//===----------------------------------------------------------------------===//
2927
2928OpFoldResult cir::DecOp::fold(FoldAdaptor adaptor) {
2929 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()))
2930 return adaptor.getInput();
2931 return {};
2932}
2933
2934//===----------------------------------------------------------------------===//
2935// MinusOp
2936//===----------------------------------------------------------------------===//
2937
2938OpFoldResult cir::MinusOp::fold(FoldAdaptor adaptor) {
2939 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()))
2940 return adaptor.getInput();
2941
2942 // Avoid materializing a duplicate constant for bool minus (identity).
2943 if (auto srcConst = getInput().getDefiningOp<cir::ConstantOp>())
2944 if (mlir::isa<cir::BoolType>(srcConst.getType()))
2945 return srcConst.getResult();
2946
2947 // Fold with constant inputs.
2948 if (mlir::Attribute attr = adaptor.getInput()) {
2949 if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(attr)) {
2950 APInt val = intAttr.getValue();
2951 val.negate();
2952 return cir::IntAttr::get(getType(), val);
2953 }
2954 if (auto fpAttr = mlir::dyn_cast<cir::FPAttr>(attr)) {
2955 APFloat val = fpAttr.getValue();
2956 val.changeSign();
2957 return cir::FPAttr::get(getType(), val);
2958 }
2959 }
2960
2961 return {};
2962}
2963
2964//===----------------------------------------------------------------------===//
2965// NotOp
2966//===----------------------------------------------------------------------===//
2967
2968OpFoldResult cir::NotOp::fold(FoldAdaptor adaptor) {
2969 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()))
2970 return adaptor.getInput();
2971
2972 // not(not(x)) -> x is handled by the Involution trait.
2973
2974 // Fold with constant inputs.
2975 if (mlir::Attribute attr = adaptor.getInput()) {
2976 if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(attr)) {
2977 APInt val = intAttr.getValue();
2978 val.flipAllBits();
2979 return cir::IntAttr::get(getType(), val);
2980 }
2981 if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr))
2982 return cir::BoolAttr::get(getContext(), !boolAttr.getValue());
2983 }
2984
2985 return {};
2986}
2987
2988//===----------------------------------------------------------------------===//
2989// BaseDataMemberOp & DerivedDataMemberOp
2990//===----------------------------------------------------------------------===//
2991
2992static LogicalResult verifyMemberPtrCast(Operation *op, mlir::Value src,
2993 mlir::Type resultTy) {
2994 // Let the operand type be T1 C1::*, let the result type be T2 C2::*.
2995 // Verify that T1 and T2 are the same type.
2996 mlir::Type inputMemberTy;
2997 mlir::Type resultMemberTy;
2998 if (mlir::isa<cir::DataMemberType>(src.getType())) {
2999 inputMemberTy =
3000 mlir::cast<cir::DataMemberType>(src.getType()).getMemberTy();
3001 resultMemberTy = mlir::cast<cir::DataMemberType>(resultTy).getMemberTy();
3002 }
3004 if (inputMemberTy != resultMemberTy)
3005 return op->emitOpError()
3006 << "member types of the operand and the result do not match";
3007
3008 return mlir::success();
3009}
3010
3011LogicalResult cir::BaseDataMemberOp::verify() {
3012 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
3013}
3014
3015LogicalResult cir::DerivedDataMemberOp::verify() {
3016 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
3017}
3018
3019//===----------------------------------------------------------------------===//
3020// BaseMethodOp & DerivedMethodOp
3021//===----------------------------------------------------------------------===//
3022
3023LogicalResult cir::BaseMethodOp::verify() {
3024 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
3025}
3026
3027LogicalResult cir::DerivedMethodOp::verify() {
3028 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
3029}
3030
3031//===----------------------------------------------------------------------===//
3032// AwaitOp
3033//===----------------------------------------------------------------------===//
3034
3035void cir::AwaitOp::build(OpBuilder &builder, OperationState &result,
3036 cir::AwaitKind kind, BuilderCallbackRef readyBuilder,
3037 BuilderCallbackRef suspendBuilder,
3038 BuilderCallbackRef resumeBuilder) {
3039 result.addAttribute(getKindAttrName(result.name),
3040 cir::AwaitKindAttr::get(builder.getContext(), kind));
3041 {
3042 OpBuilder::InsertionGuard guard(builder);
3043 Region *readyRegion = result.addRegion();
3044 builder.createBlock(readyRegion);
3045 readyBuilder(builder, result.location);
3046 }
3047
3048 {
3049 OpBuilder::InsertionGuard guard(builder);
3050 Region *suspendRegion = result.addRegion();
3051 builder.createBlock(suspendRegion);
3052 suspendBuilder(builder, result.location);
3053 }
3054
3055 {
3056 OpBuilder::InsertionGuard guard(builder);
3057 Region *resumeRegion = result.addRegion();
3058 builder.createBlock(resumeRegion);
3059 resumeBuilder(builder, result.location);
3060 }
3061}
3062
3063void cir::AwaitOp::getSuccessorRegions(
3064 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
3065 // If any index all the underlying regions branch back to the parent
3066 // operation.
3067 if (!point.isParent()) {
3068 regions.push_back(RegionSuccessor::parent());
3069 return;
3070 }
3071
3072 // TODO: retrieve information from the promise and only push the
3073 // necessary ones. Example: `std::suspend_never` on initial or final
3074 // await's might allow suspend region to be skipped.
3075 regions.push_back(RegionSuccessor(&this->getReady()));
3076 regions.push_back(RegionSuccessor(&this->getSuspend()));
3077 regions.push_back(RegionSuccessor(&this->getResume()));
3078}
3079
3080mlir::ValueRange cir::AwaitOp::getSuccessorInputs(RegionSuccessor successor) {
3081 if (successor.isParent())
3082 return getOperation()->getResults();
3083 if (successor == &getReady())
3084 return getReady().getArguments();
3085 if (successor == &getSuspend())
3086 return getSuspend().getArguments();
3087 if (successor == &getResume())
3088 return getResume().getArguments();
3089 llvm_unreachable("invalid region successor");
3090}
3091
3092LogicalResult cir::AwaitOp::verify() {
3093 if (!isa<ConditionOp>(this->getReady().back().getTerminator()))
3094 return emitOpError("ready region must end with cir.condition");
3095 return success();
3096}
3097
3098LogicalResult cir::CoReturnOp::verify() {
3099 if (!getOperation()->getParentOfType<CoroBodyOp>())
3100 return emitOpError("must be inside a cir.coro.body");
3101 return success();
3102}
3103
3104//===----------------------------------------------------------------------===//
3105// CoroBody
3106//===----------------------------------------------------------------------===//
3107
3108void cir::CoroBodyOp::getSuccessorRegions(
3109 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
3110 if (!point.isParent()) {
3111 regions.push_back(RegionSuccessor::parent());
3112 return;
3113 }
3114
3115 regions.push_back(RegionSuccessor(&getBody()));
3116}
3117
3118mlir::ValueRange
3119cir::CoroBodyOp::getSuccessorInputs(RegionSuccessor successor) {
3120 return ValueRange();
3121}
3122
3123LogicalResult cir::CoroBodyOp::verify() {
3124 if (!getOperation()->getParentOfType<FuncOp>().getCoroutine())
3125 return emitOpError("enclosing function must be a coroutine");
3126 return success();
3127}
3128
3129void cir::CoroBodyOp::build(OpBuilder &builder, OperationState &result,
3130 BuilderCallbackRef bodyBuilder) {
3131 assert(bodyBuilder &&
3132 "the builder callback for 'CoroBodyOp' must be present");
3133 OpBuilder::InsertionGuard guard(builder);
3134
3135 Region *bodyRegion = result.addRegion();
3136 builder.createBlock(bodyRegion);
3137 bodyBuilder(builder, result.location);
3138}
3139
3140//===----------------------------------------------------------------------===//
3141// CopyOp Definitions
3142//===----------------------------------------------------------------------===//
3143
3144LogicalResult cir::CopyOp::verify() {
3145 // A data layout is required for us to know the number of bytes to be copied.
3146 if (!getType().getPointee().hasTrait<DataLayoutTypeInterface::Trait>())
3147 return emitError() << "missing data layout for pointee type";
3148
3149 if (getSkipTailPadding() &&
3150 !mlir::isa<cir::RecordType>(getType().getPointee()))
3151 return emitError()
3152 << "skip_tail_padding is only valid for record pointee types";
3153
3154 return mlir::success();
3155}
3156
3157//===----------------------------------------------------------------------===//
3158// GetRuntimeMemberOp Definitions
3159//===----------------------------------------------------------------------===//
3160
3161LogicalResult cir::GetRuntimeMemberOp::verify() {
3162 auto recordTy = mlir::cast<RecordType>(getAddr().getType().getPointee());
3163 cir::DataMemberType memberPtrTy = getMember().getType();
3164
3165 if (recordTy != memberPtrTy.getClassTy())
3166 return emitError() << "record type does not match the member pointer type";
3167 if (getType().getPointee() != memberPtrTy.getMemberTy())
3168 return emitError() << "result type does not match the member pointer type";
3169 return mlir::success();
3170}
3171
3172//===----------------------------------------------------------------------===//
3173// GetMethodOp Definitions
3174//===----------------------------------------------------------------------===//
3175
3176LogicalResult cir::GetMethodOp::verify() {
3177 cir::MethodType methodTy = getMethod().getType();
3178
3179 // Assume objectTy is !cir.ptr<!T>
3180 cir::PointerType objectPtrTy = getObject().getType();
3181 mlir::Type objectTy = objectPtrTy.getPointee();
3182
3183 if (methodTy.getClassTy() != objectTy)
3184 return emitError() << "method class type and object type do not match";
3185
3186 // Assume methodFuncTy is !cir.func<!Ret (!Args)>
3187 auto calleeTy = mlir::cast<cir::FuncType>(getCallee().getType().getPointee());
3188 cir::FuncType methodFuncTy = methodTy.getMemberFuncTy();
3189
3190 // We verify at here that calleeTy is !cir.func<!Ret (!cir.ptr<!void>, !Args)>
3191 // Note that the first parameter type of the callee is !cir.ptr<!void> instead
3192 // of !cir.ptr<!T> because the "this" pointer may be adjusted before calling
3193 // the callee.
3194
3195 if (methodFuncTy.getReturnType() != calleeTy.getReturnType())
3196 return emitError()
3197 << "method return type and callee return type do not match";
3198
3199 llvm::ArrayRef<mlir::Type> calleeArgsTy = calleeTy.getInputs();
3200 llvm::ArrayRef<mlir::Type> methodFuncArgsTy = methodFuncTy.getInputs();
3201
3202 if (calleeArgsTy.empty())
3203 return emitError() << "callee parameter list lacks receiver object ptr";
3204
3205 auto calleeThisArgPtrTy = mlir::dyn_cast<cir::PointerType>(calleeArgsTy[0]);
3206 if (!calleeThisArgPtrTy ||
3207 !mlir::isa<cir::VoidType>(calleeThisArgPtrTy.getPointee())) {
3208 return emitError()
3209 << "the first parameter of callee must be a void pointer";
3210 }
3211
3212 if (calleeArgsTy.slice(1) != methodFuncArgsTy)
3213 return emitError()
3214 << "callee parameters and method parameters do not match";
3215
3216 return mlir::success();
3217}
3218
3219//===----------------------------------------------------------------------===//
3220// GetMemberOp Definitions
3221//===----------------------------------------------------------------------===//
3222
3223LogicalResult cir::GetMemberOp::verify() {
3224 const auto recordTy = dyn_cast<RecordType>(getAddrTy().getPointee());
3225 if (!recordTy)
3226 return emitError() << "expected pointer to a record type";
3227
3228 if (recordTy.getMembers().size() <= getIndex())
3229 return emitError() << "member index out of bounds";
3230
3231 if (recordTy.getMembers()[getIndex()] != getType().getPointee())
3232 return emitError() << "member type mismatch";
3233
3234 return mlir::success();
3235}
3236
3237//===----------------------------------------------------------------------===//
3238// ExtractMemberOp Definitions
3239//===----------------------------------------------------------------------===//
3240
3241LogicalResult cir::ExtractMemberOp::verify() {
3242 auto recordTy = mlir::cast<cir::RecordType>(getRecord().getType());
3243 if (recordTy.getKind() == cir::RecordType::Union)
3244 return emitError()
3245 << "cir.extract_member currently does not support unions";
3246 if (recordTy.getMembers().size() <= getIndex())
3247 return emitError() << "member index out of bounds";
3248 if (recordTy.getMembers()[getIndex()] != getType())
3249 return emitError() << "member type mismatch";
3250 return mlir::success();
3251}
3252
3253//===----------------------------------------------------------------------===//
3254// InsertMemberOp Definitions
3255//===----------------------------------------------------------------------===//
3256
3257LogicalResult cir::InsertMemberOp::verify() {
3258 auto recordTy = mlir::cast<cir::RecordType>(getRecord().getType());
3259 if (recordTy.getKind() == cir::RecordType::Union)
3260 return emitError() << "cir.insert_member currently does not support unions";
3261 if (recordTy.getMembers().size() <= getIndex())
3262 return emitError() << "member index out of bounds";
3263 if (recordTy.getMembers()[getIndex()] != getValue().getType())
3264 return emitError() << "member type mismatch";
3265 // The op trait already checks that the types of $result and $record match.
3266 return mlir::success();
3267}
3268
3269//===----------------------------------------------------------------------===//
3270// VecCreateOp
3271//===----------------------------------------------------------------------===//
3272
3273OpFoldResult cir::VecCreateOp::fold(FoldAdaptor adaptor) {
3274 if (llvm::any_of(getElements(), [](mlir::Value value) {
3275 return !value.getDefiningOp<cir::ConstantOp>();
3276 }))
3277 return {};
3278
3279 return cir::ConstVectorAttr::get(
3280 getType(), mlir::ArrayAttr::get(getContext(), adaptor.getElements()));
3281}
3282
3283LogicalResult cir::VecCreateOp::verify() {
3284 // Verify that the number of arguments matches the number of elements in the
3285 // vector, and that the type of all the arguments matches the type of the
3286 // elements in the vector.
3287 const cir::VectorType vecTy = getType();
3288 if (getElements().size() != vecTy.getSize()) {
3289 return emitOpError() << "operand count of " << getElements().size()
3290 << " doesn't match vector type " << vecTy
3291 << " element count of " << vecTy.getSize();
3292 }
3293
3294 const mlir::Type elementType = vecTy.getElementType();
3295 for (const mlir::Value element : getElements()) {
3296 if (element.getType() != elementType) {
3297 return emitOpError() << "operand type " << element.getType()
3298 << " doesn't match vector element type "
3299 << elementType;
3300 }
3301 }
3302
3303 return success();
3304}
3305
3306//===----------------------------------------------------------------------===//
3307// VecExtractOp
3308//===----------------------------------------------------------------------===//
3309
3310OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
3311 const auto vectorAttr =
3312 llvm::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec());
3313 if (!vectorAttr)
3314 return {};
3315
3316 const auto indexAttr =
3317 llvm::dyn_cast_if_present<cir::IntAttr>(adaptor.getIndex());
3318 if (!indexAttr)
3319 return {};
3320
3321 const mlir::ArrayAttr elements = vectorAttr.getElts();
3322 const uint64_t index = indexAttr.getUInt();
3323 if (index >= elements.size())
3324 return {};
3325
3326 return elements[index];
3327}
3328
3329//===----------------------------------------------------------------------===//
3330// VecCmpOp
3331//===----------------------------------------------------------------------===//
3332
3333OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
3334 auto lhsVecAttr =
3335 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getLhs());
3336 auto rhsVecAttr =
3337 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getRhs());
3338 if (!lhsVecAttr || !rhsVecAttr)
3339 return {};
3340
3341 mlir::Type inputElemTy =
3342 mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType();
3343 if (!isAnyIntegerOrFloatingPointType(inputElemTy))
3344 return {};
3345
3346 cir::CmpOpKind opKind = adaptor.getKind();
3347 mlir::ArrayAttr lhsVecElhs = lhsVecAttr.getElts();
3348 mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts();
3349 uint64_t vecSize = lhsVecElhs.size();
3350
3351 SmallVector<mlir::Attribute, 16> elements(vecSize);
3352 bool isIntAttr = vecSize && mlir::isa<cir::IntAttr>(lhsVecElhs[0]);
3353 for (uint64_t i = 0; i < vecSize; i++) {
3354 mlir::Attribute lhsAttr = lhsVecElhs[i];
3355 mlir::Attribute rhsAttr = rhsVecElhs[i];
3356 int cmpResult = 0;
3357 switch (opKind) {
3358 case cir::CmpOpKind::lt: {
3359 if (isIntAttr) {
3360 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
3361 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3362 } else {
3363 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <
3364 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3365 }
3366 break;
3367 }
3368 case cir::CmpOpKind::le: {
3369 if (isIntAttr) {
3370 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
3371 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3372 } else {
3373 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <=
3374 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3375 }
3376 break;
3377 }
3378 case cir::CmpOpKind::gt: {
3379 if (isIntAttr) {
3380 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
3381 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3382 } else {
3383 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >
3384 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3385 }
3386 break;
3387 }
3388 case cir::CmpOpKind::ge: {
3389 if (isIntAttr) {
3390 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
3391 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3392 } else {
3393 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >=
3394 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3395 }
3396 break;
3397 }
3398 case cir::CmpOpKind::eq: {
3399 if (isIntAttr) {
3400 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
3401 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3402 } else {
3403 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() ==
3404 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3405 }
3406 break;
3407 }
3408 case cir::CmpOpKind::ne: {
3409 if (isIntAttr) {
3410 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
3411 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3412 } else {
3413 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() !=
3414 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3415 }
3416 break;
3417 }
3418 case cir::CmpOpKind::one: {
3419 llvm::APFloat::cmpResult cr =
3420 mlir::cast<cir::FPAttr>(lhsAttr).getValue().compare(
3421 mlir::cast<cir::FPAttr>(rhsAttr).getValue());
3422 cmpResult =
3423 cr != llvm::APFloat::cmpUnordered && cr != llvm::APFloat::cmpEqual;
3424 break;
3425 }
3426 case cir::CmpOpKind::uno: {
3427 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue().compare(
3428 mlir::cast<cir::FPAttr>(rhsAttr).getValue()) ==
3429 llvm::APFloat::cmpUnordered;
3430 break;
3431 }
3432 }
3433
3434 elements[i] = cir::IntAttr::get(getType().getElementType(), cmpResult);
3435 }
3436
3437 return cir::ConstVectorAttr::get(
3438 getType(), mlir::ArrayAttr::get(getContext(), elements));
3439}
3440
3441//===----------------------------------------------------------------------===//
3442// VecShuffleOp
3443//===----------------------------------------------------------------------===//
3444
3445OpFoldResult cir::VecShuffleOp::fold(FoldAdaptor adaptor) {
3446 auto vec1Attr =
3447 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec1());
3448 auto vec2Attr =
3449 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec2());
3450 if (!vec1Attr || !vec2Attr)
3451 return {};
3452
3453 mlir::Type vec1ElemTy =
3454 mlir::cast<cir::VectorType>(vec1Attr.getType()).getElementType();
3455
3456 mlir::ArrayAttr vec1Elts = vec1Attr.getElts();
3457 mlir::ArrayAttr vec2Elts = vec2Attr.getElts();
3458 mlir::ArrayAttr indicesElts = adaptor.getIndices();
3459
3461 elements.reserve(indicesElts.size());
3462
3463 uint64_t vec1Size = vec1Elts.size();
3464 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
3465 if (idxAttr.getSInt() == -1) {
3466 elements.push_back(cir::UndefAttr::get(vec1ElemTy));
3467 continue;
3468 }
3469
3470 uint64_t idxValue = idxAttr.getUInt();
3471 elements.push_back(idxValue < vec1Size ? vec1Elts[idxValue]
3472 : vec2Elts[idxValue - vec1Size]);
3473 }
3474
3475 return cir::ConstVectorAttr::get(
3476 getType(), mlir::ArrayAttr::get(getContext(), elements));
3477}
3478
3479LogicalResult cir::VecShuffleOp::verify() {
3480 // The number of elements in the indices array must match the number of
3481 // elements in the result type.
3482 if (getIndices().size() != getResult().getType().getSize()) {
3483 return emitOpError() << ": the number of elements in " << getIndices()
3484 << " and " << getResult().getType() << " don't match";
3485 }
3486
3487 // The element types of the two input vectors and of the result type must
3488 // match.
3489 if (getVec1().getType().getElementType() !=
3490 getResult().getType().getElementType()) {
3491 return emitOpError() << ": element types of " << getVec1().getType()
3492 << " and " << getResult().getType() << " don't match";
3493 }
3494
3495 const uint64_t maxValidIndex =
3496 getVec1().getType().getSize() + getVec2().getType().getSize() - 1;
3497 if (llvm::any_of(
3498 getIndices().getAsRange<cir::IntAttr>(), [&](cir::IntAttr idxAttr) {
3499 return idxAttr.getSInt() != -1 && idxAttr.getUInt() > maxValidIndex;
3500 })) {
3501 return emitOpError() << ": index for __builtin_shufflevector must be "
3502 "less than the total number of vector elements";
3503 }
3504 return success();
3505}
3506
3507//===----------------------------------------------------------------------===//
3508// VecShuffleDynamicOp
3509//===----------------------------------------------------------------------===//
3510
3511OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
3512 mlir::Attribute vec = adaptor.getVec();
3513 mlir::Attribute indices = adaptor.getIndices();
3514 if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec) &&
3515 mlir::isa_and_nonnull<cir::ConstVectorAttr>(indices)) {
3516 auto vecAttr = mlir::cast<cir::ConstVectorAttr>(vec);
3517 auto indicesAttr = mlir::cast<cir::ConstVectorAttr>(indices);
3518
3519 mlir::ArrayAttr vecElts = vecAttr.getElts();
3520 mlir::ArrayAttr indicesElts = indicesAttr.getElts();
3521
3522 const uint64_t numElements = vecElts.size();
3523
3525 elements.reserve(numElements);
3526
3527 const uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1;
3528 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
3529 uint64_t idxValue = idxAttr.getUInt();
3530 uint64_t newIdx = idxValue & maskBits;
3531 elements.push_back(vecElts[newIdx]);
3532 }
3533
3534 return cir::ConstVectorAttr::get(
3535 getType(), mlir::ArrayAttr::get(getContext(), elements));
3536 }
3537
3538 return {};
3539}
3540
3541LogicalResult cir::VecShuffleDynamicOp::verify() {
3542 // The number of elements in the two input vectors must match.
3543 if (getVec().getType().getSize() !=
3544 mlir::cast<cir::VectorType>(getIndices().getType()).getSize()) {
3545 return emitOpError() << ": the number of elements in " << getVec().getType()
3546 << " and " << getIndices().getType() << " don't match";
3547 }
3548 return success();
3549}
3550
3551//===----------------------------------------------------------------------===//
3552// VecTernaryOp
3553//===----------------------------------------------------------------------===//
3554
3555LogicalResult cir::VecTernaryOp::verify() {
3556 // Verify that the condition operand has the same number of elements as the
3557 // other operands. (The automatic verification already checked that all
3558 // operands are vector types and that the second and third operands are the
3559 // same type.)
3560 if (getCond().getType().getSize() != getLhs().getType().getSize()) {
3561 return emitOpError() << ": the number of elements in "
3562 << getCond().getType() << " and " << getLhs().getType()
3563 << " don't match";
3564 }
3565 return success();
3566}
3567
3568OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
3569 mlir::Attribute cond = adaptor.getCond();
3570 mlir::Attribute lhs = adaptor.getLhs();
3571 mlir::Attribute rhs = adaptor.getRhs();
3572
3573 if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
3574 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
3575 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
3576 return {};
3577 auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
3578 auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
3579 auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
3580
3581 mlir::ArrayAttr condElts = condVec.getElts();
3582
3584 elements.reserve(condElts.size());
3585
3586 for (const auto &[idx, condAttr] :
3587 llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
3588 if (condAttr.getSInt()) {
3589 elements.push_back(lhsVec.getElts()[idx]);
3590 } else {
3591 elements.push_back(rhsVec.getElts()[idx]);
3592 }
3593 }
3594
3595 cir::VectorType vecTy = getLhs().getType();
3596 return cir::ConstVectorAttr::get(
3597 vecTy, mlir::ArrayAttr::get(getContext(), elements));
3598}
3599
3600//===----------------------------------------------------------------------===//
3601// ComplexCreateOp
3602//===----------------------------------------------------------------------===//
3603
3604LogicalResult cir::ComplexCreateOp::verify() {
3605 if (getType().getElementType() != getReal().getType()) {
3606 emitOpError()
3607 << "operand type of cir.complex.create does not match its result type";
3608 return failure();
3609 }
3610
3611 return success();
3612}
3613
3614OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
3615 mlir::Attribute real = adaptor.getReal();
3616 mlir::Attribute imag = adaptor.getImag();
3617 if (!real || !imag)
3618 return {};
3619
3620 // When both of real and imag are constants, we can fold the operation into an
3621 // `#cir.const_complex` operation.
3622 auto realAttr = mlir::cast<mlir::TypedAttr>(real);
3623 auto imagAttr = mlir::cast<mlir::TypedAttr>(imag);
3624 return cir::ConstComplexAttr::get(realAttr, imagAttr);
3625}
3626
3627//===----------------------------------------------------------------------===//
3628// ComplexRealOp
3629//===----------------------------------------------------------------------===//
3630
3631LogicalResult cir::ComplexRealOp::verify() {
3632 mlir::Type operandTy = getOperand().getType();
3633 if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
3634 operandTy = complexOperandTy.getElementType();
3635
3636 if (getType() != operandTy) {
3637 emitOpError() << ": result type does not match operand type";
3638 return failure();
3639 }
3640
3641 return success();
3642}
3643
3644OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
3645 if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
3646 return nullptr;
3647
3648 if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
3649 return complexCreateOp.getOperand(0);
3650
3651 auto complex =
3652 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
3653 return complex ? complex.getReal() : nullptr;
3654}
3655
3656//===----------------------------------------------------------------------===//
3657// ComplexImagOp
3658//===----------------------------------------------------------------------===//
3659
3660LogicalResult cir::ComplexImagOp::verify() {
3661 mlir::Type operandTy = getOperand().getType();
3662 if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
3663 operandTy = complexOperandTy.getElementType();
3664
3665 if (getType() != operandTy) {
3666 emitOpError() << ": result type does not match operand type";
3667 return failure();
3668 }
3669
3670 return success();
3671}
3672
3673OpFoldResult cir::ComplexImagOp::fold(FoldAdaptor adaptor) {
3674 if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
3675 return nullptr;
3676
3677 if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
3678 return complexCreateOp.getOperand(1);
3679
3680 auto complex =
3681 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
3682 return complex ? complex.getImag() : nullptr;
3683}
3684
3685//===----------------------------------------------------------------------===//
3686// ComplexRealPtrOp
3687//===----------------------------------------------------------------------===//
3688
3689LogicalResult cir::ComplexRealPtrOp::verify() {
3690 mlir::Type resultPointeeTy = getType().getPointee();
3691 cir::PointerType operandPtrTy = getOperand().getType();
3692 auto operandPointeeTy =
3693 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
3694
3695 if (resultPointeeTy != operandPointeeTy.getElementType()) {
3696 return emitOpError() << ": result type does not match operand type";
3697 }
3698
3699 return success();
3700}
3701
3702//===----------------------------------------------------------------------===//
3703// ComplexImagPtrOp
3704//===----------------------------------------------------------------------===//
3705
3706LogicalResult cir::ComplexImagPtrOp::verify() {
3707 mlir::Type resultPointeeTy = getType().getPointee();
3708 cir::PointerType operandPtrTy = getOperand().getType();
3709 auto operandPointeeTy =
3710 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
3711
3712 if (resultPointeeTy != operandPointeeTy.getElementType()) {
3713 return emitOpError()
3714 << "cir.complex.imag_ptr result type does not match operand type";
3715 }
3716 return success();
3717}
3718
3719//===----------------------------------------------------------------------===//
3720// Bit manipulation operations
3721//===----------------------------------------------------------------------===//
3722
3723static OpFoldResult
3724foldUnaryBitOp(mlir::Attribute inputAttr,
3725 llvm::function_ref<llvm::APInt(const llvm::APInt &)> func,
3726 bool poisonZero = false) {
3727 if (mlir::isa_and_present<cir::PoisonAttr>(inputAttr)) {
3728 // Propagate poison value
3729 return inputAttr;
3730 }
3731
3732 auto input = mlir::dyn_cast_if_present<IntAttr>(inputAttr);
3733 if (!input)
3734 return nullptr;
3735
3736 llvm::APInt inputValue = input.getValue();
3737 if (poisonZero && inputValue.isZero())
3738 return cir::PoisonAttr::get(input.getType());
3739
3740 llvm::APInt resultValue = func(inputValue);
3741 return IntAttr::get(input.getType(), resultValue);
3742}
3743
3744OpFoldResult BitClrsbOp::fold(FoldAdaptor adaptor) {
3745 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3746 unsigned resultValue =
3747 inputValue.getBitWidth() - inputValue.getSignificantBits();
3748 return llvm::APInt(inputValue.getBitWidth(), resultValue);
3749 });
3750}
3751
3752OpFoldResult BitClzOp::fold(FoldAdaptor adaptor) {
3753 return foldUnaryBitOp(
3754 adaptor.getInput(),
3755 [](const llvm::APInt &inputValue) {
3756 unsigned resultValue = inputValue.countLeadingZeros();
3757 return llvm::APInt(inputValue.getBitWidth(), resultValue);
3758 },
3759 getPoisonZero());
3760}
3761
3762OpFoldResult BitCtzOp::fold(FoldAdaptor adaptor) {
3763 return foldUnaryBitOp(
3764 adaptor.getInput(),
3765 [](const llvm::APInt &inputValue) {
3766 return llvm::APInt(inputValue.getBitWidth(),
3767 inputValue.countTrailingZeros());
3768 },
3769 getPoisonZero());
3770}
3771
3772OpFoldResult BitFfsOp::fold(FoldAdaptor adaptor) {
3773 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3774 unsigned trailingZeros = inputValue.countTrailingZeros();
3775 unsigned result =
3776 trailingZeros == inputValue.getBitWidth() ? 0 : trailingZeros + 1;
3777 return llvm::APInt(inputValue.getBitWidth(), result);
3778 });
3779}
3780
3781OpFoldResult BitParityOp::fold(FoldAdaptor adaptor) {
3782 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3783 return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount() % 2);
3784 });
3785}
3786
3787OpFoldResult BitPopcountOp::fold(FoldAdaptor adaptor) {
3788 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3789 return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount());
3790 });
3791}
3792
3793OpFoldResult BitReverseOp::fold(FoldAdaptor adaptor) {
3794 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3795 return inputValue.reverseBits();
3796 });
3797}
3798
3799OpFoldResult ByteSwapOp::fold(FoldAdaptor adaptor) {
3800 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3801 return inputValue.byteSwap();
3802 });
3803}
3804
3805OpFoldResult RotateOp::fold(FoldAdaptor adaptor) {
3806 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()) ||
3807 mlir::isa_and_present<cir::PoisonAttr>(adaptor.getAmount())) {
3808 // Propagate poison values
3809 return cir::PoisonAttr::get(getType());
3810 }
3811
3812 auto input = mlir::dyn_cast_if_present<IntAttr>(adaptor.getInput());
3813 auto amount = mlir::dyn_cast_if_present<IntAttr>(adaptor.getAmount());
3814 if (!input && !amount)
3815 return nullptr;
3816
3817 // We could fold cir.rotate even if one of its two operands is not a constant:
3818 // - `cir.rotate left/right %0, 0` could be folded into just %0 even if %0
3819 // is not a constant.
3820 // - `cir.rotate left/right 0/0b111...111, %0` could be folded into 0 or
3821 // 0b111...111 even if %0 is not a constant.
3822
3823 llvm::APInt inputValue;
3824 if (input) {
3825 inputValue = input.getValue();
3826 if (inputValue.isZero() || inputValue.isAllOnes()) {
3827 // An input value of all 0s or all 1s will not change after rotation
3828 return input;
3829 }
3830 }
3831
3832 uint64_t amountValue;
3833 if (amount) {
3834 amountValue = amount.getValue().urem(getInput().getType().getWidth());
3835 if (amountValue == 0) {
3836 // A shift amount of 0 will not change the input value
3837 return getInput();
3838 }
3839 }
3840
3841 if (!input || !amount)
3842 return nullptr;
3843
3844 assert(inputValue.getBitWidth() == getInput().getType().getWidth() &&
3845 "input value must have the same bit width as the input type");
3846
3847 llvm::APInt resultValue;
3848 if (isRotateLeft())
3849 resultValue = inputValue.rotl(amountValue);
3850 else
3851 resultValue = inputValue.rotr(amountValue);
3852
3853 return IntAttr::get(input.getContext(), input.getType(), resultValue);
3854}
3855
3856//===----------------------------------------------------------------------===//
3857// InlineAsmOp
3858//===----------------------------------------------------------------------===//
3859
3860void cir::InlineAsmOp::print(OpAsmPrinter &p) {
3861 p << '(' << getAsmFlavor() << ", ";
3862 p.increaseIndent();
3863 p.printNewline();
3864
3865 llvm::SmallVector<std::string, 3> names{"out", "in", "in_out"};
3866 auto *nameIt = names.begin();
3867 auto *attrIt = getOperandAttrs().begin();
3868
3869 for (mlir::OperandRange ops : getAsmOperands()) {
3870 p << *nameIt << " = ";
3871
3872 p << '[';
3873 llvm::interleaveComma(llvm::make_range(ops.begin(), ops.end()), p,
3874 [&](Value value) {
3875 p.printOperand(value);
3876 p << " : " << value.getType();
3877 if (mlir::isa<mlir::UnitAttr>(*attrIt))
3878 p << " (maybe_memory)";
3879 attrIt++;
3880 });
3881 p << "],";
3882 p.printNewline();
3883 ++nameIt;
3884 }
3885
3886 p << "{";
3887 p.printString(getAsmString());
3888 p << " ";
3889 p.printString(getConstraints());
3890 p << "}";
3891 p.decreaseIndent();
3892 p << ')';
3893 if (getSideEffects())
3894 p << " side_effects";
3895
3896 std::array elidedAttrs{
3897 llvm::StringRef("asm_flavor"), llvm::StringRef("asm_string"),
3898 llvm::StringRef("constraints"), llvm::StringRef("operand_attrs"),
3899 llvm::StringRef("operands_segments"), llvm::StringRef("side_effects")};
3900 p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs);
3901
3902 if (auto v = getRes())
3903 p << " -> " << v.getType();
3904}
3905
3906void cir::InlineAsmOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3907 ArrayRef<ValueRange> asmOperands,
3908 StringRef asmString, StringRef constraints,
3909 bool sideEffects, cir::AsmFlavor asmFlavor,
3910 ArrayRef<Attribute> operandAttrs) {
3911 // Set up the operands_segments for VariadicOfVariadic
3912 SmallVector<int32_t> segments;
3913 for (auto operandRange : asmOperands) {
3914 segments.push_back(operandRange.size());
3915 odsState.addOperands(operandRange);
3916 }
3917
3918 odsState.addAttribute(
3919 "operands_segments",
3920 DenseI32ArrayAttr::get(odsBuilder.getContext(), segments));
3921 odsState.addAttribute("asm_string", odsBuilder.getStringAttr(asmString));
3922 odsState.addAttribute("constraints", odsBuilder.getStringAttr(constraints));
3923 odsState.addAttribute("asm_flavor",
3924 AsmFlavorAttr::get(odsBuilder.getContext(), asmFlavor));
3925
3926 if (sideEffects)
3927 odsState.addAttribute("side_effects", odsBuilder.getUnitAttr());
3928
3929 odsState.addAttribute("operand_attrs", odsBuilder.getArrayAttr(operandAttrs));
3930}
3931
3932ParseResult cir::InlineAsmOp::parse(OpAsmParser &parser,
3933 OperationState &result) {
3935 llvm::SmallVector<int32_t> operandsGroupSizes;
3936 std::string asmString, constraints;
3937 Type resType;
3938 MLIRContext *ctxt = parser.getBuilder().getContext();
3939
3940 auto error = [&](const Twine &msg) -> LogicalResult {
3941 return parser.emitError(parser.getCurrentLocation(), msg);
3942 };
3943
3944 auto expected = [&](const std::string &c) {
3945 return error("expected '" + c + "'");
3946 };
3947
3948 if (parser.parseLParen().failed())
3949 return expected("(");
3950
3951 auto flavor = FieldParser<AsmFlavor, AsmFlavor>::parse(parser);
3952 if (failed(flavor))
3953 return error("Unknown AsmFlavor");
3954
3955 if (parser.parseComma().failed())
3956 return expected(",");
3957
3958 auto parseValue = [&](Value &v) {
3959 OpAsmParser::UnresolvedOperand op;
3960
3961 if (parser.parseOperand(op) || parser.parseColon())
3962 return error("can't parse operand");
3963
3964 Type typ;
3965 if (parser.parseType(typ).failed())
3966 return error("can't parse operand type");
3968 if (parser.resolveOperand(op, typ, tmp))
3969 return error("can't resolve operand");
3970 v = tmp[0];
3971 return mlir::success();
3972 };
3973
3974 auto parseOperands = [&](llvm::StringRef name) {
3975 if (parser.parseKeyword(name).failed())
3976 return error("expected " + name + " operands here");
3977 if (parser.parseEqual().failed())
3978 return expected("=");
3979 if (parser.parseLSquare().failed())
3980 return expected("[");
3981
3982 int size = 0;
3983 if (parser.parseOptionalRSquare().succeeded()) {
3984 operandsGroupSizes.push_back(size);
3985 if (parser.parseComma())
3986 return expected(",");
3987 return mlir::success();
3988 }
3989
3990 auto parseOperand = [&]() {
3991 Value val;
3992 if (parseValue(val).succeeded()) {
3993 result.operands.push_back(val);
3994 size++;
3995
3996 if (parser.parseOptionalLParen().failed()) {
3997 operandAttrs.push_back(mlir::DictionaryAttr::get(ctxt));
3998 return mlir::success();
3999 }
4000
4001 if (parser.parseKeyword("maybe_memory").succeeded()) {
4002 operandAttrs.push_back(mlir::UnitAttr::get(ctxt));
4003 if (parser.parseRParen())
4004 return expected(")");
4005 return mlir::success();
4006 } else {
4007 return expected("maybe_memory");
4008 }
4009 }
4010 return mlir::failure();
4011 };
4012
4013 if (parser.parseCommaSeparatedList(parseOperand).failed())
4014 return mlir::failure();
4015
4016 if (parser.parseRSquare().failed() || parser.parseComma().failed())
4017 return expected("]");
4018 operandsGroupSizes.push_back(size);
4019 return mlir::success();
4020 };
4021
4022 if (parseOperands("out").failed() || parseOperands("in").failed() ||
4023 parseOperands("in_out").failed())
4024 return error("failed to parse operands");
4025
4026 if (parser.parseLBrace())
4027 return expected("{");
4028 if (parser.parseString(&asmString))
4029 return error("asm string parsing failed");
4030 if (parser.parseString(&constraints))
4031 return error("constraints string parsing failed");
4032 if (parser.parseRBrace())
4033 return expected("}");
4034 if (parser.parseRParen())
4035 return expected(")");
4036
4037 if (parser.parseOptionalKeyword("side_effects").succeeded())
4038 result.attributes.set("side_effects", UnitAttr::get(ctxt));
4039
4040 if (parser.parseOptionalAttrDict(result.attributes).failed())
4041 return mlir::failure();
4042
4043 if (parser.parseOptionalArrow().succeeded() &&
4044 parser.parseType(resType).failed())
4045 return mlir::failure();
4046
4047 result.attributes.set("asm_flavor", AsmFlavorAttr::get(ctxt, *flavor));
4048 result.attributes.set("asm_string", StringAttr::get(ctxt, asmString));
4049 result.attributes.set("constraints", StringAttr::get(ctxt, constraints));
4050 result.attributes.set("operand_attrs", ArrayAttr::get(ctxt, operandAttrs));
4051 result.getOrAddProperties<InlineAsmOp::Properties>().operands_segments =
4052 parser.getBuilder().getDenseI32ArrayAttr(operandsGroupSizes);
4053 if (resType)
4054 result.addTypes(TypeRange{resType});
4055
4056 return mlir::success();
4057}
4058
4059//===----------------------------------------------------------------------===//
4060// ThrowOp
4061//===----------------------------------------------------------------------===//
4062
4063mlir::LogicalResult cir::ThrowOp::verify() {
4064 // For the no-rethrow version, it must have at least the exception pointer.
4065 if (rethrows())
4066 return success();
4067
4068 if (getNumOperands() != 0) {
4069 if (getTypeInfo())
4070 return success();
4071 return emitOpError() << "'type_info' symbol attribute missing";
4072 }
4073
4074 return failure();
4075}
4076
4077//===----------------------------------------------------------------------===//
4078// AtomicFetchOp
4079//===----------------------------------------------------------------------===//
4080
4081LogicalResult cir::AtomicFetchOp::verify() {
4082 if (getBinop() != cir::AtomicFetchKind::Add &&
4083 getBinop() != cir::AtomicFetchKind::Sub &&
4084 getBinop() != cir::AtomicFetchKind::Max &&
4085 getBinop() != cir::AtomicFetchKind::Min &&
4086 !mlir::isa<cir::IntType>(getVal().getType()))
4087 return emitError("only atomic add, sub, max, and min operation could "
4088 "operate on floating-point values");
4089 return success();
4090}
4091
4092//===----------------------------------------------------------------------===//
4093// TypeInfoAttr
4094//===----------------------------------------------------------------------===//
4095
4096LogicalResult cir::TypeInfoAttr::verify(
4097 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
4098 ::mlir::Type type, ::mlir::ArrayAttr typeInfoData) {
4099
4100 if (cir::ConstRecordAttr::verify(emitError, type, typeInfoData).failed())
4101 return failure();
4102
4103 return success();
4104}
4105
4106//===----------------------------------------------------------------------===//
4107// TryOp
4108//===----------------------------------------------------------------------===//
4109
4110void cir::TryOp::getSuccessorRegions(
4111 mlir::RegionBranchPoint point,
4113 // The `try` and the `catchers` region branch back to the parent operation.
4114 if (!point.isParent()) {
4115 regions.push_back(RegionSuccessor::parent());
4116 return;
4117 }
4118
4119 regions.push_back(mlir::RegionSuccessor(&getTryRegion()));
4120
4121 // TODO(CIR): If we know a target function never throws a specific type, we
4122 // can remove the catch handler.
4123 for (mlir::Region &handlerRegion : this->getHandlerRegions())
4124 regions.push_back(mlir::RegionSuccessor(&handlerRegion));
4125}
4126
4127mlir::ValueRange cir::TryOp::getSuccessorInputs(RegionSuccessor successor) {
4128 return successor.isParent() ? ValueRange(getOperation()->getResults())
4129 : ValueRange();
4130}
4131
4132LogicalResult cir::TryOp::verify() {
4133 mlir::ArrayAttr handlerTypes = getHandlerTypes();
4134 if (!handlerTypes) {
4135 if (!getHandlerRegions().empty())
4136 return emitOpError(
4137 "handler regions must be empty when no handler types are present");
4138 return success();
4139 }
4140
4141 mlir::MutableArrayRef<mlir::Region> handlerRegions = getHandlerRegions();
4142
4143 // The parser and builder won't allow this to happen, but the loop below
4144 // relies on the sizes being the same, so we check it here.
4145 if (handlerRegions.size() != handlerTypes.size())
4146 return emitOpError(
4147 "number of handler regions and handler types must match");
4148
4149 for (const auto &[typeAttr, handlerRegion] :
4150 llvm::zip(handlerTypes, handlerRegions)) {
4151 // Verify that handler regions have a !cir.eh_token block argument.
4152 mlir::Block &entryBlock = handlerRegion.front();
4153 if (entryBlock.getNumArguments() != 1 ||
4154 !mlir::isa<cir::EhTokenType>(entryBlock.getArgument(0).getType()))
4155 return emitOpError(
4156 "handler region must have a single '!cir.eh_token' argument");
4157
4158 // The unwind region does not require a cir.begin_catch.
4159 if (mlir::isa<cir::UnwindAttr>(typeAttr))
4160 continue;
4161
4162 // A catch handler region must start with cir.begin_catch, optionally
4163 // preceded by a single cir.construct_catch_param that performs any
4164 // pre-begin_catch initialization for the catch parameter.
4165 if (entryBlock.empty())
4166 return emitOpError("catch handler region must not be empty");
4167 mlir::Operation *firstOp = &entryBlock.front();
4168 if (mlir::isa_and_present<cir::ConstructCatchParamOp>(firstOp))
4169 firstOp = firstOp->getNextNode();
4170 if (!firstOp || !mlir::isa<cir::BeginCatchOp>(firstOp))
4171 return emitOpError(
4172 "catch handler region must start with 'cir.begin_catch'");
4173 }
4174
4175 return success();
4176}
4177
4178static void
4179printTryHandlerRegions(mlir::OpAsmPrinter &printer, cir::TryOp op,
4180 mlir::MutableArrayRef<mlir::Region> handlerRegions,
4181 mlir::ArrayAttr handlerTypes) {
4182 if (!handlerTypes)
4183 return;
4184
4185 for (const auto [typeIdx, typeAttr] : llvm::enumerate(handlerTypes)) {
4186 if (typeIdx)
4187 printer << " ";
4188
4189 if (mlir::isa<cir::CatchAllAttr>(typeAttr)) {
4190 printer << "catch all ";
4191 } else if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
4192 printer << "unwind ";
4193 } else {
4194 printer << "catch [type ";
4195 printer.printAttribute(typeAttr);
4196 printer << "] ";
4197 }
4198
4199 // Print the handler region's !cir.eh_token block argument.
4200 mlir::Region &region = handlerRegions[typeIdx];
4201 if (!region.empty() && region.front().getNumArguments() > 0) {
4202 printer << "(";
4203 printer.printRegionArgument(region.front().getArgument(0));
4204 printer << ") ";
4205 }
4206
4207 printer.printRegion(region,
4208 /*printEntryBLockArgs=*/false,
4209 /*printBlockTerminators=*/true);
4210 }
4211}
4212
4213static mlir::ParseResult parseTryHandlerRegions(
4214 mlir::OpAsmParser &parser,
4215 llvm::SmallVectorImpl<std::unique_ptr<mlir::Region>> &handlerRegions,
4216 mlir::ArrayAttr &handlerTypes) {
4217
4218 auto parseCheckedCatcherRegion = [&]() -> mlir::ParseResult {
4219 handlerRegions.emplace_back(new mlir::Region);
4220
4221 mlir::Region &currRegion = *handlerRegions.back();
4222
4223 // Parse the required region argument: (%eh_token : !cir.eh_token)
4225 if (parser.parseLParen())
4226 return failure();
4227 mlir::OpAsmParser::Argument arg;
4228 if (parser.parseArgument(arg, /*allowType=*/true))
4229 return failure();
4230 regionArgs.push_back(arg);
4231 if (parser.parseRParen())
4232 return failure();
4233
4234 mlir::SMLoc regionLoc = parser.getCurrentLocation();
4235 if (parser.parseRegion(currRegion, regionArgs)) {
4236 handlerRegions.clear();
4237 return failure();
4238 }
4239
4240 if (currRegion.empty())
4241 return parser.emitError(regionLoc, "handler region shall not be empty");
4242
4243 if (!(currRegion.back().mightHaveTerminator() &&
4244 currRegion.back().getTerminator()))
4245 return parser.emitError(
4246 regionLoc, "blocks are expected to be explicitly terminated");
4247
4248 return success();
4249 };
4250
4251 bool hasCatchAll = false;
4253 while (parser.parseOptionalKeyword("catch").succeeded()) {
4254 bool hasLSquare = parser.parseOptionalLSquare().succeeded();
4255
4256 llvm::StringRef attrStr;
4257 if (parser.parseOptionalKeyword(&attrStr, {"all", "type"}).failed())
4258 return parser.emitError(parser.getCurrentLocation(),
4259 "expected 'all' or 'type' keyword");
4260
4261 bool isCatchAll = attrStr == "all";
4262 if (isCatchAll) {
4263 if (hasCatchAll)
4264 return parser.emitError(parser.getCurrentLocation(),
4265 "can't have more than one catch all");
4266 hasCatchAll = true;
4267 }
4268
4269 mlir::Attribute exceptionRTTIAttr;
4270 if (!isCatchAll && parser.parseAttribute(exceptionRTTIAttr).failed())
4271 return parser.emitError(parser.getCurrentLocation(),
4272 "expected valid RTTI info attribute");
4273
4274 catcherAttrs.push_back(isCatchAll
4275 ? cir::CatchAllAttr::get(parser.getContext())
4276 : exceptionRTTIAttr);
4277
4278 if (hasLSquare && isCatchAll)
4279 return parser.emitError(parser.getCurrentLocation(),
4280 "catch all dosen't need RTTI info attribute");
4281
4282 if (hasLSquare && parser.parseRSquare().failed())
4283 return parser.emitError(parser.getCurrentLocation(),
4284 "expected `]` after RTTI info attribute");
4285
4286 if (parseCheckedCatcherRegion().failed())
4287 return mlir::failure();
4288 }
4289
4290 if (parser.parseOptionalKeyword("unwind").succeeded()) {
4291 if (hasCatchAll)
4292 return parser.emitError(parser.getCurrentLocation(),
4293 "unwind can't be used with catch all");
4294
4295 catcherAttrs.push_back(cir::UnwindAttr::get(parser.getContext()));
4296 if (parseCheckedCatcherRegion().failed())
4297 return mlir::failure();
4298 }
4299
4300 handlerTypes = parser.getBuilder().getArrayAttr(catcherAttrs);
4301 return mlir::success();
4302}
4303
4304//===----------------------------------------------------------------------===//
4305// EhTypeIdOp
4306//===----------------------------------------------------------------------===//
4307
4308LogicalResult
4309cir::EhTypeIdOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
4310 Operation *op = symbolTable.lookupNearestSymbolFrom(*this, getTypeSymAttr());
4311 if (!isa_and_nonnull<GlobalOp>(op))
4312 return emitOpError("'")
4313 << getTypeSym() << "' does not reference a valid cir.global";
4314 return success();
4315}
4316
4317//===----------------------------------------------------------------------===//
4318// ConstructCatchParamOp
4319//===----------------------------------------------------------------------===//
4320
4321LogicalResult cir::ConstructCatchParamOp::verifySymbolUses(
4322 SymbolTableCollection &symbolTable) {
4323 auto fn =
4324 symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(*this, getCopyFnAttr());
4325 if (!fn)
4326 return emitOpError("'")
4327 << getCopyFn() << "' does not reference a valid cir.func";
4328
4329 if (!fn->hasAttr(cir::CIRDialect::getCatchCopyThunkAttrName()))
4330 return emitOpError("catch-init copy_fn must be tagged with the ")
4331 << cir::CIRDialect::getCatchCopyThunkAttrName() << " attribute";
4332
4333 cir::FuncType fnType = fn.getFunctionType();
4334 if (fnType.getNumInputs() != 2 || !fnType.hasVoidReturn())
4335 return emitOpError("catch-init copy_fn must take two pointer arguments and "
4336 "return void");
4337
4338 if (fnType.getInput(0) != getParamAddr().getType())
4339 return emitOpError("first argument of catch-init copy_fn must match the "
4340 "type of 'param_addr'");
4341
4342 if (fnType.getInput(1) != getParamAddr().getType())
4343 return emitOpError(
4344 "second argument of catch-init copy_fn must be a pointer "
4345 "to the catch type");
4346
4347 return success();
4348}
4349
4350//===----------------------------------------------------------------------===//
4351// EhDispatchOp
4352//===----------------------------------------------------------------------===//
4353
4354static ParseResult
4355parseEhDispatchDestinations(OpAsmParser &parser, mlir::ArrayAttr &catchTypes,
4356 SmallVectorImpl<Block *> &catchDestinations,
4357 Block *&defaultDestination,
4358 mlir::UnitAttr &defaultIsCatchAll) {
4359 // Parse: [ ... ]
4360 if (parser.parseLSquare())
4361 return failure();
4362
4363 SmallVector<Attribute> handlerTypes;
4364 bool hasCatchAll = false;
4365 bool hasUnwind = false;
4366
4367 // Parse handler list.
4368 auto parseHandler = [&]() -> ParseResult {
4369 // Check for 'catch_all' or 'unwind' keywords.
4370 if (succeeded(parser.parseOptionalKeyword("catch_all"))) {
4371 if (hasCatchAll)
4372 return parser.emitError(parser.getCurrentLocation(),
4373 "duplicate 'catch_all' handler");
4374 if (hasUnwind)
4375 return parser.emitError(parser.getCurrentLocation(),
4376 "cannot have both 'catch_all' and 'unwind'");
4377 hasCatchAll = true;
4378
4379 if (parser.parseColon().failed())
4380 return failure();
4381
4382 if (parser.parseSuccessor(defaultDestination).failed())
4383 return failure();
4384
4385 return success();
4386 }
4387
4388 if (succeeded(parser.parseOptionalKeyword("unwind"))) {
4389 if (hasUnwind)
4390 return parser.emitError(parser.getCurrentLocation(),
4391 "duplicate 'unwind' handler");
4392 if (hasCatchAll)
4393 return parser.emitError(parser.getCurrentLocation(),
4394 "cannot have both 'catch_all' and 'unwind'");
4395 hasUnwind = true;
4396
4397 if (parser.parseColon().failed())
4398 return failure();
4399
4400 if (parser.parseSuccessor(defaultDestination).failed())
4401 return failure();
4402 return success();
4403 }
4404
4405 // Otherwise, expect 'catch(<attr> : <type>) : ^block'.
4406 // The 'catch(...)' wrapper allows the attribute to include its type
4407 // without conflicting with the ':' used for the block destination.
4408 if (parser.parseKeyword("catch").failed())
4409 return failure();
4410
4411 if (parser.parseLParen().failed())
4412 return failure();
4413
4414 mlir::Attribute catchTypeAttr;
4415 if (parser.parseAttribute(catchTypeAttr).failed())
4416 return failure();
4417 handlerTypes.push_back(catchTypeAttr);
4418
4419 if (parser.parseRParen().failed())
4420 return failure();
4421
4422 if (parser.parseColon().failed())
4423 return failure();
4424
4425 Block *dest;
4426 if (parser.parseSuccessor(dest).failed())
4427 return failure();
4428 catchDestinations.push_back(dest);
4429 return success();
4430 };
4431
4432 if (parser.parseCommaSeparatedList(parseHandler).failed())
4433 return failure();
4434
4435 if (parser.parseRSquare().failed())
4436 return failure();
4437
4438 // Verify we have catch_all or unwind.
4439 if (!hasCatchAll && !hasUnwind)
4440 return parser.emitError(parser.getCurrentLocation(),
4441 "must have either 'catch_all' or 'unwind' handler");
4442
4443 // Add attributes and successors.
4444 if (!handlerTypes.empty())
4445 catchTypes = parser.getBuilder().getArrayAttr(handlerTypes);
4446
4447 if (hasCatchAll)
4448 defaultIsCatchAll = parser.getBuilder().getUnitAttr();
4449
4450 return success();
4451}
4452
4453static void printEhDispatchDestinations(OpAsmPrinter &p, cir::EhDispatchOp op,
4454 mlir::ArrayAttr catchTypes,
4455 SuccessorRange catchDestinations,
4456 Block *defaultDestination,
4457 mlir::UnitAttr defaultIsCatchAll) {
4458 p << " [";
4459 p.printNewline();
4460
4461 // If we have at least one catch type, print them.
4462 if (catchTypes) {
4463 // Print type handlers using 'catch(<attr>) : ^block' syntax.
4464 llvm::interleave(
4465 llvm::zip(catchTypes, catchDestinations),
4466 [&](auto i) {
4467 p << " catch(";
4468 p.printAttribute(std::get<0>(i));
4469 p << ") : ";
4470 p.printSuccessor(std::get<1>(i));
4471 },
4472 [&] {
4473 p << ',';
4474 p.printNewline();
4475 });
4476
4477 p << ", ";
4478 p.printNewline();
4479 }
4480
4481 // Print catch_all or unwind handler.
4482 if (defaultIsCatchAll)
4483 p << " catch_all : ";
4484 else
4485 p << " unwind : ";
4486 p.printSuccessor(defaultDestination);
4487 p.printNewline();
4488
4489 p << "]";
4490}
4491
4492//===----------------------------------------------------------------------===//
4493// TableGen'd op method definitions
4494//===----------------------------------------------------------------------===//
4495
4496#define GET_OP_CLASSES
4497#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
static const MemRegion * getRegion(const CallEvent &Call, const MutexDescriptor &Descriptor, bool IsLock)
static void printEhDispatchDestinations(OpAsmPrinter &p, cir::EhDispatchOp op, mlir::ArrayAttr catchTypes, SuccessorRange catchDestinations, Block *defaultDestination, mlir::UnitAttr defaultIsCatchAll)
static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op, cir::FuncOp function)
static bool isCirFunctionPointerType(mlir::Type ty)
static LogicalResult verifyMemberPtrCast(Operation *op, mlir::Value src, mlir::Type resultTy)
static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, mlir::OperationState &result, bool hasDestinationBlocks=false)
static bool isIntOrBoolCast(cir::CastOp op)
static ParseResult parseEhDispatchDestinations(OpAsmParser &parser, mlir::ArrayAttr &catchTypes, SmallVectorImpl< Block * > &catchDestinations, Block *&defaultDestination, mlir::UnitAttr &defaultIsCatchAll)
static void printConstant(OpAsmPrinter &p, Attribute value)
static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser, mlir::Region &region)
ParseResult parseInlineKindAttr(OpAsmParser &parser, cir::InlineKindAttr &inlineKindAttr)
void printInlineKindAttr(OpAsmPrinter &p, cir::InlineKindAttr inlineKindAttr)
static ParseResult parseSwitchFlatOpCases(OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< llvm::SmallVector< OpAsmParser::UnresolvedOperand > > &caseOperands, SmallVectorImpl< llvm::SmallVector< Type > > &caseOperandTypes)
<cases> ::= [ (case (, case )* )?
void printGlobalAddressSpaceValue(mlir::AsmPrinter &printer, cir::GlobalOp op, mlir::ptr::MemorySpaceAttrInterface attr)
static void printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym, mlir::Value indirectCallee, mlir::OpAsmPrinter &printer, bool isNothrow, cir::SideEffect sideEffect, ArrayAttr argAttrs, ArrayAttr resAttrs, mlir::Block *normalDest=nullptr, mlir::Block *unwindDest=nullptr)
static LogicalResult verifyCallCommInSymbolUses(mlir::Operation *op, SymbolTableCollection &symbolTable)
static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region, SMLoc errLoc)
static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValueAttr, mlir::Region &ctorRegion, mlir::Region &dtorRegion)
void printIndirectBrOpSucessors(OpAsmPrinter &p, cir::IndirectBrOp op, Type flagType, SuccessorRange succs, OperandRangeRange succOperands, const TypeRangeRange &succOperandsTypes)
static OpFoldResult foldUnaryBitOp(mlir::Attribute inputAttr, llvm::function_ref< llvm::APInt(const llvm::APInt &)> func, bool poisonZero=false)
static llvm::StringRef getLinkageAttrNameString()
Returns the name used for the linkage attribute.
static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue)
Parse an enum from the keyword, or default to the provided default value.
mlir::OptionalParseResult parseGlobalAddressSpaceValue(mlir::AsmParser &p, mlir::ptr::MemorySpaceAttrInterface &attr)
static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op, Type flagType, mlir::ArrayAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static mlir::ParseResult parseTryCallDestinations(mlir::OpAsmParser &parser, mlir::OperationState &result)
static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op, TypeAttr type, Attribute initAttr, mlir::Region &ctorRegion, mlir::Region &dtorRegion)
static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result)
Parse an enum from the keyword, return failure if the keyword is not found.
static Value tryFoldCastChain(cir::CastOp op)
static void printTryHandlerRegions(mlir::OpAsmPrinter &printer, cir::TryOp op, mlir::MutableArrayRef< mlir::Region > handlerRegions, mlir::ArrayAttr handlerTypes)
ParseResult parseIndirectBrOpSucessors(OpAsmParser &parser, Type &flagType, SmallVectorImpl< Block * > &succOperandBlocks, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &succOperands, SmallVectorImpl< SmallVector< Type > > &succOperandsTypes)
static bool omitRegionTerm(mlir::Region &r)
static LogicalResult verifyBinaryOverflowOp(mlir::Operation *op, bool noSignedWrap, bool noUnsignedWrap, bool saturated, bool hasSat)
static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer, cir::ScopeOp &op, mlir::Region &region)
static ParseResult parseConstantValue(OpAsmParser &parser, mlir::Attribute &valueAttr)
static LogicalResult verifyArrayCtorDtor(Op op)
static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, mlir::Attribute attrType)
static mlir::ParseResult parseTryHandlerRegions(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< std::unique_ptr< mlir::Region > > &handlerRegions, mlir::ArrayAttr &handlerTypes)
#define REGISTER_ENUM_TYPE(Ty)
static int parseOptionalKeywordAlternative(AsmParser &parser, ArrayRef< llvm::StringRef > keywords)
llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> BuilderCallbackRef
Definition CIRDialect.h:37
llvm::function_ref< void( mlir::OpBuilder &, mlir::Location, mlir::OperationState &)> BuilderOpStateCallbackRef
Definition CIRDialect.h:39
static std::optional< NonLoc > getIndex(ProgramStateRef State, const ElementRegion *ER, CharKind CK)
static Decl::Kind getKind(const Decl *D)
TokenType getType() const
Returns the token's type, e.g.
tooling::Replacements cleanup(const FormatStyle &Style, StringRef Code, ArrayRef< tooling::Range > Ranges, StringRef FileName="<stdin>")
Clean up any erroneous/redundant code in the given Ranges in Code.
*collection of selector each with an associated kind and an ordered *collection of selectors A selector has a an optional score condition
*collection of selector each with an associated kind and an ordered *collection of selectors A selector has a kind
__device__ __2f16 float c
void buildTerminatedBody(mlir::OpBuilder &builder, mlir::Location loc)
mlir::ptr::MemorySpaceAttrInterface normalizeDefaultAddressSpace(mlir::ptr::MemorySpaceAttrInterface addrSpace)
Normalize LangAddressSpace::Default to null (empty attribute).
const internal::VariadicAllOfMatcher< Attr > attr
const AstTypeMatcher< RecordType > recordType
StringRef getName(const HeaderType T)
Definition HeaderFile.h:38
RangeSelector name(std::string ID)
Given a node with a "name", (like NamedDecl, DeclRefExpr, CxxCtorInitializer, and TypeLoc) selects th...
nullptr
This class represents a compute construct, representing a 'Kind' of ‘parallel’, 'serial',...
__DEVICE__ _Tp arg(const std::complex< _Tp > &__c)
static bool memberFuncPtrCast()
static bool opCallCallConv()
static bool opScopeCleanupRegion()
static bool supportIFuncAttr()