10#include "clang/AST/ASTContext.h"
11#include "clang/AST/DeclTemplate.h"
12#include "clang/ASTMatchers/ASTMatchFinder.h"
13#include "clang/Lex/Lexer.h"
26 TemplateSpecializationTypeLoc Loc;
31 auto It = Node.redecls_begin();
32 auto EndIt = Node.redecls_end();
45 has(functionDecl(unless(hasOtherDeclarations()), isDefinition(),
46 hasReturnTypeLoc(typeLoc().bind(
"return")))
48 .bind(
"functionTemplate"),
52static std::optional<TemplateSpecializationTypeLoc>
54 if (
const auto Dep = TheType.getAs<DependentNameTypeLoc>()) {
55 const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier();
56 const ElaboratedTypeKeyword Keyword = Dep.getTypePtr()->getKeyword();
57 if (!Identifier || Identifier->getName() !=
"type" ||
58 (Keyword != ElaboratedTypeKeyword::Typename &&
59 Keyword != ElaboratedTypeKeyword::None)) {
62 TheType = Dep.getQualifierLoc().getAsTypeLoc();
69 if (
const auto SpecializationLoc =
70 TheType.getAs<TemplateSpecializationTypeLoc>()) {
71 const auto *Specialization =
72 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
76 const TemplateDecl *TD =
77 Specialization->getTemplateName().getAsTemplateDecl();
78 if (!TD || TD->getName() !=
"enable_if")
81 assert(!TD->getTemplateParameters()->empty() &&
82 "found template with no template parameters?");
83 const auto *FirstParam = dyn_cast<NonTypeTemplateParmDecl>(
84 TD->getTemplateParameters()->getParam(0));
85 if (!FirstParam || !FirstParam->getType()->isBooleanType())
88 const int NumArgs = SpecializationLoc.getNumArgs();
89 if (NumArgs != 1 && NumArgs != 2)
92 return SpecializationLoc;
97static std::optional<TemplateSpecializationTypeLoc>
99 if (
const auto SpecializationLoc =
100 TheType.getAs<TemplateSpecializationTypeLoc>()) {
101 const auto *Specialization =
102 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
106 const TemplateDecl *TD =
107 Specialization->getTemplateName().getAsTemplateDecl();
108 if (!TD || TD->getName() !=
"enable_if_t")
111 if (!Specialization->isTypeAlias())
114 assert(!TD->getTemplateParameters()->empty() &&
115 "found template with no template parameters?");
116 const auto *FirstParam = dyn_cast<NonTypeTemplateParmDecl>(
117 TD->getTemplateParameters()->getParam(0));
118 if (!FirstParam || !FirstParam->getType()->isBooleanType())
121 if (
const auto *AliasedType =
122 dyn_cast<DependentNameType>(Specialization->getAliasedType())) {
123 const ElaboratedTypeKeyword Keyword = AliasedType->getKeyword();
124 if (AliasedType->getIdentifier()->getName() !=
"type" ||
125 (Keyword != ElaboratedTypeKeyword::Typename &&
126 Keyword != ElaboratedTypeKeyword::None)) {
132 const int NumArgs = SpecializationLoc.getNumArgs();
133 if (NumArgs != 1 && NumArgs != 2)
136 return SpecializationLoc;
141static std::optional<TemplateSpecializationTypeLoc>
148static std::optional<EnableIfData>
150 if (
const auto Pointer = TheType.getAs<PointerTypeLoc>())
151 TheType = Pointer.getPointeeLoc();
152 else if (
const auto Reference = TheType.getAs<ReferenceTypeLoc>())
153 TheType = Reference.getPointeeLoc();
154 if (
const auto Qualified = TheType.getAs<QualifiedTypeLoc>())
155 TheType = Qualified.getUnqualifiedLoc();
158 return EnableIfData{std::move(*EnableIf), TheType};
162static std::pair<std::optional<EnableIfData>,
const Decl *>
172 const TemplateParameterList *TemplateParams =
173 FunctionTemplate->getTemplateParameters();
174 if (TemplateParams->empty())
177 const NamedDecl *LastParam =
178 TemplateParams->getParam(TemplateParams->size() - 1);
179 if (
const auto *LastTemplateParam =
180 dyn_cast<NonTypeTemplateParmDecl>(LastParam)) {
181 if (!LastTemplateParam->hasDefaultArgument() ||
182 !LastTemplateParam->getName().empty())
186 LastTemplateParam->getTypeSourceInfo()->getTypeLoc()),
189 if (
const auto *LastTemplateParam =
190 dyn_cast<TemplateTypeParmDecl>(LastParam)) {
191 if (LastTemplateParam->hasDefaultArgument() &&
192 LastTemplateParam->getIdentifier() ==
nullptr) {
209 return SM.getFileLoc(Element.getRAngleLoc());
214 const TemplateSpecializationTypeLoc &EnableIf) {
218 const LangOptions &LangOpts = Context.getLangOpts();
219 const SourceManager &SM = Context.getSourceManager();
220 if (EnableIf.getNumArgs() > 1) {
221 const TemplateArgumentLoc NextArg = EnableIf.getArgLoc(1);
222 return {EnableIf.getLAngleLoc().getLocWithOffset(1),
224 NextArg.getSourceRange().getBegin(), SM, LangOpts, tok::comma)};
227 return {EnableIf.getLAngleLoc().getLocWithOffset(1),
232 const TemplateSpecializationTypeLoc &EnableIf) {
233 const TemplateArgumentLoc Arg = EnableIf.getArgLoc(1);
234 const LangOptions &LangOpts = Context.getLangOpts();
235 const SourceManager &SM = Context.getSourceManager();
237 SM, LangOpts, tok::comma)
238 .getLocWithOffset(1),
245static std::optional<StringRef>
247 const TemplateSpecializationTypeLoc &EnableIf) {
248 if (EnableIf.getNumArgs() > 1) {
249 const LangOptions &LangOpts = Context.getLangOpts();
250 const SourceManager &SM = Context.getSourceManager();
251 bool Invalid =
false;
252 StringRef Text = Lexer::getSourceText(CharSourceRange::getCharRange(
254 SM, LangOpts, &Invalid)
265static std::optional<SourceLocation>
267 const SourceManager &SM = Context.getSourceManager();
268 const LangOptions &LangOpts = Context.getLangOpts();
270 if (
const auto *Constructor = dyn_cast<CXXConstructorDecl>(Function)) {
271 for (
const CXXCtorInitializer *Init : Constructor->inits())
272 if (Init->getSourceOrder() == 0)
274 SM, LangOpts, tok::colon);
275 if (!Constructor->inits().empty())
278 if (Function->isDeleted()) {
279 const SourceLocation FunctionEnd = Function->getSourceRange().getEnd();
281 tok::equal, tok::equal);
283 const Stmt *Body = Function->getBody();
287 return Body->getBeginLoc();
296 if (
const auto *Cast = dyn_cast<ImplicitCastExpr>(Expression))
297 Expression = Cast->getSubExprAsWritten();
298 if (isa<ParenExpr, DependentScopeDeclRefExpr>(Expression))
309 SourceRange ConditionRange,
310 ASTContext &Context) {
311 const SourceManager &SM = Context.getSourceManager();
312 const LangOptions &LangOpts = Context.getLangOpts();
314 SourceLocation PrevTokenLoc = ConditionRange.getEnd();
315 if (PrevTokenLoc.isInvalid())
318 const bool SkipComments =
false;
319 std::optional<Token> PrevToken;
321 PrevTokenLoc, SM, LangOpts, SkipComments);
322 const bool EndsWithDoubleSlash =
323 PrevToken && PrevToken->is(tok::comment) &&
324 Lexer::getSourceText(CharSourceRange::getCharRange(
325 PrevTokenLoc, PrevTokenLoc.getLocWithOffset(2)),
326 SM, LangOpts) ==
"//";
328 bool Invalid =
false;
329 const StringRef ConditionText = Lexer::getSourceText(
330 CharSourceRange::getCharRange(ConditionRange), SM, LangOpts, &Invalid);
334 auto AddParens = [&](StringRef Text) -> std::string {
337 return "(" + Text.str() +
")";
340 if (EndsWithDoubleSlash)
341 return AddParens(ConditionText);
342 return AddParens(ConditionText.trim());
355 const TypeLoc &ReturnType,
356 const EnableIfData &EnableIf,
357 ASTContext &Context) {
358 const TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
363 EnableCondition.getSourceExpression(), ConditionRange, Context);
367 std::optional<StringRef> TypeText =
getTypeText(Context, EnableIf.Loc);
372 Function->getAssociatedConstraints(ExistingConstraints);
373 if (!ExistingConstraints.empty()) {
379 std::optional<SourceLocation> ConstraintInsertionLoc =
381 if (!ConstraintInsertionLoc)
384 std::vector<FixItHint> FixIts;
385 FixIts.push_back(FixItHint::CreateReplacement(
386 CharSourceRange::getTokenRange(EnableIf.Outer.getSourceRange()),
388 FixIts.push_back(FixItHint::CreateInsertion(
389 *ConstraintInsertionLoc,
"requires " + *ConditionText +
" "));
402static std::vector<FixItHint>
404 const FunctionDecl *Function,
405 const Decl *LastTemplateParam,
406 const EnableIfData &EnableIf, ASTContext &Context) {
407 const SourceManager &SM = Context.getSourceManager();
408 const LangOptions &LangOpts = Context.getLangOpts();
410 const TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
415 EnableCondition.getSourceExpression(), ConditionRange, Context);
420 Function->getAssociatedConstraints(ExistingConstraints);
421 if (!ExistingConstraints.empty()) {
427 SourceRange RemovalRange;
428 const TemplateParameterList *TemplateParams =
429 FunctionTemplate->getTemplateParameters();
430 if (!TemplateParams || TemplateParams->empty())
433 if (TemplateParams->size() == 1) {
435 SourceRange(TemplateParams->getTemplateLoc(),
440 LastTemplateParam->getSourceRange().getBegin(), SM,
441 LangOpts, tok::comma),
445 std::optional<SourceLocation> ConstraintInsertionLoc =
447 if (!ConstraintInsertionLoc)
450 std::vector<FixItHint> FixIts;
452 FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange)));
453 FixIts.push_back(FixItHint::CreateInsertion(
454 *ConstraintInsertionLoc,
"requires " + *ConditionText +
" "));
459 const auto *FunctionTemplate =
460 Result.Nodes.getNodeAs<FunctionTemplateDecl>(
"functionTemplate");
461 const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>(
"function");
462 const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>(
"return");
463 if (!FunctionTemplate || !Function || !ReturnType)
486 diag(ReturnType->getBeginLoc(),
487 "use C++20 requires constraints instead of enable_if")
493 if (
auto [EnableIf, LastTemplateParam] =
495 EnableIf && LastTemplateParam) {
496 diag(LastTemplateParam->getSourceRange().getBegin(),
497 "use C++20 requires constraints instead of enable_if")
499 LastTemplateParam, *EnableIf,
void registerMatchers(ast_matchers::MatchFinder *Finder) override
void check(const ast_matchers::MatchFinder::MatchResult &Result) override
AST_MATCHER(BinaryOperator, isRelationalOperator)
static std::pair< std::optional< EnableIfData >, const Decl * > matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate)
static std::vector< FixItHint > handleReturnType(const FunctionDecl *Function, const TypeLoc &ReturnType, const EnableIfData &EnableIf, ASTContext &Context)
static SourceRange getTypeRange(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static std::optional< SourceLocation > findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImplTypename(TypeLoc TheType)
static std::optional< std::string > getConditionText(const Expr *ConditionExpr, SourceRange ConditionRange, ASTContext &Context)
static std::vector< FixItHint > handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate, const FunctionDecl *Function, const Decl *LastTemplateParam, const EnableIfData &EnableIf, ASTContext &Context)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImplTrait(TypeLoc TheType)
static std::optional< StringRef > getTypeText(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static SourceRange getConditionRange(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static SourceLocation getRAngleFileLoc(const SourceManager &SM, const T &Element)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImpl(TypeLoc TheType)
static bool isPrimaryExpression(const Expr *Expression)
static std::optional< EnableIfData > matchEnableIfSpecialization(TypeLoc TheType)
std::pair< std::optional< Token >, SourceLocation > getPreviousTokenAndStart(SourceLocation Location, const SourceManager &SM, const LangOptions &LangOpts, bool SkipComments)
SourceLocation findNextAnyTokenKind(SourceLocation Start, const SourceManager &SM, const LangOptions &LangOpts, TokenKind TK, TokenKinds... TKs)
SourceLocation findPreviousTokenKind(SourceLocation Start, const SourceManager &SM, const LangOptions &LangOpts, tok::TokenKind TK)