[RS4GC] Fix algorithm to avoid setting vector BDV for scalar derived pointer
authorAnna Thomas <anna@azul.com>
Thu, 14 May 2020 13:15:57 +0000 (09:15 -0400)
committerAnna Thomas <anna@azul.com>
Thu, 14 May 2020 14:03:30 +0000 (10:03 -0400)
Summary:
This is a more general fix to 59029b9eef23 (D75704).
This patch does the following:
1. updates isKnownBaseValue to account for base pointer and
derived pointer having differing types.
2. This inturn allows us to populate the
lattice (States) for such derived pointers.
3. It also updates all states where the base and derived pointers have
differing types (vector versus scalar) and conservatively marks these
states as conflictcs.
Note that in 59029b9eef23, we were just fixing existing lattice values
and that too, only for uses of extractelement.

Reviewers: reames, skatkov, dantrushin

Reviewed By: skatkov

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D76305

llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
llvm/test/Transforms/RewriteStatepointsForGC/scalar-base-vector.ll

index 468c9b8..0efa6f4 100644 (file)
@@ -387,8 +387,13 @@ static void analyzeParsePointLiveness(
   Result.LiveSet = LiveSet;
 }
 
+// Returns true is V is a knownBaseResult.
 static bool isKnownBaseResult(Value *V);
 
