10#include "clang/AST/ASTContext.h"
11#include "clang/AST/DeclTemplate.h"
12#include "clang/ASTMatchers/ASTMatchFinder.h"
13#include "clang/Lex/Lexer.h"
25 TemplateSpecializationTypeLoc
Loc;
30AST_MATCHER(FunctionDecl, hasOtherDeclarations) {
31 auto It = Node.redecls_begin();
32 auto EndIt = Node.redecls_end();
46 unless(isExpansionInSystemHeader()),
47 has(functionDecl(unless(hasOtherDeclarations()), isDefinition(),
48 hasReturnTypeLoc(typeLoc().bind(
"return")))
50 .bind(
"functionTemplate"),
54static std::optional<TemplateSpecializationTypeLoc>
56 if (
const auto Dep = TheType.getAs<DependentNameTypeLoc>()) {
57 const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier();
58 ElaboratedTypeKeyword Keyword = Dep.getTypePtr()->getKeyword();
59 if (!Identifier || Identifier->getName() !=
"type" ||
60 (Keyword != ElaboratedTypeKeyword::Typename &&
61 Keyword != ElaboratedTypeKeyword::None)) {
64 TheType = Dep.getQualifierLoc().getAsTypeLoc();
71 if (
const auto SpecializationLoc =
72 TheType.getAs<TemplateSpecializationTypeLoc>()) {
74 const auto *Specialization =
75 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
79 const TemplateDecl *TD =
80 Specialization->getTemplateName().getAsTemplateDecl();
81 if (!TD || TD->getName() !=
"enable_if")
84 assert(!TD->getTemplateParameters()->empty() &&
85 "found template with no template parameters?");
86 const auto *FirstParam = dyn_cast<NonTypeTemplateParmDecl>(
87 TD->getTemplateParameters()->getParam(0));
88 if (!FirstParam || !FirstParam->getType()->isBooleanType())
91 int NumArgs = SpecializationLoc.getNumArgs();
92 if (NumArgs != 1 && NumArgs != 2)
95 return SpecializationLoc;
100static std::optional<TemplateSpecializationTypeLoc>
102 if (
const auto SpecializationLoc =
103 TheType.getAs<TemplateSpecializationTypeLoc>()) {
105 const auto *Specialization =
106 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
110 const TemplateDecl *TD =
111 Specialization->getTemplateName().getAsTemplateDecl();
112 if (!TD || TD->getName() !=
"enable_if_t")
115 if (!Specialization->isTypeAlias())
118 assert(!TD->getTemplateParameters()->empty() &&
119 "found template with no template parameters?");
120 const auto *FirstParam = dyn_cast<NonTypeTemplateParmDecl>(
121 TD->getTemplateParameters()->getParam(0));
122 if (!FirstParam || !FirstParam->getType()->isBooleanType())
125 if (
const auto *AliasedType =
126 dyn_cast<DependentNameType>(Specialization->getAliasedType())) {
127 ElaboratedTypeKeyword Keyword = AliasedType->getKeyword();
128 if (AliasedType->getIdentifier()->getName() !=
"type" ||
129 (Keyword != ElaboratedTypeKeyword::Typename &&
130 Keyword != ElaboratedTypeKeyword::None)) {
136 int NumArgs = SpecializationLoc.getNumArgs();
137 if (NumArgs != 1 && NumArgs != 2)
140 return SpecializationLoc;
145static std::optional<TemplateSpecializationTypeLoc>
152static std::optional<EnableIfData>
154 if (
const auto Pointer = TheType.getAs<PointerTypeLoc>())
155 TheType = Pointer.getPointeeLoc();
156 else if (
const auto Reference = TheType.getAs<ReferenceTypeLoc>())
157 TheType = Reference.getPointeeLoc();
158 if (
const auto Qualified = TheType.getAs<QualifiedTypeLoc>())
159 TheType = Qualified.getUnqualifiedLoc();
166static std::pair<std::optional<EnableIfData>,
const Decl *>
176 const TemplateParameterList *TemplateParams =
177 FunctionTemplate->getTemplateParameters();
178 if (TemplateParams->empty())
181 const NamedDecl *LastParam =
182 TemplateParams->getParam(TemplateParams->size() - 1);
183 if (
const auto *LastTemplateParam =
184 dyn_cast<NonTypeTemplateParmDecl>(LastParam)) {
186 if (!LastTemplateParam->hasDefaultArgument() ||
187 !LastTemplateParam->getName().empty())
191 LastTemplateParam->getTypeSourceInfo()->getTypeLoc()),
194 if (
const auto *LastTemplateParam =
195 dyn_cast<TemplateTypeParmDecl>(LastParam)) {
196 if (LastTemplateParam->hasDefaultArgument() &&
197 LastTemplateParam->getIdentifier() ==
nullptr) {
214 return SM.getFileLoc(Element.getRAngleLoc());
219 const TemplateSpecializationTypeLoc &EnableIf) {
223 const LangOptions &LangOpts = Context.getLangOpts();
224 const SourceManager &SM = Context.getSourceManager();
225 if (EnableIf.getNumArgs() > 1) {
226 TemplateArgumentLoc NextArg = EnableIf.getArgLoc(1);
227 return {EnableIf.getLAngleLoc().getLocWithOffset(1),
229 NextArg.getSourceRange().getBegin(), SM, LangOpts, tok::comma)};
232 return {EnableIf.getLAngleLoc().getLocWithOffset(1),
237 const TemplateSpecializationTypeLoc &EnableIf) {
238 TemplateArgumentLoc Arg = EnableIf.getArgLoc(1);
239 const LangOptions &LangOpts = Context.getLangOpts();
240 const SourceManager &SM = Context.getSourceManager();
242 SM, LangOpts, tok::comma)
243 .getLocWithOffset(1),
250static std::optional<StringRef>
252 const TemplateSpecializationTypeLoc &EnableIf) {
253 if (EnableIf.getNumArgs() > 1) {
254 const LangOptions &LangOpts = Context.getLangOpts();
255 const SourceManager &SM = Context.getSourceManager();
256 bool Invalid =
false;
257 StringRef Text = Lexer::getSourceText(CharSourceRange::getCharRange(
259 SM, LangOpts, &Invalid)
270static std::optional<SourceLocation>
272 SourceManager &SM = Context.getSourceManager();
273 const LangOptions &LangOpts = Context.getLangOpts();
275 if (
const auto *Constructor = dyn_cast<CXXConstructorDecl>(Function)) {
276 for (
const CXXCtorInitializer *Init : Constructor->inits()) {
277 if (Init->getSourceOrder() == 0)
279 SM, LangOpts, tok::colon);
281 if (!Constructor->inits().empty())
284 if (Function->isDeleted()) {
285 SourceLocation FunctionEnd = Function->getSourceRange().getEnd();
287 tok::equal, tok::equal);
289 const Stmt *Body = Function->getBody();
293 return Body->getBeginLoc();
302 if (
const auto *Cast = dyn_cast<ImplicitCastExpr>(Expression))
303 Expression = Cast->getSubExprAsWritten();
304 if (isa<ParenExpr, DependentScopeDeclRefExpr>(Expression))
315 SourceRange ConditionRange,
316 ASTContext &Context) {
317 SourceManager &SM = Context.getSourceManager();
318 const LangOptions &LangOpts = Context.getLangOpts();
320 SourceLocation PrevTokenLoc = ConditionRange.getEnd();
321 if (PrevTokenLoc.isInvalid())
324 const bool SkipComments =
false;
327 PrevTokenLoc, SM, LangOpts, SkipComments);
328 bool EndsWithDoubleSlash =
329 PrevToken.is(tok::comment) &&
330 Lexer::getSourceText(CharSourceRange::getCharRange(
331 PrevTokenLoc, PrevTokenLoc.getLocWithOffset(2)),
332 SM, LangOpts) ==
"//";
334 bool Invalid =
false;
335 llvm::StringRef ConditionText = Lexer::getSourceText(
336 CharSourceRange::getCharRange(ConditionRange), SM, LangOpts, &Invalid);
340 auto AddParens = [&](llvm::StringRef Text) -> std::string {
343 return "(" + Text.str() +
")";
346 if (EndsWithDoubleSlash)
347 return AddParens(ConditionText);
348 return AddParens(ConditionText.trim());
361 const TypeLoc &ReturnType,
363 ASTContext &Context) {
364 TemplateArgumentLoc EnableCondition = EnableIf.
Loc.getArgLoc(0);
369 EnableCondition.getSourceExpression(), ConditionRange, Context);
373 std::optional<StringRef> TypeText =
getTypeText(Context, EnableIf.
Loc);
377 SmallVector<AssociatedConstraint, 3> ExistingConstraints;
378 Function->getAssociatedConstraints(ExistingConstraints);
379 if (!ExistingConstraints.empty()) {
385 std::optional<SourceLocation> ConstraintInsertionLoc =
387 if (!ConstraintInsertionLoc)
390 std::vector<FixItHint> FixIts;
391 FixIts.push_back(FixItHint::CreateReplacement(
392 CharSourceRange::getTokenRange(EnableIf.
Outer.getSourceRange()),
394 FixIts.push_back(FixItHint::CreateInsertion(
395 *ConstraintInsertionLoc,
"requires " + *ConditionText +
" "));
408static std::vector<FixItHint>
410 const FunctionDecl *Function,
411 const Decl *LastTemplateParam,
413 SourceManager &SM = Context.getSourceManager();
414 const LangOptions &LangOpts = Context.getLangOpts();
416 TemplateArgumentLoc EnableCondition = EnableIf.
Loc.getArgLoc(0);
421 EnableCondition.getSourceExpression(), ConditionRange, Context);
425 SmallVector<AssociatedConstraint, 3> ExistingConstraints;
426 Function->getAssociatedConstraints(ExistingConstraints);
427 if (!ExistingConstraints.empty()) {
433 SourceRange RemovalRange;
434 const TemplateParameterList *TemplateParams =
435 FunctionTemplate->getTemplateParameters();
436 if (!TemplateParams || TemplateParams->empty())
439 if (TemplateParams->size() == 1) {
441 SourceRange(TemplateParams->getTemplateLoc(),
446 LastTemplateParam->getSourceRange().getBegin(), SM,
447 LangOpts, tok::comma),
451 std::optional<SourceLocation> ConstraintInsertionLoc =
453 if (!ConstraintInsertionLoc)
456 std::vector<FixItHint> FixIts;
458 FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange)));
459 FixIts.push_back(FixItHint::CreateInsertion(
460 *ConstraintInsertionLoc,
"requires " + *ConditionText +
" "));
465 const auto *FunctionTemplate =
466 Result.Nodes.getNodeAs<FunctionTemplateDecl>(
"functionTemplate");
467 const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>(
"function");
468 const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>(
"return");
469 if (!FunctionTemplate || !Function || !ReturnType)
492 diag(ReturnType->getBeginLoc(),
493 "use C++20 requires constraints instead of enable_if")
499 if (
auto [EnableIf, LastTemplateParam] =
501 EnableIf && LastTemplateParam) {
502 diag(LastTemplateParam->getSourceRange().getBegin(),
503 "use C++20 requires constraints instead of enable_if")
505 LastTemplateParam, *EnableIf,
void registerMatchers(ast_matchers::MatchFinder *Finder) override
void check(const ast_matchers::MatchFinder::MatchResult &Result) override
static std::pair< std::optional< EnableIfData >, const Decl * > matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate)
static std::vector< FixItHint > handleReturnType(const FunctionDecl *Function, const TypeLoc &ReturnType, const EnableIfData &EnableIf, ASTContext &Context)
static SourceRange getTypeRange(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static std::optional< SourceLocation > findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImplTypename(TypeLoc TheType)
static std::optional< std::string > getConditionText(const Expr *ConditionExpr, SourceRange ConditionRange, ASTContext &Context)
static std::vector< FixItHint > handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate, const FunctionDecl *Function, const Decl *LastTemplateParam, const EnableIfData &EnableIf, ASTContext &Context)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImplTrait(TypeLoc TheType)
static std::optional< StringRef > getTypeText(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static SourceRange getConditionRange(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static SourceLocation getRAngleFileLoc(const SourceManager &SM, const T &Element)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImpl(TypeLoc TheType)
static bool isPrimaryExpression(const Expr *Expression)
static std::optional< EnableIfData > matchEnableIfSpecialization(TypeLoc TheType)
std::pair< Token, SourceLocation > getPreviousTokenAndStart(SourceLocation Location, const SourceManager &SM, const LangOptions &LangOpts, bool SkipComments)
SourceLocation findNextAnyTokenKind(SourceLocation Start, const SourceManager &SM, const LangOptions &LangOpts, TokenKind TK, TokenKinds... TKs)
SourceLocation findPreviousTokenKind(SourceLocation Start, const SourceManager &SM, const LangOptions &LangOpts, tok::TokenKind TK)
TemplateSpecializationTypeLoc Loc