clang 22.0.0git
CIRDialect.cpp
Go to the documentation of this file.
1//===- CIRDialect.cpp - MLIR CIR ops implementation -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the CIR dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
18
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/DialectImplementation.h"
21#include "mlir/Interfaces/ControlFlowInterfaces.h"
22#include "mlir/Interfaces/FunctionImplementation.h"
23#include "mlir/Support/LLVM.h"
24
25#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
26#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
28#include "llvm/ADT/SetOperations.h"
29#include "llvm/ADT/SmallSet.h"
30#include "llvm/Support/LogicalResult.h"
31
32using namespace mlir;
33using namespace cir;
34
35//===----------------------------------------------------------------------===//
36// CIR Dialect
37//===----------------------------------------------------------------------===//
38namespace {
39struct CIROpAsmDialectInterface : public OpAsmDialectInterface {
40 using OpAsmDialectInterface::OpAsmDialectInterface;
41
42 AliasResult getAlias(Type type, raw_ostream &os) const final {
43 if (auto recordType = dyn_cast<cir::RecordType>(type)) {
44 StringAttr nameAttr = recordType.getName();
45 if (!nameAttr)
46 os << "rec_anon_" << recordType.getKindAsStr();
47 else
48 os << "rec_" << nameAttr.getValue();
49 return AliasResult::OverridableAlias;
50 }
51 if (auto intType = dyn_cast<cir::IntType>(type)) {
52 // We only provide alias for standard integer types (i.e. integer types
53 // whose width is a power of 2 and at least 8).
54 unsigned width = intType.getWidth();
55 if (width < 8 || !llvm::isPowerOf2_32(width))
56 return AliasResult::NoAlias;
57 os << intType.getAlias();
58 return AliasResult::OverridableAlias;
59 }
60 if (auto voidType = dyn_cast<cir::VoidType>(type)) {
61 os << voidType.getAlias();
62 return AliasResult::OverridableAlias;
63 }
64
65 return AliasResult::NoAlias;
66 }
67
68 AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
69 if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr)) {
70 os << (boolAttr.getValue() ? "true" : "false");
71 return AliasResult::FinalAlias;
72 }
73 if (auto bitfield = mlir::dyn_cast<cir::BitfieldInfoAttr>(attr)) {
74 os << "bfi_" << bitfield.getName().str();
75 return AliasResult::FinalAlias;
76 }
77 if (auto dynCastInfoAttr = mlir::dyn_cast<cir::DynamicCastInfoAttr>(attr)) {
78 os << dynCastInfoAttr.getAlias();
79 return AliasResult::FinalAlias;
80 }
81 return AliasResult::NoAlias;
82 }
83};
84} // namespace
85
86void cir::CIRDialect::initialize() {
87 registerTypes();
88 registerAttributes();
89 addOperations<
90#define GET_OP_LIST
91#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
92 >();
93 addInterfaces<CIROpAsmDialectInterface>();
94}
95
96Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
97 mlir::Attribute value,
98 mlir::Type type,
99 mlir::Location loc) {
100 return cir::ConstantOp::create(builder, loc, type,
101 mlir::cast<mlir::TypedAttr>(value));
102}
103
104//===----------------------------------------------------------------------===//
105// Helpers
106//===----------------------------------------------------------------------===//
107
108// Parses one of the keywords provided in the list `keywords` and returns the
109// position of the parsed keyword in the list. If none of the keywords from the
110// list is parsed, returns -1.
111static int parseOptionalKeywordAlternative(AsmParser &parser,
112 ArrayRef<llvm::StringRef> keywords) {
113 for (auto en : llvm::enumerate(keywords)) {
114 if (succeeded(parser.parseOptionalKeyword(en.value())))
115 return en.index();
116 }
117 return -1;
118}
119
120namespace {
121template <typename Ty> struct EnumTraits {};
122
123#define REGISTER_ENUM_TYPE(Ty) \
124 template <> struct EnumTraits<cir::Ty> { \
125 static llvm::StringRef stringify(cir::Ty value) { \
126 return stringify##Ty(value); \
127 } \
128 static unsigned getMaxEnumVal() { return cir::getMaxEnumValFor##Ty(); } \
129 }
130
131REGISTER_ENUM_TYPE(GlobalLinkageKind);
132REGISTER_ENUM_TYPE(VisibilityKind);
133REGISTER_ENUM_TYPE(SideEffect);
134} // namespace
135
136/// Parse an enum from the keyword, or default to the provided default value.
137/// The return type is the enum type by default, unless overriden with the
138/// second template argument.
139template <typename EnumTy, typename RetTy = EnumTy>
140static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue) {
142 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
143 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
144
145 int index = parseOptionalKeywordAlternative(parser, names);
146 if (index == -1)
147 return static_cast<RetTy>(defaultValue);
148 return static_cast<RetTy>(index);
149}
150
151/// Parse an enum from the keyword, return failure if the keyword is not found.
152template <typename EnumTy, typename RetTy = EnumTy>
153static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result) {
155 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
156 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
157
158 int index = parseOptionalKeywordAlternative(parser, names);
159 if (index == -1)
160 return failure();
161 result = static_cast<RetTy>(index);
162 return success();
163}
164
165// Check if a region's termination omission is valid and, if so, creates and
166// inserts the omitted terminator into the region.
167static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region,
168 SMLoc errLoc) {
169 Location eLoc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
170 OpBuilder builder(parser.getBuilder().getContext());
171
172 // Insert empty block in case the region is empty to ensure the terminator
173 // will be inserted
174 if (region.empty())
175 builder.createBlock(&region);
176
177 Block &block = region.back();
178 // Region is properly terminated: nothing to do.
179 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>())
180 return success();
181
182 // Check for invalid terminator omissions.
183 if (!region.hasOneBlock())
184 return parser.emitError(errLoc,
185 "multi-block region must not omit terminator");
186
187 // Terminator was omitted correctly: recreate it.
188 builder.setInsertionPointToEnd(&block);
189 cir::YieldOp::create(builder, eLoc);
190 return success();
191}
192
193// True if the region's terminator should be omitted.
194static bool omitRegionTerm(mlir::Region &r) {
195 const auto singleNonEmptyBlock = r.hasOneBlock() && !r.back().empty();
196 const auto yieldsNothing = [&r]() {
197 auto y = dyn_cast<cir::YieldOp>(r.back().getTerminator());
198 return y && y.getArgs().empty();
199 };
200 return singleNonEmptyBlock && yieldsNothing();
201}
202
203void printVisibilityAttr(OpAsmPrinter &printer,
204 cir::VisibilityAttr &visibility) {
205 switch (visibility.getValue()) {
206 case cir::VisibilityKind::Hidden:
207 printer << "hidden";
208 break;
209 case cir::VisibilityKind::Protected:
210 printer << "protected";
211 break;
212 case cir::VisibilityKind::Default:
213 break;
214 }
215}
216
217void parseVisibilityAttr(OpAsmParser &parser, cir::VisibilityAttr &visibility) {
218 cir::VisibilityKind visibilityKind =
219 parseOptionalCIRKeyword(parser, cir::VisibilityKind::Default);
220 visibility = cir::VisibilityAttr::get(parser.getContext(), visibilityKind);
221}
222
223//===----------------------------------------------------------------------===//
224// CIR Custom Parsers/Printers
225//===----------------------------------------------------------------------===//
226
227static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser,
228 mlir::Region &region) {
229 auto regionLoc = parser.getCurrentLocation();
230 if (parser.parseRegion(region))
231 return failure();
232 if (ensureRegionTerm(parser, region, regionLoc).failed())
233 return failure();
234 return success();
235}
236
237static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer,
238 cir::ScopeOp &op,
239 mlir::Region &region) {
240 printer.printRegion(region,
241 /*printEntryBlockArgs=*/false,
242 /*printBlockTerminators=*/!omitRegionTerm(region));
243}
244
245//===----------------------------------------------------------------------===//
246// AllocaOp
247//===----------------------------------------------------------------------===//
248
249void cir::AllocaOp::build(mlir::OpBuilder &odsBuilder,
250 mlir::OperationState &odsState, mlir::Type addr,
251 mlir::Type allocaType, llvm::StringRef name,
252 mlir::IntegerAttr alignment) {
253 odsState.addAttribute(getAllocaTypeAttrName(odsState.name),
254 mlir::TypeAttr::get(allocaType));
255 odsState.addAttribute(getNameAttrName(odsState.name),
256 odsBuilder.getStringAttr(name));
257 if (alignment) {
258 odsState.addAttribute(getAlignmentAttrName(odsState.name), alignment);
259 }
260 odsState.addTypes(addr);
261}
262
263//===----------------------------------------------------------------------===//
264// BreakOp
265//===----------------------------------------------------------------------===//
266
267LogicalResult cir::BreakOp::verify() {
269 if (!getOperation()->getParentOfType<LoopOpInterface>() &&
270 !getOperation()->getParentOfType<SwitchOp>())
271 return emitOpError("must be within a loop");
272 return success();
273}
274
275//===----------------------------------------------------------------------===//
276// ConditionOp
277//===----------------------------------------------------------------------===//
278
279//===----------------------------------
280// BranchOpTerminatorInterface Methods
281//===----------------------------------
282
283void cir::ConditionOp::getSuccessorRegions(
284 ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
285 // TODO(cir): The condition value may be folded to a constant, narrowing
286 // down its list of possible successors.
287
288 // Parent is a loop: condition may branch to the body or to the parent op.
289 if (auto loopOp = dyn_cast<LoopOpInterface>(getOperation()->getParentOp())) {
290 regions.emplace_back(&loopOp.getBody(), loopOp.getBody().getArguments());
291 regions.emplace_back(getOperation(), loopOp->getResults());
292 }
293
294 // Parent is an await: condition may branch to resume or suspend regions.
295 auto await = cast<AwaitOp>(getOperation()->getParentOp());
296 regions.emplace_back(&await.getResume(), await.getResume().getArguments());
297 regions.emplace_back(&await.getSuspend(), await.getSuspend().getArguments());
298}
299
300MutableOperandRange
301cir::ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
302 // No values are yielded to the successor region.
303 return MutableOperandRange(getOperation(), 0, 0);
304}
305
306LogicalResult cir::ConditionOp::verify() {
307 if (!isa<LoopOpInterface, AwaitOp>(getOperation()->getParentOp()))
308 return emitOpError("condition must be within a conditional region");
309 return success();
310}
311
312//===----------------------------------------------------------------------===//
313// ConstantOp
314//===----------------------------------------------------------------------===//
315
316static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
317 mlir::Attribute attrType) {
318 if (isa<cir::ConstPtrAttr>(attrType)) {
319 if (!mlir::isa<cir::PointerType>(opType))
320 return op->emitOpError(
321 "pointer constant initializing a non-pointer type");
322 return success();
323 }
324
325 if (isa<cir::ZeroAttr>(attrType)) {
326 if (isa<cir::RecordType, cir::ArrayType, cir::VectorType, cir::ComplexType>(
327 opType))
328 return success();
329 return op->emitOpError(
330 "zero expects struct, array, vector, or complex type");
331 }
332
333 if (mlir::isa<cir::UndefAttr>(attrType)) {
334 if (!mlir::isa<cir::VoidType>(opType))
335 return success();
336 return op->emitOpError("undef expects non-void type");
337 }
338
339 if (mlir::isa<cir::BoolAttr>(attrType)) {
340 if (!mlir::isa<cir::BoolType>(opType))
341 return op->emitOpError("result type (")
342 << opType << ") must be '!cir.bool' for '" << attrType << "'";
343 return success();
344 }
345
346 if (mlir::isa<cir::IntAttr, cir::FPAttr>(attrType)) {
347 auto at = cast<TypedAttr>(attrType);
348 if (at.getType() != opType) {
349 return op->emitOpError("result type (")
350 << opType << ") does not match value type (" << at.getType()
351 << ")";
352 }
353 return success();
354 }
355
356 if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
357 cir::ConstComplexAttr, cir::ConstRecordAttr,
358 cir::GlobalViewAttr, cir::PoisonAttr, cir::TypeInfoAttr,
359 cir::VTableAttr>(attrType))
360 return success();
361
362 assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
363 return op->emitOpError("global with type ")
364 << cast<TypedAttr>(attrType).getType() << " not yet supported";
365}
366
367LogicalResult cir::ConstantOp::verify() {
368 // ODS already generates checks to make sure the result type is valid. We just
369 // need to additionally check that the value's attribute type is consistent
370 // with the result type.
371 return checkConstantTypes(getOperation(), getType(), getValue());
372}
373
374OpFoldResult cir::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
375 return getValue();
376}
377
378//===----------------------------------------------------------------------===//
379// ContinueOp
380//===----------------------------------------------------------------------===//
381
382LogicalResult cir::ContinueOp::verify() {
383 if (!getOperation()->getParentOfType<LoopOpInterface>())
384 return emitOpError("must be within a loop");
385 return success();
386}
387
388//===----------------------------------------------------------------------===//
389// CastOp
390//===----------------------------------------------------------------------===//
391
392LogicalResult cir::CastOp::verify() {
393 mlir::Type resType = getType();
394 mlir::Type srcType = getSrc().getType();
395
396 // Verify address space casts for pointer types. given that
397 // casts for within a different address space are illegal.
398 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
399 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
400 if (srcPtrTy && resPtrTy && (getKind() != cir::CastKind::address_space))
401 if (srcPtrTy.getAddrSpace() != resPtrTy.getAddrSpace()) {
402 return emitOpError() << "result type address space does not match the "
403 "address space of the operand";
404 }
405
406 if (mlir::isa<cir::VectorType>(srcType) &&
407 mlir::isa<cir::VectorType>(resType)) {
408 // Use the element type of the vector to verify the cast kind. (Except for
409 // bitcast, see below.)
410 srcType = mlir::dyn_cast<cir::VectorType>(srcType).getElementType();
411 resType = mlir::dyn_cast<cir::VectorType>(resType).getElementType();
412 }
413
414 switch (getKind()) {
415 case cir::CastKind::int_to_bool: {
416 if (!mlir::isa<cir::BoolType>(resType))
417 return emitOpError() << "requires !cir.bool type for result";
418 if (!mlir::isa<cir::IntType>(srcType))
419 return emitOpError() << "requires !cir.int type for source";
420 return success();
421 }
422 case cir::CastKind::ptr_to_bool: {
423 if (!mlir::isa<cir::BoolType>(resType))
424 return emitOpError() << "requires !cir.bool type for result";
425 if (!mlir::isa<cir::PointerType>(srcType))
426 return emitOpError() << "requires !cir.ptr type for source";
427 return success();
428 }
429 case cir::CastKind::integral: {
430 if (!mlir::isa<cir::IntType>(resType))
431 return emitOpError() << "requires !cir.int type for result";
432 if (!mlir::isa<cir::IntType>(srcType))
433 return emitOpError() << "requires !cir.int type for source";
434 return success();
435 }
436 case cir::CastKind::array_to_ptrdecay: {
437 const auto arrayPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
438 const auto flatPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
439 if (!arrayPtrTy || !flatPtrTy)
440 return emitOpError() << "requires !cir.ptr type for source and result";
441
442 // TODO(CIR): Make sure the AddrSpace of both types are equals
443 return success();
444 }
445 case cir::CastKind::bitcast: {
446 // Handle the pointer types first.
447 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
448 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
449
450 if (srcPtrTy && resPtrTy) {
451 return success();
452 }
453
454 return success();
455 }
456 case cir::CastKind::floating: {
457 if (!mlir::isa<cir::FPTypeInterface>(srcType) ||
458 !mlir::isa<cir::FPTypeInterface>(resType))
459 return emitOpError() << "requires !cir.float type for source and result";
460 return success();
461 }
462 case cir::CastKind::float_to_int: {
463 if (!mlir::isa<cir::FPTypeInterface>(srcType))
464 return emitOpError() << "requires !cir.float type for source";
465 if (!mlir::dyn_cast<cir::IntType>(resType))
466 return emitOpError() << "requires !cir.int type for result";
467 return success();
468 }
469 case cir::CastKind::int_to_ptr: {
470 if (!mlir::dyn_cast<cir::IntType>(srcType))
471 return emitOpError() << "requires !cir.int type for source";
472 if (!mlir::dyn_cast<cir::PointerType>(resType))
473 return emitOpError() << "requires !cir.ptr type for result";
474 return success();
475 }
476 case cir::CastKind::ptr_to_int: {
477 if (!mlir::dyn_cast<cir::PointerType>(srcType))
478 return emitOpError() << "requires !cir.ptr type for source";
479 if (!mlir::dyn_cast<cir::IntType>(resType))
480 return emitOpError() << "requires !cir.int type for result";
481 return success();
482 }
483 case cir::CastKind::float_to_bool: {
484 if (!mlir::isa<cir::FPTypeInterface>(srcType))
485 return emitOpError() << "requires !cir.float type for source";
486 if (!mlir::isa<cir::BoolType>(resType))
487 return emitOpError() << "requires !cir.bool type for result";
488 return success();
489 }
490 case cir::CastKind::bool_to_int: {
491 if (!mlir::isa<cir::BoolType>(srcType))
492 return emitOpError() << "requires !cir.bool type for source";
493 if (!mlir::isa<cir::IntType>(resType))
494 return emitOpError() << "requires !cir.int type for result";
495 return success();
496 }
497 case cir::CastKind::int_to_float: {
498 if (!mlir::isa<cir::IntType>(srcType))
499 return emitOpError() << "requires !cir.int type for source";
500 if (!mlir::isa<cir::FPTypeInterface>(resType))
501 return emitOpError() << "requires !cir.float type for result";
502 return success();
503 }
504 case cir::CastKind::bool_to_float: {
505 if (!mlir::isa<cir::BoolType>(srcType))
506 return emitOpError() << "requires !cir.bool type for source";
507 if (!mlir::isa<cir::FPTypeInterface>(resType))
508 return emitOpError() << "requires !cir.float type for result";
509 return success();
510 }
511 case cir::CastKind::address_space: {
512 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
513 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
514 if (!srcPtrTy || !resPtrTy)
515 return emitOpError() << "requires !cir.ptr type for source and result";
516 if (srcPtrTy.getPointee() != resPtrTy.getPointee())
517 return emitOpError() << "requires two types differ in addrspace only";
518 return success();
519 }
520 case cir::CastKind::float_to_complex: {
521 if (!mlir::isa<cir::FPTypeInterface>(srcType))
522 return emitOpError() << "requires !cir.float type for source";
523 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
524 if (!resComplexTy)
525 return emitOpError() << "requires !cir.complex type for result";
526 if (srcType != resComplexTy.getElementType())
527 return emitOpError() << "requires source type match result element type";
528 return success();
529 }
530 case cir::CastKind::int_to_complex: {
531 if (!mlir::isa<cir::IntType>(srcType))
532 return emitOpError() << "requires !cir.int type for source";
533 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
534 if (!resComplexTy)
535 return emitOpError() << "requires !cir.complex type for result";
536 if (srcType != resComplexTy.getElementType())
537 return emitOpError() << "requires source type match result element type";
538 return success();
539 }
540 case cir::CastKind::float_complex_to_real: {
541 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
542 if (!srcComplexTy)
543 return emitOpError() << "requires !cir.complex type for source";
544 if (!mlir::isa<cir::FPTypeInterface>(resType))
545 return emitOpError() << "requires !cir.float type for result";
546 if (srcComplexTy.getElementType() != resType)
547 return emitOpError() << "requires source element type match result type";
548 return success();
549 }
550 case cir::CastKind::int_complex_to_real: {
551 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
552 if (!srcComplexTy)
553 return emitOpError() << "requires !cir.complex type for source";
554 if (!mlir::isa<cir::IntType>(resType))
555 return emitOpError() << "requires !cir.int type for result";
556 if (srcComplexTy.getElementType() != resType)
557 return emitOpError() << "requires source element type match result type";
558 return success();
559 }
560 case cir::CastKind::float_complex_to_bool: {
561 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
562 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
563 return emitOpError()
564 << "requires floating point !cir.complex type for source";
565 if (!mlir::isa<cir::BoolType>(resType))
566 return emitOpError() << "requires !cir.bool type for result";
567 return success();
568 }
569 case cir::CastKind::int_complex_to_bool: {
570 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
571 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
572 return emitOpError()
573 << "requires floating point !cir.complex type for source";
574 if (!mlir::isa<cir::BoolType>(resType))
575 return emitOpError() << "requires !cir.bool type for result";
576 return success();
577 }
578 case cir::CastKind::float_complex: {
579 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
580 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
581 return emitOpError()
582 << "requires floating point !cir.complex type for source";
583 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
584 if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
585 return emitOpError()
586 << "requires floating point !cir.complex type for result";
587 return success();
588 }
589 case cir::CastKind::float_complex_to_int_complex: {
590 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
591 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
592 return emitOpError()
593 << "requires floating point !cir.complex type for source";
594 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
595 if (!resComplexTy || !resComplexTy.isIntegerComplex())
596 return emitOpError() << "requires integer !cir.complex type for result";
597 return success();
598 }
599 case cir::CastKind::int_complex: {
600 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
601 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
602 return emitOpError() << "requires integer !cir.complex type for source";
603 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
604 if (!resComplexTy || !resComplexTy.isIntegerComplex())
605 return emitOpError() << "requires integer !cir.complex type for result";
606 return success();
607 }
608 case cir::CastKind::int_complex_to_float_complex: {
609 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
610 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
611 return emitOpError() << "requires integer !cir.complex type for source";
612 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
613 if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
614 return emitOpError()
615 << "requires floating point !cir.complex type for result";
616 return success();
617 }
618 default:
619 llvm_unreachable("Unknown CastOp kind?");
620 }
621}
622
623static bool isIntOrBoolCast(cir::CastOp op) {
624 auto kind = op.getKind();
625 return kind == cir::CastKind::bool_to_int ||
626 kind == cir::CastKind::int_to_bool || kind == cir::CastKind::integral;
627}
628
629static Value tryFoldCastChain(cir::CastOp op) {
630 cir::CastOp head = op, tail = op;
631
632 while (op) {
633 if (!isIntOrBoolCast(op))
634 break;
635 head = op;
636 op = head.getSrc().getDefiningOp<cir::CastOp>();
637 }
638
639 if (head == tail)
640 return {};
641
642 // if bool_to_int -> ... -> int_to_bool: take the bool
643 // as we had it was before all casts
644 if (head.getKind() == cir::CastKind::bool_to_int &&
645 tail.getKind() == cir::CastKind::int_to_bool)
646 return head.getSrc();
647
648 // if int_to_bool -> ... -> int_to_bool: take the result
649 // of the first one, as no other casts (and ext casts as well)
650 // don't change the first result
651 if (head.getKind() == cir::CastKind::int_to_bool &&
652 tail.getKind() == cir::CastKind::int_to_bool)
653 return head.getResult();
654
655 return {};
656}
657
658OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
659 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getSrc())) {
660 // Propagate poison value
661 return cir::PoisonAttr::get(getContext(), getType());
662 }
663
664 if (getSrc().getType() == getType()) {
665 switch (getKind()) {
666 case cir::CastKind::integral: {
667 // TODO: for sign differences, it's possible in certain conditions to
668 // create a new attribute that's capable of representing the source.
670 auto foldOrder = getSrc().getDefiningOp()->fold(foldResults);
671 if (foldOrder.succeeded() && mlir::isa<mlir::Attribute>(foldResults[0]))
672 return mlir::cast<mlir::Attribute>(foldResults[0]);
673 return {};
674 }
675 case cir::CastKind::bitcast:
676 case cir::CastKind::address_space:
677 case cir::CastKind::float_complex:
678 case cir::CastKind::int_complex: {
679 return getSrc();
680 }
681 default:
682 return {};
683 }
684 }
685 return tryFoldCastChain(*this);
686}
687
688//===----------------------------------------------------------------------===//
689// CallOp
690//===----------------------------------------------------------------------===//
691
692mlir::OperandRange cir::CallOp::getArgOperands() {
693 if (isIndirect())
694 return getArgs().drop_front(1);
695 return getArgs();
696}
697
698mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
699 mlir::MutableOperandRange args = getArgsMutable();
700 if (isIndirect())
701 return args.slice(1, args.size() - 1);
702 return args;
703}
704
705mlir::Value cir::CallOp::getIndirectCall() {
706 assert(isIndirect());
707 return getOperand(0);
708}
709
710/// Return the operand at index 'i'.
711Value cir::CallOp::getArgOperand(unsigned i) {
712 if (isIndirect())
713 ++i;
714 return getOperand(i);
715}
716
717/// Return the number of operands.
718unsigned cir::CallOp::getNumArgOperands() {
719 if (isIndirect())
720 return this->getOperation()->getNumOperands() - 1;
721 return this->getOperation()->getNumOperands();
722}
723
724static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
725 mlir::OperationState &result) {
727 llvm::SMLoc opsLoc;
728 mlir::FlatSymbolRefAttr calleeAttr;
729 llvm::ArrayRef<mlir::Type> allResultTypes;
730
731 // If we cannot parse a string callee, it means this is an indirect call.
732 if (!parser
733 .parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
734 result.attributes)
735 .has_value()) {
736 OpAsmParser::UnresolvedOperand indirectVal;
737 // Do not resolve right now, since we need to figure out the type
738 if (parser.parseOperand(indirectVal).failed())
739 return failure();
740 ops.push_back(indirectVal);
741 }
742
743 if (parser.parseLParen())
744 return mlir::failure();
745
746 opsLoc = parser.getCurrentLocation();
747 if (parser.parseOperandList(ops))
748 return mlir::failure();
749 if (parser.parseRParen())
750 return mlir::failure();
751
752 if (parser.parseOptionalKeyword("nothrow").succeeded())
753 result.addAttribute(CIRDialect::getNoThrowAttrName(),
754 mlir::UnitAttr::get(parser.getContext()));
755
756 if (parser.parseOptionalKeyword("side_effect").succeeded()) {
757 if (parser.parseLParen().failed())
758 return failure();
759 cir::SideEffect sideEffect;
760 if (parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed())
761 return failure();
762 if (parser.parseRParen().failed())
763 return failure();
764 auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
765 result.addAttribute(CIRDialect::getSideEffectAttrName(), attr);
766 }
767
768 if (parser.parseOptionalAttrDict(result.attributes))
769 return ::mlir::failure();
770
771 if (parser.parseColon())
772 return ::mlir::failure();
773
774 mlir::FunctionType opsFnTy;
775 if (parser.parseType(opsFnTy))
776 return mlir::failure();
777
778 allResultTypes = opsFnTy.getResults();
779 result.addTypes(allResultTypes);
780
781 if (parser.resolveOperands(ops, opsFnTy.getInputs(), opsLoc, result.operands))
782 return mlir::failure();
783
784 return mlir::success();
785}
786
787static void printCallCommon(mlir::Operation *op,
788 mlir::FlatSymbolRefAttr calleeSym,
789 mlir::Value indirectCallee,
790 mlir::OpAsmPrinter &printer, bool isNothrow,
791 cir::SideEffect sideEffect) {
792 printer << ' ';
793
794 auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
795 auto ops = callLikeOp.getArgOperands();
796
797 if (calleeSym) {
798 // Direct calls
799 printer.printAttributeWithoutType(calleeSym);
800 } else {
801 // Indirect calls
802 assert(indirectCallee);
803 printer << indirectCallee;
804 }
805 printer << "(" << ops << ")";
806
807 if (isNothrow)
808 printer << " nothrow";
809
810 if (sideEffect != cir::SideEffect::All) {
811 printer << " side_effect(";
812 printer << stringifySideEffect(sideEffect);
813 printer << ")";
814 }
815
816 printer.printOptionalAttrDict(op->getAttrs(),
817 {CIRDialect::getCalleeAttrName(),
818 CIRDialect::getNoThrowAttrName(),
819 CIRDialect::getSideEffectAttrName()});
820
821 printer << " : ";
822 printer.printFunctionalType(op->getOperands().getTypes(),
823 op->getResultTypes());
824}
825
826mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
827 mlir::OperationState &result) {
828 return parseCallCommon(parser, result);
829}
830
831void cir::CallOp::print(mlir::OpAsmPrinter &p) {
832 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
833 cir::SideEffect sideEffect = getSideEffect();
834 printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
835 sideEffect);
836}
837
838static LogicalResult
839verifyCallCommInSymbolUses(mlir::Operation *op,
840 SymbolTableCollection &symbolTable) {
841 auto fnAttr =
842 op->getAttrOfType<FlatSymbolRefAttr>(CIRDialect::getCalleeAttrName());
843 if (!fnAttr) {
844 // This is an indirect call, thus we don't have to check the symbol uses.
845 return mlir::success();
846 }
847
848 auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
849 if (!fn)
850 return op->emitOpError() << "'" << fnAttr.getValue()
851 << "' does not reference a valid function";
852
853 auto callIf = dyn_cast<cir::CIRCallOpInterface>(op);
854 assert(callIf && "expected CIR call interface to be always available");
855
856 // Verify that the operand and result types match the callee. Note that
857 // argument-checking is disabled for functions without a prototype.
858 auto fnType = fn.getFunctionType();
859 if (!fn.getNoProto()) {
860 unsigned numCallOperands = callIf.getNumArgOperands();
861 unsigned numFnOpOperands = fnType.getNumInputs();
862
863 if (!fnType.isVarArg() && numCallOperands != numFnOpOperands)
864 return op->emitOpError("incorrect number of operands for callee");
865 if (fnType.isVarArg() && numCallOperands < numFnOpOperands)
866 return op->emitOpError("too few operands for callee");
867
868 for (unsigned i = 0, e = numFnOpOperands; i != e; ++i)
869 if (callIf.getArgOperand(i).getType() != fnType.getInput(i))
870 return op->emitOpError("operand type mismatch: expected operand type ")
871 << fnType.getInput(i) << ", but provided "
872 << op->getOperand(i).getType() << " for operand number " << i;
873 }
874
876
877 // Void function must not return any results.
878 if (fnType.hasVoidReturn() && op->getNumResults() != 0)
879 return op->emitOpError("callee returns void but call has results");
880
881 // Non-void function calls must return exactly one result.
882 if (!fnType.hasVoidReturn() && op->getNumResults() != 1)
883 return op->emitOpError("incorrect number of results for callee");
884
885 // Parent function and return value types must match.
886 if (!fnType.hasVoidReturn() &&
887 op->getResultTypes().front() != fnType.getReturnType()) {
888 return op->emitOpError("result type mismatch: expected ")
889 << fnType.getReturnType() << ", but provided "
890 << op->getResult(0).getType();
891 }
892
893 return mlir::success();
894}
895
896LogicalResult
897cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
898 return verifyCallCommInSymbolUses(*this, symbolTable);
899}
900
901//===----------------------------------------------------------------------===//
902// ReturnOp
903//===----------------------------------------------------------------------===//
904
905static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op,
906 cir::FuncOp function) {
907 // ReturnOps currently only have a single optional operand.
908 if (op.getNumOperands() > 1)
909 return op.emitOpError() << "expects at most 1 return operand";
910
911 // Ensure returned type matches the function signature.
912 auto expectedTy = function.getFunctionType().getReturnType();
913 auto actualTy =
914 (op.getNumOperands() == 0 ? cir::VoidType::get(op.getContext())
915 : op.getOperand(0).getType());
916 if (actualTy != expectedTy)
917 return op.emitOpError() << "returns " << actualTy
918 << " but enclosing function returns " << expectedTy;
919
920 return mlir::success();
921}
922
923mlir::LogicalResult cir::ReturnOp::verify() {
924 // Returns can be present in multiple different scopes, get the
925 // wrapping function and start from there.
926 auto *fnOp = getOperation()->getParentOp();
927 while (!isa<cir::FuncOp>(fnOp))
928 fnOp = fnOp->getParentOp();
929
930 // Make sure return types match function return type.
931 if (checkReturnAndFunction(*this, cast<cir::FuncOp>(fnOp)).failed())
932 return failure();
933
934 return success();
935}
936
937//===----------------------------------------------------------------------===//
938// IfOp
939//===----------------------------------------------------------------------===//
940
941ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
942 // create the regions for 'then'.
943 result.regions.reserve(2);
944 Region *thenRegion = result.addRegion();
945 Region *elseRegion = result.addRegion();
946
947 mlir::Builder &builder = parser.getBuilder();
948 OpAsmParser::UnresolvedOperand cond;
949 Type boolType = cir::BoolType::get(builder.getContext());
950
951 if (parser.parseOperand(cond) ||
952 parser.resolveOperand(cond, boolType, result.operands))
953 return failure();
954
955 // Parse 'then' region.
956 mlir::SMLoc parseThenLoc = parser.getCurrentLocation();
957 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
958 return failure();
959
960 if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed())
961 return failure();
962
963 // If we find an 'else' keyword, parse the 'else' region.
964 if (!parser.parseOptionalKeyword("else")) {
965 mlir::SMLoc parseElseLoc = parser.getCurrentLocation();
966 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
967 return failure();
968 if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed())
969 return failure();
970 }
971
972 // Parse the optional attribute list.
973 if (parser.parseOptionalAttrDict(result.attributes))
974 return failure();
975 return success();
976}
977
978void cir::IfOp::print(OpAsmPrinter &p) {
979 p << " " << getCondition() << " ";
980 mlir::Region &thenRegion = this->getThenRegion();
981 p.printRegion(thenRegion,
982 /*printEntryBlockArgs=*/false,
983 /*printBlockTerminators=*/!omitRegionTerm(thenRegion));
984
985 // Print the 'else' regions if it exists and has a block.
986 mlir::Region &elseRegion = this->getElseRegion();
987 if (!elseRegion.empty()) {
988 p << " else ";
989 p.printRegion(elseRegion,
990 /*printEntryBlockArgs=*/false,
991 /*printBlockTerminators=*/!omitRegionTerm(elseRegion));
992 }
993
994 p.printOptionalAttrDict(getOperation()->getAttrs());
995}
996
997/// Default callback for IfOp builders.
998void cir::buildTerminatedBody(OpBuilder &builder, Location loc) {
999 // add cir.yield to end of the block
1000 cir::YieldOp::create(builder, loc);
1001}
1002
1003/// Given the region at `index`, or the parent operation if `index` is None,
1004/// return the successor regions. These are the regions that may be selected
1005/// during the flow of control. `operands` is a set of optional attributes that
1006/// correspond to a constant value for each operand, or null if that operand is
1007/// not a constant.
1008void cir::IfOp::getSuccessorRegions(mlir::RegionBranchPoint point,
1009 SmallVectorImpl<RegionSuccessor> &regions) {
1010 // The `then` and the `else` region branch back to the parent operation.
1011 if (!point.isParent()) {
1012 regions.push_back(
1013 RegionSuccessor(getOperation(), getOperation()->getResults()));
1014 return;
1015 }
1016
1017 // Don't consider the else region if it is empty.
1018 Region *elseRegion = &this->getElseRegion();
1019 if (elseRegion->empty())
1020 elseRegion = nullptr;
1021
1022 // If the condition isn't constant, both regions may be executed.
1023 regions.push_back(RegionSuccessor(&getThenRegion()));
1024 // If the else region does not exist, it is not a viable successor.
1025 if (elseRegion)
1026 regions.push_back(RegionSuccessor(elseRegion));
1027
1028 return;
1029}
1030
1031void cir::IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1032 bool withElseRegion, BuilderCallbackRef thenBuilder,
1033 BuilderCallbackRef elseBuilder) {
1034 assert(thenBuilder && "the builder callback for 'then' must be present");
1035 result.addOperands(cond);
1036
1037 OpBuilder::InsertionGuard guard(builder);
1038 Region *thenRegion = result.addRegion();
1039 builder.createBlock(thenRegion);
1040 thenBuilder(builder, result.location);
1041
1042 Region *elseRegion = result.addRegion();
1043 if (!withElseRegion)
1044 return;
1045
1046 builder.createBlock(elseRegion);
1047 elseBuilder(builder, result.location);
1048}
1049
1050//===----------------------------------------------------------------------===//
1051// ScopeOp
1052//===----------------------------------------------------------------------===//
1053
1054/// Given the region at `index`, or the parent operation if `index` is None,
1055/// return the successor regions. These are the regions that may be selected
1056/// during the flow of control. `operands` is a set of optional attributes
1057/// that correspond to a constant value for each operand, or null if that
1058/// operand is not a constant.
1059void cir::ScopeOp::getSuccessorRegions(
1060 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1061 // The only region always branch back to the parent operation.
1062 if (!point.isParent()) {
1063 regions.push_back(RegionSuccessor(getOperation(), getODSResults(0)));
1064 return;
1065 }
1066
1067 // If the condition isn't constant, both regions may be executed.
1068 regions.push_back(RegionSuccessor(&getScopeRegion()));
1069}
1070
1071void cir::ScopeOp::build(
1072 OpBuilder &builder, OperationState &result,
1073 function_ref<void(OpBuilder &, Type &, Location)> scopeBuilder) {
1074 assert(scopeBuilder && "the builder callback for 'then' must be present");
1075
1076 OpBuilder::InsertionGuard guard(builder);
1077 Region *scopeRegion = result.addRegion();
1078 builder.createBlock(scopeRegion);
1080
1081 mlir::Type yieldTy;
1082 scopeBuilder(builder, yieldTy, result.location);
1083
1084 if (yieldTy)
1085 result.addTypes(TypeRange{yieldTy});
1086}
1087
1088void cir::ScopeOp::build(
1089 OpBuilder &builder, OperationState &result,
1090 function_ref<void(OpBuilder &, Location)> scopeBuilder) {
1091 assert(scopeBuilder && "the builder callback for 'then' must be present");
1092 OpBuilder::InsertionGuard guard(builder);
1093 Region *scopeRegion = result.addRegion();
1094 builder.createBlock(scopeRegion);
1096 scopeBuilder(builder, result.location);
1097}
1098
1099LogicalResult cir::ScopeOp::verify() {
1100 if (getRegion().empty()) {
1101 return emitOpError() << "cir.scope must not be empty since it should "
1102 "include at least an implicit cir.yield ";
1103 }
1104
1105 mlir::Block &lastBlock = getRegion().back();
1106 if (lastBlock.empty() || !lastBlock.mightHaveTerminator() ||
1107 !lastBlock.getTerminator()->hasTrait<OpTrait::IsTerminator>())
1108 return emitOpError() << "last block of cir.scope must be terminated";
1109 return success();
1110}
1111
1112//===----------------------------------------------------------------------===//
1113// BrOp
1114//===----------------------------------------------------------------------===//
1115
1116mlir::SuccessorOperands cir::BrOp::getSuccessorOperands(unsigned index) {
1117 assert(index == 0 && "invalid successor index");
1118 return mlir::SuccessorOperands(getDestOperandsMutable());
1119}
1120
1121Block *cir::BrOp::getSuccessorForOperands(ArrayRef<Attribute>) {
1122 return getDest();
1123}
1124
1125//===----------------------------------------------------------------------===//
1126// BrCondOp
1127//===----------------------------------------------------------------------===//
1128
1129mlir::SuccessorOperands cir::BrCondOp::getSuccessorOperands(unsigned index) {
1130 assert(index < getNumSuccessors() && "invalid successor index");
1131 return SuccessorOperands(index == 0 ? getDestOperandsTrueMutable()
1132 : getDestOperandsFalseMutable());
1133}
1134
1135Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1136 if (IntegerAttr condAttr = dyn_cast_if_present<IntegerAttr>(operands.front()))
1137 return condAttr.getValue().isOne() ? getDestTrue() : getDestFalse();
1138 return nullptr;
1139}
1140
1141//===----------------------------------------------------------------------===//
1142// CaseOp
1143//===----------------------------------------------------------------------===//
1144
1145void cir::CaseOp::getSuccessorRegions(
1146 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1147 if (!point.isParent()) {
1148 regions.push_back(
1149 RegionSuccessor(getOperation(), getOperation()->getResults()));
1150 return;
1151 }
1152 regions.push_back(RegionSuccessor(&getCaseRegion()));
1153}
1154
1155void cir::CaseOp::build(OpBuilder &builder, OperationState &result,
1156 ArrayAttr value, CaseOpKind kind,
1157 OpBuilder::InsertPoint &insertPoint) {
1158 OpBuilder::InsertionGuard guardSwitch(builder);
1159 result.addAttribute("value", value);
1160 result.getOrAddProperties<Properties>().kind =
1161 cir::CaseOpKindAttr::get(builder.getContext(), kind);
1162 Region *caseRegion = result.addRegion();
1163 builder.createBlock(caseRegion);
1164
1165 insertPoint = builder.saveInsertionPoint();
1166}
1167
1168//===----------------------------------------------------------------------===//
1169// SwitchOp
1170//===----------------------------------------------------------------------===//
1171
1172static ParseResult parseSwitchOp(OpAsmParser &parser, mlir::Region &regions,
1173 mlir::OpAsmParser::UnresolvedOperand &cond,
1174 mlir::Type &condType) {
1175 cir::IntType intCondType;
1176
1177 if (parser.parseLParen())
1178 return mlir::failure();
1179
1180 if (parser.parseOperand(cond))
1181 return mlir::failure();
1182 if (parser.parseColon())
1183 return mlir::failure();
1184 if (parser.parseCustomTypeWithFallback(intCondType))
1185 return mlir::failure();
1186 condType = intCondType;
1187
1188 if (parser.parseRParen())
1189 return mlir::failure();
1190 if (parser.parseRegion(regions, /*arguments=*/{}, /*argTypes=*/{}))
1191 return failure();
1192
1193 return mlir::success();
1194}
1195
1196static void printSwitchOp(OpAsmPrinter &p, cir::SwitchOp op,
1197 mlir::Region &bodyRegion, mlir::Value condition,
1198 mlir::Type condType) {
1199 p << "(";
1200 p << condition;
1201 p << " : ";
1202 p.printStrippedAttrOrType(condType);
1203 p << ")";
1204
1205 p << ' ';
1206 p.printRegion(bodyRegion, /*printEntryBlockArgs=*/false,
1207 /*printBlockTerminators=*/true);
1208}
1209
1210void cir::SwitchOp::getSuccessorRegions(
1211 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &region) {
1212 if (!point.isParent()) {
1213 region.push_back(
1214 RegionSuccessor(getOperation(), getOperation()->getResults()));
1215 return;
1216 }
1217
1218 region.push_back(RegionSuccessor(&getBody()));
1219}
1220
1221void cir::SwitchOp::build(OpBuilder &builder, OperationState &result,
1222 Value cond, BuilderOpStateCallbackRef switchBuilder) {
1223 assert(switchBuilder && "the builder callback for regions must be present");
1224 OpBuilder::InsertionGuard guardSwitch(builder);
1225 Region *switchRegion = result.addRegion();
1226 builder.createBlock(switchRegion);
1227 result.addOperands({cond});
1228 switchBuilder(builder, result.location, result);
1229}
1230
1231void cir::SwitchOp::collectCases(llvm::SmallVectorImpl<CaseOp> &cases) {
1232 walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
1233 // Don't walk in nested switch op.
1234 if (isa<cir::SwitchOp>(op) && op != *this)
1235 return WalkResult::skip();
1236
1237 if (auto caseOp = dyn_cast<cir::CaseOp>(op))
1238 cases.push_back(caseOp);
1239
1240 return WalkResult::advance();
1241 });
1242}
1243
1244bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
1245 collectCases(cases);
1246
1247 if (getBody().empty())
1248 return false;
1249
1250 if (!isa<YieldOp>(getBody().front().back()))
1251 return false;
1252
1253 if (!llvm::all_of(getBody().front(),
1254 [](Operation &op) { return isa<CaseOp, YieldOp>(op); }))
1255 return false;
1256
1257 return llvm::all_of(cases, [this](CaseOp op) {
1258 return op->getParentOfType<SwitchOp>() == *this;
1259 });
1260}
1261
1262//===----------------------------------------------------------------------===//
1263// SwitchFlatOp
1264//===----------------------------------------------------------------------===//
1265
1266void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
1267 Value value, Block *defaultDestination,
1268 ValueRange defaultOperands,
1269 ArrayRef<APInt> caseValues,
1270 BlockRange caseDestinations,
1271 ArrayRef<ValueRange> caseOperands) {
1272
1273 std::vector<mlir::Attribute> caseValuesAttrs;
1274 for (const APInt &val : caseValues)
1275 caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
1276 mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
1277
1278 build(builder, result, value, defaultOperands, caseOperands, attrs,
1279 defaultDestination, caseDestinations);
1280}
1281
1282/// <cases> ::= `[` (case (`,` case )* )? `]`
1283/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
1284static ParseResult parseSwitchFlatOpCases(
1285 OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
1286 SmallVectorImpl<Block *> &caseDestinations,
1288 &caseOperands,
1289 SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
1290 if (failed(parser.parseLSquare()))
1291 return failure();
1292 if (succeeded(parser.parseOptionalRSquare()))
1293 return success();
1295
1296 auto parseCase = [&]() {
1297 int64_t value = 0;
1298 if (failed(parser.parseInteger(value)))
1299 return failure();
1300
1301 values.push_back(cir::IntAttr::get(flagType, value));
1302
1303 Block *destination;
1305 llvm::SmallVector<Type> operandTypes;
1306 if (parser.parseColon() || parser.parseSuccessor(destination))
1307 return failure();
1308 if (!parser.parseOptionalLParen()) {
1309 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
1310 /*allowResultNumber=*/false) ||
1311 parser.parseColonTypeList(operandTypes) || parser.parseRParen())
1312 return failure();
1313 }
1314 caseDestinations.push_back(destination);
1315 caseOperands.emplace_back(operands);
1316 caseOperandTypes.emplace_back(operandTypes);
1317 return success();
1318 };
1319 if (failed(parser.parseCommaSeparatedList(parseCase)))
1320 return failure();
1321
1322 caseValues = ArrayAttr::get(flagType.getContext(), values);
1323
1324 return parser.parseRSquare();
1325}
1326
1327static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
1328 Type flagType, mlir::ArrayAttr caseValues,
1329 SuccessorRange caseDestinations,
1330 OperandRangeRange caseOperands,
1331 const TypeRangeRange &caseOperandTypes) {
1332 p << '[';
1333 p.printNewline();
1334 if (!caseValues) {
1335 p << ']';
1336 return;
1337 }
1338
1339 size_t index = 0;
1340 llvm::interleave(
1341 llvm::zip(caseValues, caseDestinations),
1342 [&](auto i) {
1343 p << " ";
1344 mlir::Attribute a = std::get<0>(i);
1345 p << mlir::cast<cir::IntAttr>(a).getValue();
1346 p << ": ";
1347 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
1348 },
1349 [&] {
1350 p << ',';
1351 p.printNewline();
1352 });
1353 p.printNewline();
1354 p << ']';
1355}
1356
1357//===----------------------------------------------------------------------===//
1358// GlobalOp
1359//===----------------------------------------------------------------------===//
1360
1361static ParseResult parseConstantValue(OpAsmParser &parser,
1362 mlir::Attribute &valueAttr) {
1363 NamedAttrList attr;
1364 return parser.parseAttribute(valueAttr, "value", attr);
1365}
1366
1367static void printConstant(OpAsmPrinter &p, Attribute value) {
1368 p.printAttribute(value);
1369}
1370
1371mlir::LogicalResult cir::GlobalOp::verify() {
1372 // Verify that the initial value, if present, is either a unit attribute or
1373 // an attribute CIR supports.
1374 if (getInitialValue().has_value()) {
1375 if (checkConstantTypes(getOperation(), getSymType(), *getInitialValue())
1376 .failed())
1377 return failure();
1378 }
1379
1380 // TODO(CIR): Many other checks for properties that haven't been upstreamed
1381 // yet.
1382
1383 return success();
1384}
1385
1386void cir::GlobalOp::build(
1387 OpBuilder &odsBuilder, OperationState &odsState, llvm::StringRef sym_name,
1388 mlir::Type sym_type, bool isConstant, cir::GlobalLinkageKind linkage,
1389 function_ref<void(OpBuilder &, Location)> ctorBuilder,
1390 function_ref<void(OpBuilder &, Location)> dtorBuilder) {
1391 odsState.addAttribute(getSymNameAttrName(odsState.name),
1392 odsBuilder.getStringAttr(sym_name));
1393 odsState.addAttribute(getSymTypeAttrName(odsState.name),
1394 mlir::TypeAttr::get(sym_type));
1395 if (isConstant)
1396 odsState.addAttribute(getConstantAttrName(odsState.name),
1397 odsBuilder.getUnitAttr());
1398
1399 cir::GlobalLinkageKindAttr linkageAttr =
1400 cir::GlobalLinkageKindAttr::get(odsBuilder.getContext(), linkage);
1401 odsState.addAttribute(getLinkageAttrName(odsState.name), linkageAttr);
1402
1403 Region *ctorRegion = odsState.addRegion();
1404 if (ctorBuilder) {
1405 odsBuilder.createBlock(ctorRegion);
1406 ctorBuilder(odsBuilder, odsState.location);
1407 }
1408
1409 Region *dtorRegion = odsState.addRegion();
1410 if (dtorBuilder) {
1411 odsBuilder.createBlock(dtorRegion);
1412 dtorBuilder(odsBuilder, odsState.location);
1413 }
1414
1415 odsState.addAttribute(getGlobalVisibilityAttrName(odsState.name),
1416 cir::VisibilityAttr::get(odsBuilder.getContext()));
1417}
1418
1419/// Given the region at `index`, or the parent operation if `index` is None,
1420/// return the successor regions. These are the regions that may be selected
1421/// during the flow of control. `operands` is a set of optional attributes that
1422/// correspond to a constant value for each operand, or null if that operand is
1423/// not a constant.
1424void cir::GlobalOp::getSuccessorRegions(
1425 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1426 // The `ctor` and `dtor` regions always branch back to the parent operation.
1427 if (!point.isParent()) {
1428 regions.push_back(
1429 RegionSuccessor(getOperation(), getOperation()->getResults()));
1430 return;
1431 }
1432
1433 // Don't consider the ctor region if it is empty.
1434 Region *ctorRegion = &this->getCtorRegion();
1435 if (ctorRegion->empty())
1436 ctorRegion = nullptr;
1437
1438 // Don't consider the dtor region if it is empty.
1439 Region *dtorRegion = &this->getCtorRegion();
1440 if (dtorRegion->empty())
1441 dtorRegion = nullptr;
1442
1443 // If the condition isn't constant, both regions may be executed.
1444 if (ctorRegion)
1445 regions.push_back(RegionSuccessor(ctorRegion));
1446 if (dtorRegion)
1447 regions.push_back(RegionSuccessor(dtorRegion));
1448}
1449
1450static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op,
1451 TypeAttr type, Attribute initAttr,
1452 mlir::Region &ctorRegion,
1453 mlir::Region &dtorRegion) {
1454 auto printType = [&]() { p << ": " << type; };
1455 if (!op.isDeclaration()) {
1456 p << "= ";
1457 if (!ctorRegion.empty()) {
1458 p << "ctor ";
1459 printType();
1460 p << " ";
1461 p.printRegion(ctorRegion,
1462 /*printEntryBlockArgs=*/false,
1463 /*printBlockTerminators=*/false);
1464 } else {
1465 // This also prints the type...
1466 if (initAttr)
1467 printConstant(p, initAttr);
1468 }
1469
1470 if (!dtorRegion.empty()) {
1471 p << " dtor ";
1472 p.printRegion(dtorRegion,
1473 /*printEntryBlockArgs=*/false,
1474 /*printBlockTerminators=*/false);
1475 }
1476 } else {
1477 printType();
1478 }
1479}
1480
1481static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser,
1482 TypeAttr &typeAttr,
1483 Attribute &initialValueAttr,
1484 mlir::Region &ctorRegion,
1485 mlir::Region &dtorRegion) {
1486 mlir::Type opTy;
1487 if (parser.parseOptionalEqual().failed()) {
1488 // Absence of equal means a declaration, so we need to parse the type.
1489 // cir.global @a : !cir.int<s, 32>
1490 if (parser.parseColonType(opTy))
1491 return failure();
1492 } else {
1493 // Parse contructor, example:
1494 // cir.global @rgb = ctor : type { ... }
1495 if (!parser.parseOptionalKeyword("ctor")) {
1496 if (parser.parseColonType(opTy))
1497 return failure();
1498 auto parseLoc = parser.getCurrentLocation();
1499 if (parser.parseRegion(ctorRegion, /*arguments=*/{}, /*argTypes=*/{}))
1500 return failure();
1501 if (ensureRegionTerm(parser, ctorRegion, parseLoc).failed())
1502 return failure();
1503 } else {
1504 // Parse constant with initializer, examples:
1505 // cir.global @y = 3.400000e+00 : f32
1506 // cir.global @rgb = #cir.const_array<[...] : !cir.array<i8 x 3>>
1507 if (parseConstantValue(parser, initialValueAttr).failed())
1508 return failure();
1509
1510 assert(mlir::isa<mlir::TypedAttr>(initialValueAttr) &&
1511 "Non-typed attrs shouldn't appear here.");
1512 auto typedAttr = mlir::cast<mlir::TypedAttr>(initialValueAttr);
1513 opTy = typedAttr.getType();
1514 }
1515
1516 // Parse destructor, example:
1517 // dtor { ... }
1518 if (!parser.parseOptionalKeyword("dtor")) {
1519 auto parseLoc = parser.getCurrentLocation();
1520 if (parser.parseRegion(dtorRegion, /*arguments=*/{}, /*argTypes=*/{}))
1521 return failure();
1522 if (ensureRegionTerm(parser, dtorRegion, parseLoc).failed())
1523 return failure();
1524 }
1525 }
1526
1527 typeAttr = TypeAttr::get(opTy);
1528 return success();
1529}
1530
1531//===----------------------------------------------------------------------===//
1532// GetGlobalOp
1533//===----------------------------------------------------------------------===//
1534
1535LogicalResult
1536cir::GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1537 // Verify that the result type underlying pointer type matches the type of
1538 // the referenced cir.global or cir.func op.
1539 mlir::Operation *op =
1540 symbolTable.lookupNearestSymbolFrom(*this, getNameAttr());
1541 if (op == nullptr || !(isa<GlobalOp>(op) || isa<FuncOp>(op)))
1542 return emitOpError("'")
1543 << getName()
1544 << "' does not reference a valid cir.global or cir.func";
1545
1546 mlir::Type symTy;
1547 if (auto g = dyn_cast<GlobalOp>(op)) {
1548 symTy = g.getSymType();
1551 } else if (auto f = dyn_cast<FuncOp>(op)) {
1552 symTy = f.getFunctionType();
1553 } else {
1554 llvm_unreachable("Unexpected operation for GetGlobalOp");
1555 }
1556
1557 auto resultType = dyn_cast<PointerType>(getAddr().getType());
1558 if (!resultType || symTy != resultType.getPointee())
1559 return emitOpError("result type pointee type '")
1560 << resultType.getPointee() << "' does not match type " << symTy
1561 << " of the global @" << getName();
1562
1563 return success();
1564}
1565
1566//===----------------------------------------------------------------------===//
1567// VTableAddrPointOp
1568//===----------------------------------------------------------------------===//
1569
1570LogicalResult
1571cir::VTableAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1572 StringRef name = getName();
1573
1574 // Verify that the result type underlying pointer type matches the type of
1575 // the referenced cir.global.
1576 auto op =
1577 symbolTable.lookupNearestSymbolFrom<cir::GlobalOp>(*this, getNameAttr());
1578 if (!op)
1579 return emitOpError("'")
1580 << name << "' does not reference a valid cir.global";
1581 std::optional<mlir::Attribute> init = op.getInitialValue();
1582 if (!init)
1583 return success();
1584 if (!isa<cir::VTableAttr>(*init))
1585 return emitOpError("Expected #cir.vtable in initializer for global '")
1586 << name << "'";
1587 return success();
1588}
1589
1590//===----------------------------------------------------------------------===//
1591// VTTAddrPointOp
1592//===----------------------------------------------------------------------===//
1593
1594LogicalResult
1595cir::VTTAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1596 // VTT ptr is not coming from a symbol.
1597 if (!getName())
1598 return success();
1599 StringRef name = *getName();
1600
1601 // Verify that the result type underlying pointer type matches the type of
1602 // the referenced cir.global op.
1603 auto op =
1604 symbolTable.lookupNearestSymbolFrom<cir::GlobalOp>(*this, getNameAttr());
1605 if (!op)
1606 return emitOpError("'")
1607 << name << "' does not reference a valid cir.global";
1608 std::optional<mlir::Attribute> init = op.getInitialValue();
1609 if (!init)
1610 return success();
1611 if (!isa<cir::ConstArrayAttr>(*init))
1612 return emitOpError(
1613 "Expected constant array in initializer for global VTT '")
1614 << name << "'";
1615 return success();
1616}
1617
1618LogicalResult cir::VTTAddrPointOp::verify() {
1619 // The operation uses either a symbol or a value to operate, but not both
1620 if (getName() && getSymAddr())
1621 return emitOpError("should use either a symbol or value, but not both");
1622
1623 // If not a symbol, stick with the concrete type used for getSymAddr.
1624 if (getSymAddr())
1625 return success();
1626
1627 mlir::Type resultType = getAddr().getType();
1628 mlir::Type resTy = cir::PointerType::get(
1629 cir::PointerType::get(cir::VoidType::get(getContext())));
1630
1631 if (resultType != resTy)
1632 return emitOpError("result type must be ")
1633 << resTy << ", but provided result type is " << resultType;
1634 return success();
1635}
1636
1637//===----------------------------------------------------------------------===//
1638// FuncOp
1639//===----------------------------------------------------------------------===//
1640
1641/// Returns the name used for the linkage attribute. This *must* correspond to
1642/// the name of the attribute in ODS.
1643static llvm::StringRef getLinkageAttrNameString() { return "linkage"; }
1644
1645void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
1646 StringRef name, FuncType type,
1647 GlobalLinkageKind linkage) {
1648 result.addRegion();
1649 result.addAttribute(SymbolTable::getSymbolAttrName(),
1650 builder.getStringAttr(name));
1651 result.addAttribute(getFunctionTypeAttrName(result.name),
1652 TypeAttr::get(type));
1653 result.addAttribute(
1655 GlobalLinkageKindAttr::get(builder.getContext(), linkage));
1656 result.addAttribute(getGlobalVisibilityAttrName(result.name),
1657 cir::VisibilityAttr::get(builder.getContext()));
1658}
1659
1660ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
1661 llvm::SMLoc loc = parser.getCurrentLocation();
1662 mlir::Builder &builder = parser.getBuilder();
1663
1664 mlir::StringAttr builtinNameAttr = getBuiltinAttrName(state.name);
1665 mlir::StringAttr coroutineNameAttr = getCoroutineAttrName(state.name);
1666 mlir::StringAttr lambdaNameAttr = getLambdaAttrName(state.name);
1667 mlir::StringAttr noProtoNameAttr = getNoProtoAttrName(state.name);
1668 mlir::StringAttr visNameAttr = getSymVisibilityAttrName(state.name);
1669 mlir::StringAttr visibilityNameAttr = getGlobalVisibilityAttrName(state.name);
1670 mlir::StringAttr dsoLocalNameAttr = getDsoLocalAttrName(state.name);
1671 mlir::StringAttr specialMemberAttr = getCxxSpecialMemberAttrName(state.name);
1672
1673 if (::mlir::succeeded(parser.parseOptionalKeyword(builtinNameAttr.strref())))
1674 state.addAttribute(builtinNameAttr, parser.getBuilder().getUnitAttr());
1675 if (::mlir::succeeded(
1676 parser.parseOptionalKeyword(coroutineNameAttr.strref())))
1677 state.addAttribute(coroutineNameAttr, parser.getBuilder().getUnitAttr());
1678 if (::mlir::succeeded(parser.parseOptionalKeyword(lambdaNameAttr.strref())))
1679 state.addAttribute(lambdaNameAttr, parser.getBuilder().getUnitAttr());
1680 if (parser.parseOptionalKeyword(noProtoNameAttr).succeeded())
1681 state.addAttribute(noProtoNameAttr, parser.getBuilder().getUnitAttr());
1682
1683 // Default to external linkage if no keyword is provided.
1684 state.addAttribute(getLinkageAttrNameString(),
1685 GlobalLinkageKindAttr::get(
1686 parser.getContext(),
1688 parser, GlobalLinkageKind::ExternalLinkage)));
1689
1690 ::llvm::StringRef visAttrStr;
1691 if (parser.parseOptionalKeyword(&visAttrStr, {"private", "public", "nested"})
1692 .succeeded()) {
1693 state.addAttribute(visNameAttr,
1694 parser.getBuilder().getStringAttr(visAttrStr));
1695 }
1696
1697 cir::VisibilityAttr cirVisibilityAttr;
1698 parseVisibilityAttr(parser, cirVisibilityAttr);
1699 state.addAttribute(visibilityNameAttr, cirVisibilityAttr);
1700
1701 if (parser.parseOptionalKeyword(dsoLocalNameAttr).succeeded())
1702 state.addAttribute(dsoLocalNameAttr, parser.getBuilder().getUnitAttr());
1703
1704 StringAttr nameAttr;
1705 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1706 state.attributes))
1707 return failure();
1711 bool isVariadic = false;
1712 if (function_interface_impl::parseFunctionSignatureWithArguments(
1713 parser, /*allowVariadic=*/true, arguments, isVariadic, resultTypes,
1714 resultAttrs))
1715 return failure();
1717 for (OpAsmParser::Argument &arg : arguments)
1718 argTypes.push_back(arg.type);
1719
1720 if (resultTypes.size() > 1) {
1721 return parser.emitError(
1722 loc, "functions with multiple return types are not supported");
1723 }
1724
1725 mlir::Type returnType =
1726 (resultTypes.empty() ? cir::VoidType::get(builder.getContext())
1727 : resultTypes.front());
1728
1729 cir::FuncType fnType = cir::FuncType::get(argTypes, returnType, isVariadic);
1730 if (!fnType)
1731 return failure();
1732 state.addAttribute(getFunctionTypeAttrName(state.name),
1733 TypeAttr::get(fnType));
1734
1735 bool hasAlias = false;
1736 mlir::StringAttr aliaseeNameAttr = getAliaseeAttrName(state.name);
1737 if (parser.parseOptionalKeyword("alias").succeeded()) {
1738 if (parser.parseLParen().failed())
1739 return failure();
1740 mlir::StringAttr aliaseeAttr;
1741 if (parser.parseOptionalSymbolName(aliaseeAttr).failed())
1742 return failure();
1743 state.addAttribute(aliaseeNameAttr, FlatSymbolRefAttr::get(aliaseeAttr));
1744 if (parser.parseRParen().failed())
1745 return failure();
1746 hasAlias = true;
1747 }
1748
1749 auto parseGlobalDtorCtor =
1750 [&](StringRef keyword,
1751 llvm::function_ref<void(std::optional<int> prio)> createAttr)
1752 -> mlir::LogicalResult {
1753 if (mlir::succeeded(parser.parseOptionalKeyword(keyword))) {
1754 std::optional<int> priority;
1755 if (mlir::succeeded(parser.parseOptionalLParen())) {
1756 auto parsedPriority = mlir::FieldParser<int>::parse(parser);
1757 if (mlir::failed(parsedPriority))
1758 return parser.emitError(parser.getCurrentLocation(),
1759 "failed to parse 'priority', of type 'int'");
1760 priority = parsedPriority.value_or(int());
1761 // Parse literal ')'
1762 if (parser.parseRParen())
1763 return failure();
1764 }
1765 createAttr(priority);
1766 }
1767 return success();
1768 };
1769
1770 // Parse CXXSpecialMember attribute
1771 if (parser.parseOptionalKeyword("special_member").succeeded()) {
1772 cir::CXXCtorAttr ctorAttr;
1773 cir::CXXDtorAttr dtorAttr;
1774 cir::CXXAssignAttr assignAttr;
1775 if (parser.parseLess().failed())
1776 return failure();
1777 if (parser.parseOptionalAttribute(ctorAttr).has_value())
1778 state.addAttribute(specialMemberAttr, ctorAttr);
1779 else if (parser.parseOptionalAttribute(dtorAttr).has_value())
1780 state.addAttribute(specialMemberAttr, dtorAttr);
1781 else if (parser.parseOptionalAttribute(assignAttr).has_value())
1782 state.addAttribute(specialMemberAttr, assignAttr);
1783 if (parser.parseGreater().failed())
1784 return failure();
1785 }
1786
1787 if (parseGlobalDtorCtor("global_ctor", [&](std::optional<int> priority) {
1788 mlir::IntegerAttr globalCtorPriorityAttr =
1789 builder.getI32IntegerAttr(priority.value_or(65535));
1790 state.addAttribute(getGlobalCtorPriorityAttrName(state.name),
1791 globalCtorPriorityAttr);
1792 }).failed())
1793 return failure();
1794
1795 if (parseGlobalDtorCtor("global_dtor", [&](std::optional<int> priority) {
1796 mlir::IntegerAttr globalDtorPriorityAttr =
1797 builder.getI32IntegerAttr(priority.value_or(65535));
1798 state.addAttribute(getGlobalDtorPriorityAttrName(state.name),
1799 globalDtorPriorityAttr);
1800 }).failed())
1801 return failure();
1802
1803 // Parse optional inline kind: inline(never|always|hint)
1804 if (parser.parseOptionalKeyword("inline").succeeded()) {
1805 if (parser.parseLParen().failed())
1806 return failure();
1807
1808 llvm::StringRef inlineKindStr;
1809 const std::array<llvm::StringRef, cir::getMaxEnumValForInlineKind()>
1810 allowedInlineKindStrs{
1811 cir::stringifyInlineKind(cir::InlineKind::NoInline),
1812 cir::stringifyInlineKind(cir::InlineKind::AlwaysInline),
1813 cir::stringifyInlineKind(cir::InlineKind::InlineHint),
1814 };
1815 if (parser.parseOptionalKeyword(&inlineKindStr, allowedInlineKindStrs)
1816 .failed())
1817 return parser.emitError(parser.getCurrentLocation(),
1818 "expected 'never', 'always', or 'hint'");
1819
1820 std::optional<InlineKind> inlineKind =
1821 cir::symbolizeInlineKind(inlineKindStr);
1822 if (!inlineKind)
1823 return parser.emitError(parser.getCurrentLocation(),
1824 "invalid inline kind");
1825
1826 state.addAttribute(getInlineKindAttrName(state.name),
1827 cir::InlineAttr::get(builder.getContext(), *inlineKind));
1828
1829 if (parser.parseRParen().failed())
1830 return failure();
1831 }
1832
1833 // Parse the optional function body.
1834 auto *body = state.addRegion();
1835 OptionalParseResult parseResult = parser.parseOptionalRegion(
1836 *body, arguments, /*enableNameShadowing=*/false);
1837 if (parseResult.has_value()) {
1838 if (hasAlias)
1839 return parser.emitError(loc, "function alias shall not have a body");
1840 if (failed(*parseResult))
1841 return failure();
1842 // Function body was parsed, make sure its not empty.
1843 if (body->empty())
1844 return parser.emitError(loc, "expected non-empty function body");
1845 }
1846
1847 return success();
1848}
1849
1850// This function corresponds to `llvm::GlobalValue::isDeclaration` and should
1851// have a similar implementation. We don't currently ifuncs or materializable
1852// functions, but those should be handled here as they are implemented.
1853bool cir::FuncOp::isDeclaration() {
1855
1856 std::optional<StringRef> aliasee = getAliasee();
1857 if (!aliasee)
1858 return getFunctionBody().empty();
1859
1860 // Aliases are always definitions.
1861 return false;
1862}
1863
1864bool cir::FuncOp::isCXXSpecialMemberFunction() {
1865 return getCxxSpecialMemberAttr() != nullptr;
1866}
1867
1868bool cir::FuncOp::isCxxConstructor() {
1869 auto attr = getCxxSpecialMemberAttr();
1870 return attr && dyn_cast<CXXCtorAttr>(attr);
1871}
1872
1873bool cir::FuncOp::isCxxDestructor() {
1874 auto attr = getCxxSpecialMemberAttr();
1875 return attr && dyn_cast<CXXDtorAttr>(attr);
1876}
1877
1878bool cir::FuncOp::isCxxSpecialAssignment() {
1879 auto attr = getCxxSpecialMemberAttr();
1880 return attr && dyn_cast<CXXAssignAttr>(attr);
1881}
1882
1883std::optional<CtorKind> cir::FuncOp::getCxxConstructorKind() {
1884 mlir::Attribute attr = getCxxSpecialMemberAttr();
1885 if (attr) {
1886 if (auto ctor = dyn_cast<CXXCtorAttr>(attr))
1887 return ctor.getCtorKind();
1888 }
1889 return std::nullopt;
1890}
1891
1892std::optional<AssignKind> cir::FuncOp::getCxxSpecialAssignKind() {
1893 mlir::Attribute attr = getCxxSpecialMemberAttr();
1894 if (attr) {
1895 if (auto assign = dyn_cast<CXXAssignAttr>(attr))
1896 return assign.getAssignKind();
1897 }
1898 return std::nullopt;
1899}
1900
1901bool cir::FuncOp::isCxxTrivialMemberFunction() {
1902 mlir::Attribute attr = getCxxSpecialMemberAttr();
1903 if (attr) {
1904 if (auto ctor = dyn_cast<CXXCtorAttr>(attr))
1905 return ctor.getIsTrivial();
1906 if (auto dtor = dyn_cast<CXXDtorAttr>(attr))
1907 return dtor.getIsTrivial();
1908 if (auto assign = dyn_cast<CXXAssignAttr>(attr))
1909 return assign.getIsTrivial();
1910 }
1911 return false;
1912}
1913
1914mlir::Region *cir::FuncOp::getCallableRegion() {
1915 // TODO(CIR): This function will have special handling for aliases and a
1916 // check for an external function, once those features have been upstreamed.
1917 return &getBody();
1918}
1919
1920void cir::FuncOp::print(OpAsmPrinter &p) {
1921 if (getBuiltin())
1922 p << " builtin";
1923
1924 if (getCoroutine())
1925 p << " coroutine";
1926
1927 if (getLambda())
1928 p << " lambda";
1929
1930 if (getNoProto())
1931 p << " no_proto";
1932
1933 if (getComdat())
1934 p << " comdat";
1935
1936 if (getLinkage() != GlobalLinkageKind::ExternalLinkage)
1937 p << ' ' << stringifyGlobalLinkageKind(getLinkage());
1938
1939 mlir::SymbolTable::Visibility vis = getVisibility();
1940 if (vis != mlir::SymbolTable::Visibility::Public)
1941 p << ' ' << vis;
1942
1943 cir::VisibilityAttr cirVisibilityAttr = getGlobalVisibilityAttr();
1944 if (!cirVisibilityAttr.isDefault()) {
1945 p << ' ';
1946 printVisibilityAttr(p, cirVisibilityAttr);
1947 }
1948
1949 if (getDsoLocal())
1950 p << " dso_local";
1951
1952 p << ' ';
1953 p.printSymbolName(getSymName());
1954 cir::FuncType fnType = getFunctionType();
1955 function_interface_impl::printFunctionSignature(
1956 p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes());
1957
1958 if (std::optional<StringRef> aliaseeName = getAliasee()) {
1959 p << " alias(";
1960 p.printSymbolName(*aliaseeName);
1961 p << ")";
1962 }
1963
1964 if (auto specialMemberAttr = getCxxSpecialMember()) {
1965 p << " special_member<";
1966 p.printAttribute(*specialMemberAttr);
1967 p << '>';
1968 }
1969
1970 if (auto globalCtorPriority = getGlobalCtorPriority()) {
1971 p << " global_ctor";
1972 if (globalCtorPriority.value() != 65535)
1973 p << "(" << globalCtorPriority.value() << ")";
1974 }
1975
1976 if (auto globalDtorPriority = getGlobalDtorPriority()) {
1977 p << " global_dtor";
1978 if (globalDtorPriority.value() != 65535)
1979 p << "(" << globalDtorPriority.value() << ")";
1980 }
1981
1982 if (cir::InlineAttr inlineAttr = getInlineKindAttr()) {
1983 p << " inline(" << cir::stringifyInlineKind(inlineAttr.getValue()) << ")";
1984 }
1985
1986 // Print the body if this is not an external function.
1987 Region &body = getOperation()->getRegion(0);
1988 if (!body.empty()) {
1989 p << ' ';
1990 p.printRegion(body, /*printEntryBlockArgs=*/false,
1991 /*printBlockTerminators=*/true);
1992 }
1993}
1994
1995mlir::LogicalResult cir::FuncOp::verify() {
1996
1997 if (!isDeclaration() && getCoroutine()) {
1998 bool foundAwait = false;
1999 this->walk([&](Operation *op) {
2000 if (auto await = dyn_cast<AwaitOp>(op)) {
2001 foundAwait = true;
2002 return;
2003 }
2004 });
2005 if (!foundAwait)
2006 return emitOpError()
2007 << "coroutine body must use at least one cir.await op";
2008 }
2009
2010 llvm::SmallSet<llvm::StringRef, 16> labels;
2011 llvm::SmallSet<llvm::StringRef, 16> gotos;
2012 llvm::SmallSet<llvm::StringRef, 16> blockAddresses;
2013 bool invalidBlockAddress = false;
2014 getOperation()->walk([&](mlir::Operation *op) {
2015 if (auto lab = dyn_cast<cir::LabelOp>(op)) {
2016 labels.insert(lab.getLabel());
2017 } else if (auto goTo = dyn_cast<cir::GotoOp>(op)) {
2018 gotos.insert(goTo.getLabel());
2019 } else if (auto blkAdd = dyn_cast<cir::BlockAddressOp>(op)) {
2020 if (blkAdd.getBlockAddrInfoAttr().getFunc().getAttr() != getSymName()) {
2021 // Stop the walk early, no need to continue
2022 invalidBlockAddress = true;
2023 return mlir::WalkResult::interrupt();
2024 }
2025 blockAddresses.insert(blkAdd.getBlockAddrInfoAttr().getLabel());
2026 }
2027 return mlir::WalkResult::advance();
2028 });
2029
2030 if (invalidBlockAddress)
2031 return emitOpError() << "blockaddress references a different function";
2032
2033 llvm::SmallSet<llvm::StringRef, 16> mismatched;
2034 if (!labels.empty() || !gotos.empty()) {
2035 mismatched = llvm::set_difference(gotos, labels);
2036
2037 if (!mismatched.empty())
2038 return emitOpError() << "goto/label mismatch";
2039 }
2040
2041 mismatched.clear();
2042
2043 if (!labels.empty() || !blockAddresses.empty()) {
2044 mismatched = llvm::set_difference(blockAddresses, labels);
2045
2046 if (!mismatched.empty())
2047 return emitOpError()
2048 << "expects an existing label target in the referenced function";
2049 }
2050
2051 return success();
2052}
2053
2054//===----------------------------------------------------------------------===//
2055// BinOp
2056//===----------------------------------------------------------------------===//
2057LogicalResult cir::BinOp::verify() {
2058 bool noWrap = getNoUnsignedWrap() || getNoSignedWrap();
2059 bool saturated = getSaturated();
2060
2061 if (!isa<cir::IntType>(getType()) && noWrap)
2062 return emitError()
2063 << "only operations on integer values may have nsw/nuw flags";
2064
2065 bool noWrapOps = getKind() == cir::BinOpKind::Add ||
2066 getKind() == cir::BinOpKind::Sub ||
2067 getKind() == cir::BinOpKind::Mul;
2068
2069 bool saturatedOps =
2070 getKind() == cir::BinOpKind::Add || getKind() == cir::BinOpKind::Sub;
2071
2072 if (noWrap && !noWrapOps)
2073 return emitError() << "The nsw/nuw flags are applicable to opcodes: 'add', "
2074 "'sub' and 'mul'";
2075 if (saturated && !saturatedOps)
2076 return emitError() << "The saturated flag is applicable to opcodes: 'add' "
2077 "and 'sub'";
2078 if (noWrap && saturated)
2079 return emitError() << "The nsw/nuw flags and the saturated flag are "
2080 "mutually exclusive";
2081
2082 return mlir::success();
2083}
2084
2085//===----------------------------------------------------------------------===//
2086// TernaryOp
2087//===----------------------------------------------------------------------===//
2088
2089/// Given the region at `point`, or the parent operation if `point` is None,
2090/// return the successor regions. These are the regions that may be selected
2091/// during the flow of control. `operands` is a set of optional attributes that
2092/// correspond to a constant value for each operand, or null if that operand is
2093/// not a constant.
2094void cir::TernaryOp::getSuccessorRegions(
2095 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2096 // The `true` and the `false` region branch back to the parent operation.
2097 if (!point.isParent()) {
2098 regions.push_back(RegionSuccessor(getOperation(), this->getODSResults(0)));
2099 return;
2100 }
2101
2102 // When branching from the parent operation, both the true and false
2103 // regions are considered possible successors
2104 regions.push_back(RegionSuccessor(&getTrueRegion()));
2105 regions.push_back(RegionSuccessor(&getFalseRegion()));
2106}
2107
2108void cir::TernaryOp::build(
2109 OpBuilder &builder, OperationState &result, Value cond,
2110 function_ref<void(OpBuilder &, Location)> trueBuilder,
2111 function_ref<void(OpBuilder &, Location)> falseBuilder) {
2112 result.addOperands(cond);
2113 OpBuilder::InsertionGuard guard(builder);
2114 Region *trueRegion = result.addRegion();
2115 builder.createBlock(trueRegion);
2116 trueBuilder(builder, result.location);
2117 Region *falseRegion = result.addRegion();
2118 builder.createBlock(falseRegion);
2119 falseBuilder(builder, result.location);
2120
2121 // Get result type from whichever branch has a yield (the other may have
2122 // unreachable from a throw expression)
2123 auto yield =
2124 dyn_cast_or_null<cir::YieldOp>(trueRegion->back().getTerminator());
2125 if (!yield)
2126 yield = dyn_cast_or_null<cir::YieldOp>(falseRegion->back().getTerminator());
2127
2128 assert((yield && yield.getNumOperands() <= 1) &&
2129 "expected zero or one result type");
2130 if (yield.getNumOperands() == 1)
2131 result.addTypes(TypeRange{yield.getOperandTypes().front()});
2132}
2133
2134//===----------------------------------------------------------------------===//
2135// SelectOp
2136//===----------------------------------------------------------------------===//
2137
2138OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
2139 mlir::Attribute condition = adaptor.getCondition();
2140 if (condition) {
2141 bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
2142 return conditionValue ? getTrueValue() : getFalseValue();
2143 }
2144
2145 // cir.select if %0 then x else x -> x
2146 mlir::Attribute trueValue = adaptor.getTrueValue();
2147 mlir::Attribute falseValue = adaptor.getFalseValue();
2148 if (trueValue == falseValue)
2149 return trueValue;
2150 if (getTrueValue() == getFalseValue())
2151 return getTrueValue();
2152
2153 return {};
2154}
2155
2156//===----------------------------------------------------------------------===//
2157// ShiftOp
2158//===----------------------------------------------------------------------===//
2159LogicalResult cir::ShiftOp::verify() {
2160 mlir::Operation *op = getOperation();
2161 auto op0VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(0).getType());
2162 auto op1VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(1).getType());
2163 if (!op0VecTy ^ !op1VecTy)
2164 return emitOpError() << "input types cannot be one vector and one scalar";
2165
2166 if (op0VecTy) {
2167 if (op0VecTy.getSize() != op1VecTy.getSize())
2168 return emitOpError() << "input vector types must have the same size";
2169
2170 auto opResultTy = mlir::dyn_cast<cir::VectorType>(getType());
2171 if (!opResultTy)
2172 return emitOpError() << "the type of the result must be a vector "
2173 << "if it is vector shift";
2174
2175 auto op0VecEleTy = mlir::cast<cir::IntType>(op0VecTy.getElementType());
2176 auto op1VecEleTy = mlir::cast<cir::IntType>(op1VecTy.getElementType());
2177 if (op0VecEleTy.getWidth() != op1VecEleTy.getWidth())
2178 return emitOpError()
2179 << "vector operands do not have the same elements sizes";
2180
2181 auto resVecEleTy = mlir::cast<cir::IntType>(opResultTy.getElementType());
2182 if (op0VecEleTy.getWidth() != resVecEleTy.getWidth())
2183 return emitOpError() << "vector operands and result type do not have the "
2184 "same elements sizes";
2185 }
2186
2187 return mlir::success();
2188}
2189
2190//===----------------------------------------------------------------------===//
2191// LabelOp Definitions
2192//===----------------------------------------------------------------------===//
2193
2194LogicalResult cir::LabelOp::verify() {
2195 mlir::Operation *op = getOperation();
2196 mlir::Block *blk = op->getBlock();
2197 if (&blk->front() != op)
2198 return emitError() << "must be the first operation in a block";
2199
2200 return mlir::success();
2201}
2202
2203//===----------------------------------------------------------------------===//
2204// UnaryOp
2205//===----------------------------------------------------------------------===//
2206
2207LogicalResult cir::UnaryOp::verify() {
2208 switch (getKind()) {
2209 case cir::UnaryOpKind::Inc:
2210 case cir::UnaryOpKind::Dec:
2211 case cir::UnaryOpKind::Plus:
2212 case cir::UnaryOpKind::Minus:
2213 case cir::UnaryOpKind::Not:
2214 // Nothing to verify.
2215 return success();
2216 }
2217
2218 llvm_unreachable("Unknown UnaryOp kind?");
2219}
2220
2221static bool isBoolNot(cir::UnaryOp op) {
2222 return isa<cir::BoolType>(op.getInput().getType()) &&
2223 op.getKind() == cir::UnaryOpKind::Not;
2224}
2225
2226// This folder simplifies the sequential boolean not operations.
2227// For instance, the next two unary operations will be eliminated:
2228//
2229// ```mlir
2230// %1 = cir.unary(not, %0) : !cir.bool, !cir.bool
2231// %2 = cir.unary(not, %1) : !cir.bool, !cir.bool
2232// ```
2233//
2234// and the argument of the first one (%0) will be used instead.
2235OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) {
2236 if (auto poison =
2237 mlir::dyn_cast_if_present<cir::PoisonAttr>(adaptor.getInput())) {
2238 // Propagate poison values
2239 return poison;
2240 }
2241
2242 if (isBoolNot(*this))
2243 if (auto previous = getInput().getDefiningOp<cir::UnaryOp>())
2244 if (isBoolNot(previous))
2245 return previous.getInput();
2246
2247 return {};
2248}
2249//===----------------------------------------------------------------------===//
2250// AwaitOp
2251//===----------------------------------------------------------------------===//
2252
2253void cir::AwaitOp::build(OpBuilder &builder, OperationState &result,
2254 cir::AwaitKind kind, BuilderCallbackRef readyBuilder,
2255 BuilderCallbackRef suspendBuilder,
2256 BuilderCallbackRef resumeBuilder) {
2257 result.addAttribute(getKindAttrName(result.name),
2258 cir::AwaitKindAttr::get(builder.getContext(), kind));
2259 {
2260 OpBuilder::InsertionGuard guard(builder);
2261 Region *readyRegion = result.addRegion();
2262 builder.createBlock(readyRegion);
2263 readyBuilder(builder, result.location);
2264 }
2265
2266 {
2267 OpBuilder::InsertionGuard guard(builder);
2268 Region *suspendRegion = result.addRegion();
2269 builder.createBlock(suspendRegion);
2270 suspendBuilder(builder, result.location);
2271 }
2272
2273 {
2274 OpBuilder::InsertionGuard guard(builder);
2275 Region *resumeRegion = result.addRegion();
2276 builder.createBlock(resumeRegion);
2277 resumeBuilder(builder, result.location);
2278 }
2279}
2280
2281void cir::AwaitOp::getSuccessorRegions(
2282 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2283 // If any index all the underlying regions branch back to the parent
2284 // operation.
2285 if (!point.isParent()) {
2286 regions.push_back(
2287 RegionSuccessor(getOperation(), getOperation()->getResults()));
2288 return;
2289 }
2290
2291 // TODO: retrieve information from the promise and only push the
2292 // necessary ones. Example: `std::suspend_never` on initial or final
2293 // await's might allow suspend region to be skipped.
2294 regions.push_back(RegionSuccessor(&this->getReady()));
2295 regions.push_back(RegionSuccessor(&this->getSuspend()));
2296 regions.push_back(RegionSuccessor(&this->getResume()));
2297}
2298
2299LogicalResult cir::AwaitOp::verify() {
2300 if (!isa<ConditionOp>(this->getReady().back().getTerminator()))
2301 return emitOpError("ready region must end with cir.condition");
2302 return success();
2303}
2304
2305//===----------------------------------------------------------------------===//
2306// CopyOp Definitions
2307//===----------------------------------------------------------------------===//
2308
2309LogicalResult cir::CopyOp::verify() {
2310 // A data layout is required for us to know the number of bytes to be copied.
2311 if (!getType().getPointee().hasTrait<DataLayoutTypeInterface::Trait>())
2312 return emitError() << "missing data layout for pointee type";
2313
2314 if (getSrc() == getDst())
2315 return emitError() << "source and destination are the same";
2316
2317 return mlir::success();
2318}
2319
2320//===----------------------------------------------------------------------===//
2321// GetMemberOp Definitions
2322//===----------------------------------------------------------------------===//
2323
2324LogicalResult cir::GetMemberOp::verify() {
2325 const auto recordTy = dyn_cast<RecordType>(getAddrTy().getPointee());
2326 if (!recordTy)
2327 return emitError() << "expected pointer to a record type";
2328
2329 if (recordTy.getMembers().size() <= getIndex())
2330 return emitError() << "member index out of bounds";
2331
2332 if (recordTy.getMembers()[getIndex()] != getType().getPointee())
2333 return emitError() << "member type mismatch";
2334
2335 return mlir::success();
2336}
2337
2338//===----------------------------------------------------------------------===//
2339// VecCreateOp
2340//===----------------------------------------------------------------------===//
2341
2342OpFoldResult cir::VecCreateOp::fold(FoldAdaptor adaptor) {
2343 if (llvm::any_of(getElements(), [](mlir::Value value) {
2344 return !value.getDefiningOp<cir::ConstantOp>();
2345 }))
2346 return {};
2347
2348 return cir::ConstVectorAttr::get(
2349 getType(), mlir::ArrayAttr::get(getContext(), adaptor.getElements()));
2350}
2351
2352LogicalResult cir::VecCreateOp::verify() {
2353 // Verify that the number of arguments matches the number of elements in the
2354 // vector, and that the type of all the arguments matches the type of the
2355 // elements in the vector.
2356 const cir::VectorType vecTy = getType();
2357 if (getElements().size() != vecTy.getSize()) {
2358 return emitOpError() << "operand count of " << getElements().size()
2359 << " doesn't match vector type " << vecTy
2360 << " element count of " << vecTy.getSize();
2361 }
2362
2363 const mlir::Type elementType = vecTy.getElementType();
2364 for (const mlir::Value element : getElements()) {
2365 if (element.getType() != elementType) {
2366 return emitOpError() << "operand type " << element.getType()
2367 << " doesn't match vector element type "
2368 << elementType;
2369 }
2370 }
2371
2372 return success();
2373}
2374
2375//===----------------------------------------------------------------------===//
2376// VecExtractOp
2377//===----------------------------------------------------------------------===//
2378
2379OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
2380 const auto vectorAttr =
2381 llvm::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec());
2382 if (!vectorAttr)
2383 return {};
2384
2385 const auto indexAttr =
2386 llvm::dyn_cast_if_present<cir::IntAttr>(adaptor.getIndex());
2387 if (!indexAttr)
2388 return {};
2389
2390 const mlir::ArrayAttr elements = vectorAttr.getElts();
2391 const uint64_t index = indexAttr.getUInt();
2392 if (index >= elements.size())
2393 return {};
2394
2395 return elements[index];
2396}
2397
2398//===----------------------------------------------------------------------===//
2399// VecCmpOp
2400//===----------------------------------------------------------------------===//
2401
2402OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
2403 auto lhsVecAttr =
2404 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getLhs());
2405 auto rhsVecAttr =
2406 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getRhs());
2407 if (!lhsVecAttr || !rhsVecAttr)
2408 return {};
2409
2410 mlir::Type inputElemTy =
2411 mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType();
2412 if (!isAnyIntegerOrFloatingPointType(inputElemTy))
2413 return {};
2414
2415 cir::CmpOpKind opKind = adaptor.getKind();
2416 mlir::ArrayAttr lhsVecElhs = lhsVecAttr.getElts();
2417 mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts();
2418 uint64_t vecSize = lhsVecElhs.size();
2419
2420 SmallVector<mlir::Attribute, 16> elements(vecSize);
2421 bool isIntAttr = vecSize && mlir::isa<cir::IntAttr>(lhsVecElhs[0]);
2422 for (uint64_t i = 0; i < vecSize; i++) {
2423 mlir::Attribute lhsAttr = lhsVecElhs[i];
2424 mlir::Attribute rhsAttr = rhsVecElhs[i];
2425 int cmpResult = 0;
2426 switch (opKind) {
2427 case cir::CmpOpKind::lt: {
2428 if (isIntAttr) {
2429 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
2430 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2431 } else {
2432 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <
2433 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2434 }
2435 break;
2436 }
2437 case cir::CmpOpKind::le: {
2438 if (isIntAttr) {
2439 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
2440 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2441 } else {
2442 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <=
2443 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2444 }
2445 break;
2446 }
2447 case cir::CmpOpKind::gt: {
2448 if (isIntAttr) {
2449 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
2450 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2451 } else {
2452 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >
2453 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2454 }
2455 break;
2456 }
2457 case cir::CmpOpKind::ge: {
2458 if (isIntAttr) {
2459 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
2460 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2461 } else {
2462 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >=
2463 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2464 }
2465 break;
2466 }
2467 case cir::CmpOpKind::eq: {
2468 if (isIntAttr) {
2469 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
2470 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2471 } else {
2472 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() ==
2473 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2474 }
2475 break;
2476 }
2477 case cir::CmpOpKind::ne: {
2478 if (isIntAttr) {
2479 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
2480 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2481 } else {
2482 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() !=
2483 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2484 }
2485 break;
2486 }
2487 }
2488
2489 elements[i] = cir::IntAttr::get(getType().getElementType(), cmpResult);
2490 }
2491
2492 return cir::ConstVectorAttr::get(
2493 getType(), mlir::ArrayAttr::get(getContext(), elements));
2494}
2495
2496//===----------------------------------------------------------------------===//
2497// VecShuffleOp
2498//===----------------------------------------------------------------------===//
2499
2500OpFoldResult cir::VecShuffleOp::fold(FoldAdaptor adaptor) {
2501 auto vec1Attr =
2502 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec1());
2503 auto vec2Attr =
2504 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec2());
2505 if (!vec1Attr || !vec2Attr)
2506 return {};
2507
2508 mlir::Type vec1ElemTy =
2509 mlir::cast<cir::VectorType>(vec1Attr.getType()).getElementType();
2510
2511 mlir::ArrayAttr vec1Elts = vec1Attr.getElts();
2512 mlir::ArrayAttr vec2Elts = vec2Attr.getElts();
2513 mlir::ArrayAttr indicesElts = adaptor.getIndices();
2514
2516 elements.reserve(indicesElts.size());
2517
2518 uint64_t vec1Size = vec1Elts.size();
2519 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
2520 if (idxAttr.getSInt() == -1) {
2521 elements.push_back(cir::UndefAttr::get(vec1ElemTy));
2522 continue;
2523 }
2524
2525 uint64_t idxValue = idxAttr.getUInt();
2526 elements.push_back(idxValue < vec1Size ? vec1Elts[idxValue]
2527 : vec2Elts[idxValue - vec1Size]);
2528 }
2529
2530 return cir::ConstVectorAttr::get(
2531 getType(), mlir::ArrayAttr::get(getContext(), elements));
2532}
2533
2534LogicalResult cir::VecShuffleOp::verify() {
2535 // The number of elements in the indices array must match the number of
2536 // elements in the result type.
2537 if (getIndices().size() != getResult().getType().getSize()) {
2538 return emitOpError() << ": the number of elements in " << getIndices()
2539 << " and " << getResult().getType() << " don't match";
2540 }
2541
2542 // The element types of the two input vectors and of the result type must
2543 // match.
2544 if (getVec1().getType().getElementType() !=
2545 getResult().getType().getElementType()) {
2546 return emitOpError() << ": element types of " << getVec1().getType()
2547 << " and " << getResult().getType() << " don't match";
2548 }
2549
2550 const uint64_t maxValidIndex =
2551 getVec1().getType().getSize() + getVec2().getType().getSize() - 1;
2552 if (llvm::any_of(
2553 getIndices().getAsRange<cir::IntAttr>(), [&](cir::IntAttr idxAttr) {
2554 return idxAttr.getSInt() != -1 && idxAttr.getUInt() > maxValidIndex;
2555 })) {
2556 return emitOpError() << ": index for __builtin_shufflevector must be "
2557 "less than the total number of vector elements";
2558 }
2559 return success();
2560}
2561
2562//===----------------------------------------------------------------------===//
2563// VecShuffleDynamicOp
2564//===----------------------------------------------------------------------===//
2565
2566OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
2567 mlir::Attribute vec = adaptor.getVec();
2568 mlir::Attribute indices = adaptor.getIndices();
2569 if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec) &&
2570 mlir::isa_and_nonnull<cir::ConstVectorAttr>(indices)) {
2571 auto vecAttr = mlir::cast<cir::ConstVectorAttr>(vec);
2572 auto indicesAttr = mlir::cast<cir::ConstVectorAttr>(indices);
2573
2574 mlir::ArrayAttr vecElts = vecAttr.getElts();
2575 mlir::ArrayAttr indicesElts = indicesAttr.getElts();
2576
2577 const uint64_t numElements = vecElts.size();
2578
2580 elements.reserve(numElements);
2581
2582 const uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1;
2583 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
2584 uint64_t idxValue = idxAttr.getUInt();
2585 uint64_t newIdx = idxValue & maskBits;
2586 elements.push_back(vecElts[newIdx]);
2587 }
2588
2589 return cir::ConstVectorAttr::get(
2590 getType(), mlir::ArrayAttr::get(getContext(), elements));
2591 }
2592
2593 return {};
2594}
2595
2596LogicalResult cir::VecShuffleDynamicOp::verify() {
2597 // The number of elements in the two input vectors must match.
2598 if (getVec().getType().getSize() !=
2599 mlir::cast<cir::VectorType>(getIndices().getType()).getSize()) {
2600 return emitOpError() << ": the number of elements in " << getVec().getType()
2601 << " and " << getIndices().getType() << " don't match";
2602 }
2603 return success();
2604}
2605
2606//===----------------------------------------------------------------------===//
2607// VecTernaryOp
2608//===----------------------------------------------------------------------===//
2609
2610LogicalResult cir::VecTernaryOp::verify() {
2611 // Verify that the condition operand has the same number of elements as the
2612 // other operands. (The automatic verification already checked that all
2613 // operands are vector types and that the second and third operands are the
2614 // same type.)
2615 if (getCond().getType().getSize() != getLhs().getType().getSize()) {
2616 return emitOpError() << ": the number of elements in "
2617 << getCond().getType() << " and " << getLhs().getType()
2618 << " don't match";
2619 }
2620 return success();
2621}
2622
2623OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
2624 mlir::Attribute cond = adaptor.getCond();
2625 mlir::Attribute lhs = adaptor.getLhs();
2626 mlir::Attribute rhs = adaptor.getRhs();
2627
2628 if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
2629 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
2630 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
2631 return {};
2632 auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
2633 auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
2634 auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
2635
2636 mlir::ArrayAttr condElts = condVec.getElts();
2637
2639 elements.reserve(condElts.size());
2640
2641 for (const auto &[idx, condAttr] :
2642 llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
2643 if (condAttr.getSInt()) {
2644 elements.push_back(lhsVec.getElts()[idx]);
2645 } else {
2646 elements.push_back(rhsVec.getElts()[idx]);
2647 }
2648 }
2649
2650 cir::VectorType vecTy = getLhs().getType();
2651 return cir::ConstVectorAttr::get(
2652 vecTy, mlir::ArrayAttr::get(getContext(), elements));
2653}
2654
2655//===----------------------------------------------------------------------===//
2656// ComplexCreateOp
2657//===----------------------------------------------------------------------===//
2658
2659LogicalResult cir::ComplexCreateOp::verify() {
2660 if (getType().getElementType() != getReal().getType()) {
2661 emitOpError()
2662 << "operand type of cir.complex.create does not match its result type";
2663 return failure();
2664 }
2665
2666 return success();
2667}
2668
2669OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
2670 mlir::Attribute real = adaptor.getReal();
2671 mlir::Attribute imag = adaptor.getImag();
2672 if (!real || !imag)
2673 return {};
2674
2675 // When both of real and imag are constants, we can fold the operation into an
2676 // `#cir.const_complex` operation.
2677 auto realAttr = mlir::cast<mlir::TypedAttr>(real);
2678 auto imagAttr = mlir::cast<mlir::TypedAttr>(imag);
2679 return cir::ConstComplexAttr::get(realAttr, imagAttr);
2680}
2681
2682//===----------------------------------------------------------------------===//
2683// ComplexRealOp
2684//===----------------------------------------------------------------------===//
2685
2686LogicalResult cir::ComplexRealOp::verify() {
2687 mlir::Type operandTy = getOperand().getType();
2688 if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
2689 operandTy = complexOperandTy.getElementType();
2690
2691 if (getType() != operandTy) {
2692 emitOpError() << ": result type does not match operand type";
2693 return failure();
2694 }
2695
2696 return success();
2697}
2698
2699OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
2700 if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
2701 return nullptr;
2702
2703 if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
2704 return complexCreateOp.getOperand(0);
2705
2706 auto complex =
2707 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
2708 return complex ? complex.getReal() : nullptr;
2709}
2710
2711//===----------------------------------------------------------------------===//
2712// ComplexImagOp
2713//===----------------------------------------------------------------------===//
2714
2715LogicalResult cir::ComplexImagOp::verify() {
2716 mlir::Type operandTy = getOperand().getType();
2717 if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
2718 operandTy = complexOperandTy.getElementType();
2719
2720 if (getType() != operandTy) {
2721 emitOpError() << ": result type does not match operand type";
2722 return failure();
2723 }
2724
2725 return success();
2726}
2727
2728OpFoldResult cir::ComplexImagOp::fold(FoldAdaptor adaptor) {
2729 if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
2730 return nullptr;
2731
2732 if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
2733 return complexCreateOp.getOperand(1);
2734
2735 auto complex =
2736 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
2737 return complex ? complex.getImag() : nullptr;
2738}
2739
2740//===----------------------------------------------------------------------===//
2741// ComplexRealPtrOp
2742//===----------------------------------------------------------------------===//
2743
2744LogicalResult cir::ComplexRealPtrOp::verify() {
2745 mlir::Type resultPointeeTy = getType().getPointee();
2746 cir::PointerType operandPtrTy = getOperand().getType();
2747 auto operandPointeeTy =
2748 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
2749
2750 if (resultPointeeTy != operandPointeeTy.getElementType()) {
2751 return emitOpError() << ": result type does not match operand type";
2752 }
2753
2754 return success();
2755}
2756
2757//===----------------------------------------------------------------------===//
2758// ComplexImagPtrOp
2759//===----------------------------------------------------------------------===//
2760
2761LogicalResult cir::ComplexImagPtrOp::verify() {
2762 mlir::Type resultPointeeTy = getType().getPointee();
2763 cir::PointerType operandPtrTy = getOperand().getType();
2764 auto operandPointeeTy =
2765 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
2766
2767 if (resultPointeeTy != operandPointeeTy.getElementType()) {
2768 return emitOpError()
2769 << "cir.complex.imag_ptr result type does not match operand type";
2770 }
2771 return success();
2772}
2773
2774//===----------------------------------------------------------------------===//
2775// Bit manipulation operations
2776//===----------------------------------------------------------------------===//
2777
2778static OpFoldResult
2779foldUnaryBitOp(mlir::Attribute inputAttr,
2780 llvm::function_ref<llvm::APInt(const llvm::APInt &)> func,
2781 bool poisonZero = false) {
2782 if (mlir::isa_and_present<cir::PoisonAttr>(inputAttr)) {
2783 // Propagate poison value
2784 return inputAttr;
2785 }
2786
2787 auto input = mlir::dyn_cast_if_present<IntAttr>(inputAttr);
2788 if (!input)
2789 return nullptr;
2790
2791 llvm::APInt inputValue = input.getValue();
2792 if (poisonZero && inputValue.isZero())
2793 return cir::PoisonAttr::get(input.getType());
2794
2795 llvm::APInt resultValue = func(inputValue);
2796 return IntAttr::get(input.getType(), resultValue);
2797}
2798
2799OpFoldResult BitClrsbOp::fold(FoldAdaptor adaptor) {
2800 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
2801 unsigned resultValue =
2802 inputValue.getBitWidth() - inputValue.getSignificantBits();
2803 return llvm::APInt(inputValue.getBitWidth(), resultValue);
2804 });
2805}
2806
2807OpFoldResult BitClzOp::fold(FoldAdaptor adaptor) {
2808 return foldUnaryBitOp(
2809 adaptor.getInput(),
2810 [](const llvm::APInt &inputValue) {
2811 unsigned resultValue = inputValue.countLeadingZeros();
2812 return llvm::APInt(inputValue.getBitWidth(), resultValue);
2813 },
2814 getPoisonZero());
2815}
2816
2817OpFoldResult BitCtzOp::fold(FoldAdaptor adaptor) {
2818 return foldUnaryBitOp(
2819 adaptor.getInput(),
2820 [](const llvm::APInt &inputValue) {
2821 return llvm::APInt(inputValue.getBitWidth(),
2822 inputValue.countTrailingZeros());
2823 },
2824 getPoisonZero());
2825}
2826
2827OpFoldResult BitFfsOp::fold(FoldAdaptor adaptor) {
2828 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
2829 unsigned trailingZeros = inputValue.countTrailingZeros();
2830 unsigned result =
2831 trailingZeros == inputValue.getBitWidth() ? 0 : trailingZeros + 1;
2832 return llvm::APInt(inputValue.getBitWidth(), result);
2833 });
2834}
2835
2836OpFoldResult BitParityOp::fold(FoldAdaptor adaptor) {
2837 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
2838 return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount() % 2);
2839 });
2840}
2841
2842OpFoldResult BitPopcountOp::fold(FoldAdaptor adaptor) {
2843 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
2844 return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount());
2845 });
2846}
2847
2848OpFoldResult BitReverseOp::fold(FoldAdaptor adaptor) {
2849 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
2850 return inputValue.reverseBits();
2851 });
2852}
2853
2854OpFoldResult ByteSwapOp::fold(FoldAdaptor adaptor) {
2855 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
2856 return inputValue.byteSwap();
2857 });
2858}
2859
2860OpFoldResult RotateOp::fold(FoldAdaptor adaptor) {
2861 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()) ||
2862 mlir::isa_and_present<cir::PoisonAttr>(adaptor.getAmount())) {
2863 // Propagate poison values
2864 return cir::PoisonAttr::get(getType());
2865 }
2866
2867 auto input = mlir::dyn_cast_if_present<IntAttr>(adaptor.getInput());
2868 auto amount = mlir::dyn_cast_if_present<IntAttr>(adaptor.getAmount());
2869 if (!input && !amount)
2870 return nullptr;
2871
2872 // We could fold cir.rotate even if one of its two operands is not a constant:
2873 // - `cir.rotate left/right %0, 0` could be folded into just %0 even if %0
2874 // is not a constant.
2875 // - `cir.rotate left/right 0/0b111...111, %0` could be folded into 0 or
2876 // 0b111...111 even if %0 is not a constant.
2877
2878 llvm::APInt inputValue;
2879 if (input) {
2880 inputValue = input.getValue();
2881 if (inputValue.isZero() || inputValue.isAllOnes()) {
2882 // An input value of all 0s or all 1s will not change after rotation
2883 return input;
2884 }
2885 }
2886
2887 uint64_t amountValue;
2888 if (amount) {
2889 amountValue = amount.getValue().urem(getInput().getType().getWidth());
2890 if (amountValue == 0) {
2891 // A shift amount of 0 will not change the input value
2892 return getInput();
2893 }
2894 }
2895
2896 if (!input || !amount)
2897 return nullptr;
2898
2899 assert(inputValue.getBitWidth() == getInput().getType().getWidth() &&
2900 "input value must have the same bit width as the input type");
2901
2902 llvm::APInt resultValue;
2903 if (isRotateLeft())
2904 resultValue = inputValue.rotl(amountValue);
2905 else
2906 resultValue = inputValue.rotr(amountValue);
2907
2908 return IntAttr::get(input.getContext(), input.getType(), resultValue);
2909}
2910
2911//===----------------------------------------------------------------------===//
2912// InlineAsmOp
2913//===----------------------------------------------------------------------===//
2914
2915void cir::InlineAsmOp::print(OpAsmPrinter &p) {
2916 p << '(' << getAsmFlavor() << ", ";
2917 p.increaseIndent();
2918 p.printNewline();
2919
2920 llvm::SmallVector<std::string, 3> names{"out", "in", "in_out"};
2921 auto *nameIt = names.begin();
2922 auto *attrIt = getOperandAttrs().begin();
2923
2924 for (mlir::OperandRange ops : getAsmOperands()) {
2925 p << *nameIt << " = ";
2926
2927 p << '[';
2928 llvm::interleaveComma(llvm::make_range(ops.begin(), ops.end()), p,
2929 [&](Value value) {
2930 p.printOperand(value);
2931 p << " : " << value.getType();
2932 if (*attrIt)
2933 p << " (maybe_memory)";
2934 attrIt++;
2935 });
2936 p << "],";
2937 p.printNewline();
2938 ++nameIt;
2939 }
2940
2941 p << "{";
2942 p.printString(getAsmString());
2943 p << " ";
2944 p.printString(getConstraints());
2945 p << "}";
2946 p.decreaseIndent();
2947 p << ')';
2948 if (getSideEffects())
2949 p << " side_effects";
2950
2951 std::array elidedAttrs{
2952 llvm::StringRef("asm_flavor"), llvm::StringRef("asm_string"),
2953 llvm::StringRef("constraints"), llvm::StringRef("operand_attrs"),
2954 llvm::StringRef("operands_segments"), llvm::StringRef("side_effects")};
2955 p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs);
2956
2957 if (auto v = getRes())
2958 p << " -> " << v.getType();
2959}
2960
2961void cir::InlineAsmOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2962 ArrayRef<ValueRange> asmOperands,
2963 StringRef asmString, StringRef constraints,
2964 bool sideEffects, cir::AsmFlavor asmFlavor,
2965 ArrayRef<Attribute> operandAttrs) {
2966 // Set up the operands_segments for VariadicOfVariadic
2967 SmallVector<int32_t> segments;
2968 for (auto operandRange : asmOperands) {
2969 segments.push_back(operandRange.size());
2970 odsState.addOperands(operandRange);
2971 }
2972
2973 odsState.addAttribute(
2974 "operands_segments",
2975 DenseI32ArrayAttr::get(odsBuilder.getContext(), segments));
2976 odsState.addAttribute("asm_string", odsBuilder.getStringAttr(asmString));
2977 odsState.addAttribute("constraints", odsBuilder.getStringAttr(constraints));
2978 odsState.addAttribute("asm_flavor",
2979 AsmFlavorAttr::get(odsBuilder.getContext(), asmFlavor));
2980
2981 if (sideEffects)
2982 odsState.addAttribute("side_effects", odsBuilder.getUnitAttr());
2983
2984 odsState.addAttribute("operand_attrs", odsBuilder.getArrayAttr(operandAttrs));
2985}
2986
2987ParseResult cir::InlineAsmOp::parse(OpAsmParser &parser,
2988 OperationState &result) {
2990 llvm::SmallVector<int32_t> operandsGroupSizes;
2991 std::string asmString, constraints;
2992 Type resType;
2993 MLIRContext *ctxt = parser.getBuilder().getContext();
2994
2995 auto error = [&](const Twine &msg) -> LogicalResult {
2996 return parser.emitError(parser.getCurrentLocation(), msg);
2997 };
2998
2999 auto expected = [&](const std::string &c) {
3000 return error("expected '" + c + "'");
3001 };
3002
3003 if (parser.parseLParen().failed())
3004 return expected("(");
3005
3006 auto flavor = FieldParser<AsmFlavor, AsmFlavor>::parse(parser);
3007 if (failed(flavor))
3008 return error("Unknown AsmFlavor");
3009
3010 if (parser.parseComma().failed())
3011 return expected(",");
3012
3013 auto parseValue = [&](Value &v) {
3014 OpAsmParser::UnresolvedOperand op;
3015
3016 if (parser.parseOperand(op) || parser.parseColon())
3017 return error("can't parse operand");
3018
3019 Type typ;
3020 if (parser.parseType(typ).failed())
3021 return error("can't parse operand type");
3023 if (parser.resolveOperand(op, typ, tmp))
3024 return error("can't resolve operand");
3025 v = tmp[0];
3026 return mlir::success();
3027 };
3028
3029 auto parseOperands = [&](llvm::StringRef name) {
3030 if (parser.parseKeyword(name).failed())
3031 return error("expected " + name + " operands here");
3032 if (parser.parseEqual().failed())
3033 return expected("=");
3034 if (parser.parseLSquare().failed())
3035 return expected("[");
3036
3037 int size = 0;
3038 if (parser.parseOptionalRSquare().succeeded()) {
3039 operandsGroupSizes.push_back(size);
3040 if (parser.parseComma())
3041 return expected(",");
3042 return mlir::success();
3043 }
3044
3045 auto parseOperand = [&]() {
3046 Value val;
3047 if (parseValue(val).succeeded()) {
3048 result.operands.push_back(val);
3049 size++;
3050
3051 if (parser.parseOptionalLParen().failed()) {
3052 operandAttrs.push_back(mlir::Attribute());
3053 return mlir::success();
3054 }
3055
3056 if (parser.parseKeyword("maybe_memory").succeeded()) {
3057 operandAttrs.push_back(mlir::UnitAttr::get(ctxt));
3058 if (parser.parseRParen())
3059 return expected(")");
3060 return mlir::success();
3061 } else {
3062 return expected("maybe_memory");
3063 }
3064 }
3065 return mlir::failure();
3066 };
3067
3068 if (parser.parseCommaSeparatedList(parseOperand).failed())
3069 return mlir::failure();
3070
3071 if (parser.parseRSquare().failed() || parser.parseComma().failed())
3072 return expected("]");
3073 operandsGroupSizes.push_back(size);
3074 return mlir::success();
3075 };
3076
3077 if (parseOperands("out").failed() || parseOperands("in").failed() ||
3078 parseOperands("in_out").failed())
3079 return error("failed to parse operands");
3080
3081 if (parser.parseLBrace())
3082 return expected("{");
3083 if (parser.parseString(&asmString))
3084 return error("asm string parsing failed");
3085 if (parser.parseString(&constraints))
3086 return error("constraints string parsing failed");
3087 if (parser.parseRBrace())
3088 return expected("}");
3089 if (parser.parseRParen())
3090 return expected(")");
3091
3092 if (parser.parseOptionalKeyword("side_effects").succeeded())
3093 result.attributes.set("side_effects", UnitAttr::get(ctxt));
3094
3095 if (parser.parseOptionalArrow().succeeded() &&
3096 parser.parseType(resType).failed())
3097 return mlir::failure();
3098
3099 if (parser.parseOptionalAttrDict(result.attributes).failed())
3100 return mlir::failure();
3101
3102 result.attributes.set("asm_flavor", AsmFlavorAttr::get(ctxt, *flavor));
3103 result.attributes.set("asm_string", StringAttr::get(ctxt, asmString));
3104 result.attributes.set("constraints", StringAttr::get(ctxt, constraints));
3105 result.attributes.set("operand_attrs", ArrayAttr::get(ctxt, operandAttrs));
3106 result.getOrAddProperties<InlineAsmOp::Properties>().operands_segments =
3107 parser.getBuilder().getDenseI32ArrayAttr(operandsGroupSizes);
3108 if (resType)
3109 result.addTypes(TypeRange{resType});
3110
3111 return mlir::success();
3112}
3113
3114//===----------------------------------------------------------------------===//
3115// ThrowOp
3116//===----------------------------------------------------------------------===//
3117
3118mlir::LogicalResult cir::ThrowOp::verify() {
3119 // For the no-rethrow version, it must have at least the exception pointer.
3120 if (rethrows())
3121 return success();
3122
3123 if (getNumOperands() != 0) {
3124 if (getTypeInfo())
3125 return success();
3126 return emitOpError() << "'type_info' symbol attribute missing";
3127 }
3128
3129 return failure();
3130}
3131
3132//===----------------------------------------------------------------------===//
3133// AtomicFetchOp
3134//===----------------------------------------------------------------------===//
3135
3136LogicalResult cir::AtomicFetchOp::verify() {
3137 if (getBinop() != cir::AtomicFetchKind::Add &&
3138 getBinop() != cir::AtomicFetchKind::Sub &&
3139 getBinop() != cir::AtomicFetchKind::Max &&
3140 getBinop() != cir::AtomicFetchKind::Min &&
3141 !mlir::isa<cir::IntType>(getVal().getType()))
3142 return emitError("only atomic add, sub, max, and min operation could "
3143 "operate on floating-point values");
3144 return success();
3145}
3146
3147//===----------------------------------------------------------------------===//
3148// TypeInfoAttr
3149//===----------------------------------------------------------------------===//
3150
3151LogicalResult cir::TypeInfoAttr::verify(
3152 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
3153 ::mlir::Type type, ::mlir::ArrayAttr typeInfoData) {
3154
3155 if (cir::ConstRecordAttr::verify(emitError, type, typeInfoData).failed())
3156 return failure();
3157
3158 return success();
3159}
3160
3161//===----------------------------------------------------------------------===//
3162// TryOp
3163//===----------------------------------------------------------------------===//
3164
3165void cir::TryOp::getSuccessorRegions(
3166 mlir::RegionBranchPoint point,
3168 // The `try` and the `catchers` region branch back to the parent operation.
3169 if (!point.isParent()) {
3170 regions.push_back(
3171 RegionSuccessor(getOperation(), getOperation()->getResults()));
3172 return;
3173 }
3174
3175 regions.push_back(mlir::RegionSuccessor(&getTryRegion()));
3176
3177 // TODO(CIR): If we know a target function never throws a specific type, we
3178 // can remove the catch handler.
3179 for (mlir::Region &handlerRegion : this->getHandlerRegions())
3180 regions.push_back(mlir::RegionSuccessor(&handlerRegion));
3181}
3182
3183static void
3184printTryHandlerRegions(mlir::OpAsmPrinter &printer, cir::TryOp op,
3185 mlir::MutableArrayRef<mlir::Region> handlerRegions,
3186 mlir::ArrayAttr handlerTypes) {
3187 if (!handlerTypes)
3188 return;
3189
3190 for (const auto [typeIdx, typeAttr] : llvm::enumerate(handlerTypes)) {
3191 if (typeIdx)
3192 printer << " ";
3193
3194 if (mlir::isa<cir::CatchAllAttr>(typeAttr)) {
3195 printer << "catch all ";
3196 } else if (mlir::isa<cir::UnwindAttr>(typeAttr)) {
3197 printer << "unwind ";
3198 } else {
3199 printer << "catch [type ";
3200 printer.printAttribute(typeAttr);
3201 printer << "] ";
3202 }
3203
3204 printer.printRegion(handlerRegions[typeIdx],
3205 /*printEntryBLockArgs=*/false,
3206 /*printBlockTerminators=*/true);
3207 }
3208}
3209
3210static mlir::ParseResult parseTryHandlerRegions(
3211 mlir::OpAsmParser &parser,
3212 llvm::SmallVectorImpl<std::unique_ptr<mlir::Region>> &handlerRegions,
3213 mlir::ArrayAttr &handlerTypes) {
3214
3215 auto parseCheckedCatcherRegion = [&]() -> mlir::ParseResult {
3216 handlerRegions.emplace_back(new mlir::Region);
3217
3218 mlir::Region &currRegion = *handlerRegions.back();
3219 mlir::SMLoc regionLoc = parser.getCurrentLocation();
3220 if (parser.parseRegion(currRegion)) {
3221 handlerRegions.clear();
3222 return failure();
3223 }
3224
3225 if (currRegion.empty())
3226 return parser.emitError(regionLoc, "handler region shall not be empty");
3227
3228 if (!(currRegion.back().mightHaveTerminator() &&
3229 currRegion.back().getTerminator()))
3230 return parser.emitError(
3231 regionLoc, "blocks are expected to be explicitly terminated");
3232
3233 return success();
3234 };
3235
3236 bool hasCatchAll = false;
3238 while (parser.parseOptionalKeyword("catch").succeeded()) {
3239 bool hasLSquare = parser.parseOptionalLSquare().succeeded();
3240
3241 llvm::StringRef attrStr;
3242 if (parser.parseOptionalKeyword(&attrStr, {"all", "type"}).failed())
3243 return parser.emitError(parser.getCurrentLocation(),
3244 "expected 'all' or 'type' keyword");
3245
3246 bool isCatchAll = attrStr == "all";
3247 if (isCatchAll) {
3248 if (hasCatchAll)
3249 return parser.emitError(parser.getCurrentLocation(),
3250 "can't have more than one catch all");
3251 hasCatchAll = true;
3252 }
3253
3254 mlir::Attribute exceptionRTTIAttr;
3255 if (!isCatchAll && parser.parseAttribute(exceptionRTTIAttr).failed())
3256 return parser.emitError(parser.getCurrentLocation(),
3257 "expected valid RTTI info attribute");
3258
3259 catcherAttrs.push_back(isCatchAll
3260 ? cir::CatchAllAttr::get(parser.getContext())
3261 : exceptionRTTIAttr);
3262
3263 if (hasLSquare && isCatchAll)
3264 return parser.emitError(parser.getCurrentLocation(),
3265 "catch all dosen't need RTTI info attribute");
3266
3267 if (hasLSquare && parser.parseRSquare().failed())
3268 return parser.emitError(parser.getCurrentLocation(),
3269 "expected `]` after RTTI info attribute");
3270
3271 if (parseCheckedCatcherRegion().failed())
3272 return mlir::failure();
3273 }
3274
3275 if (parser.parseOptionalKeyword("unwind").succeeded()) {
3276 if (hasCatchAll)
3277 return parser.emitError(parser.getCurrentLocation(),
3278 "unwind can't be used with catch all");
3279
3280 catcherAttrs.push_back(cir::UnwindAttr::get(parser.getContext()));
3281 if (parseCheckedCatcherRegion().failed())
3282 return mlir::failure();
3283 }
3284
3285 handlerTypes = parser.getBuilder().getArrayAttr(catcherAttrs);
3286 return mlir::success();
3287}
3288
3289//===----------------------------------------------------------------------===//
3290// TableGen'd op method definitions
3291//===----------------------------------------------------------------------===//
3292
3293#define GET_OP_CLASSES
3294#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
static const MemRegion * getRegion(const CallEvent &Call, const MutexDescriptor &Descriptor, bool IsLock)
static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op, cir::FuncOp function)
static bool isBoolNot(cir::UnaryOp op)
static bool isIntOrBoolCast(cir::CastOp op)
static void printConstant(OpAsmPrinter &p, Attribute value)
static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser, mlir::Region &region)
void printVisibilityAttr(OpAsmPrinter &printer, cir::VisibilityAttr &visibility)
static ParseResult parseSwitchFlatOpCases(OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< llvm::SmallVector< OpAsmParser::UnresolvedOperand > > &caseOperands, SmallVectorImpl< llvm::SmallVector< Type > > &caseOperandTypes)
<cases> ::= [ (case (, case )* )?
static LogicalResult verifyCallCommInSymbolUses(mlir::Operation *op, SymbolTableCollection &symbolTable)
static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region, SMLoc errLoc)
static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValueAttr, mlir::Region &ctorRegion, mlir::Region &dtorRegion)
static OpFoldResult foldUnaryBitOp(mlir::Attribute inputAttr, llvm::function_ref< llvm::APInt(const llvm::APInt &)> func, bool poisonZero=false)
static llvm::StringRef getLinkageAttrNameString()
Returns the name used for the linkage attribute.
static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue)
Parse an enum from the keyword, or default to the provided default value.
static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op, Type flagType, mlir::ArrayAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static void printSwitchOp(OpAsmPrinter &p, cir::SwitchOp op, mlir::Region &bodyRegion, mlir::Value condition, mlir::Type condType)
static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op, TypeAttr type, Attribute initAttr, mlir::Region &ctorRegion, mlir::Region &dtorRegion)
static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result)
Parse an enum from the keyword, return failure if the keyword is not found.
static Value tryFoldCastChain(cir::CastOp op)
void parseVisibilityAttr(OpAsmParser &parser, cir::VisibilityAttr &visibility)
static void printTryHandlerRegions(mlir::OpAsmPrinter &printer, cir::TryOp op, mlir::MutableArrayRef< mlir::Region > handlerRegions, mlir::ArrayAttr handlerTypes)
static bool omitRegionTerm(mlir::Region &r)
static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer, cir::ScopeOp &op, mlir::Region &region)
static ParseResult parseConstantValue(OpAsmParser &parser, mlir::Attribute &valueAttr)
static void printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym, mlir::Value indirectCallee, mlir::OpAsmPrinter &printer, bool isNothrow, cir::SideEffect sideEffect)
static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, mlir::Attribute attrType)
static ParseResult parseSwitchOp(OpAsmParser &parser, mlir::Region &regions, mlir::OpAsmParser::UnresolvedOperand &cond, mlir::Type &condType)
static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, mlir::OperationState &result)
static mlir::ParseResult parseTryHandlerRegions(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< std::unique_ptr< mlir::Region > > &handlerRegions, mlir::ArrayAttr &handlerTypes)
#define REGISTER_ENUM_TYPE(Ty)
static int parseOptionalKeywordAlternative(AsmParser &parser, ArrayRef< llvm::StringRef > keywords)
llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> BuilderCallbackRef
Definition CIRDialect.h:37
llvm::function_ref< void( mlir::OpBuilder &, mlir::Location, mlir::OperationState &)> BuilderOpStateCallbackRef
Definition CIRDialect.h:39
static std::optional< NonLoc > getIndex(ProgramStateRef State, const ElementRegion *ER, CharKind CK)
static Decl::Kind getKind(const Decl *D)
TokenType getType() const
Returns the token's type, e.g.
__device__ __2f16 float c
void buildTerminatedBody(mlir::OpBuilder &builder, mlir::Location loc)
const internal::VariadicAllOfMatcher< Attr > attr
const AstTypeMatcher< RecordType > recordType
StringRef getName(const HeaderType T)
Definition HeaderFile.h:38
RangeSelector name(std::string ID)
Given a node with a "name", (like NamedDecl, DeclRefExpr, CxxCtorInitializer, and TypeLoc) selects th...
nullptr
This class represents a compute construct, representing a 'Kind' of ‘parallel’, 'serial',...
__DEVICE__ _Tp arg(const std::complex< _Tp > &__c)
static bool addressSpace()
static bool opGlobalThreadLocal()
static bool opCallCallConv()
static bool opScopeCleanupRegion()
static bool supportIFuncAttr()