10 #include "clang/AST/ASTContext.h"
11 #include "clang/ASTMatchers/ASTMatchFinder.h"
12 #include "clang/Analysis/CallGraph.h"
13 #include "llvm/ADT/DenseMapInfo.h"
14 #include "llvm/ADT/SCCIterator.h"
28 template <
typename T,
unsigned SmallSize>
class ImmutableSmallSet {
30 llvm::DenseSet<T> Set;
32 static_assert(SmallSize <= 32,
"N should be small");
34 bool isSmall()
const {
return Set.empty(); }
37 using size_type = size_t;
39 ImmutableSmallSet() =
delete;
40 ImmutableSmallSet(
const ImmutableSmallSet &) =
delete;
41 ImmutableSmallSet(ImmutableSmallSet &&) =
delete;
42 T &operator=(
const ImmutableSmallSet &) =
delete;
43 T &operator=(ImmutableSmallSet &&) =
delete;
46 ImmutableSmallSet(ArrayRef<T> Storage) {
48 if (Storage.size() <= SmallSize) {
55 Set.reserve(Storage.size());
56 Set.insert(Storage.begin(), Storage.end());
60 size_type count(
const T &V)
const {
63 return llvm::find(Vector, V) == Vector.end() ? 0 : 1;
75 template <
typename T,
unsigned SmallSize>
class SmartSmallSetVector {
77 using size_type = size_t;
80 SmallVector<T, SmallSize> Vector;
81 llvm::DenseSet<T> Set;
83 static_assert(SmallSize <= 32,
"N should be small");
86 bool isSmall()
const {
return Set.empty(); }
89 bool entiretyOfVectorSmallSizeIsOccupied()
const {
90 assert(isSmall() && Vector.size() <= SmallSize &&
91 "Shouldn't ask if we have already [should have] migrated into Set.");
92 return Vector.size() == SmallSize;
96 assert(Set.empty() &&
"Should not have already utilized the Set.");
99 const size_t NewMaxElts = 4 * Vector.size();
100 Vector.reserve(NewMaxElts);
101 Set.reserve(NewMaxElts);
102 Set.insert(Vector.begin(), Vector.end());
106 size_type count(
const T &V)
const {
109 return llvm::find(Vector, V) == Vector.end() ? 0 : 1;
115 bool setInsert(
const T &V) {
121 if (!entiretyOfVectorSmallSizeIsOccupied())
128 bool SetInsertionSucceeded = Set.insert(V).second;
129 (void)SetInsertionSucceeded;
130 assert(SetInsertionSucceeded &&
"We did check that no such value existed");
137 bool insert(
const T &
X) {
138 bool Result = setInsert(
X);
145 decltype(Vector) takeVector() {
147 return std::move(Vector);
151 constexpr
unsigned SmallCallStackSize = 16;
152 constexpr
unsigned SmallSCCSize = 32;
155 llvm::SmallVector<CallGraphNode::CallRecord, SmallCallStackSize>;
160 CallStackTy pathfindSomeCycle(ArrayRef<CallGraphNode *> SCC) {
163 const ImmutableSmallSet<CallGraphNode *, SmallSCCSize> SCCElts(SCC);
166 auto NodeIsPartOfSCC = [&SCCElts](CallGraphNode *N) {
167 return SCCElts.count(N) != 0;
171 SmartSmallSetVector<CallGraphNode::CallRecord, SmallCallStackSize>
175 CallGraphNode::CallRecord EntryNode(SCC.front(),
nullptr);
178 CallGraphNode::CallRecord *Node = &EntryNode;
181 if (!CallStackSet.insert(*Node))
185 Node = llvm::find_if(Node->Callee->callees(), NodeIsPartOfSCC);
190 CallStackTy CallStack = CallStackSet.takeVector();
191 CallStack.emplace_back(*Node);
198 void NoRecursionCheck::registerMatchers(MatchFinder *Finder) {
199 Finder->addMatcher(translationUnitDecl().bind(
"TUDecl"),
this);
202 void NoRecursionCheck::handleSCC(ArrayRef<CallGraphNode *> SCC) {
203 assert(!SCC.empty() &&
"Empty SCC does not make sense.");
206 for (CallGraphNode *N : SCC) {
207 FunctionDecl *
D = N->getDefinition();
208 diag(
D->getLocation(),
"function %0 is within a recursive call chain") <<
D;
215 const CallStackTy EventuallyCyclicCallStack = pathfindSomeCycle(SCC);
216 assert(!EventuallyCyclicCallStack.empty() &&
"We should've found the cycle");
221 const auto CyclicCallStack =
222 ArrayRef<CallGraphNode::CallRecord>(EventuallyCyclicCallStack)
223 .drop_until([LastNode = EventuallyCyclicCallStack.back()](
224 CallGraphNode::CallRecord FrontNode) {
225 return FrontNode == LastNode;
227 assert(CyclicCallStack.size() >= 2 &&
"Cycle requires at least 2 frames");
230 FunctionDecl *CycleEntryFn = CyclicCallStack.front().Callee->getDefinition();
233 diag(CycleEntryFn->getLocation(),
234 "example recursive call chain, starting from function %0",
237 for (
int CurFrame = 1, NumFrames = CyclicCallStack.size();
238 CurFrame != NumFrames; ++CurFrame) {
239 CallGraphNode::CallRecord PrevNode = CyclicCallStack[CurFrame - 1];
240 CallGraphNode::CallRecord CurrNode = CyclicCallStack[CurFrame];
242 Decl *PrevDecl = PrevNode.Callee->getDecl();
243 Decl *CurrDecl = CurrNode.Callee->getDecl();
245 diag(CurrNode.CallExpr->getBeginLoc(),
246 "Frame #%0: function %1 calls function %2 here:", DiagnosticIDs::Note)
247 << CurFrame << cast<NamedDecl>(PrevDecl) << cast<NamedDecl>(CurrDecl);
250 diag(CyclicCallStack.back().CallExpr->getBeginLoc(),
251 "... which was the starting point of the recursive call chain; there "
252 "may be other cycles",
253 DiagnosticIDs::Note);
258 const auto *TU = Result.Nodes.getNodeAs<TranslationUnitDecl>(
"TUDecl");
260 CG.addToCallGraph(
const_cast<TranslationUnitDecl *
>(TU));
264 for (llvm::scc_iterator<CallGraph *> SCCI = llvm::scc_begin(&CG),
265 SCCE = llvm::scc_end(&CG);
266 SCCI != SCCE; ++SCCI) {
267 if (!SCCI.hasCycle())