clang-tools 22.0.0git
UseConstraintsCheck.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "clang/AST/ASTContext.h"
11#include "clang/AST/DeclTemplate.h"
12#include "clang/ASTMatchers/ASTMatchFinder.h"
13#include "clang/Lex/Lexer.h"
14
15#include "../utils/LexerUtils.h"
16
17#include <optional>
18#include <utility>
19
20using namespace clang::ast_matchers;
21
22namespace clang::tidy::modernize {
23
25 TemplateSpecializationTypeLoc Loc;
26 TypeLoc Outer;
27};
28
29namespace {
30AST_MATCHER(FunctionDecl, hasOtherDeclarations) {
31 auto It = Node.redecls_begin();
32 auto EndIt = Node.redecls_end();
33
34 if (It == EndIt)
35 return false;
36
37 ++It;
38 return It != EndIt;
39}
40} // namespace
41
42void UseConstraintsCheck::registerMatchers(MatchFinder *Finder) {
43 Finder->addMatcher(
44 functionTemplateDecl(
45 // Skip external libraries included as system headers
46 unless(isExpansionInSystemHeader()),
47 has(functionDecl(unless(hasOtherDeclarations()), isDefinition(),
48 hasReturnTypeLoc(typeLoc().bind("return")))
49 .bind("function")))
50 .bind("functionTemplate"),
51 this);
52}
53
54static std::optional<TemplateSpecializationTypeLoc>
56 if (const auto Dep = TheType.getAs<DependentNameTypeLoc>()) {
57 const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier();
58 ElaboratedTypeKeyword Keyword = Dep.getTypePtr()->getKeyword();
59 if (!Identifier || Identifier->getName() != "type" ||
60 (Keyword != ElaboratedTypeKeyword::Typename &&
61 Keyword != ElaboratedTypeKeyword::None)) {
62 return std::nullopt;
63 }
64 TheType = Dep.getQualifierLoc().getAsTypeLoc();
65 if (TheType.isNull())
66 return std::nullopt;
67 } else {
68 return std::nullopt;
69 }
70
71 if (const auto SpecializationLoc =
72 TheType.getAs<TemplateSpecializationTypeLoc>()) {
73
74 const auto *Specialization =
75 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
76 if (!Specialization)
77 return std::nullopt;
78
79 const TemplateDecl *TD =
80 Specialization->getTemplateName().getAsTemplateDecl();
81 if (!TD || TD->getName() != "enable_if")
82 return std::nullopt;
83
84 assert(!TD->getTemplateParameters()->empty() &&
85 "found template with no template parameters?");
86 const auto *FirstParam = dyn_cast<NonTypeTemplateParmDecl>(
87 TD->getTemplateParameters()->getParam(0));
88 if (!FirstParam || !FirstParam->getType()->isBooleanType())
89 return std::nullopt;
90
91 int NumArgs = SpecializationLoc.getNumArgs();
92 if (NumArgs != 1 && NumArgs != 2)
93 return std::nullopt;
94
95 return SpecializationLoc;
96 }
97 return std::nullopt;
98}
99
100static std::optional<TemplateSpecializationTypeLoc>
102 if (const auto SpecializationLoc =
103 TheType.getAs<TemplateSpecializationTypeLoc>()) {
104
105 const auto *Specialization =
106 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
107 if (!Specialization)
108 return std::nullopt;
109
110 const TemplateDecl *TD =
111 Specialization->getTemplateName().getAsTemplateDecl();
112 if (!TD || TD->getName() != "enable_if_t")
113 return std::nullopt;
114
115 if (!Specialization->isTypeAlias())
116 return std::nullopt;
117
118 assert(!TD->getTemplateParameters()->empty() &&
119 "found template with no template parameters?");
120 const auto *FirstParam = dyn_cast<NonTypeTemplateParmDecl>(
121 TD->getTemplateParameters()->getParam(0));
122 if (!FirstParam || !FirstParam->getType()->isBooleanType())
123 return std::nullopt;
124
125 if (const auto *AliasedType =
126 dyn_cast<DependentNameType>(Specialization->getAliasedType())) {
127 ElaboratedTypeKeyword Keyword = AliasedType->getKeyword();
128 if (AliasedType->getIdentifier()->getName() != "type" ||
129 (Keyword != ElaboratedTypeKeyword::Typename &&
130 Keyword != ElaboratedTypeKeyword::None)) {
131 return std::nullopt;
132 }
133 } else {
134 return std::nullopt;
135 }
136 int NumArgs = SpecializationLoc.getNumArgs();
137 if (NumArgs != 1 && NumArgs != 2)
138 return std::nullopt;
139
140 return SpecializationLoc;
141 }
142 return std::nullopt;
143}
144
145static std::optional<TemplateSpecializationTypeLoc>
147 if (auto EnableIf = matchEnableIfSpecializationImplTypename(TheType))
148 return EnableIf;
150}
151
152static std::optional<EnableIfData>
154 if (const auto Pointer = TheType.getAs<PointerTypeLoc>())
155 TheType = Pointer.getPointeeLoc();
156 else if (const auto Reference = TheType.getAs<ReferenceTypeLoc>())
157 TheType = Reference.getPointeeLoc();
158 if (const auto Qualified = TheType.getAs<QualifiedTypeLoc>())
159 TheType = Qualified.getUnqualifiedLoc();
160
161 if (auto EnableIf = matchEnableIfSpecializationImpl(TheType))
162 return EnableIfData{std::move(*EnableIf), TheType};
163 return std::nullopt;
164}
165
166static std::pair<std::optional<EnableIfData>, const Decl *>
167matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate) {
168 // For non-type trailing param, match very specifically
169 // 'template <..., enable_if_type<Condition, Type> = Default>' where
170 // enable_if_type is 'enable_if' or 'enable_if_t'. E.g., 'template <typename
171 // T, enable_if_t<is_same_v<T, bool>, int*> = nullptr>
172 //
173 // Otherwise, match a trailing default type arg.
174 // E.g., 'template <typename T, typename = enable_if_t<is_same_v<T, bool>>>'
175
176 const TemplateParameterList *TemplateParams =
177 FunctionTemplate->getTemplateParameters();
178 if (TemplateParams->empty())
179 return {};
180
181 const NamedDecl *LastParam =
182 TemplateParams->getParam(TemplateParams->size() - 1);
183 if (const auto *LastTemplateParam =
184 dyn_cast<NonTypeTemplateParmDecl>(LastParam)) {
185
186 if (!LastTemplateParam->hasDefaultArgument() ||
187 !LastTemplateParam->getName().empty())
188 return {};
189
191 LastTemplateParam->getTypeSourceInfo()->getTypeLoc()),
192 LastTemplateParam};
193 }
194 if (const auto *LastTemplateParam =
195 dyn_cast<TemplateTypeParmDecl>(LastParam)) {
196 if (LastTemplateParam->hasDefaultArgument() &&
197 LastTemplateParam->getIdentifier() == nullptr) {
198 return {
199 matchEnableIfSpecialization(LastTemplateParam->getDefaultArgument()
200 .getTypeSourceInfo()
201 ->getTypeLoc()),
202 LastTemplateParam};
203 }
204 }
205 return {};
206}
207
208template <typename T>
209static SourceLocation getRAngleFileLoc(const SourceManager &SM,
210 const T &Element) {
211 // getFileLoc handles the case where the RAngle loc is part of a synthesized
212 // '>>', which ends up allocating a 'scratch space' buffer in the source
213 // manager.
214 return SM.getFileLoc(Element.getRAngleLoc());
215}
216
217static SourceRange
218getConditionRange(ASTContext &Context,
219 const TemplateSpecializationTypeLoc &EnableIf) {
220 // TemplateArgumentLoc's SourceRange End is the location of the last token
221 // (per UnqualifiedId docs). E.g., in `enable_if<AAA && BBB>`, the End
222 // location will be the first 'B' in 'BBB'.
223 const LangOptions &LangOpts = Context.getLangOpts();
224 const SourceManager &SM = Context.getSourceManager();
225 if (EnableIf.getNumArgs() > 1) {
226 TemplateArgumentLoc NextArg = EnableIf.getArgLoc(1);
227 return {EnableIf.getLAngleLoc().getLocWithOffset(1),
229 NextArg.getSourceRange().getBegin(), SM, LangOpts, tok::comma)};
230 }
231
232 return {EnableIf.getLAngleLoc().getLocWithOffset(1),
233 getRAngleFileLoc(SM, EnableIf)};
234}
235
236static SourceRange getTypeRange(ASTContext &Context,
237 const TemplateSpecializationTypeLoc &EnableIf) {
238 TemplateArgumentLoc Arg = EnableIf.getArgLoc(1);
239 const LangOptions &LangOpts = Context.getLangOpts();
240 const SourceManager &SM = Context.getSourceManager();
241 return {utils::lexer::findPreviousTokenKind(Arg.getSourceRange().getBegin(),
242 SM, LangOpts, tok::comma)
243 .getLocWithOffset(1),
244 getRAngleFileLoc(SM, EnableIf)};
245}
246
247// Returns the original source text of the second argument of a call to
248// enable_if_t. E.g., in enable_if_t<Condition, TheType>, this function
249// returns 'TheType'.
250static std::optional<StringRef>
251getTypeText(ASTContext &Context,
252 const TemplateSpecializationTypeLoc &EnableIf) {
253 if (EnableIf.getNumArgs() > 1) {
254 const LangOptions &LangOpts = Context.getLangOpts();
255 const SourceManager &SM = Context.getSourceManager();
256 bool Invalid = false;
257 StringRef Text = Lexer::getSourceText(CharSourceRange::getCharRange(
258 getTypeRange(Context, EnableIf)),
259 SM, LangOpts, &Invalid)
260 .trim();
261 if (Invalid)
262 return std::nullopt;
263
264 return Text;
265 }
266
267 return "void";
268}
269
270static std::optional<SourceLocation>
271findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context) {
272 SourceManager &SM = Context.getSourceManager();
273 const LangOptions &LangOpts = Context.getLangOpts();
274
275 if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(Function)) {
276 for (const CXXCtorInitializer *Init : Constructor->inits()) {
277 if (Init->getSourceOrder() == 0)
278 return utils::lexer::findPreviousTokenKind(Init->getSourceLocation(),
279 SM, LangOpts, tok::colon);
280 }
281 if (!Constructor->inits().empty())
282 return std::nullopt;
283 }
284 if (Function->isDeleted()) {
285 SourceLocation FunctionEnd = Function->getSourceRange().getEnd();
286 return utils::lexer::findNextAnyTokenKind(FunctionEnd, SM, LangOpts,
287 tok::equal, tok::equal);
288 }
289 const Stmt *Body = Function->getBody();
290 if (!Body)
291 return std::nullopt;
292
293 return Body->getBeginLoc();
294}
295
296static bool isPrimaryExpression(const Expr *Expression) {
297 // This function is an incomplete approximation of checking whether
298 // an Expr is a primary expression. In particular, if this function
299 // returns true, the expression is a primary expression. The converse
300 // is not necessarily true.
301
302 if (const auto *Cast = dyn_cast<ImplicitCastExpr>(Expression))
303 Expression = Cast->getSubExprAsWritten();
304 if (isa<ParenExpr, DependentScopeDeclRefExpr>(Expression))
305 return true;
306
307 return false;
308}
309
310// Return the original source text of an enable_if_t condition, i.e., the
311// first template argument). For example, in
312// 'enable_if_t<FirstCondition || SecondCondition, AType>', the text
313// the text 'FirstCondition || SecondCondition' is returned.
314static std::optional<std::string> getConditionText(const Expr *ConditionExpr,
315 SourceRange ConditionRange,
316 ASTContext &Context) {
317 SourceManager &SM = Context.getSourceManager();
318 const LangOptions &LangOpts = Context.getLangOpts();
319
320 SourceLocation PrevTokenLoc = ConditionRange.getEnd();
321 if (PrevTokenLoc.isInvalid())
322 return std::nullopt;
323
324 const bool SkipComments = false;
325 Token PrevToken;
326 std::tie(PrevToken, PrevTokenLoc) = utils::lexer::getPreviousTokenAndStart(
327 PrevTokenLoc, SM, LangOpts, SkipComments);
328 bool EndsWithDoubleSlash =
329 PrevToken.is(tok::comment) &&
330 Lexer::getSourceText(CharSourceRange::getCharRange(
331 PrevTokenLoc, PrevTokenLoc.getLocWithOffset(2)),
332 SM, LangOpts) == "//";
333
334 bool Invalid = false;
335 llvm::StringRef ConditionText = Lexer::getSourceText(
336 CharSourceRange::getCharRange(ConditionRange), SM, LangOpts, &Invalid);
337 if (Invalid)
338 return std::nullopt;
339
340 auto AddParens = [&](llvm::StringRef Text) -> std::string {
341 if (isPrimaryExpression(ConditionExpr))
342 return Text.str();
343 return "(" + Text.str() + ")";
344 };
345
346 if (EndsWithDoubleSlash)
347 return AddParens(ConditionText);
348 return AddParens(ConditionText.trim());
349}
350
351// Handle functions that return enable_if_t, e.g.,
352// template <...>
353// enable_if_t<Condition, ReturnType> function();
354//
355// Return a vector of FixItHints if the code can be replaced with
356// a C++20 requires clause. In the example above, returns FixItHints
357// to result in
358// template <...>
359// ReturnType function() requires Condition {}
360static std::vector<FixItHint> handleReturnType(const FunctionDecl *Function,
361 const TypeLoc &ReturnType,
362 const EnableIfData &EnableIf,
363 ASTContext &Context) {
364 TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
365
366 SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
367
368 std::optional<std::string> ConditionText = getConditionText(
369 EnableCondition.getSourceExpression(), ConditionRange, Context);
370 if (!ConditionText)
371 return {};
372
373 std::optional<StringRef> TypeText = getTypeText(Context, EnableIf.Loc);
374 if (!TypeText)
375 return {};
376
377 SmallVector<AssociatedConstraint, 3> ExistingConstraints;
378 Function->getAssociatedConstraints(ExistingConstraints);
379 if (!ExistingConstraints.empty()) {
380 // FIXME - Support adding new constraints to existing ones. Do we need to
381 // consider subsumption?
382 return {};
383 }
384
385 std::optional<SourceLocation> ConstraintInsertionLoc =
386 findInsertionForConstraint(Function, Context);
387 if (!ConstraintInsertionLoc)
388 return {};
389
390 std::vector<FixItHint> FixIts;
391 FixIts.push_back(FixItHint::CreateReplacement(
392 CharSourceRange::getTokenRange(EnableIf.Outer.getSourceRange()),
393 *TypeText));
394 FixIts.push_back(FixItHint::CreateInsertion(
395 *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
396 return FixIts;
397}
398
399// Handle enable_if_t in a trailing template parameter, e.g.,
400// template <..., enable_if_t<Condition, Type> = Type{}>
401// ReturnType function();
402//
403// Return a vector of FixItHints if the code can be replaced with
404// a C++20 requires clause. In the example above, returns FixItHints
405// to result in
406// template <...>
407// ReturnType function() requires Condition {}
408static std::vector<FixItHint>
409handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate,
410 const FunctionDecl *Function,
411 const Decl *LastTemplateParam,
412 const EnableIfData &EnableIf, ASTContext &Context) {
413 SourceManager &SM = Context.getSourceManager();
414 const LangOptions &LangOpts = Context.getLangOpts();
415
416 TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
417
418 SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
419
420 std::optional<std::string> ConditionText = getConditionText(
421 EnableCondition.getSourceExpression(), ConditionRange, Context);
422 if (!ConditionText)
423 return {};
424
425 SmallVector<AssociatedConstraint, 3> ExistingConstraints;
426 Function->getAssociatedConstraints(ExistingConstraints);
427 if (!ExistingConstraints.empty()) {
428 // FIXME - Support adding new constraints to existing ones. Do we need to
429 // consider subsumption?
430 return {};
431 }
432
433 SourceRange RemovalRange;
434 const TemplateParameterList *TemplateParams =
435 FunctionTemplate->getTemplateParameters();
436 if (!TemplateParams || TemplateParams->empty())
437 return {};
438
439 if (TemplateParams->size() == 1) {
440 RemovalRange =
441 SourceRange(TemplateParams->getTemplateLoc(),
442 getRAngleFileLoc(SM, *TemplateParams).getLocWithOffset(1));
443 } else {
444 RemovalRange =
446 LastTemplateParam->getSourceRange().getBegin(), SM,
447 LangOpts, tok::comma),
448 getRAngleFileLoc(SM, *TemplateParams));
449 }
450
451 std::optional<SourceLocation> ConstraintInsertionLoc =
452 findInsertionForConstraint(Function, Context);
453 if (!ConstraintInsertionLoc)
454 return {};
455
456 std::vector<FixItHint> FixIts;
457 FixIts.push_back(
458 FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange)));
459 FixIts.push_back(FixItHint::CreateInsertion(
460 *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
461 return FixIts;
462}
463
464void UseConstraintsCheck::check(const MatchFinder::MatchResult &Result) {
465 const auto *FunctionTemplate =
466 Result.Nodes.getNodeAs<FunctionTemplateDecl>("functionTemplate");
467 const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>("function");
468 const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>("return");
469 if (!FunctionTemplate || !Function || !ReturnType)
470 return;
471
472 // Check for
473 //
474 // Case 1. Return type of function
475 //
476 // template <...>
477 // enable_if_t<Condition, ReturnType>::type function() {}
478 //
479 // Case 2. Trailing template parameter
480 //
481 // template <..., enable_if_t<Condition, Type> = Type{}>
482 // ReturnType function() {}
483 //
484 // or
485 //
486 // template <..., typename = enable_if_t<Condition, void>>
487 // ReturnType function() {}
488 //
489
490 // Case 1. Return type of function
491 if (auto EnableIf = matchEnableIfSpecialization(*ReturnType)) {
492 diag(ReturnType->getBeginLoc(),
493 "use C++20 requires constraints instead of enable_if")
494 << handleReturnType(Function, *ReturnType, *EnableIf, *Result.Context);
495 return;
496 }
497
498 // Case 2. Trailing template parameter
499 if (auto [EnableIf, LastTemplateParam] =
500 matchTrailingTemplateParam(FunctionTemplate);
501 EnableIf && LastTemplateParam) {
502 diag(LastTemplateParam->getSourceRange().getBegin(),
503 "use C++20 requires constraints instead of enable_if")
504 << handleTrailingTemplateType(FunctionTemplate, Function,
505 LastTemplateParam, *EnableIf,
506 *Result.Context);
507 return;
508 }
509}
510
511} // namespace clang::tidy::modernize
void registerMatchers(ast_matchers::MatchFinder *Finder) override
void check(const ast_matchers::MatchFinder::MatchResult &Result) override
static std::pair< std::optional< EnableIfData >, const Decl * > matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate)
static std::vector< FixItHint > handleReturnType(const FunctionDecl *Function, const TypeLoc &ReturnType, const EnableIfData &EnableIf, ASTContext &Context)
static SourceRange getTypeRange(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static std::optional< SourceLocation > findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImplTypename(TypeLoc TheType)
static std::optional< std::string > getConditionText(const Expr *ConditionExpr, SourceRange ConditionRange, ASTContext &Context)
static std::vector< FixItHint > handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate, const FunctionDecl *Function, const Decl *LastTemplateParam, const EnableIfData &EnableIf, ASTContext &Context)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImplTrait(TypeLoc TheType)
static std::optional< StringRef > getTypeText(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static SourceRange getConditionRange(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static SourceLocation getRAngleFileLoc(const SourceManager &SM, const T &Element)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImpl(TypeLoc TheType)
static bool isPrimaryExpression(const Expr *Expression)
static std::optional< EnableIfData > matchEnableIfSpecialization(TypeLoc TheType)
std::pair< Token, SourceLocation > getPreviousTokenAndStart(SourceLocation Location, const SourceManager &SM, const LangOptions &LangOpts, bool SkipComments)
SourceLocation findNextAnyTokenKind(SourceLocation Start, const SourceManager &SM, const LangOptions &LangOpts, TokenKind TK, TokenKinds... TKs)
Definition LexerUtils.h:68
SourceLocation findPreviousTokenKind(SourceLocation Start, const SourceManager &SM, const LangOptions &LangOpts, tok::TokenKind TK)