Fix division range estimation error in simplifier (#6244)
authorKrzysztof Parzyszek <kparzysz@quicinc.com>
Tue, 11 Aug 2020 15:35:06 +0000 (10:35 -0500)
committerGitHub <noreply@github.com>
Tue, 11 Aug 2020 15:35:06 +0000 (08:35 -0700)
Division a/b assumes maximum values when b is close to 0. Account
for that when estimating the range for a/b when 0 belongs to the
estimated range for b.

Assume that a division by zero cannot happen in a valid program,
so in such cases treat the range for b as a union
  [b.min_value, -1] u [1, b.max_value]

src/arith/const_int_bound.cc
tests/python/unittest/test_arith_const_int_bound.py

index be830d3..fbb52a9 100644 (file)
@@ -205,17 +205,14 @@ class ConstIntBoundAnalyzer::Impl
   Entry VisitExpr_(const MulNode* op) final {
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
-    return BinaryOpBoundry(a, b, InfAwareMul);
+    return BinaryOpBoundary(a, b, InfAwareMul);
   }
 
   Entry VisitExpr_(const DivNode* op) final {
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
     CHECK(!b.is_const(0)) << "divide by zero";
-    // assume no division by 0
-    if (b.min_value == 0) b.min_value = 1;
-    if (b.max_value == 0) b.max_value = -1;
-    return BinaryOpBoundry(a, b, InfAwareDiv);
+    return HandleDivision(a, b, op->dtype, InfAwareDiv);
   }
 
   Entry VisitExpr_(const ModNode* op) final {
@@ -244,10 +241,7 @@ class ConstIntBoundAnalyzer::Impl
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
     CHECK(!b.is_const(0)) << "floordiv by zero";
-    // assume no division by 0
-    if (b.min_value == 0) b.min_value = 1;
-    if (b.max_value == 0) b.max_value = -1;
-    return BinaryOpBoundry(a, b, InfAwareFloorDiv);
+    return HandleDivision(a, b, op->dtype, InfAwareFloorDiv);
   }
 
   Entry VisitExpr_(const FloorModNode* op) final {
@@ -331,7 +325,7 @@ class ConstIntBoundAnalyzer::Impl
   Entry VisitRightShift(const CallNode* op) {
     Entry a = VisitExpr(op->args[0]);
     Entry b = VisitExpr(op->args[1]);
-    return BinaryOpBoundry(a, b, InfAwareRightShift);
+    return BinaryOpBoundary(a, b, InfAwareRightShift);
   }
 
   Entry VisitBitwiseAnd(const CallNode* op) {
@@ -380,14 +374,14 @@ class ConstIntBoundAnalyzer::Impl
   // internal helper functions
   /*!
    * \brief Get boundary of binary op who are monotonic wrt to one argument.
-   * \param param a The entry of the left operand.
-   * \param param a The entry of the right operand.
+   * \param a The entry of the left operand.
+   * \param b The entry of the right operand.
    * \param op The operator.
    * \tparam F the operator function type.
    * \return The result.
    */
   template <typename F>
-  static Entry BinaryOpBoundry(Entry a, Entry b, const F& op) {
+  static Entry BinaryOpBoundary(Entry a, Entry b, const F& op) {
     Entry ret;
     // The boundary point must be shihft of the original boundary.
     int64_t v1 = op(a.min_value, b.min_value);
@@ -399,6 +393,38 @@ class ConstIntBoundAnalyzer::Impl
     return ret;
   }
   /*!
+   * \brief Get value boundaries of division (e.g. Div or FloorDiv).
+   * \param a The entry of the left operand.
+   * \param b The entry of the right operand.
+   * \param dt The data type of the division operator.
+   * \param op The division operator.
+   * \tparam F the operator function type.
+   * \return The result.
+   */
+  template <typename F>
+  static Entry HandleDivision(Entry a, Entry b, DataType dt, const F& op) {
+    // Here we have a / b.
+    // The largest value of the division will be for the smallest (with
+    // respect to the absolute value) value of b. If the range of b starts
+    // at a negative value and ends at a positive one, narrow it down to
+    // be closer to 0, because BinaryOpBoundary only checks end-points of
+    // the domain ranges.
+
+    // If the range of b contains 0, then some infinity will be involved
+    if (b.min_value <= 0 && 0 <= b.max_value) {
+      Entry b_neg = b.min_value < 0 ? MakeBound(b.min_value, -1) : Everything(dt);
+      Entry b_pos = b.max_value > 0 ? MakeBound(1, b.max_value) : Everything(dt);
+
+      Entry e_neg = BinaryOpBoundary(a, b_neg, op);
+      Entry e_pos = BinaryOpBoundary(a, b_pos, op);
+
+      return MakeBound(std::min(e_neg.min_value, e_pos.min_value),
+                       std::max(e_neg.max_value, e_pos.max_value));
+    }
+    // If the range of b does not have 0, use BinaryOpBoundary.
+    return BinaryOpBoundary(a, b, op);
+  }
+  /*!
    * \brief Compute x + y, aware of inf.
    * \param x The left operand.
    * \param y The right operand.
index c5794cd..9ead0d4 100644 (file)
@@ -122,6 +122,12 @@ def test_truncdiv_bound():
     assert bd.min_value == bd.NEG_INF
     assert bd.max_value == bd.POS_INF
 
+    analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True)
+    analyzer.update(y, tvm.arith.ConstIntBound(-4, 12), override=True)
+    bd = analyzer.const_int_bound(tdiv(x, y))
+    assert bd.min_value == -9
+    assert bd.max_value == 9
+
 
 def test_truncmod_bound():
     analyzer = tvm.arith.Analyzer()
@@ -169,6 +175,12 @@ def test_floordiv_bound():
     assert bd.min_value == bd.NEG_INF
     assert bd.max_value == bd.POS_INF
 
+    analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True)
+    analyzer.update(y, tvm.arith.ConstIntBound(-4, 12), override=True)
+    bd = analyzer.const_int_bound(fld(x, y))
+    assert bd.min_value == -9
+    assert bd.max_value == 9
+
 
 def test_floormod_bound():
     analyzer = tvm.arith.Analyzer()