29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/Optional.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/StringRef.h"
33 #include "llvm/Support/Casting.h"
34 #include "llvm/Support/ErrorHandling.h"
55 using namespace clang;
56 using namespace consumed;
64 for (
const auto &B : *
Block)
66 return CS->getStmt()->getBeginLoc();
70 if (
Block->succ_size() == 1 && *
Block->succ_begin())
79 if (
const Stmt *StmtNode =
Block->getTerminatorStmt()) {
80 return StmtNode->getBeginLoc();
83 BE =
Block->rend(); BI != BE; ++BI) {
85 return CS->getStmt()->getBeginLoc();
91 if (
Block->succ_size() == 1 && *
Block->succ_begin())
97 if (
Block->pred_size() == 1 && *
Block->pred_begin())
114 llvm_unreachable(
"invalid enum");
119 for (
const auto &S : CWAttr->callableStates()) {
127 case CallableWhenAttr::Unconsumed:
131 case CallableWhenAttr::Consumed:
136 if (MappedAttrState ==
State)
148 return RD->hasAttr<ConsumableAttr>();
158 return RD->hasAttr<ConsumableAutoCastAttr>();
165 return RD->hasAttr<ConsumableSetOnReadAttr>();
178 llvm_unreachable(
"invalid enum");
186 return FunDecl->
hasAttr<TestTypestateAttr>();
196 const ConsumableAttr *CAttr =
199 switch (CAttr->getDefaultState()) {
202 case ConsumableAttr::Unconsumed:
204 case ConsumableAttr::Consumed:
207 llvm_unreachable(
"invalid enum");
212 switch (PTAttr->getParamState()) {
215 case ParamTypestateAttr::Unconsumed:
217 case ParamTypestateAttr::Consumed:
220 llvm_unreachable(
"invalid_enum");
225 switch (RTSAttr->getState()) {
228 case ReturnTypestateAttr::Unconsumed:
230 case ReturnTypestateAttr::Consumed:
233 llvm_unreachable(
"invalid enum");
237 switch (STAttr->getNewState()) {
240 case SetTypestateAttr::Unconsumed:
242 case SetTypestateAttr::Consumed:
245 llvm_unreachable(
"invalid_enum");
262 llvm_unreachable(
"invalid enum");
267 switch (FunDecl->
getAttr<TestTypestateAttr>()->getTestState()) {
268 case TestTypestateAttr::Unconsumed:
270 case TestTypestateAttr::Consumed:
273 llvm_unreachable(
"invalid enum");
278 struct VarTestResult {
301 } InfoType = IT_None;
321 : InfoType(IT_VarTest), VarTest(VarTest) {}
324 : InfoType(IT_VarTest) {
326 VarTest.TestsFor = TestsFor;
330 const VarTestResult <est,
const VarTestResult &RTest)
331 : InfoType(IT_BinTest) {
332 BinTest.Source = Source;
334 BinTest.LTest = LTest;
335 BinTest.RTest = RTest;
341 : InfoType(IT_BinTest) {
342 BinTest.Source = Source;
344 BinTest.LTest.Var = LVar;
345 BinTest.LTest.TestsFor = LTestsFor;
346 BinTest.RTest.Var = RVar;
347 BinTest.RTest.TestsFor = RTestsFor;
354 : InfoType(IT_Tmp), Tmp(Tmp) {}
357 assert(InfoType == IT_State);
362 assert(InfoType == IT_VarTest);
367 assert(InfoType == IT_BinTest);
368 return BinTest.LTest;
372 assert(InfoType == IT_BinTest);
373 return BinTest.RTest;
377 assert(InfoType == IT_Var);
382 assert(InfoType == IT_Tmp);
387 assert(isVar() || isTmp() || isState());
400 assert(InfoType == IT_BinTest);
405 assert(InfoType == IT_BinTest);
406 return BinTest.Source;
409 bool isValid()
const {
return InfoType != IT_None; }
410 bool isState()
const {
return InfoType == IT_State; }
411 bool isVarTest()
const {
return InfoType == IT_VarTest; }
412 bool isBinTest()
const {
return InfoType == IT_BinTest; }
413 bool isVar()
const {
return InfoType == IT_Var; }
414 bool isTmp()
const {
return InfoType == IT_Tmp; }
417 return InfoType == IT_VarTest || InfoType == IT_BinTest;
421 return InfoType == IT_Var || InfoType == IT_Tmp;
425 assert(InfoType == IT_VarTest || InfoType == IT_BinTest);
427 if (InfoType == IT_VarTest) {
431 }
else if (InfoType == IT_BinTest) {
460 using MapType = llvm::DenseMap<const Stmt *, PropagationInfo>;
461 using PairType= std::pair<const Stmt *, PropagationInfo>;
462 using InfoEntry = MapType::iterator;
463 using ConstInfoEntry = MapType::const_iterator;
467 MapType PropagationMap;
469 InfoEntry findInfo(
const Expr *E) {
470 if (
const auto Cleanups = dyn_cast<ExprWithCleanups>(E))
471 if (!Cleanups->cleanupsHaveSideEffects())
472 E = Cleanups->getSubExpr();
476 ConstInfoEntry findInfo(
const Expr *E)
const {
477 if (
const auto Cleanups = dyn_cast<ExprWithCleanups>(E))
478 if (!Cleanups->cleanupsHaveSideEffects())
479 E = Cleanups->getSubExpr();
487 void forwardInfo(
const Expr *From,
const Expr *To);
497 bool handleCall(
const CallExpr *Call,
const Expr *ObjArg,
501 void VisitCallExpr(
const CallExpr *Call);
508 void VisitDeclStmt(
const DeclStmt *DelcS);
510 void VisitMemberExpr(
const MemberExpr *MExpr);
514 void VisitVarDecl(
const VarDecl *Var);
517 : Analyzer(Analyzer), StateMap(StateMap) {}
520 ConstInfoEntry Entry = findInfo(StmtNode);
522 if (Entry != PropagationMap.end())
523 return Entry->second;
529 StateMap = NewStateMap;
536 void ConsumedStmtVisitor::forwardInfo(
const Expr *From,
const Expr *To) {
537 InfoEntry Entry = findInfo(From);
538 if (Entry != PropagationMap.end())
539 insertInfo(To, Entry->second);
544 void ConsumedStmtVisitor::copyInfo(
const Expr *From,
const Expr *To,
546 InfoEntry Entry = findInfo(From);
547 if (Entry != PropagationMap.end()) {
559 InfoEntry Entry = findInfo(From);
560 if (Entry != PropagationMap.end()) {
569 InfoEntry Entry = findInfo(To);
570 if (Entry != PropagationMap.end()) {
584 const CallableWhenAttr *CWAttr = FunDecl->
getAttr<CallableWhenAttr>();
614 if (isa<CXXOperatorCallExpr>(Call) && isa<CXXMethodDecl>(FunD))
618 for (
unsigned Index =
Offset; Index < Call->getNumArgs(); ++Index) {
626 InfoEntry Entry = findInfo(Call->getArg(Index));
628 if (Entry == PropagationMap.end() || Entry->second.isTest())
633 if (ParamTypestateAttr *PTA = Param->
getAttr<ParamTypestateAttr>()) {
637 if (ParamState != ExpectedState)
639 Call->getArg(Index)->getExprLoc(),
643 if (!(Entry->second.isVar() || Entry->second.isTmp()))
647 if (ReturnTypestateAttr *RT = Param->
getAttr<ReturnTypestateAttr>())
661 InfoEntry Entry = findInfo(ObjArg);
662 if (Entry != PropagationMap.end()) {
666 if (SetTypestateAttr *STA = FunD->
getAttr<SetTypestateAttr>()) {
671 else if (PInfo.
isTmp()) {
677 PropagationMap.insert(PairType(Call,
684 void ConsumedStmtVisitor::propagateReturnType(
const Expr *Call,
692 if (ReturnTypestateAttr *RTA = Fun->
getAttr<ReturnTypestateAttr>())
705 InfoEntry LEntry = findInfo(BinOp->
getLHS()),
706 REntry = findInfo(BinOp->
getRHS());
708 VarTestResult LTest, RTest;
710 if (LEntry != PropagationMap.end() && LEntry->second.isVarTest()) {
711 LTest = LEntry->second.getVarTest();
717 if (REntry != PropagationMap.end() && REntry->second.isVarTest()) {
718 RTest = REntry->second.getVarTest();
724 if (!(LTest.Var ==
nullptr && RTest.Var ==
nullptr))
732 forwardInfo(BinOp->
getLHS(), BinOp);
747 if (Call->isCallToStdMove()) {
753 propagateReturnType(Call, FunDecl);
757 forwardInfo(
Cast->getSubExpr(),
Cast);
763 InfoEntry Entry = findInfo(Temp->
getSubExpr());
765 if (Entry != PropagationMap.end() && !Entry->second.isTest()) {
766 StateMap->
setState(Temp, Entry->second.getAsState(StateMap));
774 QualType ThisType = Constructor->getThisType()->getPointeeType();
780 if (ReturnTypestateAttr *RTA = Constructor->getAttr<ReturnTypestateAttr>()) {
784 }
else if (Constructor->isDefaultConstructor()) {
785 PropagationMap.insert(PairType(Call,
787 }
else if (Constructor->isMoveConstructor()) {
789 }
else if (Constructor->isCopyConstructor()) {
794 copyInfo(Call->getArg(0), Call, NS);
808 handleCall(Call, Call->getImplicitObjectArgument(), MD);
809 propagateReturnType(Call, MD);
814 const auto *FunDecl = dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee());
815 if (!FunDecl)
return;
817 if (Call->getOperator() == OO_Equal) {
819 if (!
handleCall(Call, Call->getArg(0), FunDecl))
820 setInfo(Call->getArg(0), CS);
824 if (
const auto *MCall = dyn_cast<CXXMemberCallExpr>(Call))
825 handleCall(MCall, MCall->getImplicitObjectArgument(), FunDecl);
829 propagateReturnType(Call, FunDecl);
833 if (
const auto *Var = dyn_cast_or_null<VarDecl>(DeclRef->
getDecl()))
839 for (
const auto *DI : DeclS->
decls())
840 if (isa<VarDecl>(DI))
844 if (
const auto *Var = dyn_cast_or_null<VarDecl>(DeclS->
getSingleDecl()))
854 forwardInfo(MExpr->
getBase(), MExpr);
861 if (
const ParamTypestateAttr *PTA = Param->
getAttr<ParamTypestateAttr>())
873 StateMap->
setState(Param, ParamState);
879 if (ExpectedState !=
CS_None) {
880 InfoEntry Entry = findInfo(
Ret->getRetValue());
882 if (Entry != PropagationMap.end()) {
885 if (RetState != ExpectedState)
897 InfoEntry Entry = findInfo(UOp->
getSubExpr());
898 if (Entry == PropagationMap.end())
return;
902 PropagationMap.insert(PairType(UOp, Entry->second));
906 if (Entry->second.isTest())
907 PropagationMap.insert(PairType(UOp, Entry->second.invertTest()));
920 if (VIT != PropagationMap.end()) {
941 ThenStates->
setState(Test.Var, Test.TestsFor);
945 }
else if (VarState == Test.TestsFor) {
953 const VarTestResult <est = PInfo.
getLTest(),
962 ThenStates->
setState(LTest.Var, LTest.TestsFor);
965 }
else if (LState == LTest.TestsFor &&
isKnownState(RState)) {
966 if (RState == RTest.TestsFor)
975 }
else if (LState == LTest.TestsFor) {
979 if (RState == RTest.TestsFor)
990 ThenStates->
setState(RTest.Var, RTest.TestsFor);
997 else if (RState == RTest.TestsFor)
1005 assert(CurrBlock &&
"Block pointer must not be NULL");
1006 assert(TargetBlock &&
"TargetBlock pointer must not be NULL");
1008 unsigned int CurrBlockOrder = VisitOrder[CurrBlock->
getBlockID()];
1010 PE = TargetBlock->
pred_end(); PI != PE; ++PI) {
1011 if (*PI && CurrBlockOrder < VisitOrder[(*PI)->getBlockID()] )
1019 std::unique_ptr<ConsumedStateMap> &OwnedStateMap) {
1020 assert(
Block &&
"Block pointer must not be NULL");
1022 auto &Entry = StateMapsArray[
Block->getBlockID()];
1025 Entry->intersect(*StateMap);
1026 }
else if (OwnedStateMap)
1027 Entry = std::move(OwnedStateMap);
1029 Entry = std::make_unique<ConsumedStateMap>(*StateMap);
1033 std::unique_ptr<ConsumedStateMap> StateMap) {
1034 assert(
Block &&
"Block pointer must not be NULL");
1036 auto &Entry = StateMapsArray[
Block->getBlockID()];
1039 Entry->intersect(*StateMap);
1041 Entry = std::move(StateMap);
1046 assert(
Block &&
"Block pointer must not be NULL");
1047 assert(StateMapsArray[
Block->getBlockID()] &&
"Block has no block info");
1049 return StateMapsArray[
Block->getBlockID()].get();
1053 StateMapsArray[
Block->getBlockID()] =
nullptr;
1056 std::unique_ptr<ConsumedStateMap>
1058 assert(
Block &&
"Block pointer must not be NULL");
1060 auto &Entry = StateMapsArray[
Block->getBlockID()];
1066 assert(From &&
"From block must not be NULL");
1067 assert(To &&
"From block must not be NULL");
1073 assert(
Block &&
"Block pointer must not be NULL");
1077 if (
Block->pred_size() < 2)
1080 unsigned int BlockVisitOrder = VisitOrder[
Block->getBlockID()];
1082 PE =
Block->pred_end(); PI != PE; ++PI) {
1083 if (*PI && BlockVisitOrder < VisitOrder[(*PI)->getBlockID()])
1092 for (
const auto &DM :
VarMap) {
1093 if (isa<ParmVarDecl>(DM.first)) {
1094 const auto *Param = cast<ParmVarDecl>(DM.first);
1095 const ReturnTypestateAttr *RTA = Param->getAttr<ReturnTypestateAttr>();
1101 if (DM.second != ExpectedState)
1114 VarMapType::const_iterator Entry =
VarMap.find(Var);
1116 if (Entry !=
VarMap.end())
1117 return Entry->second;
1124 TmpMapType::const_iterator Entry =
TmpMap.find(Tmp);
1126 if (Entry !=
TmpMap.end())
1127 return Entry->second;
1135 if (this->From && this->From == Other.
From && !Other.
Reachable) {
1140 for (
const auto &DM : Other.
VarMap) {
1141 LocalState = this->
getState(DM.first);
1146 if (LocalState != DM.second)
1158 for (
const auto &DM : LoopBackStates->
VarMap) {
1159 LocalState = this->
getState(DM.first);
1164 if (LocalState != DM.second) {
1167 DM.first->getNameAsString());
1192 for (
const auto &DM : Other->
VarMap)
1193 if (this->
getState(DM.first) != DM.second)
1201 if (
const auto *Constructor = dyn_cast<CXXConstructorDecl>(D)) {
1202 ReturnType = Constructor->getThisType()->getPointeeType();
1206 if (
const ReturnTypestateAttr *RTSAttr = D->
getAttr<ReturnTypestateAttr>()) {
1208 if (!RD || !RD->
hasAttr<ConsumableAttr>()) {
1214 RTSAttr->getLocation(), ReturnType.
getAsString());
1215 ExpectedReturnState =
CS_None;
1220 ExpectedReturnState =
CS_None;
1225 ExpectedReturnState =
CS_None;
1228 bool ConsumedAnalyzer::splitState(
const CFGBlock *CurrBlock,
1230 std::unique_ptr<ConsumedStateMap> FalseStates(
1234 if (
const auto *IfNode =
1236 const Expr *Cond = IfNode->getCond();
1238 PInfo = Visitor.getInfo(Cond);
1239 if (!PInfo.
isValid() && isa<BinaryOperator>(Cond))
1240 PInfo = Visitor.getInfo(cast<BinaryOperator>(Cond)->getRHS());
1243 CurrStates->setSource(Cond);
1244 FalseStates->setSource(Cond);
1254 }
else if (
const auto *BinOp =
1256 PInfo = Visitor.getInfo(BinOp->getLHS());
1258 if ((BinOp = dyn_cast_or_null<BinaryOperator>(BinOp->getLHS()))) {
1259 PInfo = Visitor.getInfo(BinOp->getRHS());
1268 CurrStates->setSource(BinOp);
1269 FalseStates->setSource(BinOp);
1271 const VarTestResult &Test = PInfo.
getVarTest();
1274 if (BinOp->getOpcode() == BO_LAnd) {
1276 CurrStates->setState(Test.Var, Test.TestsFor);
1278 CurrStates->markUnreachable();
1280 }
else if (BinOp->getOpcode() == BO_LOr) {
1282 FalseStates->setState(Test.Var,
1284 else if (VarState == Test.TestsFor)
1285 FalseStates->markUnreachable();
1294 BlockInfo.
addInfo(*SI, std::move(CurrStates));
1296 CurrStates =
nullptr;
1299 BlockInfo.
addInfo(*SI, std::move(FalseStates));
1305 const auto *D = dyn_cast_or_null<FunctionDecl>(AC.
getDecl());
1313 determineExpectedReturnState(AC, D);
1320 CurrStates = std::make_unique<ConsumedStateMap>();
1328 for (
const auto *CurrBlock : *SortedGraph) {
1330 CurrStates = BlockInfo.getInfo(CurrBlock);
1334 }
else if (!CurrStates->isReachable()) {
1335 CurrStates =
nullptr;
1339 Visitor.
reset(CurrStates.get());
1342 for (
const auto &B : *CurrBlock) {
1343 switch (B.getKind()) {
1355 CurrStates->remove(BTE);
1377 if (!splitState(CurrBlock, Visitor)) {
1378 CurrStates->setSource(
nullptr);
1380 if (CurrBlock->succ_size() > 1 ||
1381 (CurrBlock->succ_size() == 1 &&
1382 (*CurrBlock->succ_begin())->pred_size() > 1)) {
1384 auto *RawState = CurrStates.get();
1387 SE = CurrBlock->succ_end(); SI != SE; ++SI) {
1388 if (*SI ==
nullptr)
continue;
1390 if (BlockInfo.isBackEdge(CurrBlock, *SI)) {
1391 BlockInfo.borrowInfo(*SI)->intersectAtLoopHead(
1394 if (BlockInfo.allBackEdgesVisited(CurrBlock, *SI))
1395 BlockInfo.discardInfo(*SI);
1397 BlockInfo.addInfo(*SI, RawState, CurrStates);
1401 CurrStates =
nullptr;
1407 CurrStates->checkParamsForReturnTypestate(D->
getLocation(),
1412 CurrStates =
nullptr;