[ARITH] Revamp IntSet (#3272)
authorTianqi Chen <tqchen@users.noreply.github.com>
Thu, 13 Jun 2019 20:09:58 +0000 (13:09 -0700)
committerGitHub <noreply@github.com>
Thu, 13 Jun 2019 20:09:58 +0000 (13:09 -0700)
17 files changed:
include/tvm/arithmetic.h
python/tvm/arith.py
src/api/api_arith.cc
src/arithmetic/analyzer.cc
src/arithmetic/bound_deducer.cc
src/arithmetic/canonical_simplify.cc
src/arithmetic/compute_expr.h
src/arithmetic/const_fold.h
src/arithmetic/detect_linear_equation.cc
src/arithmetic/int_op_overflow.h
src/arithmetic/int_set.cc
src/arithmetic/int_set.h [new file with mode: 0644]
src/arithmetic/int_set_internal.h [deleted file]
src/lang/expr_operator.cc
src/pass/loop_partition.cc
tests/python/unittest/test_arith_deduce_bound.py [new file with mode: 0644]
tests/python/unittest/test_arith_intset.py

index 600e3c5..c506268 100644 (file)
@@ -328,71 +328,14 @@ class ConstraintContext {
   std::function<void()> exit_;
 };
 
-/*!
- * \brief Analyzer that contains bunch of sub-analyzers.
- *
- * Each sub-analyzer can make use of another sub-analyzer
- * by weak reference of this.
- *
- * NOTE for sub-analyzer developers:
- * If the analyzer uses memoization, we need to clear the internal
- * cache when information about a Var has been overrideen.
- */
-class Analyzer {
- public:
-  /*! \brief sub-analyzer: const integer bound */
-  ConstIntBoundAnalyzer const_int_bound;
-  /*! \brief sub-analyzer: modular set */
-  ModularSetAnalyzer modular_set;
-  /*! \brief sub-analyzer rewrite simplify */
-  RewriteSimplifier rewrite_simplify;
-  /*! \brief sub-analyzer canonical simplify */
-  CanonicalSimplifier canonical_simplify;
-  /*! \brief constructor */
-  Analyzer();
-  /*!
-   * \brief Notify all the sub-analyzers that var
-   *        is created and binded to expr.
-   *
-   *  Each var can only be binded once.
-   *
-   * \param var The variable.
-   * \param expr The expression we bind to.
-   */
-  void Bind(const VarExpr& var, const Expr& expr);
-  /*!
-   * \brief Notify all the sub-analyzers that var
-   *        is created and binded to a range.
-   *
-   *  Each var can only be binded once.
-   *
-   * \param var The variable.
-   * \param range The range we bind to.
-   */
-  void Bind(const VarExpr& var, const Range& range);
-  /*!
-   * \brief Whether can we proof expr >= val.
-
-   *  Non-negative proof is very useful in integer analysis
-   *  to lower divisions and mods given difference in trunc and ceil mode.
-   *
-   * \param expr The expression.
-   * \param lower_bound The lower bound.
-   * \return Whether we can proof it.
-   *
-   * \note Analyzer will call into sub-analyzers to get the result.
-   */
-  bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
-};
-
 //-----------------------------------------------
-// Integer set abstraction API.
+// Integer set data structure.
 //
 // This is a API build on top of the base
 // integer analysis API to provide set analysis.
 //------------------------------------------------
 /*!
- * \brief Sign of an expression or set.
+ * \brief Sign type of an integer expression.
  */
 enum SignType {
   kPositive,
@@ -401,8 +344,13 @@ enum SignType {
   kUnknown
 };
 
-// internal node container of int set.
-struct IntSetNode;
+/*!
+ * \brief Base class of all IntSet containers.
+ */
+struct IntSetNode : public Node {
+  static constexpr const char* _type_key = "IntSet";
+  TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
+};
 
 /*!
  * \brief Integer set class, represent a set of integers in one dimension.
@@ -424,11 +372,6 @@ class IntSet : public NodeRef {
    * \return The covering range.
    */
   Range cover_range(Range max_range) const;
-  /*!
-   * \brief find an interval that covers the set.
-   * \return The covering interval set.
-   */
-  IntSet cover_interval() const;
   /*! \return Lower bound of the set */
   Expr min() const;
   /*! \return upper bound of the set */
@@ -493,33 +436,91 @@ class IntSet : public NodeRef {
 };
 
 /*!
- * \brief Base class of all IntSet containers.
+ * \brief Integer set analyzer.
  */
-struct IntSetNode : public Node {
-  static constexpr const char* _type_key = "IntSet";
-  TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
+class IntSetAnalyzer {
+ public:
+  /*!
+   * \brief Find a symbolic integer set that contains all possible values of
+   *        expr given the domain of each variables.
+   *
+   * \param expr The expression of interest.
+   * \param dom_map The domain map to indicate which variable to relax.
+   * \return the result of the analysis.
+   */
+  IntSet operator()(const Expr& expr, const Map<Var, IntSet>& dom_map);
+
+ private:
+  friend class Analyzer;
+  explicit IntSetAnalyzer(Analyzer* parent);
+  ~IntSetAnalyzer();
+  class Impl;
+  /*! \brief Internal impl */
+  Impl* impl_;
 };
 
 /*!
- * \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
- *  Where coeff[i] and base are invariant of var[j] for all i and j.
+ * \brief Analyzer that contains bunch of sub-analyzers.
  *
- * \param e The expression to be detected.
- * \param vars List of variables to be used in detection.
- * \return [coeff[i]] if it is possible, empty array if it is not.
- */
-Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars);
-
-/*!
- * \brief Detect if expression corresponds to clip bound of the vars
+ * Each sub-analyzer can make use of another sub-analyzer
+ * by weak reference of this.
  *
- * \param e The expression to be detected.
- * \param vars List of variables to be used in detection.
- * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
- *          return empty if the e does not match the pattern.
+ * NOTE for sub-analyzer developers:
+ * If the analyzer uses memoization, we need to clear the internal
+ * cache when information about a Var has been overridden.
  */
-Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars);
+class Analyzer {
+ public:
+  /*! \brief sub-analyzer: const integer bound */
+  ConstIntBoundAnalyzer const_int_bound;
+  /*! \brief sub-analyzer: modular set */
+  ModularSetAnalyzer modular_set;
+  /*! \brief sub-analyzer rewrite simplify */
+  RewriteSimplifier rewrite_simplify;
+  /*! \brief sub-analyzer canonical simplify */
+  CanonicalSimplifier canonical_simplify;
+  /*! \brief sub-analyzer: int set */
+  IntSetAnalyzer int_set;
+  /*! \brief constructor */
+  Analyzer();
+  /*!
+   * \brief Notify all the sub-analyzers that var
+   *        is created and binded to expr.
+   *
+   *  Each var can only be binded once.
+   *
+   * \param var The variable.
+   * \param expr The expression we bind to.
+   */
+  void Bind(const VarExpr& var, const Expr& expr);
+  /*!
+   * \brief Notify all the sub-analyzers that var
+   *        is created and binded to a range.
+   *
+   *  Each var can only be binded once.
+   *
+   * \param var The variable.
+   * \param range The range we bind to.
+   */
+  void Bind(const VarExpr& var, const Range& range);
+  /*!
+   * \brief Whether can we prove expr >= val.
+
+   *  Non-negative proof is very useful in integer analysis
+   *  to lower divisions and mods given difference in trunc and ceil mode.
+   *
+   * \param expr The expression.
+   * \param lower_bound The lower bound.
+   * \return Whether we can prove it.
+   *
+   * \note Analyzer will call into sub-analyzers to get the result.
+   */
+  bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
+};
 
+//-----------------------------------------------
+// Integer set legacy API.
+//------------------------------------------------
 /*!
  * \brief Find an symbolic integer set that contains all possible values of
  *  e given the domain of each iteration variables.
@@ -638,6 +639,29 @@ IntSet DeduceBound(Expr v, Expr cond,
  */
 Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);
 
+// Expression pattern detector.
+/*!
+ * \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
+ *  Where coeff[i] and base are invariant of var[j] for all i and j.
+ *
+ * \param e The expression to be detected.
+ * \param vars List of variables to be used in detection.
+ * \return [coeff[i]] if it is possible, empty array if it is not.
+ */
+Array<Expr> DetectLinearEquation(const Expr& e,
+                                 const Array<Var>& vars);
+
+/*!
+ * \brief Detect if expression corresponds to clip bound of the vars
+ *
+ * \param e The expression to be detected.
+ * \param vars List of variables to be used in detection.
+ * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
+ *          return empty if the e does not match the pattern.
+ */
+Array<Expr> DetectClipBound(const Expr& e,
+                            const Array<Var>& vars);
+
 // implementation
 inline const IntSetNode* IntSet::operator->() const {
   return static_cast<const IntSetNode*>(node_.get());
index eda5cb8..4c3c05f 100644 (file)
@@ -32,21 +32,21 @@ class IntSet(NodeBase):
         return _api_internal._IntSetIsEverything(self)
 
 
-@register_node
+@register_node("arith.IntervalSet")
 class IntervalSet(IntSet):
-    """Represent set of continuous interval"""
-    def min(self):
-        """get the minimum value"""
-        return _api_internal._IntervalSetGetMin(self)
-
-    def max(self):
-        """get the maximum value"""
-        return _api_internal._IntervalSetGetMax(self)
+    """Represent set of continuous interval [min_value, max_value]
 
+    Parameters
+    ----------
+    min_value : Expr
+        The minimum value in the interval.
 
-@register_node
-class StrideSet(IntSet):
-    """Represent set of strided integers"""
+    max_value : Expr
+        The maximum value in the interval.
+    """
+    def __init__(self, min_value, max_value):
+        self.__init_handle_by_constructor__(
+            _make_IntervalSet, min_value, max_value)
 
 
 @register_node("arith.ModularSet")
@@ -114,6 +114,7 @@ class Analyzer:
         self._modular_set = _mod("modular_set")
         self._rewrite_simplify = _mod("rewrite_simplify")
         self._canonical_simplify = _mod("canonical_simplify")
+        self._int_set = _mod("int_set")
         self._enter_constraint_context = _mod("enter_constraint_context")
 
     def const_int_bound(self, expr):
@@ -176,6 +177,24 @@ class Analyzer:
         """
         return self._canonical_simplify(expr)
 
+    def int_set(self, expr, dom_map):
+        """Compute a symbolic IntSet that covers expr for all values in dom_map.
+
+        Parameters
+        ----------
+        expr : tvm.Expr
+            The expression.
+
+        dom_map : Dict[Var, tvm.arith.IntSet]
+            The domain for variables to be relaxed.
+
+        Returns
+        -------
+        result : IntSet
+            The result.
+        """
+        return self._int_set(expr, dom_map)
+
     def bind(self, var, expr):
         """Bind a variable to the expression.
 
index 4d5d8bd..f31f02b 100644 (file)
@@ -39,6 +39,7 @@ TVM_REGISTER_API("arith.intset_vector")
 TVM_REGISTER_API("arith.intset_interval")
 .set_body_typed(IntSet::interval);
 
+
 TVM_REGISTER_API("arith.DetectLinearEquation")
 .set_body_typed(DetectLinearEquation);
 
@@ -110,6 +111,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
         return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
             *ret = self->canonical_simplify(args[0]);
         });
