18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
25 static llvm::cl::opt<bool>
27 llvm::cl::desc(
"Enable value profiling"),
28 llvm::cl::Hidden, llvm::cl::init(
false));
30 using namespace clang;
31 using namespace CodeGen;
33 void CodeGenPGO::setFuncName(StringRef Name,
34 llvm::GlobalValue::LinkageTypes
Linkage) {
35 llvm::IndexedInstrProfReader *PGOReader = CGM.
getPGOReader();
36 FuncName = llvm::getPGOFuncName(
38 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
42 FuncNameVar = llvm::createPGOFuncNameVar(CGM.
getModule(),
Linkage, FuncName);
45 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46 setFuncName(Fn->getName(), Fn->getLinkage());
48 llvm::createPGOFuncNameMetadata(*Fn, FuncName);
78 static const int NumBitsPerType = 6;
79 static const unsigned NumTypesPerWord =
sizeof(uint64_t) * 8 / NumBitsPerType;
80 static const unsigned TooBig = 1u << NumBitsPerType;
90 enum HashType :
unsigned char {
131 static_assert(LastHashType <= TooBig,
"Too many types in HashType");
134 : Working(0), Count(0), HashVersion(HashVersion) {}
135 void combine(HashType
Type);
139 const int PGOHash::NumBitsPerType;
140 const unsigned PGOHash::NumTypesPerWord;
141 const unsigned PGOHash::TooBig;
144 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
146 if (PGOReader->getVersion() <= 4)
148 if (PGOReader->getVersion() <= 5)
158 unsigned NextCounter;
162 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
164 uint64_t ProfileVersion;
166 MapRegionCounters(
PGOHashVersion HashVersion, uint64_t ProfileVersion,
167 llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
168 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
169 ProfileVersion(ProfileVersion) {}
173 bool TraverseBlockExpr(
BlockExpr *BE) {
return true; }
176 for (
auto C : zip(
LE->captures(),
LE->capture_inits()))
177 TraverseLambdaCapture(
LE, &std::get<0>(C), std::get<1>(C));
180 bool TraverseCapturedStmt(
CapturedStmt *CS) {
return true; }
182 bool VisitDecl(
const Decl *D) {
187 case Decl::CXXMethod:
188 case Decl::CXXConstructor:
189 case Decl::CXXDestructor:
190 case Decl::CXXConversion:
191 case Decl::ObjCMethod:
194 CounterMap[D->
getBody()] = NextCounter++;
202 PGOHash::HashType updateCounterMappings(
Stmt *S) {
205 CounterMap[S] = NextCounter++;
214 if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
215 if (S->isLogicalOp() &&
217 CounterMap[S->getRHS()] = NextCounter++;
218 return Base::VisitBinaryOperator(S);
222 bool VisitStmt(
Stmt *S) {
223 auto Type = updateCounterMappings(S);
225 Type = getHashType(Hash.getHashVersion(), S);
231 bool TraverseIfStmt(
IfStmt *If) {
234 return Base::TraverseIfStmt(If);
242 Hash.combine(PGOHash::IfThenBranch);
244 Hash.combine(PGOHash::IfElseBranch);
247 Hash.combine(PGOHash::EndOfScope);
254 #define DEFINE_NESTABLE_TRAVERSAL(N) \
255 bool Traverse##N(N *S) { \
256 Base::Traverse##N(S); \
257 if (Hash.getHashVersion() != PGO_HASH_V1) \
258 Hash.combine(PGOHash::EndOfScope); \
272 switch (S->getStmtClass()) {
275 case Stmt::LabelStmtClass:
276 return PGOHash::LabelStmt;
277 case Stmt::WhileStmtClass:
278 return PGOHash::WhileStmt;
279 case Stmt::DoStmtClass:
280 return PGOHash::DoStmt;
281 case Stmt::ForStmtClass:
282 return PGOHash::ForStmt;
283 case Stmt::CXXForRangeStmtClass:
284 return PGOHash::CXXForRangeStmt;
285 case Stmt::ObjCForCollectionStmtClass:
286 return PGOHash::ObjCForCollectionStmt;
287 case Stmt::SwitchStmtClass:
288 return PGOHash::SwitchStmt;
289 case Stmt::CaseStmtClass:
290 return PGOHash::CaseStmt;
291 case Stmt::DefaultStmtClass:
292 return PGOHash::DefaultStmt;
293 case Stmt::IfStmtClass:
294 return PGOHash::IfStmt;
295 case Stmt::CXXTryStmtClass:
296 return PGOHash::CXXTryStmt;
297 case Stmt::CXXCatchStmtClass:
298 return PGOHash::CXXCatchStmt;
299 case Stmt::ConditionalOperatorClass:
300 return PGOHash::ConditionalOperator;
301 case Stmt::BinaryConditionalOperatorClass:
302 return PGOHash::BinaryConditionalOperator;
303 case Stmt::BinaryOperatorClass: {
306 return PGOHash::BinaryOperatorLAnd;
308 return PGOHash::BinaryOperatorLOr;
314 return PGOHash::BinaryOperatorLT;
316 return PGOHash::BinaryOperatorGT;
318 return PGOHash::BinaryOperatorLE;
320 return PGOHash::BinaryOperatorGE;
322 return PGOHash::BinaryOperatorEQ;
324 return PGOHash::BinaryOperatorNE;
332 switch (S->getStmtClass()) {
335 case Stmt::GotoStmtClass:
336 return PGOHash::GotoStmt;
337 case Stmt::IndirectGotoStmtClass:
338 return PGOHash::IndirectGotoStmt;
339 case Stmt::BreakStmtClass:
340 return PGOHash::BreakStmt;
341 case Stmt::ContinueStmtClass:
342 return PGOHash::ContinueStmt;
343 case Stmt::ReturnStmtClass:
344 return PGOHash::ReturnStmt;
345 case Stmt::CXXThrowExprClass:
346 return PGOHash::ThrowExpr;
347 case Stmt::UnaryOperatorClass: {
350 return PGOHash::UnaryOperatorLNot;
368 bool RecordNextStmtCount;
371 uint64_t CurrentCount;
374 llvm::DenseMap<const Stmt *, uint64_t> &
CountMap;
377 struct BreakContinue {
379 uint64_t ContinueCount;
380 BreakContinue() : BreakCount(0), ContinueCount(0) {}
384 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &
CountMap,
388 void RecordStmtCount(
const Stmt *S) {
389 if (RecordNextStmtCount) {
391 RecordNextStmtCount =
false;
396 uint64_t setCount(uint64_t Count) {
397 CurrentCount = Count;
401 void VisitStmt(
const Stmt *S) {
403 for (
const Stmt *Child : S->children())
434 void VisitBlockDecl(
const BlockDecl *D) {
443 if (S->getRetValue())
444 Visit(S->getRetValue());
446 RecordNextStmtCount =
true;
454 RecordNextStmtCount =
true;
457 void VisitGotoStmt(
const GotoStmt *S) {
460 RecordNextStmtCount =
true;
463 void VisitLabelStmt(
const LabelStmt *S) {
464 RecordNextStmtCount =
false;
468 Visit(S->getSubStmt());
471 void VisitBreakStmt(
const BreakStmt *S) {
473 assert(!BreakContinueStack.empty() &&
"break not in a loop or switch!");
474 BreakContinueStack.back().BreakCount += CurrentCount;
476 RecordNextStmtCount =
true;
481 assert(!BreakContinueStack.empty() &&
"continue stmt not in a loop!");
482 BreakContinueStack.back().ContinueCount += CurrentCount;
484 RecordNextStmtCount =
true;
487 void VisitWhileStmt(
const WhileStmt *S) {
489 uint64_t ParentCount = CurrentCount;
491 BreakContinueStack.push_back(BreakContinue());
495 CountMap[S->getBody()] = CurrentCount;
497 uint64_t BackedgeCount = CurrentCount;
503 BreakContinue BC = BreakContinueStack.pop_back_val();
505 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
508 setCount(BC.BreakCount + CondCount - BodyCount);
509 RecordNextStmtCount =
true;
512 void VisitDoStmt(
const DoStmt *S) {
516 BreakContinueStack.push_back(BreakContinue());
518 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
521 uint64_t BackedgeCount = CurrentCount;
523 BreakContinue BC = BreakContinueStack.pop_back_val();
526 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
529 setCount(BC.BreakCount + CondCount - LoopCount);
530 RecordNextStmtCount =
true;
533 void VisitForStmt(
const ForStmt *S) {
538 uint64_t ParentCount = CurrentCount;
540 BreakContinueStack.push_back(BreakContinue());
546 uint64_t BackedgeCount = CurrentCount;
547 BreakContinue BC = BreakContinueStack.pop_back_val();
552 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
559 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
564 setCount(BC.BreakCount + CondCount - BodyCount);
565 RecordNextStmtCount =
true;
572 Visit(S->getLoopVarStmt());
573 Visit(S->getRangeStmt());
574 Visit(S->getBeginStmt());
575 Visit(S->getEndStmt());
577 uint64_t ParentCount = CurrentCount;
578 BreakContinueStack.push_back(BreakContinue());
584 uint64_t BackedgeCount = CurrentCount;
585 BreakContinue BC = BreakContinueStack.pop_back_val();
589 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
595 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
598 setCount(BC.BreakCount + CondCount - BodyCount);
599 RecordNextStmtCount =
true;
604 Visit(S->getElement());
605 uint64_t ParentCount = CurrentCount;
606 BreakContinueStack.push_back(BreakContinue());
611 uint64_t BackedgeCount = CurrentCount;
612 BreakContinue BC = BreakContinueStack.pop_back_val();
614 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
616 RecordNextStmtCount =
true;
625 BreakContinueStack.push_back(BreakContinue());
628 BreakContinue BC = BreakContinueStack.pop_back_val();
629 if (!BreakContinueStack.empty())
630 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
633 RecordNextStmtCount =
true;
637 RecordNextStmtCount =
false;
642 setCount(CurrentCount + CaseCount);
646 RecordNextStmtCount =
true;
647 Visit(S->getSubStmt());
650 void VisitIfStmt(
const IfStmt *S) {
653 if (S->isConsteval()) {
654 const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
660 uint64_t ParentCount = CurrentCount;
670 uint64_t OutCount = CurrentCount;
672 uint64_t ElseCount = ParentCount - ThenCount;
677 OutCount += CurrentCount;
679 OutCount += ElseCount;
681 RecordNextStmtCount =
true;
686 Visit(S->getTryBlock());
687 for (
unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
688 Visit(S->getHandler(I));
691 RecordNextStmtCount =
true;
695 RecordNextStmtCount =
false;
699 Visit(S->getHandlerBlock());
704 uint64_t ParentCount = CurrentCount;
712 uint64_t OutCount = CurrentCount;
714 uint64_t FalseCount = setCount(ParentCount - TrueCount);
717 OutCount += CurrentCount;
720 RecordNextStmtCount =
true;
725 uint64_t ParentCount = CurrentCount;
731 setCount(ParentCount + RHSCount - CurrentCount);
732 RecordNextStmtCount =
true;
737 uint64_t ParentCount = CurrentCount;
743 setCount(ParentCount + RHSCount - CurrentCount);
744 RecordNextStmtCount =
true;
749 void PGOHash::combine(HashType
Type) {
751 assert(
Type &&
"Hash is invalid: unexpected type 0");
752 assert(
unsigned(
Type) < TooBig &&
"Hash is invalid: too many types");
755 if (Count && Count % NumTypesPerWord == 0) {
756 using namespace llvm::support;
757 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
758 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped,
sizeof(Swapped)));
764 Working = Working << NumBitsPerType |
Type;
769 if (Count <= NumTypesPerWord)
780 MD5.update({(uint8_t)Working});
782 using namespace llvm::support;
783 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
784 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped,
sizeof(Swapped)));
789 llvm::MD5::MD5Result Result;
805 llvm::IndexedInstrProfReader *PGOReader = CGM.
getPGOReader();
806 if (!InstrumentRegions && !PGOReader)
814 if (
const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
823 if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
828 mapRegionCounters(D);
830 emitCounterRegionMapping(D);
833 loadRegionCounts(PGOReader,
SM.isInMainFile(D->
getLocation()));
834 computeRegionCounts(D);
835 applyFunctionAttributes(PGOReader, Fn);
839 void CodeGenPGO::mapRegionCounters(
const Decl *D) {
843 uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
845 HashVersion = getPGOHashVersion(PGOReader, CGM);
846 ProfileVersion = PGOReader->getVersion();
849 RegionCounterMap.reset(
new llvm::DenseMap<const Stmt *, unsigned>);
850 MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap);
851 if (
const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
853 else if (
const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
855 else if (
const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
856 Walker.TraverseDecl(
const_cast<BlockDecl *
>(BD));
857 else if (
const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
859 assert(Walker.NextCounter > 0 &&
"no entry counter mapped for decl");
860 NumRegionCounters = Walker.NextCounter;
861 FunctionHash = Walker.Hash.finalize();
864 bool CodeGenPGO::skipRegionMappingForDecl(
const Decl *D) {
874 !D->
hasAttr<CUDAGlobalAttr>()) ||
876 (D->
hasAttr<CUDAGlobalAttr>() ||
883 return SM.isInSystemHeader(Loc);
886 void CodeGenPGO::emitCounterRegionMapping(
const Decl *D) {
887 if (skipRegionMappingForDecl(D))
891 llvm::raw_string_ostream
OS(CoverageMapping);
895 MappingGen.emitCounterMapping(D, OS);
898 if (CoverageMapping.empty())
902 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
907 llvm::GlobalValue::LinkageTypes
Linkage) {
908 if (skipRegionMappingForDecl(D))
912 llvm::raw_string_ostream OS(CoverageMapping);
919 if (CoverageMapping.empty())
924 FuncNameVar, FuncName, FunctionHash, CoverageMapping,
false);
927 void CodeGenPGO::computeRegionCounts(
const Decl *D) {
928 StmtCountMap.reset(
new llvm::DenseMap<const Stmt *, uint64_t>);
929 ComputeRegionCounts Walker(*StmtCountMap, *
this);
930 if (
const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
931 Walker.VisitFunctionDecl(FD);
932 else if (
const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
933 Walker.VisitObjCMethodDecl(MD);
934 else if (
const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
935 Walker.VisitBlockDecl(BD);
936 else if (
const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
937 Walker.VisitCapturedDecl(
const_cast<CapturedDecl *
>(CD));
941 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
942 llvm::Function *Fn) {
947 Fn->setEntryCount(FunctionCount);
951 llvm::Value *StepV) {
954 if (!Builder.GetInsertBlock())
957 unsigned Counter = (*RegionCounterMap)[S];
960 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
961 Builder.getInt64(FunctionHash),
962 Builder.getInt32(NumRegionCounters),
963 Builder.getInt32(Counter), StepV};
965 Builder.CreateCall(CGM.
getIntrinsic(llvm::Intrinsic::instrprof_increment),
966 makeArrayRef(Args, 4));
969 CGM.
getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
982 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
987 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
990 if (isa<llvm::Constant>(ValuePtr))
994 if (InstrumentValueSites && RegionCounterMap) {
995 auto BuilderInsertPoint = Builder.saveIP();
996 Builder.SetInsertPoint(ValueSite);
997 llvm::Value *Args[5] = {
998 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
999 Builder.getInt64(FunctionHash),
1000 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
1001 Builder.getInt32(ValueKind),
1002 Builder.getInt32(NumValueSites[ValueKind]++)
1005 CGM.
getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
1006 Builder.restoreIP(BuilderInsertPoint);
1010 llvm::IndexedInstrProfReader *PGOReader = CGM.
getPGOReader();
1018 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1021 llvm::annotateValueSite(CGM.
getModule(), *ValueSite, *ProfRecord,
1022 (llvm::InstrProfValueKind)ValueKind,
1023 NumValueSites[ValueKind]);
1025 NumValueSites[ValueKind]++;
1029 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1030 bool IsInMainFile) {
1032 RegionCounts.clear();
1034 PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1035 if (
auto E = RecordExpected.takeError()) {
1036 auto IPE = llvm::InstrProfError::take(std::move(E));
1037 if (IPE == llvm::instrprof_error::unknown_function)
1039 else if (IPE == llvm::instrprof_error::hash_mismatch)
1041 else if (IPE == llvm::instrprof_error::malformed)
1047 std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1048 RegionCounts = ProfRecord->Counts;
1056 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1069 assert(Scale &&
"scale by 0?");
1070 uint64_t Scaled = Weight / Scale + 1;
1071 assert(Scaled <= UINT32_MAX &&
"overflow 32-bits");
1075 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1076 uint64_t FalseCount)
const {
1078 if (!TrueCount && !FalseCount)
1092 if (Weights.size() < 2)
1096 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1104 ScaledWeights.reserve(Weights.size());
1105 for (uint64_t W : Weights)
1109 return MDHelper.createBranchWeights(ScaledWeights);
1113 CodeGenFunction::createProfileWeightsForLoop(
const Stmt *Cond,
1114 uint64_t LoopCount)
const {
1118 if (!CondCount || *CondCount == 0)
1120 return createProfileWeights(LoopCount,
1121 std::max(*CondCount, LoopCount) - LoopCount);