clang 22.0.0git
CIRDialect.cpp
Go to the documentation of this file.
1//===- CIRDialect.cpp - MLIR CIR ops implementation -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the CIR dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
17
18#include "mlir/Interfaces/ControlFlowInterfaces.h"
19#include "mlir/Interfaces/FunctionImplementation.h"
20#include "mlir/Support/LLVM.h"
21
22#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
23#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
25#include "llvm/ADT/SetOperations.h"
26#include "llvm/ADT/SmallSet.h"
27#include "llvm/Support/LogicalResult.h"
28
29using namespace mlir;
30using namespace cir;
31
32//===----------------------------------------------------------------------===//
33// CIR Dialect
34//===----------------------------------------------------------------------===//
35namespace {
36struct CIROpAsmDialectInterface : public OpAsmDialectInterface {
37 using OpAsmDialectInterface::OpAsmDialectInterface;
38
39 AliasResult getAlias(Type type, raw_ostream &os) const final {
40 if (auto recordType = dyn_cast<cir::RecordType>(type)) {
41 StringAttr nameAttr = recordType.getName();
42 if (!nameAttr)
43 os << "rec_anon_" << recordType.getKindAsStr();
44 else
45 os << "rec_" << nameAttr.getValue();
46 return AliasResult::OverridableAlias;
47 }
48 if (auto intType = dyn_cast<cir::IntType>(type)) {
49 // We only provide alias for standard integer types (i.e. integer types
50 // whose width is a power of 2 and at least 8).
51 unsigned width = intType.getWidth();
52 if (width < 8 || !llvm::isPowerOf2_32(width))
53 return AliasResult::NoAlias;
54 os << intType.getAlias();
55 return AliasResult::OverridableAlias;
56 }
57 if (auto voidType = dyn_cast<cir::VoidType>(type)) {
58 os << voidType.getAlias();
59 return AliasResult::OverridableAlias;
60 }
61
62 return AliasResult::NoAlias;
63 }
64
65 AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
66 if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr)) {
67 os << (boolAttr.getValue() ? "true" : "false");
68 return AliasResult::FinalAlias;
69 }
70 if (auto bitfield = mlir::dyn_cast<cir::BitfieldInfoAttr>(attr)) {
71 os << "bfi_" << bitfield.getName().str();
72 return AliasResult::FinalAlias;
73 }
74 return AliasResult::NoAlias;
75 }
76};
77} // namespace
78
79void cir::CIRDialect::initialize() {
80 registerTypes();
81 registerAttributes();
82 addOperations<
83#define GET_OP_LIST
84#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
85 >();
86 addInterfaces<CIROpAsmDialectInterface>();
87}
88
89Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
90 mlir::Attribute value,
91 mlir::Type type,
92 mlir::Location loc) {
93 return builder.create<cir::ConstantOp>(loc, type,
94 mlir::cast<mlir::TypedAttr>(value));
95}
96
97//===----------------------------------------------------------------------===//
98// Helpers
99//===----------------------------------------------------------------------===//
100
101// Parses one of the keywords provided in the list `keywords` and returns the
102// position of the parsed keyword in the list. If none of the keywords from the
103// list is parsed, returns -1.
104static int parseOptionalKeywordAlternative(AsmParser &parser,
105 ArrayRef<llvm::StringRef> keywords) {
106 for (auto en : llvm::enumerate(keywords)) {
107 if (succeeded(parser.parseOptionalKeyword(en.value())))
108 return en.index();
109 }
110 return -1;
111}
112
113namespace {
114template <typename Ty> struct EnumTraits {};
115
116#define REGISTER_ENUM_TYPE(Ty) \
117 template <> struct EnumTraits<cir::Ty> { \
118 static llvm::StringRef stringify(cir::Ty value) { \
119 return stringify##Ty(value); \
120 } \
121 static unsigned getMaxEnumVal() { return cir::getMaxEnumValFor##Ty(); } \
122 }
123
124REGISTER_ENUM_TYPE(GlobalLinkageKind);
125REGISTER_ENUM_TYPE(VisibilityKind);
126REGISTER_ENUM_TYPE(SideEffect);
127} // namespace
128
129/// Parse an enum from the keyword, or default to the provided default value.
130/// The return type is the enum type by default, unless overriden with the
131/// second template argument.
132template <typename EnumTy, typename RetTy = EnumTy>
133static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue) {
135 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
136 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
137
138 int index = parseOptionalKeywordAlternative(parser, names);
139 if (index == -1)
140 return static_cast<RetTy>(defaultValue);
141 return static_cast<RetTy>(index);
142}
143
144/// Parse an enum from the keyword, return failure if the keyword is not found.
145template <typename EnumTy, typename RetTy = EnumTy>
146static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result) {
148 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
149 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
150
151 int index = parseOptionalKeywordAlternative(parser, names);
152 if (index == -1)
153 return failure();
154 result = static_cast<RetTy>(index);
155 return success();
156}
157
158// Check if a region's termination omission is valid and, if so, creates and
159// inserts the omitted terminator into the region.
160static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region,
161 SMLoc errLoc) {
162 Location eLoc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
163 OpBuilder builder(parser.getBuilder().getContext());
164
165 // Insert empty block in case the region is empty to ensure the terminator
166 // will be inserted
167 if (region.empty())
168 builder.createBlock(&region);
169
170 Block &block = region.back();
171 // Region is properly terminated: nothing to do.
172 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>())
173 return success();
174
175 // Check for invalid terminator omissions.
176 if (!region.hasOneBlock())
177 return parser.emitError(errLoc,
178 "multi-block region must not omit terminator");
179
180 // Terminator was omitted correctly: recreate it.
181 builder.setInsertionPointToEnd(&block);
182 builder.create<cir::YieldOp>(eLoc);
183 return success();
184}
185
186// True if the region's terminator should be omitted.
187static bool omitRegionTerm(mlir::Region &r) {
188 const auto singleNonEmptyBlock = r.hasOneBlock() && !r.back().empty();
189 const auto yieldsNothing = [&r]() {
190 auto y = dyn_cast<cir::YieldOp>(r.back().getTerminator());
191 return y && y.getArgs().empty();
192 };
193 return singleNonEmptyBlock && yieldsNothing();
194}
195
196void printVisibilityAttr(OpAsmPrinter &printer,
197 cir::VisibilityAttr &visibility) {
198 switch (visibility.getValue()) {
199 case cir::VisibilityKind::Hidden:
200 printer << "hidden";
201 break;
202 case cir::VisibilityKind::Protected:
203 printer << "protected";
204 break;
205 case cir::VisibilityKind::Default:
206 break;
207 }
208}
209
210void parseVisibilityAttr(OpAsmParser &parser, cir::VisibilityAttr &visibility) {
211 cir::VisibilityKind visibilityKind =
212 parseOptionalCIRKeyword(parser, cir::VisibilityKind::Default);
213 visibility = cir::VisibilityAttr::get(parser.getContext(), visibilityKind);
214}
215
216//===----------------------------------------------------------------------===//
217// CIR Custom Parsers/Printers
218//===----------------------------------------------------------------------===//
219
220static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser,
221 mlir::Region &region) {
222 auto regionLoc = parser.getCurrentLocation();
223 if (parser.parseRegion(region))
224 return failure();
225 if (ensureRegionTerm(parser, region, regionLoc).failed())
226 return failure();
227 return success();
228}
229
230static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer,
231 cir::ScopeOp &op,
232 mlir::Region &region) {
233 printer.printRegion(region,
234 /*printEntryBlockArgs=*/false,
235 /*printBlockTerminators=*/!omitRegionTerm(region));
236}
237
238//===----------------------------------------------------------------------===//
239// AllocaOp
240//===----------------------------------------------------------------------===//
241
242void cir::AllocaOp::build(mlir::OpBuilder &odsBuilder,
243 mlir::OperationState &odsState, mlir::Type addr,
244 mlir::Type allocaType, llvm::StringRef name,
245 mlir::IntegerAttr alignment) {
246 odsState.addAttribute(getAllocaTypeAttrName(odsState.name),
247 mlir::TypeAttr::get(allocaType));
248 odsState.addAttribute(getNameAttrName(odsState.name),
249 odsBuilder.getStringAttr(name));
250 if (alignment) {
251 odsState.addAttribute(getAlignmentAttrName(odsState.name), alignment);
252 }
253 odsState.addTypes(addr);
254}
255
256//===----------------------------------------------------------------------===//
257// BreakOp
258//===----------------------------------------------------------------------===//
259
260LogicalResult cir::BreakOp::verify() {
262 if (!getOperation()->getParentOfType<LoopOpInterface>() &&
263 !getOperation()->getParentOfType<SwitchOp>())
264 return emitOpError("must be within a loop");
265 return success();
266}
267
268//===----------------------------------------------------------------------===//
269// ConditionOp
270//===----------------------------------------------------------------------===//
271
272//===----------------------------------
273// BranchOpTerminatorInterface Methods
274//===----------------------------------
275
276void cir::ConditionOp::getSuccessorRegions(
277 ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
278 // TODO(cir): The condition value may be folded to a constant, narrowing
279 // down its list of possible successors.
280
281 // Parent is a loop: condition may branch to the body or to the parent op.
282 if (auto loopOp = dyn_cast<LoopOpInterface>(getOperation()->getParentOp())) {
283 regions.emplace_back(&loopOp.getBody(), loopOp.getBody().getArguments());
284 regions.emplace_back(loopOp->getResults());
285 }
286
288}
289
290MutableOperandRange
291cir::ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
292 // No values are yielded to the successor region.
293 return MutableOperandRange(getOperation(), 0, 0);
294}
295
296LogicalResult cir::ConditionOp::verify() {
298 if (!isa<LoopOpInterface>(getOperation()->getParentOp()))
299 return emitOpError("condition must be within a conditional region");
300 return success();
301}
302
303//===----------------------------------------------------------------------===//
304// ConstantOp
305//===----------------------------------------------------------------------===//
306
307static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
308 mlir::Attribute attrType) {
309 if (isa<cir::ConstPtrAttr>(attrType)) {
310 if (!mlir::isa<cir::PointerType>(opType))
311 return op->emitOpError(
312 "pointer constant initializing a non-pointer type");
313 return success();
314 }
315
316 if (isa<cir::ZeroAttr>(attrType)) {
317 if (isa<cir::RecordType, cir::ArrayType, cir::VectorType, cir::ComplexType>(
318 opType))
319 return success();
320 return op->emitOpError(
321 "zero expects struct, array, vector, or complex type");
322 }
323
324 if (mlir::isa<cir::BoolAttr>(attrType)) {
325 if (!mlir::isa<cir::BoolType>(opType))
326 return op->emitOpError("result type (")
327 << opType << ") must be '!cir.bool' for '" << attrType << "'";
328 return success();
329 }
330
331 if (mlir::isa<cir::IntAttr, cir::FPAttr>(attrType)) {
332 auto at = cast<TypedAttr>(attrType);
333 if (at.getType() != opType) {
334 return op->emitOpError("result type (")
335 << opType << ") does not match value type (" << at.getType()
336 << ")";
337 }
338 return success();
339 }
340
341 if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
342 cir::ConstComplexAttr, cir::ConstRecordAttr,
343 cir::GlobalViewAttr, cir::PoisonAttr, cir::TypeInfoAttr,
344 cir::VTableAttr>(attrType))
345 return success();
346
347 assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
348 return op->emitOpError("global with type ")
349 << cast<TypedAttr>(attrType).getType() << " not yet supported";
350}
351
352LogicalResult cir::ConstantOp::verify() {
353 // ODS already generates checks to make sure the result type is valid. We just
354 // need to additionally check that the value's attribute type is consistent
355 // with the result type.
356 return checkConstantTypes(getOperation(), getType(), getValue());
357}
358
359OpFoldResult cir::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
360 return getValue();
361}
362
363//===----------------------------------------------------------------------===//
364// ContinueOp
365//===----------------------------------------------------------------------===//
366
367LogicalResult cir::ContinueOp::verify() {
368 if (!getOperation()->getParentOfType<LoopOpInterface>())
369 return emitOpError("must be within a loop");
370 return success();
371}
372
373//===----------------------------------------------------------------------===//
374// CastOp
375//===----------------------------------------------------------------------===//
376
377LogicalResult cir::CastOp::verify() {
378 mlir::Type resType = getType();
379 mlir::Type srcType = getSrc().getType();
380
381 if (mlir::isa<cir::VectorType>(srcType) &&
382 mlir::isa<cir::VectorType>(resType)) {
383 // Use the element type of the vector to verify the cast kind. (Except for
384 // bitcast, see below.)
385 srcType = mlir::dyn_cast<cir::VectorType>(srcType).getElementType();
386 resType = mlir::dyn_cast<cir::VectorType>(resType).getElementType();
387 }
388
389 switch (getKind()) {
390 case cir::CastKind::int_to_bool: {
391 if (!mlir::isa<cir::BoolType>(resType))
392 return emitOpError() << "requires !cir.bool type for result";
393 if (!mlir::isa<cir::IntType>(srcType))
394 return emitOpError() << "requires !cir.int type for source";
395 return success();
396 }
397 case cir::CastKind::ptr_to_bool: {
398 if (!mlir::isa<cir::BoolType>(resType))
399 return emitOpError() << "requires !cir.bool type for result";
400 if (!mlir::isa<cir::PointerType>(srcType))
401 return emitOpError() << "requires !cir.ptr type for source";
402 return success();
403 }
404 case cir::CastKind::integral: {
405 if (!mlir::isa<cir::IntType>(resType))
406 return emitOpError() << "requires !cir.int type for result";
407 if (!mlir::isa<cir::IntType>(srcType))
408 return emitOpError() << "requires !cir.int type for source";
409 return success();
410 }
411 case cir::CastKind::array_to_ptrdecay: {
412 const auto arrayPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
413 const auto flatPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
414 if (!arrayPtrTy || !flatPtrTy)
415 return emitOpError() << "requires !cir.ptr type for source and result";
416
417 // TODO(CIR): Make sure the AddrSpace of both types are equals
418 return success();
419 }
420 case cir::CastKind::bitcast: {
421 // Handle the pointer types first.
422 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
423 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
424
425 if (srcPtrTy && resPtrTy) {
426 return success();
427 }
428
429 return success();
430 }
431 case cir::CastKind::floating: {
432 if (!mlir::isa<cir::FPTypeInterface>(srcType) ||
433 !mlir::isa<cir::FPTypeInterface>(resType))
434 return emitOpError() << "requires !cir.float type for source and result";
435 return success();
436 }
437 case cir::CastKind::float_to_int: {
438 if (!mlir::isa<cir::FPTypeInterface>(srcType))
439 return emitOpError() << "requires !cir.float type for source";
440 if (!mlir::dyn_cast<cir::IntType>(resType))
441 return emitOpError() << "requires !cir.int type for result";
442 return success();
443 }
444 case cir::CastKind::int_to_ptr: {
445 if (!mlir::dyn_cast<cir::IntType>(srcType))
446 return emitOpError() << "requires !cir.int type for source";
447 if (!mlir::dyn_cast<cir::PointerType>(resType))
448 return emitOpError() << "requires !cir.ptr type for result";
449 return success();
450 }
451 case cir::CastKind::ptr_to_int: {
452 if (!mlir::dyn_cast<cir::PointerType>(srcType))
453 return emitOpError() << "requires !cir.ptr type for source";
454 if (!mlir::dyn_cast<cir::IntType>(resType))
455 return emitOpError() << "requires !cir.int type for result";
456 return success();
457 }
458 case cir::CastKind::float_to_bool: {
459 if (!mlir::isa<cir::FPTypeInterface>(srcType))
460 return emitOpError() << "requires !cir.float type for source";
461 if (!mlir::isa<cir::BoolType>(resType))
462 return emitOpError() << "requires !cir.bool type for result";
463 return success();
464 }
465 case cir::CastKind::bool_to_int: {
466 if (!mlir::isa<cir::BoolType>(srcType))
467 return emitOpError() << "requires !cir.bool type for source";
468 if (!mlir::isa<cir::IntType>(resType))
469 return emitOpError() << "requires !cir.int type for result";
470 return success();
471 }
472 case cir::CastKind::int_to_float: {
473 if (!mlir::isa<cir::IntType>(srcType))
474 return emitOpError() << "requires !cir.int type for source";
475 if (!mlir::isa<cir::FPTypeInterface>(resType))
476 return emitOpError() << "requires !cir.float type for result";
477 return success();
478 }
479 case cir::CastKind::bool_to_float: {
480 if (!mlir::isa<cir::BoolType>(srcType))
481 return emitOpError() << "requires !cir.bool type for source";
482 if (!mlir::isa<cir::FPTypeInterface>(resType))
483 return emitOpError() << "requires !cir.float type for result";
484 return success();
485 }
486 case cir::CastKind::address_space: {
487 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
488 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
489 if (!srcPtrTy || !resPtrTy)
490 return emitOpError() << "requires !cir.ptr type for source and result";
491 if (srcPtrTy.getPointee() != resPtrTy.getPointee())
492 return emitOpError() << "requires two types differ in addrspace only";
493 return success();
494 }
495 case cir::CastKind::float_to_complex: {
496 if (!mlir::isa<cir::FPTypeInterface>(srcType))
497 return emitOpError() << "requires !cir.float type for source";
498 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
499 if (!resComplexTy)
500 return emitOpError() << "requires !cir.complex type for result";
501 if (srcType != resComplexTy.getElementType())
502 return emitOpError() << "requires source type match result element type";
503 return success();
504 }
505 case cir::CastKind::int_to_complex: {
506 if (!mlir::isa<cir::IntType>(srcType))
507 return emitOpError() << "requires !cir.int type for source";
508 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
509 if (!resComplexTy)
510 return emitOpError() << "requires !cir.complex type for result";
511 if (srcType != resComplexTy.getElementType())
512 return emitOpError() << "requires source type match result element type";
513 return success();
514 }
515 case cir::CastKind::float_complex_to_real: {
516 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
517 if (!srcComplexTy)
518 return emitOpError() << "requires !cir.complex type for source";
519 if (!mlir::isa<cir::FPTypeInterface>(resType))
520 return emitOpError() << "requires !cir.float type for result";
521 if (srcComplexTy.getElementType() != resType)
522 return emitOpError() << "requires source element type match result type";
523 return success();
524 }
525 case cir::CastKind::int_complex_to_real: {
526 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
527 if (!srcComplexTy)
528 return emitOpError() << "requires !cir.complex type for source";
529 if (!mlir::isa<cir::IntType>(resType))
530 return emitOpError() << "requires !cir.int type for result";
531 if (srcComplexTy.getElementType() != resType)
532 return emitOpError() << "requires source element type match result type";
533 return success();
534 }
535 case cir::CastKind::float_complex_to_bool: {
536 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
537 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
538 return emitOpError()
539 << "requires floating point !cir.complex type for source";
540 if (!mlir::isa<cir::BoolType>(resType))
541 return emitOpError() << "requires !cir.bool type for result";
542 return success();
543 }
544 case cir::CastKind::int_complex_to_bool: {
545 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
546 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
547 return emitOpError()
548 << "requires floating point !cir.complex type for source";
549 if (!mlir::isa<cir::BoolType>(resType))
550 return emitOpError() << "requires !cir.bool type for result";
551 return success();
552 }
553 case cir::CastKind::float_complex: {
554 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
555 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
556 return emitOpError()
557 << "requires floating point !cir.complex type for source";
558 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
559 if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
560 return emitOpError()
561 << "requires floating point !cir.complex type for result";
562 return success();
563 }
564 case cir::CastKind::float_complex_to_int_complex: {
565 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
566 if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
567 return emitOpError()
568 << "requires floating point !cir.complex type for source";
569 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
570 if (!resComplexTy || !resComplexTy.isIntegerComplex())
571 return emitOpError() << "requires integer !cir.complex type for result";
572 return success();
573 }
574 case cir::CastKind::int_complex: {
575 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
576 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
577 return emitOpError() << "requires integer !cir.complex type for source";
578 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
579 if (!resComplexTy || !resComplexTy.isIntegerComplex())
580 return emitOpError() << "requires integer !cir.complex type for result";
581 return success();
582 }
583 case cir::CastKind::int_complex_to_float_complex: {
584 auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
585 if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
586 return emitOpError() << "requires integer !cir.complex type for source";
587 auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
588 if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
589 return emitOpError()
590 << "requires floating point !cir.complex type for result";
591 return success();
592 }
593 default:
594 llvm_unreachable("Unknown CastOp kind?");
595 }
596}
597
598static bool isIntOrBoolCast(cir::CastOp op) {
599 auto kind = op.getKind();
600 return kind == cir::CastKind::bool_to_int ||
601 kind == cir::CastKind::int_to_bool || kind == cir::CastKind::integral;
602}
603
604static Value tryFoldCastChain(cir::CastOp op) {
605 cir::CastOp head = op, tail = op;
606
607 while (op) {
608 if (!isIntOrBoolCast(op))
609 break;
610 head = op;
611 op = head.getSrc().getDefiningOp<cir::CastOp>();
612 }
613
614 if (head == tail)
615 return {};
616
617 // if bool_to_int -> ... -> int_to_bool: take the bool
618 // as we had it was before all casts
619 if (head.getKind() == cir::CastKind::bool_to_int &&
620 tail.getKind() == cir::CastKind::int_to_bool)
621 return head.getSrc();
622
623 // if int_to_bool -> ... -> int_to_bool: take the result
624 // of the first one, as no other casts (and ext casts as well)
625 // don't change the first result
626 if (head.getKind() == cir::CastKind::int_to_bool &&
627 tail.getKind() == cir::CastKind::int_to_bool)
628 return head.getResult();
629
630 return {};
631}
632
633OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
634 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getSrc())) {
635 // Propagate poison value
636 return cir::PoisonAttr::get(getContext(), getType());
637 }
638
639 if (getSrc().getType() == getType()) {
640 switch (getKind()) {
641 case cir::CastKind::integral: {
642 // TODO: for sign differences, it's possible in certain conditions to
643 // create a new attribute that's capable of representing the source.
645 auto foldOrder = getSrc().getDefiningOp()->fold(foldResults);
646 if (foldOrder.succeeded() && mlir::isa<mlir::Attribute>(foldResults[0]))
647 return mlir::cast<mlir::Attribute>(foldResults[0]);
648 return {};
649 }
650 case cir::CastKind::bitcast:
651 case cir::CastKind::address_space:
652 case cir::CastKind::float_complex:
653 case cir::CastKind::int_complex: {
654 return getSrc();
655 }
656 default:
657 return {};
658 }
659 }
660 return tryFoldCastChain(*this);
661}
662
663//===----------------------------------------------------------------------===//
664// CallOp
665//===----------------------------------------------------------------------===//
666
667mlir::OperandRange cir::CallOp::getArgOperands() {
668 if (isIndirect())
669 return getArgs().drop_front(1);
670 return getArgs();
671}
672
673mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
674 mlir::MutableOperandRange args = getArgsMutable();
675 if (isIndirect())
676 return args.slice(1, args.size() - 1);
677 return args;
678}
679
680mlir::Value cir::CallOp::getIndirectCall() {
681 assert(isIndirect());
682 return getOperand(0);
683}
684
685/// Return the operand at index 'i'.
686Value cir::CallOp::getArgOperand(unsigned i) {
687 if (isIndirect())
688 ++i;
689 return getOperand(i);
690}
691
692/// Return the number of operands.
693unsigned cir::CallOp::getNumArgOperands() {
694 if (isIndirect())
695 return this->getOperation()->getNumOperands() - 1;
696 return this->getOperation()->getNumOperands();
697}
698
699static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
700 mlir::OperationState &result) {
702 llvm::SMLoc opsLoc;
703 mlir::FlatSymbolRefAttr calleeAttr;
704 llvm::ArrayRef<mlir::Type> allResultTypes;
705
706 // If we cannot parse a string callee, it means this is an indirect call.
707 if (!parser
708 .parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
709 result.attributes)
710 .has_value()) {
711 OpAsmParser::UnresolvedOperand indirectVal;
712 // Do not resolve right now, since we need to figure out the type
713 if (parser.parseOperand(indirectVal).failed())
714 return failure();
715 ops.push_back(indirectVal);
716 }
717
718 if (parser.parseLParen())
719 return mlir::failure();
720
721 opsLoc = parser.getCurrentLocation();
722 if (parser.parseOperandList(ops))
723 return mlir::failure();
724 if (parser.parseRParen())
725 return mlir::failure();
726
727 if (parser.parseOptionalKeyword("nothrow").succeeded())
728 result.addAttribute(CIRDialect::getNoThrowAttrName(),
729 mlir::UnitAttr::get(parser.getContext()));
730
731 if (parser.parseOptionalKeyword("side_effect").succeeded()) {
732 if (parser.parseLParen().failed())
733 return failure();
734 cir::SideEffect sideEffect;
735 if (parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed())
736 return failure();
737 if (parser.parseRParen().failed())
738 return failure();
739 auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
740 result.addAttribute(CIRDialect::getSideEffectAttrName(), attr);
741 }
742
743 if (parser.parseOptionalAttrDict(result.attributes))
744 return ::mlir::failure();
745
746 if (parser.parseColon())
747 return ::mlir::failure();
748
749 mlir::FunctionType opsFnTy;
750 if (parser.parseType(opsFnTy))
751 return mlir::failure();
752
753 allResultTypes = opsFnTy.getResults();
754 result.addTypes(allResultTypes);
755
756 if (parser.resolveOperands(ops, opsFnTy.getInputs(), opsLoc, result.operands))
757 return mlir::failure();
758
759 return mlir::success();
760}
761
762static void printCallCommon(mlir::Operation *op,
763 mlir::FlatSymbolRefAttr calleeSym,
764 mlir::Value indirectCallee,
765 mlir::OpAsmPrinter &printer, bool isNothrow,
766 cir::SideEffect sideEffect) {
767 printer << ' ';
768
769 auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
770 auto ops = callLikeOp.getArgOperands();
771
772 if (calleeSym) {
773 // Direct calls
774 printer.printAttributeWithoutType(calleeSym);
775 } else {
776 // Indirect calls
777 assert(indirectCallee);
778 printer << indirectCallee;
779 }
780 printer << "(" << ops << ")";
781
782 if (isNothrow)
783 printer << " nothrow";
784
785 if (sideEffect != cir::SideEffect::All) {
786 printer << " side_effect(";
787 printer << stringifySideEffect(sideEffect);
788 printer << ")";
789 }
790
791 printer.printOptionalAttrDict(op->getAttrs(),
792 {CIRDialect::getCalleeAttrName(),
793 CIRDialect::getNoThrowAttrName(),
794 CIRDialect::getSideEffectAttrName()});
795
796 printer << " : ";
797 printer.printFunctionalType(op->getOperands().getTypes(),
798 op->getResultTypes());
799}
800
801mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
802 mlir::OperationState &result) {
803 return parseCallCommon(parser, result);
804}
805
806void cir::CallOp::print(mlir::OpAsmPrinter &p) {
807 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
808 cir::SideEffect sideEffect = getSideEffect();
809 printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
810 sideEffect);
811}
812
813static LogicalResult
814verifyCallCommInSymbolUses(mlir::Operation *op,
815 SymbolTableCollection &symbolTable) {
816 auto fnAttr =
817 op->getAttrOfType<FlatSymbolRefAttr>(CIRDialect::getCalleeAttrName());
818 if (!fnAttr) {
819 // This is an indirect call, thus we don't have to check the symbol uses.
820 return mlir::success();
821 }
822
823 auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
824 if (!fn)
825 return op->emitOpError() << "'" << fnAttr.getValue()
826 << "' does not reference a valid function";
827
828 auto callIf = dyn_cast<cir::CIRCallOpInterface>(op);
829 assert(callIf && "expected CIR call interface to be always available");
830
831 // Verify that the operand and result types match the callee. Note that
832 // argument-checking is disabled for functions without a prototype.
833 auto fnType = fn.getFunctionType();
834 if (!fn.getNoProto()) {
835 unsigned numCallOperands = callIf.getNumArgOperands();
836 unsigned numFnOpOperands = fnType.getNumInputs();
837
838 if (!fnType.isVarArg() && numCallOperands != numFnOpOperands)
839 return op->emitOpError("incorrect number of operands for callee");
840 if (fnType.isVarArg() && numCallOperands < numFnOpOperands)
841 return op->emitOpError("too few operands for callee");
842
843 for (unsigned i = 0, e = numFnOpOperands; i != e; ++i)
844 if (callIf.getArgOperand(i).getType() != fnType.getInput(i))
845 return op->emitOpError("operand type mismatch: expected operand type ")
846 << fnType.getInput(i) << ", but provided "
847 << op->getOperand(i).getType() << " for operand number " << i;
848 }
849
851
852 // Void function must not return any results.
853 if (fnType.hasVoidReturn() && op->getNumResults() != 0)
854 return op->emitOpError("callee returns void but call has results");
855
856 // Non-void function calls must return exactly one result.
857 if (!fnType.hasVoidReturn() && op->getNumResults() != 1)
858 return op->emitOpError("incorrect number of results for callee");
859
860 // Parent function and return value types must match.
861 if (!fnType.hasVoidReturn() &&
862 op->getResultTypes().front() != fnType.getReturnType()) {
863 return op->emitOpError("result type mismatch: expected ")
864 << fnType.getReturnType() << ", but provided "
865 << op->getResult(0).getType();
866 }
867
868 return mlir::success();
869}
870
871LogicalResult
872cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
873 return verifyCallCommInSymbolUses(*this, symbolTable);
874}
875
876//===----------------------------------------------------------------------===//
877// ReturnOp
878//===----------------------------------------------------------------------===//
879
880static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op,
881 cir::FuncOp function) {
882 // ReturnOps currently only have a single optional operand.
883 if (op.getNumOperands() > 1)
884 return op.emitOpError() << "expects at most 1 return operand";
885
886 // Ensure returned type matches the function signature.
887 auto expectedTy = function.getFunctionType().getReturnType();
888 auto actualTy =
889 (op.getNumOperands() == 0 ? cir::VoidType::get(op.getContext())
890 : op.getOperand(0).getType());
891 if (actualTy != expectedTy)
892 return op.emitOpError() << "returns " << actualTy
893 << " but enclosing function returns " << expectedTy;
894
895 return mlir::success();
896}
897
898mlir::LogicalResult cir::ReturnOp::verify() {
899 // Returns can be present in multiple different scopes, get the
900 // wrapping function and start from there.
901 auto *fnOp = getOperation()->getParentOp();
902 while (!isa<cir::FuncOp>(fnOp))
903 fnOp = fnOp->getParentOp();
904
905 // Make sure return types match function return type.
906 if (checkReturnAndFunction(*this, cast<cir::FuncOp>(fnOp)).failed())
907 return failure();
908
909 return success();
910}
911
912//===----------------------------------------------------------------------===//
913// IfOp
914//===----------------------------------------------------------------------===//
915
916ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
917 // create the regions for 'then'.
918 result.regions.reserve(2);
919 Region *thenRegion = result.addRegion();
920 Region *elseRegion = result.addRegion();
921
922 mlir::Builder &builder = parser.getBuilder();
923 OpAsmParser::UnresolvedOperand cond;
924 Type boolType = cir::BoolType::get(builder.getContext());
925
926 if (parser.parseOperand(cond) ||
927 parser.resolveOperand(cond, boolType, result.operands))
928 return failure();
929
930 // Parse 'then' region.
931 mlir::SMLoc parseThenLoc = parser.getCurrentLocation();
932 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
933 return failure();
934
935 if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed())
936 return failure();
937
938 // If we find an 'else' keyword, parse the 'else' region.
939 if (!parser.parseOptionalKeyword("else")) {
940 mlir::SMLoc parseElseLoc = parser.getCurrentLocation();
941 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
942 return failure();
943 if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed())
944 return failure();
945 }
946
947 // Parse the optional attribute list.
948 if (parser.parseOptionalAttrDict(result.attributes))
949 return failure();
950 return success();
951}
952
953void cir::IfOp::print(OpAsmPrinter &p) {
954 p << " " << getCondition() << " ";
955 mlir::Region &thenRegion = this->getThenRegion();
956 p.printRegion(thenRegion,
957 /*printEntryBlockArgs=*/false,
958 /*printBlockTerminators=*/!omitRegionTerm(thenRegion));
959
960 // Print the 'else' regions if it exists and has a block.
961 mlir::Region &elseRegion = this->getElseRegion();
962 if (!elseRegion.empty()) {
963 p << " else ";
964 p.printRegion(elseRegion,
965 /*printEntryBlockArgs=*/false,
966 /*printBlockTerminators=*/!omitRegionTerm(elseRegion));
967 }
968
969 p.printOptionalAttrDict(getOperation()->getAttrs());
970}
971
972/// Default callback for IfOp builders.
973void cir::buildTerminatedBody(OpBuilder &builder, Location loc) {
974 // add cir.yield to end of the block
975 builder.create<cir::YieldOp>(loc);
976}
977
978/// Given the region at `index`, or the parent operation if `index` is None,
979/// return the successor regions. These are the regions that may be selected
980/// during the flow of control. `operands` is a set of optional attributes that
981/// correspond to a constant value for each operand, or null if that operand is
982/// not a constant.
983void cir::IfOp::getSuccessorRegions(mlir::RegionBranchPoint point,
984 SmallVectorImpl<RegionSuccessor> &regions) {
985 // The `then` and the `else` region branch back to the parent operation.
986 if (!point.isParent()) {
987 regions.push_back(RegionSuccessor());
988 return;
989 }
990
991 // Don't consider the else region if it is empty.
992 Region *elseRegion = &this->getElseRegion();
993 if (elseRegion->empty())
994 elseRegion = nullptr;
995
996 // If the condition isn't constant, both regions may be executed.
997 regions.push_back(RegionSuccessor(&getThenRegion()));
998 // If the else region does not exist, it is not a viable successor.
999 if (elseRegion)
1000 regions.push_back(RegionSuccessor(elseRegion));
1001
1002 return;
1003}
1004
1005void cir::IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1006 bool withElseRegion, BuilderCallbackRef thenBuilder,
1007 BuilderCallbackRef elseBuilder) {
1008 assert(thenBuilder && "the builder callback for 'then' must be present");
1009 result.addOperands(cond);
1010
1011 OpBuilder::InsertionGuard guard(builder);
1012 Region *thenRegion = result.addRegion();
1013 builder.createBlock(thenRegion);
1014 thenBuilder(builder, result.location);
1015
1016 Region *elseRegion = result.addRegion();
1017 if (!withElseRegion)
1018 return;
1019
1020 builder.createBlock(elseRegion);
1021 elseBuilder(builder, result.location);
1022}
1023
1024//===----------------------------------------------------------------------===//
1025// ScopeOp
1026//===----------------------------------------------------------------------===//
1027
1028/// Given the region at `index`, or the parent operation if `index` is None,
1029/// return the successor regions. These are the regions that may be selected
1030/// during the flow of control. `operands` is a set of optional attributes
1031/// that correspond to a constant value for each operand, or null if that
1032/// operand is not a constant.
1033void cir::ScopeOp::getSuccessorRegions(
1034 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1035 // The only region always branch back to the parent operation.
1036 if (!point.isParent()) {
1037 regions.push_back(RegionSuccessor(getODSResults(0)));
1038 return;
1039 }
1040
1041 // If the condition isn't constant, both regions may be executed.
1042 regions.push_back(RegionSuccessor(&getScopeRegion()));
1043}
1044
1045void cir::ScopeOp::build(
1046 OpBuilder &builder, OperationState &result,
1047 function_ref<void(OpBuilder &, Type &, Location)> scopeBuilder) {
1048 assert(scopeBuilder && "the builder callback for 'then' must be present");
1049
1050 OpBuilder::InsertionGuard guard(builder);
1051 Region *scopeRegion = result.addRegion();
1052 builder.createBlock(scopeRegion);
1054
1055 mlir::Type yieldTy;
1056 scopeBuilder(builder, yieldTy, result.location);
1057
1058 if (yieldTy)
1059 result.addTypes(TypeRange{yieldTy});
1060}
1061
1062void cir::ScopeOp::build(
1063 OpBuilder &builder, OperationState &result,
1064 function_ref<void(OpBuilder &, Location)> scopeBuilder) {
1065 assert(scopeBuilder && "the builder callback for 'then' must be present");
1066 OpBuilder::InsertionGuard guard(builder);
1067 Region *scopeRegion = result.addRegion();
1068 builder.createBlock(scopeRegion);
1070 scopeBuilder(builder, result.location);
1071}
1072
1073LogicalResult cir::ScopeOp::verify() {
1074 if (getRegion().empty()) {
1075 return emitOpError() << "cir.scope must not be empty since it should "
1076 "include at least an implicit cir.yield ";
1077 }
1078
1079 mlir::Block &lastBlock = getRegion().back();
1080 if (lastBlock.empty() || !lastBlock.mightHaveTerminator() ||
1081 !lastBlock.getTerminator()->hasTrait<OpTrait::IsTerminator>())
1082 return emitOpError() << "last block of cir.scope must be terminated";
1083 return success();
1084}
1085
1086//===----------------------------------------------------------------------===//
1087// BrOp
1088//===----------------------------------------------------------------------===//
1089
1090mlir::SuccessorOperands cir::BrOp::getSuccessorOperands(unsigned index) {
1091 assert(index == 0 && "invalid successor index");
1092 return mlir::SuccessorOperands(getDestOperandsMutable());
1093}
1094
1095Block *cir::BrOp::getSuccessorForOperands(ArrayRef<Attribute>) {
1096 return getDest();
1097}
1098
1099//===----------------------------------------------------------------------===//
1100// BrCondOp
1101//===----------------------------------------------------------------------===//
1102
1103mlir::SuccessorOperands cir::BrCondOp::getSuccessorOperands(unsigned index) {
1104 assert(index < getNumSuccessors() && "invalid successor index");
1105 return SuccessorOperands(index == 0 ? getDestOperandsTrueMutable()
1106 : getDestOperandsFalseMutable());
1107}
1108
1109Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1110 if (IntegerAttr condAttr = dyn_cast_if_present<IntegerAttr>(operands.front()))
1111 return condAttr.getValue().isOne() ? getDestTrue() : getDestFalse();
1112 return nullptr;
1113}
1114
1115//===----------------------------------------------------------------------===//
1116// CaseOp
1117//===----------------------------------------------------------------------===//
1118
1119void cir::CaseOp::getSuccessorRegions(
1120 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1121 if (!point.isParent()) {
1122 regions.push_back(RegionSuccessor());
1123 return;
1124 }
1125 regions.push_back(RegionSuccessor(&getCaseRegion()));
1126}
1127
1128void cir::CaseOp::build(OpBuilder &builder, OperationState &result,
1129 ArrayAttr value, CaseOpKind kind,
1130 OpBuilder::InsertPoint &insertPoint) {
1131 OpBuilder::InsertionGuard guardSwitch(builder);
1132 result.addAttribute("value", value);
1133 result.getOrAddProperties<Properties>().kind =
1134 cir::CaseOpKindAttr::get(builder.getContext(), kind);
1135 Region *caseRegion = result.addRegion();
1136 builder.createBlock(caseRegion);
1137
1138 insertPoint = builder.saveInsertionPoint();
1139}
1140
1141//===----------------------------------------------------------------------===//
1142// SwitchOp
1143//===----------------------------------------------------------------------===//
1144
1145static ParseResult parseSwitchOp(OpAsmParser &parser, mlir::Region &regions,
1146 mlir::OpAsmParser::UnresolvedOperand &cond,
1147 mlir::Type &condType) {
1148 cir::IntType intCondType;
1149
1150 if (parser.parseLParen())
1151 return mlir::failure();
1152
1153 if (parser.parseOperand(cond))
1154 return mlir::failure();
1155 if (parser.parseColon())
1156 return mlir::failure();
1157 if (parser.parseCustomTypeWithFallback(intCondType))
1158 return mlir::failure();
1159 condType = intCondType;
1160
1161 if (parser.parseRParen())
1162 return mlir::failure();
1163 if (parser.parseRegion(regions, /*arguments=*/{}, /*argTypes=*/{}))
1164 return failure();
1165
1166 return mlir::success();
1167}
1168
1169static void printSwitchOp(OpAsmPrinter &p, cir::SwitchOp op,
1170 mlir::Region &bodyRegion, mlir::Value condition,
1171 mlir::Type condType) {
1172 p << "(";
1173 p << condition;
1174 p << " : ";
1175 p.printStrippedAttrOrType(condType);
1176 p << ")";
1177
1178 p << ' ';
1179 p.printRegion(bodyRegion, /*printEntryBlockArgs=*/false,
1180 /*printBlockTerminators=*/true);
1181}
1182
1183void cir::SwitchOp::getSuccessorRegions(
1184 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &region) {
1185 if (!point.isParent()) {
1186 region.push_back(RegionSuccessor());
1187 return;
1188 }
1189
1190 region.push_back(RegionSuccessor(&getBody()));
1191}
1192
1193void cir::SwitchOp::build(OpBuilder &builder, OperationState &result,
1194 Value cond, BuilderOpStateCallbackRef switchBuilder) {
1195 assert(switchBuilder && "the builder callback for regions must be present");
1196 OpBuilder::InsertionGuard guardSwitch(builder);
1197 Region *switchRegion = result.addRegion();
1198 builder.createBlock(switchRegion);
1199 result.addOperands({cond});
1200 switchBuilder(builder, result.location, result);
1201}
1202
1203void cir::SwitchOp::collectCases(llvm::SmallVectorImpl<CaseOp> &cases) {
1204 walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
1205 // Don't walk in nested switch op.
1206 if (isa<cir::SwitchOp>(op) && op != *this)
1207 return WalkResult::skip();
1208
1209 if (auto caseOp = dyn_cast<cir::CaseOp>(op))
1210 cases.push_back(caseOp);
1211
1212 return WalkResult::advance();
1213 });
1214}
1215
1216bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
1217 collectCases(cases);
1218
1219 if (getBody().empty())
1220 return false;
1221
1222 if (!isa<YieldOp>(getBody().front().back()))
1223 return false;
1224
1225 if (!llvm::all_of(getBody().front(),
1226 [](Operation &op) { return isa<CaseOp, YieldOp>(op); }))
1227 return false;
1228
1229 return llvm::all_of(cases, [this](CaseOp op) {
1230 return op->getParentOfType<SwitchOp>() == *this;
1231 });
1232}
1233
1234//===----------------------------------------------------------------------===//
1235// SwitchFlatOp
1236//===----------------------------------------------------------------------===//
1237
1238void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
1239 Value value, Block *defaultDestination,
1240 ValueRange defaultOperands,
1241 ArrayRef<APInt> caseValues,
1242 BlockRange caseDestinations,
1243 ArrayRef<ValueRange> caseOperands) {
1244
1245 std::vector<mlir::Attribute> caseValuesAttrs;
1246 for (const APInt &val : caseValues)
1247 caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
1248 mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
1249
1250 build(builder, result, value, defaultOperands, caseOperands, attrs,
1251 defaultDestination, caseDestinations);
1252}
1253
1254/// <cases> ::= `[` (case (`,` case )* )? `]`
1255/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
1256static ParseResult parseSwitchFlatOpCases(
1257 OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
1258 SmallVectorImpl<Block *> &caseDestinations,
1260 &caseOperands,
1261 SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
1262 if (failed(parser.parseLSquare()))
1263 return failure();
1264 if (succeeded(parser.parseOptionalRSquare()))
1265 return success();
1267
1268 auto parseCase = [&]() {
1269 int64_t value = 0;
1270 if (failed(parser.parseInteger(value)))
1271 return failure();
1272
1273 values.push_back(cir::IntAttr::get(flagType, value));
1274
1275 Block *destination;
1277 llvm::SmallVector<Type> operandTypes;
1278 if (parser.parseColon() || parser.parseSuccessor(destination))
1279 return failure();
1280 if (!parser.parseOptionalLParen()) {
1281 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
1282 /*allowResultNumber=*/false) ||
1283 parser.parseColonTypeList(operandTypes) || parser.parseRParen())
1284 return failure();
1285 }
1286 caseDestinations.push_back(destination);
1287 caseOperands.emplace_back(operands);
1288 caseOperandTypes.emplace_back(operandTypes);
1289 return success();
1290 };
1291 if (failed(parser.parseCommaSeparatedList(parseCase)))
1292 return failure();
1293
1294 caseValues = ArrayAttr::get(flagType.getContext(), values);
1295
1296 return parser.parseRSquare();
1297}
1298
1299static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
1300 Type flagType, mlir::ArrayAttr caseValues,
1301 SuccessorRange caseDestinations,
1302 OperandRangeRange caseOperands,
1303 const TypeRangeRange &caseOperandTypes) {
1304 p << '[';
1305 p.printNewline();
1306 if (!caseValues) {
1307 p << ']';
1308 return;
1309 }
1310
1311 size_t index = 0;
1312 llvm::interleave(
1313 llvm::zip(caseValues, caseDestinations),
1314 [&](auto i) {
1315 p << " ";
1316 mlir::Attribute a = std::get<0>(i);
1317 p << mlir::cast<cir::IntAttr>(a).getValue();
1318 p << ": ";
1319 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
1320 },
1321 [&] {
1322 p << ',';
1323 p.printNewline();
1324 });
1325 p.printNewline();
1326 p << ']';
1327}
1328
1329//===----------------------------------------------------------------------===//
1330// GlobalOp
1331//===----------------------------------------------------------------------===//
1332
1333static ParseResult parseConstantValue(OpAsmParser &parser,
1334 mlir::Attribute &valueAttr) {
1335 NamedAttrList attr;
1336 return parser.parseAttribute(valueAttr, "value", attr);
1337}
1338
1339static void printConstant(OpAsmPrinter &p, Attribute value) {
1340 p.printAttribute(value);
1341}
1342
1343mlir::LogicalResult cir::GlobalOp::verify() {
1344 // Verify that the initial value, if present, is either a unit attribute or
1345 // an attribute CIR supports.
1346 if (getInitialValue().has_value()) {
1347 if (checkConstantTypes(getOperation(), getSymType(), *getInitialValue())
1348 .failed())
1349 return failure();
1350 }
1351
1352 // TODO(CIR): Many other checks for properties that haven't been upstreamed
1353 // yet.
1354
1355 return success();
1356}
1357
1358void cir::GlobalOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1359 llvm::StringRef sym_name, mlir::Type sym_type,
1360 bool isConstant, cir::GlobalLinkageKind linkage) {
1361 odsState.addAttribute(getSymNameAttrName(odsState.name),
1362 odsBuilder.getStringAttr(sym_name));
1363 odsState.addAttribute(getSymTypeAttrName(odsState.name),
1364 mlir::TypeAttr::get(sym_type));
1365 if (isConstant)
1366 odsState.addAttribute(getConstantAttrName(odsState.name),
1367 odsBuilder.getUnitAttr());
1368
1369 cir::GlobalLinkageKindAttr linkageAttr =
1370 cir::GlobalLinkageKindAttr::get(odsBuilder.getContext(), linkage);
1371 odsState.addAttribute(getLinkageAttrName(odsState.name), linkageAttr);
1372
1373 odsState.addAttribute(getGlobalVisibilityAttrName(odsState.name),
1374 cir::VisibilityAttr::get(odsBuilder.getContext()));
1375}
1376
1377static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op,
1378 TypeAttr type,
1379 Attribute initAttr) {
1380 if (!op.isDeclaration()) {
1381 p << "= ";
1382 // This also prints the type...
1383 if (initAttr)
1384 printConstant(p, initAttr);
1385 } else {
1386 p << ": " << type;
1387 }
1388}
1389
1390static ParseResult
1391parseGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1392 Attribute &initialValueAttr) {
1393 mlir::Type opTy;
1394 if (parser.parseOptionalEqual().failed()) {
1395 // Absence of equal means a declaration, so we need to parse the type.
1396 // cir.global @a : !cir.int<s, 32>
1397 if (parser.parseColonType(opTy))
1398 return failure();
1399 } else {
1400 // Parse constant with initializer, examples:
1401 // cir.global @y = #cir.fp<1.250000e+00> : !cir.double
1402 // cir.global @rgb = #cir.const_array<[...] : !cir.array<i8 x 3>>
1403 if (parseConstantValue(parser, initialValueAttr).failed())
1404 return failure();
1405
1406 assert(mlir::isa<mlir::TypedAttr>(initialValueAttr) &&
1407 "Non-typed attrs shouldn't appear here.");
1408 auto typedAttr = mlir::cast<mlir::TypedAttr>(initialValueAttr);
1409 opTy = typedAttr.getType();
1410 }
1411
1412 typeAttr = TypeAttr::get(opTy);
1413 return success();
1414}
1415
1416//===----------------------------------------------------------------------===//
1417// GetGlobalOp
1418//===----------------------------------------------------------------------===//
1419
1420LogicalResult
1421cir::GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1422 // Verify that the result type underlying pointer type matches the type of
1423 // the referenced cir.global or cir.func op.
1424 mlir::Operation *op =
1425 symbolTable.lookupNearestSymbolFrom(*this, getNameAttr());
1426 if (op == nullptr || !(isa<GlobalOp>(op) || isa<FuncOp>(op)))
1427 return emitOpError("'")
1428 << getName()
1429 << "' does not reference a valid cir.global or cir.func";
1430
1431 mlir::Type symTy;
1432 if (auto g = dyn_cast<GlobalOp>(op)) {
1433 symTy = g.getSymType();
1436 } else if (auto f = dyn_cast<FuncOp>(op)) {
1437 symTy = f.getFunctionType();
1438 } else {
1439 llvm_unreachable("Unexpected operation for GetGlobalOp");
1440 }
1441
1442 auto resultType = dyn_cast<PointerType>(getAddr().getType());
1443 if (!resultType || symTy != resultType.getPointee())
1444 return emitOpError("result type pointee type '")
1445 << resultType.getPointee() << "' does not match type " << symTy
1446 << " of the global @" << getName();
1447
1448 return success();
1449}
1450
1451//===----------------------------------------------------------------------===//
1452// VTableAddrPointOp
1453//===----------------------------------------------------------------------===//
1454
1455LogicalResult
1456cir::VTableAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1457 StringRef name = getName();
1458
1459 // Verify that the result type underlying pointer type matches the type of
1460 // the referenced cir.global.
1461 auto op =
1462 symbolTable.lookupNearestSymbolFrom<cir::GlobalOp>(*this, getNameAttr());
1463 if (!op)
1464 return emitOpError("'")
1465 << name << "' does not reference a valid cir.global";
1466 std::optional<mlir::Attribute> init = op.getInitialValue();
1467 if (!init)
1468 return success();
1469 if (!isa<cir::VTableAttr>(*init))
1470 return emitOpError("Expected #cir.vtable in initializer for global '")
1471 << name << "'";
1472 return success();
1473}
1474
1475//===----------------------------------------------------------------------===//
1476// VTTAddrPointOp
1477//===----------------------------------------------------------------------===//
1478
1479LogicalResult
1480cir::VTTAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1481 // VTT ptr is not coming from a symbol.
1482 if (!getName())
1483 return success();
1484 StringRef name = *getName();
1485
1486 // Verify that the result type underlying pointer type matches the type of
1487 // the referenced cir.global op.
1488 auto op =
1489 symbolTable.lookupNearestSymbolFrom<cir::GlobalOp>(*this, getNameAttr());
1490 if (!op)
1491 return emitOpError("'")
1492 << name << "' does not reference a valid cir.global";
1493 std::optional<mlir::Attribute> init = op.getInitialValue();
1494 if (!init)
1495 return success();
1496 if (!isa<cir::ConstArrayAttr>(*init))
1497 return emitOpError(
1498 "Expected constant array in initializer for global VTT '")
1499 << name << "'";
1500 return success();
1501}
1502
1503LogicalResult cir::VTTAddrPointOp::verify() {
1504 // The operation uses either a symbol or a value to operate, but not both
1505 if (getName() && getSymAddr())
1506 return emitOpError("should use either a symbol or value, but not both");
1507
1508 // If not a symbol, stick with the concrete type used for getSymAddr.
1509 if (getSymAddr())
1510 return success();
1511
1512 mlir::Type resultType = getAddr().getType();
1513 mlir::Type resTy = cir::PointerType::get(
1514 cir::PointerType::get(cir::VoidType::get(getContext())));
1515
1516 if (resultType != resTy)
1517 return emitOpError("result type must be ")
1518 << resTy << ", but provided result type is " << resultType;
1519 return success();
1520}
1521
1522//===----------------------------------------------------------------------===//
1523// FuncOp
1524//===----------------------------------------------------------------------===//
1525
1526/// Returns the name used for the linkage attribute. This *must* correspond to
1527/// the name of the attribute in ODS.
1528static llvm::StringRef getLinkageAttrNameString() { return "linkage"; }
1529
1530void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
1531 StringRef name, FuncType type,
1532 GlobalLinkageKind linkage) {
1533 result.addRegion();
1534 result.addAttribute(SymbolTable::getSymbolAttrName(),
1535 builder.getStringAttr(name));
1536 result.addAttribute(getFunctionTypeAttrName(result.name),
1537 TypeAttr::get(type));
1538 result.addAttribute(
1540 GlobalLinkageKindAttr::get(builder.getContext(), linkage));
1541 result.addAttribute(getGlobalVisibilityAttrName(result.name),
1542 cir::VisibilityAttr::get(builder.getContext()));
1543}
1544
1545ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
1546 llvm::SMLoc loc = parser.getCurrentLocation();
1547 mlir::Builder &builder = parser.getBuilder();
1548
1549 mlir::StringAttr lambdaNameAttr = getLambdaAttrName(state.name);
1550 mlir::StringAttr noProtoNameAttr = getNoProtoAttrName(state.name);
1551 mlir::StringAttr visNameAttr = getSymVisibilityAttrName(state.name);
1552 mlir::StringAttr visibilityNameAttr = getGlobalVisibilityAttrName(state.name);
1553 mlir::StringAttr dsoLocalNameAttr = getDsoLocalAttrName(state.name);
1554
1555 if (::mlir::succeeded(parser.parseOptionalKeyword(lambdaNameAttr.strref())))
1556 state.addAttribute(lambdaNameAttr, parser.getBuilder().getUnitAttr());
1557 if (parser.parseOptionalKeyword(noProtoNameAttr).succeeded())
1558 state.addAttribute(noProtoNameAttr, parser.getBuilder().getUnitAttr());
1559
1560 // Default to external linkage if no keyword is provided.
1561 state.addAttribute(getLinkageAttrNameString(),
1562 GlobalLinkageKindAttr::get(
1563 parser.getContext(),
1565 parser, GlobalLinkageKind::ExternalLinkage)));
1566
1567 ::llvm::StringRef visAttrStr;
1568 if (parser.parseOptionalKeyword(&visAttrStr, {"private", "public", "nested"})
1569 .succeeded()) {
1570 state.addAttribute(visNameAttr,
1571 parser.getBuilder().getStringAttr(visAttrStr));
1572 }
1573
1574 cir::VisibilityAttr cirVisibilityAttr;
1575 parseVisibilityAttr(parser, cirVisibilityAttr);
1576 state.addAttribute(visibilityNameAttr, cirVisibilityAttr);
1577
1578 if (parser.parseOptionalKeyword(dsoLocalNameAttr).succeeded())
1579 state.addAttribute(dsoLocalNameAttr, parser.getBuilder().getUnitAttr());
1580
1581 StringAttr nameAttr;
1582 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1583 state.attributes))
1584 return failure();
1588 bool isVariadic = false;
1589 if (function_interface_impl::parseFunctionSignatureWithArguments(
1590 parser, /*allowVariadic=*/true, arguments, isVariadic, resultTypes,
1591 resultAttrs))
1592 return failure();
1594 for (OpAsmParser::Argument &arg : arguments)
1595 argTypes.push_back(arg.type);
1596
1597 if (resultTypes.size() > 1) {
1598 return parser.emitError(
1599 loc, "functions with multiple return types are not supported");
1600 }
1601
1602 mlir::Type returnType =
1603 (resultTypes.empty() ? cir::VoidType::get(builder.getContext())
1604 : resultTypes.front());
1605
1606 cir::FuncType fnType = cir::FuncType::get(argTypes, returnType, isVariadic);
1607 if (!fnType)
1608 return failure();
1609 state.addAttribute(getFunctionTypeAttrName(state.name),
1610 TypeAttr::get(fnType));
1611
1612 bool hasAlias = false;
1613 mlir::StringAttr aliaseeNameAttr = getAliaseeAttrName(state.name);
1614 if (parser.parseOptionalKeyword("alias").succeeded()) {
1615 if (parser.parseLParen().failed())
1616 return failure();
1617 mlir::StringAttr aliaseeAttr;
1618 if (parser.parseOptionalSymbolName(aliaseeAttr).failed())
1619 return failure();
1620 state.addAttribute(aliaseeNameAttr, FlatSymbolRefAttr::get(aliaseeAttr));
1621 if (parser.parseRParen().failed())
1622 return failure();
1623 hasAlias = true;
1624 }
1625
1626 // Parse the optional function body.
1627 auto *body = state.addRegion();
1628 OptionalParseResult parseResult = parser.parseOptionalRegion(
1629 *body, arguments, /*enableNameShadowing=*/false);
1630 if (parseResult.has_value()) {
1631 if (hasAlias)
1632 return parser.emitError(loc, "function alias shall not have a body");
1633 if (failed(*parseResult))
1634 return failure();
1635 // Function body was parsed, make sure its not empty.
1636 if (body->empty())
1637 return parser.emitError(loc, "expected non-empty function body");
1638 }
1639
1640 return success();
1641}
1642
1643// This function corresponds to `llvm::GlobalValue::isDeclaration` and should
1644// have a similar implementation. We don't currently ifuncs or materializable
1645// functions, but those should be handled here as they are implemented.
1646bool cir::FuncOp::isDeclaration() {
1648
1649 std::optional<StringRef> aliasee = getAliasee();
1650 if (!aliasee)
1651 return getFunctionBody().empty();
1652
1653 // Aliases are always definitions.
1654 return false;
1655}
1656
1657mlir::Region *cir::FuncOp::getCallableRegion() {
1658 // TODO(CIR): This function will have special handling for aliases and a
1659 // check for an external function, once those features have been upstreamed.
1660 return &getBody();
1661}
1662
1663void cir::FuncOp::print(OpAsmPrinter &p) {
1664 if (getLambda())
1665 p << " lambda";
1666
1667 if (getNoProto())
1668 p << " no_proto";
1669
1670 if (getComdat())
1671 p << " comdat";
1672
1673 if (getLinkage() != GlobalLinkageKind::ExternalLinkage)
1674 p << ' ' << stringifyGlobalLinkageKind(getLinkage());
1675
1676 mlir::SymbolTable::Visibility vis = getVisibility();
1677 if (vis != mlir::SymbolTable::Visibility::Public)
1678 p << ' ' << vis;
1679
1680 cir::VisibilityAttr cirVisibilityAttr = getGlobalVisibilityAttr();
1681 if (!cirVisibilityAttr.isDefault()) {
1682 p << ' ';
1683 printVisibilityAttr(p, cirVisibilityAttr);
1684 }
1685
1686 if (getDsoLocal())
1687 p << " dso_local";
1688
1689 p << ' ';
1690 p.printSymbolName(getSymName());
1691 cir::FuncType fnType = getFunctionType();
1692 function_interface_impl::printFunctionSignature(
1693 p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes());
1694
1695 if (std::optional<StringRef> aliaseeName = getAliasee()) {
1696 p << " alias(";
1697 p.printSymbolName(*aliaseeName);
1698 p << ")";
1699 }
1700
1701 // Print the body if this is not an external function.
1702 Region &body = getOperation()->getRegion(0);
1703 if (!body.empty()) {
1704 p << ' ';
1705 p.printRegion(body, /*printEntryBlockArgs=*/false,
1706 /*printBlockTerminators=*/true);
1707 }
1708}
1709
1710mlir::LogicalResult cir::FuncOp::verify() {
1711
1712 llvm::SmallSet<llvm::StringRef, 16> labels;
1713 llvm::SmallSet<llvm::StringRef, 16> gotos;
1714
1715 getOperation()->walk([&](mlir::Operation *op) {
1716 if (auto lab = dyn_cast<cir::LabelOp>(op)) {
1717 labels.insert(lab.getLabel());
1718 } else if (auto goTo = dyn_cast<cir::GotoOp>(op)) {
1719 gotos.insert(goTo.getLabel());
1720 }
1721 });
1722
1723 if (!labels.empty() || !gotos.empty()) {
1724 llvm::SmallSet<llvm::StringRef, 16> mismatched =
1725 llvm::set_difference(gotos, labels);
1726
1727 if (!mismatched.empty())
1728 return emitOpError() << "goto/label mismatch";
1729 }
1730 return success();
1731}
1732
1733//===----------------------------------------------------------------------===//
1734// BinOp
1735//===----------------------------------------------------------------------===//
1736LogicalResult cir::BinOp::verify() {
1737 bool noWrap = getNoUnsignedWrap() || getNoSignedWrap();
1738 bool saturated = getSaturated();
1739
1740 if (!isa<cir::IntType>(getType()) && noWrap)
1741 return emitError()
1742 << "only operations on integer values may have nsw/nuw flags";
1743
1744 bool noWrapOps = getKind() == cir::BinOpKind::Add ||
1745 getKind() == cir::BinOpKind::Sub ||
1746 getKind() == cir::BinOpKind::Mul;
1747
1748 bool saturatedOps =
1749 getKind() == cir::BinOpKind::Add || getKind() == cir::BinOpKind::Sub;
1750
1751 if (noWrap && !noWrapOps)
1752 return emitError() << "The nsw/nuw flags are applicable to opcodes: 'add', "
1753 "'sub' and 'mul'";
1754 if (saturated && !saturatedOps)
1755 return emitError() << "The saturated flag is applicable to opcodes: 'add' "
1756 "and 'sub'";
1757 if (noWrap && saturated)
1758 return emitError() << "The nsw/nuw flags and the saturated flag are "
1759 "mutually exclusive";
1760
1761 return mlir::success();
1762}
1763
1764//===----------------------------------------------------------------------===//
1765// TernaryOp
1766//===----------------------------------------------------------------------===//
1767
1768/// Given the region at `point`, or the parent operation if `point` is None,
1769/// return the successor regions. These are the regions that may be selected
1770/// during the flow of control. `operands` is a set of optional attributes that
1771/// correspond to a constant value for each operand, or null if that operand is
1772/// not a constant.
1773void cir::TernaryOp::getSuccessorRegions(
1774 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1775 // The `true` and the `false` region branch back to the parent operation.
1776 if (!point.isParent()) {
1777 regions.push_back(RegionSuccessor(this->getODSResults(0)));
1778 return;
1779 }
1780
1781 // When branching from the parent operation, both the true and false
1782 // regions are considered possible successors
1783 regions.push_back(RegionSuccessor(&getTrueRegion()));
1784 regions.push_back(RegionSuccessor(&getFalseRegion()));
1785}
1786
1787void cir::TernaryOp::build(
1788 OpBuilder &builder, OperationState &result, Value cond,
1789 function_ref<void(OpBuilder &, Location)> trueBuilder,
1790 function_ref<void(OpBuilder &, Location)> falseBuilder) {
1791 result.addOperands(cond);
1792 OpBuilder::InsertionGuard guard(builder);
1793 Region *trueRegion = result.addRegion();
1794 Block *block = builder.createBlock(trueRegion);
1795 trueBuilder(builder, result.location);
1796 Region *falseRegion = result.addRegion();
1797 builder.createBlock(falseRegion);
1798 falseBuilder(builder, result.location);
1799
1800 auto yield = dyn_cast<YieldOp>(block->getTerminator());
1801 assert((yield && yield.getNumOperands() <= 1) &&
1802 "expected zero or one result type");
1803 if (yield.getNumOperands() == 1)
1804 result.addTypes(TypeRange{yield.getOperandTypes().front()});
1805}
1806
1807//===----------------------------------------------------------------------===//
1808// SelectOp
1809//===----------------------------------------------------------------------===//
1810
1811OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
1812 mlir::Attribute condition = adaptor.getCondition();
1813 if (condition) {
1814 bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
1815 return conditionValue ? getTrueValue() : getFalseValue();
1816 }
1817
1818 // cir.select if %0 then x else x -> x
1819 mlir::Attribute trueValue = adaptor.getTrueValue();
1820 mlir::Attribute falseValue = adaptor.getFalseValue();
1821 if (trueValue == falseValue)
1822 return trueValue;
1823 if (getTrueValue() == getFalseValue())
1824 return getTrueValue();
1825
1826 return {};
1827}
1828
1829//===----------------------------------------------------------------------===//
1830// ShiftOp
1831//===----------------------------------------------------------------------===//
1832LogicalResult cir::ShiftOp::verify() {
1833 mlir::Operation *op = getOperation();
1834 auto op0VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(0).getType());
1835 auto op1VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(1).getType());
1836 if (!op0VecTy ^ !op1VecTy)
1837 return emitOpError() << "input types cannot be one vector and one scalar";
1838
1839 if (op0VecTy) {
1840 if (op0VecTy.getSize() != op1VecTy.getSize())
1841 return emitOpError() << "input vector types must have the same size";
1842
1843 auto opResultTy = mlir::dyn_cast<cir::VectorType>(getType());
1844 if (!opResultTy)
1845 return emitOpError() << "the type of the result must be a vector "
1846 << "if it is vector shift";
1847
1848 auto op0VecEleTy = mlir::cast<cir::IntType>(op0VecTy.getElementType());
1849 auto op1VecEleTy = mlir::cast<cir::IntType>(op1VecTy.getElementType());
1850 if (op0VecEleTy.getWidth() != op1VecEleTy.getWidth())
1851 return emitOpError()
1852 << "vector operands do not have the same elements sizes";
1853
1854 auto resVecEleTy = mlir::cast<cir::IntType>(opResultTy.getElementType());
1855 if (op0VecEleTy.getWidth() != resVecEleTy.getWidth())
1856 return emitOpError() << "vector operands and result type do not have the "
1857 "same elements sizes";
1858 }
1859
1860 return mlir::success();
1861}
1862
1863//===----------------------------------------------------------------------===//
1864// LabelOp Definitions
1865//===----------------------------------------------------------------------===//
1866
1867LogicalResult cir::LabelOp::verify() {
1868 mlir::Operation *op = getOperation();
1869 mlir::Block *blk = op->getBlock();
1870 if (&blk->front() != op)
1871 return emitError() << "must be the first operation in a block";
1872
1873 return mlir::success();
1874}
1875
1876//===----------------------------------------------------------------------===//
1877// UnaryOp
1878//===----------------------------------------------------------------------===//
1879
1880LogicalResult cir::UnaryOp::verify() {
1881 switch (getKind()) {
1882 case cir::UnaryOpKind::Inc:
1883 case cir::UnaryOpKind::Dec:
1884 case cir::UnaryOpKind::Plus:
1885 case cir::UnaryOpKind::Minus:
1886 case cir::UnaryOpKind::Not:
1887 // Nothing to verify.
1888 return success();
1889 }
1890
1891 llvm_unreachable("Unknown UnaryOp kind?");
1892}
1893
1894static bool isBoolNot(cir::UnaryOp op) {
1895 return isa<cir::BoolType>(op.getInput().getType()) &&
1896 op.getKind() == cir::UnaryOpKind::Not;
1897}
1898
1899// This folder simplifies the sequential boolean not operations.
1900// For instance, the next two unary operations will be eliminated:
1901//
1902// ```mlir
1903// %1 = cir.unary(not, %0) : !cir.bool, !cir.bool
1904// %2 = cir.unary(not, %1) : !cir.bool, !cir.bool
1905// ```
1906//
1907// and the argument of the first one (%0) will be used instead.
1908OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) {
1909 if (auto poison =
1910 mlir::dyn_cast_if_present<cir::PoisonAttr>(adaptor.getInput())) {
1911 // Propagate poison values
1912 return poison;
1913 }
1914
1915 if (isBoolNot(*this))
1916 if (auto previous = getInput().getDefiningOp<cir::UnaryOp>())
1917 if (isBoolNot(previous))
1918 return previous.getInput();
1919
1920 return {};
1921}
1922
1923//===----------------------------------------------------------------------===//
1924// CopyOp Definitions
1925//===----------------------------------------------------------------------===//
1926
1927LogicalResult cir::CopyOp::verify() {
1928 // A data layout is required for us to know the number of bytes to be copied.
1929 if (!getType().getPointee().hasTrait<DataLayoutTypeInterface::Trait>())
1930 return emitError() << "missing data layout for pointee type";
1931
1932 if (getSrc() == getDst())
1933 return emitError() << "source and destination are the same";
1934
1935 return mlir::success();
1936}
1937
1938//===----------------------------------------------------------------------===//
1939// GetMemberOp Definitions
1940//===----------------------------------------------------------------------===//
1941
1942LogicalResult cir::GetMemberOp::verify() {
1943 const auto recordTy = dyn_cast<RecordType>(getAddrTy().getPointee());
1944 if (!recordTy)
1945 return emitError() << "expected pointer to a record type";
1946
1947 if (recordTy.getMembers().size() <= getIndex())
1948 return emitError() << "member index out of bounds";
1949
1950 if (recordTy.getMembers()[getIndex()] != getType().getPointee())
1951 return emitError() << "member type mismatch";
1952
1953 return mlir::success();
1954}
1955
1956//===----------------------------------------------------------------------===//
1957// VecCreateOp
1958//===----------------------------------------------------------------------===//
1959
1960OpFoldResult cir::VecCreateOp::fold(FoldAdaptor adaptor) {
1961 if (llvm::any_of(getElements(), [](mlir::Value value) {
1962 return !value.getDefiningOp<cir::ConstantOp>();
1963 }))
1964 return {};
1965
1966 return cir::ConstVectorAttr::get(
1967 getType(), mlir::ArrayAttr::get(getContext(), adaptor.getElements()));
1968}
1969
1970LogicalResult cir::VecCreateOp::verify() {
1971 // Verify that the number of arguments matches the number of elements in the
1972 // vector, and that the type of all the arguments matches the type of the
1973 // elements in the vector.
1974 const cir::VectorType vecTy = getType();
1975 if (getElements().size() != vecTy.getSize()) {
1976 return emitOpError() << "operand count of " << getElements().size()
1977 << " doesn't match vector type " << vecTy
1978 << " element count of " << vecTy.getSize();
1979 }
1980
1981 const mlir::Type elementType = vecTy.getElementType();
1982 for (const mlir::Value element : getElements()) {
1983 if (element.getType() != elementType) {
1984 return emitOpError() << "operand type " << element.getType()
1985 << " doesn't match vector element type "
1986 << elementType;
1987 }
1988 }
1989
1990 return success();
1991}
1992
1993//===----------------------------------------------------------------------===//
1994// VecExtractOp
1995//===----------------------------------------------------------------------===//
1996
1997OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
1998 const auto vectorAttr =
1999 llvm::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec());
2000 if (!vectorAttr)
2001 return {};
2002
2003 const auto indexAttr =
2004 llvm::dyn_cast_if_present<cir::IntAttr>(adaptor.getIndex());
2005 if (!indexAttr)
2006 return {};
2007
2008 const mlir::ArrayAttr elements = vectorAttr.getElts();
2009 const uint64_t index = indexAttr.getUInt();
2010 if (index >= elements.size())
2011 return {};
2012
2013 return elements[index];
2014}
2015
2016//===----------------------------------------------------------------------===//
2017// VecCmpOp
2018//===----------------------------------------------------------------------===//
2019
2020OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
2021 auto lhsVecAttr =
2022 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getLhs());
2023 auto rhsVecAttr =
2024 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getRhs());
2025 if (!lhsVecAttr || !rhsVecAttr)
2026 return {};
2027
2028 mlir::Type inputElemTy =
2029 mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType();
2030 if (!isAnyIntegerOrFloatingPointType(inputElemTy))
2031 return {};
2032
2033 cir::CmpOpKind opKind = adaptor.getKind();
2034 mlir::ArrayAttr lhsVecElhs = lhsVecAttr.getElts();
2035 mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts();
2036 uint64_t vecSize = lhsVecElhs.size();
2037
2038 SmallVector<mlir::Attribute, 16> elements(vecSize);
2039 bool isIntAttr = vecSize && mlir::isa<cir::IntAttr>(lhsVecElhs[0]);
2040 for (uint64_t i = 0; i < vecSize; i++) {
2041 mlir::Attribute lhsAttr = lhsVecElhs[i];
2042 mlir::Attribute rhsAttr = rhsVecElhs[i];
2043 int cmpResult = 0;
2044 switch (opKind) {
2045 case cir::CmpOpKind::lt: {
2046 if (isIntAttr) {
2047 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
2048 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2049 } else {
2050 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <
2051 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2052 }
2053 break;
2054 }
2055 case cir::CmpOpKind::le: {
2056 if (isIntAttr) {
2057 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
2058 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2059 } else {
2060 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <=
2061 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2062 }
2063 break;
2064 }
2065 case cir::CmpOpKind::gt: {
2066 if (isIntAttr) {
2067 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
2068 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2069 } else {
2070 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >
2071 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2072 }
2073 break;
2074 }
2075 case cir::CmpOpKind::ge: {
2076 if (isIntAttr) {
2077 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
2078 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2079 } else {
2080 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >=
2081 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2082 }
2083 break;
2084 }
2085 case cir::CmpOpKind::eq: {
2086 if (isIntAttr) {
2087 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
2088 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2089 } else {
2090 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() ==
2091 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2092 }
2093 break;
2094 }
2095 case cir::CmpOpKind::ne: {
2096 if (isIntAttr) {
2097 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
2098 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
2099 } else {
2100 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() !=
2101 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
2102 }
2103 break;
2104 }
2105 }
2106
2107 elements[i] = cir::IntAttr::get(getType().getElementType(), cmpResult);
2108 }
2109
2110 return cir::ConstVectorAttr::get(
2111 getType(), mlir::ArrayAttr::get(getContext(), elements));
2112}
2113
2114//===----------------------------------------------------------------------===//
2115// VecShuffleOp
2116//===----------------------------------------------------------------------===//
2117
2118OpFoldResult cir::VecShuffleOp::fold(FoldAdaptor adaptor) {
2119 auto vec1Attr =
2120 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec1());
2121 auto vec2Attr =
2122 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec2());
2123 if (!vec1Attr || !vec2Attr)
2124 return {};
2125
2126 mlir::Type vec1ElemTy =
2127 mlir::cast<cir::VectorType>(vec1Attr.getType()).getElementType();
2128
2129 mlir::ArrayAttr vec1Elts = vec1Attr.getElts();
2130 mlir::ArrayAttr vec2Elts = vec2Attr.getElts();
2131 mlir::ArrayAttr indicesElts = adaptor.getIndices();
2132
2134 elements.reserve(indicesElts.size());
2135
2136 uint64_t vec1Size = vec1Elts.size();
2137 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
2138 if (idxAttr.getSInt() == -1) {
2139 elements.push_back(cir::UndefAttr::get(vec1ElemTy));
2140 continue;
2141 }
2142
2143 uint64_t idxValue = idxAttr.getUInt();
2144 elements.push_back(idxValue < vec1Size ? vec1Elts[idxValue]
2145 : vec2Elts[idxValue - vec1Size]);
2146 }
2147
2148 return cir::ConstVectorAttr::get(
2149 getType(), mlir::ArrayAttr::get(getContext(), elements));
2150}
2151
2152LogicalResult cir::VecShuffleOp::verify() {
2153 // The number of elements in the indices array must match the number of
2154 // elements in the result type.
2155 if (getIndices().size() != getResult().getType().getSize()) {
2156 return emitOpError() << ": the number of elements in " << getIndices()
2157 << " and " << getResult().getType() << " don't match";
2158 }
2159
2160 // The element types of the two input vectors and of the result type must
2161 // match.
2162 if (getVec1().getType().getElementType() !=
2163 getResult().getType().getElementType()) {
2164 return emitOpError() << ": element types of " << getVec1().getType()
2165 << " and " << getResult().getType() << " don't match";
2166 }
2167
2168 const uint64_t maxValidIndex =
2169 getVec1().getType().getSize() + getVec2().getType().getSize() - 1;
2170 if (llvm::any_of(
2171 getIndices().getAsRange<cir::IntAttr>(), [&](cir::IntAttr idxAttr) {
2172 return idxAttr.getSInt() != -1 && idxAttr.getUInt() > maxValidIndex;
2173 })) {
2174 return emitOpError() << ": index for __builtin_shufflevector must be "
2175 "less than the total number of vector elements";
2176 }
2177 return success();
2178}
2179
2180//===----------------------------------------------------------------------===//
2181// VecShuffleDynamicOp
2182//===----------------------------------------------------------------------===//
2183
2184OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
2185 mlir::Attribute vec = adaptor.getVec();
2186 mlir::Attribute indices = adaptor.getIndices();
2187 if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec) &&
2188 mlir::isa_and_nonnull<cir::ConstVectorAttr>(indices)) {
2189 auto vecAttr = mlir::cast<cir::ConstVectorAttr>(vec);
2190 auto indicesAttr = mlir::cast<cir::ConstVectorAttr>(indices);
2191
2192 mlir::ArrayAttr vecElts = vecAttr.getElts();
2193 mlir::ArrayAttr indicesElts = indicesAttr.getElts();
2194
2195 const uint64_t numElements = vecElts.size();
2196
2198 elements.reserve(numElements);
2199
2200 const uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1;
2201 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
2202 uint64_t idxValue = idxAttr.getUInt();
2203 uint64_t newIdx = idxValue & maskBits;
2204 elements.push_back(vecElts[newIdx]);
2205 }
2206
2207 return cir::ConstVectorAttr::get(
2208 getType(), mlir::ArrayAttr::get(getContext(), elements));
2209 }
2210
2211 return {};
2212}
2213
2214LogicalResult cir::VecShuffleDynamicOp::verify() {
2215 // The number of elements in the two input vectors must match.
2216 if (getVec().getType().getSize() !=
2217 mlir::cast<cir::VectorType>(getIndices().getType()).getSize()) {
2218 return emitOpError() << ": the number of elements in " << getVec().getType()
2219 << " and " << getIndices().getType() << " don't match";
2220 }
2221 return success();
2222}
2223
2224//===----------------------------------------------------------------------===//
2225// VecTernaryOp
2226//===----------------------------------------------------------------------===//
2227
2228LogicalResult cir::VecTernaryOp::verify() {
2229 // Verify that the condition operand has the same number of elements as the
2230 // other operands. (The automatic verification already checked that all
2231 // operands are vector types and that the second and third operands are the
2232 // same type.)
2233 if (getCond().getType().getSize() != getLhs().getType().getSize()) {
2234 return emitOpError() << ": the number of elements in "
2235 << getCond().getType() << " and " << getLhs().getType()
2236 << " don't match";
2237 }
2238 return success();
2239}
2240
2241OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
2242 mlir::Attribute cond = adaptor.getCond();
2243 mlir::Attribute lhs = adaptor.getLhs();
2244 mlir::Attribute rhs = adaptor.getRhs();
2245
2246 if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
2247 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
2248 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
2249 return {};
2250 auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
2251 auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
2252 auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
2253
2254 mlir::ArrayAttr condElts = condVec.getElts();
2255
2257 elements.reserve(condElts.size());
2258
2259 for (const auto &[idx, condAttr] :
2260 llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
2261 if (condAttr.getSInt()) {
2262 elements.push_back(lhsVec.getElts()[idx]);
2263 } else {
2264 elements.push_back(rhsVec.getElts()[idx]);
2265 }
2266 }
2267
2268 cir::VectorType vecTy = getLhs().getType();
2269 return cir::ConstVectorAttr::get(
2270 vecTy, mlir::ArrayAttr::get(getContext(), elements));
2271}
2272
2273//===----------------------------------------------------------------------===//
2274// ComplexCreateOp
2275//===----------------------------------------------------------------------===//
2276
2277LogicalResult cir::ComplexCreateOp::verify() {
2278 if (getType().getElementType() != getReal().getType()) {
2279 emitOpError()
2280 << "operand type of cir.complex.create does not match its result type";
2281 return failure();
2282 }
2283
2284 return success();
2285}
2286
2287OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
2288 mlir::Attribute real = adaptor.getReal();
2289 mlir::Attribute imag = adaptor.getImag();
2290 if (!real || !imag)
2291 return {};
2292
2293 // When both of real and imag are constants, we can fold the operation into an
2294 // `#cir.const_complex` operation.
2295 auto realAttr = mlir::cast<mlir::TypedAttr>(real);
2296 auto imagAttr = mlir::cast<mlir::TypedAttr>(imag);
2297 return cir::ConstComplexAttr::get(realAttr, imagAttr);
2298}
2299
2300//===----------------------------------------------------------------------===//
2301// ComplexRealOp
2302//===----------------------------------------------------------------------===//
2303
2304LogicalResult cir::ComplexRealOp::verify() {
2305 if (getType() != getOperand().getType().getElementType()) {
2306 emitOpError() << ": result type does not match operand type";
2307 return failure();
2308 }
2309 return success();
2310}
2311
2312OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
2313 if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
2314 return complexCreateOp.getOperand(0);
2315
2316 auto complex =
2317 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
2318 return complex ? complex.getReal() : nullptr;
2319}
2320
2321//===----------------------------------------------------------------------===//
2322// ComplexImagOp
2323//===----------------------------------------------------------------------===//
2324
2325LogicalResult cir::ComplexImagOp::verify() {
2326 if (getType() != getOperand().getType().getElementType()) {
2327 emitOpError() << ": result type does not match operand type";
2328 return failure();
2329 }
2330 return success();
2331}
2332
2333OpFoldResult cir::ComplexImagOp::fold(FoldAdaptor adaptor) {
2334 if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
2335 return complexCreateOp.getOperand(1);
2336
2337 auto complex =
2338 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
2339 return complex ? complex.getImag() : nullptr;
2340}
2341
2342//===----------------------------------------------------------------------===//
2343// ComplexRealPtrOp
2344//===----------------------------------------------------------------------===//
2345
2346LogicalResult cir::ComplexRealPtrOp::verify() {
2347 mlir::Type resultPointeeTy = getType().getPointee();
2348 cir::PointerType operandPtrTy = getOperand().getType();
2349 auto operandPointeeTy =
2350 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
2351
2352 if (resultPointeeTy != operandPointeeTy.getElementType()) {
2353 return emitOpError() << ": result type does not match operand type";
2354 }
2355
2356 return success();
2357}
2358
2359//===----------------------------------------------------------------------===//
2360// ComplexImagPtrOp
2361//===----------------------------------------------------------------------===//
2362
2363LogicalResult cir::ComplexImagPtrOp::verify() {
2364 mlir::Type resultPointeeTy = getType().getPointee();
2365 cir::PointerType operandPtrTy = getOperand().getType();
2366 auto operandPointeeTy =
2367 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
2368
2369 if (resultPointeeTy != operandPointeeTy.getElementType()) {
2370 return emitOpError()
2371 << "cir.complex.imag_ptr result type does not match operand type";
2372 }
2373 return success();
2374}
2375
2376//===----------------------------------------------------------------------===//
2377// Bit manipulation operations
2378//===----------------------------------------------------------------------===//
2379
2380static OpFoldResult
2381foldUnaryBitOp(mlir::Attribute inputAttr,
2382 llvm::function_ref<llvm::APInt(const llvm::APInt &)> func,
2383 bool poisonZero = false) {
2384 if (mlir::isa_and_present<cir::PoisonAttr>(inputAttr)) {
2385 // Propagate poison value
2386 return inputAttr;
2387 }
2388
2389 auto input = mlir::dyn_cast_if_present<IntAttr>(inputAttr);
2390 if (!input)
2391 return nullptr;
2392
2393 llvm::APInt inputValue = input.getValue();
2394 if (poisonZero && inputValue.isZero())
2395 return cir::PoisonAttr::get(input.getType());
2396
2397 llvm::APInt resultValue = func(inputValue);
2398 return IntAttr::get(input.getType(), resultValue);
2399}
2400
2401OpFoldResult BitClrsbOp::fold(FoldAdaptor adaptor) {
2402 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
2403 unsigned resultValue =
2404 inputValue.getBitWidth() - inputValue.getSignificantBits();
2405 return llvm::APInt(inputValue.getBitWidth(), resultValue);
2406 });
2407}
2408
2409OpFoldResult BitClzOp::fold(FoldAdaptor adaptor) {
2410 return foldUnaryBitOp(
2411 adaptor.getInput(),
2412 [](const llvm::APInt &inputValue) {
2413 unsigned resultValue = inputValue.countLeadingZeros();
2414 return llvm::APInt(inputValue.getBitWidth(), resultValue);
2415 },
2416 getPoisonZero());
2417}
2418
2419OpFoldResult BitCtzOp::fold(FoldAdaptor adaptor) {
2420 return foldUnaryBitOp(
2421 adaptor.getInput(),
2422 [](const llvm::APInt &inputValue) {
2423 return llvm::APInt(inputValue.getBitWidth(),
2424 inputValue.countTrailingZeros());
2425 },
2426 getPoisonZero());
2427}
2428
2429OpFoldResult BitFfsOp::fold(FoldAdaptor adaptor) {
2430 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
2431 unsigned trailingZeros = inputValue.countTrailingZeros();
2432 unsigned result =
2433 trailingZeros == inputValue.getBitWidth() ? 0 : trailingZeros + 1;
2434 return llvm::APInt(inputValue.getBitWidth(), result);
2435 });
2436}
2437
2438OpFoldResult BitParityOp::fold(FoldAdaptor adaptor) {
2439 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
2440 return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount() % 2);
2441 });
2442}
2443
2444OpFoldResult BitPopcountOp::fold(FoldAdaptor adaptor) {
2445 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
2446 return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount());
2447 });
2448}
2449
2450OpFoldResult BitReverseOp::fold(FoldAdaptor adaptor) {
2451 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
2452 return inputValue.reverseBits();
2453 });
2454}
2455
2456OpFoldResult ByteSwapOp::fold(FoldAdaptor adaptor) {
2457 return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) {
2458 return inputValue.byteSwap();
2459 });
2460}
2461
2462OpFoldResult RotateOp::fold(FoldAdaptor adaptor) {
2463 if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()) ||
2464 mlir::isa_and_present<cir::PoisonAttr>(adaptor.getAmount())) {
2465 // Propagate poison values
2466 return cir::PoisonAttr::get(getType());
2467 }
2468
2469 auto input = mlir::dyn_cast_if_present<IntAttr>(adaptor.getInput());
2470 auto amount = mlir::dyn_cast_if_present<IntAttr>(adaptor.getAmount());
2471 if (!input && !amount)
2472 return nullptr;
2473
2474 // We could fold cir.rotate even if one of its two operands is not a constant:
2475 // - `cir.rotate left/right %0, 0` could be folded into just %0 even if %0
2476 // is not a constant.
2477 // - `cir.rotate left/right 0/0b111...111, %0` could be folded into 0 or
2478 // 0b111...111 even if %0 is not a constant.
2479
2480 llvm::APInt inputValue;
2481 if (input) {
2482 inputValue = input.getValue();
2483 if (inputValue.isZero() || inputValue.isAllOnes()) {
2484 // An input value of all 0s or all 1s will not change after rotation
2485 return input;
2486 }
2487 }
2488
2489 uint64_t amountValue;
2490 if (amount) {
2491 amountValue = amount.getValue().urem(getInput().getType().getWidth());
2492 if (amountValue == 0) {
2493 // A shift amount of 0 will not change the input value
2494 return getInput();
2495 }
2496 }
2497
2498 if (!input || !amount)
2499 return nullptr;
2500
2501 assert(inputValue.getBitWidth() == getInput().getType().getWidth() &&
2502 "input value must have the same bit width as the input type");
2503
2504 llvm::APInt resultValue;
2505 if (isRotateLeft())
2506 resultValue = inputValue.rotl(amountValue);
2507 else
2508 resultValue = inputValue.rotr(amountValue);
2509
2510 return IntAttr::get(input.getContext(), input.getType(), resultValue);
2511}
2512
2513//===----------------------------------------------------------------------===//
2514// InlineAsmOp
2515//===----------------------------------------------------------------------===//
2516
2517void cir::InlineAsmOp::print(OpAsmPrinter &p) {
2518 p << '(' << getAsmFlavor() << ", ";
2519 p.increaseIndent();
2520 p.printNewline();
2521
2522 llvm::SmallVector<std::string, 3> names{"out", "in", "in_out"};
2523 auto *nameIt = names.begin();
2524 auto *attrIt = getOperandAttrs().begin();
2525
2526 for (mlir::OperandRange ops : getAsmOperands()) {
2527 p << *nameIt << " = ";
2528
2529 p << '[';
2530 llvm::interleaveComma(llvm::make_range(ops.begin(), ops.end()), p,
2531 [&](Value value) {
2532 p.printOperand(value);
2533 p << " : " << value.getType();
2534 if (*attrIt)
2535 p << " (maybe_memory)";
2536 attrIt++;
2537 });
2538 p << "],";
2539 p.printNewline();
2540 ++nameIt;
2541 }
2542
2543 p << "{";
2544 p.printString(getAsmString());
2545 p << " ";
2546 p.printString(getConstraints());
2547 p << "}";
2548 p.decreaseIndent();
2549 p << ')';
2550 if (getSideEffects())
2551 p << " side_effects";
2552
2553 std::array elidedAttrs{
2554 llvm::StringRef("asm_flavor"), llvm::StringRef("asm_string"),
2555 llvm::StringRef("constraints"), llvm::StringRef("operand_attrs"),
2556 llvm::StringRef("operands_segments"), llvm::StringRef("side_effects")};
2557 p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs);
2558
2559 if (auto v = getRes())
2560 p << " -> " << v.getType();
2561}
2562
2563void cir::InlineAsmOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2564 ArrayRef<ValueRange> asmOperands,
2565 StringRef asmString, StringRef constraints,
2566 bool sideEffects, cir::AsmFlavor asmFlavor,
2567 ArrayRef<Attribute> operandAttrs) {
2568 // Set up the operands_segments for VariadicOfVariadic
2569 SmallVector<int32_t> segments;
2570 for (auto operandRange : asmOperands) {
2571 segments.push_back(operandRange.size());
2572 odsState.addOperands(operandRange);
2573 }
2574
2575 odsState.addAttribute(
2576 "operands_segments",
2577 DenseI32ArrayAttr::get(odsBuilder.getContext(), segments));
2578 odsState.addAttribute("asm_string", odsBuilder.getStringAttr(asmString));
2579 odsState.addAttribute("constraints", odsBuilder.getStringAttr(constraints));
2580 odsState.addAttribute("asm_flavor",
2581 AsmFlavorAttr::get(odsBuilder.getContext(), asmFlavor));
2582
2583 if (sideEffects)
2584 odsState.addAttribute("side_effects", odsBuilder.getUnitAttr());
2585
2586 odsState.addAttribute("operand_attrs", odsBuilder.getArrayAttr(operandAttrs));
2587}
2588
2589ParseResult cir::InlineAsmOp::parse(OpAsmParser &parser,
2590 OperationState &result) {
2592 llvm::SmallVector<int32_t> operandsGroupSizes;
2593 std::string asmString, constraints;
2594 Type resType;
2595 MLIRContext *ctxt = parser.getBuilder().getContext();
2596
2597 auto error = [&](const Twine &msg) -> LogicalResult {
2598 return parser.emitError(parser.getCurrentLocation(), msg);
2599 };
2600
2601 auto expected = [&](const std::string &c) {
2602 return error("expected '" + c + "'");
2603 };
2604
2605 if (parser.parseLParen().failed())
2606 return expected("(");
2607
2608 auto flavor = FieldParser<AsmFlavor, AsmFlavor>::parse(parser);
2609 if (failed(flavor))
2610 return error("Unknown AsmFlavor");
2611
2612 if (parser.parseComma().failed())
2613 return expected(",");
2614
2615 auto parseValue = [&](Value &v) {
2616 OpAsmParser::UnresolvedOperand op;
2617
2618 if (parser.parseOperand(op) || parser.parseColon())
2619 return error("can't parse operand");
2620
2621 Type typ;
2622 if (parser.parseType(typ).failed())
2623 return error("can't parse operand type");
2625 if (parser.resolveOperand(op, typ, tmp))
2626 return error("can't resolve operand");
2627 v = tmp[0];
2628 return mlir::success();
2629 };
2630
2631 auto parseOperands = [&](llvm::StringRef name) {
2632 if (parser.parseKeyword(name).failed())
2633 return error("expected " + name + " operands here");
2634 if (parser.parseEqual().failed())
2635 return expected("=");
2636 if (parser.parseLSquare().failed())
2637 return expected("[");
2638
2639 int size = 0;
2640 if (parser.parseOptionalRSquare().succeeded()) {
2641 operandsGroupSizes.push_back(size);
2642 if (parser.parseComma())
2643 return expected(",");
2644 return mlir::success();
2645 }
2646
2647 auto parseOperand = [&]() {
2648 Value val;
2649 if (parseValue(val).succeeded()) {
2650 result.operands.push_back(val);
2651 size++;
2652
2653 if (parser.parseOptionalLParen().failed()) {
2654 operandAttrs.push_back(mlir::Attribute());
2655 return mlir::success();
2656 }
2657
2658 if (parser.parseKeyword("maybe_memory").succeeded()) {
2659 operandAttrs.push_back(mlir::UnitAttr::get(ctxt));
2660 if (parser.parseRParen())
2661 return expected(")");
2662 return mlir::success();
2663 } else {
2664 return expected("maybe_memory");
2665 }
2666 }
2667 return mlir::failure();
2668 };
2669
2670 if (parser.parseCommaSeparatedList(parseOperand).failed())
2671 return mlir::failure();
2672
2673 if (parser.parseRSquare().failed() || parser.parseComma().failed())
2674 return expected("]");
2675 operandsGroupSizes.push_back(size);
2676 return mlir::success();
2677 };
2678
2679 if (parseOperands("out").failed() || parseOperands("in").failed() ||
2680 parseOperands("in_out").failed())
2681 return error("failed to parse operands");
2682
2683 if (parser.parseLBrace())
2684 return expected("{");
2685 if (parser.parseString(&asmString))
2686 return error("asm string parsing failed");
2687 if (parser.parseString(&constraints))
2688 return error("constraints string parsing failed");
2689 if (parser.parseRBrace())
2690 return expected("}");
2691 if (parser.parseRParen())
2692 return expected(")");
2693
2694 if (parser.parseOptionalKeyword("side_effects").succeeded())
2695 result.attributes.set("side_effects", UnitAttr::get(ctxt));
2696
2697 if (parser.parseOptionalArrow().succeeded() &&
2698 parser.parseType(resType).failed())
2699 return mlir::failure();
2700
2701 if (parser.parseOptionalAttrDict(result.attributes).failed())
2702 return mlir::failure();
2703
2704 result.attributes.set("asm_flavor", AsmFlavorAttr::get(ctxt, *flavor));
2705 result.attributes.set("asm_string", StringAttr::get(ctxt, asmString));
2706 result.attributes.set("constraints", StringAttr::get(ctxt, constraints));
2707 result.attributes.set("operand_attrs", ArrayAttr::get(ctxt, operandAttrs));
2708 result.getOrAddProperties<InlineAsmOp::Properties>().operands_segments =
2709 parser.getBuilder().getDenseI32ArrayAttr(operandsGroupSizes);
2710 if (resType)
2711 result.addTypes(TypeRange{resType});
2712
2713 return mlir::success();
2714}
2715
2716//===----------------------------------------------------------------------===//
2717// ThrowOp
2718//===----------------------------------------------------------------------===//
2719
2720mlir::LogicalResult cir::ThrowOp::verify() {
2721 // For the no-rethrow version, it must have at least the exception pointer.
2722 if (rethrows())
2723 return success();
2724
2725 if (getNumOperands() != 0) {
2726 if (getTypeInfo())
2727 return success();
2728 return emitOpError() << "'type_info' symbol attribute missing";
2729 }
2730
2731 return failure();
2732}
2733
2734//===----------------------------------------------------------------------===//
2735// AtomicCmpXchg
2736//===----------------------------------------------------------------------===//
2737
2738LogicalResult cir::AtomicCmpXchg::verify() {
2739 mlir::Type pointeeType = getPtr().getType().getPointee();
2740
2741 if (pointeeType != getExpected().getType() ||
2742 pointeeType != getDesired().getType())
2743 return emitOpError("ptr, expected and desired types must match");
2744
2745 return success();
2746}
2747
2748//===----------------------------------------------------------------------===//
2749// TypeInfoAttr
2750//===----------------------------------------------------------------------===//
2751
2752LogicalResult cir::TypeInfoAttr::verify(
2753 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
2754 ::mlir::Type type, ::mlir::ArrayAttr typeInfoData) {
2755
2756 if (cir::ConstRecordAttr::verify(emitError, type, typeInfoData).failed())
2757 return failure();
2758
2759 return success();
2760}
2761
2762//===----------------------------------------------------------------------===//
2763// TableGen'd op method definitions
2764//===----------------------------------------------------------------------===//
2765
2766#define GET_OP_CLASSES
2767#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
static const MemRegion * getRegion(const CallEvent &Call, const MutexDescriptor &Descriptor, bool IsLock)
static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op, TypeAttr type, Attribute initAttr)
static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op, cir::FuncOp function)
static bool isBoolNot(cir::UnaryOp op)
static bool isIntOrBoolCast(cir::CastOp op)
static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValueAttr)
static void printConstant(OpAsmPrinter &p, Attribute value)
static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser, mlir::Region &region)
void printVisibilityAttr(OpAsmPrinter &printer, cir::VisibilityAttr &visibility)
static ParseResult parseSwitchFlatOpCases(OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< llvm::SmallVector< OpAsmParser::UnresolvedOperand > > &caseOperands, SmallVectorImpl< llvm::SmallVector< Type > > &caseOperandTypes)
<cases> ::= [ (case (, case )* )?
static LogicalResult verifyCallCommInSymbolUses(mlir::Operation *op, SymbolTableCollection &symbolTable)
static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region, SMLoc errLoc)
static OpFoldResult foldUnaryBitOp(mlir::Attribute inputAttr, llvm::function_ref< llvm::APInt(const llvm::APInt &)> func, bool poisonZero=false)
static llvm::StringRef getLinkageAttrNameString()
Returns the name used for the linkage attribute.
static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue)
Parse an enum from the keyword, or default to the provided default value.
static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op, Type flagType, mlir::ArrayAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static void printSwitchOp(OpAsmPrinter &p, cir::SwitchOp op, mlir::Region &bodyRegion, mlir::Value condition, mlir::Type condType)
static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result)
Parse an enum from the keyword, return failure if the keyword is not found.
static Value tryFoldCastChain(cir::CastOp op)
void parseVisibilityAttr(OpAsmParser &parser, cir::VisibilityAttr &visibility)
static bool omitRegionTerm(mlir::Region &r)
static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer, cir::ScopeOp &op, mlir::Region &region)
static ParseResult parseConstantValue(OpAsmParser &parser, mlir::Attribute &valueAttr)
static void printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym, mlir::Value indirectCallee, mlir::OpAsmPrinter &printer, bool isNothrow, cir::SideEffect sideEffect)
static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, mlir::Attribute attrType)
static ParseResult parseSwitchOp(OpAsmParser &parser, mlir::Region &regions, mlir::OpAsmParser::UnresolvedOperand &cond, mlir::Type &condType)
static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, mlir::OperationState &result)
#define REGISTER_ENUM_TYPE(Ty)
static int parseOptionalKeywordAlternative(AsmParser &parser, ArrayRef< llvm::StringRef > keywords)
llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> BuilderCallbackRef
Definition CIRDialect.h:37
llvm::function_ref< void( mlir::OpBuilder &, mlir::Location, mlir::OperationState &)> BuilderOpStateCallbackRef
Definition CIRDialect.h:39
static std::optional< NonLoc > getIndex(ProgramStateRef State, const ElementRegion *ER, CharKind CK)
static Decl::Kind getKind(const Decl *D)
TokenType getType() const
Returns the token's type, e.g.
__device__ __2f16 float c
void buildTerminatedBody(mlir::OpBuilder &builder, mlir::Location loc)
const AstTypeMatcher< RecordType > recordType
StringRef getName(const HeaderType T)
Definition HeaderFile.h:38
RangeSelector name(std::string ID)
Given a node with a "name", (like NamedDecl, DeclRefExpr, CxxCtorInitializer, and TypeLoc) selects th...
nullptr
This class represents a compute construct, representing a 'Kind' of ‘parallel’, 'serial',...
__DEVICE__ _Tp arg(const std::complex< _Tp > &__c)
static bool addressSpace()
static bool opGlobalThreadLocal()
static bool opCallCallConv()
static bool opScopeCleanupRegion()
static bool supportIFuncAttr()