+      } else if (name == "int_set") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            *ret = self->int_set(args[0], args[1]);
+        });
       } else if (name == "bind") {
         return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
             auto& sptr = args[1].node_sptr();
index bd8c700..10a1c7f 100644 (file)
@@ -31,7 +31,8 @@ Analyzer::Analyzer()
     : const_int_bound(this),
       modular_set(this),
       rewrite_simplify(this),
-      canonical_simplify(this) {
+      canonical_simplify(this),
+      int_set(this) {
 }
 
 void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
@@ -74,7 +75,7 @@ void ConstraintContext::ExitWithScope() {
 
 bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
   if (const auto* ptr = expr.as<ir::IntImm>()) {
-    return ptr->value > lower_bound;
+    return ptr->value >= lower_bound;
   }
   auto bd = this->const_int_bound(this->rewrite_simplify(expr));
   if (bd->min_value >= lower_bound) return true;
index 89e556c..395a371 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 
 #include <unordered_set>
 #include <unordered_map>
+#include "int_set.h"
 
 namespace tvm {
 namespace arith {
 
 using namespace ir;
-using HalideIR::Internal::Interval;
 
 // a visitor to find the path to the target variable
 // from a expression.
@@ -293,7 +293,7 @@ IntSet DeduceBound(Expr v, Expr e,
   BoundDeducer d(v, e, hint_map, relax_map);
   d.Deduce();
   if (!d.success) return IntSet::nothing();
-  Expr min = Interval::neg_inf, max = Interval::pos_inf;
+  Expr min = neg_inf(), max = pos_inf();
   if (d.is_greater) {
     min = d.result;
   } else {
index 1bf1f84..a50cbfb 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2019 by Contributors
  * \file canonical_simplify.cc
  * \brief Canonical form based simplification.
  */
@@ -763,7 +762,10 @@ Mutate_(const Mod* op, const Expr& self) {
           if (TryCompare(temp, cval) == kLT) {
             return temp;
           } else {
-            return SplitModConst(ToSplitExpr(temp), cval);
+            // contonue to use logic below.
+            a = extra;
+            psum = a.as<SumExprNode>();
+            CHECK(psum != nullptr);
           }
         }
       }
index ff2fb8d..cc54bff 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -27,8 +27,8 @@
 #define TVM_ARITHMETIC_COMPUTE_EXPR_H_
 
 #include <tvm/ir.h>
-#include <arithmetic/Interval.h>
 #include <limits>
+#include <algorithm>
 
 namespace tvm {
 namespace arith {
@@ -105,12 +105,12 @@ inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) {
 
 template<>
 inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) {
-  return HalideIR::Internal::Interval::make_max(a, b);
+  return max(a, b);
 }
 
 template<>
 inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
-  return HalideIR::Internal::Interval::make_min(a, b);
+  return min(a, b);
 }
 
 template<typename Op>
index fbf8fe7..ec50aef 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -206,6 +206,7 @@ inline Expr TryConstFold<ir::Min>(Expr a, Expr b) {
       if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value));
       if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value));
     });
+  if (a.same_as(b)) return a;
   return Expr();
 }
 
@@ -216,6 +217,7 @@ inline Expr TryConstFold<ir::Max>(Expr a, Expr b) {
       if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value));
       if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value));
     });
+  if (a.same_as(b)) return a;
   return Expr();
 }
 
@@ -307,6 +309,58 @@ inline Expr TryConstFold<ir::Not>(Expr a) {
   return Expr();
 }
 
+/*! \brief Helper namespace for symbolic value limits */
+struct SymbolicLimits {
+  /*! \brief positive infinity */
+  static Expr pos_inf_;
+  /*! \brief negative infinity */
+  static Expr neg_inf_;
+};
+
+/*!
+ * \brief Opaque expression representing positive infinity.
+ *
+ *  It can can only be used as parameter of by min/max
+ *  for integer analysis and cannot be used in normal expressions.
+ *
+ * \return positive infinity.
+ */
+inline Expr pos_inf() {
+  return SymbolicLimits::pos_inf_;
+}
+
+/*!
+ * \brief Check if value is positive infinity.
+ * \param value The value to be checked.
+ *
+ * \return The check result.
+ */
+inline bool is_pos_inf(const Expr& value) {
+  return value.same_as(SymbolicLimits::pos_inf_);
+}
+
+/*!
+ * \brief Opaque expression representing negative infinity.
+ *
+ *  It can can only be used as parameter of by min/max
+ *  for integer analysis and cannot be used in normal expressions.
+ *
+ * \return negative infinity.
+ */
+inline Expr neg_inf() {
+  return SymbolicLimits::neg_inf_;
+}
+
+/*!
+ * \brief Check if value is negative infinity.
+ * \param value The value to be checked.
+ *
+ * \return The check result.
+ */
+inline bool is_neg_inf(const Expr& value) {
+  return value.same_as(SymbolicLimits::neg_inf_);
+}
+
 }  // namespace arith
 }  // namespace tvm
 #endif  // TVM_ARITHMETIC_CONST_FOLD_H_
index 2fe21fe..e584c8b 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,8 @@
 
 /*!
  *  Copyright (c) 2017 by Contributors
- * \file bound_deducer.cc
- * \brief Utility to deduce bound of expression
+ * \file detect_linear_equation.cc
+ * \brief Utility to detect patterns in the expression.
  */
 #include <tvm/expr.h>
 #include <tvm/ir_pass.h>
index 87f4f05..b78f21c 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
index abbb7cd..75a4aaf 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file int_set.cc
  * \brief The integer set functions
  */
 #include <tvm/ir.h>
-#include <tvm/ir_pass.h>
-#include <tvm/arithmetic.h>
 #include <tvm/ir_functor_ext.h>
-#include <arithmetic/Interval.h>
+#include <tvm/api_registry.h>
+#include <utility>
+#include <algorithm>
 #include <unordered_map>