+// Returns true if V is a BaseResult that already exists in the IR, i.e. it is
+// not created by the findBasePointers algorithm.
+static bool isOriginalBaseResult(Value *V);
+
 namespace {
 
 /// A single base defining value - An immediate base defining value for an
@@ -633,15 +638,20 @@ static Value *findBaseOrBDV(Value *I, DefiningValueMapTy &Cache) {
   return Def;
 }
 
+/// This value is a base pointer that is not generated by RS4GC, i.e. it already
+/// exists in the code.
+static bool isOriginalBaseResult(Value *V) {
+  // no recursion possible
+  return !isa<PHINode>(V) && !isa<SelectInst>(V) &&
+         !isa<ExtractElementInst>(V) && !isa<InsertElementInst>(V) &&
+         !isa<ShuffleVectorInst>(V);
+}
+
 /// Given the result of a call to findBaseDefiningValue, or findBaseOrBDV,
 /// is it known to be a base pointer?  Or do we need to continue searching.
 static bool isKnownBaseResult(Value *V) {
-  if (!isa<PHINode>(V) && !isa<SelectInst>(V) &&
-      !isa<ExtractElementInst>(V) && !isa<InsertElementInst>(V) &&
-      !isa<ShuffleVectorInst>(V)) {
-    // no recursion possible
+  if (isOriginalBaseResult(V))
     return true;
-  }
   if (isa<Instruction>(V) &&
       cast<Instruction>(V)->getMetadata("is_base_value")) {
     // This is a previously inserted base phi or select.  We know
@@ -653,6 +663,12 @@ static bool isKnownBaseResult(Value *V) {
   return false;
 }
 
+// Returns true if First and Second values are both scalar or both vector.
+static bool areBothVectorOrScalar(Value *First, Value *Second) {
+  return isa<VectorType>(First->getType()) ==
+         isa<VectorType>(Second->getType());
+}
+
 namespace {
 
 /// Models the state of a single base defining value in the findBasePointer
@@ -762,7 +778,7 @@ static BDVState meetBDVState(const BDVState &LHS, const BDVState &RHS) {
 static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
   Value *Def = findBaseOrBDV(I, Cache);
 
-  if (isKnownBaseResult(Def))
+  if (isKnownBaseResult(Def) && areBothVectorOrScalar(Def, I))
     return Def;
 
   // Here's the rough algorithm:
@@ -810,13 +826,16 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
     States.insert({Def, BDVState()});
     while (!Worklist.empty()) {
       Value *Current = Worklist.pop_back_val();
-      assert(!isKnownBaseResult(Current) && "why did it get added?");
+      assert(!isOriginalBaseResult(Current) && "why did it get added?");
 
       auto visitIncomingValue = [&](Value *InVal) {
         Value *Base = findBaseOrBDV(InVal, Cache);
-        if (isKnownBaseResult(Base))
+        if (isKnownBaseResult(Base) && areBothVectorOrScalar(Base, InVal))
           // Known bases won't need new instructions introduced and can be
-          // ignored safely
+          // ignored safely. However, this can only be done when InVal and Base
+          // are both scalar or both vector. Otherwise, we need to find a
+          // correct BDV for InVal, by creating an entry in the lattice
+          // (States).
           return;
         assert(isExpectedBDVType(Base) && "the only non-base values "
                "we see should be base defining values");
@@ -853,10 +872,10 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
 
   // Return a phi state for a base defining value.  We'll generate a new
   // base state for known bases and expect to find a cached state otherwise.
-  auto getStateForBDV = [&](Value *baseValue) {
-    if (isKnownBaseResult(baseValue))
-      return BDVState(baseValue);
-    auto I = States.find(baseValue);
+  auto GetStateForBDV = [&](Value *BaseValue, Value *Input) {
+    if (isKnownBaseResult(BaseValue) && areBothVectorOrScalar(BaseValue, Input))
+      return BDVState(BaseValue);
+    auto I = States.find(BaseValue);
     assert(I != States.end() && "lookup failed!");
     return I->second;
   };
@@ -873,13 +892,18 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
     // much faster.
     for (auto Pair : States) {
       Value *BDV = Pair.first;
-      assert(!isKnownBaseResult(BDV) && "why did it get added?");
+      // Only values that do not have known bases or those that have differing
+      // type (scalar versus vector) from a possible known base should be in the
+      // lattice.
+      assert((!isKnownBaseResult(BDV) ||
+             !areBothVectorOrScalar(BDV, Pair.second.getBaseValue())) &&
+                 "why did it get added?");
 
       // Given an input value for the current instruction, return a BDVState
       // instance which represents the BDV of that value.
       auto getStateForInput = [&](Value *V) mutable {
         Value *BDV = findBaseOrBDV(V, Cache);
-        return getStateForBDV(BDV);
+        return GetStateForBDV(BDV, V);
       };
 
       BDVState NewState;
@@ -926,41 +950,41 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
   }
 #endif
 
-  // Handle extractelement instructions and their uses.
+  // Handle all instructions that have a vector BDV, but the instruction itself
+  // is of scalar type.
   for (auto Pair : States) {
     Instruction *I = cast<Instruction>(Pair.first);
     BDVState State = Pair.second;
-    assert(!isKnownBaseResult(I) && "why did it get added?");
+    auto *BaseValue = State.getBaseValue();
+    // Only values that do not have known bases or those that have differing
+    // type (scalar versus vector) from a possible known base should be in the
+    // lattice.
+    assert((!isKnownBaseResult(I) || !areBothVectorOrScalar(I, BaseValue)) &&
+           "why did it get added?");
     assert(!State.isUnknown() && "Optimistic algorithm didn't complete!");
 
+    if (!State.isBase() || !isa<VectorType>(BaseValue->getType()))
+      continue;
     // extractelement instructions are a bit special in that we may need to
     // insert an extract even when we know an exact base for the instruction.
     // The problem is that we need to convert from a vector base to a scalar
     // base for the particular indice we're interested in.
-    if (!State.isBase() || !isa<ExtractElementInst>(I) ||
-        !isa<VectorType>(State.getBaseValue()->getType()))
-      continue;
-    auto *EE = cast<ExtractElementInst>(I);
-    // TODO: In many cases, the new instruction is just EE itself.  We should
-    // exploit this, but can't do it here since it would break the invariant
-    // about the BDV not being known to be a base.
-    auto *BaseInst = ExtractElementInst::Create(
-        State.getBaseValue(), EE->getIndexOperand(), "base_ee", EE);
-    BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {}));
-    States[I] = BDVState(BDVState::Base, BaseInst);
-
-    // We need to handle uses of the extractelement that have the same vector
-    // base as well but the use is a scalar type. Since we cannot reuse the
-    // same BaseInst above (may not satisfy property that base pointer should
-    // always dominate derived pointer), we conservatively set this as conflict.
-    // Setting the base value for these conflicts is handled in the next loop
-    // which traverses States.
-    for (User *U : I->users()) {
-      auto *UseI = dyn_cast<Instruction>(U);
-      if (!UseI || !States.count(UseI))
-        continue;
-      if (!isa<VectorType>(UseI->getType()) && States[UseI] == State)
-        States[UseI] = BDVState(BDVState::Conflict);
+    if (isa<ExtractElementInst>(I)) {
+      auto *EE = cast<ExtractElementInst>(I);
+      // TODO: In many cases, the new instruction is just EE itself.  We should
+      // exploit this, but can't do it here since it would break the invariant
+      // about the BDV not being known to be a base.
+      auto *BaseInst = ExtractElementInst::Create(
+          State.getBaseValue(), EE->getIndexOperand(), "base_ee", EE);
+      BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {}));
+      States[I] = BDVState(BDVState::Base, BaseInst);
+    } else if (!isa<VectorType>(I->getType())) {
+      // We need to handle cases that have a vector base but the instruction is
+      // a scalar type (these could be phis or selects or any instruction that
+      // are of scalar type, but the base can be a vector type).  We
+      // conservatively set this as conflict.  Setting the base value for these
+      // conflicts is handled in the next loop which traverses States.
+      States[I] = BDVState(BDVState::Conflict);
     }
   }
 
@@ -969,7 +993,11 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
   for (auto Pair : States) {
     Instruction *I = cast<Instruction>(Pair.first);
     BDVState State = Pair.second;
-    assert(!isKnownBaseResult(I) && "why did it get added?");
+    // Only values that do not have known bases or those that have differing
+    // type (scalar versus vector) from a possible known base should be in the
+    // lattice.
+    assert((!isKnownBaseResult(I) || !areBothVectorOrScalar(I, State.getBaseValue())) &&
+           "why did it get added?");
     assert(!State.isUnknown() && "Optimistic algorithm didn't complete!");
 
     // Since we're joining a vector and scalar base, they can never be the
@@ -1030,7 +1058,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
   auto getBaseForInput = [&](Value *Input, Instruction *InsertPt) {
     Value *BDV = findBaseOrBDV(Input, Cache);
     Value *Base = nullptr;
-    if (isKnownBaseResult(BDV)) {
+    if (isKnownBaseResult(BDV) && areBothVectorOrScalar(BDV, Input)) {
       Base = BDV;
     } else {
       // Either conflict or base.
@@ -1051,7 +1079,12 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
     Instruction *BDV = cast<Instruction>(Pair.first);
     BDVState State = Pair.second;
 
-    assert(!isKnownBaseResult(BDV) && "why did it get added?");
+    // Only values that do not have known bases or those that have differing
+    // type (scalar versus vector) from a possible known base should be in the
+    // lattice.
+    assert((!isKnownBaseResult(BDV) ||
+            !areBothVectorOrScalar(BDV, State.getBaseValue())) &&
+           "why did it get added?");
     assert(!State.isUnknown() && "Optimistic algorithm didn't complete!");
     if (!State.isConflict())
       continue;
@@ -1141,7 +1174,11 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
     auto *BDV = Pair.first;
     Value *Base = Pair.second.getBaseValue();
     assert(BDV && Base);
-    assert(!isKnownBaseResult(BDV) && "why did it get added?");
+    // Only values that do not have known bases or those that have differing
+    // type (scalar versus vector) from a possible known base should be in the
+    // lattice.
+    assert((!isKnownBaseResult(BDV) || !areBothVectorOrScalar(BDV, Base)) &&
+           "why did it get added?");
 
     LLVM_DEBUG(
         dbgs() << "Updating base value cache"
index 34af81c..a4290ef 100644 (file)
@@ -192,5 +192,75 @@ latch:                                              ; preds = %bb25, %bb7
   br label %header
 }
 
+; Uses of extractelement that are of scalar type should not have the BDV
+; incorrectly identified as a vector type.
+define void @widget() gc "statepoint-example" {
+; CHECK-LABEL: @widget(
+; CHECK-NEXT:  bb6:
+; CHECK-NEXT:    [[BASE_EE:%.*]] = extractelement <2 x i8 addrspace(1)*> zeroinitializer, i32 1, !is_base_value !0
+; CHECK-NEXT:    [[TMP:%.*]] = extractelement <2 x i8 addrspace(1)*> undef, i32 1
+; CHECK-NEXT:    br i1 undef, label [[BB7:%.*]], label [[BB9:%.*]]
+; CHECK:       bb7:
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i8, i8 addrspace(1)* [[TMP]], i64 12
+; CHECK-NEXT:    br label [[BB11:%.*]]
+; CHECK:       bb9:
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds i8, i8 addrspace(1)* [[TMP]], i64 12
+; CHECK-NEXT:    br i1 undef, label [[BB11]], label [[BB15:%.*]]
+; CHECK:       bb11:
+; CHECK-NEXT:    [[TMP12_BASE:%.*]] = phi i8 addrspace(1)* [ [[BASE_EE]], [[BB7]] ], [ [[BASE_EE]], [[BB9]] ], !is_base_value !0
+; CHECK-NEXT:    [[TMP12:%.*]] = phi i8 addrspace(1)* [ [[TMP8]], [[BB7]] ], [ [[TMP10]], [[BB9]] ]
+; CHECK-NEXT:    [[STATEPOINT_TOKEN:%.*]] = call token (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 2882400000, i32 0, void ()* @snork, i32 0, i32 0, i32 0, i32 1, i32 undef, i8 addrspace(1)* [[TMP12_BASE]], i8 addrspace(1)* [[TMP12]])
+; CHECK-NEXT:    [[TMP12_BASE_RELOCATED:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN]], i32 8, i32 8)
+; CHECK-NEXT:    [[TMP12_RELOCATED:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN]], i32 8, i32 9)
+; CHECK-NEXT:    br label [[BB15]]
+; CHECK:       bb15:
+; CHECK-NEXT:    [[TMP16_BASE:%.*]] = phi i8 addrspace(1)* [ [[BASE_EE]], [[BB9]] ], [ [[TMP12_BASE_RELOCATED]], [[BB11]] ], !is_base_value !0
+; CHECK-NEXT:    [[TMP16:%.*]] = phi i8 addrspace(1)* [ [[TMP10]], [[BB9]] ], [ [[TMP12_RELOCATED]], [[BB11]] ]
+; CHECK-NEXT:    br i1 undef, label [[BB17:%.*]], label [[BB20:%.*]]
+; CHECK:       bb17:
+; CHECK-NEXT:    [[STATEPOINT_TOKEN1:%.*]] = call token (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 2882400000, i32 0, void ()* @snork, i32 0, i32 0, i32 0, i32 1, i32 undef, i8 addrspace(1)* [[TMP16_BASE]], i8 addrspace(1)* [[TMP16]])
+; CHECK-NEXT:    [[TMP16_BASE_RELOCATED:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN1]], i32 8, i32 8)
+; CHECK-NEXT:    [[TMP16_RELOCATED:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN1]], i32 8, i32 9)
+; CHECK-NEXT:    br label [[BB20]]
+; CHECK:       bb20:
+; CHECK-NEXT:    [[DOT05:%.*]] = phi i8 addrspace(1)* [ [[TMP16_BASE_RELOCATED]], [[BB17]] ], [ [[TMP16_BASE]], [[BB15]] ]
+; CHECK-NEXT:    [[DOT0:%.*]] = phi i8 addrspace(1)* [ [[TMP16_RELOCATED]], [[BB17]] ], [ [[TMP16]], [[BB15]] ]
+; CHECK-NEXT:    [[STATEPOINT_TOKEN2:%.*]] = call token (i64, i32, void (i8 addrspace(1)*)*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidp1i8f(i64 2882400000, i32 0, void (i8 addrspace(1)*)* @foo, i32 1, i32 0, i8 addrspace(1)* [[DOT0]], i32 0, i32 0, i8 addrspace(1)* [[DOT05]], i8 addrspace(1)* [[DOT0]])
+; CHECK-NEXT:    [[TMP16_BASE_RELOCATED3:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN2]], i32 8, i32 8)
+; CHECK-NEXT:    [[TMP16_RELOCATED4:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN2]], i32 8, i32 9)
+; CHECK-NEXT:    ret void
+;
+bb6:                                              ; preds = %bb3
+  %tmp = extractelement <2 x i8 addrspace(1)*> undef, i32 1
+  br i1 undef, label %bb7, label %bb9
+
+bb7:                                              ; preds = %bb6
+  %tmp8 = getelementptr inbounds i8, i8 addrspace(1)* %tmp, i64 12
+  br label %bb11
+
+bb9:                                              ; preds = %bb6, %bb6
+  %tmp10 = getelementptr inbounds i8, i8 addrspace(1)* %tmp, i64 12
+  br i1 undef, label %bb11, label %bb15
+
+bb11:                                             ; preds = %bb9, %bb7
+  %tmp12 = phi i8 addrspace(1)* [ %tmp8, %bb7 ], [ %tmp10, %bb9 ]
+  call void @snork() [ "deopt"(i32 undef) ]
+  br label %bb15
+
+bb15:                                             ; preds = %bb11, %bb9, %bb9
+  %tmp16 = phi i8 addrspace(1)* [ %tmp10, %bb9 ], [ %tmp12, %bb11 ]
+  br i1 undef, label %bb17, label %bb20
+
+bb17:                                             ; preds = %bb15
+  call void @snork() [ "deopt"(i32 undef) ]
+  br label %bb20
+
+bb20:                                             ; preds = %bb17, %bb15, %bb15
+  call void @foo(i8 addrspace(1)* %tmp16)
+  ret void
+}
+
+declare void @snork()
+declare void @foo(i8 addrspace(1)*)
 declare void @spam()
 declare <2 x i8 addrspace(1)*> @baz()