56#include "clang/AST/ASTContext.h"
57#include "clang/AST/Decl.h"
58#include "clang/AST/DeclBase.h"
59#include "clang/AST/NestedNameSpecifier.h"
60#include "clang/AST/RecursiveASTVisitor.h"
61#include "clang/AST/Stmt.h"
62#include "clang/Basic/LangOptions.h"
63#include "clang/Basic/SourceLocation.h"
64#include "clang/Basic/SourceManager.h"
65#include "clang/Tooling/Core/Replacement.h"
66#include "clang/Tooling/Refactoring/Extract/SourceExtraction.h"
67#include "llvm/ADT/STLExtras.h"
68#include "llvm/ADT/SmallSet.h"
69#include "llvm/ADT/SmallVector.h"
70#include "llvm/ADT/StringRef.h"
71#include "llvm/Support/Casting.h"
72#include "llvm/Support/Error.h"
73#include "llvm/Support/raw_os_ostream.h"
80using Node = SelectionTree::Node;
85enum class ZoneRelative {
92enum FunctionDeclKind {
101bool isRootStmt(
const Node *N) {
102 if (!N->ASTNode.get<Stmt>())
124const Node *getParentOfRootStmts(
const Node *CommonAnc) {
128 switch (CommonAnc->Selected) {
140 Parent = CommonAnc->Parent;
145 if (
Parent->ASTNode.get<DeclStmt>())
150 return llvm::all_of(
Parent->Children, isRootStmt) ?
Parent :
nullptr;
154struct ExtractionZone {
166 SourceLocation getInsertionPoint()
const {
169 bool isRootStmt(
const Stmt *S)
const;
172 const Node *getLastRootStmt()
const {
return Parent->Children.back(); }
178 bool requiresHoisting(
const SourceManager &SM,
179 const HeuristicResolver *Resolver)
const {
181 llvm::SmallSet<const Decl *, 1> DeclsInExtZone;
182 for (
auto *RootStmt : RootStmts) {
185 [&DeclsInExtZone](
const ReferenceLoc &
Loc) {
188 DeclsInExtZone.insert(
Loc.Targets.front());
193 if (DeclsInExtZone.empty())
197 if (SM.isBeforeInTranslationUnit(S->getSourceRange().getEnd(),
200 bool HasPostUse =
false;
203 [&](
const ReferenceLoc &
Loc) {
205 SM.isBeforeInTranslationUnit(
Loc.NameLoc,
ZoneRange.getEnd()))
207 HasPostUse = llvm::any_of(
Loc.Targets,
208 [&DeclsInExtZone](
const Decl *Target) {
209 return DeclsInExtZone.contains(Target);
224bool alwaysReturns(
const ExtractionZone &EZ) {
225 const Stmt *Last = EZ.getLastRootStmt()->ASTNode.get<Stmt>();
227 while (
const auto *CS = llvm::dyn_cast<CompoundStmt>(Last)) {
228 if (CS->body_empty())
230 Last = CS->body_back();
232 return llvm::isa<ReturnStmt>(Last);
235bool ExtractionZone::isRootStmt(
const Stmt *S)
const {
240const FunctionDecl *findEnclosingFunction(
const Node *CommonAnc) {
242 for (
const Node *CurNode = CommonAnc; CurNode; CurNode = CurNode->Parent) {
244 if (CurNode->ASTNode.get<LambdaExpr>())
246 if (
const FunctionDecl *Func = CurNode->ASTNode.get<FunctionDecl>()) {
248 if (Func->isTemplated())
250 if (!Func->getBody())
252 for (
const auto *S : Func->getBody()->children()) {
267std::optional<SourceRange> findZoneRange(
const Node *
Parent,
268 const SourceManager &SM,
269 const LangOptions &LangOpts) {
272 SM, LangOpts,
Parent->Children.front()->ASTNode.getSourceRange()))
273 SR.setBegin(BeginFileRange->getBegin());
277 SM, LangOpts,
Parent->Children.back()->ASTNode.getSourceRange()))
278 SR.setEnd(EndFileRange->getEnd());
288std::optional<SourceRange>
290 const SourceManager &SM,
291 const LangOptions &LangOpts) {
297bool validSingleChild(
const Node *Child,
const FunctionDecl *EnclosingFunc) {
301 if (Child->ASTNode.get<Expr>())
304 assert(EnclosingFunc->hasBody() &&
305 "We should always be extracting from a function body.");
306 if (Child->ASTNode.get<Stmt>() == EnclosingFunc->getBody())
313std::optional<ExtractionZone> findExtractionZone(
const Node *CommonAnc,
314 const SourceManager &SM,
315 const LangOptions &LangOpts) {
316 ExtractionZone ExtZone;
317 ExtZone.Parent = getParentOfRootStmts(CommonAnc);
318 if (!ExtZone.Parent || ExtZone.Parent->Children.empty())
320 ExtZone.EnclosingFunction = findEnclosingFunction(ExtZone.Parent);
321 if (!ExtZone.EnclosingFunction)
325 if (ExtZone.Parent->Children.size() == 1 &&
326 !validSingleChild(ExtZone.getLastRootStmt(), ExtZone.EnclosingFunction))
329 computeEnclosingFuncRange(ExtZone.EnclosingFunction, SM, LangOpts))
330 ExtZone.EnclosingFuncRange = *FuncRange;
331 if (
auto ZoneRange = findZoneRange(ExtZone.Parent, SM, LangOpts))
333 if (ExtZone.EnclosingFuncRange.isInvalid() || ExtZone.ZoneRange.isInvalid())
336 for (
const Node *Child : ExtZone.Parent->Children)
337 ExtZone.RootStmts.insert(Child->ASTNode.get<Stmt>());
350 std::string render(
const DeclContext *Context)
const;
351 bool operator<(
const Parameter &Other)
const {
355 std::string
Name =
"extracted";
368 ConstexprSpecKind
Constexpr = ConstexprSpecKind::Unspecified;
374 const LangOptions *LangOpts;
375 NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy,
376 const LangOptions *LangOpts)
379 std::string renderCall()
const;
381 std::string renderDeclaration(FunctionDeclKind
K,
382 const DeclContext &SemanticDC,
383 const DeclContext &SyntacticDC,
384 const SourceManager &SM)
const;
388 renderParametersForDeclaration(
const DeclContext &Enclosing)
const;
389 std::string renderParametersForCall()
const;
390 std::string renderSpecifiers(FunctionDeclKind
K)
const;
391 std::string renderQualifiers()
const;
392 std::string renderDeclarationName(FunctionDeclKind
K)
const;
394 std::string getFuncBody(
const SourceManager &SM)
const;
397std::string NewFunction::renderParametersForDeclaration(
398 const DeclContext &Enclosing)
const {
400 bool NeedCommaBefore =
false;
404 NeedCommaBefore =
true;
405 Result += P.render(&Enclosing);
410std::string NewFunction::renderParametersForCall()
const {
412 bool NeedCommaBefore =
false;
416 NeedCommaBefore =
true;
422std::string NewFunction::renderSpecifiers(FunctionDeclKind
K)
const {
425 if (
Static &&
K != FunctionDeclKind::OutOfLineDefinition) {
430 case ConstexprSpecKind::Unspecified:
431 case ConstexprSpecKind::Constinit:
433 case ConstexprSpecKind::Constexpr:
436 case ConstexprSpecKind::Consteval:
444std::string NewFunction::renderQualifiers()
const {
454std::string NewFunction::renderDeclarationName(FunctionDeclKind
K)
const {
459 std::string QualifierName;
460 llvm::raw_string_ostream Oss(QualifierName);
462 return llvm::formatv(
"{0}{1}", QualifierName,
Name);
465std::string NewFunction::renderCall()
const {
468 renderParametersForCall(),
472std::string NewFunction::renderDeclaration(FunctionDeclKind
K,
475 const SourceManager &SM)
const {
476 std::string
Declaration = std::string(llvm::formatv(
477 "{0}{1} {2}({3}){4}", renderSpecifiers(
K),
479 renderParametersForDeclaration(
SemanticDC), renderQualifiers()));
482 case ForwardDeclaration:
483 return std::string(llvm::formatv(
"{0};\n",
Declaration));
484 case OutOfLineDefinition:
485 case InlineDefinition:
487 llvm::formatv(
"{0} {\n{1}\n}\n",
Declaration, getFuncBody(SM)));
490 llvm_unreachable(
"Unsupported FunctionDeclKind enum");
493std::string NewFunction::getFuncBody(
const SourceManager &SM)
const {
502std::string NewFunction::Parameter::render(
const DeclContext *Context)
const {
507struct CapturedZoneInfo {
508 struct DeclInformation {
516 DeclInformation(
const Decl *TheDecl, ZoneRelative DeclaredIn,
520 void markOccurence(ZoneRelative ReferenceLoc);
531 DeclInformation *createDeclInfo(
const Decl *D, ZoneRelative RelativeLoc);
532 DeclInformation *getDeclInfoFor(
const Decl *D);
535CapturedZoneInfo::DeclInformation *
536CapturedZoneInfo::createDeclInfo(
const Decl *D, ZoneRelative RelativeLoc) {
539 {D, DeclInformation(D, RelativeLoc,
DeclInfoMap.size())});
541 return &InsertionResult.first->second;
544CapturedZoneInfo::DeclInformation *
545CapturedZoneInfo::getDeclInfoFor(
const Decl *D) {
550 return &Iter->second;
553void CapturedZoneInfo::DeclInformation::markOccurence(
554 ZoneRelative ReferenceLoc) {
555 switch (ReferenceLoc) {
556 case ZoneRelative::Inside:
559 case ZoneRelative::After:
567bool isLoop(
const Stmt *S) {
568 return isa<ForStmt>(S) || isa<DoStmt>(S) || isa<WhileStmt>(S) ||
569 isa<CXXForRangeStmt>(S);
573CapturedZoneInfo captureZoneInfo(
const ExtractionZone &ExtZone) {
577 class ExtractionZoneVisitor
578 :
public clang::RecursiveASTVisitor<ExtractionZoneVisitor> {
580 ExtractionZoneVisitor(
const ExtractionZone &ExtZone) : ExtZone(ExtZone) {
581 TraverseDecl(
const_cast<FunctionDecl *
>(ExtZone.EnclosingFunction));
584 bool TraverseStmt(Stmt *S) {
587 bool IsRootStmt = ExtZone.isRootStmt(
const_cast<const Stmt *
>(S));
591 CurrentLocation = ZoneRelative::Inside;
592 addToLoopSwitchCounters(S, 1);
594 RecursiveASTVisitor::TraverseStmt(S);
595 addToLoopSwitchCounters(S, -1);
599 CurrentLocation = ZoneRelative::After;
605 void addToLoopSwitchCounters(Stmt *S,
int Increment) {
606 if (CurrentLocation != ZoneRelative::Inside)
609 CurNumberOfNestedLoops += Increment;
610 else if (isa<SwitchStmt>(S))
611 CurNumberOfSwitch += Increment;
614 bool VisitDecl(
Decl *D) {
615 Info.createDeclInfo(D, CurrentLocation);
619 bool VisitDeclRefExpr(DeclRefExpr *DRE) {
621 const Decl *D = DRE->getDecl();
622 auto *DeclInfo =
Info.getDeclInfoFor(D);
625 DeclInfo =
Info.createDeclInfo(D, ZoneRelative::OutsideFunc);
626 DeclInfo->markOccurence(CurrentLocation);
631 bool VisitReturnStmt(ReturnStmt *Return) {
632 if (CurrentLocation == ZoneRelative::Inside)
633 Info.HasReturnStmt =
true;
637 bool VisitBreakStmt(BreakStmt *Break) {
640 if (CurrentLocation == ZoneRelative::Inside &&
641 !(CurNumberOfNestedLoops || CurNumberOfSwitch))
642 Info.BrokenControlFlow =
true;
646 bool VisitContinueStmt(ContinueStmt *Continue) {
649 if (CurrentLocation == ZoneRelative::Inside && !CurNumberOfNestedLoops)
650 Info.BrokenControlFlow =
true;
653 CapturedZoneInfo
Info;
654 const ExtractionZone &ExtZone;
655 ZoneRelative CurrentLocation = ZoneRelative::Before;
658 unsigned CurNumberOfNestedLoops = 0;
659 unsigned CurNumberOfSwitch = 0;
661 ExtractionZoneVisitor Visitor(ExtZone);
662 CapturedZoneInfo Result = std::move(Visitor.Info);
663 Result.AlwaysReturns = alwaysReturns(ExtZone);
671bool createParameters(NewFunction &ExtractedFunc,
672 const CapturedZoneInfo &CapturedInfo) {
673 for (
const auto &KeyVal : CapturedInfo.DeclInfoMap) {
674 const auto &DeclInfo = KeyVal.second;
678 if (DeclInfo.DeclaredIn == ZoneRelative::Inside &&
679 DeclInfo.IsReferencedInPostZone)
681 if (!DeclInfo.IsReferencedInZone)
683 if (DeclInfo.DeclaredIn == ZoneRelative::Inside ||
684 DeclInfo.DeclaredIn == ZoneRelative::OutsideFunc)
687 const ValueDecl *VD = dyn_cast_or_null<ValueDecl>(DeclInfo.TheDecl);
690 if (!VD || isa<FunctionDecl>(DeclInfo.TheDecl))
693 QualType
TypeInfo = VD->getType().getNonReferenceType();
699 bool IsPassedByReference =
true;
701 ExtractedFunc.Parameters.push_back({std::string(VD->getName()),
TypeInfo,
703 DeclInfo.DeclIndex});
705 llvm::sort(ExtractedFunc.Parameters);
712tooling::ExtractionSemicolonPolicy
713getSemicolonPolicy(ExtractionZone &ExtZone,
const SourceManager &SM,
714 const LangOptions &LangOpts) {
716 SourceRange FuncBodyRange = {ExtZone.ZoneRange.getBegin(),
717 ExtZone.ZoneRange.getEnd().getLocWithOffset(-1)};
719 ExtZone.getLastRootStmt()->ASTNode.get<Stmt>(), FuncBodyRange, SM,
722 ExtZone.ZoneRange.setEnd(FuncBodyRange.getEnd().getLocWithOffset(1));
727bool generateReturnProperties(NewFunction &ExtractedFunc,
728 const FunctionDecl &EnclosingFunc,
729 const CapturedZoneInfo &CapturedInfo) {
733 if (CapturedInfo.HasReturnStmt) {
736 if (!CapturedInfo.AlwaysReturns)
738 QualType Ret = EnclosingFunc.getReturnType();
741 if (Ret->isDependentType())
743 ExtractedFunc.ReturnType = Ret;
747 ExtractedFunc.ReturnType = EnclosingFunc.getParentASTContext().VoidTy;
751void captureMethodInfo(NewFunction &ExtractedFunc,
752 const CXXMethodDecl *
Method) {
753 ExtractedFunc.Static =
Method->isStatic();
754 ExtractedFunc.Const =
Method->isConst();
755 ExtractedFunc.EnclosingClass =
Method->getParent();
760llvm::Expected<NewFunction> getExtractedFunction(ExtractionZone &ExtZone,
761 const SourceManager &SM,
762 const LangOptions &LangOpts) {
763 CapturedZoneInfo CapturedInfo = captureZoneInfo(ExtZone);
765 if (CapturedInfo.BrokenControlFlow)
766 return error(
"Cannot extract break/continue without corresponding "
767 "loop/switch statement.");
768 NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts),
771 ExtractedFunc.SyntacticDC =
772 ExtZone.EnclosingFunction->getLexicalDeclContext();
773 ExtractedFunc.SemanticDC = ExtZone.EnclosingFunction->getDeclContext();
774 ExtractedFunc.DefinitionQualifier = ExtZone.EnclosingFunction->getQualifier();
775 ExtractedFunc.Constexpr = ExtZone.EnclosingFunction->getConstexprKind();
778 llvm::dyn_cast<CXXMethodDecl>(ExtZone.EnclosingFunction))
779 captureMethodInfo(ExtractedFunc,
Method);
781 if (ExtZone.EnclosingFunction->isOutOfLine()) {
784 const auto *FirstOriginalDecl =
785 ExtZone.EnclosingFunction->getCanonicalDecl();
789 return error(
"Declaration is inside a macro");
790 ExtractedFunc.ForwardDeclarationPoint = DeclPos->getBegin();
791 ExtractedFunc.ForwardDeclarationSyntacticDC = ExtractedFunc.SemanticDC;
794 ExtractedFunc.BodyRange = ExtZone.ZoneRange;
795 ExtractedFunc.DefinitionPoint = ExtZone.getInsertionPoint();
797 ExtractedFunc.CallerReturnsValue = CapturedInfo.AlwaysReturns;
798 if (!createParameters(ExtractedFunc, CapturedInfo) ||
799 !generateReturnProperties(ExtractedFunc, *ExtZone.EnclosingFunction,
801 return error(
"Too complex to extract.");
802 return ExtractedFunc;
805class ExtractFunction :
public Tweak {
807 const char *id() const final;
808 bool prepare(const Selection &Inputs) override;
809 Expected<Effect> apply(const Selection &Inputs) override;
810 std::
string title()
const override {
return "Extract to function"; }
811 llvm::StringLiteral kind()
const override {
816 ExtractionZone ExtZone;
820tooling::Replacement replaceWithFuncCall(
const NewFunction &ExtractedFunc,
821 const SourceManager &SM,
822 const LangOptions &LangOpts) {
823 std::string FuncCall = ExtractedFunc.renderCall();
824 return tooling::Replacement(
825 SM, CharSourceRange(ExtractedFunc.BodyRange,
false), FuncCall, LangOpts);
828tooling::Replacement createFunctionDefinition(
const NewFunction &ExtractedFunc,
829 const SourceManager &SM) {
830 FunctionDeclKind DeclKind = InlineDefinition;
831 if (ExtractedFunc.ForwardDeclarationPoint)
832 DeclKind = OutOfLineDefinition;
833 std::string FunctionDef = ExtractedFunc.renderDeclaration(
834 DeclKind, *ExtractedFunc.SemanticDC, *ExtractedFunc.SyntacticDC, SM);
836 return tooling::Replacement(SM, ExtractedFunc.DefinitionPoint, 0,
840tooling::Replacement createForwardDeclaration(
const NewFunction &ExtractedFunc,
841 const SourceManager &SM) {
842 std::string FunctionDecl = ExtractedFunc.renderDeclaration(
843 ForwardDeclaration, *ExtractedFunc.SemanticDC,
844 *ExtractedFunc.ForwardDeclarationSyntacticDC, SM);
845 SourceLocation DeclPoint = *ExtractedFunc.ForwardDeclarationPoint;
847 return tooling::Replacement(SM, DeclPoint, 0, FunctionDecl);
851bool hasReturnStmt(
const ExtractionZone &ExtZone) {
852 class ReturnStmtVisitor
853 :
public clang::RecursiveASTVisitor<ReturnStmtVisitor> {
855 bool VisitReturnStmt(ReturnStmt *Return) {
863 for (
const Stmt *RootStmt : ExtZone.RootStmts) {
864 V.TraverseStmt(
const_cast<Stmt *
>(RootStmt));
871bool ExtractFunction::prepare(
const Selection &Inputs) {
872 const LangOptions &LangOpts = Inputs.AST->getLangOpts();
873 if (!LangOpts.CPlusPlus)
875 const Node *CommonAnc = Inputs.ASTSelection.commonAncestor();
876 const SourceManager &SM = Inputs.AST->getSourceManager();
877 auto MaybeExtZone = findExtractionZone(CommonAnc, SM, LangOpts);
879 (hasReturnStmt(*MaybeExtZone) && !alwaysReturns(*MaybeExtZone)))
883 if (MaybeExtZone->requiresHoisting(SM, Inputs.AST->getHeuristicResolver()))
886 ExtZone = std::move(*MaybeExtZone);
890Expected<Tweak::Effect> ExtractFunction::apply(
const Selection &Inputs) {
891 const SourceManager &SM = Inputs.AST->getSourceManager();
892 const LangOptions &LangOpts = Inputs.AST->getLangOpts();
893 auto ExtractedFunc = getExtractedFunction(ExtZone, SM, LangOpts);
896 return ExtractedFunc.takeError();
897 tooling::Replacements Edit;
898 if (
auto Err = Edit.add(createFunctionDefinition(*ExtractedFunc, SM)))
899 return std::move(Err);
900 if (
auto Err = Edit.add(replaceWithFuncCall(*ExtractedFunc, SM, LangOpts)))
901 return std::move(Err);
903 if (
auto FwdLoc = ExtractedFunc->ForwardDeclarationPoint) {
906 if (SM.isWrittenInSameFile(ExtractedFunc->DefinitionPoint, *FwdLoc)) {
907 if (
auto Err = Edit.add(createForwardDeclaration(*ExtractedFunc, SM)))
908 return std::move(Err);
910 auto MultiFileEffect = Effect::mainFileEdit(SM, std::move(Edit));
911 if (!MultiFileEffect)
912 return MultiFileEffect.takeError();
914 tooling::Replacements OtherEdit(
915 createForwardDeclaration(*ExtractedFunc, SM));
918 MultiFileEffect->ApplyEdits.try_emplace(PathAndEdit->first,
919 PathAndEdit->second);
921 return PathAndEdit.takeError();
922 return MultiFileEffect;
925 return Effect::mainFileEdit(SM, std::move(Edit));
ArrayRef< const ParmVarDecl * > Parameters
const FunctionDecl * Decl
llvm::SmallString< 256U > Name
std::vector< std::pair< std::string, std::string > > Attributes
std::vector< const char * > Expected
::clang::DynTypedNode Node
#define REGISTER_TWEAK(Subclass)
std::optional< SourceRange > toHalfOpenFileRange(const SourceManager &SM, const LangOptions &LangOpts, SourceRange R)
Turns a token range into a half-open range and checks its correctness.
std::string printType(const QualType QT, const DeclContext &CurContext, const llvm::StringRef Placeholder)
Returns a QualType as string.
void findExplicitReferences(const Stmt *S, llvm::function_ref< void(ReferenceLoc)> Out, const HeuristicResolver *Resolver)
Recursively traverse S and report all references explicitly written in the code.
llvm::Error error(std::error_code EC, const char *Fmt, Ts &&... Vals)
llvm::StringRef toSourceCode(const SourceManager &SM, SourceRange R)
Returns the source code covered by the source range.
bool operator<(const Ref &L, const Ref &R)
@ Parameter
An inlay hint that is for a parameter.
===– Representation.cpp - ClangDoc Representation --------—*- C++ -*-===//
static const llvm::StringLiteral REFACTOR_KIND
static llvm::Expected< std::pair< Path, Edit > > fileEdit(const SourceManager &SM, FileID FID, tooling::Replacements Replacements)
Path is the absolute, symlink-resolved path for the file pointed by FID in SM.