[NewGVN] Track simplification dependencies for phi-of-ops.
authorFlorian Hahn <flo@fhahn.com>
Fri, 23 Apr 2021 08:27:06 +0000 (09:27 +0100)
committerFlorian Hahn <flo@fhahn.com>
Fri, 23 Apr 2021 08:48:38 +0000 (09:48 +0100)
If we are using a simplified value, we need to add an extra
dependency this value , because changes to the class of the
simplified value may require us to invalidate any decision based on
that value.

This is done by adding such values as additional users, however the
current code does not excludes temporary instructions.

At the moment, this means that we miss those dependencies for
phi-of-ops, because they are temporary instructions at this point. We
instead need to add the extra dependencies to the root instruction of
the phi-of-ops.

This patch pushes the responsibility of adding extra users to the
callers of createExpression & performSymbolicEvaluation. At those
points, it is clearer which real instruction to pick.

Alternatively we could either pass the 'real' instruction as additional
argument or use another map, but I think the approach in the patch makes
things a bit easier to follow.

Fixes PR35074.

Reviewed By: asbirlea

Differential Revision: https://reviews.llvm.org/D99987

llvm/lib/Transforms/Scalar/NewGVN.cpp
llvm/test/Transforms/NewGVN/phi-of-ops-simplification-dependencies.ll [new file with mode: 0644]

index 98c537c..46254df 100644 (file)
@@ -668,8 +668,23 @@ public:
   bool runGVN();
 
 private:
+  /// Helper struct return a Expression with an optional extra dependency.
+  struct ExprResult {
+    const Expression *Expr;
+    Value *ExtraDep;
+
+    ~ExprResult() { assert(!ExtraDep && "unhandled ExtraDep"); }
+
+    operator bool() const { return Expr; }
+
+    static ExprResult none() { return {nullptr, nullptr}; }
+    static ExprResult some(const Expression *Expr, Value *ExtraDep = nullptr) {
+      return {Expr, ExtraDep};
+    }
+  };
+
   // Expression handling.
-  const Expression *createExpression(Instruction *) const;
+  ExprResult createExpression(Instruction *) const;
   const Expression *createBinaryExpression(unsigned, Type *, Value *, Value *,
                                            Instruction *) const;
 
@@ -742,10 +757,9 @@ private:
   void valueNumberInstruction(Instruction *);
 
   // Symbolic evaluation.
-  const Expression *checkSimplificationResults(Expression *, Instruction *,
-                                               Value *) const;
-  const Expression *performSymbolicEvaluation(Value *,
-                                              SmallPtrSetImpl<Value *> &) const;
+  ExprResult checkExprResults(Expression *, Instruction *, Value *) const;
+  ExprResult performSymbolicEvaluation(Value *,
+                                       SmallPtrSetImpl<Value *> &) const;
   const Expression *performSymbolicLoadCoercion(Type *, Value *, LoadInst *,
                                                 Instruction *,
                                                 MemoryAccess *) const;
@@ -757,7 +771,7 @@ private:
                                                  Instruction *I,
                                                  BasicBlock *PHIBlock) const;
   const Expression *performSymbolicAggrValueEvaluation(Instruction *) const;
-  const Expression *performSymbolicCmpEvaluation(Instruction *) const;
+  ExprResult performSymbolicCmpEvaluation(Instruction *) const;
   const Expression *performSymbolicPredicateInfoEvaluation(Instruction *) const;
 
   // Congruence finding.
@@ -814,6 +828,7 @@ private:
   void addPredicateUsers(const PredicateBase *, Instruction *) const;
   void addMemoryUsers(const MemoryAccess *To, MemoryAccess *U) const;
   void addAdditionalUsers(Value *To, Value *User) const;
+  void addAdditionalUsers(ExprResult &Res, Value *User) const;
 
   // Main loop of value numbering
   void iterateTouchedInstructions();