-#include "compute_expr.h"
-#include "int_set_internal.h"
+#include "int_set.h"
+#include "pattern_match.h"
 
 namespace tvm {
 namespace arith {
 
-using HalideIR::Internal::Interval;
-using namespace ir;
-
-inline IntSet IntSet::cover_interval() const {
-  if ((*this).as<IntervalSet>()) return *this;
-  const StrideSet* s =  (*this).as<StrideSet>();
-  if (s) {
-    CHECK_NE(s->extents.size(), 0U);
-    Expr max = s->base.max;
-    for (size_t i = 0; i < s->extents.size(); ++i) {
-      max = max + s->extents[i] * s->strides[i] - s->strides[i];
-    }
-    return IntervalSet::make(s->base.min, Simplify(max));
-  }
-  LOG(FATAL) << "cannot convert set " << (*this)->type_key() << " to interval";
-  return IntSet::everything();
-}
-
-Range IntSet::cover_range(Range max_range) const {
-  IntSet temp;
-  const IntervalSet* s_int = (*this).as<IntervalSet>();
-  if (s_int == nullptr) {
-    temp = this->cover_interval();
-    s_int = temp.as<IntervalSet>();
-  }
-  if (s_int->i.is_bounded()) {
-    return Range::make_by_min_extent(
-        s_int->i.min, Simplify(s_int->i.max + 1 - s_int->i.min));
-  }
-  return max_range;
-}
-
-Expr IntSet::min() const {
-  const IntervalSet* s_int = (*this).as<IntervalSet>();
-  CHECK(s_int);
-  return s_int->i.min;
-}
-
-Expr IntSet::max() const {
-  const IntervalSet* s_int = (*this).as<IntervalSet>();
-  CHECK(s_int);
-  return s_int->i.max;
-}
-
-bool IntSet::is_nothing() const {
-  const IntervalSet* s_int = (*this).as<IntervalSet>();
-  return (s_int && s_int->i.is_empty());
-}
-
-bool IntSet::is_everything() const {
-  const IntervalSet* s_int = (*this).as<IntervalSet>();
-  return (s_int && s_int->i.is_everything());
-}
+Expr SymbolicLimits::pos_inf_ = Var("pos_inf", Handle());
+Expr SymbolicLimits::neg_inf_ = Var("neg_inf", Handle());
 
-bool IntSet::is_single_point() const {
-  const IntervalSet* s_int = (*this).as<IntervalSet>();
-  return (s_int && s_int->i.is_single_point());
+IntervalSet::IntervalSet(Expr min_value, Expr max_value) {
+  auto node = make_node<IntervalSetNode>();
+  node->min_value = std::move(min_value);
+  node->max_value = std::move(max_value);
+  node_ = std::move(node);
 }
 
-bool IntSet::can_prove_positive() const {
-  const IntervalSet* s_int = (*this).as<IntervalSet>();
-  return (s_int && is_positive_const(ir::Simplify(s_int->i.min)));
+IntervalSet MakeIntervalSet(Expr min_value, Expr max_value) {
+  return IntervalSet(min_value, max_value);
 }
 
-bool IntSet::can_prove_negative() const {
-  const IntervalSet* s_int = (*this).as<IntervalSet>();
-  return (s_int && is_negative_const(ir::Simplify(s_int->i.max)));
-}
+TVM_REGISTER_API("arith._make_IntervalSet")
+.set_body_typed(MakeIntervalSet);
 
-bool IntSet::can_prove_non_positive() const {
-  if (const IntervalSet* s_int = (*this).as<IntervalSet>()) {
-    auto max = ir::Simplify(s_int->i.max);
-    return is_zero(max) || is_negative_const(max);
-  }
-  return false;
-}
 
-bool IntSet::can_prove_non_negative() const {
-  if (const IntervalSet* s_int = (*this).as<IntervalSet>()) {
-    // Any reason why we should or should not use can_prove() to implement
-    // these functions?
-    auto min = ir::Simplify(s_int->i.min);
-    return is_zero(min) || is_positive_const(min);
-  }
-  return false;
-}
-
-
-SignType IntSet::sign_type() const {
-  if (can_prove_positive()) {
-    return kPositive;
-  } else if (can_prove_negative()) {
-    return kNegative;
-  } else if (is_single_point() && is_zero(point_value())) {
-    return kZero;
+IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
+  Expr max_value = min(a->max_value, b->max_value);
+  Expr min_value = max(a->min_value, b->min_value);
+  if ((max_value.type().is_int() || max_value.type().is_uint()) &&
+      (min_value.type().is_int() || min_value.type().is_uint()) &&
+      analyzer->CanProveGreaterEqual(min_value - max_value, 1)) {
+    return IntervalSet::Empty();
   } else {
-    return kUnknown;
-  }
-}
-Expr IntSet::point_value() const {
-  const IntervalSet* s_int = (*this).as<IntervalSet>();
-  CHECK(s_int && s_int->i.is_single_point());
-  return s_int->i.min;
-}
-
-IntSet IntSet::nothing() {
-  return IntervalSet::make(Interval::nothing());
-}
-
-IntSet IntSet::everything() {
-  return IntervalSet::make(Interval::everything());
-}
-
-IntSet IntSet::single_point(Expr x) {
-  return IntervalSet::make(Interval::single_point(x));
-}
-
-IntSet IntSet::range(Range r) {
-  // must make sure it can be matched back by MatchRange.
-  if (is_one(r->extent)) {
-    return IntSet::single_point(r->min);
-  }
-  if (is_positive_const(r->extent) && is_const(r->min)) {
-    return IntervalSet::make(
-        r->min, ComputeExpr<Sub>(ComputeExpr<Add>(r->extent, r->min), 1));
-  }
-  return IntervalSet::make(r->min, (r->extent + r->min) - 1);
-}
-
-IntSet IntSet::interval(Expr min, Expr max) {
-  if (min.same_as(max)) {
-    return IntSet::single_point(min);
-  }
-  return IntervalSet::make(min, max);
-}
-
-inline bool prove_equal(Expr lhs, Expr rhs) {
-  return is_zero(ir::Simplify(lhs - rhs));
-}
-
-// Check if a is created from b.
-bool IntSet::match_range(const Range& b) const {
-  const IntSet& a = *this;
-  const IntervalSet* a_int = a.as<IntervalSet>();
-  if (!a_int) return false;
-  const Interval& i = a_int->i;
-  return prove_equal(i.min, b->min) &&
-      prove_equal(i.max, ComputeExpr<Sub>(ComputeExpr<Add>(b->extent, b->min), 1));
-}
-
-inline bool MatchPoint(const IntSet& a,
-                       const Expr& b) {
-  const IntervalSet* a_int = a.as<IntervalSet>();
-  if (!a_int) return false;
-  const Interval& i = a_int->i;
-  return i.is_single_point() && i.min.same_as(b);
-}
-
-IntSet Union(const Array<IntSet>& sets) {
-  if (sets.size() == 0) return IntSet::nothing();
-  if (sets.size() == 1) return sets[0];
-  Interval x = sets[0].cover_interval().as<IntervalSet>()->i;
-  for (size_t i = 1; i < sets.size(); ++i) {
-    IntSet s = sets[i].cover_interval();
-    const Interval& y = s.as<IntervalSet>()->i;
-    x.include(y);
+    return IntervalSet(min_value, max_value);
   }
-  x.max = ir::Simplify(x.max);
-  x.min = ir::Simplify(x.min);
-  return IntervalSet::make(x);
 }
 
-IntSet Intersect(const Array<IntSet>& sets) {
-  Interval x = sets[0].cover_interval().as<IntervalSet>()->i;
-  for (size_t i = 1; i < sets.size(); ++i) {
-    Interval y = sets[i].cover_interval().as<IntervalSet>()->i;
-    x = Interval::make_intersection(x, y);
-  }
-  return IntervalSet::make(x);
+IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
+  Expr max_value = max(a->max_value, b->max_value);
+  Expr min_value = min(a->min_value, b->min_value);
+  return IntervalSet(min_value, max_value);
 }
 
 // type traits
@@ -227,407 +81,623 @@ struct is_logical_op {
     static const bool value = true;             \
   };
 
-// interval related.
-template<typename OP>
-inline IntSet CombineInterval(Interval a, Interval b) {
-  if (a.is_single_point() && b.is_single_point()) {
-    return IntSet::single_point(ComputeExpr<OP>(a.min, b.min));
-  }
-  LOG(WARNING) << "Return Everything in CombineInterval " << OP::_type_key;
-  return IntSet::everything();
+TVM_DECLARE_LOGICAL_OP(And);
+TVM_DECLARE_LOGICAL_OP(Or);
+TVM_DECLARE_LOGICAL_OP(EQ);
+TVM_DECLARE_LOGICAL_OP(NE);
+TVM_DECLARE_LOGICAL_OP(GE);
+TVM_DECLARE_LOGICAL_OP(GT);
+TVM_DECLARE_LOGICAL_OP(LE);
+TVM_DECLARE_LOGICAL_OP(LT);
+TVM_DECLARE_LOGICAL_OP(Not);
+
+/*!
+ * \brief Combine two interval set under arithmetic operations.
+ * \note this can possibly relax the set.
+ */
+template<typename Op>
+inline IntervalSet Combine(Analyzer* analyzer,
+                           IntervalSet a,
+                           IntervalSet b) {
+  if (a->IsSinglePoint() && b->IsSinglePoint()) {
+    Expr res = TryConstFold<Op>(a->min_value, b->min_value);
+    if (!res.defined()) res = Op::make(a->min_value, b->min_value);
+    return IntervalSet::SinglePoint(res);
+  }
+  if (is_logical_op<Op>::value) {
+    return IntervalSet(make_const(a->min_value.type(), 0),
+                       make_const(a->min_value.type(), 1));
+  }
+  if (a->IsEmpty()) return a;
+  if (b->IsEmpty()) return b;
+  if (a->IsEverything()) return a;
+  if (b->IsEverything()) return b;
+  return IntervalSet::Everything();
 }
 
 template<>
-inline IntSet CombineInterval<Add>(Interval a, Interval b) {
-  if (a.is_single_point() && b.is_single_point()) {
-    return IntSet::single_point(ComputeExpr<Add>(a.min, b.min));
-  }
-  Interval r = Interval::everything();
-  if (a.has_lower_bound() && b.has_lower_bound()) {
-    r.min = ComputeExpr<Add>(a.min, b.min);
-  }
-  if (a.has_upper_bound() && b.has_upper_bound()) {
-    r.max = ComputeExpr<Add>(a.max, b.max);
-  }
-  return IntervalSet::make(r);
+inline IntervalSet Combine<ir::Add>(Analyzer* analyer,
+                                    IntervalSet a,
+                                    IntervalSet b) {
+  if (a->IsSinglePoint() && b->IsSinglePoint()) {
+    return IntervalSet::SinglePoint(a->min_value + b->min_value);
+  }
+  if (a->IsEmpty()) return a;
+  if (b->IsEmpty()) return b;
+  Expr min_value =
+      a->HasLowerBound() && b->HasLowerBound() ?
+      a->min_value + b->min_value : neg_inf();
+  Expr max_value =
+      a->HasUpperBound() && b->HasUpperBound() ?
+      a->max_value + b->max_value : pos_inf();
+  return IntervalSet(min_value, max_value);
 }
 
 template<>
-inline IntSet CombineInterval<Sub>(Interval a, Interval b) {
-  if (a.is_single_point() && b.is_single_point()) {
-    return IntSet::single_point(ComputeExpr<Sub>(a.min, b.min));
+inline IntervalSet Combine<ir::Sub>(Analyzer* analyer,
+                                    IntervalSet a,
+                                    IntervalSet b) {
+  if (a->IsSinglePoint() && b->IsSinglePoint()) {
+    return IntervalSet::SinglePoint(a->min_value - b->min_value);
   }
-  Interval r = Interval::everything();
-  if (a.has_lower_bound() && b.has_upper_bound()) {
-    r.min = ComputeExpr<Sub>(a.min, b.max);
-  }
-  if (a.has_upper_bound() && b.has_lower_bound()) {
-    r.max = ComputeExpr<Sub>(a.max, b.min);
-  }
-  return IntervalSet::make(r);
+  if (a->IsEmpty()) return a;
+  if (b->IsEmpty()) return b;
+  Expr min_value =
+      a->HasLowerBound() && b->HasUpperBound() ?
+      a->min_value - b->max_value : neg_inf();
+  Expr max_value =
+      a->HasUpperBound() && b->HasLowerBound() ?
+      a->max_value - b->min_value : pos_inf();
+  return IntervalSet(min_value, max_value);
 }
 
+
 template<>
-inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
-  if (a.is_single_point() && b.is_single_point()) {
-    return IntSet::single_point(ComputeExpr<Mul>(a.min, b.min));
-  }
-  if (a.is_single_point() && !b.is_single_point()) {
+inline IntervalSet Combine<ir::Mul>(Analyzer* analyzer,
+                                    IntervalSet a,
+                                    IntervalSet b) {
+  if (a->IsSinglePoint() && b->IsSinglePoint()) {
+    return IntervalSet::SinglePoint(a->min_value * b->min_value);
+  }
+  if (a->IsEmpty()) return a;
+  if (b->IsEmpty()) return b;
+  if (a->IsSinglePoint()) {
     std::swap(a, b);
   }
-  if (b.is_single_point()) {
-    if (is_zero(b.min)) return IntSet::single_point(0);
-    if (is_one(b.min)) return IntervalSet::make(a);
-    Expr e1 = a.has_lower_bound() ? ComputeExpr<Mul>(a.min, b.min) : a.min;
-    Expr e2 = a.has_upper_bound() ? ComputeExpr<Mul>(a.max, b.min) : a.max;
-    // no relaxation is needed in here due to set is inclusive
-    // TODO(tqchen): consider convert to StrideSet.
-    if (is_positive_const(b.min)) {
-      return IntervalSet::make(e1, e2);
-    } else if (is_negative_const(b.min)) {
-      return IntervalSet::make(e2, e1);
-    } else if (a.is_bounded()) {
+  if (b->IsSinglePoint()) {
+    if (is_zero(b->min_value)) return b;
+    if (is_one(b->min_value)) return a;
+    if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
+      Expr min_value = a->HasLowerBound() ? a->min_value * b->min_value : neg_inf();
+      Expr max_value = a->HasUpperBound() ? a->max_value * b->min_value : pos_inf();
+      return IntervalSet(min_value, max_value);
+    } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
+      Expr min_value = a->HasUpperBound() ? a->max_value * b->min_value : neg_inf();
+      Expr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf();
+      return IntervalSet(min_value, max_value);
+    } else if (a->HasUpperBound() && a->HasLowerBound()) {
       using ir::Select;
-      Expr cmp = b.min >= make_zero(b.min.type().element_of());
-      return IntervalSet::make(Select::make(cmp, e1, e2), Select::make(cmp, e2, e1));
+      Expr sign = b->min_value >= make_zero(b->min_value.type().element_of());
+      Expr e1 = a->min_value * b->min_value;
+      Expr e2 = a->max_value * b->min_value;
+      return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1));
     }
   }
-  LOG(WARNING) << "Return Everything in CombineInterval Mul";
-  return IntSet::everything();
+  DLOG(WARNING) << "Return Everything in CombineInterval Mul";
+  return IntervalSet::Everything();
 }
 
 template<>
-inline IntSet CombineInterval<Div>(Interval a, Interval b) {
-  if (a.is_single_point() && b.is_single_point()) {
-    return IntSet::single_point(ComputeExpr<Div>(a.min, b.min));
-  }
-  if (b.is_single_point()) {
-    if (is_zero(b.min)) {
+inline IntervalSet Combine<ir::Div>(Analyzer* analyzer,
+                                    IntervalSet a,
+                                    IntervalSet b) {
+  if (a->IsSinglePoint() && b->IsSinglePoint()) {
+    return IntervalSet::SinglePoint(a->min_value / b->min_value);
+  }
+  if (a->IsEmpty()) return a;
+  if (b->IsEmpty()) return b;
+  if (b->IsSinglePoint()) {
+    if (is_zero(b->min_value)) {
       LOG(FATAL) << "Divide by zero in CombineInterval Div";
     }
-    if (is_one(b.min)) return IntervalSet::make(a);
-    Expr e1 = a.has_lower_bound() ? ComputeExpr<Div>(a.min, b.min) : a.min;
-    Expr e2 = a.has_upper_bound() ? ComputeExpr<Div>(a.max, b.min) : a.max;
+    if (is_one(b->min_value)) return a;
     // no relaxation is needed in here due to set is inclusive
-    if (is_positive_const(b.min)) {
-      return IntervalSet::make(e1, e2);
-    } else if (is_negative_const(b.min)) {
-      return IntervalSet::make(e2, e1);
-    } else if (a.is_bounded()) {
+    if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
+      Expr min_value = a->HasLowerBound() ? a->min_value / b->min_value : neg_inf();
+      Expr max_value = a->HasUpperBound() ? a->max_value / b->min_value : pos_inf();
+      return IntervalSet(min_value, max_value);
+    } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
+      Expr min_value = a->HasUpperBound() ? a->max_value / b->min_value : neg_inf();
+      Expr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf();
+      return IntervalSet(min_value, max_value);
+    } else if (a->HasUpperBound() && a->HasLowerBound()) {
       using ir::Select;
-      Expr cmp = b.min >= make_zero(b.min.type().element_of());
-      return IntervalSet::make(Select::make(cmp, e1, e2), Select::make(cmp, e2, e1));
+      Expr sign = b->min_value >= make_zero(b->min_value.type().element_of());
+      Expr e1 = a->min_value / b->min_value;
+      Expr e2 = a->max_value / b->min_value;
+      return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1));
     }
   }
-  LOG(WARNING) << "Return Everything in CombineInterval Div";
-  return IntSet::everything();
+  DLOG(WARNING) << "Return Everything in CombineInterval Div";
+  return IntervalSet::Everything();
 }
 
 template<>
-inline IntSet CombineInterval<Mod>(Interval a, Interval b) {
-  if (a.is_single_point() && b.is_single_point()) {
-    return IntSet::single_point(ComputeExpr<Mod>(a.min, b.min));
+inline IntervalSet Combine<ir::Mod>(Analyzer* analyzer,
+                                    IntervalSet a,
+                                    IntervalSet b) {
+  if (a->IsSinglePoint() && b->IsSinglePoint()) {
+    return IntervalSet::SinglePoint(a->min_value % b->min_value);
   }
-  if (b.is_single_point()) {
-    Expr divisor = b.min;
+  if (a->IsEmpty()) return a;
+  if (b->IsEmpty()) return b;
+
+  if (b->IsSinglePoint()) {
+    const Expr& divisor = b->min_value;
     if (is_zero(divisor)) {
       LOG(FATAL) << "Modular by zero in CombineInterval Mod";
     }
-    return IntervalSet::make(make_zero(divisor.type()), divisor - 1);
+    // We need to add more bound constraints throughout the code.
+    // The logic below assumes a is non-negative, which usually
+    // is the case of our application.
+    // TODO(tqchen): add bound constraints for a.
+    if (analyzer->CanProveGreaterEqual(divisor, 0)) {
+      return IntervalSet(make_zero(divisor.type()), divisor - 1);
+    } else {
+      Expr bound = abs(divisor) - 1;
+      return IntervalSet(-bound, bound);
+    }
   }
-
-  LOG(WARNING) << "Return Everything in CombineInterval Mod";
-  return IntSet::everything();
+  DLOG(WARNING) << "Return Everything in CombineInterval Mod";
+  return IntervalSet::Everything();
 }
 
 template<>
-inline IntSet CombineInterval<Max>(Interval a, Interval b) {
-  if (a.is_single_point() && b.is_single_point()) {
-    return IntSet::single_point(ComputeExpr<Max>(a.min, b.min));
+inline IntervalSet Combine<ir::Max>(Analyzer* analzyer,
+                                    IntervalSet a,
+                                    IntervalSet b) {
+  if (a->IsSinglePoint() && b->IsSinglePoint()) {
+    return IntervalSet::SinglePoint(max(a->min_value,  b->min_value));
   }
-  return IntervalSet::make(Interval::make_max(a.min, b.min),
-                           Interval::make_max(a.max, b.max));
+  if (a->IsEmpty()) return a;
+  if (b->IsEmpty()) return b;
+  return IntervalSet(max(a->min_value, b->min_value),
+                     max(a->max_value, b->max_value));
 }
 
 template<>
-inline IntSet CombineInterval<Min>(Interval a, Interval b) {
-  if (a.is_single_point() && b.is_single_point()) {
-    return IntSet::single_point(ComputeExpr<Min>(a.min, b.min));
+inline IntervalSet Combine<ir::Min>(Analyzer* analzyer,
+                                    IntervalSet a,
+                                    IntervalSet b) {
+  if (a->IsSinglePoint() && b->IsSinglePoint()) {
+    return IntervalSet::SinglePoint(min(a->min_value, b->min_value));
   }
-  return IntervalSet::make(Interval::make_min(a.min, b.min),
-                           Interval::make_min(a.max, b.max));
-}
-
-template<typename OP>
-inline IntSet CombineInterval_(IntSet a, IntSet b) {
-  return CombineInterval<OP>(
-      a.as<IntervalSet>()->i, b.as<IntervalSet>()->i);
-}
-
-// stride related
-inline IntSet AsStrideSet(IntSet a) {
-  if (a.as<StrideSet>()) return a;
-  const IntervalSet* s = a.as<IntervalSet>();
-  CHECK(s->i.is_bounded());
-  NodePtr<StrideSet> n = make_node<StrideSet>();
-  n->base = s->i;
-  return IntSet(n);
-}
-template<typename OP>
-inline IntSet CombineSets(IntSet a, IntSet b) {
-  return CombineInterval_<OP>(a.cover_interval(), b.cover_interval());
+  if (a->IsEmpty()) return a;
+  if (b->IsEmpty()) return b;
+  return IntervalSet(min(a->min_value, b->min_value),
+                     min(a->max_value, b->max_value));
 }
 
-template<>
-inline IntSet CombineSets<Add>(IntSet a, IntSet b) {
-  const IntervalSet* a_int = a.as<IntervalSet>();
-  const IntervalSet* b_int = b.as<IntervalSet>();
-  if (a_int && is_zero(a_int->i.min)) return b;
-  if (b_int && is_zero(b_int->i.min)) return a;
-  a = AsStrideSet(a);
-  b = AsStrideSet(b);
-  const StrideSet* a_stride = a.as<StrideSet>();
-  const StrideSet* b_stride = b.as<StrideSet>();
-  auto n = make_node<StrideSet>(*a_stride);
-  for (size_t i = 0; i < b_stride->extents.size(); ++i) {
-    n->extents.push_back(b_stride->extents[i]);
-    n->strides.push_back(b_stride->strides[i]);
-  }
-  n->base = CombineInterval<Add>(
-      a_stride->base, b_stride->base).as<IntervalSet>()->i;
-  return IntSet(n);
-}
-
-inline IntSet NegateSet(IntSet a) {
-  const IntervalSet* a_int = a.as<IntervalSet>();
-  if (a_int) {
-    if (a_int->i.is_single_point()) {
-      return IntSet::single_point(-a_int->i.min);
-    } else {
-      Interval r = Interval::everything();
-      if (a_int->i.has_upper_bound()) {
-        r.min = -(a_int->i.max);
-      }
-      if (a_int->i.has_lower_bound()) {
-        r.max = -(a_int->i.min);
-      }
-      return IntervalSet::make(r);
-    }
-  } else {
-    return NegateSet(a.cover_interval());
+// internal helper function to get an interval set
+IntervalSet ToIntervalSet(IntSet set) {
+  if (auto* node = set.as<IntervalSetNode>()) {
+    return GetRef<IntervalSet>(node);
   }
+  DLOG(INFO) << "cannot resolve int set " << set;
+  return IntervalSet::Everything();
 }
 
-template<>
-inline IntSet CombineSets<Sub>(IntSet a, IntSet b) {
-  return CombineSets<Add>(a, NegateSet(b));
-}
-
-TVM_DECLARE_LOGICAL_OP(And);
-TVM_DECLARE_LOGICAL_OP(Or);
-TVM_DECLARE_LOGICAL_OP(EQ);
-TVM_DECLARE_LOGICAL_OP(NE);
-TVM_DECLARE_LOGICAL_OP(GE);
-TVM_DECLARE_LOGICAL_OP(GT);
-TVM_DECLARE_LOGICAL_OP(LE);
-TVM_DECLARE_LOGICAL_OP(LT);
-TVM_DECLARE_LOGICAL_OP(Not);
+using namespace ir;
 
-// generic combine operations of two sets
-template<typename OP>
-inline IntSet Combine(const IntSet& a, const IntSet &b) {
-  if (is_logical_op<OP>::value) {
-    return IntervalSet::make(0, 1);
+// Simplified version of int set evaluator that operates on IntervalSet
+// We might use better set analysis in the future to replace the intervalset.
+class IntervalSetEvaluator :
+      public ExprFunctor<IntervalSet(const Expr&)> {
+ public:
+  IntervalSetEvaluator(Analyzer* analyzer,
+                       const Map<Var, IntSet>& dom_map,
+                       bool eval_vec = false)
+      : analyzer_(analyzer),
+        dom_map_(dom_map),
+        eval_vec_(eval_vec) {
   }
-  const IntervalSet* a_int = a.as<IntervalSet>();
-  const IntervalSet* b_int = b.as<IntervalSet>();
-  if (a_int && a_int->i.is_everything()) return a;
-  if (b_int && b_int->i.is_everything()) return b;
-  if (a_int && b_int) {
-    return CombineInterval<OP>(a_int->i, b_int->i);
+
+  IntervalSet Eval(const Expr& val) {
+    return this->VisitExpr(val);
   }
-  if (a_int && !(a_int->i.is_bounded())) {
-    return CombineInterval_<OP>(a, b.cover_interval());
+
+  IntervalSet VisitExpr_(const IntImm* op) final {
+    return IntervalSet::SinglePoint(GetRef<Expr>(op));
   }
-  if (b_int && !(b_int->i.is_bounded())) {
-    return CombineInterval_<OP>(a.cover_interval(), b);
+
+  IntervalSet VisitExpr_(const UIntImm* op) final {
+    return IntervalSet::SinglePoint(GetRef<Expr>(op));
   }
-  return CombineSets<OP>(a, b);
-}
 
-class IntSetEvaluator :
-      public ExprFunctor<IntSet(const Expr&, const Expr&)> {
- public:
-  explicit IntSetEvaluator(
-      const std::unordered_map<const Variable*, IntSet>& dom_map,
-      bool eval_vec = false)
-      : dom_map_(dom_map), eval_vec_(eval_vec) {}
-  // Evaluate.
-  IntSet Eval(const Expr& e) {
-    return this->VisitExpr(e, e);
-  }
-  IntSet VisitExpr_(const IntImm* op, const Expr& e) final {
-    return IntSet::single_point(e);
-  }
-  IntSet VisitExpr_(const UIntImm* op, const Expr& e) final {
-    return IntSet::single_point(e);
-  }
-  IntSet VisitExpr_(const Variable* op, const Expr& e) final {
-    auto it = dom_map_.find(op);
+  IntervalSet VisitExpr_(const Variable* op) final {
+    Var var = GetRef<Var>(op);
+    auto it = dom_map_.find(var);
     if (it != dom_map_.end()) {
-      return it->second;
+      return ToIntervalSet((*it).second);
     } else {
-      return IntSet::single_point(e);
+      return IntervalSet::SinglePoint(var);
     }
   }
-  IntSet VisitExpr_(const Add* op, const Expr& e) final {
-    return Binary(op, e);
+
+  IntervalSet VisitExpr_(const Add* op) final {
+    return VisitBinaryExpr_(op);
   }
-  IntSet VisitExpr_(const Sub* op, const Expr& e) final {
-    return Binary(op, e);
+
+  IntervalSet VisitExpr_(const Sub* op) final {
+    return VisitBinaryExpr_(op);
   }
-  IntSet VisitExpr_(const Mul* op, const Expr& e) final {
-    return Binary(op, e);
+
+  IntervalSet VisitExpr_(const Mul* op) final {
+    return VisitBinaryExpr_(op);
   }
-  IntSet VisitExpr_(const Div* op, const Expr& e) final {
-    return Binary(op, e);
+
+  IntervalSet VisitExpr_(const Div* op) final {
+    return VisitBinaryExpr_(op);
   }
-  IntSet VisitExpr_(const Mod* op, const Expr& e) final {
-    return Binary(op, e);
+
+  IntervalSet VisitExpr_(const Mod* op) final {
+    return VisitBinaryExpr_(op);
   }
-  IntSet VisitExpr_(const Min* op, const Expr& e) final {
-    return Binary(op, e);
+
+  IntervalSet VisitExpr_(const Min* op) final {
+    return VisitBinaryExpr_(op);
   }
-  IntSet VisitExpr_(const Max* op, const Expr& e) final {
-    return Binary(op, e);
+
+  IntervalSet VisitExpr_(const Max* op) final {
+    return VisitBinaryExpr_(op);
   }
-  IntSet VisitExpr_(const EQ* op, const Expr& e) final {
-    return Binary(op, e);
+
+  IntervalSet VisitExpr_(const EQ* op) final {
+    return VisitBinaryExpr_(op);
   }
-  IntSet VisitExpr_(const NE* op, const Expr& e) final {
-    return Binary(op, e);
+
+  IntervalSet VisitExpr_(const NE* op) final {
+    return VisitBinaryExpr_(op);
   }
-  IntSet VisitExpr_(const LT* op, const Expr& e) final {
-    return Binary(op, e);
+
+  IntervalSet VisitExpr_(const LT* op) final {
+    return VisitBinaryExpr_(op);
   }
-  IntSet VisitExpr_(const LE* op, const Expr& e) final {
-    return Binary(op, e);
+
+  IntervalSet VisitExpr_(const LE* op) final {
+    return VisitBinaryExpr_(op);
   }
-  IntSet VisitExpr_(const GT* op, const Expr& e) final {
-    return Binary(op, e);
+
+  IntervalSet VisitExpr_(const GT* op) final {
+    return VisitBinaryExpr_(op);
   }
-  IntSet VisitExpr_(const GE* op, const Expr& e) final {
-    return Binary(op, e);
+
+  IntervalSet VisitExpr_(const GE* op) final {
+    return VisitBinaryExpr_(op);
   }
-  IntSet VisitExpr_(const And* op, const Expr& e) final {
-    return Binary(op, e);
+
+  IntervalSet VisitExpr_(const And* op) final {
+    return VisitBinaryExpr_(op);
   }
-  IntSet VisitExpr_(const Or* op, const Expr& e) final {
-    return Binary(op, e);
+
+  IntervalSet VisitExpr_(const Or* op) final {
+    return VisitBinaryExpr_(op);
   }
-  IntSet VisitExpr_(const Ramp* op, const Expr& e) final {
+
+  IntervalSet VisitExpr_(const Ramp* op) final {
     CHECK(eval_vec_);
-    IntSet base = Eval(op->base);
-    int vstride;
-    if (GetConstInt(op->stride, &vstride)) {
+    IntervalSet base = Eval(op->base);
+    PVar<Integer> stride;
+    if (stride.Match(op->stride)) {
       Type t = op->base.type();
-      if (vstride > 0) {
+      int64_t vstride = stride.Eval()->value;
+      if (vstride> 0) {
         return Combine<Add>(
+            analyzer_,
             base,
-            IntSet::interval(make_zero(t),
-                             make_const(t, vstride * op->lanes -1)));
+            IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1)));
       } else {
         return Combine<Add>(
+            analyzer_,
             base,
-            IntSet::interval(make_const(t, vstride * op->lanes + 1),
-                             make_zero(t)));
+            IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t)));
       }
     }
-    LOG(WARNING) << "cannot evaluate set on expression " << e;
-    return IntSet::everything();
+    DLOG(WARNING) << "cannot evaluate set on expression " << GetRef<Expr>(op);
+    return IntervalSet::Everything();
   }
-  IntSet VisitExpr_(const Broadcast* op, const Expr& e) final {
+
+  IntervalSet VisitExpr_(const Broadcast* op) final {
     CHECK(eval_vec_);
-    return Eval(op->value);
+    return VisitExpr(op->value);
   }
-  IntSet VisitExpr_(const Select* op, const Expr& e) final {
-    IntSet true_set = this->Eval(op->true_value);
-    IntSet false_set = this->Eval(op->false_value);
-    return Union({false_set, true_set});
+
+  IntervalSet VisitExpr_(const Select* op) final {
+    IntervalSet true_set = this->Eval(op->true_value);
+    IntervalSet false_set = this->Eval(op->false_value);
+    return Union(analyzer_, false_set, true_set);
   }
-  IntSet VisitExprDefault_(const Node* op, const Expr& e) final {
-    LOG(WARNING) << "cannot evaluate set type " << e->type_key();
-    return IntSet::everything();
+
+  IntervalSet VisitExprDefault_(const Node* op) final {
+    DLOG(WARNING) << "cannot evaluate set type " << op->type_key();
+    return IntervalSet::Everything();
   }
 
  private:
+  // whether set is exactly single point that equals value.
+  bool MatchPoint(const IntervalSet& set,
+                  const Expr& value) const {
+    return set->min_value.same_as(value) && set->max_value.same_as(value);
+  }
+
   template<typename T>
-  inline IntSet Binary(const T* op, const Expr& e) {
-    IntSet a = this->Eval(op->a);
-    IntSet b = this->Eval(op->b);
+  inline IntervalSet VisitBinaryExpr_(const T* op) {
+    IntervalSet a = this->Eval(op->a);
+    IntervalSet b = this->Eval(op->b);
     if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
-      return IntSet::single_point(e);
+      return IntervalSet::SinglePoint(GetRef<Expr>(op));
     }
-    return Combine<T>(a, b);
+    return Combine<T>(analyzer_, a, b);
   }
 
-  const std::unordered_map<const Variable*, IntSet>& dom_map_;
+  Analyzer* analyzer_;
+  const Map<Var, IntSet>& dom_map_;
   bool eval_vec_{false};
 };
 
+class IntSetAnalyzer::Impl {
+ public:
+  explicit Impl(Analyzer* analyzer)
+      : analyzer_(analyzer) {
+  }
+
+  IntSet Eval(const Expr& expr, const Map<Var, IntSet>& dom_map) const {
+    return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr);
+  }
+
+ private:
+  Analyzer* analyzer_;
+};
+
+IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent)
+    : impl_(new Impl(parent)) {
+}
+
+IntSetAnalyzer::~IntSetAnalyzer() {
+  delete impl_;
+}
+
+IntSet IntSetAnalyzer::operator()(const Expr& expr,
+                                  const Map<Var, IntSet>& dom_map) {
+  return impl_->Eval(expr, dom_map);
+}
+
+// Quickly adapt to IntSet interface
+// TODO(tqchen): revisit IntSet interface as well.
+Range IntSet::cover_range(Range max_range) const {
+  IntSet temp;
+  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+  CHECK(s_int != nullptr);
+  if (s_int->HasUpperBound() && s_int->HasLowerBound()) {
+    return Range::make_by_min_extent(
+        s_int->min_value, Simplify(s_int->max_value + 1 - s_int->min_value));
+  }
+  return max_range;
+}
+
+Expr IntSet::min() const {
+  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+  CHECK(s_int);
+  return s_int->min_value;
+}
+
+Expr IntSet::max() const {
+  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+  CHECK(s_int);
+  return s_int->max_value;
+}
+
+bool IntSet::is_nothing() const {
+  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+  return (s_int && s_int->IsEmpty());
+}
+
+bool IntSet::is_everything() const {
+  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+  return (s_int && s_int->IsEverything());
+}
+
+bool IntSet::is_single_point() const {
+  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+  return (s_int && s_int->IsSinglePoint());
+}
+
+bool IntSet::can_prove_positive() const {
+  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+  return (s_int && is_positive_const(ir::Simplify(s_int->min_value)));
+}
+
+bool IntSet::can_prove_negative() const {
+  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+  return (s_int && is_negative_const(ir::Simplify(s_int->max_value)));
+}
+
+bool IntSet::can_prove_non_positive() const {
+  if (const auto* s_int = (*this).as<IntervalSetNode>()) {
+    auto max = ir::Simplify(s_int->max_value);
+    return is_zero(max) || is_negative_const(max);
+  }
+  return false;
+}
+
+bool IntSet::can_prove_non_negative() const {
+  if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) {
+    auto min = ir::Simplify(s_int->min_value);
+    return is_zero(min) || is_positive_const(min);
+  }
+  return false;
+}
+
+SignType IntSet::sign_type() const {
+  if (can_prove_positive()) {
+    return kPositive;
+  } else if (can_prove_negative()) {
+    return kNegative;
+  } else if (is_single_point() && is_zero(point_value())) {
+    return kZero;
+  } else {
+    return kUnknown;
+  }
+}
+Expr IntSet::point_value() const {
+  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+  CHECK(s_int && s_int->IsSinglePoint());
+  return s_int->min_value;
+}
+
+IntSet IntSet::nothing() {
+  return IntervalSet::Empty();
+}
+
+IntSet IntSet::everything() {
+  return IntervalSet::Everything();
+}
+
+IntSet IntSet::single_point(Expr x) {
+  return IntervalSet::SinglePoint(x);
+}
+
+IntSet IntSet::interval(Expr min, Expr max) {
+  if (min.same_as(max)) {
+    return IntSet::single_point(min);
+  }
+  return IntervalSet(min, max);
+}
+
+// Range related code
+inline bool ProveEqual(Expr lhs, Expr rhs) {
+  return is_zero(ir::Simplify(lhs - rhs));
+}
+
+IntSet IntSet::range(Range r) {
+  // must make sure it can be matched back by MatchRange.
+  if (is_one(r->extent)) {
+    return IntSet::single_point(r->min);
+  }
+  return IntervalSet(r->min, r->extent + r->min - 1);
+}
+
+bool IntSet::match_range(const Range& b) const {
+  const IntSet& a = *this;
+  const IntervalSetNode* a_int = a.as<IntervalSetNode>();
+  if (!a_int) return false;
+  return ProveEqual(a_int->min_value, b->min) &&
+      ProveEqual(a_int->max_value, b->extent + b->min - 1);
+}
+
+IntSet Union(const Array<IntSet>& sets) {
+  if (sets.size() == 0) return IntSet::nothing();
+  if (sets.size() == 1) return sets[0];
+  Analyzer ana;
+  IntervalSet x = ToIntervalSet(sets[0]);
+  for (size_t i = 1; i < sets.size(); ++i) {
+    x = Union(&ana, x, ToIntervalSet(sets[i]));
+  }
+  return IntervalSet(ir::Simplify(x->min_value),
+                     ir::Simplify(x->max_value));
+}
+
+IntSet Intersect(const Array<IntSet>& sets) {
+  if (sets.size() == 0) return IntSet::nothing();
+  if (sets.size() == 1) return sets[0];
+  Analyzer ana;
+  IntervalSet x = ToIntervalSet(sets[0]);
+  for (size_t i = 1; i < sets.size(); ++i) {
+    x = Intersect(&ana, x, ToIntervalSet(sets[i]));
+  }
+  return IntervalSet(ir::Simplify(x->min_value),
+                     ir::Simplify(x->max_value));
+}
+
+Map<Var, IntSet> ConvertDomMap(const Map<IterVar, IntSet>& dom_map) {
+  Map<Var, IntSet> dmap;
+  for (auto kv : dom_map) {
+    dmap.Set(kv.first->var, kv.second);
+  }
+  return dmap;
+}
+
+Map<Var, IntSet> ConvertDomMap(
+    const std::unordered_map<const Variable*, IntSet>& dom_map) {
+  Map<Var, IntSet> dmap;
+  for (auto kv : dom_map) {
+    dmap.Set(GetRef<Var>(kv.first), kv.second);
+  }
+  return dmap;
+}
+
 IntSet EvalSet(Expr e,
-               const std::unordered_map<const Variable*, IntSet>& dom_map) {
-  return IntSetEvaluator(dom_map, false).Eval(e);
+               const Map<Var, IntSet>& dom_map) {
+  Analyzer ana;
+  return IntervalSetEvaluator(&ana, dom_map, false).Eval(e);
 }
 
 IntSet IntSet::vector(Expr x) {
-  std::unordered_map<const Variable*, IntSet> dmap;
-  return IntSetEvaluator(dmap, true).Eval(x);
+  Analyzer ana;
+  Map<Var, IntSet> dmap;
+  return IntervalSetEvaluator(&ana, dmap, true).Eval(x);
 }
 
 IntSet EvalSet(Expr e,
                const Map<IterVar, IntSet>& dom_map) {
-  std::unordered_map<const Variable*, IntSet> dmap;
-  for (auto kv : dom_map) {
-    dmap[kv.first->var.as<Variable>()] = kv.second;
-  }
-  return EvalSet(e, dmap);
+  return EvalSet(e, ConvertDomMap(dom_map));
 }
 
-IntSet EvalSet(Range r,
+IntSet EvalSet(Expr e,
                const std::unordered_map<const Variable*, IntSet>& dom_map) {
-  IntSetEvaluator m(dom_map);
-  IntSet min_set = m.Eval(r->min).cover_interval();
+  return EvalSet(e, ConvertDomMap(dom_map));
+}
+
+IntSet EvalSet(Range r,
+               const Map<Var, IntSet>& dom_map) {
+  Analyzer ana;
+  IntervalSetEvaluator m(&ana, dom_map);
+  IntervalSet min_set = m.Eval(r->min);
   // Simplifying first can give tighter bounds if r->min and r->extent share variables
-  Expr sum = ComputeExpr<Sub>(ComputeExpr<Add>(r->min, r->extent), 1);
-  IntSet max_set = m.Eval(Simplify(sum)).cover_interval();
-  const Interval& ni = min_set.as<IntervalSet>()->i;
-  const Interval& xi = max_set.as<IntervalSet>()->i;
-  if (!ni.has_lower_bound()) return IntSet::everything();
-  if (!xi.has_upper_bound()) return IntSet::everything();
-  return IntervalSet::make(ni.min, xi.max);
+  Expr sum = r->min + r->extent - 1;
+  IntervalSet max_set = m.Eval(Simplify(sum));
+  if (!min_set->HasLowerBound()) return IntSet::everything();
+  if (!max_set->HasUpperBound()) return IntSet::everything();
+  return IntervalSet(min_set->min_value, max_set->max_value);
 }
 
-IntSet EvalSet(IntSet s,
+IntSet EvalSet(Range r,
                const std::unordered_map<const Variable*, IntSet>& dom_map) {
-  IntSetEvaluator m(dom_map);
-  s = s.cover_interval();
-  const IntervalSet* s_int = s.as<IntervalSet>();
-  Expr vmax = s_int->i.has_upper_bound() ?
-      m.Eval(s_int->i.max).cover_interval().max() : s_int->i.max;
-  Expr vmin = s_int->i.has_lower_bound() ?
-      m.Eval(s_int->i.min).cover_interval().min() : s_int->i.min;
-  return IntervalSet::make(vmin, vmax);
+  return EvalSet(r, ConvertDomMap(dom_map));
 }
 
-class SubExprIntSetEvaluator : public IntSetEvaluator {
+IntSet EvalSet(IntSet s,
+               const std::unordered_map<const Variable*, IntSet>& dom_map) {
+  Analyzer ana;
+  auto dmap = ConvertDomMap(dom_map);
+  IntervalSetEvaluator m(&ana, dmap);
+  const IntervalSetNode* s_int = s.as<IntervalSetNode>();
+  Expr vmax = s_int->HasUpperBound() ?
+      m.Eval(s_int->max_value).max() : s_int->max_value;
+  Expr vmin = s_int->HasLowerBound() ?
+      m.Eval(s_int->min_value).min() : s_int->min_value;
+  return IntervalSet(vmin, vmax);
+}
+
+class SubExprIntervalSetEvaluator : public IntervalSetEvaluator {
  public:
-  explicit SubExprIntSetEvaluator(
-      const std::unordered_map<const Variable*, IntSet>& dom_map)
-      : IntSetEvaluator(dom_map) {}
+  explicit SubExprIntervalSetEvaluator(
+      Analyzer* analyzer,
+      const Map<Var, IntSet>& dom_map)
+      : IntervalSetEvaluator(analyzer, dom_map) {}
 
-  IntSet VisitExpr(const Expr& n, const Expr& e) final {
-    IntSet ret = IntSetEvaluator::VisitExpr(n, e);
+  IntervalSet VisitExpr(const Expr& n) final {
+    IntervalSet ret = IntervalSetEvaluator::VisitExpr(n);
     expr_map[n] = ret;
     return ret;
   }
@@ -635,28 +705,26 @@ class SubExprIntSetEvaluator : public IntSetEvaluator {
   ExprIntSetMap expr_map;
 };
 
-ExprIntSetMap EvalSetForEachSubExpr(Expr e,
+ExprIntSetMap EvalSetForEachSubExpr(
+    Expr e,
     const std::unordered_map<const Variable*, IntSet>& dom_map) {
-  SubExprIntSetEvaluator m(dom_map);
+  Analyzer ana;
+  auto dmap = ConvertDomMap(dom_map);
+  SubExprIntervalSetEvaluator m(&ana, dmap);
   m.Eval(e);
   return m.expr_map;
 }
 
 IntSet EvalSet(Range r,
                const Map<IterVar, IntSet>& dom_map) {
-  std::unordered_map<const Variable*, IntSet> dmap;
-  for (auto kv : dom_map) {
-    dmap[kv.first->var.as<Variable>()] = kv.second;
-  }
-  return EvalSet(r, dmap);
+  return EvalSet(r, ConvertDomMap(dom_map));
 }
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<IntervalSet>([](const IntervalSet *op, IRPrinter *p) {
-    p->stream << "interval-set"
-              << "[" << op->i.min << ", "
-              << op->i.max << ']';
+.set_dispatch<IntervalSetNode>([](const IntervalSetNode *op, IRPrinter *p) {
+    p->stream << "IntervalSet"
+              << "[" << op->min_value << ", "
+              << op->max_value << ']';
   });
-
 }  // namespace arith
 }  // namespace tvm
diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h
new file mode 100644 (file)
index 0000000..bf7fec2
--- /dev/null
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file int_set.h
+ * \brief Internal data structure for integer set.
+ */
+#ifndef TVM_ARITHMETIC_INT_SET_H_
+#define TVM_ARITHMETIC_INT_SET_H_
+
+#include <tvm/arithmetic.h>
+#include <tvm/expr_operator.h>
+#include <limits>
+#include "const_fold.h"
+
+namespace tvm {
+namespace arith {
+
+/*!
+ * \brief Symbolic interval set.
+ *
+ * \note We intentionally keep the internal of IntSet private,
+         as we might change it later.
+ */
+class IntervalSetNode : public IntSetNode {
+ public:
+  /*! \brief Minimum value in the interval. */
+  Expr min_value;
+  /*! \brief Maximum value in the interval. */
+  Expr max_value;
+
+  // visitor overload.
+  void VisitAttrs(tvm::AttrVisitor* v) final {
+    v->Visit("min_value", &min_value);
+    v->Visit("max_value", &max_value);
+  }
+
+  /*! \return Whether the interval has upper bound. */
+  bool HasUpperBound() const {
+    return !is_pos_inf(max_value) && !IsEmpty();
+  }
+  /*! \return Whether the interval has lower bound. */
+  bool HasLowerBound() const {
+    return !is_neg_inf(min_value) && !IsEmpty();
+  }
+  /*! \return Whether the interval is a single point. */
+  bool IsSinglePoint() const {
+    return min_value.same_as(max_value);
+  }
+  /*! \return whether interval represent nothing */
+  bool IsEmpty() const {
+    // during computations, either extreme could occur.
+    return is_pos_inf(min_value) || is_neg_inf(max_value);
+  }
+  /*! \return whether interval represent everything */
+  bool IsEverything() const {
+    return is_neg_inf(min_value) && is_pos_inf(max_value);
+  }
+
+  static constexpr const char* _type_key = "arith.IntervalSet";
+  TVM_DECLARE_NODE_TYPE_INFO(IntervalSetNode, IntSetNode);
+};
+
+/*!
+ * \brief Interval set used for symbolic integer analysis.
+ * \sa IntervalSetNode
+ */
+class IntervalSet : public IntSet {
+ public:
+  /*!
+   * \brief Make a new instance of interval set.
+   * \param min_value The minimum value in the interval.
+   * \param max_value The maximum value in the interval.
+   * \return The created set.
+   */
+  TVM_DLL IntervalSet(Expr min_value, Expr max_value);
+
+  /*!
+   * \brief Create an IntervalSet that represents a single point.
+   * \param value The value to be represented.
+   * \return The result set.
+   */
+  static IntervalSet SinglePoint(Expr value) {
+    return IntervalSet(value, value);
+  }
+  /*!
+   * \brief Create an IntervalSet that represents everything.
+   * \param value The value to be represented.
+   * \return The result set.
+   */
+  static IntervalSet Everything() {
+    return IntervalSet(neg_inf(), pos_inf());
+  }
+  /*!
+   * \brief Create an empty eet.
+   * \return The result set.
+   */
+  static IntervalSet Empty() {
+    return IntervalSet(pos_inf(), neg_inf());
+  }
+
+  TVM_DEFINE_NODE_REF_COW(IntervalSetNode);
+  TVM_DEFINE_NODE_REF_METHODS(IntervalSet, IntSet, IntervalSetNode);
+};
+
+/*!
+ * \brief Create union of two IntervalSets.
+ * \param analyzer The analyzer for simplification analysis.
+ * \param a The first set.
+ * \param b The second set.
+ * \return The result set.
+ */
+TVM_DLL IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b);
+
+/*!
+ * \brief Create insersection of two IntervalSets.
+ * \param analzyer The analyzer for simplification analysis.
+ * \param a The first set.
+ * \param b The second set.
+ * \return The result set.
+ */
+TVM_DLL IntervalSet Intersect(Analyzer *analzyer, IntervalSet a, IntervalSet b);
+
+}  // namespace arith
+}  // namespace tvm
+
+#endif  // TVM_ARITHMETIC_INT_SET_H_
diff --git a/src/arithmetic/int_set_internal.h b/src/arithmetic/int_set_internal.h
deleted file mode 100644 (file)
index 8b675cf..0000000
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- * 
- *   http://www.apache.org/licenses/LICENSE-2.0
- * 
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *  Copyright (c) 2017 by Contributors
- * \file int_set_internal.h
- * \brief Implementations of integer set
- */
-#ifndef TVM_ARITHMETIC_INT_SET_INTERNAL_H_
-#define TVM_ARITHMETIC_INT_SET_INTERNAL_H_
-
-#include <tvm/ir.h>
-#include <tvm/ir_pass.h>
-#include <tvm/arithmetic.h>
-
-namespace tvm {
-namespace arith {
-
-using HalideIR::Internal::Interval;
-
-/*! \brief Set of continuous interval */
-struct IntervalSet : public IntSetNode {
-  /*! \brief the internal interval*/
-  Interval i;
-
-  static IntSet make(Interval i) {
-    NodePtr<IntervalSet> n =
-        make_node<IntervalSet>();
-    n->i = i;
-    return IntSet(n);
-  }
-  static IntSet make(Expr min, Expr max) {
-    NodePtr<IntervalSet> n =
-        make_node<IntervalSet>();
-    n->i.min = min;
-    n->i.max = max;
-    return IntSet(n);
-  }
-
-  static constexpr const char* _type_key = "IntervalSet";
-  TVM_DECLARE_NODE_TYPE_INFO(IntervalSet, IntSetNode);
-};
-
-/*!
- * \brief set represented by strided integers
- *  Reserved for cases where strided access is supported.
- */
-struct StrideSet : public IntSetNode {
-  /*! \brief the base inetrval */
-  Interval base;
-  /*! \brief additional extents in positive number */
-  Array<Expr> extents;
-  /*! \brief additional strides in positive number */
-  Array<Expr> strides;
-
-  static constexpr const char* _type_key = "StrideSet";
-  TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode);
-};
-
-}  // namespace arith
-}  // namespace tvm
-
-#endif  // TVM_ARITHMETIC_INT_SET_INTERNAL_H_
index 8537f17..3f52540 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -188,7 +188,15 @@ Expr operator%(Expr a, Expr b) {
   return ir::Mod::make(a, b);
 }
 
+
 Expr min(Expr a, Expr b) {
+  // inf-aware simplificaiton
+  using arith::is_pos_inf;
+  using arith::is_neg_inf;
+  if (is_pos_inf(a)) return b;
+  if (is_neg_inf(a)) return a;
+  if (is_pos_inf(b)) return a;
+  if (is_neg_inf(b)) return b;
   BinaryOpMatchTypes(a, b);
   Expr ret = arith::TryConstFold<ir::Min>(a, b);
   if (ret.defined()) return ret;
@@ -196,6 +204,13 @@ Expr min(Expr a, Expr b) {
 }
 
 Expr max(Expr a, Expr b) {
+  // inf-aware simplificaiton
+  using arith::is_pos_inf;
+  using arith::is_neg_inf;
+  if (is_pos_inf(a)) return a;
+  if (is_neg_inf(a)) return b;
+  if (is_pos_inf(b)) return b;
+  if (is_neg_inf(b)) return a;
   BinaryOpMatchTypes(a, b);
   Expr ret = arith::TryConstFold<ir::Max>(a, b);
   if (ret.defined()) return ret;
index bcb2608..0a5b741 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -28,7 +28,7 @@
 #include <tvm/arithmetic.h>
 #include <unordered_map>
 #include <unordered_set>
-#include "../arithmetic/int_set_internal.h"
+#include "../arithmetic/int_set.h"
 #include "../runtime/thread_storage_scope.h"
 
 namespace tvm {
@@ -366,7 +366,7 @@ class LoopPartitioner : public IRMutator {
 
   std::pair<IntSet, std::unordered_set<const Node*>>
   GetIntervalAndCondset(const Partition &partitions,
-                        const arith::Interval &for_interval,
+                        const arith::IntervalSet &for_interval,
                         bool cond_value);
 
   inline Stmt MakeFor(const Node* op, Expr extent, Stmt body);
@@ -374,6 +374,7 @@ class LoopPartitioner : public IRMutator {
   /* Candidate IRs that may be partitioned potentially */
   std::unordered_map<const Variable*, IntSet> hint_map_;
   std::unordered_map<const Variable*, IntSet> relax_map_;
+  arith::Analyzer analyzer_;
   CandidateSelector selector;
 };
 
@@ -381,16 +382,17 @@ class LoopPartitioner : public IRMutator {
 // given in the second component provably have value given by cond_value
 std::pair<IntSet, std::unordered_set<const Node*>>
 LoopPartitioner::GetIntervalAndCondset(const Partition &partitions,
-                                       const arith::Interval &for_interval,
+                                       const arith::IntervalSet &for_interval,
                                        bool cond_value) {
   Array<IntSet> sets;
   std::unordered_set<const Node*> cond_set;
 
   for (const auto &kv : partitions) {
     if (kv.first.second == cond_value) {
-      arith::Interval interval = kv.second.as<arith::IntervalSet>()->i;
-      arith::Interval intersection = arith::Interval::make_intersection(interval, for_interval);
-      if (!intersection.is_empty()) {
+      arith::IntervalSet interval = Downcast<arith::IntervalSet>(kv.second);
+      arith::IntervalSet intersection = arith::Intersect(
+          &analyzer_, interval, for_interval);
+      if (!intersection->IsEmpty()) {
         sets.push_back(kv.second);
         cond_set.insert(kv.first.first);
       }
@@ -463,11 +465,12 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
                                    Expr max,
                                    Stmt body,
                                    bool partition_thread_scope) {
+  using namespace arith;
   PartitionFinder finder(var, hint_map_, relax_map_);
   finder.Visit(body);
   if (finder.partitions.empty()) return Stmt();
 
-  arith::Interval for_interval(min, max);
+  arith::IntervalSet for_interval(min, max);
   bool cond_value;
   IntSet middle_interval;
   std::unordered_set<const Node*> cond_set;
@@ -478,7 +481,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
     // if such interval doesn't exist, find an interval in which all
     // conditions on var are false
     std::tie(middle_interval, cond_set) =
-            GetIntervalAndCondset(finder.partitions, for_interval, false);
+        GetIntervalAndCondset(finder.partitions, for_interval, false);
     if (middle_interval.is_nothing())
       // we couldn't find an interval in which the condintions are provably true or false
       // Therefore, we can't partition the loop based on those conds
@@ -488,7 +491,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
     cond_value = true;
   }
 
-  arith::Interval middle_interval_i = middle_interval.as<arith::IntervalSet>()->i;
+  IntervalSet middle_interval_i = Downcast<IntervalSet>(middle_interval);
   // middle_interval is the subrange of the loop variable range for which a
   // set of conditions are true (or false resp.)
   // The part of the loop variable range that is before (after resp.) that
@@ -499,7 +502,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
   Expr body_begin;
   Stmt pre_stmt;
   bool pre_stmt_recurse = true;
-  if (middle_interval_i.has_lower_bound()) {
+  if (middle_interval_i->HasLowerBound()) {
     body_begin = ir::Simplify(middle_interval.min());
     if (!can_prove(body_begin == min)) {
       Expr cond = (body_begin - min >= 0);
@@ -524,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
   Expr post_doubt_begin;
   Stmt post_stmt;
   bool post_stmt_recurse = true;
-  if (middle_interval_i.has_upper_bound()) {
+  if (middle_interval_i->HasUpperBound()) {
     post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
     if (!can_prove(middle_interval.max() == max)) {
       // require the extent to be non-negative
diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py
new file mode 100644 (file)
index 0000000..7fe6f56
--- /dev/null
@@ -0,0 +1,168 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+
+def test_deduce():
+    a = tvm.var('a')
+    b = tvm.var('b')
+    c = tvm.var('c')
+    d = tvm.var('d')
+
+    b_s = tvm.arith.IntervalSet(2, 3)
+    c_s = tvm.arith.IntervalSet(10, 15)
+    d_s = tvm.arith.IntervalSet(-3, -1)
+    zero = tvm.const(0, "int32")
+
+    e0 = (-b)*a+c-d
+    res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
+    ans0 = ((d - c) /(b*-1))
+    assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
+
+    # expression containing variable a is on rhs
+    res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
+    assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
+
+    e0 = d*a+c-d
+    res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
+    ans0 = ((0-c)/d + 1)
+    assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
+
+    # expression containing variable a is on rhs
+    res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
+    assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
+
+    e1 = (a*4+b < c)
+    res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
+    ans1 = (((c - b) + -1)/4)
+    assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1)
+
+    # expression containing variable a is on rhs
+    e1 = (c > a*4+b)
+    res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
+    assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1)
+
+    e2 = (tvm.max(5, a * 4) < 0)
+    res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
+    assert str(res2.max_value) == "neg_inf"
+    assert str(res2.min_value) == "pos_inf"
+
+    # expression containing variable a is on rhs
+    e2 = (zero < tvm.max(5, a * 4))
+    res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
+    assert str(res2.max_value) == "neg_inf"
+    assert str(res2.min_value) == "pos_inf"
+
+
+    e3 = (-b)+a*c-d
+    res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
+    ans3 = 2/c+1
+    assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
+
+    res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
+    assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
+
+def test_check():
+    a = tvm.var('a')
+    b = tvm.var('b')
+    c = tvm.var('c')
+    d = tvm.var('d')
+
+    b_s = tvm.arith.IntervalSet(2, 3)
+    c_s = tvm.arith.IntervalSet(5, 7)
+    d_s = tvm.arith.IntervalSet(-3, -1)
+
+    # no compare operator
+    res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {})
+    assert res1.is_nothing()
+
+    # multiple compare operators
+    res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {})
+    assert res2.is_nothing()
+
+    # multiple target variable
+    res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {})
+    assert res2.is_nothing()
+
+def test_deduce_basic():
+    def test_basic(a1, a2, coff):
+        a = tvm.var('a')
+        b = tvm.var('b')
+        b_s = tvm.arith.IntervalSet(a1, a2)
+        e0 = b + a*coff + 3
+
+        res1 = tvm.arith.DeduceBound(a, e0<17, {b: b_s}, {b: b_s})
+        [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
+        assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1
+
+        # expression containing variable a is on rhs
+        res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s})
+        [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
+        assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1
+
+        # expression containing variable a is on rhs
+        res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s})
+        [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
+        assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1
+
+        res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s})
+        [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
+        assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1
+
+    test_basic(0, 4, 4)
+    test_basic(1, 5, 4)
+    test_basic(2, 6, 4)
+    test_basic(0, 4, -4)
+    test_basic(1, 5, -4)
+    test_basic(2, 6, -4)
+
+def test_deduce_complex():
+    def test_complex(a1, a2, coff):
+        a = tvm.var('a')
+        b = tvm.var('b')
+        b_s = tvm.arith.IntervalSet(a1, a2)
+        e0 = (b*3 + a* coff) * 4
+
+        res1 = tvm.arith.DeduceBound(a, e0<63, {b: b_s}, {b: b_s})
+        [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
+        assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1
+
+        # expression containing variable a is on rhs
+        res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s})
+        [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
+        assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1
+
+        res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s})
+        [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
+        assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1
+
+        # expression containing variable a is on rhs
+        res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s})
+        [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
+        assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1
+
+    test_complex(0, 4, 4)
+    test_complex(0, 4, -4)
+    test_complex(2, 6, 4)
+    test_complex(0, 4, -4)
+    test_complex(1, 5, -4)
+    test_complex(2, 6, -4)
+
+
+if __name__ == "__main__":
+    test_check()
+    test_deduce_basic()
+    test_deduce_complex()
index a74162e..fa14bcf 100644 (file)
 # under the License.
 import tvm
 
+
+class IntSetChecker:
+    def __init__(self):
+        self.analyzer = tvm.arith.Analyzer()
+
+    def verify(self, data, dmap, expected):
+        res = self.analyzer.int_set(data, dmap)
+        def err_msg():
+            return "\ndata={}\ndmap={}\nres={}\nexpected={}".format(data, dmap, res, expected)
+        def equal(x, y):
+            res = self.analyzer.canonical_simplify(x - y)
+            return tvm.ir_pass.Equal(res, 0)
+        assert equal(res.min_value, expected[0]), err_msg()
+        assert equal(res.max_value, expected[1]), err_msg()
+
 def test_basic():
-    s = tvm.arith.intset_interval(2, 3)
-    assert s.min().value == 2
-    assert s.max().value == 3
+    s = tvm.arith.IntervalSet(2, 3)
+    assert s.min_value.value == 2
+    assert s.max_value.value == 3
+
 
 def test_vector():
     base = 10
     stride = 3
     lanes = 2
     s = tvm.arith.intset_vector(tvm.make.Ramp(base, stride, lanes))
-    assert s.min().value == base
-    assert s.max().value == base + stride * lanes - 1
-
-def test_deduce():
-    a = tvm.var('a')
-    b = tvm.var('b')
-    c = tvm.var('c')
-    d = tvm.var('d')
-
-    b_s = tvm.arith.intset_interval(2, 3)
-    c_s = tvm.arith.intset_interval(10, 15)
-    d_s = tvm.arith.intset_interval(-3, -1)
-    zero = tvm.const(0, "int32")
-
-    e0 = (-b)*a+c-d
-    res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
-    ans0 = ((d - c) /(b*-1))
-    assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
-
-    # expression containing variable a is on rhs
-    res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
-    assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
-
-    e0 = d*a+c-d
-    res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
-    ans0 = ((0-c)/d + 1)
-    assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
-
-    # expression containing variable a is on rhs
-    res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
-    assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
-
-    e1 = (a*4+b < c)
-    res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
-    ans1 = (((c - b) + -1)/4)
-    assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)
-
-    # expression containing variable a is on rhs
-    e1 = (c > a*4+b)
-    res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
-    assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)
-
-    e2 = (tvm.max(5, a * 4) < 0)
-    res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
-    assert str(res2.max()) == "neg_inf"
-    assert str(res2.min()) == "pos_inf"
-
-    # expression containing variable a is on rhs
-    e2 = (zero < tvm.max(5, a * 4))
-    res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
-    assert str(res2.max()) == "neg_inf"
-    assert str(res2.min()) == "pos_inf"
-
-
-    e3 = (-b)+a*c-d
-    res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
-    ans3 = 2/c+1
-    assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)
-
-    res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
-    assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)
-
-def test_check():
-    a = tvm.var('a')
-    b = tvm.var('b')
-    c = tvm.var('c')
-    d = tvm.var('d')
-
-    b_s = tvm.arith.intset_interval(2, 3)
-    c_s = tvm.arith.intset_interval(5, 7)
-    d_s = tvm.arith.intset_interval(-3, -1)
-
-    # no compare operator
-    res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {})
-    assert res1.is_nothing()
-
-    # multiple compare operators
-    res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {})
-    assert res2.is_nothing()
-
-    # multiple target variable
-    res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {})
-    assert res2.is_nothing()
-
-def test_deduce_basic():
-    def test_basic(a1, a2, coff):
-        a = tvm.var('a')
-        b = tvm.var('b')
-        b_s = tvm.arith.intset_interval(a1, a2)
-        e0 = b + a*coff + 3
-
-        res1 = tvm.arith.DeduceBound(a, e0<17, {b: b_s}, {b: b_s})
-        [x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
-        assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1
-
-        # expression containing variable a is on rhs
-        res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s})
-        [x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
-        assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1
-
-        # expression containing variable a is on rhs
-        res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s})
-        [x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
-        assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1
-
-        res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s})
-        [x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
-        assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1
-
-    test_basic(0, 4, 4)
-    test_basic(1, 5, 4)
-    test_basic(2, 6, 4)
-    test_basic(0, 4, -4)
-    test_basic(1, 5, -4)
-    test_basic(2, 6, -4)
-
-def test_deduce_complex():
-    def test_complex(a1, a2, coff):
-        a = tvm.var('a')
-        b = tvm.var('b')
-        b_s = tvm.arith.intset_interval(a1, a2)
-        e0 = (b*3 + a* coff) * 4
-
-        res1 = tvm.arith.DeduceBound(a, e0<63, {b: b_s}, {b: b_s})
-        [t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
-        assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1
-
-        # expression containing variable a is on rhs
-        res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s})
-        [t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
-        assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1
-
-        res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s})
-        [t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
-        assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1
-
-        # expression containing variable a is on rhs
-        res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s})
-        [t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
-        assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1
-
-    test_complex(0, 4, 4)
-    test_complex(0, 4, -4)
-    test_complex(2, 6, 4)
-    test_complex(0, 4, -4)
-    test_complex(1, 5, -4)
-    test_complex(2, 6, -4)
+    assert s.min_value.value == base
+    assert s.max_value.value == base + stride * lanes - 1
+
+
+def test_add_sub():
+    ck = IntSetChecker()
+    x, y = tvm.var("x"), tvm.var("y")
+    ck.verify(x + y, {x : tvm.arith.IntervalSet(0, 10)}, (y, 10 + y))
+    ck.verify(x + y,
+              {x : tvm.arith.IntervalSet(0, 10), y : tvm.arith.IntervalSet(1, 11)},
+              (1, 21))
+    ck.verify(x - y,
+              {x : tvm.arith.IntervalSet(0, 10), y : tvm.arith.IntervalSet(1, 11)},
+              (-11, 9))
+
+def test_mul_div():
+    ck = IntSetChecker()
+    x, y = tvm.var("x"), tvm.var("y")
+    ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
+    ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y))
+    ck.verify(x * 2, {x : tvm.arith.IntervalSet(1, 10)}, (2, 20))
+    ck.verify(x * -2, {x : tvm.arith.IntervalSet(1, 10)}, (-20, -2))
+    ck.verify(x / y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 / y))
+    ck.verify(x / 2, {x : tvm.arith.IntervalSet(1, 10)}, (0, 5))
+
+
+def test_mod():
+    ck = IntSetChecker()
+    x, y = tvm.var("x"), tvm.var("y")
+    ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
+    ck.verify(x % y, {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
+    ck.verify(x % 10, {x : tvm.arith.IntervalSet(1, 10)}, (0, 9))
+
+def test_max_min():
+    ck = IntSetChecker()
+    x, y = tvm.var("x"), tvm.var("y")
+    ck.verify(tvm.max(x, x + 1), {x : tvm.arith.IntervalSet(0, 10)}, (1, 11))
+    ck.verify(tvm.min(x - 1, x + 1), {x : tvm.arith.IntervalSet(0, 10)}, (-1, 9))
+    ck.verify(tvm.min(x, y), {}, (tvm.min(x, y), tvm.min(x, y)))
+    ck.verify(tvm.max(x, y), {}, (tvm.max(x, y), tvm.max(x, y)))
+
+
+def test_select():
+    ck = IntSetChecker()
+    x, y = tvm.var("x"), tvm.var("y")
+    ck.verify(tvm.expr.Select(x > 0, x - 1, x + 1),
+              {x : tvm.arith.IntervalSet(0, 10)}, (-1, 11))
+
 
 if __name__ == "__main__":
     test_basic()
     test_vector()
-    test_deduce()
-    test_check()
-    test_deduce_basic()
-    test_deduce_complex()
+    test_add_sub()
+    test_mul_div()
+    test_max_min()
+    test_select()
+    test_mod()
+