clang 22.0.0git
CIRDialect.cpp
Go to the documentation of this file.
1//===- CIRDialect.cpp - MLIR CIR ops implementation -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the CIR dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
18
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/DialectImplementation.h"
21#include "mlir/Interfaces/ControlFlowInterfaces.h"
22#include "mlir/Interfaces/FunctionImplementation.h"
23#include "mlir/Support/LLVM.h"
24
25#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
26#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
28#include "llvm/ADT/SetOperations.h"
29#include "llvm/ADT/SmallSet.h"
30#include "llvm/Support/LogicalResult.h"
31
32using namespace mlir;
33using namespace cir;
34
35//===----------------------------------------------------------------------===//
36// CIR Dialect
37//===----------------------------------------------------------------------===//
38namespace {
39struct CIROpAsmDialectInterface : public OpAsmDialectInterface {
40 using OpAsmDialectInterface::OpAsmDialectInterface;
41
42 AliasResult getAlias(Type type, raw_ostream &os) const final {
43 if (auto recordType = dyn_cast<cir::RecordType>(type)) {
44 StringAttr nameAttr = recordType.getName();
45 if (!nameAttr)
46 os << "rec_anon_" << recordType.getKindAsStr();
47 else
48 os << "rec_" << nameAttr.getValue();
49 return AliasResult::OverridableAlias;
50 }
51 if (auto intType = dyn_cast<cir::IntType>(type)) {
52 // We only provide alias for standard integer types (i.e. integer types
53 // whose width is a power of 2 and at least 8).
54 unsigned width = intType.getWidth();
55 if (width < 8 || !llvm::isPowerOf2_32(width))
56 return AliasResult::NoAlias;
57 os << intType.getAlias();
58 return AliasResult::OverridableAlias;
59 }
60 if (auto voidType = dyn_cast<cir::VoidType>(type)) {
61 os << voidType.getAlias();
62 return AliasResult::OverridableAlias;
63 }
64
65 return AliasResult::NoAlias;
66 }
67
68 AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
69 if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr)) {
70 os << (boolAttr.getValue() ? "true" : "false");
71 return AliasResult::FinalAlias;
72 }
73 if (auto bitfield = mlir::dyn_cast<cir::BitfieldInfoAttr>(attr)) {
74 os << "bfi_" << bitfield.getName().str();
75 return AliasResult::FinalAlias;
76 }
77 if (auto dynCastInfoAttr = mlir::dyn_cast<cir::DynamicCastInfoAttr>(attr)) {
78 os << dynCastInfoAttr.getAlias();
79 return AliasResult::FinalAlias;
80 }
81 return AliasResult::NoAlias;
82 }
83};
84} // namespace
85
86void cir::CIRDialect::initialize() {
87 registerTypes();
88 registerAttributes();
89 addOperations<
90#define GET_OP_LIST
91#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
92 >();
93 addInterfaces<CIROpAsmDialectInterface>();
94}
95
96Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
97 mlir::Attribute value,
98 mlir::Type type,
99 mlir::Location loc) {
100 return cir::ConstantOp::create(builder, loc, type,
101 mlir::cast<mlir::TypedAttr>(value));
102}
103
104//===----------------------------------------------------------------------===//
105// Helpers
106//===----------------------------------------------------------------------===//
107
108// Parses one of the keywords provided in the list `keywords` and returns the
109// position of the parsed keyword in the list. If none of the keywords from the
110// list is parsed, returns -1.
111static int parseOptionalKeywordAlternative(AsmParser &parser,
112 ArrayRef<llvm::StringRef> keywords) {
113 for (auto en : llvm::enumerate(keywords)) {
114 if (succeeded(parser.parseOptionalKeyword(en.value())))
115 return en.index();
116 }
117 return -1;
118}
119
120namespace {
121template <typename Ty> struct EnumTraits {};
122
123#define REGISTER_ENUM_TYPE(Ty) \
124 template <> struct EnumTraits<cir::Ty> { \
125 static llvm::StringRef stringify(cir::Ty value) { \
126 return stringify##Ty(value); \
127 } \
128 static unsigned getMaxEnumVal() { return cir::getMaxEnumValFor##Ty(); } \
129 }
130
131REGISTER_ENUM_TYPE(GlobalLinkageKind);
132REGISTER_ENUM_TYPE(VisibilityKind);
133REGISTER_ENUM_TYPE(SideEffect);
134} // namespace
135
136/// Parse an enum from the keyword, or default to the provided default value.
137/// The return type is the enum type by default, unless overriden with the
138/// second template argument.
139template <typename EnumTy, typename RetTy = EnumTy>
140static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue) {
142 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
143 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
144
145 int index = parseOptionalKeywordAlternative(parser, names);
146 if (index == -1)
147 return static_cast<RetTy>(defaultValue);
148 return static_cast<RetTy>(index);
149}
150
151/// Parse an enum from the keyword, return failure if the keyword is not found.
152template <typename EnumTy, typename RetTy = EnumTy>
153static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result) {
155 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
156 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
157
158 int index = parseOptionalKeywordAlternative(parser, names);
159 if (index == -1)
160 return failure();
161 result = static_cast<RetTy>(index);
162 return success();
163}
164
165// Check if a region's termination omission is valid and, if so, creates and
166// inserts the omitted terminator into the region.
167static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region,
168 SMLoc errLoc) {
169 Location eLoc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
170 OpBuilder builder(parser.getBuilder().getContext());
171
172 // Insert empty block in case the region is empty to ensure the terminator
173 // will be inserted
174 if (region.empty())
175 builder.createBlock(&region);
176
177 Block &block = region.back();
178 // Region is properly terminated: nothing to do.
179 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>())
180 return success();
181
182 // Check for invalid terminator omissions.
183 if (!region.hasOneBlock())
184 return parser.emitError(errLoc,
185 "multi-block region must not omit terminator");
186
187 // Terminator was omitted correctly: recreate it.
188 builder.setInsertionPointToEnd(&block);
189 cir::YieldOp::create(builder, eLoc);
190 return success();
191}
192
193// True if the region's terminator should be omitted.
194static bool omitRegionTerm(mlir::Region &r) {
195 const auto singleNonEmptyBlock = r.hasOneBlock() && !r.back().empty();
196 const auto yieldsNothing = [&r]() {
197 auto y = dyn_cast<cir::YieldOp>(r.back().getTerminator());
198 return y && y.getArgs().empty();
199 };
200 return singleNonEmptyBlock && yieldsNothing();
201}
202
203void printVisibilityAttr(OpAsmPrinter &printer,
204 cir::VisibilityAttr &visibility) {
205 switch (visibility.getValue()) {
206 case cir::VisibilityKind::Hidden:
207 printer << "hidden";
208 break;
209 case cir::VisibilityKind::Protected:
210 printer << "protected";
211 break;
212 case cir::VisibilityKind::Default:
213 break;
214 }
215}
216
217void parseVisibilityAttr(OpAsmParser &parser, cir::VisibilityAttr &visibility) {
218 cir::VisibilityKind visibilityKind =
219 parseOptionalCIRKeyword(parser, cir::VisibilityKind::Default);
220 visibility = cir::VisibilityAttr::get(parser.getContext(), visibilityKind);
221}
222
223//===----------------------------------------------------------------------===//
224// InlineKindAttr (FIXME: remove once FuncOp uses assembly format)
225//===----------------------------------------------------------------------===//
226
227ParseResult parseInlineKindAttr(OpAsmParser &parser,
228 cir::InlineKindAttr &inlineKindAttr) {
229 // Static list of possible inline kind keywords
230 static constexpr llvm::StringRef keywords[] = {"no_inline", "always_inline",
231 "inline_hint"};
232
233 // Parse the inline kind keyword (optional)
234 llvm::StringRef keyword;
235 if (parser.parseOptionalKeyword(&keyword, keywords).failed()) {
236 // Not an inline kind keyword, leave inlineKindAttr empty
237 return success();
238 }
239
240 // Parse the enum value from the keyword
241 auto inlineKindResult = ::cir::symbolizeEnum<::cir::InlineKind>(keyword);
242 if (!inlineKindResult) {
243 return parser.emitError(parser.getCurrentLocation(), "expected one of [")
244 << llvm::join(llvm::ArrayRef(keywords), ", ")
245 << "] for inlineKind, got: " << keyword;
246 }
247
248 inlineKindAttr =
249 ::cir::InlineKindAttr::get(parser.getContext(), *inlineKindResult);
250 return success();
251}
252
253void printInlineKindAttr(OpAsmPrinter &p, cir::InlineKindAttr inlineKindAttr) {
254 if (inlineKindAttr) {
255 p << " " << stringifyInlineKind(inlineKindAttr.getValue());
256 }
257}
258//===----------------------------------------------------------------------===//
259// CIR Custom Parsers/Printers
260//===----------------------------------------------------------------------===//
261
262static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser,
263 mlir::Region &region) {
264 auto regionLoc = parser.getCurrentLocation();
265 if (parser.parseRegion(region))
266 return failure();
267 if (ensureRegionTerm(parser, region, regionLoc).failed())
268 return failure();
269 return success();
270}
271
272static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer,
273 cir::ScopeOp &op,
274 mlir::Region &region) {
275 printer.printRegion(region,
276 /*printEntryBlockArgs=*/false,
277 /*printBlockTerminators=*/!omitRegionTerm(region));
278}
279
280//===----------------------------------------------------------------------===//
281// AllocaOp
282//===----------------------------------------------------------------------===//
283
284void cir::AllocaOp::build(mlir::OpBuilder &odsBuilder,
285 mlir::OperationState &odsState, mlir::Type addr,
286 mlir::Type allocaType, llvm::StringRef name,
287 mlir::IntegerAttr alignment) {
288 odsState.addAttribute(getAllocaTypeAttrName(odsState.name),
289 mlir::TypeAttr::get(allocaType));
290 odsState.addAttribute(getNameAttrName(odsState.name),
291 odsBuilder.getStringAttr(name));
292 if (alignment) {
293 odsState.addAttribute(getAlignmentAttrName(odsState.name), alignment);
294 }
295 odsState.addTypes(addr);
296}
297
298//===----------------------------------------------------------------------===//
299// BreakOp
300//===----------------------------------------------------------------------===//
301
302LogicalResult cir::BreakOp::verify() {
304 if (!getOperation()->getParentOfType<LoopOpInterface>() &&
305 !getOperation()->getParentOfType<SwitchOp>())
306 return emitOpError("must be within a loop");
307 return success();
308}
309
310//===----------------------------------------------------------------------===//
311// ConditionOp
312//===----------------------------------------------------------------------===//
313
314//===----------------------------------
315// BranchOpTerminatorInterface Methods
316//===----------------------------------
317
318void cir::ConditionOp::getSuccessorRegions(
319 ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
320 // TODO(cir): The condition value may be folded to a constant, narrowing
321 // down its list of possible successors.
322
323 // Parent is a loop: condition may branch to the body or to the parent op.
324 if (auto loopOp = dyn_cast<LoopOpInterface>(getOperation()->getParentOp())) {
325 regions.emplace_back(&loopOp.getBody(), loopOp.getBody().getArguments());
326 regions.emplace_back(getOperation(), loopOp->getResults());
327 }
328
329 // Parent is an await: condition may branch to resume or suspend regions.
330 auto await = cast<AwaitOp>(getOperation()->getParentOp());
331 regions.emplace_back(&await.getResume(), await.getResume().getArguments());
332 regions.emplace_back(&await.getSuspend(), await.getSuspend().getArguments());
333}
334
335MutableOperandRange
336cir::ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
337 // No values are yielded to the successor region.
338 return MutableOperandRange(getOperation(), 0, 0);
339}
340
341LogicalResult cir::ConditionOp::verify() {
342 if (!isa<LoopOpInterface, AwaitOp>(getOperation()->getParentOp()))
343 return emitOpError("condition must be within a conditional region");
344 return success();
345}
346
347//===----------------------------------------------------------------------===//
348// ConstantOp
349//===----------------------------------------------------------------------===//
350
351static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
352 mlir::Attribute attrType) {
353 if (isa<cir::ConstPtrAttr>(attrType)) {
354 if (!mlir::isa<cir::PointerType>(opType))
355 return op->emitOpError(
356 "pointer constant initializing a non-pointer type");
357 return success();
358 }
359
360 if (isa<cir::DataMemberAttr>(attrType)) {
361 // More detailed type verifications are already done in
362 // DataMemberAttr::verify. Don't need to repeat here.
363 return success();
364 }
365
366 if (isa<cir::ZeroAttr>(attrType)) {
367 if (isa<cir::RecordType, cir::ArrayType, cir::VectorType, cir::ComplexType>(
368 opType))
369 return success();
370 return op->emitOpError(
371 "zero expects struct, array, vector, or complex type");
372 }
373
374 if (mlir::isa<cir::UndefAttr>(attrType)) {
375 if (!mlir::isa<cir::VoidType>(opType))
376 return success();
377 return op->emitOpError("undef expects non-void type");
378 }
379
380 if (mlir::isa<cir::BoolAttr>(attrType)) {
381 if (!mlir::isa<cir::BoolType>(opType))
382 return op->emitOpError("result type (")
383 << opType << ") must be '!cir.bool' for '" << attrType << "'";
384 return success();
385 }
386
387 if (mlir::isa<cir::IntAttr, cir::FPAttr>(attrType)) {
388 auto at = cast<TypedAttr>(attrType);
389 if (at.getType() != opType) {
390 return op->emitOpError("result type (")
391 << opType << ") does not match value type (" << at.getType()
392 << ")";
393 }
394 return success();
395 }
396
397 if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
398 cir::ConstComplexAttr, cir::ConstRecordAttr,
399 cir::GlobalViewAttr, cir::PoisonAttr, cir::TypeInfoAttr,
400 cir::VTableAttr>(attrType))
401 return success();
402
403 assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
404 return op->emitOpError("global with type ")
405 << cast<TypedAttr>(attrType).getType() << " not yet supported";
406}
407
408LogicalResult cir::ConstantOp::verify() {
409 // ODS already generates checks to make sure the result type is valid. We just
410 // need to additionally check that the value's attribute type is consistent
411 // with the result type.
412 return checkConstantTypes(getOperation(), getType(), getValue());
413}
414
415OpFoldResult cir::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
416 return getValue();
417}
418
419//===----------------------------------------------------------------------===//
420// ContinueOp
421//===----------------------------------------------------------------------===//
422
423LogicalResult cir::ContinueOp::verify() {
424 if (!getOperation()->getParentOfType<LoopOpInterface>())
425 return emitOpError("must be within a loop");
426 return success();
427}
428
429//===----------------------------------------------------------------------===//
430// CastOp
431//===----------------------------------------------------------------------===//
432
433LogicalResult cir::CastOp::verify() {
434 mlir::Type resType = getType();
435 mlir::Type srcType = getSrc().getType();
436
437 // Verify address space casts for pointer types. given that
438 // casts for within a different address space are illegal.
439 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
440 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
441 if (srcPtrTy && resPtrTy && (getKind() != cir::CastKind::address_space))
442 if (srcPtrTy.getAddrSpace() != resPtrTy.getAddrSpace()) {
443 return emitOpError() << "result type address space does not match the "
444 "address space of the operand";
445 }
446
447 if (mlir::isa<cir::VectorType>(srcType) &&
448 mlir::isa<cir::VectorType>(resType)) {
449 // Use the element type of the vector to verify the cast kind. (Except for
450 // bitcast, see below.)
451 srcType = mlir::dyn_cast<cir::VectorType>(srcType).getElementType();
452 resType = mlir::dyn_cast<cir::VectorType>(resType).getElementType();
453 }
454
455 switch (getKind()) {
456 case cir::CastKind::int_to_bool: {
457 if (!mlir::isa<cir::BoolType>(resType))
458 return emitOpError() << "requires !cir.bool type for result";
459 if (!mlir::isa<cir::IntType>(srcType))
460 return emitOpError() << "requires !cir.int type for source";
461 return success();
462 }
463 case cir::CastKind::ptr_to_bool: {
464 if (!mlir::isa<cir::BoolType>(resType))
465 return emitOpError() << "requires !cir.bool type for result";
466 if (!mlir::isa<cir::PointerType>(srcType))
467 return emitOpError() << "requires !cir.ptr type for source";
468 return success();
469 }
470 case cir::CastKind::integral: {
471 if (!mlir::isa<cir::IntType>(resType))
472 return emitOpError() << "requires !cir.int type for result";
473 if (!mlir::isa<cir::IntType>(srcType))
474 return emitOpError() << "requires !cir.int type for source";
475 return success();
476 }
477 case cir::CastKind::array_to_ptrdecay: {
478 const auto arrayPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
479 const auto flatPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
480 if (!arrayPtrTy || !flatPtrTy)
481 return emitOpError() << "requires !cir.ptr type for source and result";
482
483 // TODO(CIR): Make sure the AddrSpace of both types are equals
484 return success();
485 }
486 case cir::CastKind::bitcast: {
487 // Handle the pointer types first.
488 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
489 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
490
491 if (srcPtrTy && resPtrTy) {
492 return success();
493 }
494
495 return success();
496 }
497 case cir::CastKind::floating: {
498 if (!mlir::isa<cir::FPTypeInterface>(srcType) ||
499 !mlir::isa<cir::FPTypeInterface>(resType))
500 return emitOpError() << "requires !cir.float type for source and result";
501 return success();
502 }
503 case cir::CastKind::float_to_int: {
504 if (!mlir::isa<cir::FPTypeInterface>(srcType))
505 return emitOpError() << "requires !cir.float type for source";
506 if (!mlir::dyn_cast<cir::IntType>(resType))
507 return emitOpError() << "requires !cir.int type for result";
508 return success();
509 }
510 case cir::CastKind::int_to_ptr: {
511 if (!mlir::dyn_cast<cir::IntType>(srcType))
512 return emitOpError() << "requires !cir.int type for source";
513 if (!mlir::dyn_cast<cir::PointerType>(resType))
514 return emitOpError() << "requires !cir.ptr type for result";
515 return success();
516 }
517 case cir::CastKind::ptr_to_int: {
518 if (!mlir::dyn_cast<cir::PointerType>(srcType))
519 return emitOpError() << "requires !cir.ptr type for source";
520 if (!mlir::dyn_cast<cir::IntType>(resType))
521 return emitOpError() << "requires !cir.int type for result";
522 return success();
523 }
524 case cir::CastKind::float_to_bool: {
525 if (!mlir::isa<cir::FPTypeInterface>(srcType))
526 return emitOpError() << "requires !cir.float type for source";
527 if (!mlir::isa<cir::BoolType>(resType))
528 return emitOpError() << "requires !cir.bool type for result";
529 return success();
530 }
531 case cir::CastKind::bool_to_int: {
532 if (!mlir::isa<cir::BoolType>(srcType))
533 return emitOpError() << "requires !cir.bool type for source";
534 if (!mlir::isa<cir::IntType>(resType))
535 return emitOpError() << "requires !cir.int type for result";
536 return success();
537 }
538 case cir::CastKind::int_to_float: {
539 if (!mlir::isa<cir::IntType>(srcType))
540 return emitOpError() << "requires !cir.int type for source";
541 if (!mlir::isa<cir::FPTypeInterface>(resType))
542 return emitOpError() << "requires !cir.float type for result";
543 return success();
544 }
545 case cir::CastKind::bool_to_float: {
546 if (!mlir::isa<cir::BoolType>(srcType))
547 return emitOpError() << "requires !cir.bool type for source";
548 if (!mlir::isa<cir::FPTypeInterface>(resType))
549 return emitOpError() << "requires !cir.float type for result";
550 return success();
551 }
552 case cir::CastKind::address_space: {
553 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
554 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
555 if (!srcPtrTy || !resPtrTy)
556 return emitOpError() << "requires !cir.ptr type for source and result";
557 if (srcPtrTy.getPointee() != resPtrTy.getPointee())
558 return emitOpError() << "requires two types differ in addrspace only";
559 return success();
560 }
561 case cir::CastKind::float_to_complex: {
562 if (!mlir::isa<cir::FPTypeInterface>(srcType))
563 return emitOpError() << "requires !cir.float type for source";
564 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
565 if (!resComplexTy)
566 return emitOpError() << "requires !cir.complex type for result";
567 if (srcType != resComplexTy.getElementType())
568 return emitOpError() << "requires source type match result element type";
569 return success();
570 }
571 case cir::CastKind::int_to_complex: {
572 if (!mlir::isa<cir::IntType>(srcType))
573 return emitOpError() << "requires !cir.int type for source";
574 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
575 if (!resComplexTy)
576 return emitOpError() << "requires !cir.complex type for result";
577 if (srcType != resComplexTy.getElementType())
578 return emitOpError() << "requires source type match result element type";
579 return success();
580 }
581 case cir::CastKind::float_complex_to_real: {
582 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
583 if (!srcComplexTy)
584 return emitOpError() << "requires !cir.complex type for source";
585 if (!mlir::isa<cir::FPTypeInterface>(resType))
586 return emitOpError() << "requires !cir.float type for result";
587 if (srcComplexTy.getElementType() != resType)
588 return emitOpError() << "requires source element type match result type";
589 return success();
590 }
591 case cir::CastKind::int_complex_to_real: {
592 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
593 if (!srcComplexTy)
594 return emitOpError() << "requires !cir.complex type for source";
595 if (!mlir::isa<cir::IntType>(resType))
596 return emitOpError() << "requires !cir.int type for result";
597 if (srcComplexTy.getElementType() != resType)
598 return emitOpError() << "requires source element type match result type";
599 return success();
600 }
601 case cir::CastKind::float_complex_to_bool: {
602 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
603 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
604 return emitOpError()
605 << "requires floating point !cir.complex type for source";
606 if (!mlir::isa<cir::BoolType>(resType))
607 return emitOpError() << "requires !cir.bool type for result";
608 return success();
609 }
610 case cir::CastKind::int_complex_to_bool: {
611 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
612 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
613 return emitOpError()
614 << "requires floating point !cir.complex type for source";
615 if (!mlir::isa<cir::BoolType>(resType))
616 return emitOpError() << "requires !cir.bool type for result";
617 return success();
618 }
619 case cir::CastKind::float_complex: {
620 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
621 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
622 return emitOpError()
623 << "requires floating point !cir.complex type for source";
624 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
625 if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
626 return emitOpError()
627 << "requires floating point !cir.complex type for result";
628 return success();
629 }
630 case cir::CastKind::float_complex_to_int_complex: {
631 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
632 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
633 return emitOpError()
634 << "requires floating point !cir.complex type for source";
635 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
636 if (!resComplexTy || !resComplexTy.isIntegerComplex())
637 return emitOpError() << "requires integer !cir.complex type for result";
638 return success();
639 }
640 case cir::CastKind::int_complex: {
641 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
642 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
643 return emitOpError() << "requires integer !cir.complex type for source";
644 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
645 if (!resComplexTy || !resComplexTy.isIntegerComplex())
646 return emitOpError() << "requires integer !cir.complex type for result";
647 return success();
648 }
649 case cir::CastKind::int_complex_to_float_complex: {
650 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
651 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
652 return emitOpError() << "requires integer !cir.complex type for source";
653 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
654 if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
655 return emitOpError()
656 << "requires floating point !cir.complex type for result";
657 return success();
658 }
659 default:
660 llvm_unreachable("Unknown CastOp kind?");
661 }
662}
663
664static bool isIntOrBoolCast(cir::CastOp op) {
665 auto kind = op.getKind();
666 return kind == cir::CastKind::bool_to_int ||
667 kind == cir::CastKind::int_to_bool || kind == cir::CastKind::integral;
668}
669
670static Value tryFoldCastChain(cir::CastOp op) {
671 cir::CastOp head = op, tail = op;
672
673 while (op) {
674 if (!isIntOrBoolCast(op))
675 break;
676 head = op;
677 op = head.getSrc().getDefiningOp<cir::CastOp>();
678 }
679
680 if (head == tail)
681 return {};
682
683 // if bool_to_int -> ... -> int_to_bool: take the bool
684 // as we had it was before all casts
685 if (head.getKind() == cir::CastKind::bool_to_int &&
686 tail.getKind() == cir::CastKind::int_to_bool)
687 return head.getSrc();
688
689 // if int_to_bool -> ... -> int_to_bool: take the result
690 // of the first one, as no other casts (and ext casts as well)
691 // don't change the first result
692 if (head.getKind() == cir::CastKind::int_to_bool &&
693 tail.getKind() == cir::CastKind::int_to_bool)
694 return head.getResult();
695
696 return {};
697}
698
699OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
700 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getSrc())) {
701 // Propagate poison value
702 return cir::PoisonAttr::get(getContext(), getType());
703 }
704
705 if (getSrc().getType() == getType()) {
706 switch (getKind()) {
707 case cir::CastKind::integral: {
708 // TODO: for sign differences, it's possible in certain conditions to
709 // create a new attribute that's capable of representing the source.
711 auto foldOrder = getSrc().getDefiningOp()->fold(foldResults);
712 if (foldOrder.succeeded() && mlir::isa<mlir::Attribute>(foldResults[0]))
713 return mlir::cast<mlir::Attribute>(foldResults[0]);
714 return {};
715 }
716 case cir::CastKind::bitcast:
717 case cir::CastKind::address_space:
718 case cir::CastKind::float_complex:
719 case cir::CastKind::int_complex: {
720 return getSrc();
721 }
722 default:
723 return {};
724 }
725 }
726 return tryFoldCastChain(*this);
727}
728
729//===----------------------------------------------------------------------===//
730// CallOp
731//===----------------------------------------------------------------------===//
732
733mlir::OperandRange cir::CallOp::getArgOperands() {
734 if (isIndirect())
735 return getArgs().drop_front(1);
736 return getArgs();
737}
738
739mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
740 mlir::MutableOperandRange args = getArgsMutable();
741 if (isIndirect())
742 return args.slice(1, args.size() - 1);
743 return args;
744}
745
746mlir::Value cir::CallOp::getIndirectCall() {
747 assert(isIndirect());
748 return getOperand(0);
749}
750
751/// Return the operand at index 'i'.
752Value cir::CallOp::getArgOperand(unsigned i) {
753 if (isIndirect())
754 ++i;
755 return getOperand(i);
756}
757
758/// Return the number of operands.
759unsigned cir::CallOp::getNumArgOperands() {
760 if (isIndirect())
761 return this->getOperation()->getNumOperands() - 1;
762 return this->getOperation()->getNumOperands();
763}
764
765static mlir::ParseResult
766parseTryCallDestinations(mlir::OpAsmParser &parser,
767 mlir::OperationState &result) {
768 mlir::Block *normalDestSuccessor;
769 if (parser.parseSuccessor(normalDestSuccessor))
770 return mlir::failure();
771
772 if (parser.parseComma())
773 return mlir::failure();
774
775 mlir::Block *unwindDestSuccessor;
776 if (parser.parseSuccessor(unwindDestSuccessor))
777 return mlir::failure();
778
779 result.addSuccessors(normalDestSuccessor);
780 result.addSuccessors(unwindDestSuccessor);
781 return mlir::success();
782}
783
784static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
785 mlir::OperationState &result,
786 bool hasDestinationBlocks = false) {
788 llvm::SMLoc opsLoc;
789 mlir::FlatSymbolRefAttr calleeAttr;
790 llvm::ArrayRef<mlir::Type> allResultTypes;
791
792 // If we cannot parse a string callee, it means this is an indirect call.
793 if (!parser
794 .parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
795 result.attributes)
796 .has_value()) {
797 OpAsmParser::UnresolvedOperand indirectVal;
798 // Do not resolve right now, since we need to figure out the type
799 if (parser.parseOperand(indirectVal).failed())
800 return failure();
801 ops.push_back(indirectVal);
802 }
803
804 if (parser.parseLParen())
805 return mlir::failure();
806
807 opsLoc = parser.getCurrentLocation();
808 if (parser.parseOperandList(ops))
809 return mlir::failure();
810 if (parser.parseRParen())
811 return mlir::failure();
812
813 if (hasDestinationBlocks &&
814 parseTryCallDestinations(parser, result).failed()) {
815 return ::mlir::failure();
816 }
817
818 if (parser.parseOptionalKeyword("nothrow").succeeded())
819 result.addAttribute(CIRDialect::getNoThrowAttrName(),
820 mlir::UnitAttr::get(parser.getContext()));
821
822 if (parser.parseOptionalKeyword("side_effect").succeeded()) {
823 if (parser.parseLParen().failed())
824 return failure();
825 cir::SideEffect sideEffect;
826 if (parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed())
827 return failure();
828 if (parser.parseRParen().failed())
829 return failure();
830 auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
831 result.addAttribute(CIRDialect::getSideEffectAttrName(), attr);
832 }
833
834 if (parser.parseOptionalAttrDict(result.attributes))
835 return ::mlir::failure();
836
837 if (parser.parseColon())
838 return ::mlir::failure();
839
840 mlir::FunctionType opsFnTy;
841 if (parser.parseType(opsFnTy))
842 return mlir::failure();
843
844 allResultTypes = opsFnTy.getResults();
845 result.addTypes(allResultTypes);
846
847 if (parser.resolveOperands(ops, opsFnTy.getInputs(), opsLoc, result.operands))
848 return mlir::failure();
849
850 return mlir::success();
851}
852
853static void printCallCommon(mlir::Operation *op,
854 mlir::FlatSymbolRefAttr calleeSym,
855 mlir::Value indirectCallee,
856 mlir::OpAsmPrinter &printer, bool isNothrow,
857 cir::SideEffect sideEffect,
858 mlir::Block *normalDest = nullptr,
859 mlir::Block *unwindDest = nullptr) {
860 printer << ' ';
861
862 auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
863 auto ops = callLikeOp.getArgOperands();
864
865 if (calleeSym) {
866 // Direct calls
867 printer.printAttributeWithoutType(calleeSym);
868 } else {
869 // Indirect calls
870 assert(indirectCallee);
871 printer << indirectCallee;
872 }
873
874 printer << "(" << ops << ")";
875
876 if (normalDest) {
877 assert(unwindDest && "expected two successors");
878 auto tryCall = cast<cir::TryCallOp>(op);
879 printer << ' ' << tryCall.getNormalDest();
880 printer << ",";
881 printer << ' ';
882 printer << tryCall.getUnwindDest();
883 }
884
885 if (isNothrow)
886 printer << " nothrow";
887
888 if (sideEffect != cir::SideEffect::All) {
889 printer << " side_effect(";
890 printer << stringifySideEffect(sideEffect);
891 printer << ")";
892 }
893
895 CIRDialect::getCalleeAttrName(), CIRDialect::getNoThrowAttrName(),
896 CIRDialect::getSideEffectAttrName(),
897 CIRDialect::getOperandSegmentSizesAttrName()};
898 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
899 printer << " : ";
900 printer.printFunctionalType(op->getOperands().getTypes(),
901 op->getResultTypes());
902}
903
904mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
905 mlir::OperationState &result) {
906 return parseCallCommon(parser, result);
907}
908
909void cir::CallOp::print(mlir::OpAsmPrinter &p) {
910 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
911 cir::SideEffect sideEffect = getSideEffect();
912 printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
913 sideEffect);
914}
915
916static LogicalResult
917verifyCallCommInSymbolUses(mlir::Operation *op,
918 SymbolTableCollection &symbolTable) {
919 auto fnAttr =
920 op->getAttrOfType<FlatSymbolRefAttr>(CIRDialect::getCalleeAttrName());
921 if (!fnAttr) {
922 // This is an indirect call, thus we don't have to check the symbol uses.
923 return mlir::success();
924 }
925
926 auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
927 if (!fn)
928 return op->emitOpError() << "'" << fnAttr.getValue()
929 << "' does not reference a valid function";
930
931 auto callIf = dyn_cast<cir::CIRCallOpInterface>(op);
932 assert(callIf && "expected CIR call interface to be always available");
933
934 // Verify that the operand and result types match the callee. Note that
935 // argument-checking is disabled for functions without a prototype.
936 auto fnType = fn.getFunctionType();
937 if (!fn.getNoProto()) {
938 unsigned numCallOperands = callIf.getNumArgOperands();
939 unsigned numFnOpOperands = fnType.getNumInputs();
940
941 if (!fnType.isVarArg() && numCallOperands != numFnOpOperands)
942 return op->emitOpError("incorrect number of operands for callee");
943 if (fnType.isVarArg() && numCallOperands < numFnOpOperands)
944 return op->emitOpError("too few operands for callee");
945
946 for (unsigned i = 0, e = numFnOpOperands; i != e; ++i)
947 if (callIf.getArgOperand(i).getType() != fnType.getInput(i))
948 return op->emitOpError("operand type mismatch: expected operand type ")
949 << fnType.getInput(i) << ", but provided "
950 << op->getOperand(i).getType() << " for operand number " << i;
951 }
952
954
955 // Void function must not return any results.
956 if (fnType.hasVoidReturn() && op->getNumResults() != 0)
957 return op->emitOpError("callee returns void but call has results");
958
959 // Non-void function calls must return exactly one result.
960 if (!fnType.hasVoidReturn() && op->getNumResults() != 1)
961 return op->emitOpError("incorrect number of results for callee");
962
963 // Parent function and return value types must match.
964 if (!fnType.hasVoidReturn() &&
965 op->getResultTypes().front() != fnType.getReturnType()) {
966 return op->emitOpError("result type mismatch: expected ")
967 << fnType.getReturnType() << ", but provided "
968 << op->getResult(0).getType();
969 }
970
971 return mlir::success();
972}
973
974LogicalResult
975cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
976 return verifyCallCommInSymbolUses(*this, symbolTable);
977}
978
979//===----------------------------------------------------------------------===//
980// TryCallOp
981//===----------------------------------------------------------------------===//
982
983mlir::OperandRange cir::TryCallOp::getArgOperands() {
984 if (isIndirect())
985 return getArgs().drop_front(1);
986 return getArgs();
987}
988
989mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
990 mlir::MutableOperandRange args = getArgsMutable();
991 if (isIndirect())
992 return args.slice(1, args.size() - 1);
993 return args;
994}
995
996mlir::Value cir::TryCallOp::getIndirectCall() {
997 assert(isIndirect());
998 return getOperand(0);
999}
1000
1001/// Return the operand at index 'i'.
1002Value cir::TryCallOp::getArgOperand(unsigned i) {
1003 if (isIndirect())
1004 ++i;
1005 return getOperand(i);
1006}
1007
1008/// Return the number of operands.
1009unsigned cir::TryCallOp::getNumArgOperands() {
1010 if (isIndirect())
1011 return this->getOperation()->getNumOperands() - 1;
1012 return this->getOperation()->getNumOperands();
1013}
1014
1015LogicalResult
1016cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1017 return verifyCallCommInSymbolUses(*this, symbolTable);
1018}
1019
1020mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
1021 mlir::OperationState &result) {
1022 return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
1023}
1024
1025void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
1026 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
1027 cir::SideEffect sideEffect = getSideEffect();
1028 printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
1029 sideEffect, getNormalDest(), getUnwindDest());
1030}
1031
1032//===----------------------------------------------------------------------===//
1033// ReturnOp
1034//===----------------------------------------------------------------------===//
1035
1036static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op,
1037 cir::FuncOp function) {
1038 // ReturnOps currently only have a single optional operand.
1039 if (op.getNumOperands() > 1)
1040 return op.emitOpError() << "expects at most 1 return operand";
1041
1042 // Ensure returned type matches the function signature.
1043 auto expectedTy = function.getFunctionType().getReturnType();
1044 auto actualTy =
1045 (op.getNumOperands() == 0 ? cir::VoidType::get(op.getContext())
1046 : op.getOperand(0).getType());
1047 if (actualTy != expectedTy)
1048 return op.emitOpError() << "returns " << actualTy
1049 << " but enclosing function returns " << expectedTy;
1050
1051 return mlir::success();
1052}
1053
1054mlir::LogicalResult cir::ReturnOp::verify() {
1055 // Returns can be present in multiple different scopes, get the
1056 // wrapping function and start from there.
1057 auto *fnOp = getOperation()->getParentOp();
1058 while (!isa<cir::FuncOp>(fnOp))
1059 fnOp = fnOp->getParentOp();
1060
1061 // Make sure return types match function return type.
1062 if (checkReturnAndFunction(*this, cast<cir::FuncOp>(fnOp)).failed())
1063 return failure();
1064
1065 return success();
1066}
1067
1068//===----------------------------------------------------------------------===//
1069// IfOp
1070//===----------------------------------------------------------------------===//
1071
1072ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
1073 // create the regions for 'then'.
1074 result.regions.reserve(2);
1075 Region *thenRegion = result.addRegion();
1076 Region *elseRegion = result.addRegion();
1077
1078 mlir::Builder &builder = parser.getBuilder();
1079 OpAsmParser::UnresolvedOperand cond;
1080 Type boolType = cir::BoolType::get(builder.getContext());
1081
1082 if (parser.parseOperand(cond) ||
1083 parser.resolveOperand(cond, boolType, result.operands))
1084 return failure();
1085
1086 // Parse 'then' region.
1087 mlir::SMLoc parseThenLoc = parser.getCurrentLocation();
1088 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1089 return failure();
1090
1091 if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed())
1092 return failure();
1093
1094 // If we find an 'else' keyword, parse the 'else' region.
1095 if (!parser.parseOptionalKeyword("else")) {
1096 mlir::SMLoc parseElseLoc = parser.getCurrentLocation();
1097 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1098 return failure();
1099 if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed())
1100 return failure();
1101 }
1102
1103 // Parse the optional attribute list.
1104 if (parser.parseOptionalAttrDict(result.attributes))
1105 return failure();
1106 return success();
1107}
1108
1109void cir::IfOp::print(OpAsmPrinter &p) {
1110 p << " " << getCondition() << " ";
1111 mlir::Region &thenRegion = this->getThenRegion();
1112 p.printRegion(thenRegion,
1113 /*printEntryBlockArgs=*/false,
1114 /*printBlockTerminators=*/!omitRegionTerm(thenRegion));
1115
1116 // Print the 'else' regions if it exists and has a block.
1117 mlir::Region &elseRegion = this->getElseRegion();
1118 if (!elseRegion.empty()) {
1119 p << " else ";
1120 p.printRegion(elseRegion,
1121 /*printEntryBlockArgs=*/false,
1122 /*printBlockTerminators=*/!omitRegionTerm(elseRegion));
1123 }
1124
1125 p.printOptionalAttrDict(getOperation()->getAttrs());
1126}
1127
1128/// Default callback for IfOp builders.
1129void cir::buildTerminatedBody(OpBuilder &builder, Location loc) {
1130 // add cir.yield to end of the block
1131 cir::YieldOp::create(builder, loc);
1132}
1133
1134/// Given the region at `index`, or the parent operation if `index` is None,
1135/// return the successor regions. These are the regions that may be selected
1136/// during the flow of control. `operands` is a set of optional attributes that
1137/// correspond to a constant value for each operand, or null if that operand is
1138/// not a constant.
1139void cir::IfOp::getSuccessorRegions(mlir::RegionBranchPoint point,
1140 SmallVectorImpl<RegionSuccessor> &regions) {
1141 // The `then` and the `else` region branch back to the parent operation.
1142 if (!point.isParent()) {
1143 regions.push_back(
1144 RegionSuccessor(getOperation(), getOperation()->getResults()));
1145 return;
1146 }
1147
1148 // Don't consider the else region if it is empty.
1149 Region *elseRegion = &this->getElseRegion();
1150 if (elseRegion->empty())
1151 elseRegion = nullptr;
1152
1153 // If the condition isn't constant, both regions may be executed.
1154 regions.push_back(RegionSuccessor(&getThenRegion()));
1155 // If the else region does not exist, it is not a viable successor.
1156 if (elseRegion)
1157 regions.push_back(RegionSuccessor(elseRegion));
1158
1159 return;
1160}
1161
1162void cir::IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1163 bool withElseRegion, BuilderCallbackRef thenBuilder,
1164 BuilderCallbackRef elseBuilder) {
1165 assert(thenBuilder && "the builder callback for 'then' must be present");
1166 result.addOperands(cond);
1167
1168 OpBuilder::InsertionGuard guard(builder);
1169 Region *thenRegion = result.addRegion();
1170 builder.createBlock(thenRegion);
1171 thenBuilder(builder, result.location);
1172
1173 Region *elseRegion = result.addRegion();
1174 if (!withElseRegion)
1175 return;
1176
1177 builder.createBlock(elseRegion);
1178 elseBuilder(builder, result.location);
1179}
1180
1181//===----------------------------------------------------------------------===//
1182// ScopeOp
1183//===----------------------------------------------------------------------===//
1184
1185/// Given the region at `index`, or the parent operation if `index` is None,
1186/// return the successor regions. These are the regions that may be selected
1187/// during the flow of control. `operands` is a set of optional attributes
1188/// that correspond to a constant value for each operand, or null if that
1189/// operand is not a constant.
1190void cir::ScopeOp::getSuccessorRegions(
1191 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1192 // The only region always branch back to the parent operation.
1193 if (!point.isParent()) {
1194 regions.push_back(RegionSuccessor(getOperation(), getODSResults(0)));
1195 return;
1196 }
1197
1198 // If the condition isn't constant, both regions may be executed.
1199 regions.push_back(RegionSuccessor(&getScopeRegion()));
1200}
1201
1202void cir::ScopeOp::build(
1203 OpBuilder &builder, OperationState &result,
1204 function_ref<void(OpBuilder &, Type &, Location)> scopeBuilder) {
1205 assert(scopeBuilder && "the builder callback for 'then' must be present");
1206
1207 OpBuilder::InsertionGuard guard(builder);
1208 Region *scopeRegion = result.addRegion();
1209 builder.createBlock(scopeRegion);
1211
1212 mlir::Type yieldTy;
1213 scopeBuilder(builder, yieldTy, result.location);
1214
1215 if (yieldTy)
1216 result.addTypes(TypeRange{yieldTy});
1217}
1218
1219void cir::ScopeOp::build(
1220 OpBuilder &builder, OperationState &result,
1221 function_ref<void(OpBuilder &, Location)> scopeBuilder) {
1222 assert(scopeBuilder && "the builder callback for 'then' must be present");
1223 OpBuilder::InsertionGuard guard(builder);
1224 Region *scopeRegion = result.addRegion();
1225 builder.createBlock(scopeRegion);
1227 scopeBuilder(builder, result.location);
1228}
1229
1230LogicalResult cir::ScopeOp::verify() {
1231 if (getRegion().empty()) {
1232 return emitOpError() << "cir.scope must not be empty since it should "
1233 "include at least an implicit cir.yield ";
1234 }
1235
1236 mlir::Block &lastBlock = getRegion().back();
1237 if (lastBlock.empty() || !lastBlock.mightHaveTerminator() ||
1238 !lastBlock.getTerminator()->hasTrait<OpTrait::IsTerminator>())
1239 return emitOpError() << "last block of cir.scope must be terminated";
1240 return success();
1241}
1242
1243//===----------------------------------------------------------------------===//
1244// BrOp
1245//===----------------------------------------------------------------------===//
1246
1247mlir::SuccessorOperands cir::BrOp::getSuccessorOperands(unsigned index) {
1248 assert(index == 0 && "invalid successor index");
1249 return mlir::SuccessorOperands(getDestOperandsMutable());
1250}
1251
1252Block *cir::BrOp::getSuccessorForOperands(ArrayRef<Attribute>) {
1253 return getDest();
1254}
1255
1256//===----------------------------------------------------------------------===//
1257// IndirectBrCondOp
1258//===----------------------------------------------------------------------===//
1259
1260mlir::SuccessorOperands
1261cir::IndirectBrOp::getSuccessorOperands(unsigned index) {
1262 assert(index < getNumSuccessors() && "invalid successor index");
1263 return mlir::SuccessorOperands(getSuccOperandsMutable()[index]);
1264}
1265
1267 OpAsmParser &parser, Type &flagType,
1268 SmallVectorImpl<Block *> &succOperandBlocks,
1269 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &succOperands,
1270 SmallVectorImpl<SmallVector<Type>> &succOperandsTypes) {
1271 if (failed(parser.parseCommaSeparatedList(
1272 OpAsmParser::Delimiter::Square,
1273 [&]() {
1274 Block *destination = nullptr;
1275 SmallVector<OpAsmParser::UnresolvedOperand> operands;
1276 SmallVector<Type> operandTypes;
1277
1278 if (parser.parseSuccessor(destination).failed())
1279 return failure();
1280
1281 if (succeeded(parser.parseOptionalLParen())) {
1282 if (failed(parser.parseOperandList(
1283 operands, OpAsmParser::Delimiter::None)) ||
1284 failed(parser.parseColonTypeList(operandTypes)) ||
1285 failed(parser.parseRParen()))
1286 return failure();
1287 }
1288 succOperandBlocks.push_back(destination);
1289 succOperands.emplace_back(operands);
1290 succOperandsTypes.emplace_back(operandTypes);
1291 return success();
1292 },
1293 "successor blocks")))
1294 return failure();
1295 return success();
1296}
1297
1298void printIndirectBrOpSucessors(OpAsmPrinter &p, cir::IndirectBrOp op,
1299 Type flagType, SuccessorRange succs,
1300 OperandRangeRange succOperands,
1301 const TypeRangeRange &succOperandsTypes) {
1302 p << "[";
1303 llvm::interleave(
1304 llvm::zip(succs, succOperands),
1305 [&](auto i) {
1306 p.printNewline();
1307 p.printSuccessorAndUseList(std::get<0>(i), std::get<1>(i));
1308 },
1309 [&] { p << ','; });
1310 if (!succOperands.empty())
1311 p.printNewline();
1312 p << "]";
1313}
1314
1315//===----------------------------------------------------------------------===//
1316// BrCondOp
1317//===----------------------------------------------------------------------===//
1318
1319mlir::SuccessorOperands cir::BrCondOp::getSuccessorOperands(unsigned index) {
1320 assert(index < getNumSuccessors() && "invalid successor index");
1321 return SuccessorOperands(index == 0 ? getDestOperandsTrueMutable()
1322 : getDestOperandsFalseMutable());
1323}
1324
1325Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1326 if (IntegerAttr condAttr = dyn_cast_if_present<IntegerAttr>(operands.front()))
1327 return condAttr.getValue().isOne() ? getDestTrue() : getDestFalse();
1328 return nullptr;
1329}
1330
1331//===----------------------------------------------------------------------===//
1332// CaseOp
1333//===----------------------------------------------------------------------===//
1334
1335void cir::CaseOp::getSuccessorRegions(
1336 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1337 if (!point.isParent()) {
1338 regions.push_back(
1339 RegionSuccessor(getOperation(), getOperation()->getResults()));
1340 return;
1341 }
1342 regions.push_back(RegionSuccessor(&getCaseRegion()));
1343}
1344
1345void cir::CaseOp::build(OpBuilder &builder, OperationState &result,
1346 ArrayAttr value, CaseOpKind kind,
1347 OpBuilder::InsertPoint &insertPoint) {
1348 OpBuilder::InsertionGuard guardSwitch(builder);
1349 result.addAttribute("value", value);
1350 result.getOrAddProperties<Properties>().kind =
1351 cir::CaseOpKindAttr::get(builder.getContext(), kind);
1352 Region *caseRegion = result.addRegion();
1353 builder.createBlock(caseRegion);
1354
1355 insertPoint = builder.saveInsertionPoint();
1356}
1357
1358//===----------------------------------------------------------------------===//
1359// SwitchOp
1360//===----------------------------------------------------------------------===//
1361
1362void cir::SwitchOp::getSuccessorRegions(
1363 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &region) {
1364 if (!point.isParent()) {
1365 region.push_back(
1366 RegionSuccessor(getOperation(), getOperation()->getResults()));
1367 return;
1368 }
1369
1370 region.push_back(RegionSuccessor(&getBody()));
1371}
1372
1373void cir::SwitchOp::build(OpBuilder &builder, OperationState &result,
1374 Value cond, BuilderOpStateCallbackRef switchBuilder) {
1375 assert(switchBuilder && "the builder callback for regions must be present");
1376 OpBuilder::InsertionGuard guardSwitch(builder);
1377 Region *switchRegion = result.addRegion();
1378 builder.createBlock(switchRegion);
1379 result.addOperands({cond});
1380 switchBuilder(builder, result.location, result);
1381}
1382
1383void cir::SwitchOp::collectCases(llvm::SmallVectorImpl<CaseOp> &cases) {
1384 walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
1385 // Don't walk in nested switch op.
1386 if (isa<cir::SwitchOp>(op) && op != *this)
1387 return WalkResult::skip();
1388
1389 if (auto caseOp = dyn_cast<cir::CaseOp>(op))
1390 cases.push_back(caseOp);
1391
1392 return WalkResult::advance();
1393 });
1394}
1395
1396bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
1397 collectCases(cases);
1398
1399 if (getBody().empty())
1400 return false;
1401
1402 if (!isa<YieldOp>(getBody().front().back()))
1403 return false;
1404
1405 if (!llvm::all_of(getBody().front(),
1406 [](Operation &op) { return isa<CaseOp, YieldOp>(op); }))
1407 return false;
1408
1409 return llvm::all_of(cases, [this](CaseOp op) {
1410 return op->getParentOfType<SwitchOp>() == *this;
1411 });
1412}
1413
1414//===----------------------------------------------------------------------===//
1415// SwitchFlatOp
1416//===----------------------------------------------------------------------===//
1417
1418void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
1419 Value value, Block *defaultDestination,
1420 ValueRange defaultOperands,
1421 ArrayRef<APInt> caseValues,
1422 BlockRange caseDestinations,
1423 ArrayRef<ValueRange> caseOperands) {
1424
1425 std::vector<mlir::Attribute> caseValuesAttrs;
1426 for (const APInt &val : caseValues)
1427 caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
1428 mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
1429
1430 build(builder, result, value, defaultOperands, caseOperands, attrs,
1431 defaultDestination, caseDestinations);
1432}
1433
1434/// <cases> ::= `[` (case (`,` case )* )? `]`
1435/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
1436static ParseResult parseSwitchFlatOpCases(
1437 OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
1438 SmallVectorImpl<Block *> &caseDestinations,
1440 &caseOperands,
1441 SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
1442 if (failed(parser.parseLSquare()))
1443 return failure();
1444 if (succeeded(parser.parseOptionalRSquare()))
1445 return success();
1447
1448 auto parseCase = [&]() {
1449 int64_t value = 0;
1450 if (failed(parser.parseInteger(value)))
1451 return failure();
1452
1453 values.push_back(cir::IntAttr::get(flagType, value));
1454
1455 Block *destination;
1457 llvm::SmallVector<Type> operandTypes;
1458 if (parser.parseColon() || parser.parseSuccessor(destination))
1459 return failure();
1460 if (!parser.parseOptionalLParen()) {
1461 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
1462 /*allowResultNumber=*/false) ||
1463 parser.parseColonTypeList(operandTypes) || parser.parseRParen())
1464 return failure();
1465 }
1466 caseDestinations.push_back(destination);
1467 caseOperands.emplace_back(operands);
1468 caseOperandTypes.emplace_back(operandTypes);
1469 return success();
1470 };
1471 if (failed(parser.parseCommaSeparatedList(parseCase)))
1472 return failure();
1473
1474 caseValues = ArrayAttr::get(flagType.getContext(), values);
1475
1476 return parser.parseRSquare();
1477}
1478
1479static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
1480 Type flagType, mlir::ArrayAttr caseValues,
1481 SuccessorRange caseDestinations,
1482 OperandRangeRange caseOperands,
1483 const TypeRangeRange &caseOperandTypes) {
1484 p << '[';
1485 p.printNewline();
1486 if (!caseValues) {
1487 p << ']';
1488 return;
1489 }
1490
1491 size_t index = 0;
1492 llvm::interleave(
1493 llvm::zip(caseValues, caseDestinations),
1494 [&](auto i) {
1495 p << " ";
1496 mlir::Attribute a = std::get<0>(i);
1497 p << mlir::cast<cir::IntAttr>(a).getValue();
1498 p << ": ";
1499 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
1500 },
1501 [&] {
1502 p << ',';
1503 p.printNewline();
1504 });
1505 p.printNewline();
1506 p << ']';
1507}
1508
1509//===----------------------------------------------------------------------===//
1510// GlobalOp
1511//===----------------------------------------------------------------------===//
1512
1513static ParseResult parseConstantValue(OpAsmParser &parser,
1514 mlir::Attribute &valueAttr) {
1515 NamedAttrList attr;
1516 return parser.parseAttribute(valueAttr, "value", attr);
1517}
1518
1519static void printConstant(OpAsmPrinter &p, Attribute value) {
1520 p.printAttribute(value);
1521}
1522
1523mlir::LogicalResult cir::GlobalOp::verify() {
1524 // Verify that the initial value, if present, is either a unit attribute or
1525 // an attribute CIR supports.
1526 if (getInitialValue().has_value()) {
1527 if (checkConstantTypes(getOperation(), getSymType(), *getInitialValue())
1528 .failed())
1529 return failure();
1530 }
1531
1532 // TODO(CIR): Many other checks for properties that haven't been upstreamed
1533 // yet.
1534
1535 return success();
1536}
1537
1538void cir::GlobalOp::build(
1539 OpBuilder &odsBuilder, OperationState &odsState, llvm::StringRef sym_name,
1540 mlir::Type sym_type, bool isConstant, cir::GlobalLinkageKind linkage,
1541 function_ref<void(OpBuilder &, Location)> ctorBuilder,
1542 function_ref<void(OpBuilder &, Location)> dtorBuilder) {
1543 odsState.addAttribute(getSymNameAttrName(odsState.name),
1544 odsBuilder.getStringAttr(sym_name));
1545 odsState.addAttribute(getSymTypeAttrName(odsState.name),
1546 mlir::TypeAttr::get(sym_type));
1547 if (isConstant)
1548 odsState.addAttribute(getConstantAttrName(odsState.name),
1549 odsBuilder.getUnitAttr());
1550
1551 cir::GlobalLinkageKindAttr linkageAttr =
1552 cir::GlobalLinkageKindAttr::get(odsBuilder.getContext(), linkage);
1553 odsState.addAttribute(getLinkageAttrName(odsState.name), linkageAttr);
1554
1555 Region *ctorRegion = odsState.addRegion();
1556 if (ctorBuilder) {
1557 odsBuilder.createBlock(ctorRegion);
1558 ctorBuilder(odsBuilder, odsState.location);
1559 }
1560
1561 Region *dtorRegion = odsState.addRegion();
1562 if (dtorBuilder) {
1563 odsBuilder.createBlock(dtorRegion);
1564 dtorBuilder(odsBuilder, odsState.location);
1565 }
1566
1567 odsState.addAttribute(getGlobalVisibilityAttrName(odsState.name),
1568 cir::VisibilityAttr::get(odsBuilder.getContext()));
1569}
1570
1571/// Given the region at `index`, or the parent operation if `index` is None,
1572/// return the successor regions. These are the regions that may be selected
1573/// during the flow of control. `operands` is a set of optional attributes that
1574/// correspond to a constant value for each operand, or null if that operand is
1575/// not a constant.
1576void cir::GlobalOp::getSuccessorRegions(
1577 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1578 // The `ctor` and `dtor` regions always branch back to the parent operation.
1579 if (!point.isParent()) {
1580 regions.push_back(
1581 RegionSuccessor(getOperation(), getOperation()->getResults()));
1582 return;
1583 }
1584
1585 // Don't consider the ctor region if it is empty.
1586 Region *ctorRegion = &this->getCtorRegion();
1587 if (ctorRegion->empty())
1588 ctorRegion = nullptr;
1589
1590 // Don't consider the dtor region if it is empty.
1591 Region *dtorRegion = &this->getCtorRegion();
1592 if (dtorRegion->empty())
1593 dtorRegion = nullptr;
1594
1595 // If the condition isn't constant, both regions may be executed.
1596 if (ctorRegion)
1597 regions.push_back(RegionSuccessor(ctorRegion));
1598 if (dtorRegion)
1599 regions.push_back(RegionSuccessor(dtorRegion));
1600}
1601
1602static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op,
1603 TypeAttr type, Attribute initAttr,
1604 mlir::Region &ctorRegion,
1605 mlir::Region &dtorRegion) {
1606 auto printType = [&]() { p << ": " << type; };
1607 if (!op.isDeclaration()) {
1608 p << "= ";
1609 if (!ctorRegion.empty()) {
1610 p << "ctor ";
1611 printType();
1612 p << " ";
1613 p.printRegion(ctorRegion,
1614 /*printEntryBlockArgs=*/false,
1615 /*printBlockTerminators=*/false);
1616 } else {
1617 // This also prints the type...
1618 if (initAttr)
1619 printConstant(p, initAttr);
1620 }
1621
1622 if (!dtorRegion.empty()) {
1623 p << " dtor ";
1624 p.printRegion(dtorRegion,
1625 /*printEntryBlockArgs=*/false,
1626 /*printBlockTerminators=*/false);
1627 }
1628 } else {
1629 printType();
1630 }
1631}
1632
1633static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser,
1634 TypeAttr &typeAttr,
1635 Attribute &initialValueAttr,
1636 mlir::Region &ctorRegion,
1637 mlir::Region &dtorRegion) {
1638 mlir::Type opTy;
1639 if (parser.parseOptionalEqual().failed()) {
1640 // Absence of equal means a declaration, so we need to parse the type.
1641 // cir.global @a : !cir.int<s, 32>
1642 if (parser.parseColonType(opTy))
1643 return failure();
1644 } else {
1645 // Parse contructor, example:
1646 // cir.global @rgb = ctor : type { ... }
1647 if (!parser.parseOptionalKeyword("ctor")) {
1648 if (parser.parseColonType(opTy))
1649 return failure();
1650 auto parseLoc = parser.getCurrentLocation();
1651 if (parser.parseRegion(ctorRegion, /*arguments=*/{}, /*argTypes=*/{}))
1652 return failure();
1653 if (ensureRegionTerm(parser, ctorRegion, parseLoc).failed())
1654 return failure();
1655 } else {
1656 // Parse constant with initializer, examples:
1657 // cir.global @y = 3.400000e+00 : f32
1658 // cir.global @rgb = #cir.const_array<[...] : !cir.array<i8 x 3>>
1659 if (parseConstantValue(parser, initialValueAttr).failed())
1660 return failure();
1661
1662 assert(mlir::isa<mlir::TypedAttr>(initialValueAttr) &&
1663 "Non-typed attrs shouldn't appear here.");
1664 auto typedAttr = mlir::cast<mlir::TypedAttr>(initialValueAttr);
1665 opTy = typedAttr.getType();
1666 }
1667
1668 // Parse destructor, example:
1669 // dtor { ... }
1670 if (!parser.parseOptionalKeyword("dtor")) {
1671 auto parseLoc = parser.getCurrentLocation();
1672 if (parser.parseRegion(dtorRegion, /*arguments=*/{}, /*argTypes=*/{}))
1673 return failure();
1674 if (ensureRegionTerm(parser, dtorRegion, parseLoc).failed())
1675 return failure();
1676 }
1677 }
1678
1679 typeAttr = TypeAttr::get(opTy);
1680 return success();
1681}
1682
1683//===----------------------------------------------------------------------===//
1684// GetGlobalOp
1685//===----------------------------------------------------------------------===//
1686
1687LogicalResult
1688cir::GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1689 // Verify that the result type underlying pointer type matches the type of
1690 // the referenced cir.global or cir.func op.
1691 mlir::Operation *op =
1692 symbolTable.lookupNearestSymbolFrom(*this, getNameAttr());
1693 if (op == nullptr || !(isa<GlobalOp>(op) || isa<FuncOp>(op)))
1694 return emitOpError("'")
1695 << getName()
1696 << "' does not reference a valid cir.global or cir.func";
1697
1698 mlir::Type symTy;
1699 if (auto g = dyn_cast<GlobalOp>(op)) {
1700 symTy = g.getSymType();
1702 // Verify that for thread local global access, the global needs to
1703 // be marked with tls bits.
1704 if (getTls() && !g.getTlsModel())
1705 return emitOpError("access to global not marked thread local");
1706 } else if (auto f = dyn_cast<FuncOp>(op)) {
1707 symTy = f.getFunctionType();
1708 } else {
1709 llvm_unreachable("Unexpected operation for GetGlobalOp");
1710 }
1711
1712 auto resultType = dyn_cast<PointerType>(getAddr().getType());
1713 if (!resultType || symTy != resultType.getPointee())
1714 return emitOpError("result type pointee type '")
1715 << resultType.getPointee() << "' does not match type " << symTy
1716 << " of the global @" << getName();
1717
1718 return success();
1719}
1720
1721//===----------------------------------------------------------------------===//
1722// VTableAddrPointOp
1723//===----------------------------------------------------------------------===//
1724
1725LogicalResult
1726cir::VTableAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1727 StringRef name = getName();
1728
1729 // Verify that the result type underlying pointer type matches the type of
1730 // the referenced cir.global.
1731 auto op =
1732 symbolTable.lookupNearestSymbolFrom<cir::GlobalOp>(*this, getNameAttr());
1733 if (!op)
1734 return emitOpError("'")
1735 << name << "' does not reference a valid cir.global";
1736 std::optional<mlir::Attribute> init = op.getInitialValue();
1737 if (!init)
1738 return success();
1739 if (!isa<cir::VTableAttr>(*init))
1740 return emitOpError("Expected #cir.vtable in initializer for global '")
1741 << name << "'";
1742 return success();
1743}
1744
1745//===----------------------------------------------------------------------===//
1746// VTTAddrPointOp
1747//===----------------------------------------------------------------------===//
1748
1749LogicalResult
1750cir::VTTAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1751 // VTT ptr is not coming from a symbol.
1752 if (!getName())
1753 return success();
1754 StringRef name = *getName();
1755
1756 // Verify that the result type underlying pointer type matches the type of
1757 // the referenced cir.global op.
1758 auto op =
1759 symbolTable.lookupNearestSymbolFrom<cir::GlobalOp>(*this, getNameAttr());
1760 if (!op)
1761 return emitOpError("'")
1762 << name << "' does not reference a valid cir.global";
1763 std::optional<mlir::Attribute> init = op.getInitialValue();
1764 if (!init)
1765 return success();
1766 if (!isa<cir::ConstArrayAttr>(*init))
1767 return emitOpError(
1768 "Expected constant array in initializer for global VTT '")
1769 << name << "'";
1770 return success();
1771}
1772
1773LogicalResult cir::VTTAddrPointOp::verify() {
1774 // The operation uses either a symbol or a value to operate, but not both
1775 if (getName() && getSymAddr())
1776 return emitOpError("should use either a symbol or value, but not both");
1777
1778 // If not a symbol, stick with the concrete type used for getSymAddr.
1779 if (getSymAddr())
1780 return success();
1781
1782 mlir::Type resultType = getAddr().getType();
1783 mlir::Type resTy = cir::PointerType::get(
1784 cir::PointerType::get(cir::VoidType::get(getContext())));
1785
1786 if (resultType != resTy)
1787 return emitOpError("result type must be ")
1788 << resTy << ", but provided result type is " << resultType;
1789 return success();
1790}
1791
1792//===----------------------------------------------------------------------===//
1793// FuncOp
1794//===----------------------------------------------------------------------===//
1795
1796/// Returns the name used for the linkage attribute. This *must* correspond to
1797/// the name of the attribute in ODS.
1798static llvm::StringRef getLinkageAttrNameString() { return "linkage"; }
1799
1800void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
1801 StringRef name, FuncType type,
1802 GlobalLinkageKind linkage) {
1803 result.addRegion();
1804 result.addAttribute(SymbolTable::getSymbolAttrName(),
1805 builder.getStringAttr(name));
1806 result.addAttribute(getFunctionTypeAttrName(result.name),
1807 TypeAttr::get(type));
1808 result.addAttribute(
1810 GlobalLinkageKindAttr::get(builder.getContext(), linkage));
1811 result.addAttribute(getGlobalVisibilityAttrName(result.name),
1812 cir::VisibilityAttr::get(builder.getContext()));
1813}
1814
1815ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
1816 llvm::SMLoc loc = parser.getCurrentLocation();
1817 mlir::Builder &builder = parser.getBuilder();
1818
1819 mlir::StringAttr builtinNameAttr = getBuiltinAttrName(state.name);
1820 mlir::StringAttr coroutineNameAttr = getCoroutineAttrName(state.name);
1821 mlir::StringAttr inlineKindNameAttr = getInlineKindAttrName(state.name);
1822 mlir::StringAttr lambdaNameAttr = getLambdaAttrName(state.name);
1823 mlir::StringAttr noProtoNameAttr = getNoProtoAttrName(state.name);
1824 mlir::StringAttr visNameAttr = getSymVisibilityAttrName(state.name);
1825 mlir::StringAttr visibilityNameAttr = getGlobalVisibilityAttrName(state.name);
1826 mlir::StringAttr dsoLocalNameAttr = getDsoLocalAttrName(state.name);
1827 mlir::StringAttr specialMemberAttr = getCxxSpecialMemberAttrName(state.name);
1828
1829 if (::mlir::succeeded(parser.parseOptionalKeyword(builtinNameAttr.strref())))
1830 state.addAttribute(builtinNameAttr, parser.getBuilder().getUnitAttr());
1831 if (::mlir::succeeded(
1832 parser.parseOptionalKeyword(coroutineNameAttr.strref())))
1833 state.addAttribute(coroutineNameAttr, parser.getBuilder().getUnitAttr());
1834
1835 // Parse optional inline kind attribute
1836 cir::InlineKindAttr inlineKindAttr;
1837 if (failed(parseInlineKindAttr(parser, inlineKindAttr)))
1838 return failure();
1839 if (inlineKindAttr)
1840 state.addAttribute(inlineKindNameAttr, inlineKindAttr);
1841
1842 if (::mlir::succeeded(parser.parseOptionalKeyword(lambdaNameAttr.strref())))
1843 state.addAttribute(lambdaNameAttr, parser.getBuilder().getUnitAttr());
1844 if (parser.parseOptionalKeyword(noProtoNameAttr).succeeded())
1845 state.addAttribute(noProtoNameAttr, parser.getBuilder().getUnitAttr());
1846
1847 // Default to external linkage if no keyword is provided.
1848 state.addAttribute(getLinkageAttrNameString(),
1849 GlobalLinkageKindAttr::get(
1850 parser.getContext(),
1852 parser, GlobalLinkageKind::ExternalLinkage)));
1853
1854 ::llvm::StringRef visAttrStr;
1855 if (parser.parseOptionalKeyword(&visAttrStr, {"private", "public", "nested"})
1856 .succeeded()) {
1857 state.addAttribute(visNameAttr,
1858 parser.getBuilder().getStringAttr(visAttrStr));
1859 }
1860
1861 cir::VisibilityAttr cirVisibilityAttr;
1862 parseVisibilityAttr(parser, cirVisibilityAttr);
1863 state.addAttribute(visibilityNameAttr, cirVisibilityAttr);
1864
1865 if (parser.parseOptionalKeyword(dsoLocalNameAttr).succeeded())
1866 state.addAttribute(dsoLocalNameAttr, parser.getBuilder().getUnitAttr());
1867
1868 StringAttr nameAttr;
1869 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1870 state.attributes))
1871 return failure();
1875 bool isVariadic = false;
1876 if (function_interface_impl::parseFunctionSignatureWithArguments(
1877 parser, /*allowVariadic=*/true, arguments, isVariadic, resultTypes,
1878 resultAttrs))
1879 return failure();
1881 for (OpAsmParser::Argument &arg : arguments)
1882 argTypes.push_back(arg.type);
1883
1884 if (resultTypes.size() > 1) {
1885 return parser.emitError(
1886 loc, "functions with multiple return types are not supported");
1887 }
1888
1889 mlir::Type returnType =
1890 (resultTypes.empty() ? cir::VoidType::get(builder.getContext())
1891 : resultTypes.front());
1892
1893 cir::FuncType fnType = cir::FuncType::get(argTypes, returnType, isVariadic);
1894 if (!fnType)
1895 return failure();
1896 state.addAttribute(getFunctionTypeAttrName(state.name),
1897 TypeAttr::get(fnType));
1898
1899 bool hasAlias = false;
1900 mlir::StringAttr aliaseeNameAttr = getAliaseeAttrName(state.name);
1901 if (parser.parseOptionalKeyword("alias").succeeded()) {
1902 if (parser.parseLParen().failed())
1903 return failure();
1904 mlir::StringAttr aliaseeAttr;
1905 if (parser.parseOptionalSymbolName(aliaseeAttr).failed())
1906 return failure();
1907 state.addAttribute(aliaseeNameAttr, FlatSymbolRefAttr::get(aliaseeAttr));
1908 if (parser.parseRParen().failed())
1909 return failure();
1910 hasAlias = true;
1911 }
1912
1913 mlir::StringAttr personalityNameAttr = getPersonalityAttrName(state.name);
1914 if (parser.parseOptionalKeyword("personality").succeeded()) {
1915 if (parser.parseLParen().failed())
1916 return failure();
1917 mlir::StringAttr personalityAttr;
1918 if (parser.parseOptionalSymbolName(personalityAttr).failed())
1919 return failure();
1920 state.addAttribute(personalityNameAttr,
1921 FlatSymbolRefAttr::get(personalityAttr));
1922 if (parser.parseRParen().failed())
1923 return failure();
1924 }
1925
1926 auto parseGlobalDtorCtor =
1927 [&](StringRef keyword,
1928 llvm::function_ref<void(std::optional<int> prio)> createAttr)
1929 -> mlir::LogicalResult {
1930 if (mlir::succeeded(parser.parseOptionalKeyword(keyword))) {
1931 std::optional<int> priority;
1932 if (mlir::succeeded(parser.parseOptionalLParen())) {
1933 auto parsedPriority = mlir::FieldParser<int>::parse(parser);
1934 if (mlir::failed(parsedPriority))
1935 return parser.emitError(parser.getCurrentLocation(),
1936 "failed to parse 'priority', of type 'int'");
1937 priority = parsedPriority.value_or(int());
1938 // Parse literal ')'
1939 if (parser.parseRParen())
1940 return failure();
1941 }
1942 createAttr(priority);
1943 }
1944 return success();
1945 };
1946
1947 // Parse CXXSpecialMember attribute
1948 if (parser.parseOptionalKeyword("special_member").succeeded()) {
1949 cir::CXXCtorAttr ctorAttr;
1950 cir::CXXDtorAttr dtorAttr;
1951 cir::CXXAssignAttr assignAttr;
1952 if (parser.parseLess().failed())
1953 return failure();
1954 if (parser.parseOptionalAttribute(ctorAttr).has_value())
1955 state.addAttribute(specialMemberAttr, ctorAttr);
1956 else if (parser.parseOptionalAttribute(dtorAttr).has_value())
1957 state.addAttribute(specialMemberAttr, dtorAttr);
1958 else if (parser.parseOptionalAttribute(assignAttr).has_value())
1959 state.addAttribute(specialMemberAttr, assignAttr);
1960 if (parser.parseGreater().failed())
1961 return failure();
1962 }
1963
1964 if (parseGlobalDtorCtor("global_ctor", [&](std::optional<int> priority) {
1965 mlir::IntegerAttr globalCtorPriorityAttr =
1966 builder.getI32IntegerAttr(priority.value_or(65535));
1967 state.addAttribute(getGlobalCtorPriorityAttrName(state.name),
1968 globalCtorPriorityAttr);
1969 }).failed())
1970 return failure();
1971
1972 if (parseGlobalDtorCtor("global_dtor", [&](std::optional<int> priority) {
1973 mlir::IntegerAttr globalDtorPriorityAttr =
1974 builder.getI32IntegerAttr(priority.value_or(65535));
1975 state.addAttribute(getGlobalDtorPriorityAttrName(state.name),
1976 globalDtorPriorityAttr);
1977 }).failed())
1978 return failure();
1979
1980 // Parse the rest of the attributes.
1981 NamedAttrList parsedAttrs;
1982 if (parser.parseOptionalAttrDictWithKeyword(parsedAttrs))
1983 return failure();
1984
1985 for (StringRef disallowed : cir::FuncOp::getAttributeNames()) {
1986 if (parsedAttrs.get(disallowed))
1987 return parser.emitError(loc, "attribute '")
1988 << disallowed
1989 << "' should not be specified in the explicit attribute list";
1990 }
1991
1992 state.attributes.append(parsedAttrs);
1993
1994 // Parse the optional function body.
1995 auto *body = state.addRegion();
1996 OptionalParseResult parseResult = parser.parseOptionalRegion(
1997 *body, arguments, /*enableNameShadowing=*/false);
1998 if (parseResult.has_value()) {
1999 if (hasAlias)
2000 return parser.emitError(loc, "function alias shall not have a body");
2001 if (failed(*parseResult))
2002 return failure();
2003 // Function body was parsed, make sure its not empty.
2004 if (body->empty())
2005 return parser.emitError(loc, "expected non-empty function body");
2006 }
2007
2008 return success();
2009}
2010
2011// This function corresponds to `llvm::GlobalValue::isDeclaration` and should
2012// have a similar implementation. We don't currently ifuncs or materializable
2013// functions, but those should be handled here as they are implemented.
2014bool cir::FuncOp::isDeclaration() {
2016
2017 std::optional<StringRef> aliasee = getAliasee();
2018 if (!aliasee)
2019 return getFunctionBody().empty();
2020
2021 // Aliases are always definitions.
2022 return false;
2023}
2024
2025bool cir::FuncOp::isCXXSpecialMemberFunction() {
2026 return getCxxSpecialMemberAttr() != nullptr;
2027}
2028
2029bool cir::FuncOp::isCxxConstructor() {
2030 auto attr = getCxxSpecialMemberAttr();
2031 return attr && dyn_cast<CXXCtorAttr>(attr);
2032}
2033
2034bool cir::FuncOp::isCxxDestructor() {
2035 auto attr = getCxxSpecialMemberAttr();
2036 return attr && dyn_cast<CXXDtorAttr>(attr);
2037}
2038
2039bool cir::FuncOp::isCxxSpecialAssignment() {
2040 auto attr = getCxxSpecialMemberAttr();
2041 return attr && dyn_cast<CXXAssignAttr>(attr);
2042}
2043
2044std::optional<CtorKind> cir::FuncOp::getCxxConstructorKind() {
2045 mlir::Attribute attr = getCxxSpecialMemberAttr();
2046 if (attr) {
2047 if (auto ctor = dyn_cast<CXXCtorAttr>(attr))
2048 return ctor.getCtorKind();
2049 }
2050 return std::nullopt;
2051}
2052
2053std::optional<AssignKind> cir::FuncOp::getCxxSpecialAssignKind() {
2054 mlir::Attribute attr = getCxxSpecialMemberAttr();
2055 if (attr) {
2056 if (auto assign = dyn_cast<CXXAssignAttr>(attr))
2057 return assign.getAssignKind();
2058 }
2059 return std::nullopt;
2060}
2061
2062bool cir::FuncOp::isCxxTrivialMemberFunction() {
2063 mlir::Attribute attr = getCxxSpecialMemberAttr();
2064 if (attr) {
2065 if (auto ctor = dyn_cast<CXXCtorAttr>(attr))
2066 return ctor.getIsTrivial();
2067 if (auto dtor = dyn_cast<CXXDtorAttr>(attr))
2068 return dtor.getIsTrivial();
2069 if (auto assign = dyn_cast<CXXAssignAttr>(attr))
2070 return assign.getIsTrivial();
2071 }
2072 return false;
2073}
2074
2075mlir::Region *cir::FuncOp::getCallableRegion() {
2076 // TODO(CIR): This function will have special handling for aliases and a
2077 // check for an external function, once those features have been upstreamed.
2078 return &getBody();
2079}
2080
2081void cir::FuncOp::print(OpAsmPrinter &p) {
2082 if (getBuiltin())
2083 p << " builtin";
2084
2085 if (getCoroutine())
2086 p << " coroutine";
2087
2088 printInlineKindAttr(p, getInlineKindAttr());
2089
2090 if (getLambda())
2091 p << " lambda";
2092
2093 if (getNoProto())
2094 p << " no_proto";
2095
2096 if (getComdat())
2097 p << " comdat";
2098
2099 if (getLinkage() != GlobalLinkageKind::ExternalLinkage)
2100 p << ' ' << stringifyGlobalLinkageKind(getLinkage());
2101
2102 mlir::SymbolTable::Visibility vis = getVisibility();
2103 if (vis != mlir::SymbolTable::Visibility::Public)
2104 p << ' ' << vis;
2105
2106 cir::VisibilityAttr cirVisibilityAttr = getGlobalVisibilityAttr();
2107 if (!cirVisibilityAttr.isDefault()) {
2108 p << ' ';
2109 printVisibilityAttr(p, cirVisibilityAttr);
2110 }
2111
2112 if (getDsoLocal())
2113 p << " dso_local";
2114
2115 p << ' ';
2116 p.printSymbolName(getSymName());
2117 cir::FuncType fnType = getFunctionType();
2118 function_interface_impl::printFunctionSignature(
2119 p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes());
2120
2121 if (std::optional<StringRef> aliaseeName = getAliasee()) {
2122 p << " alias(";
2123 p.printSymbolName(*aliaseeName);
2124 p << ")";
2125 }
2126
2127 if (std::optional<StringRef> personalityName = getPersonality()) {
2128 p << " personality(";
2129 p.printSymbolName(*personalityName);
2130 p << ")";
2131 }
2132
2133 if (auto specialMemberAttr = getCxxSpecialMember()) {
2134 p << " special_member<";
2135 p.printAttribute(*specialMemberAttr);
2136 p << '>';
2137 }
2138
2139 if (auto globalCtorPriority = getGlobalCtorPriority()) {
2140 p << " global_ctor";
2141 if (globalCtorPriority.value() != 65535)
2142 p << "(" << globalCtorPriority.value() << ")";
2143 }
2144
2145 if (auto globalDtorPriority = getGlobalDtorPriority()) {
2146 p << " global_dtor";
2147 if (globalDtorPriority.value() != 65535)
2148 p << "(" << globalDtorPriority.value() << ")";
2149 }
2150
2151 function_interface_impl::printFunctionAttributes(
2152 p, *this, cir::FuncOp::getAttributeNames());
2153
2154 // Print the body if this is not an external function.
2155 Region &body = getOperation()->getRegion(0);
2156 if (!body.empty()) {
2157 p << ' ';
2158 p.printRegion(body, /*printEntryBlockArgs=*/false,
2159 /*printBlockTerminators=*/true);
2160 }
2161}
2162
2163mlir::LogicalResult cir::FuncOp::verify() {
2164
2165 if (!isDeclaration() && getCoroutine()) {
2166 bool foundAwait = false;
2167 this->walk([&](Operation *op) {
2168 if (auto await = dyn_cast<AwaitOp>(op)) {
2169 foundAwait = true;
2170 return;
2171 }
2172 });
2173 if (!foundAwait)
2174 return emitOpError()
2175 << "coroutine body must use at least one cir.await op";
2176 }
2177
2178 llvm::SmallSet<llvm::StringRef, 16> labels;
2179 llvm::SmallSet<llvm::StringRef, 16> gotos;
2180 llvm::SmallSet<llvm::StringRef, 16> blockAddresses;
2181 bool invalidBlockAddress = false;
2182 getOperation()->walk([&](mlir::Operation *op) {
2183 if (auto lab = dyn_cast<cir::LabelOp>(op)) {
2184 labels.insert(lab.getLabel());
2185 } else if (auto goTo = dyn_cast<cir::GotoOp>(op)) {
2186 gotos.insert(goTo.getLabel());
2187 } else if (auto blkAdd = dyn_cast<cir::BlockAddressOp>(op)) {
2188 if (blkAdd.getBlockAddrInfoAttr().getFunc().getAttr() != getSymName()) {
2189 // Stop the walk early, no need to continue
2190 invalidBlockAddress = true;
2191 return mlir::WalkResult::interrupt();
2192 }
2193 blockAddresses.insert(blkAdd.getBlockAddrInfoAttr().getLabel());
2194 }
2195 return mlir::WalkResult::advance();
2196 });
2197
2198 if (invalidBlockAddress)
2199 return emitOpError() << "blockaddress references a different function";
2200
2201 llvm::SmallSet<llvm::StringRef, 16> mismatched;
2202 if (!labels.empty() || !gotos.empty()) {
2203 mismatched = llvm::set_difference(gotos, labels);
2204
2205 if (!mismatched.empty())
2206 return emitOpError() << "goto/label mismatch";
2207 }
2208
2209 mismatched.clear();
2210
2211 if (!labels.empty() || !blockAddresses.empty()) {
2212 mismatched = llvm::set_difference(blockAddresses, labels);
2213
2214 if (!mismatched.empty())
2215 return emitOpError()
2216 << "expects an existing label target in the referenced function";
2217 }
2218
2219 return success();
2220}
2221
2222//===----------------------------------------------------------------------===//
2223// BinOp
2224//===----------------------------------------------------------------------===//
2225LogicalResult cir::BinOp::verify() {
2226 bool noWrap = getNoUnsignedWrap() || getNoSignedWrap();
2227 bool saturated = getSaturated();
2228
2229 if (!isa<cir::IntType>(getType()) && noWrap)
2230 return emitError()
2231 << "only operations on integer values may have nsw/nuw flags";
2232
2233 bool noWrapOps = getKind() == cir::BinOpKind::Add ||
2234 getKind() == cir::BinOpKind::Sub ||
2235 getKind() == cir::BinOpKind::Mul;
2236
2237 bool saturatedOps =
2238 getKind() == cir::BinOpKind::Add || getKind() == cir::BinOpKind::Sub;
2239
2240 if (noWrap && !noWrapOps)
2241 return emitError() << "The nsw/nuw flags are applicable to opcodes: 'add', "
2242 "'sub' and 'mul'";
2243 if (saturated && !saturatedOps)
2244 return emitError() << "The saturated flag is applicable to opcodes: 'add' "
2245 "and 'sub'";
2246 if (noWrap && saturated)
2247 return emitError() << "The nsw/nuw flags and the saturated flag are "
2248 "mutually exclusive";
2249
2250 return mlir::success();
2251}
2252
2253//===----------------------------------------------------------------------===//
2254// TernaryOp
2255//===----------------------------------------------------------------------===//
2256
2257/// Given the region at `point`, or the parent operation if `point` is None,
2258/// return the successor regions. These are the regions that may be selected
2259/// during the flow of control. `operands` is a set of optional attributes that
2260/// correspond to a constant value for each operand, or null if that operand is
2261/// not a constant.
2262void cir::TernaryOp::getSuccessorRegions(
2263 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2264 // The `true` and the `false` region branch back to the parent operation.
2265 if (!point.isParent()) {
2266 regions.push_back(RegionSuccessor(getOperation(), this->getODSResults(0)));
2267 return;
2268 }
2269
2270 // When branching from the parent operation, both the true and false
2271 // regions are considered possible successors
2272 regions.push_back(RegionSuccessor(&getTrueRegion()));
2273 regions.push_back(RegionSuccessor(&getFalseRegion()));
2274}
2275
2276void cir::TernaryOp::build(
2277 OpBuilder &builder, OperationState &result, Value cond,
2278 function_ref<void(OpBuilder &, Location)> trueBuilder,
2279 function_ref<void(OpBuilder &, Location)> falseBuilder) {
2280 result.addOperands(cond);
2281 OpBuilder::InsertionGuard guard(builder);
2282 Region *trueRegion = result.addRegion();
2283 builder.createBlock(trueRegion);
2284 trueBuilder(builder, result.location);
2285 Region *falseRegion = result.addRegion();
2286 builder.createBlock(falseRegion);
2287 falseBuilder(builder, result.location);
2288
2289 // Get result type from whichever branch has a yield (the other may have
2290 // unreachable from a throw expression)
2291 auto yield =
2292 dyn_cast_or_null<cir::YieldOp>(trueRegion->back().getTerminator());
2293 if (!yield)
2294 yield = dyn_cast_or_null<cir::YieldOp>(falseRegion->back().getTerminator());
2295
2296 assert((yield && yield.getNumOperands() <= 1) &&
2297 "expected zero or one result type");
2298 if (yield.getNumOperands() == 1)
2299 result.addTypes(TypeRange{yield.getOperandTypes().front()});
2300}
2301
2302//===----------------------------------------------------------------------===//
2303// SelectOp
2304//===----------------------------------------------------------------------===//
2305
2306OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
2307 mlir::Attribute condition = adaptor.getCondition();
2308 if (condition) {
2309 bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
2310 return conditionValue ? getTrueValue() : getFalseValue();
2311 }
2312
2313 // cir.select if %0 then x else x -> x
2314 mlir::Attribute trueValue = adaptor.getTrueValue();
2315 mlir::Attribute falseValue = adaptor.getFalseValue();
2316 if (trueValue == falseValue)
2317 return trueValue;
2318 if (getTrueValue() == getFalseValue())
2319 return getTrueValue();
2320
2321 return {};
2322}
2323
2324LogicalResult cir::SelectOp::verify() {
2325 // AllTypesMatch already guarantees trueVal and falseVal have matching types.
2326 auto condTy = dyn_cast<cir::VectorType>(getCondition().getType());
2327
2328 // If condition is not a vector, no further checks are needed.
2329 if (!condTy)
2330 return success();
2331
2332 // When condition is a vector, both other operands must also be vectors.
2333 if (!isa<cir::VectorType>(getTrueValue().getType()) ||
2334 !isa<cir::VectorType>(getFalseValue().getType())) {
2335 return emitOpError()
2336 << "expected both true and false operands to be vector types "
2337 "when the condition is a vector boolean type";
2338 }
2339
2340 return success();
2341}
2342
2343//===----------------------------------------------------------------------===//
2344// ShiftOp
2345//===----------------------------------------------------------------------===//
2346LogicalResult cir::ShiftOp::verify() {
2347 mlir::Operation *op = getOperation();
2348 auto op0VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(0).getType());
2349 auto op1VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(1).getType());
2350 if (!op0VecTy ^ !op1VecTy)
2351 return emitOpError() << "input types cannot be one vector and one scalar";
2352
2353 if (op0VecTy) {
2354 if (op0VecTy.getSize() != op1VecTy.getSize())
2355 return emitOpError() << "input vector types must have the same size";
2356
2357 auto opResultTy = mlir::dyn_cast<cir::VectorType>(getType());
2358 if (!opResultTy)
2359 return emitOpError() << "the type of the result must be a vector "
2360 << "if it is vector shift";
2361
2362 auto op0VecEleTy = mlir::cast<cir::IntType>(op0VecTy.getElementType());
2363 auto op1VecEleTy = mlir::cast<cir::IntType>(op1VecTy.getElementType());
2364 if (op0VecEleTy.getWidth() != op1VecEleTy.getWidth())
2365 return emitOpError()
2366 << "vector operands do not have the same elements sizes";
2367
2368 auto resVecEleTy = mlir::cast<cir::IntType>(opResultTy.getElementType());
2369 if (op0VecEleTy.getWidth() != resVecEleTy.getWidth())
2370 return emitOpError() << "vector operands and result type do not have the "
2371 "same elements sizes";
2372 }
2373
2374 return mlir::success();
2375}
2376
2377//===----------------------------------------------------------------------===//
2378// LabelOp Definitions
2379//===----------------------------------------------------------------------===//
2380
2381LogicalResult cir::LabelOp::verify() {
2382 mlir::Operation *op = getOperation();
2383 mlir::Block *blk = op->getBlock();
2384 if (&blk->front() != op)
2385 return emitError() << "must be the first operation in a block";
2386
2387 return mlir::success();
2388}
2389
2390//===----------------------------------------------------------------------===//
2391// UnaryOp
2392//===----------------------------------------------------------------------===//
2393
2394LogicalResult cir::UnaryOp::verify() {
2395 switch (getKind()) {
2396 case cir::UnaryOpKind::Inc:
2397 case cir::UnaryOpKind::Dec:
2398 case cir::UnaryOpKind::Plus:
2399 case cir::UnaryOpKind::Minus:
2400 case cir::UnaryOpKind::Not:
2401 // Nothing to verify.
2402 return success();
2403 }
2404
2405 llvm_unreachable("Unknown UnaryOp kind?");
2406}
2407
2408static bool isBoolNot(cir::UnaryOp op) {
2409 return isa<cir::BoolType>(op.getInput().getType()) &&
2410 op.getKind() == cir::UnaryOpKind::Not;
2411}
2412
2413// This folder simplifies the sequential boolean not operations.
2414// For instance, the next two unary operations will be eliminated:
2415//
2416// ```mlir
2417// %1 = cir.unary(not, %0) : !cir.bool, !cir.bool
2418// %2 = cir.unary(not, %1) : !cir.bool, !cir.bool
2419// ```
2420//
2421// and the argument of the first one (%0) will be used instead.
2422OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) {
2423 if (auto poison =
2424 mlir::dyn_cast_if_present<cir::PoisonAttr>(adaptor.getInput())) {
2425 // Propagate poison values
2426 return poison;
2427 }
2428
2429 if (isBoolNot(*this))
2430 if (auto previous = getInput().getDefiningOp<cir::UnaryOp>())
2431 if (isBoolNot(previous))
2432 return previous.getInput();
2433
2434 return {};
2435}
2436//===----------------------------------------------------------------------===//
2437// AwaitOp
2438//===----------------------------------------------------------------------===//
2439
2440void cir::AwaitOp::build(OpBuilder &builder, OperationState &result,
2441 cir::AwaitKind kind, BuilderCallbackRef readyBuilder,
2442 BuilderCallbackRef suspendBuilder,
2443 BuilderCallbackRef resumeBuilder) {
2444 result.addAttribute(getKindAttrName(result.name),
2445 cir::AwaitKindAttr::get(builder.getContext(), kind));
2446 {
2447 OpBuilder::InsertionGuard guard(builder);
2448 Region *readyRegion = result.addRegion();
2449 builder.createBlock(readyRegion);
2450 readyBuilder(builder, result.location);
2451 }
2452
2453 {
2454 OpBuilder::InsertionGuard guard(builder);
2455 Region *suspendRegion = result.addRegion();
2456 builder.createBlock(suspendRegion);
2457 suspendBuilder(builder, result.location);
2458 }
2459
2460 {
2461 OpBuilder::InsertionGuard guard(builder);
2462 Region *resumeRegion = result.addRegion();
2463 builder.createBlock(resumeRegion);
2464 resumeBuilder(builder, result.location);
2465 }
2466}
2467
2468void cir::AwaitOp::getSuccessorRegions(
2469 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2470 // If any index all the underlying regions branch back to the parent
2471 // operation.
2472 if (!point.isParent()) {
2473 regions.push_back(
2474 RegionSuccessor(getOperation(), getOperation()->getResults()));
2475 return;
2476 }
2477
2478 // TODO: retrieve information from the promise and only push the
2479 // necessary ones. Example: `std::suspend_never` on initial or final
2480 // await's might allow suspend region to be skipped.
2481 regions.push_back(RegionSuccessor(&this->getReady()));
2482 regions.push_back(RegionSuccessor(&this->getSuspend()));
2483 regions.push_back(RegionSuccessor(&this->getResume()));
2484}
2485
2486LogicalResult cir::AwaitOp::verify() {
2487 if (!isa<ConditionOp>(this->getReady().back().getTerminator()))
2488 return emitOpError("ready region must end with cir.condition");
2489 return success();
2490}
2491
2492//===----------------------------------------------------------------------===//
2493// CopyOp Definitions
2494//===----------------------------------------------------------------------===//
2495
2496LogicalResult cir::CopyOp::verify() {
2497 // A data layout is required for us to know the number of bytes to be copied.
2498 if (!getType().getPointee().hasTrait<DataLayoutTypeInterface::Trait>())
2499 return emitError() << "missing data layout for pointee type";
2500
2501 if (getSrc() == getDst())
2502 return emitError() << "source and destination are the same";
2503
2504 return mlir::success();
2505}
2506
2507//===----------------------------------------------------------------------===//
2508// GetRuntimeMemberOp Definitions
2509//===----------------------------------------------------------------------===//
2510
2511LogicalResult cir::GetRuntimeMemberOp::verify() {
2512 auto recordTy = mlir::cast<RecordType>(getAddr().getType().getPointee());
2513 cir::DataMemberType memberPtrTy = getMember().getType();
2514
2515 if (recordTy != memberPtrTy.getClassTy())
2516 return emitError() << "record type does not match the member pointer type";
2517 if (getType().getPointee() != memberPtrTy.getMemberTy())
2518 return emitError() << "result type does not match the member pointer type";
2519 return mlir::success();
2520}
2521
2522//===----------------------------------------------------------------------===//
2523// GetMemberOp Definitions
2524//===----------------------------------------------------------------------===//
2525
2526LogicalResult cir::GetMemberOp::verify() {
2527 const auto recordTy = dyn_cast<RecordType>(getAddrTy().getPointee());
2528 if (!recordTy)
2529 return emitError() << "expected pointer to a record type";
2530
2531 if (recordTy.getMembers().size() <= getIndex())
2532 return emitError() << "member index out of bounds";
2533
2534 if (recordTy.getMembers()[getIndex()] != getType().getPointee())
2535 return emitError() << "member type mismatch";
2536
2537 return mlir::success();
2538}
2539
2540//===----------------------------------------------------------------------===//
2541// VecCreateOp
2542//===----------------------------------------------------------------------===//
2543
2544OpFoldResult cir::VecCreateOp::fold(FoldAdaptor adaptor) {
2545 if (llvm::any_of(getElements(), [](mlir::Value value) {
2546 return !value.getDefiningOp<cir::ConstantOp>();
2547 }))
2548 return {};
2549
2550 return cir::ConstVectorAttr::get(
2551 getType(), mlir::ArrayAttr::get(getContext(), adaptor.getElements()));
2552}
2553
2554LogicalResult cir::VecCreateOp::verify() {
2555 // Verify that the number of arguments matches the number of elements in the
2556 // vector, and that the type of all the arguments matches the type of the
2557 // elements in the vector.
2558 const cir::VectorType vecTy = getType();
2559 if (getElements().size() != vecTy.getSize()) {
2560 return emitOpError() << "operand count of " << getElements().size()
2561 << " doesn't match vector type " << vecTy
2562 << " element count of " << vecTy.getSize();
2563 }
2564
2565 const mlir::Type elementType = vecTy.getElementType();
2566 for (const mlir::Value element : getElements()) {
2567 if (element.getType() != elementType) {
2568 return emitOpError() << "operand type " << element.getType()
2569 << " doesn't match vector element type "
2570 << elementType;
2571 }
2572 }
2573
2574 return success();
2575}
2576
2577//===----------------------------------------------------------------------===//
2578// VecExtractOp
2579//===----------------------------------------------------------------------===//
2580
2581OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
2582 const auto vectorAttr =
2583 llvm::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec());
2584 if (!vectorAttr)
2585 return {};
2586
2587 const auto indexAttr =
2588 llvm::dyn_cast_if_present<cir::IntAttr>(adaptor.getIndex());
2589 if (!indexAttr)
2590 return {};
2591
2592 const mlir::ArrayAttr elements = vectorAttr.getElts();
2593 const uint64_t index = indexAttr.getUInt();
2594 if (index >= elements.size())
2595 return {};
2596
2597 return elements[index];
2598}
2599
2600//===----------------------------------------------------------------------===//
2601// VecCmpOp
2602//===----------------------------------------------------------------------===//
2603
2604OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
2605 auto lhsVecAttr =
2606 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getLhs());
2607 auto rhsVecAttr =
2608 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getRhs());
2609 if (!lhsVecAttr || !rhsVecAttr)
2610 return {};
2611
2612 mlir::Type inputElemTy =
2613 mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType();
2614 if (!isAnyIntegerOrFloatingPointType(inputElemTy))
2615 return {};
2616
2617 cir::CmpOpKind opKind = adaptor.getKind();
2618 mlir::ArrayAttr lhsVecElhs = lhsVecAttr.getElts();
2619 mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts();
2620 uint64_t vecSize = lhsVecElhs.size();
2621
2622 SmallVector<mlir::Attribute, 16> elements(vecSize);
2623 bool isIntAttr = vecSize && mlir::isa<cir::IntAttr>(lhsVecElhs[0]);
2624 for (uint64_t i = 0; i < vecSize; i++) {
2625 mlir::Attribute lhsAttr = lhsVecElhs[i];
2626 mlir::Attribute rhsAttr = rhsVecElhs[i];
2627 int cmpResult = 0;
2628 switch (opKind) {
2629 case cir::CmpOpKind::lt: {
2630 if (isIntAttr) {
2631 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
2632 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2633 } else {
2634 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <
2635 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2636 }
2637 break;
2638 }
2639 case cir::CmpOpKind::le: {
2640 if (isIntAttr) {
2641 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
2642 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2643 } else {
2644 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <=
2645 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2646 }
2647 break;
2648 }
2649 case cir::CmpOpKind::gt: {
2650 if (isIntAttr) {
2651 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
2652 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2653 } else {
2654 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >
2655 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2656 }
2657 break;
2658 }
2659 case cir::CmpOpKind::ge: {
2660 if (isIntAttr) {
2661 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
2662 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2663 } else {
2664 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >=
2665 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2666 }
2667 break;
2668 }
2669 case cir::CmpOpKind::eq: {
2670 if (isIntAttr) {
2671 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
2672 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2673 } else {
2674 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() ==
2675 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2676 }
2677 break;
2678 }
2679 case cir::CmpOpKind::ne: {
2680 if (isIntAttr) {
2681 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
2682 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2683 } else {
2684 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() !=
2685 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2686 }
2687 break;
2688 }
2689 }
2690
2691 elements[i] = cir::IntAttr::get(getType().getElementType(), cmpResult);
2692 }
2693
2694 return cir::ConstVectorAttr::get(
2695 getType(), mlir::ArrayAttr::get(getContext(), elements));
2696}
2697
2698//===----------------------------------------------------------------------===//
2699// VecShuffleOp
2700//===----------------------------------------------------------------------===//
2701
2702OpFoldResult cir::VecShuffleOp::fold(FoldAdaptor adaptor) {
2703 auto vec1Attr =
2704 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec1());
2705 auto vec2Attr =
2706 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec2());
2707 if (!vec1Attr || !vec2Attr)
2708 return {};
2709
2710 mlir::Type vec1ElemTy =
2711 mlir::cast<cir::VectorType>(vec1Attr.getType()).getElementType();
2712
2713 mlir::ArrayAttr vec1Elts = vec1Attr.getElts();
2714 mlir::ArrayAttr vec2Elts = vec2Attr.getElts();
2715 mlir::ArrayAttr indicesElts = adaptor.getIndices();
2716
2718 elements.reserve(indicesElts.size());
2719
2720 uint64_t vec1Size = vec1Elts.size();
2721 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
2722 if (idxAttr.getSInt() == -1) {
2723 elements.push_back(cir::UndefAttr::get(vec1ElemTy));
2724 continue;
2725 }
2726
2727 uint64_t idxValue = idxAttr.getUInt();
2728 elements.push_back(idxValue < vec1Size ? vec1Elts[idxValue]
2729 : vec2Elts[idxValue - vec1Size]);
2730 }
2731
2732 return cir::ConstVectorAttr::get(
2733 getType(), mlir::ArrayAttr::get(getContext(), elements));
2734}
2735
2736LogicalResult cir::VecShuffleOp::verify() {
2737 // The number of elements in the indices array must match the number of
2738 // elements in the result type.
2739 if (getIndices().size() != getResult().getType().getSize()) {
2740 return emitOpError() << ": the number of elements in " << getIndices()
2741 << " and " << getResult().getType() << " don't match";
2742 }
2743
2744 // The element types of the two input vectors and of the result type must
2745 // match.
2746 if (getVec1().getType().getElementType() !=
2747 getResult().getType().getElementType()) {
2748 return emitOpError() << ": element types of " << getVec1().getType()
2749 << " and " << getResult().getType() << " don't match";
2750 }
2751
2752 const uint64_t maxValidIndex =
2753 getVec1().getType().getSize() + getVec2().getType().getSize() - 1;
2754 if (llvm::any_of(
2755 getIndices().getAsRange<cir::IntAttr>(), [&](cir::IntAttr idxAttr) {
2756 return idxAttr.getSInt() != -1 && idxAttr.getUInt() > maxValidIndex;
2757 })) {
2758 return emitOpError() << ": index for __builtin_shufflevector must be "
2759 "less than the total number of vector elements";
2760 }
2761 return success();
2762}
2763
2764//===----------------------------------------------------------------------===//
2765// VecShuffleDynamicOp
2766//===----------------------------------------------------------------------===//
2767
2768OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
2769 mlir::Attribute vec = adaptor.getVec();
2770 mlir::Attribute indices = adaptor.getIndices();
2771 if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec) &&
2772 mlir::isa_and_nonnull<cir::ConstVectorAttr>(indices)) {
2773 auto vecAttr = mlir::cast<cir::ConstVectorAttr>(vec);
2774 auto indicesAttr = mlir::cast<cir::ConstVectorAttr>(indices);
2775
2776 mlir::ArrayAttr vecElts = vecAttr.getElts();
2777 mlir::ArrayAttr indicesElts = indicesAttr.getElts();
2778
2779 const uint64_t numElements = vecElts.size();
2780
2782 elements.reserve(numElements);
2783
2784 const uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1;
2785 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
2786 uint64_t idxValue = idxAttr.getUInt();
2787 uint64_t newIdx = idxValue & maskBits;
2788 elements.push_back(vecElts[newIdx]);
2789 }
2790
2791 return cir::ConstVectorAttr::get(
2792 getType(), mlir::ArrayAttr::get(getContext(), elements));
2793 }
2794
2795 return {};
2796}
2797
2798LogicalResult cir::VecShuffleDynamicOp::verify() {
2799 // The number of elements in the two input vectors must match.
2800 if (getVec().getType().getSize() !=
2801 mlir::cast<cir::VectorType>(getIndices().getType()).getSize()) {
2802 return emitOpError() << ": the number of elements in " << getVec().getType()
2803 << " and " << getIndices().getType() << " don't match";
2804 }
2805 return success();
2806}
2807
2808//===----------------------------------------------------------------------===//
2809// VecTernaryOp
2810//===----------------------------------------------------------------------===//
2811
2812LogicalResult cir::VecTernaryOp::verify() {
2813 // Verify that the condition operand has the same number of elements as the
2814 // other operands. (The automatic verification already checked that all
2815 // operands are vector types and that the second and third operands are the
2816 // same type.)
2817 if (getCond().getType().getSize() != getLhs().getType().getSize()) {
2818 return emitOpError() << ": the number of elements in "
2819 << getCond().getType() << " and " << getLhs().getType()
2820 << " don't match";
2821 }
2822 return success();
2823}
2824
2825OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
2826 mlir::Attribute cond = adaptor.getCond();
2827 mlir::Attribute lhs = adaptor.getLhs();
2828 mlir::Attribute rhs = adaptor.getRhs();
2829
2830 if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
2831 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
2832 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
2833 return {};
2834 auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
2835 auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
2836 auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
2837
2838 mlir::ArrayAttr condElts = condVec.getElts();
2839
2841 elements.reserve(condElts.size());
2842
2843 for (const auto &[idx, condAttr] :
2844 llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
2845 if (condAttr.getSInt()) {
2846 elements.push_back(lhsVec.getElts()[idx]);
2847 } else {
2848 elements.push_back(rhsVec.getElts()[idx]);
2849 }
2850 }
2851
2852 cir::VectorType vecTy = getLhs().getType();
2853 return cir::ConstVectorAttr::get(
2854 vecTy, mlir::ArrayAttr::get(getContext(), elements));
2855}
2856
2857//===----------------------------------------------------------------------===//
2858// ComplexCreateOp
2859//===----------------------------------------------------------------------===//
2860
2861LogicalResult cir::ComplexCreateOp::verify() {
2862 if (getType().getElementType() != getReal().getType()) {
2863 emitOpError()
2864 << "operand type of cir.complex.create does not match its result type";
2865 return failure();
2866 }
2867
2868 return success();
2869}
2870
2871OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
2872 mlir::Attribute real = adaptor.getReal();
2873 mlir::Attribute imag = adaptor.getImag();
2874 if (!real || !imag)
2875 return {};
2876
2877 // When both of real and imag are constants, we can fold the operation into an
2878 // `#cir.const_complex` operation.
2879 auto realAttr = mlir::cast<mlir::TypedAttr>(real);
2880 auto imagAttr = mlir::cast<mlir::TypedAttr>(imag);
2881 return cir::ConstComplexAttr::get(realAttr, imagAttr);
2882}
2883
2884//===----------------------------------------------------------------------===//
2885// ComplexRealOp
2886//===----------------------------------------------------------------------===//
2887
2888LogicalResult cir::ComplexRealOp::verify() {
2889 mlir::Type operandTy = getOperand().getType();
2890 if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
2891 operandTy = complexOperandTy.getElementType();
2892
2893 if (getType() != operandTy) {
2894 emitOpError() << ": result type does not match operand type";
2895 return failure();
2896 }
2897
2898 return success();
2899}
2900
2901OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
2902 if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
2903 return nullptr;
2904
2905 if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
2906 return complexCreateOp.getOperand(0);
2907
2908 auto complex =
2909 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
2910 return complex ? complex.getReal() : nullptr;
2911}
2912
2913//===----------------------------------------------------------------------===//
2914// ComplexImagOp
2915//===----------------------------------------------------------------------===//
2916
2917LogicalResult cir::ComplexImagOp::verify() {
2918 mlir::Type operandTy = getOperand().getType();
2919 if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
2920 operandTy = complexOperandTy.getElementType();
2921
2922 if (getType() != operandTy) {
2923 emitOpError() << ": result type does not match operand type";
2924 return failure();
2925 }
2926
2927 return success();
2928}
2929
2930OpFoldResult cir::ComplexImagOp::fold(FoldAdaptor adaptor) {
2931 if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
2932 return nullptr;
2933
2934 if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
2935 return complexCreateOp.getOperand(1);
2936
2937 auto complex =
2938 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
2939 return complex ? complex.getImag() : nullptr;
2940}
2941
2942//===----------------------------------------------------------------------===//
2943// ComplexRealPtrOp
2944//===----------------------------------------------------------------------===//
2945
2946LogicalResult cir::ComplexRealPtrOp::verify() {
2947 mlir::Type resultPointeeTy = getType().getPointee();
2948 cir::PointerType operandPtrTy = getOperand().getType();
2949 auto operandPointeeTy =
2950 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
2951
2952 if (resultPointeeTy != operandPointeeTy.getElementType()) {
2953 return emitOpError() << ": result type does not match operand type";
2954 }
2955
2956 return success();
2957}
2958
2959//===----------------------------------------------------------------------===//
2960// ComplexImagPtrOp
2961//===----------------------------------------------------------------------===//
2962
2963LogicalResult cir::ComplexImagPtrOp::verify() {
2964 mlir::Type resultPointeeTy = getType().getPointee();
2965 cir::PointerType operandPtrTy = getOperand().getType();
2966 auto operandPointeeTy =
2967 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
2968
2969 if (resultPointeeTy != operandPointeeTy.getElementType()) {
2970 return emitOpError()
2971 << "cir.complex.imag_ptr result type does not match operand type";
2972 }
2973 return success();
2974}
2975
2976//===----------------------------------------------------------------------===//
2977// Bit manipulation operations
2978//===----------------------------------------------------------------------===//
2979
2980static OpFoldResult
2981foldUnaryBitOp(mlir::Attribute inputAttr,
2982 llvm::function_ref<llvm::APInt(const llvm::APInt &)> func,
2983 bool poisonZero = false) {
2984 if (mlir::isa_and_present<cir::PoisonAttr>(inputAttr)) {
2985 // Propagate poison value
2986 return inputAttr;
2987 }
2988
2989 auto input = mlir::dyn_cast_if_present<IntAttr>(inputAttr);
2990 if (!input)
2991 return nullptr;
2992
2993 llvm::APInt inputValue = input.getValue();
2994 if (poisonZero && inputValue.isZero())
2995 return cir::PoisonAttr::get(input.getType());
2996
2997 llvm::APInt resultValue = func(inputValue);
2998 return IntAttr::get(input.getType(), resultValue);
2999}
3000
3001OpFoldResult BitClrsbOp::fold(FoldAdaptor adaptor) {
3002 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3003 unsigned resultValue =
3004 inputValue.getBitWidth() - inputValue.getSignificantBits();
3005 return llvm::APInt(inputValue.getBitWidth(), resultValue);
3006 });
3007}
3008
3009OpFoldResult BitClzOp::fold(FoldAdaptor adaptor) {
3010 return foldUnaryBitOp(
3011 adaptor.getInput(),
3012 [](const llvm::APInt &inputValue) {
3013 unsigned resultValue = inputValue.countLeadingZeros();
3014 return llvm::APInt(inputValue.getBitWidth(), resultValue);
3015 },
3016 getPoisonZero());
3017}
3018
3019OpFoldResult BitCtzOp::fold(FoldAdaptor adaptor) {
3020 return foldUnaryBitOp(
3021 adaptor.getInput(),
3022 [](const llvm::APInt &inputValue) {
3023 return llvm::APInt(inputValue.getBitWidth(),
3024 inputValue.countTrailingZeros());
3025 },
3026 getPoisonZero());
3027}
3028
3029OpFoldResult BitFfsOp::fold(FoldAdaptor adaptor) {
3030 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3031 unsigned trailingZeros = inputValue.countTrailingZeros();
3032 unsigned result =
3033 trailingZeros == inputValue.getBitWidth() ? 0 : trailingZeros + 1;
3034 return llvm::APInt(inputValue.getBitWidth(), result);
3035 });
3036}
3037
3038OpFoldResult BitParityOp::fold(FoldAdaptor adaptor) {
3039 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3040 return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount() % 2);
3041 });
3042}
3043
3044OpFoldResult BitPopcountOp::fold(FoldAdaptor adaptor) {
3045 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3046 return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount());
3047 });
3048}
3049
3050OpFoldResult BitReverseOp::fold(FoldAdaptor adaptor) {
3051 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3052 return inputValue.reverseBits();
3053 });
3054}
3055
3056OpFoldResult ByteSwapOp::fold(FoldAdaptor adaptor) {
3057 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
3058 return inputValue.byteSwap();
3059 });
3060}
3061
3062OpFoldResult RotateOp::fold(FoldAdaptor adaptor) {
3063 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()) ||
3064 mlir::isa_and_present<cir::PoisonAttr>(adaptor.getAmount())) {
3065 // Propagate poison values
3066 return cir::PoisonAttr::get(getType());
3067 }
3068
3069 auto input = mlir::dyn_cast_if_present<IntAttr>(adaptor.getInput());
3070 auto amount = mlir::dyn_cast_if_present<IntAttr>(adaptor.getAmount());
3071 if (!input && !amount)
3072 return nullptr;
3073
3074 // We could fold cir.rotate even if one of its two operands is not a constant:
3075 // - `cir.rotate left/right %0, 0` could be folded into just %0 even if %0
3076 // is not a constant.
3077 // - `cir.rotate left/right 0/0b111...111, %0` could be folded into 0 or
3078 // 0b111...111 even if %0 is not a constant.
3079
3080 llvm::APInt inputValue;
3081 if (input) {
3082 inputValue = input.getValue();
3083 if (inputValue.isZero() || inputValue.isAllOnes()) {
3084 // An input value of all 0s or all 1s will not change after rotation
3085 return input;
3086 }
3087 }
3088
3089 uint64_t amountValue;
3090 if (amount) {
3091 amountValue = amount.getValue().urem(getInput().getType().getWidth());
3092 if (amountValue == 0) {
3093 // A shift amount of 0 will not change the input value
3094 return getInput();
3095 }
3096 }
3097
3098 if (!input || !amount)
3099 return nullptr;
3100
3101 assert(inputValue.getBitWidth() == getInput().getType().getWidth() &&
3102 "input value must have the same bit width as the input type");
3103
3104 llvm::APInt resultValue;
3105 if (isRotateLeft())
3106 resultValue = inputValue.rotl(amountValue);
3107 else
3108 resultValue = inputValue.rotr(amountValue);
3109
3110 return IntAttr::get(input.getContext(), input.getType(), resultValue);
3111}
3112
3113//===----------------------------------------------------------------------===//
3114// InlineAsmOp
3115//===----------------------------------------------------------------------===//
3116
3117void cir::InlineAsmOp::print(OpAsmPrinter &p) {
3118 p << '(' << getAsmFlavor() << ", ";
3119 p.increaseIndent();
3120 p.printNewline();
3121
3122 llvm::SmallVector<std::string, 3> names{"out", "in", "in_out"};
3123 auto *nameIt = names.begin();
3124 auto *attrIt = getOperandAttrs().begin();
3125
3126 for (mlir::OperandRange ops : getAsmOperands()) {
3127 p << *nameIt << " = ";
3128
3129 p << '[';
3130 llvm::interleaveComma(llvm::make_range(ops.begin(), ops.end()), p,
3131 [&](Value value) {
3132 p.printOperand(value);
3133 p << " : " << value.getType();
3134 if (*attrIt)
3135 p << " (maybe_memory)";
3136 attrIt++;
3137 });
3138 p << "],";
3139 p.printNewline();
3140 ++nameIt;
3141 }
3142
3143 p << "{";
3144 p.printString(getAsmString());
3145 p << " ";
3146 p.printString(getConstraints());
3147 p << "}";
3148 p.decreaseIndent();
3149 p << ')';
3150 if (getSideEffects())
3151 p << " side_effects";
3152
3153 std::array elidedAttrs{
3154 llvm::StringRef("asm_flavor"), llvm::StringRef("asm_string"),
3155 llvm::StringRef("constraints"), llvm::StringRef("operand_attrs"),
3156 llvm::StringRef("operands_segments"), llvm::StringRef("side_effects")};
3157 p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs);
3158
3159 if (auto v = getRes())
3160 p << " -> " << v.getType();
3161}
3162
3163void cir::InlineAsmOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3164 ArrayRef<ValueRange> asmOperands,
3165 StringRef asmString, StringRef constraints,
3166 bool sideEffects, cir::AsmFlavor asmFlavor,
3167 ArrayRef<Attribute> operandAttrs) {
3168 // Set up the operands_segments for VariadicOfVariadic
3169 SmallVector<int32_t> segments;
3170 for (auto operandRange : asmOperands) {
3171 segments.push_back(operandRange.size());
3172 odsState.addOperands(operandRange);
3173 }
3174
3175 odsState.addAttribute(
3176 "operands_segments",
3177 DenseI32ArrayAttr::get(odsBuilder.getContext(), segments));
3178 odsState.addAttribute("asm_string", odsBuilder.getStringAttr(asmString));
3179 odsState.addAttribute("constraints", odsBuilder.getStringAttr(constraints));
3180 odsState.addAttribute("asm_flavor",
3181 AsmFlavorAttr::get(odsBuilder.getContext(), asmFlavor));
3182
3183 if (sideEffects)
3184 odsState.addAttribute("side_effects", odsBuilder.getUnitAttr());
3185
3186 odsState.addAttribute("operand_attrs", odsBuilder.getArrayAttr(operandAttrs));
3187}
3188
3189ParseResult cir::InlineAsmOp::parse(OpAsmParser &parser,
3190 OperationState &result) {
3192 llvm::SmallVector<int32_t> operandsGroupSizes;
3193 std::string asmString, constraints;
3194 Type resType;
3195 MLIRContext *ctxt = parser.getBuilder().getContext();
3196
3197 auto error = [&](const Twine &msg) -> LogicalResult {
3198 return parser.emitError(parser.getCurrentLocation(), msg);
3199 };
3200
3201 auto expected = [&](const std::string &c) {
3202 return error("expected '" + c + "'");
3203 };
3204
3205 if (parser.parseLParen().failed())
3206 return expected("(");
3207
3208 auto flavor = FieldParser<AsmFlavor, AsmFlavor>::parse(parser);
3209 if (failed(flavor))
3210 return error("Unknown AsmFlavor");
3211
3212 if (parser.parseComma().failed())
3213 return expected(",");
3214
3215 auto parseValue = [&](Value &v) {
3216 OpAsmParser::UnresolvedOperand op;
3217
3218 if (parser.parseOperand(op) || parser.parseColon())
3219 return error("can't parse operand");
3220
3221 Type typ;
3222 if (parser.parseType(typ).failed())
3223 return error("can't parse operand type");
3225 if (parser.resolveOperand(op, typ, tmp))
3226 return error("can't resolve operand");
3227 v = tmp[0];
3228 return mlir::success();
3229 };
3230
3231 auto parseOperands = [&](llvm::StringRef name) {
3232 if (parser.parseKeyword(name).failed())
3233 return error("expected " + name + " operands here");
3234 if (parser.parseEqual().failed())
3235 return expected("=");
3236 if (parser.parseLSquare().failed())
3237 return expected("[");
3238
3239 int size = 0;
3240 if (parser.parseOptionalRSquare().succeeded()) {
3241 operandsGroupSizes.push_back(size);
3242 if (parser.parseComma())
3243 return expected(",");
3244 return mlir::success();
3245 }
3246
3247 auto parseOperand = [&]() {
3248 Value val;
3249 if (parseValue(val).succeeded()) {
3250 result.operands.push_back(val);
3251 size++;
3252
3253 if (parser.parseOptionalLParen().failed()) {
3254 operandAttrs.push_back(mlir::Attribute());
3255 return mlir::success();
3256 }
3257
3258 if (parser.parseKeyword("maybe_memory").succeeded()) {
3259 operandAttrs.push_back(mlir::UnitAttr::get(ctxt));
3260 if (parser.parseRParen())
3261 return expected(")");
3262 return mlir::success();
3263 } else {
3264 return expected("maybe_memory");
3265 }
3266 }
3267 return mlir::failure();
3268 };
3269
3270 if (parser.parseCommaSeparatedList(parseOperand).failed())
3271 return mlir::failure();
3272
3273 if (parser.parseRSquare().failed() || parser.parseComma().failed())
3274 return expected("]");
3275 operandsGroupSizes.push_back(size);
3276 return mlir::success();
3277 };
3278
3279 if (parseOperands("out").failed() || parseOperands("in").failed() ||
3280 parseOperands("in_out").failed())
3281 return error("failed to parse operands");
3282
3283 if (parser.parseLBrace())
3284 return expected("{");
3285 if (parser.parseString(&asmString))
3286 return error("asm string parsing failed");
3287 if (parser.parseString(&constraints))
3288 return error("constraints string parsing failed");
3289 if (parser.parseRBrace())
3290 return expected("}");
3291 if (parser.parseRParen())
3292 return expected(")");
3293
3294 if (parser.parseOptionalKeyword("side_effects").succeeded())
3295 result.attributes.set("side_effects", UnitAttr::get(ctxt));
3296
3297 if (parser.parseOptionalArrow().succeeded() &&
3298 parser.parseType(resType).failed())
3299 return mlir::failure();
3300
3301 if (parser.parseOptionalAttrDict(result.attributes).failed())
3302 return mlir::failure();
3303
3304 result.attributes.set("asm_flavor", AsmFlavorAttr::get(ctxt, *flavor));
3305 result.attributes.set("asm_string", StringAttr::get(ctxt, asmString));
3306 result.attributes.set("constraints", StringAttr::get(ctxt, constraints));
3307 result.attributes.set("operand_attrs", ArrayAttr::get(ctxt, operandAttrs));
3308 result.getOrAddProperties<InlineAsmOp::Properties>().operands_segments =
3309 parser.getBuilder().getDenseI32ArrayAttr(operandsGroupSizes);
3310 if (resType)
3311 result.addTypes(TypeRange{resType});
3312
3313 return mlir::success();
3314}
3315
3316//===----------------------------------------------------------------------===//
3317// ThrowOp
3318//===----------------------------------------------------------------------===//
3319
3320mlir::LogicalResult cir::ThrowOp::verify() {
3321 // For the no-rethrow version, it must have at least the exception pointer.
3322 if (rethrows())
3323 return success();
3324
3325 if (getNumOperands() != 0) {
3326 if (getTypeInfo())
3327 return success();
3328 return emitOpError() << "'type_info' symbol attribute missing";
3329 }
3330
3331 return failure();
3332}
3333
3334//===----------------------------------------------------------------------===//
3335// AtomicFetchOp
3336//===----------------------------------------------------------------------===//
3337
3338LogicalResult cir::AtomicFetchOp::verify() {
3339 if (getBinop() != cir::AtomicFetchKind::Add &&
3340 getBinop() != cir::AtomicFetchKind::Sub &&
3341 getBinop() != cir::AtomicFetchKind::Max &&
3342 getBinop() != cir::AtomicFetchKind::Min &&
3343 !mlir::isa<cir::IntType>(getVal().getType()))
3344 return emitError("only atomic add, sub, max, and min operation could "
3345 "operate on floating-point values");
3346 return success();
3347}
3348
3349//===----------------------------------------------------------------------===//
3350// TypeInfoAttr
3351//===----------------------------------------------------------------------===//
3352
3353LogicalResult cir::TypeInfoAttr::verify(
3354 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
3355 ::mlir::Type type, ::mlir::ArrayAttr typeInfoData) {
3356
3357 if (cir::ConstRecordAttr::verify(emitError, type, typeInfoData).failed())
3358 return failure();
3359
3360 return success();
3361}
3362
3363//===----------------------------------------------------------------------===//
3364// TryOp
3365//===----------------------------------------------------------------------===//
3366
3367void cir::TryOp::getSuccessorRegions(
3368 mlir::RegionBranchPoint point,
3370 // The `try` and the `catchers` region branch back to the parent operation.
3371 if (!point.isParent()) {
3372 regions.push_back(
3373 RegionSuccessor(getOperation(), getOperation()->getResults()));
3374 return;
3375 }
3376
3377 regions.push_back(mlir::RegionSuccessor(&getTryRegion()));
3378
3379 // TODO(CIR): If we know a target function never throws a specific type, we
3380 // can remove the catch handler.
3381 for (mlir::Region &handlerRegion : this->getHandlerRegions())
3382 regions.push_back(mlir::RegionSuccessor(&handlerRegion));
3383}
3384
3385static void
3386printTryHandlerRegions(mlir::OpAsmPrinter &printer, cir::TryOp op,
3387 mlir::MutableArrayRef<mlir::Region> handlerRegions,
3388 mlir::ArrayAttr handlerTypes) {
3389 if (!handlerTypes)
3390 return;
3391
3392 for (const auto [typeIdx, typeAttr] : llvm::enumerate(handlerTypes)) {
3393 if (typeIdx)
3394 printer << " ";
3395
3396 if (mlir::isa<cir::CatchAllAttr>(typeAttr)) {
3397 printer << "catch all ";
3398 } else if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
3399 printer << "unwind ";
3400 } else {
3401 printer << "catch [type ";
3402 printer.printAttribute(typeAttr);
3403 printer << "] ";
3404 }
3405
3406 printer.printRegion(handlerRegions[typeIdx],
3407 /*printEntryBLockArgs=*/false,
3408 /*printBlockTerminators=*/true);
3409 }
3410}
3411
3412static mlir::ParseResult parseTryHandlerRegions(
3413 mlir::OpAsmParser &parser,
3414 llvm::SmallVectorImpl<std::unique_ptr<mlir::Region>> &handlerRegions,
3415 mlir::ArrayAttr &handlerTypes) {
3416
3417 auto parseCheckedCatcherRegion = [&]() -> mlir::ParseResult {
3418 handlerRegions.emplace_back(new mlir::Region);
3419
3420 mlir::Region &currRegion = *handlerRegions.back();
3421 mlir::SMLoc regionLoc = parser.getCurrentLocation();
3422 if (parser.parseRegion(currRegion)) {
3423 handlerRegions.clear();
3424 return failure();
3425 }
3426
3427 if (currRegion.empty())
3428 return parser.emitError(regionLoc, "handler region shall not be empty");
3429
3430 if (!(currRegion.back().mightHaveTerminator() &&
3431 currRegion.back().getTerminator()))
3432 return parser.emitError(
3433 regionLoc, "blocks are expected to be explicitly terminated");
3434
3435 return success();
3436 };
3437
3438 bool hasCatchAll = false;
3440 while (parser.parseOptionalKeyword("catch").succeeded()) {
3441 bool hasLSquare = parser.parseOptionalLSquare().succeeded();
3442
3443 llvm::StringRef attrStr;
3444 if (parser.parseOptionalKeyword(&attrStr, {"all", "type"}).failed())
3445 return parser.emitError(parser.getCurrentLocation(),
3446 "expected 'all' or 'type' keyword");
3447
3448 bool isCatchAll = attrStr == "all";
3449 if (isCatchAll) {
3450 if (hasCatchAll)
3451 return parser.emitError(parser.getCurrentLocation(),
3452 "can't have more than one catch all");
3453 hasCatchAll = true;
3454 }
3455
3456 mlir::Attribute exceptionRTTIAttr;
3457 if (!isCatchAll && parser.parseAttribute(exceptionRTTIAttr).failed())
3458 return parser.emitError(parser.getCurrentLocation(),
3459 "expected valid RTTI info attribute");
3460
3461 catcherAttrs.push_back(isCatchAll
3462 ? cir::CatchAllAttr::get(parser.getContext())
3463 : exceptionRTTIAttr);
3464
3465 if (hasLSquare && isCatchAll)
3466 return parser.emitError(parser.getCurrentLocation(),
3467 "catch all dosen't need RTTI info attribute");
3468
3469 if (hasLSquare && parser.parseRSquare().failed())
3470 return parser.emitError(parser.getCurrentLocation(),
3471 "expected `]` after RTTI info attribute");
3472
3473 if (parseCheckedCatcherRegion().failed())
3474 return mlir::failure();
3475 }
3476
3477 if (parser.parseOptionalKeyword("unwind").succeeded()) {
3478 if (hasCatchAll)
3479 return parser.emitError(parser.getCurrentLocation(),
3480 "unwind can't be used with catch all");
3481
3482 catcherAttrs.push_back(cir::UnwindAttr::get(parser.getContext()));
3483 if (parseCheckedCatcherRegion().failed())
3484 return mlir::failure();
3485 }
3486
3487 handlerTypes = parser.getBuilder().getArrayAttr(catcherAttrs);
3488 return mlir::success();
3489}
3490
3491//===----------------------------------------------------------------------===//
3492// TableGen'd op method definitions
3493//===----------------------------------------------------------------------===//
3494
3495#define GET_OP_CLASSES
3496#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
static const MemRegion * getRegion(const CallEvent &Call, const MutexDescriptor &Descriptor, bool IsLock)
static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op, cir::FuncOp function)
static bool isBoolNot(cir::UnaryOp op)
static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, mlir::OperationState &result, bool hasDestinationBlocks=false)
static bool isIntOrBoolCast(cir::CastOp op)
static void printConstant(OpAsmPrinter &p, Attribute value)
static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser, mlir::Region &region)
ParseResult parseInlineKindAttr(OpAsmParser &parser, cir::InlineKindAttr &inlineKindAttr)
void printInlineKindAttr(OpAsmPrinter &p, cir::InlineKindAttr inlineKindAttr)
void printVisibilityAttr(OpAsmPrinter &printer, cir::VisibilityAttr &visibility)
static ParseResult parseSwitchFlatOpCases(OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< llvm::SmallVector< OpAsmParser::UnresolvedOperand > > &caseOperands, SmallVectorImpl< llvm::SmallVector< Type > > &caseOperandTypes)
<cases> ::= [ (case (, case )* )?
static LogicalResult verifyCallCommInSymbolUses(mlir::Operation *op, SymbolTableCollection &symbolTable)
static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region, SMLoc errLoc)
static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValueAttr, mlir::Region &ctorRegion, mlir::Region &dtorRegion)
void printIndirectBrOpSucessors(OpAsmPrinter &p, cir::IndirectBrOp op, Type flagType, SuccessorRange succs, OperandRangeRange succOperands, const TypeRangeRange &succOperandsTypes)
static OpFoldResult foldUnaryBitOp(mlir::Attribute inputAttr, llvm::function_ref< llvm::APInt(const llvm::APInt &)> func, bool poisonZero=false)
static llvm::StringRef getLinkageAttrNameString()
Returns the name used for the linkage attribute.
static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue)
Parse an enum from the keyword, or default to the provided default value.
static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op, Type flagType, mlir::ArrayAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static mlir::ParseResult parseTryCallDestinations(mlir::OpAsmParser &parser, mlir::OperationState &result)
static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op, TypeAttr type, Attribute initAttr, mlir::Region &ctorRegion, mlir::Region &dtorRegion)
static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result)
Parse an enum from the keyword, return failure if the keyword is not found.
static Value tryFoldCastChain(cir::CastOp op)
void parseVisibilityAttr(OpAsmParser &parser, cir::VisibilityAttr &visibility)
static void printTryHandlerRegions(mlir::OpAsmPrinter &printer, cir::TryOp op, mlir::MutableArrayRef< mlir::Region > handlerRegions, mlir::ArrayAttr handlerTypes)
ParseResult parseIndirectBrOpSucessors(OpAsmParser &parser, Type &flagType, SmallVectorImpl< Block * > &succOperandBlocks, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &succOperands, SmallVectorImpl< SmallVector< Type > > &succOperandsTypes)
static bool omitRegionTerm(mlir::Region &r)
static void printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym, mlir::Value indirectCallee, mlir::OpAsmPrinter &printer, bool isNothrow, cir::SideEffect sideEffect, mlir::Block *normalDest=nullptr, mlir::Block *unwindDest=nullptr)
static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer, cir::ScopeOp &op, mlir::Region &region)
static ParseResult parseConstantValue(OpAsmParser &parser, mlir::Attribute &valueAttr)
static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, mlir::Attribute attrType)
static mlir::ParseResult parseTryHandlerRegions(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< std::unique_ptr< mlir::Region > > &handlerRegions, mlir::ArrayAttr &handlerTypes)
#define REGISTER_ENUM_TYPE(Ty)
static int parseOptionalKeywordAlternative(AsmParser &parser, ArrayRef< llvm::StringRef > keywords)
llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> BuilderCallbackRef
Definition CIRDialect.h:37
llvm::function_ref< void( mlir::OpBuilder &, mlir::Location, mlir::OperationState &)> BuilderOpStateCallbackRef
Definition CIRDialect.h:39
static std::optional< NonLoc > getIndex(ProgramStateRef State, const ElementRegion *ER, CharKind CK)
static Decl::Kind getKind(const Decl *D)
TokenType getType() const
Returns the token's type, e.g.
__device__ __2f16 float c
void buildTerminatedBody(mlir::OpBuilder &builder, mlir::Location loc)
const internal::VariadicAllOfMatcher< Attr > attr
const AstTypeMatcher< RecordType > recordType
StringRef getName(const HeaderType T)
Definition HeaderFile.h:38
RangeSelector name(std::string ID)
Given a node with a "name", (like NamedDecl, DeclRefExpr, CxxCtorInitializer, and TypeLoc) selects th...
nullptr
This class represents a compute construct, representing a 'Kind' of ‘parallel’, 'serial',...
__DEVICE__ _Tp arg(const std::complex< _Tp > &__c)
static bool addressSpace()
static bool opCallCallConv()
static bool opScopeCleanupRegion()
static bool supportIFuncAttr()