clang-tools 22.0.0git
UseNewMLIROpBuilderCheck.cpp
Go to the documentation of this file.
1//===--- UseNewMLIROpBuilderCheck.cpp - clang-tidy ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "clang/ASTMatchers/ASTMatchers.h"
11#include "clang/Basic/LLVM.h"
12#include "clang/Lex/Lexer.h"
13#include "clang/Tooling/Transformer/RangeSelector.h"
14#include "clang/Tooling/Transformer/RewriteRule.h"
15#include "clang/Tooling/Transformer/SourceCode.h"
16#include "clang/Tooling/Transformer/Stencil.h"
17#include "llvm/Support/Error.h"
18#include "llvm/Support/FormatVariadic.h"
19
21namespace {
22
23using namespace ::clang::ast_matchers;
24using namespace ::clang::transformer;
25
26EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
27 RangeSelector CallArgs) {
28 // This is using an EditGenerator rather than ASTEdit as we want to warn even
29 // if in macro.
30 return [Call = std::move(Call), Builder = std::move(Builder),
31 CallArgs =
32 std::move(CallArgs)](const MatchFinder::MatchResult &Result)
33 -> Expected<SmallVector<transformer::Edit, 1>> {
34 Expected<CharSourceRange> CallRange = Call(Result);
35 if (!CallRange)
36 return CallRange.takeError();
37 SourceManager &SM = *Result.SourceManager;
38 const LangOptions &LangOpts = Result.Context->getLangOpts();
39 SourceLocation Begin = CallRange->getBegin();
40
41 // This will result in just a warning and no edit.
42 bool InMacro = CallRange->getBegin().isMacroID();
43 if (InMacro) {
44 while (SM.isMacroArgExpansion(Begin))
45 Begin = SM.getImmediateExpansionRange(Begin).getBegin();
46 Edit WarnOnly;
47 WarnOnly.Kind = EditKind::Range;
48 WarnOnly.Range = CharSourceRange::getCharRange(Begin, Begin);
49 return SmallVector<Edit, 1>({WarnOnly});
50 }
51
52 // This will try to extract the template argument as written so that the
53 // rewritten code looks closest to original.
54 auto NextToken = [&](std::optional<Token> CurrentToken) {
55 if (!CurrentToken)
56 return CurrentToken;
57 if (CurrentToken->getEndLoc() >= CallRange->getEnd())
58 return std::optional<Token>();
59 return clang::Lexer::findNextToken(CurrentToken->getLocation(), SM,
60 LangOpts);
61 };
62 std::optional<Token> LessToken =
63 clang::Lexer::findNextToken(Begin, SM, LangOpts);
64 while (LessToken && LessToken->getKind() != clang::tok::less) {
65 LessToken = NextToken(LessToken);
66 }
67 if (!LessToken) {
68 return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
69 "missing '<' token");
70 }
71 std::optional<Token> EndToken = NextToken(LessToken);
72 for (std::optional<Token> GreaterToken = NextToken(EndToken);
73 GreaterToken && GreaterToken->getKind() != clang::tok::greater;
74 GreaterToken = NextToken(GreaterToken)) {
75 EndToken = GreaterToken;
76 }
77 if (!EndToken) {
78 return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
79 "missing '>' token");
80 }
81
82 Expected<CharSourceRange> BuilderRange = Builder(Result);
83 if (!BuilderRange)
84 return BuilderRange.takeError();
85 Expected<CharSourceRange> CallArgsRange = CallArgs(Result);
86 if (!CallArgsRange)
87 return CallArgsRange.takeError();
88
89 // Helper for concatting below.
90 auto GetText = [&](const CharSourceRange &Range) {
91 return clang::Lexer::getSourceText(Range, SM, LangOpts);
92 };
93
94 Edit Replace;
95 Replace.Kind = EditKind::Range;
96 Replace.Range = *CallRange;
97 std::string CallArgsStr;
98 // Only emit args if there are any.
99 if (auto CallArgsText = GetText(*CallArgsRange).ltrim();
100 !CallArgsText.rtrim().empty()) {
101 CallArgsStr = llvm::formatv(", {}", CallArgsText);
102 }
103 Replace.Replacement =
104 llvm::formatv("{}::create({}{})",
105 GetText(CharSourceRange::getTokenRange(
106 LessToken->getEndLoc(), EndToken->getLastLoc())),
107 GetText(*BuilderRange), CallArgsStr);
108
109 return SmallVector<Edit, 1>({Replace});
110 };
111}
112
113RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule() {
114 Stencil message = cat("use 'OpType::create(builder, ...)' instead of "
115 "'builder.create<OpType>(...)'");
116 // Match a create call on an OpBuilder.
117 ast_matchers::internal::Matcher<Stmt> base =
118 cxxMemberCallExpr(
119 on(expr(hasType(
120 cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder"))))
121 .bind("builder")),
122 callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()))),
123 callee(cxxMethodDecl(hasName("create"))))
124 .bind("call");
125 return applyFirst(
126 // Attempt rewrite given an lvalue builder, else just warn.
127 {makeRule(cxxMemberCallExpr(unless(on(cxxTemporaryObjectExpr())), base),
128 rewrite(node("call"), node("builder"), callArgs("call")),
129 message),
130 makeRule(base, noopEdit(node("call")), message)});
131}
132} // namespace
133
135 ClangTidyContext *Context)
136 : TransformerClangTidyCheck(useNewMlirOpBuilderCheckRule(), Name, Context) {
137}
138
139} // namespace clang::tidy::llvm_check
Every ClangTidyCheck reports errors through a DiagnosticsEngine provided by this context.
UseNewMlirOpBuilderCheck(StringRef Name, ClangTidyContext *Context)
TransformerClangTidyCheck(StringRef Name, ClangTidyContext *Context)
node(n, label, next_label)