10#include "clang/AST/ASTContext.h"
11#include "clang/AST/DeclTemplate.h"
12#include "clang/ASTMatchers/ASTMatchFinder.h"
13#include "clang/Lex/Lexer.h"
25 TemplateSpecializationTypeLoc
Loc;
30AST_MATCHER(FunctionDecl, hasOtherDeclarations) {
31 auto It = Node.redecls_begin();
32 auto EndIt = Node.redecls_end();
46 unless(isExpansionInSystemHeader()),
47 has(functionDecl(unless(hasOtherDeclarations()), isDefinition(),
48 hasReturnTypeLoc(typeLoc().bind(
"return")))
50 .bind(
"functionTemplate"),
54static std::optional<TemplateSpecializationTypeLoc>
56 if (
const auto Dep = TheType.getAs<DependentNameTypeLoc>()) {
57 const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier();
58 const ElaboratedTypeKeyword Keyword = Dep.getTypePtr()->getKeyword();
59 if (!Identifier || Identifier->getName() !=
"type" ||
60 (Keyword != ElaboratedTypeKeyword::Typename &&
61 Keyword != ElaboratedTypeKeyword::None)) {
64 TheType = Dep.getQualifierLoc().getAsTypeLoc();
71 if (
const auto SpecializationLoc =
72 TheType.getAs<TemplateSpecializationTypeLoc>()) {
73 const auto *Specialization =
74 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
78 const TemplateDecl *TD =
79 Specialization->getTemplateName().getAsTemplateDecl();
80 if (!TD || TD->getName() !=
"enable_if")
83 assert(!TD->getTemplateParameters()->empty() &&
84 "found template with no template parameters?");
85 const auto *FirstParam = dyn_cast<NonTypeTemplateParmDecl>(
86 TD->getTemplateParameters()->getParam(0));
87 if (!FirstParam || !FirstParam->getType()->isBooleanType())
90 const int NumArgs = SpecializationLoc.getNumArgs();
91 if (NumArgs != 1 && NumArgs != 2)
94 return SpecializationLoc;
99static std::optional<TemplateSpecializationTypeLoc>
101 if (
const auto SpecializationLoc =
102 TheType.getAs<TemplateSpecializationTypeLoc>()) {
103 const auto *Specialization =
104 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
108 const TemplateDecl *TD =
109 Specialization->getTemplateName().getAsTemplateDecl();
110 if (!TD || TD->getName() !=
"enable_if_t")
113 if (!Specialization->isTypeAlias())
116 assert(!TD->getTemplateParameters()->empty() &&
117 "found template with no template parameters?");
118 const auto *FirstParam = dyn_cast<NonTypeTemplateParmDecl>(
119 TD->getTemplateParameters()->getParam(0));
120 if (!FirstParam || !FirstParam->getType()->isBooleanType())
123 if (
const auto *AliasedType =
124 dyn_cast<DependentNameType>(Specialization->getAliasedType())) {
125 const ElaboratedTypeKeyword Keyword = AliasedType->getKeyword();
126 if (AliasedType->getIdentifier()->getName() !=
"type" ||
127 (Keyword != ElaboratedTypeKeyword::Typename &&
128 Keyword != ElaboratedTypeKeyword::None)) {
134 const int NumArgs = SpecializationLoc.getNumArgs();
135 if (NumArgs != 1 && NumArgs != 2)
138 return SpecializationLoc;
143static std::optional<TemplateSpecializationTypeLoc>
150static std::optional<EnableIfData>
152 if (
const auto Pointer = TheType.getAs<PointerTypeLoc>())
153 TheType = Pointer.getPointeeLoc();
154 else if (
const auto Reference = TheType.getAs<ReferenceTypeLoc>())
155 TheType = Reference.getPointeeLoc();
156 if (
const auto Qualified = TheType.getAs<QualifiedTypeLoc>())
157 TheType = Qualified.getUnqualifiedLoc();
164static std::pair<std::optional<EnableIfData>,
const Decl *>
174 const TemplateParameterList *TemplateParams =
175 FunctionTemplate->getTemplateParameters();
176 if (TemplateParams->empty())
179 const NamedDecl *LastParam =
180 TemplateParams->getParam(TemplateParams->size() - 1);
181 if (
const auto *LastTemplateParam =
182 dyn_cast<NonTypeTemplateParmDecl>(LastParam)) {
183 if (!LastTemplateParam->hasDefaultArgument() ||
184 !LastTemplateParam->getName().empty())
188 LastTemplateParam->getTypeSourceInfo()->getTypeLoc()),
191 if (
const auto *LastTemplateParam =
192 dyn_cast<TemplateTypeParmDecl>(LastParam)) {
193 if (LastTemplateParam->hasDefaultArgument() &&
194 LastTemplateParam->getIdentifier() ==
nullptr) {
211 return SM.getFileLoc(Element.getRAngleLoc());
216 const TemplateSpecializationTypeLoc &EnableIf) {
220 const LangOptions &LangOpts = Context.getLangOpts();
221 const SourceManager &SM = Context.getSourceManager();
222 if (EnableIf.getNumArgs() > 1) {
223 const TemplateArgumentLoc NextArg = EnableIf.getArgLoc(1);
224 return {EnableIf.getLAngleLoc().getLocWithOffset(1),
226 NextArg.getSourceRange().getBegin(), SM, LangOpts, tok::comma)};
229 return {EnableIf.getLAngleLoc().getLocWithOffset(1),
234 const TemplateSpecializationTypeLoc &EnableIf) {
235 const TemplateArgumentLoc Arg = EnableIf.getArgLoc(1);
236 const LangOptions &LangOpts = Context.getLangOpts();
237 const SourceManager &SM = Context.getSourceManager();
239 SM, LangOpts, tok::comma)
240 .getLocWithOffset(1),
247static std::optional<StringRef>
249 const TemplateSpecializationTypeLoc &EnableIf) {
250 if (EnableIf.getNumArgs() > 1) {
251 const LangOptions &LangOpts = Context.getLangOpts();
252 const SourceManager &SM = Context.getSourceManager();
253 bool Invalid =
false;
254 StringRef Text = Lexer::getSourceText(CharSourceRange::getCharRange(
256 SM, LangOpts, &Invalid)
267static std::optional<SourceLocation>
269 const SourceManager &SM = Context.getSourceManager();
270 const LangOptions &LangOpts = Context.getLangOpts();
272 if (
const auto *Constructor = dyn_cast<CXXConstructorDecl>(Function)) {
273 for (
const CXXCtorInitializer *Init : Constructor->inits()) {
274 if (Init->getSourceOrder() == 0)
276 SM, LangOpts, tok::colon);
278 if (!Constructor->inits().empty())
281 if (Function->isDeleted()) {
282 const SourceLocation FunctionEnd = Function->getSourceRange().getEnd();
284 tok::equal, tok::equal);
286 const Stmt *Body = Function->getBody();
290 return Body->getBeginLoc();
299 if (
const auto *Cast = dyn_cast<ImplicitCastExpr>(Expression))
300 Expression = Cast->getSubExprAsWritten();
301 if (isa<ParenExpr, DependentScopeDeclRefExpr>(Expression))
312 SourceRange ConditionRange,
313 ASTContext &Context) {
314 const SourceManager &SM = Context.getSourceManager();
315 const LangOptions &LangOpts = Context.getLangOpts();
317 SourceLocation PrevTokenLoc = ConditionRange.getEnd();
318 if (PrevTokenLoc.isInvalid())
321 const bool SkipComments =
false;
324 PrevTokenLoc, SM, LangOpts, SkipComments);
325 const bool EndsWithDoubleSlash =
326 PrevToken.is(tok::comment) &&
327 Lexer::getSourceText(CharSourceRange::getCharRange(
328 PrevTokenLoc, PrevTokenLoc.getLocWithOffset(2)),
329 SM, LangOpts) ==
"//";
331 bool Invalid =
false;
332 const llvm::StringRef ConditionText = Lexer::getSourceText(
333 CharSourceRange::getCharRange(ConditionRange), SM, LangOpts, &Invalid);
337 auto AddParens = [&](llvm::StringRef Text) -> std::string {
340 return "(" + Text.str() +
")";
343 if (EndsWithDoubleSlash)
344 return AddParens(ConditionText);
345 return AddParens(ConditionText.trim());
358 const TypeLoc &ReturnType,
360 ASTContext &Context) {
361 const TemplateArgumentLoc EnableCondition = EnableIf.
Loc.getArgLoc(0);
366 EnableCondition.getSourceExpression(), ConditionRange, Context);
370 std::optional<StringRef> TypeText =
getTypeText(Context, EnableIf.
Loc);
374 SmallVector<AssociatedConstraint, 3> ExistingConstraints;
375 Function->getAssociatedConstraints(ExistingConstraints);
376 if (!ExistingConstraints.empty()) {
382 std::optional<SourceLocation> ConstraintInsertionLoc =
384 if (!ConstraintInsertionLoc)
387 std::vector<FixItHint> FixIts;
388 FixIts.push_back(FixItHint::CreateReplacement(
389 CharSourceRange::getTokenRange(EnableIf.
Outer.getSourceRange()),
391 FixIts.push_back(FixItHint::CreateInsertion(
392 *ConstraintInsertionLoc,
"requires " + *ConditionText +
" "));
405static std::vector<FixItHint>
407 const FunctionDecl *Function,
408 const Decl *LastTemplateParam,
410 const SourceManager &SM = Context.getSourceManager();
411 const LangOptions &LangOpts = Context.getLangOpts();
413 const TemplateArgumentLoc EnableCondition = EnableIf.
Loc.getArgLoc(0);
418 EnableCondition.getSourceExpression(), ConditionRange, Context);
422 SmallVector<AssociatedConstraint, 3> ExistingConstraints;
423 Function->getAssociatedConstraints(ExistingConstraints);
424 if (!ExistingConstraints.empty()) {
430 SourceRange RemovalRange;
431 const TemplateParameterList *TemplateParams =
432 FunctionTemplate->getTemplateParameters();
433 if (!TemplateParams || TemplateParams->empty())
436 if (TemplateParams->size() == 1) {
438 SourceRange(TemplateParams->getTemplateLoc(),
443 LastTemplateParam->getSourceRange().getBegin(), SM,
444 LangOpts, tok::comma),
448 std::optional<SourceLocation> ConstraintInsertionLoc =
450 if (!ConstraintInsertionLoc)
453 std::vector<FixItHint> FixIts;
455 FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange)));
456 FixIts.push_back(FixItHint::CreateInsertion(
457 *ConstraintInsertionLoc,
"requires " + *ConditionText +
" "));
462 const auto *FunctionTemplate =
463 Result.Nodes.getNodeAs<FunctionTemplateDecl>(
"functionTemplate");
464 const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>(
"function");
465 const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>(
"return");
466 if (!FunctionTemplate || !Function || !ReturnType)
489 diag(ReturnType->getBeginLoc(),
490 "use C++20 requires constraints instead of enable_if")
496 if (
auto [EnableIf, LastTemplateParam] =
498 EnableIf && LastTemplateParam) {
499 diag(LastTemplateParam->getSourceRange().getBegin(),
500 "use C++20 requires constraints instead of enable_if")
502 LastTemplateParam, *EnableIf,
void registerMatchers(ast_matchers::MatchFinder *Finder) override
void check(const ast_matchers::MatchFinder::MatchResult &Result) override
static std::pair< std::optional< EnableIfData >, const Decl * > matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate)
static std::vector< FixItHint > handleReturnType(const FunctionDecl *Function, const TypeLoc &ReturnType, const EnableIfData &EnableIf, ASTContext &Context)
static SourceRange getTypeRange(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static std::optional< SourceLocation > findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImplTypename(TypeLoc TheType)
static std::optional< std::string > getConditionText(const Expr *ConditionExpr, SourceRange ConditionRange, ASTContext &Context)
static std::vector< FixItHint > handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate, const FunctionDecl *Function, const Decl *LastTemplateParam, const EnableIfData &EnableIf, ASTContext &Context)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImplTrait(TypeLoc TheType)
static std::optional< StringRef > getTypeText(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static SourceRange getConditionRange(ASTContext &Context, const TemplateSpecializationTypeLoc &EnableIf)
static SourceLocation getRAngleFileLoc(const SourceManager &SM, const T &Element)
static std::optional< TemplateSpecializationTypeLoc > matchEnableIfSpecializationImpl(TypeLoc TheType)
static bool isPrimaryExpression(const Expr *Expression)
static std::optional< EnableIfData > matchEnableIfSpecialization(TypeLoc TheType)
std::pair< Token, SourceLocation > getPreviousTokenAndStart(SourceLocation Location, const SourceManager &SM, const LangOptions &LangOpts, bool SkipComments)
SourceLocation findNextAnyTokenKind(SourceLocation Start, const SourceManager &SM, const LangOptions &LangOpts, TokenKind TK, TokenKinds... TKs)
SourceLocation findPreviousTokenKind(SourceLocation Start, const SourceManager &SM, const LangOptions &LangOpts, tok::TokenKind TK)
TemplateSpecializationTypeLoc Loc