clang 23.0.0git
CIRDialect.cpp
Go to the documentation of this file.
1//===- CIRDialect.cpp - MLIR CIR ops implementation -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the CIR dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
18
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/DialectImplementation.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/IR/Value.h"
23#include "mlir/Interfaces/ControlFlowInterfaces.h"
24#include "mlir/Interfaces/FunctionImplementation.h"
25#include "mlir/Support/LLVM.h"
26
27#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
28#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
30#include "llvm/ADT/SetOperations.h"
31#include "llvm/ADT/SmallSet.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/LogicalResult.h"
34
35using namespace mlir;
36using namespace cir;
37
38//===----------------------------------------------------------------------===//
39// CIR Dialect
40//===----------------------------------------------------------------------===//
41namespace {
42struct CIROpAsmDialectInterface : public OpAsmDialectInterface {
43 using OpAsmDialectInterface::OpAsmDialectInterface;
44
45 AliasResult getAlias(Type type, raw_ostream &os) const final {
46 if (auto recordType = dyn_cast<cir::RecordType>(type)) {
47 StringAttr nameAttr = recordType.getName();
48 if (!nameAttr)
49 os << "rec_anon_" << recordType.getKindAsStr();
50 else
51 os << "rec_" << nameAttr.getValue();
52 return AliasResult::OverridableAlias;
53 }
54 if (auto intType = dyn_cast<cir::IntType>(type)) {
55 // We only provide alias for standard integer types (i.e. integer types
56 // whose width is a power of 2 and at least 8).
57 unsigned width = intType.getWidth();
58 if (width < 8 || !llvm::isPowerOf2_32(width))
59 return AliasResult::NoAlias;
60 os << intType.getAlias();
61 return AliasResult::OverridableAlias;
62 }
63 if (auto voidType = dyn_cast<cir::VoidType>(type)) {
64 os << voidType.getAlias();
65 return AliasResult::OverridableAlias;
66 }
67
68 return AliasResult::NoAlias;
69 }
70
71 AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
72 if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr)) {
73 os << (boolAttr.getValue() ? "true" : "false");
74 return AliasResult::FinalAlias;
75 }
76 if (auto bitfield = mlir::dyn_cast<cir::BitfieldInfoAttr>(attr)) {
77 os << "bfi_" << bitfield.getName().str();
78 return AliasResult::FinalAlias;
79 }
80 if (auto dynCastInfoAttr = mlir::dyn_cast<cir::DynamicCastInfoAttr>(attr)) {
81 os << dynCastInfoAttr.getAlias();
82 return AliasResult::FinalAlias;
83 }
84 if (auto cmpThreeWayInfoAttr =
85 mlir::dyn_cast<cir::CmpThreeWayInfoAttr>(attr)) {
86 os << cmpThreeWayInfoAttr.getAlias();
87 return AliasResult::FinalAlias;
88 }
89 return AliasResult::NoAlias;
90 }
91};
92} // namespace
93
94void cir::CIRDialect::initialize() {
95 registerTypes();
96 registerAttributes();
97 addOperations<
98#define GET_OP_LIST
99#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
100 >();
101 addInterfaces<CIROpAsmDialectInterface>();
102}
103
104Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
105 mlir::Attribute value,
106 mlir::Type type,
107 mlir::Location loc) {
108 return cir::ConstantOp::create(builder, loc, type,
109 mlir::cast<mlir::TypedAttr>(value));
110}
111
112//===----------------------------------------------------------------------===//
113// Helpers
114//===----------------------------------------------------------------------===//
115
116// Parses one of the keywords provided in the list `keywords` and returns the
117// position of the parsed keyword in the list. If none of the keywords from the
118// list is parsed, returns -1.
119static int parseOptionalKeywordAlternative(AsmParser &parser,
120 ArrayRef<llvm::StringRef> keywords) {
121 for (auto en : llvm::enumerate(keywords)) {
122 if (succeeded(parser.parseOptionalKeyword(en.value())))
123 return en.index();
124 }
125 return -1;
126}
127
128namespace {
129template <typename Ty> struct EnumTraits {};
130
131#define REGISTER_ENUM_TYPE(Ty) \
132 template <> struct EnumTraits<cir::Ty> { \
133 static llvm::StringRef stringify(cir::Ty value) { \
134 return stringify##Ty(value); \
135 } \
136 static unsigned getMaxEnumVal() { return cir::getMaxEnumValFor##Ty(); } \
137 }
138
139REGISTER_ENUM_TYPE(GlobalLinkageKind);
140REGISTER_ENUM_TYPE(VisibilityKind);
141REGISTER_ENUM_TYPE(SideEffect);
142REGISTER_ENUM_TYPE(CallingConv);
143} // namespace
144
145/// Parse an enum from the keyword, or default to the provided default value.
146/// The return type is the enum type by default, unless overriden with the
147/// second template argument.
148template <typename EnumTy, typename RetTy = EnumTy>
149static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue) {
151 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
152 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
153
154 int index = parseOptionalKeywordAlternative(parser, names);
155 if (index == -1)
156 return static_cast<RetTy>(defaultValue);
157 return static_cast<RetTy>(index);
158}
159
160/// Parse an enum from the keyword, return failure if the keyword is not found.
161template <typename EnumTy, typename RetTy = EnumTy>
162static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result) {
164 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
165 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
166
167 int index = parseOptionalKeywordAlternative(parser, names);
168 if (index == -1)
169 return failure();
170 result = static_cast<RetTy>(index);
171 return success();
172}
173
174// Check if a region's termination omission is valid and, if so, creates and
175// inserts the omitted terminator into the region.
176static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region,
177 SMLoc errLoc) {
178 Location eLoc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
179 OpBuilder builder(parser.getBuilder().getContext());
180
181 // Insert empty block in case the region is empty to ensure the terminator
182 // will be inserted
183 if (region.empty())
184 builder.createBlock(&region);
185
186 Block &block = region.back();
187 // Region is properly terminated: nothing to do.
188 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>())
189 return success();
190
191 // Check for invalid terminator omissions.
192 if (!region.hasOneBlock())
193 return parser.emitError(errLoc,
194 "multi-block region must not omit terminator");
195
196 // Terminator was omitted correctly: recreate it.
197 builder.setInsertionPointToEnd(&block);
198 cir::YieldOp::create(builder, eLoc);
199 return success();
200}
201
202// True if the region's terminator should be omitted.
203static bool omitRegionTerm(mlir::Region &r) {
204 const auto singleNonEmptyBlock = r.hasOneBlock() && !r.back().empty();
205 const auto yieldsNothing = [&r]() {
206 auto y = dyn_cast<cir::YieldOp>(r.back().getTerminator());
207 return y && y.getArgs().empty();
208 };
209 return singleNonEmptyBlock && yieldsNothing();
210}
211
212// Verifies that the given operand is produced by an operation of type
213// ExpectedProducerOp.
214template <typename ExpectedProducerOp>
215static LogicalResult verifyProducedBy(Operation *op, Value operand,
216 StringRef operandName) {
217 Operation *producer = operand.getDefiningOp();
218 if (!producer || !isa<ExpectedProducerOp>(producer))
219 return op->emitOpError()
220 << "operand '" << operandName << "' must be produced by '"
221 << ExpectedProducerOp::getOperationName() << "'";
222 return success();
223}
224
225//===----------------------------------------------------------------------===//
226// InlineKindAttr (FIXME: remove once FuncOp uses assembly format)
227//===----------------------------------------------------------------------===//
228
229ParseResult parseInlineKindAttr(OpAsmParser &parser,
230 cir::InlineKindAttr &inlineKindAttr) {
231 // Static list of possible inline kind keywords
232 static constexpr llvm::StringRef keywords[] = {"no_inline", "always_inline",
233 "inline_hint"};
234
235 // Parse the inline kind keyword (optional)
236 llvm::StringRef keyword;
237 if (parser.parseOptionalKeyword(&keyword, keywords).failed()) {
238 // Not an inline kind keyword, leave inlineKindAttr empty
239 return success();
240 }
241
242 // Parse the enum value from the keyword
243 auto inlineKindResult = ::cir::symbolizeEnum<::cir::InlineKind>(keyword);
244 if (!inlineKindResult) {
245 return parser.emitError(parser.getCurrentLocation(), "expected one of [")
246 << llvm::join(llvm::ArrayRef(keywords), ", ")
247 << "] for inlineKind, got: " << keyword;
248 }
249
250 inlineKindAttr =
251 ::cir::InlineKindAttr::get(parser.getContext(), *inlineKindResult);
252 return success();
253}
254
255void printInlineKindAttr(OpAsmPrinter &p, cir::InlineKindAttr inlineKindAttr) {
256 if (inlineKindAttr) {
257 p << " " << stringifyInlineKind(inlineKindAttr.getValue());
258 }
259}
260//===----------------------------------------------------------------------===//
261// CIR Custom Parsers/Printers
262//===----------------------------------------------------------------------===//
263
264static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser,
265 mlir::Region &region) {
266 auto regionLoc = parser.getCurrentLocation();
267 if (parser.parseRegion(region))
268 return failure();
269 if (ensureRegionTerm(parser, region, regionLoc).failed())
270 return failure();
271 return success();
272}
273
274static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer,
275 cir::ScopeOp &op,
276 mlir::Region &region) {
277 printer.printRegion(region,
278 /*printEntryBlockArgs=*/false,
279 /*printBlockTerminators=*/!omitRegionTerm(region));
280}
281
282mlir::OptionalParseResult
283parseGlobalAddressSpaceValue(mlir::AsmParser &p,
284 mlir::ptr::MemorySpaceAttrInterface &attr);
285
286void printGlobalAddressSpaceValue(mlir::AsmPrinter &printer, cir::GlobalOp op,
287 mlir::ptr::MemorySpaceAttrInterface attr);
288
289//===----------------------------------------------------------------------===//
290// AllocaOp
291//===----------------------------------------------------------------------===//
292
293void cir::AllocaOp::build(mlir::OpBuilder &odsBuilder,
294 mlir::OperationState &odsState, mlir::Type addr,
295 llvm::StringRef name, mlir::IntegerAttr alignment) {
296 odsState.addAttribute(getNameAttrName(odsState.name),
297 odsBuilder.getStringAttr(name));
298 if (alignment) {
299 odsState.addAttribute(getAlignmentAttrName(odsState.name), alignment);
300 }
301 odsState.addTypes(addr);
302}
303
304//===----------------------------------------------------------------------===//
305// ArrayCtor & ArrayDtor
306//===----------------------------------------------------------------------===//
307
308template <typename Op> static LogicalResult verifyArrayCtorDtor(Op op) {
309 auto ptrTy = mlir::cast<cir::PointerType>(op.getAddr().getType());
310 mlir::Type pointeeTy = ptrTy.getPointee();
311
312 mlir::Block &body = op.getBody().front();
313 if (body.getNumArguments() != 1)
314 return op.emitOpError("body must have exactly one block argument");
315
316 auto expectedEltPtrTy =
317 mlir::dyn_cast<cir::PointerType>(body.getArgument(0).getType());
318 if (!expectedEltPtrTy)
319 return op.emitOpError("block argument must be a !cir.ptr type");
320
321 if (op.getNumElements()) {
322 auto recTy = mlir::dyn_cast<cir::RecordType>(pointeeTy);
323 if (!recTy)
324 return op.emitOpError(
325 "when 'num_elements' is present, 'addr' must be a pointer to a "
326 "!cir.struct or !cir.union type");
327
328 if (expectedEltPtrTy != ptrTy)
329 return op.emitOpError("when 'num_elements' is present, 'addr' type must "
330 "match the block argument type");
331 } else {
332 auto arrayTy = mlir::dyn_cast<cir::ArrayType>(pointeeTy);
333 if (!arrayTy)
334 return op.emitOpError(
335 "when 'num_elements' is absent, 'addr' must be a pointer to a "
336 "!cir.array type");
337
338 mlir::Type innerEltTy = arrayTy.getElementType();
339 while (auto nested = mlir::dyn_cast<cir::ArrayType>(innerEltTy))
340 innerEltTy = nested.getElementType();
341
342 auto recTy = mlir::dyn_cast<cir::RecordType>(innerEltTy);
343 if (!recTy)
344 return op.emitOpError("the block argument type must be a pointer to a "
345 "!cir.struct or !cir.union type");
346
347 if (expectedEltPtrTy.getPointee() != innerEltTy)
348 return op.emitOpError(
349 "block argument pointee type must match the innermost array "
350 "element type");
351 }
352
353 return success();
354}
355
356LogicalResult cir::ArrayCtor::verify() {
357 if (failed(verifyArrayCtorDtor(*this)))
358 return failure();
359
360 mlir::Region &partialDtor = getPartialDtor();
361 if (!partialDtor.empty()) {
362 mlir::Block &dtorBlock = partialDtor.front();
363 if (dtorBlock.getNumArguments() != 1)
364 return emitOpError("partial_dtor must have exactly one block argument");
365
366 auto bodyArgTy = getBody().front().getArgument(0).getType();
367 if (dtorBlock.getArgument(0).getType() != bodyArgTy)
368 return emitOpError("partial_dtor block argument type must match "
369 "the body block argument type");
370 }
371 return success();
372}
373LogicalResult cir::ArrayDtor::verify() { return verifyArrayCtorDtor(*this); }
374
375//===----------------------------------------------------------------------===//
376// DeleteArrayOp
377//===----------------------------------------------------------------------===//
378
379LogicalResult cir::DeleteArrayOp::verify() {
380 if (getDtorMayThrow() && !getElementDtorAttr())
381 return emitOpError(
382 "'dtor_may_throw' requires an 'element_dtor' to be present");
383 return success();
384}
385
386//===----------------------------------------------------------------------===//
387// AssumeOp
388//===----------------------------------------------------------------------===//
389
390static void printAssumeBundle(OpAsmPrinter &p, cir::AssumeOp op,
391 cir::AssumeBundleKindAttr kindAttr,
392 OperandRange bundleArgs,
393 TypeRange bundleArgTypes) {
394 cir::AssumeBundleKind kind = kindAttr.getValue();
395 if (kind == cir::AssumeBundleKind::None)
396 return;
397
398 p << " " << cir::stringifyAssumeBundleKind(kind);
399 if (bundleArgs.empty())
400 return;
401
402 p << "(";
403 p.printOperands(bundleArgs);
404 p << " : ";
405 llvm::interleaveComma(bundleArgTypes, p);
406 p << ")";
407}
408
409static ParseResult parseAssumeBundle(
410 OpAsmParser &p, cir::AssumeBundleKindAttr &bundleKindAttr,
412 llvm::SmallVector<mlir::Type, 1> &bundleArgTypes) {
413 StringRef keyword;
414 auto loc = p.getCurrentLocation();
415 if (failed(p.parseOptionalKeyword(&keyword))) {
416 bundleKindAttr = cir::AssumeBundleKindAttr::get(
417 p.getContext(), cir::AssumeBundleKind::None);
418 return success();
419 }
420
421 std::optional<cir::AssumeBundleKind> parsedKind =
422 cir::symbolizeAssumeBundleKind(keyword);
423 if (!parsedKind)
424 return p.emitError(loc, "unknown assume bundle kind '") << keyword << "'";
425
426 bundleKindAttr = cir::AssumeBundleKindAttr::get(p.getContext(), *parsedKind);
427
428 if (p.parseOptionalLParen())
429 return success();
430
431 if (p.parseOperandList(bundleArgs) || p.parseColon() ||
432 p.parseTypeList(bundleArgTypes) || p.parseRParen())
433 return failure();
434
435 return success();
436}
437
438LogicalResult cir::AssumeOp::verify() {
439 cir::AssumeBundleKind kind = getBundleKind();
440 size_t numArgs = getBundleArgs().size();
441
442 if (kind == cir::AssumeBundleKind::None) {
443 if (numArgs != 0)
444 return emitOpError("unexpected bundle operands for kind 'none'");
445 return success();
446 }
447
448 if (numArgs == 0)
449 return emitOpError("expected bundle operands for kind '")
450 << cir::stringifyAssumeBundleKind(kind) << "'";
451
452 switch (kind) {
453 case cir::AssumeBundleKind::Align:
454 if (numArgs != 2 && numArgs != 3)
455 return emitOpError("align bundle expects 2 or 3 operands");
456 break;
457 case cir::AssumeBundleKind::SeparateStorage:
458 if (numArgs != 2)
459 return emitOpError("separate_storage bundle expects 2 operands");
460 break;
461 case cir::AssumeBundleKind::Dereferenceable:
462 if (numArgs != 2)
463 return emitOpError("dereferenceable bundle expects 2 operands");
464 break;
465 default:
466 break;
467 }
468 return success();
469}
470
471//===----------------------------------------------------------------------===//
472// LocalInitOp
473//===----------------------------------------------------------------------===//
474
475LogicalResult
476cir::LocalInitOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
477 cir::GlobalOp global = getReferencedGlobal(symbolTable);
478 if (!global)
479 return emitOpError("'")
480 << getGlobalName() << "' does not reference a valid cir.global";
481
482 if (getTls() && !global.getTlsModel())
483 return emitOpError("access to global not marked thread local");
484
485 if (!global.getStaticLocalGuard().has_value())
486 return emitOpError("static_local attribute mismatch");
487
488 return success();
489}
490
491//===----------------------------------------------------------------------===//
492// ConditionOp
493//===----------------------------------------------------------------------===//
494
495//===----------------------------------
496// BranchOpTerminatorInterface Methods
497//===----------------------------------
498
499void cir::ConditionOp::getSuccessorRegions(
500 ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
501 // TODO(cir): The condition value may be folded to a constant, narrowing
502 // down its list of possible successors.
503
504 // Parent is a loop: condition may branch to the body or to the parent op.
505 if (auto loopOp = dyn_cast<LoopOpInterface>(getOperation()->getParentOp())) {
506 regions.emplace_back(&loopOp.getBody());
507 regions.emplace_back(getOperation());
508 return;
509 }
510
511 // Parent is an await: condition may branch to resume or suspend regions.
512 auto await = cast<AwaitOp>(getOperation()->getParentOp());
513 regions.emplace_back(&await.getResume());
514 regions.emplace_back(&await.getSuspend());
515}
516
517MutableOperandRange
518cir::ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
519 // No values are yielded to the successor region.
520 return MutableOperandRange(getOperation(), 0, 0);
521}
522
523MutableOperandRange
524cir::ResumeOp::getMutableSuccessorOperands(RegionSuccessor point) {
525 // The eh_token operand is not forwarded to the parent region.
526 return MutableOperandRange(getOperation(), 0, 0);
527}
528
529LogicalResult cir::ConditionOp::verify() {
530 if (!isa<LoopOpInterface, AwaitOp>(getOperation()->getParentOp()))
531 return emitOpError("condition must be within a conditional region");
532 return success();
533}
534
535//===----------------------------------------------------------------------===//
536// ConstantOp
537//===----------------------------------------------------------------------===//
538
539static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
540 mlir::Attribute attrType) {
541 if (isa<cir::ConstPtrAttr>(attrType)) {
542 if (!mlir::isa<cir::PointerType>(opType))
543 return op->emitOpError(
544 "pointer constant initializing a non-pointer type");
545 return success();
546 }
547
548 if (isa<cir::DataMemberAttr, cir::MethodAttr>(attrType)) {
549 // More detailed type verifications are already done in
550 // DataMemberAttr::verify or MethodAttr::verify. Don't need to repeat here.
551 return success();
552 }
553
554 if (isa<cir::ZeroAttr>(attrType)) {
555 if (isa<cir::RecordType, cir::ArrayType, cir::VectorType, cir::ComplexType>(
556 opType))
557 return success();
558 return op->emitOpError(
559 "zero expects struct, array, vector, or complex type");
560 }
561
562 if (mlir::isa<cir::UndefAttr>(attrType)) {
563 if (!mlir::isa<cir::VoidType>(opType))
564 return success();
565 return op->emitOpError("undef expects non-void type");
566 }
567
568 if (mlir::isa<cir::BoolAttr>(attrType)) {
569 if (!mlir::isa<cir::BoolType>(opType))
570 return op->emitOpError("result type (")
571 << opType << ") must be '!cir.bool' for '" << attrType << "'";
572 return success();
573 }
574
575 if (mlir::isa<cir::IntAttr, cir::FPAttr>(attrType)) {
576 auto at = cast<TypedAttr>(attrType);
577 if (at.getType() != opType) {
578 return op->emitOpError("result type (")
579 << opType << ") does not match value type (" << at.getType()
580 << ")";
581 }
582 return success();
583 }
584
585 if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
586 cir::ConstComplexAttr, cir::ConstRecordAttr,
587 cir::GlobalViewAttr, cir::PoisonAttr, cir::TypeInfoAttr,
588 cir::VTableAttr>(attrType))
589 return success();
590
591 assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
592 return op->emitOpError("global with type ")
593 << cast<TypedAttr>(attrType).getType() << " not yet supported";
594}
595
596LogicalResult cir::ConstantOp::verify() {
597 // ODS already generates checks to make sure the result type is valid. We just
598 // need to additionally check that the value's attribute type is consistent
599 // with the result type.
600 return checkConstantTypes(getOperation(), getType(), getValue());
601}
602
603OpFoldResult cir::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
604 return getValue();
605}
606
607//===----------------------------------------------------------------------===//
608// CastOp
609//===----------------------------------------------------------------------===//
610
611LogicalResult cir::CastOp::verify() {
612 mlir::Type resType = getType();
613 mlir::Type srcType = getSrc().getType();
614
615 // Verify address space casts for pointer types. given that
616 // casts for within a different address space are illegal.
617 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
618 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
619 if (srcPtrTy && resPtrTy && (getKind() != cir::CastKind::address_space))
620 if (srcPtrTy.getAddrSpace() != resPtrTy.getAddrSpace()) {
621 return emitOpError() << "result type address space does not match the "
622 "address space of the operand";
623 }
624
625 cir::CastKind kind = getKind();
626 auto srcVTy = mlir::dyn_cast<cir::VectorType>(srcType);
627 auto resVTy = mlir::dyn_cast<cir::VectorType>(resType);
628 if (srcVTy && resVTy) {
629 if ((kind == cir::CastKind::int_to_float ||
630 kind == cir::CastKind::float_to_int) &&
631 srcVTy.getSize() != resVTy.getSize()) {
632 return emitOpError()
633 << "vector float-to-int and int-to-float casts require "
634 "source and destination vectors to have the same number of "
635 "elements";
636 }
637 // Use the element type of the vector to verify the cast kind. (Except for
638 // bitcast, see below.)
639 srcType = srcVTy.getElementType();
640 resType = resVTy.getElementType();
641 }
642
643 switch (getKind()) {
644 case cir::CastKind::int_to_bool: {
645 if (!mlir::isa<cir::BoolType>(resType))
646 return emitOpError() << "requires !cir.bool type for result";
647 if (!mlir::isa<cir::IntType>(srcType))
648 return emitOpError() << "requires !cir.int type for source";
649 return success();
650 }
651 case cir::CastKind::ptr_to_bool: {
652 if (!mlir::isa<cir::BoolType>(resType))
653 return emitOpError() << "requires !cir.bool type for result";
654 if (!mlir::isa<cir::PointerType>(srcType))
655 return emitOpError() << "requires !cir.ptr type for source";
656 return success();
657 }
658 case cir::CastKind::integral: {
659 if (!mlir::isa<cir::IntType>(resType))
660 return emitOpError() << "requires !cir.int type for result";
661 if (!mlir::isa<cir::IntType>(srcType))
662 return emitOpError() << "requires !cir.int type for source";
663 return success();
664 }
665 case cir::CastKind::array_to_ptrdecay: {
666 const auto arrayPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
667 const auto flatPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
668 if (!arrayPtrTy || !flatPtrTy)
669 return emitOpError() << "requires !cir.ptr type for source and result";
670
671 // TODO(CIR): Make sure the AddrSpace of both types are equals
672 return success();
673 }
674 case cir::CastKind::bitcast: {
675 // Handle the pointer types first.
676 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
677 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
678
679 if (srcPtrTy && resPtrTy) {
680 return success();
681 }
682
683 return success();
684 }
685 case cir::CastKind::floating: {
686 if (!mlir::isa<cir::FPTypeInterface>(srcType) ||
687 !mlir::isa<cir::FPTypeInterface>(resType))
688 return emitOpError() << "requires !cir.float type for source and result";
689 return success();
690 }
691 case cir::CastKind::float_to_int: {
692 if (!mlir::isa<cir::FPTypeInterface>(srcType))
693 return emitOpError() << "requires !cir.float type for source";
694 if (!mlir::dyn_cast<cir::IntType>(resType))
695 return emitOpError() << "requires !cir.int type for result";
696 return success();
697 }
698 case cir::CastKind::int_to_ptr: {
699 if (!mlir::dyn_cast<cir::IntType>(srcType))
700 return emitOpError() << "requires !cir.int type for source";
701 if (!mlir::dyn_cast<cir::PointerType>(resType))
702 return emitOpError() << "requires !cir.ptr type for result";
703 return success();
704 }
705 case cir::CastKind::ptr_to_int: {
706 if (!mlir::dyn_cast<cir::PointerType>(srcType))
707 return emitOpError() << "requires !cir.ptr type for source";
708 if (!mlir::dyn_cast<cir::IntType>(resType))
709 return emitOpError() << "requires !cir.int type for result";
710 return success();
711 }
712 case cir::CastKind::float_to_bool: {
713 if (!mlir::isa<cir::FPTypeInterface>(srcType))
714 return emitOpError() << "requires !cir.float type for source";
715 if (!mlir::isa<cir::BoolType>(resType))
716 return emitOpError() << "requires !cir.bool type for result";
717 return success();
718 }
719 case cir::CastKind::bool_to_int: {
720 if (!mlir::isa<cir::BoolType>(srcType))
721 return emitOpError() << "requires !cir.bool type for source";
722 if (!mlir::isa<cir::IntType>(resType))
723 return emitOpError() << "requires !cir.int type for result";
724 return success();
725 }
726 case cir::CastKind::int_to_float: {
727 if (!mlir::isa<cir::IntType>(srcType))
728 return emitOpError() << "requires !cir.int type for source";
729 if (!mlir::isa<cir::FPTypeInterface>(resType))
730 return emitOpError() << "requires !cir.float type for result";
731 return success();
732 }
733 case cir::CastKind::bool_to_float: {
734 if (!mlir::isa<cir::BoolType>(srcType))
735 return emitOpError() << "requires !cir.bool type for source";
736 if (!mlir::isa<cir::FPTypeInterface>(resType))
737 return emitOpError() << "requires !cir.float type for result";
738 return success();
739 }
740 case cir::CastKind::address_space: {
741 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
742 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
743 if (!srcPtrTy || !resPtrTy)
744 return emitOpError() << "requires !cir.ptr type for source and result";
745 if (srcPtrTy.getPointee() != resPtrTy.getPointee())
746 return emitOpError() << "requires two types differ in addrspace only";
747 return success();
748 }
749 case cir::CastKind::float_to_complex: {
750 if (!mlir::isa<cir::FPTypeInterface>(srcType))
751 return emitOpError() << "requires !cir.float type for source";
752 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
753 if (!resComplexTy)
754 return emitOpError() << "requires !cir.complex type for result";
755 if (srcType != resComplexTy.getElementType())
756 return emitOpError() << "requires source type match result element type";
757 return success();
758 }
759 case cir::CastKind::int_to_complex: {
760 if (!mlir::isa<cir::IntType>(srcType))
761 return emitOpError() << "requires !cir.int type for source";
762 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
763 if (!resComplexTy)
764 return emitOpError() << "requires !cir.complex type for result";
765 if (srcType != resComplexTy.getElementType())
766 return emitOpError() << "requires source type match result element type";
767 return success();
768 }
769 case cir::CastKind::float_complex_to_real: {
770 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
771 if (!srcComplexTy)
772 return emitOpError() << "requires !cir.complex type for source";
773 if (!mlir::isa<cir::FPTypeInterface>(resType))
774 return emitOpError() << "requires !cir.float type for result";
775 if (srcComplexTy.getElementType() != resType)
776 return emitOpError() << "requires source element type match result type";
777 return success();
778 }
779 case cir::CastKind::int_complex_to_real: {
780 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
781 if (!srcComplexTy)
782 return emitOpError() << "requires !cir.complex type for source";
783 if (!mlir::isa<cir::IntType>(resType))
784 return emitOpError() << "requires !cir.int type for result";
785 if (srcComplexTy.getElementType() != resType)
786 return emitOpError() << "requires source element type match result type";
787 return success();
788 }
789 case cir::CastKind::float_complex_to_bool: {
790 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
791 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
792 return emitOpError()
793 << "requires floating point !cir.complex type for source";
794 if (!mlir::isa<cir::BoolType>(resType))
795 return emitOpError() << "requires !cir.bool type for result";
796 return success();
797 }
798 case cir::CastKind::int_complex_to_bool: {
799 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
800 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
801 return emitOpError()
802 << "requires floating point !cir.complex type for source";
803 if (!mlir::isa<cir::BoolType>(resType))
804 return emitOpError() << "requires !cir.bool type for result";
805 return success();
806 }
807 case cir::CastKind::float_complex: {
808 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
809 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
810 return emitOpError()
811 << "requires floating point !cir.complex type for source";
812 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
813 if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
814 return emitOpError()
815 << "requires floating point !cir.complex type for result";
816 return success();
817 }
818 case cir::CastKind::float_complex_to_int_complex: {
819 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
820 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
821 return emitOpError()
822 << "requires floating point !cir.complex type for source";
823 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
824 if (!resComplexTy || !resComplexTy.isIntegerComplex())
825 return emitOpError() << "requires integer !cir.complex type for result";
826 return success();
827 }
828 case cir::CastKind::int_complex: {
829 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
830 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
831 return emitOpError() << "requires integer !cir.complex type for source";
832 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
833 if (!resComplexTy || !resComplexTy.isIntegerComplex())
834 return emitOpError() << "requires integer !cir.complex type for result";
835 return success();
836 }
837 case cir::CastKind::int_complex_to_float_complex: {
838 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
839 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
840 return emitOpError() << "requires integer !cir.complex type for source";
841 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
842 if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
843 return emitOpError()
844 << "requires floating point !cir.complex type for result";
845 return success();
846 }
847 case cir::CastKind::member_ptr_to_bool: {
848 if (!mlir::isa<cir::DataMemberType, cir::MethodType>(srcType))
849 return emitOpError()
850 << "requires !cir.data_member or !cir.method type for source";
851 if (!mlir::isa<cir::BoolType>(resType))
852 return emitOpError() << "requires !cir.bool type for result";
853 return success();
854 }
855 }
856 llvm_unreachable("Unknown CastOp kind?");
857}
858
859static bool isIntOrBoolCast(cir::CastOp op) {
860 auto kind = op.getKind();
861 return kind == cir::CastKind::bool_to_int ||
862 kind == cir::CastKind::int_to_bool || kind == cir::CastKind::integral;
863}
864
865static bool isCirFunctionPointerType(mlir::Type ty) {
866 const auto ptrTy = mlir::dyn_cast<cir::PointerType>(ty);
867 return ptrTy && mlir::isa<cir::FuncType>(ptrTy.getPointee());
868}
869
870static Value tryFoldCastChain(cir::CastOp op) {
871 cir::CastOp head = op, tail = op;
872
873 while (op) {
874 if (!isIntOrBoolCast(op))
875 break;
876 head = op;
877 op = head.getSrc().getDefiningOp<cir::CastOp>();
878 }
879
880 if (head != tail) {
881 // if bool_to_int -> ... -> int_to_bool: take the bool
882 // as we had it was before all casts
883 if (head.getKind() == cir::CastKind::bool_to_int &&
884 tail.getKind() == cir::CastKind::int_to_bool)
885 return head.getSrc();
886
887 // if int_to_bool -> ... -> int_to_bool: take the result
888 // of the first one, as no other casts (and ext casts as well)
889 // don't change the first result
890 if (head.getKind() == cir::CastKind::int_to_bool &&
891 tail.getKind() == cir::CastKind::int_to_bool)
892 return head.getResult();
893
894 return {};
895 }
896
897 // Bitcast round-trip on function pointers: T0 -> T1 -> T0 (e.g. no-proto
898 // redeclaration vs. actual prototype). Restrict to function pointers so
899 // other pointer bitcast chains are unchanged.
900 if (tail.getKind() == cir::CastKind::bitcast) {
901 auto *inner = tail.getSrc().getDefiningOp();
902 if (inner && isCirFunctionPointerType(tail.getType())) {
903 auto innerCast = mlir::dyn_cast<cir::CastOp>(inner);
904 if (innerCast && innerCast.getKind() == cir::CastKind::bitcast &&
905 innerCast.getSrc().getType() == tail.getType() &&
906 innerCast.getType() == tail.getSrc().getType()) {
907 return innerCast.getSrc();
908 }
909 }
910 }
911
912 return {};
913}
914
915OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
916 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getSrc())) {
917 // Propagate poison value
918 return cir::PoisonAttr::get(getContext(), getType());
919 }
920
921 if (getSrc().getType() == getType()) {
922 switch (getKind()) {
923 case cir::CastKind::integral: {
925 auto foldOrder = getSrc().getDefiningOp()->fold(foldResults);
926 if (foldOrder.succeeded() && mlir::isa<mlir::Attribute>(foldResults[0]))
927 return mlir::cast<mlir::Attribute>(foldResults[0]);
928 return {};
929 }
930 case cir::CastKind::bitcast:
931 case cir::CastKind::address_space:
932 case cir::CastKind::float_complex:
933 case cir::CastKind::int_complex: {
934 return getSrc();
935 }
936 default:
937 return {};
938 }
939 }
940
941 // Handle cases where a chain of casts cancel out.
942 Value result = tryFoldCastChain(*this);
943 if (result)
944 return result;
945
946 // Handle simple constant casts.
947 if (auto srcConst = getSrc().getDefiningOp<cir::ConstantOp>()) {
948 switch (getKind()) {
949 case cir::CastKind::integral: {
950 mlir::Type srcTy = getSrc().getType();
951 // Don't try to fold vector casts for now.
952 assert(mlir::isa<cir::VectorType>(srcTy) ==
953 mlir::isa<cir::VectorType>(getType()));
954 if (mlir::isa<cir::VectorType>(srcTy))
955 break;
956
957 auto srcIntTy = mlir::cast<cir::IntType>(srcTy);
958 auto dstIntTy = mlir::cast<cir::IntType>(getType());
959 APInt newVal =
960 srcIntTy.isSigned()
961 ? srcConst.getIntValue().sextOrTrunc(dstIntTy.getWidth())
962 : srcConst.getIntValue().zextOrTrunc(dstIntTy.getWidth());
963 return cir::IntAttr::get(dstIntTy, newVal);
964 }
965 default:
966 break;
967 }
968 }
969 return {};
970}
971
972//===----------------------------------------------------------------------===//
973// CallOp
974//===----------------------------------------------------------------------===//
975
976mlir::OperandRange cir::CallOp::getArgOperands() {
977 if (isIndirect())
978 return getArgs().drop_front(1);
979 return getArgs();
980}
981
982mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
983 mlir::MutableOperandRange args = getArgsMutable();
984 if (isIndirect())
985 return args.slice(1, args.size() - 1);
986 return args;
987}
988
989mlir::Value cir::CallOp::getIndirectCall() {
990 assert(isIndirect());
991 return getOperand(0);
992}
993
994/// Return the operand at index 'i'.
995Value cir::CallOp::getArgOperand(unsigned i) {
996 if (isIndirect())
997 ++i;
998 return getOperand(i);
999}
1000
1001/// Return the number of operands.
1002unsigned cir::CallOp::getNumArgOperands() {
1003 if (isIndirect())
1004 return this->getOperation()->getNumOperands() - 1;
1005 return this->getOperation()->getNumOperands();
1006}
1007
1008static mlir::ParseResult
1009parseTryCallDestinations(mlir::OpAsmParser &parser,
1010 mlir::OperationState &result) {
1011 mlir::Block *normalDestSuccessor;
1012 if (parser.parseSuccessor(normalDestSuccessor))
1013 return mlir::failure();
1014
1015 if (parser.parseComma())
1016 return mlir::failure();
1017
1018 mlir::Block *unwindDestSuccessor;
1019 if (parser.parseSuccessor(unwindDestSuccessor))
1020 return mlir::failure();
1021
1022 result.addSuccessors(normalDestSuccessor);
1023 result.addSuccessors(unwindDestSuccessor);
1024 return mlir::success();
1025}
1026
1027static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
1028 mlir::OperationState &result,
1029 bool hasDestinationBlocks = false) {
1031 llvm::SMLoc opsLoc;
1032 mlir::FlatSymbolRefAttr calleeAttr;
1033
1034 // If we cannot parse a string callee, it means this is an indirect call.
1035 if (!parser
1036 .parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
1037 result.attributes)
1038 .has_value()) {
1039 OpAsmParser::UnresolvedOperand indirectVal;
1040 // Do not resolve right now, since we need to figure out the type
1041 if (parser.parseOperand(indirectVal).failed())
1042 return failure();
1043 ops.push_back(indirectVal);
1044 }
1045
1046 if (parser.parseLParen())
1047 return mlir::failure();
1048
1049 opsLoc = parser.getCurrentLocation();
1050 if (parser.parseOperandList(ops))
1051 return mlir::failure();
1052 if (parser.parseRParen())
1053 return mlir::failure();
1054
1055 if (hasDestinationBlocks &&
1056 parseTryCallDestinations(parser, result).failed()) {
1057 return ::mlir::failure();
1058 }
1059
1060 if (parser.parseOptionalKeyword("musttail").succeeded())
1061 result.addAttribute(CIRDialect::getMustTailAttrName(),
1062 mlir::UnitAttr::get(parser.getContext()));
1063
1064 if (parser.parseOptionalKeyword("nothrow").succeeded())
1065 result.addAttribute(CIRDialect::getNoThrowAttrName(),
1066 mlir::UnitAttr::get(parser.getContext()));
1067
1068 if (parser.parseOptionalKeyword("side_effect").succeeded()) {
1069 if (parser.parseLParen().failed())
1070 return failure();
1071 cir::SideEffect sideEffect;
1072 if (parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed())
1073 return failure();
1074 if (parser.parseRParen().failed())
1075 return failure();
1076 auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
1077 result.addAttribute(CIRDialect::getSideEffectAttrName(), attr);
1078 }
1079
1080 if (parser.parseOptionalAttrDict(result.attributes))
1081 return ::mlir::failure();
1082
1083 if (parser.parseColon())
1084 return ::mlir::failure();
1085
1086 SmallVector<Type> argTypes;
1088 SmallVector<Type> resultTypes;
1089 SmallVector<DictionaryAttr> resultAttrs;
1090 if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
1091 resultTypes, resultAttrs))
1092 return mlir::failure();
1093
1094 if (resultTypes.size() > 1 || resultAttrs.size() > 1)
1095 return parser.emitError(
1096 parser.getCurrentLocation(),
1097 "functions with multiple return types are not supported");
1098
1099 result.addTypes(resultTypes);
1100
1101 if (parser.resolveOperands(ops, argTypes, opsLoc, result.operands))
1102 return mlir::failure();
1103
1104 if (!resultAttrs.empty() && resultAttrs[0])
1105 result.addAttribute(
1106 CIRDialect::getResAttrsAttrName(),
1107 mlir::ArrayAttr::get(parser.getContext(), {resultAttrs[0]}));
1108
1109 // ArrayAttr requires a vector of 'Attribute', so we have to do the conversion
1110 // here into a separate collection.
1111 llvm::SmallVector<Attribute> convertedArgAttrs;
1112 bool argAttrsEmpty = true;
1113
1114 llvm::transform(argAttrs, std::back_inserter(convertedArgAttrs),
1115 [&](DictionaryAttr da) -> mlir::Attribute {
1116 if (da)
1117 argAttrsEmpty = false;
1118 return da;
1119 });
1120
1121 if (!argAttrsEmpty) {
1122 llvm::ArrayRef argAttrsRef = convertedArgAttrs;
1123 if (!calleeAttr) {
1124 // Fixup for indirect calls, which get an extra entry in the 'args' for
1125 // the indirect type, which doesn't get attributes.
1126 argAttrsRef = argAttrsRef.drop_front();
1127 }
1128 result.addAttribute(CIRDialect::getArgAttrsAttrName(),
1129 mlir::ArrayAttr::get(parser.getContext(), argAttrsRef));
1130 }
1131
1132 return mlir::success();
1133}
1134
1135static void
1136printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym,
1137 mlir::Value indirectCallee, mlir::OpAsmPrinter &printer,
1138 bool isNothrow, cir::SideEffect sideEffect, ArrayAttr argAttrs,
1139 ArrayAttr resAttrs, mlir::Block *normalDest = nullptr,
1140 mlir::Block *unwindDest = nullptr) {
1141 printer << ' ';
1142
1143 auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
1144 auto ops = callLikeOp.getArgOperands();
1145
1146 if (calleeSym) {
1147 // Direct calls
1148 printer.printAttributeWithoutType(calleeSym);
1149 } else {
1150 // Indirect calls
1151 assert(indirectCallee);
1152 printer << indirectCallee;
1153 }
1154
1155 printer << "(" << ops << ")";
1156
1157 if (normalDest) {
1158 assert(unwindDest && "expected two successors");
1159 auto tryCall = cast<cir::TryCallOp>(op);
1160 printer << ' ' << tryCall.getNormalDest();
1161 printer << ",";
1162 printer << ' ';
1163 printer << tryCall.getUnwindDest();
1164 }
1165
1166 if (op->hasAttr(CIRDialect::getMustTailAttrName()))
1167 printer << " musttail";
1168
1169 if (isNothrow)
1170 printer << " nothrow";
1171
1172 if (sideEffect != cir::SideEffect::All) {
1173 printer << " side_effect(";
1174 printer << stringifySideEffect(sideEffect);
1175 printer << ")";
1176 }
1177
1179 CIRDialect::getCalleeAttrName(),
1180 CIRDialect::getMustTailAttrName(),
1181 CIRDialect::getNoThrowAttrName(),
1182 CIRDialect::getSideEffectAttrName(),
1183 CIRDialect::getOperandSegmentSizesAttrName(),
1184 llvm::StringRef("res_attrs"),
1185 llvm::StringRef("arg_attrs")};
1186 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
1187 printer << " : ";
1188 if (calleeSym || !argAttrs) {
1189 call_interface_impl::printFunctionSignature(
1190 printer, op->getOperands().getTypes(), argAttrs,
1191 /*isVariadic=*/false, op->getResultTypes(), resAttrs);
1192 } else {
1193 // indirect function calls use an 'arg' type for the type of its indirect
1194 // argument. However, we don't store a similar attribute collection. In
1195 // order to make `printFunctionSignature` have the attributes line up, we
1196 // have to make a 'shimmed' copy of the attributes that have a blank set of
1197 // attributes for the indirect argument.
1198 llvm::SmallVector<Attribute> shimmedArgAttrs;
1199 shimmedArgAttrs.push_back(mlir::DictionaryAttr::get(op->getContext(), {}));
1200 shimmedArgAttrs.append(argAttrs.begin(), argAttrs.end());
1201 call_interface_impl::printFunctionSignature(
1202 printer, op->getOperands().getTypes(),
1203 mlir::ArrayAttr::get(op->getContext(), shimmedArgAttrs),
1204 /*isVariadic=*/false, op->getResultTypes(), resAttrs);
1205 }
1206}
1207
1208mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
1209 mlir::OperationState &result) {
1210 return parseCallCommon(parser, result);
1211}
1212
1213void cir::CallOp::print(mlir::OpAsmPrinter &p) {
1214 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
1215 cir::SideEffect sideEffect = getSideEffect();
1216 printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
1217 sideEffect, getArgAttrsAttr(), getResAttrsAttr());
1218}
1219
1220static LogicalResult
1221verifyCallCommInSymbolUses(mlir::Operation *op,
1222 SymbolTableCollection &symbolTable) {
1223 auto fnAttr =
1224 op->getAttrOfType<FlatSymbolRefAttr>(CIRDialect::getCalleeAttrName());
1225 if (!fnAttr) {
1226 // This is an indirect call, thus we don't have to check the symbol uses.
1227 return mlir::success();
1228 }
1229
1230 auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
1231 if (!fn)
1232 return op->emitOpError() << "'" << fnAttr.getValue()
1233 << "' does not reference a valid function";
1234
1235 auto callIf = dyn_cast<cir::CIRCallOpInterface>(op);
1236 assert(callIf && "expected CIR call interface to be always available");
1237
1238 // Verify that the operand and result types match the callee. Note that
1239 // argument-checking is disabled for functions without a prototype.
1240 auto fnType = fn.getFunctionType();
1241 if (!fn.getNoProto()) {
1242 unsigned numCallOperands = callIf.getNumArgOperands();
1243 unsigned numFnOpOperands = fnType.getNumInputs();
1244
1245 if (!fnType.isVarArg() && numCallOperands != numFnOpOperands)
1246 return op->emitOpError("incorrect number of operands for callee");
1247 if (fnType.isVarArg() && numCallOperands < numFnOpOperands)
1248 return op->emitOpError("too few operands for callee");
1249
1250 for (unsigned i = 0, e = numFnOpOperands; i != e; ++i)
1251 if (callIf.getArgOperand(i).getType() != fnType.getInput(i))
1252 return op->emitOpError("operand type mismatch: expected operand type ")
1253 << fnType.getInput(i) << ", but provided "
1254 << op->getOperand(i).getType() << " for operand number " << i;
1255 }
1256
1258
1259 // Void function must not return any results.
1260 if (fnType.hasVoidReturn() && op->getNumResults() != 0)
1261 return op->emitOpError("callee returns void but call has results");
1262
1263 // Non-void function calls must return exactly one result.
1264 if (!fnType.hasVoidReturn() && op->getNumResults() != 1)
1265 return op->emitOpError("incorrect number of results for callee");
1266
1267 // Parent function and return value types must match.
1268 if (!fnType.hasVoidReturn() &&
1269 op->getResultTypes().front() != fnType.getReturnType()) {
1270 return op->emitOpError("result type mismatch: expected ")
1271 << fnType.getReturnType() << ", but provided "
1272 << op->getResult(0).getType();
1273 }
1274
1275 return mlir::success();
1276}
1277
1278LogicalResult
1279cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1280 return verifyCallCommInSymbolUses(*this, symbolTable);
1281}
1282
1283//===----------------------------------------------------------------------===//
1284// TryCallOp
1285//===----------------------------------------------------------------------===//
1286
1287mlir::OperandRange cir::TryCallOp::getArgOperands() {
1288 if (isIndirect())
1289 return getArgs().drop_front(1);
1290 return getArgs();
1291}
1292
1293mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
1294 mlir::MutableOperandRange args = getArgsMutable();
1295 if (isIndirect())
1296 return args.slice(1, args.size() - 1);
1297 return args;
1298}
1299
1300mlir::Value cir::TryCallOp::getIndirectCall() {
1301 assert(isIndirect());
1302 return getOperand(0);
1303}
1304
1305/// Return the operand at index 'i'.
1306Value cir::TryCallOp::getArgOperand(unsigned i) {
1307 if (isIndirect())
1308 ++i;
1309 return getOperand(i);
1310}
1311
1312/// Return the number of operands.
1313unsigned cir::TryCallOp::getNumArgOperands() {
1314 if (isIndirect())
1315 return this->getOperation()->getNumOperands() - 1;
1316 return this->getOperation()->getNumOperands();
1317}
1318
1319LogicalResult
1320cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1321 return verifyCallCommInSymbolUses(*this, symbolTable);
1322}
1323
1324mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
1325 mlir::OperationState &result) {
1326 return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
1327}
1328
1329void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
1330 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
1331 cir::SideEffect sideEffect = getSideEffect();
1332 printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
1333 sideEffect, getArgAttrsAttr(), getResAttrsAttr(),
1334 getNormalDest(), getUnwindDest());
1335}
1336
1337//===----------------------------------------------------------------------===//
1338// ReturnOp
1339//===----------------------------------------------------------------------===//
1340
1341static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op,
1342 cir::FuncOp function) {
1343 // ReturnOps currently only have a single optional operand.
1344 if (op.getNumOperands() > 1)
1345 return op.emitOpError() << "expects at most 1 return operand";
1346
1347 // Ensure returned type matches the function signature.
1348 auto expectedTy = function.getFunctionType().getReturnType();
1349 auto actualTy =
1350 (op.getNumOperands() == 0 ? cir::VoidType::get(op.getContext())
1351 : op.getOperand(0).getType());
1352 if (actualTy != expectedTy)
1353 return op.emitOpError() << "returns " << actualTy
1354 << " but enclosing function returns " << expectedTy;
1355
1356 return mlir::success();
1357}
1358
1359mlir::LogicalResult cir::ReturnOp::verify() {
1360 // Returns can be present in multiple different scopes, get the
1361 // wrapping function and start from there.
1362 auto *fnOp = getOperation()->getParentOp();
1363 while (!isa<cir::FuncOp>(fnOp))
1364 fnOp = fnOp->getParentOp();
1365
1366 // Make sure return types match function return type.
1367 if (checkReturnAndFunction(*this, cast<cir::FuncOp>(fnOp)).failed())
1368 return failure();
1369
1370 return success();
1371}
1372
1373//===----------------------------------------------------------------------===//
1374// IfOp
1375//===----------------------------------------------------------------------===//
1376
1377ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
1378 // create the regions for 'then'.
1379 result.regions.reserve(2);
1380 Region *thenRegion = result.addRegion();
1381 Region *elseRegion = result.addRegion();
1382
1383 mlir::Builder &builder = parser.getBuilder();
1384 OpAsmParser::UnresolvedOperand cond;
1385 Type boolType = cir::BoolType::get(builder.getContext());
1386
1387 if (parser.parseOperand(cond) ||
1388 parser.resolveOperand(cond, boolType, result.operands))
1389 return failure();
1390
1391 // Parse 'then' region.
1392 mlir::SMLoc parseThenLoc = parser.getCurrentLocation();
1393 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1394 return failure();
1395
1396 if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed())
1397 return failure();
1398
1399 // If we find an 'else' keyword, parse the 'else' region.
1400 if (!parser.parseOptionalKeyword("else")) {
1401 mlir::SMLoc parseElseLoc = parser.getCurrentLocation();
1402 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1403 return failure();
1404 if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed())
1405 return failure();
1406 }
1407
1408 // Parse the optional attribute list.
1409 if (parser.parseOptionalAttrDict(result.attributes))
1410 return failure();
1411 return success();
1412}
1413
1414void cir::IfOp::print(OpAsmPrinter &p) {
1415 p << " " << getCondition() << " ";
1416 mlir::Region &thenRegion = this->getThenRegion();
1417 p.printRegion(thenRegion,
1418 /*printEntryBlockArgs=*/false,
1419 /*printBlockTerminators=*/!omitRegionTerm(thenRegion));
1420
1421 // Print the 'else' regions if it exists and has a block.
1422 mlir::Region &elseRegion = this->getElseRegion();
1423 if (!elseRegion.empty()) {
1424 p << " else ";
1425 p.printRegion(elseRegion,
1426 /*printEntryBlockArgs=*/false,
1427 /*printBlockTerminators=*/!omitRegionTerm(elseRegion));
1428 }
1429
1430 p.printOptionalAttrDict(getOperation()->getAttrs());
1431}
1432
1433/// Default callback for IfOp builders.
1434void cir::buildTerminatedBody(OpBuilder &builder, Location loc) {
1435 // add cir.yield to end of the block
1436 cir::YieldOp::create(builder, loc);
1437}
1438
1439/// Given the region at `index`, or the parent operation if `index` is None,
1440/// return the successor regions. These are the regions that may be selected
1441/// during the flow of control. `operands` is a set of optional attributes that
1442/// correspond to a constant value for each operand, or null if that operand is
1443/// not a constant.
1444void cir::IfOp::getSuccessorRegions(mlir::RegionBranchPoint point,
1445 SmallVectorImpl<RegionSuccessor> &regions) {
1446 // The `then` and the `else` region branch back to the parent operation.
1447 if (!point.isParent()) {
1448 regions.emplace_back(getOperation());
1449 return;
1450 }
1451
1452 // Don't consider the else region if it is empty.
1453 Region *elseRegion = &this->getElseRegion();
1454 if (elseRegion->empty())
1455 elseRegion = nullptr;
1456
1457 // If the condition isn't constant, both regions may be executed.
1458 regions.push_back(RegionSuccessor(&getThenRegion()));
1459 if (elseRegion)
1460 regions.push_back(RegionSuccessor(elseRegion));
1461 else
1462 regions.emplace_back(getOperation());
1463}
1464
1465mlir::ValueRange cir::IfOp::getSuccessorInputs(RegionSuccessor successor) {
1466 return successor.isOperation() ? ValueRange(getOperation()->getResults())
1467 : ValueRange();
1468}
1469
1470void cir::IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1471 bool withElseRegion, BuilderCallbackRef thenBuilder,
1472 BuilderCallbackRef elseBuilder) {
1473 assert(thenBuilder && "the builder callback for 'then' must be present");
1474 result.addOperands(cond);
1475
1476 OpBuilder::InsertionGuard guard(builder);
1477 Region *thenRegion = result.addRegion();
1478 builder.createBlock(thenRegion);
1479 thenBuilder(builder, result.location);
1480
1481 Region *elseRegion = result.addRegion();
1482 if (!withElseRegion)
1483 return;
1484
1485 builder.createBlock(elseRegion);
1486 elseBuilder(builder, result.location);
1487}
1488
1489//===----------------------------------------------------------------------===//
1490// ScopeOp
1491//===----------------------------------------------------------------------===//
1492
1493/// Given the region at `index`, or the parent operation if `index` is None,
1494/// return the successor regions. These are the regions that may be selected
1495/// during the flow of control. `operands` is a set of optional attributes
1496/// that correspond to a constant value for each operand, or null if that
1497/// operand is not a constant.
1498void cir::ScopeOp::getSuccessorRegions(
1499 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1500 // The only region always branch back to the parent operation.
1501 if (!point.isParent()) {
1502 regions.emplace_back(getOperation());
1503 return;
1504 }
1505
1506 // If the condition isn't constant, both regions may be executed.
1507 regions.push_back(RegionSuccessor(&getScopeRegion()));
1508}
1509
1510mlir::ValueRange cir::ScopeOp::getSuccessorInputs(RegionSuccessor successor) {
1511 return successor.isOperation() ? ValueRange(getOperation()->getResults())
1512 : ValueRange();
1513}
1514
1515void cir::ScopeOp::build(
1516 OpBuilder &builder, OperationState &result,
1517 function_ref<void(OpBuilder &, Type &, Location)> scopeBuilder) {
1518 assert(scopeBuilder && "the builder callback for 'then' must be present");
1519
1520 OpBuilder::InsertionGuard guard(builder);
1521 Region *scopeRegion = result.addRegion();
1522 builder.createBlock(scopeRegion);
1524
1525 mlir::Type yieldTy;
1526 scopeBuilder(builder, yieldTy, result.location);
1527
1528 if (yieldTy)
1529 result.addTypes(TypeRange{yieldTy});
1530}
1531
1532void cir::ScopeOp::build(
1533 OpBuilder &builder, OperationState &result,
1534 function_ref<void(OpBuilder &, Location)> scopeBuilder) {
1535 assert(scopeBuilder && "the builder callback for 'then' must be present");
1536 OpBuilder::InsertionGuard guard(builder);
1537 Region *scopeRegion = result.addRegion();
1538 builder.createBlock(scopeRegion);
1540 scopeBuilder(builder, result.location);
1541}
1542
1543LogicalResult cir::ScopeOp::verify() {
1544 if (getRegion().empty()) {
1545 return emitOpError() << "cir.scope must not be empty since it should "
1546 "include at least an implicit cir.yield ";
1547 }
1548
1549 mlir::Block &lastBlock = getRegion().back();
1550 if (lastBlock.empty() || !lastBlock.mightHaveTerminator() ||
1551 !lastBlock.getTerminator()->hasTrait<OpTrait::IsTerminator>())
1552 return emitOpError() << "last block of cir.scope must be terminated";
1553 return success();
1554}
1555
1556LogicalResult cir::ScopeOp::fold(FoldAdaptor /*adaptor*/,
1557 SmallVectorImpl<OpFoldResult> &results) {
1558 // Only fold "trivial" scopes: a single block containing only a `cir.yield`.
1559 if (!getRegion().hasOneBlock())
1560 return failure();
1561 Block &block = getRegion().front();
1562 if (block.getOperations().size() != 1)
1563 return failure();
1564
1565 auto yield = dyn_cast<cir::YieldOp>(block.front());
1566 if (!yield)
1567 return failure();
1568
1569 // Only fold when the scope produces a value.
1570 if (getNumResults() != 1 || yield.getNumOperands() != 1)
1571 return failure();
1572
1573 results.push_back(yield.getOperand(0));
1574 return success();
1575}
1576
1577//===----------------------------------------------------------------------===//
1578// CleanupScopeOp
1579//===----------------------------------------------------------------------===//
1580
1581void cir::CleanupScopeOp::getSuccessorRegions(
1582 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1583 if (!point.isParent()) {
1584 regions.emplace_back(getOperation());
1585 return;
1586 }
1587
1588 // Execution always proceeds from the body region to the cleanup region.
1589 regions.push_back(RegionSuccessor(&getBodyRegion()));
1590 regions.push_back(RegionSuccessor(&getCleanupRegion()));
1591}
1592
1593mlir::ValueRange
1594cir::CleanupScopeOp::getSuccessorInputs(RegionSuccessor successor) {
1595 return ValueRange();
1596}
1597
1598LogicalResult cir::CleanupScopeOp::canonicalize(CleanupScopeOp op,
1599 PatternRewriter &rewriter) {
1600 auto isRegionTrivial = [](Region &region) {
1601 assert(!region.empty() && "CleanupScopeOp regions must not be empty");
1602 if (!region.hasOneBlock())
1603 return false;
1604 Block &block = llvm::getSingleElement(region);
1605 return llvm::hasSingleElement(block) &&
1606 isa<cir::YieldOp>(llvm::getSingleElement(block));
1607 };
1608
1609 Region &body = op.getBodyRegion();
1610 Region &cleanup = op.getCleanupRegion();
1611
1612 // An EH-only cleanup scope with an empty body can never trigger its cleanup
1613 // region — there are no operations in the body that could throw. Erase it.
1614 if (op.getCleanupKind() == CleanupKind::EH && isRegionTrivial(body)) {
1615 rewriter.eraseOp(op);
1616 return success();
1617 }
1618
1619 // A cleanup scope with a trivial cleanup region has no cleanup to perform.
1620 // Inline the body into the parent block and erase the scope.
1621 if (!isRegionTrivial(cleanup) || !body.hasOneBlock())
1622 return failure();
1623
1624 Block &bodyBlock = body.front();
1625 if (!isa<cir::YieldOp>(bodyBlock.getTerminator()))
1626 return failure();
1627
1628 Operation *yield = bodyBlock.getTerminator();
1629 rewriter.inlineBlockBefore(&bodyBlock, op);
1630 rewriter.eraseOp(yield);
1631 rewriter.eraseOp(op);
1632 return success();
1633}
1634
1635void cir::CleanupScopeOp::build(
1636 OpBuilder &builder, OperationState &result, CleanupKind cleanupKind,
1637 function_ref<void(OpBuilder &, Location)> bodyBuilder,
1638 function_ref<void(OpBuilder &, Location)> cleanupBuilder) {
1639 result.addAttribute(getCleanupKindAttrName(result.name),
1640 CleanupKindAttr::get(builder.getContext(), cleanupKind));
1641
1642 OpBuilder::InsertionGuard guard(builder);
1643
1644 // Build body region.
1645 Region *bodyRegion = result.addRegion();
1646 builder.createBlock(bodyRegion);
1647 if (bodyBuilder)
1648 bodyBuilder(builder, result.location);
1649
1650 // Build cleanup region.
1651 Region *cleanupRegion = result.addRegion();
1652 builder.createBlock(cleanupRegion);
1653 if (cleanupBuilder)
1654 cleanupBuilder(builder, result.location);
1655}
1656
1657//===----------------------------------------------------------------------===//
1658// BrOp
1659//===----------------------------------------------------------------------===//
1660
1661/// Merges blocks connected by a unique unconditional branch.
1662///
1663/// ^bb0: ^bb0:
1664/// ... ...
1665/// cir.br ^bb1 => ...
1666/// ^bb1: cir.return
1667/// ...
1668/// cir.return
1669LogicalResult cir::BrOp::canonicalize(BrOp op, PatternRewriter &rewriter) {
1670 Block *src = op->getBlock();
1671 Block *dst = op.getDest();
1672
1673 // Do not fold self-loops.
1674 if (src == dst)
1675 return failure();
1676
1677 // Only merge when this is the unique edge between the blocks.
1678 if (src->getNumSuccessors() != 1 || dst->getSinglePredecessor() != src)
1679 return failure();
1680
1681 // Don't merge blocks that start with LabelOp or IndirectBrOp.
1682 // This is to avoid merging blocks that have an indirect predecessor.
1683 if (isa<cir::LabelOp, cir::IndirectBrOp>(dst->front()))
1684 return failure();
1685
1686 auto operands = op.getDestOperands();
1687 rewriter.eraseOp(op);
1688 rewriter.mergeBlocks(dst, src, operands);
1689 return success();
1690}
1691
1692mlir::SuccessorOperands cir::BrOp::getSuccessorOperands(unsigned index) {
1693 assert(index == 0 && "invalid successor index");
1694 return mlir::SuccessorOperands(getDestOperandsMutable());
1695}
1696
1697Block *cir::BrOp::getSuccessorForOperands(ArrayRef<Attribute>) {
1698 return getDest();
1699}
1700
1701//===----------------------------------------------------------------------===//
1702// IndirectBrCondOp
1703//===----------------------------------------------------------------------===//
1704
1705mlir::SuccessorOperands
1706cir::IndirectBrOp::getSuccessorOperands(unsigned index) {
1707 assert(index < getNumSuccessors() && "invalid successor index");
1708 return mlir::SuccessorOperands(getSuccOperandsMutable()[index]);
1709}
1710
1712 OpAsmParser &parser, Type &flagType,
1713 SmallVectorImpl<Block *> &succOperandBlocks,
1714 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &succOperands,
1715 SmallVectorImpl<SmallVector<Type>> &succOperandsTypes) {
1716 if (failed(parser.parseCommaSeparatedList(
1717 OpAsmParser::Delimiter::Square,
1718 [&]() {
1719 Block *destination = nullptr;
1720 SmallVector<OpAsmParser::UnresolvedOperand> operands;
1721 SmallVector<Type> operandTypes;
1722
1723 if (parser.parseSuccessor(destination).failed())
1724 return failure();
1725
1726 if (succeeded(parser.parseOptionalLParen())) {
1727 if (failed(parser.parseOperandList(
1728 operands, OpAsmParser::Delimiter::None)) ||
1729 failed(parser.parseColonTypeList(operandTypes)) ||
1730 failed(parser.parseRParen()))
1731 return failure();
1732 }
1733 succOperandBlocks.push_back(destination);
1734 succOperands.emplace_back(operands);
1735 succOperandsTypes.emplace_back(operandTypes);
1736 return success();
1737 },
1738 "successor blocks")))
1739 return failure();
1740 return success();
1741}
1742
1743void printIndirectBrOpSucessors(OpAsmPrinter &p, cir::IndirectBrOp op,
1744 Type flagType, SuccessorRange succs,
1745 OperandRangeRange succOperands,
1746 const TypeRangeRange &succOperandsTypes) {
1747 p << "[";
1748 llvm::interleave(
1749 llvm::zip(succs, succOperands),
1750 [&](auto i) {
1751 p.printNewline();
1752 p.printSuccessorAndUseList(std::get<0>(i), std::get<1>(i));
1753 },
1754 [&] { p << ','; });
1755 if (!succOperands.empty())
1756 p.printNewline();
1757 p << "]";
1758}
1759
1760//===----------------------------------------------------------------------===//
1761// BrCondOp
1762//===----------------------------------------------------------------------===//
1763
1764mlir::SuccessorOperands cir::BrCondOp::getSuccessorOperands(unsigned index) {
1765 assert(index < getNumSuccessors() && "invalid successor index");
1766 return SuccessorOperands(index == 0 ? getDestOperandsTrueMutable()
1767 : getDestOperandsFalseMutable());
1768}
1769
1770Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1771 if (IntegerAttr condAttr = dyn_cast_if_present<IntegerAttr>(operands.front()))
1772 return condAttr.getValue().isOne() ? getDestTrue() : getDestFalse();
1773 return nullptr;
1774}
1775
1776//===----------------------------------------------------------------------===//
1777// CaseOp
1778//===----------------------------------------------------------------------===//
1779
1780void cir::CaseOp::getSuccessorRegions(
1781 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1782 if (!point.isParent()) {
1783 regions.emplace_back(getOperation());
1784 return;
1785 }
1786 regions.push_back(RegionSuccessor(&getCaseRegion()));
1787}
1788
1789mlir::ValueRange cir::CaseOp::getSuccessorInputs(RegionSuccessor successor) {
1790 return successor.isOperation() ? ValueRange(getOperation()->getResults())
1791 : ValueRange();
1792}
1793
1794void cir::CaseOp::build(OpBuilder &builder, OperationState &result,
1795 ArrayAttr value, CaseOpKind kind,
1796 OpBuilder::InsertPoint &insertPoint) {
1797 OpBuilder::InsertionGuard guardSwitch(builder);
1798 result.addAttribute("value", value);
1799 result.getOrAddProperties<Properties>().kind =
1800 cir::CaseOpKindAttr::get(builder.getContext(), kind);
1801 Region *caseRegion = result.addRegion();
1802 builder.createBlock(caseRegion);
1803
1804 insertPoint = builder.saveInsertionPoint();
1805}
1806
1807//===----------------------------------------------------------------------===//
1808// SwitchOp
1809//===----------------------------------------------------------------------===//
1810
1811void cir::SwitchOp::getSuccessorRegions(
1812 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &region) {
1813 if (!point.isParent()) {
1814 region.emplace_back(getOperation());
1815 return;
1816 }
1817
1818 region.push_back(RegionSuccessor(&getBody()));
1819}
1820
1821mlir::ValueRange cir::SwitchOp::getSuccessorInputs(RegionSuccessor successor) {
1822 return successor.isOperation() ? ValueRange(getOperation()->getResults())
1823 : ValueRange();
1824}
1825
1826void cir::SwitchOp::build(OpBuilder &builder, OperationState &result,
1827 Value cond, BuilderOpStateCallbackRef switchBuilder) {
1828 assert(switchBuilder && "the builder callback for regions must be present");
1829 OpBuilder::InsertionGuard guardSwitch(builder);
1830 Region *switchRegion = result.addRegion();
1831 builder.createBlock(switchRegion);
1832 result.addOperands({cond});
1833 switchBuilder(builder, result.location, result);
1834}
1835
1836void cir::SwitchOp::collectCases(llvm::SmallVectorImpl<CaseOp> &cases) {
1837 walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
1838 // Don't walk in nested switch op.
1839 if (isa<cir::SwitchOp>(op) && op != *this)
1840 return WalkResult::skip();
1841
1842 if (auto caseOp = dyn_cast<cir::CaseOp>(op))
1843 cases.push_back(caseOp);
1844
1845 return WalkResult::advance();
1846 });
1847}
1848
1849bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
1850 collectCases(cases);
1851
1852 if (getBody().empty())
1853 return false;
1854
1855 if (!isa<YieldOp>(getBody().front().back()))
1856 return false;
1857
1858 if (!llvm::all_of(getBody().front(),
1859 [](Operation &op) { return isa<CaseOp, YieldOp>(op); }))
1860 return false;
1861
1862 return llvm::all_of(cases, [this](CaseOp op) {
1863 return op->getParentOfType<SwitchOp>() == *this;
1864 });
1865}
1866
1867//===----------------------------------------------------------------------===//
1868// SwitchFlatOp
1869//===----------------------------------------------------------------------===//
1870
1871void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
1872 Value value, Block *defaultDestination,
1873 ValueRange defaultOperands,
1874 ArrayRef<APInt> caseValues,
1875 BlockRange caseDestinations,
1876 ArrayRef<ValueRange> caseOperands) {
1877
1878 std::vector<mlir::Attribute> caseValuesAttrs;
1879 for (const APInt &val : caseValues)
1880 caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
1881 mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
1882
1883 build(builder, result, value, defaultOperands, caseOperands, attrs,
1884 defaultDestination, caseDestinations);
1885}
1886
1887/// <cases> ::= `[` (case (`,` case )* )? `]`
1888/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
1889static ParseResult parseSwitchFlatOpCases(
1890 OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
1891 SmallVectorImpl<Block *> &caseDestinations,
1893 &caseOperands,
1894 SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
1895 if (failed(parser.parseLSquare()))
1896 return failure();
1897 if (succeeded(parser.parseOptionalRSquare()))
1898 return success();
1900
1901 auto parseCase = [&]() {
1902 int64_t value = 0;
1903 if (failed(parser.parseInteger(value)))
1904 return failure();
1905
1906 values.push_back(cir::IntAttr::get(flagType, value));
1907
1908 Block *destination;
1910 llvm::SmallVector<Type> operandTypes;
1911 if (parser.parseColon() || parser.parseSuccessor(destination))
1912 return failure();
1913 if (!parser.parseOptionalLParen()) {
1914 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
1915 /*allowResultNumber=*/false) ||
1916 parser.parseColonTypeList(operandTypes) || parser.parseRParen())
1917 return failure();
1918 }
1919 caseDestinations.push_back(destination);
1920 caseOperands.emplace_back(operands);
1921 caseOperandTypes.emplace_back(operandTypes);
1922 return success();
1923 };
1924 if (failed(parser.parseCommaSeparatedList(parseCase)))
1925 return failure();
1926
1927 caseValues = ArrayAttr::get(flagType.getContext(), values);
1928
1929 return parser.parseRSquare();
1930}
1931
1932static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
1933 Type flagType, mlir::ArrayAttr caseValues,
1934 SuccessorRange caseDestinations,
1935 OperandRangeRange caseOperands,
1936 const TypeRangeRange &caseOperandTypes) {
1937 p << '[';
1938 p.printNewline();
1939 if (!caseValues) {
1940 p << ']';
1941 return;
1942 }
1943
1944 size_t index = 0;
1945 llvm::interleave(
1946 llvm::zip(caseValues, caseDestinations),
1947 [&](auto i) {
1948 p << " ";
1949 mlir::Attribute a = std::get<0>(i);
1950 p << mlir::cast<cir::IntAttr>(a).getValue();
1951 p << ": ";
1952 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
1953 },
1954 [&] {
1955 p << ',';
1956 p.printNewline();
1957 });
1958 p.printNewline();
1959 p << ']';
1960}
1961
1962//===----------------------------------------------------------------------===//
1963// GlobalOp
1964//===----------------------------------------------------------------------===//
1965
1966static ParseResult parseConstantValue(OpAsmParser &parser,
1967 mlir::Attribute &valueAttr) {
1968 NamedAttrList attr;
1969 return parser.parseAttribute(valueAttr, "value", attr);
1970}
1971
1972static void printConstant(OpAsmPrinter &p, Attribute value) {
1973 p.printAttribute(value);
1974}
1975
1976mlir::LogicalResult cir::GlobalOp::verify() {
1977 // Verify that the initial value, if present, is either a unit attribute or
1978 // an attribute CIR supports.
1979 if (getInitialValue().has_value()) {
1980 if (checkConstantTypes(getOperation(), getSymType(), *getInitialValue())
1981 .failed())
1982 return failure();
1983 }
1984
1985 if ((getStaticLocalGuard().has_value()) &&
1986 (!getCtorRegion().empty() || !getDtorRegion().empty()))
1987 return emitOpError(
1988 "Cannot have a static-local global-op with a constructor or "
1989 "destructor, they require in-function initialization via LocalInitOp");
1990
1991 if (getDynTlsRefs()) {
1992 if (getStaticLocalGuard().has_value())
1993 return emitOpError(
1994 "cannot have both static local and dynamic tls references");
1995 if (!getTlsModel() || getTlsModel() != TLS_Model::GeneralDynamic)
1996 return emitOpError("'dyn_tls_refs' only valid for dynamic tls");
1997 }
1998
1999 if (getAliasee().has_value()) {
2000 if (getInitialValue().has_value() || !getCtorRegion().empty() ||
2001 !getDtorRegion().empty())
2002 return emitOpError("global alias shall not have an initializer or "
2003 "constructor/destructor regions");
2004 }
2005
2006 // TODO(CIR): Many other checks for properties that haven't been upstreamed
2007 // yet.
2008
2009 return success();
2010}
2011
2012void cir::GlobalOp::build(
2013 OpBuilder &odsBuilder, OperationState &odsState, llvm::StringRef sym_name,
2014 mlir::Type sym_type, bool isConstant,
2015 mlir::ptr::MemorySpaceAttrInterface addrSpace,
2016 cir::GlobalLinkageKind linkage,
2017 function_ref<void(OpBuilder &, Location)> ctorBuilder,
2018 function_ref<void(OpBuilder &, Location)> dtorBuilder) {
2019 odsState.addAttribute(getSymNameAttrName(odsState.name),
2020 odsBuilder.getStringAttr(sym_name));
2021 odsState.addAttribute(getSymTypeAttrName(odsState.name),
2022 mlir::TypeAttr::get(sym_type));
2023 auto &properties = odsState.getOrAddProperties<cir::GlobalOp::Properties>();
2024 properties.setConstant(isConstant);
2025
2026 addrSpace = normalizeDefaultAddressSpace(addrSpace);
2027 if (addrSpace)
2028 odsState.addAttribute(getAddrSpaceAttrName(odsState.name), addrSpace);
2029
2030 cir::GlobalLinkageKindAttr linkageAttr =
2031 cir::GlobalLinkageKindAttr::get(odsBuilder.getContext(), linkage);
2032 odsState.addAttribute(getLinkageAttrName(odsState.name), linkageAttr);
2033
2034 Region *ctorRegion = odsState.addRegion();
2035 if (ctorBuilder) {
2036 odsBuilder.createBlock(ctorRegion);
2037 ctorBuilder(odsBuilder, odsState.location);
2038 }
2039
2040 Region *dtorRegion = odsState.addRegion();
2041 if (dtorBuilder) {
2042 odsBuilder.createBlock(dtorRegion);
2043 dtorBuilder(odsBuilder, odsState.location);
2044 }
2045}
2046
2047/// Given the region at `index`, or the parent operation if `index` is None,
2048/// return the successor regions. These are the regions that may be selected
2049/// during the flow of control. `operands` is a set of optional attributes that
2050/// correspond to a constant value for each operand, or null if that operand is
2051/// not a constant.
2052void cir::GlobalOp::getSuccessorRegions(
2053 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2054 // The `ctor` and `dtor` regions always branch back to the parent operation.
2055 if (!point.isParent()) {
2056 regions.emplace_back(getOperation());
2057 return;
2058 }
2059
2060 // Don't consider the ctor region if it is empty.
2061 Region *ctorRegion = &this->getCtorRegion();
2062 if (ctorRegion->empty())
2063 ctorRegion = nullptr;
2064
2065 // Don't consider the dtor region if it is empty.
2066 Region *dtorRegion = &this->getDtorRegion();
2067 if (dtorRegion->empty())
2068 dtorRegion = nullptr;
2069
2070 // If the condition isn't constant, both regions may be executed.
2071 if (ctorRegion)
2072 regions.push_back(RegionSuccessor(ctorRegion));
2073 if (dtorRegion)
2074 regions.push_back(RegionSuccessor(dtorRegion));
2075}
2076
2077mlir::ValueRange cir::GlobalOp::getSuccessorInputs(RegionSuccessor successor) {
2078 return successor.isOperation() ? ValueRange(getOperation()->getResults())
2079 : ValueRange();
2080}
2081
2082static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op,
2083 TypeAttr type, Attribute initAttr,
2084 mlir::Region &ctorRegion,
2085 mlir::Region &dtorRegion) {
2086 auto printType = [&]() { p << ": " << type; };
2087 // Aliases are definitions but they have no initial value or ctor/dtor; the
2088 // assembly prints them like declarations (`: type`).
2089 if (op.isDeclaration() || op.getAliasee()) {
2090 printType();
2091 return;
2092 }
2093
2094 p << "= ";
2095 if (!ctorRegion.empty()) {
2096 p << "ctor ";
2097 printType();
2098 p << " ";
2099 p.printRegion(ctorRegion,
2100 /*printEntryBlockArgs=*/false,
2101 /*printBlockTerminators=*/false);
2102 } else {
2103 // This also prints the type...
2104 if (initAttr)
2105 printConstant(p, initAttr);
2106 }
2107
2108 if (!dtorRegion.empty()) {
2109 p << " dtor ";
2110 p.printRegion(dtorRegion,
2111 /*printEntryBlockArgs=*/false,
2112 /*printBlockTerminators=*/false);
2113 }
2114}
2115
2116static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser,
2117 TypeAttr &typeAttr,
2118 Attribute &initialValueAttr,
2119 mlir::Region &ctorRegion,
2120 mlir::Region &dtorRegion) {
2121 mlir::Type opTy;
2122 if (parser.parseOptionalEqual().failed()) {
2123 // Absence of equal means a declaration, so we need to parse the type.
2124 // cir.global @a : !cir.int<s, 32>
2125 if (parser.parseColonType(opTy))
2126 return failure();
2127 } else {
2128 // Parse contructor, example:
2129 // cir.global @rgb = ctor : type { ... }
2130 if (!parser.parseOptionalKeyword("ctor")) {
2131 if (parser.parseColonType(opTy))
2132 return failure();
2133 auto parseLoc = parser.getCurrentLocation();
2134 if (parser.parseRegion(ctorRegion, /*arguments=*/{}, /*argTypes=*/{}))
2135 return failure();
2136 if (ensureRegionTerm(parser, ctorRegion, parseLoc).failed())
2137 return failure();
2138 } else {
2139 // Parse constant with initializer, examples:
2140 // cir.global @y = 3.400000e+00 : f32
2141 // cir.global @rgb = #cir.const_array<[...] : !cir.array<i8 x 3>>
2142 if (parseConstantValue(parser, initialValueAttr).failed())
2143 return failure();
2144
2145 assert(mlir::isa<mlir::TypedAttr>(initialValueAttr) &&
2146 "Non-typed attrs shouldn't appear here.");
2147 auto typedAttr = mlir::cast<mlir::TypedAttr>(initialValueAttr);
2148 opTy = typedAttr.getType();
2149 }
2150
2151 // Parse destructor, example:
2152 // dtor { ... }
2153 if (!parser.parseOptionalKeyword("dtor")) {
2154 auto parseLoc = parser.getCurrentLocation();
2155 if (parser.parseRegion(dtorRegion, /*arguments=*/{}, /*argTypes=*/{}))
2156 return failure();
2157 if (ensureRegionTerm(parser, dtorRegion, parseLoc).failed())
2158 return failure();
2159 }
2160 }
2161
2162 typeAttr = TypeAttr::get(opTy);
2163 return success();
2164}
2165
2166//===----------------------------------------------------------------------===//
2167// GetGlobalOp
2168//===----------------------------------------------------------------------===//
2169
2170LogicalResult
2171cir::GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2172 // Verify that the result type underlying pointer type matches the type of
2173 // the referenced cir.global or cir.func op.
2174 mlir::Operation *op =
2175 symbolTable.lookupNearestSymbolFrom(*this, getNameAttr());
2176 if (op == nullptr || !(isa<GlobalOp>(op) || isa<FuncOp>(op)))
2177 return emitOpError("'")
2178 << getName()
2179 << "' does not reference a valid cir.global or cir.func";
2180
2181 mlir::Type symTy;
2182 mlir::ptr::MemorySpaceAttrInterface symAddrSpaceAttr{};
2183 if (auto g = dyn_cast<GlobalOp>(op)) {
2184 symTy = g.getSymType();
2185 symAddrSpaceAttr = g.getAddrSpaceAttr();
2186 // Verify that for thread local global access, the global needs to
2187 // be marked with tls bits.
2188 if (getTls() && !g.getTlsModel())
2189 return emitOpError("access to global not marked thread local");
2190
2191 // Verify that the static_local attribute on GetGlobalOp matches the
2192 // static_local_guard attribute on GlobalOp. GetGlobalOp uses a UnitAttr,
2193 // GlobalOp uses StaticLocalGuardAttr. Both should be present, or neither.
2194 bool getGlobalIsStaticLocal = getStaticLocal();
2195 bool globalIsStaticLocal = g.getStaticLocalGuard().has_value();
2196 if (getGlobalIsStaticLocal != globalIsStaticLocal &&
2197 !getOperation()->getParentOfType<cir::GlobalOp>())
2198 return emitOpError("static_local attribute mismatch");
2199 } else if (auto f = dyn_cast<FuncOp>(op)) {
2200 symTy = f.getFunctionType();
2201 } else {
2202 llvm_unreachable("Unexpected operation for GetGlobalOp");
2203 }
2204
2205 auto resultType = dyn_cast<PointerType>(getAddr().getType());
2206 if (!resultType || symTy != resultType.getPointee())
2207 return emitOpError("result type pointee type '")
2208 << resultType.getPointee() << "' does not match type " << symTy
2209 << " of the global @" << getName();
2210
2211 if (symAddrSpaceAttr != resultType.getAddrSpace()) {
2212 return emitOpError()
2213 << "result type address space does not match the address "
2214 "space of the global @"
2215 << getName();
2216 }
2217
2218 return success();
2219}
2220
2221//===----------------------------------------------------------------------===//
2222// VTableAddrPointOp
2223//===----------------------------------------------------------------------===//
2224
2225LogicalResult
2226cir::VTableAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2227 StringRef name = getName();
2228
2229 // Verify that the result type underlying pointer type matches the type of
2230 // the referenced cir.global.
2231 auto op =
2232 symbolTable.lookupNearestSymbolFrom<cir::GlobalOp>(*this, getNameAttr());
2233 if (!op)
2234 return emitOpError("'")
2235 << name << "' does not reference a valid cir.global";
2236 std::optional<mlir::Attribute> init = op.getInitialValue();
2237 if (!init)
2238 return success();
2239 if (!isa<cir::VTableAttr>(*init))
2240 return emitOpError("Expected #cir.vtable in initializer for global '")
2241 << name << "'";
2242 return success();
2243}
2244
2245//===----------------------------------------------------------------------===//
2246// VTTAddrPointOp
2247//===----------------------------------------------------------------------===//
2248
2249LogicalResult
2250cir::VTTAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2251 // VTT ptr is not coming from a symbol.
2252 if (!getName())
2253 return success();
2254 StringRef name = *getName();
2255
2256 // Verify that the result type underlying pointer type matches the type of
2257 // the referenced cir.global op.
2258 auto op =
2259 symbolTable.lookupNearestSymbolFrom<cir::GlobalOp>(*this, getNameAttr());
2260 if (!op)
2261 return emitOpError("'")
2262 << name << "' does not reference a valid cir.global";
2263 std::optional<mlir::Attribute> init = op.getInitialValue();
2264 if (!init)
2265 return success();
2266 if (!isa<cir::ConstArrayAttr>(*init))
2267 return emitOpError(
2268 "Expected constant array in initializer for global VTT '")
2269 << name << "'";
2270 return success();
2271}
2272
2273LogicalResult cir::VTTAddrPointOp::verify() {
2274 // The operation uses either a symbol or a value to operate, but not both
2275 if (getName() && getSymAddr())
2276 return emitOpError("should use either a symbol or value, but not both");
2277
2278 // If not a symbol, stick with the concrete type used for getSymAddr.
2279 if (getSymAddr())
2280 return success();
2281
2282 mlir::Type resultType = getAddr().getType();
2283 mlir::Type resTy = cir::PointerType::get(
2284 cir::PointerType::get(cir::VoidType::get(getContext())));
2285
2286 if (resultType != resTy)
2287 return emitOpError("result type must be ")
2288 << resTy << ", but provided result type is " << resultType;
2289 return success();
2290}
2291
2292//===----------------------------------------------------------------------===//
2293// FuncOp
2294//===----------------------------------------------------------------------===//
2295
2296/// Returns the name used for the linkage attribute. This *must* correspond to
2297/// the name of the attribute in ODS.
2298static llvm::StringRef getLinkageAttrNameString() { return "linkage"; }
2299
2300void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
2301 StringRef name, FuncType type,
2302 GlobalLinkageKind linkage, CallingConv callingConv) {
2303 result.addRegion();
2304 result.addAttribute(SymbolTable::getSymbolAttrName(),
2305 builder.getStringAttr(name));
2306 result.addAttribute(getFunctionTypeAttrName(result.name),
2307 TypeAttr::get(type));
2308 result.addAttribute(
2310 GlobalLinkageKindAttr::get(builder.getContext(), linkage));
2311 result.addAttribute(getCallingConvAttrName(result.name),
2312 CallingConvAttr::get(builder.getContext(), callingConv));
2313}
2314
2315//===----------------------------------------------------------------------===//
2316// AnnotationAttr
2317//===----------------------------------------------------------------------===//
2318
2319LogicalResult
2320cir::AnnotationAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2321 mlir::StringAttr name, mlir::ArrayAttr args) {
2322 if (!args)
2323 return success();
2324 for (mlir::Attribute arg : args) {
2325 if (!isa<mlir::StringAttr, mlir::IntegerAttr>(arg))
2326 return emitError() << "annotation args must be StringAttr or IntegerAttr,"
2327 << " got " << arg;
2328 }
2329 return success();
2330}
2331
2332ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
2333 llvm::SMLoc loc = parser.getCurrentLocation();
2334 mlir::Builder &builder = parser.getBuilder();
2335
2336 mlir::StringAttr builtinNameAttr = getBuiltinAttrName(state.name);
2337 mlir::StringAttr coroutineNameAttr = getCoroutineAttrName(state.name);
2338 mlir::StringAttr inlineKindNameAttr = getInlineKindAttrName(state.name);
2339 mlir::StringAttr lambdaNameAttr = getLambdaAttrName(state.name);
2340 mlir::StringAttr noProtoNameAttr = getNoProtoAttrName(state.name);
2341 mlir::StringAttr comdatNameAttr = getComdatAttrName(state.name);
2342 mlir::StringAttr visNameAttr = getSymVisibilityAttrName(state.name);
2343 mlir::StringAttr dsoLocalNameAttr = getDsoLocalAttrName(state.name);
2344 mlir::StringAttr specialMemberAttr = getCxxSpecialMemberAttrName(state.name);
2345
2346 if (::mlir::succeeded(parser.parseOptionalKeyword(builtinNameAttr.strref())))
2347 state.addAttribute(builtinNameAttr, parser.getBuilder().getUnitAttr());
2348 if (::mlir::succeeded(
2349 parser.parseOptionalKeyword(coroutineNameAttr.strref())))
2350 state.addAttribute(coroutineNameAttr, parser.getBuilder().getUnitAttr());
2351
2352 // Parse optional inline kind attribute
2353 cir::InlineKindAttr inlineKindAttr;
2354 if (failed(parseInlineKindAttr(parser, inlineKindAttr)))
2355 return failure();
2356 if (inlineKindAttr)
2357 state.addAttribute(inlineKindNameAttr, inlineKindAttr);
2358
2359 if (::mlir::succeeded(parser.parseOptionalKeyword(lambdaNameAttr.strref())))
2360 state.addAttribute(lambdaNameAttr, parser.getBuilder().getUnitAttr());
2361 if (parser.parseOptionalKeyword(noProtoNameAttr).succeeded())
2362 state.addAttribute(noProtoNameAttr, parser.getBuilder().getUnitAttr());
2363
2364 if (parser.parseOptionalKeyword(comdatNameAttr).succeeded())
2365 state.addAttribute(comdatNameAttr, parser.getBuilder().getUnitAttr());
2366
2367 // Default to external linkage if no keyword is provided.
2368 state.addAttribute(getLinkageAttrNameString(),
2369 GlobalLinkageKindAttr::get(
2370 parser.getContext(),
2372 parser, GlobalLinkageKind::ExternalLinkage)));
2373
2374 ::llvm::StringRef visAttrStr;
2375 if (parser.parseOptionalKeyword(&visAttrStr, {"private", "public", "nested"})
2376 .succeeded()) {
2377 state.addAttribute(visNameAttr,
2378 parser.getBuilder().getStringAttr(visAttrStr));
2379 }
2380
2381 state.getOrAddProperties<cir::FuncOp::Properties>().global_visibility =
2382 parseOptionalCIRKeyword(parser, cir::VisibilityKind::Default);
2383
2384 if (parser.parseOptionalKeyword(dsoLocalNameAttr).succeeded())
2385 state.addAttribute(dsoLocalNameAttr, parser.getBuilder().getUnitAttr());
2386
2387 StringAttr nameAttr;
2388 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2389 state.attributes))
2390 return failure();
2394 bool isVariadic = false;
2395 if (function_interface_impl::parseFunctionSignatureWithArguments(
2396 parser, /*allowVariadic=*/true, arguments, isVariadic, resultTypes,
2397 resultAttrs))
2398 return failure();
2401 bool argAttrsEmpty = true;
2402 for (OpAsmParser::Argument &arg : arguments) {
2403 argTypes.push_back(arg.type);
2404 // Add the 'empty' attribute anyway to make sure the arity matches, but we
2405 // only want to 'set' the attribute at the top level if there is SOME data
2406 // along the way.
2407 argAttrs.push_back(arg.attrs);
2408 if (arg.attrs)
2409 argAttrsEmpty = false;
2410 }
2411
2412 // These should be in sync anyway, but test both of them anyway.
2413 if (resultTypes.size() > 1 || resultAttrs.size() > 1)
2414 return parser.emitError(
2415 loc, "functions with multiple return types are not supported");
2416
2417 mlir::Type returnType =
2418 (resultTypes.empty() ? cir::VoidType::get(builder.getContext())
2419 : resultTypes.front());
2420
2421 cir::FuncType fnType = cir::FuncType::get(argTypes, returnType, isVariadic);
2422 if (!fnType)
2423 return failure();
2424
2425 state.addAttribute(getFunctionTypeAttrName(state.name),
2426 TypeAttr::get(fnType));
2427
2428 if (!resultAttrs.empty() && resultAttrs[0])
2429 state.addAttribute(
2430 getResAttrsAttrName(state.name),
2431 mlir::ArrayAttr::get(parser.getContext(), {resultAttrs[0]}));
2432
2433 if (!argAttrsEmpty)
2434 state.addAttribute(getArgAttrsAttrName(state.name),
2435 mlir::ArrayAttr::get(parser.getContext(), argAttrs));
2436
2437 bool hasAlias = false;
2438 mlir::StringAttr aliaseeNameAttr = getAliaseeAttrName(state.name);
2439 if (parser.parseOptionalKeyword("alias").succeeded()) {
2440 if (parser.parseLParen().failed())
2441 return failure();
2442 mlir::StringAttr aliaseeAttr;
2443 if (parser.parseOptionalSymbolName(aliaseeAttr).failed())
2444 return failure();
2445 state.addAttribute(aliaseeNameAttr, FlatSymbolRefAttr::get(aliaseeAttr));
2446 if (parser.parseRParen().failed())
2447 return failure();
2448 hasAlias = true;
2449 }
2450
2451 mlir::StringAttr personalityNameAttr = getPersonalityAttrName(state.name);
2452 if (parser.parseOptionalKeyword("personality").succeeded()) {
2453 if (parser.parseLParen().failed())
2454 return failure();
2455 mlir::StringAttr personalityAttr;
2456 if (parser.parseOptionalSymbolName(personalityAttr).failed())
2457 return failure();
2458 state.addAttribute(personalityNameAttr,
2459 FlatSymbolRefAttr::get(personalityAttr));
2460 if (parser.parseRParen().failed())
2461 return failure();
2462 }
2463
2464 // Default to C calling convention if no keyword is provided.
2465 mlir::StringAttr callConvNameAttr = getCallingConvAttrName(state.name);
2466 cir::CallingConv callConv = cir::CallingConv::C;
2467 if (parser.parseOptionalKeyword("cc").succeeded()) {
2468 if (parser.parseLParen().failed())
2469 return failure();
2470 if (parseCIRKeyword<cir::CallingConv>(parser, callConv).failed())
2471 return parser.emitError(loc) << "unknown calling convention";
2472 if (parser.parseRParen().failed())
2473 return failure();
2474 }
2475 state.addAttribute(callConvNameAttr,
2476 cir::CallingConvAttr::get(parser.getContext(), callConv));
2477
2478 auto parseGlobalDtorCtor =
2479 [&](StringRef keyword,
2480 llvm::function_ref<void(std::optional<int> prio)> createAttr)
2481 -> mlir::LogicalResult {
2482 if (mlir::succeeded(parser.parseOptionalKeyword(keyword))) {
2483 std::optional<int> priority;
2484 if (mlir::succeeded(parser.parseOptionalLParen())) {
2485 auto parsedPriority = mlir::FieldParser<int>::parse(parser);
2486 if (mlir::failed(parsedPriority))
2487 return parser.emitError(parser.getCurrentLocation(),
2488 "failed to parse 'priority', of type 'int'");
2489 priority = parsedPriority.value_or(int());
2490 // Parse literal ')'
2491 if (parser.parseRParen())
2492 return failure();
2493 }
2494 createAttr(priority);
2495 }
2496 return success();
2497 };
2498
2499 // Parse CXXSpecialMember attribute
2500 if (parser.parseOptionalKeyword("special_member").succeeded()) {
2501 if (parser.parseLess().failed())
2502 return failure();
2503
2504 mlir::Attribute attr;
2505 if (parser.parseAttribute(attr).failed())
2506 return failure();
2507 if (!mlir::isa<cir::CXXCtorAttr, cir::CXXDtorAttr, cir::CXXAssignAttr>(
2508 attr))
2509 return parser.emitError(parser.getCurrentLocation(),
2510 "expected a C++ special member attribute");
2511 state.addAttribute(specialMemberAttr, attr);
2512
2513 if (parser.parseGreater().failed())
2514 return failure();
2515 }
2516
2517 if (parseGlobalDtorCtor("global_ctor", [&](std::optional<int> priority) {
2518 mlir::IntegerAttr globalCtorPriorityAttr =
2519 builder.getI32IntegerAttr(priority.value_or(65535));
2520 state.addAttribute(getGlobalCtorPriorityAttrName(state.name),
2521 globalCtorPriorityAttr);
2522 }).failed())
2523 return failure();
2524
2525 if (parseGlobalDtorCtor("global_dtor", [&](std::optional<int> priority) {
2526 mlir::IntegerAttr globalDtorPriorityAttr =
2527 builder.getI32IntegerAttr(priority.value_or(65535));
2528 state.addAttribute(getGlobalDtorPriorityAttrName(state.name),
2529 globalDtorPriorityAttr);
2530 }).failed())
2531 return failure();
2532
2533 if (parser.parseOptionalKeyword("side_effect").succeeded()) {
2534 cir::SideEffect sideEffect;
2535
2536 if (parser.parseLParen().failed() ||
2537 parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed() ||
2538 parser.parseRParen().failed())
2539 return failure();
2540
2541 auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
2542 state.addAttribute(CIRDialect::getSideEffectAttrName(), attr);
2543 }
2544
2545 // Parse optional annotations attribute (an ArrayAttr of AnnotationAttr).
2546 mlir::StringAttr annotationsNameAttr = getAnnotationsAttrName(state.name);
2547 mlir::ArrayAttr annotationsAttr;
2548 if (parser.parseOptionalAttribute(annotationsAttr).has_value() &&
2549 annotationsAttr)
2550 state.addAttribute(annotationsNameAttr, annotationsAttr);
2551
2552 // Parse the rest of the attributes.
2553 NamedAttrList parsedAttrs;
2554 if (parser.parseOptionalAttrDictWithKeyword(parsedAttrs))
2555 return failure();
2556
2557 for (StringRef disallowed : cir::FuncOp::getAttributeNames()) {
2558 if (parsedAttrs.get(disallowed))
2559 return parser.emitError(loc, "attribute '")
2560 << disallowed
2561 << "' should not be specified in the explicit attribute list";
2562 }
2563
2564 state.attributes.append(parsedAttrs);
2565
2566 // Parse the optional function body.
2567 auto *body = state.addRegion();
2568 OptionalParseResult parseResult = parser.parseOptionalRegion(
2569 *body, arguments, /*enableNameShadowing=*/false);
2570 if (parseResult.has_value()) {
2571 if (hasAlias)
2572 return parser.emitError(loc, "function alias shall not have a body");
2573 if (failed(*parseResult))
2574 return failure();
2575 // Function body was parsed, make sure its not empty.
2576 if (body->empty())
2577 return parser.emitError(loc, "expected non-empty function body");
2578 }
2579
2580 return success();
2581}
2582
2583// This function corresponds to `llvm::GlobalValue::isDeclaration` and should
2584// have a similar implementation. We don't currently ifuncs or materializable
2585// functions, but those should be handled here as they are implemented.
2586bool cir::FuncOp::isDeclaration() {
2588
2589 std::optional<StringRef> aliasee = getAliasee();
2590 if (!aliasee)
2591 return getFunctionBody().empty();
2592
2593 // Aliases are always definitions.
2594 return false;
2595}
2596
2597bool cir::FuncOp::isCXXSpecialMemberFunction() {
2598 return getCxxSpecialMemberAttr() != nullptr;
2599}
2600
2601bool cir::FuncOp::isCxxConstructor() {
2602 auto attr = getCxxSpecialMemberAttr();
2603 return attr && dyn_cast<CXXCtorAttr>(attr);
2604}
2605
2606bool cir::FuncOp::isCxxDestructor() {
2607 auto attr = getCxxSpecialMemberAttr();
2608 return attr && dyn_cast<CXXDtorAttr>(attr);
2609}
2610
2611bool cir::FuncOp::isCxxSpecialAssignment() {
2612 auto attr = getCxxSpecialMemberAttr();
2613 return attr && dyn_cast<CXXAssignAttr>(attr);
2614}
2615
2616std::optional<CtorKind> cir::FuncOp::getCxxConstructorKind() {
2617 mlir::Attribute attr = getCxxSpecialMemberAttr();
2618 if (attr) {
2619 if (auto ctor = dyn_cast<CXXCtorAttr>(attr))
2620 return ctor.getCtorKind();
2621 }
2622 return std::nullopt;
2623}
2624
2625std::optional<AssignKind> cir::FuncOp::getCxxSpecialAssignKind() {
2626 mlir::Attribute attr = getCxxSpecialMemberAttr();
2627 if (attr) {
2628 if (auto assign = dyn_cast<CXXAssignAttr>(attr))
2629 return assign.getAssignKind();
2630 }
2631 return std::nullopt;
2632}
2633
2634bool cir::FuncOp::isCxxTrivialMemberFunction() {
2635 mlir::Attribute attr = getCxxSpecialMemberAttr();
2636 if (attr) {
2637 if (auto ctor = dyn_cast<CXXCtorAttr>(attr))
2638 return ctor.getIsTrivial();
2639 if (auto dtor = dyn_cast<CXXDtorAttr>(attr))
2640 return dtor.getIsTrivial();
2641 if (auto assign = dyn_cast<CXXAssignAttr>(attr))
2642 return assign.getIsTrivial();
2643 }
2644 return false;
2645}
2646
2647mlir::Region *cir::FuncOp::getCallableRegion() {
2648 // TODO(CIR): This function will have special handling for aliases and a
2649 // check for an external function, once those features have been upstreamed.
2650 return &getBody();
2651}
2652
2653void cir::FuncOp::print(OpAsmPrinter &p) {
2654 if (getBuiltin())
2655 p << " builtin";
2656
2657 if (getCoroutine())
2658 p << " coroutine";
2659
2660 printInlineKindAttr(p, getInlineKindAttr());
2661
2662 if (getLambda())
2663 p << " lambda";
2664
2665 if (getNoProto())
2666 p << " no_proto";
2667
2668 if (getComdat())
2669 p << " comdat";
2670
2671 if (getLinkage() != GlobalLinkageKind::ExternalLinkage)
2672 p << ' ' << stringifyGlobalLinkageKind(getLinkage());
2673
2674 mlir::SymbolTable::Visibility vis = getVisibility();
2675 if (vis != mlir::SymbolTable::Visibility::Public)
2676 p << ' ' << vis;
2677
2678 if (getGlobalVisibility() != cir::VisibilityKind::Default)
2679 p << ' ' << stringifyVisibilityKind(getGlobalVisibility());
2680
2681 if (getDsoLocal())
2682 p << " dso_local";
2683
2684 p << ' ';
2685 p.printSymbolName(getSymName());
2686 cir::FuncType fnType = getFunctionType();
2687 function_interface_impl::printFunctionSignature(
2688 p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes());
2689
2690 if (std::optional<StringRef> aliaseeName = getAliasee()) {
2691 p << " alias(";
2692 p.printSymbolName(*aliaseeName);
2693 p << ")";
2694 }
2695
2696 if (getCallingConv() != cir::CallingConv::C) {
2697 p << " cc(";
2698 p << stringifyCallingConv(getCallingConv());
2699 p << ")";
2700 }
2701
2702 if (std::optional<StringRef> personalityName = getPersonality()) {
2703 p << " personality(";
2704 p.printSymbolName(*personalityName);
2705 p << ")";
2706 }
2707
2708 if (auto specialMemberAttr = getCxxSpecialMember()) {
2709 p << " special_member<";
2710 p.printAttribute(*specialMemberAttr);
2711 p << '>';
2712 }
2713
2714 if (auto globalCtorPriority = getGlobalCtorPriority()) {
2715 p << " global_ctor";
2716 if (globalCtorPriority.value() != 65535)
2717 p << "(" << globalCtorPriority.value() << ")";
2718 }
2719
2720 if (auto globalDtorPriority = getGlobalDtorPriority()) {
2721 p << " global_dtor";
2722 if (globalDtorPriority.value() != 65535)
2723 p << "(" << globalDtorPriority.value() << ")";
2724 }
2725
2726 if (std::optional<cir::SideEffect> sideEffect = getSideEffect();
2727 sideEffect && *sideEffect != cir::SideEffect::All) {
2728 p << " side_effect(";
2729 p << stringifySideEffect(*sideEffect);
2730 p << ")";
2731 }
2732
2733 if (mlir::ArrayAttr annotations = getAnnotationsAttr()) {
2734 p << ' ';
2735 p.printAttribute(annotations);
2736 }
2737
2738 function_interface_impl::printFunctionAttributes(
2739 p, *this, cir::FuncOp::getAttributeNames());
2740
2741 // Print the body if this is not an external function.
2742 Region &body = getOperation()->getRegion(0);
2743 if (!body.empty()) {
2744 p << ' ';
2745 p.printRegion(body, /*printEntryBlockArgs=*/false,
2746 /*printBlockTerminators=*/true);
2747 }
2748}
2749
2750mlir::LogicalResult cir::FuncOp::verify() {
2751
2752 if (!isDeclaration() && getCoroutine()) {
2753 bool foundAwait = false;
2754 int coroBodyCount = 0;
2755 this->walk([&](Operation *op) {
2756 if (auto await = dyn_cast<AwaitOp>(op)) {
2757 foundAwait = true;
2758 } else if (isa<CoroBodyOp>(op)) {
2759 coroBodyCount++;
2760 if (coroBodyCount > 1) {
2761 return mlir::WalkResult::interrupt();
2762 }
2763 }
2764 return mlir::WalkResult::advance();
2765 });
2766 if (!foundAwait)
2767 return emitOpError()
2768 << "coroutine body must use at least one cir.await op";
2769 if (coroBodyCount != 1)
2770 return emitOpError()
2771 << "coroutine function must have exactly one cir.body op";
2772 }
2773
2774 llvm::SmallSet<llvm::StringRef, 16> labels;
2775 llvm::SmallSet<llvm::StringRef, 16> gotos;
2776 llvm::SmallSet<llvm::StringRef, 16> blockAddresses;
2777 bool invalidBlockAddress = false;
2778 getOperation()->walk([&](mlir::Operation *op) {
2779 if (auto lab = dyn_cast<cir::LabelOp>(op)) {
2780 labels.insert(lab.getLabel());
2781 } else if (auto goTo = dyn_cast<cir::GotoOp>(op)) {
2782 gotos.insert(goTo.getLabel());
2783 } else if (auto blkAdd = dyn_cast<cir::BlockAddressOp>(op)) {
2784 if (blkAdd.getBlockAddrInfoAttr().getFunc().getAttr() != getSymName()) {
2785 // Stop the walk early, no need to continue
2786 invalidBlockAddress = true;
2787 return mlir::WalkResult::interrupt();
2788 }
2789 blockAddresses.insert(blkAdd.getBlockAddrInfoAttr().getLabel());
2790 }
2791 return mlir::WalkResult::advance();
2792 });
2793
2794 if (invalidBlockAddress)
2795 return emitOpError() << "blockaddress references a different function";
2796
2797 llvm::SmallSet<llvm::StringRef, 16> mismatched;
2798 if (!labels.empty() || !gotos.empty()) {
2799 mismatched = llvm::set_difference(gotos, labels);
2800
2801 if (!mismatched.empty())
2802 return emitOpError() << "goto/label mismatch";
2803 }
2804
2805 mismatched.clear();
2806
2807 if (!labels.empty() || !blockAddresses.empty()) {
2808 mismatched = llvm::set_difference(blockAddresses, labels);
2809
2810 if (!mismatched.empty())
2811 return emitOpError()
2812 << "expects an existing label target in the referenced function";
2813 }
2814
2815 return success();
2816}
2817
2818//===----------------------------------------------------------------------===//
2819// AddOp / SubOp
2820//===----------------------------------------------------------------------===//
2821
2822// The integer-only type constraint on these ops makes the nsw/nuw/sat flag
2823// type checks unnecessary. Only the mutual-exclusivity between nsw/nuw and
2824// sat needs to be verified.
2825
2826LogicalResult cir::AddOp::verify() {
2827 if (getSaturated() && (getNoSignedWrap() || getNoUnsignedWrap()))
2828 return emitOpError()
2829 << "the nsw/nuw flags and the saturated flag are mutually exclusive";
2830 return mlir::success();
2831}
2832
2833LogicalResult cir::SubOp::verify() {
2834 if (getSaturated() && (getNoSignedWrap() || getNoUnsignedWrap()))
2835 return emitOpError()
2836 << "the nsw/nuw flags and the saturated flag are mutually exclusive";
2837 return mlir::success();
2838}
2839
2840//===----------------------------------------------------------------------===//
2841// TernaryOp
2842//===----------------------------------------------------------------------===//
2843
2844/// Given the region at `point`, or the parent operation if `point` is None,
2845/// return the successor regions. These are the regions that may be selected
2846/// during the flow of control. `operands` is a set of optional attributes that
2847/// correspond to a constant value for each operand, or null if that operand is
2848/// not a constant.
2849void cir::TernaryOp::getSuccessorRegions(
2850 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2851 // The `true` and the `false` region branch back to the parent operation.
2852 if (!point.isParent()) {
2853 regions.emplace_back(getOperation());
2854 return;
2855 }
2856
2857 // When branching from the parent operation, both the true and false
2858 // regions are considered possible successors
2859 regions.push_back(RegionSuccessor(&getTrueRegion()));
2860 regions.push_back(RegionSuccessor(&getFalseRegion()));
2861}
2862
2863mlir::ValueRange cir::TernaryOp::getSuccessorInputs(RegionSuccessor successor) {
2864 return successor.isOperation() ? ValueRange(getOperation()->getResults())
2865 : ValueRange();
2866}
2867
2868void cir::TernaryOp::build(
2869 OpBuilder &builder, OperationState &result, Value cond,
2870 function_ref<void(OpBuilder &, Location)> trueBuilder,
2871 function_ref<void(OpBuilder &, Location)> falseBuilder) {
2872 result.addOperands(cond);
2873 OpBuilder::InsertionGuard guard(builder);
2874 Region *trueRegion = result.addRegion();
2875 builder.createBlock(trueRegion);
2876 trueBuilder(builder, result.location);
2877 Region *falseRegion = result.addRegion();
2878 builder.createBlock(falseRegion);
2879 falseBuilder(builder, result.location);
2880
2881 // Get result type from whichever branch has a yield (the other may have
2882 // unreachable from a throw expression)
2883 cir::YieldOp yield;
2884 if (trueRegion->back().mightHaveTerminator())
2885 yield = dyn_cast_or_null<cir::YieldOp>(trueRegion->back().getTerminator());
2886 if (!yield && falseRegion->back().mightHaveTerminator())
2887 yield = dyn_cast_or_null<cir::YieldOp>(falseRegion->back().getTerminator());
2888
2889 assert((!yield || yield.getNumOperands() <= 1) &&
2890 "expected zero or one result type");
2891 if (yield && yield.getNumOperands() == 1)
2892 result.addTypes(TypeRange{yield.getOperandTypes().front()});
2893}
2894
2895//===----------------------------------------------------------------------===//
2896// SelectOp
2897//===----------------------------------------------------------------------===//
2898
2899OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
2900 mlir::Attribute condition = adaptor.getCondition();
2901 if (condition) {
2902 bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
2903 return conditionValue ? getTrueValue() : getFalseValue();
2904 }
2905
2906 // cir.select if %0 then x else x -> x
2907 mlir::Attribute trueValue = adaptor.getTrueValue();
2908 mlir::Attribute falseValue = adaptor.getFalseValue();
2909 if (trueValue == falseValue)
2910 return trueValue;
2911 if (getTrueValue() == getFalseValue())
2912 return getTrueValue();
2913
2914 return {};
2915}
2916
2917LogicalResult cir::SelectOp::verify() {
2918 // AllTypesMatch already guarantees trueVal and falseVal have matching types.
2919 auto condTy = dyn_cast<cir::VectorType>(getCondition().getType());
2920
2921 // If condition is not a vector, no further checks are needed.
2922 if (!condTy)
2923 return success();
2924
2925 // When condition is a vector, both other operands must also be vectors.
2926 if (!isa<cir::VectorType>(getTrueValue().getType()) ||
2927 !isa<cir::VectorType>(getFalseValue().getType())) {
2928 return emitOpError()
2929 << "expected both true and false operands to be vector types "
2930 "when the condition is a vector boolean type";
2931 }
2932
2933 return success();
2934}
2935
2936//===----------------------------------------------------------------------===//
2937// ShiftOp
2938//===----------------------------------------------------------------------===//
2939LogicalResult cir::ShiftOp::verify() {
2940 mlir::Operation *op = getOperation();
2941 auto op0VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(0).getType());
2942 auto op1VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(1).getType());
2943 if (!op0VecTy ^ !op1VecTy)
2944 return emitOpError() << "input types cannot be one vector and one scalar";
2945
2946 if (op0VecTy) {
2947 if (op0VecTy.getSize() != op1VecTy.getSize())
2948 return emitOpError() << "input vector types must have the same size";
2949
2950 auto opResultTy = mlir::dyn_cast<cir::VectorType>(getType());
2951 if (!opResultTy)
2952 return emitOpError() << "the type of the result must be a vector "
2953 << "if it is vector shift";
2954
2955 auto op0VecEleTy = mlir::cast<cir::IntType>(op0VecTy.getElementType());
2956 auto op1VecEleTy = mlir::cast<cir::IntType>(op1VecTy.getElementType());
2957 if (op0VecEleTy.getWidth() != op1VecEleTy.getWidth())
2958 return emitOpError()
2959 << "vector operands do not have the same elements sizes";
2960
2961 auto resVecEleTy = mlir::cast<cir::IntType>(opResultTy.getElementType());
2962 if (op0VecEleTy.getWidth() != resVecEleTy.getWidth())
2963 return emitOpError() << "vector operands and result type do not have the "
2964 "same elements sizes";
2965 }
2966
2967 return mlir::success();
2968}
2969
2970//===----------------------------------------------------------------------===//
2971// LabelOp Definitions
2972//===----------------------------------------------------------------------===//
2973
2974LogicalResult cir::LabelOp::verify() {
2975 mlir::Operation *op = getOperation();
2976 mlir::Block *blk = op->getBlock();
2977 if (&blk->front() != op)
2978 return emitError() << "must be the first operation in a block";
2979
2980 return mlir::success();
2981}
2982
2983//===----------------------------------------------------------------------===//
2984// IncOp
2985//===----------------------------------------------------------------------===//
2986
2987OpFoldResult cir::IncOp::fold(FoldAdaptor adaptor) {
2988 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()))
2989 return adaptor.getInput();
2990 return {};
2991}
2992
2993//===----------------------------------------------------------------------===//
2994// DecOp
2995//===----------------------------------------------------------------------===//
2996
2997OpFoldResult cir::DecOp::fold(FoldAdaptor adaptor) {
2998 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()))
2999 return adaptor.getInput();
3000 return {};
3001}
3002
3003//===----------------------------------------------------------------------===//
3004// MinusOp
3005//===----------------------------------------------------------------------===//
3006
3007OpFoldResult cir::MinusOp::fold(FoldAdaptor adaptor) {
3008 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()))
3009 return adaptor.getInput();
3010
3011 // Fold with constant inputs. Floating-point negation is handled by
3012 // cir::FNegOp.
3013 if (auto intAttr =
3014 mlir::dyn_cast_if_present<cir::IntAttr>(adaptor.getInput())) {
3015 APInt val = intAttr.getValue();
3016 val.negate();
3017 return cir::IntAttr::get(getType(), val);
3018 }
3019
3020 return {};
3021}
3022
3023//===----------------------------------------------------------------------===//
3024// FNegOp
3025//===----------------------------------------------------------------------===//
3026
3027OpFoldResult cir::FNegOp::fold(FoldAdaptor adaptor) {
3028 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()))
3029 return adaptor.getInput();
3030
3031 // Fold with constant inputs.
3032 if (auto fpAttr =
3033 mlir::dyn_cast_if_present<cir::FPAttr>(adaptor.getInput())) {
3034 APFloat val = fpAttr.getValue();
3035 val.changeSign();
3036 return cir::FPAttr::get(getType(), val);
3037 }
3038
3039 return {};
3040}
3041
3042//===----------------------------------------------------------------------===//
3043// NotOp
3044//===----------------------------------------------------------------------===//
3045
3046OpFoldResult cir::NotOp::fold(FoldAdaptor adaptor) {
3047 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()))
3048 return adaptor.getInput();
3049
3050 // not(not(x)) -> x is handled by the Involution trait.
3051
3052 // Fold with constant inputs.
3053 if (mlir::Attribute attr = adaptor.getInput()) {
3054 if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(attr)) {
3055 APInt val = intAttr.getValue();
3056 val.flipAllBits();
3057 return cir::IntAttr::get(getType(), val);
3058 }
3059 if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr))
3060 return cir::BoolAttr::get(getContext(), !boolAttr.getValue());
3061 }
3062
3063 return {};
3064}
3065
3066//===----------------------------------------------------------------------===//
3067// BaseDataMemberOp & DerivedDataMemberOp
3068//===----------------------------------------------------------------------===//
3069
3070static LogicalResult verifyMemberPtrCast(Operation *op, mlir::Value src,
3071 mlir::Type resultTy) {
3072 // Let the operand type be T1 C1::*, let the result type be T2 C2::*.
3073 // Verify that T1 and T2 are the same type.
3074 mlir::Type inputMemberTy;
3075 mlir::Type resultMemberTy;
3076 if (mlir::isa<cir::DataMemberType>(src.getType())) {
3077 inputMemberTy =
3078 mlir::cast<cir::DataMemberType>(src.getType()).getMemberTy();
3079 resultMemberTy = mlir::cast<cir::DataMemberType>(resultTy).getMemberTy();
3080 }
3082 if (inputMemberTy != resultMemberTy)
3083 return op->emitOpError()
3084 << "member types of the operand and the result do not match";
3085
3086 return mlir::success();
3087}
3088
3089LogicalResult cir::BaseDataMemberOp::verify() {
3090 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
3091}
3092
3093LogicalResult cir::DerivedDataMemberOp::verify() {
3094 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
3095}
3096
3097//===----------------------------------------------------------------------===//
3098// BaseMethodOp & DerivedMethodOp
3099//===----------------------------------------------------------------------===//
3100
3101LogicalResult cir::BaseMethodOp::verify() {
3102 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
3103}
3104
3105LogicalResult cir::DerivedMethodOp::verify() {
3106 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
3107}
3108
3109//===----------------------------------------------------------------------===//
3110// AwaitOp
3111//===----------------------------------------------------------------------===//
3112
3113void cir::AwaitOp::build(OpBuilder &builder, OperationState &result,
3114 cir::AwaitKind kind, BuilderCallbackRef readyBuilder,
3115 BuilderCallbackRef suspendBuilder,
3116 BuilderCallbackRef resumeBuilder) {
3117 result.addAttribute(getKindAttrName(result.name),
3118 cir::AwaitKindAttr::get(builder.getContext(), kind));
3119 {
3120 OpBuilder::InsertionGuard guard(builder);
3121 Region *readyRegion = result.addRegion();
3122 builder.createBlock(readyRegion);
3123 readyBuilder(builder, result.location);
3124 }
3125
3126 {
3127 OpBuilder::InsertionGuard guard(builder);
3128 Region *suspendRegion = result.addRegion();
3129 builder.createBlock(suspendRegion);
3130 suspendBuilder(builder, result.location);
3131 }
3132
3133 {
3134 OpBuilder::InsertionGuard guard(builder);
3135 Region *resumeRegion = result.addRegion();
3136 builder.createBlock(resumeRegion);
3137 resumeBuilder(builder, result.location);
3138 }
3139}
3140
3141void cir::AwaitOp::getSuccessorRegions(
3142 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
3143 // If any index all the underlying regions branch back to the parent
3144 // operation.
3145 if (!point.isParent()) {
3146 regions.emplace_back(getOperation());
3147 return;
3148 }
3149
3150 // TODO: retrieve information from the promise and only push the
3151 // necessary ones. Example: `std::suspend_never` on initial or final
3152 // await's might allow suspend region to be skipped.
3153 regions.push_back(RegionSuccessor(&this->getReady()));
3154 regions.push_back(RegionSuccessor(&this->getSuspend()));
3155 regions.push_back(RegionSuccessor(&this->getResume()));
3156}
3157
3158mlir::ValueRange cir::AwaitOp::getSuccessorInputs(RegionSuccessor successor) {
3159 if (successor.isOperation())
3160 return getOperation()->getResults();
3161 if (successor == &getReady())
3162 return getReady().getArguments();
3163 if (successor == &getSuspend())
3164 return getSuspend().getArguments();
3165 if (successor == &getResume())
3166 return getResume().getArguments();
3167 llvm_unreachable("invalid region successor");
3168}
3169
3170LogicalResult cir::AwaitOp::verify() {
3171 if (!isa<ConditionOp>(this->getReady().back().getTerminator()))
3172 return emitOpError("ready region must end with cir.condition");
3173 return success();
3174}
3175
3176//===----------------------------------------------------------------------===//
3177// CoroBody
3178//===----------------------------------------------------------------------===//
3179
3180void cir::CoroBodyOp::getSuccessorRegions(
3181 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
3182 if (!point.isParent()) {
3183 regions.emplace_back(getOperation());
3184 return;
3185 }
3186
3187 regions.push_back(RegionSuccessor(&getBody()));
3188}
3189
3190mlir::ValueRange
3191cir::CoroBodyOp::getSuccessorInputs(RegionSuccessor successor) {
3192 return ValueRange();
3193}
3194
3195LogicalResult cir::CoroBodyOp::verify() {
3196 if (!getOperation()->getParentOfType<FuncOp>().getCoroutine())
3197 return emitOpError("enclosing function must be a coroutine");
3198 return success();
3199}
3200
3201void cir::CoroBodyOp::build(OpBuilder &builder, OperationState &result,
3202 BuilderCallbackRef bodyBuilder) {
3203 assert(bodyBuilder &&
3204 "the builder callback for 'CoroBodyOp' must be present");
3205 OpBuilder::InsertionGuard guard(builder);
3206
3207 Region *bodyRegion = result.addRegion();
3208 builder.createBlock(bodyRegion);
3209 bodyBuilder(builder, result.location);
3210}
3211
3212//===----------------------------------------------------------------------===//
3213// CopyOp Definitions
3214//===----------------------------------------------------------------------===//
3215
3216LogicalResult cir::CopyOp::verify() {
3217 // A data layout is required for us to know the number of bytes to be copied.
3218 if (!getType().getPointee().hasTrait<DataLayoutTypeInterface::Trait>())
3219 return emitError() << "missing data layout for pointee type";
3220
3221 if (getSkipTailPadding() &&
3222 !mlir::isa<cir::RecordType>(getType().getPointee()))
3223 return emitError()
3224 << "skip_tail_padding is only valid for record pointee types";
3225
3226 return mlir::success();
3227}
3228
3229//===----------------------------------------------------------------------===//
3230// GetRuntimeMemberOp Definitions
3231//===----------------------------------------------------------------------===//
3232
3233LogicalResult cir::GetRuntimeMemberOp::verify() {
3234 cir::DataMemberType memberPtrTy = getMember().getType();
3235
3236 if (getAddr().getType().getPointee() != memberPtrTy.getClassTy())
3237 return emitError() << "record type does not match the member pointer type";
3238 if (getType().getPointee() != memberPtrTy.getMemberTy())
3239 return emitError() << "result type does not match the member pointer type";
3240 return mlir::success();
3241}
3242
3243//===----------------------------------------------------------------------===//
3244// GetMethodOp Definitions
3245//===----------------------------------------------------------------------===//
3246
3247LogicalResult cir::GetMethodOp::verify() {
3248 cir::MethodType methodTy = getMethod().getType();
3249
3250 // Assume objectTy is !cir.ptr<!T>
3251 cir::PointerType objectPtrTy = getObject().getType();
3252 mlir::Type objectTy = objectPtrTy.getPointee();
3253
3254 if (methodTy.getClassTy() != objectTy)
3255 return emitError() << "method class type and object type do not match";
3256
3257 // Assume methodFuncTy is !cir.func<!Ret (!Args)>
3258 auto calleeTy = mlir::cast<cir::FuncType>(getCallee().getType().getPointee());
3259 cir::FuncType methodFuncTy = methodTy.getMemberFuncTy();
3260
3261 // We verify at here that calleeTy is !cir.func<!Ret (!cir.ptr<!void>, !Args)>
3262 // Note that the first parameter type of the callee is !cir.ptr<!void> instead
3263 // of !cir.ptr<!T> because the "this" pointer may be adjusted before calling
3264 // the callee.
3265
3266 if (methodFuncTy.getReturnType() != calleeTy.getReturnType())
3267 return emitError()
3268 << "method return type and callee return type do not match";
3269
3270 llvm::ArrayRef<mlir::Type> calleeArgsTy = calleeTy.getInputs();
3271 llvm::ArrayRef<mlir::Type> methodFuncArgsTy = methodFuncTy.getInputs();
3272
3273 if (calleeArgsTy.empty())
3274 return emitError() << "callee parameter list lacks receiver object ptr";
3275
3276 auto calleeThisArgPtrTy = mlir::dyn_cast<cir::PointerType>(calleeArgsTy[0]);
3277 if (!calleeThisArgPtrTy ||
3278 !mlir::isa<cir::VoidType>(calleeThisArgPtrTy.getPointee())) {
3279 return emitError()
3280 << "the first parameter of callee must be a void pointer";
3281 }
3282
3283 if (calleeArgsTy.size() != methodFuncArgsTy.size())
3284 return emitError() << "callee and method parameter counts do not match";
3285
3286 if (calleeArgsTy.size() > 1 &&
3287 calleeArgsTy.slice(1) != methodFuncArgsTy.slice(1))
3288 return emitError()
3289 << "callee parameters and method parameters do not match";
3290
3291 return mlir::success();
3292}
3293
3294//===----------------------------------------------------------------------===//
3295// GetMemberOp Definitions
3296//===----------------------------------------------------------------------===//
3297
3298LogicalResult cir::GetMemberOp::verify() {
3299 const auto recordTy = dyn_cast<RecordType>(getAddrTy().getPointee());
3300 if (!recordTy)
3301 return emitError() << "expected pointer to a record type";
3302
3303 if (recordTy.getMembers().size() <= getIndex())
3304 return emitError() << "member index out of bounds";
3305
3306 if (recordTy.getMembers()[getIndex()] != getType().getPointee())
3307 return emitError() << "member type mismatch";
3308
3309 return mlir::success();
3310}
3311
3312//===----------------------------------------------------------------------===//
3313// ExtractMemberOp Definitions
3314//===----------------------------------------------------------------------===//
3315
3316LogicalResult cir::ExtractMemberOp::verify() {
3317 if (mlir::isa<cir::UnionType>(getRecord().getType()))
3318 return emitError()
3319 << "cir.extract_member currently does not support unions";
3320 auto structTy = mlir::cast<cir::StructType>(getRecord().getType());
3321 if (structTy.getMembers().size() <= getIndex())
3322 return emitError() << "member index out of bounds";
3323 if (structTy.getMembers()[getIndex()] != getType())
3324 return emitError() << "member type mismatch";
3325 return mlir::success();
3326}
3327
3328//===----------------------------------------------------------------------===//
3329// InsertMemberOp Definitions
3330//===----------------------------------------------------------------------===//
3331
3332LogicalResult cir::InsertMemberOp::verify() {
3333 if (mlir::isa<cir::UnionType>(getRecord().getType()))
3334 return emitError() << "cir.insert_member currently does not support unions";
3335 auto structTy = mlir::cast<cir::StructType>(getRecord().getType());
3336 if (structTy.getMembers().size() <= getIndex())
3337 return emitError() << "member index out of bounds";
3338 if (structTy.getMembers()[getIndex()] != getValue().getType())
3339 return emitError() << "member type mismatch";
3340 // The op trait already checks that the types of $result and $record match.
3341 return mlir::success();
3342}
3343
3344//===----------------------------------------------------------------------===//
3345// VecCreateOp
3346//===----------------------------------------------------------------------===//
3347
3348OpFoldResult cir::VecCreateOp::fold(FoldAdaptor adaptor) {
3349 if (llvm::any_of(getElements(), [](mlir::Value value) {
3350 return !value.getDefiningOp<cir::ConstantOp>();
3351 }))
3352 return {};
3353
3354 return cir::ConstVectorAttr::get(
3355 getType(), mlir::ArrayAttr::get(getContext(), adaptor.getElements()));
3356}
3357
3358LogicalResult cir::VecCreateOp::verify() {
3359 // Verify that the number of arguments matches the number of elements in the
3360 // vector, and that the type of all the arguments matches the type of the
3361 // elements in the vector.
3362 const cir::VectorType vecTy = getType();
3363 if (getElements().size() != vecTy.getSize()) {
3364 return emitOpError() << "operand count of " << getElements().size()
3365 << " doesn't match vector type " << vecTy
3366 << " element count of " << vecTy.getSize();
3367 }
3368
3369 const mlir::Type elementType = vecTy.getElementType();
3370 for (const mlir::Value element : getElements()) {
3371 if (element.getType() != elementType) {
3372 return emitOpError() << "operand type " << element.getType()
3373 << " doesn't match vector element type "
3374 << elementType;
3375 }
3376 }
3377
3378 return success();
3379}
3380
3381//===----------------------------------------------------------------------===//
3382// VecExtractOp
3383//===----------------------------------------------------------------------===//
3384
3385OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
3386 const auto vectorAttr =
3387 llvm::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec());
3388 if (!vectorAttr)
3389 return {};
3390
3391 const auto indexAttr =
3392 llvm::dyn_cast_if_present<cir::IntAttr>(adaptor.getIndex());
3393 if (!indexAttr)
3394 return {};
3395
3396 const mlir::ArrayAttr elements = vectorAttr.getElts();
3397 const uint64_t index = indexAttr.getUInt();
3398 if (index >= elements.size())
3399 return {};
3400
3401 return elements[index];
3402}
3403
3404//===----------------------------------------------------------------------===//
3405// VecCmpOp
3406//===----------------------------------------------------------------------===//
3407
3408OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
3409 auto lhsVecAttr =
3410 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getLhs());
3411 auto rhsVecAttr =
3412 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getRhs());
3413 if (!lhsVecAttr || !rhsVecAttr)
3414 return {};
3415
3416 mlir::Type inputElemTy =
3417 mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType();
3418 if (!isAnyIntegerOrFloatingPointType(inputElemTy))
3419 return {};
3420
3421 cir::CmpOpKind opKind = adaptor.getKind();
3422 mlir::ArrayAttr lhsVecElhs = lhsVecAttr.getElts();
3423 mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts();
3424 uint64_t vecSize = lhsVecElhs.size();
3425
3426 SmallVector<mlir::Attribute, 16> elements(vecSize);
3427 bool isIntAttr = vecSize && mlir::isa<cir::IntAttr>(lhsVecElhs[0]);
3428 for (uint64_t i = 0; i < vecSize; i++) {
3429 mlir::Attribute lhsAttr = lhsVecElhs[i];
3430 mlir::Attribute rhsAttr = rhsVecElhs[i];
3431 int cmpResult = 0;
3432 switch (opKind) {
3433 case cir::CmpOpKind::lt: {
3434 if (isIntAttr) {
3435 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
3436 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3437 } else {
3438 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <
3439 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3440 }
3441 break;
3442 }
3443 case cir::CmpOpKind::le: {
3444 if (isIntAttr) {
3445 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
3446 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3447 } else {
3448 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <=
3449 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3450 }
3451 break;
3452 }
3453 case cir::CmpOpKind::gt: {
3454 if (isIntAttr) {
3455 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
3456 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3457 } else {
3458 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >
3459 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3460 }
3461 break;
3462 }
3463 case cir::CmpOpKind::ge: {
3464 if (isIntAttr) {
3465 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
3466 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3467 } else {
3468 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >=
3469 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3470 }
3471 break;
3472 }
3473 case cir::CmpOpKind::eq: {
3474 if (isIntAttr) {
3475 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
3476 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3477 } else {
3478 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() ==
3479 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3480 }
3481 break;
3482 }
3483 case cir::CmpOpKind::ne: {
3484 if (isIntAttr) {
3485 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
3486 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3487 } else {
3488 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() !=
3489 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3490 }
3491 break;
3492 }
3493 case cir::CmpOpKind::one: {
3494 llvm::APFloat::cmpResult cr =
3495 mlir::cast<cir::FPAttr>(lhsAttr).getValue().compare(
3496 mlir::cast<cir::FPAttr>(rhsAttr).getValue());
3497 cmpResult =
3498 cr != llvm::APFloat::cmpUnordered && cr != llvm::APFloat::cmpEqual;
3499 break;
3500 }
3501 case cir::CmpOpKind::uno: {
3502 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue().compare(
3503 mlir::cast<cir::FPAttr>(rhsAttr).getValue()) ==
3504 llvm::APFloat::cmpUnordered;
3505 break;
3506 }
3507 }
3508
3509 elements[i] = cir::IntAttr::get(getType().getElementType(), cmpResult);
3510 }
3511
3512 return cir::ConstVectorAttr::get(
3513 getType(), mlir::ArrayAttr::get(getContext(), elements));
3514}
3515
3516//===----------------------------------------------------------------------===//
3517// VecShuffleOp
3518//===----------------------------------------------------------------------===//
3519
3520OpFoldResult cir::VecShuffleOp::fold(FoldAdaptor adaptor) {
3521 auto vec1Attr =
3522 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec1());
3523 auto vec2Attr =
3524 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec2());
3525 if (!vec1Attr || !vec2Attr)
3526 return {};
3527
3528 mlir::Type vec1ElemTy =
3529 mlir::cast<cir::VectorType>(vec1Attr.getType()).getElementType();
3530
3531 mlir::ArrayAttr vec1Elts = vec1Attr.getElts();
3532 mlir::ArrayAttr vec2Elts = vec2Attr.getElts();
3533 mlir::ArrayAttr indicesElts = adaptor.getIndices();
3534
3536 elements.reserve(indicesElts.size());
3537
3538 uint64_t vec1Size = vec1Elts.size();
3539 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
3540 if (idxAttr.getSInt() == -1) {
3541 elements.push_back(cir::UndefAttr::get(vec1ElemTy));
3542 continue;
3543 }
3544
3545 uint64_t idxValue = idxAttr.getUInt();
3546 elements.push_back(idxValue < vec1Size ? vec1Elts[idxValue]
3547 : vec2Elts[idxValue - vec1Size]);
3548 }
3549
3550 return cir::ConstVectorAttr::get(
3551 getType(), mlir::ArrayAttr::get(getContext(), elements));
3552}
3553
3554LogicalResult cir::VecShuffleOp::verify() {
3555 // The number of elements in the indices array must match the number of
3556 // elements in the result type.
3557 if (getIndices().size() != getResult().getType().getSize()) {
3558 return emitOpError() << ": the number of elements in " << getIndices()
3559 << " and " << getResult().getType() << " don't match";
3560 }
3561
3562 // The element types of the two input vectors and of the result type must
3563 // match.
3564 if (getVec1().getType().getElementType() !=
3565 getResult().getType().getElementType()) {
3566 return emitOpError() << ": element types of " << getVec1().getType()
3567 << " and " << getResult().getType() << " don't match";
3568 }
3569
3570 const uint64_t maxValidIndex =
3571 getVec1().getType().getSize() + getVec2().getType().getSize() - 1;
3572 if (llvm::any_of(
3573 getIndices().getAsRange<cir::IntAttr>(), [&](cir::IntAttr idxAttr) {
3574 return idxAttr.getSInt() != -1 && idxAttr.getUInt() > maxValidIndex;
3575 })) {
3576 return emitOpError() << ": index for __builtin_shufflevector must be "
3577 "less than the total number of vector elements";
3578 }
3579 return success();
3580}
3581
3582//===----------------------------------------------------------------------===//
3583// VecShuffleDynamicOp
3584//===----------------------------------------------------------------------===//
3585
3586OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
3587 mlir::Attribute vec = adaptor.getVec();
3588 mlir::Attribute indices = adaptor.getIndices();
3589 if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec) &&
3590 mlir::isa_and_nonnull<cir::ConstVectorAttr>(indices)) {
3591 auto vecAttr = mlir::cast<cir::ConstVectorAttr>(vec);
3592 auto indicesAttr = mlir::cast<cir::ConstVectorAttr>(indices);
3593
3594 mlir::ArrayAttr vecElts = vecAttr.getElts();
3595 mlir::ArrayAttr indicesElts = indicesAttr.getElts();
3596
3597 const uint64_t numElements = vecElts.size();
3598
3600 elements.reserve(numElements);
3601
3602 const uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1;
3603 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
3604 uint64_t idxValue = idxAttr.getUInt();
3605 uint64_t newIdx = idxValue & maskBits;
3606 elements.push_back(vecElts[newIdx]);
3607 }
3608
3609 return cir::ConstVectorAttr::get(
3610 getType(), mlir::ArrayAttr::get(getContext(), elements));
3611 }
3612
3613 return {};
3614}
3615
3616LogicalResult cir::VecShuffleDynamicOp::verify() {
3617 // The number of elements in the two input vectors must match.
3618 if (getVec().getType().getSize() !=
3619 mlir::cast<cir::VectorType>(getIndices().getType()).getSize()) {
3620 return emitOpError() << ": the number of elements in " << getVec().getType()
3621 << " and " << getIndices().getType() << " don't match";
3622 }
3623 return success();
3624}
3625
3626//===----------------------------------------------------------------------===//
3627// VecTernaryOp
3628//===----------------------------------------------------------------------===//
3629
3630LogicalResult cir::VecTernaryOp::verify() {
3631 // Verify that the condition operand has the same number of elements as the
3632 // other operands. (The automatic verification already checked that all
3633 // operands are vector types and that the second and third operands are the
3634 // same type.)
3635 if (getCond().getType().getSize() != getLhs().getType().getSize()) {
3636 return emitOpError() << ": the number of elements in "
3637 << getCond().getType() << " and " << getLhs().getType()
3638 << " don't match";
3639 }
3640 return success();
3641}
3642
3643OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
3644 mlir::Attribute cond = adaptor.getCond();
3645 mlir::Attribute lhs = adaptor.getLhs();
3646 mlir::Attribute rhs = adaptor.getRhs();
3647
3648 if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
3649 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
3650 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
3651 return {};
3652 auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
3653 auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
3654 auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
3655
3656 mlir::ArrayAttr condElts = condVec.getElts();
3657
3659 elements.reserve(condElts.size());
3660
3661 for (const auto &[idx, condAttr] :
3662 llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
3663 if (condAttr.getSInt()) {
3664 elements.push_back(lhsVec.getElts()[idx]);
3665 } else {
3666 elements.push_back(rhsVec.getElts()[idx]);
3667 }
3668 }
3669
3670 cir::VectorType vecTy = getLhs().getType();
3671 return cir::ConstVectorAttr::get(
3672 vecTy, mlir::ArrayAttr::get(getContext(), elements));
3673}
3674
3675//===----------------------------------------------------------------------===//
3676// ComplexCreateOp
3677//===----------------------------------------------------------------------===//
3678
3679LogicalResult cir::ComplexCreateOp::verify() {
3680 if (getType().getElementType() != getReal().getType()) {
3681 emitOpError()
3682 << "operand type of cir.complex.create does not match its result type";
3683 return failure();
3684 }
3685
3686 return success();
3687}
3688
3689OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
3690 mlir::Attribute real = adaptor.getReal();
3691 mlir::Attribute imag = adaptor.getImag();
3692 if (!real || !imag)
3693 return {};
3694
3695 // When both of real and imag are constants, we can fold the operation into an
3696 // `#cir.const_complex` operation.
3697 auto realAttr = mlir::cast<mlir::TypedAttr>(real);
3698 auto imagAttr = mlir::cast<mlir::TypedAttr>(imag);
3699 return cir::ConstComplexAttr::get(realAttr, imagAttr);
3700}
3701
3702//===----------------------------------------------------------------------===//
3703// ComplexRealOp
3704//===----------------------------------------------------------------------===//
3705
3706LogicalResult cir::ComplexRealOp::verify() {
3707 mlir::Type operandTy = getOperand().getType();
3708 if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
3709 operandTy = complexOperandTy.getElementType();
3710
3711 if (getType() != operandTy) {
3712 emitOpError() << ": result type does not match operand type";
3713 return failure();
3714 }
3715
3716 return success();
3717}
3718
3719OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
3720 if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
3721 return nullptr;
3722
3723 if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
3724 return complexCreateOp.getOperand(0);
3725
3726 auto complex =
3727 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
3728 return complex ? complex.getReal() : nullptr;
3729}
3730
3731//===----------------------------------------------------------------------===//
3732// ComplexImagOp
3733//===----------------------------------------------------------------------===//
3734
3735LogicalResult cir::ComplexImagOp::verify() {
3736 mlir::Type operandTy = getOperand().getType();
3737 if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
3738 operandTy = complexOperandTy.getElementType();
3739
3740 if (getType() != operandTy) {
3741 emitOpError() << ": result type does not match operand type";
3742 return failure();
3743 }
3744
3745 return success();
3746}
3747
3748OpFoldResult cir::ComplexImagOp::fold(FoldAdaptor adaptor) {
3749 if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
3750 return nullptr;
3751
3752 if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
3753 return complexCreateOp.getOperand(1);
3754
3755 auto complex =
3756 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
3757 return complex ? complex.getImag() : nullptr;
3758}
3759
3760//===----------------------------------------------------------------------===//
3761// ComplexRealPtrOp
3762//===----------------------------------------------------------------------===//
3763
3764LogicalResult cir::ComplexRealPtrOp::verify() {
3765 mlir::Type resultPointeeTy = getType().getPointee();
3766 cir::PointerType operandPtrTy = getOperand().getType();
3767 auto operandPointeeTy =
3768 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
3769
3770 if (resultPointeeTy != operandPointeeTy.getElementType()) {
3771 return emitOpError() << ": result type does not match operand type";
3772 }
3773
3774 return success();
3775}
3776
3777//===----------------------------------------------------------------------===//
3778// ComplexImagPtrOp
3779//===----------------------------------------------------------------------===//
3780
3781LogicalResult cir::ComplexImagPtrOp::verify() {
3782 mlir::Type resultPointeeTy = getType().getPointee();
3783 cir::PointerType operandPtrTy = getOperand().getType();
3784 auto operandPointeeTy =
3785 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
3786
3787 if (resultPointeeTy != operandPointeeTy.getElementType()) {
3788 return emitOpError()
3789 << "cir.complex.imag_ptr result type does not match operand type";
3790 }
3791 return success();
3792}
3793
3794//===----------------------------------------------------------------------===//
3795// Bit manipulation operations
3796//===----------------------------------------------------------------------===//
3797
3798static OpFoldResult
3799foldUnaryBitOp(mlir::Attribute inputAttr,
3800 llvm::function_ref<llvm::APInt(const llvm::APInt &)> func,
3801 bool poisonZero = false) {
3802 if (mlir::isa_and_present<cir::PoisonAttr>(inputAttr)) {
3803 // Propagate poison value
3804 return inputAttr;
3805 }
3806
3807 auto input = mlir::dyn_cast_if_present<IntAttr>(inputAttr);
3808 if (!input)
3809 return nullptr;
3810
3811 llvm::APInt inputValue = input.getValue();
3812 if (poisonZero && inputValue.isZero())
3813 return cir::PoisonAttr::get(input.getType());
3814
3815 llvm::APInt resultValue = func(inputValue);
3816 return IntAttr::get(input.getType(), resultValue);
3817}
3818
3819OpFoldResult BitClrsbOp::fold(FoldAdaptor adaptor) {
3820 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3821 unsigned resultValue =
3822 inputValue.getBitWidth() - inputValue.getSignificantBits();
3823 return llvm::APInt(inputValue.getBitWidth(), resultValue);
3824 });
3825}
3826
3827OpFoldResult BitClzOp::fold(FoldAdaptor adaptor) {
3828 return foldUnaryBitOp(
3829 adaptor.getInput(),
3830 [](const llvm::APInt &inputValue) {
3831 unsigned resultValue = inputValue.countLeadingZeros();
3832 return llvm::APInt(inputValue.getBitWidth(), resultValue);
3833 },
3834 getPoisonZero());
3835}
3836
3837OpFoldResult BitCtzOp::fold(FoldAdaptor adaptor) {
3838 return foldUnaryBitOp(
3839 adaptor.getInput(),
3840 [](const llvm::APInt &inputValue) {
3841 return llvm::APInt(inputValue.getBitWidth(),
3842 inputValue.countTrailingZeros());
3843 },
3844 getPoisonZero());
3845}
3846
3847OpFoldResult BitFfsOp::fold(FoldAdaptor adaptor) {
3848 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3849 unsigned trailingZeros = inputValue.countTrailingZeros();
3850 unsigned result =
3851 trailingZeros == inputValue.getBitWidth() ? 0 : trailingZeros + 1;
3852 return llvm::APInt(inputValue.getBitWidth(), result);
3853 });
3854}
3855
3856OpFoldResult BitParityOp::fold(FoldAdaptor adaptor) {
3857 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3858 return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount() % 2);
3859 });
3860}
3861
3862OpFoldResult BitPopcountOp::fold(FoldAdaptor adaptor) {
3863 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3864 return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount());
3865 });
3866}
3867
3868OpFoldResult BitReverseOp::fold(FoldAdaptor adaptor) {
3869 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3870 return inputValue.reverseBits();
3871 });
3872}
3873
3874OpFoldResult ByteSwapOp::fold(FoldAdaptor adaptor) {
3875 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3876 return inputValue.byteSwap();
3877 });
3878}
3879
3880OpFoldResult RotateOp::fold(FoldAdaptor adaptor) {
3881 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()) ||
3882 mlir::isa_and_present<cir::PoisonAttr>(adaptor.getAmount())) {
3883 // Propagate poison values
3884 return cir::PoisonAttr::get(getType());
3885 }
3886
3887 auto input = mlir::dyn_cast_if_present<IntAttr>(adaptor.getInput());
3888 auto amount = mlir::dyn_cast_if_present<IntAttr>(adaptor.getAmount());
3889 if (!input && !amount)
3890 return nullptr;
3891
3892 // We could fold cir.rotate even if one of its two operands is not a constant:
3893 // - `cir.rotate left/right %0, 0` could be folded into just %0 even if %0
3894 // is not a constant.
3895 // - `cir.rotate left/right 0/0b111...111, %0` could be folded into 0 or
3896 // 0b111...111 even if %0 is not a constant.
3897
3898 llvm::APInt inputValue;
3899 if (input) {
3900 inputValue = input.getValue();
3901 if (inputValue.isZero() || inputValue.isAllOnes()) {
3902 // An input value of all 0s or all 1s will not change after rotation
3903 return input;
3904 }
3905 }
3906
3907 uint64_t amountValue;
3908 if (amount) {
3909 amountValue = amount.getValue().urem(getInput().getType().getWidth());
3910 if (amountValue == 0) {
3911 // A shift amount of 0 will not change the input value
3912 return getInput();
3913 }
3914 }
3915
3916 if (!input || !amount)
3917 return nullptr;
3918
3919 assert(inputValue.getBitWidth() == getInput().getType().getWidth() &&
3920 "input value must have the same bit width as the input type");
3921
3922 llvm::APInt resultValue;
3923 if (isRotateLeft())
3924 resultValue = inputValue.rotl(amountValue);
3925 else
3926 resultValue = inputValue.rotr(amountValue);
3927
3928 return IntAttr::get(input.getContext(), input.getType(), resultValue);
3929}
3930
3931//===----------------------------------------------------------------------===//
3932// InlineAsmOp
3933//===----------------------------------------------------------------------===//
3934
3935void cir::InlineAsmOp::print(OpAsmPrinter &p) {
3936 p << '(' << getAsmFlavor() << ", ";
3937 p.increaseIndent();
3938 p.printNewline();
3939
3940 llvm::SmallVector<std::string, 3> names{"out", "in", "in_out"};
3941 auto *nameIt = names.begin();
3942 auto *attrIt = getOperandAttrs().begin();
3943
3944 for (mlir::OperandRange ops : getAsmOperands()) {
3945 p << *nameIt << " = ";
3946
3947 p << '[';
3948 llvm::interleaveComma(llvm::make_range(ops.begin(), ops.end()), p,
3949 [&](Value value) {
3950 p.printOperand(value);
3951 p << " : " << value.getType();
3952 if (mlir::isa<mlir::UnitAttr>(*attrIt))
3953 p << " (maybe_memory)";
3954 attrIt++;
3955 });
3956 p << "],";
3957 p.printNewline();
3958 ++nameIt;
3959 }
3960
3961 p << "{";
3962 p.printString(getAsmString());
3963 p << " ";
3964 p.printString(getConstraints());
3965 p << "}";
3966 p.decreaseIndent();
3967 p << ')';
3968 if (getSideEffects())
3969 p << " side_effects";
3970
3971 std::array elidedAttrs{
3972 llvm::StringRef("asm_flavor"), llvm::StringRef("asm_string"),
3973 llvm::StringRef("constraints"), llvm::StringRef("operand_attrs"),
3974 llvm::StringRef("operands_segments"), llvm::StringRef("side_effects")};
3975 p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs);
3976
3977 if (auto v = getRes())
3978 p << " -> " << v.getType();
3979}
3980
3981void cir::InlineAsmOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3982 ArrayRef<ValueRange> asmOperands,
3983 StringRef asmString, StringRef constraints,
3984 bool sideEffects, cir::AsmFlavor asmFlavor,
3985 ArrayRef<Attribute> operandAttrs) {
3986 // Set up the operands_segments for VariadicOfVariadic
3987 SmallVector<int32_t> segments;
3988 for (auto operandRange : asmOperands) {
3989 segments.push_back(operandRange.size());
3990 odsState.addOperands(operandRange);
3991 }
3992
3993 odsState.addAttribute(
3994 "operands_segments",
3995 DenseI32ArrayAttr::get(odsBuilder.getContext(), segments));
3996 odsState.addAttribute("asm_string", odsBuilder.getStringAttr(asmString));
3997 odsState.addAttribute("constraints", odsBuilder.getStringAttr(constraints));
3998 odsState.addAttribute("asm_flavor",
3999 AsmFlavorAttr::get(odsBuilder.getContext(), asmFlavor));
4000
4001 if (sideEffects)
4002 odsState.addAttribute("side_effects", odsBuilder.getUnitAttr());
4003
4004 odsState.addAttribute("operand_attrs", odsBuilder.getArrayAttr(operandAttrs));
4005}
4006
4007ParseResult cir::InlineAsmOp::parse(OpAsmParser &parser,
4008 OperationState &result) {
4010 llvm::SmallVector<int32_t> operandsGroupSizes;
4011 std::string asmString, constraints;
4012 Type resType;
4013 MLIRContext *ctxt = parser.getBuilder().getContext();
4014
4015 auto error = [&](const Twine &msg) -> LogicalResult {
4016 return parser.emitError(parser.getCurrentLocation(), msg);
4017 };
4018
4019 auto expected = [&](const std::string &c) {
4020 return error("expected '" + c + "'");
4021 };
4022
4023 if (parser.parseLParen().failed())
4024 return expected("(");
4025
4026 auto flavor = FieldParser<AsmFlavor, AsmFlavor>::parse(parser);
4027 if (failed(flavor))
4028 return error("Unknown AsmFlavor");
4029
4030 if (parser.parseComma().failed())
4031 return expected(",");
4032
4033 auto parseValue = [&](Value &v) {
4034 OpAsmParser::UnresolvedOperand op;
4035
4036 if (parser.parseOperand(op) || parser.parseColon())
4037 return error("can't parse operand");
4038
4039 Type typ;
4040 if (parser.parseType(typ).failed())
4041 return error("can't parse operand type");
4043 if (parser.resolveOperand(op, typ, tmp))
4044 return error("can't resolve operand");
4045 v = tmp[0];
4046 return mlir::success();
4047 };
4048
4049 auto parseOperands = [&](llvm::StringRef name) {
4050 if (parser.parseKeyword(name).failed())
4051 return error("expected " + name + " operands here");
4052 if (parser.parseEqual().failed())
4053 return expected("=");
4054 if (parser.parseLSquare().failed())
4055 return expected("[");
4056
4057 int size = 0;
4058 if (parser.parseOptionalRSquare().succeeded()) {
4059 operandsGroupSizes.push_back(size);
4060 if (parser.parseComma())
4061 return expected(",");
4062 return mlir::success();
4063 }
4064
4065 auto parseOperand = [&]() {
4066 Value val;
4067 if (parseValue(val).succeeded()) {
4068 result.operands.push_back(val);
4069 size++;
4070
4071 if (parser.parseOptionalLParen().failed()) {
4072 operandAttrs.push_back(mlir::DictionaryAttr::get(ctxt));
4073 return mlir::success();
4074 }
4075
4076 if (parser.parseKeyword("maybe_memory").succeeded()) {
4077 operandAttrs.push_back(mlir::UnitAttr::get(ctxt));
4078 if (parser.parseRParen())
4079 return expected(")");
4080 return mlir::success();
4081 } else {
4082 return expected("maybe_memory");
4083 }
4084 }
4085 return mlir::failure();
4086 };
4087
4088 if (parser.parseCommaSeparatedList(parseOperand).failed())
4089 return mlir::failure();
4090
4091 if (parser.parseRSquare().failed() || parser.parseComma().failed())
4092 return expected("]");
4093 operandsGroupSizes.push_back(size);
4094 return mlir::success();
4095 };
4096
4097 if (parseOperands("out").failed() || parseOperands("in").failed() ||
4098 parseOperands("in_out").failed())
4099 return error("failed to parse operands");
4100
4101 if (parser.parseLBrace())
4102 return expected("{");
4103 if (parser.parseString(&asmString))
4104 return error("asm string parsing failed");
4105 if (parser.parseString(&constraints))
4106 return error("constraints string parsing failed");
4107 if (parser.parseRBrace())
4108 return expected("}");
4109 if (parser.parseRParen())
4110 return expected(")");
4111
4112 if (parser.parseOptionalKeyword("side_effects").succeeded())
4113 result.attributes.set("side_effects", UnitAttr::get(ctxt));
4114
4115 if (parser.parseOptionalAttrDict(result.attributes).failed())
4116 return mlir::failure();
4117
4118 if (parser.parseOptionalArrow().succeeded() &&
4119 parser.parseType(resType).failed())
4120 return mlir::failure();
4121
4122 result.attributes.set("asm_flavor", AsmFlavorAttr::get(ctxt, *flavor));
4123 result.attributes.set("asm_string", StringAttr::get(ctxt, asmString));
4124 result.attributes.set("constraints", StringAttr::get(ctxt, constraints));
4125 result.attributes.set("operand_attrs", ArrayAttr::get(ctxt, operandAttrs));
4126 result.getOrAddProperties<InlineAsmOp::Properties>().operands_segments =
4127 parser.getBuilder().getDenseI32ArrayAttr(operandsGroupSizes);
4128 if (resType)
4129 result.addTypes(TypeRange{resType});
4130
4131 return mlir::success();
4132}
4133
4134//===----------------------------------------------------------------------===//
4135// ThrowOp / TryThrowOp
4136//===----------------------------------------------------------------------===//
4137
4138template <typename ThrowOpTy>
4139static mlir::LogicalResult verifyThrowOpImpl(ThrowOpTy op) {
4140 if (op.rethrows())
4141 return mlir::success();
4142
4143 if (op.getNumOperands() != 0) {
4144 if (op.getTypeInfo())
4145 return mlir::success();
4146 return op.emitOpError() << "'type_info' symbol attribute missing";
4147 }
4148
4149 return mlir::failure();
4150}
4151
4152mlir::LogicalResult cir::ThrowOp::verify() { return verifyThrowOpImpl(*this); }
4153
4154mlir::LogicalResult cir::TryThrowOp::verify() {
4155 return verifyThrowOpImpl(*this);
4156}
4157
4158//===----------------------------------------------------------------------===//
4159// AtomicFetchOp
4160//===----------------------------------------------------------------------===//
4161
4162LogicalResult cir::AtomicFetchOp::verify() {
4163 if (getBinop() != cir::AtomicFetchKind::Add &&
4164 getBinop() != cir::AtomicFetchKind::Sub &&
4165 getBinop() != cir::AtomicFetchKind::Max &&
4166 getBinop() != cir::AtomicFetchKind::Min &&
4167 !mlir::isa<cir::IntType>(getVal().getType()))
4168 return emitError("only atomic add, sub, max, and min operation could "
4169 "operate on floating-point values");
4170 return success();
4171}
4172
4173//===----------------------------------------------------------------------===//
4174// TypeInfoAttr
4175//===----------------------------------------------------------------------===//
4176
4177LogicalResult cir::TypeInfoAttr::verify(
4178 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
4179 ::mlir::Type type, ::mlir::ArrayAttr typeInfoData) {
4180
4181 if (cir::ConstRecordAttr::verify(emitError, type, typeInfoData).failed())
4182 return failure();
4183
4184 return success();
4185}
4186
4187//===----------------------------------------------------------------------===//
4188// TryOp
4189//===----------------------------------------------------------------------===//
4190
4191void cir::TryOp::getSuccessorRegions(
4192 mlir::RegionBranchPoint point,
4194 // The `try` and the `catchers` region branch back to the parent operation.
4195 if (!point.isParent()) {
4196 regions.emplace_back(getOperation());
4197 return;
4198 }
4199
4200 regions.push_back(mlir::RegionSuccessor(&getTryRegion()));
4201
4202 // TODO(CIR): If we know a target function never throws a specific type, we
4203 // can remove the catch handler.
4204 for (mlir::Region &handlerRegion : this->getHandlerRegions())
4205 regions.push_back(mlir::RegionSuccessor(&handlerRegion));
4206}
4207
4208mlir::ValueRange cir::TryOp::getSuccessorInputs(RegionSuccessor successor) {
4209 return successor.isOperation() ? ValueRange(getOperation()->getResults())
4210 : ValueRange();
4211}
4212
4213LogicalResult cir::TryOp::verify() {
4214 mlir::ArrayAttr handlerTypes = getHandlerTypes();
4215 if (!handlerTypes) {
4216 if (!getHandlerRegions().empty())
4217 return emitOpError(
4218 "handler regions must be empty when no handler types are present");
4219 return success();
4220 }
4221
4222 mlir::MutableArrayRef<mlir::Region> handlerRegions = getHandlerRegions();
4223
4224 // The parser and builder won't allow this to happen, but the loop below
4225 // relies on the sizes being the same, so we check it here.
4226 if (handlerRegions.size() != handlerTypes.size())
4227 return emitOpError(
4228 "number of handler regions and handler types must match");
4229
4230 for (const auto &[typeAttr, handlerRegion] :
4231 llvm::zip(handlerTypes, handlerRegions)) {
4232 // Verify that handler regions have a !cir.eh_token block argument.
4233 mlir::Block &entryBlock = handlerRegion.front();
4234 if (entryBlock.getNumArguments() != 1 ||
4235 !mlir::isa<cir::EhTokenType>(entryBlock.getArgument(0).getType()))
4236 return emitOpError(
4237 "handler region must have a single '!cir.eh_token' argument");
4238
4239 // The unwind region does not require a cir.begin_catch.
4240 if (mlir::isa<cir::UnwindAttr>(typeAttr))
4241 continue;
4242
4243 // A catch handler region must start with cir.begin_catch, optionally
4244 // preceded by a single cir.construct_catch_param that performs any
4245 // pre-begin_catch initialization for the catch parameter.
4246 if (entryBlock.empty())
4247 return emitOpError("catch handler region must not be empty");
4248 mlir::Operation *firstOp = &entryBlock.front();
4249 if (mlir::isa_and_present<cir::ConstructCatchParamOp>(firstOp))
4250 firstOp = firstOp->getNextNode();
4251 if (!firstOp || !mlir::isa<cir::BeginCatchOp>(firstOp))
4252 return emitOpError(
4253 "catch handler region must start with 'cir.begin_catch'");
4254 }
4255
4256 return success();
4257}
4258
4259static void
4260printTryHandlerRegions(mlir::OpAsmPrinter &printer, cir::TryOp op,
4261 mlir::MutableArrayRef<mlir::Region> handlerRegions,
4262 mlir::ArrayAttr handlerTypes) {
4263 if (!handlerTypes)
4264 return;
4265
4266 for (const auto [typeIdx, typeAttr] : llvm::enumerate(handlerTypes)) {
4267 if (typeIdx)
4268 printer << " ";
4269
4270 if (mlir::isa<cir::CatchAllAttr>(typeAttr)) {
4271 printer << "catch all ";
4272 } else if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
4273 printer << "unwind ";
4274 } else {
4275 printer << "catch [type ";
4276 printer.printAttribute(typeAttr);
4277 printer << "] ";
4278 }
4279
4280 // Print the handler region's !cir.eh_token block argument.
4281 mlir::Region &region = handlerRegions[typeIdx];
4282 if (!region.empty() && region.front().getNumArguments() > 0) {
4283 printer << "(";
4284 printer.printRegionArgument(region.front().getArgument(0));
4285 printer << ") ";
4286 }
4287
4288 printer.printRegion(region,
4289 /*printEntryBLockArgs=*/false,
4290 /*printBlockTerminators=*/true);
4291 }
4292}
4293
4294static mlir::ParseResult parseTryHandlerRegions(
4295 mlir::OpAsmParser &parser,
4296 llvm::SmallVectorImpl<std::unique_ptr<mlir::Region>> &handlerRegions,
4297 mlir::ArrayAttr &handlerTypes) {
4298
4299 auto parseCheckedCatcherRegion = [&]() -> mlir::ParseResult {
4300 handlerRegions.emplace_back(new mlir::Region);
4301
4302 mlir::Region &currRegion = *handlerRegions.back();
4303
4304 // Parse the required region argument: (%eh_token : !cir.eh_token)
4306 if (parser.parseLParen())
4307 return failure();
4308 mlir::OpAsmParser::Argument arg;
4309 if (parser.parseArgument(arg, /*allowType=*/true))
4310 return failure();
4311 regionArgs.push_back(arg);
4312 if (parser.parseRParen())
4313 return failure();
4314
4315 mlir::SMLoc regionLoc = parser.getCurrentLocation();
4316 if (parser.parseRegion(currRegion, regionArgs)) {
4317 handlerRegions.clear();
4318 return failure();
4319 }
4320
4321 if (currRegion.empty())
4322 return parser.emitError(regionLoc, "handler region shall not be empty");
4323
4324 if (!(currRegion.back().mightHaveTerminator() &&
4325 currRegion.back().getTerminator()))
4326 return parser.emitError(
4327 regionLoc, "blocks are expected to be explicitly terminated");
4328
4329 return success();
4330 };
4331
4332 bool hasCatchAll = false;
4334 while (parser.parseOptionalKeyword("catch").succeeded()) {
4335 bool hasLSquare = parser.parseOptionalLSquare().succeeded();
4336
4337 llvm::StringRef attrStr;
4338 if (parser.parseOptionalKeyword(&attrStr, {"all", "type"}).failed())
4339 return parser.emitError(parser.getCurrentLocation(),
4340 "expected 'all' or 'type' keyword");
4341
4342 bool isCatchAll = attrStr == "all";
4343 if (isCatchAll) {
4344 if (hasCatchAll)
4345 return parser.emitError(parser.getCurrentLocation(),
4346 "can't have more than one catch all");
4347 hasCatchAll = true;
4348 }
4349
4350 mlir::Attribute exceptionRTTIAttr;
4351 if (!isCatchAll && parser.parseAttribute(exceptionRTTIAttr).failed())
4352 return parser.emitError(parser.getCurrentLocation(),
4353 "expected valid RTTI info attribute");
4354
4355 catcherAttrs.push_back(isCatchAll
4356 ? cir::CatchAllAttr::get(parser.getContext())
4357 : exceptionRTTIAttr);
4358
4359 if (hasLSquare && isCatchAll)
4360 return parser.emitError(parser.getCurrentLocation(),
4361 "catch all dosen't need RTTI info attribute");
4362
4363 if (hasLSquare && parser.parseRSquare().failed())
4364 return parser.emitError(parser.getCurrentLocation(),
4365 "expected `]` after RTTI info attribute");
4366
4367 if (parseCheckedCatcherRegion().failed())
4368 return mlir::failure();
4369 }
4370
4371 if (parser.parseOptionalKeyword("unwind").succeeded()) {
4372 if (hasCatchAll)
4373 return parser.emitError(parser.getCurrentLocation(),
4374 "unwind can't be used with catch all");
4375
4376 catcherAttrs.push_back(cir::UnwindAttr::get(parser.getContext()));
4377 if (parseCheckedCatcherRegion().failed())
4378 return mlir::failure();
4379 }
4380
4381 handlerTypes = parser.getBuilder().getArrayAttr(catcherAttrs);
4382 return mlir::success();
4383}
4384
4385//===----------------------------------------------------------------------===//
4386// EhTypeIdOp
4387//===----------------------------------------------------------------------===//
4388
4389LogicalResult
4390cir::EhTypeIdOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
4391 Operation *op = symbolTable.lookupNearestSymbolFrom(*this, getTypeSymAttr());
4392 if (!isa_and_nonnull<GlobalOp>(op))
4393 return emitOpError("'")
4394 << getTypeSym() << "' does not reference a valid cir.global";
4395 return success();
4396}
4397
4398//===----------------------------------------------------------------------===//
4399// LifetimeStartOp & LifetimeEndOp
4400//===----------------------------------------------------------------------===//
4401
4402LogicalResult cir::LifetimeStartOp::verify() {
4403 return verifyProducedBy<cir::AllocaOp>(*this, getPtr(), "ptr");
4404}
4405
4406LogicalResult cir::LifetimeEndOp::verify() {
4407 return verifyProducedBy<cir::AllocaOp>(*this, getPtr(), "ptr");
4408}
4409
4410//===----------------------------------------------------------------------===//
4411// ConstructCatchParamOp
4412//===----------------------------------------------------------------------===//
4413
4414LogicalResult cir::ConstructCatchParamOp::verifySymbolUses(
4415 SymbolTableCollection &symbolTable) {
4416 auto copyFnAttr = getCopyFnAttr();
4417 if (!copyFnAttr)
4418 return success();
4419 auto fn =
4420 symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(*this, getCopyFnAttr());
4421 if (!fn)
4422 return emitOpError("'")
4423 << *getCopyFn() << "' does not reference a valid cir.func";
4424
4425 if (!fn->hasAttr(cir::CIRDialect::getCatchCopyThunkAttrName()))
4426 return emitOpError("catch-init copy_fn must be tagged with the ")
4427 << cir::CIRDialect::getCatchCopyThunkAttrName() << " attribute";
4428
4429 cir::FuncType fnType = fn.getFunctionType();
4430 if (fnType.getNumInputs() != 2 || !fnType.hasVoidReturn())
4431 return emitOpError("catch-init copy_fn must take two pointer arguments and "
4432 "return void");
4433
4434 if (fnType.getInput(0) != getParamAddr().getType())
4435 return emitOpError("first argument of catch-init copy_fn must match the "
4436 "type of 'param_addr'");
4437
4438 if (fnType.getInput(1) != getParamAddr().getType())
4439 return emitOpError(
4440 "second argument of catch-init copy_fn must be a pointer "
4441 "to the catch type");
4442
4443 return success();
4444}
4445
4446//===----------------------------------------------------------------------===//
4447// EhDispatchOp
4448//===----------------------------------------------------------------------===//
4449
4450static ParseResult
4451parseEhDispatchDestinations(OpAsmParser &parser, mlir::ArrayAttr &catchTypes,
4452 SmallVectorImpl<Block *> &catchDestinations,
4453 Block *&defaultDestination,
4454 mlir::UnitAttr &defaultIsCatchAll) {
4455 // Parse: [ ... ]
4456 if (parser.parseLSquare())
4457 return failure();
4458
4459 SmallVector<Attribute> handlerTypes;
4460 bool hasCatchAll = false;
4461 bool hasUnwind = false;
4462
4463 // Parse handler list.
4464 auto parseHandler = [&]() -> ParseResult {
4465 // Check for 'catch_all' or 'unwind' keywords.
4466 if (succeeded(parser.parseOptionalKeyword("catch_all"))) {
4467 if (hasCatchAll)
4468 return parser.emitError(parser.getCurrentLocation(),
4469 "duplicate 'catch_all' handler");
4470 if (hasUnwind)
4471 return parser.emitError(parser.getCurrentLocation(),
4472 "cannot have both 'catch_all' and 'unwind'");
4473 hasCatchAll = true;
4474
4475 if (parser.parseColon().failed())
4476 return failure();
4477
4478 if (parser.parseSuccessor(defaultDestination).failed())
4479 return failure();
4480
4481 return success();
4482 }
4483
4484 if (succeeded(parser.parseOptionalKeyword("unwind"))) {
4485 if (hasUnwind)
4486 return parser.emitError(parser.getCurrentLocation(),
4487 "duplicate 'unwind' handler");
4488 if (hasCatchAll)
4489 return parser.emitError(parser.getCurrentLocation(),
4490 "cannot have both 'catch_all' and 'unwind'");
4491 hasUnwind = true;
4492
4493 if (parser.parseColon().failed())
4494 return failure();
4495
4496 if (parser.parseSuccessor(defaultDestination).failed())
4497 return failure();
4498 return success();
4499 }
4500
4501 // Otherwise, expect 'catch(<attr> : <type>) : ^block'.
4502 // The 'catch(...)' wrapper allows the attribute to include its type
4503 // without conflicting with the ':' used for the block destination.
4504 if (parser.parseKeyword("catch").failed())
4505 return failure();
4506
4507 if (parser.parseLParen().failed())
4508 return failure();
4509
4510 mlir::Attribute catchTypeAttr;
4511 if (parser.parseAttribute(catchTypeAttr).failed())
4512 return failure();
4513 handlerTypes.push_back(catchTypeAttr);
4514
4515 if (parser.parseRParen().failed())
4516 return failure();
4517
4518 if (parser.parseColon().failed())
4519 return failure();
4520
4521 Block *dest;
4522 if (parser.parseSuccessor(dest).failed())
4523 return failure();
4524 catchDestinations.push_back(dest);
4525 return success();
4526 };
4527
4528 if (parser.parseCommaSeparatedList(parseHandler).failed())
4529 return failure();
4530
4531 if (parser.parseRSquare().failed())
4532 return failure();
4533
4534 // Verify we have catch_all or unwind.
4535 if (!hasCatchAll && !hasUnwind)
4536 return parser.emitError(parser.getCurrentLocation(),
4537 "must have either 'catch_all' or 'unwind' handler");
4538
4539 // Add attributes and successors.
4540 if (!handlerTypes.empty())
4541 catchTypes = parser.getBuilder().getArrayAttr(handlerTypes);
4542
4543 if (hasCatchAll)
4544 defaultIsCatchAll = parser.getBuilder().getUnitAttr();
4545
4546 return success();
4547}
4548
4549static void printEhDispatchDestinations(OpAsmPrinter &p, cir::EhDispatchOp op,
4550 mlir::ArrayAttr catchTypes,
4551 SuccessorRange catchDestinations,
4552 Block *defaultDestination,
4553 mlir::UnitAttr defaultIsCatchAll) {
4554 p << " [";
4555 p.printNewline();
4556
4557 // If we have at least one catch type, print them.
4558 if (catchTypes) {
4559 // Print type handlers using 'catch(<attr>) : ^block' syntax.
4560 llvm::interleave(
4561 llvm::zip(catchTypes, catchDestinations),
4562 [&](auto i) {
4563 p << " catch(";
4564 p.printAttribute(std::get<0>(i));
4565 p << ") : ";
4566 p.printSuccessor(std::get<1>(i));
4567 },
4568 [&] {
4569 p << ',';
4570 p.printNewline();
4571 });
4572
4573 p << ", ";
4574 p.printNewline();
4575 }
4576
4577 // Print catch_all or unwind handler.
4578 if (defaultIsCatchAll)
4579 p << " catch_all : ";
4580 else
4581 p << " unwind : ";
4582 p.printSuccessor(defaultDestination);
4583 p.printNewline();
4584
4585 p << "]";
4586}
4587
4588//===----------------------------------------------------------------------===//
4589// TableGen'd op method definitions
4590//===----------------------------------------------------------------------===//
4591
4592#define GET_OP_CLASSES
4593#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
static const MemRegion * getRegion(const CallEvent &Call, const MutexDescriptor &Descriptor, bool IsLock)
static void printEhDispatchDestinations(OpAsmPrinter &p, cir::EhDispatchOp op, mlir::ArrayAttr catchTypes, SuccessorRange catchDestinations, Block *defaultDestination, mlir::UnitAttr defaultIsCatchAll)
static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op, cir::FuncOp function)
static bool isCirFunctionPointerType(mlir::Type ty)
static LogicalResult verifyMemberPtrCast(Operation *op, mlir::Value src, mlir::Type resultTy)
static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, mlir::OperationState &result, bool hasDestinationBlocks=false)
static bool isIntOrBoolCast(cir::CastOp op)
static ParseResult parseAssumeBundle(OpAsmParser &p, cir::AssumeBundleKindAttr &bundleKindAttr, llvm::SmallVector< mlir::OpAsmParser::UnresolvedOperand, 4 > &bundleArgs, llvm::SmallVector< mlir::Type, 1 > &bundleArgTypes)
static ParseResult parseEhDispatchDestinations(OpAsmParser &parser, mlir::ArrayAttr &catchTypes, SmallVectorImpl< Block * > &catchDestinations, Block *&defaultDestination, mlir::UnitAttr &defaultIsCatchAll)
static void printConstant(OpAsmPrinter &p, Attribute value)
static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser, mlir::Region &region)
static void printAssumeBundle(OpAsmPrinter &p, cir::AssumeOp op, cir::AssumeBundleKindAttr kindAttr, OperandRange bundleArgs, TypeRange bundleArgTypes)
ParseResult parseInlineKindAttr(OpAsmParser &parser, cir::InlineKindAttr &inlineKindAttr)
void printInlineKindAttr(OpAsmPrinter &p, cir::InlineKindAttr inlineKindAttr)
static ParseResult parseSwitchFlatOpCases(OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< llvm::SmallVector< OpAsmParser::UnresolvedOperand > > &caseOperands, SmallVectorImpl< llvm::SmallVector< Type > > &caseOperandTypes)
<cases> ::= [ (case (, case )* )?
void printGlobalAddressSpaceValue(mlir::AsmPrinter &printer, cir::GlobalOp op, mlir::ptr::MemorySpaceAttrInterface attr)
static void printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym, mlir::Value indirectCallee, mlir::OpAsmPrinter &printer, bool isNothrow, cir::SideEffect sideEffect, ArrayAttr argAttrs, ArrayAttr resAttrs, mlir::Block *normalDest=nullptr, mlir::Block *unwindDest=nullptr)
static LogicalResult verifyCallCommInSymbolUses(mlir::Operation *op, SymbolTableCollection &symbolTable)
static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region, SMLoc errLoc)
static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValueAttr, mlir::Region &ctorRegion, mlir::Region &dtorRegion)
void printIndirectBrOpSucessors(OpAsmPrinter &p, cir::IndirectBrOp op, Type flagType, SuccessorRange succs, OperandRangeRange succOperands, const TypeRangeRange &succOperandsTypes)
static OpFoldResult foldUnaryBitOp(mlir::Attribute inputAttr, llvm::function_ref< llvm::APInt(const llvm::APInt &)> func, bool poisonZero=false)
static llvm::StringRef getLinkageAttrNameString()
Returns the name used for the linkage attribute.
static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue)
Parse an enum from the keyword, or default to the provided default value.
mlir::OptionalParseResult parseGlobalAddressSpaceValue(mlir::AsmParser &p, mlir::ptr::MemorySpaceAttrInterface &attr)
static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op, Type flagType, mlir::ArrayAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static LogicalResult verifyProducedBy(Operation *op, Value operand, StringRef operandName)
static mlir::ParseResult parseTryCallDestinations(mlir::OpAsmParser &parser, mlir::OperationState &result)
static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op, TypeAttr type, Attribute initAttr, mlir::Region &ctorRegion, mlir::Region &dtorRegion)
static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result)
Parse an enum from the keyword, return failure if the keyword is not found.
static Value tryFoldCastChain(cir::CastOp op)
static void printTryHandlerRegions(mlir::OpAsmPrinter &printer, cir::TryOp op, mlir::MutableArrayRef< mlir::Region > handlerRegions, mlir::ArrayAttr handlerTypes)
ParseResult parseIndirectBrOpSucessors(OpAsmParser &parser, Type &flagType, SmallVectorImpl< Block * > &succOperandBlocks, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &succOperands, SmallVectorImpl< SmallVector< Type > > &succOperandsTypes)
static bool omitRegionTerm(mlir::Region &r)
static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer, cir::ScopeOp &op, mlir::Region &region)
static ParseResult parseConstantValue(OpAsmParser &parser, mlir::Attribute &valueAttr)
static LogicalResult verifyArrayCtorDtor(Op op)
static mlir::LogicalResult verifyThrowOpImpl(ThrowOpTy op)
static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, mlir::Attribute attrType)
static mlir::ParseResult parseTryHandlerRegions(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< std::unique_ptr< mlir::Region > > &handlerRegions, mlir::ArrayAttr &handlerTypes)
#define REGISTER_ENUM_TYPE(Ty)
static int parseOptionalKeywordAlternative(AsmParser &parser, ArrayRef< llvm::StringRef > keywords)
llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> BuilderCallbackRef
Definition CIRDialect.h:37
llvm::function_ref< void( mlir::OpBuilder &, mlir::Location, mlir::OperationState &)> BuilderOpStateCallbackRef
Definition CIRDialect.h:39
static std::optional< NonLoc > getIndex(ProgramStateRef State, const ElementRegion *ER, CharKind CK)
static Decl::Kind getKind(const Decl *D)
TokenType getType() const
Returns the token's type, e.g.
tooling::Replacements cleanup(const FormatStyle &Style, StringRef Code, ArrayRef< tooling::Range > Ranges, StringRef FileName="<stdin>")
Clean up any erroneous/redundant code in the given Ranges in Code.
*collection of selector each with an associated kind and an ordered *collection of selectors A selector has a an optional score condition
*collection of selector each with an associated kind and an ordered *collection of selectors A selector has a kind
void buildTerminatedBody(mlir::OpBuilder &builder, mlir::Location loc)
mlir::ptr::MemorySpaceAttrInterface normalizeDefaultAddressSpace(mlir::ptr::MemorySpaceAttrInterface addrSpace)
Normalize LangAddressSpace::Default to null (empty attribute).
const internal::VariadicAllOfMatcher< Attr > attr
const AstTypeMatcher< RecordType > recordType
StringRef getName(const HeaderType T)
Definition HeaderFile.h:38
RangeSelector name(std::string ID)
Given a node with a "name", (like NamedDecl, DeclRefExpr, CxxCtorInitializer, and TypeLoc) selects th...
nullptr
This class represents a compute construct, representing a 'Kind' of ‘parallel’, 'serial',...
__DEVICE__ _Tp arg(const std::complex< _Tp > &__c)
static bool memberFuncPtrCast()
static bool opCallCallConv()
static bool opScopeCleanupRegion()
static bool supportIFuncAttr()