clang-tools 22.0.0git
UseConstraintsCheck.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "clang/AST/ASTContext.h"
11#include "clang/AST/DeclTemplate.h"
12#include "clang/ASTMatchers/ASTMatchFinder.h"
13#include "clang/Lex/Lexer.h"
14
15#include "../utils/LexerUtils.h"
16
17#include <optional>
18#include <utility>
19
20using namespace clang::ast_matchers;
21
22namespace clang::tidy::modernize {
23
25 TemplateSpecializationTypeLoc Loc;
26 TypeLoc Outer;
27};
28
29namespace {
30AST_MATCHER(FunctionDecl, hasOtherDeclarations) {
31 auto It = Node.redecls_begin();
32 auto EndIt = Node.redecls_end();
33
34 if (It == EndIt)
35 return false;
36
37 ++It;
38 return It != EndIt;
39}
40} // namespace
41
42void UseConstraintsCheck::registerMatchers(MatchFinder *Finder) {
43 Finder->addMatcher(
44 functionTemplateDecl(
45 // Skip external libraries included as system headers
46 unless(isExpansionInSystemHeader()),
47 has(functionDecl(unless(hasOtherDeclarations()), isDefinition(),
48 hasReturnTypeLoc(typeLoc().bind("return")))
49 .bind("function")))
50 .bind("functionTemplate"),
51 this);
52}
53
54static std::optional<TemplateSpecializationTypeLoc>
56 if (const auto Dep = TheType.getAs<DependentNameTypeLoc>()) {
57 const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier();
58 const ElaboratedTypeKeyword Keyword = Dep.getTypePtr()->getKeyword();
59 if (!Identifier || Identifier->getName() != "type" ||
60 (Keyword != ElaboratedTypeKeyword::Typename &&
61 Keyword != ElaboratedTypeKeyword::None)) {
62 return std::nullopt;
63 }
64 TheType = Dep.getQualifierLoc().getAsTypeLoc();
65 if (TheType.isNull())
66 return std::nullopt;
67 } else {
68 return std::nullopt;
69 }
70
71 if (const auto SpecializationLoc =
72 TheType.getAs<TemplateSpecializationTypeLoc>()) {
73 const auto *Specialization =
74 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
75 if (!Specialization)
76 return std::nullopt;
77
78 const TemplateDecl *TD =
79 Specialization->getTemplateName().getAsTemplateDecl();
80 if (!TD || TD->getName() != "enable_if")
81 return std::nullopt;
82
83 assert(!TD->getTemplateParameters()->empty() &&
84 "found template with no template parameters?");
85 const auto *FirstParam = dyn_cast<NonTypeTemplateParmDecl>(
86 TD->getTemplateParameters()->getParam(0));
87 if (!FirstParam || !FirstParam->getType()->isBooleanType())
88 return std::nullopt;
89
90 const int NumArgs = SpecializationLoc.getNumArgs();
91 if (NumArgs != 1 && NumArgs != 2)
92 return std::nullopt;
93
94 return SpecializationLoc;
95 }
96 return std::nullopt;
97}
98
99static std::optional<TemplateSpecializationTypeLoc>
101 if (const auto SpecializationLoc =
102 TheType.getAs<TemplateSpecializationTypeLoc>()) {
103 const auto *Specialization =
104 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
105 if (!Specialization)
106 return std::nullopt;
107
108 const TemplateDecl *TD =
109 Specialization->getTemplateName().getAsTemplateDecl();
110 if (!TD || TD->getName() != "enable_if_t")
111 return std::nullopt;
112
113 if (!Specialization->isTypeAlias())
114 return std::nullopt;
115
116 assert(!TD->getTemplateParameters()->empty() &&
117 "found template with no template parameters?");
118 const auto *FirstParam = dyn_cast<NonTypeTemplateParmDecl>(
119 TD->getTemplateParameters()->getParam(0));
120 if (!FirstParam || !FirstParam->getType()->isBooleanType())
121 return std::nullopt;
122
123 if (const auto *AliasedType =
124 dyn_cast<DependentNameType>(Specialization->getAliasedType())) {
125 const ElaboratedTypeKeyword Keyword = AliasedType->getKeyword();
126 if (AliasedType->getIdentifier()->getName() != "type" ||
127 (Keyword != ElaboratedTypeKeyword::Typename &&
128 Keyword != ElaboratedTypeKeyword::None)) {
129 return std::nullopt;
130 }
131 } else {
132 return std::nullopt;
133 }
134 const int NumArgs = SpecializationLoc.getNumArgs();
135 if (NumArgs != 1 && NumArgs != 2)
136 return std::nullopt;
137
138 return SpecializationLoc;
139 }
140 return std::nullopt;
141}
142
143static std::optional<TemplateSpecializationTypeLoc>
145 if (auto EnableIf = matchEnableIfSpecializationImplTypename(TheType))
146 return EnableIf;
148}
149
150static std::optional<EnableIfData>
152 if (const auto Pointer = TheType.getAs<PointerTypeLoc>())
153 TheType = Pointer.getPointeeLoc();
154 else if (const auto Reference = TheType.getAs<ReferenceTypeLoc>())
155 TheType = Reference.getPointeeLoc();
156 if (const auto Qualified = TheType.getAs<QualifiedTypeLoc>())
157 TheType = Qualified.getUnqualifiedLoc();
158
159 if (auto EnableIf = matchEnableIfSpecializationImpl(TheType))
160 return EnableIfData{std::move(*EnableIf), TheType};
161 return std::nullopt;
162}
163
164static std::pair<std::optional<EnableIfData>, const Decl *>
165matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate) {
166 // For non-type trailing param, match very specifically
167 // 'template <..., enable_if_type<Condition, Type> = Default>' where
168 // enable_if_type is 'enable_if' or 'enable_if_t'. E.g., 'template <typename
169 // T, enable_if_t<is_same_v<T, bool>, int*> = nullptr>
170 //
171 // Otherwise, match a trailing default type arg.
172 // E.g., 'template <typename T, typename = enable_if_t<is_same_v<T, bool>>>'
173
174 const TemplateParameterList *TemplateParams =
175 FunctionTemplate->getTemplateParameters();
176 if (TemplateParams->empty())
177 return {};
178
179 const NamedDecl *LastParam =
180 TemplateParams->getParam(TemplateParams->size() - 1);
181 if (const auto *LastTemplateParam =
182 dyn_cast<NonTypeTemplateParmDecl>(LastParam)) {
183 if (!LastTemplateParam->hasDefaultArgument() ||
184 !LastTemplateParam->getName().empty())
185 return {};
186
188 LastTemplateParam->getTypeSourceInfo()->getTypeLoc()),
189 LastTemplateParam};
190 }
191 if (const auto *LastTemplateParam =
192 dyn_cast<TemplateTypeParmDecl>(LastParam)) {
193 if (LastTemplateParam->hasDefaultArgument() &&
194 LastTemplateParam->getIdentifier() == nullptr) {
195 return {
196 matchEnableIfSpecialization(LastTemplateParam->getDefaultArgument()
197 .getTypeSourceInfo()
198 ->getTypeLoc()),
199 LastTemplateParam};
200 }
201 }
202 return {};
203}
204
205template <typename T>
206static SourceLocation getRAngleFileLoc(const SourceManager &SM,
207 const T &Element) {
208 // getFileLoc handles the case where the RAngle loc is part of a synthesized
209 // '>>', which ends up allocating a 'scratch space' buffer in the source
210 // manager.
211 return SM.getFileLoc(Element.getRAngleLoc());
212}
213
214static SourceRange
215getConditionRange(ASTContext &Context,
216 const TemplateSpecializationTypeLoc &EnableIf) {
217 // TemplateArgumentLoc's SourceRange End is the location of the last token
218 // (per UnqualifiedId docs). E.g., in `enable_if<AAA && BBB>`, the End
219 // location will be the first 'B' in 'BBB'.
220 const LangOptions &LangOpts = Context.getLangOpts();
221 const SourceManager &SM = Context.getSourceManager();
222 if (EnableIf.getNumArgs() > 1) {
223 const TemplateArgumentLoc NextArg = EnableIf.getArgLoc(1);
224 return {EnableIf.getLAngleLoc().getLocWithOffset(1),
226 NextArg.getSourceRange().getBegin(), SM, LangOpts, tok::comma)};
227 }
228
229 return {EnableIf.getLAngleLoc().getLocWithOffset(1),
230 getRAngleFileLoc(SM, EnableIf)};
231}
232
233static SourceRange getTypeRange(ASTContext &Context,
234 const TemplateSpecializationTypeLoc &EnableIf) {
235 const TemplateArgumentLoc Arg = EnableIf.getArgLoc(1);
236 const LangOptions &LangOpts = Context.getLangOpts();
237 const SourceManager &SM = Context.getSourceManager();
238 return {utils::lexer::findPreviousTokenKind(Arg.getSourceRange().getBegin(),
239 SM, LangOpts, tok::comma)
240 .getLocWithOffset(1),
241 getRAngleFileLoc(SM, EnableIf)};
242}
243
244// Returns the original source text of the second argument of a call to
245// enable_if_t. E.g., in enable_if_t<Condition, TheType>, this function
246// returns 'TheType'.
247static std::optional<StringRef>
248getTypeText(ASTContext &Context,
249 const TemplateSpecializationTypeLoc &EnableIf) {
250 if (EnableIf.getNumArgs() > 1) {
251 const LangOptions &LangOpts = Context.getLangOpts();
252 const SourceManager &SM = Context.getSourceManager();
253 bool Invalid = false;
254 StringRef Text = Lexer::getSourceText(CharSourceRange::getCharRange(
255 getTypeRange(Context, EnableIf)),
256 SM, LangOpts, &Invalid)
257 .trim();
258 if (Invalid)
259 return std::nullopt;
260
261 return Text;
262 }
263
264 return "void";
265}
266
267static std::optional<SourceLocation>
268findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context) {
269 const SourceManager &SM = Context.getSourceManager();
270 const LangOptions &LangOpts = Context.getLangOpts();
271
272 if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(Function)) {
273 for (const CXXCtorInitializer *Init : Constructor->inits()) {
274 if (Init->getSourceOrder() == 0)
275 return utils::lexer::findPreviousTokenKind(Init->getSourceLocation(),
276 SM, LangOpts, tok::colon);
277 }
278 if (!Constructor->inits().empty())
279 return std::nullopt;
280 }
281 if (Function->isDeleted()) {
282 const SourceLocation FunctionEnd = Function->getSourceRange().getEnd();
283 return utils::lexer::findNextAnyTokenKind(FunctionEnd, SM, LangOpts,
284 tok::equal, tok::equal);
285 }
286 const Stmt *Body = Function->getBody();
287 if (!Body)
288 return std::nullopt;
289
290 return Body->getBeginLoc();
291}
292
293static bool isPrimaryExpression(const Expr *Expression) {
294 // This function is an incomplete approximation of checking whether
295 // an Expr is a primary expression. In particular, if this function
296 // returns true, the expression is a primary expression. The converse
297 // is not necessarily true.
298
299 if (const auto *Cast = dyn_cast<ImplicitCastExpr>(Expression))
300 Expression = Cast->getSubExprAsWritten();
301 if (isa<ParenExpr, DependentScopeDeclRefExpr>(Expression))
302 return true;
303
304 return false;
305}
306
307// Return the original source text of an enable_if_t condition, i.e., the
308// first template argument). For example, in
309// 'enable_if_t<FirstCondition || SecondCondition, AType>', the text
310// the text 'FirstCondition || SecondCondition' is returned.
311static std::optional<std::string> getConditionText(const Expr *ConditionExpr,
312 SourceRange ConditionRange,
313 ASTContext &Context) {
314 const SourceManager &SM = Context.getSourceManager();
315 const LangOptions &LangOpts = Context.getLangOpts();
316
317 SourceLocation PrevTokenLoc = ConditionRange.getEnd();
318 if (PrevTokenLoc.isInvalid())
319 return std::nullopt;
320
321 const bool SkipComments = false;
322 Token PrevToken;
323 std::tie(PrevToken, PrevTokenLoc) = utils::lexer::getPreviousTokenAndStart(
324 PrevTokenLoc, SM, LangOpts, SkipComments);
325 const bool EndsWithDoubleSlash =
326 PrevToken.is(tok::comment) &&
327 Lexer::getSourceText(CharSourceRange::getCharRange(
328 PrevTokenLoc, PrevTokenLoc.getLocWithOffset(2)),
329 SM, LangOpts) == "//";
330
331 bool Invalid = false;
332 const llvm::StringRef ConditionText = Lexer::getSourceText(
333 CharSourceRange::getCharRange(ConditionRange), SM, LangOpts, &Invalid);
334 if (Invalid)
335 return std::nullopt;
336
337 auto AddParens = [&](llvm::StringRef Text) -> std::string {
338 if (isPrimaryExpression(ConditionExpr))
339 return Text.str();
340 return "(" + Text.str() + ")";
341 };
342
343 if (EndsWithDoubleSlash)
344 return AddParens(ConditionText);
345 return AddParens(ConditionText.trim());
346}
347
348// Handle functions that return enable_if_t, e.g.,
349// template <...>
350// enable_if_t<Condition, ReturnType> function();
351//
352// Return a vector of FixItHints if the code can be replaced with
353// a C++20 requires clause. In the example above, returns FixItHints
354// to result in
355// template <...>
356// ReturnType function() requires Condition {}
357static std::vector<FixItHint> handleReturnType(const FunctionDecl *Function,
358 const TypeLoc &ReturnType,
359 const EnableIfData &EnableIf,
360 ASTContext &Context) {
361 const TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
362
363 const SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
364
365 std::optional<std::string> ConditionText = getConditionText(
366 EnableCondition.getSourceExpression(), ConditionRange, Context);
367 if (!ConditionText)
368 return {};
369
370 std::optional<StringRef> TypeText = getTypeText(Context, EnableIf.Loc);
371 if (!TypeText)
372 return {};
373
374 SmallVector<AssociatedConstraint, 3> ExistingConstraints;
375 Function->getAssociatedConstraints(ExistingConstraints);
376 if (!ExistingConstraints.empty()) {
377 // FIXME - Support adding new constraints to existing ones. Do we need to
378 // consider subsumption?
379 return {};
380 }
381
382 std::optional<SourceLocation> ConstraintInsertionLoc =
383 findInsertionForConstraint(Function, Context);
384 if (!ConstraintInsertionLoc)
385 return {};
386
387 std::vector<FixItHint> FixIts;
388 FixIts.push_back(FixItHint::CreateReplacement(
389 CharSourceRange::getTokenRange(EnableIf.Outer.getSourceRange()),
390 *TypeText));
391 FixIts.push_back(FixItHint::CreateInsertion(
392 *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
393 return FixIts;
394}
395
396// Handle enable_if_t in a trailing template parameter, e.g.,
397// template <..., enable_if_t<Condition, Type> = Type{}>
398// ReturnType function();
399//
400// Return a vector of FixItHints if the code can be replaced with
401// a C++20 requires clause. In the example above, returns FixItHints
402// to result in
403// template <...>
404// ReturnType function() requires Condition {}
405static std::vector<FixItHint>
406handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate,
407 const FunctionDecl *Function,
408 const Decl *LastTemplateParam,
409 const EnableIfData &EnableIf, ASTContext &Context) {
410 const SourceManager &SM = Context.getSourceManager();
411 const LangOptions &LangOpts = Context.getLangOpts();
412
413 const TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
414
415 const SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
416
417 std::optional<std::string> ConditionText = getConditionText(
418 EnableCondition.getSourceExpression(), ConditionRange, Context);
419 if (!ConditionText)
420 return {};
421
422 SmallVector<AssociatedConstraint, 3> ExistingConstraints;
423 Function->getAssociatedConstraints(ExistingConstraints);
424 if (!ExistingConstraints.empty()) {
425 // FIXME - Support adding new constraints to existing ones. Do we need to
426 // consider subsumption?
427 return {};
428 }
429
430 SourceRange RemovalRange;
431 const TemplateParameterList *TemplateParams =
432 FunctionTemplate->getTemplateParameters();
433 if (!TemplateParams || TemplateParams->empty())
434 return {};
435
436 if (TemplateParams->size() == 1) {
437 RemovalRange =
438 SourceRange(TemplateParams->getTemplateLoc(),
439 getRAngleFileLoc(SM, *TemplateParams).getLocWithOffset(1));
440 } else {
441 RemovalRange =
443 LastTemplateParam->getSourceRange().getBegin(), SM,
444 LangOpts, tok::comma),
445 getRAngleFileLoc(SM, *TemplateParams));
446 }
447
448 std::optional<SourceLocation> ConstraintInsertionLoc =
449 findInsertionForConstraint(Function, Context);
450 if (!ConstraintInsertionLoc)
451 return {};
452
453 std::vector<FixItHint> FixIts;
454 FixIts.push_back(
455 FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange)));
456 FixIts.push_back(FixItHint::CreateInsertion(
457 *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
458 return FixIts;
459}
460
461void UseConstraintsCheck::check(const MatchFinder::MatchResult &Result) {
462 const auto *FunctionTemplate =
463 Result.Nodes.getNodeAs<FunctionTemplateDecl>("functionTemplate");
464 const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>("function");
465 const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>("return");
466 if (!FunctionTemplate || !Function || !ReturnType)
467 return;
468
469 // Check for
470 //
471 // Case 1. Return type of function
472 //
473 // template <...>
474 // enable_if_t<Condition, ReturnType>::type function() {}
475 //
476 // Case 2. Trailing template parameter
477 //
478 // template <..., enable_if_t<Condition, Type> = Type{}>
479 // ReturnType function() {}
480 //
481 // or
482 //
483 // template <..., typename = enable_if_t<Condition, void>>
484 // ReturnType function() {}
485 //
486
487 // Case 1. Return type of function
488 if (auto EnableIf = matchEnableIfSpecialization(*ReturnType)) {
489 diag(ReturnType->getBeginLoc(),
490 "use C++20 requires constraints instead of enable_if")
491 << handleReturnType(Function, *ReturnType, *EnableIf, *Result.Context);
492 return;
493 }
494
495 // Case 2. Trailing template parameter
496 if (auto [EnableIf, LastTemplateParam] =
497 matchTrailingTemplateParam(FunctionTemplate);
498 EnableIf && LastTemplateParam) {
499 diag(LastTemplateParam->getSourceRange().getBegin(),
500 "use C++20 requires constraints instead of enable_if")
501 << handleTrailingTemplateType(FunctionTemplate, Function,
502 LastTemplateParam, *EnableIf,
503 *Result.Context);
504 return;
505 }
506}
507
508} // namespace clang::tidy::modernize
void registerMatchers(ast_matchers::MatchFinder *Finder) override
void check(const ast_matchers::MatchFinder::MatchResult &Result) override
static std::pair< std::optional< EnableIfData >, const Decl * > matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate)
static std::vector< FixItHint > handleReturnType(const FunctionDecl *Function, const TypeLoc &ReturnType, const EnableIfData &EnableIf, ASTContext &Context)
static SourceRange getTypeRange(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static std::optional< SourceLocation > findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImplTypename(TypeLoc TheType)
static std::optional< std::string > getConditionText(const Expr *ConditionExpr, SourceRange ConditionRange, ASTContext &Context)
static std::vector< FixItHint > handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate, const FunctionDecl *Function, const Decl *LastTemplateParam, const EnableIfData &EnableIf, ASTContext &Context)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImplTrait(TypeLoc TheType)
static std::optional< StringRef > getTypeText(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static SourceRange getConditionRange(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static SourceLocation getRAngleFileLoc(const SourceManager &SM, const T &Element)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImpl(TypeLoc TheType)
static bool isPrimaryExpression(const Expr *Expression)
static std::optional< EnableIfData > matchEnableIfSpecialization(TypeLoc TheType)
std::pair< Token, SourceLocation > getPreviousTokenAndStart(SourceLocation Location, const SourceManager &SM, const LangOptions &LangOpts, bool SkipComments)
SourceLocation findNextAnyTokenKind(SourceLocation Start, const SourceManager &SM, const LangOptions &LangOpts, TokenKind TK, TokenKinds... TKs)
Definition LexerUtils.h:68
SourceLocation findPreviousTokenKind(SourceLocation Start, const SourceManager &SM, const LangOptions &LangOpts, tok::TokenKind TK)