56#include "clang/AST/ASTContext.h"
57#include "clang/AST/Decl.h"
58#include "clang/AST/DeclBase.h"
59#include "clang/AST/ExprCXX.h"
60#include "clang/AST/NestedNameSpecifier.h"
61#include "clang/AST/RecursiveASTVisitor.h"
62#include "clang/AST/Stmt.h"
63#include "clang/Basic/LangOptions.h"
64#include "clang/Basic/SourceLocation.h"
65#include "clang/Basic/SourceManager.h"
66#include "clang/Tooling/Core/Replacement.h"
67#include "clang/Tooling/Refactoring/Extract/SourceExtraction.h"
68#include "llvm/ADT/STLExtras.h"
69#include "llvm/ADT/SmallSet.h"
70#include "llvm/ADT/SmallVector.h"
71#include "llvm/ADT/StringRef.h"
72#include "llvm/Support/Casting.h"
73#include "llvm/Support/Error.h"
80using Node = SelectionTree::Node;
85enum class ZoneRelative {
92enum FunctionDeclKind {
101bool isRootStmt(
const Node *N) {
102 if (!N->ASTNode.get<Stmt>())
112 !N->ASTNode.get<CXXOperatorCallExpr>())
127const Node *getParentOfRootStmts(
const Node *CommonAnc) {
131 switch (CommonAnc->Selected) {
143 Parent = CommonAnc->Parent;
148 if (
Parent->ASTNode.get<DeclStmt>())
153 return llvm::all_of(
Parent->Children, isRootStmt) ?
Parent :
nullptr;
157struct ExtractionZone {
169 SourceLocation getInsertionPoint()
const {
172 bool isRootStmt(
const Stmt *S)
const;
175 const Node *getLastRootStmt()
const {
return Parent->Children.back(); }
181 bool requiresHoisting(
const SourceManager &SM,
182 const HeuristicResolver *Resolver)
const {
184 llvm::SmallSet<const Decl *, 1> DeclsInExtZone;
185 for (
auto *RootStmt : RootStmts) {
188 [&DeclsInExtZone](
const ReferenceLoc &
Loc) {
191 DeclsInExtZone.insert(
Loc.Targets.front());
196 if (DeclsInExtZone.empty())
200 if (SM.isBeforeInTranslationUnit(S->getSourceRange().getEnd(),
203 bool HasPostUse =
false;
206 [&](
const ReferenceLoc &
Loc) {
208 SM.isBeforeInTranslationUnit(
Loc.NameLoc,
ZoneRange.getEnd()))
210 HasPostUse = llvm::any_of(
Loc.Targets,
211 [&DeclsInExtZone](
const Decl *Target) {
212 return DeclsInExtZone.contains(Target);
227bool alwaysReturns(
const ExtractionZone &EZ) {
228 const Stmt *Last = EZ.getLastRootStmt()->ASTNode.get<Stmt>();
230 while (
const auto *CS = llvm::dyn_cast<CompoundStmt>(Last)) {
231 if (CS->body_empty())
233 Last = CS->body_back();
235 return llvm::isa<ReturnStmt>(Last);
238bool ExtractionZone::isRootStmt(
const Stmt *S)
const {
243const FunctionDecl *findEnclosingFunction(
const Node *CommonAnc) {
245 for (
const Node *CurNode = CommonAnc; CurNode; CurNode = CurNode->Parent) {
247 if (CurNode->ASTNode.get<LambdaExpr>())
249 if (
const FunctionDecl *Func = CurNode->ASTNode.get<FunctionDecl>()) {
251 if (Func->isTemplated())
253 if (!Func->getBody())
255 for (
const auto *S : Func->getBody()->children()) {
270std::optional<SourceRange> findZoneRange(
const Node *
Parent,
271 const SourceManager &SM,
272 const LangOptions &LangOpts) {
275 SM, LangOpts,
Parent->Children.front()->ASTNode.getSourceRange()))
276 SR.setBegin(BeginFileRange->getBegin());
280 SM, LangOpts,
Parent->Children.back()->ASTNode.getSourceRange()))
281 SR.setEnd(EndFileRange->getEnd());
291std::optional<SourceRange>
293 const SourceManager &SM,
294 const LangOptions &LangOpts) {
300bool validSingleChild(
const Node *Child,
const FunctionDecl *EnclosingFunc) {
304 if (Child->ASTNode.get<Expr>())
307 assert(EnclosingFunc->hasBody() &&
308 "We should always be extracting from a function body.");
309 if (Child->ASTNode.get<Stmt>() == EnclosingFunc->getBody())
316std::optional<ExtractionZone> findExtractionZone(
const Node *CommonAnc,
317 const SourceManager &SM,
318 const LangOptions &LangOpts) {
319 ExtractionZone ExtZone;
320 ExtZone.Parent = getParentOfRootStmts(CommonAnc);
321 if (!ExtZone.Parent || ExtZone.Parent->Children.empty())
323 ExtZone.EnclosingFunction = findEnclosingFunction(ExtZone.Parent);
324 if (!ExtZone.EnclosingFunction)
328 if (ExtZone.Parent->Children.size() == 1 &&
329 !validSingleChild(ExtZone.getLastRootStmt(), ExtZone.EnclosingFunction))
332 computeEnclosingFuncRange(ExtZone.EnclosingFunction, SM, LangOpts))
333 ExtZone.EnclosingFuncRange = *FuncRange;
334 if (
auto ZoneRange = findZoneRange(ExtZone.Parent, SM, LangOpts))
336 if (ExtZone.EnclosingFuncRange.isInvalid() || ExtZone.ZoneRange.isInvalid())
339 for (
const Node *Child : ExtZone.Parent->Children)
340 ExtZone.RootStmts.insert(Child->ASTNode.get<Stmt>());
353 std::string render(
const DeclContext *Context)
const;
354 bool operator<(
const Parameter &Other)
const {
358 std::string
Name =
"extracted";
371 ConstexprSpecKind
Constexpr = ConstexprSpecKind::Unspecified;
377 const LangOptions *LangOpts;
378 NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy,
379 const LangOptions *LangOpts)
382 std::string renderCall()
const;
384 std::string renderDeclaration(FunctionDeclKind
K,
385 const DeclContext &SemanticDC,
386 const DeclContext &SyntacticDC,
387 const SourceManager &SM)
const;
391 renderParametersForDeclaration(
const DeclContext &Enclosing)
const;
392 std::string renderParametersForCall()
const;
393 std::string renderSpecifiers(FunctionDeclKind
K)
const;
394 std::string renderQualifiers()
const;
395 std::string renderDeclarationName(FunctionDeclKind
K)
const;
397 std::string getFuncBody(
const SourceManager &SM)
const;
400std::string NewFunction::renderParametersForDeclaration(
401 const DeclContext &Enclosing)
const {
403 bool NeedCommaBefore =
false;
407 NeedCommaBefore =
true;
408 Result += P.render(&Enclosing);
413std::string NewFunction::renderParametersForCall()
const {
415 bool NeedCommaBefore =
false;
419 NeedCommaBefore =
true;
425std::string NewFunction::renderSpecifiers(FunctionDeclKind
K)
const {
428 if (
Static &&
K != FunctionDeclKind::OutOfLineDefinition) {
433 case ConstexprSpecKind::Unspecified:
434 case ConstexprSpecKind::Constinit:
436 case ConstexprSpecKind::Constexpr:
439 case ConstexprSpecKind::Consteval:
447std::string NewFunction::renderQualifiers()
const {
457std::string NewFunction::renderDeclarationName(FunctionDeclKind
K)
const {
462 std::string QualifierName;
463 llvm::raw_string_ostream Oss(QualifierName);
465 return llvm::formatv(
"{0}{1}", QualifierName,
Name);
468std::string NewFunction::renderCall()
const {
471 renderParametersForCall(),
475std::string NewFunction::renderDeclaration(FunctionDeclKind
K,
478 const SourceManager &SM)
const {
479 std::string
Declaration = std::string(llvm::formatv(
480 "{0}{1} {2}({3}){4}", renderSpecifiers(
K),
482 renderParametersForDeclaration(
SemanticDC), renderQualifiers()));
485 case ForwardDeclaration:
486 return std::string(llvm::formatv(
"{0};\n",
Declaration));
487 case OutOfLineDefinition:
488 case InlineDefinition:
490 llvm::formatv(
"{0} {\n{1}\n}\n",
Declaration, getFuncBody(SM)));
493 llvm_unreachable(
"Unsupported FunctionDeclKind enum");
496std::string NewFunction::getFuncBody(
const SourceManager &SM)
const {
505std::string NewFunction::Parameter::render(
const DeclContext *Context)
const {
510struct CapturedZoneInfo {
511 struct DeclInformation {
519 DeclInformation(
const Decl *TheDecl, ZoneRelative DeclaredIn,
523 void markOccurence(ZoneRelative ReferenceLoc);
534 DeclInformation *createDeclInfo(
const Decl *D, ZoneRelative RelativeLoc);
535 DeclInformation *getDeclInfoFor(
const Decl *D);
538CapturedZoneInfo::DeclInformation *
539CapturedZoneInfo::createDeclInfo(
const Decl *D, ZoneRelative RelativeLoc) {
542 {D, DeclInformation(D, RelativeLoc,
DeclInfoMap.size())});
544 return &InsertionResult.first->second;
547CapturedZoneInfo::DeclInformation *
548CapturedZoneInfo::getDeclInfoFor(
const Decl *D) {
553 return &Iter->second;
556void CapturedZoneInfo::DeclInformation::markOccurence(
557 ZoneRelative ReferenceLoc) {
558 switch (ReferenceLoc) {
559 case ZoneRelative::Inside:
562 case ZoneRelative::After:
570bool isLoop(
const Stmt *S) {
571 return isa<ForStmt>(S) || isa<DoStmt>(S) || isa<WhileStmt>(S) ||
572 isa<CXXForRangeStmt>(S);
576CapturedZoneInfo captureZoneInfo(
const ExtractionZone &ExtZone) {
580 class ExtractionZoneVisitor
581 :
public clang::RecursiveASTVisitor<ExtractionZoneVisitor> {
583 ExtractionZoneVisitor(
const ExtractionZone &ExtZone) : ExtZone(ExtZone) {
584 TraverseDecl(
const_cast<FunctionDecl *
>(ExtZone.EnclosingFunction));
587 bool TraverseStmt(Stmt *S) {
590 bool IsRootStmt = ExtZone.isRootStmt(
const_cast<const Stmt *
>(S));
594 CurrentLocation = ZoneRelative::Inside;
595 addToLoopSwitchCounters(S, 1);
597 RecursiveASTVisitor::TraverseStmt(S);
598 addToLoopSwitchCounters(S, -1);
602 CurrentLocation = ZoneRelative::After;
608 void addToLoopSwitchCounters(Stmt *S,
int Increment) {
609 if (CurrentLocation != ZoneRelative::Inside)
612 CurNumberOfNestedLoops += Increment;
613 else if (isa<SwitchStmt>(S))
614 CurNumberOfSwitch += Increment;
617 bool VisitDecl(
Decl *D) {
618 Info.createDeclInfo(D, CurrentLocation);
622 bool VisitDeclRefExpr(DeclRefExpr *DRE) {
624 const Decl *D = DRE->getDecl();
625 auto *DeclInfo =
Info.getDeclInfoFor(D);
628 DeclInfo =
Info.createDeclInfo(D, ZoneRelative::OutsideFunc);
629 DeclInfo->markOccurence(CurrentLocation);
634 bool VisitReturnStmt(ReturnStmt *Return) {
635 if (CurrentLocation == ZoneRelative::Inside)
636 Info.HasReturnStmt =
true;
640 bool VisitBreakStmt(BreakStmt *Break) {
643 if (CurrentLocation == ZoneRelative::Inside &&
644 !(CurNumberOfNestedLoops || CurNumberOfSwitch))
645 Info.BrokenControlFlow =
true;
649 bool VisitContinueStmt(ContinueStmt *Continue) {
652 if (CurrentLocation == ZoneRelative::Inside && !CurNumberOfNestedLoops)
653 Info.BrokenControlFlow =
true;
656 CapturedZoneInfo
Info;
657 const ExtractionZone &ExtZone;
658 ZoneRelative CurrentLocation = ZoneRelative::Before;
661 unsigned CurNumberOfNestedLoops = 0;
662 unsigned CurNumberOfSwitch = 0;
664 ExtractionZoneVisitor Visitor(ExtZone);
665 CapturedZoneInfo Result = std::move(Visitor.Info);
666 Result.AlwaysReturns = alwaysReturns(ExtZone);
674bool createParameters(NewFunction &ExtractedFunc,
675 const CapturedZoneInfo &CapturedInfo) {
676 for (
const auto &KeyVal : CapturedInfo.DeclInfoMap) {
677 const auto &DeclInfo = KeyVal.second;
681 if (DeclInfo.DeclaredIn == ZoneRelative::Inside &&
682 DeclInfo.IsReferencedInPostZone)
684 if (!DeclInfo.IsReferencedInZone)
686 if (DeclInfo.DeclaredIn == ZoneRelative::Inside ||
687 DeclInfo.DeclaredIn == ZoneRelative::OutsideFunc)
690 const ValueDecl *VD = dyn_cast_or_null<ValueDecl>(DeclInfo.TheDecl);
693 if (!VD || isa<FunctionDecl>(DeclInfo.TheDecl))
696 QualType
TypeInfo = VD->getType().getNonReferenceType();
702 bool IsPassedByReference =
true;
704 ExtractedFunc.Parameters.push_back({std::string(VD->getName()),
TypeInfo,
706 DeclInfo.DeclIndex});
708 llvm::sort(ExtractedFunc.Parameters);
715tooling::ExtractionSemicolonPolicy
716getSemicolonPolicy(ExtractionZone &ExtZone,
const SourceManager &SM,
717 const LangOptions &LangOpts) {
719 SourceRange FuncBodyRange = {ExtZone.ZoneRange.getBegin(),
720 ExtZone.ZoneRange.getEnd().getLocWithOffset(-1)};
722 ExtZone.getLastRootStmt()->ASTNode.get<Stmt>(), FuncBodyRange, SM,
725 ExtZone.ZoneRange.setEnd(FuncBodyRange.getEnd().getLocWithOffset(1));
730bool generateReturnProperties(NewFunction &ExtractedFunc,
731 const FunctionDecl &EnclosingFunc,
732 const CapturedZoneInfo &CapturedInfo) {
736 if (CapturedInfo.HasReturnStmt) {
739 if (!CapturedInfo.AlwaysReturns)
741 QualType Ret = EnclosingFunc.getReturnType();
744 if (Ret->isDependentType())
746 ExtractedFunc.ReturnType = Ret;
750 ExtractedFunc.ReturnType = EnclosingFunc.getParentASTContext().VoidTy;
754void captureMethodInfo(NewFunction &ExtractedFunc,
755 const CXXMethodDecl *
Method) {
756 ExtractedFunc.Static =
Method->isStatic();
757 ExtractedFunc.Const =
Method->isConst();
758 ExtractedFunc.EnclosingClass =
Method->getParent();
763llvm::Expected<NewFunction> getExtractedFunction(ExtractionZone &ExtZone,
764 const SourceManager &SM,
765 const LangOptions &LangOpts) {
766 CapturedZoneInfo CapturedInfo = captureZoneInfo(ExtZone);
768 if (CapturedInfo.BrokenControlFlow)
769 return error(
"Cannot extract break/continue without corresponding "
770 "loop/switch statement.");
771 NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts),
774 ExtractedFunc.SyntacticDC =
775 ExtZone.EnclosingFunction->getLexicalDeclContext();
776 ExtractedFunc.SemanticDC = ExtZone.EnclosingFunction->getDeclContext();
777 ExtractedFunc.DefinitionQualifier = ExtZone.EnclosingFunction->getQualifier();
778 ExtractedFunc.Constexpr = ExtZone.EnclosingFunction->getConstexprKind();
781 llvm::dyn_cast<CXXMethodDecl>(ExtZone.EnclosingFunction))
782 captureMethodInfo(ExtractedFunc,
Method);
784 if (ExtZone.EnclosingFunction->isOutOfLine()) {
787 const auto *FirstOriginalDecl =
788 ExtZone.EnclosingFunction->getCanonicalDecl();
792 return error(
"Declaration is inside a macro");
793 ExtractedFunc.ForwardDeclarationPoint = DeclPos->getBegin();
794 ExtractedFunc.ForwardDeclarationSyntacticDC = ExtractedFunc.SemanticDC;
797 ExtractedFunc.BodyRange = ExtZone.ZoneRange;
798 ExtractedFunc.DefinitionPoint = ExtZone.getInsertionPoint();
800 ExtractedFunc.CallerReturnsValue = CapturedInfo.AlwaysReturns;
801 if (!createParameters(ExtractedFunc, CapturedInfo) ||
802 !generateReturnProperties(ExtractedFunc, *ExtZone.EnclosingFunction,
804 return error(
"Too complex to extract.");
805 return ExtractedFunc;
808class ExtractFunction :
public Tweak {
810 const char *id() const final;
811 bool prepare(const Selection &Inputs) override;
812 Expected<Effect> apply(const Selection &Inputs) override;
813 std::
string title()
const override {
return "Extract to function"; }
814 llvm::StringLiteral kind()
const override {
819 ExtractionZone ExtZone;
823tooling::Replacement replaceWithFuncCall(
const NewFunction &ExtractedFunc,
824 const SourceManager &SM,
825 const LangOptions &LangOpts) {
826 std::string FuncCall = ExtractedFunc.renderCall();
827 return tooling::Replacement(
828 SM, CharSourceRange(ExtractedFunc.BodyRange,
false), FuncCall, LangOpts);
831tooling::Replacement createFunctionDefinition(
const NewFunction &ExtractedFunc,
832 const SourceManager &SM) {
833 FunctionDeclKind DeclKind = InlineDefinition;
834 if (ExtractedFunc.ForwardDeclarationPoint)
835 DeclKind = OutOfLineDefinition;
836 std::string FunctionDef = ExtractedFunc.renderDeclaration(
837 DeclKind, *ExtractedFunc.SemanticDC, *ExtractedFunc.SyntacticDC, SM);
839 return tooling::Replacement(SM, ExtractedFunc.DefinitionPoint, 0,
843tooling::Replacement createForwardDeclaration(
const NewFunction &ExtractedFunc,
844 const SourceManager &SM) {
845 std::string FunctionDecl = ExtractedFunc.renderDeclaration(
846 ForwardDeclaration, *ExtractedFunc.SemanticDC,
847 *ExtractedFunc.ForwardDeclarationSyntacticDC, SM);
848 SourceLocation DeclPoint = *ExtractedFunc.ForwardDeclarationPoint;
850 return tooling::Replacement(SM, DeclPoint, 0, FunctionDecl);
854bool hasReturnStmt(
const ExtractionZone &ExtZone) {
855 class ReturnStmtVisitor
856 :
public clang::RecursiveASTVisitor<ReturnStmtVisitor> {
858 bool VisitReturnStmt(ReturnStmt *Return) {
866 for (
const Stmt *RootStmt : ExtZone.RootStmts) {
867 V.TraverseStmt(
const_cast<Stmt *
>(RootStmt));
874bool ExtractFunction::prepare(
const Selection &Inputs) {
875 const LangOptions &LangOpts = Inputs.AST->getLangOpts();
876 if (!LangOpts.CPlusPlus)
878 const Node *CommonAnc = Inputs.ASTSelection.commonAncestor();
879 const SourceManager &SM = Inputs.AST->getSourceManager();
880 auto MaybeExtZone = findExtractionZone(CommonAnc, SM, LangOpts);
882 (hasReturnStmt(*MaybeExtZone) && !alwaysReturns(*MaybeExtZone)))
886 if (MaybeExtZone->requiresHoisting(SM, Inputs.AST->getHeuristicResolver()))
889 ExtZone = std::move(*MaybeExtZone);
893Expected<Tweak::Effect> ExtractFunction::apply(
const Selection &Inputs) {
894 const SourceManager &SM = Inputs.AST->getSourceManager();
895 const LangOptions &LangOpts = Inputs.AST->getLangOpts();
896 auto ExtractedFunc = getExtractedFunction(ExtZone, SM, LangOpts);
899 return ExtractedFunc.takeError();
900 tooling::Replacements Edit;
901 if (
auto Err = Edit.add(createFunctionDefinition(*ExtractedFunc, SM)))
902 return std::move(Err);
903 if (
auto Err = Edit.add(replaceWithFuncCall(*ExtractedFunc, SM, LangOpts)))
904 return std::move(Err);
906 if (
auto FwdLoc = ExtractedFunc->ForwardDeclarationPoint) {
909 if (SM.isWrittenInSameFile(ExtractedFunc->DefinitionPoint, *FwdLoc)) {
910 if (
auto Err = Edit.add(createForwardDeclaration(*ExtractedFunc, SM)))
911 return std::move(Err);
913 auto MultiFileEffect = Effect::mainFileEdit(SM, std::move(Edit));
914 if (!MultiFileEffect)
915 return MultiFileEffect.takeError();
917 tooling::Replacements OtherEdit(
918 createForwardDeclaration(*ExtractedFunc, SM));
919 if (
auto PathAndEdit =
921 MultiFileEffect->ApplyEdits.try_emplace(PathAndEdit->first,
922 PathAndEdit->second);
924 return PathAndEdit.takeError();
925 return MultiFileEffect;
928 return Effect::mainFileEdit(SM, std::move(Edit));
ArrayRef< const ParmVarDecl * > Parameters
const FunctionDecl * Decl
llvm::SmallString< 256U > Name
std::vector< std::pair< std::string, std::string > > Attributes
std::vector< const char * > Expected
::clang::DynTypedNode Node
#define REGISTER_TWEAK(Subclass)
std::optional< SourceRange > toHalfOpenFileRange(const SourceManager &SM, const LangOptions &LangOpts, SourceRange R)
Turns a token range into a half-open range and checks its correctness.
std::string printType(const QualType QT, const DeclContext &CurContext, const llvm::StringRef Placeholder)
Returns a QualType as string.
void findExplicitReferences(const Stmt *S, llvm::function_ref< void(ReferenceLoc)> Out, const HeuristicResolver *Resolver)
Recursively traverse S and report all references explicitly written in the code.
llvm::Error error(std::error_code EC, const char *Fmt, Ts &&... Vals)
llvm::StringRef toSourceCode(const SourceManager &SM, SourceRange R)
Returns the source code covered by the source range.
bool operator<(const Ref &L, const Ref &R)
@ Parameter
An inlay hint that is for a parameter.
===– Representation.cpp - ClangDoc Representation --------—*- C++ -*-===//
static const llvm::StringLiteral REFACTOR_KIND
static llvm::Expected< std::pair< Path, Edit > > fileEdit(const SourceManager &SM, FileID FID, tooling::Replacements Replacements)
Path is the absolute, symlink-resolved path for the file pointed by FID in SM.