clang-tools 22.0.0git
UseConstraintsCheck.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "clang/AST/ASTContext.h"
11#include "clang/AST/DeclTemplate.h"
12#include "clang/ASTMatchers/ASTMatchFinder.h"
13#include "clang/Lex/Lexer.h"
14
15#include "../utils/LexerUtils.h"
16
17#include <optional>
18#include <utility>
19
20using namespace clang::ast_matchers;
21
22namespace clang::tidy::modernize {
23
25 TemplateSpecializationTypeLoc Loc;
26 TypeLoc Outer;
27};
28
29namespace {
30AST_MATCHER(FunctionDecl, hasOtherDeclarations) {
31 auto It = Node.redecls_begin();
32 auto EndIt = Node.redecls_end();
33
34 if (It == EndIt)
35 return false;
36
37 ++It;
38 return It != EndIt;
39}
40} // namespace
41
42void UseConstraintsCheck::registerMatchers(MatchFinder *Finder) {
43 Finder->addMatcher(
44 functionTemplateDecl(
45 // Skip external libraries included as system headers
46 unless(isExpansionInSystemHeader()),
47 has(functionDecl(unless(hasOtherDeclarations()), isDefinition(),
48 hasReturnTypeLoc(typeLoc().bind("return")))
49 .bind("function")))
50 .bind("functionTemplate"),
51 this);
52}
53
54static std::optional<TemplateSpecializationTypeLoc>
56 if (const auto Dep = TheType.getAs<DependentNameTypeLoc>()) {
57 const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier();
58 const ElaboratedTypeKeyword Keyword = Dep.getTypePtr()->getKeyword();
59 if (!Identifier || Identifier->getName() != "type" ||
60 (Keyword != ElaboratedTypeKeyword::Typename &&
61 Keyword != ElaboratedTypeKeyword::None)) {
62 return std::nullopt;
63 }
64 TheType = Dep.getQualifierLoc().getAsTypeLoc();
65 if (TheType.isNull())
66 return std::nullopt;
67 } else {
68 return std::nullopt;
69 }
70
71 if (const auto SpecializationLoc =
72 TheType.getAs<TemplateSpecializationTypeLoc>()) {
73 const auto *Specialization =
74 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
75 if (!Specialization)
76 return std::nullopt;
77
78 const TemplateDecl *TD =
79 Specialization->getTemplateName().getAsTemplateDecl();
80 if (!TD || TD->getName() != "enable_if")
81 return std::nullopt;
82
83 assert(!TD->getTemplateParameters()->empty() &&
84 "found template with no template parameters?");
85 const auto *FirstParam = dyn_cast<NonTypeTemplateParmDecl>(
86 TD->getTemplateParameters()->getParam(0));
87 if (!FirstParam || !FirstParam->getType()->isBooleanType())
88 return std::nullopt;
89
90 const int NumArgs = SpecializationLoc.getNumArgs();
91 if (NumArgs != 1 && NumArgs != 2)
92 return std::nullopt;
93
94 return SpecializationLoc;
95 }
96 return std::nullopt;
97}
98
99static std::optional<TemplateSpecializationTypeLoc>
101 if (const auto SpecializationLoc =
102 TheType.getAs<TemplateSpecializationTypeLoc>()) {
103 const auto *Specialization =
104 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
105 if (!Specialization)
106 return std::nullopt;
107
108 const TemplateDecl *TD =
109 Specialization->getTemplateName().getAsTemplateDecl();
110 if (!TD || TD->getName() != "enable_if_t")
111 return std::nullopt;
112
113 if (!Specialization->isTypeAlias())
114 return std::nullopt;
115
116 assert(!TD->getTemplateParameters()->empty() &&
117 "found template with no template parameters?");
118 const auto *FirstParam = dyn_cast<NonTypeTemplateParmDecl>(
119 TD->getTemplateParameters()->getParam(0));
120 if (!FirstParam || !FirstParam->getType()->isBooleanType())
121 return std::nullopt;
122
123 if (const auto *AliasedType =
124 dyn_cast<DependentNameType>(Specialization->getAliasedType())) {
125 const ElaboratedTypeKeyword Keyword = AliasedType->getKeyword();
126 if (AliasedType->getIdentifier()->getName() != "type" ||
127 (Keyword != ElaboratedTypeKeyword::Typename &&
128 Keyword != ElaboratedTypeKeyword::None)) {
129 return std::nullopt;
130 }
131 } else {
132 return std::nullopt;
133 }
134 const int NumArgs = SpecializationLoc.getNumArgs();
135 if (NumArgs != 1 && NumArgs != 2)
136 return std::nullopt;
137
138 return SpecializationLoc;
139 }
140 return std::nullopt;
141}
142
143static std::optional<TemplateSpecializationTypeLoc>
145 if (auto EnableIf = matchEnableIfSpecializationImplTypename(TheType))
146 return EnableIf;
148}
149
150static std::optional<EnableIfData>
152 if (const auto Pointer = TheType.getAs<PointerTypeLoc>())
153 TheType = Pointer.getPointeeLoc();
154 else if (const auto Reference = TheType.getAs<ReferenceTypeLoc>())
155 TheType = Reference.getPointeeLoc();
156 if (const auto Qualified = TheType.getAs<QualifiedTypeLoc>())
157 TheType = Qualified.getUnqualifiedLoc();
158
159 if (auto EnableIf = matchEnableIfSpecializationImpl(TheType))
160 return EnableIfData{std::move(*EnableIf), TheType};
161 return std::nullopt;
162}
163
164static std::pair<std::optional<EnableIfData>, const Decl *>
165matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate) {
166 // For non-type trailing param, match very specifically
167 // 'template <..., enable_if_type<Condition, Type> = Default>' where
168 // enable_if_type is 'enable_if' or 'enable_if_t'. E.g., 'template <typename
169 // T, enable_if_t<is_same_v<T, bool>, int*> = nullptr>
170 //
171 // Otherwise, match a trailing default type arg.
172 // E.g., 'template <typename T, typename = enable_if_t<is_same_v<T, bool>>>'
173
174 const TemplateParameterList *TemplateParams =
175 FunctionTemplate->getTemplateParameters();
176 if (TemplateParams->empty())
177 return {};
178
179 const NamedDecl *LastParam =
180 TemplateParams->getParam(TemplateParams->size() - 1);
181 if (const auto *LastTemplateParam =
182 dyn_cast<NonTypeTemplateParmDecl>(LastParam)) {
183 if (!LastTemplateParam->hasDefaultArgument() ||
184 !LastTemplateParam->getName().empty())
185 return {};
186
188 LastTemplateParam->getTypeSourceInfo()->getTypeLoc()),
189 LastTemplateParam};
190 }
191 if (const auto *LastTemplateParam =
192 dyn_cast<TemplateTypeParmDecl>(LastParam)) {
193 if (LastTemplateParam->hasDefaultArgument() &&
194 LastTemplateParam->getIdentifier() == nullptr) {
195 return {
196 matchEnableIfSpecialization(LastTemplateParam->getDefaultArgument()
197 .getTypeSourceInfo()
198 ->getTypeLoc()),
199 LastTemplateParam};
200 }
201 }
202 return {};
203}
204
205template <typename T>
206static SourceLocation getRAngleFileLoc(const SourceManager &SM,
207 const T &Element) {
208 // getFileLoc handles the case where the RAngle loc is part of a synthesized
209 // '>>', which ends up allocating a 'scratch space' buffer in the source
210 // manager.
211 return SM.getFileLoc(Element.getRAngleLoc());
212}
213
214static SourceRange
215getConditionRange(ASTContext &Context,
216 const TemplateSpecializationTypeLoc &EnableIf) {
217 // TemplateArgumentLoc's SourceRange End is the location of the last token
218 // (per UnqualifiedId docs). E.g., in `enable_if<AAA && BBB>`, the End
219 // location will be the first 'B' in 'BBB'.
220 const LangOptions &LangOpts = Context.getLangOpts();
221 const SourceManager &SM = Context.getSourceManager();
222 if (EnableIf.getNumArgs() > 1) {
223 const TemplateArgumentLoc NextArg = EnableIf.getArgLoc(1);
224 return {EnableIf.getLAngleLoc().getLocWithOffset(1),
226 NextArg.getSourceRange().getBegin(), SM, LangOpts, tok::comma)};
227 }
228
229 return {EnableIf.getLAngleLoc().getLocWithOffset(1),
230 getRAngleFileLoc(SM, EnableIf)};
231}
232
233static SourceRange getTypeRange(ASTContext &Context,
234 const TemplateSpecializationTypeLoc &EnableIf) {
235 const TemplateArgumentLoc Arg = EnableIf.getArgLoc(1);
236 const LangOptions &LangOpts = Context.getLangOpts();
237 const SourceManager &SM = Context.getSourceManager();
238 return {utils::lexer::findPreviousTokenKind(Arg.getSourceRange().getBegin(),
239 SM, LangOpts, tok::comma)
240 .getLocWithOffset(1),
241 getRAngleFileLoc(SM, EnableIf)};
242}
243
244// Returns the original source text of the second argument of a call to
245// enable_if_t. E.g., in enable_if_t<Condition, TheType>, this function
246// returns 'TheType'.
247static std::optional<StringRef>
248getTypeText(ASTContext &Context,
249 const TemplateSpecializationTypeLoc &EnableIf) {
250 if (EnableIf.getNumArgs() > 1) {
251 const LangOptions &LangOpts = Context.getLangOpts();
252 const SourceManager &SM = Context.getSourceManager();
253 bool Invalid = false;
254 StringRef Text = Lexer::getSourceText(CharSourceRange::getCharRange(
255 getTypeRange(Context, EnableIf)),
256 SM, LangOpts, &Invalid)
257 .trim();
258 if (Invalid)
259 return std::nullopt;
260
261 return Text;
262 }
263
264 return "void";
265}
266
267static std::optional<SourceLocation>
268findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context) {
269 const SourceManager &SM = Context.getSourceManager();
270 const LangOptions &LangOpts = Context.getLangOpts();
271
272 if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(Function)) {
273 for (const CXXCtorInitializer *Init : Constructor->inits())
274 if (Init->getSourceOrder() == 0)
275 return utils::lexer::findPreviousTokenKind(Init->getSourceLocation(),
276 SM, LangOpts, tok::colon);
277 if (!Constructor->inits().empty())
278 return std::nullopt;
279 }
280 if (Function->isDeleted()) {
281 const SourceLocation FunctionEnd = Function->getSourceRange().getEnd();
282 return utils::lexer::findNextAnyTokenKind(FunctionEnd, SM, LangOpts,
283 tok::equal, tok::equal);
284 }
285 const Stmt *Body = Function->getBody();
286 if (!Body)
287 return std::nullopt;
288
289 return Body->getBeginLoc();
290}
291
292static bool isPrimaryExpression(const Expr *Expression) {
293 // This function is an incomplete approximation of checking whether
294 // an Expr is a primary expression. In particular, if this function
295 // returns true, the expression is a primary expression. The converse
296 // is not necessarily true.
297
298 if (const auto *Cast = dyn_cast<ImplicitCastExpr>(Expression))
299 Expression = Cast->getSubExprAsWritten();
300 if (isa<ParenExpr, DependentScopeDeclRefExpr>(Expression))
301 return true;
302
303 return false;
304}
305
306// Return the original source text of an enable_if_t condition, i.e., the
307// first template argument). For example, in
308// 'enable_if_t<FirstCondition || SecondCondition, AType>', the text
309// the text 'FirstCondition || SecondCondition' is returned.
310static std::optional<std::string> getConditionText(const Expr *ConditionExpr,
311 SourceRange ConditionRange,
312 ASTContext &Context) {
313 const SourceManager &SM = Context.getSourceManager();
314 const LangOptions &LangOpts = Context.getLangOpts();
315
316 SourceLocation PrevTokenLoc = ConditionRange.getEnd();
317 if (PrevTokenLoc.isInvalid())
318 return std::nullopt;
319
320 const bool SkipComments = false;
321 Token PrevToken;
322 std::tie(PrevToken, PrevTokenLoc) = utils::lexer::getPreviousTokenAndStart(
323 PrevTokenLoc, SM, LangOpts, SkipComments);
324 const bool EndsWithDoubleSlash =
325 PrevToken.is(tok::comment) &&
326 Lexer::getSourceText(CharSourceRange::getCharRange(
327 PrevTokenLoc, PrevTokenLoc.getLocWithOffset(2)),
328 SM, LangOpts) == "//";
329
330 bool Invalid = false;
331 const llvm::StringRef ConditionText = Lexer::getSourceText(
332 CharSourceRange::getCharRange(ConditionRange), SM, LangOpts, &Invalid);
333 if (Invalid)
334 return std::nullopt;
335
336 auto AddParens = [&](llvm::StringRef Text) -> std::string {
337 if (isPrimaryExpression(ConditionExpr))
338 return Text.str();
339 return "(" + Text.str() + ")";
340 };
341
342 if (EndsWithDoubleSlash)
343 return AddParens(ConditionText);
344 return AddParens(ConditionText.trim());
345}
346
347// Handle functions that return enable_if_t, e.g.,
348// template <...>
349// enable_if_t<Condition, ReturnType> function();
350//
351// Return a vector of FixItHints if the code can be replaced with
352// a C++20 requires clause. In the example above, returns FixItHints
353// to result in
354// template <...>
355// ReturnType function() requires Condition {}
356static std::vector<FixItHint> handleReturnType(const FunctionDecl *Function,
357 const TypeLoc &ReturnType,
358 const EnableIfData &EnableIf,
359 ASTContext &Context) {
360 const TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
361
362 const SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
363
364 std::optional<std::string> ConditionText = getConditionText(
365 EnableCondition.getSourceExpression(), ConditionRange, Context);
366 if (!ConditionText)
367 return {};
368
369 std::optional<StringRef> TypeText = getTypeText(Context, EnableIf.Loc);
370 if (!TypeText)
371 return {};
372
373 SmallVector<AssociatedConstraint, 3> ExistingConstraints;
374 Function->getAssociatedConstraints(ExistingConstraints);
375 if (!ExistingConstraints.empty()) {
376 // FIXME - Support adding new constraints to existing ones. Do we need to
377 // consider subsumption?
378 return {};
379 }
380
381 std::optional<SourceLocation> ConstraintInsertionLoc =
382 findInsertionForConstraint(Function, Context);
383 if (!ConstraintInsertionLoc)
384 return {};
385
386 std::vector<FixItHint> FixIts;
387 FixIts.push_back(FixItHint::CreateReplacement(
388 CharSourceRange::getTokenRange(EnableIf.Outer.getSourceRange()),
389 *TypeText));
390 FixIts.push_back(FixItHint::CreateInsertion(
391 *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
392 return FixIts;
393}
394
395// Handle enable_if_t in a trailing template parameter, e.g.,
396// template <..., enable_if_t<Condition, Type> = Type{}>
397// ReturnType function();
398//
399// Return a vector of FixItHints if the code can be replaced with
400// a C++20 requires clause. In the example above, returns FixItHints
401// to result in
402// template <...>
403// ReturnType function() requires Condition {}
404static std::vector<FixItHint>
405handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate,
406 const FunctionDecl *Function,
407 const Decl *LastTemplateParam,
408 const EnableIfData &EnableIf, ASTContext &Context) {
409 const SourceManager &SM = Context.getSourceManager();
410 const LangOptions &LangOpts = Context.getLangOpts();
411
412 const TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
413
414 const SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
415
416 std::optional<std::string> ConditionText = getConditionText(
417 EnableCondition.getSourceExpression(), ConditionRange, Context);
418 if (!ConditionText)
419 return {};
420
421 SmallVector<AssociatedConstraint, 3> ExistingConstraints;
422 Function->getAssociatedConstraints(ExistingConstraints);
423 if (!ExistingConstraints.empty()) {
424 // FIXME - Support adding new constraints to existing ones. Do we need to
425 // consider subsumption?
426 return {};
427 }
428
429 SourceRange RemovalRange;
430 const TemplateParameterList *TemplateParams =
431 FunctionTemplate->getTemplateParameters();
432 if (!TemplateParams || TemplateParams->empty())
433 return {};
434
435 if (TemplateParams->size() == 1) {
436 RemovalRange =
437 SourceRange(TemplateParams->getTemplateLoc(),
438 getRAngleFileLoc(SM, *TemplateParams).getLocWithOffset(1));
439 } else {
440 RemovalRange =
442 LastTemplateParam->getSourceRange().getBegin(), SM,
443 LangOpts, tok::comma),
444 getRAngleFileLoc(SM, *TemplateParams));
445 }
446
447 std::optional<SourceLocation> ConstraintInsertionLoc =
448 findInsertionForConstraint(Function, Context);
449 if (!ConstraintInsertionLoc)
450 return {};
451
452 std::vector<FixItHint> FixIts;
453 FixIts.push_back(
454 FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange)));
455 FixIts.push_back(FixItHint::CreateInsertion(
456 *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
457 return FixIts;
458}
459
460void UseConstraintsCheck::check(const MatchFinder::MatchResult &Result) {
461 const auto *FunctionTemplate =
462 Result.Nodes.getNodeAs<FunctionTemplateDecl>("functionTemplate");
463 const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>("function");
464 const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>("return");
465 if (!FunctionTemplate || !Function || !ReturnType)
466 return;
467
468 // Check for
469 //
470 // Case 1. Return type of function
471 //
472 // template <...>
473 // enable_if_t<Condition, ReturnType>::type function() {}
474 //
475 // Case 2. Trailing template parameter
476 //
477 // template <..., enable_if_t<Condition, Type> = Type{}>
478 // ReturnType function() {}
479 //
480 // or
481 //
482 // template <..., typename = enable_if_t<Condition, void>>
483 // ReturnType function() {}
484 //
485
486 // Case 1. Return type of function
487 if (auto EnableIf = matchEnableIfSpecialization(*ReturnType)) {
488 diag(ReturnType->getBeginLoc(),
489 "use C++20 requires constraints instead of enable_if")
490 << handleReturnType(Function, *ReturnType, *EnableIf, *Result.Context);
491 return;
492 }
493
494 // Case 2. Trailing template parameter
495 if (auto [EnableIf, LastTemplateParam] =
496 matchTrailingTemplateParam(FunctionTemplate);
497 EnableIf && LastTemplateParam) {
498 diag(LastTemplateParam->getSourceRange().getBegin(),
499 "use C++20 requires constraints instead of enable_if")
500 << handleTrailingTemplateType(FunctionTemplate, Function,
501 LastTemplateParam, *EnableIf,
502 *Result.Context);
503 return;
504 }
505}
506
507} // namespace clang::tidy::modernize
void registerMatchers(ast_matchers::MatchFinder *Finder) override
void check(const ast_matchers::MatchFinder::MatchResult &Result) override
static std::pair< std::optional< EnableIfData >, const Decl * > matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate)
static std::vector< FixItHint > handleReturnType(const FunctionDecl *Function, const TypeLoc &ReturnType, const EnableIfData &EnableIf, ASTContext &Context)
static SourceRange getTypeRange(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static std::optional< SourceLocation > findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImplTypename(TypeLoc TheType)
static std::optional< std::string > getConditionText(const Expr *ConditionExpr, SourceRange ConditionRange, ASTContext &Context)
static std::vector< FixItHint > handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate, const FunctionDecl *Function, const Decl *LastTemplateParam, const EnableIfData &EnableIf, ASTContext &Context)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImplTrait(TypeLoc TheType)
static std::optional< StringRef > getTypeText(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static SourceRange getConditionRange(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static SourceLocation getRAngleFileLoc(const SourceManager &SM, const T &Element)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImpl(TypeLoc TheType)
static bool isPrimaryExpression(const Expr *Expression)
static std::optional< EnableIfData > matchEnableIfSpecialization(TypeLoc TheType)
std::pair< Token, SourceLocation > getPreviousTokenAndStart(SourceLocation Location, const SourceManager &SM, const LangOptions &LangOpts, bool SkipComments)
SourceLocation findNextAnyTokenKind(SourceLocation Start, const SourceManager &SM, const LangOptions &LangOpts, TokenKind TK, TokenKinds... TKs)
Definition LexerUtils.h:68
SourceLocation findPreviousTokenKind(SourceLocation Start, const SourceManager &SM, const LangOptions &LangOpts, tok::TokenKind TK)