10#include "clang/AST/ASTContext.h"
11#include "clang/AST/DeclTemplate.h"
12#include "clang/ASTMatchers/ASTMatchFinder.h"
13#include "clang/Lex/Lexer.h"
25 TemplateSpecializationTypeLoc
Loc;
30AST_MATCHER(FunctionDecl, hasOtherDeclarations) {
31 auto It = Node.redecls_begin();
32 auto EndIt = Node.redecls_end();
46 unless(isExpansionInSystemHeader()),
47 has(functionDecl(unless(hasOtherDeclarations()), isDefinition(),
48 hasReturnTypeLoc(typeLoc().bind(
"return")))
50 .bind(
"functionTemplate"),
54static std::optional<TemplateSpecializationTypeLoc>
56 if (
const auto Dep = TheType.getAs<DependentNameTypeLoc>()) {
57 const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier();
58 const ElaboratedTypeKeyword Keyword = Dep.getTypePtr()->getKeyword();
59 if (!Identifier || Identifier->getName() !=
"type" ||
60 (Keyword != ElaboratedTypeKeyword::Typename &&
61 Keyword != ElaboratedTypeKeyword::None)) {
64 TheType = Dep.getQualifierLoc().getAsTypeLoc();
71 if (
const auto SpecializationLoc =
72 TheType.getAs<TemplateSpecializationTypeLoc>()) {
73 const auto *Specialization =
74 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
78 const TemplateDecl *TD =
79 Specialization->getTemplateName().getAsTemplateDecl();
80 if (!TD || TD->getName() !=
"enable_if")
83 assert(!TD->getTemplateParameters()->empty() &&
84 "found template with no template parameters?");
85 const auto *FirstParam = dyn_cast<NonTypeTemplateParmDecl>(
86 TD->getTemplateParameters()->getParam(0));
87 if (!FirstParam || !FirstParam->getType()->isBooleanType())
90 const int NumArgs = SpecializationLoc.getNumArgs();
91 if (NumArgs != 1 && NumArgs != 2)
94 return SpecializationLoc;
99static std::optional<TemplateSpecializationTypeLoc>
101 if (
const auto SpecializationLoc =
102 TheType.getAs<TemplateSpecializationTypeLoc>()) {
103 const auto *Specialization =
104 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
108 const TemplateDecl *TD =
109 Specialization->getTemplateName().getAsTemplateDecl();
110 if (!TD || TD->getName() !=
"enable_if_t")
113 if (!Specialization->isTypeAlias())
116 assert(!TD->getTemplateParameters()->empty() &&
117 "found template with no template parameters?");
118 const auto *FirstParam = dyn_cast<NonTypeTemplateParmDecl>(
119 TD->getTemplateParameters()->getParam(0));
120 if (!FirstParam || !FirstParam->getType()->isBooleanType())
123 if (
const auto *AliasedType =
124 dyn_cast<DependentNameType>(Specialization->getAliasedType())) {
125 const ElaboratedTypeKeyword Keyword = AliasedType->getKeyword();
126 if (AliasedType->getIdentifier()->getName() !=
"type" ||
127 (Keyword != ElaboratedTypeKeyword::Typename &&
128 Keyword != ElaboratedTypeKeyword::None)) {
134 const int NumArgs = SpecializationLoc.getNumArgs();
135 if (NumArgs != 1 && NumArgs != 2)
138 return SpecializationLoc;
143static std::optional<TemplateSpecializationTypeLoc>
150static std::optional<EnableIfData>
152 if (
const auto Pointer = TheType.getAs<PointerTypeLoc>())
153 TheType = Pointer.getPointeeLoc();
154 else if (
const auto Reference = TheType.getAs<ReferenceTypeLoc>())
155 TheType = Reference.getPointeeLoc();
156 if (
const auto Qualified = TheType.getAs<QualifiedTypeLoc>())
157 TheType = Qualified.getUnqualifiedLoc();
164static std::pair<std::optional<EnableIfData>,
const Decl *>
174 const TemplateParameterList *TemplateParams =
175 FunctionTemplate->getTemplateParameters();
176 if (TemplateParams->empty())
179 const NamedDecl *LastParam =
180 TemplateParams->getParam(TemplateParams->size() - 1);
181 if (
const auto *LastTemplateParam =
182 dyn_cast<NonTypeTemplateParmDecl>(LastParam)) {
183 if (!LastTemplateParam->hasDefaultArgument() ||
184 !LastTemplateParam->getName().empty())
188 LastTemplateParam->getTypeSourceInfo()->getTypeLoc()),
191 if (
const auto *LastTemplateParam =
192 dyn_cast<TemplateTypeParmDecl>(LastParam)) {
193 if (LastTemplateParam->hasDefaultArgument() &&
194 LastTemplateParam->getIdentifier() ==
nullptr) {
211 return SM.getFileLoc(Element.getRAngleLoc());
216 const TemplateSpecializationTypeLoc &EnableIf) {
220 const LangOptions &LangOpts = Context.getLangOpts();
221 const SourceManager &SM = Context.getSourceManager();
222 if (EnableIf.getNumArgs() > 1) {
223 const TemplateArgumentLoc NextArg = EnableIf.getArgLoc(1);
224 return {EnableIf.getLAngleLoc().getLocWithOffset(1),
226 NextArg.getSourceRange().getBegin(), SM, LangOpts, tok::comma)};
229 return {EnableIf.getLAngleLoc().getLocWithOffset(1),
234 const TemplateSpecializationTypeLoc &EnableIf) {
235 const TemplateArgumentLoc Arg = EnableIf.getArgLoc(1);
236 const LangOptions &LangOpts = Context.getLangOpts();
237 const SourceManager &SM = Context.getSourceManager();
239 SM, LangOpts, tok::comma)
240 .getLocWithOffset(1),
247static std::optional<StringRef>
249 const TemplateSpecializationTypeLoc &EnableIf) {
250 if (EnableIf.getNumArgs() > 1) {
251 const LangOptions &LangOpts = Context.getLangOpts();
252 const SourceManager &SM = Context.getSourceManager();
253 bool Invalid =
false;
254 StringRef Text = Lexer::getSourceText(CharSourceRange::getCharRange(
256 SM, LangOpts, &Invalid)
267static std::optional<SourceLocation>
269 const SourceManager &SM = Context.getSourceManager();
270 const LangOptions &LangOpts = Context.getLangOpts();
272 if (
const auto *Constructor = dyn_cast<CXXConstructorDecl>(Function)) {
273 for (
const CXXCtorInitializer *Init : Constructor->inits())
274 if (Init->getSourceOrder() == 0)
276 SM, LangOpts, tok::colon);
277 if (!Constructor->inits().empty())
280 if (Function->isDeleted()) {
281 const SourceLocation FunctionEnd = Function->getSourceRange().getEnd();
283 tok::equal, tok::equal);
285 const Stmt *Body = Function->getBody();
289 return Body->getBeginLoc();
298 if (
const auto *Cast = dyn_cast<ImplicitCastExpr>(Expression))
299 Expression = Cast->getSubExprAsWritten();
300 if (isa<ParenExpr, DependentScopeDeclRefExpr>(Expression))
311 SourceRange ConditionRange,
312 ASTContext &Context) {
313 const SourceManager &SM = Context.getSourceManager();
314 const LangOptions &LangOpts = Context.getLangOpts();
316 SourceLocation PrevTokenLoc = ConditionRange.getEnd();
317 if (PrevTokenLoc.isInvalid())
320 const bool SkipComments =
false;
323 PrevTokenLoc, SM, LangOpts, SkipComments);
324 const bool EndsWithDoubleSlash =
325 PrevToken.is(tok::comment) &&
326 Lexer::getSourceText(CharSourceRange::getCharRange(
327 PrevTokenLoc, PrevTokenLoc.getLocWithOffset(2)),
328 SM, LangOpts) ==
"//";
330 bool Invalid =
false;
331 const llvm::StringRef ConditionText = Lexer::getSourceText(
332 CharSourceRange::getCharRange(ConditionRange), SM, LangOpts, &Invalid);
336 auto AddParens = [&](llvm::StringRef Text) -> std::string {
339 return "(" + Text.str() +
")";
342 if (EndsWithDoubleSlash)
343 return AddParens(ConditionText);
344 return AddParens(ConditionText.trim());
357 const TypeLoc &ReturnType,
359 ASTContext &Context) {
360 const TemplateArgumentLoc EnableCondition = EnableIf.
Loc.getArgLoc(0);
365 EnableCondition.getSourceExpression(), ConditionRange, Context);
369 std::optional<StringRef> TypeText =
getTypeText(Context, EnableIf.
Loc);
373 SmallVector<AssociatedConstraint, 3> ExistingConstraints;
374 Function->getAssociatedConstraints(ExistingConstraints);
375 if (!ExistingConstraints.empty()) {
381 std::optional<SourceLocation> ConstraintInsertionLoc =
383 if (!ConstraintInsertionLoc)
386 std::vector<FixItHint> FixIts;
387 FixIts.push_back(FixItHint::CreateReplacement(
388 CharSourceRange::getTokenRange(EnableIf.
Outer.getSourceRange()),
390 FixIts.push_back(FixItHint::CreateInsertion(
391 *ConstraintInsertionLoc,
"requires " + *ConditionText +
" "));
404static std::vector<FixItHint>
406 const FunctionDecl *Function,
407 const Decl *LastTemplateParam,
409 const SourceManager &SM = Context.getSourceManager();
410 const LangOptions &LangOpts = Context.getLangOpts();
412 const TemplateArgumentLoc EnableCondition = EnableIf.
Loc.getArgLoc(0);
417 EnableCondition.getSourceExpression(), ConditionRange, Context);
421 SmallVector<AssociatedConstraint, 3> ExistingConstraints;
422 Function->getAssociatedConstraints(ExistingConstraints);
423 if (!ExistingConstraints.empty()) {
429 SourceRange RemovalRange;
430 const TemplateParameterList *TemplateParams =
431 FunctionTemplate->getTemplateParameters();
432 if (!TemplateParams || TemplateParams->empty())
435 if (TemplateParams->size() == 1) {
437 SourceRange(TemplateParams->getTemplateLoc(),
442 LastTemplateParam->getSourceRange().getBegin(), SM,
443 LangOpts, tok::comma),
447 std::optional<SourceLocation> ConstraintInsertionLoc =
449 if (!ConstraintInsertionLoc)
452 std::vector<FixItHint> FixIts;
454 FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange)));
455 FixIts.push_back(FixItHint::CreateInsertion(
456 *ConstraintInsertionLoc,
"requires " + *ConditionText +
" "));
461 const auto *FunctionTemplate =
462 Result.Nodes.getNodeAs<FunctionTemplateDecl>(
"functionTemplate");
463 const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>(
"function");
464 const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>(
"return");
465 if (!FunctionTemplate || !Function || !ReturnType)
488 diag(ReturnType->getBeginLoc(),
489 "use C++20 requires constraints instead of enable_if")
495 if (
auto [EnableIf, LastTemplateParam] =
497 EnableIf && LastTemplateParam) {
498 diag(LastTemplateParam->getSourceRange().getBegin(),
499 "use C++20 requires constraints instead of enable_if")
501 LastTemplateParam, *EnableIf,
void registerMatchers(ast_matchers::MatchFinder *Finder) override
void check(const ast_matchers::MatchFinder::MatchResult &Result) override
static std::pair< std::optional< EnableIfData >, const Decl * > matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate)
static std::vector< FixItHint > handleReturnType(const FunctionDecl *Function, const TypeLoc &ReturnType, const EnableIfData &EnableIf, ASTContext &Context)
static SourceRange getTypeRange(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static std::optional< SourceLocation > findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImplTypename(TypeLoc TheType)
static std::optional< std::string > getConditionText(const Expr *ConditionExpr, SourceRange ConditionRange, ASTContext &Context)
static std::vector< FixItHint > handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate, const FunctionDecl *Function, const Decl *LastTemplateParam, const EnableIfData &EnableIf, ASTContext &Context)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImplTrait(TypeLoc TheType)
static std::optional< StringRef > getTypeText(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static SourceRange getConditionRange(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static SourceLocation getRAngleFileLoc(const SourceManager &SM, const T &Element)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImpl(TypeLoc TheType)
static bool isPrimaryExpression(const Expr *Expression)
static std::optional< EnableIfData > matchEnableIfSpecialization(TypeLoc TheType)
std::pair< Token, SourceLocation > getPreviousTokenAndStart(SourceLocation Location, const SourceManager &SM, const LangOptions &LangOpts, bool SkipComments)
SourceLocation findNextAnyTokenKind(SourceLocation Start, const SourceManager &SM, const LangOptions &LangOpts, TokenKind TK, TokenKinds... TKs)
SourceLocation findPreviousTokenKind(SourceLocation Start, const SourceManager &SM, const LangOptions &LangOpts, tok::TokenKind TK)
TemplateSpecializationTypeLoc Loc