clang 23.0.0git
CIRAttrs.cpp
Go to the documentation of this file.
1//===- CIRAttrs.cpp - MLIR CIR Attributes ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the attributes in the CIR dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "mlir/IR/DialectImplementation.h"
16#include "llvm/ADT/TypeSwitch.h"
17
18//===-----------------------------------------------------------------===//
19// RecordMembers
20//===-----------------------------------------------------------------===//
21
22static void printRecordMembers(mlir::AsmPrinter &p, mlir::ArrayAttr members);
23static mlir::ParseResult parseRecordMembers(mlir::AsmParser &parser,
24 mlir::ArrayAttr &members);
25
26//===-----------------------------------------------------------------===//
27// IntLiteral
28//===-----------------------------------------------------------------===//
29
30static void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,
31 cir::IntTypeInterface ty);
32static mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser,
33 llvm::APInt &value,
34 cir::IntTypeInterface ty);
35//===-----------------------------------------------------------------===//
36// FloatLiteral
37//===-----------------------------------------------------------------===//
38
39static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value,
40 mlir::Type ty);
41static mlir::ParseResult
42parseFloatLiteral(mlir::AsmParser &parser,
43 mlir::FailureOr<llvm::APFloat> &value,
44 cir::FPTypeInterface fpType);
45
46//===----------------------------------------------------------------------===//
47// AddressSpaceAttr
48//===----------------------------------------------------------------------===//
49
50mlir::ParseResult parseTargetAddressSpace(mlir::AsmParser &p,
51 cir::TargetAddressSpaceAttr &attr);
52
53void printTargetAddressSpace(mlir::AsmPrinter &p,
54 cir::TargetAddressSpaceAttr attr);
55
56static mlir::ParseResult parseConstPtr(mlir::AsmParser &parser,
57 mlir::IntegerAttr &value);
58
59static void printConstPtr(mlir::AsmPrinter &p, mlir::IntegerAttr value);
60
61#define GET_ATTRDEF_CLASSES
62#include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc"
63
64using namespace mlir;
65using namespace cir;
66
67//===----------------------------------------------------------------------===//
68// General CIR parsing / printing
69//===----------------------------------------------------------------------===//
70
71static void printRecordMembers(mlir::AsmPrinter &printer,
72 mlir::ArrayAttr members) {
73 printer << '{';
74 llvm::interleaveComma(members, printer);
75 printer << '}';
76}
77
78static ParseResult parseRecordMembers(mlir::AsmParser &parser,
79 mlir::ArrayAttr &members) {
81
82 auto delimiter = AsmParser::Delimiter::Braces;
83 auto result = parser.parseCommaSeparatedList(delimiter, [&]() {
84 mlir::TypedAttr attr;
85 if (parser.parseAttribute(attr).failed())
86 return mlir::failure();
87 elts.push_back(attr);
88 return mlir::success();
89 });
90
91 if (result.failed())
92 return mlir::failure();
93
94 members = mlir::ArrayAttr::get(parser.getContext(), elts);
95 return mlir::success();
96}
97
98//===----------------------------------------------------------------------===//
99// ConstRecordAttr definitions
100//===----------------------------------------------------------------------===//
101
102LogicalResult
103ConstRecordAttr::verify(function_ref<InFlightDiagnostic()> emitError,
104 mlir::Type type, ArrayAttr members) {
105 auto sTy = mlir::dyn_cast_if_present<cir::RecordType>(type);
106 if (!sTy)
107 return emitError() << "expected !cir.record type";
108
109 if (sTy.getMembers().size() != members.size())
110 return emitError() << "number of elements must match";
111
112 unsigned attrIdx = 0;
113 for (auto &member : sTy.getMembers()) {
114 auto m = mlir::cast<mlir::TypedAttr>(members[attrIdx]);
115 if (member != m.getType())
116 return emitError() << "element at index " << attrIdx << " has type "
117 << m.getType()
118 << " but the expected type for this element is "
119 << member;
120 attrIdx++;
121 }
122
123 return success();
124}
125
126//===----------------------------------------------------------------------===//
127// OptInfoAttr definitions
128//===----------------------------------------------------------------------===//
129
130LogicalResult OptInfoAttr::verify(function_ref<InFlightDiagnostic()> emitError,
131 unsigned level, unsigned size) {
132 if (level > 3)
133 return emitError()
134 << "optimization level must be between 0 and 3 inclusive";
135 if (size > 2)
136 return emitError()
137 << "size optimization level must be between 0 and 2 inclusive";
138 return success();
139}
140
141//===----------------------------------------------------------------------===//
142// ConstPtrAttr definitions
143//===----------------------------------------------------------------------===//
144
145// TODO(CIR): Consider encoding the null value differently and use conditional
146// assembly format instead of custom parsing/printing.
147static ParseResult parseConstPtr(AsmParser &parser, mlir::IntegerAttr &value) {
148
149 if (parser.parseOptionalKeyword("null").succeeded()) {
150 value = parser.getBuilder().getI64IntegerAttr(0);
151 return success();
152 }
153
154 return parser.parseAttribute(value);
155}
156
157static void printConstPtr(AsmPrinter &p, mlir::IntegerAttr value) {
158 if (!value.getInt())
159 p << "null";
160 else
161 p << value;
162}
163
164//===----------------------------------------------------------------------===//
165// IntAttr definitions
166//===----------------------------------------------------------------------===//
167
168template <typename IntT>
169static bool isTooLargeForType(const mlir::APInt &value, IntT expectedValue) {
170 if constexpr (std::is_signed_v<IntT>) {
171 return value.getSExtValue() != expectedValue;
172 } else {
173 return value.getZExtValue() != expectedValue;
174 }
175}
176
177template <typename IntT>
178static mlir::ParseResult parseIntLiteralImpl(mlir::AsmParser &p,
179 llvm::APInt &value,
180 cir::IntTypeInterface ty) {
181 IntT ivalue;
182 const bool isSigned = ty.isSigned();
183 if (p.parseInteger(ivalue))
184 return p.emitError(p.getCurrentLocation(), "expected integer value");
185
186 value = mlir::APInt(ty.getWidth(), ivalue, isSigned, /*implicitTrunc=*/true);
187 if (isTooLargeForType(value, ivalue))
188 return p.emitError(p.getCurrentLocation(),
189 "integer value too large for the given type");
190
191 return success();
192}
193
194mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser, llvm::APInt &value,
195 cir::IntTypeInterface ty) {
196 if (ty.isSigned())
197 return parseIntLiteralImpl<int64_t>(parser, value, ty);
198 return parseIntLiteralImpl<uint64_t>(parser, value, ty);
199}
200
201void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,
202 cir::IntTypeInterface ty) {
203 if (ty.isSigned())
204 p << value.getSExtValue();
205 else
206 p << value.getZExtValue();
207}
208
209LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError,
210 cir::IntTypeInterface type, llvm::APInt value) {
211 if (value.getBitWidth() != type.getWidth())
212 return emitError() << "type and value bitwidth mismatch: "
213 << type.getWidth() << " != " << value.getBitWidth();
214 return success();
215}
216
217//===----------------------------------------------------------------------===//
218// FPAttr definitions
219//===----------------------------------------------------------------------===//
220
221static void printFloatLiteral(AsmPrinter &p, APFloat value, Type ty) {
222 p << value;
223}
224
225static ParseResult parseFloatLiteral(AsmParser &parser,
226 FailureOr<APFloat> &value,
227 cir::FPTypeInterface fpType) {
228
229 APFloat parsedValue(0.0);
230 if (parser.parseFloat(fpType.getFloatSemantics(), parsedValue))
231 return failure();
232
233 value.emplace(parsedValue);
234 return success();
235}
236
237FPAttr FPAttr::getZero(Type type) {
238 return get(type,
239 APFloat::getZero(
240 mlir::cast<cir::FPTypeInterface>(type).getFloatSemantics()));
241}
242
243LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
244 cir::FPTypeInterface fpType, APFloat value) {
245 if (APFloat::SemanticsToEnum(fpType.getFloatSemantics()) !=
246 APFloat::SemanticsToEnum(value.getSemantics()))
247 return emitError() << "floating-point semantics mismatch";
248
249 return success();
250}
251
252//===----------------------------------------------------------------------===//
253// ConstComplexAttr definitions
254//===----------------------------------------------------------------------===//
255
256LogicalResult
257ConstComplexAttr::verify(function_ref<InFlightDiagnostic()> emitError,
258 cir::ComplexType type, mlir::TypedAttr real,
259 mlir::TypedAttr imag) {
260 mlir::Type elemType = type.getElementType();
261 if (real.getType() != elemType)
262 return emitError()
263 << "type of the real part does not match the complex type";
264
265 if (imag.getType() != elemType)
266 return emitError()
267 << "type of the imaginary part does not match the complex type";
268
269 return success();
270}
271
272//===----------------------------------------------------------------------===//
273// DataMemberAttr definitions
274//===----------------------------------------------------------------------===//
275
276LogicalResult
277DataMemberAttr::verify(function_ref<InFlightDiagnostic()> emitError,
278 cir::DataMemberType ty,
279 std::optional<unsigned> memberIndex) {
280 // DataMemberAttr without a given index represents a null value.
281 if (!memberIndex.has_value())
282 return success();
283
284 cir::RecordType recTy = ty.getClassTy();
285 if (recTy.isIncomplete())
286 return emitError()
287 << "incomplete 'cir.record' cannot be used to build a non-null "
288 "data member pointer";
289
290 unsigned memberIndexValue = memberIndex.value();
291 if (memberIndexValue >= recTy.getNumElements())
292 return emitError()
293 << "member index of a #cir.data_member attribute is out of range";
294
295 mlir::Type memberTy = recTy.getMembers()[memberIndexValue];
296 if (memberTy != ty.getMemberTy())
297 return emitError()
298 << "member type of a #cir.data_member attribute must match the "
299 "attribute type";
300
301 return success();
302}
303
304//===----------------------------------------------------------------------===//
305// MethodAttr definitions
306//===----------------------------------------------------------------------===//
307
308LogicalResult MethodAttr::verify(function_ref<InFlightDiagnostic()> emitError,
309 cir::MethodType type,
310 std::optional<FlatSymbolRefAttr> symbol,
311 std::optional<uint64_t> vtable_offset) {
312 if (symbol.has_value() && vtable_offset.has_value())
313 return emitError()
314 << "at most one of symbol and vtable_offset can be present "
315 "in #cir.method";
316
317 return success();
318}
319
320Attribute MethodAttr::parse(AsmParser &parser, Type odsType) {
321 auto ty = mlir::cast<cir::MethodType>(odsType);
322
323 if (parser.parseLess().failed())
324 return {};
325
326 // Try to parse the null pointer constant.
327 if (parser.parseOptionalKeyword("null").succeeded()) {
328 if (parser.parseGreater().failed())
329 return {};
330 return get(ty);
331 }
332
333 // Try to parse a flat symbol ref for a pointer to non-virtual member
334 // function.
335 FlatSymbolRefAttr symbol;
336 mlir::OptionalParseResult parseSymbolRefResult =
337 parser.parseOptionalAttribute(symbol);
338 if (parseSymbolRefResult.has_value()) {
339 if (parseSymbolRefResult.value().failed())
340 return {};
341 if (parser.parseGreater().failed())
342 return {};
343 return get(ty, symbol);
344 }
345
346 // Parse a uint64 that represents the vtable offset.
347 std::uint64_t vtableOffset = 0;
348 if (parser.parseKeyword("vtable_offset"))
349 return {};
350 if (parser.parseEqual())
351 return {};
352 if (parser.parseInteger(vtableOffset))
353 return {};
354
355 if (parser.parseGreater())
356 return {};
357
358 return get(ty, vtableOffset);
359}
360
361void MethodAttr::print(AsmPrinter &printer) const {
362 auto symbol = getSymbol();
363 auto vtableOffset = getVtableOffset();
364
365 printer << '<';
366 if (symbol.has_value()) {
367 printer << *symbol;
368 } else if (vtableOffset.has_value()) {
369 printer << "vtable_offset = " << *vtableOffset;
370 } else {
371 printer << "null";
372 }
373 printer << '>';
374}
375
376//===----------------------------------------------------------------------===//
377// CIR ConstArrayAttr
378//===----------------------------------------------------------------------===//
379
380LogicalResult
381ConstArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError, Type type,
382 Attribute elts, int trailingZerosNum) {
383
384 if (!(mlir::isa<ArrayAttr, StringAttr>(elts)))
385 return emitError() << "constant array expects ArrayAttr or StringAttr";
386
387 if (auto strAttr = mlir::dyn_cast<StringAttr>(elts)) {
388 const auto arrayTy = mlir::cast<ArrayType>(type);
389 const auto intTy = mlir::dyn_cast<IntType>(arrayTy.getElementType());
390
391 // TODO: add CIR type for char.
392 if (!intTy || intTy.getWidth() != 8)
393 return emitError()
394 << "constant array element for string literals expects "
395 "!cir.int<u, 8> element type";
396 return success();
397 }
398
399 assert(mlir::isa<ArrayAttr>(elts));
400 const auto arrayAttr = mlir::cast<mlir::ArrayAttr>(elts);
401 const auto arrayTy = mlir::cast<ArrayType>(type);
402
403 // Make sure both number of elements and subelement types match type.
404 if (arrayTy.getSize() != arrayAttr.size() + trailingZerosNum)
405 return emitError() << "constant array size should match type size";
406 return success();
407}
408
409Attribute ConstArrayAttr::parse(AsmParser &parser, Type type) {
410 mlir::FailureOr<Type> resultTy;
411 mlir::FailureOr<Attribute> resultVal;
412
413 // Parse literal '<'
414 if (parser.parseLess())
415 return {};
416
417 // Parse variable 'value'
418 resultVal = FieldParser<Attribute>::parse(parser);
419 if (failed(resultVal)) {
420 parser.emitError(
421 parser.getCurrentLocation(),
422 "failed to parse ConstArrayAttr parameter 'value' which is "
423 "to be a `Attribute`");
424 return {};
425 }
426
427 // ArrayAttrrs have per-element type, not the type of the array...
428 if (mlir::isa<ArrayAttr>(*resultVal)) {
429 // Array has implicit type: infer from const array type.
430 if (parser.parseOptionalColon().failed()) {
431 resultTy = type;
432 } else { // Array has explicit type: parse it.
433 resultTy = FieldParser<Type>::parse(parser);
434 if (failed(resultTy)) {
435 parser.emitError(
436 parser.getCurrentLocation(),
437 "failed to parse ConstArrayAttr parameter 'type' which is "
438 "to be a `::mlir::Type`");
439 return {};
440 }
441 }
442 } else {
443 auto ta = mlir::cast<TypedAttr>(*resultVal);
444 resultTy = ta.getType();
445 if (mlir::isa<mlir::NoneType>(*resultTy)) {
446 parser.emitError(parser.getCurrentLocation(),
447 "expected type declaration for string literal");
448 return {};
449 }
450 }
451
452 unsigned zeros = 0;
453 if (parser.parseOptionalComma().succeeded()) {
454 if (parser.parseOptionalKeyword("trailing_zeros").succeeded()) {
455 unsigned typeSize =
456 mlir::cast<cir::ArrayType>(resultTy.value()).getSize();
457 mlir::Attribute elts = resultVal.value();
458 if (auto str = mlir::dyn_cast<mlir::StringAttr>(elts))
459 zeros = typeSize - str.size();
460 else
461 zeros = typeSize - mlir::cast<mlir::ArrayAttr>(elts).size();
462 } else {
463 return {};
464 }
465 }
466
467 // Parse literal '>'
468 if (parser.parseGreater())
469 return {};
470
471 return parser.getChecked<ConstArrayAttr>(
472 parser.getCurrentLocation(), parser.getContext(), resultTy.value(),
473 resultVal.value(), zeros);
474}
475
476void ConstArrayAttr::print(AsmPrinter &printer) const {
477 printer << "<";
478 printer.printStrippedAttrOrType(getElts());
479 if (getTrailingZerosNum())
480 printer << ", trailing_zeros";
481 printer << ">";
482}
483
484//===----------------------------------------------------------------------===//
485// CIR ConstVectorAttr
486//===----------------------------------------------------------------------===//
487
488LogicalResult
489cir::ConstVectorAttr::verify(function_ref<InFlightDiagnostic()> emitError,
490 Type type, ArrayAttr elts) {
491
492 if (!mlir::isa<cir::VectorType>(type))
493 return emitError() << "type of cir::ConstVectorAttr is not a "
494 "cir::VectorType: "
495 << type;
496
497 const auto vecType = mlir::cast<cir::VectorType>(type);
498
499 if (vecType.getSize() != elts.size())
500 return emitError()
501 << "number of constant elements should match vector size";
502
503 // Check if the types of the elements match
504 LogicalResult elementTypeCheck = success();
505 elts.walkImmediateSubElements(
506 [&](Attribute element) {
507 if (elementTypeCheck.failed()) {
508 // An earlier element didn't match
509 return;
510 }
511 auto typedElement = mlir::dyn_cast<TypedAttr>(element);
512 if (!typedElement ||
513 typedElement.getType() != vecType.getElementType()) {
514 elementTypeCheck = failure();
515 emitError() << "constant type should match vector element type";
516 }
517 },
518 [&](Type) {});
519
520 return elementTypeCheck;
521}
522
523//===----------------------------------------------------------------------===//
524// CIR VTableAttr
525//===----------------------------------------------------------------------===//
526
527LogicalResult cir::VTableAttr::verify(
528 llvm::function_ref<mlir::InFlightDiagnostic()> emitError, mlir::Type type,
529 mlir::ArrayAttr data) {
530 auto sTy = mlir::dyn_cast_if_present<cir::RecordType>(type);
531 if (!sTy)
532 return emitError() << "expected !cir.record type result";
533 if (sTy.getMembers().empty() || data.empty())
534 return emitError() << "expected record type with one or more subtype";
535
536 if (cir::ConstRecordAttr::verify(emitError, type, data).failed())
537 return failure();
538
539 for (const auto &element : data.getAsRange<mlir::Attribute>()) {
540 const auto &constArrayAttr = mlir::dyn_cast<cir::ConstArrayAttr>(element);
541 if (!constArrayAttr)
542 return emitError() << "expected constant array subtype";
543
544 LogicalResult eltTypeCheck = success();
545 auto arrayElts = mlir::cast<ArrayAttr>(constArrayAttr.getElts());
546 arrayElts.walkImmediateSubElements(
547 [&](mlir::Attribute attr) {
548 if (mlir::isa<ConstPtrAttr, GlobalViewAttr>(attr))
549 return;
550
551 eltTypeCheck = emitError()
552 << "expected GlobalViewAttr or ConstPtrAttr";
553 },
554 [&](mlir::Type type) {});
555 if (eltTypeCheck.failed())
556 return eltTypeCheck;
557 }
558 return success();
559}
560
561//===----------------------------------------------------------------------===//
562// DynamicCastInfoAtttr definitions
563//===----------------------------------------------------------------------===//
564
565std::string DynamicCastInfoAttr::getAlias() const {
566 // The alias looks like: `dyn_cast_info_<src>_<dest>`
567
568 std::string alias = "dyn_cast_info_";
569
570 alias.append(getSrcRtti().getSymbol().getValue());
571 alias.push_back('_');
572 alias.append(getDestRtti().getSymbol().getValue());
573
574 return alias;
575}
576
577LogicalResult DynamicCastInfoAttr::verify(
578 function_ref<InFlightDiagnostic()> emitError, cir::GlobalViewAttr srcRtti,
579 cir::GlobalViewAttr destRtti, mlir::FlatSymbolRefAttr runtimeFunc,
580 mlir::FlatSymbolRefAttr badCastFunc, cir::IntAttr offsetHint) {
581 auto isRttiPtr = [](mlir::Type ty) {
582 // RTTI pointers are !cir.ptr<!u8i>.
583
584 auto ptrTy = mlir::dyn_cast<cir::PointerType>(ty);
585 if (!ptrTy)
586 return false;
587
588 auto pointeeIntTy = mlir::dyn_cast<cir::IntType>(ptrTy.getPointee());
589 if (!pointeeIntTy)
590 return false;
591
592 return pointeeIntTy.isUnsigned() && pointeeIntTy.getWidth() == 8;
593 };
594
595 if (!isRttiPtr(srcRtti.getType()))
596 return emitError() << "srcRtti must be an RTTI pointer";
597
598 if (!isRttiPtr(destRtti.getType()))
599 return emitError() << "destRtti must be an RTTI pointer";
600
601 return success();
602}
603
604//===----------------------------------------------------------------------===//
605// CIR Dialect
606//===----------------------------------------------------------------------===//
607
608void CIRDialect::registerAttributes() {
609 addAttributes<
610#define GET_ATTRDEF_LIST
611#include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc"
612 >();
613}
static mlir::ParseResult parseFloatLiteral(mlir::AsmParser &parser, mlir::FailureOr< llvm::APFloat > &value, cir::FPTypeInterface fpType)
static mlir::ParseResult parseIntLiteralImpl(mlir::AsmParser &p, llvm::APInt &value, cir::IntTypeInterface ty)
Definition CIRAttrs.cpp:178
static void printConstPtr(mlir::AsmPrinter &p, mlir::IntegerAttr value)
static void printRecordMembers(mlir::AsmPrinter &p, mlir::ArrayAttr members)
Definition CIRAttrs.cpp:71
static mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser, llvm::APInt &value, cir::IntTypeInterface ty)
Definition CIRAttrs.cpp:194
mlir::ParseResult parseTargetAddressSpace(mlir::AsmParser &p, cir::TargetAddressSpaceAttr &attr)
Definition CIRTypes.cpp:957
static void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value, cir::IntTypeInterface ty)
Definition CIRAttrs.cpp:201
static mlir::ParseResult parseConstPtr(mlir::AsmParser &parser, mlir::IntegerAttr &value)
static bool isTooLargeForType(const mlir::APInt &value, IntT expectedValue)
Definition CIRAttrs.cpp:169
static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value, mlir::Type ty)
static mlir::ParseResult parseRecordMembers(mlir::AsmParser &parser, mlir::ArrayAttr &members)
Definition CIRAttrs.cpp:78
void printTargetAddressSpace(mlir::AsmPrinter &p, cir::TargetAddressSpaceAttr attr)
Definition CIRTypes.cpp:982
RangeSelector member(std::string ID)
Given a MemberExpr, selects the member token. ID is the node's binding in the match result.