clang 23.0.0git
CIRAttrs.cpp
Go to the documentation of this file.
1//===- CIRAttrs.cpp - MLIR CIR Attributes ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the attributes in the CIR dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "mlir/IR/DialectImplementation.h"
16#include "llvm/ADT/TypeSwitch.h"
17
18//===-----------------------------------------------------------------===//
19// RecordMembers
20//===-----------------------------------------------------------------===//
21
22static void printRecordMembers(mlir::AsmPrinter &p, mlir::ArrayAttr members);
23static mlir::ParseResult parseRecordMembers(mlir::AsmParser &parser,
24 mlir::ArrayAttr &members);
25
26//===-----------------------------------------------------------------===//
27// IntLiteral
28//===-----------------------------------------------------------------===//
29
30static void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,
31 cir::IntTypeInterface ty);
32static mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser,
33 llvm::APInt &value,
34 cir::IntTypeInterface ty);
35//===-----------------------------------------------------------------===//
36// FloatLiteral
37//===-----------------------------------------------------------------===//
38
39static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value,
40 mlir::Type ty);
41static mlir::ParseResult
42parseFloatLiteral(mlir::AsmParser &parser,
43 mlir::FailureOr<llvm::APFloat> &value,
44 cir::FPTypeInterface fpType);
45
46//===----------------------------------------------------------------------===//
47// AddressSpaceAttr
48//===----------------------------------------------------------------------===//
49
50mlir::ParseResult parseTargetAddressSpace(mlir::AsmParser &p,
51 cir::TargetAddressSpaceAttr &attr);
52
53void printTargetAddressSpace(mlir::AsmPrinter &p,
54 cir::TargetAddressSpaceAttr attr);
55
56static mlir::ParseResult parseConstPtr(mlir::AsmParser &parser,
57 mlir::IntegerAttr &value);
58
59static void printConstPtr(mlir::AsmPrinter &p, mlir::IntegerAttr value);
60
61#define GET_ATTRDEF_CLASSES
62#include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc"
63
64using namespace mlir;
65using namespace cir;
66
67//===----------------------------------------------------------------------===//
68// General CIR parsing / printing
69//===----------------------------------------------------------------------===//
70
71static void printRecordMembers(mlir::AsmPrinter &printer,
72 mlir::ArrayAttr members) {
73 printer << '{';
74 llvm::interleaveComma(members, printer);
75 printer << '}';
76}
77
78static ParseResult parseRecordMembers(mlir::AsmParser &parser,
79 mlir::ArrayAttr &members) {
81
82 auto delimiter = AsmParser::Delimiter::Braces;
83 auto result = parser.parseCommaSeparatedList(delimiter, [&]() {
84 mlir::TypedAttr attr;
85 if (parser.parseAttribute(attr).failed())
86 return mlir::failure();
87 elts.push_back(attr);
88 return mlir::success();
89 });
90
91 if (result.failed())
92 return mlir::failure();
93
94 members = mlir::ArrayAttr::get(parser.getContext(), elts);
95 return mlir::success();
96}
97
98//===----------------------------------------------------------------------===//
99// ConstRecordAttr definitions
100//===----------------------------------------------------------------------===//
101
102LogicalResult
103ConstRecordAttr::verify(function_ref<InFlightDiagnostic()> emitError,
104 mlir::Type type, ArrayAttr members) {
105 auto sTy = mlir::dyn_cast_if_present<cir::RecordType>(type);
106 if (!sTy)
107 return emitError() << "expected !cir.record type";
108
109 if (sTy.getMembers().size() != members.size())
110 return emitError() << "number of elements must match";
111
112 unsigned attrIdx = 0;
113 for (auto &member : sTy.getMembers()) {
114 auto m = mlir::cast<mlir::TypedAttr>(members[attrIdx]);
115 if (member != m.getType())
116 return emitError() << "element at index " << attrIdx << " has type "
117 << m.getType()
118 << " but the expected type for this element is "
119 << member;
120 attrIdx++;
121 }
122
123 return success();
124}
125
126//===----------------------------------------------------------------------===//
127// OptInfoAttr definitions
128//===----------------------------------------------------------------------===//
129
130LogicalResult OptInfoAttr::verify(function_ref<InFlightDiagnostic()> emitError,
131 unsigned level, unsigned size) {
132 if (level > 3)
133 return emitError()
134 << "optimization level must be between 0 and 3 inclusive";
135 if (size > 2)
136 return emitError()
137 << "size optimization level must be between 0 and 2 inclusive";
138 return success();
139}
140
141//===----------------------------------------------------------------------===//
142// ConstPtrAttr definitions
143//===----------------------------------------------------------------------===//
144
145// TODO(CIR): Consider encoding the null value differently and use conditional
146// assembly format instead of custom parsing/printing.
147static ParseResult parseConstPtr(AsmParser &parser, mlir::IntegerAttr &value) {
148
149 if (parser.parseOptionalKeyword("null").succeeded()) {
150 value = parser.getBuilder().getI64IntegerAttr(0);
151 return success();
152 }
153
154 return parser.parseAttribute(value);
155}
156
157static void printConstPtr(AsmPrinter &p, mlir::IntegerAttr value) {
158 if (!value.getInt())
159 p << "null";
160 else
161 p << value;
162}
163
164//===----------------------------------------------------------------------===//
165// IntAttr definitions
166//===----------------------------------------------------------------------===//
167
168template <typename IntT>
169static bool isTooLargeForType(const mlir::APInt &value, IntT expectedValue) {
170 if constexpr (std::is_signed_v<IntT>) {
171 return value.getSExtValue() != expectedValue;
172 } else {
173 return value.getZExtValue() != expectedValue;
174 }
175}
176
177template <typename IntT>
178static mlir::ParseResult parseIntLiteralImpl(mlir::AsmParser &p,
179 llvm::APInt &value,
180 cir::IntTypeInterface ty) {
181 IntT ivalue;
182 const bool isSigned = ty.isSigned();
183 if (p.parseInteger(ivalue))
184 return p.emitError(p.getCurrentLocation(), "expected integer value");
185
186 value = mlir::APInt(ty.getWidth(), ivalue, isSigned, /*implicitTrunc=*/true);
187 if (isTooLargeForType(value, ivalue))
188 return p.emitError(p.getCurrentLocation(),
189 "integer value too large for the given type");
190
191 return success();
192}
193
194mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser, llvm::APInt &value,
195 cir::IntTypeInterface ty) {
196 if (ty.isSigned())
197 return parseIntLiteralImpl<int64_t>(parser, value, ty);
198 return parseIntLiteralImpl<uint64_t>(parser, value, ty);
199}
200
201void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,
202 cir::IntTypeInterface ty) {
203 if (ty.isSigned())
204 p << value.getSExtValue();
205 else
206 p << value.getZExtValue();
207}
208
209LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError,
210 cir::IntTypeInterface type, llvm::APInt value) {
211 if (value.getBitWidth() != type.getWidth())
212 return emitError() << "type and value bitwidth mismatch: "
213 << type.getWidth() << " != " << value.getBitWidth();
214 return success();
215}
216
217//===----------------------------------------------------------------------===//
218// FPAttr definitions
219//===----------------------------------------------------------------------===//
220
221static void printFloatLiteral(AsmPrinter &p, APFloat value, Type ty) {
222 p << value;
223}
224
225static ParseResult parseFloatLiteral(AsmParser &parser,
226 FailureOr<APFloat> &value,
227 cir::FPTypeInterface fpType) {
228
229 APFloat parsedValue(0.0);
230 if (parser.parseFloat(fpType.getFloatSemantics(), parsedValue))
231 return failure();
232
233 value.emplace(parsedValue);
234 return success();
235}
236
237FPAttr FPAttr::getZero(Type type) {
238 return get(type,
239 APFloat::getZero(
240 mlir::cast<cir::FPTypeInterface>(type).getFloatSemantics()));
241}
242
243LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
244 cir::FPTypeInterface fpType, APFloat value) {
245 if (APFloat::SemanticsToEnum(fpType.getFloatSemantics()) !=
246 APFloat::SemanticsToEnum(value.getSemantics()))
247 return emitError() << "floating-point semantics mismatch";
248
249 return success();
250}
251
252//===----------------------------------------------------------------------===//
253// ConstComplexAttr definitions
254//===----------------------------------------------------------------------===//
255
256LogicalResult
257ConstComplexAttr::verify(function_ref<InFlightDiagnostic()> emitError,
258 cir::ComplexType type, mlir::TypedAttr real,
259 mlir::TypedAttr imag) {
260 mlir::Type elemType = type.getElementType();
261 if (real.getType() != elemType)
262 return emitError()
263 << "type of the real part does not match the complex type";
264
265 if (imag.getType() != elemType)
266 return emitError()
267 << "type of the imaginary part does not match the complex type";
268
269 return success();
270}
271
272//===----------------------------------------------------------------------===//
273// DataMemberAttr definitions
274//===----------------------------------------------------------------------===//
275
276LogicalResult
277DataMemberAttr::verify(function_ref<InFlightDiagnostic()> emitError,
278 cir::DataMemberType ty,
279 std::optional<unsigned> memberIndex) {
280 // DataMemberAttr without a given index represents a null value.
281 if (!memberIndex.has_value())
282 return success();
283
284 cir::RecordType recTy = ty.getClassTy();
285 if (recTy.isIncomplete())
286 return emitError()
287 << "incomplete 'cir.record' cannot be used to build a non-null "
288 "data member pointer";
289
290 unsigned memberIndexValue = memberIndex.value();
291 if (memberIndexValue >= recTy.getNumElements())
292 return emitError()
293 << "member index of a #cir.data_member attribute is out of range";
294
295 mlir::Type memberTy = recTy.getMembers()[memberIndexValue];
296 if (memberTy != ty.getMemberTy())
297 return emitError()
298 << "member type of a #cir.data_member attribute must match the "
299 "attribute type";
300
301 return success();
302}
303
304//===----------------------------------------------------------------------===//
305// MethodAttr definitions
306//===----------------------------------------------------------------------===//
307
308Attribute MethodAttr::parse(AsmParser &parser, Type odsType) {
309 auto ty = mlir::cast<cir::MethodType>(odsType);
310
311 if (parser.parseLess().failed())
312 return {};
313
314 // Try to parse the null pointer constant.
315 if (parser.parseOptionalKeyword("null").succeeded()) {
316 if (parser.parseGreater().failed())
317 return {};
318 return get(ty);
319 }
320
321 // Try to parse a flat symbol ref for a pointer to non-virtual member
322 // function.
323 FlatSymbolRefAttr symbol;
324 mlir::OptionalParseResult parseSymbolRefResult =
325 parser.parseOptionalAttribute(symbol);
326 if (parseSymbolRefResult.has_value()) {
327 if (parseSymbolRefResult.value().failed())
328 return {};
329 if (parser.parseGreater().failed())
330 return {};
331 return get(ty, symbol);
332 }
333
334 return {};
335}
336
337void MethodAttr::print(AsmPrinter &printer) const {
338 auto symbol = getSymbol();
339
340 printer << '<';
341 if (symbol.has_value()) {
342 printer << *symbol;
343 } else {
344 printer << "null";
345 }
346 printer << '>';
347}
348
349//===----------------------------------------------------------------------===//
350// CIR ConstArrayAttr
351//===----------------------------------------------------------------------===//
352
353LogicalResult
354ConstArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError, Type type,
355 Attribute elts, int trailingZerosNum) {
356
357 if (!(mlir::isa<ArrayAttr, StringAttr>(elts)))
358 return emitError() << "constant array expects ArrayAttr or StringAttr";
359
360 if (auto strAttr = mlir::dyn_cast<StringAttr>(elts)) {
361 const auto arrayTy = mlir::cast<ArrayType>(type);
362 const auto intTy = mlir::dyn_cast<IntType>(arrayTy.getElementType());
363
364 // TODO: add CIR type for char.
365 if (!intTy || intTy.getWidth() != 8)
366 return emitError()
367 << "constant array element for string literals expects "
368 "!cir.int<u, 8> element type";
369 return success();
370 }
371
372 assert(mlir::isa<ArrayAttr>(elts));
373 const auto arrayAttr = mlir::cast<mlir::ArrayAttr>(elts);
374 const auto arrayTy = mlir::cast<ArrayType>(type);
375
376 // Make sure both number of elements and subelement types match type.
377 if (arrayTy.getSize() != arrayAttr.size() + trailingZerosNum)
378 return emitError() << "constant array size should match type size";
379 return success();
380}
381
382Attribute ConstArrayAttr::parse(AsmParser &parser, Type type) {
383 mlir::FailureOr<Type> resultTy;
384 mlir::FailureOr<Attribute> resultVal;
385
386 // Parse literal '<'
387 if (parser.parseLess())
388 return {};
389
390 // Parse variable 'value'
391 resultVal = FieldParser<Attribute>::parse(parser);
392 if (failed(resultVal)) {
393 parser.emitError(
394 parser.getCurrentLocation(),
395 "failed to parse ConstArrayAttr parameter 'value' which is "
396 "to be a `Attribute`");
397 return {};
398 }
399
400 // ArrayAttrrs have per-element type, not the type of the array...
401 if (mlir::isa<ArrayAttr>(*resultVal)) {
402 // Array has implicit type: infer from const array type.
403 if (parser.parseOptionalColon().failed()) {
404 resultTy = type;
405 } else { // Array has explicit type: parse it.
406 resultTy = FieldParser<Type>::parse(parser);
407 if (failed(resultTy)) {
408 parser.emitError(
409 parser.getCurrentLocation(),
410 "failed to parse ConstArrayAttr parameter 'type' which is "
411 "to be a `::mlir::Type`");
412 return {};
413 }
414 }
415 } else {
416 auto ta = mlir::cast<TypedAttr>(*resultVal);
417 resultTy = ta.getType();
418 if (mlir::isa<mlir::NoneType>(*resultTy)) {
419 parser.emitError(parser.getCurrentLocation(),
420 "expected type declaration for string literal");
421 return {};
422 }
423 }
424
425 unsigned zeros = 0;
426 if (parser.parseOptionalComma().succeeded()) {
427 if (parser.parseOptionalKeyword("trailing_zeros").succeeded()) {
428 unsigned typeSize =
429 mlir::cast<cir::ArrayType>(resultTy.value()).getSize();
430 mlir::Attribute elts = resultVal.value();
431 if (auto str = mlir::dyn_cast<mlir::StringAttr>(elts))
432 zeros = typeSize - str.size();
433 else
434 zeros = typeSize - mlir::cast<mlir::ArrayAttr>(elts).size();
435 } else {
436 return {};
437 }
438 }
439
440 // Parse literal '>'
441 if (parser.parseGreater())
442 return {};
443
444 return parser.getChecked<ConstArrayAttr>(
445 parser.getCurrentLocation(), parser.getContext(), resultTy.value(),
446 resultVal.value(), zeros);
447}
448
449void ConstArrayAttr::print(AsmPrinter &printer) const {
450 printer << "<";
451 printer.printStrippedAttrOrType(getElts());
452 if (getTrailingZerosNum())
453 printer << ", trailing_zeros";
454 printer << ">";
455}
456
457//===----------------------------------------------------------------------===//
458// CIR ConstVectorAttr
459//===----------------------------------------------------------------------===//
460
461LogicalResult
462cir::ConstVectorAttr::verify(function_ref<InFlightDiagnostic()> emitError,
463 Type type, ArrayAttr elts) {
464
465 if (!mlir::isa<cir::VectorType>(type))
466 return emitError() << "type of cir::ConstVectorAttr is not a "
467 "cir::VectorType: "
468 << type;
469
470 const auto vecType = mlir::cast<cir::VectorType>(type);
471
472 if (vecType.getSize() != elts.size())
473 return emitError()
474 << "number of constant elements should match vector size";
475
476 // Check if the types of the elements match
477 LogicalResult elementTypeCheck = success();
478 elts.walkImmediateSubElements(
479 [&](Attribute element) {
480 if (elementTypeCheck.failed()) {
481 // An earlier element didn't match
482 return;
483 }
484 auto typedElement = mlir::dyn_cast<TypedAttr>(element);
485 if (!typedElement ||
486 typedElement.getType() != vecType.getElementType()) {
487 elementTypeCheck = failure();
488 emitError() << "constant type should match vector element type";
489 }
490 },
491 [&](Type) {});
492
493 return elementTypeCheck;
494}
495
496//===----------------------------------------------------------------------===//
497// CIR VTableAttr
498//===----------------------------------------------------------------------===//
499
500LogicalResult cir::VTableAttr::verify(
501 llvm::function_ref<mlir::InFlightDiagnostic()> emitError, mlir::Type type,
502 mlir::ArrayAttr data) {
503 auto sTy = mlir::dyn_cast_if_present<cir::RecordType>(type);
504 if (!sTy)
505 return emitError() << "expected !cir.record type result";
506 if (sTy.getMembers().empty() || data.empty())
507 return emitError() << "expected record type with one or more subtype";
508
509 if (cir::ConstRecordAttr::verify(emitError, type, data).failed())
510 return failure();
511
512 for (const auto &element : data.getAsRange<mlir::Attribute>()) {
513 const auto &constArrayAttr = mlir::dyn_cast<cir::ConstArrayAttr>(element);
514 if (!constArrayAttr)
515 return emitError() << "expected constant array subtype";
516
517 LogicalResult eltTypeCheck = success();
518 auto arrayElts = mlir::cast<ArrayAttr>(constArrayAttr.getElts());
519 arrayElts.walkImmediateSubElements(
520 [&](mlir::Attribute attr) {
521 if (mlir::isa<ConstPtrAttr, GlobalViewAttr>(attr))
522 return;
523
524 eltTypeCheck = emitError()
525 << "expected GlobalViewAttr or ConstPtrAttr";
526 },
527 [&](mlir::Type type) {});
528 if (eltTypeCheck.failed())
529 return eltTypeCheck;
530 }
531 return success();
532}
533
534//===----------------------------------------------------------------------===//
535// DynamicCastInfoAtttr definitions
536//===----------------------------------------------------------------------===//
537
538std::string DynamicCastInfoAttr::getAlias() const {
539 // The alias looks like: `dyn_cast_info_<src>_<dest>`
540
541 std::string alias = "dyn_cast_info_";
542
543 alias.append(getSrcRtti().getSymbol().getValue());
544 alias.push_back('_');
545 alias.append(getDestRtti().getSymbol().getValue());
546
547 return alias;
548}
549
550LogicalResult DynamicCastInfoAttr::verify(
551 function_ref<InFlightDiagnostic()> emitError, cir::GlobalViewAttr srcRtti,
552 cir::GlobalViewAttr destRtti, mlir::FlatSymbolRefAttr runtimeFunc,
553 mlir::FlatSymbolRefAttr badCastFunc, cir::IntAttr offsetHint) {
554 auto isRttiPtr = [](mlir::Type ty) {
555 // RTTI pointers are !cir.ptr<!u8i>.
556
557 auto ptrTy = mlir::dyn_cast<cir::PointerType>(ty);
558 if (!ptrTy)
559 return false;
560
561 auto pointeeIntTy = mlir::dyn_cast<cir::IntType>(ptrTy.getPointee());
562 if (!pointeeIntTy)
563 return false;
564
565 return pointeeIntTy.isUnsigned() && pointeeIntTy.getWidth() == 8;
566 };
567
568 if (!isRttiPtr(srcRtti.getType()))
569 return emitError() << "srcRtti must be an RTTI pointer";
570
571 if (!isRttiPtr(destRtti.getType()))
572 return emitError() << "destRtti must be an RTTI pointer";
573
574 return success();
575}
576
577//===----------------------------------------------------------------------===//
578// CIR Dialect
579//===----------------------------------------------------------------------===//
580
581void CIRDialect::registerAttributes() {
582 addAttributes<
583#define GET_ATTRDEF_LIST
584#include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc"
585 >();
586}
static mlir::ParseResult parseFloatLiteral(mlir::AsmParser &parser, mlir::FailureOr< llvm::APFloat > &value, cir::FPTypeInterface fpType)
static mlir::ParseResult parseIntLiteralImpl(mlir::AsmParser &p, llvm::APInt &value, cir::IntTypeInterface ty)
Definition CIRAttrs.cpp:178
static void printConstPtr(mlir::AsmPrinter &p, mlir::IntegerAttr value)
static void printRecordMembers(mlir::AsmPrinter &p, mlir::ArrayAttr members)
Definition CIRAttrs.cpp:71
static mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser, llvm::APInt &value, cir::IntTypeInterface ty)
Definition CIRAttrs.cpp:194
mlir::ParseResult parseTargetAddressSpace(mlir::AsmParser &p, cir::TargetAddressSpaceAttr &attr)
Definition CIRTypes.cpp:957
static void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value, cir::IntTypeInterface ty)
Definition CIRAttrs.cpp:201
static mlir::ParseResult parseConstPtr(mlir::AsmParser &parser, mlir::IntegerAttr &value)
static bool isTooLargeForType(const mlir::APInt &value, IntT expectedValue)
Definition CIRAttrs.cpp:169
static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value, mlir::Type ty)
static mlir::ParseResult parseRecordMembers(mlir::AsmParser &parser, mlir::ArrayAttr &members)
Definition CIRAttrs.cpp:78
void printTargetAddressSpace(mlir::AsmPrinter &p, cir::TargetAddressSpaceAttr attr)
Definition CIRTypes.cpp:982
RangeSelector member(std::string ID)
Given a MemberExpr, selects the member token. ID is the node's binding in the match result.