@@ -1052,19 +1067,21 @@ const Expression *NewGVN::createBinaryExpression(unsigned Opcode, Type *T,
   E->op_push_back(lookupOperandLeader(Arg2));
 
   Value *V = SimplifyBinOp(Opcode, E->getOperand(0), E->getOperand(1), SQ);
-  if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))
-    return SimplifiedE;
+  if (auto Simplified = checkExprResults(E, I, V)) {
+    addAdditionalUsers(Simplified, I);
+    return Simplified.Expr;
+  }
   return E;
 }
 
 // Take a Value returned by simplification of Expression E/Instruction
 // I, and see if it resulted in a simpler expression. If so, return
 // that expression.
-const Expression *NewGVN::checkSimplificationResults(Expression *E,
-                                                     Instruction *I,
-                                                     Value *V) const {
+NewGVN::ExprResult NewGVN::checkExprResults(Expression *E, Instruction *I,
+                                            Value *V) const {
   if (!V)
-    return nullptr;
+    return ExprResult::none();
+
   if (auto *C = dyn_cast<Constant>(V)) {
     if (I)
       LLVM_DEBUG(dbgs() << "Simplified " << *I << " to "
@@ -1073,52 +1090,37 @@ const Expression *NewGVN::checkSimplificationResults(Expression *E,
     assert(isa<BasicExpression>(E) &&
            "We should always have had a basic expression here");
     deleteExpression(E);
-    return createConstantExpression(C);
+    return ExprResult::some(createConstantExpression(C));
   } else if (isa<Argument>(V) || isa<GlobalVariable>(V)) {
     if (I)
       LLVM_DEBUG(dbgs() << "Simplified " << *I << " to "
                         << " variable " << *V << "\n");
     deleteExpression(E);
-    return createVariableExpression(V);
+    return ExprResult::some(createVariableExpression(V));
   }
 
   CongruenceClass *CC = ValueToClass.lookup(V);
   if (CC) {
     if (CC->getLeader() && CC->getLeader() != I) {
-      // If we simplified to something else, we need to communicate
-      // that we're users of the value we simplified to.
-      if (I != V) {
-        // Don't add temporary instructions to the user lists.
-        if (!AllTempInstructions.count(I))
-          addAdditionalUsers(V, I);
-      }
-      return createVariableOrConstant(CC->getLeader());
+      return ExprResult::some(createVariableOrConstant(CC->getLeader()), V);
     }
     if (CC->getDefiningExpr()) {
-      // If we simplified to something else, we need to communicate
-      // that we're users of the value we simplified to.
-      if (I != V) {
-        // Don't add temporary instructions to the user lists.
-        if (!AllTempInstructions.count(I))
-          addAdditionalUsers(V, I);
-      }
-
       if (I)
         LLVM_DEBUG(dbgs() << "Simplified " << *I << " to "
                           << " expression " << *CC->getDefiningExpr() << "\n");
       NumGVNOpsSimplified++;
       deleteExpression(E);
-      return CC->getDefiningExpr();
+      return ExprResult::some(CC->getDefiningExpr(), V);
     }
   }
 
-  return nullptr;
+  return ExprResult::none();
 }
 
 // Create a value expression from the instruction I, replacing operands with
 // their leaders.
 
-const Expression *NewGVN::createExpression(Instruction *I) const {
+NewGVN::ExprResult NewGVN::createExpression(Instruction *I) const {
   auto *E = new (ExpressionAllocator) BasicExpression(I->getNumOperands());
 
   bool AllConstant = setBasicExpressionInfo(I, E);
@@ -1149,8 +1151,8 @@ const Expression *NewGVN::createExpression(Instruction *I) const {
             E->getOperand(1)->getType() == I->getOperand(1)->getType()));
     Value *V =
         SimplifyCmpInst(Predicate, E->getOperand(0), E->getOperand(1), SQ);
-    if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))
-      return SimplifiedE;
+    if (auto Simplified = checkExprResults(E, I, V))
+      return Simplified;
   } else if (isa<SelectInst>(I)) {
     if (isa<Constant>(E->getOperand(0)) ||
         E->getOperand(1) == E->getOperand(2)) {
@@ -1158,24 +1160,24 @@ const Expression *NewGVN::createExpression(Instruction *I) const {
              E->getOperand(2)->getType() == I->getOperand(2)->getType());
       Value *V = SimplifySelectInst(E->getOperand(0), E->getOperand(1),
                                     E->getOperand(2), SQ);
-      if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))
-        return SimplifiedE;
+      if (auto Simplified = checkExprResults(E, I, V))
+        return Simplified;
     }
   } else if (I->isBinaryOp()) {
     Value *V =
         SimplifyBinOp(E->getOpcode(), E->getOperand(0), E->getOperand(1), SQ);
-    if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))
-      return SimplifiedE;
+    if (auto Simplified = checkExprResults(E, I, V))
+      return Simplified;
   } else if (auto *CI = dyn_cast<CastInst>(I)) {
     Value *V =
         SimplifyCastInst(CI->getOpcode(), E->getOperand(0), CI->getType(), SQ);
-    if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))
-      return SimplifiedE;
+    if (auto Simplified = checkExprResults(E, I, V))
+      return Simplified;
   } else if (isa<GetElementPtrInst>(I)) {
     Value *V = SimplifyGEPInst(
         E->getType(), ArrayRef<Value *>(E->op_begin(), E->op_end()), SQ);
-    if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))
-      return SimplifiedE;
+    if (auto Simplified = checkExprResults(E, I, V))
+      return Simplified;
   } else if (AllConstant) {
     // We don't bother trying to simplify unless all of the operands
     // were constant.
@@ -1189,10 +1191,10 @@ const Expression *NewGVN::createExpression(Instruction *I) const {
       C.emplace_back(cast<Constant>(Arg));
 
     if (Value *V = ConstantFoldInstOperands(I, C, DL, TLI))
-      if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))
-        return SimplifiedE;
+      if (auto Simplified = checkExprResults(E, I, V))
+        return Simplified;
   }
