10#include "clang/AST/APValue.h"
11#include "clang/AST/ASTContext.h"
12#include "clang/AST/ASTTypeTraits.h"
13#include "clang/AST/OperationKinds.h"
14#include "clang/AST/ParentMapContext.h"
15#include "clang/ASTMatchers/ASTMatchFinder.h"
24 MaxLoopIterations(Options.get(
"MaxLoopIterations", 100U)) {}
27 const auto HasLoopBound = hasDescendant(
28 varDecl(matchesName(
"__end*"),
29 hasDescendant(integerLiteral().bind(
"cxx_loop_bound"))));
30 const auto CXXForRangeLoop =
31 cxxForRangeStmt(anyOf(HasLoopBound, unless(HasLoopBound)));
32 const auto AnyLoop = anyOf(forStmt(), whileStmt(), doStmt(), CXXForRangeLoop);
34 stmt(AnyLoop, unless(hasDescendant(stmt(AnyLoop)))).bind(
"loop"),
this);
38 const auto *Loop = Result.Nodes.getNodeAs<Stmt>(
"loop");
39 const auto *CXXLoopBound =
40 Result.Nodes.getNodeAs<IntegerLiteral>(
"cxx_loop_bound");
41 const ASTContext *Context = Result.Context;
42 switch (unrollType(Loop, Result.Context)) {
44 diag(Loop->getBeginLoc(),
45 "kernel performance could be improved by unrolling this loop with a "
46 "'#pragma unroll' directive");
48 case PartiallyUnrolled:
52 if (hasKnownBounds(Loop, CXXLoopBound, Context)) {
53 if (hasLargeNumIterations(Loop, CXXLoopBound, Context)) {
54 diag(Loop->getBeginLoc(),
55 "loop likely has a large number of iterations and thus "
56 "cannot be fully unrolled; to partially unroll this loop, use "
57 "the '#pragma unroll <num>' directive");
62 if (isa<WhileStmt, DoStmt>(Loop)) {
63 diag(Loop->getBeginLoc(),
64 "full unrolling requested, but loop bounds may not be known; to "
65 "partially unroll this loop, use the '#pragma unroll <num>' "
70 diag(Loop->getBeginLoc(),
71 "full unrolling requested, but loop bounds are not known; to "
72 "partially unroll this loop, use the '#pragma unroll <num>' "
78enum UnrollLoopsCheck::UnrollType
79UnrollLoopsCheck::unrollType(
const Stmt *Statement, ASTContext *Context) {
80 const DynTypedNodeList Parents = Context->getParents<Stmt>(*Statement);
81 for (
const DynTypedNode &
Parent : Parents) {
82 const auto *ParentStmt =
Parent.get<AttributedStmt>();
85 for (
const Attr *Attribute : ParentStmt->getAttrs()) {
86 const auto *LoopHint = dyn_cast<LoopHintAttr>(Attribute);
89 switch (LoopHint->getState()) {
90 case LoopHintAttr::Numeric:
91 return PartiallyUnrolled;
92 case LoopHintAttr::Disable:
94 case LoopHintAttr::Full:
96 case LoopHintAttr::Enable:
98 case LoopHintAttr::AssumeSafety:
100 case LoopHintAttr::FixedWidth:
102 case LoopHintAttr::ScalableWidth:
110bool UnrollLoopsCheck::hasKnownBounds(
const Stmt *Statement,
111 const IntegerLiteral *CXXLoopBound,
112 const ASTContext *Context) {
113 if (isa<CXXForRangeStmt>(Statement))
114 return CXXLoopBound !=
nullptr;
117 if (isa<WhileStmt, DoStmt>(Statement))
120 const auto *ForLoop = cast<ForStmt>(Statement);
121 const Stmt *Initializer = ForLoop->getInit();
122 const Expr *Conditional = ForLoop->getCond();
123 const Expr *Increment = ForLoop->getInc();
124 if (!Initializer || !Conditional || !Increment)
127 if (
const auto *InitDeclStatement = dyn_cast<DeclStmt>(Initializer)) {
128 if (
const auto *VariableDecl =
129 dyn_cast<VarDecl>(InitDeclStatement->getSingleDecl())) {
130 APValue *Evaluation = VariableDecl->evaluateValue();
131 if (!Evaluation || !Evaluation->hasValue())
136 if (
const auto *Op = dyn_cast<UnaryOperator>(Increment))
137 if (!Op->isIncrementDecrementOp())
140 if (
const auto *BinaryOp = dyn_cast<BinaryOperator>(Conditional)) {
141 const Expr *LHS = BinaryOp->getLHS();
142 const Expr *RHS = BinaryOp->getRHS();
144 return LHS->isEvaluatable(*Context) != RHS->isEvaluatable(*Context);
149const Expr *UnrollLoopsCheck::getCondExpr(
const Stmt *Statement) {
150 if (
const auto *ForLoop = dyn_cast<ForStmt>(Statement))
151 return ForLoop->getCond();
152 if (
const auto *WhileLoop = dyn_cast<WhileStmt>(Statement))
153 return WhileLoop->getCond();
154 if (
const auto *DoWhileLoop = dyn_cast<DoStmt>(Statement))
155 return DoWhileLoop->getCond();
156 if (
const auto *CXXRangeLoop = dyn_cast<CXXForRangeStmt>(Statement))
157 return CXXRangeLoop->getCond();
158 llvm_unreachable(
"Unknown loop");
161bool UnrollLoopsCheck::hasLargeNumIterations(
const Stmt *Statement,
162 const IntegerLiteral *CXXLoopBound,
163 const ASTContext *Context) {
166 if (isa<CXXForRangeStmt>(Statement)) {
167 assert(CXXLoopBound &&
"CXX ranged for loop has no loop bound");
168 return exprHasLargeNumIterations(CXXLoopBound, Context);
170 const auto *ForLoop = cast<ForStmt>(Statement);
171 const Stmt *Initializer = ForLoop->getInit();
172 const Expr *Conditional = ForLoop->getCond();
173 const Expr *Increment = ForLoop->getInc();
176 if (
const auto *InitDeclStatement = dyn_cast<DeclStmt>(Initializer)) {
177 if (
const auto *VariableDecl =
178 dyn_cast<VarDecl>(InitDeclStatement->getSingleDecl())) {
179 APValue *Evaluation = VariableDecl->evaluateValue();
180 if (!Evaluation || !Evaluation->isInt())
182 InitValue = Evaluation->getInt().getExtValue();
187 const auto *BinaryOp = cast<BinaryOperator>(Conditional);
188 if (!extractValue(EndValue, BinaryOp, Context))
191 double Iterations = 0.0;
194 if (
const auto *Op = dyn_cast<UnaryOperator>(Increment)) {
195 if (Op->isIncrementOp())
196 Iterations = EndValue - InitValue;
197 else if (Op->isDecrementOp())
198 Iterations = InitValue - EndValue;
200 llvm_unreachable(
"Unary operator neither increment nor decrement");
205 if (
const auto *Op = dyn_cast<BinaryOperator>(Increment)) {
206 int ConstantValue = 0;
207 if (!extractValue(ConstantValue, Op, Context))
209 switch (Op->getOpcode()) {
211 Iterations = ceil(
float(EndValue - InitValue) / ConstantValue);
214 Iterations = ceil(
float(InitValue - EndValue) / ConstantValue);
217 Iterations = 1 + (
log((
double)EndValue) -
log((
double)InitValue)) /
218 log((
double)ConstantValue);
221 Iterations = 1 + (
log((
double)InitValue) -
log((
double)EndValue)) /
222 log((
double)ConstantValue);
229 return Iterations > MaxLoopIterations;
232bool UnrollLoopsCheck::extractValue(
int &Value,
const BinaryOperator *Op,
233 const ASTContext *Context) {
234 const Expr *LHS = Op->getLHS();
235 const Expr *RHS = Op->getRHS();
236 Expr::EvalResult Result;
237 if (LHS->isEvaluatable(*Context))
238 LHS->EvaluateAsRValue(Result, *Context);
239 else if (RHS->isEvaluatable(*Context))
240 RHS->EvaluateAsRValue(Result, *Context);
243 if (!Result.Val.isInt())
246 Value = Result.Val.getInt().getExtValue();
250bool UnrollLoopsCheck::exprHasLargeNumIterations(
const Expr *Expression,
251 const ASTContext *Context)
const {
252 Expr::EvalResult Result;
253 if (Expression->EvaluateAsRValue(Result, *Context)) {
254 if (!Result.Val.isInt())
258 return Result.Val.getInt() > MaxLoopIterations;
266 Options.
store(Opts,
"MaxLoopIterations", MaxLoopIterations);
llvm::SmallString< 256U > Name
void store(ClangTidyOptions::OptionMap &Options, StringRef LocalName, StringRef Value) const
Stores an option with the check-local name LocalName with string value Value to Options.
Base class for all clang-tidy checks.
DiagnosticBuilder diag(SourceLocation Loc, StringRef Description, DiagnosticIDs::Level Level=DiagnosticIDs::Warning)
Add a diagnostic with the check's name.
Every ClangTidyCheck reports errors through a DiagnosticsEngine provided by this context.
void check(const ast_matchers::MatchFinder::MatchResult &Result) override
ClangTidyChecks that register ASTMatchers should do the actual work in here.
UnrollLoopsCheck(StringRef Name, ClangTidyContext *Context)
void registerMatchers(ast_matchers::MatchFinder *Finder) override
Override this to register AST matchers with Finder.
void log(Logger::Level L, const char *Fmt, Ts &&... Vals)
llvm::StringMap< ClangTidyValue > OptionMap