clang 23.0.0git
CIRDialect.cpp
Go to the documentation of this file.
1//===- CIRDialect.cpp - MLIR CIR ops implementation -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the CIR dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
18
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/DialectImplementation.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Interfaces/ControlFlowInterfaces.h"
23#include "mlir/Interfaces/FunctionImplementation.h"
24#include "mlir/Support/LLVM.h"
25
26#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
27#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
29#include "llvm/ADT/SetOperations.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/LogicalResult.h"
33
34using namespace mlir;
35using namespace cir;
36
37//===----------------------------------------------------------------------===//
38// CIR Dialect
39//===----------------------------------------------------------------------===//
40namespace {
41struct CIROpAsmDialectInterface : public OpAsmDialectInterface {
42 using OpAsmDialectInterface::OpAsmDialectInterface;
43
44 AliasResult getAlias(Type type, raw_ostream &os) const final {
45 if (auto recordType = dyn_cast<cir::RecordType>(type)) {
46 StringAttr nameAttr = recordType.getName();
47 if (!nameAttr)
48 os << "rec_anon_" << recordType.getKindAsStr();
49 else
50 os << "rec_" << nameAttr.getValue();
51 return AliasResult::OverridableAlias;
52 }
53 if (auto intType = dyn_cast<cir::IntType>(type)) {
54 // We only provide alias for standard integer types (i.e. integer types
55 // whose width is a power of 2 and at least 8).
56 unsigned width = intType.getWidth();
57 if (width < 8 || !llvm::isPowerOf2_32(width))
58 return AliasResult::NoAlias;
59 os << intType.getAlias();
60 return AliasResult::OverridableAlias;
61 }
62 if (auto voidType = dyn_cast<cir::VoidType>(type)) {
63 os << voidType.getAlias();
64 return AliasResult::OverridableAlias;
65 }
66
67 return AliasResult::NoAlias;
68 }
69
70 AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
71 if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr)) {
72 os << (boolAttr.getValue() ? "true" : "false");
73 return AliasResult::FinalAlias;
74 }
75 if (auto bitfield = mlir::dyn_cast<cir::BitfieldInfoAttr>(attr)) {
76 os << "bfi_" << bitfield.getName().str();
77 return AliasResult::FinalAlias;
78 }
79 if (auto dynCastInfoAttr = mlir::dyn_cast<cir::DynamicCastInfoAttr>(attr)) {
80 os << dynCastInfoAttr.getAlias();
81 return AliasResult::FinalAlias;
82 }
83 if (auto cmpThreeWayInfoAttr =
84 mlir::dyn_cast<cir::CmpThreeWayInfoAttr>(attr)) {
85 os << cmpThreeWayInfoAttr.getAlias();
86 return AliasResult::FinalAlias;
87 }
88 return AliasResult::NoAlias;
89 }
90};
91} // namespace
92
93void cir::CIRDialect::initialize() {
94 registerTypes();
95 registerAttributes();
96 addOperations<
97#define GET_OP_LIST
98#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
99 >();
100 addInterfaces<CIROpAsmDialectInterface>();
101}
102
103Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
104 mlir::Attribute value,
105 mlir::Type type,
106 mlir::Location loc) {
107 return cir::ConstantOp::create(builder, loc, type,
108 mlir::cast<mlir::TypedAttr>(value));
109}
110
111//===----------------------------------------------------------------------===//
112// Helpers
113//===----------------------------------------------------------------------===//
114
115// Parses one of the keywords provided in the list `keywords` and returns the
116// position of the parsed keyword in the list. If none of the keywords from the
117// list is parsed, returns -1.
118static int parseOptionalKeywordAlternative(AsmParser &parser,
119 ArrayRef<llvm::StringRef> keywords) {
120 for (auto en : llvm::enumerate(keywords)) {
121 if (succeeded(parser.parseOptionalKeyword(en.value())))
122 return en.index();
123 }
124 return -1;
125}
126
127namespace {
128template <typename Ty> struct EnumTraits {};
129
130#define REGISTER_ENUM_TYPE(Ty) \
131 template <> struct EnumTraits<cir::Ty> { \
132 static llvm::StringRef stringify(cir::Ty value) { \
133 return stringify##Ty(value); \
134 } \
135 static unsigned getMaxEnumVal() { return cir::getMaxEnumValFor##Ty(); } \
136 }
137
138REGISTER_ENUM_TYPE(GlobalLinkageKind);
139REGISTER_ENUM_TYPE(VisibilityKind);
140REGISTER_ENUM_TYPE(SideEffect);
141REGISTER_ENUM_TYPE(CallingConv);
142} // namespace
143
144/// Parse an enum from the keyword, or default to the provided default value.
145/// The return type is the enum type by default, unless overriden with the
146/// second template argument.
147template <typename EnumTy, typename RetTy = EnumTy>
148static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue) {
150 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
151 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
152
153 int index = parseOptionalKeywordAlternative(parser, names);
154 if (index == -1)
155 return static_cast<RetTy>(defaultValue);
156 return static_cast<RetTy>(index);
157}
158
159/// Parse an enum from the keyword, return failure if the keyword is not found.
160template <typename EnumTy, typename RetTy = EnumTy>
161static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result) {
163 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
164 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
165
166 int index = parseOptionalKeywordAlternative(parser, names);
167 if (index == -1)
168 return failure();
169 result = static_cast<RetTy>(index);
170 return success();
171}
172
173// Check if a region's termination omission is valid and, if so, creates and
174// inserts the omitted terminator into the region.
175static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region,
176 SMLoc errLoc) {
177 Location eLoc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
178 OpBuilder builder(parser.getBuilder().getContext());
179
180 // Insert empty block in case the region is empty to ensure the terminator
181 // will be inserted
182 if (region.empty())
183 builder.createBlock(&region);
184
185 Block &block = region.back();
186 // Region is properly terminated: nothing to do.
187 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>())
188 return success();
189
190 // Check for invalid terminator omissions.
191 if (!region.hasOneBlock())
192 return parser.emitError(errLoc,
193 "multi-block region must not omit terminator");
194
195 // Terminator was omitted correctly: recreate it.
196 builder.setInsertionPointToEnd(&block);
197 cir::YieldOp::create(builder, eLoc);
198 return success();
199}
200
201// True if the region's terminator should be omitted.
202static bool omitRegionTerm(mlir::Region &r) {
203 const auto singleNonEmptyBlock = r.hasOneBlock() && !r.back().empty();
204 const auto yieldsNothing = [&r]() {
205 auto y = dyn_cast<cir::YieldOp>(r.back().getTerminator());
206 return y && y.getArgs().empty();
207 };
208 return singleNonEmptyBlock && yieldsNothing();
209}
210
211//===----------------------------------------------------------------------===//
212// InlineKindAttr (FIXME: remove once FuncOp uses assembly format)
213//===----------------------------------------------------------------------===//
214
215ParseResult parseInlineKindAttr(OpAsmParser &parser,
216 cir::InlineKindAttr &inlineKindAttr) {
217 // Static list of possible inline kind keywords
218 static constexpr llvm::StringRef keywords[] = {"no_inline", "always_inline",
219 "inline_hint"};
220
221 // Parse the inline kind keyword (optional)
222 llvm::StringRef keyword;
223 if (parser.parseOptionalKeyword(&keyword, keywords).failed()) {
224 // Not an inline kind keyword, leave inlineKindAttr empty
225 return success();
226 }
227
228 // Parse the enum value from the keyword
229 auto inlineKindResult = ::cir::symbolizeEnum<::cir::InlineKind>(keyword);
230 if (!inlineKindResult) {
231 return parser.emitError(parser.getCurrentLocation(), "expected one of [")
232 << llvm::join(llvm::ArrayRef(keywords), ", ")
233 << "] for inlineKind, got: " << keyword;
234 }
235
236 inlineKindAttr =
237 ::cir::InlineKindAttr::get(parser.getContext(), *inlineKindResult);
238 return success();
239}
240
241void printInlineKindAttr(OpAsmPrinter &p, cir::InlineKindAttr inlineKindAttr) {
242 if (inlineKindAttr) {
243 p << " " << stringifyInlineKind(inlineKindAttr.getValue());
244 }
245}
246//===----------------------------------------------------------------------===//
247// CIR Custom Parsers/Printers
248//===----------------------------------------------------------------------===//
249
250static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser,
251 mlir::Region &region) {
252 auto regionLoc = parser.getCurrentLocation();
253 if (parser.parseRegion(region))
254 return failure();
255 if (ensureRegionTerm(parser, region, regionLoc).failed())
256 return failure();
257 return success();
258}
259
260static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer,
261 cir::ScopeOp &op,
262 mlir::Region &region) {
263 printer.printRegion(region,
264 /*printEntryBlockArgs=*/false,
265 /*printBlockTerminators=*/!omitRegionTerm(region));
266}
267
268mlir::OptionalParseResult
269parseGlobalAddressSpaceValue(mlir::AsmParser &p,
270 mlir::ptr::MemorySpaceAttrInterface &attr);
271
272void printGlobalAddressSpaceValue(mlir::AsmPrinter &printer, cir::GlobalOp op,
273 mlir::ptr::MemorySpaceAttrInterface attr);
274
275//===----------------------------------------------------------------------===//
276// AllocaOp
277//===----------------------------------------------------------------------===//
278
279void cir::AllocaOp::build(mlir::OpBuilder &odsBuilder,
280 mlir::OperationState &odsState, mlir::Type addr,
281 mlir::Type allocaType, llvm::StringRef name,
282 mlir::IntegerAttr alignment) {
283 odsState.addAttribute(getAllocaTypeAttrName(odsState.name),
284 mlir::TypeAttr::get(allocaType));
285 odsState.addAttribute(getNameAttrName(odsState.name),
286 odsBuilder.getStringAttr(name));
287 if (alignment) {
288 odsState.addAttribute(getAlignmentAttrName(odsState.name), alignment);
289 }
290 odsState.addTypes(addr);
291}
292
293//===----------------------------------------------------------------------===//
294// ArrayCtor & ArrayDtor
295//===----------------------------------------------------------------------===//
296
297template <typename Op> static LogicalResult verifyArrayCtorDtor(Op op) {
298 auto ptrTy = mlir::cast<cir::PointerType>(op.getAddr().getType());
299 mlir::Type pointeeTy = ptrTy.getPointee();
300
301 mlir::Block &body = op.getBody().front();
302 if (body.getNumArguments() != 1)
303 return op.emitOpError("body must have exactly one block argument");
304
305 auto expectedEltPtrTy =
306 mlir::dyn_cast<cir::PointerType>(body.getArgument(0).getType());
307 if (!expectedEltPtrTy)
308 return op.emitOpError("block argument must be a !cir.ptr type");
309
310 if (op.getNumElements()) {
311 auto recTy = mlir::dyn_cast<cir::RecordType>(pointeeTy);
312 if (!recTy)
313 return op.emitOpError(
314 "when 'num_elements' is present, 'addr' must be a pointer to a "
315 "!cir.record type");
316
317 if (expectedEltPtrTy != ptrTy)
318 return op.emitOpError("when 'num_elements' is present, 'addr' type must "
319 "match the block argument type");
320 } else {
321 auto arrayTy = mlir::dyn_cast<cir::ArrayType>(pointeeTy);
322 if (!arrayTy)
323 return op.emitOpError(
324 "when 'num_elements' is absent, 'addr' must be a pointer to a "
325 "!cir.array type");
326
327 mlir::Type innerEltTy = arrayTy.getElementType();
328 while (auto nested = mlir::dyn_cast<cir::ArrayType>(innerEltTy))
329 innerEltTy = nested.getElementType();
330
331 auto recTy = mlir::dyn_cast<cir::RecordType>(innerEltTy);
332 if (!recTy)
333 return op.emitOpError(
334 "the block argument type must be a pointer to a !cir.record type");
335
336 if (expectedEltPtrTy.getPointee() != innerEltTy)
337 return op.emitOpError(
338 "block argument pointee type must match the innermost array "
339 "element type");
340 }
341
342 return success();
343}
344
345LogicalResult cir::ArrayCtor::verify() {
346 if (failed(verifyArrayCtorDtor(*this)))
347 return failure();
348
349 mlir::Region &partialDtor = getPartialDtor();
350 if (!partialDtor.empty()) {
351 mlir::Block &dtorBlock = partialDtor.front();
352 if (dtorBlock.getNumArguments() != 1)
353 return emitOpError("partial_dtor must have exactly one block argument");
354
355 auto bodyArgTy = getBody().front().getArgument(0).getType();
356 if (dtorBlock.getArgument(0).getType() != bodyArgTy)
357 return emitOpError("partial_dtor block argument type must match "
358 "the body block argument type");
359 }
360 return success();
361}
362LogicalResult cir::ArrayDtor::verify() { return verifyArrayCtorDtor(*this); }
363
364//===----------------------------------------------------------------------===//
365// DeleteArrayOp
366//===----------------------------------------------------------------------===//
367
368LogicalResult cir::DeleteArrayOp::verify() {
369 if (getDtorMayThrow() && !getElementDtorAttr())
370 return emitOpError(
371 "'dtor_may_throw' requires an 'element_dtor' to be present");
372 return success();
373}
374
375//===----------------------------------------------------------------------===//
376// AssumeOp
377//===----------------------------------------------------------------------===//
378
379static void printAssumeBundle(OpAsmPrinter &p, cir::AssumeOp op,
380 cir::AssumeBundleKindAttr kindAttr,
381 OperandRange bundleArgs,
382 TypeRange bundleArgTypes) {
383 cir::AssumeBundleKind kind = kindAttr.getValue();
384 if (kind == cir::AssumeBundleKind::None)
385 return;
386
387 p << " " << cir::stringifyAssumeBundleKind(kind);
388 if (bundleArgs.empty())
389 return;
390
391 p << "(";
392 p.printOperands(bundleArgs);
393 p << " : ";
394 llvm::interleaveComma(bundleArgTypes, p);
395 p << ")";
396}
397
398static ParseResult parseAssumeBundle(
399 OpAsmParser &p, cir::AssumeBundleKindAttr &bundleKindAttr,
401 llvm::SmallVector<mlir::Type, 1> &bundleArgTypes) {
402 StringRef keyword;
403 auto loc = p.getCurrentLocation();
404 if (failed(p.parseOptionalKeyword(&keyword))) {
405 bundleKindAttr = cir::AssumeBundleKindAttr::get(
406 p.getContext(), cir::AssumeBundleKind::None);
407 return success();
408 }
409
410 std::optional<cir::AssumeBundleKind> parsedKind =
411 cir::symbolizeAssumeBundleKind(keyword);
412 if (!parsedKind)
413 return p.emitError(loc, "unknown assume bundle kind '") << keyword << "'";
414
415 bundleKindAttr = cir::AssumeBundleKindAttr::get(p.getContext(), *parsedKind);
416
417 if (p.parseOptionalLParen())
418 return success();
419
420 if (p.parseOperandList(bundleArgs) || p.parseColon() ||
421 p.parseTypeList(bundleArgTypes) || p.parseRParen())
422 return failure();
423
424 return success();
425}
426
427LogicalResult cir::AssumeOp::verify() {
428 cir::AssumeBundleKind kind = getBundleKind();
429 size_t numArgs = getBundleArgs().size();
430
431 if (kind == cir::AssumeBundleKind::None) {
432 if (numArgs != 0)
433 return emitOpError("unexpected bundle operands for kind 'none'");
434 return success();
435 }
436
437 if (numArgs == 0)
438 return emitOpError("expected bundle operands for kind '")
439 << cir::stringifyAssumeBundleKind(kind) << "'";
440
441 switch (kind) {
442 case cir::AssumeBundleKind::Align:
443 if (numArgs != 2 && numArgs != 3)
444 return emitOpError("align bundle expects 2 or 3 operands");
445 break;
446 case cir::AssumeBundleKind::SeparateStorage:
447 if (numArgs != 2)
448 return emitOpError("separate_storage bundle expects 2 operands");
449 break;
450 case cir::AssumeBundleKind::Dereferenceable:
451 if (numArgs != 2)
452 return emitOpError("dereferenceable bundle expects 2 operands");
453 break;
454 default:
455 break;
456 }
457 return success();
458}
459
460//===----------------------------------------------------------------------===//
461// LocalInitOp
462//===----------------------------------------------------------------------===//
463
464LogicalResult
465cir::LocalInitOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
466 cir::GlobalOp global = getReferencedGlobal(symbolTable);
467 if (!global)
468 return emitOpError("'")
469 << getGlobalName() << "' does not reference a valid cir.global";
470
471 if (getTls() && !global.getTlsModel())
472 return emitOpError("access to global not marked thread local");
473
474 if (!global.getStaticLocalGuard().has_value())
475 return emitOpError("static_local attribute mismatch");
476
477 return success();
478}
479
480//===----------------------------------------------------------------------===//
481// ConditionOp
482//===----------------------------------------------------------------------===//
483
484//===----------------------------------
485// BranchOpTerminatorInterface Methods
486//===----------------------------------
487
488void cir::ConditionOp::getSuccessorRegions(
489 ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
490 // TODO(cir): The condition value may be folded to a constant, narrowing
491 // down its list of possible successors.
492
493 // Parent is a loop: condition may branch to the body or to the parent op.
494 if (auto loopOp = dyn_cast<LoopOpInterface>(getOperation()->getParentOp())) {
495 regions.emplace_back(&loopOp.getBody());
496 regions.push_back(RegionSuccessor::parent());
497 return;
498 }
499
500 // Parent is an await: condition may branch to resume or suspend regions.
501 auto await = cast<AwaitOp>(getOperation()->getParentOp());
502 regions.emplace_back(&await.getResume());
503 regions.emplace_back(&await.getSuspend());
504}
505
506MutableOperandRange
507cir::ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
508 // No values are yielded to the successor region.
509 return MutableOperandRange(getOperation(), 0, 0);
510}
511
512MutableOperandRange
513cir::ResumeOp::getMutableSuccessorOperands(RegionSuccessor point) {
514 // The eh_token operand is not forwarded to the parent region.
515 return MutableOperandRange(getOperation(), 0, 0);
516}
517
518LogicalResult cir::ConditionOp::verify() {
519 if (!isa<LoopOpInterface, AwaitOp>(getOperation()->getParentOp()))
520 return emitOpError("condition must be within a conditional region");
521 return success();
522}
523
524//===----------------------------------------------------------------------===//
525// ConstantOp
526//===----------------------------------------------------------------------===//
527
528static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
529 mlir::Attribute attrType) {
530 if (isa<cir::ConstPtrAttr>(attrType)) {
531 if (!mlir::isa<cir::PointerType>(opType))
532 return op->emitOpError(
533 "pointer constant initializing a non-pointer type");
534 return success();
535 }
536
537 if (isa<cir::DataMemberAttr, cir::MethodAttr>(attrType)) {
538 // More detailed type verifications are already done in
539 // DataMemberAttr::verify or MethodAttr::verify. Don't need to repeat here.
540 return success();
541 }
542
543 if (isa<cir::ZeroAttr>(attrType)) {
544 if (isa<cir::RecordType, cir::ArrayType, cir::VectorType, cir::ComplexType>(
545 opType))
546 return success();
547 return op->emitOpError(
548 "zero expects struct, array, vector, or complex type");
549 }
550
551 if (mlir::isa<cir::UndefAttr>(attrType)) {
552 if (!mlir::isa<cir::VoidType>(opType))
553 return success();
554 return op->emitOpError("undef expects non-void type");
555 }
556
557 if (mlir::isa<cir::BoolAttr>(attrType)) {
558 if (!mlir::isa<cir::BoolType>(opType))
559 return op->emitOpError("result type (")
560 << opType << ") must be '!cir.bool' for '" << attrType << "'";
561 return success();
562 }
563
564 if (mlir::isa<cir::IntAttr, cir::FPAttr>(attrType)) {
565 auto at = cast<TypedAttr>(attrType);
566 if (at.getType() != opType) {
567 return op->emitOpError("result type (")
568 << opType << ") does not match value type (" << at.getType()
569 << ")";
570 }
571 return success();
572 }
573
574 if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
575 cir::ConstComplexAttr, cir::ConstRecordAttr,
576 cir::GlobalViewAttr, cir::PoisonAttr, cir::TypeInfoAttr,
577 cir::VTableAttr>(attrType))
578 return success();
579
580 assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
581 return op->emitOpError("global with type ")
582 << cast<TypedAttr>(attrType).getType() << " not yet supported";
583}
584
585LogicalResult cir::ConstantOp::verify() {
586 // ODS already generates checks to make sure the result type is valid. We just
587 // need to additionally check that the value's attribute type is consistent
588 // with the result type.
589 return checkConstantTypes(getOperation(), getType(), getValue());
590}
591
592OpFoldResult cir::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
593 return getValue();
594}
595
596//===----------------------------------------------------------------------===//
597// CastOp
598//===----------------------------------------------------------------------===//
599
600LogicalResult cir::CastOp::verify() {
601 mlir::Type resType = getType();
602 mlir::Type srcType = getSrc().getType();
603
604 // Verify address space casts for pointer types. given that
605 // casts for within a different address space are illegal.
606 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
607 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
608 if (srcPtrTy && resPtrTy && (getKind() != cir::CastKind::address_space))
609 if (srcPtrTy.getAddrSpace() != resPtrTy.getAddrSpace()) {
610 return emitOpError() << "result type address space does not match the "
611 "address space of the operand";
612 }
613
614 if (mlir::isa<cir::VectorType>(srcType) &&
615 mlir::isa<cir::VectorType>(resType)) {
616 // Use the element type of the vector to verify the cast kind. (Except for
617 // bitcast, see below.)
618 srcType = mlir::dyn_cast<cir::VectorType>(srcType).getElementType();
619 resType = mlir::dyn_cast<cir::VectorType>(resType).getElementType();
620 }
621
622 switch (getKind()) {
623 case cir::CastKind::int_to_bool: {
624 if (!mlir::isa<cir::BoolType>(resType))
625 return emitOpError() << "requires !cir.bool type for result";
626 if (!mlir::isa<cir::IntType>(srcType))
627 return emitOpError() << "requires !cir.int type for source";
628 return success();
629 }
630 case cir::CastKind::ptr_to_bool: {
631 if (!mlir::isa<cir::BoolType>(resType))
632 return emitOpError() << "requires !cir.bool type for result";
633 if (!mlir::isa<cir::PointerType>(srcType))
634 return emitOpError() << "requires !cir.ptr type for source";
635 return success();
636 }
637 case cir::CastKind::integral: {
638 if (!mlir::isa<cir::IntType>(resType))
639 return emitOpError() << "requires !cir.int type for result";
640 if (!mlir::isa<cir::IntType>(srcType))
641 return emitOpError() << "requires !cir.int type for source";
642 return success();
643 }
644 case cir::CastKind::array_to_ptrdecay: {
645 const auto arrayPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
646 const auto flatPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
647 if (!arrayPtrTy || !flatPtrTy)
648 return emitOpError() << "requires !cir.ptr type for source and result";
649
650 // TODO(CIR): Make sure the AddrSpace of both types are equals
651 return success();
652 }
653 case cir::CastKind::bitcast: {
654 // Handle the pointer types first.
655 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
656 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
657
658 if (srcPtrTy && resPtrTy) {
659 return success();
660 }
661
662 return success();
663 }
664 case cir::CastKind::floating: {
665 if (!mlir::isa<cir::FPTypeInterface>(srcType) ||
666 !mlir::isa<cir::FPTypeInterface>(resType))
667 return emitOpError() << "requires !cir.float type for source and result";
668 return success();
669 }
670 case cir::CastKind::float_to_int: {
671 if (!mlir::isa<cir::FPTypeInterface>(srcType))
672 return emitOpError() << "requires !cir.float type for source";
673 if (!mlir::dyn_cast<cir::IntType>(resType))
674 return emitOpError() << "requires !cir.int type for result";
675 return success();
676 }
677 case cir::CastKind::int_to_ptr: {
678 if (!mlir::dyn_cast<cir::IntType>(srcType))
679 return emitOpError() << "requires !cir.int type for source";
680 if (!mlir::dyn_cast<cir::PointerType>(resType))
681 return emitOpError() << "requires !cir.ptr type for result";
682 return success();
683 }
684 case cir::CastKind::ptr_to_int: {
685 if (!mlir::dyn_cast<cir::PointerType>(srcType))
686 return emitOpError() << "requires !cir.ptr type for source";
687 if (!mlir::dyn_cast<cir::IntType>(resType))
688 return emitOpError() << "requires !cir.int type for result";
689 return success();
690 }
691 case cir::CastKind::float_to_bool: {
692 if (!mlir::isa<cir::FPTypeInterface>(srcType))
693 return emitOpError() << "requires !cir.float type for source";
694 if (!mlir::isa<cir::BoolType>(resType))
695 return emitOpError() << "requires !cir.bool type for result";
696 return success();
697 }
698 case cir::CastKind::bool_to_int: {
699 if (!mlir::isa<cir::BoolType>(srcType))
700 return emitOpError() << "requires !cir.bool type for source";
701 if (!mlir::isa<cir::IntType>(resType))
702 return emitOpError() << "requires !cir.int type for result";
703 return success();
704 }
705 case cir::CastKind::int_to_float: {
706 if (!mlir::isa<cir::IntType>(srcType))
707 return emitOpError() << "requires !cir.int type for source";
708 if (!mlir::isa<cir::FPTypeInterface>(resType))
709 return emitOpError() << "requires !cir.float type for result";
710 return success();
711 }
712 case cir::CastKind::bool_to_float: {
713 if (!mlir::isa<cir::BoolType>(srcType))
714 return emitOpError() << "requires !cir.bool type for source";
715 if (!mlir::isa<cir::FPTypeInterface>(resType))
716 return emitOpError() << "requires !cir.float type for result";
717 return success();
718 }
719 case cir::CastKind::address_space: {
720 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
721 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
722 if (!srcPtrTy || !resPtrTy)
723 return emitOpError() << "requires !cir.ptr type for source and result";
724 if (srcPtrTy.getPointee() != resPtrTy.getPointee())
725 return emitOpError() << "requires two types differ in addrspace only";
726 return success();
727 }
728 case cir::CastKind::float_to_complex: {
729 if (!mlir::isa<cir::FPTypeInterface>(srcType))
730 return emitOpError() << "requires !cir.float type for source";
731 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
732 if (!resComplexTy)
733 return emitOpError() << "requires !cir.complex type for result";
734 if (srcType != resComplexTy.getElementType())
735 return emitOpError() << "requires source type match result element type";
736 return success();
737 }
738 case cir::CastKind::int_to_complex: {
739 if (!mlir::isa<cir::IntType>(srcType))
740 return emitOpError() << "requires !cir.int type for source";
741 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
742 if (!resComplexTy)
743 return emitOpError() << "requires !cir.complex type for result";
744 if (srcType != resComplexTy.getElementType())
745 return emitOpError() << "requires source type match result element type";
746 return success();
747 }
748 case cir::CastKind::float_complex_to_real: {
749 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
750 if (!srcComplexTy)
751 return emitOpError() << "requires !cir.complex type for source";
752 if (!mlir::isa<cir::FPTypeInterface>(resType))
753 return emitOpError() << "requires !cir.float type for result";
754 if (srcComplexTy.getElementType() != resType)
755 return emitOpError() << "requires source element type match result type";
756 return success();
757 }
758 case cir::CastKind::int_complex_to_real: {
759 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
760 if (!srcComplexTy)
761 return emitOpError() << "requires !cir.complex type for source";
762 if (!mlir::isa<cir::IntType>(resType))
763 return emitOpError() << "requires !cir.int type for result";
764 if (srcComplexTy.getElementType() != resType)
765 return emitOpError() << "requires source element type match result type";
766 return success();
767 }
768 case cir::CastKind::float_complex_to_bool: {
769 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
770 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
771 return emitOpError()
772 << "requires floating point !cir.complex type for source";
773 if (!mlir::isa<cir::BoolType>(resType))
774 return emitOpError() << "requires !cir.bool type for result";
775 return success();
776 }
777 case cir::CastKind::int_complex_to_bool: {
778 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
779 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
780 return emitOpError()
781 << "requires floating point !cir.complex type for source";
782 if (!mlir::isa<cir::BoolType>(resType))
783 return emitOpError() << "requires !cir.bool type for result";
784 return success();
785 }
786 case cir::CastKind::float_complex: {
787 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
788 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
789 return emitOpError()
790 << "requires floating point !cir.complex type for source";
791 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
792 if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
793 return emitOpError()
794 << "requires floating point !cir.complex type for result";
795 return success();
796 }
797 case cir::CastKind::float_complex_to_int_complex: {
798 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
799 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
800 return emitOpError()
801 << "requires floating point !cir.complex type for source";
802 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
803 if (!resComplexTy || !resComplexTy.isIntegerComplex())
804 return emitOpError() << "requires integer !cir.complex type for result";
805 return success();
806 }
807 case cir::CastKind::int_complex: {
808 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
809 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
810 return emitOpError() << "requires integer !cir.complex type for source";
811 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
812 if (!resComplexTy || !resComplexTy.isIntegerComplex())
813 return emitOpError() << "requires integer !cir.complex type for result";
814 return success();
815 }
816 case cir::CastKind::int_complex_to_float_complex: {
817 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
818 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
819 return emitOpError() << "requires integer !cir.complex type for source";
820 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
821 if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
822 return emitOpError()
823 << "requires floating point !cir.complex type for result";
824 return success();
825 }
826 case cir::CastKind::member_ptr_to_bool: {
827 if (!mlir::isa<cir::DataMemberType, cir::MethodType>(srcType))
828 return emitOpError()
829 << "requires !cir.data_member or !cir.method type for source";
830 if (!mlir::isa<cir::BoolType>(resType))
831 return emitOpError() << "requires !cir.bool type for result";
832 return success();
833 }
834 }
835 llvm_unreachable("Unknown CastOp kind?");
836}
837
838static bool isIntOrBoolCast(cir::CastOp op) {
839 auto kind = op.getKind();
840 return kind == cir::CastKind::bool_to_int ||
841 kind == cir::CastKind::int_to_bool || kind == cir::CastKind::integral;
842}
843
844static bool isCirFunctionPointerType(mlir::Type ty) {
845 const auto ptrTy = mlir::dyn_cast<cir::PointerType>(ty);
846 return ptrTy && mlir::isa<cir::FuncType>(ptrTy.getPointee());
847}
848
849static Value tryFoldCastChain(cir::CastOp op) {
850 cir::CastOp head = op, tail = op;
851
852 while (op) {
853 if (!isIntOrBoolCast(op))
854 break;
855 head = op;
856 op = head.getSrc().getDefiningOp<cir::CastOp>();
857 }
858
859 if (head != tail) {
860 // if bool_to_int -> ... -> int_to_bool: take the bool
861 // as we had it was before all casts
862 if (head.getKind() == cir::CastKind::bool_to_int &&
863 tail.getKind() == cir::CastKind::int_to_bool)
864 return head.getSrc();
865
866 // if int_to_bool -> ... -> int_to_bool: take the result
867 // of the first one, as no other casts (and ext casts as well)
868 // don't change the first result
869 if (head.getKind() == cir::CastKind::int_to_bool &&
870 tail.getKind() == cir::CastKind::int_to_bool)
871 return head.getResult();
872
873 return {};
874 }
875
876 // Bitcast round-trip on function pointers: T0 -> T1 -> T0 (e.g. no-proto
877 // redeclaration vs. actual prototype). Restrict to function pointers so
878 // other pointer bitcast chains are unchanged.
879 if (tail.getKind() == cir::CastKind::bitcast) {
880 auto *inner = tail.getSrc().getDefiningOp();
881 if (inner && isCirFunctionPointerType(tail.getType())) {
882 auto innerCast = mlir::dyn_cast<cir::CastOp>(inner);
883 if (innerCast && innerCast.getKind() == cir::CastKind::bitcast &&
884 innerCast.getSrc().getType() == tail.getType() &&
885 innerCast.getType() == tail.getSrc().getType()) {
886 return innerCast.getSrc();
887 }
888 }
889 }
890
891 return {};
892}
893
894OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
895 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getSrc())) {
896 // Propagate poison value
897 return cir::PoisonAttr::get(getContext(), getType());
898 }
899
900 if (getSrc().getType() == getType()) {
901 switch (getKind()) {
902 case cir::CastKind::integral: {
904 auto foldOrder = getSrc().getDefiningOp()->fold(foldResults);
905 if (foldOrder.succeeded() && mlir::isa<mlir::Attribute>(foldResults[0]))
906 return mlir::cast<mlir::Attribute>(foldResults[0]);
907 return {};
908 }
909 case cir::CastKind::bitcast:
910 case cir::CastKind::address_space:
911 case cir::CastKind::float_complex:
912 case cir::CastKind::int_complex: {
913 return getSrc();
914 }
915 default:
916 return {};
917 }
918 }
919
920 // Handle cases where a chain of casts cancel out.
921 Value result = tryFoldCastChain(*this);
922 if (result)
923 return result;
924
925 // Handle simple constant casts.
926 if (auto srcConst = getSrc().getDefiningOp<cir::ConstantOp>()) {
927 switch (getKind()) {
928 case cir::CastKind::integral: {
929 mlir::Type srcTy = getSrc().getType();
930 // Don't try to fold vector casts for now.
931 assert(mlir::isa<cir::VectorType>(srcTy) ==
932 mlir::isa<cir::VectorType>(getType()));
933 if (mlir::isa<cir::VectorType>(srcTy))
934 break;
935
936 auto srcIntTy = mlir::cast<cir::IntType>(srcTy);
937 auto dstIntTy = mlir::cast<cir::IntType>(getType());
938 APInt newVal =
939 srcIntTy.isSigned()
940 ? srcConst.getIntValue().sextOrTrunc(dstIntTy.getWidth())
941 : srcConst.getIntValue().zextOrTrunc(dstIntTy.getWidth());
942 return cir::IntAttr::get(dstIntTy, newVal);
943 }
944 default:
945 break;
946 }
947 }
948 return {};
949}
950
951//===----------------------------------------------------------------------===//
952// CallOp
953//===----------------------------------------------------------------------===//
954
955mlir::OperandRange cir::CallOp::getArgOperands() {
956 if (isIndirect())
957 return getArgs().drop_front(1);
958 return getArgs();
959}
960
961mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
962 mlir::MutableOperandRange args = getArgsMutable();
963 if (isIndirect())
964 return args.slice(1, args.size() - 1);
965 return args;
966}
967
968mlir::Value cir::CallOp::getIndirectCall() {
969 assert(isIndirect());
970 return getOperand(0);
971}
972
973/// Return the operand at index 'i'.
974Value cir::CallOp::getArgOperand(unsigned i) {
975 if (isIndirect())
976 ++i;
977 return getOperand(i);
978}
979
980/// Return the number of operands.
981unsigned cir::CallOp::getNumArgOperands() {
982 if (isIndirect())
983 return this->getOperation()->getNumOperands() - 1;
984 return this->getOperation()->getNumOperands();
985}
986
987static mlir::ParseResult
988parseTryCallDestinations(mlir::OpAsmParser &parser,
989 mlir::OperationState &result) {
990 mlir::Block *normalDestSuccessor;
991 if (parser.parseSuccessor(normalDestSuccessor))
992 return mlir::failure();
993
994 if (parser.parseComma())
995 return mlir::failure();
996
997 mlir::Block *unwindDestSuccessor;
998 if (parser.parseSuccessor(unwindDestSuccessor))
999 return mlir::failure();
1000
1001 result.addSuccessors(normalDestSuccessor);
1002 result.addSuccessors(unwindDestSuccessor);
1003 return mlir::success();
1004}
1005
1006static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
1007 mlir::OperationState &result,
1008 bool hasDestinationBlocks = false) {
1010 llvm::SMLoc opsLoc;
1011 mlir::FlatSymbolRefAttr calleeAttr;
1012
1013 // If we cannot parse a string callee, it means this is an indirect call.
1014 if (!parser
1015 .parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
1016 result.attributes)
1017 .has_value()) {
1018 OpAsmParser::UnresolvedOperand indirectVal;
1019 // Do not resolve right now, since we need to figure out the type
1020 if (parser.parseOperand(indirectVal).failed())
1021 return failure();
1022 ops.push_back(indirectVal);
1023 }
1024
1025 if (parser.parseLParen())
1026 return mlir::failure();
1027
1028 opsLoc = parser.getCurrentLocation();
1029 if (parser.parseOperandList(ops))
1030 return mlir::failure();
1031 if (parser.parseRParen())
1032 return mlir::failure();
1033
1034 if (hasDestinationBlocks &&
1035 parseTryCallDestinations(parser, result).failed()) {
1036 return ::mlir::failure();
1037 }
1038
1039 if (parser.parseOptionalKeyword("musttail").succeeded())
1040 result.addAttribute(CIRDialect::getMustTailAttrName(),
1041 mlir::UnitAttr::get(parser.getContext()));
1042
1043 if (parser.parseOptionalKeyword("nothrow").succeeded())
1044 result.addAttribute(CIRDialect::getNoThrowAttrName(),
1045 mlir::UnitAttr::get(parser.getContext()));
1046
1047 if (parser.parseOptionalKeyword("side_effect").succeeded()) {
1048 if (parser.parseLParen().failed())
1049 return failure();
1050 cir::SideEffect sideEffect;
1051 if (parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed())
1052 return failure();
1053 if (parser.parseRParen().failed())
1054 return failure();
1055 auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
1056 result.addAttribute(CIRDialect::getSideEffectAttrName(), attr);
1057 }
1058
1059 if (parser.parseOptionalAttrDict(result.attributes))
1060 return ::mlir::failure();
1061
1062 if (parser.parseColon())
1063 return ::mlir::failure();
1064
1065 SmallVector<Type> argTypes;
1067 SmallVector<Type> resultTypes;
1068 SmallVector<DictionaryAttr> resultAttrs;
1069 if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
1070 resultTypes, resultAttrs))
1071 return mlir::failure();
1072
1073 if (resultTypes.size() > 1 || resultAttrs.size() > 1)
1074 return parser.emitError(
1075 parser.getCurrentLocation(),
1076 "functions with multiple return types are not supported");
1077
1078 result.addTypes(resultTypes);
1079
1080 if (parser.resolveOperands(ops, argTypes, opsLoc, result.operands))
1081 return mlir::failure();
1082
1083 if (!resultAttrs.empty() && resultAttrs[0])
1084 result.addAttribute(
1085 CIRDialect::getResAttrsAttrName(),
1086 mlir::ArrayAttr::get(parser.getContext(), {resultAttrs[0]}));
1087
1088 // ArrayAttr requires a vector of 'Attribute', so we have to do the conversion
1089 // here into a separate collection.
1090 llvm::SmallVector<Attribute> convertedArgAttrs;
1091 bool argAttrsEmpty = true;
1092
1093 llvm::transform(argAttrs, std::back_inserter(convertedArgAttrs),
1094 [&](DictionaryAttr da) -> mlir::Attribute {
1095 if (da)
1096 argAttrsEmpty = false;
1097 return da;
1098 });
1099
1100 if (!argAttrsEmpty) {
1101 llvm::ArrayRef argAttrsRef = convertedArgAttrs;
1102 if (!calleeAttr) {
1103 // Fixup for indirect calls, which get an extra entry in the 'args' for
1104 // the indirect type, which doesn't get attributes.
1105 argAttrsRef = argAttrsRef.drop_front();
1106 }
1107 result.addAttribute(CIRDialect::getArgAttrsAttrName(),
1108 mlir::ArrayAttr::get(parser.getContext(), argAttrsRef));
1109 }
1110
1111 return mlir::success();
1112}
1113
1114static void
1115printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym,
1116 mlir::Value indirectCallee, mlir::OpAsmPrinter &printer,
1117 bool isNothrow, cir::SideEffect sideEffect, ArrayAttr argAttrs,
1118 ArrayAttr resAttrs, mlir::Block *normalDest = nullptr,
1119 mlir::Block *unwindDest = nullptr) {
1120 printer << ' ';
1121
1122 auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
1123 auto ops = callLikeOp.getArgOperands();
1124
1125 if (calleeSym) {
1126 // Direct calls
1127 printer.printAttributeWithoutType(calleeSym);
1128 } else {
1129 // Indirect calls
1130 assert(indirectCallee);
1131 printer << indirectCallee;
1132 }
1133
1134 printer << "(" << ops << ")";
1135
1136 if (normalDest) {
1137 assert(unwindDest && "expected two successors");
1138 auto tryCall = cast<cir::TryCallOp>(op);
1139 printer << ' ' << tryCall.getNormalDest();
1140 printer << ",";
1141 printer << ' ';
1142 printer << tryCall.getUnwindDest();
1143 }
1144
1145 if (op->hasAttr(CIRDialect::getMustTailAttrName()))
1146 printer << " musttail";
1147
1148 if (isNothrow)
1149 printer << " nothrow";
1150
1151 if (sideEffect != cir::SideEffect::All) {
1152 printer << " side_effect(";
1153 printer << stringifySideEffect(sideEffect);
1154 printer << ")";
1155 }
1156
1158 CIRDialect::getCalleeAttrName(),
1159 CIRDialect::getMustTailAttrName(),
1160 CIRDialect::getNoThrowAttrName(),
1161 CIRDialect::getSideEffectAttrName(),
1162 CIRDialect::getOperandSegmentSizesAttrName(),
1163 llvm::StringRef("res_attrs"),
1164 llvm::StringRef("arg_attrs")};
1165 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
1166 printer << " : ";
1167 if (calleeSym || !argAttrs) {
1168 call_interface_impl::printFunctionSignature(
1169 printer, op->getOperands().getTypes(), argAttrs,
1170 /*isVariadic=*/false, op->getResultTypes(), resAttrs);
1171 } else {
1172 // indirect function calls use an 'arg' type for the type of its indirect
1173 // argument. However, we don't store a similar attribute collection. In
1174 // order to make `printFunctionSignature` have the attributes line up, we
1175 // have to make a 'shimmed' copy of the attributes that have a blank set of
1176 // attributes for the indirect argument.
1177 llvm::SmallVector<Attribute> shimmedArgAttrs;
1178 shimmedArgAttrs.push_back(mlir::DictionaryAttr::get(op->getContext(), {}));
1179 shimmedArgAttrs.append(argAttrs.begin(), argAttrs.end());
1180 call_interface_impl::printFunctionSignature(
1181 printer, op->getOperands().getTypes(),
1182 mlir::ArrayAttr::get(op->getContext(), shimmedArgAttrs),
1183 /*isVariadic=*/false, op->getResultTypes(), resAttrs);
1184 }
1185}
1186
1187mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
1188 mlir::OperationState &result) {
1189 return parseCallCommon(parser, result);
1190}
1191
1192void cir::CallOp::print(mlir::OpAsmPrinter &p) {
1193 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
1194 cir::SideEffect sideEffect = getSideEffect();
1195 printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
1196 sideEffect, getArgAttrsAttr(), getResAttrsAttr());
1197}
1198
1199static LogicalResult
1200verifyCallCommInSymbolUses(mlir::Operation *op,
1201 SymbolTableCollection &symbolTable) {
1202 auto fnAttr =
1203 op->getAttrOfType<FlatSymbolRefAttr>(CIRDialect::getCalleeAttrName());
1204 if (!fnAttr) {
1205 // This is an indirect call, thus we don't have to check the symbol uses.
1206 return mlir::success();
1207 }
1208
1209 auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
1210 if (!fn)
1211 return op->emitOpError() << "'" << fnAttr.getValue()
1212 << "' does not reference a valid function";
1213
1214 auto callIf = dyn_cast<cir::CIRCallOpInterface>(op);
1215 assert(callIf && "expected CIR call interface to be always available");
1216
1217 // Verify that the operand and result types match the callee. Note that
1218 // argument-checking is disabled for functions without a prototype.
1219 auto fnType = fn.getFunctionType();
1220 if (!fn.getNoProto()) {
1221 unsigned numCallOperands = callIf.getNumArgOperands();
1222 unsigned numFnOpOperands = fnType.getNumInputs();
1223
1224 if (!fnType.isVarArg() && numCallOperands != numFnOpOperands)
1225 return op->emitOpError("incorrect number of operands for callee");
1226 if (fnType.isVarArg() && numCallOperands < numFnOpOperands)
1227 return op->emitOpError("too few operands for callee");
1228
1229 for (unsigned i = 0, e = numFnOpOperands; i != e; ++i)
1230 if (callIf.getArgOperand(i).getType() != fnType.getInput(i))
1231 return op->emitOpError("operand type mismatch: expected operand type ")
1232 << fnType.getInput(i) << ", but provided "
1233 << op->getOperand(i).getType() << " for operand number " << i;
1234 }
1235
1237
1238 // Void function must not return any results.
1239 if (fnType.hasVoidReturn() && op->getNumResults() != 0)
1240 return op->emitOpError("callee returns void but call has results");
1241
1242 // Non-void function calls must return exactly one result.
1243 if (!fnType.hasVoidReturn() && op->getNumResults() != 1)
1244 return op->emitOpError("incorrect number of results for callee");
1245
1246 // Parent function and return value types must match.
1247 if (!fnType.hasVoidReturn() &&
1248 op->getResultTypes().front() != fnType.getReturnType()) {
1249 return op->emitOpError("result type mismatch: expected ")
1250 << fnType.getReturnType() << ", but provided "
1251 << op->getResult(0).getType();
1252 }
1253
1254 return mlir::success();
1255}
1256
1257LogicalResult
1258cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1259 return verifyCallCommInSymbolUses(*this, symbolTable);
1260}
1261
1262//===----------------------------------------------------------------------===//
1263// TryCallOp
1264//===----------------------------------------------------------------------===//
1265
1266mlir::OperandRange cir::TryCallOp::getArgOperands() {
1267 if (isIndirect())
1268 return getArgs().drop_front(1);
1269 return getArgs();
1270}
1271
1272mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
1273 mlir::MutableOperandRange args = getArgsMutable();
1274 if (isIndirect())
1275 return args.slice(1, args.size() - 1);
1276 return args;
1277}
1278
1279mlir::Value cir::TryCallOp::getIndirectCall() {
1280 assert(isIndirect());
1281 return getOperand(0);
1282}
1283
1284/// Return the operand at index 'i'.
1285Value cir::TryCallOp::getArgOperand(unsigned i) {
1286 if (isIndirect())
1287 ++i;
1288 return getOperand(i);
1289}
1290
1291/// Return the number of operands.
1292unsigned cir::TryCallOp::getNumArgOperands() {
1293 if (isIndirect())
1294 return this->getOperation()->getNumOperands() - 1;
1295 return this->getOperation()->getNumOperands();
1296}
1297
1298LogicalResult
1299cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1300 return verifyCallCommInSymbolUses(*this, symbolTable);
1301}
1302
1303mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
1304 mlir::OperationState &result) {
1305 return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
1306}
1307
1308void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
1309 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
1310 cir::SideEffect sideEffect = getSideEffect();
1311 printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
1312 sideEffect, getArgAttrsAttr(), getResAttrsAttr(),
1313 getNormalDest(), getUnwindDest());
1314}
1315
1316//===----------------------------------------------------------------------===//
1317// ReturnOp
1318//===----------------------------------------------------------------------===//
1319
1320static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op,
1321 cir::FuncOp function) {
1322 // ReturnOps currently only have a single optional operand.
1323 if (op.getNumOperands() > 1)
1324 return op.emitOpError() << "expects at most 1 return operand";
1325
1326 // Ensure returned type matches the function signature.
1327 auto expectedTy = function.getFunctionType().getReturnType();
1328 auto actualTy =
1329 (op.getNumOperands() == 0 ? cir::VoidType::get(op.getContext())
1330 : op.getOperand(0).getType());
1331 if (actualTy != expectedTy)
1332 return op.emitOpError() << "returns " << actualTy
1333 << " but enclosing function returns " << expectedTy;
1334
1335 return mlir::success();
1336}
1337
1338mlir::LogicalResult cir::ReturnOp::verify() {
1339 // Returns can be present in multiple different scopes, get the
1340 // wrapping function and start from there.
1341 auto *fnOp = getOperation()->getParentOp();
1342 while (!isa<cir::FuncOp>(fnOp))
1343 fnOp = fnOp->getParentOp();
1344
1345 // Make sure return types match function return type.
1346 if (checkReturnAndFunction(*this, cast<cir::FuncOp>(fnOp)).failed())
1347 return failure();
1348
1349 return success();
1350}
1351
1352//===----------------------------------------------------------------------===//
1353// IfOp
1354//===----------------------------------------------------------------------===//
1355
1356ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
1357 // create the regions for 'then'.
1358 result.regions.reserve(2);
1359 Region *thenRegion = result.addRegion();
1360 Region *elseRegion = result.addRegion();
1361
1362 mlir::Builder &builder = parser.getBuilder();
1363 OpAsmParser::UnresolvedOperand cond;
1364 Type boolType = cir::BoolType::get(builder.getContext());
1365
1366 if (parser.parseOperand(cond) ||
1367 parser.resolveOperand(cond, boolType, result.operands))
1368 return failure();
1369
1370 // Parse 'then' region.
1371 mlir::SMLoc parseThenLoc = parser.getCurrentLocation();
1372 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1373 return failure();
1374
1375 if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed())
1376 return failure();
1377
1378 // If we find an 'else' keyword, parse the 'else' region.
1379 if (!parser.parseOptionalKeyword("else")) {
1380 mlir::SMLoc parseElseLoc = parser.getCurrentLocation();
1381 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1382 return failure();
1383 if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed())
1384 return failure();
1385 }
1386
1387 // Parse the optional attribute list.
1388 if (parser.parseOptionalAttrDict(result.attributes))
1389 return failure();
1390 return success();
1391}
1392
1393void cir::IfOp::print(OpAsmPrinter &p) {
1394 p << " " << getCondition() << " ";
1395 mlir::Region &thenRegion = this->getThenRegion();
1396 p.printRegion(thenRegion,
1397 /*printEntryBlockArgs=*/false,
1398 /*printBlockTerminators=*/!omitRegionTerm(thenRegion));
1399
1400 // Print the 'else' regions if it exists and has a block.
1401 mlir::Region &elseRegion = this->getElseRegion();
1402 if (!elseRegion.empty()) {
1403 p << " else ";
1404 p.printRegion(elseRegion,
1405 /*printEntryBlockArgs=*/false,
1406 /*printBlockTerminators=*/!omitRegionTerm(elseRegion));
1407 }
1408
1409 p.printOptionalAttrDict(getOperation()->getAttrs());
1410}
1411
1412/// Default callback for IfOp builders.
1413void cir::buildTerminatedBody(OpBuilder &builder, Location loc) {
1414 // add cir.yield to end of the block
1415 cir::YieldOp::create(builder, loc);
1416}
1417
1418/// Given the region at `index`, or the parent operation if `index` is None,
1419/// return the successor regions. These are the regions that may be selected
1420/// during the flow of control. `operands` is a set of optional attributes that
1421/// correspond to a constant value for each operand, or null if that operand is
1422/// not a constant.
1423void cir::IfOp::getSuccessorRegions(mlir::RegionBranchPoint point,
1424 SmallVectorImpl<RegionSuccessor> &regions) {
1425 // The `then` and the `else` region branch back to the parent operation.
1426 if (!point.isParent()) {
1427 regions.push_back(RegionSuccessor::parent());
1428 return;
1429 }
1430
1431 // Don't consider the else region if it is empty.
1432 Region *elseRegion = &this->getElseRegion();
1433 if (elseRegion->empty())
1434 elseRegion = nullptr;
1435
1436 // If the condition isn't constant, both regions may be executed.
1437 regions.push_back(RegionSuccessor(&getThenRegion()));
1438 if (elseRegion)
1439 regions.push_back(RegionSuccessor(elseRegion));
1440 else
1441 regions.push_back(RegionSuccessor::parent());
1442}
1443
1444mlir::ValueRange cir::IfOp::getSuccessorInputs(RegionSuccessor successor) {
1445 return successor.isParent() ? ValueRange(getOperation()->getResults())
1446 : ValueRange();
1447}
1448
1449void cir::IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1450 bool withElseRegion, BuilderCallbackRef thenBuilder,
1451 BuilderCallbackRef elseBuilder) {
1452 assert(thenBuilder && "the builder callback for 'then' must be present");
1453 result.addOperands(cond);
1454
1455 OpBuilder::InsertionGuard guard(builder);
1456 Region *thenRegion = result.addRegion();
1457 builder.createBlock(thenRegion);
1458 thenBuilder(builder, result.location);
1459
1460 Region *elseRegion = result.addRegion();
1461 if (!withElseRegion)
1462 return;
1463
1464 builder.createBlock(elseRegion);
1465 elseBuilder(builder, result.location);
1466}
1467
1468//===----------------------------------------------------------------------===//
1469// ScopeOp
1470//===----------------------------------------------------------------------===//
1471
1472/// Given the region at `index`, or the parent operation if `index` is None,
1473/// return the successor regions. These are the regions that may be selected
1474/// during the flow of control. `operands` is a set of optional attributes
1475/// that correspond to a constant value for each operand, or null if that
1476/// operand is not a constant.
1477void cir::ScopeOp::getSuccessorRegions(
1478 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1479 // The only region always branch back to the parent operation.
1480 if (!point.isParent()) {
1481 regions.push_back(RegionSuccessor::parent());
1482 return;
1483 }
1484
1485 // If the condition isn't constant, both regions may be executed.
1486 regions.push_back(RegionSuccessor(&getScopeRegion()));
1487}
1488
1489mlir::ValueRange cir::ScopeOp::getSuccessorInputs(RegionSuccessor successor) {
1490 return successor.isParent() ? ValueRange(getOperation()->getResults())
1491 : ValueRange();
1492}
1493
1494void cir::ScopeOp::build(
1495 OpBuilder &builder, OperationState &result,
1496 function_ref<void(OpBuilder &, Type &, Location)> scopeBuilder) {
1497 assert(scopeBuilder && "the builder callback for 'then' must be present");
1498
1499 OpBuilder::InsertionGuard guard(builder);
1500 Region *scopeRegion = result.addRegion();
1501 builder.createBlock(scopeRegion);
1503
1504 mlir::Type yieldTy;
1505 scopeBuilder(builder, yieldTy, result.location);
1506
1507 if (yieldTy)
1508 result.addTypes(TypeRange{yieldTy});
1509}
1510
1511void cir::ScopeOp::build(
1512 OpBuilder &builder, OperationState &result,
1513 function_ref<void(OpBuilder &, Location)> scopeBuilder) {
1514 assert(scopeBuilder && "the builder callback for 'then' must be present");
1515 OpBuilder::InsertionGuard guard(builder);
1516 Region *scopeRegion = result.addRegion();
1517 builder.createBlock(scopeRegion);
1519 scopeBuilder(builder, result.location);
1520}
1521
1522LogicalResult cir::ScopeOp::verify() {
1523 if (getRegion().empty()) {
1524 return emitOpError() << "cir.scope must not be empty since it should "
1525 "include at least an implicit cir.yield ";
1526 }
1527
1528 mlir::Block &lastBlock = getRegion().back();
1529 if (lastBlock.empty() || !lastBlock.mightHaveTerminator() ||
1530 !lastBlock.getTerminator()->hasTrait<OpTrait::IsTerminator>())
1531 return emitOpError() << "last block of cir.scope must be terminated";
1532 return success();
1533}
1534
1535LogicalResult cir::ScopeOp::fold(FoldAdaptor /*adaptor*/,
1536 SmallVectorImpl<OpFoldResult> &results) {
1537 // Only fold "trivial" scopes: a single block containing only a `cir.yield`.
1538 if (!getRegion().hasOneBlock())
1539 return failure();
1540 Block &block = getRegion().front();
1541 if (block.getOperations().size() != 1)
1542 return failure();
1543
1544 auto yield = dyn_cast<cir::YieldOp>(block.front());
1545 if (!yield)
1546 return failure();
1547
1548 // Only fold when the scope produces a value.
1549 if (getNumResults() != 1 || yield.getNumOperands() != 1)
1550 return failure();
1551
1552 results.push_back(yield.getOperand(0));
1553 return success();
1554}
1555
1556//===----------------------------------------------------------------------===//
1557// CleanupScopeOp
1558//===----------------------------------------------------------------------===//
1559
1560void cir::CleanupScopeOp::getSuccessorRegions(
1561 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1562 if (!point.isParent()) {
1563 regions.push_back(RegionSuccessor::parent());
1564 return;
1565 }
1566
1567 // Execution always proceeds from the body region to the cleanup region.
1568 regions.push_back(RegionSuccessor(&getBodyRegion()));
1569 regions.push_back(RegionSuccessor(&getCleanupRegion()));
1570}
1571
1572mlir::ValueRange
1573cir::CleanupScopeOp::getSuccessorInputs(RegionSuccessor successor) {
1574 return ValueRange();
1575}
1576
1577LogicalResult cir::CleanupScopeOp::canonicalize(CleanupScopeOp op,
1578 PatternRewriter &rewriter) {
1579 auto isRegionTrivial = [](Region &region) {
1580 assert(!region.empty() && "CleanupScopeOp regions must not be empty");
1581 if (!region.hasOneBlock())
1582 return false;
1583 Block &block = llvm::getSingleElement(region);
1584 return llvm::hasSingleElement(block) &&
1585 isa<cir::YieldOp>(llvm::getSingleElement(block));
1586 };
1587
1588 Region &body = op.getBodyRegion();
1589 Region &cleanup = op.getCleanupRegion();
1590
1591 // An EH-only cleanup scope with an empty body can never trigger its cleanup
1592 // region — there are no operations in the body that could throw. Erase it.
1593 if (op.getCleanupKind() == CleanupKind::EH && isRegionTrivial(body)) {
1594 rewriter.eraseOp(op);
1595 return success();
1596 }
1597
1598 // A cleanup scope with a trivial cleanup region has no cleanup to perform.
1599 // Inline the body into the parent block and erase the scope.
1600 if (!isRegionTrivial(cleanup) || !body.hasOneBlock())
1601 return failure();
1602
1603 Block &bodyBlock = body.front();
1604 if (!isa<cir::YieldOp>(bodyBlock.getTerminator()))
1605 return failure();
1606
1607 Operation *yield = bodyBlock.getTerminator();
1608 rewriter.inlineBlockBefore(&bodyBlock, op);
1609 rewriter.eraseOp(yield);
1610 rewriter.eraseOp(op);
1611 return success();
1612}
1613
1614void cir::CleanupScopeOp::build(
1615 OpBuilder &builder, OperationState &result, CleanupKind cleanupKind,
1616 function_ref<void(OpBuilder &, Location)> bodyBuilder,
1617 function_ref<void(OpBuilder &, Location)> cleanupBuilder) {
1618 result.addAttribute(getCleanupKindAttrName(result.name),
1619 CleanupKindAttr::get(builder.getContext(), cleanupKind));
1620
1621 OpBuilder::InsertionGuard guard(builder);
1622
1623 // Build body region.
1624 Region *bodyRegion = result.addRegion();
1625 builder.createBlock(bodyRegion);
1626 if (bodyBuilder)
1627 bodyBuilder(builder, result.location);
1628
1629 // Build cleanup region.
1630 Region *cleanupRegion = result.addRegion();
1631 builder.createBlock(cleanupRegion);
1632 if (cleanupBuilder)
1633 cleanupBuilder(builder, result.location);
1634}
1635
1636//===----------------------------------------------------------------------===//
1637// BrOp
1638//===----------------------------------------------------------------------===//
1639
1640/// Merges blocks connected by a unique unconditional branch.
1641///
1642/// ^bb0: ^bb0:
1643/// ... ...
1644/// cir.br ^bb1 => ...
1645/// ^bb1: cir.return
1646/// ...
1647/// cir.return
1648LogicalResult cir::BrOp::canonicalize(BrOp op, PatternRewriter &rewriter) {
1649 Block *src = op->getBlock();
1650 Block *dst = op.getDest();
1651
1652 // Do not fold self-loops.
1653 if (src == dst)
1654 return failure();
1655
1656 // Only merge when this is the unique edge between the blocks.
1657 if (src->getNumSuccessors() != 1 || dst->getSinglePredecessor() != src)
1658 return failure();
1659
1660 // Don't merge blocks that start with LabelOp or IndirectBrOp.
1661 // This is to avoid merging blocks that have an indirect predecessor.
1662 if (isa<cir::LabelOp, cir::IndirectBrOp>(dst->front()))
1663 return failure();
1664
1665 auto operands = op.getDestOperands();
1666 rewriter.eraseOp(op);
1667 rewriter.mergeBlocks(dst, src, operands);
1668 return success();
1669}
1670
1671mlir::SuccessorOperands cir::BrOp::getSuccessorOperands(unsigned index) {
1672 assert(index == 0 && "invalid successor index");
1673 return mlir::SuccessorOperands(getDestOperandsMutable());
1674}
1675
1676Block *cir::BrOp::getSuccessorForOperands(ArrayRef<Attribute>) {
1677 return getDest();
1678}
1679
1680//===----------------------------------------------------------------------===//
1681// IndirectBrCondOp
1682//===----------------------------------------------------------------------===//
1683
1684mlir::SuccessorOperands
1685cir::IndirectBrOp::getSuccessorOperands(unsigned index) {
1686 assert(index < getNumSuccessors() && "invalid successor index");
1687 return mlir::SuccessorOperands(getSuccOperandsMutable()[index]);
1688}
1689
1691 OpAsmParser &parser, Type &flagType,
1692 SmallVectorImpl<Block *> &succOperandBlocks,
1693 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &succOperands,
1694 SmallVectorImpl<SmallVector<Type>> &succOperandsTypes) {
1695 if (failed(parser.parseCommaSeparatedList(
1696 OpAsmParser::Delimiter::Square,
1697 [&]() {
1698 Block *destination = nullptr;
1699 SmallVector<OpAsmParser::UnresolvedOperand> operands;
1700 SmallVector<Type> operandTypes;
1701
1702 if (parser.parseSuccessor(destination).failed())
1703 return failure();
1704
1705 if (succeeded(parser.parseOptionalLParen())) {
1706 if (failed(parser.parseOperandList(
1707 operands, OpAsmParser::Delimiter::None)) ||
1708 failed(parser.parseColonTypeList(operandTypes)) ||
1709 failed(parser.parseRParen()))
1710 return failure();
1711 }
1712 succOperandBlocks.push_back(destination);
1713 succOperands.emplace_back(operands);
1714 succOperandsTypes.emplace_back(operandTypes);
1715 return success();
1716 },
1717 "successor blocks")))
1718 return failure();
1719 return success();
1720}
1721
1722void printIndirectBrOpSucessors(OpAsmPrinter &p, cir::IndirectBrOp op,
1723 Type flagType, SuccessorRange succs,
1724 OperandRangeRange succOperands,
1725 const TypeRangeRange &succOperandsTypes) {
1726 p << "[";
1727 llvm::interleave(
1728 llvm::zip(succs, succOperands),
1729 [&](auto i) {
1730 p.printNewline();
1731 p.printSuccessorAndUseList(std::get<0>(i), std::get<1>(i));
1732 },
1733 [&] { p << ','; });
1734 if (!succOperands.empty())
1735 p.printNewline();
1736 p << "]";
1737}
1738
1739//===----------------------------------------------------------------------===//
1740// BrCondOp
1741//===----------------------------------------------------------------------===//
1742
1743mlir::SuccessorOperands cir::BrCondOp::getSuccessorOperands(unsigned index) {
1744 assert(index < getNumSuccessors() && "invalid successor index");
1745 return SuccessorOperands(index == 0 ? getDestOperandsTrueMutable()
1746 : getDestOperandsFalseMutable());
1747}
1748
1749Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1750 if (IntegerAttr condAttr = dyn_cast_if_present<IntegerAttr>(operands.front()))
1751 return condAttr.getValue().isOne() ? getDestTrue() : getDestFalse();
1752 return nullptr;
1753}
1754
1755//===----------------------------------------------------------------------===//
1756// CaseOp
1757//===----------------------------------------------------------------------===//
1758
1759void cir::CaseOp::getSuccessorRegions(
1760 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1761 if (!point.isParent()) {
1762 regions.push_back(RegionSuccessor::parent());
1763 return;
1764 }
1765 regions.push_back(RegionSuccessor(&getCaseRegion()));
1766}
1767
1768mlir::ValueRange cir::CaseOp::getSuccessorInputs(RegionSuccessor successor) {
1769 return successor.isParent() ? ValueRange(getOperation()->getResults())
1770 : ValueRange();
1771}
1772
1773void cir::CaseOp::build(OpBuilder &builder, OperationState &result,
1774 ArrayAttr value, CaseOpKind kind,
1775 OpBuilder::InsertPoint &insertPoint) {
1776 OpBuilder::InsertionGuard guardSwitch(builder);
1777 result.addAttribute("value", value);
1778 result.getOrAddProperties<Properties>().kind =
1779 cir::CaseOpKindAttr::get(builder.getContext(), kind);
1780 Region *caseRegion = result.addRegion();
1781 builder.createBlock(caseRegion);
1782
1783 insertPoint = builder.saveInsertionPoint();
1784}
1785
1786//===----------------------------------------------------------------------===//
1787// SwitchOp
1788//===----------------------------------------------------------------------===//
1789
1790void cir::SwitchOp::getSuccessorRegions(
1791 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &region) {
1792 if (!point.isParent()) {
1793 region.push_back(RegionSuccessor::parent());
1794 return;
1795 }
1796
1797 region.push_back(RegionSuccessor(&getBody()));
1798}
1799
1800mlir::ValueRange cir::SwitchOp::getSuccessorInputs(RegionSuccessor successor) {
1801 return successor.isParent() ? ValueRange(getOperation()->getResults())
1802 : ValueRange();
1803}
1804
1805void cir::SwitchOp::build(OpBuilder &builder, OperationState &result,
1806 Value cond, BuilderOpStateCallbackRef switchBuilder) {
1807 assert(switchBuilder && "the builder callback for regions must be present");
1808 OpBuilder::InsertionGuard guardSwitch(builder);
1809 Region *switchRegion = result.addRegion();
1810 builder.createBlock(switchRegion);
1811 result.addOperands({cond});
1812 switchBuilder(builder, result.location, result);
1813}
1814
1815void cir::SwitchOp::collectCases(llvm::SmallVectorImpl<CaseOp> &cases) {
1816 walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
1817 // Don't walk in nested switch op.
1818 if (isa<cir::SwitchOp>(op) && op != *this)
1819 return WalkResult::skip();
1820
1821 if (auto caseOp = dyn_cast<cir::CaseOp>(op))
1822 cases.push_back(caseOp);
1823
1824 return WalkResult::advance();
1825 });
1826}
1827
1828bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
1829 collectCases(cases);
1830
1831 if (getBody().empty())
1832 return false;
1833
1834 if (!isa<YieldOp>(getBody().front().back()))
1835 return false;
1836
1837 if (!llvm::all_of(getBody().front(),
1838 [](Operation &op) { return isa<CaseOp, YieldOp>(op); }))
1839 return false;
1840
1841 return llvm::all_of(cases, [this](CaseOp op) {
1842 return op->getParentOfType<SwitchOp>() == *this;
1843 });
1844}
1845
1846//===----------------------------------------------------------------------===//
1847// SwitchFlatOp
1848//===----------------------------------------------------------------------===//
1849
1850void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
1851 Value value, Block *defaultDestination,
1852 ValueRange defaultOperands,
1853 ArrayRef<APInt> caseValues,
1854 BlockRange caseDestinations,
1855 ArrayRef<ValueRange> caseOperands) {
1856
1857 std::vector<mlir::Attribute> caseValuesAttrs;
1858 for (const APInt &val : caseValues)
1859 caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
1860 mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
1861
1862 build(builder, result, value, defaultOperands, caseOperands, attrs,
1863 defaultDestination, caseDestinations);
1864}
1865
1866/// <cases> ::= `[` (case (`,` case )* )? `]`
1867/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
1868static ParseResult parseSwitchFlatOpCases(
1869 OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
1870 SmallVectorImpl<Block *> &caseDestinations,
1872 &caseOperands,
1873 SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
1874 if (failed(parser.parseLSquare()))
1875 return failure();
1876 if (succeeded(parser.parseOptionalRSquare()))
1877 return success();
1879
1880 auto parseCase = [&]() {
1881 int64_t value = 0;
1882 if (failed(parser.parseInteger(value)))
1883 return failure();
1884
1885 values.push_back(cir::IntAttr::get(flagType, value));
1886
1887 Block *destination;
1889 llvm::SmallVector<Type> operandTypes;
1890 if (parser.parseColon() || parser.parseSuccessor(destination))
1891 return failure();
1892 if (!parser.parseOptionalLParen()) {
1893 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
1894 /*allowResultNumber=*/false) ||
1895 parser.parseColonTypeList(operandTypes) || parser.parseRParen())
1896 return failure();
1897 }
1898 caseDestinations.push_back(destination);
1899 caseOperands.emplace_back(operands);
1900 caseOperandTypes.emplace_back(operandTypes);
1901 return success();
1902 };
1903 if (failed(parser.parseCommaSeparatedList(parseCase)))
1904 return failure();
1905
1906 caseValues = ArrayAttr::get(flagType.getContext(), values);
1907
1908 return parser.parseRSquare();
1909}
1910
1911static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
1912 Type flagType, mlir::ArrayAttr caseValues,
1913 SuccessorRange caseDestinations,
1914 OperandRangeRange caseOperands,
1915 const TypeRangeRange &caseOperandTypes) {
1916 p << '[';
1917 p.printNewline();
1918 if (!caseValues) {
1919 p << ']';
1920 return;
1921 }
1922
1923 size_t index = 0;
1924 llvm::interleave(
1925 llvm::zip(caseValues, caseDestinations),
1926 [&](auto i) {
1927 p << " ";
1928 mlir::Attribute a = std::get<0>(i);
1929 p << mlir::cast<cir::IntAttr>(a).getValue();
1930 p << ": ";
1931 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
1932 },
1933 [&] {
1934 p << ',';
1935 p.printNewline();
1936 });
1937 p.printNewline();
1938 p << ']';
1939}
1940
1941//===----------------------------------------------------------------------===//
1942// GlobalOp
1943//===----------------------------------------------------------------------===//
1944
1945static ParseResult parseConstantValue(OpAsmParser &parser,
1946 mlir::Attribute &valueAttr) {
1947 NamedAttrList attr;
1948 return parser.parseAttribute(valueAttr, "value", attr);
1949}
1950
1951static void printConstant(OpAsmPrinter &p, Attribute value) {
1952 p.printAttribute(value);
1953}
1954
1955mlir::LogicalResult cir::GlobalOp::verify() {
1956 // Verify that the initial value, if present, is either a unit attribute or
1957 // an attribute CIR supports.
1958 if (getInitialValue().has_value()) {
1959 if (checkConstantTypes(getOperation(), getSymType(), *getInitialValue())
1960 .failed())
1961 return failure();
1962 }
1963
1964 if ((getStaticLocalGuard().has_value()) &&
1965 (!getCtorRegion().empty() || !getDtorRegion().empty()))
1966 return emitOpError(
1967 "Cannot have a static-local global-op with a constructor or "
1968 "destructor, they require in-function initialization via LocalInitOp");
1969
1970 if (getDynTlsRefs()) {
1971 if (getStaticLocalGuard().has_value())
1972 return emitOpError(
1973 "cannot have both static local and dynamic tls references");
1974 if (!getTlsModel() || getTlsModel() != TLS_Model::GeneralDynamic)
1975 return emitOpError("'dyn_tls_refs' only valid for dynamic tls");
1976 }
1977
1978 if (getAliasee().has_value()) {
1979 if (getInitialValue().has_value() || !getCtorRegion().empty() ||
1980 !getDtorRegion().empty())
1981 return emitOpError("global alias shall not have an initializer or "
1982 "constructor/destructor regions");
1983 }
1984
1985 // TODO(CIR): Many other checks for properties that haven't been upstreamed
1986 // yet.
1987
1988 return success();
1989}
1990
1991void cir::GlobalOp::build(
1992 OpBuilder &odsBuilder, OperationState &odsState, llvm::StringRef sym_name,
1993 mlir::Type sym_type, bool isConstant,
1994 mlir::ptr::MemorySpaceAttrInterface addrSpace,
1995 cir::GlobalLinkageKind linkage,
1996 function_ref<void(OpBuilder &, Location)> ctorBuilder,
1997 function_ref<void(OpBuilder &, Location)> dtorBuilder) {
1998 odsState.addAttribute(getSymNameAttrName(odsState.name),
1999 odsBuilder.getStringAttr(sym_name));
2000 odsState.addAttribute(getSymTypeAttrName(odsState.name),
2001 mlir::TypeAttr::get(sym_type));
2002 auto &properties = odsState.getOrAddProperties<cir::GlobalOp::Properties>();
2003 properties.setConstant(isConstant);
2004
2005 addrSpace = normalizeDefaultAddressSpace(addrSpace);
2006 if (addrSpace)
2007 odsState.addAttribute(getAddrSpaceAttrName(odsState.name), addrSpace);
2008
2009 cir::GlobalLinkageKindAttr linkageAttr =
2010 cir::GlobalLinkageKindAttr::get(odsBuilder.getContext(), linkage);
2011 odsState.addAttribute(getLinkageAttrName(odsState.name), linkageAttr);
2012
2013 Region *ctorRegion = odsState.addRegion();
2014 if (ctorBuilder) {
2015 odsBuilder.createBlock(ctorRegion);
2016 ctorBuilder(odsBuilder, odsState.location);
2017 }
2018
2019 Region *dtorRegion = odsState.addRegion();
2020 if (dtorBuilder) {
2021 odsBuilder.createBlock(dtorRegion);
2022 dtorBuilder(odsBuilder, odsState.location);
2023 }
2024}
2025
2026/// Given the region at `index`, or the parent operation if `index` is None,
2027/// return the successor regions. These are the regions that may be selected
2028/// during the flow of control. `operands` is a set of optional attributes that
2029/// correspond to a constant value for each operand, or null if that operand is
2030/// not a constant.
2031void cir::GlobalOp::getSuccessorRegions(
2032 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2033 // The `ctor` and `dtor` regions always branch back to the parent operation.
2034 if (!point.isParent()) {
2035 regions.push_back(RegionSuccessor::parent());
2036 return;
2037 }
2038
2039 // Don't consider the ctor region if it is empty.
2040 Region *ctorRegion = &this->getCtorRegion();
2041 if (ctorRegion->empty())
2042 ctorRegion = nullptr;
2043
2044 // Don't consider the dtor region if it is empty.
2045 Region *dtorRegion = &this->getDtorRegion();
2046 if (dtorRegion->empty())
2047 dtorRegion = nullptr;
2048
2049 // If the condition isn't constant, both regions may be executed.
2050 if (ctorRegion)
2051 regions.push_back(RegionSuccessor(ctorRegion));
2052 if (dtorRegion)
2053 regions.push_back(RegionSuccessor(dtorRegion));
2054}
2055
2056mlir::ValueRange cir::GlobalOp::getSuccessorInputs(RegionSuccessor successor) {
2057 return successor.isParent() ? ValueRange(getOperation()->getResults())
2058 : ValueRange();
2059}
2060
2061static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op,
2062 TypeAttr type, Attribute initAttr,
2063 mlir::Region &ctorRegion,
2064 mlir::Region &dtorRegion) {
2065 auto printType = [&]() { p << ": " << type; };
2066 // Aliases are definitions but they have no initial value or ctor/dtor; the
2067 // assembly prints them like declarations (`: type`).
2068 if (op.isDeclaration() || op.getAliasee()) {
2069 printType();
2070 return;
2071 }
2072
2073 p << "= ";
2074 if (!ctorRegion.empty()) {
2075 p << "ctor ";
2076 printType();
2077 p << " ";
2078 p.printRegion(ctorRegion,
2079 /*printEntryBlockArgs=*/false,
2080 /*printBlockTerminators=*/false);
2081 } else {
2082 // This also prints the type...
2083 if (initAttr)
2084 printConstant(p, initAttr);
2085 }
2086
2087 if (!dtorRegion.empty()) {
2088 p << " dtor ";
2089 p.printRegion(dtorRegion,
2090 /*printEntryBlockArgs=*/false,
2091 /*printBlockTerminators=*/false);
2092 }
2093}
2094
2095static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser,
2096 TypeAttr &typeAttr,
2097 Attribute &initialValueAttr,
2098 mlir::Region &ctorRegion,
2099 mlir::Region &dtorRegion) {
2100 mlir::Type opTy;
2101 if (parser.parseOptionalEqual().failed()) {
2102 // Absence of equal means a declaration, so we need to parse the type.
2103 // cir.global @a : !cir.int<s, 32>
2104 if (parser.parseColonType(opTy))
2105 return failure();
2106 } else {
2107 // Parse contructor, example:
2108 // cir.global @rgb = ctor : type { ... }
2109 if (!parser.parseOptionalKeyword("ctor")) {
2110 if (parser.parseColonType(opTy))
2111 return failure();
2112 auto parseLoc = parser.getCurrentLocation();
2113 if (parser.parseRegion(ctorRegion, /*arguments=*/{}, /*argTypes=*/{}))
2114 return failure();
2115 if (ensureRegionTerm(parser, ctorRegion, parseLoc).failed())
2116 return failure();
2117 } else {
2118 // Parse constant with initializer, examples:
2119 // cir.global @y = 3.400000e+00 : f32
2120 // cir.global @rgb = #cir.const_array<[...] : !cir.array<i8 x 3>>
2121 if (parseConstantValue(parser, initialValueAttr).failed())
2122 return failure();
2123
2124 assert(mlir::isa<mlir::TypedAttr>(initialValueAttr) &&
2125 "Non-typed attrs shouldn't appear here.");
2126 auto typedAttr = mlir::cast<mlir::TypedAttr>(initialValueAttr);
2127 opTy = typedAttr.getType();
2128 }
2129
2130 // Parse destructor, example:
2131 // dtor { ... }
2132 if (!parser.parseOptionalKeyword("dtor")) {
2133 auto parseLoc = parser.getCurrentLocation();
2134 if (parser.parseRegion(dtorRegion, /*arguments=*/{}, /*argTypes=*/{}))
2135 return failure();
2136 if (ensureRegionTerm(parser, dtorRegion, parseLoc).failed())
2137 return failure();
2138 }
2139 }
2140
2141 typeAttr = TypeAttr::get(opTy);
2142 return success();
2143}
2144
2145//===----------------------------------------------------------------------===//
2146// GetGlobalOp
2147//===----------------------------------------------------------------------===//
2148
2149LogicalResult
2150cir::GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2151 // Verify that the result type underlying pointer type matches the type of
2152 // the referenced cir.global or cir.func op.
2153 mlir::Operation *op =
2154 symbolTable.lookupNearestSymbolFrom(*this, getNameAttr());
2155 if (op == nullptr || !(isa<GlobalOp>(op) || isa<FuncOp>(op)))
2156 return emitOpError("'")
2157 << getName()
2158 << "' does not reference a valid cir.global or cir.func";
2159
2160 mlir::Type symTy;
2161 mlir::ptr::MemorySpaceAttrInterface symAddrSpaceAttr{};
2162 if (auto g = dyn_cast<GlobalOp>(op)) {
2163 symTy = g.getSymType();
2164 symAddrSpaceAttr = g.getAddrSpaceAttr();
2165 // Verify that for thread local global access, the global needs to
2166 // be marked with tls bits.
2167 if (getTls() && !g.getTlsModel())
2168 return emitOpError("access to global not marked thread local");
2169
2170 // Verify that the static_local attribute on GetGlobalOp matches the
2171 // static_local_guard attribute on GlobalOp. GetGlobalOp uses a UnitAttr,
2172 // GlobalOp uses StaticLocalGuardAttr. Both should be present, or neither.
2173 bool getGlobalIsStaticLocal = getStaticLocal();
2174 bool globalIsStaticLocal = g.getStaticLocalGuard().has_value();
2175 if (getGlobalIsStaticLocal != globalIsStaticLocal &&
2176 !getOperation()->getParentOfType<cir::GlobalOp>())
2177 return emitOpError("static_local attribute mismatch");
2178 } else if (auto f = dyn_cast<FuncOp>(op)) {
2179 symTy = f.getFunctionType();
2180 } else {
2181 llvm_unreachable("Unexpected operation for GetGlobalOp");
2182 }
2183
2184 auto resultType = dyn_cast<PointerType>(getAddr().getType());
2185 if (!resultType || symTy != resultType.getPointee())
2186 return emitOpError("result type pointee type '")
2187 << resultType.getPointee() << "' does not match type " << symTy
2188 << " of the global @" << getName();
2189
2190 if (symAddrSpaceAttr != resultType.getAddrSpace()) {
2191 return emitOpError()
2192 << "result type address space does not match the address "
2193 "space of the global @"
2194 << getName();
2195 }
2196
2197 return success();
2198}
2199
2200//===----------------------------------------------------------------------===//
2201// VTableAddrPointOp
2202//===----------------------------------------------------------------------===//
2203
2204LogicalResult
2205cir::VTableAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2206 StringRef name = getName();
2207
2208 // Verify that the result type underlying pointer type matches the type of
2209 // the referenced cir.global.
2210 auto op =
2211 symbolTable.lookupNearestSymbolFrom<cir::GlobalOp>(*this, getNameAttr());
2212 if (!op)
2213 return emitOpError("'")
2214 << name << "' does not reference a valid cir.global";
2215 std::optional<mlir::Attribute> init = op.getInitialValue();
2216 if (!init)
2217 return success();
2218 if (!isa<cir::VTableAttr>(*init))
2219 return emitOpError("Expected #cir.vtable in initializer for global '")
2220 << name << "'";
2221 return success();
2222}
2223
2224//===----------------------------------------------------------------------===//
2225// VTTAddrPointOp
2226//===----------------------------------------------------------------------===//
2227
2228LogicalResult
2229cir::VTTAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2230 // VTT ptr is not coming from a symbol.
2231 if (!getName())
2232 return success();
2233 StringRef name = *getName();
2234
2235 // Verify that the result type underlying pointer type matches the type of
2236 // the referenced cir.global op.
2237 auto op =
2238 symbolTable.lookupNearestSymbolFrom<cir::GlobalOp>(*this, getNameAttr());
2239 if (!op)
2240 return emitOpError("'")
2241 << name << "' does not reference a valid cir.global";
2242 std::optional<mlir::Attribute> init = op.getInitialValue();
2243 if (!init)
2244 return success();
2245 if (!isa<cir::ConstArrayAttr>(*init))
2246 return emitOpError(
2247 "Expected constant array in initializer for global VTT '")
2248 << name << "'";
2249 return success();
2250}
2251
2252LogicalResult cir::VTTAddrPointOp::verify() {
2253 // The operation uses either a symbol or a value to operate, but not both
2254 if (getName() && getSymAddr())
2255 return emitOpError("should use either a symbol or value, but not both");
2256
2257 // If not a symbol, stick with the concrete type used for getSymAddr.
2258 if (getSymAddr())
2259 return success();
2260
2261 mlir::Type resultType = getAddr().getType();
2262 mlir::Type resTy = cir::PointerType::get(
2263 cir::PointerType::get(cir::VoidType::get(getContext())));
2264
2265 if (resultType != resTy)
2266 return emitOpError("result type must be ")
2267 << resTy << ", but provided result type is " << resultType;
2268 return success();
2269}
2270
2271//===----------------------------------------------------------------------===//
2272// FuncOp
2273//===----------------------------------------------------------------------===//
2274
2275/// Returns the name used for the linkage attribute. This *must* correspond to
2276/// the name of the attribute in ODS.
2277static llvm::StringRef getLinkageAttrNameString() { return "linkage"; }
2278
2279void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
2280 StringRef name, FuncType type,
2281 GlobalLinkageKind linkage, CallingConv callingConv) {
2282 result.addRegion();
2283 result.addAttribute(SymbolTable::getSymbolAttrName(),
2284 builder.getStringAttr(name));
2285 result.addAttribute(getFunctionTypeAttrName(result.name),
2286 TypeAttr::get(type));
2287 result.addAttribute(
2289 GlobalLinkageKindAttr::get(builder.getContext(), linkage));
2290 result.addAttribute(getCallingConvAttrName(result.name),
2291 CallingConvAttr::get(builder.getContext(), callingConv));
2292}
2293
2294//===----------------------------------------------------------------------===//
2295// AnnotationAttr
2296//===----------------------------------------------------------------------===//
2297
2298LogicalResult
2299cir::AnnotationAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2300 mlir::StringAttr name, mlir::ArrayAttr args) {
2301 if (!args)
2302 return success();
2303 for (mlir::Attribute arg : args) {
2304 if (!isa<mlir::StringAttr, mlir::IntegerAttr>(arg))
2305 return emitError() << "annotation args must be StringAttr or IntegerAttr,"
2306 << " got " << arg;
2307 }
2308 return success();
2309}
2310
2311ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
2312 llvm::SMLoc loc = parser.getCurrentLocation();
2313 mlir::Builder &builder = parser.getBuilder();
2314
2315 mlir::StringAttr builtinNameAttr = getBuiltinAttrName(state.name);
2316 mlir::StringAttr coroutineNameAttr = getCoroutineAttrName(state.name);
2317 mlir::StringAttr inlineKindNameAttr = getInlineKindAttrName(state.name);
2318 mlir::StringAttr lambdaNameAttr = getLambdaAttrName(state.name);
2319 mlir::StringAttr noProtoNameAttr = getNoProtoAttrName(state.name);
2320 mlir::StringAttr comdatNameAttr = getComdatAttrName(state.name);
2321 mlir::StringAttr visNameAttr = getSymVisibilityAttrName(state.name);
2322 mlir::StringAttr dsoLocalNameAttr = getDsoLocalAttrName(state.name);
2323 mlir::StringAttr specialMemberAttr = getCxxSpecialMemberAttrName(state.name);
2324
2325 if (::mlir::succeeded(parser.parseOptionalKeyword(builtinNameAttr.strref())))
2326 state.addAttribute(builtinNameAttr, parser.getBuilder().getUnitAttr());
2327 if (::mlir::succeeded(
2328 parser.parseOptionalKeyword(coroutineNameAttr.strref())))
2329 state.addAttribute(coroutineNameAttr, parser.getBuilder().getUnitAttr());
2330
2331 // Parse optional inline kind attribute
2332 cir::InlineKindAttr inlineKindAttr;
2333 if (failed(parseInlineKindAttr(parser, inlineKindAttr)))
2334 return failure();
2335 if (inlineKindAttr)
2336 state.addAttribute(inlineKindNameAttr, inlineKindAttr);
2337
2338 if (::mlir::succeeded(parser.parseOptionalKeyword(lambdaNameAttr.strref())))
2339 state.addAttribute(lambdaNameAttr, parser.getBuilder().getUnitAttr());
2340 if (parser.parseOptionalKeyword(noProtoNameAttr).succeeded())
2341 state.addAttribute(noProtoNameAttr, parser.getBuilder().getUnitAttr());
2342
2343 if (parser.parseOptionalKeyword(comdatNameAttr).succeeded())
2344 state.addAttribute(comdatNameAttr, parser.getBuilder().getUnitAttr());
2345
2346 // Default to external linkage if no keyword is provided.
2347 state.addAttribute(getLinkageAttrNameString(),
2348 GlobalLinkageKindAttr::get(
2349 parser.getContext(),
2351 parser, GlobalLinkageKind::ExternalLinkage)));
2352
2353 ::llvm::StringRef visAttrStr;
2354 if (parser.parseOptionalKeyword(&visAttrStr, {"private", "public", "nested"})
2355 .succeeded()) {
2356 state.addAttribute(visNameAttr,
2357 parser.getBuilder().getStringAttr(visAttrStr));
2358 }
2359
2360 state.getOrAddProperties<cir::FuncOp::Properties>().global_visibility =
2361 parseOptionalCIRKeyword(parser, cir::VisibilityKind::Default);
2362
2363 if (parser.parseOptionalKeyword(dsoLocalNameAttr).succeeded())
2364 state.addAttribute(dsoLocalNameAttr, parser.getBuilder().getUnitAttr());
2365
2366 StringAttr nameAttr;
2367 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2368 state.attributes))
2369 return failure();
2373 bool isVariadic = false;
2374 if (function_interface_impl::parseFunctionSignatureWithArguments(
2375 parser, /*allowVariadic=*/true, arguments, isVariadic, resultTypes,
2376 resultAttrs))
2377 return failure();
2380 bool argAttrsEmpty = true;
2381 for (OpAsmParser::Argument &arg : arguments) {
2382 argTypes.push_back(arg.type);
2383 // Add the 'empty' attribute anyway to make sure the arity matches, but we
2384 // only want to 'set' the attribute at the top level if there is SOME data
2385 // along the way.
2386 argAttrs.push_back(arg.attrs);
2387 if (arg.attrs)
2388 argAttrsEmpty = false;
2389 }
2390
2391 // These should be in sync anyway, but test both of them anyway.
2392 if (resultTypes.size() > 1 || resultAttrs.size() > 1)
2393 return parser.emitError(
2394 loc, "functions with multiple return types are not supported");
2395
2396 mlir::Type returnType =
2397 (resultTypes.empty() ? cir::VoidType::get(builder.getContext())
2398 : resultTypes.front());
2399
2400 cir::FuncType fnType = cir::FuncType::get(argTypes, returnType, isVariadic);
2401 if (!fnType)
2402 return failure();
2403
2404 state.addAttribute(getFunctionTypeAttrName(state.name),
2405 TypeAttr::get(fnType));
2406
2407 if (!resultAttrs.empty() && resultAttrs[0])
2408 state.addAttribute(
2409 getResAttrsAttrName(state.name),
2410 mlir::ArrayAttr::get(parser.getContext(), {resultAttrs[0]}));
2411
2412 if (!argAttrsEmpty)
2413 state.addAttribute(getArgAttrsAttrName(state.name),
2414 mlir::ArrayAttr::get(parser.getContext(), argAttrs));
2415
2416 bool hasAlias = false;
2417 mlir::StringAttr aliaseeNameAttr = getAliaseeAttrName(state.name);
2418 if (parser.parseOptionalKeyword("alias").succeeded()) {
2419 if (parser.parseLParen().failed())
2420 return failure();
2421 mlir::StringAttr aliaseeAttr;
2422 if (parser.parseOptionalSymbolName(aliaseeAttr).failed())
2423 return failure();
2424 state.addAttribute(aliaseeNameAttr, FlatSymbolRefAttr::get(aliaseeAttr));
2425 if (parser.parseRParen().failed())
2426 return failure();
2427 hasAlias = true;
2428 }
2429
2430 mlir::StringAttr personalityNameAttr = getPersonalityAttrName(state.name);
2431 if (parser.parseOptionalKeyword("personality").succeeded()) {
2432 if (parser.parseLParen().failed())
2433 return failure();
2434 mlir::StringAttr personalityAttr;
2435 if (parser.parseOptionalSymbolName(personalityAttr).failed())
2436 return failure();
2437 state.addAttribute(personalityNameAttr,
2438 FlatSymbolRefAttr::get(personalityAttr));
2439 if (parser.parseRParen().failed())
2440 return failure();
2441 }
2442
2443 // Default to C calling convention if no keyword is provided.
2444 mlir::StringAttr callConvNameAttr = getCallingConvAttrName(state.name);
2445 cir::CallingConv callConv = cir::CallingConv::C;
2446 if (parser.parseOptionalKeyword("cc").succeeded()) {
2447 if (parser.parseLParen().failed())
2448 return failure();
2449 if (parseCIRKeyword<cir::CallingConv>(parser, callConv).failed())
2450 return parser.emitError(loc) << "unknown calling convention";
2451 if (parser.parseRParen().failed())
2452 return failure();
2453 }
2454 state.addAttribute(callConvNameAttr,
2455 cir::CallingConvAttr::get(parser.getContext(), callConv));
2456
2457 auto parseGlobalDtorCtor =
2458 [&](StringRef keyword,
2459 llvm::function_ref<void(std::optional<int> prio)> createAttr)
2460 -> mlir::LogicalResult {
2461 if (mlir::succeeded(parser.parseOptionalKeyword(keyword))) {
2462 std::optional<int> priority;
2463 if (mlir::succeeded(parser.parseOptionalLParen())) {
2464 auto parsedPriority = mlir::FieldParser<int>::parse(parser);
2465 if (mlir::failed(parsedPriority))
2466 return parser.emitError(parser.getCurrentLocation(),
2467 "failed to parse 'priority', of type 'int'");
2468 priority = parsedPriority.value_or(int());
2469 // Parse literal ')'
2470 if (parser.parseRParen())
2471 return failure();
2472 }
2473 createAttr(priority);
2474 }
2475 return success();
2476 };
2477
2478 // Parse CXXSpecialMember attribute
2479 if (parser.parseOptionalKeyword("special_member").succeeded()) {
2480 if (parser.parseLess().failed())
2481 return failure();
2482
2483 mlir::Attribute attr;
2484 if (parser.parseAttribute(attr).failed())
2485 return failure();
2486 if (!mlir::isa<cir::CXXCtorAttr, cir::CXXDtorAttr, cir::CXXAssignAttr>(
2487 attr))
2488 return parser.emitError(parser.getCurrentLocation(),
2489 "expected a C++ special member attribute");
2490 state.addAttribute(specialMemberAttr, attr);
2491
2492 if (parser.parseGreater().failed())
2493 return failure();
2494 }
2495
2496 if (parseGlobalDtorCtor("global_ctor", [&](std::optional<int> priority) {
2497 mlir::IntegerAttr globalCtorPriorityAttr =
2498 builder.getI32IntegerAttr(priority.value_or(65535));
2499 state.addAttribute(getGlobalCtorPriorityAttrName(state.name),
2500 globalCtorPriorityAttr);
2501 }).failed())
2502 return failure();
2503
2504 if (parseGlobalDtorCtor("global_dtor", [&](std::optional<int> priority) {
2505 mlir::IntegerAttr globalDtorPriorityAttr =
2506 builder.getI32IntegerAttr(priority.value_or(65535));
2507 state.addAttribute(getGlobalDtorPriorityAttrName(state.name),
2508 globalDtorPriorityAttr);
2509 }).failed())
2510 return failure();
2511
2512 if (parser.parseOptionalKeyword("side_effect").succeeded()) {
2513 cir::SideEffect sideEffect;
2514
2515 if (parser.parseLParen().failed() ||
2516 parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed() ||
2517 parser.parseRParen().failed())
2518 return failure();
2519
2520 auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
2521 state.addAttribute(CIRDialect::getSideEffectAttrName(), attr);
2522 }
2523
2524 // Parse optional annotations attribute (an ArrayAttr of AnnotationAttr).
2525 mlir::StringAttr annotationsNameAttr = getAnnotationsAttrName(state.name);
2526 mlir::ArrayAttr annotationsAttr;
2527 if (parser.parseOptionalAttribute(annotationsAttr).has_value() &&
2528 annotationsAttr)
2529 state.addAttribute(annotationsNameAttr, annotationsAttr);
2530
2531 // Parse the rest of the attributes.
2532 NamedAttrList parsedAttrs;
2533 if (parser.parseOptionalAttrDictWithKeyword(parsedAttrs))
2534 return failure();
2535
2536 for (StringRef disallowed : cir::FuncOp::getAttributeNames()) {
2537 if (parsedAttrs.get(disallowed))
2538 return parser.emitError(loc, "attribute '")
2539 << disallowed
2540 << "' should not be specified in the explicit attribute list";
2541 }
2542
2543 state.attributes.append(parsedAttrs);
2544
2545 // Parse the optional function body.
2546 auto *body = state.addRegion();
2547 OptionalParseResult parseResult = parser.parseOptionalRegion(
2548 *body, arguments, /*enableNameShadowing=*/false);
2549 if (parseResult.has_value()) {
2550 if (hasAlias)
2551 return parser.emitError(loc, "function alias shall not have a body");
2552 if (failed(*parseResult))
2553 return failure();
2554 // Function body was parsed, make sure its not empty.
2555 if (body->empty())
2556 return parser.emitError(loc, "expected non-empty function body");
2557 }
2558
2559 return success();
2560}
2561
2562// This function corresponds to `llvm::GlobalValue::isDeclaration` and should
2563// have a similar implementation. We don't currently ifuncs or materializable
2564// functions, but those should be handled here as they are implemented.
2565bool cir::FuncOp::isDeclaration() {
2567
2568 std::optional<StringRef> aliasee = getAliasee();
2569 if (!aliasee)
2570 return getFunctionBody().empty();
2571
2572 // Aliases are always definitions.
2573 return false;
2574}
2575
2576bool cir::FuncOp::isCXXSpecialMemberFunction() {
2577 return getCxxSpecialMemberAttr() != nullptr;
2578}
2579
2580bool cir::FuncOp::isCxxConstructor() {
2581 auto attr = getCxxSpecialMemberAttr();
2582 return attr && dyn_cast<CXXCtorAttr>(attr);
2583}
2584
2585bool cir::FuncOp::isCxxDestructor() {
2586 auto attr = getCxxSpecialMemberAttr();
2587 return attr && dyn_cast<CXXDtorAttr>(attr);
2588}
2589
2590bool cir::FuncOp::isCxxSpecialAssignment() {
2591 auto attr = getCxxSpecialMemberAttr();
2592 return attr && dyn_cast<CXXAssignAttr>(attr);
2593}
2594
2595std::optional<CtorKind> cir::FuncOp::getCxxConstructorKind() {
2596 mlir::Attribute attr = getCxxSpecialMemberAttr();
2597 if (attr) {
2598 if (auto ctor = dyn_cast<CXXCtorAttr>(attr))
2599 return ctor.getCtorKind();
2600 }
2601 return std::nullopt;
2602}
2603
2604std::optional<AssignKind> cir::FuncOp::getCxxSpecialAssignKind() {
2605 mlir::Attribute attr = getCxxSpecialMemberAttr();
2606 if (attr) {
2607 if (auto assign = dyn_cast<CXXAssignAttr>(attr))
2608 return assign.getAssignKind();
2609 }
2610 return std::nullopt;
2611}
2612
2613bool cir::FuncOp::isCxxTrivialMemberFunction() {
2614 mlir::Attribute attr = getCxxSpecialMemberAttr();
2615 if (attr) {
2616 if (auto ctor = dyn_cast<CXXCtorAttr>(attr))
2617 return ctor.getIsTrivial();
2618 if (auto dtor = dyn_cast<CXXDtorAttr>(attr))
2619 return dtor.getIsTrivial();
2620 if (auto assign = dyn_cast<CXXAssignAttr>(attr))
2621 return assign.getIsTrivial();
2622 }
2623 return false;
2624}
2625
2626mlir::Region *cir::FuncOp::getCallableRegion() {
2627 // TODO(CIR): This function will have special handling for aliases and a
2628 // check for an external function, once those features have been upstreamed.
2629 return &getBody();
2630}
2631
2632void cir::FuncOp::print(OpAsmPrinter &p) {
2633 if (getBuiltin())
2634 p << " builtin";
2635
2636 if (getCoroutine())
2637 p << " coroutine";
2638
2639 printInlineKindAttr(p, getInlineKindAttr());
2640
2641 if (getLambda())
2642 p << " lambda";
2643
2644 if (getNoProto())
2645 p << " no_proto";
2646
2647 if (getComdat())
2648 p << " comdat";
2649
2650 if (getLinkage() != GlobalLinkageKind::ExternalLinkage)
2651 p << ' ' << stringifyGlobalLinkageKind(getLinkage());
2652
2653 mlir::SymbolTable::Visibility vis = getVisibility();
2654 if (vis != mlir::SymbolTable::Visibility::Public)
2655 p << ' ' << vis;
2656
2657 if (getGlobalVisibility() != cir::VisibilityKind::Default)
2658 p << ' ' << stringifyVisibilityKind(getGlobalVisibility());
2659
2660 if (getDsoLocal())
2661 p << " dso_local";
2662
2663 p << ' ';
2664 p.printSymbolName(getSymName());
2665 cir::FuncType fnType = getFunctionType();
2666 function_interface_impl::printFunctionSignature(
2667 p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes());
2668
2669 if (std::optional<StringRef> aliaseeName = getAliasee()) {
2670 p << " alias(";
2671 p.printSymbolName(*aliaseeName);
2672 p << ")";
2673 }
2674
2675 if (getCallingConv() != cir::CallingConv::C) {
2676 p << " cc(";
2677 p << stringifyCallingConv(getCallingConv());
2678 p << ")";
2679 }
2680
2681 if (std::optional<StringRef> personalityName = getPersonality()) {
2682 p << " personality(";
2683 p.printSymbolName(*personalityName);
2684 p << ")";
2685 }
2686
2687 if (auto specialMemberAttr = getCxxSpecialMember()) {
2688 p << " special_member<";
2689 p.printAttribute(*specialMemberAttr);
2690 p << '>';
2691 }
2692
2693 if (auto globalCtorPriority = getGlobalCtorPriority()) {
2694 p << " global_ctor";
2695 if (globalCtorPriority.value() != 65535)
2696 p << "(" << globalCtorPriority.value() << ")";
2697 }
2698
2699 if (auto globalDtorPriority = getGlobalDtorPriority()) {
2700 p << " global_dtor";
2701 if (globalDtorPriority.value() != 65535)
2702 p << "(" << globalDtorPriority.value() << ")";
2703 }
2704
2705 if (std::optional<cir::SideEffect> sideEffect = getSideEffect();
2706 sideEffect && *sideEffect != cir::SideEffect::All) {
2707 p << " side_effect(";
2708 p << stringifySideEffect(*sideEffect);
2709 p << ")";
2710 }
2711
2712 if (mlir::ArrayAttr annotations = getAnnotationsAttr()) {
2713 p << ' ';
2714 p.printAttribute(annotations);
2715 }
2716
2717 function_interface_impl::printFunctionAttributes(
2718 p, *this, cir::FuncOp::getAttributeNames());
2719
2720 // Print the body if this is not an external function.
2721 Region &body = getOperation()->getRegion(0);
2722 if (!body.empty()) {
2723 p << ' ';
2724 p.printRegion(body, /*printEntryBlockArgs=*/false,
2725 /*printBlockTerminators=*/true);
2726 }
2727}
2728
2729mlir::LogicalResult cir::FuncOp::verify() {
2730
2731 if (!isDeclaration() && getCoroutine()) {
2732 bool foundAwait = false;
2733 int coroBodyCount = 0;
2734 this->walk([&](Operation *op) {
2735 if (auto await = dyn_cast<AwaitOp>(op)) {
2736 foundAwait = true;
2737 } else if (isa<CoroBodyOp>(op)) {
2738 coroBodyCount++;
2739 if (coroBodyCount > 1) {
2740 return mlir::WalkResult::interrupt();
2741 }
2742 }
2743 return mlir::WalkResult::advance();
2744 });
2745 if (!foundAwait)
2746 return emitOpError()
2747 << "coroutine body must use at least one cir.await op";
2748 if (coroBodyCount != 1)
2749 return emitOpError()
2750 << "coroutine function must have exactly one cir.body op";
2751 }
2752
2753 llvm::SmallSet<llvm::StringRef, 16> labels;
2754 llvm::SmallSet<llvm::StringRef, 16> gotos;
2755 llvm::SmallSet<llvm::StringRef, 16> blockAddresses;
2756 bool invalidBlockAddress = false;
2757 getOperation()->walk([&](mlir::Operation *op) {
2758 if (auto lab = dyn_cast<cir::LabelOp>(op)) {
2759 labels.insert(lab.getLabel());
2760 } else if (auto goTo = dyn_cast<cir::GotoOp>(op)) {
2761 gotos.insert(goTo.getLabel());
2762 } else if (auto blkAdd = dyn_cast<cir::BlockAddressOp>(op)) {
2763 if (blkAdd.getBlockAddrInfoAttr().getFunc().getAttr() != getSymName()) {
2764 // Stop the walk early, no need to continue
2765 invalidBlockAddress = true;
2766 return mlir::WalkResult::interrupt();
2767 }
2768 blockAddresses.insert(blkAdd.getBlockAddrInfoAttr().getLabel());
2769 }
2770 return mlir::WalkResult::advance();
2771 });
2772
2773 if (invalidBlockAddress)
2774 return emitOpError() << "blockaddress references a different function";
2775
2776 llvm::SmallSet<llvm::StringRef, 16> mismatched;
2777 if (!labels.empty() || !gotos.empty()) {
2778 mismatched = llvm::set_difference(gotos, labels);
2779
2780 if (!mismatched.empty())
2781 return emitOpError() << "goto/label mismatch";
2782 }
2783
2784 mismatched.clear();
2785
2786 if (!labels.empty() || !blockAddresses.empty()) {
2787 mismatched = llvm::set_difference(blockAddresses, labels);
2788
2789 if (!mismatched.empty())
2790 return emitOpError()
2791 << "expects an existing label target in the referenced function";
2792 }
2793
2794 return success();
2795}
2796
2797//===----------------------------------------------------------------------===//
2798// AddOp / SubOp / MulOp
2799//===----------------------------------------------------------------------===//
2800
2801static LogicalResult verifyBinaryOverflowOp(mlir::Operation *op,
2802 bool noSignedWrap,
2803 bool noUnsignedWrap, bool saturated,
2804 bool hasSat) {
2805 bool noWrap = noSignedWrap || noUnsignedWrap;
2806 if (!isa<cir::IntType>(op->getResultTypes()[0]) && noWrap)
2807 return op->emitError()
2808 << "only operations on integer values may have nsw/nuw flags";
2809 if (hasSat && saturated && !isa<cir::IntType>(op->getResultTypes()[0]))
2810 return op->emitError()
2811 << "only operations on integer values may have sat flag";
2812 if (hasSat && noWrap && saturated)
2813 return op->emitError()
2814 << "the nsw/nuw flags and the saturated flag are mutually exclusive";
2815 return mlir::success();
2816}
2817
2818LogicalResult cir::AddOp::verify() {
2819 return verifyBinaryOverflowOp(getOperation(), getNoSignedWrap(),
2820 getNoUnsignedWrap(), getSaturated(),
2821 /*hasSat=*/true);
2822}
2823
2824LogicalResult cir::SubOp::verify() {
2825 return verifyBinaryOverflowOp(getOperation(), getNoSignedWrap(),
2826 getNoUnsignedWrap(), getSaturated(),
2827 /*hasSat=*/true);
2828}
2829
2830LogicalResult cir::MulOp::verify() {
2831 return verifyBinaryOverflowOp(getOperation(), getNoSignedWrap(),
2832 getNoUnsignedWrap(), /*saturated=*/false,
2833 /*hasSat=*/false);
2834}
2835
2836//===----------------------------------------------------------------------===//
2837// TernaryOp
2838//===----------------------------------------------------------------------===//
2839
2840/// Given the region at `point`, or the parent operation if `point` is None,
2841/// return the successor regions. These are the regions that may be selected
2842/// during the flow of control. `operands` is a set of optional attributes that
2843/// correspond to a constant value for each operand, or null if that operand is
2844/// not a constant.
2845void cir::TernaryOp::getSuccessorRegions(
2846 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2847 // The `true` and the `false` region branch back to the parent operation.
2848 if (!point.isParent()) {
2849 regions.push_back(RegionSuccessor::parent());
2850 return;
2851 }
2852
2853 // When branching from the parent operation, both the true and false
2854 // regions are considered possible successors
2855 regions.push_back(RegionSuccessor(&getTrueRegion()));
2856 regions.push_back(RegionSuccessor(&getFalseRegion()));
2857}
2858
2859mlir::ValueRange cir::TernaryOp::getSuccessorInputs(RegionSuccessor successor) {
2860 return successor.isParent() ? ValueRange(getOperation()->getResults())
2861 : ValueRange();
2862}
2863
2864void cir::TernaryOp::build(
2865 OpBuilder &builder, OperationState &result, Value cond,
2866 function_ref<void(OpBuilder &, Location)> trueBuilder,
2867 function_ref<void(OpBuilder &, Location)> falseBuilder) {
2868 result.addOperands(cond);
2869 OpBuilder::InsertionGuard guard(builder);
2870 Region *trueRegion = result.addRegion();
2871 builder.createBlock(trueRegion);
2872 trueBuilder(builder, result.location);
2873 Region *falseRegion = result.addRegion();
2874 builder.createBlock(falseRegion);
2875 falseBuilder(builder, result.location);
2876
2877 // Get result type from whichever branch has a yield (the other may have
2878 // unreachable from a throw expression)
2879 cir::YieldOp yield;
2880 if (trueRegion->back().mightHaveTerminator())
2881 yield = dyn_cast_or_null<cir::YieldOp>(trueRegion->back().getTerminator());
2882 if (!yield && falseRegion->back().mightHaveTerminator())
2883 yield = dyn_cast_or_null<cir::YieldOp>(falseRegion->back().getTerminator());
2884
2885 assert((!yield || yield.getNumOperands() <= 1) &&
2886 "expected zero or one result type");
2887 if (yield && yield.getNumOperands() == 1)
2888 result.addTypes(TypeRange{yield.getOperandTypes().front()});
2889}
2890
2891//===----------------------------------------------------------------------===//
2892// SelectOp
2893//===----------------------------------------------------------------------===//
2894
2895OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
2896 mlir::Attribute condition = adaptor.getCondition();
2897 if (condition) {
2898 bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
2899 return conditionValue ? getTrueValue() : getFalseValue();
2900 }
2901
2902 // cir.select if %0 then x else x -> x
2903 mlir::Attribute trueValue = adaptor.getTrueValue();
2904 mlir::Attribute falseValue = adaptor.getFalseValue();
2905 if (trueValue == falseValue)
2906 return trueValue;
2907 if (getTrueValue() == getFalseValue())
2908 return getTrueValue();
2909
2910 return {};
2911}
2912
2913LogicalResult cir::SelectOp::verify() {
2914 // AllTypesMatch already guarantees trueVal and falseVal have matching types.
2915 auto condTy = dyn_cast<cir::VectorType>(getCondition().getType());
2916
2917 // If condition is not a vector, no further checks are needed.
2918 if (!condTy)
2919 return success();
2920
2921 // When condition is a vector, both other operands must also be vectors.
2922 if (!isa<cir::VectorType>(getTrueValue().getType()) ||
2923 !isa<cir::VectorType>(getFalseValue().getType())) {
2924 return emitOpError()
2925 << "expected both true and false operands to be vector types "
2926 "when the condition is a vector boolean type";
2927 }
2928
2929 return success();
2930}
2931
2932//===----------------------------------------------------------------------===//
2933// ShiftOp
2934//===----------------------------------------------------------------------===//
2935LogicalResult cir::ShiftOp::verify() {
2936 mlir::Operation *op = getOperation();
2937 auto op0VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(0).getType());
2938 auto op1VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(1).getType());
2939 if (!op0VecTy ^ !op1VecTy)
2940 return emitOpError() << "input types cannot be one vector and one scalar";
2941
2942 if (op0VecTy) {
2943 if (op0VecTy.getSize() != op1VecTy.getSize())
2944 return emitOpError() << "input vector types must have the same size";
2945
2946 auto opResultTy = mlir::dyn_cast<cir::VectorType>(getType());
2947 if (!opResultTy)
2948 return emitOpError() << "the type of the result must be a vector "
2949 << "if it is vector shift";
2950
2951 auto op0VecEleTy = mlir::cast<cir::IntType>(op0VecTy.getElementType());
2952 auto op1VecEleTy = mlir::cast<cir::IntType>(op1VecTy.getElementType());
2953 if (op0VecEleTy.getWidth() != op1VecEleTy.getWidth())
2954 return emitOpError()
2955 << "vector operands do not have the same elements sizes";
2956
2957 auto resVecEleTy = mlir::cast<cir::IntType>(opResultTy.getElementType());
2958 if (op0VecEleTy.getWidth() != resVecEleTy.getWidth())
2959 return emitOpError() << "vector operands and result type do not have the "
2960 "same elements sizes";
2961 }
2962
2963 return mlir::success();
2964}
2965
2966//===----------------------------------------------------------------------===//
2967// LabelOp Definitions
2968//===----------------------------------------------------------------------===//
2969
2970LogicalResult cir::LabelOp::verify() {
2971 mlir::Operation *op = getOperation();
2972 mlir::Block *blk = op->getBlock();
2973 if (&blk->front() != op)
2974 return emitError() << "must be the first operation in a block";
2975
2976 return mlir::success();
2977}
2978
2979//===----------------------------------------------------------------------===//
2980// IncOp
2981//===----------------------------------------------------------------------===//
2982
2983OpFoldResult cir::IncOp::fold(FoldAdaptor adaptor) {
2984 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()))
2985 return adaptor.getInput();
2986 return {};
2987}
2988
2989//===----------------------------------------------------------------------===//
2990// DecOp
2991//===----------------------------------------------------------------------===//
2992
2993OpFoldResult cir::DecOp::fold(FoldAdaptor adaptor) {
2994 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()))
2995 return adaptor.getInput();
2996 return {};
2997}
2998
2999//===----------------------------------------------------------------------===//
3000// MinusOp
3001//===----------------------------------------------------------------------===//
3002
3003OpFoldResult cir::MinusOp::fold(FoldAdaptor adaptor) {
3004 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()))
3005 return adaptor.getInput();
3006
3007 // Avoid materializing a duplicate constant for bool minus (identity).
3008 if (auto srcConst = getInput().getDefiningOp<cir::ConstantOp>())
3009 if (mlir::isa<cir::BoolType>(srcConst.getType()))
3010 return srcConst.getResult();
3011
3012 // Fold with constant inputs.
3013 if (mlir::Attribute attr = adaptor.getInput()) {
3014 if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(attr)) {
3015 APInt val = intAttr.getValue();
3016 val.negate();
3017 return cir::IntAttr::get(getType(), val);
3018 }
3019 if (auto fpAttr = mlir::dyn_cast<cir::FPAttr>(attr)) {
3020 APFloat val = fpAttr.getValue();
3021 val.changeSign();
3022 return cir::FPAttr::get(getType(), val);
3023 }
3024 }
3025
3026 return {};
3027}
3028
3029//===----------------------------------------------------------------------===//
3030// NotOp
3031//===----------------------------------------------------------------------===//
3032
3033OpFoldResult cir::NotOp::fold(FoldAdaptor adaptor) {
3034 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()))
3035 return adaptor.getInput();
3036
3037 // not(not(x)) -> x is handled by the Involution trait.
3038
3039 // Fold with constant inputs.
3040 if (mlir::Attribute attr = adaptor.getInput()) {
3041 if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(attr)) {
3042 APInt val = intAttr.getValue();
3043 val.flipAllBits();
3044 return cir::IntAttr::get(getType(), val);
3045 }
3046 if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr))
3047 return cir::BoolAttr::get(getContext(), !boolAttr.getValue());
3048 }
3049
3050 return {};
3051}
3052
3053//===----------------------------------------------------------------------===//
3054// BaseDataMemberOp & DerivedDataMemberOp
3055//===----------------------------------------------------------------------===//
3056
3057static LogicalResult verifyMemberPtrCast(Operation *op, mlir::Value src,
3058 mlir::Type resultTy) {
3059 // Let the operand type be T1 C1::*, let the result type be T2 C2::*.
3060 // Verify that T1 and T2 are the same type.
3061 mlir::Type inputMemberTy;
3062 mlir::Type resultMemberTy;
3063 if (mlir::isa<cir::DataMemberType>(src.getType())) {
3064 inputMemberTy =
3065 mlir::cast<cir::DataMemberType>(src.getType()).getMemberTy();
3066 resultMemberTy = mlir::cast<cir::DataMemberType>(resultTy).getMemberTy();
3067 }
3069 if (inputMemberTy != resultMemberTy)
3070 return op->emitOpError()
3071 << "member types of the operand and the result do not match";
3072
3073 return mlir::success();
3074}
3075
3076LogicalResult cir::BaseDataMemberOp::verify() {
3077 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
3078}
3079
3080LogicalResult cir::DerivedDataMemberOp::verify() {
3081 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
3082}
3083
3084//===----------------------------------------------------------------------===//
3085// BaseMethodOp & DerivedMethodOp
3086//===----------------------------------------------------------------------===//
3087
3088LogicalResult cir::BaseMethodOp::verify() {
3089 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
3090}
3091
3092LogicalResult cir::DerivedMethodOp::verify() {
3093 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
3094}
3095
3096//===----------------------------------------------------------------------===//
3097// AwaitOp
3098//===----------------------------------------------------------------------===//
3099
3100void cir::AwaitOp::build(OpBuilder &builder, OperationState &result,
3101 cir::AwaitKind kind, BuilderCallbackRef readyBuilder,
3102 BuilderCallbackRef suspendBuilder,
3103 BuilderCallbackRef resumeBuilder) {
3104 result.addAttribute(getKindAttrName(result.name),
3105 cir::AwaitKindAttr::get(builder.getContext(), kind));
3106 {
3107 OpBuilder::InsertionGuard guard(builder);
3108 Region *readyRegion = result.addRegion();
3109 builder.createBlock(readyRegion);
3110 readyBuilder(builder, result.location);
3111 }
3112
3113 {
3114 OpBuilder::InsertionGuard guard(builder);
3115 Region *suspendRegion = result.addRegion();
3116 builder.createBlock(suspendRegion);
3117 suspendBuilder(builder, result.location);
3118 }
3119
3120 {
3121 OpBuilder::InsertionGuard guard(builder);
3122 Region *resumeRegion = result.addRegion();
3123 builder.createBlock(resumeRegion);
3124 resumeBuilder(builder, result.location);
3125 }
3126}
3127
3128void cir::AwaitOp::getSuccessorRegions(
3129 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
3130 // If any index all the underlying regions branch back to the parent
3131 // operation.
3132 if (!point.isParent()) {
3133 regions.push_back(RegionSuccessor::parent());
3134 return;
3135 }
3136
3137 // TODO: retrieve information from the promise and only push the
3138 // necessary ones. Example: `std::suspend_never` on initial or final
3139 // await's might allow suspend region to be skipped.
3140 regions.push_back(RegionSuccessor(&this->getReady()));
3141 regions.push_back(RegionSuccessor(&this->getSuspend()));
3142 regions.push_back(RegionSuccessor(&this->getResume()));
3143}
3144
3145mlir::ValueRange cir::AwaitOp::getSuccessorInputs(RegionSuccessor successor) {
3146 if (successor.isParent())
3147 return getOperation()->getResults();
3148 if (successor == &getReady())
3149 return getReady().getArguments();
3150 if (successor == &getSuspend())
3151 return getSuspend().getArguments();
3152 if (successor == &getResume())
3153 return getResume().getArguments();
3154 llvm_unreachable("invalid region successor");
3155}
3156
3157LogicalResult cir::AwaitOp::verify() {
3158 if (!isa<ConditionOp>(this->getReady().back().getTerminator()))
3159 return emitOpError("ready region must end with cir.condition");
3160 return success();
3161}
3162
3163//===----------------------------------------------------------------------===//
3164// CoroBody
3165//===----------------------------------------------------------------------===//
3166
3167void cir::CoroBodyOp::getSuccessorRegions(
3168 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
3169 if (!point.isParent()) {
3170 regions.push_back(RegionSuccessor::parent());
3171 return;
3172 }
3173
3174 regions.push_back(RegionSuccessor(&getBody()));
3175}
3176
3177mlir::ValueRange
3178cir::CoroBodyOp::getSuccessorInputs(RegionSuccessor successor) {
3179 return ValueRange();
3180}
3181
3182LogicalResult cir::CoroBodyOp::verify() {
3183 if (!getOperation()->getParentOfType<FuncOp>().getCoroutine())
3184 return emitOpError("enclosing function must be a coroutine");
3185 return success();
3186}
3187
3188void cir::CoroBodyOp::build(OpBuilder &builder, OperationState &result,
3189 BuilderCallbackRef bodyBuilder) {
3190 assert(bodyBuilder &&
3191 "the builder callback for 'CoroBodyOp' must be present");
3192 OpBuilder::InsertionGuard guard(builder);
3193
3194 Region *bodyRegion = result.addRegion();
3195 builder.createBlock(bodyRegion);
3196 bodyBuilder(builder, result.location);
3197}
3198
3199//===----------------------------------------------------------------------===//
3200// CopyOp Definitions
3201//===----------------------------------------------------------------------===//
3202
3203LogicalResult cir::CopyOp::verify() {
3204 // A data layout is required for us to know the number of bytes to be copied.
3205 if (!getType().getPointee().hasTrait<DataLayoutTypeInterface::Trait>())
3206 return emitError() << "missing data layout for pointee type";
3207
3208 if (getSkipTailPadding() &&
3209 !mlir::isa<cir::RecordType>(getType().getPointee()))
3210 return emitError()
3211 << "skip_tail_padding is only valid for record pointee types";
3212
3213 return mlir::success();
3214}
3215
3216//===----------------------------------------------------------------------===//
3217// GetRuntimeMemberOp Definitions
3218//===----------------------------------------------------------------------===//
3219
3220LogicalResult cir::GetRuntimeMemberOp::verify() {
3221 auto recordTy = mlir::cast<RecordType>(getAddr().getType().getPointee());
3222 cir::DataMemberType memberPtrTy = getMember().getType();
3223
3224 if (recordTy != memberPtrTy.getClassTy())
3225 return emitError() << "record type does not match the member pointer type";
3226 if (getType().getPointee() != memberPtrTy.getMemberTy())
3227 return emitError() << "result type does not match the member pointer type";
3228 return mlir::success();
3229}
3230
3231//===----------------------------------------------------------------------===//
3232// GetMethodOp Definitions
3233//===----------------------------------------------------------------------===//
3234
3235LogicalResult cir::GetMethodOp::verify() {
3236 cir::MethodType methodTy = getMethod().getType();
3237
3238 // Assume objectTy is !cir.ptr<!T>
3239 cir::PointerType objectPtrTy = getObject().getType();
3240 mlir::Type objectTy = objectPtrTy.getPointee();
3241
3242 if (methodTy.getClassTy() != objectTy)
3243 return emitError() << "method class type and object type do not match";
3244
3245 // Assume methodFuncTy is !cir.func<!Ret (!Args)>
3246 auto calleeTy = mlir::cast<cir::FuncType>(getCallee().getType().getPointee());
3247 cir::FuncType methodFuncTy = methodTy.getMemberFuncTy();
3248
3249 // We verify at here that calleeTy is !cir.func<!Ret (!cir.ptr<!void>, !Args)>
3250 // Note that the first parameter type of the callee is !cir.ptr<!void> instead
3251 // of !cir.ptr<!T> because the "this" pointer may be adjusted before calling
3252 // the callee.
3253
3254 if (methodFuncTy.getReturnType() != calleeTy.getReturnType())
3255 return emitError()
3256 << "method return type and callee return type do not match";
3257
3258 llvm::ArrayRef<mlir::Type> calleeArgsTy = calleeTy.getInputs();
3259 llvm::ArrayRef<mlir::Type> methodFuncArgsTy = methodFuncTy.getInputs();
3260
3261 if (calleeArgsTy.empty())
3262 return emitError() << "callee parameter list lacks receiver object ptr";
3263
3264 auto calleeThisArgPtrTy = mlir::dyn_cast<cir::PointerType>(calleeArgsTy[0]);
3265 if (!calleeThisArgPtrTy ||
3266 !mlir::isa<cir::VoidType>(calleeThisArgPtrTy.getPointee())) {
3267 return emitError()
3268 << "the first parameter of callee must be a void pointer";
3269 }
3270
3271 if (calleeArgsTy.size() != methodFuncArgsTy.size())
3272 return emitError() << "callee and method parameter counts do not match";
3273
3274 if (calleeArgsTy.size() > 1 &&
3275 calleeArgsTy.slice(1) != methodFuncArgsTy.slice(1))
3276 return emitError()
3277 << "callee parameters and method parameters do not match";
3278
3279 return mlir::success();
3280}
3281
3282//===----------------------------------------------------------------------===//
3283// GetMemberOp Definitions
3284//===----------------------------------------------------------------------===//
3285
3286LogicalResult cir::GetMemberOp::verify() {
3287 const auto recordTy = dyn_cast<RecordType>(getAddrTy().getPointee());
3288 if (!recordTy)
3289 return emitError() << "expected pointer to a record type";
3290
3291 if (recordTy.getMembers().size() <= getIndex())
3292 return emitError() << "member index out of bounds";
3293
3294 if (recordTy.getMembers()[getIndex()] != getType().getPointee())
3295 return emitError() << "member type mismatch";
3296
3297 return mlir::success();
3298}
3299
3300//===----------------------------------------------------------------------===//
3301// ExtractMemberOp Definitions
3302//===----------------------------------------------------------------------===//
3303
3304LogicalResult cir::ExtractMemberOp::verify() {
3305 auto recordTy = mlir::cast<cir::RecordType>(getRecord().getType());
3306 if (recordTy.getKind() == cir::RecordType::Union)
3307 return emitError()
3308 << "cir.extract_member currently does not support unions";
3309 if (recordTy.getMembers().size() <= getIndex())
3310 return emitError() << "member index out of bounds";
3311 if (recordTy.getMembers()[getIndex()] != getType())
3312 return emitError() << "member type mismatch";
3313 return mlir::success();
3314}
3315
3316//===----------------------------------------------------------------------===//
3317// InsertMemberOp Definitions
3318//===----------------------------------------------------------------------===//
3319
3320LogicalResult cir::InsertMemberOp::verify() {
3321 auto recordTy = mlir::cast<cir::RecordType>(getRecord().getType());
3322 if (recordTy.getKind() == cir::RecordType::Union)
3323 return emitError() << "cir.insert_member currently does not support unions";
3324 if (recordTy.getMembers().size() <= getIndex())
3325 return emitError() << "member index out of bounds";
3326 if (recordTy.getMembers()[getIndex()] != getValue().getType())
3327 return emitError() << "member type mismatch";
3328 // The op trait already checks that the types of $result and $record match.
3329 return mlir::success();
3330}
3331
3332//===----------------------------------------------------------------------===//
3333// VecCreateOp
3334//===----------------------------------------------------------------------===//
3335
3336OpFoldResult cir::VecCreateOp::fold(FoldAdaptor adaptor) {
3337 if (llvm::any_of(getElements(), [](mlir::Value value) {
3338 return !value.getDefiningOp<cir::ConstantOp>();
3339 }))
3340 return {};
3341
3342 return cir::ConstVectorAttr::get(
3343 getType(), mlir::ArrayAttr::get(getContext(), adaptor.getElements()));
3344}
3345
3346LogicalResult cir::VecCreateOp::verify() {
3347 // Verify that the number of arguments matches the number of elements in the
3348 // vector, and that the type of all the arguments matches the type of the
3349 // elements in the vector.
3350 const cir::VectorType vecTy = getType();
3351 if (getElements().size() != vecTy.getSize()) {
3352 return emitOpError() << "operand count of " << getElements().size()
3353 << " doesn't match vector type " << vecTy
3354 << " element count of " << vecTy.getSize();
3355 }
3356
3357 const mlir::Type elementType = vecTy.getElementType();
3358 for (const mlir::Value element : getElements()) {
3359 if (element.getType() != elementType) {
3360 return emitOpError() << "operand type " << element.getType()
3361 << " doesn't match vector element type "
3362 << elementType;
3363 }
3364 }
3365
3366 return success();
3367}
3368
3369//===----------------------------------------------------------------------===//
3370// VecExtractOp
3371//===----------------------------------------------------------------------===//
3372
3373OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
3374 const auto vectorAttr =
3375 llvm::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec());
3376 if (!vectorAttr)
3377 return {};
3378
3379 const auto indexAttr =
3380 llvm::dyn_cast_if_present<cir::IntAttr>(adaptor.getIndex());
3381 if (!indexAttr)
3382 return {};
3383
3384 const mlir::ArrayAttr elements = vectorAttr.getElts();
3385 const uint64_t index = indexAttr.getUInt();
3386 if (index >= elements.size())
3387 return {};
3388
3389 return elements[index];
3390}
3391
3392//===----------------------------------------------------------------------===//
3393// VecCmpOp
3394//===----------------------------------------------------------------------===//
3395
3396OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
3397 auto lhsVecAttr =
3398 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getLhs());
3399 auto rhsVecAttr =
3400 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getRhs());
3401 if (!lhsVecAttr || !rhsVecAttr)
3402 return {};
3403
3404 mlir::Type inputElemTy =
3405 mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType();
3406 if (!isAnyIntegerOrFloatingPointType(inputElemTy))
3407 return {};
3408
3409 cir::CmpOpKind opKind = adaptor.getKind();
3410 mlir::ArrayAttr lhsVecElhs = lhsVecAttr.getElts();
3411 mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts();
3412 uint64_t vecSize = lhsVecElhs.size();
3413
3414 SmallVector<mlir::Attribute, 16> elements(vecSize);
3415 bool isIntAttr = vecSize && mlir::isa<cir::IntAttr>(lhsVecElhs[0]);
3416 for (uint64_t i = 0; i < vecSize; i++) {
3417 mlir::Attribute lhsAttr = lhsVecElhs[i];
3418 mlir::Attribute rhsAttr = rhsVecElhs[i];
3419 int cmpResult = 0;
3420 switch (opKind) {
3421 case cir::CmpOpKind::lt: {
3422 if (isIntAttr) {
3423 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
3424 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3425 } else {
3426 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <
3427 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3428 }
3429 break;
3430 }
3431 case cir::CmpOpKind::le: {
3432 if (isIntAttr) {
3433 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
3434 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3435 } else {
3436 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <=
3437 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3438 }
3439 break;
3440 }
3441 case cir::CmpOpKind::gt: {
3442 if (isIntAttr) {
3443 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
3444 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3445 } else {
3446 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >
3447 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3448 }
3449 break;
3450 }
3451 case cir::CmpOpKind::ge: {
3452 if (isIntAttr) {
3453 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
3454 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3455 } else {
3456 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >=
3457 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3458 }
3459 break;
3460 }
3461 case cir::CmpOpKind::eq: {
3462 if (isIntAttr) {
3463 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
3464 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3465 } else {
3466 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() ==
3467 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3468 }
3469 break;
3470 }
3471 case cir::CmpOpKind::ne: {
3472 if (isIntAttr) {
3473 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
3474 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3475 } else {
3476 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() !=
3477 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3478 }
3479 break;
3480 }
3481 case cir::CmpOpKind::one: {
3482 llvm::APFloat::cmpResult cr =
3483 mlir::cast<cir::FPAttr>(lhsAttr).getValue().compare(
3484 mlir::cast<cir::FPAttr>(rhsAttr).getValue());
3485 cmpResult =
3486 cr != llvm::APFloat::cmpUnordered && cr != llvm::APFloat::cmpEqual;
3487 break;
3488 }
3489 case cir::CmpOpKind::uno: {
3490 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue().compare(
3491 mlir::cast<cir::FPAttr>(rhsAttr).getValue()) ==
3492 llvm::APFloat::cmpUnordered;
3493 break;
3494 }
3495 }
3496
3497 elements[i] = cir::IntAttr::get(getType().getElementType(), cmpResult);
3498 }
3499
3500 return cir::ConstVectorAttr::get(
3501 getType(), mlir::ArrayAttr::get(getContext(), elements));
3502}
3503
3504//===----------------------------------------------------------------------===//
3505// VecShuffleOp
3506//===----------------------------------------------------------------------===//
3507
3508OpFoldResult cir::VecShuffleOp::fold(FoldAdaptor adaptor) {
3509 auto vec1Attr =
3510 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec1());
3511 auto vec2Attr =
3512 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec2());
3513 if (!vec1Attr || !vec2Attr)
3514 return {};
3515
3516 mlir::Type vec1ElemTy =
3517 mlir::cast<cir::VectorType>(vec1Attr.getType()).getElementType();
3518
3519 mlir::ArrayAttr vec1Elts = vec1Attr.getElts();
3520 mlir::ArrayAttr vec2Elts = vec2Attr.getElts();
3521 mlir::ArrayAttr indicesElts = adaptor.getIndices();
3522
3524 elements.reserve(indicesElts.size());
3525
3526 uint64_t vec1Size = vec1Elts.size();
3527 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
3528 if (idxAttr.getSInt() == -1) {
3529 elements.push_back(cir::UndefAttr::get(vec1ElemTy));
3530 continue;
3531 }
3532
3533 uint64_t idxValue = idxAttr.getUInt();
3534 elements.push_back(idxValue < vec1Size ? vec1Elts[idxValue]
3535 : vec2Elts[idxValue - vec1Size]);
3536 }
3537
3538 return cir::ConstVectorAttr::get(
3539 getType(), mlir::ArrayAttr::get(getContext(), elements));
3540}
3541
3542LogicalResult cir::VecShuffleOp::verify() {
3543 // The number of elements in the indices array must match the number of
3544 // elements in the result type.
3545 if (getIndices().size() != getResult().getType().getSize()) {
3546 return emitOpError() << ": the number of elements in " << getIndices()
3547 << " and " << getResult().getType() << " don't match";
3548 }
3549
3550 // The element types of the two input vectors and of the result type must
3551 // match.
3552 if (getVec1().getType().getElementType() !=
3553 getResult().getType().getElementType()) {
3554 return emitOpError() << ": element types of " << getVec1().getType()
3555 << " and " << getResult().getType() << " don't match";
3556 }
3557
3558 const uint64_t maxValidIndex =
3559 getVec1().getType().getSize() + getVec2().getType().getSize() - 1;
3560 if (llvm::any_of(
3561 getIndices().getAsRange<cir::IntAttr>(), [&](cir::IntAttr idxAttr) {
3562 return idxAttr.getSInt() != -1 && idxAttr.getUInt() > maxValidIndex;
3563 })) {
3564 return emitOpError() << ": index for __builtin_shufflevector must be "
3565 "less than the total number of vector elements";
3566 }
3567 return success();
3568}
3569
3570//===----------------------------------------------------------------------===//
3571// VecShuffleDynamicOp
3572//===----------------------------------------------------------------------===//
3573
3574OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
3575 mlir::Attribute vec = adaptor.getVec();
3576 mlir::Attribute indices = adaptor.getIndices();
3577 if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec) &&
3578 mlir::isa_and_nonnull<cir::ConstVectorAttr>(indices)) {
3579 auto vecAttr = mlir::cast<cir::ConstVectorAttr>(vec);
3580 auto indicesAttr = mlir::cast<cir::ConstVectorAttr>(indices);
3581
3582 mlir::ArrayAttr vecElts = vecAttr.getElts();
3583 mlir::ArrayAttr indicesElts = indicesAttr.getElts();
3584
3585 const uint64_t numElements = vecElts.size();
3586
3588 elements.reserve(numElements);
3589
3590 const uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1;
3591 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
3592 uint64_t idxValue = idxAttr.getUInt();
3593 uint64_t newIdx = idxValue & maskBits;
3594 elements.push_back(vecElts[newIdx]);
3595 }
3596
3597 return cir::ConstVectorAttr::get(
3598 getType(), mlir::ArrayAttr::get(getContext(), elements));
3599 }
3600
3601 return {};
3602}
3603
3604LogicalResult cir::VecShuffleDynamicOp::verify() {
3605 // The number of elements in the two input vectors must match.
3606 if (getVec().getType().getSize() !=
3607 mlir::cast<cir::VectorType>(getIndices().getType()).getSize()) {
3608 return emitOpError() << ": the number of elements in " << getVec().getType()
3609 << " and " << getIndices().getType() << " don't match";
3610 }
3611 return success();
3612}
3613
3614//===----------------------------------------------------------------------===//
3615// VecTernaryOp
3616//===----------------------------------------------------------------------===//
3617
3618LogicalResult cir::VecTernaryOp::verify() {
3619 // Verify that the condition operand has the same number of elements as the
3620 // other operands. (The automatic verification already checked that all
3621 // operands are vector types and that the second and third operands are the
3622 // same type.)
3623 if (getCond().getType().getSize() != getLhs().getType().getSize()) {
3624 return emitOpError() << ": the number of elements in "
3625 << getCond().getType() << " and " << getLhs().getType()
3626 << " don't match";
3627 }
3628 return success();
3629}
3630
3631OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
3632 mlir::Attribute cond = adaptor.getCond();
3633 mlir::Attribute lhs = adaptor.getLhs();
3634 mlir::Attribute rhs = adaptor.getRhs();
3635
3636 if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
3637 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
3638 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
3639 return {};
3640 auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
3641 auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
3642 auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
3643
3644 mlir::ArrayAttr condElts = condVec.getElts();
3645
3647 elements.reserve(condElts.size());
3648
3649 for (const auto &[idx, condAttr] :
3650 llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
3651 if (condAttr.getSInt()) {
3652 elements.push_back(lhsVec.getElts()[idx]);
3653 } else {
3654 elements.push_back(rhsVec.getElts()[idx]);
3655 }
3656 }
3657
3658 cir::VectorType vecTy = getLhs().getType();
3659 return cir::ConstVectorAttr::get(
3660 vecTy, mlir::ArrayAttr::get(getContext(), elements));
3661}
3662
3663//===----------------------------------------------------------------------===//
3664// ComplexCreateOp
3665//===----------------------------------------------------------------------===//
3666
3667LogicalResult cir::ComplexCreateOp::verify() {
3668 if (getType().getElementType() != getReal().getType()) {
3669 emitOpError()
3670 << "operand type of cir.complex.create does not match its result type";
3671 return failure();
3672 }
3673
3674 return success();
3675}
3676
3677OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
3678 mlir::Attribute real = adaptor.getReal();
3679 mlir::Attribute imag = adaptor.getImag();
3680 if (!real || !imag)
3681 return {};
3682
3683 // When both of real and imag are constants, we can fold the operation into an
3684 // `#cir.const_complex` operation.
3685 auto realAttr = mlir::cast<mlir::TypedAttr>(real);
3686 auto imagAttr = mlir::cast<mlir::TypedAttr>(imag);
3687 return cir::ConstComplexAttr::get(realAttr, imagAttr);
3688}
3689
3690//===----------------------------------------------------------------------===//
3691// ComplexRealOp
3692//===----------------------------------------------------------------------===//
3693
3694LogicalResult cir::ComplexRealOp::verify() {
3695 mlir::Type operandTy = getOperand().getType();
3696 if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
3697 operandTy = complexOperandTy.getElementType();
3698
3699 if (getType() != operandTy) {
3700 emitOpError() << ": result type does not match operand type";
3701 return failure();
3702 }
3703
3704 return success();
3705}
3706
3707OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
3708 if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
3709 return nullptr;
3710
3711 if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
3712 return complexCreateOp.getOperand(0);
3713
3714 auto complex =
3715 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
3716 return complex ? complex.getReal() : nullptr;
3717}
3718
3719//===----------------------------------------------------------------------===//
3720// ComplexImagOp
3721//===----------------------------------------------------------------------===//
3722
3723LogicalResult cir::ComplexImagOp::verify() {
3724 mlir::Type operandTy = getOperand().getType();
3725 if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
3726 operandTy = complexOperandTy.getElementType();
3727
3728 if (getType() != operandTy) {
3729 emitOpError() << ": result type does not match operand type";
3730 return failure();
3731 }
3732
3733 return success();
3734}
3735
3736OpFoldResult cir::ComplexImagOp::fold(FoldAdaptor adaptor) {
3737 if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
3738 return nullptr;
3739
3740 if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
3741 return complexCreateOp.getOperand(1);
3742
3743 auto complex =
3744 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
3745 return complex ? complex.getImag() : nullptr;
3746}
3747
3748//===----------------------------------------------------------------------===//
3749// ComplexRealPtrOp
3750//===----------------------------------------------------------------------===//
3751
3752LogicalResult cir::ComplexRealPtrOp::verify() {
3753 mlir::Type resultPointeeTy = getType().getPointee();
3754 cir::PointerType operandPtrTy = getOperand().getType();
3755 auto operandPointeeTy =
3756 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
3757
3758 if (resultPointeeTy != operandPointeeTy.getElementType()) {
3759 return emitOpError() << ": result type does not match operand type";
3760 }
3761
3762 return success();
3763}
3764
3765//===----------------------------------------------------------------------===//
3766// ComplexImagPtrOp
3767//===----------------------------------------------------------------------===//
3768
3769LogicalResult cir::ComplexImagPtrOp::verify() {
3770 mlir::Type resultPointeeTy = getType().getPointee();
3771 cir::PointerType operandPtrTy = getOperand().getType();
3772 auto operandPointeeTy =
3773 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
3774
3775 if (resultPointeeTy != operandPointeeTy.getElementType()) {
3776 return emitOpError()
3777 << "cir.complex.imag_ptr result type does not match operand type";
3778 }
3779 return success();
3780}
3781
3782//===----------------------------------------------------------------------===//
3783// Bit manipulation operations
3784//===----------------------------------------------------------------------===//
3785
3786static OpFoldResult
3787foldUnaryBitOp(mlir::Attribute inputAttr,
3788 llvm::function_ref<llvm::APInt(const llvm::APInt &)> func,
3789 bool poisonZero = false) {
3790 if (mlir::isa_and_present<cir::PoisonAttr>(inputAttr)) {
3791 // Propagate poison value
3792 return inputAttr;
3793 }
3794
3795 auto input = mlir::dyn_cast_if_present<IntAttr>(inputAttr);
3796 if (!input)
3797 return nullptr;
3798
3799 llvm::APInt inputValue = input.getValue();
3800 if (poisonZero && inputValue.isZero())
3801 return cir::PoisonAttr::get(input.getType());
3802
3803 llvm::APInt resultValue = func(inputValue);
3804 return IntAttr::get(input.getType(), resultValue);
3805}
3806
3807OpFoldResult BitClrsbOp::fold(FoldAdaptor adaptor) {
3808 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3809 unsigned resultValue =
3810 inputValue.getBitWidth() - inputValue.getSignificantBits();
3811 return llvm::APInt(inputValue.getBitWidth(), resultValue);
3812 });
3813}
3814
3815OpFoldResult BitClzOp::fold(FoldAdaptor adaptor) {
3816 return foldUnaryBitOp(
3817 adaptor.getInput(),
3818 [](const llvm::APInt &inputValue) {
3819 unsigned resultValue = inputValue.countLeadingZeros();
3820 return llvm::APInt(inputValue.getBitWidth(), resultValue);
3821 },
3822 getPoisonZero());
3823}
3824
3825OpFoldResult BitCtzOp::fold(FoldAdaptor adaptor) {
3826 return foldUnaryBitOp(
3827 adaptor.getInput(),
3828 [](const llvm::APInt &inputValue) {
3829 return llvm::APInt(inputValue.getBitWidth(),
3830 inputValue.countTrailingZeros());
3831 },
3832 getPoisonZero());
3833}
3834
3835OpFoldResult BitFfsOp::fold(FoldAdaptor adaptor) {
3836 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3837 unsigned trailingZeros = inputValue.countTrailingZeros();
3838 unsigned result =
3839 trailingZeros == inputValue.getBitWidth() ? 0 : trailingZeros + 1;
3840 return llvm::APInt(inputValue.getBitWidth(), result);
3841 });
3842}
3843
3844OpFoldResult BitParityOp::fold(FoldAdaptor adaptor) {
3845 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3846 return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount() % 2);
3847 });
3848}
3849
3850OpFoldResult BitPopcountOp::fold(FoldAdaptor adaptor) {
3851 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3852 return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount());
3853 });
3854}
3855
3856OpFoldResult BitReverseOp::fold(FoldAdaptor adaptor) {
3857 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3858 return inputValue.reverseBits();
3859 });
3860}
3861
3862OpFoldResult ByteSwapOp::fold(FoldAdaptor adaptor) {
3863 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3864 return inputValue.byteSwap();
3865 });
3866}
3867
3868OpFoldResult RotateOp::fold(FoldAdaptor adaptor) {
3869 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()) ||
3870 mlir::isa_and_present<cir::PoisonAttr>(adaptor.getAmount())) {
3871 // Propagate poison values
3872 return cir::PoisonAttr::get(getType());
3873 }
3874
3875 auto input = mlir::dyn_cast_if_present<IntAttr>(adaptor.getInput());
3876 auto amount = mlir::dyn_cast_if_present<IntAttr>(adaptor.getAmount());
3877 if (!input && !amount)
3878 return nullptr;
3879
3880 // We could fold cir.rotate even if one of its two operands is not a constant:
3881 // - `cir.rotate left/right %0, 0` could be folded into just %0 even if %0
3882 // is not a constant.
3883 // - `cir.rotate left/right 0/0b111...111, %0` could be folded into 0 or
3884 // 0b111...111 even if %0 is not a constant.
3885
3886 llvm::APInt inputValue;
3887 if (input) {
3888 inputValue = input.getValue();
3889 if (inputValue.isZero() || inputValue.isAllOnes()) {
3890 // An input value of all 0s or all 1s will not change after rotation
3891 return input;
3892 }
3893 }
3894
3895 uint64_t amountValue;
3896 if (amount) {
3897 amountValue = amount.getValue().urem(getInput().getType().getWidth());
3898 if (amountValue == 0) {
3899 // A shift amount of 0 will not change the input value
3900 return getInput();
3901 }
3902 }
3903
3904 if (!input || !amount)
3905 return nullptr;
3906
3907 assert(inputValue.getBitWidth() == getInput().getType().getWidth() &&
3908 "input value must have the same bit width as the input type");
3909
3910 llvm::APInt resultValue;
3911 if (isRotateLeft())
3912 resultValue = inputValue.rotl(amountValue);
3913 else
3914 resultValue = inputValue.rotr(amountValue);
3915
3916 return IntAttr::get(input.getContext(), input.getType(), resultValue);
3917}
3918
3919//===----------------------------------------------------------------------===//
3920// InlineAsmOp
3921//===----------------------------------------------------------------------===//
3922
3923void cir::InlineAsmOp::print(OpAsmPrinter &p) {
3924 p << '(' << getAsmFlavor() << ", ";
3925 p.increaseIndent();
3926 p.printNewline();
3927
3928 llvm::SmallVector<std::string, 3> names{"out", "in", "in_out"};
3929 auto *nameIt = names.begin();
3930 auto *attrIt = getOperandAttrs().begin();
3931
3932 for (mlir::OperandRange ops : getAsmOperands()) {
3933 p << *nameIt << " = ";
3934
3935 p << '[';
3936 llvm::interleaveComma(llvm::make_range(ops.begin(), ops.end()), p,
3937 [&](Value value) {
3938 p.printOperand(value);
3939 p << " : " << value.getType();
3940 if (mlir::isa<mlir::UnitAttr>(*attrIt))
3941 p << " (maybe_memory)";
3942 attrIt++;
3943 });
3944 p << "],";
3945 p.printNewline();
3946 ++nameIt;
3947 }
3948
3949 p << "{";
3950 p.printString(getAsmString());
3951 p << " ";
3952 p.printString(getConstraints());
3953 p << "}";
3954 p.decreaseIndent();
3955 p << ')';
3956 if (getSideEffects())
3957 p << " side_effects";
3958
3959 std::array elidedAttrs{
3960 llvm::StringRef("asm_flavor"), llvm::StringRef("asm_string"),
3961 llvm::StringRef("constraints"), llvm::StringRef("operand_attrs"),
3962 llvm::StringRef("operands_segments"), llvm::StringRef("side_effects")};
3963 p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs);
3964
3965 if (auto v = getRes())
3966 p << " -> " << v.getType();
3967}
3968
3969void cir::InlineAsmOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3970 ArrayRef<ValueRange> asmOperands,
3971 StringRef asmString, StringRef constraints,
3972 bool sideEffects, cir::AsmFlavor asmFlavor,
3973 ArrayRef<Attribute> operandAttrs) {
3974 // Set up the operands_segments for VariadicOfVariadic
3975 SmallVector<int32_t> segments;
3976 for (auto operandRange : asmOperands) {
3977 segments.push_back(operandRange.size());
3978 odsState.addOperands(operandRange);
3979 }
3980
3981 odsState.addAttribute(
3982 "operands_segments",
3983 DenseI32ArrayAttr::get(odsBuilder.getContext(), segments));
3984 odsState.addAttribute("asm_string", odsBuilder.getStringAttr(asmString));
3985 odsState.addAttribute("constraints", odsBuilder.getStringAttr(constraints));
3986 odsState.addAttribute("asm_flavor",
3987 AsmFlavorAttr::get(odsBuilder.getContext(), asmFlavor));
3988
3989 if (sideEffects)
3990 odsState.addAttribute("side_effects", odsBuilder.getUnitAttr());
3991
3992 odsState.addAttribute("operand_attrs", odsBuilder.getArrayAttr(operandAttrs));
3993}
3994
3995ParseResult cir::InlineAsmOp::parse(OpAsmParser &parser,
3996 OperationState &result) {
3998 llvm::SmallVector<int32_t> operandsGroupSizes;
3999 std::string asmString, constraints;
4000 Type resType;
4001 MLIRContext *ctxt = parser.getBuilder().getContext();
4002
4003 auto error = [&](const Twine &msg) -> LogicalResult {
4004 return parser.emitError(parser.getCurrentLocation(), msg);
4005 };
4006
4007 auto expected = [&](const std::string &c) {
4008 return error("expected '" + c + "'");
4009 };
4010
4011 if (parser.parseLParen().failed())
4012 return expected("(");
4013
4014 auto flavor = FieldParser<AsmFlavor, AsmFlavor>::parse(parser);
4015 if (failed(flavor))
4016 return error("Unknown AsmFlavor");
4017
4018 if (parser.parseComma().failed())
4019 return expected(",");
4020
4021 auto parseValue = [&](Value &v) {
4022 OpAsmParser::UnresolvedOperand op;
4023
4024 if (parser.parseOperand(op) || parser.parseColon())
4025 return error("can't parse operand");
4026
4027 Type typ;
4028 if (parser.parseType(typ).failed())
4029 return error("can't parse operand type");
4031 if (parser.resolveOperand(op, typ, tmp))
4032 return error("can't resolve operand");
4033 v = tmp[0];
4034 return mlir::success();
4035 };
4036
4037 auto parseOperands = [&](llvm::StringRef name) {
4038 if (parser.parseKeyword(name).failed())
4039 return error("expected " + name + " operands here");
4040 if (parser.parseEqual().failed())
4041 return expected("=");
4042 if (parser.parseLSquare().failed())
4043 return expected("[");
4044
4045 int size = 0;
4046 if (parser.parseOptionalRSquare().succeeded()) {
4047 operandsGroupSizes.push_back(size);
4048 if (parser.parseComma())
4049 return expected(",");
4050 return mlir::success();
4051 }
4052
4053 auto parseOperand = [&]() {
4054 Value val;
4055 if (parseValue(val).succeeded()) {
4056 result.operands.push_back(val);
4057 size++;
4058
4059 if (parser.parseOptionalLParen().failed()) {
4060 operandAttrs.push_back(mlir::DictionaryAttr::get(ctxt));
4061 return mlir::success();
4062 }
4063
4064 if (parser.parseKeyword("maybe_memory").succeeded()) {
4065 operandAttrs.push_back(mlir::UnitAttr::get(ctxt));
4066 if (parser.parseRParen())
4067 return expected(")");
4068 return mlir::success();
4069 } else {
4070 return expected("maybe_memory");
4071 }
4072 }
4073 return mlir::failure();
4074 };
4075
4076 if (parser.parseCommaSeparatedList(parseOperand).failed())
4077 return mlir::failure();
4078
4079 if (parser.parseRSquare().failed() || parser.parseComma().failed())
4080 return expected("]");
4081 operandsGroupSizes.push_back(size);
4082 return mlir::success();
4083 };
4084
4085 if (parseOperands("out").failed() || parseOperands("in").failed() ||
4086 parseOperands("in_out").failed())
4087 return error("failed to parse operands");
4088
4089 if (parser.parseLBrace())
4090 return expected("{");
4091 if (parser.parseString(&asmString))
4092 return error("asm string parsing failed");
4093 if (parser.parseString(&constraints))
4094 return error("constraints string parsing failed");
4095 if (parser.parseRBrace())
4096 return expected("}");
4097 if (parser.parseRParen())
4098 return expected(")");
4099
4100 if (parser.parseOptionalKeyword("side_effects").succeeded())
4101 result.attributes.set("side_effects", UnitAttr::get(ctxt));
4102
4103 if (parser.parseOptionalAttrDict(result.attributes).failed())
4104 return mlir::failure();
4105
4106 if (parser.parseOptionalArrow().succeeded() &&
4107 parser.parseType(resType).failed())
4108 return mlir::failure();
4109
4110 result.attributes.set("asm_flavor", AsmFlavorAttr::get(ctxt, *flavor));
4111 result.attributes.set("asm_string", StringAttr::get(ctxt, asmString));
4112 result.attributes.set("constraints", StringAttr::get(ctxt, constraints));
4113 result.attributes.set("operand_attrs", ArrayAttr::get(ctxt, operandAttrs));
4114 result.getOrAddProperties<InlineAsmOp::Properties>().operands_segments =
4115 parser.getBuilder().getDenseI32ArrayAttr(operandsGroupSizes);
4116 if (resType)
4117 result.addTypes(TypeRange{resType});
4118
4119 return mlir::success();
4120}
4121
4122//===----------------------------------------------------------------------===//
4123// ThrowOp / TryThrowOp
4124//===----------------------------------------------------------------------===//
4125
4126template <typename ThrowOpTy>
4127static mlir::LogicalResult verifyThrowOpImpl(ThrowOpTy op) {
4128 if (op.rethrows())
4129 return mlir::success();
4130
4131 if (op.getNumOperands() != 0) {
4132 if (op.getTypeInfo())
4133 return mlir::success();
4134 return op.emitOpError() << "'type_info' symbol attribute missing";
4135 }
4136
4137 return mlir::failure();
4138}
4139
4140mlir::LogicalResult cir::ThrowOp::verify() { return verifyThrowOpImpl(*this); }
4141
4142mlir::LogicalResult cir::TryThrowOp::verify() {
4143 return verifyThrowOpImpl(*this);
4144}
4145
4146//===----------------------------------------------------------------------===//
4147// AtomicFetchOp
4148//===----------------------------------------------------------------------===//
4149
4150LogicalResult cir::AtomicFetchOp::verify() {
4151 if (getBinop() != cir::AtomicFetchKind::Add &&
4152 getBinop() != cir::AtomicFetchKind::Sub &&
4153 getBinop() != cir::AtomicFetchKind::Max &&
4154 getBinop() != cir::AtomicFetchKind::Min &&
4155 !mlir::isa<cir::IntType>(getVal().getType()))
4156 return emitError("only atomic add, sub, max, and min operation could "
4157 "operate on floating-point values");
4158 return success();
4159}
4160
4161//===----------------------------------------------------------------------===//
4162// TypeInfoAttr
4163//===----------------------------------------------------------------------===//
4164
4165LogicalResult cir::TypeInfoAttr::verify(
4166 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
4167 ::mlir::Type type, ::mlir::ArrayAttr typeInfoData) {
4168
4169 if (cir::ConstRecordAttr::verify(emitError, type, typeInfoData).failed())
4170 return failure();
4171
4172 return success();
4173}
4174
4175//===----------------------------------------------------------------------===//
4176// TryOp
4177//===----------------------------------------------------------------------===//
4178
4179void cir::TryOp::getSuccessorRegions(
4180 mlir::RegionBranchPoint point,
4182 // The `try` and the `catchers` region branch back to the parent operation.
4183 if (!point.isParent()) {
4184 regions.push_back(RegionSuccessor::parent());
4185 return;
4186 }
4187
4188 regions.push_back(mlir::RegionSuccessor(&getTryRegion()));
4189
4190 // TODO(CIR): If we know a target function never throws a specific type, we
4191 // can remove the catch handler.
4192 for (mlir::Region &handlerRegion : this->getHandlerRegions())
4193 regions.push_back(mlir::RegionSuccessor(&handlerRegion));
4194}
4195
4196mlir::ValueRange cir::TryOp::getSuccessorInputs(RegionSuccessor successor) {
4197 return successor.isParent() ? ValueRange(getOperation()->getResults())
4198 : ValueRange();
4199}
4200
4201LogicalResult cir::TryOp::verify() {
4202 mlir::ArrayAttr handlerTypes = getHandlerTypes();
4203 if (!handlerTypes) {
4204 if (!getHandlerRegions().empty())
4205 return emitOpError(
4206 "handler regions must be empty when no handler types are present");
4207 return success();
4208 }
4209
4210 mlir::MutableArrayRef<mlir::Region> handlerRegions = getHandlerRegions();
4211
4212 // The parser and builder won't allow this to happen, but the loop below
4213 // relies on the sizes being the same, so we check it here.
4214 if (handlerRegions.size() != handlerTypes.size())
4215 return emitOpError(
4216 "number of handler regions and handler types must match");
4217
4218 for (const auto &[typeAttr, handlerRegion] :
4219 llvm::zip(handlerTypes, handlerRegions)) {
4220 // Verify that handler regions have a !cir.eh_token block argument.
4221 mlir::Block &entryBlock = handlerRegion.front();
4222 if (entryBlock.getNumArguments() != 1 ||
4223 !mlir::isa<cir::EhTokenType>(entryBlock.getArgument(0).getType()))
4224 return emitOpError(
4225 "handler region must have a single '!cir.eh_token' argument");
4226
4227 // The unwind region does not require a cir.begin_catch.
4228 if (mlir::isa<cir::UnwindAttr>(typeAttr))
4229 continue;
4230
4231 // A catch handler region must start with cir.begin_catch, optionally
4232 // preceded by a single cir.construct_catch_param that performs any
4233 // pre-begin_catch initialization for the catch parameter.
4234 if (entryBlock.empty())
4235 return emitOpError("catch handler region must not be empty");
4236 mlir::Operation *firstOp = &entryBlock.front();
4237 if (mlir::isa_and_present<cir::ConstructCatchParamOp>(firstOp))
4238 firstOp = firstOp->getNextNode();
4239 if (!firstOp || !mlir::isa<cir::BeginCatchOp>(firstOp))
4240 return emitOpError(
4241 "catch handler region must start with 'cir.begin_catch'");
4242 }
4243
4244 return success();
4245}
4246
4247static void
4248printTryHandlerRegions(mlir::OpAsmPrinter &printer, cir::TryOp op,
4249 mlir::MutableArrayRef<mlir::Region> handlerRegions,
4250 mlir::ArrayAttr handlerTypes) {
4251 if (!handlerTypes)
4252 return;
4253
4254 for (const auto [typeIdx, typeAttr] : llvm::enumerate(handlerTypes)) {
4255 if (typeIdx)
4256 printer << " ";
4257
4258 if (mlir::isa<cir::CatchAllAttr>(typeAttr)) {
4259 printer << "catch all ";
4260 } else if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
4261 printer << "unwind ";
4262 } else {
4263 printer << "catch [type ";
4264 printer.printAttribute(typeAttr);
4265 printer << "] ";
4266 }
4267
4268 // Print the handler region's !cir.eh_token block argument.
4269 mlir::Region &region = handlerRegions[typeIdx];
4270 if (!region.empty() && region.front().getNumArguments() > 0) {
4271 printer << "(";
4272 printer.printRegionArgument(region.front().getArgument(0));
4273 printer << ") ";
4274 }
4275
4276 printer.printRegion(region,
4277 /*printEntryBLockArgs=*/false,
4278 /*printBlockTerminators=*/true);
4279 }
4280}
4281
4282static mlir::ParseResult parseTryHandlerRegions(
4283 mlir::OpAsmParser &parser,
4284 llvm::SmallVectorImpl<std::unique_ptr<mlir::Region>> &handlerRegions,
4285 mlir::ArrayAttr &handlerTypes) {
4286
4287 auto parseCheckedCatcherRegion = [&]() -> mlir::ParseResult {
4288 handlerRegions.emplace_back(new mlir::Region);
4289
4290 mlir::Region &currRegion = *handlerRegions.back();
4291
4292 // Parse the required region argument: (%eh_token : !cir.eh_token)
4294 if (parser.parseLParen())
4295 return failure();
4296 mlir::OpAsmParser::Argument arg;
4297 if (parser.parseArgument(arg, /*allowType=*/true))
4298 return failure();
4299 regionArgs.push_back(arg);
4300 if (parser.parseRParen())
4301 return failure();
4302
4303 mlir::SMLoc regionLoc = parser.getCurrentLocation();
4304 if (parser.parseRegion(currRegion, regionArgs)) {
4305 handlerRegions.clear();
4306 return failure();
4307 }
4308
4309 if (currRegion.empty())
4310 return parser.emitError(regionLoc, "handler region shall not be empty");
4311
4312 if (!(currRegion.back().mightHaveTerminator() &&
4313 currRegion.back().getTerminator()))
4314 return parser.emitError(
4315 regionLoc, "blocks are expected to be explicitly terminated");
4316
4317 return success();
4318 };
4319
4320 bool hasCatchAll = false;
4322 while (parser.parseOptionalKeyword("catch").succeeded()) {
4323 bool hasLSquare = parser.parseOptionalLSquare().succeeded();
4324
4325 llvm::StringRef attrStr;
4326 if (parser.parseOptionalKeyword(&attrStr, {"all", "type"}).failed())
4327 return parser.emitError(parser.getCurrentLocation(),
4328 "expected 'all' or 'type' keyword");
4329
4330 bool isCatchAll = attrStr == "all";
4331 if (isCatchAll) {
4332 if (hasCatchAll)
4333 return parser.emitError(parser.getCurrentLocation(),
4334 "can't have more than one catch all");
4335 hasCatchAll = true;
4336 }
4337
4338 mlir::Attribute exceptionRTTIAttr;
4339 if (!isCatchAll && parser.parseAttribute(exceptionRTTIAttr).failed())
4340 return parser.emitError(parser.getCurrentLocation(),
4341 "expected valid RTTI info attribute");
4342
4343 catcherAttrs.push_back(isCatchAll
4344 ? cir::CatchAllAttr::get(parser.getContext())
4345 : exceptionRTTIAttr);
4346
4347 if (hasLSquare && isCatchAll)
4348 return parser.emitError(parser.getCurrentLocation(),
4349 "catch all dosen't need RTTI info attribute");
4350
4351 if (hasLSquare && parser.parseRSquare().failed())
4352 return parser.emitError(parser.getCurrentLocation(),
4353 "expected `]` after RTTI info attribute");
4354
4355 if (parseCheckedCatcherRegion().failed())
4356 return mlir::failure();
4357 }
4358
4359 if (parser.parseOptionalKeyword("unwind").succeeded()) {
4360 if (hasCatchAll)
4361 return parser.emitError(parser.getCurrentLocation(),
4362 "unwind can't be used with catch all");
4363
4364 catcherAttrs.push_back(cir::UnwindAttr::get(parser.getContext()));
4365 if (parseCheckedCatcherRegion().failed())
4366 return mlir::failure();
4367 }
4368
4369 handlerTypes = parser.getBuilder().getArrayAttr(catcherAttrs);
4370 return mlir::success();
4371}
4372
4373//===----------------------------------------------------------------------===//
4374// EhTypeIdOp
4375//===----------------------------------------------------------------------===//
4376
4377LogicalResult
4378cir::EhTypeIdOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
4379 Operation *op = symbolTable.lookupNearestSymbolFrom(*this, getTypeSymAttr());
4380 if (!isa_and_nonnull<GlobalOp>(op))
4381 return emitOpError("'")
4382 << getTypeSym() << "' does not reference a valid cir.global";
4383 return success();
4384}
4385
4386//===----------------------------------------------------------------------===//
4387// ConstructCatchParamOp
4388//===----------------------------------------------------------------------===//
4389
4390LogicalResult cir::ConstructCatchParamOp::verifySymbolUses(
4391 SymbolTableCollection &symbolTable) {
4392 auto copyFnAttr = getCopyFnAttr();
4393 if (!copyFnAttr)
4394 return success();
4395 auto fn =
4396 symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(*this, getCopyFnAttr());
4397 if (!fn)
4398 return emitOpError("'")
4399 << *getCopyFn() << "' does not reference a valid cir.func";
4400
4401 if (!fn->hasAttr(cir::CIRDialect::getCatchCopyThunkAttrName()))
4402 return emitOpError("catch-init copy_fn must be tagged with the ")
4403 << cir::CIRDialect::getCatchCopyThunkAttrName() << " attribute";
4404
4405 cir::FuncType fnType = fn.getFunctionType();
4406 if (fnType.getNumInputs() != 2 || !fnType.hasVoidReturn())
4407 return emitOpError("catch-init copy_fn must take two pointer arguments and "
4408 "return void");
4409
4410 if (fnType.getInput(0) != getParamAddr().getType())
4411 return emitOpError("first argument of catch-init copy_fn must match the "
4412 "type of 'param_addr'");
4413
4414 if (fnType.getInput(1) != getParamAddr().getType())
4415 return emitOpError(
4416 "second argument of catch-init copy_fn must be a pointer "
4417 "to the catch type");
4418
4419 return success();
4420}
4421
4422//===----------------------------------------------------------------------===//
4423// EhDispatchOp
4424//===----------------------------------------------------------------------===//
4425
4426static ParseResult
4427parseEhDispatchDestinations(OpAsmParser &parser, mlir::ArrayAttr &catchTypes,
4428 SmallVectorImpl<Block *> &catchDestinations,
4429 Block *&defaultDestination,
4430 mlir::UnitAttr &defaultIsCatchAll) {
4431 // Parse: [ ... ]
4432 if (parser.parseLSquare())
4433 return failure();
4434
4435 SmallVector<Attribute> handlerTypes;
4436 bool hasCatchAll = false;
4437 bool hasUnwind = false;
4438
4439 // Parse handler list.
4440 auto parseHandler = [&]() -> ParseResult {
4441 // Check for 'catch_all' or 'unwind' keywords.
4442 if (succeeded(parser.parseOptionalKeyword("catch_all"))) {
4443 if (hasCatchAll)
4444 return parser.emitError(parser.getCurrentLocation(),
4445 "duplicate 'catch_all' handler");
4446 if (hasUnwind)
4447 return parser.emitError(parser.getCurrentLocation(),
4448 "cannot have both 'catch_all' and 'unwind'");
4449 hasCatchAll = true;
4450
4451 if (parser.parseColon().failed())
4452 return failure();
4453
4454 if (parser.parseSuccessor(defaultDestination).failed())
4455 return failure();
4456
4457 return success();
4458 }
4459
4460 if (succeeded(parser.parseOptionalKeyword("unwind"))) {
4461 if (hasUnwind)
4462 return parser.emitError(parser.getCurrentLocation(),
4463 "duplicate 'unwind' handler");
4464 if (hasCatchAll)
4465 return parser.emitError(parser.getCurrentLocation(),
4466 "cannot have both 'catch_all' and 'unwind'");
4467 hasUnwind = true;
4468
4469 if (parser.parseColon().failed())
4470 return failure();
4471
4472 if (parser.parseSuccessor(defaultDestination).failed())
4473 return failure();
4474 return success();
4475 }
4476
4477 // Otherwise, expect 'catch(<attr> : <type>) : ^block'.
4478 // The 'catch(...)' wrapper allows the attribute to include its type
4479 // without conflicting with the ':' used for the block destination.
4480 if (parser.parseKeyword("catch").failed())
4481 return failure();
4482
4483 if (parser.parseLParen().failed())
4484 return failure();
4485
4486 mlir::Attribute catchTypeAttr;
4487 if (parser.parseAttribute(catchTypeAttr).failed())
4488 return failure();
4489 handlerTypes.push_back(catchTypeAttr);
4490
4491 if (parser.parseRParen().failed())
4492 return failure();
4493
4494 if (parser.parseColon().failed())
4495 return failure();
4496
4497 Block *dest;
4498 if (parser.parseSuccessor(dest).failed())
4499 return failure();
4500 catchDestinations.push_back(dest);
4501 return success();
4502 };
4503
4504 if (parser.parseCommaSeparatedList(parseHandler).failed())
4505 return failure();
4506
4507 if (parser.parseRSquare().failed())
4508 return failure();
4509
4510 // Verify we have catch_all or unwind.
4511 if (!hasCatchAll && !hasUnwind)
4512 return parser.emitError(parser.getCurrentLocation(),
4513 "must have either 'catch_all' or 'unwind' handler");
4514
4515 // Add attributes and successors.
4516 if (!handlerTypes.empty())
4517 catchTypes = parser.getBuilder().getArrayAttr(handlerTypes);
4518
4519 if (hasCatchAll)
4520 defaultIsCatchAll = parser.getBuilder().getUnitAttr();
4521
4522 return success();
4523}
4524
4525static void printEhDispatchDestinations(OpAsmPrinter &p, cir::EhDispatchOp op,
4526 mlir::ArrayAttr catchTypes,
4527 SuccessorRange catchDestinations,
4528 Block *defaultDestination,
4529 mlir::UnitAttr defaultIsCatchAll) {
4530 p << " [";
4531 p.printNewline();
4532
4533 // If we have at least one catch type, print them.
4534 if (catchTypes) {
4535 // Print type handlers using 'catch(<attr>) : ^block' syntax.
4536 llvm::interleave(
4537 llvm::zip(catchTypes, catchDestinations),
4538 [&](auto i) {
4539 p << " catch(";
4540 p.printAttribute(std::get<0>(i));
4541 p << ") : ";
4542 p.printSuccessor(std::get<1>(i));
4543 },
4544 [&] {
4545 p << ',';
4546 p.printNewline();
4547 });
4548
4549 p << ", ";
4550 p.printNewline();
4551 }
4552
4553 // Print catch_all or unwind handler.
4554 if (defaultIsCatchAll)
4555 p << " catch_all : ";
4556 else
4557 p << " unwind : ";
4558 p.printSuccessor(defaultDestination);
4559 p.printNewline();
4560
4561 p << "]";
4562}
4563
4564//===----------------------------------------------------------------------===//
4565// TableGen'd op method definitions
4566//===----------------------------------------------------------------------===//
4567
4568#define GET_OP_CLASSES
4569#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
static const MemRegion * getRegion(const CallEvent &Call, const MutexDescriptor &Descriptor, bool IsLock)
static void printEhDispatchDestinations(OpAsmPrinter &p, cir::EhDispatchOp op, mlir::ArrayAttr catchTypes, SuccessorRange catchDestinations, Block *defaultDestination, mlir::UnitAttr defaultIsCatchAll)
static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op, cir::FuncOp function)
static bool isCirFunctionPointerType(mlir::Type ty)
static LogicalResult verifyMemberPtrCast(Operation *op, mlir::Value src, mlir::Type resultTy)
static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, mlir::OperationState &result, bool hasDestinationBlocks=false)
static bool isIntOrBoolCast(cir::CastOp op)
static ParseResult parseAssumeBundle(OpAsmParser &p, cir::AssumeBundleKindAttr &bundleKindAttr, llvm::SmallVector< mlir::OpAsmParser::UnresolvedOperand, 4 > &bundleArgs, llvm::SmallVector< mlir::Type, 1 > &bundleArgTypes)
static ParseResult parseEhDispatchDestinations(OpAsmParser &parser, mlir::ArrayAttr &catchTypes, SmallVectorImpl< Block * > &catchDestinations, Block *&defaultDestination, mlir::UnitAttr &defaultIsCatchAll)
static void printConstant(OpAsmPrinter &p, Attribute value)
static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser, mlir::Region &region)
static void printAssumeBundle(OpAsmPrinter &p, cir::AssumeOp op, cir::AssumeBundleKindAttr kindAttr, OperandRange bundleArgs, TypeRange bundleArgTypes)
ParseResult parseInlineKindAttr(OpAsmParser &parser, cir::InlineKindAttr &inlineKindAttr)
void printInlineKindAttr(OpAsmPrinter &p, cir::InlineKindAttr inlineKindAttr)
static ParseResult parseSwitchFlatOpCases(OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< llvm::SmallVector< OpAsmParser::UnresolvedOperand > > &caseOperands, SmallVectorImpl< llvm::SmallVector< Type > > &caseOperandTypes)
<cases> ::= [ (case (, case )* )?
void printGlobalAddressSpaceValue(mlir::AsmPrinter &printer, cir::GlobalOp op, mlir::ptr::MemorySpaceAttrInterface attr)
static void printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym, mlir::Value indirectCallee, mlir::OpAsmPrinter &printer, bool isNothrow, cir::SideEffect sideEffect, ArrayAttr argAttrs, ArrayAttr resAttrs, mlir::Block *normalDest=nullptr, mlir::Block *unwindDest=nullptr)
static LogicalResult verifyCallCommInSymbolUses(mlir::Operation *op, SymbolTableCollection &symbolTable)
static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region, SMLoc errLoc)
static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValueAttr, mlir::Region &ctorRegion, mlir::Region &dtorRegion)
void printIndirectBrOpSucessors(OpAsmPrinter &p, cir::IndirectBrOp op, Type flagType, SuccessorRange succs, OperandRangeRange succOperands, const TypeRangeRange &succOperandsTypes)
static OpFoldResult foldUnaryBitOp(mlir::Attribute inputAttr, llvm::function_ref< llvm::APInt(const llvm::APInt &)> func, bool poisonZero=false)
static llvm::StringRef getLinkageAttrNameString()
Returns the name used for the linkage attribute.
static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue)
Parse an enum from the keyword, or default to the provided default value.
mlir::OptionalParseResult parseGlobalAddressSpaceValue(mlir::AsmParser &p, mlir::ptr::MemorySpaceAttrInterface &attr)
static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op, Type flagType, mlir::ArrayAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static mlir::ParseResult parseTryCallDestinations(mlir::OpAsmParser &parser, mlir::OperationState &result)
static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op, TypeAttr type, Attribute initAttr, mlir::Region &ctorRegion, mlir::Region &dtorRegion)
static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result)
Parse an enum from the keyword, return failure if the keyword is not found.
static Value tryFoldCastChain(cir::CastOp op)
static void printTryHandlerRegions(mlir::OpAsmPrinter &printer, cir::TryOp op, mlir::MutableArrayRef< mlir::Region > handlerRegions, mlir::ArrayAttr handlerTypes)
ParseResult parseIndirectBrOpSucessors(OpAsmParser &parser, Type &flagType, SmallVectorImpl< Block * > &succOperandBlocks, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &succOperands, SmallVectorImpl< SmallVector< Type > > &succOperandsTypes)
static bool omitRegionTerm(mlir::Region &r)
static LogicalResult verifyBinaryOverflowOp(mlir::Operation *op, bool noSignedWrap, bool noUnsignedWrap, bool saturated, bool hasSat)
static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer, cir::ScopeOp &op, mlir::Region &region)
static ParseResult parseConstantValue(OpAsmParser &parser, mlir::Attribute &valueAttr)
static LogicalResult verifyArrayCtorDtor(Op op)
static mlir::LogicalResult verifyThrowOpImpl(ThrowOpTy op)
static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, mlir::Attribute attrType)
static mlir::ParseResult parseTryHandlerRegions(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< std::unique_ptr< mlir::Region > > &handlerRegions, mlir::ArrayAttr &handlerTypes)
#define REGISTER_ENUM_TYPE(Ty)
static int parseOptionalKeywordAlternative(AsmParser &parser, ArrayRef< llvm::StringRef > keywords)
llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> BuilderCallbackRef
Definition CIRDialect.h:37
llvm::function_ref< void( mlir::OpBuilder &, mlir::Location, mlir::OperationState &)> BuilderOpStateCallbackRef
Definition CIRDialect.h:39
static std::optional< NonLoc > getIndex(ProgramStateRef State, const ElementRegion *ER, CharKind CK)
static Decl::Kind getKind(const Decl *D)
TokenType getType() const
Returns the token's type, e.g.
tooling::Replacements cleanup(const FormatStyle &Style, StringRef Code, ArrayRef< tooling::Range > Ranges, StringRef FileName="<stdin>")
Clean up any erroneous/redundant code in the given Ranges in Code.
*collection of selector each with an associated kind and an ordered *collection of selectors A selector has a an optional score condition
*collection of selector each with an associated kind and an ordered *collection of selectors A selector has a kind
__device__ __2f16 float c
void buildTerminatedBody(mlir::OpBuilder &builder, mlir::Location loc)
mlir::ptr::MemorySpaceAttrInterface normalizeDefaultAddressSpace(mlir::ptr::MemorySpaceAttrInterface addrSpace)
Normalize LangAddressSpace::Default to null (empty attribute).
const internal::VariadicAllOfMatcher< Attr > attr
const AstTypeMatcher< RecordType > recordType
StringRef getName(const HeaderType T)
Definition HeaderFile.h:38
RangeSelector name(std::string ID)
Given a node with a "name", (like NamedDecl, DeclRefExpr, CxxCtorInitializer, and TypeLoc) selects th...
nullptr
This class represents a compute construct, representing a 'Kind' of ‘parallel’, 'serial',...
__DEVICE__ _Tp arg(const std::complex< _Tp > &__c)
static bool memberFuncPtrCast()
static bool opCallCallConv()
static bool opScopeCleanupRegion()
static bool supportIFuncAttr()