clang-tools 23.0.0git
UseConstraintsCheck.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "clang/AST/ASTContext.h"
11#include "clang/AST/DeclTemplate.h"
12#include "clang/ASTMatchers/ASTMatchFinder.h"
13#include "clang/Lex/Lexer.h"
14
15#include "../utils/LexerUtils.h"
16
17#include <optional>
18#include <utility>
19
20using namespace clang::ast_matchers;
21
22namespace clang::tidy::modernize {
23
24namespace {
25struct EnableIfData {
26 TemplateSpecializationTypeLoc Loc;
27 TypeLoc Outer;
28};
29
30AST_MATCHER(FunctionDecl, hasOtherDeclarations) {
31 auto It = Node.redecls_begin();
32 auto EndIt = Node.redecls_end();
33
34 if (It == EndIt)
35 return false;
36
37 ++It;
38 return It != EndIt;
39}
40} // namespace
41
42void UseConstraintsCheck::registerMatchers(MatchFinder *Finder) {
43 Finder->addMatcher(
44 functionTemplateDecl(
45 has(functionDecl(unless(hasOtherDeclarations()), isDefinition(),
46 hasReturnTypeLoc(typeLoc().bind("return")))
47 .bind("function")))
48 .bind("functionTemplate"),
49 this);
50}
51
52static std::optional<TemplateSpecializationTypeLoc>
54 if (const auto Dep = TheType.getAs<DependentNameTypeLoc>()) {
55 const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier();
56 const ElaboratedTypeKeyword Keyword = Dep.getTypePtr()->getKeyword();
57 if (!Identifier || Identifier->getName() != "type" ||
58 (Keyword != ElaboratedTypeKeyword::Typename &&
59 Keyword != ElaboratedTypeKeyword::None)) {
60 return std::nullopt;
61 }
62 TheType = Dep.getQualifierLoc().getAsTypeLoc();
63 if (TheType.isNull())
64 return std::nullopt;
65 } else {
66 return std::nullopt;
67 }
68
69 if (const auto SpecializationLoc =
70 TheType.getAs<TemplateSpecializationTypeLoc>()) {
71 const auto *Specialization =
72 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
73 if (!Specialization)
74 return std::nullopt;
75
76 const TemplateDecl *TD =
77 Specialization->getTemplateName().getAsTemplateDecl();
78 if (!TD || TD->getName() != "enable_if")
79 return std::nullopt;
80
81 assert(!TD->getTemplateParameters()->empty() &&
82 "found template with no template parameters?");
83 const auto *FirstParam = dyn_cast<NonTypeTemplateParmDecl>(
84 TD->getTemplateParameters()->getParam(0));
85 if (!FirstParam || !FirstParam->getType()->isBooleanType())
86 return std::nullopt;
87
88 const int NumArgs = SpecializationLoc.getNumArgs();
89 if (NumArgs != 1 && NumArgs != 2)
90 return std::nullopt;
91
92 return SpecializationLoc;
93 }
94 return std::nullopt;
95}
96
97static std::optional<TemplateSpecializationTypeLoc>
99 if (const auto SpecializationLoc =
100 TheType.getAs<TemplateSpecializationTypeLoc>()) {
101 const auto *Specialization =
102 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
103 if (!Specialization)
104 return std::nullopt;
105
106 const TemplateDecl *TD =
107 Specialization->getTemplateName().getAsTemplateDecl();
108 if (!TD || TD->getName() != "enable_if_t")
109 return std::nullopt;
110
111 if (!Specialization->isTypeAlias())
112 return std::nullopt;
113
114 assert(!TD->getTemplateParameters()->empty() &&
115 "found template with no template parameters?");
116 const auto *FirstParam = dyn_cast<NonTypeTemplateParmDecl>(
117 TD->getTemplateParameters()->getParam(0));
118 if (!FirstParam || !FirstParam->getType()->isBooleanType())
119 return std::nullopt;
120
121 if (const auto *AliasedType =
122 dyn_cast<DependentNameType>(Specialization->getAliasedType())) {
123 const ElaboratedTypeKeyword Keyword = AliasedType->getKeyword();
124 if (AliasedType->getIdentifier()->getName() != "type" ||
125 (Keyword != ElaboratedTypeKeyword::Typename &&
126 Keyword != ElaboratedTypeKeyword::None)) {
127 return std::nullopt;
128 }
129 } else {
130 return std::nullopt;
131 }
132 const int NumArgs = SpecializationLoc.getNumArgs();
133 if (NumArgs != 1 && NumArgs != 2)
134 return std::nullopt;
135
136 return SpecializationLoc;
137 }
138 return std::nullopt;
139}
140
141static std::optional<TemplateSpecializationTypeLoc>
143 if (auto EnableIf = matchEnableIfSpecializationImplTypename(TheType))
144 return EnableIf;
146}
147
148static std::optional<EnableIfData>
150 if (const auto Pointer = TheType.getAs<PointerTypeLoc>())
151 TheType = Pointer.getPointeeLoc();
152 else if (const auto Reference = TheType.getAs<ReferenceTypeLoc>())
153 TheType = Reference.getPointeeLoc();
154 if (const auto Qualified = TheType.getAs<QualifiedTypeLoc>())
155 TheType = Qualified.getUnqualifiedLoc();
156
157 if (auto EnableIf = matchEnableIfSpecializationImpl(TheType))
158 return EnableIfData{std::move(*EnableIf), TheType};
159 return std::nullopt;
160}
161
162static std::pair<std::optional<EnableIfData>, const Decl *>
163matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate) {
164 // For non-type trailing param, match very specifically
165 // 'template <..., enable_if_type<Condition, Type> = Default>' where
166 // enable_if_type is 'enable_if' or 'enable_if_t'. E.g., 'template <typename
167 // T, enable_if_t<is_same_v<T, bool>, int*> = nullptr>
168 //
169 // Otherwise, match a trailing default type arg.
170 // E.g., 'template <typename T, typename = enable_if_t<is_same_v<T, bool>>>'
171
172 const TemplateParameterList *TemplateParams =
173 FunctionTemplate->getTemplateParameters();
174 if (TemplateParams->empty())
175 return {};
176
177 const NamedDecl *LastParam =
178 TemplateParams->getParam(TemplateParams->size() - 1);
179 if (const auto *LastTemplateParam =
180 dyn_cast<NonTypeTemplateParmDecl>(LastParam)) {
181 if (!LastTemplateParam->hasDefaultArgument() ||
182 !LastTemplateParam->getName().empty())
183 return {};
184
186 LastTemplateParam->getTypeSourceInfo()->getTypeLoc()),
187 LastTemplateParam};
188 }
189 if (const auto *LastTemplateParam =
190 dyn_cast<TemplateTypeParmDecl>(LastParam)) {
191 if (LastTemplateParam->hasDefaultArgument() &&
192 LastTemplateParam->getIdentifier() == nullptr) {
193 return {
194 matchEnableIfSpecialization(LastTemplateParam->getDefaultArgument()
195 .getTypeSourceInfo()
196 ->getTypeLoc()),
197 LastTemplateParam};
198 }
199 }
200 return {};
201}
202
203template <typename T>
204static SourceLocation getRAngleFileLoc(const SourceManager &SM,
205 const T &Element) {
206 // getFileLoc handles the case where the RAngle loc is part of a synthesized
207 // '>>', which ends up allocating a 'scratch space' buffer in the source
208 // manager.
209 return SM.getFileLoc(Element.getRAngleLoc());
210}
211
212static SourceRange
213getConditionRange(ASTContext &Context,
214 const TemplateSpecializationTypeLoc &EnableIf) {
215 // TemplateArgumentLoc's SourceRange End is the location of the last token
216 // (per UnqualifiedId docs). E.g., in `enable_if<AAA && BBB>`, the End
217 // location will be the first 'B' in 'BBB'.
218 const LangOptions &LangOpts = Context.getLangOpts();
219 const SourceManager &SM = Context.getSourceManager();
220 if (EnableIf.getNumArgs() > 1) {
221 const TemplateArgumentLoc NextArg = EnableIf.getArgLoc(1);
222 return {EnableIf.getLAngleLoc().getLocWithOffset(1),
224 NextArg.getSourceRange().getBegin(), SM, LangOpts, tok::comma)};
225 }
226
227 return {EnableIf.getLAngleLoc().getLocWithOffset(1),
228 getRAngleFileLoc(SM, EnableIf)};
229}
230
231static SourceRange getTypeRange(ASTContext &Context,
232 const TemplateSpecializationTypeLoc &EnableIf) {
233 const TemplateArgumentLoc Arg = EnableIf.getArgLoc(1);
234 const LangOptions &LangOpts = Context.getLangOpts();
235 const SourceManager &SM = Context.getSourceManager();
236 return {utils::lexer::findPreviousTokenKind(Arg.getSourceRange().getBegin(),
237 SM, LangOpts, tok::comma)
238 .getLocWithOffset(1),
239 getRAngleFileLoc(SM, EnableIf)};
240}
241
242// Returns the original source text of the second argument of a call to
243// enable_if_t. E.g., in enable_if_t<Condition, TheType>, this function
244// returns 'TheType'.
245static std::optional<StringRef>
246getTypeText(ASTContext &Context,
247 const TemplateSpecializationTypeLoc &EnableIf) {
248 if (EnableIf.getNumArgs() > 1) {
249 const LangOptions &LangOpts = Context.getLangOpts();
250 const SourceManager &SM = Context.getSourceManager();
251 bool Invalid = false;
252 StringRef Text = Lexer::getSourceText(CharSourceRange::getCharRange(
253 getTypeRange(Context, EnableIf)),
254 SM, LangOpts, &Invalid)
255 .trim();
256 if (Invalid)
257 return std::nullopt;
258
259 return Text;
260 }
261
262 return "void";
263}
264
265static std::optional<SourceLocation>
266findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context) {
267 const SourceManager &SM = Context.getSourceManager();
268 const LangOptions &LangOpts = Context.getLangOpts();
269
270 if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(Function)) {
271 for (const CXXCtorInitializer *Init : Constructor->inits())
272 if (Init->getSourceOrder() == 0)
273 return utils::lexer::findPreviousTokenKind(Init->getSourceLocation(),
274 SM, LangOpts, tok::colon);
275 if (!Constructor->inits().empty())
276 return std::nullopt;
277 }
278 if (Function->isDeleted()) {
279 const SourceLocation FunctionEnd = Function->getSourceRange().getEnd();
280 return utils::lexer::findNextAnyTokenKind(FunctionEnd, SM, LangOpts,
281 tok::equal, tok::equal);
282 }
283 const Stmt *Body = Function->getBody();
284 if (!Body)
285 return std::nullopt;
286
287 return Body->getBeginLoc();
288}
289
290static bool isPrimaryExpression(const Expr *Expression) {
291 // This function is an incomplete approximation of checking whether
292 // an Expr is a primary expression. In particular, if this function
293 // returns true, the expression is a primary expression. The converse
294 // is not necessarily true.
295
296 if (const auto *Cast = dyn_cast<ImplicitCastExpr>(Expression))
297 Expression = Cast->getSubExprAsWritten();
298 if (isa<ParenExpr, DependentScopeDeclRefExpr>(Expression))
299 return true;
300
301 return false;
302}
303
304// Return the original source text of an enable_if_t condition, i.e., the
305// first template argument). For example, in
306// 'enable_if_t<FirstCondition || SecondCondition, AType>', the text
307// the text 'FirstCondition || SecondCondition' is returned.
308static std::optional<std::string> getConditionText(const Expr *ConditionExpr,
309 SourceRange ConditionRange,
310 ASTContext &Context) {
311 const SourceManager &SM = Context.getSourceManager();
312 const LangOptions &LangOpts = Context.getLangOpts();
313
314 SourceLocation PrevTokenLoc = ConditionRange.getEnd();
315 if (PrevTokenLoc.isInvalid())
316 return std::nullopt;
317
318 const bool SkipComments = false;
319 std::optional<Token> PrevToken;
320 std::tie(PrevToken, PrevTokenLoc) = utils::lexer::getPreviousTokenAndStart(
321 PrevTokenLoc, SM, LangOpts, SkipComments);
322 const bool EndsWithDoubleSlash =
323 PrevToken && PrevToken->is(tok::comment) &&
324 Lexer::getSourceText(CharSourceRange::getCharRange(
325 PrevTokenLoc, PrevTokenLoc.getLocWithOffset(2)),
326 SM, LangOpts) == "//";
327
328 bool Invalid = false;
329 const StringRef ConditionText = Lexer::getSourceText(
330 CharSourceRange::getCharRange(ConditionRange), SM, LangOpts, &Invalid);
331 if (Invalid)
332 return std::nullopt;
333
334 auto AddParens = [&](StringRef Text) -> std::string {
335 if (isPrimaryExpression(ConditionExpr))
336 return Text.str();
337 return "(" + Text.str() + ")";
338 };
339
340 if (EndsWithDoubleSlash)
341 return AddParens(ConditionText);
342 return AddParens(ConditionText.trim());
343}
344
345// Handle functions that return enable_if_t, e.g.,
346// template <...>
347// enable_if_t<Condition, ReturnType> function();
348//
349// Return a vector of FixItHints if the code can be replaced with
350// a C++20 requires clause. In the example above, returns FixItHints
351// to result in
352// template <...>
353// ReturnType function() requires Condition {}
354static std::vector<FixItHint> handleReturnType(const FunctionDecl *Function,
355 const TypeLoc &ReturnType,
356 const EnableIfData &EnableIf,
357 ASTContext &Context) {
358 const TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
359
360 const SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
361
362 std::optional<std::string> ConditionText = getConditionText(
363 EnableCondition.getSourceExpression(), ConditionRange, Context);
364 if (!ConditionText)
365 return {};
366
367 std::optional<StringRef> TypeText = getTypeText(Context, EnableIf.Loc);
368 if (!TypeText)
369 return {};
370
371 SmallVector<AssociatedConstraint, 3> ExistingConstraints;
372 Function->getAssociatedConstraints(ExistingConstraints);
373 if (!ExistingConstraints.empty()) {
374 // FIXME - Support adding new constraints to existing ones. Do we need to
375 // consider subsumption?
376 return {};
377 }
378
379 std::optional<SourceLocation> ConstraintInsertionLoc =
380 findInsertionForConstraint(Function, Context);
381 if (!ConstraintInsertionLoc)
382 return {};
383
384 std::vector<FixItHint> FixIts;
385 FixIts.push_back(FixItHint::CreateReplacement(
386 CharSourceRange::getTokenRange(EnableIf.Outer.getSourceRange()),
387 *TypeText));
388 FixIts.push_back(FixItHint::CreateInsertion(
389 *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
390 return FixIts;
391}
392
393// Handle enable_if_t in a trailing template parameter, e.g.,
394// template <..., enable_if_t<Condition, Type> = Type{}>
395// ReturnType function();
396//
397// Return a vector of FixItHints if the code can be replaced with
398// a C++20 requires clause. In the example above, returns FixItHints
399// to result in
400// template <...>
401// ReturnType function() requires Condition {}
402static std::vector<FixItHint>
403handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate,
404 const FunctionDecl *Function,
405 const Decl *LastTemplateParam,
406 const EnableIfData &EnableIf, ASTContext &Context) {
407 const SourceManager &SM = Context.getSourceManager();
408 const LangOptions &LangOpts = Context.getLangOpts();
409
410 const TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
411
412 const SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
413
414 std::optional<std::string> ConditionText = getConditionText(
415 EnableCondition.getSourceExpression(), ConditionRange, Context);
416 if (!ConditionText)
417 return {};
418
419 SmallVector<AssociatedConstraint, 3> ExistingConstraints;
420 Function->getAssociatedConstraints(ExistingConstraints);
421 if (!ExistingConstraints.empty()) {
422 // FIXME - Support adding new constraints to existing ones. Do we need to
423 // consider subsumption?
424 return {};
425 }
426
427 SourceRange RemovalRange;
428 const TemplateParameterList *TemplateParams =
429 FunctionTemplate->getTemplateParameters();
430 if (!TemplateParams || TemplateParams->empty())
431 return {};
432
433 if (TemplateParams->size() == 1) {
434 RemovalRange =
435 SourceRange(TemplateParams->getTemplateLoc(),
436 getRAngleFileLoc(SM, *TemplateParams).getLocWithOffset(1));
437 } else {
438 RemovalRange =
440 LastTemplateParam->getSourceRange().getBegin(), SM,
441 LangOpts, tok::comma),
442 getRAngleFileLoc(SM, *TemplateParams));
443 }
444
445 std::optional<SourceLocation> ConstraintInsertionLoc =
446 findInsertionForConstraint(Function, Context);
447 if (!ConstraintInsertionLoc)
448 return {};
449
450 std::vector<FixItHint> FixIts;
451 FixIts.push_back(
452 FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange)));
453 FixIts.push_back(FixItHint::CreateInsertion(
454 *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
455 return FixIts;
456}
457
458void UseConstraintsCheck::check(const MatchFinder::MatchResult &Result) {
459 const auto *FunctionTemplate =
460 Result.Nodes.getNodeAs<FunctionTemplateDecl>("functionTemplate");
461 const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>("function");
462 const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>("return");
463 if (!FunctionTemplate || !Function || !ReturnType)
464 return;
465
466 // Check for
467 //
468 // Case 1. Return type of function
469 //
470 // template <...>
471 // enable_if_t<Condition, ReturnType>::type function() {}
472 //
473 // Case 2. Trailing template parameter
474 //
475 // template <..., enable_if_t<Condition, Type> = Type{}>
476 // ReturnType function() {}
477 //
478 // or
479 //
480 // template <..., typename = enable_if_t<Condition, void>>
481 // ReturnType function() {}
482 //
483
484 // Case 1. Return type of function
485 if (auto EnableIf = matchEnableIfSpecialization(*ReturnType)) {
486 diag(ReturnType->getBeginLoc(),
487 "use C++20 requires constraints instead of enable_if")
488 << handleReturnType(Function, *ReturnType, *EnableIf, *Result.Context);
489 return;
490 }
491
492 // Case 2. Trailing template parameter
493 if (auto [EnableIf, LastTemplateParam] =
494 matchTrailingTemplateParam(FunctionTemplate);
495 EnableIf && LastTemplateParam) {
496 diag(LastTemplateParam->getSourceRange().getBegin(),
497 "use C++20 requires constraints instead of enable_if")
498 << handleTrailingTemplateType(FunctionTemplate, Function,
499 LastTemplateParam, *EnableIf,
500 *Result.Context);
501 return;
502 }
503}
504
505} // namespace clang::tidy::modernize
void registerMatchers(ast_matchers::MatchFinder *Finder) override
void check(const ast_matchers::MatchFinder::MatchResult &Result) override
AST_MATCHER(BinaryOperator, isRelationalOperator)
static std::pair< std::optional< EnableIfData >, const Decl * > matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate)
static std::vector< FixItHint > handleReturnType(const FunctionDecl *Function, const TypeLoc &ReturnType, const EnableIfData &EnableIf, ASTContext &Context)
static SourceRange getTypeRange(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static std::optional< SourceLocation > findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImplTypename(TypeLoc TheType)
static std::optional< std::string > getConditionText(const Expr *ConditionExpr, SourceRange ConditionRange, ASTContext &Context)
static std::vector< FixItHint > handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate, const FunctionDecl *Function, const Decl *LastTemplateParam, const EnableIfData &EnableIf, ASTContext &Context)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImplTrait(TypeLoc TheType)
static std::optional< StringRef > getTypeText(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static SourceRange getConditionRange(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static SourceLocation getRAngleFileLoc(const SourceManager &SM, const T &Element)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImpl(TypeLoc TheType)
static bool isPrimaryExpression(const Expr *Expression)
static std::optional< EnableIfData > matchEnableIfSpecialization(TypeLoc TheType)
std::pair< std::optional< Token >, SourceLocation > getPreviousTokenAndStart(SourceLocation Location, const SourceManager &SM, const LangOptions &LangOpts, bool SkipComments)
SourceLocation findNextAnyTokenKind(SourceLocation Start, const SourceManager &SM, const LangOptions &LangOpts, TokenKind TK, TokenKinds... TKs)
Definition LexerUtils.h:73
SourceLocation findPreviousTokenKind(SourceLocation Start, const SourceManager &SM, const LangOptions &LangOpts, tok::TokenKind TK)