-  return E;
+  return ExprResult::some(E);
 }
 
 const AggregateValueExpression *
@@ -1778,7 +1780,7 @@ NewGVN::performSymbolicAggrValueEvaluation(Instruction *I) const {
   return createAggregateValueExpression(I);
 }
 
-const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const {
+NewGVN::ExprResult NewGVN::performSymbolicCmpEvaluation(Instruction *I) const {
   assert(isa<CmpInst>(I) && "Expected a cmp instruction.");
 
   auto *CI = cast<CmpInst>(I);
@@ -1798,14 +1800,17 @@ const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const {
   // of an assume.
   auto *CmpPI = PredInfo->getPredicateInfoFor(I);
   if (dyn_cast_or_null<PredicateAssume>(CmpPI))
-    return createConstantExpression(ConstantInt::getTrue(CI->getType()));
+    return ExprResult::some(
+        createConstantExpression(ConstantInt::getTrue(CI->getType())));
 
   if (Op0 == Op1) {
     // This condition does not depend on predicates, no need to add users
     if (CI->isTrueWhenEqual())
-      return createConstantExpression(ConstantInt::getTrue(CI->getType()));
+      return ExprResult::some(
+          createConstantExpression(ConstantInt::getTrue(CI->getType())));
     else if (CI->isFalseWhenEqual())
-      return createConstantExpression(ConstantInt::getFalse(CI->getType()));
+      return ExprResult::some(
+          createConstantExpression(ConstantInt::getFalse(CI->getType())));
   }
 
   // NOTE: Because we are comparing both operands here and below, and using
@@ -1865,15 +1870,15 @@ const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const {
           if (CmpInst::isImpliedTrueByMatchingCmp(BranchPredicate,
                                                   OurPredicate)) {
             addPredicateUsers(PI, I);
-            return createConstantExpression(
-                ConstantInt::getTrue(CI->getType()));
+            return ExprResult::some(
+                createConstantExpression(ConstantInt::getTrue(CI->getType())));
           }
 
           if (CmpInst::isImpliedFalseByMatchingCmp(BranchPredicate,
                                                    OurPredicate)) {
             addPredicateUsers(PI, I);
-            return createConstantExpression(
-                ConstantInt::getFalse(CI->getType()));
+            return ExprResult::some(
+                createConstantExpression(ConstantInt::getFalse(CI->getType())));
           }
         } else {
           // Just handle the ne and eq cases, where if we have the same
@@ -1881,14 +1886,14 @@ const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const {
           if (BranchPredicate == OurPredicate) {
             addPredicateUsers(PI, I);
             // Same predicate, same ops,we know it was false, so this is false.
-            return createConstantExpression(
-                ConstantInt::getFalse(CI->getType()));
+            return ExprResult::some(
+                createConstantExpression(ConstantInt::getFalse(CI->getType())));
           } else if (BranchPredicate ==
                      CmpInst::getInversePredicate(OurPredicate)) {
             addPredicateUsers(PI, I);
             // Inverse predicate, we know the other was false, so this is true.
-            return createConstantExpression(
-                ConstantInt::getTrue(CI->getType()));
+            return ExprResult::some(
+                createConstantExpression(ConstantInt::getTrue(CI->getType())));
           }
         }
       }
@@ -1899,9 +1904,10 @@ const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const {
 }
 
 // Substitute and symbolize the value before value numbering.
-const Expression *
+NewGVN::ExprResult
 NewGVN::performSymbolicEvaluation(Value *V,
                                   SmallPtrSetImpl<Value *> &Visited) const {
+
   const Expression *E = nullptr;
   if (auto *C = dyn_cast<Constant>(V))
     E = createConstantExpression(C);
@@ -1937,11 +1943,11 @@ NewGVN::performSymbolicEvaluation(Value *V,
       break;
     case Instruction::BitCast:
     case Instruction::AddrSpaceCast:
-      E = createExpression(I);
+      return createExpression(I);
       break;
     case Instruction::ICmp:
     case Instruction::FCmp:
-      E = performSymbolicCmpEvaluation(I);
+      return performSymbolicCmpEvaluation(I);
       break;
     case Instruction::FNeg:
     case Instruction::Add:
@@ -1977,16 +1983,16 @@ NewGVN::performSymbolicEvaluation(Value *V,
     case Instruction::ExtractElement:
     case Instruction::InsertElement:
     case Instruction::GetElementPtr:
-      E = createExpression(I);
+      return createExpression(I);
       break;
     case Instruction::ShuffleVector:
       // FIXME: Add support for shufflevector to createExpression.
-      return nullptr;
+      return ExprResult::none();
     default:
-      return nullptr;
+      return ExprResult::none();
     }
   }
-  return E;
+  return ExprResult::some(E);
 }
 
 // Look up a container of values/instructions in a map, and touch all the
@@ -2007,6 +2013,12 @@ void NewGVN::addAdditionalUsers(Value *To, Value *User) const {
     AdditionalUsers[To].insert(User);
 }
 
+void NewGVN::addAdditionalUsers(ExprResult &Res, Value *User) const {
+  if (Res.ExtraDep && Res.ExtraDep != User)
+    addAdditionalUsers(Res.ExtraDep, User);
+  Res.ExtraDep = nullptr;
+}
+
 void NewGVN::markUsersTouched(Value *V) {
   // Now mark the users as touched.
   for (auto *User : V->users()) {
@@ -2414,9 +2426,14 @@ void NewGVN::processOutgoingEdges(Instruction *TI, BasicBlock *B) {
     Value *CondEvaluated = findConditionEquivalence(Cond);
     if (!CondEvaluated) {
       if (auto *I = dyn_cast<Instruction>(Cond)) {
-        const Expression *E = createExpression(I);
-        if (const auto *CE = dyn_cast<ConstantExpression>(E)) {
+        auto Res = createExpression(I);
+        if (const auto *CE = dyn_cast<ConstantExpression>(Res.Expr)) {
           CondEvaluated = CE->getConstantValue();
+          addAdditionalUsers(Res, I);
+        } else {
+          // Did not use simplification result, no need to add the extra
+          // dependency.
+          Res.ExtraDep = nullptr;
         }
       } else if (isa<ConstantInt>(Cond)) {
         CondEvaluated = Cond;
@@ -2600,7 +2617,9 @@ Value *NewGVN::findLeaderForInst(Instruction *TransInst,
   TempToBlock.insert({TransInst, PredBB});
   InstrDFS.insert({TransInst, IDFSNum});
 
-  const Expression *E = performSymbolicEvaluation(TransInst, Visited);
+  auto Res = performSymbolicEvaluation(TransInst, Visited);
+  const Expression *E = Res.Expr;
+  addAdditionalUsers(Res, OrigInst);
   InstrDFS.erase(TransInst);
   AllTempInstructions.erase(TransInst);
   TempToBlock.erase(TransInst);
@@ -3027,7 +3046,10 @@ void NewGVN::valueNumberInstruction(Instruction *I) {
     const Expression *Symbolized = nullptr;
     SmallPtrSet<Value *, 2> Visited;
     if (DebugCounter::shouldExecute(VNCounter)) {
-      Symbolized = performSymbolicEvaluation(I, Visited);
+      auto Res = performSymbolicEvaluation(I, Visited);
+      Symbolized = Res.Expr;
+      addAdditionalUsers(Res, I);
+
       // Make a phi of ops if necessary
       if (Symbolized && !isa<ConstantExpression>(Symbolized) &&
           !isa<VariableExpression>(Symbolized) && PHINodeUses.count(I)) {
diff --git a/llvm/test/Transforms/NewGVN/phi-of-ops-simplification-dependencies.ll b/llvm/test/Transforms/NewGVN/phi-of-ops-simplification-dependencies.ll
new file mode 100644 (file)
index 0000000..c340930
--- /dev/null
@@ -0,0 +1,118 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -newgvn -S %s | FileCheck %s
+
+declare void @use.i16(i16*)
+declare void @use.i32(i32)
+
+; Test cases from PR35074, where the simplification dependencies need to be
+; tracked for phi-of-ops root instructions.
+
+define void @test1() {
+; CHECK-LABEL: @test1(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label [[FOR_COND:%.*]]
+; CHECK:       for.cond:
+; CHECK-NEXT:    [[PHIOFOPS:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[Y_0:%.*]], [[FOR_INC6:%.*]] ]
+; CHECK-NEXT:    [[Y_0]] = phi i32 [ 1, [[ENTRY]] ], [ [[INC7:%.*]], [[FOR_INC6]] ]
+; CHECK-NEXT:    br i1 undef, label [[FOR_INC6]], label [[FOR_BODY_LR_PH:%.*]]
+; CHECK:       for.body.lr.ph:
+; CHECK-NEXT:    br label [[FOR_BODY4:%.*]]
+; CHECK:       for.body4:
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i32 [[PHIOFOPS]], [[Y_0]]
+; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_END:%.*]], label [[FOR_BODY4_1:%.*]]
+; CHECK:       for.end:
+; CHECK-NEXT:    ret void
+; CHECK:       for.inc6:
+; CHECK-NEXT:    [[INC7]] = add nuw nsw i32 [[Y_0]], 1
+; CHECK-NEXT:    br label [[FOR_COND]]
+; CHECK:       for.body4.1:
+; CHECK-NEXT:    [[INC_1:%.*]] = add nuw nsw i32 [[Y_0]], 1
+; CHECK-NEXT:    tail call void @use.i32(i32 [[INC_1]])
+; CHECK-NEXT:    br label [[FOR_END]]
+;
+entry:
+  br label %for.cond
+
+for.cond:                                         ; preds = %for.inc6, %entry
+  %y.0 = phi i32 [ 1, %entry ], [ %inc7, %for.inc6 ]
+  br i1 undef, label %for.inc6, label %for.body.lr.ph
+
+for.body.lr.ph:                                   ; preds = %for.cond
+  %sub = add nsw i32 %y.0, -1
+  br label %for.body4
+
+for.body4:                                        ; preds = %for.body.lr.ph
+  %cmp = icmp ugt i32 %sub, %y.0
+  br i1 %cmp, label %for.end, label %for.body4.1
+
+for.end:                                          ; preds = %for.body4.1, %for.body4
+  ret void
+
+for.inc6:                                         ; preds = %for.cond
+  %inc7 = add nuw nsw i32 %y.0, 1
+  br label %for.cond
+
+for.body4.1:                                      ; preds = %for.body4
+  %inc.1 = add nuw nsw i32 %y.0, 1
+  tail call void @use.i32(i32 %inc.1)
+  br label %for.end
+}
+
+define void @test2(i1 %c, i16* %ptr, i64 %N) {
+; CHECK-LABEL: @test2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label [[HEADER:%.*]]
+; CHECK:       header:
+; CHECK-NEXT:    [[PHIOFOPS:%.*]] = phi i64 [ -1, [[ENTRY:%.*]] ], [ [[IV:%.*]], [[LATCH:%.*]] ]
+; CHECK-NEXT:    [[IV]] = phi i64 [ [[IV_NEXT:%.*]], [[LATCH]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT:    br i1 [[C:%.*]], label [[IF_THEN:%.*]], label [[IF_ELSE:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq i64 [[IV]], 0
+; CHECK-NEXT:    br i1 [[CMP1]], label [[LATCH]], label [[LOR_RHS:%.*]]
+; CHECK:       lor.rhs:
+; CHECK-NEXT:    [[IV_ADD_1:%.*]] = add i64 [[IV]], 1
+; CHECK-NEXT:    [[IDX_1:%.*]] = getelementptr inbounds i16, i16* [[PTR:%.*]], i64 [[IV_ADD_1]]
+; CHECK-NEXT:    call void @use.i16(i16* [[IDX_1]])
+; CHECK-NEXT:    ret void
+; CHECK:       if.else:
+; CHECK-NEXT:    [[IDX_2:%.*]] = getelementptr inbounds i16, i16* [[PTR]], i64 [[PHIOFOPS]]
+; CHECK-NEXT:    call void @use.i16(i16* [[IDX_2]])
+; CHECK-NEXT:    br label [[LATCH]]
+; CHECK:       latch:
+; CHECK-NEXT:    [[IV_NEXT]] = add i64 [[IV]], 1
+; CHECK-NEXT:    [[EC:%.*]] = icmp ugt i64 [[IV_NEXT]], [[N:%.*]]
+; CHECK-NEXT:    br i1 [[EC]], label [[HEADER]], label [[EXIT:%.*]]
+; CHECK:       exit:
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %header
+
+header:                                         ; preds = %for.inc, %entry
+  %iv = phi i64 [ %iv.next, %latch ], [ 0, %entry ]
+  br i1 %c, label %if.then, label %if.else
+
+if.then:
+  %cmp1 = icmp eq i64 %iv, 0
+  br i1 %cmp1, label %latch, label %lor.rhs
+
+lor.rhs:                                          ; preds = %if.then
+  %iv.add.1 = add i64 %iv, 1
+  %idx.1 = getelementptr inbounds i16, i16* %ptr, i64 %iv.add.1
+  call void @use.i16(i16* %idx.1)
+  ret void
+
+if.else:
+  %iv.sub.1 = add i64 %iv, -1
+  %idx.2 = getelementptr inbounds i16, i16* %ptr, i64 %iv.sub.1
+  call void @use.i16(i16* %idx.2)
+  br label %latch
+
+latch:
+  %iv.next = add i64 %iv, 1
+  %ec = icmp ugt i64 %iv.next, %N
+  br i1 %ec, label %header, label %exit
+
+exit:
+  ret void
+}