clang 20.0.0git
LoopUnrolling.cpp
Go to the documentation of this file.
1//===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// This file contains functions which are used to decide if a loop worth to be
10/// unrolled. Moreover, these functions manages the stack of loop which is
11/// tracked by the ProgramState.
12///
13//===----------------------------------------------------------------------===//
14
20#include <optional>
21
22using namespace clang;
23using namespace ento;
24using namespace clang::ast_matchers;
25
26static const int MAXIMUM_STEP_UNROLLED = 128;
27
28namespace {
29struct LoopState {
30private:
31 enum Kind { Normal, Unrolled } K;
32 const Stmt *LoopStmt;
33 const LocationContext *LCtx;
34 unsigned maxStep;
35 LoopState(Kind InK, const Stmt *S, const LocationContext *L, unsigned N)
36 : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {}
37
38public:
39 static LoopState getNormal(const Stmt *S, const LocationContext *L,
40 unsigned N) {
41 return LoopState(Normal, S, L, N);
42 }
43 static LoopState getUnrolled(const Stmt *S, const LocationContext *L,
44 unsigned N) {
45 return LoopState(Unrolled, S, L, N);
46 }
47 bool isUnrolled() const { return K == Unrolled; }
48 unsigned getMaxStep() const { return maxStep; }
49 const Stmt *getLoopStmt() const { return LoopStmt; }
50 const LocationContext *getLocationContext() const { return LCtx; }
51 bool operator==(const LoopState &X) const {
52 return K == X.K && LoopStmt == X.LoopStmt;
53 }
54 void Profile(llvm::FoldingSetNodeID &ID) const {
55 ID.AddInteger(K);
56 ID.AddPointer(LoopStmt);
57 ID.AddPointer(LCtx);
58 ID.AddInteger(maxStep);
59 }
60};
61} // namespace
62
63// The tracked stack of loops. The stack indicates that which loops the
64// simulated element contained by. The loops are marked depending if we decided
65// to unroll them.
66// TODO: The loop stack should not need to be in the program state since it is
67// lexical in nature. Instead, the stack of loops should be tracked in the
68// LocationContext.
69REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState)
70
71namespace clang {
72namespace ento {
73
74static bool isLoopStmt(const Stmt *S) {
75 return isa_and_nonnull<ForStmt, WhileStmt, DoStmt>(S);
76}
77
79 auto LS = State->get<LoopStack>();
80 if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt)
81 State = State->set<LoopStack>(LS.getTail());
82 return State;
83}
84
85static internal::Matcher<Stmt> simpleCondition(StringRef BindName,
86 StringRef RefName) {
87 return binaryOperator(
88 anyOf(hasOperatorName("<"), hasOperatorName(">"),
89 hasOperatorName("<="), hasOperatorName(">="),
90 hasOperatorName("!=")),
91 hasEitherOperand(ignoringParenImpCasts(
92 declRefExpr(to(varDecl(hasType(isInteger())).bind(BindName)))
93 .bind(RefName))),
94 hasEitherOperand(
95 ignoringParenImpCasts(integerLiteral().bind("boundNum"))))
96 .bind("conditionOperator");
97}
98
99static internal::Matcher<Stmt>
100changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) {
101 return anyOf(
102 unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")),
103 hasUnaryOperand(ignoringParenImpCasts(
104 declRefExpr(to(varDecl(VarNodeMatcher)))))),
105 binaryOperator(isAssignmentOperator(),
106 hasLHS(ignoringParenImpCasts(
107 declRefExpr(to(varDecl(VarNodeMatcher)))))));
108}
109
110static internal::Matcher<Stmt>
111callByRef(internal::Matcher<Decl> VarNodeMatcher) {
112 return callExpr(forEachArgumentWithParam(
113 declRefExpr(to(varDecl(VarNodeMatcher))),
114 parmVarDecl(hasType(references(qualType(unless(isConstQualified())))))));
115}
116
117static internal::Matcher<Stmt>
118assignedToRef(internal::Matcher<Decl> VarNodeMatcher) {
120 allOf(hasType(referenceType()),
121 hasInitializer(anyOf(
122 initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))),
123 declRefExpr(to(varDecl(VarNodeMatcher)))))))));
124}
125
126static internal::Matcher<Stmt>
127getAddrTo(internal::Matcher<Decl> VarNodeMatcher) {
128 return unaryOperator(
129 hasOperatorName("&"),
130 hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher))));
131}
132
133static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) {
134 return hasDescendant(stmt(
136 // Escaping and not known mutation of the loop counter is handled
137 // by exclusion of assigning and address-of operators and
138 // pass-by-ref function calls on the loop counter from the body.
139 changeIntBoundNode(equalsBoundNode(std::string(NodeName))),
140 callByRef(equalsBoundNode(std::string(NodeName))),
141 getAddrTo(equalsBoundNode(std::string(NodeName))),
142 assignedToRef(equalsBoundNode(std::string(NodeName))))));
143}
144
145static internal::Matcher<Stmt> forLoopMatcher() {
146 return forStmt(
147 hasCondition(simpleCondition("initVarName", "initVarRef")),
148 // Initialization should match the form: 'int i = 6' or 'i = 42'.
149 hasLoopInit(
150 anyOf(declStmt(hasSingleDecl(
151 varDecl(allOf(hasInitializer(ignoringParenImpCasts(
152 integerLiteral().bind("initNum"))),
153 equalsBoundNode("initVarName"))))),
155 equalsBoundNode("initVarName"))))),
156 hasRHS(ignoringParenImpCasts(
157 integerLiteral().bind("initNum")))))),
158 // Incrementation should be a simple increment or decrement
159 // operator call.
160 hasIncrement(unaryOperator(
161 anyOf(hasOperatorName("++"), hasOperatorName("--")),
162 hasUnaryOperand(declRefExpr(
163 to(varDecl(allOf(equalsBoundNode("initVarName"),
164 hasType(isInteger())))))))),
165 unless(hasBody(hasSuspiciousStmt("initVarName"))))
166 .bind("forLoop");
167}
168
170
171 // Get the lambda CXXRecordDecl
173 const LocationContext *LocCtxt = N->getLocationContext();
174 const Decl *D = LocCtxt->getDecl();
175 const auto *MD = cast<CXXMethodDecl>(D);
176 assert(MD && MD->getParent()->isLambda() &&
177 "Captured variable should only be seen while evaluating a lambda");
178 const CXXRecordDecl *LambdaCXXRec = MD->getParent();
179
180 // Lookup the fields of the lambda
181 llvm::DenseMap<const ValueDecl *, FieldDecl *> LambdaCaptureFields;
182 FieldDecl *LambdaThisCaptureField;
183 LambdaCXXRec->getCaptureFields(LambdaCaptureFields, LambdaThisCaptureField);
184
185 // Check if the counter is captured by reference
186 const VarDecl *VD = cast<VarDecl>(DR->getDecl()->getCanonicalDecl());
187 assert(VD);
188 const FieldDecl *FD = LambdaCaptureFields[VD];
189 assert(FD && "Captured variable without a corresponding field");
190 return FD->getType()->isReferenceType();
191}
192
193static bool isFoundInStmt(const Stmt *S, const VarDecl *VD) {
194 if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) {
195 for (const Decl *D : DS->decls()) {
196 // Once we reach the declaration of the VD we can return.
197 if (D->getCanonicalDecl() == VD)
198 return true;
199 }
200 }
201 return false;
202}
203
204// A loop counter is considered escaped if:
205// case 1: It is a global variable.
206// case 2: It is a reference parameter or a reference capture.
207// case 3: It is assigned to a non-const reference variable or parameter.
208// case 4: Has its address taken.
209static bool isPossiblyEscaped(ExplodedNode *N, const DeclRefExpr *DR) {
210 const VarDecl *VD = cast<VarDecl>(DR->getDecl()->getCanonicalDecl());
211 assert(VD);
212 // Case 1:
213 if (VD->hasGlobalStorage())
214 return true;
215
216 const bool IsRefParamOrCapture =
217 isa<ParmVarDecl>(VD) || DR->refersToEnclosingVariableOrCapture();
218 // Case 2:
220 isCapturedByReference(N, DR)) ||
221 (IsRefParamOrCapture && VD->getType()->isReferenceType()))
222 return true;
223
224 while (!N->pred_empty()) {
225 // FIXME: getStmtForDiagnostics() does nasty things in order to provide
226 // a valid statement for body farms, do we need this behavior here?
227 const Stmt *S = N->getStmtForDiagnostics();
228 if (!S) {
229 N = N->getFirstPred();
230 continue;
231 }
232
233 if (isFoundInStmt(S, VD)) {
234 return false;
235 }
236
237 if (const auto *SS = dyn_cast<SwitchStmt>(S)) {
238 if (const auto *CST = dyn_cast<CompoundStmt>(SS->getBody())) {
239 for (const Stmt *CB : CST->body()) {
240 if (isFoundInStmt(CB, VD))
241 return false;
242 }
243 }
244 }
245
246 // Check the usage of the pass-by-ref function calls and adress-of operator
247 // on VD and reference initialized by VD.
248 ASTContext &ASTCtx =
250 // Case 3 and 4:
251 auto Match =
252 match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)),
253 assignedToRef(equalsNode(VD)))),
254 *S, ASTCtx);
255 if (!Match.empty())
256 return true;
257
258 N = N->getFirstPred();
259 }
260
261 // Reference parameter and reference capture will not be found.
262 if (IsRefParamOrCapture)
263 return false;
264
265 llvm_unreachable("Reached root without finding the declaration of VD");
266}
267
268bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx,
269 ExplodedNode *Pred, unsigned &maxStep) {
270
271 if (!isLoopStmt(LoopStmt))
272 return false;
273
274 // TODO: Match the cases where the bound is not a concrete literal but an
275 // integer with known value
276 auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx);
277 if (Matches.empty())
278 return false;
279
280 const auto *CounterVarRef = Matches[0].getNodeAs<DeclRefExpr>("initVarRef");
281 llvm::APInt BoundNum =
282 Matches[0].getNodeAs<IntegerLiteral>("boundNum")->getValue();
283 llvm::APInt InitNum =
284 Matches[0].getNodeAs<IntegerLiteral>("initNum")->getValue();
285 auto CondOp = Matches[0].getNodeAs<BinaryOperator>("conditionOperator");
286 if (InitNum.getBitWidth() != BoundNum.getBitWidth()) {
287 InitNum = InitNum.zext(BoundNum.getBitWidth());
288 BoundNum = BoundNum.zext(InitNum.getBitWidth());
289 }
290
291 if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE)
292 maxStep = (BoundNum - InitNum + 1).abs().getZExtValue();
293 else
294 maxStep = (BoundNum - InitNum).abs().getZExtValue();
295
296 // Check if the counter of the loop is not escaped before.
297 return !isPossiblyEscaped(Pred, CounterVarRef);
298}
299
300bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) {
301 const Stmt *S = nullptr;
302 while (!N->pred_empty()) {
303 if (N->succ_size() > 1)
304 return true;
305
307 if (std::optional<BlockEntrance> BE = P.getAs<BlockEntrance>())
308 S = BE->getBlock()->getTerminatorStmt();
309
310 if (S == LoopStmt)
311 return false;
312
313 N = N->getFirstPred();
314 }
315
316 llvm_unreachable("Reached root without encountering the previous step");
317}
318
319// updateLoopStack is called on every basic block, therefore it needs to be fast
321 ExplodedNode *Pred, unsigned maxVisitOnPath) {
322 auto State = Pred->getState();
323 auto LCtx = Pred->getLocationContext();
324
325 if (!isLoopStmt(LoopStmt))
326 return State;
327
328 auto LS = State->get<LoopStack>();
329 if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() &&
330 LCtx == LS.getHead().getLocationContext()) {
331 if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) {
332 State = State->set<LoopStack>(LS.getTail());
333 State = State->add<LoopStack>(
334 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
335 }
336 return State;
337 }
338 unsigned maxStep;
339 if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) {
340 State = State->add<LoopStack>(
341 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
342 return State;
343 }
344
345 unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep());
346
347 unsigned innerMaxStep = maxStep * outerStep;
348 if (innerMaxStep > MAXIMUM_STEP_UNROLLED)
349 State = State->add<LoopStack>(
350 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
351 else
352 State = State->add<LoopStack>(
353 LoopState::getUnrolled(LoopStmt, LCtx, innerMaxStep));
354 return State;
355}
356
358 auto LS = State->get<LoopStack>();
359 if (LS.isEmpty() || !LS.getHead().isUnrolled())
360 return false;
361 return true;
362}
363}
364}
StringRef P
const Decl * D
#define X(type, name)
Definition: Value.h:143
static const int MAXIMUM_STEP_UNROLLED
This header contains the declarations of functions which are used to decide which loops should be com...
#define REGISTER_LIST_WITH_PROGRAMSTATE(Name, Elem)
Declares an immutable list type NameTy, suitable for placement into the ProgramState.
__DEVICE__ long long abs(long long __n)
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
Definition: ASTContext.h:187
ASTContext & getASTContext() const
A builtin binary operation expression such as "x + y" or "x <= y".
Definition: Expr.h:3860
Represents a C++ struct/union/class.
Definition: DeclCXX.h:258
void getCaptureFields(llvm::DenseMap< const ValueDecl *, FieldDecl * > &Captures, FieldDecl *&ThisCapture) const
For a closure type, retrieve the mapping from captured variables and this to the non-static data memb...
Definition: DeclCXX.cpp:1680
DeclContext * getParent()
getParent - Returns the containing DeclContext.
Definition: DeclBase.h:2090
A reference to a declared variable, function, enum, etc.
Definition: Expr.h:1265
bool refersToEnclosingVariableOrCapture() const
Does this DeclRefExpr refer to an enclosing local or a captured variable?
Definition: Expr.h:1463
ValueDecl * getDecl()
Definition: Expr.h:1333
DeclStmt - Adaptor class for mixing declarations with statements and expressions.
Definition: Stmt.h:1502
Decl - This represents one declaration (or definition), e.g.
Definition: DeclBase.h:86
virtual Decl * getCanonicalDecl()
Retrieves the "canonical" declaration of the given declaration.
Definition: DeclBase.h:968
Represents a member of a struct/union/class.
Definition: Decl.h:3030
It wraps the AnalysisDeclContext to represent both the call stack with the help of StackFrameContext ...
const Decl * getDecl() const
LLVM_ATTRIBUTE_RETURNS_NONNULL AnalysisDeclContext * getAnalysisDeclContext() const
Stmt - This represents one statement.
Definition: Stmt.h:84
bool isReferenceType() const
Definition: Type.h:8021
QualType getType() const
Definition: Decl.h:678
Represents a variable declaration or definition.
Definition: Decl.h:879
bool hasGlobalStorage() const
Returns true for all variables that do not have local storage.
Definition: Decl.h:1174
const ProgramStateRef & getState() const
const Stmt * getStmtForDiagnostics() const
If the node's program point corresponds to a statement, retrieve that statement.
ProgramPoint getLocation() const
getLocation - Returns the edge associated with the given node.
const LocationContext * getLocationContext() const
ExplodedNode * getFirstPred()
unsigned succ_size() const
const internal::VariadicDynCastAllOfMatcher< Decl, VarDecl > varDecl
Matches variable declarations.
const internal::VariadicDynCastAllOfMatcher< Stmt, DeclRefExpr > declRefExpr
Matches expressions that refer to declarations.
const internal::VariadicOperatorMatcherFunc< 1, 1 > unless
Matches if the provided matcher does not match.
const internal::ArgumentAdaptingMatcherFunc< internal::HasDescendantMatcher > hasDescendant
Matches AST nodes that have descendant AST nodes that match the provided matcher.
const internal::VariadicDynCastAllOfMatcher< Decl, ParmVarDecl > parmVarDecl
Matches parameter variable declarations.
const internal::VariadicDynCastAllOfMatcher< Stmt, ReturnStmt > returnStmt
Matches return statements.
const internal::VariadicDynCastAllOfMatcher< Stmt, CallExpr > callExpr
Matches call expressions.
SmallVector< BoundNodes, 1 > match(MatcherT Matcher, const NodeT &Node, ASTContext &Context)
Returns the results of matching Matcher on Node.
const internal::VariadicDynCastAllOfMatcher< Stmt, UnaryOperator > unaryOperator
Matches unary operator expressions.
const internal::VariadicDynCastAllOfMatcher< Stmt, InitListExpr > initListExpr
Matches init list expressions.
const internal::VariadicDynCastAllOfMatcher< Stmt, ForStmt > forStmt
Matches for statements.
const internal::VariadicDynCastAllOfMatcher< Stmt, GotoStmt > gotoStmt
Matches goto statements.
const internal::VariadicDynCastAllOfMatcher< Stmt, BinaryOperator > binaryOperator
Matches binary operator expressions.
const internal::ArgumentAdaptingMatcherFunc< internal::HasMatcher > has
Matches AST nodes that have child AST nodes that match the provided matcher.
const internal::VariadicOperatorMatcherFunc< 2, std::numeric_limits< unsigned >::max()> allOf
Matches if all given matchers match.
const internal::VariadicDynCastAllOfMatcher< Stmt, SwitchStmt > switchStmt
Matches switch statements.
const AstTypeMatcher< ReferenceType > referenceType
Matches both lvalue and rvalue reference types.
const internal::VariadicDynCastAllOfMatcher< Stmt, IntegerLiteral > integerLiteral
Matches integer literals of all sizes / encodings, e.g.
internal::PolymorphicMatcher< internal::HasDeclarationMatcher, void(internal::HasDeclarationSupportedTypes), internal::Matcher< Decl > > hasDeclaration(const internal::Matcher< Decl > &InnerMatcher)
Matches a node if the declaration associated with that node matches the given matcher.
Definition: ASTMatchers.h:3653
const internal::VariadicDynCastAllOfMatcher< Stmt, DeclStmt > declStmt
Matches declaration statements.
const internal::VariadicAllOfMatcher< Stmt > stmt
Matches statements.
const internal::VariadicOperatorMatcherFunc< 2, std::numeric_limits< unsigned >::max()> anyOf
Matches if any of the given matchers matches.
const internal::VariadicAllOfMatcher< QualType > qualType
Matches QualTypes in the clang AST.
static internal::Matcher< Stmt > simpleCondition(StringRef BindName, StringRef RefName)
static internal::Matcher< Stmt > hasSuspiciousStmt(StringRef NodeName)
static internal::Matcher< Stmt > callByRef(internal::Matcher< Decl > VarNodeMatcher)
static internal::Matcher< Stmt > forLoopMatcher()
static bool isPossiblyEscaped(ExplodedNode *N, const DeclRefExpr *DR)
static internal::Matcher< Stmt > getAddrTo(internal::Matcher< Decl > VarNodeMatcher)
static bool isLoopStmt(const Stmt *S)
ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State)
Updates the given ProgramState.
bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx, ExplodedNode *Pred, unsigned &maxStep)
static bool isFoundInStmt(const Stmt *S, const VarDecl *VD)
static internal::Matcher< Stmt > assignedToRef(internal::Matcher< Decl > VarNodeMatcher)
bool isUnrolledState(ProgramStateRef State)
Returns if the given State indicates that is inside a completely unrolled loop.
static bool isCapturedByReference(ExplodedNode *N, const DeclRefExpr *DR)
ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx, ExplodedNode *Pred, unsigned maxVisitOnPath)
Updates the stack of loops contained by the ProgramState.
bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt)
static internal::Matcher< Stmt > changeIntBoundNode(internal::Matcher< Decl > VarNodeMatcher)
The JSON file list parser is used to communicate input to InstallAPI.
bool operator==(const CallGraphNode::CallRecord &LHS, const CallGraphNode::CallRecord &RHS)
Definition: CallGraph.h:207