clang-tools 22.0.0git
UseNewMLIROpBuilderCheck.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "clang/ASTMatchers/ASTMatchers.h"
11#include "clang/Basic/LLVM.h"
12#include "clang/Lex/Lexer.h"
13#include "clang/Tooling/Transformer/RangeSelector.h"
14#include "clang/Tooling/Transformer/RewriteRule.h"
15#include "clang/Tooling/Transformer/SourceCode.h"
16#include "clang/Tooling/Transformer/Stencil.h"
17#include "llvm/Support/Error.h"
18#include "llvm/Support/FormatVariadic.h"
19
21
22using namespace ::clang::ast_matchers;
23using namespace ::clang::transformer;
24
25static EditGenerator rewrite(RangeSelector Call, RangeSelector Builder) {
26 // This is using an EditGenerator rather than ASTEdit as we want to warn even
27 // if in macro.
28 return [Call = std::move(Call),
29 Builder = std::move(Builder)](const MatchFinder::MatchResult &Result)
30 -> Expected<SmallVector<transformer::Edit, 1>> {
31 Expected<CharSourceRange> CallRange = Call(Result);
32 if (!CallRange)
33 return CallRange.takeError();
34 SourceManager &SM = *Result.SourceManager;
35 const LangOptions &LangOpts = Result.Context->getLangOpts();
36 SourceLocation Begin = CallRange->getBegin();
37
38 // This will result in just a warning and no edit.
39 const bool InMacro = CallRange->getBegin().isMacroID();
40 if (InMacro) {
41 while (SM.isMacroArgExpansion(Begin))
42 Begin = SM.getImmediateExpansionRange(Begin).getBegin();
43 Edit WarnOnly;
44 WarnOnly.Kind = EditKind::Range;
45 WarnOnly.Range = CharSourceRange::getCharRange(Begin, Begin);
46 return SmallVector<Edit, 1>({WarnOnly});
47 }
48
49 // This will try to extract the template argument as written so that the
50 // rewritten code looks closest to original.
51 auto NextToken = [&](std::optional<Token> CurrentToken) {
52 if (!CurrentToken)
53 return CurrentToken;
54 if (CurrentToken->is(clang::tok::eof))
55 return std::optional<Token>();
56 return clang::Lexer::findNextToken(CurrentToken->getLocation(), SM,
57 LangOpts);
58 };
59 std::optional<Token> LessToken =
60 clang::Lexer::findNextToken(Begin, SM, LangOpts);
61 while (LessToken && LessToken->getKind() != clang::tok::less) {
62 LessToken = NextToken(LessToken);
63 }
64 if (!LessToken) {
65 return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
66 "missing '<' token");
67 }
68
69 std::optional<Token> EndToken = NextToken(LessToken);
70 std::optional<Token> GreaterToken = NextToken(EndToken);
71 for (; GreaterToken && GreaterToken->getKind() != clang::tok::greater;
72 GreaterToken = NextToken(GreaterToken)) {
73 EndToken = GreaterToken;
74 }
75 if (!EndToken) {
76 return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
77 "missing '>' token");
78 }
79
80 std::optional<Token> ArgStart = NextToken(GreaterToken);
81 if (!ArgStart || ArgStart->getKind() != clang::tok::l_paren) {
82 return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
83 "missing '(' token");
84 }
85 std::optional<Token> Arg = NextToken(ArgStart);
86 if (!Arg) {
87 return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
88 "unexpected end of file");
89 }
90 const bool HasArgs = Arg->getKind() != clang::tok::r_paren;
91
92 Expected<CharSourceRange> BuilderRange = Builder(Result);
93 if (!BuilderRange)
94 return BuilderRange.takeError();
95
96 // Helper for concatting below.
97 auto GetText = [&](const CharSourceRange &Range) {
98 return clang::Lexer::getSourceText(Range, SM, LangOpts);
99 };
100
101 Edit Replace;
102 Replace.Kind = EditKind::Range;
103 Replace.Range.setBegin(CallRange->getBegin());
104 Replace.Range.setEnd(ArgStart->getEndLoc());
105 const Expr *BuilderExpr = Result.Nodes.getNodeAs<Expr>("builder");
106 std::string BuilderText = GetText(*BuilderRange).str();
107 if (BuilderExpr->getType()->isPointerType()) {
108 BuilderText = BuilderExpr->isImplicitCXXThis()
109 ? "*this"
110 : llvm::formatv("*{}", BuilderText).str();
111 }
112 const StringRef OpType = GetText(CharSourceRange::getTokenRange(
113 LessToken->getEndLoc(), EndToken->getLastLoc()));
114 Replace.Replacement = llvm::formatv("{}::create({}{}", OpType, BuilderText,
115 HasArgs ? ", " : "");
116
117 return SmallVector<Edit, 1>({Replace});
118 };
119}
120
121static RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule() {
122 const Stencil Message = cat("use 'OpType::create(builder, ...)' instead of "
123 "'builder.create<OpType>(...)'");
124 // Match a create call on an OpBuilder.
125 auto BuilderType = cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder"));
126 const ast_matchers::internal::Matcher<Stmt> Base =
127 cxxMemberCallExpr(
128 on(expr(anyOf(hasType(BuilderType), hasType(pointsTo(BuilderType))))
129 .bind("builder")),
130 callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()),
131 hasName("create"))))
132 .bind("call");
133 return applyFirst(
134 // Attempt rewrite given an lvalue builder, else just warn.
135 {makeRule(cxxMemberCallExpr(unless(on(cxxTemporaryObjectExpr())), Base),
136 rewrite(node("call"), node("builder")), Message),
137 makeRule(Base, noopEdit(node("call")), Message)});
138}
139
144
145} // namespace clang::tidy::llvm_check
Every ClangTidyCheck reports errors through a DiagnosticsEngine provided by this context.
UseNewMlirOpBuilderCheck(StringRef Name, ClangTidyContext *Context)
TransformerClangTidyCheck(StringRef Name, ClangTidyContext *Context)
static RewriteRuleWith< std::string > useNewMlirOpBuilderCheckRule()
static EditGenerator rewrite(RangeSelector Call, RangeSelector Builder)