[rs4gc] track the original value in the state use for base pointer rewriting
authorPhilip Reames <listmail@philipreames.com>
Sat, 6 Mar 2021 02:05:21 +0000 (18:05 -0800)
committerPhilip Reames <listmail@philipreames.com>
Sat, 6 Mar 2021 16:46:15 +0000 (08:46 -0800)
I'd originally intended to build on this for another purpose and have decided not to, but at a minimum, the stronger asserts are useful.

llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp

index ad1c61c..fd92695 100644 (file)
@@ -676,22 +676,30 @@ namespace {
 /// the base of this BDV.
 class BDVState {
 public:
-  enum Status {
+  enum StatusTy {
      // Starting state of lattice
      Unknown,
-     // Some specific base value
+     // Some specific base value -- does *not* mean that instruction
+     // propagates the base of the object
+     // ex: gep %arg, 16 -> %arg is the base value
      Base,
      // Need to insert a node to represent a merge.
      Conflict
   };
 
-  BDVState() {}
-  explicit BDVState(Status Status, Value *BaseValue = nullptr)
-      : Status(Status), BaseValue(BaseValue) {
+  BDVState() {
+    llvm_unreachable("missing state in map");
+  }
+
+  explicit BDVState(Value *OriginalValue)
+    : OriginalValue(OriginalValue) {}
+  explicit BDVState(Value *OriginalValue, StatusTy Status, Value *BaseValue = nullptr)
+    : OriginalValue(OriginalValue), Status(Status), BaseValue(BaseValue) {
     assert(Status != Base || BaseValue);
   }
 
-  Status getStatus() const { return Status; }
+  StatusTy getStatus() const { return Status; }
+  Value *getOriginalValue() const { return OriginalValue; }
   Value *getBaseValue() const { return BaseValue; }
 
   bool isBase() const { return getStatus() == Base; }
@@ -699,7 +707,8 @@ public:
   bool isConflict() const { return getStatus() == Conflict; }
 
   bool operator==(const BDVState &Other) const {
-    return BaseValue == Other.BaseValue && Status == Other.Status;
+    return OriginalValue == OriginalValue && BaseValue == Other.BaseValue &&
+      Status == Other.Status;
   }
 
   bool operator!=(const BDVState &other) const { return !(*this == other); }
@@ -722,12 +731,14 @@ public:
       OS << "C";
       break;
     }
-    OS << " (" << getBaseValue() << " - "
-       << (getBaseValue() ? getBaseValue()->getName() : "nullptr") << "): ";
+    OS << " (base " << getBaseValue() << " - "
+       << (getBaseValue() ? getBaseValue()->getName() : "nullptr") << ")"
+       << " for  "  << OriginalValue->getName() << ":";
   }
 
 private:
-  Status Status = Unknown;
+  AssertingVH<Value> OriginalValue; // instruction this state corresponds to
+  StatusTy Status = Unknown;
   AssertingVH<Value> BaseValue = nullptr; // Non-null only if Status == Base.
 };
 
@@ -740,39 +751,40 @@ static raw_ostream &operator<<(raw_ostream &OS, const BDVState &State) {
 }
 #endif
 
-static BDVState meetBDVStateImpl(const BDVState &LHS, const BDVState &RHS) {
-  switch (LHS.getStatus()) {
+static BDVState::StatusTy meet(const BDVState::StatusTy &LHS,
+                               const BDVState::StatusTy &RHS) {
+  switch (LHS) {
   case BDVState::Unknown:
     return RHS;
-
   case BDVState::Base:
-    assert(LHS.getBaseValue() && "can't be null");
-    if (RHS.isUnknown())
-      return LHS;
-
-    if (RHS.isBase()) {
-      if (LHS.getBaseValue() == RHS.getBaseValue()) {
-        assert(LHS == RHS && "equality broken!");
-        return LHS;
-      }
-      return BDVState(BDVState::Conflict);
-    }
-    assert(RHS.isConflict() && "only three states!");
-    return BDVState(BDVState::Conflict);
-
+    switch (RHS) {
+    case BDVState::Unknown:
+    case BDVState::Base:
+      return BDVState::Base;
+    case BDVState::Conflict:
+      return BDVState::Conflict;
+    };
+    llvm_unreachable("covered switch");
   case BDVState::Conflict:
-    return LHS;
+    return BDVState::Conflict;
   }
-  llvm_unreachable("only three states!");
+  llvm_unreachable("covered switch");
 }
 
 // Values of type BDVState form a lattice, and this function implements the meet
 // operation.
 static BDVState meetBDVState(const BDVState &LHS, const BDVState &RHS) {
-  BDVState Result = meetBDVStateImpl(LHS, RHS);
-  assert(Result == meetBDVStateImpl(RHS, LHS) &&
-         "Math is wrong: meet does not commute!");
-  return Result;
+  auto NewStatus = meet(LHS.getStatus(), RHS.getStatus());
+  assert(NewStatus == meet(RHS.getStatus(), LHS.getStatus()));
+
+  Value *BaseValue = LHS.getStatus() == BDVState::Base ?
+    LHS.getBaseValue() : RHS.getBaseValue();
+  if (LHS.getStatus() == BDVState::Base && RHS.getStatus() == BDVState::Base &&
+      LHS.getBaseValue() != RHS.getBaseValue()) {
+    NewStatus = BDVState::Conflict;
+    BaseValue = nullptr;
+  }
+  return BDVState(LHS.getOriginalValue(), NewStatus, BaseValue);
 }
 
 /// For a given value or instruction, figure out what base ptr its derived from.
@@ -822,12 +834,18 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
   // below.  This is important for deterministic compilation.
   MapVector<Value *, BDVState> States;
 
+  auto VerifyStates = [&]() {
+    for (auto &Entry : States) {
+      assert(Entry.first == Entry.second.getOriginalValue());
+    }
+  };
+
   // Recursively fill in all base defining values reachable from the initial
   // one for which we don't already know a definite base value for
   /* scope */ {
     SmallVector<Value*, 16> Worklist;
     Worklist.push_back(Def);
-    States.insert({Def, BDVState()});
+    States.insert({Def, BDVState(Def)});
     while (!Worklist.empty()) {
       Value *Current = Worklist.pop_back_val();
       assert(!isOriginalBaseResult(Current) && "why did it get added?");
@@ -843,7 +861,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
           return;
         assert(isExpectedBDVType(Base) && "the only non-base values "
                "we see should be base defining values");
-        if (States.insert(std::make_pair(Base, BDVState())).second)
+        if (States.insert(std::make_pair(Base, BDVState(Base))).second)
           Worklist.push_back(Base);
       };
       if (PHINode *PN = dyn_cast<PHINode>(Current)) {
@@ -868,6 +886,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
   }
 
 #ifndef NDEBUG
+  VerifyStates();
   LLVM_DEBUG(dbgs() << "States after initialization:\n");
   for (auto Pair : States) {
     LLVM_DEBUG(dbgs() << " " << Pair.second << " for " << *Pair.first << "\n");
@@ -878,7 +897,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
   // base state for known bases and expect to find a cached state otherwise.
   auto GetStateForBDV = [&](Value *BaseValue, Value *Input) {
     if (isKnownBaseResult(BaseValue) && areBothVectorOrScalar(BaseValue, Input))
-      return BDVState(BDVState::Base, BaseValue);
+      return BDVState(BaseValue, BDVState::Base, BaseValue);
     auto I = States.find(BaseValue);
     assert(I != States.end() && "lookup failed!");
     return I->second;
@@ -910,7 +929,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
         return GetStateForBDV(BDV, V);
       };
 
-      BDVState NewState;
+      BDVState NewState(BDV);
       if (SelectInst *SI = dyn_cast<SelectInst>(BDV)) {
         NewState = meetBDVState(NewState, getStateForInput(SI->getTrueValue()));
         NewState =
@@ -948,6 +967,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
   }
 
 #ifndef NDEBUG
+  VerifyStates();
   LLVM_DEBUG(dbgs() << "States after meet iteration:\n");
   for (auto Pair : States) {
     LLVM_DEBUG(dbgs() << " " << Pair.second << " for " << *Pair.first << "\n");
@@ -981,17 +1001,21 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
       auto *BaseInst = ExtractElementInst::Create(
           State.getBaseValue(), EE->getIndexOperand(), "base_ee", EE);
       BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {}));
-      States[I] = BDVState(BDVState::Base, BaseInst);
+      States[I] = BDVState(I, BDVState::Base, BaseInst);
     } else if (!isa<VectorType>(I->getType())) {
       // We need to handle cases that have a vector base but the instruction is
       // a scalar type (these could be phis or selects or any instruction that
       // are of scalar type, but the base can be a vector type).  We
       // conservatively set this as conflict.  Setting the base value for these
       // conflicts is handled in the next loop which traverses States.
-      States[I] = BDVState(BDVState::Conflict);
+      States[I] = BDVState(I, BDVState::Conflict);
     }
   }
 
+#ifndef NDEBUG
+  VerifyStates();
+#endif
+
   // Insert Phis for all conflicts
   // TODO: adjust naming patterns to avoid this order of iteration dependency
   for (auto Pair : States) {
@@ -1048,9 +1072,13 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
     Instruction *BaseInst = MakeBaseInstPlaceholder(I);
     // Add metadata marking this as a base value
     BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {}));
-    States[I] = BDVState(BDVState::Conflict, BaseInst);
+    States[I] = BDVState(I, BDVState::Conflict, BaseInst);
   }
 
+#ifndef NDEBUG
+  VerifyStates();
+#endif
+
   // Returns a instruction which produces the base pointer for a given
   // instruction.  The instruction is assumed to be an input to one of the BDVs
   // seen in the inference algorithm above.  As such, we must either already
@@ -1171,6 +1199,10 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
     }
   }
 
+#ifndef NDEBUG
+  VerifyStates();
+#endif
+
   // Cache all of our results so we can cheaply reuse them
   // NOTE: This is actually two caches: one of the base defining value
   // relation and one of the base pointer relation!  FIXME