clang 23.0.0git
CIRDialect.cpp
Go to the documentation of this file.
1//===- CIRDialect.cpp - MLIR CIR ops implementation -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the CIR dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
18
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/DialectImplementation.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Interfaces/ControlFlowInterfaces.h"
23#include "mlir/Interfaces/FunctionImplementation.h"
24#include "mlir/Support/LLVM.h"
25
26#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
27#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
29#include "llvm/ADT/SetOperations.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/LogicalResult.h"
33
34using namespace mlir;
35using namespace cir;
36
37//===----------------------------------------------------------------------===//
38// CIR Dialect
39//===----------------------------------------------------------------------===//
40namespace {
41struct CIROpAsmDialectInterface : public OpAsmDialectInterface {
42 using OpAsmDialectInterface::OpAsmDialectInterface;
43
44 AliasResult getAlias(Type type, raw_ostream &os) const final {
45 if (auto recordType = dyn_cast<cir::RecordType>(type)) {
46 StringAttr nameAttr = recordType.getName();
47 if (!nameAttr)
48 os << "rec_anon_" << recordType.getKindAsStr();
49 else
50 os << "rec_" << nameAttr.getValue();
51 return AliasResult::OverridableAlias;
52 }
53 if (auto intType = dyn_cast<cir::IntType>(type)) {
54 // We only provide alias for standard integer types (i.e. integer types
55 // whose width is a power of 2 and at least 8).
56 unsigned width = intType.getWidth();
57 if (width < 8 || !llvm::isPowerOf2_32(width))
58 return AliasResult::NoAlias;
59 os << intType.getAlias();
60 return AliasResult::OverridableAlias;
61 }
62 if (auto voidType = dyn_cast<cir::VoidType>(type)) {
63 os << voidType.getAlias();
64 return AliasResult::OverridableAlias;
65 }
66
67 return AliasResult::NoAlias;
68 }
69
70 AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
71 if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr)) {
72 os << (boolAttr.getValue() ? "true" : "false");
73 return AliasResult::FinalAlias;
74 }
75 if (auto bitfield = mlir::dyn_cast<cir::BitfieldInfoAttr>(attr)) {
76 os << "bfi_" << bitfield.getName().str();
77 return AliasResult::FinalAlias;
78 }
79 if (auto dynCastInfoAttr = mlir::dyn_cast<cir::DynamicCastInfoAttr>(attr)) {
80 os << dynCastInfoAttr.getAlias();
81 return AliasResult::FinalAlias;
82 }
83 return AliasResult::NoAlias;
84 }
85};
86} // namespace
87
88void cir::CIRDialect::initialize() {
89 registerTypes();
90 registerAttributes();
91 addOperations<
92#define GET_OP_LIST
93#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
94 >();
95 addInterfaces<CIROpAsmDialectInterface>();
96}
97
98Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
99 mlir::Attribute value,
100 mlir::Type type,
101 mlir::Location loc) {
102 return cir::ConstantOp::create(builder, loc, type,
103 mlir::cast<mlir::TypedAttr>(value));
104}
105
106//===----------------------------------------------------------------------===//
107// Helpers
108//===----------------------------------------------------------------------===//
109
110// Parses one of the keywords provided in the list `keywords` and returns the
111// position of the parsed keyword in the list. If none of the keywords from the
112// list is parsed, returns -1.
113static int parseOptionalKeywordAlternative(AsmParser &parser,
114 ArrayRef<llvm::StringRef> keywords) {
115 for (auto en : llvm::enumerate(keywords)) {
116 if (succeeded(parser.parseOptionalKeyword(en.value())))
117 return en.index();
118 }
119 return -1;
120}
121
122namespace {
123template <typename Ty> struct EnumTraits {};
124
125#define REGISTER_ENUM_TYPE(Ty) \
126 template <> struct EnumTraits<cir::Ty> { \
127 static llvm::StringRef stringify(cir::Ty value) { \
128 return stringify##Ty(value); \
129 } \
130 static unsigned getMaxEnumVal() { return cir::getMaxEnumValFor##Ty(); } \
131 }
132
133REGISTER_ENUM_TYPE(GlobalLinkageKind);
134REGISTER_ENUM_TYPE(VisibilityKind);
135REGISTER_ENUM_TYPE(SideEffect);
136} // namespace
137
138/// Parse an enum from the keyword, or default to the provided default value.
139/// The return type is the enum type by default, unless overriden with the
140/// second template argument.
141template <typename EnumTy, typename RetTy = EnumTy>
142static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue) {
144 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
145 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
146
147 int index = parseOptionalKeywordAlternative(parser, names);
148 if (index == -1)
149 return static_cast<RetTy>(defaultValue);
150 return static_cast<RetTy>(index);
151}
152
153/// Parse an enum from the keyword, return failure if the keyword is not found.
154template <typename EnumTy, typename RetTy = EnumTy>
155static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result) {
157 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
158 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
159
160 int index = parseOptionalKeywordAlternative(parser, names);
161 if (index == -1)
162 return failure();
163 result = static_cast<RetTy>(index);
164 return success();
165}
166
167// Check if a region's termination omission is valid and, if so, creates and
168// inserts the omitted terminator into the region.
169static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region,
170 SMLoc errLoc) {
171 Location eLoc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
172 OpBuilder builder(parser.getBuilder().getContext());
173
174 // Insert empty block in case the region is empty to ensure the terminator
175 // will be inserted
176 if (region.empty())
177 builder.createBlock(&region);
178
179 Block &block = region.back();
180 // Region is properly terminated: nothing to do.
181 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>())
182 return success();
183
184 // Check for invalid terminator omissions.
185 if (!region.hasOneBlock())
186 return parser.emitError(errLoc,
187 "multi-block region must not omit terminator");
188
189 // Terminator was omitted correctly: recreate it.
190 builder.setInsertionPointToEnd(&block);
191 cir::YieldOp::create(builder, eLoc);
192 return success();
193}
194
195// True if the region's terminator should be omitted.
196static bool omitRegionTerm(mlir::Region &r) {
197 const auto singleNonEmptyBlock = r.hasOneBlock() && !r.back().empty();
198 const auto yieldsNothing = [&r]() {
199 auto y = dyn_cast<cir::YieldOp>(r.back().getTerminator());
200 return y && y.getArgs().empty();
201 };
202 return singleNonEmptyBlock && yieldsNothing();
203}
204
205void printVisibilityAttr(OpAsmPrinter &printer,
206 cir::VisibilityAttr &visibility) {
207 switch (visibility.getValue()) {
208 case cir::VisibilityKind::Hidden:
209 printer << "hidden";
210 break;
211 case cir::VisibilityKind::Protected:
212 printer << "protected";
213 break;
214 case cir::VisibilityKind::Default:
215 break;
216 }
217}
218
219void parseVisibilityAttr(OpAsmParser &parser, cir::VisibilityAttr &visibility) {
220 cir::VisibilityKind visibilityKind =
221 parseOptionalCIRKeyword(parser, cir::VisibilityKind::Default);
222 visibility = cir::VisibilityAttr::get(parser.getContext(), visibilityKind);
223}
224
225//===----------------------------------------------------------------------===//
226// InlineKindAttr (FIXME: remove once FuncOp uses assembly format)
227//===----------------------------------------------------------------------===//
228
229ParseResult parseInlineKindAttr(OpAsmParser &parser,
230 cir::InlineKindAttr &inlineKindAttr) {
231 // Static list of possible inline kind keywords
232 static constexpr llvm::StringRef keywords[] = {"no_inline", "always_inline",
233 "inline_hint"};
234
235 // Parse the inline kind keyword (optional)
236 llvm::StringRef keyword;
237 if (parser.parseOptionalKeyword(&keyword, keywords).failed()) {
238 // Not an inline kind keyword, leave inlineKindAttr empty
239 return success();
240 }
241
242 // Parse the enum value from the keyword
243 auto inlineKindResult = ::cir::symbolizeEnum<::cir::InlineKind>(keyword);
244 if (!inlineKindResult) {
245 return parser.emitError(parser.getCurrentLocation(), "expected one of [")
246 << llvm::join(llvm::ArrayRef(keywords), ", ")
247 << "] for inlineKind, got: " << keyword;
248 }
249
250 inlineKindAttr =
251 ::cir::InlineKindAttr::get(parser.getContext(), *inlineKindResult);
252 return success();
253}
254
255void printInlineKindAttr(OpAsmPrinter &p, cir::InlineKindAttr inlineKindAttr) {
256 if (inlineKindAttr) {
257 p << " " << stringifyInlineKind(inlineKindAttr.getValue());
258 }
259}
260//===----------------------------------------------------------------------===//
261// CIR Custom Parsers/Printers
262//===----------------------------------------------------------------------===//
263
264static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser,
265 mlir::Region &region) {
266 auto regionLoc = parser.getCurrentLocation();
267 if (parser.parseRegion(region))
268 return failure();
269 if (ensureRegionTerm(parser, region, regionLoc).failed())
270 return failure();
271 return success();
272}
273
274static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer,
275 cir::ScopeOp &op,
276 mlir::Region &region) {
277 printer.printRegion(region,
278 /*printEntryBlockArgs=*/false,
279 /*printBlockTerminators=*/!omitRegionTerm(region));
280}
281
282//===----------------------------------------------------------------------===//
283// AllocaOp
284//===----------------------------------------------------------------------===//
285
286void cir::AllocaOp::build(mlir::OpBuilder &odsBuilder,
287 mlir::OperationState &odsState, mlir::Type addr,
288 mlir::Type allocaType, llvm::StringRef name,
289 mlir::IntegerAttr alignment) {
290 odsState.addAttribute(getAllocaTypeAttrName(odsState.name),
291 mlir::TypeAttr::get(allocaType));
292 odsState.addAttribute(getNameAttrName(odsState.name),
293 odsBuilder.getStringAttr(name));
294 if (alignment) {
295 odsState.addAttribute(getAlignmentAttrName(odsState.name), alignment);
296 }
297 odsState.addTypes(addr);
298}
299
300//===----------------------------------------------------------------------===//
301// BreakOp
302//===----------------------------------------------------------------------===//
303
304LogicalResult cir::BreakOp::verify() {
306 if (!getOperation()->getParentOfType<LoopOpInterface>() &&
307 !getOperation()->getParentOfType<SwitchOp>())
308 return emitOpError("must be within a loop");
309 return success();
310}
311
312//===----------------------------------------------------------------------===//
313// ConditionOp
314//===----------------------------------------------------------------------===//
315
316//===----------------------------------
317// BranchOpTerminatorInterface Methods
318//===----------------------------------
319
320void cir::ConditionOp::getSuccessorRegions(
321 ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
322 // TODO(cir): The condition value may be folded to a constant, narrowing
323 // down its list of possible successors.
324
325 // Parent is a loop: condition may branch to the body or to the parent op.
326 if (auto loopOp = dyn_cast<LoopOpInterface>(getOperation()->getParentOp())) {
327 regions.emplace_back(&loopOp.getBody());
328 regions.push_back(RegionSuccessor::parent());
329 }
330
331 // Parent is an await: condition may branch to resume or suspend regions.
332 auto await = cast<AwaitOp>(getOperation()->getParentOp());
333 regions.emplace_back(&await.getResume());
334 regions.emplace_back(&await.getSuspend());
335}
336
337MutableOperandRange
338cir::ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
339 // No values are yielded to the successor region.
340 return MutableOperandRange(getOperation(), 0, 0);
341}
342
343LogicalResult cir::ConditionOp::verify() {
344 if (!isa<LoopOpInterface, AwaitOp>(getOperation()->getParentOp()))
345 return emitOpError("condition must be within a conditional region");
346 return success();
347}
348
349//===----------------------------------------------------------------------===//
350// ConstantOp
351//===----------------------------------------------------------------------===//
352
353static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
354 mlir::Attribute attrType) {
355 if (isa<cir::ConstPtrAttr>(attrType)) {
356 if (!mlir::isa<cir::PointerType>(opType))
357 return op->emitOpError(
358 "pointer constant initializing a non-pointer type");
359 return success();
360 }
361
362 if (isa<cir::DataMemberAttr, cir::MethodAttr>(attrType)) {
363 // More detailed type verifications are already done in
364 // DataMemberAttr::verify or MethodAttr::verify. Don't need to repeat here.
365 return success();
366 }
367
368 if (isa<cir::ZeroAttr>(attrType)) {
369 if (isa<cir::RecordType, cir::ArrayType, cir::VectorType, cir::ComplexType>(
370 opType))
371 return success();
372 return op->emitOpError(
373 "zero expects struct, array, vector, or complex type");
374 }
375
376 if (mlir::isa<cir::UndefAttr>(attrType)) {
377 if (!mlir::isa<cir::VoidType>(opType))
378 return success();
379 return op->emitOpError("undef expects non-void type");
380 }
381
382 if (mlir::isa<cir::BoolAttr>(attrType)) {
383 if (!mlir::isa<cir::BoolType>(opType))
384 return op->emitOpError("result type (")
385 << opType << ") must be '!cir.bool' for '" << attrType << "'";
386 return success();
387 }
388
389 if (mlir::isa<cir::IntAttr, cir::FPAttr>(attrType)) {
390 auto at = cast<TypedAttr>(attrType);
391 if (at.getType() != opType) {
392 return op->emitOpError("result type (")
393 << opType << ") does not match value type (" << at.getType()
394 << ")";
395 }
396 return success();
397 }
398
399 if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
400 cir::ConstComplexAttr, cir::ConstRecordAttr,
401 cir::GlobalViewAttr, cir::PoisonAttr, cir::TypeInfoAttr,
402 cir::VTableAttr>(attrType))
403 return success();
404
405 assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
406 return op->emitOpError("global with type ")
407 << cast<TypedAttr>(attrType).getType() << " not yet supported";
408}
409
410LogicalResult cir::ConstantOp::verify() {
411 // ODS already generates checks to make sure the result type is valid. We just
412 // need to additionally check that the value's attribute type is consistent
413 // with the result type.
414 return checkConstantTypes(getOperation(), getType(), getValue());
415}
416
417OpFoldResult cir::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
418 return getValue();
419}
420
421//===----------------------------------------------------------------------===//
422// ContinueOp
423//===----------------------------------------------------------------------===//
424
425LogicalResult cir::ContinueOp::verify() {
426 if (!getOperation()->getParentOfType<LoopOpInterface>())
427 return emitOpError("must be within a loop");
428 return success();
429}
430
431//===----------------------------------------------------------------------===//
432// CastOp
433//===----------------------------------------------------------------------===//
434
435LogicalResult cir::CastOp::verify() {
436 mlir::Type resType = getType();
437 mlir::Type srcType = getSrc().getType();
438
439 // Verify address space casts for pointer types. given that
440 // casts for within a different address space are illegal.
441 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
442 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
443 if (srcPtrTy && resPtrTy && (getKind() != cir::CastKind::address_space))
444 if (srcPtrTy.getAddrSpace() != resPtrTy.getAddrSpace()) {
445 return emitOpError() << "result type address space does not match the "
446 "address space of the operand";
447 }
448
449 if (mlir::isa<cir::VectorType>(srcType) &&
450 mlir::isa<cir::VectorType>(resType)) {
451 // Use the element type of the vector to verify the cast kind. (Except for
452 // bitcast, see below.)
453 srcType = mlir::dyn_cast<cir::VectorType>(srcType).getElementType();
454 resType = mlir::dyn_cast<cir::VectorType>(resType).getElementType();
455 }
456
457 switch (getKind()) {
458 case cir::CastKind::int_to_bool: {
459 if (!mlir::isa<cir::BoolType>(resType))
460 return emitOpError() << "requires !cir.bool type for result";
461 if (!mlir::isa<cir::IntType>(srcType))
462 return emitOpError() << "requires !cir.int type for source";
463 return success();
464 }
465 case cir::CastKind::ptr_to_bool: {
466 if (!mlir::isa<cir::BoolType>(resType))
467 return emitOpError() << "requires !cir.bool type for result";
468 if (!mlir::isa<cir::PointerType>(srcType))
469 return emitOpError() << "requires !cir.ptr type for source";
470 return success();
471 }
472 case cir::CastKind::integral: {
473 if (!mlir::isa<cir::IntType>(resType))
474 return emitOpError() << "requires !cir.int type for result";
475 if (!mlir::isa<cir::IntType>(srcType))
476 return emitOpError() << "requires !cir.int type for source";
477 return success();
478 }
479 case cir::CastKind::array_to_ptrdecay: {
480 const auto arrayPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
481 const auto flatPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
482 if (!arrayPtrTy || !flatPtrTy)
483 return emitOpError() << "requires !cir.ptr type for source and result";
484
485 // TODO(CIR): Make sure the AddrSpace of both types are equals
486 return success();
487 }
488 case cir::CastKind::bitcast: {
489 // Handle the pointer types first.
490 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
491 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
492
493 if (srcPtrTy && resPtrTy) {
494 return success();
495 }
496
497 return success();
498 }
499 case cir::CastKind::floating: {
500 if (!mlir::isa<cir::FPTypeInterface>(srcType) ||
501 !mlir::isa<cir::FPTypeInterface>(resType))
502 return emitOpError() << "requires !cir.float type for source and result";
503 return success();
504 }
505 case cir::CastKind::float_to_int: {
506 if (!mlir::isa<cir::FPTypeInterface>(srcType))
507 return emitOpError() << "requires !cir.float type for source";
508 if (!mlir::dyn_cast<cir::IntType>(resType))
509 return emitOpError() << "requires !cir.int type for result";
510 return success();
511 }
512 case cir::CastKind::int_to_ptr: {
513 if (!mlir::dyn_cast<cir::IntType>(srcType))
514 return emitOpError() << "requires !cir.int type for source";
515 if (!mlir::dyn_cast<cir::PointerType>(resType))
516 return emitOpError() << "requires !cir.ptr type for result";
517 return success();
518 }
519 case cir::CastKind::ptr_to_int: {
520 if (!mlir::dyn_cast<cir::PointerType>(srcType))
521 return emitOpError() << "requires !cir.ptr type for source";
522 if (!mlir::dyn_cast<cir::IntType>(resType))
523 return emitOpError() << "requires !cir.int type for result";
524 return success();
525 }
526 case cir::CastKind::float_to_bool: {
527 if (!mlir::isa<cir::FPTypeInterface>(srcType))
528 return emitOpError() << "requires !cir.float type for source";
529 if (!mlir::isa<cir::BoolType>(resType))
530 return emitOpError() << "requires !cir.bool type for result";
531 return success();
532 }
533 case cir::CastKind::bool_to_int: {
534 if (!mlir::isa<cir::BoolType>(srcType))
535 return emitOpError() << "requires !cir.bool type for source";
536 if (!mlir::isa<cir::IntType>(resType))
537 return emitOpError() << "requires !cir.int type for result";
538 return success();
539 }
540 case cir::CastKind::int_to_float: {
541 if (!mlir::isa<cir::IntType>(srcType))
542 return emitOpError() << "requires !cir.int type for source";
543 if (!mlir::isa<cir::FPTypeInterface>(resType))
544 return emitOpError() << "requires !cir.float type for result";
545 return success();
546 }
547 case cir::CastKind::bool_to_float: {
548 if (!mlir::isa<cir::BoolType>(srcType))
549 return emitOpError() << "requires !cir.bool type for source";
550 if (!mlir::isa<cir::FPTypeInterface>(resType))
551 return emitOpError() << "requires !cir.float type for result";
552 return success();
553 }
554 case cir::CastKind::address_space: {
555 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
556 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
557 if (!srcPtrTy || !resPtrTy)
558 return emitOpError() << "requires !cir.ptr type for source and result";
559 if (srcPtrTy.getPointee() != resPtrTy.getPointee())
560 return emitOpError() << "requires two types differ in addrspace only";
561 return success();
562 }
563 case cir::CastKind::float_to_complex: {
564 if (!mlir::isa<cir::FPTypeInterface>(srcType))
565 return emitOpError() << "requires !cir.float type for source";
566 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
567 if (!resComplexTy)
568 return emitOpError() << "requires !cir.complex type for result";
569 if (srcType != resComplexTy.getElementType())
570 return emitOpError() << "requires source type match result element type";
571 return success();
572 }
573 case cir::CastKind::int_to_complex: {
574 if (!mlir::isa<cir::IntType>(srcType))
575 return emitOpError() << "requires !cir.int type for source";
576 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
577 if (!resComplexTy)
578 return emitOpError() << "requires !cir.complex type for result";
579 if (srcType != resComplexTy.getElementType())
580 return emitOpError() << "requires source type match result element type";
581 return success();
582 }
583 case cir::CastKind::float_complex_to_real: {
584 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
585 if (!srcComplexTy)
586 return emitOpError() << "requires !cir.complex type for source";
587 if (!mlir::isa<cir::FPTypeInterface>(resType))
588 return emitOpError() << "requires !cir.float type for result";
589 if (srcComplexTy.getElementType() != resType)
590 return emitOpError() << "requires source element type match result type";
591 return success();
592 }
593 case cir::CastKind::int_complex_to_real: {
594 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
595 if (!srcComplexTy)
596 return emitOpError() << "requires !cir.complex type for source";
597 if (!mlir::isa<cir::IntType>(resType))
598 return emitOpError() << "requires !cir.int type for result";
599 if (srcComplexTy.getElementType() != resType)
600 return emitOpError() << "requires source element type match result type";
601 return success();
602 }
603 case cir::CastKind::float_complex_to_bool: {
604 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
605 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
606 return emitOpError()
607 << "requires floating point !cir.complex type for source";
608 if (!mlir::isa<cir::BoolType>(resType))
609 return emitOpError() << "requires !cir.bool type for result";
610 return success();
611 }
612 case cir::CastKind::int_complex_to_bool: {
613 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
614 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
615 return emitOpError()
616 << "requires floating point !cir.complex type for source";
617 if (!mlir::isa<cir::BoolType>(resType))
618 return emitOpError() << "requires !cir.bool type for result";
619 return success();
620 }
621 case cir::CastKind::float_complex: {
622 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
623 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
624 return emitOpError()
625 << "requires floating point !cir.complex type for source";
626 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
627 if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
628 return emitOpError()
629 << "requires floating point !cir.complex type for result";
630 return success();
631 }
632 case cir::CastKind::float_complex_to_int_complex: {
633 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
634 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
635 return emitOpError()
636 << "requires floating point !cir.complex type for source";
637 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
638 if (!resComplexTy || !resComplexTy.isIntegerComplex())
639 return emitOpError() << "requires integer !cir.complex type for result";
640 return success();
641 }
642 case cir::CastKind::int_complex: {
643 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
644 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
645 return emitOpError() << "requires integer !cir.complex type for source";
646 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
647 if (!resComplexTy || !resComplexTy.isIntegerComplex())
648 return emitOpError() << "requires integer !cir.complex type for result";
649 return success();
650 }
651 case cir::CastKind::int_complex_to_float_complex: {
652 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
653 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
654 return emitOpError() << "requires integer !cir.complex type for source";
655 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
656 if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
657 return emitOpError()
658 << "requires floating point !cir.complex type for result";
659 return success();
660 }
661 case cir::CastKind::member_ptr_to_bool: {
662 if (!mlir::isa<cir::DataMemberType, cir::MethodType>(srcType))
663 return emitOpError()
664 << "requires !cir.data_member or !cir.method type for source";
665 if (!mlir::isa<cir::BoolType>(resType))
666 return emitOpError() << "requires !cir.bool type for result";
667 return success();
668 }
669 }
670 llvm_unreachable("Unknown CastOp kind?");
671}
672
673static bool isIntOrBoolCast(cir::CastOp op) {
674 auto kind = op.getKind();
675 return kind == cir::CastKind::bool_to_int ||
676 kind == cir::CastKind::int_to_bool || kind == cir::CastKind::integral;
677}
678
679static Value tryFoldCastChain(cir::CastOp op) {
680 cir::CastOp head = op, tail = op;
681
682 while (op) {
683 if (!isIntOrBoolCast(op))
684 break;
685 head = op;
686 op = head.getSrc().getDefiningOp<cir::CastOp>();
687 }
688
689 if (head == tail)
690 return {};
691
692 // if bool_to_int -> ... -> int_to_bool: take the bool
693 // as we had it was before all casts
694 if (head.getKind() == cir::CastKind::bool_to_int &&
695 tail.getKind() == cir::CastKind::int_to_bool)
696 return head.getSrc();
697
698 // if int_to_bool -> ... -> int_to_bool: take the result
699 // of the first one, as no other casts (and ext casts as well)
700 // don't change the first result
701 if (head.getKind() == cir::CastKind::int_to_bool &&
702 tail.getKind() == cir::CastKind::int_to_bool)
703 return head.getResult();
704
705 return {};
706}
707
708OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
709 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getSrc())) {
710 // Propagate poison value
711 return cir::PoisonAttr::get(getContext(), getType());
712 }
713
714 if (getSrc().getType() == getType()) {
715 switch (getKind()) {
716 case cir::CastKind::integral: {
718 auto foldOrder = getSrc().getDefiningOp()->fold(foldResults);
719 if (foldOrder.succeeded() && mlir::isa<mlir::Attribute>(foldResults[0]))
720 return mlir::cast<mlir::Attribute>(foldResults[0]);
721 return {};
722 }
723 case cir::CastKind::bitcast:
724 case cir::CastKind::address_space:
725 case cir::CastKind::float_complex:
726 case cir::CastKind::int_complex: {
727 return getSrc();
728 }
729 default:
730 return {};
731 }
732 }
733
734 // Handle cases where a chain of casts cancel out.
735 Value result = tryFoldCastChain(*this);
736 if (result)
737 return result;
738
739 // Handle simple constant casts.
740 if (auto srcConst = getSrc().getDefiningOp<cir::ConstantOp>()) {
741 switch (getKind()) {
742 case cir::CastKind::integral: {
743 mlir::Type srcTy = getSrc().getType();
744 // Don't try to fold vector casts for now.
745 assert(mlir::isa<cir::VectorType>(srcTy) ==
746 mlir::isa<cir::VectorType>(getType()));
747 if (mlir::isa<cir::VectorType>(srcTy))
748 break;
749
750 auto srcIntTy = mlir::cast<cir::IntType>(srcTy);
751 auto dstIntTy = mlir::cast<cir::IntType>(getType());
752 APInt newVal =
753 srcIntTy.isSigned()
754 ? srcConst.getIntValue().sextOrTrunc(dstIntTy.getWidth())
755 : srcConst.getIntValue().zextOrTrunc(dstIntTy.getWidth());
756 return cir::IntAttr::get(dstIntTy, newVal);
757 }
758 default:
759 break;
760 }
761 }
762 return {};
763}
764
765//===----------------------------------------------------------------------===//
766// CallOp
767//===----------------------------------------------------------------------===//
768
769mlir::OperandRange cir::CallOp::getArgOperands() {
770 if (isIndirect())
771 return getArgs().drop_front(1);
772 return getArgs();
773}
774
775mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
776 mlir::MutableOperandRange args = getArgsMutable();
777 if (isIndirect())
778 return args.slice(1, args.size() - 1);
779 return args;
780}
781
782mlir::Value cir::CallOp::getIndirectCall() {
783 assert(isIndirect());
784 return getOperand(0);
785}
786
787/// Return the operand at index 'i'.
788Value cir::CallOp::getArgOperand(unsigned i) {
789 if (isIndirect())
790 ++i;
791 return getOperand(i);
792}
793
794/// Return the number of operands.
795unsigned cir::CallOp::getNumArgOperands() {
796 if (isIndirect())
797 return this->getOperation()->getNumOperands() - 1;
798 return this->getOperation()->getNumOperands();
799}
800
801static mlir::ParseResult
802parseTryCallDestinations(mlir::OpAsmParser &parser,
803 mlir::OperationState &result) {
804 mlir::Block *normalDestSuccessor;
805 if (parser.parseSuccessor(normalDestSuccessor))
806 return mlir::failure();
807
808 if (parser.parseComma())
809 return mlir::failure();
810
811 mlir::Block *unwindDestSuccessor;
812 if (parser.parseSuccessor(unwindDestSuccessor))
813 return mlir::failure();
814
815 result.addSuccessors(normalDestSuccessor);
816 result.addSuccessors(unwindDestSuccessor);
817 return mlir::success();
818}
819
820static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
821 mlir::OperationState &result,
822 bool hasDestinationBlocks = false) {
824 llvm::SMLoc opsLoc;
825 mlir::FlatSymbolRefAttr calleeAttr;
826 llvm::ArrayRef<mlir::Type> allResultTypes;
827
828 // If we cannot parse a string callee, it means this is an indirect call.
829 if (!parser
830 .parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
831 result.attributes)
832 .has_value()) {
833 OpAsmParser::UnresolvedOperand indirectVal;
834 // Do not resolve right now, since we need to figure out the type
835 if (parser.parseOperand(indirectVal).failed())
836 return failure();
837 ops.push_back(indirectVal);
838 }
839
840 if (parser.parseLParen())
841 return mlir::failure();
842
843 opsLoc = parser.getCurrentLocation();
844 if (parser.parseOperandList(ops))
845 return mlir::failure();
846 if (parser.parseRParen())
847 return mlir::failure();
848
849 if (hasDestinationBlocks &&
850 parseTryCallDestinations(parser, result).failed()) {
851 return ::mlir::failure();
852 }
853
854 if (parser.parseOptionalKeyword("nothrow").succeeded())
855 result.addAttribute(CIRDialect::getNoThrowAttrName(),
856 mlir::UnitAttr::get(parser.getContext()));
857
858 if (parser.parseOptionalKeyword("side_effect").succeeded()) {
859 if (parser.parseLParen().failed())
860 return failure();
861 cir::SideEffect sideEffect;
862 if (parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed())
863 return failure();
864 if (parser.parseRParen().failed())
865 return failure();
866 auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
867 result.addAttribute(CIRDialect::getSideEffectAttrName(), attr);
868 }
869
870 if (parser.parseOptionalAttrDict(result.attributes))
871 return ::mlir::failure();
872
873 if (parser.parseColon())
874 return ::mlir::failure();
875
876 mlir::FunctionType opsFnTy;
877 if (parser.parseType(opsFnTy))
878 return mlir::failure();
879
880 allResultTypes = opsFnTy.getResults();
881 result.addTypes(allResultTypes);
882
883 if (parser.resolveOperands(ops, opsFnTy.getInputs(), opsLoc, result.operands))
884 return mlir::failure();
885
886 return mlir::success();
887}
888
889static void printCallCommon(mlir::Operation *op,
890 mlir::FlatSymbolRefAttr calleeSym,
891 mlir::Value indirectCallee,
892 mlir::OpAsmPrinter &printer, bool isNothrow,
893 cir::SideEffect sideEffect,
894 mlir::Block *normalDest = nullptr,
895 mlir::Block *unwindDest = nullptr) {
896 printer << ' ';
897
898 auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
899 auto ops = callLikeOp.getArgOperands();
900
901 if (calleeSym) {
902 // Direct calls
903 printer.printAttributeWithoutType(calleeSym);
904 } else {
905 // Indirect calls
906 assert(indirectCallee);
907 printer << indirectCallee;
908 }
909
910 printer << "(" << ops << ")";
911
912 if (normalDest) {
913 assert(unwindDest && "expected two successors");
914 auto tryCall = cast<cir::TryCallOp>(op);
915 printer << ' ' << tryCall.getNormalDest();
916 printer << ",";
917 printer << ' ';
918 printer << tryCall.getUnwindDest();
919 }
920
921 if (isNothrow)
922 printer << " nothrow";
923
924 if (sideEffect != cir::SideEffect::All) {
925 printer << " side_effect(";
926 printer << stringifySideEffect(sideEffect);
927 printer << ")";
928 }
929
931 CIRDialect::getCalleeAttrName(), CIRDialect::getNoThrowAttrName(),
932 CIRDialect::getSideEffectAttrName(),
933 CIRDialect::getOperandSegmentSizesAttrName()};
934 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
935 printer << " : ";
936 printer.printFunctionalType(op->getOperands().getTypes(),
937 op->getResultTypes());
938}
939
940mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
941 mlir::OperationState &result) {
942 return parseCallCommon(parser, result);
943}
944
945void cir::CallOp::print(mlir::OpAsmPrinter &p) {
946 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
947 cir::SideEffect sideEffect = getSideEffect();
948 printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
949 sideEffect);
950}
951
952static LogicalResult
953verifyCallCommInSymbolUses(mlir::Operation *op,
954 SymbolTableCollection &symbolTable) {
955 auto fnAttr =
956 op->getAttrOfType<FlatSymbolRefAttr>(CIRDialect::getCalleeAttrName());
957 if (!fnAttr) {
958 // This is an indirect call, thus we don't have to check the symbol uses.
959 return mlir::success();
960 }
961
962 auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
963 if (!fn)
964 return op->emitOpError() << "'" << fnAttr.getValue()
965 << "' does not reference a valid function";
966
967 auto callIf = dyn_cast<cir::CIRCallOpInterface>(op);
968 assert(callIf && "expected CIR call interface to be always available");
969
970 // Verify that the operand and result types match the callee. Note that
971 // argument-checking is disabled for functions without a prototype.
972 auto fnType = fn.getFunctionType();
973 if (!fn.getNoProto()) {
974 unsigned numCallOperands = callIf.getNumArgOperands();
975 unsigned numFnOpOperands = fnType.getNumInputs();
976
977 if (!fnType.isVarArg() && numCallOperands != numFnOpOperands)
978 return op->emitOpError("incorrect number of operands for callee");
979 if (fnType.isVarArg() && numCallOperands < numFnOpOperands)
980 return op->emitOpError("too few operands for callee");
981
982 for (unsigned i = 0, e = numFnOpOperands; i != e; ++i)
983 if (callIf.getArgOperand(i).getType() != fnType.getInput(i))
984 return op->emitOpError("operand type mismatch: expected operand type ")
985 << fnType.getInput(i) << ", but provided "
986 << op->getOperand(i).getType() << " for operand number " << i;
987 }
988
990
991 // Void function must not return any results.
992 if (fnType.hasVoidReturn() && op->getNumResults() != 0)
993 return op->emitOpError("callee returns void but call has results");
994
995 // Non-void function calls must return exactly one result.
996 if (!fnType.hasVoidReturn() && op->getNumResults() != 1)
997 return op->emitOpError("incorrect number of results for callee");
998
999 // Parent function and return value types must match.
1000 if (!fnType.hasVoidReturn() &&
1001 op->getResultTypes().front() != fnType.getReturnType()) {
1002 return op->emitOpError("result type mismatch: expected ")
1003 << fnType.getReturnType() << ", but provided "
1004 << op->getResult(0).getType();
1005 }
1006
1007 return mlir::success();
1008}
1009
1010LogicalResult
1011cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1012 return verifyCallCommInSymbolUses(*this, symbolTable);
1013}
1014
1015//===----------------------------------------------------------------------===//
1016// TryCallOp
1017//===----------------------------------------------------------------------===//
1018
1019mlir::OperandRange cir::TryCallOp::getArgOperands() {
1020 if (isIndirect())
1021 return getArgs().drop_front(1);
1022 return getArgs();
1023}
1024
1025mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
1026 mlir::MutableOperandRange args = getArgsMutable();
1027 if (isIndirect())
1028 return args.slice(1, args.size() - 1);
1029 return args;
1030}
1031
1032mlir::Value cir::TryCallOp::getIndirectCall() {
1033 assert(isIndirect());
1034 return getOperand(0);
1035}
1036
1037/// Return the operand at index 'i'.
1038Value cir::TryCallOp::getArgOperand(unsigned i) {
1039 if (isIndirect())
1040 ++i;
1041 return getOperand(i);
1042}
1043
1044/// Return the number of operands.
1045unsigned cir::TryCallOp::getNumArgOperands() {
1046 if (isIndirect())
1047 return this->getOperation()->getNumOperands() - 1;
1048 return this->getOperation()->getNumOperands();
1049}
1050
1051LogicalResult
1052cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1053 return verifyCallCommInSymbolUses(*this, symbolTable);
1054}
1055
1056mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
1057 mlir::OperationState &result) {
1058 return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
1059}
1060
1061void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
1062 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
1063 cir::SideEffect sideEffect = getSideEffect();
1064 printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
1065 sideEffect, getNormalDest(), getUnwindDest());
1066}
1067
1068//===----------------------------------------------------------------------===//
1069// ReturnOp
1070//===----------------------------------------------------------------------===//
1071
1072static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op,
1073 cir::FuncOp function) {
1074 // ReturnOps currently only have a single optional operand.
1075 if (op.getNumOperands() > 1)
1076 return op.emitOpError() << "expects at most 1 return operand";
1077
1078 // Ensure returned type matches the function signature.
1079 auto expectedTy = function.getFunctionType().getReturnType();
1080 auto actualTy =
1081 (op.getNumOperands() == 0 ? cir::VoidType::get(op.getContext())
1082 : op.getOperand(0).getType());
1083 if (actualTy != expectedTy)
1084 return op.emitOpError() << "returns " << actualTy
1085 << " but enclosing function returns " << expectedTy;
1086
1087 return mlir::success();
1088}
1089
1090mlir::LogicalResult cir::ReturnOp::verify() {
1091 // Returns can be present in multiple different scopes, get the
1092 // wrapping function and start from there.
1093 auto *fnOp = getOperation()->getParentOp();
1094 while (!isa<cir::FuncOp>(fnOp))
1095 fnOp = fnOp->getParentOp();
1096
1097 // Make sure return types match function return type.
1098 if (checkReturnAndFunction(*this, cast<cir::FuncOp>(fnOp)).failed())
1099 return failure();
1100
1101 return success();
1102}
1103
1104//===----------------------------------------------------------------------===//
1105// IfOp
1106//===----------------------------------------------------------------------===//
1107
1108ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
1109 // create the regions for 'then'.
1110 result.regions.reserve(2);
1111 Region *thenRegion = result.addRegion();
1112 Region *elseRegion = result.addRegion();
1113
1114 mlir::Builder &builder = parser.getBuilder();
1115 OpAsmParser::UnresolvedOperand cond;
1116 Type boolType = cir::BoolType::get(builder.getContext());
1117
1118 if (parser.parseOperand(cond) ||
1119 parser.resolveOperand(cond, boolType, result.operands))
1120 return failure();
1121
1122 // Parse 'then' region.
1123 mlir::SMLoc parseThenLoc = parser.getCurrentLocation();
1124 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1125 return failure();
1126
1127 if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed())
1128 return failure();
1129
1130 // If we find an 'else' keyword, parse the 'else' region.
1131 if (!parser.parseOptionalKeyword("else")) {
1132 mlir::SMLoc parseElseLoc = parser.getCurrentLocation();
1133 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1134 return failure();
1135 if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed())
1136 return failure();
1137 }
1138
1139 // Parse the optional attribute list.
1140 if (parser.parseOptionalAttrDict(result.attributes))
1141 return failure();
1142 return success();
1143}
1144
1145void cir::IfOp::print(OpAsmPrinter &p) {
1146 p << " " << getCondition() << " ";
1147 mlir::Region &thenRegion = this->getThenRegion();
1148 p.printRegion(thenRegion,
1149 /*printEntryBlockArgs=*/false,
1150 /*printBlockTerminators=*/!omitRegionTerm(thenRegion));
1151
1152 // Print the 'else' regions if it exists and has a block.
1153 mlir::Region &elseRegion = this->getElseRegion();
1154 if (!elseRegion.empty()) {
1155 p << " else ";
1156 p.printRegion(elseRegion,
1157 /*printEntryBlockArgs=*/false,
1158 /*printBlockTerminators=*/!omitRegionTerm(elseRegion));
1159 }
1160
1161 p.printOptionalAttrDict(getOperation()->getAttrs());
1162}
1163
1164/// Default callback for IfOp builders.
1165void cir::buildTerminatedBody(OpBuilder &builder, Location loc) {
1166 // add cir.yield to end of the block
1167 cir::YieldOp::create(builder, loc);
1168}
1169
1170/// Given the region at `index`, or the parent operation if `index` is None,
1171/// return the successor regions. These are the regions that may be selected
1172/// during the flow of control. `operands` is a set of optional attributes that
1173/// correspond to a constant value for each operand, or null if that operand is
1174/// not a constant.
1175void cir::IfOp::getSuccessorRegions(mlir::RegionBranchPoint point,
1176 SmallVectorImpl<RegionSuccessor> &regions) {
1177 // The `then` and the `else` region branch back to the parent operation.
1178 if (!point.isParent()) {
1179 regions.push_back(RegionSuccessor::parent());
1180 return;
1181 }
1182
1183 // Don't consider the else region if it is empty.
1184 Region *elseRegion = &this->getElseRegion();
1185 if (elseRegion->empty())
1186 elseRegion = nullptr;
1187
1188 // If the condition isn't constant, both regions may be executed.
1189 regions.push_back(RegionSuccessor(&getThenRegion()));
1190 // If the else region does not exist, it is not a viable successor.
1191 if (elseRegion)
1192 regions.push_back(RegionSuccessor(elseRegion));
1193
1194 return;
1195}
1196
1197mlir::ValueRange cir::IfOp::getSuccessorInputs(RegionSuccessor successor) {
1198 return successor.isParent() ? ValueRange(getOperation()->getResults())
1199 : ValueRange();
1200}
1201
1202void cir::IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1203 bool withElseRegion, BuilderCallbackRef thenBuilder,
1204 BuilderCallbackRef elseBuilder) {
1205 assert(thenBuilder && "the builder callback for 'then' must be present");
1206 result.addOperands(cond);
1207
1208 OpBuilder::InsertionGuard guard(builder);
1209 Region *thenRegion = result.addRegion();
1210 builder.createBlock(thenRegion);
1211 thenBuilder(builder, result.location);
1212
1213 Region *elseRegion = result.addRegion();
1214 if (!withElseRegion)
1215 return;
1216
1217 builder.createBlock(elseRegion);
1218 elseBuilder(builder, result.location);
1219}
1220
1221//===----------------------------------------------------------------------===//
1222// ScopeOp
1223//===----------------------------------------------------------------------===//
1224
1225/// Given the region at `index`, or the parent operation if `index` is None,
1226/// return the successor regions. These are the regions that may be selected
1227/// during the flow of control. `operands` is a set of optional attributes
1228/// that correspond to a constant value for each operand, or null if that
1229/// operand is not a constant.
1230void cir::ScopeOp::getSuccessorRegions(
1231 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1232 // The only region always branch back to the parent operation.
1233 if (!point.isParent()) {
1234 regions.push_back(RegionSuccessor::parent());
1235 return;
1236 }
1237
1238 // If the condition isn't constant, both regions may be executed.
1239 regions.push_back(RegionSuccessor(&getScopeRegion()));
1240}
1241
1242mlir::ValueRange cir::ScopeOp::getSuccessorInputs(RegionSuccessor successor) {
1243 return successor.isParent() ? ValueRange(getOperation()->getResults())
1244 : ValueRange();
1245}
1246
1247void cir::ScopeOp::build(
1248 OpBuilder &builder, OperationState &result,
1249 function_ref<void(OpBuilder &, Type &, Location)> scopeBuilder) {
1250 assert(scopeBuilder && "the builder callback for 'then' must be present");
1251
1252 OpBuilder::InsertionGuard guard(builder);
1253 Region *scopeRegion = result.addRegion();
1254 builder.createBlock(scopeRegion);
1256
1257 mlir::Type yieldTy;
1258 scopeBuilder(builder, yieldTy, result.location);
1259
1260 if (yieldTy)
1261 result.addTypes(TypeRange{yieldTy});
1262}
1263
1264void cir::ScopeOp::build(
1265 OpBuilder &builder, OperationState &result,
1266 function_ref<void(OpBuilder &, Location)> scopeBuilder) {
1267 assert(scopeBuilder && "the builder callback for 'then' must be present");
1268 OpBuilder::InsertionGuard guard(builder);
1269 Region *scopeRegion = result.addRegion();
1270 builder.createBlock(scopeRegion);
1272 scopeBuilder(builder, result.location);
1273}
1274
1275LogicalResult cir::ScopeOp::verify() {
1276 if (getRegion().empty()) {
1277 return emitOpError() << "cir.scope must not be empty since it should "
1278 "include at least an implicit cir.yield ";
1279 }
1280
1281 mlir::Block &lastBlock = getRegion().back();
1282 if (lastBlock.empty() || !lastBlock.mightHaveTerminator() ||
1283 !lastBlock.getTerminator()->hasTrait<OpTrait::IsTerminator>())
1284 return emitOpError() << "last block of cir.scope must be terminated";
1285 return success();
1286}
1287
1288LogicalResult cir::ScopeOp::fold(FoldAdaptor /*adaptor*/,
1289 SmallVectorImpl<OpFoldResult> &results) {
1290 // Only fold "trivial" scopes: a single block containing only a `cir.yield`.
1291 if (!getRegion().hasOneBlock())
1292 return failure();
1293 Block &block = getRegion().front();
1294 if (block.getOperations().size() != 1)
1295 return failure();
1296
1297 auto yield = dyn_cast<cir::YieldOp>(block.front());
1298 if (!yield)
1299 return failure();
1300
1301 // Only fold when the scope produces a value.
1302 if (getNumResults() != 1 || yield.getNumOperands() != 1)
1303 return failure();
1304
1305 results.push_back(yield.getOperand(0));
1306 return success();
1307}
1308
1309//===----------------------------------------------------------------------===//
1310// BrOp
1311//===----------------------------------------------------------------------===//
1312
1313/// Merges blocks connected by a unique unconditional branch.
1314///
1315/// ^bb0: ^bb0:
1316/// ... ...
1317/// cir.br ^bb1 => ...
1318/// ^bb1: cir.return
1319/// ...
1320/// cir.return
1321LogicalResult cir::BrOp::canonicalize(BrOp op, PatternRewriter &rewriter) {
1322 Block *src = op->getBlock();
1323 Block *dst = op.getDest();
1324
1325 // Do not fold self-loops.
1326 if (src == dst)
1327 return failure();
1328
1329 // Only merge when this is the unique edge between the blocks.
1330 if (src->getNumSuccessors() != 1 || dst->getSinglePredecessor() != src)
1331 return failure();
1332
1333 // Don't merge blocks that start with LabelOp or IndirectBrOp.
1334 // This is to avoid merging blocks that have an indirect predecessor.
1335 if (isa<cir::LabelOp, cir::IndirectBrOp>(dst->front()))
1336 return failure();
1337
1338 auto operands = op.getDestOperands();
1339 rewriter.eraseOp(op);
1340 rewriter.mergeBlocks(dst, src, operands);
1341 return success();
1342}
1343
1344mlir::SuccessorOperands cir::BrOp::getSuccessorOperands(unsigned index) {
1345 assert(index == 0 && "invalid successor index");
1346 return mlir::SuccessorOperands(getDestOperandsMutable());
1347}
1348
1349Block *cir::BrOp::getSuccessorForOperands(ArrayRef<Attribute>) {
1350 return getDest();
1351}
1352
1353//===----------------------------------------------------------------------===//
1354// IndirectBrCondOp
1355//===----------------------------------------------------------------------===//
1356
1357mlir::SuccessorOperands
1358cir::IndirectBrOp::getSuccessorOperands(unsigned index) {
1359 assert(index < getNumSuccessors() && "invalid successor index");
1360 return mlir::SuccessorOperands(getSuccOperandsMutable()[index]);
1361}
1362
1364 OpAsmParser &parser, Type &flagType,
1365 SmallVectorImpl<Block *> &succOperandBlocks,
1366 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &succOperands,
1367 SmallVectorImpl<SmallVector<Type>> &succOperandsTypes) {
1368 if (failed(parser.parseCommaSeparatedList(
1369 OpAsmParser::Delimiter::Square,
1370 [&]() {
1371 Block *destination = nullptr;
1372 SmallVector<OpAsmParser::UnresolvedOperand> operands;
1373 SmallVector<Type> operandTypes;
1374
1375 if (parser.parseSuccessor(destination).failed())
1376 return failure();
1377
1378 if (succeeded(parser.parseOptionalLParen())) {
1379 if (failed(parser.parseOperandList(
1380 operands, OpAsmParser::Delimiter::None)) ||
1381 failed(parser.parseColonTypeList(operandTypes)) ||
1382 failed(parser.parseRParen()))
1383 return failure();
1384 }
1385 succOperandBlocks.push_back(destination);
1386 succOperands.emplace_back(operands);
1387 succOperandsTypes.emplace_back(operandTypes);
1388 return success();
1389 },
1390 "successor blocks")))
1391 return failure();
1392 return success();
1393}
1394
1395void printIndirectBrOpSucessors(OpAsmPrinter &p, cir::IndirectBrOp op,
1396 Type flagType, SuccessorRange succs,
1397 OperandRangeRange succOperands,
1398 const TypeRangeRange &succOperandsTypes) {
1399 p << "[";
1400 llvm::interleave(
1401 llvm::zip(succs, succOperands),
1402 [&](auto i) {
1403 p.printNewline();
1404 p.printSuccessorAndUseList(std::get<0>(i), std::get<1>(i));
1405 },
1406 [&] { p << ','; });
1407 if (!succOperands.empty())
1408 p.printNewline();
1409 p << "]";
1410}
1411
1412//===----------------------------------------------------------------------===//
1413// BrCondOp
1414//===----------------------------------------------------------------------===//
1415
1416mlir::SuccessorOperands cir::BrCondOp::getSuccessorOperands(unsigned index) {
1417 assert(index < getNumSuccessors() && "invalid successor index");
1418 return SuccessorOperands(index == 0 ? getDestOperandsTrueMutable()
1419 : getDestOperandsFalseMutable());
1420}
1421
1422Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1423 if (IntegerAttr condAttr = dyn_cast_if_present<IntegerAttr>(operands.front()))
1424 return condAttr.getValue().isOne() ? getDestTrue() : getDestFalse();
1425 return nullptr;
1426}
1427
1428//===----------------------------------------------------------------------===//
1429// CaseOp
1430//===----------------------------------------------------------------------===//
1431
1432void cir::CaseOp::getSuccessorRegions(
1433 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1434 if (!point.isParent()) {
1435 regions.push_back(RegionSuccessor::parent());
1436 return;
1437 }
1438 regions.push_back(RegionSuccessor(&getCaseRegion()));
1439}
1440
1441mlir::ValueRange cir::CaseOp::getSuccessorInputs(RegionSuccessor successor) {
1442 return successor.isParent() ? ValueRange(getOperation()->getResults())
1443 : ValueRange();
1444}
1445
1446void cir::CaseOp::build(OpBuilder &builder, OperationState &result,
1447 ArrayAttr value, CaseOpKind kind,
1448 OpBuilder::InsertPoint &insertPoint) {
1449 OpBuilder::InsertionGuard guardSwitch(builder);
1450 result.addAttribute("value", value);
1451 result.getOrAddProperties<Properties>().kind =
1452 cir::CaseOpKindAttr::get(builder.getContext(), kind);
1453 Region *caseRegion = result.addRegion();
1454 builder.createBlock(caseRegion);
1455
1456 insertPoint = builder.saveInsertionPoint();
1457}
1458
1459//===----------------------------------------------------------------------===//
1460// SwitchOp
1461//===----------------------------------------------------------------------===//
1462
1463void cir::SwitchOp::getSuccessorRegions(
1464 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &region) {
1465 if (!point.isParent()) {
1466 region.push_back(RegionSuccessor::parent());
1467 return;
1468 }
1469
1470 region.push_back(RegionSuccessor(&getBody()));
1471}
1472
1473mlir::ValueRange cir::SwitchOp::getSuccessorInputs(RegionSuccessor successor) {
1474 return successor.isParent() ? ValueRange(getOperation()->getResults())
1475 : ValueRange();
1476}
1477
1478void cir::SwitchOp::build(OpBuilder &builder, OperationState &result,
1479 Value cond, BuilderOpStateCallbackRef switchBuilder) {
1480 assert(switchBuilder && "the builder callback for regions must be present");
1481 OpBuilder::InsertionGuard guardSwitch(builder);
1482 Region *switchRegion = result.addRegion();
1483 builder.createBlock(switchRegion);
1484 result.addOperands({cond});
1485 switchBuilder(builder, result.location, result);
1486}
1487
1488void cir::SwitchOp::collectCases(llvm::SmallVectorImpl<CaseOp> &cases) {
1489 walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
1490 // Don't walk in nested switch op.
1491 if (isa<cir::SwitchOp>(op) && op != *this)
1492 return WalkResult::skip();
1493
1494 if (auto caseOp = dyn_cast<cir::CaseOp>(op))
1495 cases.push_back(caseOp);
1496
1497 return WalkResult::advance();
1498 });
1499}
1500
1501bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
1502 collectCases(cases);
1503
1504 if (getBody().empty())
1505 return false;
1506
1507 if (!isa<YieldOp>(getBody().front().back()))
1508 return false;
1509
1510 if (!llvm::all_of(getBody().front(),
1511 [](Operation &op) { return isa<CaseOp, YieldOp>(op); }))
1512 return false;
1513
1514 return llvm::all_of(cases, [this](CaseOp op) {
1515 return op->getParentOfType<SwitchOp>() == *this;
1516 });
1517}
1518
1519//===----------------------------------------------------------------------===//
1520// SwitchFlatOp
1521//===----------------------------------------------------------------------===//
1522
1523void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
1524 Value value, Block *defaultDestination,
1525 ValueRange defaultOperands,
1526 ArrayRef<APInt> caseValues,
1527 BlockRange caseDestinations,
1528 ArrayRef<ValueRange> caseOperands) {
1529
1530 std::vector<mlir::Attribute> caseValuesAttrs;
1531 for (const APInt &val : caseValues)
1532 caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
1533 mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
1534
1535 build(builder, result, value, defaultOperands, caseOperands, attrs,
1536 defaultDestination, caseDestinations);
1537}
1538
1539/// <cases> ::= `[` (case (`,` case )* )? `]`
1540/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
1541static ParseResult parseSwitchFlatOpCases(
1542 OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
1543 SmallVectorImpl<Block *> &caseDestinations,
1545 &caseOperands,
1546 SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
1547 if (failed(parser.parseLSquare()))
1548 return failure();
1549 if (succeeded(parser.parseOptionalRSquare()))
1550 return success();
1552
1553 auto parseCase = [&]() {
1554 int64_t value = 0;
1555 if (failed(parser.parseInteger(value)))
1556 return failure();
1557
1558 values.push_back(cir::IntAttr::get(flagType, value));
1559
1560 Block *destination;
1562 llvm::SmallVector<Type> operandTypes;
1563 if (parser.parseColon() || parser.parseSuccessor(destination))
1564 return failure();
1565 if (!parser.parseOptionalLParen()) {
1566 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
1567 /*allowResultNumber=*/false) ||
1568 parser.parseColonTypeList(operandTypes) || parser.parseRParen())
1569 return failure();
1570 }
1571 caseDestinations.push_back(destination);
1572 caseOperands.emplace_back(operands);
1573 caseOperandTypes.emplace_back(operandTypes);
1574 return success();
1575 };
1576 if (failed(parser.parseCommaSeparatedList(parseCase)))
1577 return failure();
1578
1579 caseValues = ArrayAttr::get(flagType.getContext(), values);
1580
1581 return parser.parseRSquare();
1582}
1583
1584static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
1585 Type flagType, mlir::ArrayAttr caseValues,
1586 SuccessorRange caseDestinations,
1587 OperandRangeRange caseOperands,
1588 const TypeRangeRange &caseOperandTypes) {
1589 p << '[';
1590 p.printNewline();
1591 if (!caseValues) {
1592 p << ']';
1593 return;
1594 }
1595
1596 size_t index = 0;
1597 llvm::interleave(
1598 llvm::zip(caseValues, caseDestinations),
1599 [&](auto i) {
1600 p << " ";
1601 mlir::Attribute a = std::get<0>(i);
1602 p << mlir::cast<cir::IntAttr>(a).getValue();
1603 p << ": ";
1604 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
1605 },
1606 [&] {
1607 p << ',';
1608 p.printNewline();
1609 });
1610 p.printNewline();
1611 p << ']';
1612}
1613
1614//===----------------------------------------------------------------------===//
1615// GlobalOp
1616//===----------------------------------------------------------------------===//
1617
1618static ParseResult parseConstantValue(OpAsmParser &parser,
1619 mlir::Attribute &valueAttr) {
1620 NamedAttrList attr;
1621 return parser.parseAttribute(valueAttr, "value", attr);
1622}
1623
1624static void printConstant(OpAsmPrinter &p, Attribute value) {
1625 p.printAttribute(value);
1626}
1627
1628mlir::LogicalResult cir::GlobalOp::verify() {
1629 // Verify that the initial value, if present, is either a unit attribute or
1630 // an attribute CIR supports.
1631 if (getInitialValue().has_value()) {
1632 if (checkConstantTypes(getOperation(), getSymType(), *getInitialValue())
1633 .failed())
1634 return failure();
1635 }
1636
1637 // TODO(CIR): Many other checks for properties that haven't been upstreamed
1638 // yet.
1639
1640 return success();
1641}
1642
1643void cir::GlobalOp::build(
1644 OpBuilder &odsBuilder, OperationState &odsState, llvm::StringRef sym_name,
1645 mlir::Type sym_type, bool isConstant, cir::GlobalLinkageKind linkage,
1646 function_ref<void(OpBuilder &, Location)> ctorBuilder,
1647 function_ref<void(OpBuilder &, Location)> dtorBuilder) {
1648 odsState.addAttribute(getSymNameAttrName(odsState.name),
1649 odsBuilder.getStringAttr(sym_name));
1650 odsState.addAttribute(getSymTypeAttrName(odsState.name),
1651 mlir::TypeAttr::get(sym_type));
1652 if (isConstant)
1653 odsState.addAttribute(getConstantAttrName(odsState.name),
1654 odsBuilder.getUnitAttr());
1655
1656 cir::GlobalLinkageKindAttr linkageAttr =
1657 cir::GlobalLinkageKindAttr::get(odsBuilder.getContext(), linkage);
1658 odsState.addAttribute(getLinkageAttrName(odsState.name), linkageAttr);
1659
1660 Region *ctorRegion = odsState.addRegion();
1661 if (ctorBuilder) {
1662 odsBuilder.createBlock(ctorRegion);
1663 ctorBuilder(odsBuilder, odsState.location);
1664 }
1665
1666 Region *dtorRegion = odsState.addRegion();
1667 if (dtorBuilder) {
1668 odsBuilder.createBlock(dtorRegion);
1669 dtorBuilder(odsBuilder, odsState.location);
1670 }
1671
1672 odsState.addAttribute(getGlobalVisibilityAttrName(odsState.name),
1673 cir::VisibilityAttr::get(odsBuilder.getContext()));
1674}
1675
1676/// Given the region at `index`, or the parent operation if `index` is None,
1677/// return the successor regions. These are the regions that may be selected
1678/// during the flow of control. `operands` is a set of optional attributes that
1679/// correspond to a constant value for each operand, or null if that operand is
1680/// not a constant.
1681void cir::GlobalOp::getSuccessorRegions(
1682 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1683 // The `ctor` and `dtor` regions always branch back to the parent operation.
1684 if (!point.isParent()) {
1685 regions.push_back(RegionSuccessor::parent());
1686 return;
1687 }
1688
1689 // Don't consider the ctor region if it is empty.
1690 Region *ctorRegion = &this->getCtorRegion();
1691 if (ctorRegion->empty())
1692 ctorRegion = nullptr;
1693
1694 // Don't consider the dtor region if it is empty.
1695 Region *dtorRegion = &this->getCtorRegion();
1696 if (dtorRegion->empty())
1697 dtorRegion = nullptr;
1698
1699 // If the condition isn't constant, both regions may be executed.
1700 if (ctorRegion)
1701 regions.push_back(RegionSuccessor(ctorRegion));
1702 if (dtorRegion)
1703 regions.push_back(RegionSuccessor(dtorRegion));
1704}
1705
1706mlir::ValueRange cir::GlobalOp::getSuccessorInputs(RegionSuccessor successor) {
1707 return successor.isParent() ? ValueRange(getOperation()->getResults())
1708 : ValueRange();
1709}
1710
1711static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op,
1712 TypeAttr type, Attribute initAttr,
1713 mlir::Region &ctorRegion,
1714 mlir::Region &dtorRegion) {
1715 auto printType = [&]() { p << ": " << type; };
1716 if (!op.isDeclaration()) {
1717 p << "= ";
1718 if (!ctorRegion.empty()) {
1719 p << "ctor ";
1720 printType();
1721 p << " ";
1722 p.printRegion(ctorRegion,
1723 /*printEntryBlockArgs=*/false,
1724 /*printBlockTerminators=*/false);
1725 } else {
1726 // This also prints the type...
1727 if (initAttr)
1728 printConstant(p, initAttr);
1729 }
1730
1731 if (!dtorRegion.empty()) {
1732 p << " dtor ";
1733 p.printRegion(dtorRegion,
1734 /*printEntryBlockArgs=*/false,
1735 /*printBlockTerminators=*/false);
1736 }
1737 } else {
1738 printType();
1739 }
1740}
1741
1742static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser,
1743 TypeAttr &typeAttr,
1744 Attribute &initialValueAttr,
1745 mlir::Region &ctorRegion,
1746 mlir::Region &dtorRegion) {
1747 mlir::Type opTy;
1748 if (parser.parseOptionalEqual().failed()) {
1749 // Absence of equal means a declaration, so we need to parse the type.
1750 // cir.global @a : !cir.int<s, 32>
1751 if (parser.parseColonType(opTy))
1752 return failure();
1753 } else {
1754 // Parse contructor, example:
1755 // cir.global @rgb = ctor : type { ... }
1756 if (!parser.parseOptionalKeyword("ctor")) {
1757 if (parser.parseColonType(opTy))
1758 return failure();
1759 auto parseLoc = parser.getCurrentLocation();
1760 if (parser.parseRegion(ctorRegion, /*arguments=*/{}, /*argTypes=*/{}))
1761 return failure();
1762 if (ensureRegionTerm(parser, ctorRegion, parseLoc).failed())
1763 return failure();
1764 } else {
1765 // Parse constant with initializer, examples:
1766 // cir.global @y = 3.400000e+00 : f32
1767 // cir.global @rgb = #cir.const_array<[...] : !cir.array<i8 x 3>>
1768 if (parseConstantValue(parser, initialValueAttr).failed())
1769 return failure();
1770
1771 assert(mlir::isa<mlir::TypedAttr>(initialValueAttr) &&
1772 "Non-typed attrs shouldn't appear here.");
1773 auto typedAttr = mlir::cast<mlir::TypedAttr>(initialValueAttr);
1774 opTy = typedAttr.getType();
1775 }
1776
1777 // Parse destructor, example:
1778 // dtor { ... }
1779 if (!parser.parseOptionalKeyword("dtor")) {
1780 auto parseLoc = parser.getCurrentLocation();
1781 if (parser.parseRegion(dtorRegion, /*arguments=*/{}, /*argTypes=*/{}))
1782 return failure();
1783 if (ensureRegionTerm(parser, dtorRegion, parseLoc).failed())
1784 return failure();
1785 }
1786 }
1787
1788 typeAttr = TypeAttr::get(opTy);
1789 return success();
1790}
1791
1792//===----------------------------------------------------------------------===//
1793// GetGlobalOp
1794//===----------------------------------------------------------------------===//
1795
1796LogicalResult
1797cir::GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1798 // Verify that the result type underlying pointer type matches the type of
1799 // the referenced cir.global or cir.func op.
1800 mlir::Operation *op =
1801 symbolTable.lookupNearestSymbolFrom(*this, getNameAttr());
1802 if (op == nullptr || !(isa<GlobalOp>(op) || isa<FuncOp>(op)))
1803 return emitOpError("'")
1804 << getName()
1805 << "' does not reference a valid cir.global or cir.func";
1806
1807 mlir::Type symTy;
1808 if (auto g = dyn_cast<GlobalOp>(op)) {
1809 symTy = g.getSymType();
1811 // Verify that for thread local global access, the global needs to
1812 // be marked with tls bits.
1813 if (getTls() && !g.getTlsModel())
1814 return emitOpError("access to global not marked thread local");
1815 } else if (auto f = dyn_cast<FuncOp>(op)) {
1816 symTy = f.getFunctionType();
1817 } else {
1818 llvm_unreachable("Unexpected operation for GetGlobalOp");
1819 }
1820
1821 auto resultType = dyn_cast<PointerType>(getAddr().getType());
1822 if (!resultType || symTy != resultType.getPointee())
1823 return emitOpError("result type pointee type '")
1824 << resultType.getPointee() << "' does not match type " << symTy
1825 << " of the global @" << getName();
1826
1827 return success();
1828}
1829
1830//===----------------------------------------------------------------------===//
1831// VTableAddrPointOp
1832//===----------------------------------------------------------------------===//
1833
1834LogicalResult
1835cir::VTableAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1836 StringRef name = getName();
1837
1838 // Verify that the result type underlying pointer type matches the type of
1839 // the referenced cir.global.
1840 auto op =
1841 symbolTable.lookupNearestSymbolFrom<cir::GlobalOp>(*this, getNameAttr());
1842 if (!op)
1843 return emitOpError("'")
1844 << name << "' does not reference a valid cir.global";
1845 std::optional<mlir::Attribute> init = op.getInitialValue();
1846 if (!init)
1847 return success();
1848 if (!isa<cir::VTableAttr>(*init))
1849 return emitOpError("Expected #cir.vtable in initializer for global '")
1850 << name << "'";
1851 return success();
1852}
1853
1854//===----------------------------------------------------------------------===//
1855// VTTAddrPointOp
1856//===----------------------------------------------------------------------===//
1857
1858LogicalResult
1859cir::VTTAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1860 // VTT ptr is not coming from a symbol.
1861 if (!getName())
1862 return success();
1863 StringRef name = *getName();
1864
1865 // Verify that the result type underlying pointer type matches the type of
1866 // the referenced cir.global op.
1867 auto op =
1868 symbolTable.lookupNearestSymbolFrom<cir::GlobalOp>(*this, getNameAttr());
1869 if (!op)
1870 return emitOpError("'")
1871 << name << "' does not reference a valid cir.global";
1872 std::optional<mlir::Attribute> init = op.getInitialValue();
1873 if (!init)
1874 return success();
1875 if (!isa<cir::ConstArrayAttr>(*init))
1876 return emitOpError(
1877 "Expected constant array in initializer for global VTT '")
1878 << name << "'";
1879 return success();
1880}
1881
1882LogicalResult cir::VTTAddrPointOp::verify() {
1883 // The operation uses either a symbol or a value to operate, but not both
1884 if (getName() && getSymAddr())
1885 return emitOpError("should use either a symbol or value, but not both");
1886
1887 // If not a symbol, stick with the concrete type used for getSymAddr.
1888 if (getSymAddr())
1889 return success();
1890
1891 mlir::Type resultType = getAddr().getType();
1892 mlir::Type resTy = cir::PointerType::get(
1893 cir::PointerType::get(cir::VoidType::get(getContext())));
1894
1895 if (resultType != resTy)
1896 return emitOpError("result type must be ")
1897 << resTy << ", but provided result type is " << resultType;
1898 return success();
1899}
1900
1901//===----------------------------------------------------------------------===//
1902// FuncOp
1903//===----------------------------------------------------------------------===//
1904
1905/// Returns the name used for the linkage attribute. This *must* correspond to
1906/// the name of the attribute in ODS.
1907static llvm::StringRef getLinkageAttrNameString() { return "linkage"; }
1908
1909void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
1910 StringRef name, FuncType type,
1911 GlobalLinkageKind linkage) {
1912 result.addRegion();
1913 result.addAttribute(SymbolTable::getSymbolAttrName(),
1914 builder.getStringAttr(name));
1915 result.addAttribute(getFunctionTypeAttrName(result.name),
1916 TypeAttr::get(type));
1917 result.addAttribute(
1919 GlobalLinkageKindAttr::get(builder.getContext(), linkage));
1920 result.addAttribute(getGlobalVisibilityAttrName(result.name),
1921 cir::VisibilityAttr::get(builder.getContext()));
1922}
1923
1924ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
1925 llvm::SMLoc loc = parser.getCurrentLocation();
1926 mlir::Builder &builder = parser.getBuilder();
1927
1928 mlir::StringAttr builtinNameAttr = getBuiltinAttrName(state.name);
1929 mlir::StringAttr coroutineNameAttr = getCoroutineAttrName(state.name);
1930 mlir::StringAttr inlineKindNameAttr = getInlineKindAttrName(state.name);
1931 mlir::StringAttr lambdaNameAttr = getLambdaAttrName(state.name);
1932 mlir::StringAttr noProtoNameAttr = getNoProtoAttrName(state.name);
1933 mlir::StringAttr visNameAttr = getSymVisibilityAttrName(state.name);
1934 mlir::StringAttr visibilityNameAttr = getGlobalVisibilityAttrName(state.name);
1935 mlir::StringAttr dsoLocalNameAttr = getDsoLocalAttrName(state.name);
1936 mlir::StringAttr specialMemberAttr = getCxxSpecialMemberAttrName(state.name);
1937
1938 if (::mlir::succeeded(parser.parseOptionalKeyword(builtinNameAttr.strref())))
1939 state.addAttribute(builtinNameAttr, parser.getBuilder().getUnitAttr());
1940 if (::mlir::succeeded(
1941 parser.parseOptionalKeyword(coroutineNameAttr.strref())))
1942 state.addAttribute(coroutineNameAttr, parser.getBuilder().getUnitAttr());
1943
1944 // Parse optional inline kind attribute
1945 cir::InlineKindAttr inlineKindAttr;
1946 if (failed(parseInlineKindAttr(parser, inlineKindAttr)))
1947 return failure();
1948 if (inlineKindAttr)
1949 state.addAttribute(inlineKindNameAttr, inlineKindAttr);
1950
1951 if (::mlir::succeeded(parser.parseOptionalKeyword(lambdaNameAttr.strref())))
1952 state.addAttribute(lambdaNameAttr, parser.getBuilder().getUnitAttr());
1953 if (parser.parseOptionalKeyword(noProtoNameAttr).succeeded())
1954 state.addAttribute(noProtoNameAttr, parser.getBuilder().getUnitAttr());
1955
1956 // Default to external linkage if no keyword is provided.
1957 state.addAttribute(getLinkageAttrNameString(),
1958 GlobalLinkageKindAttr::get(
1959 parser.getContext(),
1961 parser, GlobalLinkageKind::ExternalLinkage)));
1962
1963 ::llvm::StringRef visAttrStr;
1964 if (parser.parseOptionalKeyword(&visAttrStr, {"private", "public", "nested"})
1965 .succeeded()) {
1966 state.addAttribute(visNameAttr,
1967 parser.getBuilder().getStringAttr(visAttrStr));
1968 }
1969
1970 cir::VisibilityAttr cirVisibilityAttr;
1971 parseVisibilityAttr(parser, cirVisibilityAttr);
1972 state.addAttribute(visibilityNameAttr, cirVisibilityAttr);
1973
1974 if (parser.parseOptionalKeyword(dsoLocalNameAttr).succeeded())
1975 state.addAttribute(dsoLocalNameAttr, parser.getBuilder().getUnitAttr());
1976
1977 StringAttr nameAttr;
1978 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1979 state.attributes))
1980 return failure();
1984 bool isVariadic = false;
1985 if (function_interface_impl::parseFunctionSignatureWithArguments(
1986 parser, /*allowVariadic=*/true, arguments, isVariadic, resultTypes,
1987 resultAttrs))
1988 return failure();
1990 for (OpAsmParser::Argument &arg : arguments)
1991 argTypes.push_back(arg.type);
1992
1993 if (resultTypes.size() > 1) {
1994 return parser.emitError(
1995 loc, "functions with multiple return types are not supported");
1996 }
1997
1998 mlir::Type returnType =
1999 (resultTypes.empty() ? cir::VoidType::get(builder.getContext())
2000 : resultTypes.front());
2001
2002 cir::FuncType fnType = cir::FuncType::get(argTypes, returnType, isVariadic);
2003 if (!fnType)
2004 return failure();
2005 state.addAttribute(getFunctionTypeAttrName(state.name),
2006 TypeAttr::get(fnType));
2007
2008 bool hasAlias = false;
2009 mlir::StringAttr aliaseeNameAttr = getAliaseeAttrName(state.name);
2010 if (parser.parseOptionalKeyword("alias").succeeded()) {
2011 if (parser.parseLParen().failed())
2012 return failure();
2013 mlir::StringAttr aliaseeAttr;
2014 if (parser.parseOptionalSymbolName(aliaseeAttr).failed())
2015 return failure();
2016 state.addAttribute(aliaseeNameAttr, FlatSymbolRefAttr::get(aliaseeAttr));
2017 if (parser.parseRParen().failed())
2018 return failure();
2019 hasAlias = true;
2020 }
2021
2022 mlir::StringAttr personalityNameAttr = getPersonalityAttrName(state.name);
2023 if (parser.parseOptionalKeyword("personality").succeeded()) {
2024 if (parser.parseLParen().failed())
2025 return failure();
2026 mlir::StringAttr personalityAttr;
2027 if (parser.parseOptionalSymbolName(personalityAttr).failed())
2028 return failure();
2029 state.addAttribute(personalityNameAttr,
2030 FlatSymbolRefAttr::get(personalityAttr));
2031 if (parser.parseRParen().failed())
2032 return failure();
2033 }
2034
2035 auto parseGlobalDtorCtor =
2036 [&](StringRef keyword,
2037 llvm::function_ref<void(std::optional<int> prio)> createAttr)
2038 -> mlir::LogicalResult {
2039 if (mlir::succeeded(parser.parseOptionalKeyword(keyword))) {
2040 std::optional<int> priority;
2041 if (mlir::succeeded(parser.parseOptionalLParen())) {
2042 auto parsedPriority = mlir::FieldParser<int>::parse(parser);
2043 if (mlir::failed(parsedPriority))
2044 return parser.emitError(parser.getCurrentLocation(),
2045 "failed to parse 'priority', of type 'int'");
2046 priority = parsedPriority.value_or(int());
2047 // Parse literal ')'
2048 if (parser.parseRParen())
2049 return failure();
2050 }
2051 createAttr(priority);
2052 }
2053 return success();
2054 };
2055
2056 // Parse CXXSpecialMember attribute
2057 if (parser.parseOptionalKeyword("special_member").succeeded()) {
2058 cir::CXXCtorAttr ctorAttr;
2059 cir::CXXDtorAttr dtorAttr;
2060 cir::CXXAssignAttr assignAttr;
2061 if (parser.parseLess().failed())
2062 return failure();
2063 if (parser.parseOptionalAttribute(ctorAttr).has_value())
2064 state.addAttribute(specialMemberAttr, ctorAttr);
2065 else if (parser.parseOptionalAttribute(dtorAttr).has_value())
2066 state.addAttribute(specialMemberAttr, dtorAttr);
2067 else if (parser.parseOptionalAttribute(assignAttr).has_value())
2068 state.addAttribute(specialMemberAttr, assignAttr);
2069 if (parser.parseGreater().failed())
2070 return failure();
2071 }
2072
2073 if (parseGlobalDtorCtor("global_ctor", [&](std::optional<int> priority) {
2074 mlir::IntegerAttr globalCtorPriorityAttr =
2075 builder.getI32IntegerAttr(priority.value_or(65535));
2076 state.addAttribute(getGlobalCtorPriorityAttrName(state.name),
2077 globalCtorPriorityAttr);
2078 }).failed())
2079 return failure();
2080
2081 if (parseGlobalDtorCtor("global_dtor", [&](std::optional<int> priority) {
2082 mlir::IntegerAttr globalDtorPriorityAttr =
2083 builder.getI32IntegerAttr(priority.value_or(65535));
2084 state.addAttribute(getGlobalDtorPriorityAttrName(state.name),
2085 globalDtorPriorityAttr);
2086 }).failed())
2087 return failure();
2088
2089 if (parser.parseOptionalKeyword("side_effect").succeeded()) {
2090 cir::SideEffect sideEffect;
2091
2092 if (parser.parseLParen().failed() ||
2093 parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed() ||
2094 parser.parseRParen().failed())
2095 return failure();
2096
2097 auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
2098 state.addAttribute(CIRDialect::getSideEffectAttrName(), attr);
2099 }
2100
2101 // Parse the rest of the attributes.
2102 NamedAttrList parsedAttrs;
2103 if (parser.parseOptionalAttrDictWithKeyword(parsedAttrs))
2104 return failure();
2105
2106 for (StringRef disallowed : cir::FuncOp::getAttributeNames()) {
2107 if (parsedAttrs.get(disallowed))
2108 return parser.emitError(loc, "attribute '")
2109 << disallowed
2110 << "' should not be specified in the explicit attribute list";
2111 }
2112
2113 state.attributes.append(parsedAttrs);
2114
2115 // Parse the optional function body.
2116 auto *body = state.addRegion();
2117 OptionalParseResult parseResult = parser.parseOptionalRegion(
2118 *body, arguments, /*enableNameShadowing=*/false);
2119 if (parseResult.has_value()) {
2120 if (hasAlias)
2121 return parser.emitError(loc, "function alias shall not have a body");
2122 if (failed(*parseResult))
2123 return failure();
2124 // Function body was parsed, make sure its not empty.
2125 if (body->empty())
2126 return parser.emitError(loc, "expected non-empty function body");
2127 }
2128
2129 return success();
2130}
2131
2132// This function corresponds to `llvm::GlobalValue::isDeclaration` and should
2133// have a similar implementation. We don't currently ifuncs or materializable
2134// functions, but those should be handled here as they are implemented.
2135bool cir::FuncOp::isDeclaration() {
2137
2138 std::optional<StringRef> aliasee = getAliasee();
2139 if (!aliasee)
2140 return getFunctionBody().empty();
2141
2142 // Aliases are always definitions.
2143 return false;
2144}
2145
2146bool cir::FuncOp::isCXXSpecialMemberFunction() {
2147 return getCxxSpecialMemberAttr() != nullptr;
2148}
2149
2150bool cir::FuncOp::isCxxConstructor() {
2151 auto attr = getCxxSpecialMemberAttr();
2152 return attr && dyn_cast<CXXCtorAttr>(attr);
2153}
2154
2155bool cir::FuncOp::isCxxDestructor() {
2156 auto attr = getCxxSpecialMemberAttr();
2157 return attr && dyn_cast<CXXDtorAttr>(attr);
2158}
2159
2160bool cir::FuncOp::isCxxSpecialAssignment() {
2161 auto attr = getCxxSpecialMemberAttr();
2162 return attr && dyn_cast<CXXAssignAttr>(attr);
2163}
2164
2165std::optional<CtorKind> cir::FuncOp::getCxxConstructorKind() {
2166 mlir::Attribute attr = getCxxSpecialMemberAttr();
2167 if (attr) {
2168 if (auto ctor = dyn_cast<CXXCtorAttr>(attr))
2169 return ctor.getCtorKind();
2170 }
2171 return std::nullopt;
2172}
2173
2174std::optional<AssignKind> cir::FuncOp::getCxxSpecialAssignKind() {
2175 mlir::Attribute attr = getCxxSpecialMemberAttr();
2176 if (attr) {
2177 if (auto assign = dyn_cast<CXXAssignAttr>(attr))
2178 return assign.getAssignKind();
2179 }
2180 return std::nullopt;
2181}
2182
2183bool cir::FuncOp::isCxxTrivialMemberFunction() {
2184 mlir::Attribute attr = getCxxSpecialMemberAttr();
2185 if (attr) {
2186 if (auto ctor = dyn_cast<CXXCtorAttr>(attr))
2187 return ctor.getIsTrivial();
2188 if (auto dtor = dyn_cast<CXXDtorAttr>(attr))
2189 return dtor.getIsTrivial();
2190 if (auto assign = dyn_cast<CXXAssignAttr>(attr))
2191 return assign.getIsTrivial();
2192 }
2193 return false;
2194}
2195
2196mlir::Region *cir::FuncOp::getCallableRegion() {
2197 // TODO(CIR): This function will have special handling for aliases and a
2198 // check for an external function, once those features have been upstreamed.
2199 return &getBody();
2200}
2201
2202void cir::FuncOp::print(OpAsmPrinter &p) {
2203 if (getBuiltin())
2204 p << " builtin";
2205
2206 if (getCoroutine())
2207 p << " coroutine";
2208
2209 printInlineKindAttr(p, getInlineKindAttr());
2210
2211 if (getLambda())
2212 p << " lambda";
2213
2214 if (getNoProto())
2215 p << " no_proto";
2216
2217 if (getComdat())
2218 p << " comdat";
2219
2220 if (getLinkage() != GlobalLinkageKind::ExternalLinkage)
2221 p << ' ' << stringifyGlobalLinkageKind(getLinkage());
2222
2223 mlir::SymbolTable::Visibility vis = getVisibility();
2224 if (vis != mlir::SymbolTable::Visibility::Public)
2225 p << ' ' << vis;
2226
2227 cir::VisibilityAttr cirVisibilityAttr = getGlobalVisibilityAttr();
2228 if (!cirVisibilityAttr.isDefault()) {
2229 p << ' ';
2230 printVisibilityAttr(p, cirVisibilityAttr);
2231 }
2232
2233 if (getDsoLocal())
2234 p << " dso_local";
2235
2236 p << ' ';
2237 p.printSymbolName(getSymName());
2238 cir::FuncType fnType = getFunctionType();
2239 function_interface_impl::printFunctionSignature(
2240 p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes());
2241
2242 if (std::optional<StringRef> aliaseeName = getAliasee()) {
2243 p << " alias(";
2244 p.printSymbolName(*aliaseeName);
2245 p << ")";
2246 }
2247
2248 if (std::optional<StringRef> personalityName = getPersonality()) {
2249 p << " personality(";
2250 p.printSymbolName(*personalityName);
2251 p << ")";
2252 }
2253
2254 if (auto specialMemberAttr = getCxxSpecialMember()) {
2255 p << " special_member<";
2256 p.printAttribute(*specialMemberAttr);
2257 p << '>';
2258 }
2259
2260 if (auto globalCtorPriority = getGlobalCtorPriority()) {
2261 p << " global_ctor";
2262 if (globalCtorPriority.value() != 65535)
2263 p << "(" << globalCtorPriority.value() << ")";
2264 }
2265
2266 if (auto globalDtorPriority = getGlobalDtorPriority()) {
2267 p << " global_dtor";
2268 if (globalDtorPriority.value() != 65535)
2269 p << "(" << globalDtorPriority.value() << ")";
2270 }
2271
2272 if (std::optional<cir::SideEffect> sideEffect = getSideEffect();
2273 sideEffect && *sideEffect != cir::SideEffect::All) {
2274 p << " side_effect(";
2275 p << stringifySideEffect(*sideEffect);
2276 p << ")";
2277 }
2278
2279 function_interface_impl::printFunctionAttributes(
2280 p, *this, cir::FuncOp::getAttributeNames());
2281
2282 // Print the body if this is not an external function.
2283 Region &body = getOperation()->getRegion(0);
2284 if (!body.empty()) {
2285 p << ' ';
2286 p.printRegion(body, /*printEntryBlockArgs=*/false,
2287 /*printBlockTerminators=*/true);
2288 }
2289}
2290
2291mlir::LogicalResult cir::FuncOp::verify() {
2292
2293 if (!isDeclaration() && getCoroutine()) {
2294 bool foundAwait = false;
2295 this->walk([&](Operation *op) {
2296 if (auto await = dyn_cast<AwaitOp>(op)) {
2297 foundAwait = true;
2298 return;
2299 }
2300 });
2301 if (!foundAwait)
2302 return emitOpError()
2303 << "coroutine body must use at least one cir.await op";
2304 }
2305
2306 llvm::SmallSet<llvm::StringRef, 16> labels;
2307 llvm::SmallSet<llvm::StringRef, 16> gotos;
2308 llvm::SmallSet<llvm::StringRef, 16> blockAddresses;
2309 bool invalidBlockAddress = false;
2310 getOperation()->walk([&](mlir::Operation *op) {
2311 if (auto lab = dyn_cast<cir::LabelOp>(op)) {
2312 labels.insert(lab.getLabel());
2313 } else if (auto goTo = dyn_cast<cir::GotoOp>(op)) {
2314 gotos.insert(goTo.getLabel());
2315 } else if (auto blkAdd = dyn_cast<cir::BlockAddressOp>(op)) {
2316 if (blkAdd.getBlockAddrInfoAttr().getFunc().getAttr() != getSymName()) {
2317 // Stop the walk early, no need to continue
2318 invalidBlockAddress = true;
2319 return mlir::WalkResult::interrupt();
2320 }
2321 blockAddresses.insert(blkAdd.getBlockAddrInfoAttr().getLabel());
2322 }
2323 return mlir::WalkResult::advance();
2324 });
2325
2326 if (invalidBlockAddress)
2327 return emitOpError() << "blockaddress references a different function";
2328
2329 llvm::SmallSet<llvm::StringRef, 16> mismatched;
2330 if (!labels.empty() || !gotos.empty()) {
2331 mismatched = llvm::set_difference(gotos, labels);
2332
2333 if (!mismatched.empty())
2334 return emitOpError() << "goto/label mismatch";
2335 }
2336
2337 mismatched.clear();
2338
2339 if (!labels.empty() || !blockAddresses.empty()) {
2340 mismatched = llvm::set_difference(blockAddresses, labels);
2341
2342 if (!mismatched.empty())
2343 return emitOpError()
2344 << "expects an existing label target in the referenced function";
2345 }
2346
2347 return success();
2348}
2349
2350//===----------------------------------------------------------------------===//
2351// BinOp
2352//===----------------------------------------------------------------------===//
2353LogicalResult cir::BinOp::verify() {
2354 bool noWrap = getNoUnsignedWrap() || getNoSignedWrap();
2355 bool saturated = getSaturated();
2356
2357 if (!isa<cir::IntType>(getType()) && noWrap)
2358 return emitError()
2359 << "only operations on integer values may have nsw/nuw flags";
2360
2361 bool noWrapOps = getKind() == cir::BinOpKind::Add ||
2362 getKind() == cir::BinOpKind::Sub ||
2363 getKind() == cir::BinOpKind::Mul;
2364
2365 bool saturatedOps =
2366 getKind() == cir::BinOpKind::Add || getKind() == cir::BinOpKind::Sub;
2367
2368 if (noWrap && !noWrapOps)
2369 return emitError() << "The nsw/nuw flags are applicable to opcodes: 'add', "
2370 "'sub' and 'mul'";
2371 if (saturated && !saturatedOps)
2372 return emitError() << "The saturated flag is applicable to opcodes: 'add' "
2373 "and 'sub'";
2374 if (noWrap && saturated)
2375 return emitError() << "The nsw/nuw flags and the saturated flag are "
2376 "mutually exclusive";
2377
2378 return mlir::success();
2379}
2380
2381//===----------------------------------------------------------------------===//
2382// TernaryOp
2383//===----------------------------------------------------------------------===//
2384
2385/// Given the region at `point`, or the parent operation if `point` is None,
2386/// return the successor regions. These are the regions that may be selected
2387/// during the flow of control. `operands` is a set of optional attributes that
2388/// correspond to a constant value for each operand, or null if that operand is
2389/// not a constant.
2390void cir::TernaryOp::getSuccessorRegions(
2391 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2392 // The `true` and the `false` region branch back to the parent operation.
2393 if (!point.isParent()) {
2394 regions.push_back(RegionSuccessor::parent());
2395 return;
2396 }
2397
2398 // When branching from the parent operation, both the true and false
2399 // regions are considered possible successors
2400 regions.push_back(RegionSuccessor(&getTrueRegion()));
2401 regions.push_back(RegionSuccessor(&getFalseRegion()));
2402}
2403
2404mlir::ValueRange cir::TernaryOp::getSuccessorInputs(RegionSuccessor successor) {
2405 return successor.isParent() ? ValueRange(getOperation()->getResults())
2406 : ValueRange();
2407}
2408
2409void cir::TernaryOp::build(
2410 OpBuilder &builder, OperationState &result, Value cond,
2411 function_ref<void(OpBuilder &, Location)> trueBuilder,
2412 function_ref<void(OpBuilder &, Location)> falseBuilder) {
2413 result.addOperands(cond);
2414 OpBuilder::InsertionGuard guard(builder);
2415 Region *trueRegion = result.addRegion();
2416 builder.createBlock(trueRegion);
2417 trueBuilder(builder, result.location);
2418 Region *falseRegion = result.addRegion();
2419 builder.createBlock(falseRegion);
2420 falseBuilder(builder, result.location);
2421
2422 // Get result type from whichever branch has a yield (the other may have
2423 // unreachable from a throw expression)
2424 auto yield =
2425 dyn_cast_or_null<cir::YieldOp>(trueRegion->back().getTerminator());
2426 if (!yield)
2427 yield = dyn_cast_or_null<cir::YieldOp>(falseRegion->back().getTerminator());
2428
2429 assert((yield && yield.getNumOperands() <= 1) &&
2430 "expected zero or one result type");
2431 if (yield.getNumOperands() == 1)
2432 result.addTypes(TypeRange{yield.getOperandTypes().front()});
2433}
2434
2435//===----------------------------------------------------------------------===//
2436// SelectOp
2437//===----------------------------------------------------------------------===//
2438
2439OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
2440 mlir::Attribute condition = adaptor.getCondition();
2441 if (condition) {
2442 bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
2443 return conditionValue ? getTrueValue() : getFalseValue();
2444 }
2445
2446 // cir.select if %0 then x else x -> x
2447 mlir::Attribute trueValue = adaptor.getTrueValue();
2448 mlir::Attribute falseValue = adaptor.getFalseValue();
2449 if (trueValue == falseValue)
2450 return trueValue;
2451 if (getTrueValue() == getFalseValue())
2452 return getTrueValue();
2453
2454 return {};
2455}
2456
2457LogicalResult cir::SelectOp::verify() {
2458 // AllTypesMatch already guarantees trueVal and falseVal have matching types.
2459 auto condTy = dyn_cast<cir::VectorType>(getCondition().getType());
2460
2461 // If condition is not a vector, no further checks are needed.
2462 if (!condTy)
2463 return success();
2464
2465 // When condition is a vector, both other operands must also be vectors.
2466 if (!isa<cir::VectorType>(getTrueValue().getType()) ||
2467 !isa<cir::VectorType>(getFalseValue().getType())) {
2468 return emitOpError()
2469 << "expected both true and false operands to be vector types "
2470 "when the condition is a vector boolean type";
2471 }
2472
2473 return success();
2474}
2475
2476//===----------------------------------------------------------------------===//
2477// ShiftOp
2478//===----------------------------------------------------------------------===//
2479LogicalResult cir::ShiftOp::verify() {
2480 mlir::Operation *op = getOperation();
2481 auto op0VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(0).getType());
2482 auto op1VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(1).getType());
2483 if (!op0VecTy ^ !op1VecTy)
2484 return emitOpError() << "input types cannot be one vector and one scalar";
2485
2486 if (op0VecTy) {
2487 if (op0VecTy.getSize() != op1VecTy.getSize())
2488 return emitOpError() << "input vector types must have the same size";
2489
2490 auto opResultTy = mlir::dyn_cast<cir::VectorType>(getType());
2491 if (!opResultTy)
2492 return emitOpError() << "the type of the result must be a vector "
2493 << "if it is vector shift";
2494
2495 auto op0VecEleTy = mlir::cast<cir::IntType>(op0VecTy.getElementType());
2496 auto op1VecEleTy = mlir::cast<cir::IntType>(op1VecTy.getElementType());
2497 if (op0VecEleTy.getWidth() != op1VecEleTy.getWidth())
2498 return emitOpError()
2499 << "vector operands do not have the same elements sizes";
2500
2501 auto resVecEleTy = mlir::cast<cir::IntType>(opResultTy.getElementType());
2502 if (op0VecEleTy.getWidth() != resVecEleTy.getWidth())
2503 return emitOpError() << "vector operands and result type do not have the "
2504 "same elements sizes";
2505 }
2506
2507 return mlir::success();
2508}
2509
2510//===----------------------------------------------------------------------===//
2511// LabelOp Definitions
2512//===----------------------------------------------------------------------===//
2513
2514LogicalResult cir::LabelOp::verify() {
2515 mlir::Operation *op = getOperation();
2516 mlir::Block *blk = op->getBlock();
2517 if (&blk->front() != op)
2518 return emitError() << "must be the first operation in a block";
2519
2520 return mlir::success();
2521}
2522
2523//===----------------------------------------------------------------------===//
2524// UnaryOp
2525//===----------------------------------------------------------------------===//
2526
2527LogicalResult cir::UnaryOp::verify() {
2528 switch (getKind()) {
2529 case cir::UnaryOpKind::Inc:
2530 case cir::UnaryOpKind::Dec:
2531 case cir::UnaryOpKind::Plus:
2532 case cir::UnaryOpKind::Minus:
2533 case cir::UnaryOpKind::Not:
2534 // Nothing to verify.
2535 return success();
2536 }
2537
2538 llvm_unreachable("Unknown UnaryOp kind?");
2539}
2540
2541static bool isBoolNot(cir::UnaryOp op) {
2542 return isa<cir::BoolType>(op.getInput().getType()) &&
2543 op.getKind() == cir::UnaryOpKind::Not;
2544}
2545
2546// This folder simplifies the sequential boolean not operations.
2547// For instance, the next two unary operations will be eliminated:
2548//
2549// ```mlir
2550// %1 = cir.unary(not, %0) : !cir.bool, !cir.bool
2551// %2 = cir.unary(not, %1) : !cir.bool, !cir.bool
2552// ```
2553//
2554// and the argument of the first one (%0) will be used instead.
2555OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) {
2556 if (auto poison =
2557 mlir::dyn_cast_if_present<cir::PoisonAttr>(adaptor.getInput())) {
2558 // Propagate poison values
2559 return poison;
2560 }
2561
2562 if (isBoolNot(*this))
2563 if (auto previous = getInput().getDefiningOp<cir::UnaryOp>())
2564 if (isBoolNot(previous))
2565 return previous.getInput();
2566
2567 // Avoid introducing unnecessary duplicate constants in cases where we are
2568 // just folding the operation to its input value. If we return the
2569 // input attribute from the adapter, a new constant is materialized, but
2570 // if we return the input value directly, it avoids that.
2571 if (auto srcConst = getInput().getDefiningOp<cir::ConstantOp>()) {
2572 if (getKind() == cir::UnaryOpKind::Plus ||
2573 (mlir::isa<cir::BoolType>(srcConst.getType()) &&
2574 getKind() == cir::UnaryOpKind::Minus))
2575 return srcConst.getResult();
2576 }
2577
2578 // Fold unary operations with constant inputs. If the input is a ConstantOp,
2579 // it "folds" to its value attribute. If it was some other operation that
2580 // was folded, it will be an mlir::Attribute that hasn't yet been
2581 // materialized. If it was a value that couldn't be folded, it will be null.
2582 if (mlir::Attribute attr = adaptor.getInput()) {
2583 // For now, we only attempt to fold simple scalar values.
2584 OpFoldResult result =
2585 llvm::TypeSwitch<mlir::Attribute, OpFoldResult>(attr)
2586 .Case<cir::IntAttr>([&](cir::IntAttr attrT) {
2587 switch (getKind()) {
2588 case cir::UnaryOpKind::Not: {
2589 APInt val = attrT.getValue();
2590 val.flipAllBits();
2591 return cir::IntAttr::get(getType(), val);
2592 }
2593 case cir::UnaryOpKind::Plus:
2594 return attrT;
2595 case cir::UnaryOpKind::Minus: {
2596 APInt val = attrT.getValue();
2597 val.negate();
2598 return cir::IntAttr::get(getType(), val);
2599 }
2600 default:
2601 return cir::IntAttr{};
2602 }
2603 })
2604 .Case<cir::FPAttr>([&](cir::FPAttr attrT) {
2605 switch (getKind()) {
2606 case cir::UnaryOpKind::Plus:
2607 return attrT;
2608 case cir::UnaryOpKind::Minus: {
2609 APFloat val = attrT.getValue();
2610 val.changeSign();
2611 return cir::FPAttr::get(getType(), val);
2612 }
2613 default:
2614 return cir::FPAttr{};
2615 }
2616 })
2617 .Case<cir::BoolAttr>([&](cir::BoolAttr attrT) {
2618 switch (getKind()) {
2619 case cir::UnaryOpKind::Not:
2620 return cir::BoolAttr::get(getContext(), !attrT.getValue());
2621 case cir::UnaryOpKind::Plus:
2622 case cir::UnaryOpKind::Minus:
2623 return attrT;
2624 default:
2625 return cir::BoolAttr{};
2626 }
2627 })
2628 .Default([&](auto attrT) { return mlir::Attribute{}; });
2629 if (result)
2630 return result;
2631 }
2632
2633 return {};
2634}
2635
2636//===----------------------------------------------------------------------===//
2637// BaseDataMemberOp & DerivedDataMemberOp
2638//===----------------------------------------------------------------------===//
2639
2640static LogicalResult verifyMemberPtrCast(Operation *op, mlir::Value src,
2641 mlir::Type resultTy) {
2642 // Let the operand type be T1 C1::*, let the result type be T2 C2::*.
2643 // Verify that T1 and T2 are the same type.
2644 mlir::Type inputMemberTy;
2645 mlir::Type resultMemberTy;
2646 if (mlir::isa<cir::DataMemberType>(src.getType())) {
2647 inputMemberTy =
2648 mlir::cast<cir::DataMemberType>(src.getType()).getMemberTy();
2649 resultMemberTy = mlir::cast<cir::DataMemberType>(resultTy).getMemberTy();
2650 }
2652 if (inputMemberTy != resultMemberTy)
2653 return op->emitOpError()
2654 << "member types of the operand and the result do not match";
2655
2656 return mlir::success();
2657}
2658
2659LogicalResult cir::BaseDataMemberOp::verify() {
2660 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
2661}
2662
2663LogicalResult cir::DerivedDataMemberOp::verify() {
2664 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
2665}
2666
2667//===----------------------------------------------------------------------===//
2668// BaseMethodOp & DerivedMethodOp
2669//===----------------------------------------------------------------------===//
2670
2671LogicalResult cir::BaseMethodOp::verify() {
2672 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
2673}
2674
2675LogicalResult cir::DerivedMethodOp::verify() {
2676 return verifyMemberPtrCast(getOperation(), getSrc(), getType());
2677}
2678
2679//===----------------------------------------------------------------------===//
2680// AwaitOp
2681//===----------------------------------------------------------------------===//
2682
2683void cir::AwaitOp::build(OpBuilder &builder, OperationState &result,
2684 cir::AwaitKind kind, BuilderCallbackRef readyBuilder,
2685 BuilderCallbackRef suspendBuilder,
2686 BuilderCallbackRef resumeBuilder) {
2687 result.addAttribute(getKindAttrName(result.name),
2688 cir::AwaitKindAttr::get(builder.getContext(), kind));
2689 {
2690 OpBuilder::InsertionGuard guard(builder);
2691 Region *readyRegion = result.addRegion();
2692 builder.createBlock(readyRegion);
2693 readyBuilder(builder, result.location);
2694 }
2695
2696 {
2697 OpBuilder::InsertionGuard guard(builder);
2698 Region *suspendRegion = result.addRegion();
2699 builder.createBlock(suspendRegion);
2700 suspendBuilder(builder, result.location);
2701 }
2702
2703 {
2704 OpBuilder::InsertionGuard guard(builder);
2705 Region *resumeRegion = result.addRegion();
2706 builder.createBlock(resumeRegion);
2707 resumeBuilder(builder, result.location);
2708 }
2709}
2710
2711void cir::AwaitOp::getSuccessorRegions(
2712 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2713 // If any index all the underlying regions branch back to the parent
2714 // operation.
2715 if (!point.isParent()) {
2716 regions.push_back(RegionSuccessor::parent());
2717 return;
2718 }
2719
2720 // TODO: retrieve information from the promise and only push the
2721 // necessary ones. Example: `std::suspend_never` on initial or final
2722 // await's might allow suspend region to be skipped.
2723 regions.push_back(RegionSuccessor(&this->getReady()));
2724 regions.push_back(RegionSuccessor(&this->getSuspend()));
2725 regions.push_back(RegionSuccessor(&this->getResume()));
2726}
2727
2728mlir::ValueRange cir::AwaitOp::getSuccessorInputs(RegionSuccessor successor) {
2729 if (successor.isParent())
2730 return getOperation()->getResults();
2731 if (successor == &getReady())
2732 return getReady().getArguments();
2733 if (successor == &getSuspend())
2734 return getSuspend().getArguments();
2735 if (successor == &getResume())
2736 return getResume().getArguments();
2737 llvm_unreachable("invalid region successor");
2738}
2739
2740LogicalResult cir::AwaitOp::verify() {
2741 if (!isa<ConditionOp>(this->getReady().back().getTerminator()))
2742 return emitOpError("ready region must end with cir.condition");
2743 return success();
2744}
2745
2746//===----------------------------------------------------------------------===//
2747// CopyOp Definitions
2748//===----------------------------------------------------------------------===//
2749
2750LogicalResult cir::CopyOp::verify() {
2751 // A data layout is required for us to know the number of bytes to be copied.
2752 if (!getType().getPointee().hasTrait<DataLayoutTypeInterface::Trait>())
2753 return emitError() << "missing data layout for pointee type";
2754
2755 if (getSrc() == getDst())
2756 return emitError() << "source and destination are the same";
2757
2758 return mlir::success();
2759}
2760
2761//===----------------------------------------------------------------------===//
2762// GetRuntimeMemberOp Definitions
2763//===----------------------------------------------------------------------===//
2764
2765LogicalResult cir::GetRuntimeMemberOp::verify() {
2766 auto recordTy = mlir::cast<RecordType>(getAddr().getType().getPointee());
2767 cir::DataMemberType memberPtrTy = getMember().getType();
2768
2769 if (recordTy != memberPtrTy.getClassTy())
2770 return emitError() << "record type does not match the member pointer type";
2771 if (getType().getPointee() != memberPtrTy.getMemberTy())
2772 return emitError() << "result type does not match the member pointer type";
2773 return mlir::success();
2774}
2775
2776//===----------------------------------------------------------------------===//
2777// GetMethodOp Definitions
2778//===----------------------------------------------------------------------===//
2779
2780LogicalResult cir::GetMethodOp::verify() {
2781 cir::MethodType methodTy = getMethod().getType();
2782
2783 // Assume objectTy is !cir.ptr<!T>
2784 cir::PointerType objectPtrTy = getObject().getType();
2785 mlir::Type objectTy = objectPtrTy.getPointee();
2786
2787 if (methodTy.getClassTy() != objectTy)
2788 return emitError() << "method class type and object type do not match";
2789
2790 // Assume methodFuncTy is !cir.func<!Ret (!Args)>
2791 auto calleeTy = mlir::cast<cir::FuncType>(getCallee().getType().getPointee());
2792 cir::FuncType methodFuncTy = methodTy.getMemberFuncTy();
2793
2794 // We verify at here that calleeTy is !cir.func<!Ret (!cir.ptr<!void>, !Args)>
2795 // Note that the first parameter type of the callee is !cir.ptr<!void> instead
2796 // of !cir.ptr<!T> because the "this" pointer may be adjusted before calling
2797 // the callee.
2798
2799 if (methodFuncTy.getReturnType() != calleeTy.getReturnType())
2800 return emitError()
2801 << "method return type and callee return type do not match";
2802
2803 llvm::ArrayRef<mlir::Type> calleeArgsTy = calleeTy.getInputs();
2804 llvm::ArrayRef<mlir::Type> methodFuncArgsTy = methodFuncTy.getInputs();
2805
2806 if (calleeArgsTy.empty())
2807 return emitError() << "callee parameter list lacks receiver object ptr";
2808
2809 auto calleeThisArgPtrTy = mlir::dyn_cast<cir::PointerType>(calleeArgsTy[0]);
2810 if (!calleeThisArgPtrTy ||
2811 !mlir::isa<cir::VoidType>(calleeThisArgPtrTy.getPointee())) {
2812 return emitError()
2813 << "the first parameter of callee must be a void pointer";
2814 }
2815
2816 if (calleeArgsTy.slice(1) != methodFuncArgsTy)
2817 return emitError()
2818 << "callee parameters and method parameters do not match";
2819
2820 return mlir::success();
2821}
2822
2823//===----------------------------------------------------------------------===//
2824// GetMemberOp Definitions
2825//===----------------------------------------------------------------------===//
2826
2827LogicalResult cir::GetMemberOp::verify() {
2828 const auto recordTy = dyn_cast<RecordType>(getAddrTy().getPointee());
2829 if (!recordTy)
2830 return emitError() << "expected pointer to a record type";
2831
2832 if (recordTy.getMembers().size() <= getIndex())
2833 return emitError() << "member index out of bounds";
2834
2835 if (recordTy.getMembers()[getIndex()] != getType().getPointee())
2836 return emitError() << "member type mismatch";
2837
2838 return mlir::success();
2839}
2840
2841//===----------------------------------------------------------------------===//
2842// ExtractMemberOp Definitions
2843//===----------------------------------------------------------------------===//
2844
2845LogicalResult cir::ExtractMemberOp::verify() {
2846 auto recordTy = mlir::cast<cir::RecordType>(getRecord().getType());
2847 if (recordTy.getKind() == cir::RecordType::Union)
2848 return emitError()
2849 << "cir.extract_member currently does not support unions";
2850 if (recordTy.getMembers().size() <= getIndex())
2851 return emitError() << "member index out of bounds";
2852 if (recordTy.getMembers()[getIndex()] != getType())
2853 return emitError() << "member type mismatch";
2854 return mlir::success();
2855}
2856
2857//===----------------------------------------------------------------------===//
2858// InsertMemberOp Definitions
2859//===----------------------------------------------------------------------===//
2860
2861LogicalResult cir::InsertMemberOp::verify() {
2862 auto recordTy = mlir::cast<cir::RecordType>(getRecord().getType());
2863 if (recordTy.getKind() == cir::RecordType::Union)
2864 return emitError() << "cir.insert_member currently does not support unions";
2865 if (recordTy.getMembers().size() <= getIndex())
2866 return emitError() << "member index out of bounds";
2867 if (recordTy.getMembers()[getIndex()] != getValue().getType())
2868 return emitError() << "member type mismatch";
2869 // The op trait already checks that the types of $result and $record match.
2870 return mlir::success();
2871}
2872
2873//===----------------------------------------------------------------------===//
2874// VecCreateOp
2875//===----------------------------------------------------------------------===//
2876
2877OpFoldResult cir::VecCreateOp::fold(FoldAdaptor adaptor) {
2878 if (llvm::any_of(getElements(), [](mlir::Value value) {
2879 return !value.getDefiningOp<cir::ConstantOp>();
2880 }))
2881 return {};
2882
2883 return cir::ConstVectorAttr::get(
2884 getType(), mlir::ArrayAttr::get(getContext(), adaptor.getElements()));
2885}
2886
2887LogicalResult cir::VecCreateOp::verify() {
2888 // Verify that the number of arguments matches the number of elements in the
2889 // vector, and that the type of all the arguments matches the type of the
2890 // elements in the vector.
2891 const cir::VectorType vecTy = getType();
2892 if (getElements().size() != vecTy.getSize()) {
2893 return emitOpError() << "operand count of " << getElements().size()
2894 << " doesn't match vector type " << vecTy
2895 << " element count of " << vecTy.getSize();
2896 }
2897
2898 const mlir::Type elementType = vecTy.getElementType();
2899 for (const mlir::Value element : getElements()) {
2900 if (element.getType() != elementType) {
2901 return emitOpError() << "operand type " << element.getType()
2902 << " doesn't match vector element type "
2903 << elementType;
2904 }
2905 }
2906
2907 return success();
2908}
2909
2910//===----------------------------------------------------------------------===//
2911// VecExtractOp
2912//===----------------------------------------------------------------------===//
2913
2914OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
2915 const auto vectorAttr =
2916 llvm::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec());
2917 if (!vectorAttr)
2918 return {};
2919
2920 const auto indexAttr =
2921 llvm::dyn_cast_if_present<cir::IntAttr>(adaptor.getIndex());
2922 if (!indexAttr)
2923 return {};
2924
2925 const mlir::ArrayAttr elements = vectorAttr.getElts();
2926 const uint64_t index = indexAttr.getUInt();
2927 if (index >= elements.size())
2928 return {};
2929
2930 return elements[index];
2931}
2932
2933//===----------------------------------------------------------------------===//
2934// VecCmpOp
2935//===----------------------------------------------------------------------===//
2936
2937OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
2938 auto lhsVecAttr =
2939 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getLhs());
2940 auto rhsVecAttr =
2941 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getRhs());
2942 if (!lhsVecAttr || !rhsVecAttr)
2943 return {};
2944
2945 mlir::Type inputElemTy =
2946 mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType();
2947 if (!isAnyIntegerOrFloatingPointType(inputElemTy))
2948 return {};
2949
2950 cir::CmpOpKind opKind = adaptor.getKind();
2951 mlir::ArrayAttr lhsVecElhs = lhsVecAttr.getElts();
2952 mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts();
2953 uint64_t vecSize = lhsVecElhs.size();
2954
2955 SmallVector<mlir::Attribute, 16> elements(vecSize);
2956 bool isIntAttr = vecSize && mlir::isa<cir::IntAttr>(lhsVecElhs[0]);
2957 for (uint64_t i = 0; i < vecSize; i++) {
2958 mlir::Attribute lhsAttr = lhsVecElhs[i];
2959 mlir::Attribute rhsAttr = rhsVecElhs[i];
2960 int cmpResult = 0;
2961 switch (opKind) {
2962 case cir::CmpOpKind::lt: {
2963 if (isIntAttr) {
2964 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
2965 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2966 } else {
2967 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <
2968 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2969 }
2970 break;
2971 }
2972 case cir::CmpOpKind::le: {
2973 if (isIntAttr) {
2974 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
2975 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2976 } else {
2977 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <=
2978 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2979 }
2980 break;
2981 }
2982 case cir::CmpOpKind::gt: {
2983 if (isIntAttr) {
2984 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
2985 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2986 } else {
2987 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >
2988 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2989 }
2990 break;
2991 }
2992 case cir::CmpOpKind::ge: {
2993 if (isIntAttr) {
2994 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
2995 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2996 } else {
2997 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >=
2998 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2999 }
3000 break;
3001 }
3002 case cir::CmpOpKind::eq: {
3003 if (isIntAttr) {
3004 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
3005 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3006 } else {
3007 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() ==
3008 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3009 }
3010 break;
3011 }
3012 case cir::CmpOpKind::ne: {
3013 if (isIntAttr) {
3014 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
3015 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
3016 } else {
3017 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() !=
3018 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
3019 }
3020 break;
3021 }
3022 }
3023
3024 elements[i] = cir::IntAttr::get(getType().getElementType(), cmpResult);
3025 }
3026
3027 return cir::ConstVectorAttr::get(
3028 getType(), mlir::ArrayAttr::get(getContext(), elements));
3029}
3030
3031//===----------------------------------------------------------------------===//
3032// VecShuffleOp
3033//===----------------------------------------------------------------------===//
3034
3035OpFoldResult cir::VecShuffleOp::fold(FoldAdaptor adaptor) {
3036 auto vec1Attr =
3037 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec1());
3038 auto vec2Attr =
3039 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec2());
3040 if (!vec1Attr || !vec2Attr)
3041 return {};
3042
3043 mlir::Type vec1ElemTy =
3044 mlir::cast<cir::VectorType>(vec1Attr.getType()).getElementType();
3045
3046 mlir::ArrayAttr vec1Elts = vec1Attr.getElts();
3047 mlir::ArrayAttr vec2Elts = vec2Attr.getElts();
3048 mlir::ArrayAttr indicesElts = adaptor.getIndices();
3049
3051 elements.reserve(indicesElts.size());
3052
3053 uint64_t vec1Size = vec1Elts.size();
3054 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
3055 if (idxAttr.getSInt() == -1) {
3056 elements.push_back(cir::UndefAttr::get(vec1ElemTy));
3057 continue;
3058 }
3059
3060 uint64_t idxValue = idxAttr.getUInt();
3061 elements.push_back(idxValue < vec1Size ? vec1Elts[idxValue]
3062 : vec2Elts[idxValue - vec1Size]);
3063 }
3064
3065 return cir::ConstVectorAttr::get(
3066 getType(), mlir::ArrayAttr::get(getContext(), elements));
3067}
3068
3069LogicalResult cir::VecShuffleOp::verify() {
3070 // The number of elements in the indices array must match the number of
3071 // elements in the result type.
3072 if (getIndices().size() != getResult().getType().getSize()) {
3073 return emitOpError() << ": the number of elements in " << getIndices()
3074 << " and " << getResult().getType() << " don't match";
3075 }
3076
3077 // The element types of the two input vectors and of the result type must
3078 // match.
3079 if (getVec1().getType().getElementType() !=
3080 getResult().getType().getElementType()) {
3081 return emitOpError() << ": element types of " << getVec1().getType()
3082 << " and " << getResult().getType() << " don't match";
3083 }
3084
3085 const uint64_t maxValidIndex =
3086 getVec1().getType().getSize() + getVec2().getType().getSize() - 1;
3087 if (llvm::any_of(
3088 getIndices().getAsRange<cir::IntAttr>(), [&](cir::IntAttr idxAttr) {
3089 return idxAttr.getSInt() != -1 && idxAttr.getUInt() > maxValidIndex;
3090 })) {
3091 return emitOpError() << ": index for __builtin_shufflevector must be "
3092 "less than the total number of vector elements";
3093 }
3094 return success();
3095}
3096
3097//===----------------------------------------------------------------------===//
3098// VecShuffleDynamicOp
3099//===----------------------------------------------------------------------===//
3100
3101OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
3102 mlir::Attribute vec = adaptor.getVec();
3103 mlir::Attribute indices = adaptor.getIndices();
3104 if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec) &&
3105 mlir::isa_and_nonnull<cir::ConstVectorAttr>(indices)) {
3106 auto vecAttr = mlir::cast<cir::ConstVectorAttr>(vec);
3107 auto indicesAttr = mlir::cast<cir::ConstVectorAttr>(indices);
3108
3109 mlir::ArrayAttr vecElts = vecAttr.getElts();
3110 mlir::ArrayAttr indicesElts = indicesAttr.getElts();
3111
3112 const uint64_t numElements = vecElts.size();
3113
3115 elements.reserve(numElements);
3116
3117 const uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1;
3118 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
3119 uint64_t idxValue = idxAttr.getUInt();
3120 uint64_t newIdx = idxValue & maskBits;
3121 elements.push_back(vecElts[newIdx]);
3122 }
3123
3124 return cir::ConstVectorAttr::get(
3125 getType(), mlir::ArrayAttr::get(getContext(), elements));
3126 }
3127
3128 return {};
3129}
3130
3131LogicalResult cir::VecShuffleDynamicOp::verify() {
3132 // The number of elements in the two input vectors must match.
3133 if (getVec().getType().getSize() !=
3134 mlir::cast<cir::VectorType>(getIndices().getType()).getSize()) {
3135 return emitOpError() << ": the number of elements in " << getVec().getType()
3136 << " and " << getIndices().getType() << " don't match";
3137 }
3138 return success();
3139}
3140
3141//===----------------------------------------------------------------------===//
3142// VecTernaryOp
3143//===----------------------------------------------------------------------===//
3144
3145LogicalResult cir::VecTernaryOp::verify() {
3146 // Verify that the condition operand has the same number of elements as the
3147 // other operands. (The automatic verification already checked that all
3148 // operands are vector types and that the second and third operands are the
3149 // same type.)
3150 if (getCond().getType().getSize() != getLhs().getType().getSize()) {
3151 return emitOpError() << ": the number of elements in "
3152 << getCond().getType() << " and " << getLhs().getType()
3153 << " don't match";
3154 }
3155 return success();
3156}
3157
3158OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
3159 mlir::Attribute cond = adaptor.getCond();
3160 mlir::Attribute lhs = adaptor.getLhs();
3161 mlir::Attribute rhs = adaptor.getRhs();
3162
3163 if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
3164 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
3165 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
3166 return {};
3167 auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
3168 auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
3169 auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
3170
3171 mlir::ArrayAttr condElts = condVec.getElts();
3172
3174 elements.reserve(condElts.size());
3175
3176 for (const auto &[idx, condAttr] :
3177 llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
3178 if (condAttr.getSInt()) {
3179 elements.push_back(lhsVec.getElts()[idx]);
3180 } else {
3181 elements.push_back(rhsVec.getElts()[idx]);
3182 }
3183 }
3184
3185 cir::VectorType vecTy = getLhs().getType();
3186 return cir::ConstVectorAttr::get(
3187 vecTy, mlir::ArrayAttr::get(getContext(), elements));
3188}
3189
3190//===----------------------------------------------------------------------===//
3191// ComplexCreateOp
3192//===----------------------------------------------------------------------===//
3193
3194LogicalResult cir::ComplexCreateOp::verify() {
3195 if (getType().getElementType() != getReal().getType()) {
3196 emitOpError()
3197 << "operand type of cir.complex.create does not match its result type";
3198 return failure();
3199 }
3200
3201 return success();
3202}
3203
3204OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
3205 mlir::Attribute real = adaptor.getReal();
3206 mlir::Attribute imag = adaptor.getImag();
3207 if (!real || !imag)
3208 return {};
3209
3210 // When both of real and imag are constants, we can fold the operation into an
3211 // `#cir.const_complex` operation.
3212 auto realAttr = mlir::cast<mlir::TypedAttr>(real);
3213 auto imagAttr = mlir::cast<mlir::TypedAttr>(imag);
3214 return cir::ConstComplexAttr::get(realAttr, imagAttr);
3215}
3216
3217//===----------------------------------------------------------------------===//
3218// ComplexRealOp
3219//===----------------------------------------------------------------------===//
3220
3221LogicalResult cir::ComplexRealOp::verify() {
3222 mlir::Type operandTy = getOperand().getType();
3223 if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
3224 operandTy = complexOperandTy.getElementType();
3225
3226 if (getType() != operandTy) {
3227 emitOpError() << ": result type does not match operand type";
3228 return failure();
3229 }
3230
3231 return success();
3232}
3233
3234OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
3235 if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
3236 return nullptr;
3237
3238 if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
3239 return complexCreateOp.getOperand(0);
3240
3241 auto complex =
3242 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
3243 return complex ? complex.getReal() : nullptr;
3244}
3245
3246//===----------------------------------------------------------------------===//
3247// ComplexImagOp
3248//===----------------------------------------------------------------------===//
3249
3250LogicalResult cir::ComplexImagOp::verify() {
3251 mlir::Type operandTy = getOperand().getType();
3252 if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
3253 operandTy = complexOperandTy.getElementType();
3254
3255 if (getType() != operandTy) {
3256 emitOpError() << ": result type does not match operand type";
3257 return failure();
3258 }
3259
3260 return success();
3261}
3262
3263OpFoldResult cir::ComplexImagOp::fold(FoldAdaptor adaptor) {
3264 if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
3265 return nullptr;
3266
3267 if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
3268 return complexCreateOp.getOperand(1);
3269
3270 auto complex =
3271 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
3272 return complex ? complex.getImag() : nullptr;
3273}
3274
3275//===----------------------------------------------------------------------===//
3276// ComplexRealPtrOp
3277//===----------------------------------------------------------------------===//
3278
3279LogicalResult cir::ComplexRealPtrOp::verify() {
3280 mlir::Type resultPointeeTy = getType().getPointee();
3281 cir::PointerType operandPtrTy = getOperand().getType();
3282 auto operandPointeeTy =
3283 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
3284
3285 if (resultPointeeTy != operandPointeeTy.getElementType()) {
3286 return emitOpError() << ": result type does not match operand type";
3287 }
3288
3289 return success();
3290}
3291
3292//===----------------------------------------------------------------------===//
3293// ComplexImagPtrOp
3294//===----------------------------------------------------------------------===//
3295
3296LogicalResult cir::ComplexImagPtrOp::verify() {
3297 mlir::Type resultPointeeTy = getType().getPointee();
3298 cir::PointerType operandPtrTy = getOperand().getType();
3299 auto operandPointeeTy =
3300 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
3301
3302 if (resultPointeeTy != operandPointeeTy.getElementType()) {
3303 return emitOpError()
3304 << "cir.complex.imag_ptr result type does not match operand type";
3305 }
3306 return success();
3307}
3308
3309//===----------------------------------------------------------------------===//
3310// Bit manipulation operations
3311//===----------------------------------------------------------------------===//
3312
3313static OpFoldResult
3314foldUnaryBitOp(mlir::Attribute inputAttr,
3315 llvm::function_ref<llvm::APInt(const llvm::APInt &)> func,
3316 bool poisonZero = false) {
3317 if (mlir::isa_and_present<cir::PoisonAttr>(inputAttr)) {
3318 // Propagate poison value
3319 return inputAttr;
3320 }
3321
3322 auto input = mlir::dyn_cast_if_present<IntAttr>(inputAttr);
3323 if (!input)
3324 return nullptr;
3325
3326 llvm::APInt inputValue = input.getValue();
3327 if (poisonZero && inputValue.isZero())
3328 return cir::PoisonAttr::get(input.getType());
3329
3330 llvm::APInt resultValue = func(inputValue);
3331 return IntAttr::get(input.getType(), resultValue);
3332}
3333
3334OpFoldResult BitClrsbOp::fold(FoldAdaptor adaptor) {
3335 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3336 unsigned resultValue =
3337 inputValue.getBitWidth() - inputValue.getSignificantBits();
3338 return llvm::APInt(inputValue.getBitWidth(), resultValue);
3339 });
3340}
3341
3342OpFoldResult BitClzOp::fold(FoldAdaptor adaptor) {
3343 return foldUnaryBitOp(
3344 adaptor.getInput(),
3345 [](const llvm::APInt &inputValue) {
3346 unsigned resultValue = inputValue.countLeadingZeros();
3347 return llvm::APInt(inputValue.getBitWidth(), resultValue);
3348 },
3349 getPoisonZero());
3350}
3351
3352OpFoldResult BitCtzOp::fold(FoldAdaptor adaptor) {
3353 return foldUnaryBitOp(
3354 adaptor.getInput(),
3355 [](const llvm::APInt &inputValue) {
3356 return llvm::APInt(inputValue.getBitWidth(),
3357 inputValue.countTrailingZeros());
3358 },
3359 getPoisonZero());
3360}
3361
3362OpFoldResult BitFfsOp::fold(FoldAdaptor adaptor) {
3363 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3364 unsigned trailingZeros = inputValue.countTrailingZeros();
3365 unsigned result =
3366 trailingZeros == inputValue.getBitWidth() ? 0 : trailingZeros + 1;
3367 return llvm::APInt(inputValue.getBitWidth(), result);
3368 });
3369}
3370
3371OpFoldResult BitParityOp::fold(FoldAdaptor adaptor) {
3372 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3373 return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount() % 2);
3374 });
3375}
3376
3377OpFoldResult BitPopcountOp::fold(FoldAdaptor adaptor) {
3378 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3379 return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount());
3380 });
3381}
3382
3383OpFoldResult BitReverseOp::fold(FoldAdaptor adaptor) {
3384 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3385 return inputValue.reverseBits();
3386 });
3387}
3388
3389OpFoldResult ByteSwapOp::fold(FoldAdaptor adaptor) {
3390 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3391 return inputValue.byteSwap();
3392 });
3393}
3394
3395OpFoldResult RotateOp::fold(FoldAdaptor adaptor) {
3396 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()) ||
3397 mlir::isa_and_present<cir::PoisonAttr>(adaptor.getAmount())) {
3398 // Propagate poison values
3399 return cir::PoisonAttr::get(getType());
3400 }
3401
3402 auto input = mlir::dyn_cast_if_present<IntAttr>(adaptor.getInput());
3403 auto amount = mlir::dyn_cast_if_present<IntAttr>(adaptor.getAmount());
3404 if (!input && !amount)
3405 return nullptr;
3406
3407 // We could fold cir.rotate even if one of its two operands is not a constant:
3408 // - `cir.rotate left/right %0, 0` could be folded into just %0 even if %0
3409 // is not a constant.
3410 // - `cir.rotate left/right 0/0b111...111, %0` could be folded into 0 or
3411 // 0b111...111 even if %0 is not a constant.
3412
3413 llvm::APInt inputValue;
3414 if (input) {
3415 inputValue = input.getValue();
3416 if (inputValue.isZero() || inputValue.isAllOnes()) {
3417 // An input value of all 0s or all 1s will not change after rotation
3418 return input;
3419 }
3420 }
3421
3422 uint64_t amountValue;
3423 if (amount) {
3424 amountValue = amount.getValue().urem(getInput().getType().getWidth());
3425 if (amountValue == 0) {
3426 // A shift amount of 0 will not change the input value
3427 return getInput();
3428 }
3429 }
3430
3431 if (!input || !amount)
3432 return nullptr;
3433
3434 assert(inputValue.getBitWidth() == getInput().getType().getWidth() &&
3435 "input value must have the same bit width as the input type");
3436
3437 llvm::APInt resultValue;
3438 if (isRotateLeft())
3439 resultValue = inputValue.rotl(amountValue);
3440 else
3441 resultValue = inputValue.rotr(amountValue);
3442
3443 return IntAttr::get(input.getContext(), input.getType(), resultValue);
3444}
3445
3446//===----------------------------------------------------------------------===//
3447// InlineAsmOp
3448//===----------------------------------------------------------------------===//
3449
3450void cir::InlineAsmOp::print(OpAsmPrinter &p) {
3451 p << '(' << getAsmFlavor() << ", ";
3452 p.increaseIndent();
3453 p.printNewline();
3454
3455 llvm::SmallVector<std::string, 3> names{"out", "in", "in_out"};
3456 auto *nameIt = names.begin();
3457 auto *attrIt = getOperandAttrs().begin();
3458
3459 for (mlir::OperandRange ops : getAsmOperands()) {
3460 p << *nameIt << " = ";
3461
3462 p << '[';
3463 llvm::interleaveComma(llvm::make_range(ops.begin(), ops.end()), p,
3464 [&](Value value) {
3465 p.printOperand(value);
3466 p << " : " << value.getType();
3467 if (*attrIt)
3468 p << " (maybe_memory)";
3469 attrIt++;
3470 });
3471 p << "],";
3472 p.printNewline();
3473 ++nameIt;
3474 }
3475
3476 p << "{";
3477 p.printString(getAsmString());
3478 p << " ";
3479 p.printString(getConstraints());
3480 p << "}";
3481 p.decreaseIndent();
3482 p << ')';
3483 if (getSideEffects())
3484 p << " side_effects";
3485
3486 std::array elidedAttrs{
3487 llvm::StringRef("asm_flavor"), llvm::StringRef("asm_string"),
3488 llvm::StringRef("constraints"), llvm::StringRef("operand_attrs"),
3489 llvm::StringRef("operands_segments"), llvm::StringRef("side_effects")};
3490 p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs);
3491
3492 if (auto v = getRes())
3493 p << " -> " << v.getType();
3494}
3495
3496void cir::InlineAsmOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3497 ArrayRef<ValueRange> asmOperands,
3498 StringRef asmString, StringRef constraints,
3499 bool sideEffects, cir::AsmFlavor asmFlavor,
3500 ArrayRef<Attribute> operandAttrs) {
3501 // Set up the operands_segments for VariadicOfVariadic
3502 SmallVector<int32_t> segments;
3503 for (auto operandRange : asmOperands) {
3504 segments.push_back(operandRange.size());
3505 odsState.addOperands(operandRange);
3506 }
3507
3508 odsState.addAttribute(
3509 "operands_segments",
3510 DenseI32ArrayAttr::get(odsBuilder.getContext(), segments));
3511 odsState.addAttribute("asm_string", odsBuilder.getStringAttr(asmString));
3512 odsState.addAttribute("constraints", odsBuilder.getStringAttr(constraints));
3513 odsState.addAttribute("asm_flavor",
3514 AsmFlavorAttr::get(odsBuilder.getContext(), asmFlavor));
3515
3516 if (sideEffects)
3517 odsState.addAttribute("side_effects", odsBuilder.getUnitAttr());
3518
3519 odsState.addAttribute("operand_attrs", odsBuilder.getArrayAttr(operandAttrs));
3520}
3521
3522ParseResult cir::InlineAsmOp::parse(OpAsmParser &parser,
3523 OperationState &result) {
3525 llvm::SmallVector<int32_t> operandsGroupSizes;
3526 std::string asmString, constraints;
3527 Type resType;
3528 MLIRContext *ctxt = parser.getBuilder().getContext();
3529
3530 auto error = [&](const Twine &msg) -> LogicalResult {
3531 return parser.emitError(parser.getCurrentLocation(), msg);
3532 };
3533
3534 auto expected = [&](const std::string &c) {
3535 return error("expected '" + c + "'");
3536 };
3537
3538 if (parser.parseLParen().failed())
3539 return expected("(");
3540
3541 auto flavor = FieldParser<AsmFlavor, AsmFlavor>::parse(parser);
3542 if (failed(flavor))
3543 return error("Unknown AsmFlavor");
3544
3545 if (parser.parseComma().failed())
3546 return expected(",");
3547
3548 auto parseValue = [&](Value &v) {
3549 OpAsmParser::UnresolvedOperand op;
3550
3551 if (parser.parseOperand(op) || parser.parseColon())
3552 return error("can't parse operand");
3553
3554 Type typ;
3555 if (parser.parseType(typ).failed())
3556 return error("can't parse operand type");
3558 if (parser.resolveOperand(op, typ, tmp))
3559 return error("can't resolve operand");
3560 v = tmp[0];
3561 return mlir::success();
3562 };
3563
3564 auto parseOperands = [&](llvm::StringRef name) {
3565 if (parser.parseKeyword(name).failed())
3566 return error("expected " + name + " operands here");
3567 if (parser.parseEqual().failed())
3568 return expected("=");
3569 if (parser.parseLSquare().failed())
3570 return expected("[");
3571
3572 int size = 0;
3573 if (parser.parseOptionalRSquare().succeeded()) {
3574 operandsGroupSizes.push_back(size);
3575 if (parser.parseComma())
3576 return expected(",");
3577 return mlir::success();
3578 }
3579
3580 auto parseOperand = [&]() {
3581 Value val;
3582 if (parseValue(val).succeeded()) {
3583 result.operands.push_back(val);
3584 size++;
3585
3586 if (parser.parseOptionalLParen().failed()) {
3587 operandAttrs.push_back(mlir::Attribute());
3588 return mlir::success();
3589 }
3590
3591 if (parser.parseKeyword("maybe_memory").succeeded()) {
3592 operandAttrs.push_back(mlir::UnitAttr::get(ctxt));
3593 if (parser.parseRParen())
3594 return expected(")");
3595 return mlir::success();
3596 } else {
3597 return expected("maybe_memory");
3598 }
3599 }
3600 return mlir::failure();
3601 };
3602
3603 if (parser.parseCommaSeparatedList(parseOperand).failed())
3604 return mlir::failure();
3605
3606 if (parser.parseRSquare().failed() || parser.parseComma().failed())
3607 return expected("]");
3608 operandsGroupSizes.push_back(size);
3609 return mlir::success();
3610 };
3611
3612 if (parseOperands("out").failed() || parseOperands("in").failed() ||
3613 parseOperands("in_out").failed())
3614 return error("failed to parse operands");
3615
3616 if (parser.parseLBrace())
3617 return expected("{");
3618 if (parser.parseString(&asmString))
3619 return error("asm string parsing failed");
3620 if (parser.parseString(&constraints))
3621 return error("constraints string parsing failed");
3622 if (parser.parseRBrace())
3623 return expected("}");
3624 if (parser.parseRParen())
3625 return expected(")");
3626
3627 if (parser.parseOptionalKeyword("side_effects").succeeded())
3628 result.attributes.set("side_effects", UnitAttr::get(ctxt));
3629
3630 if (parser.parseOptionalArrow().succeeded() &&
3631 parser.parseType(resType).failed())
3632 return mlir::failure();
3633
3634 if (parser.parseOptionalAttrDict(result.attributes).failed())
3635 return mlir::failure();
3636
3637 result.attributes.set("asm_flavor", AsmFlavorAttr::get(ctxt, *flavor));
3638 result.attributes.set("asm_string", StringAttr::get(ctxt, asmString));
3639 result.attributes.set("constraints", StringAttr::get(ctxt, constraints));
3640 result.attributes.set("operand_attrs", ArrayAttr::get(ctxt, operandAttrs));
3641 result.getOrAddProperties<InlineAsmOp::Properties>().operands_segments =
3642 parser.getBuilder().getDenseI32ArrayAttr(operandsGroupSizes);
3643 if (resType)
3644 result.addTypes(TypeRange{resType});
3645
3646 return mlir::success();
3647}
3648
3649//===----------------------------------------------------------------------===//
3650// ThrowOp
3651//===----------------------------------------------------------------------===//
3652
3653mlir::LogicalResult cir::ThrowOp::verify() {
3654 // For the no-rethrow version, it must have at least the exception pointer.
3655 if (rethrows())
3656 return success();
3657
3658 if (getNumOperands() != 0) {
3659 if (getTypeInfo())
3660 return success();
3661 return emitOpError() << "'type_info' symbol attribute missing";
3662 }
3663
3664 return failure();
3665}
3666
3667//===----------------------------------------------------------------------===//
3668// AtomicFetchOp
3669//===----------------------------------------------------------------------===//
3670
3671LogicalResult cir::AtomicFetchOp::verify() {
3672 if (getBinop() != cir::AtomicFetchKind::Add &&
3673 getBinop() != cir::AtomicFetchKind::Sub &&
3674 getBinop() != cir::AtomicFetchKind::Max &&
3675 getBinop() != cir::AtomicFetchKind::Min &&
3676 !mlir::isa<cir::IntType>(getVal().getType()))
3677 return emitError("only atomic add, sub, max, and min operation could "
3678 "operate on floating-point values");
3679 return success();
3680}
3681
3682//===----------------------------------------------------------------------===//
3683// TypeInfoAttr
3684//===----------------------------------------------------------------------===//
3685
3686LogicalResult cir::TypeInfoAttr::verify(
3687 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
3688 ::mlir::Type type, ::mlir::ArrayAttr typeInfoData) {
3689
3690 if (cir::ConstRecordAttr::verify(emitError, type, typeInfoData).failed())
3691 return failure();
3692
3693 return success();
3694}
3695
3696//===----------------------------------------------------------------------===//
3697// TryOp
3698//===----------------------------------------------------------------------===//
3699
3700void cir::TryOp::getSuccessorRegions(
3701 mlir::RegionBranchPoint point,
3703 // The `try` and the `catchers` region branch back to the parent operation.
3704 if (!point.isParent()) {
3705 regions.push_back(RegionSuccessor::parent());
3706 return;
3707 }
3708
3709 regions.push_back(mlir::RegionSuccessor(&getTryRegion()));
3710
3711 // TODO(CIR): If we know a target function never throws a specific type, we
3712 // can remove the catch handler.
3713 for (mlir::Region &handlerRegion : this->getHandlerRegions())
3714 regions.push_back(mlir::RegionSuccessor(&handlerRegion));
3715}
3716
3717mlir::ValueRange cir::TryOp::getSuccessorInputs(RegionSuccessor successor) {
3718 return successor.isParent() ? ValueRange(getOperation()->getResults())
3719 : ValueRange();
3720}
3721
3722static void
3723printTryHandlerRegions(mlir::OpAsmPrinter &printer, cir::TryOp op,
3724 mlir::MutableArrayRef<mlir::Region> handlerRegions,
3725 mlir::ArrayAttr handlerTypes) {
3726 if (!handlerTypes)
3727 return;
3728
3729 for (const auto [typeIdx, typeAttr] : llvm::enumerate(handlerTypes)) {
3730 if (typeIdx)
3731 printer << " ";
3732
3733 if (mlir::isa<cir::CatchAllAttr>(typeAttr)) {
3734 printer << "catch all ";
3735 } else if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
3736 printer << "unwind ";
3737 } else {
3738 printer << "catch [type ";
3739 printer.printAttribute(typeAttr);
3740 printer << "] ";
3741 }
3742
3743 printer.printRegion(handlerRegions[typeIdx],
3744 /*printEntryBLockArgs=*/false,
3745 /*printBlockTerminators=*/true);
3746 }
3747}
3748
3749static mlir::ParseResult parseTryHandlerRegions(
3750 mlir::OpAsmParser &parser,
3751 llvm::SmallVectorImpl<std::unique_ptr<mlir::Region>> &handlerRegions,
3752 mlir::ArrayAttr &handlerTypes) {
3753
3754 auto parseCheckedCatcherRegion = [&]() -> mlir::ParseResult {
3755 handlerRegions.emplace_back(new mlir::Region);
3756
3757 mlir::Region &currRegion = *handlerRegions.back();
3758 mlir::SMLoc regionLoc = parser.getCurrentLocation();
3759 if (parser.parseRegion(currRegion)) {
3760 handlerRegions.clear();
3761 return failure();
3762 }
3763
3764 if (currRegion.empty())
3765 return parser.emitError(regionLoc, "handler region shall not be empty");
3766
3767 if (!(currRegion.back().mightHaveTerminator() &&
3768 currRegion.back().getTerminator()))
3769 return parser.emitError(
3770 regionLoc, "blocks are expected to be explicitly terminated");
3771
3772 return success();
3773 };
3774
3775 bool hasCatchAll = false;
3777 while (parser.parseOptionalKeyword("catch").succeeded()) {
3778 bool hasLSquare = parser.parseOptionalLSquare().succeeded();
3779
3780 llvm::StringRef attrStr;
3781 if (parser.parseOptionalKeyword(&attrStr, {"all", "type"}).failed())
3782 return parser.emitError(parser.getCurrentLocation(),
3783 "expected 'all' or 'type' keyword");
3784
3785 bool isCatchAll = attrStr == "all";
3786 if (isCatchAll) {
3787 if (hasCatchAll)
3788 return parser.emitError(parser.getCurrentLocation(),
3789 "can't have more than one catch all");
3790 hasCatchAll = true;
3791 }
3792
3793 mlir::Attribute exceptionRTTIAttr;
3794 if (!isCatchAll && parser.parseAttribute(exceptionRTTIAttr).failed())
3795 return parser.emitError(parser.getCurrentLocation(),
3796 "expected valid RTTI info attribute");
3797
3798 catcherAttrs.push_back(isCatchAll
3799 ? cir::CatchAllAttr::get(parser.getContext())
3800 : exceptionRTTIAttr);
3801
3802 if (hasLSquare && isCatchAll)
3803 return parser.emitError(parser.getCurrentLocation(),
3804 "catch all dosen't need RTTI info attribute");
3805
3806 if (hasLSquare && parser.parseRSquare().failed())
3807 return parser.emitError(parser.getCurrentLocation(),
3808 "expected `]` after RTTI info attribute");
3809
3810 if (parseCheckedCatcherRegion().failed())
3811 return mlir::failure();
3812 }
3813
3814 if (parser.parseOptionalKeyword("unwind").succeeded()) {
3815 if (hasCatchAll)
3816 return parser.emitError(parser.getCurrentLocation(),
3817 "unwind can't be used with catch all");
3818
3819 catcherAttrs.push_back(cir::UnwindAttr::get(parser.getContext()));
3820 if (parseCheckedCatcherRegion().failed())
3821 return mlir::failure();
3822 }
3823
3824 handlerTypes = parser.getBuilder().getArrayAttr(catcherAttrs);
3825 return mlir::success();
3826}
3827
3828//===----------------------------------------------------------------------===//
3829// EhTypeIdOp
3830//===----------------------------------------------------------------------===//
3831
3832LogicalResult
3833cir::EhTypeIdOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3834 Operation *op = symbolTable.lookupNearestSymbolFrom(*this, getTypeSymAttr());
3835 if (!isa_and_nonnull<GlobalOp>(op))
3836 return emitOpError("'")
3837 << getTypeSym() << "' does not reference a valid cir.global";
3838 return success();
3839}
3840
3841//===----------------------------------------------------------------------===//
3842// TableGen'd op method definitions
3843//===----------------------------------------------------------------------===//
3844
3845#define GET_OP_CLASSES
3846#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
static const MemRegion * getRegion(const CallEvent &Call, const MutexDescriptor &Descriptor, bool IsLock)
static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op, cir::FuncOp function)
static LogicalResult verifyMemberPtrCast(Operation *op, mlir::Value src, mlir::Type resultTy)
static bool isBoolNot(cir::UnaryOp op)
static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, mlir::OperationState &result, bool hasDestinationBlocks=false)
static bool isIntOrBoolCast(cir::CastOp op)
static void printConstant(OpAsmPrinter &p, Attribute value)
static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser, mlir::Region &region)
ParseResult parseInlineKindAttr(OpAsmParser &parser, cir::InlineKindAttr &inlineKindAttr)
void printInlineKindAttr(OpAsmPrinter &p, cir::InlineKindAttr inlineKindAttr)
void printVisibilityAttr(OpAsmPrinter &printer, cir::VisibilityAttr &visibility)
static ParseResult parseSwitchFlatOpCases(OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< llvm::SmallVector< OpAsmParser::UnresolvedOperand > > &caseOperands, SmallVectorImpl< llvm::SmallVector< Type > > &caseOperandTypes)
<cases> ::= [ (case (, case )* )?
static LogicalResult verifyCallCommInSymbolUses(mlir::Operation *op, SymbolTableCollection &symbolTable)
static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region, SMLoc errLoc)
static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValueAttr, mlir::Region &ctorRegion, mlir::Region &dtorRegion)
void printIndirectBrOpSucessors(OpAsmPrinter &p, cir::IndirectBrOp op, Type flagType, SuccessorRange succs, OperandRangeRange succOperands, const TypeRangeRange &succOperandsTypes)
static OpFoldResult foldUnaryBitOp(mlir::Attribute inputAttr, llvm::function_ref< llvm::APInt(const llvm::APInt &)> func, bool poisonZero=false)
static llvm::StringRef getLinkageAttrNameString()
Returns the name used for the linkage attribute.
static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue)
Parse an enum from the keyword, or default to the provided default value.
static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op, Type flagType, mlir::ArrayAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static mlir::ParseResult parseTryCallDestinations(mlir::OpAsmParser &parser, mlir::OperationState &result)
static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op, TypeAttr type, Attribute initAttr, mlir::Region &ctorRegion, mlir::Region &dtorRegion)
static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result)
Parse an enum from the keyword, return failure if the keyword is not found.
static Value tryFoldCastChain(cir::CastOp op)
void parseVisibilityAttr(OpAsmParser &parser, cir::VisibilityAttr &visibility)
static void printTryHandlerRegions(mlir::OpAsmPrinter &printer, cir::TryOp op, mlir::MutableArrayRef< mlir::Region > handlerRegions, mlir::ArrayAttr handlerTypes)
ParseResult parseIndirectBrOpSucessors(OpAsmParser &parser, Type &flagType, SmallVectorImpl< Block * > &succOperandBlocks, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &succOperands, SmallVectorImpl< SmallVector< Type > > &succOperandsTypes)
static bool omitRegionTerm(mlir::Region &r)
static void printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym, mlir::Value indirectCallee, mlir::OpAsmPrinter &printer, bool isNothrow, cir::SideEffect sideEffect, mlir::Block *normalDest=nullptr, mlir::Block *unwindDest=nullptr)
static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer, cir::ScopeOp &op, mlir::Region &region)
static ParseResult parseConstantValue(OpAsmParser &parser, mlir::Attribute &valueAttr)
static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, mlir::Attribute attrType)
static mlir::ParseResult parseTryHandlerRegions(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< std::unique_ptr< mlir::Region > > &handlerRegions, mlir::ArrayAttr &handlerTypes)
#define REGISTER_ENUM_TYPE(Ty)
static int parseOptionalKeywordAlternative(AsmParser &parser, ArrayRef< llvm::StringRef > keywords)
llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> BuilderCallbackRef
Definition CIRDialect.h:37
llvm::function_ref< void( mlir::OpBuilder &, mlir::Location, mlir::OperationState &)> BuilderOpStateCallbackRef
Definition CIRDialect.h:39
static std::optional< NonLoc > getIndex(ProgramStateRef State, const ElementRegion *ER, CharKind CK)
static Decl::Kind getKind(const Decl *D)
TokenType getType() const
Returns the token's type, e.g.
*collection of selector each with an associated kind and an ordered *collection of selectors A selector has a an optional score condition
*collection of selector each with an associated kind and an ordered *collection of selectors A selector has a kind
__device__ __2f16 float c
void buildTerminatedBody(mlir::OpBuilder &builder, mlir::Location loc)
const internal::VariadicAllOfMatcher< Attr > attr
const AstTypeMatcher< RecordType > recordType
StringRef getName(const HeaderType T)
Definition HeaderFile.h:38
RangeSelector name(std::string ID)
Given a node with a "name", (like NamedDecl, DeclRefExpr, CxxCtorInitializer, and TypeLoc) selects th...
nullptr
This class represents a compute construct, representing a 'Kind' of ‘parallel’, 'serial',...
const half4 dst(half4 Src0, half4 Src1)
__DEVICE__ _Tp arg(const std::complex< _Tp > &__c)
static bool memberFuncPtrCast()
static bool addressSpace()
static bool opCallCallConv()
static bool opScopeCleanupRegion()
static bool supportIFuncAttr()