clang 23.0.0git
CIRDialect.cpp
Go to the documentation of this file.
1//===- CIRDialect.cpp - MLIR CIR ops implementation -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the CIR dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
18
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/DialectImplementation.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Interfaces/ControlFlowInterfaces.h"
23#include "mlir/Interfaces/FunctionImplementation.h"
24#include "mlir/Support/LLVM.h"
25
26#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
27#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
29#include "llvm/ADT/SetOperations.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/LogicalResult.h"
33
34using namespace mlir;
35using namespace cir;
36
37//===----------------------------------------------------------------------===//
38// CIR Dialect
39//===----------------------------------------------------------------------===//
40namespace {
41struct CIROpAsmDialectInterface : public OpAsmDialectInterface {
42 using OpAsmDialectInterface::OpAsmDialectInterface;
43
44 AliasResult getAlias(Type type, raw_ostream &os) const final {
45 if (auto recordType = dyn_cast<cir::RecordType>(type)) {
46 StringAttr nameAttr = recordType.getName();
47 if (!nameAttr)
48 os << "rec_anon_" << recordType.getKindAsStr();
49 else
50 os << "rec_" << nameAttr.getValue();
51 return AliasResult::OverridableAlias;
52 }
53 if (auto intType = dyn_cast<cir::IntType>(type)) {
54 // We only provide alias for standard integer types (i.e. integer types
55 // whose width is a power of 2 and at least 8).
56 unsigned width = intType.getWidth();
57 if (width < 8 || !llvm::isPowerOf2_32(width))
58 return AliasResult::NoAlias;
59 os << intType.getAlias();
60 return AliasResult::OverridableAlias;
61 }
62 if (auto voidType = dyn_cast<cir::VoidType>(type)) {
63 os << voidType.getAlias();
64 return AliasResult::OverridableAlias;
65 }
66
67 return AliasResult::NoAlias;
68 }
69
70 AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
71 if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr)) {
72 os << (boolAttr.getValue() ? "true" : "false");
73 return AliasResult::FinalAlias;
74 }
75 if (auto bitfield = mlir::dyn_cast<cir::BitfieldInfoAttr>(attr)) {
76 os << "bfi_" << bitfield.getName().str();
77 return AliasResult::FinalAlias;
78 }
79 if (auto dynCastInfoAttr = mlir::dyn_cast<cir::DynamicCastInfoAttr>(attr)) {
80 os << dynCastInfoAttr.getAlias();
81 return AliasResult::FinalAlias;
82 }
83 if (auto cmpThreeWayInfoAttr =
84 mlir::dyn_cast<cir::CmpThreeWayInfoAttr>(attr)) {
85 os << cmpThreeWayInfoAttr.getAlias();
86 return AliasResult::FinalAlias;
87 }
88 return AliasResult::NoAlias;
89 }
90};
91} // namespace
92
93void cir::CIRDialect::initialize() {
94 registerTypes();
95 registerAttributes();
96 addOperations<
97#define GET_OP_LIST
98#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
99 >();
100 addInterfaces<CIROpAsmDialectInterface>();
101}
102
103Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
104 mlir::Attribute value,
105 mlir::Type type,
106 mlir::Location loc) {
107 return cir::ConstantOp::create(builder, loc, type,
108 mlir::cast<mlir::TypedAttr>(value));
109}
110
111//===----------------------------------------------------------------------===//
112// Helpers
113//===----------------------------------------------------------------------===//
114
115// Parses one of the keywords provided in the list `keywords` and returns the
116// position of the parsed keyword in the list. If none of the keywords from the
117// list is parsed, returns -1.
118static int parseOptionalKeywordAlternative(AsmParser &parser,
119 ArrayRef<llvm::StringRef> keywords) {
120 for (auto en : llvm::enumerate(keywords)) {
121 if (succeeded(parser.parseOptionalKeyword(en.value())))
122 return en.index();
123 }
124 return -1;
125}
126
127namespace {
128template <typename Ty> struct EnumTraits {};
129
130#define REGISTER_ENUM_TYPE(Ty) \
131 template <> struct EnumTraits<cir::Ty> { \
132 static llvm::StringRef stringify(cir::Ty value) { \
133 return stringify##Ty(value); \
134 } \
135 static unsigned getMaxEnumVal() { return cir::getMaxEnumValFor##Ty(); } \
136 }
137
138REGISTER_ENUM_TYPE(GlobalLinkageKind);
139REGISTER_ENUM_TYPE(VisibilityKind);
140REGISTER_ENUM_TYPE(SideEffect);
141} // namespace
142
143/// Parse an enum from the keyword, or default to the provided default value.
144/// The return type is the enum type by default, unless overriden with the
145/// second template argument.
146template <typename EnumTy, typename RetTy = EnumTy>
147static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue) {
149 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
150 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
151
152 int index = parseOptionalKeywordAlternative(parser, names);
153 if (index == -1)
154 return static_cast<RetTy>(defaultValue);
155 return static_cast<RetTy>(index);
156}
157
158/// Parse an enum from the keyword, return failure if the keyword is not found.
159template <typename EnumTy, typename RetTy = EnumTy>
160static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result) {
162 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
163 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
164
165 int index = parseOptionalKeywordAlternative(parser, names);
166 if (index == -1)
167 return failure();
168 result = static_cast<RetTy>(index);
169 return success();
170}
171
172// Check if a region's termination omission is valid and, if so, creates and
173// inserts the omitted terminator into the region.
174static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region,
175 SMLoc errLoc) {
176 Location eLoc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
177 OpBuilder builder(parser.getBuilder().getContext());
178
179 // Insert empty block in case the region is empty to ensure the terminator
180 // will be inserted
181 if (region.empty())
182 builder.createBlock(&region);
183
184 Block &block = region.back();
185 // Region is properly terminated: nothing to do.
186 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>())
187 return success();
188
189 // Check for invalid terminator omissions.
190 if (!region.hasOneBlock())
191 return parser.emitError(errLoc,
192 "multi-block region must not omit terminator");
193
194 // Terminator was omitted correctly: recreate it.
195 builder.setInsertionPointToEnd(&block);
196 cir::YieldOp::create(builder, eLoc);
197 return success();
198}
199
200// True if the region's terminator should be omitted.
201static bool omitRegionTerm(mlir::Region &r) {
202 const auto singleNonEmptyBlock = r.hasOneBlock() && !r.back().empty();
203 const auto yieldsNothing = [&r]() {
204 auto y = dyn_cast<cir::YieldOp>(r.back().getTerminator());
205 return y && y.getArgs().empty();
206 };
207 return singleNonEmptyBlock && yieldsNothing();
208}
209
210//===----------------------------------------------------------------------===//
211// InlineKindAttr (FIXME: remove once FuncOp uses assembly format)
212//===----------------------------------------------------------------------===//
213
214ParseResult parseInlineKindAttr(OpAsmParser &parser,
215 cir::InlineKindAttr &inlineKindAttr) {
216 // Static list of possible inline kind keywords
217 static constexpr llvm::StringRef keywords[] = {"no_inline", "always_inline",
218 "inline_hint"};
219
220 // Parse the inline kind keyword (optional)
221 llvm::StringRef keyword;
222 if (parser.parseOptionalKeyword(&keyword, keywords).failed()) {
223 // Not an inline kind keyword, leave inlineKindAttr empty
224 return success();
225 }
226
227 // Parse the enum value from the keyword
228 auto inlineKindResult = ::cir::symbolizeEnum<::cir::InlineKind>(keyword);
229 if (!inlineKindResult) {
230 return parser.emitError(parser.getCurrentLocation(), "expected one of [")
231 << llvm::join(llvm::ArrayRef(keywords), ", ")
232 << "] for inlineKind, got: " << keyword;
233 }
234
235 inlineKindAttr =
236 ::cir::InlineKindAttr::get(parser.getContext(), *inlineKindResult);
237 return success();
238}
239
240void printInlineKindAttr(OpAsmPrinter &p, cir::InlineKindAttr inlineKindAttr) {
241 if (inlineKindAttr) {
242 p << " " << stringifyInlineKind(inlineKindAttr.getValue());
243 }
244}
245//===----------------------------------------------------------------------===//
246// CIR Custom Parsers/Printers
247//===----------------------------------------------------------------------===//
248
249static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser,
250 mlir::Region &region) {
251 auto regionLoc = parser.getCurrentLocation();
252 if (parser.parseRegion(region))
253 return failure();
254 if (ensureRegionTerm(parser, region, regionLoc).failed())
255 return failure();
256 return success();
257}
258
259static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer,
260 cir::ScopeOp &op,
261 mlir::Region &region) {
262 printer.printRegion(region,
263 /*printEntryBlockArgs=*/false,
264 /*printBlockTerminators=*/!omitRegionTerm(region));
265}
266
267mlir::OptionalParseResult
268parseGlobalAddressSpaceValue(mlir::AsmParser &p,
269 mlir::ptr::MemorySpaceAttrInterface &attr);
270
271void printGlobalAddressSpaceValue(mlir::AsmPrinter &printer, cir::GlobalOp op,
272 mlir::ptr::MemorySpaceAttrInterface attr);
273
274//===----------------------------------------------------------------------===//
275// AllocaOp
276//===----------------------------------------------------------------------===//
277
278void cir::AllocaOp::build(mlir::OpBuilder &odsBuilder,
279 mlir::OperationState &odsState, mlir::Type addr,
280 mlir::Type allocaType, llvm::StringRef name,
281 mlir::IntegerAttr alignment) {
282 odsState.addAttribute(getAllocaTypeAttrName(odsState.name),
283 mlir::TypeAttr::get(allocaType));
284 odsState.addAttribute(getNameAttrName(odsState.name),
285 odsBuilder.getStringAttr(name));
286 if (alignment) {
287 odsState.addAttribute(getAlignmentAttrName(odsState.name), alignment);
288 }
289 odsState.addTypes(addr);
290}
291
292//===----------------------------------------------------------------------===//
293// ArrayCtor & ArrayDtor
294//===----------------------------------------------------------------------===//
295
296template <typename Op> static LogicalResult verifyArrayCtorDtor(Op op) {
297 auto ptrTy = mlir::cast<cir::PointerType>(op.getAddr().getType());
298 mlir::Type pointeeTy = ptrTy.getPointee();
299
300 mlir::Block &body = op.getBody().front();
301 if (body.getNumArguments() != 1)
302 return op.emitOpError("body must have exactly one block argument");
303
304 auto expectedEltPtrTy =
305 mlir::dyn_cast<cir::PointerType>(body.getArgument(0).getType());
306 if (!expectedEltPtrTy)
307 return op.emitOpError("block argument must be a !cir.ptr type");
308
309 if (op.getNumElements()) {
310 auto recTy = mlir::dyn_cast<cir::RecordType>(pointeeTy);
311 if (!recTy)
312 return op.emitOpError(
313 "when 'num_elements' is present, 'addr' must be a pointer to a "
314 "!cir.record type");
315
316 if (expectedEltPtrTy != ptrTy)
317 return op.emitOpError("when 'num_elements' is present, 'addr' type must "
318 "match the block argument type");
319 } else {
320 auto arrayTy = mlir::dyn_cast<cir::ArrayType>(pointeeTy);
321 if (!arrayTy)
322 return op.emitOpError(
323 "when 'num_elements' is absent, 'addr' must be a pointer to a "
324 "!cir.array type");
325
326 mlir::Type innerEltTy = arrayTy.getElementType();
327 while (auto nested = mlir::dyn_cast<cir::ArrayType>(innerEltTy))
328 innerEltTy = nested.getElementType();
329
330 auto recTy = mlir::dyn_cast<cir::RecordType>(innerEltTy);
331 if (!recTy)
332 return op.emitOpError(
333 "the block argument type must be a pointer to a !cir.record type");
334
335 if (expectedEltPtrTy.getPointee() != innerEltTy)
336 return op.emitOpError(
337 "block argument pointee type must match the innermost array "
338 "element type");
339 }
340
341 return success();
342}
343
344LogicalResult cir::ArrayCtor::verify() {
345 if (failed(verifyArrayCtorDtor(*this)))
346 return failure();
347
348 mlir::Region &partialDtor = getPartialDtor();
349 if (!partialDtor.empty()) {
350 mlir::Block &dtorBlock = partialDtor.front();
351 if (dtorBlock.getNumArguments() != 1)
352 return emitOpError("partial_dtor must have exactly one block argument");
353
354 auto bodyArgTy = getBody().front().getArgument(0).getType();
355 if (dtorBlock.getArgument(0).getType() != bodyArgTy)
356 return emitOpError("partial_dtor block argument type must match "
357 "the body block argument type");
358 }
359 return success();
360}
361LogicalResult cir::ArrayDtor::verify() { return verifyArrayCtorDtor(*this); }
362
363//===----------------------------------------------------------------------===//
364// BreakOp
365//===----------------------------------------------------------------------===//
366
367LogicalResult cir::BreakOp::verify() {
368 if (!getOperation()->getParentOfType<LoopOpInterface>() &&
369 !getOperation()->getParentOfType<SwitchOp>())
370 return emitOpError("must be within a loop");
371 return success();
372}
373
374//===----------------------------------------------------------------------===//
375// ConditionOp
376//===----------------------------------------------------------------------===//
377
378//===----------------------------------
379// BranchOpTerminatorInterface Methods
380//===----------------------------------
381
382void cir::ConditionOp::getSuccessorRegions(
383 ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
384 // TODO(cir): The condition value may be folded to a constant, narrowing
385 // down its list of possible successors.
386
387 // Parent is a loop: condition may branch to the body or to the parent op.
388 if (auto loopOp = dyn_cast<LoopOpInterface>(getOperation()->getParentOp())) {
389 regions.emplace_back(&loopOp.getBody());
390 regions.push_back(RegionSuccessor::parent());
391 }
392
393 // Parent is an await: condition may branch to resume or suspend regions.
394 auto await = cast<AwaitOp>(getOperation()->getParentOp());
395 regions.emplace_back(&await.getResume());
396 regions.emplace_back(&await.getSuspend());
397}
398
399MutableOperandRange
400cir::ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
401 // No values are yielded to the successor region.
402 return MutableOperandRange(getOperation(), 0, 0);
403}
404
405MutableOperandRange
406cir::ResumeOp::getMutableSuccessorOperands(RegionSuccessor point) {
407 // The eh_token operand is not forwarded to the parent region.
408 return MutableOperandRange(getOperation(), 0, 0);
409}
410
411LogicalResult cir::ConditionOp::verify() {
412 if (!isa<LoopOpInterface, AwaitOp>(getOperation()->getParentOp()))
413 return emitOpError("condition must be within a conditional region");
414 return success();
415}
416
417//===----------------------------------------------------------------------===//
418// ConstantOp
419//===----------------------------------------------------------------------===//
420
421static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
422 mlir::Attribute attrType) {
423 if (isa<cir::ConstPtrAttr>(attrType)) {
424 if (!mlir::isa<cir::PointerType>(opType))
425 return op->emitOpError(
426 "pointer constant initializing a non-pointer type");
427 return success();
428 }
429
430 if (isa<cir::DataMemberAttr, cir::MethodAttr>(attrType)) {
431 // More detailed type verifications are already done in
432 // DataMemberAttr::verify or MethodAttr::verify. Don't need to repeat here.
433 return success();
434 }
435
436 if (isa<cir::ZeroAttr>(attrType)) {
437 if (isa<cir::RecordType, cir::ArrayType, cir::VectorType, cir::ComplexType>(
438 opType))
439 return success();
440 return op->emitOpError(
441 "zero expects struct, array, vector, or complex type");
442 }
443
444 if (mlir::isa<cir::UndefAttr>(attrType)) {
445 if (!mlir::isa<cir::VoidType>(opType))
446 return success();
447 return op->emitOpError("undef expects non-void type");
448 }
449
450 if (mlir::isa<cir::BoolAttr>(attrType)) {
451 if (!mlir::isa<cir::BoolType>(opType))
452 return op->emitOpError("result type (")
453 << opType << ") must be '!cir.bool' for '" << attrType << "'";
454 return success();
455 }
456
457 if (mlir::isa<cir::IntAttr, cir::FPAttr>(attrType)) {
458 auto at = cast<TypedAttr>(attrType);
459 if (at.getType() != opType) {
460 return op->emitOpError("result type (")
461 << opType << ") does not match value type (" << at.getType()
462 << ")";
463 }
464 return success();
465 }
466
467 if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
468 cir::ConstComplexAttr, cir::ConstRecordAttr,
469 cir::GlobalViewAttr, cir::PoisonAttr, cir::TypeInfoAttr,
470 cir::VTableAttr>(attrType))
471 return success();
472
473 assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
474 return op->emitOpError("global with type ")
475 << cast<TypedAttr>(attrType).getType() << " not yet supported";
476}
477
478LogicalResult cir::ConstantOp::verify() {
479 // ODS already generates checks to make sure the result type is valid. We just
480 // need to additionally check that the value's attribute type is consistent
481 // with the result type.
482 return checkConstantTypes(getOperation(), getType(), getValue());
483}
484
485OpFoldResult cir::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
486 return getValue();
487}
488
489//===----------------------------------------------------------------------===//
490// ContinueOp
491//===----------------------------------------------------------------------===//
492
493LogicalResult cir::ContinueOp::verify() {
494 if (!getOperation()->getParentOfType<LoopOpInterface>())
495 return emitOpError("must be within a loop");
496 return success();
497}
498
499//===----------------------------------------------------------------------===//
500// CastOp
501//===----------------------------------------------------------------------===//
502
503LogicalResult cir::CastOp::verify() {
504 mlir::Type resType = getType();
505 mlir::Type srcType = getSrc().getType();
506
507 // Verify address space casts for pointer types. given that
508 // casts for within a different address space are illegal.
509 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
510 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
511 if (srcPtrTy && resPtrTy && (getKind() != cir::CastKind::address_space))
512 if (srcPtrTy.getAddrSpace() != resPtrTy.getAddrSpace()) {
513 return emitOpError() << "result type address space does not match the "
514 "address space of the operand";
515 }
516
517 if (mlir::isa<cir::VectorType>(srcType) &&
518 mlir::isa<cir::VectorType>(resType)) {
519 // Use the element type of the vector to verify the cast kind. (Except for
520 // bitcast, see below.)
521 srcType = mlir::dyn_cast<cir::VectorType>(srcType).getElementType();
522 resType = mlir::dyn_cast<cir::VectorType>(resType).getElementType();
523 }
524
525 switch (getKind()) {
526 case cir::CastKind::int_to_bool: {
527 if (!mlir::isa<cir::BoolType>(resType))
528 return emitOpError() << "requires !cir.bool type for result";
529 if (!mlir::isa<cir::IntType>(srcType))
530 return emitOpError() << "requires !cir.int type for source";
531 return success();
532 }
533 case cir::CastKind::ptr_to_bool: {
534 if (!mlir::isa<cir::BoolType>(resType))
535 return emitOpError() << "requires !cir.bool type for result";
536 if (!mlir::isa<cir::PointerType>(srcType))
537 return emitOpError() << "requires !cir.ptr type for source";
538 return success();
539 }
540 case cir::CastKind::integral: {
541 if (!mlir::isa<cir::IntType>(resType))
542 return emitOpError() << "requires !cir.int type for result";
543 if (!mlir::isa<cir::IntType>(srcType))
544 return emitOpError() << "requires !cir.int type for source";
545 return success();
546 }
547 case cir::CastKind::array_to_ptrdecay: {
548 const auto arrayPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
549 const auto flatPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
550 if (!arrayPtrTy || !flatPtrTy)
551 return emitOpError() << "requires !cir.ptr type for source and result";
552
553 // TODO(CIR): Make sure the AddrSpace of both types are equals
554 return success();
555 }
556 case cir::CastKind::bitcast: {
557 // Handle the pointer types first.
558 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
559 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
560
561 if (srcPtrTy && resPtrTy) {
562 return success();
563 }
564
565 return success();
566 }
567 case cir::CastKind::floating: {
568 if (!mlir::isa<cir::FPTypeInterface>(srcType) ||
569 !mlir::isa<cir::FPTypeInterface>(resType))
570 return emitOpError() << "requires !cir.float type for source and result";
571 return success();
572 }
573 case cir::CastKind::float_to_int: {
574 if (!mlir::isa<cir::FPTypeInterface>(srcType))
575 return emitOpError() << "requires !cir.float type for source";
576 if (!mlir::dyn_cast<cir::IntType>(resType))
577 return emitOpError() << "requires !cir.int type for result";
578 return success();
579 }
580 case cir::CastKind::int_to_ptr: {
581 if (!mlir::dyn_cast<cir::IntType>(srcType))
582 return emitOpError() << "requires !cir.int type for source";
583 if (!mlir::dyn_cast<cir::PointerType>(resType))
584 return emitOpError() << "requires !cir.ptr type for result";
585 return success();
586 }
587 case cir::CastKind::ptr_to_int: {
588 if (!mlir::dyn_cast<cir::PointerType>(srcType))
589 return emitOpError() << "requires !cir.ptr type for source";
590 if (!mlir::dyn_cast<cir::IntType>(resType))
591 return emitOpError() << "requires !cir.int type for result";
592 return success();
593 }
594 case cir::CastKind::float_to_bool: {
595 if (!mlir::isa<cir::FPTypeInterface>(srcType))
596 return emitOpError() << "requires !cir.float type for source";
597 if (!mlir::isa<cir::BoolType>(resType))
598 return emitOpError() << "requires !cir.bool type for result";
599 return success();
600 }
601 case cir::CastKind::bool_to_int: {
602 if (!mlir::isa<cir::BoolType>(srcType))
603 return emitOpError() << "requires !cir.bool type for source";
604 if (!mlir::isa<cir::IntType>(resType))
605 return emitOpError() << "requires !cir.int type for result";
606 return success();
607 }
608 case cir::CastKind::int_to_float: {
609 if (!mlir::isa<cir::IntType>(srcType))
610 return emitOpError() << "requires !cir.int type for source";
611 if (!mlir::isa<cir::FPTypeInterface>(resType))
612 return emitOpError() << "requires !cir.float type for result";
613 return success();
614 }
615 case cir::CastKind::bool_to_float: {
616 if (!mlir::isa<cir::BoolType>(srcType))
617 return emitOpError() << "requires !cir.bool type for source";
618 if (!mlir::isa<cir::FPTypeInterface>(resType))
619 return emitOpError() << "requires !cir.float type for result";
620 return success();
621 }
622 case cir::CastKind::address_space: {
623 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
624 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
625 if (!srcPtrTy || !resPtrTy)
626 return emitOpError() << "requires !cir.ptr type for source and result";
627 if (srcPtrTy.getPointee() != resPtrTy.getPointee())
628 return emitOpError() << "requires two types differ in addrspace only";
629 return success();
630 }
631 case cir::CastKind::float_to_complex: {
632 if (!mlir::isa<cir::FPTypeInterface>(srcType))
633 return emitOpError() << "requires !cir.float type for source";
634 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
635 if (!resComplexTy)
636 return emitOpError() << "requires !cir.complex type for result";
637 if (srcType != resComplexTy.getElementType())
638 return emitOpError() << "requires source type match result element type";
639 return success();
640 }
641 case cir::CastKind::int_to_complex: {
642 if (!mlir::isa<cir::IntType>(srcType))
643 return emitOpError() << "requires !cir.int type for source";
644 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
645 if (!resComplexTy)
646 return emitOpError() << "requires !cir.complex type for result";
647 if (srcType != resComplexTy.getElementType())
648 return emitOpError() << "requires source type match result element type";
649 return success();
650 }
651 case cir::CastKind::float_complex_to_real: {
652 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
653 if (!srcComplexTy)
654 return emitOpError() << "requires !cir.complex type for source";
655 if (!mlir::isa<cir::FPTypeInterface>(resType))
656 return emitOpError() << "requires !cir.float type for result";
657 if (srcComplexTy.getElementType() != resType)
658 return emitOpError() << "requires source element type match result type";
659 return success();
660 }
661 case cir::CastKind::int_complex_to_real: {
662 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
663 if (!srcComplexTy)
664 return emitOpError() << "requires !cir.complex type for source";
665 if (!mlir::isa<cir::IntType>(resType))
666 return emitOpError() << "requires !cir.int type for result";
667 if (srcComplexTy.getElementType() != resType)
668 return emitOpError() << "requires source element type match result type";
669 return success();
670 }
671 case cir::CastKind::float_complex_to_bool: {
672 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
673 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
674 return emitOpError()
675 << "requires floating point !cir.complex type for source";
676 if (!mlir::isa<cir::BoolType>(resType))
677 return emitOpError() << "requires !cir.bool type for result";
678 return success();
679 }
680 case cir::CastKind::int_complex_to_bool: {
681 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
682 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
683 return emitOpError()
684 << "requires floating point !cir.complex type for source";
685 if (!mlir::isa<cir::BoolType>(resType))
686 return emitOpError() << "requires !cir.bool type for result";
687 return success();
688 }
689 case cir::CastKind::float_complex: {
690 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
691 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
692 return emitOpError()
693 << "requires floating point !cir.complex type for source";
694 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
695 if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
696 return emitOpError()
697 << "requires floating point !cir.complex type for result";
698 return success();
699 }
700 case cir::CastKind::float_complex_to_int_complex: {
701 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
702 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
703 return emitOpError()
704 << "requires floating point !cir.complex type for source";
705 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
706 if (!resComplexTy || !resComplexTy.isIntegerComplex())
707 return emitOpError() << "requires integer !cir.complex type for result";
708 return success();
709 }
710 case cir::CastKind::int_complex: {
711 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
712 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
713 return emitOpError() << "requires integer !cir.complex type for source";
714 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
715 if (!resComplexTy || !resComplexTy.isIntegerComplex())
716 return emitOpError() << "requires integer !cir.complex type for result";
717 return success();
718 }
719 case cir::CastKind::int_complex_to_float_complex: {
720 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
721 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
722 return emitOpError() << "requires integer !cir.complex type for source";
723 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
724 if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
725 return emitOpError()
726 << "requires floating point !cir.complex type for result";
727 return success();
728 }
729 case cir::CastKind::member_ptr_to_bool: {
730 if (!mlir::isa<cir::DataMemberType, cir::MethodType>(srcType))
731 return emitOpError()
732 << "requires !cir.data_member or !cir.method type for source";
733 if (!mlir::isa<cir::BoolType>(resType))
734 return emitOpError() << "requires !cir.bool type for result";
735 return success();
736 }
737 }
738 llvm_unreachable("Unknown CastOp kind?");
739}
740
741static bool isIntOrBoolCast(cir::CastOp op) {
742 auto kind = op.getKind();
743 return kind == cir::CastKind::bool_to_int ||
744 kind == cir::CastKind::int_to_bool || kind == cir::CastKind::integral;
745}
746
747static Value tryFoldCastChain(cir::CastOp op) {
748 cir::CastOp head = op, tail = op;
749
750 while (op) {
751 if (!isIntOrBoolCast(op))
752 break;
753 head = op;
754 op = head.getSrc().getDefiningOp<cir::CastOp>();
755 }
756
757 if (head == tail)
758 return {};
759
760 // if bool_to_int -> ... -> int_to_bool: take the bool
761 // as we had it was before all casts
762 if (head.getKind() == cir::CastKind::bool_to_int &&
763 tail.getKind() == cir::CastKind::int_to_bool)
764 return head.getSrc();
765
766 // if int_to_bool -> ... -> int_to_bool: take the result
767 // of the first one, as no other casts (and ext casts as well)
768 // don't change the first result
769 if (head.getKind() == cir::CastKind::int_to_bool &&
770 tail.getKind() == cir::CastKind::int_to_bool)
771 return head.getResult();
772
773 return {};
774}
775
776OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
777 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getSrc())) {
778 // Propagate poison value
779 return cir::PoisonAttr::get(getContext(), getType());
780 }
781
782 if (getSrc().getType() == getType()) {
783 switch (getKind()) {
784 case cir::CastKind::integral: {
786 auto foldOrder = getSrc().getDefiningOp()->fold(foldResults);
787 if (foldOrder.succeeded() && mlir::isa<mlir::Attribute>(foldResults[0]))
788 return mlir::cast<mlir::Attribute>(foldResults[0]);
789 return {};
790 }
791 case cir::CastKind::bitcast:
792 case cir::CastKind::address_space:
793 case cir::CastKind::float_complex:
794 case cir::CastKind::int_complex: {
795 return getSrc();
796 }
797 default:
798 return {};
799 }
800 }
801
802 // Handle cases where a chain of casts cancel out.
803 Value result = tryFoldCastChain(*this);
804 if (result)
805 return result;
806
807 // Handle simple constant casts.
808 if (auto srcConst = getSrc().getDefiningOp<cir::ConstantOp>()) {
809 switch (getKind()) {
810 case cir::CastKind::integral: {
811 mlir::Type srcTy = getSrc().getType();
812 // Don't try to fold vector casts for now.
813 assert(mlir::isa<cir::VectorType>(srcTy) ==
814 mlir::isa<cir::VectorType>(getType()));
815 if (mlir::isa<cir::VectorType>(srcTy))
816 break;
817
818 auto srcIntTy = mlir::cast<cir::IntType>(srcTy);
819 auto dstIntTy = mlir::cast<cir::IntType>(getType());
820 APInt newVal =
821 srcIntTy.isSigned()
822 ? srcConst.getIntValue().sextOrTrunc(dstIntTy.getWidth())
823 : srcConst.getIntValue().zextOrTrunc(dstIntTy.getWidth());
824 return cir::IntAttr::get(dstIntTy, newVal);
825 }
826 default:
827 break;
828 }
829 }
830 return {};
831}
832
833//===----------------------------------------------------------------------===//
834// CallOp
835//===----------------------------------------------------------------------===//
836
837mlir::OperandRange cir::CallOp::getArgOperands() {
838 if (isIndirect())
839 return getArgs().drop_front(1);
840 return getArgs();
841}
842
843mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
844 mlir::MutableOperandRange args = getArgsMutable();
845 if (isIndirect())
846 return args.slice(1, args.size() - 1);
847 return args;
848}
849
850mlir::Value cir::CallOp::getIndirectCall() {
851 assert(isIndirect());
852 return getOperand(0);
853}
854
855/// Return the operand at index 'i'.
856Value cir::CallOp::getArgOperand(unsigned i) {
857 if (isIndirect())
858 ++i;
859 return getOperand(i);
860}
861
862/// Return the number of operands.
863unsigned cir::CallOp::getNumArgOperands() {
864 if (isIndirect())
865 return this->getOperation()->getNumOperands() - 1;
866 return this->getOperation()->getNumOperands();
867}
868
869static mlir::ParseResult
870parseTryCallDestinations(mlir::OpAsmParser &parser,
871 mlir::OperationState &result) {
872 mlir::Block *normalDestSuccessor;
873 if (parser.parseSuccessor(normalDestSuccessor))
874 return mlir::failure();
875
876 if (parser.parseComma())
877 return mlir::failure();
878
879 mlir::Block *unwindDestSuccessor;
880 if (parser.parseSuccessor(unwindDestSuccessor))
881 return mlir::failure();
882
883 result.addSuccessors(normalDestSuccessor);
884 result.addSuccessors(unwindDestSuccessor);
885 return mlir::success();
886}
887
888static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
889 mlir::OperationState &result,
890 bool hasDestinationBlocks = false) {
892 llvm::SMLoc opsLoc;
893 mlir::FlatSymbolRefAttr calleeAttr;
894
895 // If we cannot parse a string callee, it means this is an indirect call.
896 if (!parser
897 .parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
898 result.attributes)
899 .has_value()) {
900 OpAsmParser::UnresolvedOperand indirectVal;
901 // Do not resolve right now, since we need to figure out the type
902 if (parser.parseOperand(indirectVal).failed())
903 return failure();
904 ops.push_back(indirectVal);
905 }
906
907 if (parser.parseLParen())
908 return mlir::failure();
909
910 opsLoc = parser.getCurrentLocation();
911 if (parser.parseOperandList(ops))
912 return mlir::failure();
913 if (parser.parseRParen())
914 return mlir::failure();
915
916 if (hasDestinationBlocks &&
917 parseTryCallDestinations(parser, result).failed()) {
918 return ::mlir::failure();
919 }
920
921 if (parser.parseOptionalKeyword("musttail").succeeded())
922 result.addAttribute(CIRDialect::getMustTailAttrName(),
923 mlir::UnitAttr::get(parser.getContext()));
924
925 if (parser.parseOptionalKeyword("nothrow").succeeded())
926 result.addAttribute(CIRDialect::getNoThrowAttrName(),
927 mlir::UnitAttr::get(parser.getContext()));
928
929 if (parser.parseOptionalKeyword("side_effect").succeeded()) {
930 if (parser.parseLParen().failed())
931 return failure();
932 cir::SideEffect sideEffect;
933 if (parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed())
934 return failure();
935 if (parser.parseRParen().failed())
936 return failure();
937 auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
938 result.addAttribute(CIRDialect::getSideEffectAttrName(), attr);
939 }
940
941 if (parser.parseOptionalAttrDict(result.attributes))
942 return ::mlir::failure();
943
944 if (parser.parseColon())
945 return ::mlir::failure();
946
947 SmallVector<Type> argTypes;
949 SmallVector<Type> resultTypes;
950 SmallVector<DictionaryAttr> resultAttrs;
951 if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
952 resultTypes, resultAttrs))
953 return mlir::failure();
954
955 if (resultTypes.size() > 1 || resultAttrs.size() > 1)
956 return parser.emitError(
957 parser.getCurrentLocation(),
958 "functions with multiple return types are not supported");
959
960 result.addTypes(resultTypes);
961
962 if (parser.resolveOperands(ops, argTypes, opsLoc, result.operands))
963 return mlir::failure();
964
965 if (!resultAttrs.empty() && resultAttrs[0])
966 result.addAttribute(
967 CIRDialect::getResAttrsAttrName(),
968 mlir::ArrayAttr::get(parser.getContext(), {resultAttrs[0]}));
969
970 // ArrayAttr requires a vector of 'Attribute', so we have to do the conversion
971 // here into a separate collection.
972 llvm::SmallVector<Attribute> convertedArgAttrs;
973 bool argAttrsEmpty = true;
974
975 llvm::transform(argAttrs, std::back_inserter(convertedArgAttrs),
976 [&](DictionaryAttr da) -> mlir::Attribute {
977 if (da)
978 argAttrsEmpty = false;
979 return da;
980 });
981
982 if (!argAttrsEmpty) {
983 llvm::ArrayRef argAttrsRef = convertedArgAttrs;
984 if (!calleeAttr) {
985 // Fixup for indirect calls, which get an extra entry in the 'args' for
986 // the indirect type, which doesn't get attributes.
987 argAttrsRef = argAttrsRef.drop_front();
988 }
989 result.addAttribute(CIRDialect::getArgAttrsAttrName(),
990 mlir::ArrayAttr::get(parser.getContext(), argAttrsRef));
991 }
992
993 return mlir::success();
994}
995
996static void
997printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym,
998 mlir::Value indirectCallee, mlir::OpAsmPrinter &printer,
999 bool isNothrow, cir::SideEffect sideEffect, ArrayAttr argAttrs,
1000 ArrayAttr resAttrs, mlir::Block *normalDest = nullptr,
1001 mlir::Block *unwindDest = nullptr) {
1002 printer << ' ';
1003
1004 auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
1005 auto ops = callLikeOp.getArgOperands();
1006
1007 if (calleeSym) {
1008 // Direct calls
1009 printer.printAttributeWithoutType(calleeSym);
1010 } else {
1011 // Indirect calls
1012 assert(indirectCallee);
1013 printer << indirectCallee;
1014 }
1015
1016 printer << "(" << ops << ")";
1017
1018 if (normalDest) {
1019 assert(unwindDest && "expected two successors");
1020 auto tryCall = cast<cir::TryCallOp>(op);
1021 printer << ' ' << tryCall.getNormalDest();
1022 printer << ",";
1023 printer << ' ';
1024 printer << tryCall.getUnwindDest();
1025 }
1026
1027 if (op->hasAttr(CIRDialect::getMustTailAttrName()))
1028 printer << " musttail";
1029
1030 if (isNothrow)
1031 printer << " nothrow";
1032
1033 if (sideEffect != cir::SideEffect::All) {
1034 printer << " side_effect(";
1035 printer << stringifySideEffect(sideEffect);
1036 printer << ")";
1037 }
1038
1040 CIRDialect::getCalleeAttrName(),
1041 CIRDialect::getMustTailAttrName(),
1042 CIRDialect::getNoThrowAttrName(),
1043 CIRDialect::getSideEffectAttrName(),
1044 CIRDialect::getOperandSegmentSizesAttrName(),
1045 llvm::StringRef("res_attrs"),
1046 llvm::StringRef("arg_attrs")};
1047 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
1048 printer << " : ";
1049 if (calleeSym || !argAttrs) {
1050 call_interface_impl::printFunctionSignature(
1051 printer, op->getOperands().getTypes(), argAttrs,
1052 /*isVariadic=*/false, op->getResultTypes(), resAttrs);
1053 } else {
1054 // indirect function calls use an 'arg' type for the type of its indirect
1055 // argument. However, we don't store a similar attribute collection. In
1056 // order to make `printFunctionSignature` have the attributes line up, we
1057 // have to make a 'shimmed' copy of the attributes that have a blank set of
1058 // attributes for the indirect argument.
1059 llvm::SmallVector<Attribute> shimmedArgAttrs;
1060 shimmedArgAttrs.push_back(mlir::DictionaryAttr::get(op->getContext(), {}));
1061 shimmedArgAttrs.append(argAttrs.begin(), argAttrs.end());
1062 call_interface_impl::printFunctionSignature(
1063 printer, op->getOperands().getTypes(),
1064 mlir::ArrayAttr::get(op->getContext(), shimmedArgAttrs),
1065 /*isVariadic=*/false, op->getResultTypes(), resAttrs);
1066 }
1067}
1068
1069mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
1070 mlir::OperationState &result) {
1071 return parseCallCommon(parser, result);
1072}
1073
1074void cir::CallOp::print(mlir::OpAsmPrinter &p) {
1075 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
1076 cir::SideEffect sideEffect = getSideEffect();
1077 printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
1078 sideEffect, getArgAttrsAttr(), getResAttrsAttr());
1079}
1080
1081static LogicalResult
1082verifyCallCommInSymbolUses(mlir::Operation *op,
1083 SymbolTableCollection &symbolTable) {
1084 auto fnAttr =
1085 op->getAttrOfType<FlatSymbolRefAttr>(CIRDialect::getCalleeAttrName());
1086 if (!fnAttr) {
1087 // This is an indirect call, thus we don't have to check the symbol uses.
1088 return mlir::success();
1089 }
1090
1091 auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
1092 if (!fn)
1093 return op->emitOpError() << "'" << fnAttr.getValue()
1094 << "' does not reference a valid function";
1095
1096 auto callIf = dyn_cast<cir::CIRCallOpInterface>(op);
1097 assert(callIf && "expected CIR call interface to be always available");
1098
1099 // Verify that the operand and result types match the callee. Note that
1100 // argument-checking is disabled for functions without a prototype.
1101 auto fnType = fn.getFunctionType();
1102 if (!fn.getNoProto()) {
1103 unsigned numCallOperands = callIf.getNumArgOperands();
1104 unsigned numFnOpOperands = fnType.getNumInputs();
1105
1106 if (!fnType.isVarArg() && numCallOperands != numFnOpOperands)
1107 return op->emitOpError("incorrect number of operands for callee");
1108 if (fnType.isVarArg() && numCallOperands < numFnOpOperands)
1109 return op->emitOpError("too few operands for callee");
1110
1111 for (unsigned i = 0, e = numFnOpOperands; i != e; ++i)
1112 if (callIf.getArgOperand(i).getType() != fnType.getInput(i))
1113 return op->emitOpError("operand type mismatch: expected operand type ")
1114 << fnType.getInput(i) << ", but provided "
1115 << op->getOperand(i).getType() << " for operand number " << i;
1116 }
1117
1119
1120 // Void function must not return any results.
1121 if (fnType.hasVoidReturn() && op->getNumResults() != 0)
1122 return op->emitOpError("callee returns void but call has results");
1123
1124 // Non-void function calls must return exactly one result.
1125 if (!fnType.hasVoidReturn() && op->getNumResults() != 1)
1126 return op->emitOpError("incorrect number of results for callee");
1127
1128 // Parent function and return value types must match.
1129 if (!fnType.hasVoidReturn() &&
1130 op->getResultTypes().front() != fnType.getReturnType()) {
1131 return op->emitOpError("result type mismatch: expected ")
1132 << fnType.getReturnType() << ", but provided "
1133 << op->getResult(0).getType();
1134 }
1135
1136 return mlir::success();
1137}
1138
1139LogicalResult
1140cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1141 return verifyCallCommInSymbolUses(*this, symbolTable);
1142}
1143
1144//===----------------------------------------------------------------------===//
1145// TryCallOp
1146//===----------------------------------------------------------------------===//
1147
1148mlir::OperandRange cir::TryCallOp::getArgOperands() {
1149 if (isIndirect())
1150 return getArgs().drop_front(1);
1151 return getArgs();
1152}
1153
1154mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
1155 mlir::MutableOperandRange args = getArgsMutable();
1156 if (isIndirect())
1157 return args.slice(1, args.size() - 1);
1158 return args;
1159}
1160
1161mlir::Value cir::TryCallOp::getIndirectCall() {
1162 assert(isIndirect());
1163 return getOperand(0);
1164}
1165
1166/// Return the operand at index 'i'.
1167Value cir::TryCallOp::getArgOperand(unsigned i) {
1168 if (isIndirect())
1169 ++i;
1170 return getOperand(i);
1171}
1172
1173/// Return the number of operands.
1174unsigned cir::TryCallOp::getNumArgOperands() {
1175 if (isIndirect())
1176 return this->getOperation()->getNumOperands() - 1;
1177 return this->getOperation()->getNumOperands();
1178}
1179
1180LogicalResult
1181cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1182 return verifyCallCommInSymbolUses(*this, symbolTable);
1183}
1184
1185mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
1186 mlir::OperationState &result) {
1187 return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
1188}
1189
1190void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
1191 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
1192 cir::SideEffect sideEffect = getSideEffect();
1193 printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
1194 sideEffect, getArgAttrsAttr(), getResAttrsAttr(),
1195 getNormalDest(), getUnwindDest());
1196}
1197
1198//===----------------------------------------------------------------------===//
1199// ReturnOp
1200//===----------------------------------------------------------------------===//
1201
1202static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op,
1203 cir::FuncOp function) {
1204 // ReturnOps currently only have a single optional operand.
1205 if (op.getNumOperands() > 1)
1206 return op.emitOpError() << "expects at most 1 return operand";
1207
1208 // Ensure returned type matches the function signature.
1209 auto expectedTy = function.getFunctionType().getReturnType();
1210 auto actualTy =
1211 (op.getNumOperands() == 0 ? cir::VoidType::get(op.getContext())
1212 : op.getOperand(0).getType());
1213 if (actualTy != expectedTy)
1214 return op.emitOpError() << "returns " << actualTy
1215 << " but enclosing function returns " << expectedTy;
1216
1217 return mlir::success();
1218}
1219
1220mlir::LogicalResult cir::ReturnOp::verify() {
1221 // Returns can be present in multiple different scopes, get the
1222 // wrapping function and start from there.
1223 auto *fnOp = getOperation()->getParentOp();
1224 while (!isa<cir::FuncOp>(fnOp))
1225 fnOp = fnOp->getParentOp();
1226
1227 // Make sure return types match function return type.
1228 if (checkReturnAndFunction(*this, cast<cir::FuncOp>(fnOp)).failed())
1229 return failure();
1230
1231 return success();
1232}
1233
1234//===----------------------------------------------------------------------===//
1235// IfOp
1236//===----------------------------------------------------------------------===//
1237
1238ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
1239 // create the regions for 'then'.
1240 result.regions.reserve(2);
1241 Region *thenRegion = result.addRegion();
1242 Region *elseRegion = result.addRegion();
1243
1244 mlir::Builder &builder = parser.getBuilder();
1245 OpAsmParser::UnresolvedOperand cond;
1246 Type boolType = cir::BoolType::get(builder.getContext());
1247
1248 if (parser.parseOperand(cond) ||
1249 parser.resolveOperand(cond, boolType, result.operands))
1250 return failure();
1251
1252 // Parse 'then' region.
1253 mlir::SMLoc parseThenLoc = parser.getCurrentLocation();
1254 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1255 return failure();
1256
1257 if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed())
1258 return failure();
1259
1260 // If we find an 'else' keyword, parse the 'else' region.
1261 if (!parser.parseOptionalKeyword("else")) {
1262 mlir::SMLoc parseElseLoc = parser.getCurrentLocation();
1263 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1264 return failure();
1265 if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed())
1266 return failure();
1267 }
1268
1269 // Parse the optional attribute list.
1270 if (parser.parseOptionalAttrDict(result.attributes))
1271 return failure();
1272 return success();
1273}
1274
1275void cir::IfOp::print(OpAsmPrinter &p) {
1276 p << " " << getCondition() << " ";
1277 mlir::Region &thenRegion = this->getThenRegion();
1278 p.printRegion(thenRegion,
1279 /*printEntryBlockArgs=*/false,
1280 /*printBlockTerminators=*/!omitRegionTerm(thenRegion));
1281
1282 // Print the 'else' regions if it exists and has a block.
1283 mlir::Region &elseRegion = this->getElseRegion();
1284 if (!elseRegion.empty()) {
1285 p << " else ";
1286 p.printRegion(elseRegion,
1287 /*printEntryBlockArgs=*/false,
1288 /*printBlockTerminators=*/!omitRegionTerm(elseRegion));
1289 }
1290
1291 p.printOptionalAttrDict(getOperation()->getAttrs());
1292}
1293
1294/// Default callback for IfOp builders.
1295void cir::buildTerminatedBody(OpBuilder &builder, Location loc) {
1296 // add cir.yield to end of the block
1297 cir::YieldOp::create(builder, loc);
1298}
1299
1300/// Given the region at `index`, or the parent operation if `index` is None,
1301/// return the successor regions. These are the regions that may be selected
1302/// during the flow of control. `operands` is a set of optional attributes that
1303/// correspond to a constant value for each operand, or null if that operand is
1304/// not a constant.
1305void cir::IfOp::getSuccessorRegions(mlir::RegionBranchPoint point,
1306 SmallVectorImpl<RegionSuccessor> &regions) {
1307 // The `then` and the `else` region branch back to the parent operation.
1308 if (!point.isParent()) {
1309 regions.push_back(RegionSuccessor::parent());
1310 return;
1311 }
1312
1313 // Don't consider the else region if it is empty.
1314 Region *elseRegion = &this->getElseRegion();
1315 if (elseRegion->empty())
1316 elseRegion = nullptr;
1317
1318 // If the condition isn't constant, both regions may be executed.
1319 regions.push_back(RegionSuccessor(&getThenRegion()));
1320 // If the else region does not exist, it is not a viable successor.
1321 if (elseRegion)
1322 regions.push_back(RegionSuccessor(elseRegion));
1323
1324 return;
1325}
1326
1327mlir::ValueRange cir::IfOp::getSuccessorInputs(RegionSuccessor successor) {
1328 return successor.isParent() ? ValueRange(getOperation()->getResults())
1329 : ValueRange();
1330}
1331
1332void cir::IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1333 bool withElseRegion, BuilderCallbackRef thenBuilder,
1334 BuilderCallbackRef elseBuilder) {
1335 assert(thenBuilder && "the builder callback for 'then' must be present");
1336 result.addOperands(cond);
1337
1338 OpBuilder::InsertionGuard guard(builder);
1339 Region *thenRegion = result.addRegion();
1340 builder.createBlock(thenRegion);
1341 thenBuilder(builder, result.location);
1342
1343 Region *elseRegion = result.addRegion();
1344 if (!withElseRegion)
1345 return;
1346
1347 builder.createBlock(elseRegion);
1348 elseBuilder(builder, result.location);
1349}
1350
1351//===----------------------------------------------------------------------===//
1352// ScopeOp
1353//===----------------------------------------------------------------------===//
1354
1355/// Given the region at `index`, or the parent operation if `index` is None,
1356/// return the successor regions. These are the regions that may be selected
1357/// during the flow of control. `operands` is a set of optional attributes
1358/// that correspond to a constant value for each operand, or null if that
1359/// operand is not a constant.
1360void cir::ScopeOp::getSuccessorRegions(
1361 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1362 // The only region always branch back to the parent operation.
1363 if (!point.isParent()) {
1364 regions.push_back(RegionSuccessor::parent());
1365 return;
1366 }
1367
1368 // If the condition isn't constant, both regions may be executed.
1369 regions.push_back(RegionSuccessor(&getScopeRegion()));
1370}
1371
1372mlir::ValueRange cir::ScopeOp::getSuccessorInputs(RegionSuccessor successor) {
1373 return successor.isParent() ? ValueRange(getOperation()->getResults())
1374 : ValueRange();
1375}
1376
1377void cir::ScopeOp::build(
1378 OpBuilder &builder, OperationState &result,
1379 function_ref<void(OpBuilder &, Type &, Location)> scopeBuilder) {
1380 assert(scopeBuilder && "the builder callback for 'then' must be present");
1381
1382 OpBuilder::InsertionGuard guard(builder);
1383 Region *scopeRegion = result.addRegion();
1384 builder.createBlock(scopeRegion);
1386
1387 mlir::Type yieldTy;
1388 scopeBuilder(builder, yieldTy, result.location);
1389
1390 if (yieldTy)
1391 result.addTypes(TypeRange{yieldTy});
1392}
1393
1394void cir::ScopeOp::build(
1395 OpBuilder &builder, OperationState &result,
1396 function_ref<void(OpBuilder &, Location)> scopeBuilder) {
1397 assert(scopeBuilder && "the builder callback for 'then' must be present");
1398 OpBuilder::InsertionGuard guard(builder);
1399 Region *scopeRegion = result.addRegion();
1400 builder.createBlock(scopeRegion);
1402 scopeBuilder(builder, result.location);
1403}
1404
1405LogicalResult cir::ScopeOp::verify() {
1406 if (getRegion().empty()) {
1407 return emitOpError() << "cir.scope must not be empty since it should "
1408 "include at least an implicit cir.yield ";
1409 }
1410
1411 mlir::Block &lastBlock = getRegion().back();
1412 if (lastBlock.empty() || !lastBlock.mightHaveTerminator() ||
1413 !lastBlock.getTerminator()->hasTrait<OpTrait::IsTerminator>())
1414 return emitOpError() << "last block of cir.scope must be terminated";
1415 return success();
1416}
1417
1418LogicalResult cir::ScopeOp::fold(FoldAdaptor /*adaptor*/,
1419 SmallVectorImpl<OpFoldResult> &results) {
1420 // Only fold "trivial" scopes: a single block containing only a `cir.yield`.
1421 if (!getRegion().hasOneBlock())
1422 return failure();
1423 Block &block = getRegion().front();
1424 if (block.getOperations().size() != 1)
1425 return failure();
1426
1427 auto yield = dyn_cast<cir::YieldOp>(block.front());
1428 if (!yield)
1429 return failure();
1430
1431 // Only fold when the scope produces a value.
1432 if (getNumResults() != 1 || yield.getNumOperands() != 1)
1433 return failure();
1434
1435 results.push_back(yield.getOperand(0));
1436 return success();
1437}
1438
1439//===----------------------------------------------------------------------===//
1440// CleanupScopeOp
1441//===----------------------------------------------------------------------===//
1442
1443void cir::CleanupScopeOp::getSuccessorRegions(
1444 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1445 if (!point.isParent()) {
1446 regions.push_back(RegionSuccessor::parent());
1447 return;
1448 }
1449
1450 // Execution always proceeds from the body region to the cleanup region.
1451 regions.push_back(RegionSuccessor(&getBodyRegion()));
1452 regions.push_back(RegionSuccessor(&getCleanupRegion()));
1453}
1454
1455mlir::ValueRange
1456cir::CleanupScopeOp::getSuccessorInputs(RegionSuccessor successor) {
1457 return ValueRange();
1458}
1459
1460LogicalResult cir::CleanupScopeOp::canonicalize(CleanupScopeOp op,
1461 PatternRewriter &rewriter) {
1462 auto isRegionTrivial = [](Region &region) {
1463 assert(!region.empty() && "CleanupScopeOp regions must not be empty");
1464 if (!region.hasOneBlock())
1465 return false;
1466 Block &block = llvm::getSingleElement(region);
1467 return llvm::hasSingleElement(block) &&
1468 isa<cir::YieldOp>(llvm::getSingleElement(block));
1469 };
1470
1471 Region &body = op.getBodyRegion();
1472 Region &cleanup = op.getCleanupRegion();
1473
1474 // An EH-only cleanup scope with an empty body can never trigger its cleanup
1475 // region — there are no operations in the body that could throw. Erase it.
1476 if (op.getCleanupKind() == CleanupKind::EH && isRegionTrivial(body)) {
1477 rewriter.eraseOp(op);
1478 return success();
1479 }
1480
1481 // A cleanup scope with a trivial cleanup region has no cleanup to perform.
1482 // Inline the body into the parent block and erase the scope.
1483 if (!isRegionTrivial(cleanup) || !body.hasOneBlock())
1484 return failure();
1485
1486 Block &bodyBlock = body.front();
1487 if (!isa<cir::YieldOp>(bodyBlock.getTerminator()))
1488 return failure();
1489
1490 Operation *yield = bodyBlock.getTerminator();
1491 rewriter.inlineBlockBefore(&bodyBlock, op);
1492 rewriter.eraseOp(yield);
1493 rewriter.eraseOp(op);
1494 return success();
1495}
1496
1497void cir::CleanupScopeOp::build(
1498 OpBuilder &builder, OperationState &result, CleanupKind cleanupKind,
1499 function_ref<void(OpBuilder &, Location)> bodyBuilder,
1500 function_ref<void(OpBuilder &, Location)> cleanupBuilder) {
1501 result.addAttribute(getCleanupKindAttrName(result.name),
1502 CleanupKindAttr::get(builder.getContext(), cleanupKind));
1503
1504 OpBuilder::InsertionGuard guard(builder);
1505
1506 // Build body region.
1507 Region *bodyRegion = result.addRegion();
1508 builder.createBlock(bodyRegion);
1509 if (bodyBuilder)
1510 bodyBuilder(builder, result.location);
1511
1512 // Build cleanup region.
1513 Region *cleanupRegion = result.addRegion();
1514 builder.createBlock(cleanupRegion);
1515 if (cleanupBuilder)
1516 cleanupBuilder(builder, result.location);
1517}
1518
1519//===----------------------------------------------------------------------===//
1520// BrOp
1521//===----------------------------------------------------------------------===//
1522
1523/// Merges blocks connected by a unique unconditional branch.
1524///
1525/// ^bb0: ^bb0:
1526/// ... ...
1527/// cir.br ^bb1 => ...
1528/// ^bb1: cir.return
1529/// ...
1530/// cir.return
1531LogicalResult cir::BrOp::canonicalize(BrOp op, PatternRewriter &rewriter) {
1532 Block *src = op->getBlock();
1533 Block *dst = op.getDest();
1534
1535 // Do not fold self-loops.
1536 if (src == dst)
1537 return failure();
1538
1539 // Only merge when this is the unique edge between the blocks.
1540 if (src->getNumSuccessors() != 1 || dst->getSinglePredecessor() != src)
1541 return failure();
1542
1543 // Don't merge blocks that start with LabelOp or IndirectBrOp.
1544 // This is to avoid merging blocks that have an indirect predecessor.
1545 if (isa<cir::LabelOp, cir::IndirectBrOp>(dst->front()))
1546 return failure();
1547
1548 auto operands = op.getDestOperands();
1549 rewriter.eraseOp(op);
1550 rewriter.mergeBlocks(dst, src, operands);
1551 return success();
1552}
1553
1554mlir::SuccessorOperands cir::BrOp::getSuccessorOperands(unsigned index) {
1555 assert(index == 0 && "invalid successor index");
1556 return mlir::SuccessorOperands(getDestOperandsMutable());
1557}
1558
1559Block *cir::BrOp::getSuccessorForOperands(ArrayRef<Attribute>) {
1560 return getDest();
1561}
1562
1563//===----------------------------------------------------------------------===//
1564// IndirectBrCondOp
1565//===----------------------------------------------------------------------===//
1566
1567mlir::SuccessorOperands
1568cir::IndirectBrOp::getSuccessorOperands(unsigned index) {
1569 assert(index < getNumSuccessors() && "invalid successor index");
1570 return mlir::SuccessorOperands(getSuccOperandsMutable()[index]);
1571}
1572
1574 OpAsmParser &parser, Type &flagType,
1575 SmallVectorImpl<Block *> &succOperandBlocks,
1576 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &succOperands,
1577 SmallVectorImpl<SmallVector<Type>> &succOperandsTypes) {
1578 if (failed(parser.parseCommaSeparatedList(
1579 OpAsmParser::Delimiter::Square,
1580 [&]() {
1581 Block *destination = nullptr;
1582 SmallVector<OpAsmParser::UnresolvedOperand> operands;
1583 SmallVector<Type> operandTypes;
1584
1585 if (parser.parseSuccessor(destination).failed())
1586 return failure();
1587
1588 if (succeeded(parser.parseOptionalLParen())) {
1589 if (failed(parser.parseOperandList(
1590 operands, OpAsmParser::Delimiter::None)) ||
1591 failed(parser.parseColonTypeList(operandTypes)) ||
1592 failed(parser.parseRParen()))
1593 return failure();
1594 }
1595 succOperandBlocks.push_back(destination);
1596 succOperands.emplace_back(operands);
1597 succOperandsTypes.emplace_back(operandTypes);
1598 return success();
1599 },
1600 "successor blocks")))
1601 return failure();
1602 return success();
1603}
1604
1605void printIndirectBrOpSucessors(OpAsmPrinter &p, cir::IndirectBrOp op,
1606 Type flagType, SuccessorRange succs,
1607 OperandRangeRange succOperands,
1608 const TypeRangeRange &succOperandsTypes) {
1609 p << "[";
1610 llvm::interleave(
1611 llvm::zip(succs, succOperands),
1612 [&](auto i) {
1613 p.printNewline();
1614 p.printSuccessorAndUseList(std::get<0>(i), std::get<1>(i));
1615 },
1616 [&] { p << ','; });
1617 if (!succOperands.empty())
1618 p.printNewline();
1619 p << "]";
1620}
1621
1622//===----------------------------------------------------------------------===//
1623// BrCondOp
1624//===----------------------------------------------------------------------===//
1625
1626mlir::SuccessorOperands cir::BrCondOp::getSuccessorOperands(unsigned index) {
1627 assert(index < getNumSuccessors() && "invalid successor index");
1628 return SuccessorOperands(index == 0 ? getDestOperandsTrueMutable()
1629 : getDestOperandsFalseMutable());
1630}
1631
1632Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1633 if (IntegerAttr condAttr = dyn_cast_if_present<IntegerAttr>(operands.front()))
1634 return condAttr.getValue().isOne() ? getDestTrue() : getDestFalse();
1635 return nullptr;
1636}
1637
1638//===----------------------------------------------------------------------===//
1639// CaseOp
1640//===----------------------------------------------------------------------===//
1641
1642void cir::CaseOp::getSuccessorRegions(
1643 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1644 if (!point.isParent()) {
1645 regions.push_back(RegionSuccessor::parent());
1646 return;
1647 }
1648 regions.push_back(RegionSuccessor(&getCaseRegion()));
1649}
1650
1651mlir::ValueRange cir::CaseOp::getSuccessorInputs(RegionSuccessor successor) {
1652 return successor.isParent() ? ValueRange(getOperation()->getResults())
1653 : ValueRange();
1654}
1655
1656void cir::CaseOp::build(OpBuilder &builder, OperationState &result,
1657 ArrayAttr value, CaseOpKind kind,
1658 OpBuilder::InsertPoint &insertPoint) {
1659 OpBuilder::InsertionGuard guardSwitch(builder);
1660 result.addAttribute("value", value);
1661 result.getOrAddProperties<Properties>().kind =
1662 cir::CaseOpKindAttr::get(builder.getContext(), kind);
1663 Region *caseRegion = result.addRegion();
1664 builder.createBlock(caseRegion);
1665
1666 insertPoint = builder.saveInsertionPoint();
1667}
1668
1669//===----------------------------------------------------------------------===//
1670// SwitchOp
1671//===----------------------------------------------------------------------===//
1672
1673void cir::SwitchOp::getSuccessorRegions(
1674 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &region) {
1675 if (!point.isParent()) {
1676 region.push_back(RegionSuccessor::parent());
1677 return;
1678 }
1679
1680 region.push_back(RegionSuccessor(&getBody()));
1681}
1682
1683mlir::ValueRange cir::SwitchOp::getSuccessorInputs(RegionSuccessor successor) {
1684 return successor.isParent() ? ValueRange(getOperation()->getResults())
1685 : ValueRange();
1686}
1687
1688void cir::SwitchOp::build(OpBuilder &builder, OperationState &result,
1689 Value cond, BuilderOpStateCallbackRef switchBuilder) {
1690 assert(switchBuilder && "the builder callback for regions must be present");
1691 OpBuilder::InsertionGuard guardSwitch(builder);
1692 Region *switchRegion = result.addRegion();
1693 builder.createBlock(switchRegion);
1694 result.addOperands({cond});
1695 switchBuilder(builder, result.location, result);
1696}
1697
1698void cir::SwitchOp::collectCases(llvm::SmallVectorImpl<CaseOp> &cases) {
1699 walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
1700 // Don't walk in nested switch op.
1701 if (isa<cir::SwitchOp>(op) && op != *this)
1702 return WalkResult::skip();
1703
1704 if (auto caseOp = dyn_cast<cir::CaseOp>(op))
1705 cases.push_back(caseOp);
1706
1707 return WalkResult::advance();
1708 });
1709}
1710
1711bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
1712 collectCases(cases);
1713
1714 if (getBody().empty())
1715 return false;
1716
1717 if (!isa<YieldOp>(getBody().front().back()))
1718 return false;
1719
1720 if (!llvm::all_of(getBody().front(),
1721 [](Operation &op) { return isa<CaseOp, YieldOp>(op); }))
1722 return false;
1723
1724 return llvm::all_of(cases, [this](CaseOp op) {
1725 return op->getParentOfType<SwitchOp>() == *this;
1726 });
1727}
1728
1729//===----------------------------------------------------------------------===//
1730// SwitchFlatOp
1731//===----------------------------------------------------------------------===//
1732
1733void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
1734 Value value, Block *defaultDestination,
1735 ValueRange defaultOperands,
1736 ArrayRef<APInt> caseValues,
1737 BlockRange caseDestinations,
1738 ArrayRef<ValueRange> caseOperands) {
1739
1740 std::vector<mlir::Attribute> caseValuesAttrs;
1741 for (const APInt &val : caseValues)
1742 caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
1743 mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
1744
1745 build(builder, result, value, defaultOperands, caseOperands, attrs,
1746 defaultDestination, caseDestinations);
1747}
1748
1749/// <cases> ::= `[` (case (`,` case )* )? `]`
1750/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
1751static ParseResult parseSwitchFlatOpCases(
1752 OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
1753 SmallVectorImpl<Block *> &caseDestinations,
1755 &caseOperands,
1756 SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
1757 if (failed(parser.parseLSquare()))
1758 return failure();
1759 if (succeeded(parser.parseOptionalRSquare()))
1760 return success();
1762
1763 auto parseCase = [&]() {
1764 int64_t value = 0;
1765 if (failed(parser.parseInteger(value)))
1766 return failure();
1767
1768 values.push_back(cir::IntAttr::get(flagType, value));
1769
1770 Block *destination;
1772 llvm::SmallVector<Type> operandTypes;
1773 if (parser.parseColon() || parser.parseSuccessor(destination))
1774 return failure();
1775 if (!parser.parseOptionalLParen()) {
1776 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
1777 /*allowResultNumber=*/false) ||
1778 parser.parseColonTypeList(operandTypes) || parser.parseRParen())
1779 return failure();
1780 }
1781 caseDestinations.push_back(destination);
1782 caseOperands.emplace_back(operands);
1783 caseOperandTypes.emplace_back(operandTypes);
1784 return success();
1785 };
1786 if (failed(parser.parseCommaSeparatedList(parseCase)))
1787 return failure();
1788
1789 caseValues = ArrayAttr::get(flagType.getContext(), values);
1790
1791 return parser.parseRSquare();
1792}
1793
1794static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
1795 Type flagType, mlir::ArrayAttr caseValues,
1796 SuccessorRange caseDestinations,
1797 OperandRangeRange caseOperands,
1798 const TypeRangeRange &caseOperandTypes) {
1799 p << '[';
1800 p.printNewline();
1801 if (!caseValues) {
1802 p << ']';
1803 return;
1804 }
1805
1806 size_t index = 0;
1807 llvm::interleave(
1808 llvm::zip(caseValues, caseDestinations),
1809 [&](auto i) {
1810 p << " ";
1811 mlir::Attribute a = std::get<0>(i);
1812 p << mlir::cast<cir::IntAttr>(a).getValue();
1813 p << ": ";
1814 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
1815 },
1816 [&] {
1817 p << ',';
1818 p.printNewline();
1819 });
1820 p.printNewline();
1821 p << ']';
1822}
1823
1824//===----------------------------------------------------------------------===//
1825// GlobalOp
1826//===----------------------------------------------------------------------===//
1827
1828static ParseResult parseConstantValue(OpAsmParser &parser,
1829 mlir::Attribute &valueAttr) {
1830 NamedAttrList attr;
1831 return parser.parseAttribute(valueAttr, "value", attr);
1832}
1833
1834static void printConstant(OpAsmPrinter &p, Attribute value) {
1835 p.printAttribute(value);
1836}
1837
1838mlir::LogicalResult cir::GlobalOp::verify() {
1839 // Verify that the initial value, if present, is either a unit attribute or
1840 // an attribute CIR supports.
1841 if (getInitialValue().has_value()) {
1842 if (checkConstantTypes(getOperation(), getSymType(), *getInitialValue())
1843 .failed())
1844 return failure();
1845 }
1846
1847 // TODO(CIR): Many other checks for properties that haven't been upstreamed
1848 // yet.
1849
1850 return success();
1851}
1852
1853void cir::GlobalOp::build(
1854 OpBuilder &odsBuilder, OperationState &odsState, llvm::StringRef sym_name,
1855 mlir::Type sym_type, bool isConstant,
1856 mlir::ptr::MemorySpaceAttrInterface addrSpace,
1857 cir::GlobalLinkageKind linkage,
1858 function_ref<void(OpBuilder &, Location)> ctorBuilder,
1859 function_ref<void(OpBuilder &, Location)> dtorBuilder) {
1860 odsState.addAttribute(getSymNameAttrName(odsState.name),
1861 odsBuilder.getStringAttr(sym_name));
1862 odsState.addAttribute(getSymTypeAttrName(odsState.name),
1863 mlir::TypeAttr::get(sym_type));
1864 auto &properties = odsState.getOrAddProperties<cir::GlobalOp::Properties>();
1865 properties.setConstant(isConstant);
1866
1867 addrSpace = normalizeDefaultAddressSpace(addrSpace);
1868 if (addrSpace)
1869 odsState.addAttribute(getAddrSpaceAttrName(odsState.name), addrSpace);
1870
1871 cir::GlobalLinkageKindAttr linkageAttr =
1872 cir::GlobalLinkageKindAttr::get(odsBuilder.getContext(), linkage);
1873 odsState.addAttribute(getLinkageAttrName(odsState.name), linkageAttr);
1874
1875 Region *ctorRegion = odsState.addRegion();
1876 if (ctorBuilder) {
1877 odsBuilder.createBlock(ctorRegion);
1878 ctorBuilder(odsBuilder, odsState.location);
1879 }
1880
1881 Region *dtorRegion = odsState.addRegion();
1882 if (dtorBuilder) {
1883 odsBuilder.createBlock(dtorRegion);
1884 dtorBuilder(odsBuilder, odsState.location);
1885 }
1886}
1887
1888/// Given the region at `index`, or the parent operation if `index` is None,
1889/// return the successor regions. These are the regions that may be selected
1890/// during the flow of control. `operands` is a set of optional attributes that
1891/// correspond to a constant value for each operand, or null if that operand is
1892/// not a constant.
1893void cir::GlobalOp::getSuccessorRegions(
1894 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1895 // The `ctor` and `dtor` regions always branch back to the parent operation.
1896 if (!point.isParent()) {
1897 regions.push_back(RegionSuccessor::parent());
1898 return;
1899 }
1900
1901 // Don't consider the ctor region if it is empty.
1902 Region *ctorRegion = &this->getCtorRegion();
1903 if (ctorRegion->empty())
1904 ctorRegion = nullptr;
1905
1906 // Don't consider the dtor region if it is empty.
1907 Region *dtorRegion = &this->getDtorRegion();
1908 if (dtorRegion->empty())
1909 dtorRegion = nullptr;
1910
1911 // If the condition isn't constant, both regions may be executed.
1912 if (ctorRegion)
1913 regions.push_back(RegionSuccessor(ctorRegion));
1914 if (dtorRegion)
1915 regions.push_back(RegionSuccessor(dtorRegion));
1916}
1917
1918mlir::ValueRange cir::GlobalOp::getSuccessorInputs(RegionSuccessor successor) {
1919 return successor.isParent() ? ValueRange(getOperation()->getResults())
1920 : ValueRange();
1921}
1922
1923static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op,
1924 TypeAttr type, Attribute initAttr,
1925 mlir::Region &ctorRegion,
1926 mlir::Region &dtorRegion) {
1927 auto printType = [&]() { p << ": " << type; };
1928 if (!op.isDeclaration()) {
1929 p << "= ";
1930 if (!ctorRegion.empty()) {
1931 p << "ctor ";
1932 printType();
1933 p << " ";
1934 p.printRegion(ctorRegion,
1935 /*printEntryBlockArgs=*/false,
1936 /*printBlockTerminators=*/false);
1937 } else {
1938 // This also prints the type...
1939 if (initAttr)
1940 printConstant(p, initAttr);
1941 }
1942
1943 if (!dtorRegion.empty()) {
1944 p << " dtor ";
1945 p.printRegion(dtorRegion,
1946 /*printEntryBlockArgs=*/false,
1947 /*printBlockTerminators=*/false);
1948 }
1949 } else {
1950 printType();
1951 }
1952}
1953
1954static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser,
1955 TypeAttr &typeAttr,
1956 Attribute &initialValueAttr,
1957 mlir::Region &ctorRegion,
1958 mlir::Region &dtorRegion) {
1959 mlir::Type opTy;
1960 if (parser.parseOptionalEqual().failed()) {
1961 // Absence of equal means a declaration, so we need to parse the type.
1962 // cir.global @a : !cir.int<s, 32>
1963 if (parser.parseColonType(opTy))
1964 return failure();
1965 } else {
1966 // Parse contructor, example:
1967 // cir.global @rgb = ctor : type { ... }
1968 if (!parser.parseOptionalKeyword("ctor")) {
1969 if (parser.parseColonType(opTy))
1970 return failure();
1971 auto parseLoc = parser.getCurrentLocation();
1972 if (parser.parseRegion(ctorRegion, /*arguments=*/{}, /*argTypes=*/{}))
1973 return failure();
1974 if (ensureRegionTerm(parser, ctorRegion, parseLoc).failed())
1975 return failure();
1976 } else {
1977 // Parse constant with initializer, examples:
1978 // cir.global @y = 3.400000e+00 : f32
1979 // cir.global @rgb = #cir.const_array<[...] : !cir.array<i8 x 3>>
1980 if (parseConstantValue(parser, initialValueAttr).failed())
1981 return failure();
1982
1983 assert(mlir::isa<mlir::TypedAttr>(initialValueAttr) &&
1984 "Non-typed attrs shouldn't appear here.");
1985 auto typedAttr = mlir::cast<mlir::TypedAttr>(initialValueAttr);
1986 opTy = typedAttr.getType();
1987 }
1988
1989 // Parse destructor, example:
1990 // dtor { ... }
1991 if (!parser.parseOptionalKeyword("dtor")) {
1992 auto parseLoc = parser.getCurrentLocation();
1993 if (parser.parseRegion(dtorRegion, /*arguments=*/{}, /*argTypes=*/{}))
1994 return failure();
1995 if (ensureRegionTerm(parser, dtorRegion, parseLoc).failed())
1996 return failure();
1997 }
1998 }
1999
2000 typeAttr = TypeAttr::get(opTy);
2001 return success();
2002}
2003
2004//===----------------------------------------------------------------------===//
2005// GetGlobalOp
2006//===----------------------------------------------------------------------===//
2007
2008LogicalResult
2009cir::GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2010 // Verify that the result type underlying pointer type matches the type of
2011 // the referenced cir.global or cir.func op.
2012 mlir::Operation *op =
2013 symbolTable.lookupNearestSymbolFrom(*this, getNameAttr());
2014 if (op == nullptr || !(isa<GlobalOp>(op) || isa<FuncOp>(op)))
2015 return emitOpError("'")
2016 << getName()
2017 << "' does not reference a valid cir.global or cir.func";
2018
2019 mlir::Type symTy;
2020 mlir::ptr::MemorySpaceAttrInterface symAddrSpaceAttr{};
2021 if (auto g = dyn_cast<GlobalOp>(op)) {
2022 symTy = g.getSymType();
2023 symAddrSpaceAttr = g.getAddrSpaceAttr();
2024 // Verify that for thread local global access, the global needs to
2025 // be marked with tls bits.
2026 if (getTls() && !g.getTlsModel())
2027 return emitOpError("access to global not marked thread local");
2028
2029 // Verify that the static_local attribute on GetGlobalOp matches the
2030 // static_local_guard attribute on GlobalOp. GetGlobalOp uses a UnitAttr,
2031 // GlobalOp uses StaticLocalGuardAttr. Both should be present, or neither.
2032 bool getGlobalIsStaticLocal = getStaticLocal();
2033 bool globalIsStaticLocal = g.getStaticLocalGuard().has_value();
2034 if (getGlobalIsStaticLocal != globalIsStaticLocal &&
2035 !getOperation()->getParentOfType<cir::GlobalOp>())
2036 return emitOpError("static_local attribute mismatch");
2037 } else if (auto f = dyn_cast<FuncOp>(op)) {
2038 symTy = f.getFunctionType();
2039 } else {
2040 llvm_unreachable("Unexpected operation for GetGlobalOp");
2041 }
2042
2043 auto resultType = dyn_cast<PointerType>(getAddr().getType());
2044 if (!resultType || symTy != resultType.getPointee())
2045 return emitOpError("result type pointee type '")
2046 << resultType.getPointee() << "' does not match type " << symTy
2047 << " of the global @" << getName();
2048
2049 if (symAddrSpaceAttr != resultType.getAddrSpace()) {
2050 return emitOpError()
2051 << "result type address space does not match the address "
2052 "space of the global @"
2053 << getName();
2054 }
2055
2056 return success();
2057}
2058
2059//===----------------------------------------------------------------------===//
2060// VTableAddrPointOp
2061//===----------------------------------------------------------------------===//
2062
2063LogicalResult
2064cir::VTableAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2065 StringRef name = getName();
2066
2067 // Verify that the result type underlying pointer type matches the type of
2068 // the referenced cir.global.
2069 auto op =
2070 symbolTable.lookupNearestSymbolFrom<cir::GlobalOp>(*this, getNameAttr());
2071 if (!op)
2072 return emitOpError("'")
2073 << name << "' does not reference a valid cir.global";
2074 std::optional<mlir::Attribute> init = op.getInitialValue();
2075 if (!init)
2076 return success();
2077 if (!isa<cir::VTableAttr>(*init))
2078 return emitOpError("Expected #cir.vtable in initializer for global '")
2079 << name << "'";
2080 return success();
2081}
2082
2083//===----------------------------------------------------------------------===//
2084// VTTAddrPointOp
2085//===----------------------------------------------------------------------===//
2086
2087LogicalResult
2088cir::VTTAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2089 // VTT ptr is not coming from a symbol.
2090 if (!getName())
2091 return success();
2092 StringRef name = *getName();
2093
2094 // Verify that the result type underlying pointer type matches the type of
2095 // the referenced cir.global op.
2096 auto op =
2097 symbolTable.lookupNearestSymbolFrom<cir::GlobalOp>(*this, getNameAttr());
2098 if (!op)
2099 return emitOpError("'")
2100 << name << "' does not reference a valid cir.global";
2101 std::optional<mlir::Attribute> init = op.getInitialValue();
2102 if (!init)
2103 return success();
2104 if (!isa<cir::ConstArrayAttr>(*init))
2105 return emitOpError(
2106 "Expected constant array in initializer for global VTT '")
2107 << name << "'";
2108 return success();
2109}
2110
2111LogicalResult cir::VTTAddrPointOp::verify() {
2112 // The operation uses either a symbol or a value to operate, but not both
2113 if (getName() && getSymAddr())
2114 return emitOpError("should use either a symbol or value, but not both");
2115
2116 // If not a symbol, stick with the concrete type used for getSymAddr.
2117 if (getSymAddr())
2118 return success();
2119
2120 mlir::Type resultType = getAddr().getType();
2121 mlir::Type resTy = cir::PointerType::get(
2122 cir::PointerType::get(cir::VoidType::get(getContext())));
2123
2124 if (resultType != resTy)
2125 return emitOpError("result type must be ")
2126 << resTy << ", but provided result type is " << resultType;
2127 return success();
2128}
2129
2130//===----------------------------------------------------------------------===//
2131// FuncOp
2132//===----------------------------------------------------------------------===//
2133
2134/// Returns the name used for the linkage attribute. This *must* correspond to
2135/// the name of the attribute in ODS.
2136static llvm::StringRef getLinkageAttrNameString() { return "linkage"; }
2137
2138void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
2139 StringRef name, FuncType type,
2140 GlobalLinkageKind linkage) {
2141 result.addRegion();
2142 result.addAttribute(SymbolTable::getSymbolAttrName(),
2143 builder.getStringAttr(name));
2144 result.addAttribute(getFunctionTypeAttrName(result.name),
2145 TypeAttr::get(type));
2146 result.addAttribute(
2148 GlobalLinkageKindAttr::get(builder.getContext(), linkage));
2149}
2150
2151ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
2152 llvm::SMLoc loc = parser.getCurrentLocation();
2153 mlir::Builder &builder = parser.getBuilder();
2154
2155 mlir::StringAttr builtinNameAttr = getBuiltinAttrName(state.name);
2156 mlir::StringAttr coroutineNameAttr = getCoroutineAttrName(state.name);
2157 mlir::StringAttr inlineKindNameAttr = getInlineKindAttrName(state.name);
2158 mlir::StringAttr lambdaNameAttr = getLambdaAttrName(state.name);
2159 mlir::StringAttr noProtoNameAttr = getNoProtoAttrName(state.name);
2160 mlir::StringAttr comdatNameAttr = getComdatAttrName(state.name);
2161 mlir::StringAttr visNameAttr = getSymVisibilityAttrName(state.name);
2162 mlir::StringAttr dsoLocalNameAttr = getDsoLocalAttrName(state.name);
2163 mlir::StringAttr specialMemberAttr = getCxxSpecialMemberAttrName(state.name);
2164
2165 if (::mlir::succeeded(parser.parseOptionalKeyword(builtinNameAttr.strref())))
2166 state.addAttribute(builtinNameAttr, parser.getBuilder().getUnitAttr());
2167 if (::mlir::succeeded(
2168 parser.parseOptionalKeyword(coroutineNameAttr.strref())))
2169 state.addAttribute(coroutineNameAttr, parser.getBuilder().getUnitAttr());
2170
2171 // Parse optional inline kind attribute
2172 cir::InlineKindAttr inlineKindAttr;
2173 if (failed(parseInlineKindAttr(parser, inlineKindAttr)))
2174 return failure();
2175 if (inlineKindAttr)
2176 state.addAttribute(inlineKindNameAttr, inlineKindAttr);
2177
2178 if (::mlir::succeeded(parser.parseOptionalKeyword(lambdaNameAttr.strref())))
2179 state.addAttribute(lambdaNameAttr, parser.getBuilder().getUnitAttr());
2180 if (parser.parseOptionalKeyword(noProtoNameAttr).succeeded())
2181 state.addAttribute(noProtoNameAttr, parser.getBuilder().getUnitAttr());
2182
2183 if (parser.parseOptionalKeyword(comdatNameAttr).succeeded())
2184 state.addAttribute(comdatNameAttr, parser.getBuilder().getUnitAttr());
2185
2186 // Default to external linkage if no keyword is provided.
2187 state.addAttribute(getLinkageAttrNameString(),
2188 GlobalLinkageKindAttr::get(
2189 parser.getContext(),
2191 parser, GlobalLinkageKind::ExternalLinkage)));
2192
2193 ::llvm::StringRef visAttrStr;
2194 if (parser.parseOptionalKeyword(&visAttrStr, {"private", "public", "nested"})
2195 .succeeded()) {
2196 state.addAttribute(visNameAttr,
2197 parser.getBuilder().getStringAttr(visAttrStr));
2198 }
2199
2200 state.getOrAddProperties<cir::FuncOp::Properties>().global_visibility =
2201 parseOptionalCIRKeyword(parser, cir::VisibilityKind::Default);
2202
2203 if (parser.parseOptionalKeyword(dsoLocalNameAttr).succeeded())
2204 state.addAttribute(dsoLocalNameAttr, parser.getBuilder().getUnitAttr());
2205
2206 StringAttr nameAttr;
2207 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2208 state.attributes))
2209 return failure();
2213 bool isVariadic = false;
2214 if (function_interface_impl::parseFunctionSignatureWithArguments(
2215 parser, /*allowVariadic=*/true, arguments, isVariadic, resultTypes,
2216 resultAttrs))
2217 return failure();
2220 bool argAttrsEmpty = true;
2221 for (OpAsmParser::Argument &arg : arguments) {
2222 argTypes.push_back(arg.type);
2223 // Add the 'empty' attribute anyway to make sure the arity matches, but we
2224 // only want to 'set' the attribute at the top level if there is SOME data
2225 // along the way.
2226 argAttrs.push_back(arg.attrs);
2227 if (arg.attrs)
2228 argAttrsEmpty = false;
2229 }
2230
2231 // These should be in sync anyway, but test both of them anyway.
2232 if (resultTypes.size() > 1 || resultAttrs.size() > 1)
2233 return parser.emitError(
2234 loc, "functions with multiple return types are not supported");
2235
2236 mlir::Type returnType =
2237 (resultTypes.empty() ? cir::VoidType::get(builder.getContext())
2238 : resultTypes.front());
2239
2240 cir::FuncType fnType = cir::FuncType::get(argTypes, returnType, isVariadic);
2241 if (!fnType)
2242 return failure();
2243
2244 state.addAttribute(getFunctionTypeAttrName(state.name),
2245 TypeAttr::get(fnType));
2246
2247 if (!resultAttrs.empty() && resultAttrs[0])
2248 state.addAttribute(
2249 getResAttrsAttrName(state.name),
2250 mlir::ArrayAttr::get(parser.getContext(), {resultAttrs[0]}));
2251
2252 if (!argAttrsEmpty)
2253 state.addAttribute(getArgAttrsAttrName(state.name),
2254 mlir::ArrayAttr::get(parser.getContext(), argAttrs));
2255
2256 bool hasAlias = false;
2257 mlir::StringAttr aliaseeNameAttr = getAliaseeAttrName(state.name);
2258 if (parser.parseOptionalKeyword("alias").succeeded()) {
2259 if (parser.parseLParen().failed())
2260 return failure();
2261 mlir::StringAttr aliaseeAttr;
2262 if (parser.parseOptionalSymbolName(aliaseeAttr).failed())
2263 return failure();
2264 state.addAttribute(aliaseeNameAttr, FlatSymbolRefAttr::get(aliaseeAttr));
2265 if (parser.parseRParen().failed())
2266 return failure();
2267 hasAlias = true;
2268 }
2269
2270 mlir::StringAttr personalityNameAttr = getPersonalityAttrName(state.name);
2271 if (parser.parseOptionalKeyword("personality").succeeded()) {
2272 if (parser.parseLParen().failed())
2273 return failure();
2274 mlir::StringAttr personalityAttr;
2275 if (parser.parseOptionalSymbolName(personalityAttr).failed())
2276 return failure();
2277 state.addAttribute(personalityNameAttr,
2278 FlatSymbolRefAttr::get(personalityAttr));
2279 if (parser.parseRParen().failed())
2280 return failure();
2281 }
2282
2283 auto parseGlobalDtorCtor =
2284 [&](StringRef keyword,
2285 llvm::function_ref<void(std::optional<int> prio)> createAttr)
2286 -> mlir::LogicalResult {
2287 if (mlir::succeeded(parser.parseOptionalKeyword(keyword))) {
2288 std::optional<int> priority;
2289 if (mlir::succeeded(parser.parseOptionalLParen())) {
2290 auto parsedPriority = mlir::FieldParser<int>::parse(parser);
2291 if (mlir::failed(parsedPriority))
2292 return parser.emitError(parser.getCurrentLocation(),
2293 "failed to parse 'priority', of type 'int'");
2294 priority = parsedPriority.value_or(int());
2295 // Parse literal ')'
2296 if (parser.parseRParen())
2297 return failure();
2298 }
2299 createAttr(priority);
2300 }
2301 return success();
2302 };
2303
2304 // Parse CXXSpecialMember attribute
2305 if (parser.parseOptionalKeyword("special_member").succeeded()) {
2306 if (parser.parseLess().failed())
2307 return failure();
2308
2309 mlir::Attribute attr;
2310 if (parser.parseAttribute(attr).failed())
2311 return failure();
2312 if (!mlir::isa<cir::CXXCtorAttr, cir::CXXDtorAttr, cir::CXXAssignAttr>(
2313 attr))
2314 return parser.emitError(parser.getCurrentLocation(),
2315 "expected a C++ special member attribute");
2316 state.addAttribute(specialMemberAttr, attr);
2317
2318 if (parser.parseGreater().failed())
2319 return failure();
2320 }
2321
2322 if (parseGlobalDtorCtor("global_ctor", [&](std::optional<int> priority) {
2323 mlir::IntegerAttr globalCtorPriorityAttr =
2324 builder.getI32IntegerAttr(priority.value_or(65535));
2325 state.addAttribute(getGlobalCtorPriorityAttrName(state.name),
2326 globalCtorPriorityAttr);
2327 }).failed())
2328 return failure();
2329
2330 if (parseGlobalDtorCtor("global_dtor", [&](std::optional<int> priority) {
2331 mlir::IntegerAttr globalDtorPriorityAttr =
2332 builder.getI32IntegerAttr(priority.value_or(65535));
2333 state.addAttribute(getGlobalDtorPriorityAttrName(state.name),
2334 globalDtorPriorityAttr);
2335 }).failed())
2336 return failure();
2337
2338 if (parser.parseOptionalKeyword("side_effect").succeeded()) {
2339 cir::SideEffect sideEffect;
2340
2341 if (parser.parseLParen().failed() ||
2342 parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed() ||
2343 parser.parseRParen().failed())
2344 return failure();
2345
2346 auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
2347 state.addAttribute(CIRDialect::getSideEffectAttrName(), attr);
2348 }
2349
2350 // Parse the rest of the attributes.
2351 NamedAttrList parsedAttrs;
2352 if (parser.parseOptionalAttrDictWithKeyword(parsedAttrs))
2353 return failure();
2354
2355 for (StringRef disallowed : cir::FuncOp::getAttributeNames()) {
2356 if (parsedAttrs.get(disallowed))
2357 return parser.emitError(loc, "attribute '")
2358 << disallowed
2359 << "' should not be specified in the explicit attribute list";
2360 }
2361
2362 state.attributes.append(parsedAttrs);
2363
2364 // Parse the optional function body.
2365 auto *body = state.addRegion();
2366 OptionalParseResult parseResult = parser.parseOptionalRegion(
2367 *body, arguments, /*enableNameShadowing=*/false);
2368 if (parseResult.has_value()) {
2369 if (hasAlias)
2370 return parser.emitError(loc, "function alias shall not have a body");
2371 if (failed(*parseResult))
2372 return failure();
2373 // Function body was parsed, make sure its not empty.
2374 if (body->empty())
2375 return parser.emitError(loc, "expected non-empty function body");
2376 }
2377
2378 return success();
2379}
2380
2381// This function corresponds to `llvm::GlobalValue::isDeclaration` and should
2382// have a similar implementation. We don't currently ifuncs or materializable
2383// functions, but those should be handled here as they are implemented.
2384bool cir::FuncOp::isDeclaration() {
2386
2387 std::optional<StringRef> aliasee = getAliasee();
2388 if (!aliasee)
2389 return getFunctionBody().empty();
2390
2391 // Aliases are always definitions.
2392 return false;
2393}
2394
2395bool cir::FuncOp::isCXXSpecialMemberFunction() {
2396 return getCxxSpecialMemberAttr() != nullptr;
2397}
2398
2399bool cir::FuncOp::isCxxConstructor() {
2400 auto attr = getCxxSpecialMemberAttr();
2401 return attr && dyn_cast<CXXCtorAttr>(attr);
2402}
2403
2404bool cir::FuncOp::isCxxDestructor() {
2405 auto attr = getCxxSpecialMemberAttr();
2406 return attr && dyn_cast<CXXDtorAttr>(attr);
2407}
2408
2409bool cir::FuncOp::isCxxSpecialAssignment() {
2410 auto attr = getCxxSpecialMemberAttr();
2411 return attr && dyn_cast<CXXAssignAttr>(attr);
2412}
2413
2414std::optional<CtorKind> cir::FuncOp::getCxxConstructorKind() {
2415 mlir::Attribute attr = getCxxSpecialMemberAttr();
2416 if (attr) {
2417 if (auto ctor = dyn_cast<CXXCtorAttr>(attr))
2418 return ctor.getCtorKind();
2419 }
2420 return std::nullopt;
2421}
2422
2423std::optional<AssignKind> cir::FuncOp::getCxxSpecialAssignKind() {
2424 mlir::Attribute attr = getCxxSpecialMemberAttr();
2425 if (attr) {
2426 if (auto assign = dyn_cast<CXXAssignAttr>(attr))
2427 return assign.getAssignKind();
2428 }
2429 return std::nullopt;
2430}
2431
2432bool cir::FuncOp::isCxxTrivialMemberFunction() {
2433 mlir::Attribute attr = getCxxSpecialMemberAttr();
2434 if (attr) {
2435 if (auto ctor = dyn_cast<CXXCtorAttr>(attr))
2436 return ctor.getIsTrivial();
2437 if (auto dtor = dyn_cast<CXXDtorAttr>(attr))
2438 return dtor.getIsTrivial();
2439 if (auto assign = dyn_cast<CXXAssignAttr>(attr))
2440 return assign.getIsTrivial();
2441 }
2442 return false;
2443}
2444
2445mlir::Region *cir::FuncOp::getCallableRegion() {
2446 // TODO(CIR): This function will have special handling for aliases and a
2447 // check for an external function, once those features have been upstreamed.
2448 return &getBody();
2449}
2450
2451void cir::FuncOp::print(OpAsmPrinter &p) {
2452 if (getBuiltin())
2453 p << " builtin";
2454
2455 if (getCoroutine())
2456 p << " coroutine";
2457
2458 printInlineKindAttr(p, getInlineKindAttr());
2459
2460 if (getLambda())
2461 p << " lambda";
2462
2463 if (getNoProto())
2464 p << " no_proto";
2465
2466 if (getComdat())
2467 p << " comdat";
2468
2469 if (getLinkage() != GlobalLinkageKind::ExternalLinkage)
2470 p << ' ' << stringifyGlobalLinkageKind(getLinkage());
2471
2472 mlir::SymbolTable::Visibility vis = getVisibility();
2473 if (vis != mlir::SymbolTable::Visibility::Public)
2474 p << ' ' << vis;
2475
2476 if (getGlobalVisibility() != cir::VisibilityKind::Default)
2477 p << ' ' << stringifyVisibilityKind(getGlobalVisibility());
2478
2479 if (getDsoLocal())
2480 p << " dso_local";
2481
2482 p << ' ';
2483 p.printSymbolName(getSymName());
2484 cir::FuncType fnType = getFunctionType();
2485 function_interface_impl::printFunctionSignature(
2486 p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes());
2487
2488 if (std::optional<StringRef> aliaseeName = getAliasee()) {
2489 p << " alias(";
2490 p.printSymbolName(*aliaseeName);
2491 p << ")";
2492 }
2493
2494 if (std::optional<StringRef> personalityName = getPersonality()) {
2495 p << " personality(";
2496 p.printSymbolName(*personalityName);
2497 p << ")";
2498 }
2499
2500 if (auto specialMemberAttr = getCxxSpecialMember()) {
2501 p << " special_member<";
2502 p.printAttribute(*specialMemberAttr);
2503 p << '>';
2504 }
2505
2506 if (auto globalCtorPriority = getGlobalCtorPriority()) {
2507 p << " global_ctor";
2508 if (globalCtorPriority.value() != 65535)
2509 p << "(" << globalCtorPriority.value() << ")";
2510 }
2511
2512 if (auto globalDtorPriority = getGlobalDtorPriority()) {
2513 p << " global_dtor";
2514 if (globalDtorPriority.value() != 65535)
2515 p << "(" << globalDtorPriority.value() << ")";
2516 }
2517
2518 if (std::optional<cir::SideEffect> sideEffect = getSideEffect();
2519 sideEffect && *sideEffect != cir::SideEffect::All) {
2520 p << " side_effect(";
2521 p << stringifySideEffect(*sideEffect);
2522 p << ")";
2523 }
2524
2525 function_interface_impl::printFunctionAttributes(
2526 p, *this, cir::FuncOp::getAttributeNames());
2527
2528 // Print the body if this is not an external function.
2529 Region &body = getOperation()->getRegion(0);
2530 if (!body.empty()) {
2531 p << ' ';
2532 p.printRegion(body, /*printEntryBlockArgs=*/false,
2533 /*printBlockTerminators=*/true);
2534 }
2535}
2536
2537mlir::LogicalResult cir::FuncOp::verify() {
2538
2539 if (!isDeclaration() && getCoroutine()) {
2540 bool foundAwait = false;
2541 this->walk([&](Operation *op) {
2542 if (auto await = dyn_cast<AwaitOp>(op)) {
2543 foundAwait = true;
2544 return;
2545 }
2546 });
2547 if (!foundAwait)
2548 return emitOpError()
2549 << "coroutine body must use at least one cir.await op";
2550 }
2551
2552 llvm::SmallSet<llvm::StringRef, 16> labels;
2553 llvm::SmallSet<llvm::StringRef, 16> gotos;
2554 llvm::SmallSet<llvm::StringRef, 16> blockAddresses;
2555 bool invalidBlockAddress = false;
2556 getOperation()->walk([&](mlir::Operation *op) {
2557 if (auto lab = dyn_cast<cir::LabelOp>(op)) {
2558 labels.insert(lab.getLabel());
2559 } else if (auto goTo = dyn_cast<cir::GotoOp>(op)) {
2560 gotos.insert(goTo.getLabel());
2561 } else if (auto blkAdd = dyn_cast<cir::BlockAddressOp>(op)) {
2562 if (blkAdd.getBlockAddrInfoAttr().getFunc().getAttr() != getSymName()) {
2563 // Stop the walk early, no need to continue
2564 invalidBlockAddress = true;
2565 return mlir::WalkResult::interrupt();
2566 }
2567 blockAddresses.insert(blkAdd.getBlockAddrInfoAttr().getLabel());
2568 }
2569 return mlir::WalkResult::advance();
2570 });
2571
2572 if (invalidBlockAddress)
2573 return emitOpError() << "blockaddress references a different function";
2574
2575 llvm::SmallSet<llvm::StringRef, 16> mismatched;
2576 if (!labels.empty() || !gotos.empty()) {
2577 mismatched = llvm::set_difference(gotos, labels);
2578
2579 if (!mismatched.empty())
2580 return emitOpError() << "goto/label mismatch";
2581 }
2582
2583 mismatched.clear();
2584
2585 if (!labels.empty() || !blockAddresses.empty()) {
2586 mismatched = llvm::set_difference(blockAddresses, labels);
2587
2588 if (!mismatched.empty())
2589 return emitOpError()
2590 << "expects an existing label target in the referenced function";
2591 }
2592
2593 return success();
2594}
2595
2596//===----------------------------------------------------------------------===//
2597// AddOp / SubOp / MulOp
2598//===----------------------------------------------------------------------===//
2599
2600static LogicalResult verifyBinaryOverflowOp(mlir::Operation *op,
2601 bool noSignedWrap,
2602 bool noUnsignedWrap, bool saturated,
2603 bool hasSat) {
2604 bool noWrap = noSignedWrap || noUnsignedWrap;
2605 if (!isa<cir::IntType>(op->getResultTypes()[0]) && noWrap)
2606 return op->emitError()
2607 << "only operations on integer values may have nsw/nuw flags";
2608 if (hasSat && saturated && !isa<cir::IntType>(op->getResultTypes()[0]))
2609 return op->emitError()
2610 << "only operations on integer values may have sat flag";
2611 if (hasSat && noWrap && saturated)
2612 return op->emitError()
2613 << "the nsw/nuw flags and the saturated flag are mutually exclusive";
2614 return mlir::success();
2615}
2616
2617LogicalResult cir::AddOp::verify() {
2618 return verifyBinaryOverflowOp(getOperation(), getNoSignedWrap(),
2619 getNoUnsignedWrap(), getSaturated(),
2620 /*hasSat=*/true);
2621}
2622
2623LogicalResult cir::SubOp::verify() {
2624 return verifyBinaryOverflowOp(getOperation(), getNoSignedWrap(),
2625 getNoUnsignedWrap(), getSaturated(),
2626 /*hasSat=*/true);
2627}
2628
2629LogicalResult cir::MulOp::verify() {
2630 return verifyBinaryOverflowOp(getOperation(), getNoSignedWrap(),
2631 getNoUnsignedWrap(), /*saturated=*/false,
2632 /*hasSat=*/false);
2633}
2634
2635//===----------------------------------------------------------------------===//
2636// TernaryOp
2637//===----------------------------------------------------------------------===//
2638
2639/// Given the region at `point`, or the parent operation if `point` is None,
2640/// return the successor regions. These are the regions that may be selected
2641/// during the flow of control. `operands` is a set of optional attributes that
2642/// correspond to a constant value for each operand, or null if that operand is
2643/// not a constant.
2644void cir::TernaryOp::getSuccessorRegions(
2645 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2646 // The `true` and the `false` region branch back to the parent operation.
2647 if (!point.isParent()) {
2648 regions.push_back(RegionSuccessor::parent());
2649 return;
2650 }
2651
2652 // When branching from the parent operation, both the true and false
2653 // regions are considered possible successors
2654 regions.push_back(RegionSuccessor(&getTrueRegion()));
2655 regions.push_back(RegionSuccessor(&getFalseRegion()));
2656}
2657
2658mlir::ValueRange cir::TernaryOp::getSuccessorInputs(RegionSuccessor successor) {
2659 return successor.isParent() ? ValueRange(getOperation()->getResults())
2660 : ValueRange();
2661}
2662
2663void cir::TernaryOp::build(
2664 OpBuilder &builder, OperationState &result, Value cond,
2665 function_ref<void(OpBuilder &, Location)> trueBuilder,
2666 function_ref<void(OpBuilder &, Location)> falseBuilder) {
2667 result.addOperands(cond);
2668 OpBuilder::InsertionGuard guard(builder);
2669 Region *trueRegion = result.addRegion();
2670 builder.createBlock(trueRegion);
2671 trueBuilder(builder, result.location);
2672 Region *falseRegion = result.addRegion();
2673 builder.createBlock(falseRegion);
2674 falseBuilder(builder, result.location);
2675
2676 // Get result type from whichever branch has a yield (the other may have
2677 // unreachable from a throw expression)
2678 cir::YieldOp yield;
2679 if (trueRegion->back().mightHaveTerminator())
2680 yield = dyn_cast_or_null<cir::YieldOp>(trueRegion->back().getTerminator());
2681 if (!yield && falseRegion->back().mightHaveTerminator())
2682 yield = dyn_cast_or_null<cir::YieldOp>(falseRegion->back().getTerminator());
2683
2684 assert((!yield || yield.getNumOperands() <= 1) &&
2685 "expected zero or one result type");
2686 if (yield && yield.getNumOperands() == 1)
2687 result.addTypes(TypeRange{yield.getOperandTypes().front()});
2688}
2689
2690//===----------------------------------------------------------------------===//
2691// SelectOp
2692//===----------------------------------------------------------------------===//
2693
2694OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
2695 mlir::Attribute condition = adaptor.getCondition();
2696 if (condition) {
2697 bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
2698 return conditionValue ? getTrueValue() : getFalseValue();
2699 }
2700
2701 // cir.select if %0 then x else x -> x
2702 mlir::Attribute trueValue = adaptor.getTrueValue();
2703 mlir::Attribute falseValue = adaptor.getFalseValue();
2704 if (trueValue == falseValue)
2705 return trueValue;
2706 if (getTrueValue() == getFalseValue())
2707 return getTrueValue();
2708
2709 return {};
2710}
2711
2712LogicalResult cir::SelectOp::verify() {
2713 // AllTypesMatch already guarantees trueVal and falseVal have matching types.
2714 auto condTy = dyn_cast<cir::VectorType>(getCondition().getType());
2715
2716 // If condition is not a vector, no further checks are needed.
2717 if (!condTy)
2718 return success();
2719
2720 // When condition is a vector, both other operands must also be vectors.
2721 if (!isa<cir::VectorType>(getTrueValue().getType()) ||
2722 !isa<cir::VectorType>(getFalseValue().getType())) {
2723 return emitOpError()
2724 << "expected both true and false operands to be vector types "
2725 "when the condition is a vector boolean type";
2726 }
2727
2728 return success();
2729}
2730
2731//===----------------------------------------------------------------------===//
2732// ShiftOp
2733//===----------------------------------------------------------------------===//
2734LogicalResult cir::ShiftOp::verify() {
2735 mlir::Operation *op = getOperation();
2736 auto op0VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(0).getType());
2737 auto op1VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(1).getType());
2738 if (!op0VecTy ^ !op1VecTy)
2739 return emitOpError() << "input types cannot be one vector and one scalar";
2740
2741 if (op0VecTy) {
2742 if (op0VecTy.getSize() != op1VecTy.getSize())
2743 return emitOpError() << "input vector types must have the same size";
2744
2745 auto opResultTy = mlir::dyn_cast<cir::VectorType>(getType());
2746 if (!opResultTy)
2747 return emitOpError() << "the type of the result must be a vector "
2748 << "if it is vector shift";
2749
2750 auto op0VecEleTy = mlir::cast<cir::IntType>(op0VecTy.getElementType());
2751 auto op1VecEleTy = mlir::cast<cir::IntType>(op1VecTy.getElementType());
2752 if (op0VecEleTy.getWidth() != op1VecEleTy.getWidth())
2753 return emitOpError()
2754 << "vector operands do not have the same elements sizes";
2755
2756 auto resVecEleTy = mlir::cast<cir::IntType>(opResultTy.getElementType());
2757 if (op0VecEleTy.getWidth() != resVecEleTy.getWidth())
2758 return emitOpError() << "vector operands and result type do not have the "
2759 "same elements sizes";
2760 }
2761
2762 return mlir::success();
2763}
2764
2765//===----------------------------------------------------------------------===//
2766// LabelOp Definitions
2767//===----------------------------------------------------------------------===//
2768
2769LogicalResult cir::LabelOp::verify() {
2770 mlir::Operation *op = getOperation();
2771 mlir::Block *blk = op->getBlock();
2772 if (&blk->front() != op)
2773 return emitError() << "must be the first operation in a block";
2774
2775 return mlir::success();
2776}
2777
2778//===----------------------------------------------------------------------===//
2779// IncOp
2780//===----------------------------------------------------------------------===//
2781
2782OpFoldResult cir::IncOp::fold(FoldAdaptor adaptor) {
2783 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()))
2784 return adaptor.getInput();
2785 return {};
2786}
2787
2788//===----------------------------------------------------------------------===//
2789// DecOp
2790//===----------------------------------------------------------------------===//
2791
2792OpFoldResult cir::DecOp::fold(FoldAdaptor adaptor) {
2793 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()))
2794 return adaptor.getInput();
2795 return {};
2796}
2797
2798//===----------------------------------------------------------------------===//
2799// MinusOp
2800//===----------------------------------------------------------------------===//
2801
2802OpFoldResult cir::MinusOp::fold(FoldAdaptor adaptor) {
2803 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()))
2804 return adaptor.getInput();
2805
2806 // Avoid materializing a duplicate constant for bool minus (identity).
2807 if (auto srcConst = getInput().getDefiningOp<cir::ConstantOp>())
2808 if (mlir::isa<cir::BoolType>(srcConst.getType()))
2809 return srcConst.getResult();
2810
2811 // Fold with constant inputs.
2812 if (mlir::Attribute attr = adaptor.getInput()) {
2813 if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(attr)) {
2814 APInt val = intAttr.getValue();
2815 val.negate();
2816 return cir::IntAttr::get(getType(), val);
2817 }
2818 if (auto fpAttr = mlir::dyn_cast<cir::FPAttr>(attr)) {
2819 APFloat val = fpAttr.getValue();
2820 val.changeSign();
2821 return cir::FPAttr::get(getType(), val);
2822 }
2823 }
2824
2825 return {};
2826}
2827
2828//===----------------------------------------------------------------------===//
2829// NotOp
2830//===----------------------------------------------------------------------===//
2831
2832OpFoldResult cir::NotOp::fold(FoldAdaptor adaptor) {
2833 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()))
2834 return adaptor.getInput();
2835
2836 // not(not(x)) -> x is handled by the Involution trait.
2837
2838 // Fold with constant inputs.
2839 if (mlir::Attribute attr = adaptor.getInput()) {
2840 if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(attr)) {
2841 APInt val = intAttr.getValue();
2842 val.flipAllBits();
2843 return cir::IntAttr::get(getType(), val);
2844 }
2845 if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr))
2846 return cir::BoolAttr::get(getContext(), !boolAttr.getValue());
2847 }
2848
2849 return {};
2850}
2851
2852//===----------------------------------------------------------------------===//
2853// BaseDataMemberOp & DerivedDataMemberOp
2854//===----------------------------------------------------------------------===//
2855
2856static LogicalResult verifyMemberPtrCast(Operation *op, mlir::Value src,
2857 mlir::Type resultTy) {
2858 // Let the operand type be T1 C1::*, let the result type be T2 C2::*.
2859 // Verify that T1 and T2 are the same type.
2860 mlir::Type inputMemberTy;
2861 mlir::Type resultMemberTy;
2862 if (mlir::isa<cir::DataMemberType>(src.getType())) {
2863 inputMemberTy =
2864 mlir::cast<cir::DataMemberType>(src.getType()).getMemberTy();
2865 resultMemberTy = mlir::cast<cir::DataMemberType>(resultTy).getMemberTy();
2866 }
2868 if (inputMemberTy != resultMemberTy)
2869 return op->emitOpError()
2870 << "member types of the operand and the result do not match";
2871
2872 return mlir::success();
2873}
2874
2875LogicalResult cir::BaseDataMemberOp::verify() {
2876 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
2877}
2878
2879LogicalResult cir::DerivedDataMemberOp::verify() {
2880 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
2881}
2882
2883//===----------------------------------------------------------------------===//
2884// BaseMethodOp & DerivedMethodOp
2885//===----------------------------------------------------------------------===//
2886
2887LogicalResult cir::BaseMethodOp::verify() {
2888 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
2889}
2890
2891LogicalResult cir::DerivedMethodOp::verify() {
2892 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
2893}
2894
2895//===----------------------------------------------------------------------===//
2896// AwaitOp
2897//===----------------------------------------------------------------------===//
2898
2899void cir::AwaitOp::build(OpBuilder &builder, OperationState &result,
2900 cir::AwaitKind kind, BuilderCallbackRef readyBuilder,
2901 BuilderCallbackRef suspendBuilder,
2902 BuilderCallbackRef resumeBuilder) {
2903 result.addAttribute(getKindAttrName(result.name),
2904 cir::AwaitKindAttr::get(builder.getContext(), kind));
2905 {
2906 OpBuilder::InsertionGuard guard(builder);
2907 Region *readyRegion = result.addRegion();
2908 builder.createBlock(readyRegion);
2909 readyBuilder(builder, result.location);
2910 }
2911
2912 {
2913 OpBuilder::InsertionGuard guard(builder);
2914 Region *suspendRegion = result.addRegion();
2915 builder.createBlock(suspendRegion);
2916 suspendBuilder(builder, result.location);
2917 }
2918
2919 {
2920 OpBuilder::InsertionGuard guard(builder);
2921 Region *resumeRegion = result.addRegion();
2922 builder.createBlock(resumeRegion);
2923 resumeBuilder(builder, result.location);
2924 }
2925}
2926
2927void cir::AwaitOp::getSuccessorRegions(
2928 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2929 // If any index all the underlying regions branch back to the parent
2930 // operation.
2931 if (!point.isParent()) {
2932 regions.push_back(RegionSuccessor::parent());
2933 return;
2934 }
2935
2936 // TODO: retrieve information from the promise and only push the
2937 // necessary ones. Example: `std::suspend_never` on initial or final
2938 // await's might allow suspend region to be skipped.
2939 regions.push_back(RegionSuccessor(&this->getReady()));
2940 regions.push_back(RegionSuccessor(&this->getSuspend()));
2941 regions.push_back(RegionSuccessor(&this->getResume()));
2942}
2943
2944mlir::ValueRange cir::AwaitOp::getSuccessorInputs(RegionSuccessor successor) {
2945 if (successor.isParent())
2946 return getOperation()->getResults();
2947 if (successor == &getReady())
2948 return getReady().getArguments();
2949 if (successor == &getSuspend())
2950 return getSuspend().getArguments();
2951 if (successor == &getResume())
2952 return getResume().getArguments();
2953 llvm_unreachable("invalid region successor");
2954}
2955
2956LogicalResult cir::AwaitOp::verify() {
2957 if (!isa<ConditionOp>(this->getReady().back().getTerminator()))
2958 return emitOpError("ready region must end with cir.condition");
2959 return success();
2960}
2961
2962//===----------------------------------------------------------------------===//
2963// CopyOp Definitions
2964//===----------------------------------------------------------------------===//
2965
2966LogicalResult cir::CopyOp::verify() {
2967 // A data layout is required for us to know the number of bytes to be copied.
2968 if (!getType().getPointee().hasTrait<DataLayoutTypeInterface::Trait>())
2969 return emitError() << "missing data layout for pointee type";
2970
2971 if (getSrc() == getDst())
2972 return emitError() << "source and destination are the same";
2973
2974 if (getSkipTailPadding() &&
2975 !mlir::isa<cir::RecordType>(getType().getPointee()))
2976 return emitError()
2977 << "skip_tail_padding is only valid for record pointee types";
2978
2979 return mlir::success();
2980}
2981
2982//===----------------------------------------------------------------------===//
2983// GetRuntimeMemberOp Definitions
2984//===----------------------------------------------------------------------===//
2985
2986LogicalResult cir::GetRuntimeMemberOp::verify() {
2987 auto recordTy = mlir::cast<RecordType>(getAddr().getType().getPointee());
2988 cir::DataMemberType memberPtrTy = getMember().getType();
2989
2990 if (recordTy != memberPtrTy.getClassTy())
2991 return emitError() << "record type does not match the member pointer type";
2992 if (getType().getPointee() != memberPtrTy.getMemberTy())
2993 return emitError() << "result type does not match the member pointer type";
2994 return mlir::success();
2995}
2996
2997//===----------------------------------------------------------------------===//
2998// GetMethodOp Definitions
2999//===----------------------------------------------------------------------===//
3000
3001LogicalResult cir::GetMethodOp::verify() {
3002 cir::MethodType methodTy = getMethod().getType();
3003
3004 // Assume objectTy is !cir.ptr<!T>
3005 cir::PointerType objectPtrTy = getObject().getType();
3006 mlir::Type objectTy = objectPtrTy.getPointee();
3007
3008 if (methodTy.getClassTy() != objectTy)
3009 return emitError() << "method class type and object type do not match";
3010
3011 // Assume methodFuncTy is !cir.func<!Ret (!Args)>
3012 auto calleeTy = mlir::cast<cir::FuncType>(getCallee().getType().getPointee());
3013 cir::FuncType methodFuncTy = methodTy.getMemberFuncTy();
3014
3015 // We verify at here that calleeTy is !cir.func<!Ret (!cir.ptr<!void>, !Args)>
3016 // Note that the first parameter type of the callee is !cir.ptr<!void> instead
3017 // of !cir.ptr<!T> because the "this" pointer may be adjusted before calling
3018 // the callee.
3019
3020 if (methodFuncTy.getReturnType() != calleeTy.getReturnType())
3021 return emitError()
3022 << "method return type and callee return type do not match";
3023
3024 llvm::ArrayRef<mlir::Type> calleeArgsTy = calleeTy.getInputs();
3025 llvm::ArrayRef<mlir::Type> methodFuncArgsTy = methodFuncTy.getInputs();
3026
3027 if (calleeArgsTy.empty())
3028 return emitError() << "callee parameter list lacks receiver object ptr";
3029
3030 auto calleeThisArgPtrTy = mlir::dyn_cast<cir::PointerType>(calleeArgsTy[0]);
3031 if (!calleeThisArgPtrTy ||
3032 !mlir::isa<cir::VoidType>(calleeThisArgPtrTy.getPointee())) {
3033 return emitError()
3034 << "the first parameter of callee must be a void pointer";
3035 }
3036
3037 if (calleeArgsTy.slice(1) != methodFuncArgsTy)
3038 return emitError()
3039 << "callee parameters and method parameters do not match";
3040
3041 return mlir::success();
3042}
3043
3044//===----------------------------------------------------------------------===//
3045// GetMemberOp Definitions
3046//===----------------------------------------------------------------------===//
3047
3048LogicalResult cir::GetMemberOp::verify() {
3049 const auto recordTy = dyn_cast<RecordType>(getAddrTy().getPointee());
3050 if (!recordTy)
3051 return emitError() << "expected pointer to a record type";
3052
3053 if (recordTy.getMembers().size() <= getIndex())
3054 return emitError() << "member index out of bounds";
3055
3056 if (recordTy.getMembers()[getIndex()] != getType().getPointee())
3057 return emitError() << "member type mismatch";
3058
3059 return mlir::success();
3060}
3061
3062//===----------------------------------------------------------------------===//
3063// ExtractMemberOp Definitions
3064//===----------------------------------------------------------------------===//
3065
3066LogicalResult cir::ExtractMemberOp::verify() {
3067 auto recordTy = mlir::cast<cir::RecordType>(getRecord().getType());
3068 if (recordTy.getKind() == cir::RecordType::Union)
3069 return emitError()
3070 << "cir.extract_member currently does not support unions";
3071 if (recordTy.getMembers().size() <= getIndex())
3072 return emitError() << "member index out of bounds";
3073 if (recordTy.getMembers()[getIndex()] != getType())
3074 return emitError() << "member type mismatch";
3075 return mlir::success();
3076}
3077
3078//===----------------------------------------------------------------------===//
3079// InsertMemberOp Definitions
3080//===----------------------------------------------------------------------===//
3081
3082LogicalResult cir::InsertMemberOp::verify() {
3083 auto recordTy = mlir::cast<cir::RecordType>(getRecord().getType());
3084 if (recordTy.getKind() == cir::RecordType::Union)
3085 return emitError() << "cir.insert_member currently does not support unions";
3086 if (recordTy.getMembers().size() <= getIndex())
3087 return emitError() << "member index out of bounds";
3088 if (recordTy.getMembers()[getIndex()] != getValue().getType())
3089 return emitError() << "member type mismatch";
3090 // The op trait already checks that the types of $result and $record match.
3091 return mlir::success();
3092}
3093
3094//===----------------------------------------------------------------------===//
3095// VecCreateOp
3096//===----------------------------------------------------------------------===//
3097
3098OpFoldResult cir::VecCreateOp::fold(FoldAdaptor adaptor) {
3099 if (llvm::any_of(getElements(), [](mlir::Value value) {
3100 return !value.getDefiningOp<cir::ConstantOp>();
3101 }))
3102 return {};
3103
3104 return cir::ConstVectorAttr::get(
3105 getType(), mlir::ArrayAttr::get(getContext(), adaptor.getElements()));
3106}
3107
3108LogicalResult cir::VecCreateOp::verify() {
3109 // Verify that the number of arguments matches the number of elements in the
3110 // vector, and that the type of all the arguments matches the type of the
3111 // elements in the vector.
3112 const cir::VectorType vecTy = getType();
3113 if (getElements().size() != vecTy.getSize()) {
3114 return emitOpError() << "operand count of " << getElements().size()
3115 << " doesn't match vector type " << vecTy
3116 << " element count of " << vecTy.getSize();
3117 }
3118
3119 const mlir::Type elementType = vecTy.getElementType();
3120 for (const mlir::Value element : getElements()) {
3121 if (element.getType() != elementType) {
3122 return emitOpError() << "operand type " << element.getType()
3123 << " doesn't match vector element type "
3124 << elementType;
3125 }
3126 }
3127
3128 return success();
3129}
3130
3131//===----------------------------------------------------------------------===//
3132// VecExtractOp
3133//===----------------------------------------------------------------------===//
3134
3135OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
3136 const auto vectorAttr =
3137 llvm::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec());
3138 if (!vectorAttr)
3139 return {};
3140
3141 const auto indexAttr =
3142 llvm::dyn_cast_if_present<cir::IntAttr>(adaptor.getIndex());
3143 if (!indexAttr)
3144 return {};
3145
3146 const mlir::ArrayAttr elements = vectorAttr.getElts();
3147 const uint64_t index = indexAttr.getUInt();
3148 if (index >= elements.size())
3149 return {};
3150
3151 return elements[index];
3152}
3153
3154//===----------------------------------------------------------------------===//
3155// VecCmpOp
3156//===----------------------------------------------------------------------===//
3157
3158OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
3159 auto lhsVecAttr =
3160 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getLhs());
3161 auto rhsVecAttr =
3162 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getRhs());
3163 if (!lhsVecAttr || !rhsVecAttr)
3164 return {};
3165
3166 mlir::Type inputElemTy =
3167 mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType();
3168 if (!isAnyIntegerOrFloatingPointType(inputElemTy))
3169 return {};
3170
3171 cir::CmpOpKind opKind = adaptor.getKind();
3172 mlir::ArrayAttr lhsVecElhs = lhsVecAttr.getElts();
3173 mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts();
3174 uint64_t vecSize = lhsVecElhs.size();
3175
3176 SmallVector<mlir::Attribute, 16> elements(vecSize);
3177 bool isIntAttr = vecSize && mlir::isa<cir::IntAttr>(lhsVecElhs[0]);
3178 for (uint64_t i = 0; i < vecSize; i++) {
3179 mlir::Attribute lhsAttr = lhsVecElhs[i];
3180 mlir::Attribute rhsAttr = rhsVecElhs[i];
3181 int cmpResult = 0;
3182 switch (opKind) {
3183 case cir::CmpOpKind::lt: {
3184 if (isIntAttr) {
3185 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
3186 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3187 } else {
3188 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <
3189 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3190 }
3191 break;
3192 }
3193 case cir::CmpOpKind::le: {
3194 if (isIntAttr) {
3195 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
3196 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3197 } else {
3198 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <=
3199 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3200 }
3201 break;
3202 }
3203 case cir::CmpOpKind::gt: {
3204 if (isIntAttr) {
3205 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
3206 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3207 } else {
3208 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >
3209 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3210 }
3211 break;
3212 }
3213 case cir::CmpOpKind::ge: {
3214 if (isIntAttr) {
3215 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
3216 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3217 } else {
3218 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >=
3219 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3220 }
3221 break;
3222 }
3223 case cir::CmpOpKind::eq: {
3224 if (isIntAttr) {
3225 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
3226 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3227 } else {
3228 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() ==
3229 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3230 }
3231 break;
3232 }
3233 case cir::CmpOpKind::ne: {
3234 if (isIntAttr) {
3235 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
3236 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3237 } else {
3238 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() !=
3239 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3240 }
3241 break;
3242 }
3243 }
3244
3245 elements[i] = cir::IntAttr::get(getType().getElementType(), cmpResult);
3246 }
3247
3248 return cir::ConstVectorAttr::get(
3249 getType(), mlir::ArrayAttr::get(getContext(), elements));
3250}
3251
3252//===----------------------------------------------------------------------===//
3253// VecShuffleOp
3254//===----------------------------------------------------------------------===//
3255
3256OpFoldResult cir::VecShuffleOp::fold(FoldAdaptor adaptor) {
3257 auto vec1Attr =
3258 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec1());
3259 auto vec2Attr =
3260 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec2());
3261 if (!vec1Attr || !vec2Attr)
3262 return {};
3263
3264 mlir::Type vec1ElemTy =
3265 mlir::cast<cir::VectorType>(vec1Attr.getType()).getElementType();
3266
3267 mlir::ArrayAttr vec1Elts = vec1Attr.getElts();
3268 mlir::ArrayAttr vec2Elts = vec2Attr.getElts();
3269 mlir::ArrayAttr indicesElts = adaptor.getIndices();
3270
3272 elements.reserve(indicesElts.size());
3273
3274 uint64_t vec1Size = vec1Elts.size();
3275 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
3276 if (idxAttr.getSInt() == -1) {
3277 elements.push_back(cir::UndefAttr::get(vec1ElemTy));
3278 continue;
3279 }
3280
3281 uint64_t idxValue = idxAttr.getUInt();
3282 elements.push_back(idxValue < vec1Size ? vec1Elts[idxValue]
3283 : vec2Elts[idxValue - vec1Size]);
3284 }
3285
3286 return cir::ConstVectorAttr::get(
3287 getType(), mlir::ArrayAttr::get(getContext(), elements));
3288}
3289
3290LogicalResult cir::VecShuffleOp::verify() {
3291 // The number of elements in the indices array must match the number of
3292 // elements in the result type.
3293 if (getIndices().size() != getResult().getType().getSize()) {
3294 return emitOpError() << ": the number of elements in " << getIndices()
3295 << " and " << getResult().getType() << " don't match";
3296 }
3297
3298 // The element types of the two input vectors and of the result type must
3299 // match.
3300 if (getVec1().getType().getElementType() !=
3301 getResult().getType().getElementType()) {
3302 return emitOpError() << ": element types of " << getVec1().getType()
3303 << " and " << getResult().getType() << " don't match";
3304 }
3305
3306 const uint64_t maxValidIndex =
3307 getVec1().getType().getSize() + getVec2().getType().getSize() - 1;
3308 if (llvm::any_of(
3309 getIndices().getAsRange<cir::IntAttr>(), [&](cir::IntAttr idxAttr) {
3310 return idxAttr.getSInt() != -1 && idxAttr.getUInt() > maxValidIndex;
3311 })) {
3312 return emitOpError() << ": index for __builtin_shufflevector must be "
3313 "less than the total number of vector elements";
3314 }
3315 return success();
3316}
3317
3318//===----------------------------------------------------------------------===//
3319// VecShuffleDynamicOp
3320//===----------------------------------------------------------------------===//
3321
3322OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
3323 mlir::Attribute vec = adaptor.getVec();
3324 mlir::Attribute indices = adaptor.getIndices();
3325 if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec) &&
3326 mlir::isa_and_nonnull<cir::ConstVectorAttr>(indices)) {
3327 auto vecAttr = mlir::cast<cir::ConstVectorAttr>(vec);
3328 auto indicesAttr = mlir::cast<cir::ConstVectorAttr>(indices);
3329
3330 mlir::ArrayAttr vecElts = vecAttr.getElts();
3331 mlir::ArrayAttr indicesElts = indicesAttr.getElts();
3332
3333 const uint64_t numElements = vecElts.size();
3334
3336 elements.reserve(numElements);
3337
3338 const uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1;
3339 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
3340 uint64_t idxValue = idxAttr.getUInt();
3341 uint64_t newIdx = idxValue & maskBits;
3342 elements.push_back(vecElts[newIdx]);
3343 }
3344
3345 return cir::ConstVectorAttr::get(
3346 getType(), mlir::ArrayAttr::get(getContext(), elements));
3347 }
3348
3349 return {};
3350}
3351
3352LogicalResult cir::VecShuffleDynamicOp::verify() {
3353 // The number of elements in the two input vectors must match.
3354 if (getVec().getType().getSize() !=
3355 mlir::cast<cir::VectorType>(getIndices().getType()).getSize()) {
3356 return emitOpError() << ": the number of elements in " << getVec().getType()
3357 << " and " << getIndices().getType() << " don't match";
3358 }
3359 return success();
3360}
3361
3362//===----------------------------------------------------------------------===//
3363// VecTernaryOp
3364//===----------------------------------------------------------------------===//
3365
3366LogicalResult cir::VecTernaryOp::verify() {
3367 // Verify that the condition operand has the same number of elements as the
3368 // other operands. (The automatic verification already checked that all
3369 // operands are vector types and that the second and third operands are the
3370 // same type.)
3371 if (getCond().getType().getSize() != getLhs().getType().getSize()) {
3372 return emitOpError() << ": the number of elements in "
3373 << getCond().getType() << " and " << getLhs().getType()
3374 << " don't match";
3375 }
3376 return success();
3377}
3378
3379OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
3380 mlir::Attribute cond = adaptor.getCond();
3381 mlir::Attribute lhs = adaptor.getLhs();
3382 mlir::Attribute rhs = adaptor.getRhs();
3383
3384 if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
3385 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
3386 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
3387 return {};
3388 auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
3389 auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
3390 auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
3391
3392 mlir::ArrayAttr condElts = condVec.getElts();
3393
3395 elements.reserve(condElts.size());
3396
3397 for (const auto &[idx, condAttr] :
3398 llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
3399 if (condAttr.getSInt()) {
3400 elements.push_back(lhsVec.getElts()[idx]);
3401 } else {
3402 elements.push_back(rhsVec.getElts()[idx]);
3403 }
3404 }
3405
3406 cir::VectorType vecTy = getLhs().getType();
3407 return cir::ConstVectorAttr::get(
3408 vecTy, mlir::ArrayAttr::get(getContext(), elements));
3409}
3410
3411//===----------------------------------------------------------------------===//
3412// ComplexCreateOp
3413//===----------------------------------------------------------------------===//
3414
3415LogicalResult cir::ComplexCreateOp::verify() {
3416 if (getType().getElementType() != getReal().getType()) {
3417 emitOpError()
3418 << "operand type of cir.complex.create does not match its result type";
3419 return failure();
3420 }
3421
3422 return success();
3423}
3424
3425OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
3426 mlir::Attribute real = adaptor.getReal();
3427 mlir::Attribute imag = adaptor.getImag();
3428 if (!real || !imag)
3429 return {};
3430
3431 // When both of real and imag are constants, we can fold the operation into an
3432 // `#cir.const_complex` operation.
3433 auto realAttr = mlir::cast<mlir::TypedAttr>(real);
3434 auto imagAttr = mlir::cast<mlir::TypedAttr>(imag);
3435 return cir::ConstComplexAttr::get(realAttr, imagAttr);
3436}
3437
3438//===----------------------------------------------------------------------===//
3439// ComplexRealOp
3440//===----------------------------------------------------------------------===//
3441
3442LogicalResult cir::ComplexRealOp::verify() {
3443 mlir::Type operandTy = getOperand().getType();
3444 if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
3445 operandTy = complexOperandTy.getElementType();
3446
3447 if (getType() != operandTy) {
3448 emitOpError() << ": result type does not match operand type";
3449 return failure();
3450 }
3451
3452 return success();
3453}
3454
3455OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
3456 if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
3457 return nullptr;
3458
3459 if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
3460 return complexCreateOp.getOperand(0);
3461
3462 auto complex =
3463 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
3464 return complex ? complex.getReal() : nullptr;
3465}
3466
3467//===----------------------------------------------------------------------===//
3468// ComplexImagOp
3469//===----------------------------------------------------------------------===//
3470
3471LogicalResult cir::ComplexImagOp::verify() {
3472 mlir::Type operandTy = getOperand().getType();
3473 if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
3474 operandTy = complexOperandTy.getElementType();
3475
3476 if (getType() != operandTy) {
3477 emitOpError() << ": result type does not match operand type";
3478 return failure();
3479 }
3480
3481 return success();
3482}
3483
3484OpFoldResult cir::ComplexImagOp::fold(FoldAdaptor adaptor) {
3485 if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
3486 return nullptr;
3487
3488 if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
3489 return complexCreateOp.getOperand(1);
3490
3491 auto complex =
3492 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
3493 return complex ? complex.getImag() : nullptr;
3494}
3495
3496//===----------------------------------------------------------------------===//
3497// ComplexRealPtrOp
3498//===----------------------------------------------------------------------===//
3499
3500LogicalResult cir::ComplexRealPtrOp::verify() {
3501 mlir::Type resultPointeeTy = getType().getPointee();
3502 cir::PointerType operandPtrTy = getOperand().getType();
3503 auto operandPointeeTy =
3504 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
3505
3506 if (resultPointeeTy != operandPointeeTy.getElementType()) {
3507 return emitOpError() << ": result type does not match operand type";
3508 }
3509
3510 return success();
3511}
3512
3513//===----------------------------------------------------------------------===//
3514// ComplexImagPtrOp
3515//===----------------------------------------------------------------------===//
3516
3517LogicalResult cir::ComplexImagPtrOp::verify() {
3518 mlir::Type resultPointeeTy = getType().getPointee();
3519 cir::PointerType operandPtrTy = getOperand().getType();
3520 auto operandPointeeTy =
3521 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
3522
3523 if (resultPointeeTy != operandPointeeTy.getElementType()) {
3524 return emitOpError()
3525 << "cir.complex.imag_ptr result type does not match operand type";
3526 }
3527 return success();
3528}
3529
3530//===----------------------------------------------------------------------===//
3531// Bit manipulation operations
3532//===----------------------------------------------------------------------===//
3533
3534static OpFoldResult
3535foldUnaryBitOp(mlir::Attribute inputAttr,
3536 llvm::function_ref<llvm::APInt(const llvm::APInt &)> func,
3537 bool poisonZero = false) {
3538 if (mlir::isa_and_present<cir::PoisonAttr>(inputAttr)) {
3539 // Propagate poison value
3540 return inputAttr;
3541 }
3542
3543 auto input = mlir::dyn_cast_if_present<IntAttr>(inputAttr);
3544 if (!input)
3545 return nullptr;
3546
3547 llvm::APInt inputValue = input.getValue();
3548 if (poisonZero && inputValue.isZero())
3549 return cir::PoisonAttr::get(input.getType());
3550
3551 llvm::APInt resultValue = func(inputValue);
3552 return IntAttr::get(input.getType(), resultValue);
3553}
3554
3555OpFoldResult BitClrsbOp::fold(FoldAdaptor adaptor) {
3556 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3557 unsigned resultValue =
3558 inputValue.getBitWidth() - inputValue.getSignificantBits();
3559 return llvm::APInt(inputValue.getBitWidth(), resultValue);
3560 });
3561}
3562
3563OpFoldResult BitClzOp::fold(FoldAdaptor adaptor) {
3564 return foldUnaryBitOp(
3565 adaptor.getInput(),
3566 [](const llvm::APInt &inputValue) {
3567 unsigned resultValue = inputValue.countLeadingZeros();
3568 return llvm::APInt(inputValue.getBitWidth(), resultValue);
3569 },
3570 getPoisonZero());
3571}
3572
3573OpFoldResult BitCtzOp::fold(FoldAdaptor adaptor) {
3574 return foldUnaryBitOp(
3575 adaptor.getInput(),
3576 [](const llvm::APInt &inputValue) {
3577 return llvm::APInt(inputValue.getBitWidth(),
3578 inputValue.countTrailingZeros());
3579 },
3580 getPoisonZero());
3581}
3582
3583OpFoldResult BitFfsOp::fold(FoldAdaptor adaptor) {
3584 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3585 unsigned trailingZeros = inputValue.countTrailingZeros();
3586 unsigned result =
3587 trailingZeros == inputValue.getBitWidth() ? 0 : trailingZeros + 1;
3588 return llvm::APInt(inputValue.getBitWidth(), result);
3589 });
3590}
3591
3592OpFoldResult BitParityOp::fold(FoldAdaptor adaptor) {
3593 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3594 return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount() % 2);
3595 });
3596}
3597
3598OpFoldResult BitPopcountOp::fold(FoldAdaptor adaptor) {
3599 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3600 return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount());
3601 });
3602}
3603
3604OpFoldResult BitReverseOp::fold(FoldAdaptor adaptor) {
3605 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3606 return inputValue.reverseBits();
3607 });
3608}
3609
3610OpFoldResult ByteSwapOp::fold(FoldAdaptor adaptor) {
3611 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3612 return inputValue.byteSwap();
3613 });
3614}
3615
3616OpFoldResult RotateOp::fold(FoldAdaptor adaptor) {
3617 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()) ||
3618 mlir::isa_and_present<cir::PoisonAttr>(adaptor.getAmount())) {
3619 // Propagate poison values
3620 return cir::PoisonAttr::get(getType());
3621 }
3622
3623 auto input = mlir::dyn_cast_if_present<IntAttr>(adaptor.getInput());
3624 auto amount = mlir::dyn_cast_if_present<IntAttr>(adaptor.getAmount());
3625 if (!input && !amount)
3626 return nullptr;
3627
3628 // We could fold cir.rotate even if one of its two operands is not a constant:
3629 // - `cir.rotate left/right %0, 0` could be folded into just %0 even if %0
3630 // is not a constant.
3631 // - `cir.rotate left/right 0/0b111...111, %0` could be folded into 0 or
3632 // 0b111...111 even if %0 is not a constant.
3633
3634 llvm::APInt inputValue;
3635 if (input) {
3636 inputValue = input.getValue();
3637 if (inputValue.isZero() || inputValue.isAllOnes()) {
3638 // An input value of all 0s or all 1s will not change after rotation
3639 return input;
3640 }
3641 }
3642
3643 uint64_t amountValue;
3644 if (amount) {
3645 amountValue = amount.getValue().urem(getInput().getType().getWidth());
3646 if (amountValue == 0) {
3647 // A shift amount of 0 will not change the input value
3648 return getInput();
3649 }
3650 }
3651
3652 if (!input || !amount)
3653 return nullptr;
3654
3655 assert(inputValue.getBitWidth() == getInput().getType().getWidth() &&
3656 "input value must have the same bit width as the input type");
3657
3658 llvm::APInt resultValue;
3659 if (isRotateLeft())
3660 resultValue = inputValue.rotl(amountValue);
3661 else
3662 resultValue = inputValue.rotr(amountValue);
3663
3664 return IntAttr::get(input.getContext(), input.getType(), resultValue);
3665}
3666
3667//===----------------------------------------------------------------------===//
3668// InlineAsmOp
3669//===----------------------------------------------------------------------===//
3670
3671void cir::InlineAsmOp::print(OpAsmPrinter &p) {
3672 p << '(' << getAsmFlavor() << ", ";
3673 p.increaseIndent();
3674 p.printNewline();
3675
3676 llvm::SmallVector<std::string, 3> names{"out", "in", "in_out"};
3677 auto *nameIt = names.begin();
3678 auto *attrIt = getOperandAttrs().begin();
3679
3680 for (mlir::OperandRange ops : getAsmOperands()) {
3681 p << *nameIt << " = ";
3682
3683 p << '[';
3684 llvm::interleaveComma(llvm::make_range(ops.begin(), ops.end()), p,
3685 [&](Value value) {
3686 p.printOperand(value);
3687 p << " : " << value.getType();
3688 if (mlir::isa<mlir::UnitAttr>(*attrIt))
3689 p << " (maybe_memory)";
3690 attrIt++;
3691 });
3692 p << "],";
3693 p.printNewline();
3694 ++nameIt;
3695 }
3696
3697 p << "{";
3698 p.printString(getAsmString());
3699 p << " ";
3700 p.printString(getConstraints());
3701 p << "}";
3702 p.decreaseIndent();
3703 p << ')';
3704 if (getSideEffects())
3705 p << " side_effects";
3706
3707 std::array elidedAttrs{
3708 llvm::StringRef("asm_flavor"), llvm::StringRef("asm_string"),
3709 llvm::StringRef("constraints"), llvm::StringRef("operand_attrs"),
3710 llvm::StringRef("operands_segments"), llvm::StringRef("side_effects")};
3711 p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs);
3712
3713 if (auto v = getRes())
3714 p << " -> " << v.getType();
3715}
3716
3717void cir::InlineAsmOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3718 ArrayRef<ValueRange> asmOperands,
3719 StringRef asmString, StringRef constraints,
3720 bool sideEffects, cir::AsmFlavor asmFlavor,
3721 ArrayRef<Attribute> operandAttrs) {
3722 // Set up the operands_segments for VariadicOfVariadic
3723 SmallVector<int32_t> segments;
3724 for (auto operandRange : asmOperands) {
3725 segments.push_back(operandRange.size());
3726 odsState.addOperands(operandRange);
3727 }
3728
3729 odsState.addAttribute(
3730 "operands_segments",
3731 DenseI32ArrayAttr::get(odsBuilder.getContext(), segments));
3732 odsState.addAttribute("asm_string", odsBuilder.getStringAttr(asmString));
3733 odsState.addAttribute("constraints", odsBuilder.getStringAttr(constraints));
3734 odsState.addAttribute("asm_flavor",
3735 AsmFlavorAttr::get(odsBuilder.getContext(), asmFlavor));
3736
3737 if (sideEffects)
3738 odsState.addAttribute("side_effects", odsBuilder.getUnitAttr());
3739
3740 odsState.addAttribute("operand_attrs", odsBuilder.getArrayAttr(operandAttrs));
3741}
3742
3743ParseResult cir::InlineAsmOp::parse(OpAsmParser &parser,
3744 OperationState &result) {
3746 llvm::SmallVector<int32_t> operandsGroupSizes;
3747 std::string asmString, constraints;
3748 Type resType;
3749 MLIRContext *ctxt = parser.getBuilder().getContext();
3750
3751 auto error = [&](const Twine &msg) -> LogicalResult {
3752 return parser.emitError(parser.getCurrentLocation(), msg);
3753 };
3754
3755 auto expected = [&](const std::string &c) {
3756 return error("expected '" + c + "'");
3757 };
3758
3759 if (parser.parseLParen().failed())
3760 return expected("(");
3761
3762 auto flavor = FieldParser<AsmFlavor, AsmFlavor>::parse(parser);
3763 if (failed(flavor))
3764 return error("Unknown AsmFlavor");
3765
3766 if (parser.parseComma().failed())
3767 return expected(",");
3768
3769 auto parseValue = [&](Value &v) {
3770 OpAsmParser::UnresolvedOperand op;
3771
3772 if (parser.parseOperand(op) || parser.parseColon())
3773 return error("can't parse operand");
3774
3775 Type typ;
3776 if (parser.parseType(typ).failed())
3777 return error("can't parse operand type");
3779 if (parser.resolveOperand(op, typ, tmp))
3780 return error("can't resolve operand");
3781 v = tmp[0];
3782 return mlir::success();
3783 };
3784
3785 auto parseOperands = [&](llvm::StringRef name) {
3786 if (parser.parseKeyword(name).failed())
3787 return error("expected " + name + " operands here");
3788 if (parser.parseEqual().failed())
3789 return expected("=");
3790 if (parser.parseLSquare().failed())
3791 return expected("[");
3792
3793 int size = 0;
3794 if (parser.parseOptionalRSquare().succeeded()) {
3795 operandsGroupSizes.push_back(size);
3796 if (parser.parseComma())
3797 return expected(",");
3798 return mlir::success();
3799 }
3800
3801 auto parseOperand = [&]() {
3802 Value val;
3803 if (parseValue(val).succeeded()) {
3804 result.operands.push_back(val);
3805 size++;
3806
3807 if (parser.parseOptionalLParen().failed()) {
3808 operandAttrs.push_back(mlir::DictionaryAttr::get(ctxt));
3809 return mlir::success();
3810 }
3811
3812 if (parser.parseKeyword("maybe_memory").succeeded()) {
3813 operandAttrs.push_back(mlir::UnitAttr::get(ctxt));
3814 if (parser.parseRParen())
3815 return expected(")");
3816 return mlir::success();
3817 } else {
3818 return expected("maybe_memory");
3819 }
3820 }
3821 return mlir::failure();
3822 };
3823
3824 if (parser.parseCommaSeparatedList(parseOperand).failed())
3825 return mlir::failure();
3826
3827 if (parser.parseRSquare().failed() || parser.parseComma().failed())
3828 return expected("]");
3829 operandsGroupSizes.push_back(size);
3830 return mlir::success();
3831 };
3832
3833 if (parseOperands("out").failed() || parseOperands("in").failed() ||
3834 parseOperands("in_out").failed())
3835 return error("failed to parse operands");
3836
3837 if (parser.parseLBrace())
3838 return expected("{");
3839 if (parser.parseString(&asmString))
3840 return error("asm string parsing failed");
3841 if (parser.parseString(&constraints))
3842 return error("constraints string parsing failed");
3843 if (parser.parseRBrace())
3844 return expected("}");
3845 if (parser.parseRParen())
3846 return expected(")");
3847
3848 if (parser.parseOptionalKeyword("side_effects").succeeded())
3849 result.attributes.set("side_effects", UnitAttr::get(ctxt));
3850
3851 if (parser.parseOptionalAttrDict(result.attributes).failed())
3852 return mlir::failure();
3853
3854 if (parser.parseOptionalArrow().succeeded() &&
3855 parser.parseType(resType).failed())
3856 return mlir::failure();
3857
3858 result.attributes.set("asm_flavor", AsmFlavorAttr::get(ctxt, *flavor));
3859 result.attributes.set("asm_string", StringAttr::get(ctxt, asmString));
3860 result.attributes.set("constraints", StringAttr::get(ctxt, constraints));
3861 result.attributes.set("operand_attrs", ArrayAttr::get(ctxt, operandAttrs));
3862 result.getOrAddProperties<InlineAsmOp::Properties>().operands_segments =
3863 parser.getBuilder().getDenseI32ArrayAttr(operandsGroupSizes);
3864 if (resType)
3865 result.addTypes(TypeRange{resType});
3866
3867 return mlir::success();
3868}
3869
3870//===----------------------------------------------------------------------===//
3871// ThrowOp
3872//===----------------------------------------------------------------------===//
3873
3874mlir::LogicalResult cir::ThrowOp::verify() {
3875 // For the no-rethrow version, it must have at least the exception pointer.
3876 if (rethrows())
3877 return success();
3878
3879 if (getNumOperands() != 0) {
3880 if (getTypeInfo())
3881 return success();
3882 return emitOpError() << "'type_info' symbol attribute missing";
3883 }
3884
3885 return failure();
3886}
3887
3888//===----------------------------------------------------------------------===//
3889// AtomicFetchOp
3890//===----------------------------------------------------------------------===//
3891
3892LogicalResult cir::AtomicFetchOp::verify() {
3893 if (getBinop() != cir::AtomicFetchKind::Add &&
3894 getBinop() != cir::AtomicFetchKind::Sub &&
3895 getBinop() != cir::AtomicFetchKind::Max &&
3896 getBinop() != cir::AtomicFetchKind::Min &&
3897 !mlir::isa<cir::IntType>(getVal().getType()))
3898 return emitError("only atomic add, sub, max, and min operation could "
3899 "operate on floating-point values");
3900 return success();
3901}
3902
3903//===----------------------------------------------------------------------===//
3904// TypeInfoAttr
3905//===----------------------------------------------------------------------===//
3906
3907LogicalResult cir::TypeInfoAttr::verify(
3908 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
3909 ::mlir::Type type, ::mlir::ArrayAttr typeInfoData) {
3910
3911 if (cir::ConstRecordAttr::verify(emitError, type, typeInfoData).failed())
3912 return failure();
3913
3914 return success();
3915}
3916
3917//===----------------------------------------------------------------------===//
3918// TryOp
3919//===----------------------------------------------------------------------===//
3920
3921void cir::TryOp::getSuccessorRegions(
3922 mlir::RegionBranchPoint point,
3924 // The `try` and the `catchers` region branch back to the parent operation.
3925 if (!point.isParent()) {
3926 regions.push_back(RegionSuccessor::parent());
3927 return;
3928 }
3929
3930 regions.push_back(mlir::RegionSuccessor(&getTryRegion()));
3931
3932 // TODO(CIR): If we know a target function never throws a specific type, we
3933 // can remove the catch handler.
3934 for (mlir::Region &handlerRegion : this->getHandlerRegions())
3935 regions.push_back(mlir::RegionSuccessor(&handlerRegion));
3936}
3937
3938mlir::ValueRange cir::TryOp::getSuccessorInputs(RegionSuccessor successor) {
3939 return successor.isParent() ? ValueRange(getOperation()->getResults())
3940 : ValueRange();
3941}
3942
3943LogicalResult cir::TryOp::verify() {
3944 mlir::ArrayAttr handlerTypes = getHandlerTypes();
3945 if (!handlerTypes) {
3946 if (!getHandlerRegions().empty())
3947 return emitOpError(
3948 "handler regions must be empty when no handler types are present");
3949 return success();
3950 }
3951
3952 mlir::MutableArrayRef<mlir::Region> handlerRegions = getHandlerRegions();
3953
3954 // The parser and builder won't allow this to happen, but the loop below
3955 // relies on the sizes being the same, so we check it here.
3956 if (handlerRegions.size() != handlerTypes.size())
3957 return emitOpError(
3958 "number of handler regions and handler types must match");
3959
3960 for (const auto &[typeAttr, handlerRegion] :
3961 llvm::zip(handlerTypes, handlerRegions)) {
3962 // Verify that handler regions have a !cir.eh_token block argument.
3963 mlir::Block &entryBlock = handlerRegion.front();
3964 if (entryBlock.getNumArguments() != 1 ||
3965 !mlir::isa<cir::EhTokenType>(entryBlock.getArgument(0).getType()))
3966 return emitOpError(
3967 "handler region must have a single '!cir.eh_token' argument");
3968
3969 // The unwind region does not require a cir.begin_catch.
3970 if (mlir::isa<cir::UnwindAttr>(typeAttr))
3971 continue;
3972
3973 if (entryBlock.empty() || !mlir::isa<cir::BeginCatchOp>(entryBlock.front()))
3974 return emitOpError(
3975 "catch handler region must start with 'cir.begin_catch'");
3976 }
3977
3978 return success();
3979}
3980
3981static void
3982printTryHandlerRegions(mlir::OpAsmPrinter &printer, cir::TryOp op,
3983 mlir::MutableArrayRef<mlir::Region> handlerRegions,
3984 mlir::ArrayAttr handlerTypes) {
3985 if (!handlerTypes)
3986 return;
3987
3988 for (const auto [typeIdx, typeAttr] : llvm::enumerate(handlerTypes)) {
3989 if (typeIdx)
3990 printer << " ";
3991
3992 if (mlir::isa<cir::CatchAllAttr>(typeAttr)) {
3993 printer << "catch all ";
3994 } else if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
3995 printer << "unwind ";
3996 } else {
3997 printer << "catch [type ";
3998 printer.printAttribute(typeAttr);
3999 printer << "] ";
4000 }
4001
4002 // Print the handler region's !cir.eh_token block argument.
4003 mlir::Region &region = handlerRegions[typeIdx];
4004 if (!region.empty() && region.front().getNumArguments() > 0) {
4005 printer << "(";
4006 printer.printRegionArgument(region.front().getArgument(0));
4007 printer << ") ";
4008 }
4009
4010 printer.printRegion(region,
4011 /*printEntryBLockArgs=*/false,
4012 /*printBlockTerminators=*/true);
4013 }
4014}
4015
4016static mlir::ParseResult parseTryHandlerRegions(
4017 mlir::OpAsmParser &parser,
4018 llvm::SmallVectorImpl<std::unique_ptr<mlir::Region>> &handlerRegions,
4019 mlir::ArrayAttr &handlerTypes) {
4020
4021 auto parseCheckedCatcherRegion = [&]() -> mlir::ParseResult {
4022 handlerRegions.emplace_back(new mlir::Region);
4023
4024 mlir::Region &currRegion = *handlerRegions.back();
4025
4026 // Parse the required region argument: (%eh_token : !cir.eh_token)
4028 if (parser.parseLParen())
4029 return failure();
4030 mlir::OpAsmParser::Argument arg;
4031 if (parser.parseArgument(arg, /*allowType=*/true))
4032 return failure();
4033 regionArgs.push_back(arg);
4034 if (parser.parseRParen())
4035 return failure();
4036
4037 mlir::SMLoc regionLoc = parser.getCurrentLocation();
4038 if (parser.parseRegion(currRegion, regionArgs)) {
4039 handlerRegions.clear();
4040 return failure();
4041 }
4042
4043 if (currRegion.empty())
4044 return parser.emitError(regionLoc, "handler region shall not be empty");
4045
4046 if (!(currRegion.back().mightHaveTerminator() &&
4047 currRegion.back().getTerminator()))
4048 return parser.emitError(
4049 regionLoc, "blocks are expected to be explicitly terminated");
4050
4051 return success();
4052 };
4053
4054 bool hasCatchAll = false;
4056 while (parser.parseOptionalKeyword("catch").succeeded()) {
4057 bool hasLSquare = parser.parseOptionalLSquare().succeeded();
4058
4059 llvm::StringRef attrStr;
4060 if (parser.parseOptionalKeyword(&attrStr, {"all", "type"}).failed())
4061 return parser.emitError(parser.getCurrentLocation(),
4062 "expected 'all' or 'type' keyword");
4063
4064 bool isCatchAll = attrStr == "all";
4065 if (isCatchAll) {
4066 if (hasCatchAll)
4067 return parser.emitError(parser.getCurrentLocation(),
4068 "can't have more than one catch all");
4069 hasCatchAll = true;
4070 }
4071
4072 mlir::Attribute exceptionRTTIAttr;
4073 if (!isCatchAll && parser.parseAttribute(exceptionRTTIAttr).failed())
4074 return parser.emitError(parser.getCurrentLocation(),
4075 "expected valid RTTI info attribute");
4076
4077 catcherAttrs.push_back(isCatchAll
4078 ? cir::CatchAllAttr::get(parser.getContext())
4079 : exceptionRTTIAttr);
4080
4081 if (hasLSquare && isCatchAll)
4082 return parser.emitError(parser.getCurrentLocation(),
4083 "catch all dosen't need RTTI info attribute");
4084
4085 if (hasLSquare && parser.parseRSquare().failed())
4086 return parser.emitError(parser.getCurrentLocation(),
4087 "expected `]` after RTTI info attribute");
4088
4089 if (parseCheckedCatcherRegion().failed())
4090 return mlir::failure();
4091 }
4092
4093 if (parser.parseOptionalKeyword("unwind").succeeded()) {
4094 if (hasCatchAll)
4095 return parser.emitError(parser.getCurrentLocation(),
4096 "unwind can't be used with catch all");
4097
4098 catcherAttrs.push_back(cir::UnwindAttr::get(parser.getContext()));
4099 if (parseCheckedCatcherRegion().failed())
4100 return mlir::failure();
4101 }
4102
4103 handlerTypes = parser.getBuilder().getArrayAttr(catcherAttrs);
4104 return mlir::success();
4105}
4106
4107//===----------------------------------------------------------------------===//
4108// EhTypeIdOp
4109//===----------------------------------------------------------------------===//
4110
4111LogicalResult
4112cir::EhTypeIdOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
4113 Operation *op = symbolTable.lookupNearestSymbolFrom(*this, getTypeSymAttr());
4114 if (!isa_and_nonnull<GlobalOp>(op))
4115 return emitOpError("'")
4116 << getTypeSym() << "' does not reference a valid cir.global";
4117 return success();
4118}
4119
4120//===----------------------------------------------------------------------===//
4121// EhDispatchOp
4122//===----------------------------------------------------------------------===//
4123
4124static ParseResult
4125parseEhDispatchDestinations(OpAsmParser &parser, mlir::ArrayAttr &catchTypes,
4126 SmallVectorImpl<Block *> &catchDestinations,
4127 Block *&defaultDestination,
4128 mlir::UnitAttr &defaultIsCatchAll) {
4129 // Parse: [ ... ]
4130 if (parser.parseLSquare())
4131 return failure();
4132
4133 SmallVector<Attribute> handlerTypes;
4134 bool hasCatchAll = false;
4135 bool hasUnwind = false;
4136
4137 // Parse handler list.
4138 auto parseHandler = [&]() -> ParseResult {
4139 // Check for 'catch_all' or 'unwind' keywords.
4140 if (succeeded(parser.parseOptionalKeyword("catch_all"))) {
4141 if (hasCatchAll)
4142 return parser.emitError(parser.getCurrentLocation(),
4143 "duplicate 'catch_all' handler");
4144 if (hasUnwind)
4145 return parser.emitError(parser.getCurrentLocation(),
4146 "cannot have both 'catch_all' and 'unwind'");
4147 hasCatchAll = true;
4148
4149 if (parser.parseColon().failed())
4150 return failure();
4151
4152 if (parser.parseSuccessor(defaultDestination).failed())
4153 return failure();
4154
4155 return success();
4156 }
4157
4158 if (succeeded(parser.parseOptionalKeyword("unwind"))) {
4159 if (hasUnwind)
4160 return parser.emitError(parser.getCurrentLocation(),
4161 "duplicate 'unwind' handler");
4162 if (hasCatchAll)
4163 return parser.emitError(parser.getCurrentLocation(),
4164 "cannot have both 'catch_all' and 'unwind'");
4165 hasUnwind = true;
4166
4167 if (parser.parseColon().failed())
4168 return failure();
4169
4170 if (parser.parseSuccessor(defaultDestination).failed())
4171 return failure();
4172 return success();
4173 }
4174
4175 // Otherwise, expect 'catch(<attr> : <type>) : ^block'.
4176 // The 'catch(...)' wrapper allows the attribute to include its type
4177 // without conflicting with the ':' used for the block destination.
4178 if (parser.parseKeyword("catch").failed())
4179 return failure();
4180
4181 if (parser.parseLParen().failed())
4182 return failure();
4183
4184 mlir::Attribute catchTypeAttr;
4185 if (parser.parseAttribute(catchTypeAttr).failed())
4186 return failure();
4187 handlerTypes.push_back(catchTypeAttr);
4188
4189 if (parser.parseRParen().failed())
4190 return failure();
4191
4192 if (parser.parseColon().failed())
4193 return failure();
4194
4195 Block *dest;
4196 if (parser.parseSuccessor(dest).failed())
4197 return failure();
4198 catchDestinations.push_back(dest);
4199 return success();
4200 };
4201
4202 if (parser.parseCommaSeparatedList(parseHandler).failed())
4203 return failure();
4204
4205 if (parser.parseRSquare().failed())
4206 return failure();
4207
4208 // Verify we have catch_all or unwind.
4209 if (!hasCatchAll && !hasUnwind)
4210 return parser.emitError(parser.getCurrentLocation(),
4211 "must have either 'catch_all' or 'unwind' handler");
4212
4213 // Add attributes and successors.
4214 if (!handlerTypes.empty())
4215 catchTypes = parser.getBuilder().getArrayAttr(handlerTypes);
4216
4217 if (hasCatchAll)
4218 defaultIsCatchAll = parser.getBuilder().getUnitAttr();
4219
4220 return success();
4221}
4222
4223static void printEhDispatchDestinations(OpAsmPrinter &p, cir::EhDispatchOp op,
4224 mlir::ArrayAttr catchTypes,
4225 SuccessorRange catchDestinations,
4226 Block *defaultDestination,
4227 mlir::UnitAttr defaultIsCatchAll) {
4228 p << " [";
4229 p.printNewline();
4230
4231 // If we have at least one catch type, print them.
4232 if (catchTypes) {
4233 // Print type handlers using 'catch(<attr>) : ^block' syntax.
4234 llvm::interleave(
4235 llvm::zip(catchTypes, catchDestinations),
4236 [&](auto i) {
4237 p << " catch(";
4238 p.printAttribute(std::get<0>(i));
4239 p << ") : ";
4240 p.printSuccessor(std::get<1>(i));
4241 },
4242 [&] {
4243 p << ',';
4244 p.printNewline();
4245 });
4246
4247 p << ", ";
4248 p.printNewline();
4249 }
4250
4251 // Print catch_all or unwind handler.
4252 if (defaultIsCatchAll)
4253 p << " catch_all : ";
4254 else
4255 p << " unwind : ";
4256 p.printSuccessor(defaultDestination);
4257 p.printNewline();
4258
4259 p << "]";
4260}
4261
4262//===----------------------------------------------------------------------===//
4263// TableGen'd op method definitions
4264//===----------------------------------------------------------------------===//
4265
4266#define GET_OP_CLASSES
4267#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
static const MemRegion * getRegion(const CallEvent &Call, const MutexDescriptor &Descriptor, bool IsLock)
static void printEhDispatchDestinations(OpAsmPrinter &p, cir::EhDispatchOp op, mlir::ArrayAttr catchTypes, SuccessorRange catchDestinations, Block *defaultDestination, mlir::UnitAttr defaultIsCatchAll)
static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op, cir::FuncOp function)
static LogicalResult verifyMemberPtrCast(Operation *op, mlir::Value src, mlir::Type resultTy)
static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, mlir::OperationState &result, bool hasDestinationBlocks=false)
static bool isIntOrBoolCast(cir::CastOp op)
static ParseResult parseEhDispatchDestinations(OpAsmParser &parser, mlir::ArrayAttr &catchTypes, SmallVectorImpl< Block * > &catchDestinations, Block *&defaultDestination, mlir::UnitAttr &defaultIsCatchAll)
static void printConstant(OpAsmPrinter &p, Attribute value)
static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser, mlir::Region &region)
ParseResult parseInlineKindAttr(OpAsmParser &parser, cir::InlineKindAttr &inlineKindAttr)
void printInlineKindAttr(OpAsmPrinter &p, cir::InlineKindAttr inlineKindAttr)
static ParseResult parseSwitchFlatOpCases(OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< llvm::SmallVector< OpAsmParser::UnresolvedOperand > > &caseOperands, SmallVectorImpl< llvm::SmallVector< Type > > &caseOperandTypes)
<cases> ::= [ (case (, case )* )?
void printGlobalAddressSpaceValue(mlir::AsmPrinter &printer, cir::GlobalOp op, mlir::ptr::MemorySpaceAttrInterface attr)
static void printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym, mlir::Value indirectCallee, mlir::OpAsmPrinter &printer, bool isNothrow, cir::SideEffect sideEffect, ArrayAttr argAttrs, ArrayAttr resAttrs, mlir::Block *normalDest=nullptr, mlir::Block *unwindDest=nullptr)
static LogicalResult verifyCallCommInSymbolUses(mlir::Operation *op, SymbolTableCollection &symbolTable)
static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region, SMLoc errLoc)
static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValueAttr, mlir::Region &ctorRegion, mlir::Region &dtorRegion)
void printIndirectBrOpSucessors(OpAsmPrinter &p, cir::IndirectBrOp op, Type flagType, SuccessorRange succs, OperandRangeRange succOperands, const TypeRangeRange &succOperandsTypes)
static OpFoldResult foldUnaryBitOp(mlir::Attribute inputAttr, llvm::function_ref< llvm::APInt(const llvm::APInt &)> func, bool poisonZero=false)
static llvm::StringRef getLinkageAttrNameString()
Returns the name used for the linkage attribute.
static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue)
Parse an enum from the keyword, or default to the provided default value.
mlir::OptionalParseResult parseGlobalAddressSpaceValue(mlir::AsmParser &p, mlir::ptr::MemorySpaceAttrInterface &attr)
static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op, Type flagType, mlir::ArrayAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static mlir::ParseResult parseTryCallDestinations(mlir::OpAsmParser &parser, mlir::OperationState &result)
static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op, TypeAttr type, Attribute initAttr, mlir::Region &ctorRegion, mlir::Region &dtorRegion)
static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result)
Parse an enum from the keyword, return failure if the keyword is not found.
static Value tryFoldCastChain(cir::CastOp op)
static void printTryHandlerRegions(mlir::OpAsmPrinter &printer, cir::TryOp op, mlir::MutableArrayRef< mlir::Region > handlerRegions, mlir::ArrayAttr handlerTypes)
ParseResult parseIndirectBrOpSucessors(OpAsmParser &parser, Type &flagType, SmallVectorImpl< Block * > &succOperandBlocks, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &succOperands, SmallVectorImpl< SmallVector< Type > > &succOperandsTypes)
static bool omitRegionTerm(mlir::Region &r)
static LogicalResult verifyBinaryOverflowOp(mlir::Operation *op, bool noSignedWrap, bool noUnsignedWrap, bool saturated, bool hasSat)
static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer, cir::ScopeOp &op, mlir::Region &region)
static ParseResult parseConstantValue(OpAsmParser &parser, mlir::Attribute &valueAttr)
static LogicalResult verifyArrayCtorDtor(Op op)
static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, mlir::Attribute attrType)
static mlir::ParseResult parseTryHandlerRegions(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< std::unique_ptr< mlir::Region > > &handlerRegions, mlir::ArrayAttr &handlerTypes)
#define REGISTER_ENUM_TYPE(Ty)
static int parseOptionalKeywordAlternative(AsmParser &parser, ArrayRef< llvm::StringRef > keywords)
llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> BuilderCallbackRef
Definition CIRDialect.h:37
llvm::function_ref< void( mlir::OpBuilder &, mlir::Location, mlir::OperationState &)> BuilderOpStateCallbackRef
Definition CIRDialect.h:39
static std::optional< NonLoc > getIndex(ProgramStateRef State, const ElementRegion *ER, CharKind CK)
static Decl::Kind getKind(const Decl *D)
TokenType getType() const
Returns the token's type, e.g.
tooling::Replacements cleanup(const FormatStyle &Style, StringRef Code, ArrayRef< tooling::Range > Ranges, StringRef FileName="<stdin>")
Clean up any erroneous/redundant code in the given Ranges in Code.
*collection of selector each with an associated kind and an ordered *collection of selectors A selector has a an optional score condition
*collection of selector each with an associated kind and an ordered *collection of selectors A selector has a kind
__device__ __2f16 float c
void buildTerminatedBody(mlir::OpBuilder &builder, mlir::Location loc)
mlir::ptr::MemorySpaceAttrInterface normalizeDefaultAddressSpace(mlir::ptr::MemorySpaceAttrInterface addrSpace)
Normalize LangAddressSpace::Default to null (empty attribute).
const internal::VariadicAllOfMatcher< Attr > attr
const AstTypeMatcher< RecordType > recordType
StringRef getName(const HeaderType T)
Definition HeaderFile.h:38
RangeSelector name(std::string ID)
Given a node with a "name", (like NamedDecl, DeclRefExpr, CxxCtorInitializer, and TypeLoc) selects th...
nullptr
This class represents a compute construct, representing a 'Kind' of ‘parallel’, 'serial',...
__DEVICE__ _Tp arg(const std::complex< _Tp > &__c)
static bool memberFuncPtrCast()
static bool opCallCallConv()
static bool opScopeCleanupRegion()
static bool supportIFuncAttr()