[Arith] linear system and equation solver (#5171)
authorYizhi Liu <liuyizhi@apache.org>
Fri, 10 Apr 2020 15:11:21 +0000 (08:11 -0700)
committerGitHub <noreply@github.com>
Fri, 10 Apr 2020 15:11:21 +0000 (08:11 -0700)
* [arith] linear system and equation solver

Co-authored-by: Sergei Grechanik <sergei.grechanik+h@gmail.com>
* avoid constructing analyzer every time

* generate random test cases and address comments

Co-authored-by: Sergei Grechanik <sergei.grechanik@gmail.com>
* rename linear_system to int_constraints

* add comments and use random seed

* message for reporting failure with seed

* add SEqualReduce to IntConstraints; allow variables & ranges to be None

Co-authored-by: Sergei Grechanik <sergei.grechanik+h@gmail.com>
Co-authored-by: Sergei Grechanik <sergei.grechanik@gmail.com>
include/tvm/arith/analyzer.h
include/tvm/arith/int_solver.h [new file with mode: 0644]
include/tvm/arith/util.h [new file with mode: 0644]
python/tvm/arith/__init__.py
python/tvm/arith/int_solver.py [new file with mode: 0644]
src/arith/analyzer.cc
src/arith/int_constraints.cc [new file with mode: 0644]
src/arith/solve_linear_equation.cc [new file with mode: 0644]
src/arith/util.cc [new file with mode: 0644]
tests/python/unittest/test_arith_solve_linear_system.py [new file with mode: 0644]

index 1889e16..3a71e5e 100644 (file)
@@ -424,6 +424,12 @@ class Analyzer {
    */
   void Bind(const Var& var, const Range& range);
   /*!
+   * \brief Bind all the vars in the Map
+   *
+   * \param variables The {variable -> range} map.
+   */
+  void Bind(const Map<Var, Range>& variables);
+  /*!
    * \brief Whether can we prove expr >= val.
 
    *  Non-negative proof is very useful in integer analysis
diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h
new file mode 100644 (file)
index 0000000..57f3af4
--- /dev/null
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/arith/int_solver.h
+ * \brief integer constraints data structures and solvers
+ */
+#ifndef TVM_ARITH_INT_SOLVER_H_
+#define TVM_ARITH_INT_SOLVER_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/tir/expr.h>
+#include <unordered_map>
+#include <vector>
+
+namespace tvm {
+namespace arith {
+
+using tir::Var;
+using tir::VarNode;
+using tir::IterVar;
+
+/*!
+ * \brief Represent integer constrains including (integer) variables, their ranges and
+ *        the relations between them (either equations or inequalities).
+ * \sa LinearSystem
+ */
+class IntConstraintsNode : public Object {
+ public:
+  // e.g., \alpha, \beta, must be integers
+  Array<Var> variables;
+  // e.g., 1 <= \alpha <= N, etc.
+  // it is absolutely ok to include ranges for parameters
+  // (variables that are not in this->variables) in this map
+  Map<Var, Range> ranges;
+  // linear equalities or inequalities
+  // e.g., A \alpha = \beta or A \alpha <= \beta
+  Array<PrimExpr> relations;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("variables", &variables);
+    v->Visit("ranges", &ranges);
+    v->Visit("relations", &relations);
+  }
+
+  bool SEqualReduce(const IntConstraintsNode* other, SEqualReducer equal) const {
+    return
+        equal(variables, other->variables) &&
+        equal(ranges, other->ranges) &&
+        equal(relations, other->relations);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(variables);
+    hash_reduce(ranges);
+    hash_reduce(relations);
+  }
+
+  static constexpr const bool _type_has_method_sequal_reduce = true;
+  static constexpr const char* _type_key = "arith.IntConstraints";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object);
+};
+
+/*!
+ * \brief Managed reference to IntConstraintsNode.
+ * \sa IntConstraintsNode
+ */
+class IntConstraints : public ObjectRef {
+ public:
+  /*!
+   * \brief Constructor by fields
+   * \param variables The variables in the constraints, must be integers.
+   * \param ranges    The ranges of the variables.
+   * \param relations The linear relations between the variables
+   *                  (either equations or inequalities)
+   */
+  TVM_DLL IntConstraints(Array<Var> variables,
+                         Map<Var, Range> ranges,
+                         Array<PrimExpr> relations);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode);
+};
+
+/*!
+ * \brief We can have different set of variables to represent the same constraints.
+ *        For example, the following two systems are equivalent,
+ *        {a + b = 0 | a >= 0, b >= 0} and
+ *        {m - n = 0 | m >= 0, n <= 0}
+ *        This data structure represents the transformation
+ *        between two equivalent linear systems.
+ *        In the above example,
+ *        src        : {a + b = 0 | a >= 0, b >= 0}
+ *        dst        : {m - n = 0 | m >= 0, n <= 0}
+ *        src_to_dst : {a -> m, b -> -n}
+ *        dst_to_src : {m -> a, n -> -b}
+ * \sa IntConstraintsTransform
+ */
+class IntConstraintsTransformNode : public Object {
+ public:
+  IntConstraints src;
+  IntConstraints dst;
+  Map<Var, PrimExpr> src_to_dst;
+  Map<Var, PrimExpr> dst_to_src;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("src", &src);
+    v->Visit("dst", &dst);
+    v->Visit("src_to_dst", &src_to_dst);
+    v->Visit("dst_to_src", &dst_to_src);
+  }
+
+  bool SEqualReduce(const IntConstraintsTransformNode* other, SEqualReducer equal) const {
+    return
+        equal(src, other->src) &&
+        equal(dst, other->dst) &&
+        equal(src_to_dst, other->src_to_dst) &&
+        equal(dst_to_src, other->dst_to_src);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(src);
+    hash_reduce(dst);
+    hash_reduce(src_to_dst);
+    hash_reduce(dst_to_src);
+  }
+
+  static constexpr const bool _type_has_method_sequal_reduce = true;
+  static constexpr const char* _type_key = "arith.IntConstraintsTransform";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object);
+};
+
+/*!
+ * \brief Managed reference to IntConstraintsTransformNode.
+ * \sa IntConstraintsTransformNode
+ */
+class IntConstraintsTransform : public ObjectRef {
+ public:
+  /*!
+   * \brief Constructor by fields
+   * \param src        source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0}
+   * \param dst        integer constraints equivalent to the source,
+   *                   e.g., {m - n = 0 | m >= 0, n <= 0}
+   * \param src_to_dst mapping from variables in the \p src to the variables in the \p dst,
+   *                   e.g., {a -> m, b -> -n}
+   * \param dst_to_src mapping from variables in the \p dst to the variables in the \p src,
+   *                   e.g., {m -> a, n -> -b}
+   */
+  TVM_DLL IntConstraintsTransform(IntConstraints src,
+                                  IntConstraints dst,
+                                  Map<Var, PrimExpr> src_to_dst,
+                                  Map<Var, PrimExpr> dst_to_src);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode);
+};
+
+/*!
+ * \brief Obtain Smith Normal Form of linear equation A x = y.
+ *        Smith Normal Form of matrix A_{mxn} is S_{mxn} = U_{mxm} A_{mxn} V_{nxn},
+ *        in which S_{mxn} is diag(s1, s2, ..., sr, 0, ..., 0) and r is the rank of A.
+ *        NOTE: Although in standard Smith Normal Form the diagonal elements satisfy
+ *              s_i | s_{i+1} (| means divides), the implement here does not guarantee it.
+ *        TODO(yzhliu): From sergei-grechanik:
+ *          computing the proper Smith normal form may improve stability of automatic differentiation
+ *          (generating the same gradient code for slightly different but equivalent input code
+ *        U_{mxm} and V_{nxn} are invertible matrices.
+ *        This function modifies \p S to be S_{mxn}, \p V to be V_{nxn},
+ *        \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x.
+ * \param S  the original A_{mxn}, it will be modified to S_{mxn}
+ * \param V  an identity matrix, it will be modified to V_{nxn}
+ * \param x  the x in A x = y. it will be modified to V^{-1}_{nxn} x_{nx1}
+ * \param y  the y in A x = y. it will be modified to U_{mxm} y_{mx1}
+ */
+void SmithNormalFormDiag(std::vector<std::vector<int64_t>> *S,
+                         std::vector<std::vector<int64_t>> *V,
+                         std::vector<PrimExpr>* x,
+                         std::vector<PrimExpr> *y);
+
+/*!
+ * \brief Solve linear equations.
+ * \param system_to_solve the variables to solve, their ranges, and a list of equations.
+ * \return  A new linear system, with less variables (if \p system_to_solve is NOT of full rank),
+ *          or no variable (if \p system_to_solve is of full rank),
+ *          or an empty linear system (if \p system_to_solve is unsolvable).
+ *          It also provides the ranges of the variables in the new system,
+ *          as well as inequalities inferred from the \p system_to_solve.
+ *          You can get the mapping from the original variables to the solution via ret->src_to_dst.
+ */
+IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve);
+
+}  // namespace arith
+}  // namespace tvm
+#endif  // TVM_ARITH_INT_SOLVER_H_
diff --git a/include/tvm/arith/util.h b/include/tvm/arith/util.h
new file mode 100644 (file)
index 0000000..adfcefc
--- /dev/null
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/arith/util.h
+ * \brief Utils for arithmetic analysis.
+ */
+#ifndef TVM_ARITH_UTIL_H_
+#define TVM_ARITH_UTIL_H_
+
+#include <cstdint>
+#include <tuple>
+
+namespace tvm {
+/*! \brief namespace of arithmetic analysis. */
+namespace arith {
+
+/*!
+ * \brief Calculate the extended greatest common divisor for two values.
+ *        See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm.
+ * \param a an integer number
+ * \param b an integer number
+ * \return 3 integers (div, m, n) where div = gcd(a, b) and a*m + b*n = div
+ */
+std::tuple<int64_t, int64_t, int64_t> xgcd(int64_t a, int64_t b);
+
+}  // namespace arith
+}  // namespace tvm
+#endif  // TVM_ARITH_UTIL_H_
index 40e977e..017934a 100644 (file)
@@ -20,3 +20,4 @@ from .int_set import IntSet, IntervalSet
 from .analyzer import ModularSet, ConstIntBound, Analyzer
 from .bound import deduce_bound
 from .pattern import detect_linear_equation, detect_clip_bound
+from .int_solver import solve_linear_equations
diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py
new file mode 100644 (file)
index 0000000..e35435c
--- /dev/null
@@ -0,0 +1,99 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""integer constraints data structures and solvers"""
+import tvm._ffi
+from tvm.runtime import Object
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("arith.IntConstraints")
+class IntConstraints(Object):
+    """Represent a set of integer constraints including variables, their ranges and
+       the relations between them (either equations or inequalities)
+
+    Parameters
+    ----------
+    variables : List[tvm.tir.Var]
+        The variables in the constraints. Must be integers
+    ranges    : Map[tvm.tir.Var, tvm.ir.Range]
+        The ranges of the variables.
+    relations : List[tvm.ir.PrimExpr]
+        The relations between the variables (either equations or inequalities)
+    """
+    def __init__(self, variables, ranges, relations):
+        self.__init_handle_by_constructor__(
+            _ffi_api.IntConstraints, variables, ranges, relations)
+
+
+@tvm._ffi.register_object("arith.IntConstraintsTransform")
+class IntConstraintsTransform(Object):
+    """We can have different set of variables to represent the same integer constraints.
+       For example, the following two constrains are equivalent,
+       {a + b = 0 | a >= 0, b >= 0} and
+       {m - n = 0 | m >= 0, n <= 0}
+       This data structure represents the transformation
+       between two equivalent integer constraints.
+       In the above example,
+       src        : {a + b = 0 | a >= 0, b >= 0}
+       dst        : {m - n = 0 | m >= 0, n <= 0}
+       src_to_dst : {a -> m, b -> -n}
+       dst_to_src : {m -> a, n -> -b}
+
+    Parameters
+    ----------
+    src : arith.IntConstraints
+        source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0}
+    dst : arith.IntConstraints
+        integer constraints equivalent to the source, e.g., {m - n = 0 | m >= 0, n <= 0}
+    src_to_dst : Map[tvm.tir.Var, tvm.ir.PrimExpr]
+        mapping from variables in the src to the variables in the dst,
+                e.g., {a -> m, b -> -n}
+    dst_to_src : Map[tvm.tir.Var, tvm.ir.PrimExpr]
+        mapping from variables in the dst to the variables in the src,
+        e.g., {m -> a, n -> -b}
+    """
+    def __init__(self, src, dst, src_to_dst, dst_to_src):
+        self.__init_handle_by_constructor__(
+            _ffi_api.IntConstraintsTransform, src, dst, src_to_dst, dst_to_src)
+
+
+def solve_linear_equations(equations, variables=None, ranges=None):
+    """Solve linear equations.
+
+    Parameters
+    ----------
+    equations: List[tvm.ir.PrimExpr] or IntConstraints
+        The equations of the variables
+    variables : Optional[List[tvm.tir.Var]]
+        The variables in the system.
+    ranges    : Optional[Map[tvm.tir.Var, tvm.ir.Range]]
+        The ranges of the variables.
+
+    Returns
+    -------
+    int_constraints_transform : IntConstraintsTransform
+        New integer constraints, with less variables (if the problem is NOT of full rank),
+        or no variable (if the problem is of full rank),
+        or an empty integer constraints (if the problem is unsolvable).
+        It also provides the ranges of the variables in the new system,
+        as well as inequalities inferred from the problem.
+        You can get the mapping from the original variables to the solution via
+        int_constraints_transform.src_to_dst.
+    """
+    if isinstance(equations, IntConstraints):
+        return _ffi_api.SolveLinearEquations(equations)
+    return _ffi_api.SolveLinearEquations(variables, ranges, equations)
index 9df5aa2..83dfc64 100644 (file)
@@ -58,6 +58,11 @@ void Analyzer::Bind(const Var& var, const Range& range) {
   // skip rewrite simplify
 }
 
+void Analyzer::Bind(const Map<Var, Range>& variables) {
+  for (const auto& iter : variables) {
+    this->Bind(iter.first, iter.second);
+  }
+}
 
 void ConstraintContext::EnterWithScope() {
   CHECK(exit_ == nullptr);
diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc
new file mode 100644 (file)
index 0000000..34efa98
--- /dev/null
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file int_constraints.cc
+ * \brief The integer constraints data structures.
+ */
+#include <tvm/arith/int_solver.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/runtime/registry.h>
+
+#include <utility>
+#include <algorithm>
+#include <unordered_map>
+
+namespace tvm {
+namespace arith {
+
+IntConstraints::IntConstraints(Array<Var> variables,
+                               Map<Var, Range> ranges,
+                               Array<PrimExpr> relations) {
+  ObjectPtr<IntConstraintsNode> node = make_object<IntConstraintsNode>();
+  if (!variables.defined()) {
+    variables = Array<Var>();
+  }
+  if (!ranges.defined()) {
+    ranges = Map<Var, Range>();
+  }
+  CHECK(relations.defined());
+  for (const auto& var : variables) {
+    CHECK(var.dtype().is_int() || var.dtype().is_uint())
+      << "Variables in IntConstraints must be integers";
+  }
+  node->variables = std::move(variables);
+  node->ranges = std::move(ranges);
+  node->relations = std::move(relations);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_NODE_TYPE(IntConstraintsNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+.set_dispatch<IntConstraintsNode>([](const ObjectRef& node, ReprPrinter* p) {
+    auto* op = static_cast<const IntConstraintsNode*>(node.get());
+    p->stream << "IntConstraints("
+              << op->variables
+              << ", " << op->ranges
+              << ", " << op->relations
+              << ")";
+  });
+
+
+IntConstraintsTransform::IntConstraintsTransform(IntConstraints src,
+                                                 IntConstraints dst,
+                                                 Map<Var, PrimExpr> src_to_dst,
+                                                 Map<Var, PrimExpr> dst_to_src) {
+  ObjectPtr<IntConstraintsTransformNode> node = make_object<IntConstraintsTransformNode>();
+  node->src = std::move(src);
+  node->dst = std::move(dst);
+  node->src_to_dst = std::move(src_to_dst);
+  node->dst_to_src = std::move(dst_to_src);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+.set_dispatch<IntConstraintsTransformNode>([](const ObjectRef& node, ReprPrinter* p) {
+    auto* op = static_cast<const IntConstraintsTransformNode*>(node.get());
+    p->stream << "IntConstraintsTransform("
+              << "\n\t" << op->src
+              << "\n\t" << op->dst
+              << "\n\t" << op->src_to_dst
+              << "\n\t" << op->dst_to_src
+              << "\n)";
+  });
+
+}  // namespace arith
+}  // namespace tvm
diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc
new file mode 100644 (file)
index 0000000..8142a03
--- /dev/null
@@ -0,0 +1,480 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/arith/solve_linear_equation.cc
+ * \brief Solve linear equations.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/arith/util.h>
+#include <tvm/tir/op.h>
+#include <tvm/arith/pattern.h>
+#include <tvm/tir/ir_pass.h>
+#include <tvm/runtime/data_type.h>
+
+namespace tvm {
+namespace arith {
+
+using namespace tvm::runtime;
+
+void SmithNormalFormDiag(std::vector<std::vector<int64_t> >* S,
+                         std::vector<std::vector<int64_t> >* V,
+                         std::vector<PrimExpr>* x,
+                         std::vector<PrimExpr>* y) {
+  if (S->empty() || V->empty()) return;
+  size_t m = S->size();
+  size_t n = (*S)[0].size();  // n is # of variables
+  CHECK_EQ(V->size(), n);
+  CHECK_EQ((*V)[0].size(), n);
+
+  for (size_t index = 0; index < std::min(m, n); ++index) {
+    // Here A is partially diagonalized, that is A[i, j] is zero for all i, j
+    // such that (i < index) or (j < index), unless (i == j).
+    // That is, now we are diagonalizing the submatrix with i >= index and j >= index
+
+    // Find a row with a nonzero element in the index-th column
+    // (We also prefer rows where this element has minimal abs value)
+    size_t best_i = index;
+    for (size_t i = best_i; i < m; ++i) {
+      int64_t s_old = (*S)[best_i][index];
+      int64_t s_new = (*S)[i][index];
+      if (s_new != 0) {
+        if (s_old == 0 || std::abs(s_new) < std::abs(s_old)) {
+          best_i = i;
+        }
+      }
+    }
+    // Move the row we found to the index-th position
+    std::swap((*S)[index], (*S)[best_i]);
+    std::swap((*y)[index], (*y)[best_i]);
+
+    // If the index-th diagonal element is still zero, try to find a column with nonzero index-th
+    // element and move it to the index-th position
+    if ((*S)[index][index] == 0) {
+      for (size_t j = index + 1; j < n; ++j) {
+        if ((*S)[index][j] != 0) {
+          for (size_t i = index; i < m; ++i) {
+            std::swap((*S)[i][index], (*S)[i][j]);
+          }
+          // swapping columns corresponds to swapping the corresponding x
+          std::swap((*x)[index], (*x)[j]);
+          for (size_t i = 0; i < n; ++i) {
+            std::swap((*V)[i][index], (*V)[i][j]);
+          }
+          break;
+        }
+      }
+    }
+
+    // If the index-th diagonal element is still zero, then both the index-th row and the index-th
+    // column are completely zero, and we don't need to do anything; just go to the next index
+    if ((*S)[index][index] == 0) {
+      continue;
+    }
+
+    // Now the index-th diagonal element is non-zero and we can zero all the index-th column
+    // below it by subtracting rows from each other
+    for (auto i = index + 1; i < m; ++i) {
+      if ((*S)[i][index] != 0) {
+        int64_t g, a, b;
+        // g = a*matrix[index][index] + b*matrix[i][index]
+        if ((*S)[i][index] % (*S)[index][index] != 0) {
+          std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[i][index]);
+        } else {
+          // Explicitly avoid changing the index-th row. This is important to avoid infinite loop.
+          g = (*S)[index][index];
+          a = 1;
+          b = 0;
+        }
+
+        // Let m = S[index][index], n = S[i][index], then the following is true:
+        //
+        // [ a   n/g ][ m/g  n/g ] = [ 1  0 ]
+        // [ b  -m/g ][ b    -a  ] = [ 0  1 ]
+        //
+        // Note that the two matrices are integer (since g = gcd(m, n)).
+        // We will essentially multiply our matrix on the left by a dilated and transposed version
+        // of the first of these two matrices. The second matrix is not needed here, however we will
+        // use it while zeroing the index-th row.
+
+        int64_t m_g = (*S)[index][index] / g;
+        int64_t n_g = (*S)[i][index] / g;
+
+        // Note that j is the index of the column, not the row
+        for (size_t j = index; j < (*S)[i].size(); ++j) {
+          // Multiply index-th row by a and add the i-th row multiplied by b
+          // This will make the index-th diagonal element equal to the gcd
+          int64_t new_index_j = a*(*S)[index][j] + b*(*S)[i][j];
+          // This transformation performs zeroing of matrix[i][index]
+          int64_t new_i_j = n_g*(*S)[index][j] - m_g*(*S)[i][j];
+          (*S)[index][j] = new_index_j;
+          (*S)[i][j] = new_i_j;
+        }
+        // We have to do the same with rhs
+        PrimExpr ea = te::make_const((*y)[index].dtype(), a);
+        PrimExpr eb = te::make_const((*y)[i].dtype(), b);
+        PrimExpr e_m_g = te::make_const((*y)[i].dtype(), m_g);
+        PrimExpr e_n_g = te::make_const((*y)[index].dtype(), n_g);
+        PrimExpr new_index_rhs = ea*(*y)[index] + eb*(*y)[i];
+        PrimExpr new_i_rhs = e_n_g*(*y)[index] - e_m_g*(*y)[i];
+        (*y)[index] = new_index_rhs;
+        (*y)[i] = new_i_rhs;
+      }
+    }
+
+    bool changed = false;
+
+    // Now we have to zero the elements of the index-th row by manipulating columns.
+    // This is more difficult because column manipulation corresponds to variable manipulation,
+    // but the algorithm is essentially the same as before.
+    for (size_t j = index + 1; j < n; ++j) {
+      if ((*S)[index][j] != 0) {
+        int64_t g, a, b;
+        // g = a*matrix[index][index] + b*matrix[index][j]
+        if ((*S)[index][j] % (*S)[index][index] != 0) {
+          std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[index][j]);
+          // During this phase we may disrupt the zeroness of the index-th column, so we will
+          // have to take some action if this might have happened.
+          changed = true;
+        } else {
+          // Explicitly avoid changing the index-th column. This is important to avoid infinite
+          // loop. Note that here we don't have to set `changed` to true since we don't change the
+          // index-th column.
+          g = (*S)[index][index];
+          a = 1;
+          b = 0;
+        }
+
+        // Let m = S[index][index], n = S[index][j], then the following is true:
+        //
+        // [ a   n/g ][ m/g  n/g ] = [ 1  0 ]
+        // [ b  -m/g ][ b    -a  ] = [ 0  1 ]
+        //
+        // Now we are going to multiply our matrix on the right (to manipulate columns instead of
+        // rows), we will also transform the old_to_new matrix the same way, and we will use the
+        // second matrix to transform new_to_old.
+
+        int64_t m_g = (*S)[index][index] / g;
+        int64_t n_g = (*S)[index][j] / g;
+
+        for (size_t i = index; i < m; ++i) {
+          int64_t new_i_index = a*(*S)[i][index] + b*(*S)[i][j];
+          int64_t new_i_j = n_g*(*S)[i][index] - m_g*(*S)[i][j];
+          (*S)[i][index] = new_i_index;
+          (*S)[i][j] = new_i_j;
+        }
+        // We do exactly the same transformations with V
+        for (size_t i = 0; i < n; ++i) {
+          int64_t new_i_index = a*(*V)[i][index] + b*(*V)[i][j];
+          int64_t new_i_j = n_g*(*V)[i][index] - m_g*(*V)[i][j];
+          (*V)[i][index] = new_i_index;
+          (*V)[i][j] = new_i_j;
+        }
+        // And apply reverse transformations to new_to_old.
+        PrimExpr ea = te::make_const((*x)[j].dtype(), a);
+        PrimExpr eb = te::make_const((*x)[index].dtype(), b);
+        PrimExpr e_m_g = te::make_const((*x)[index].dtype(), m_g);
+        PrimExpr e_n_g = te::make_const((*x)[j].dtype(), n_g);
+        PrimExpr new_index = e_m_g*(*x)[index] + e_n_g*(*x)[j];
+        PrimExpr new_j = eb*(*x)[index] - ea*(*x)[j];
+        (*x)[index] = new_index;
+        (*x)[j] = new_j;
+      }
+    }
+
+    if (changed) {
+      // We might have changed the first column, so we have to zero it once more
+      // (or at least check if it's zero), so just perform this iteration once more.
+      index -= 1;
+    }
+  }
+}
+
+Map<Var, Range> InferRange(const Map<Var, PrimExpr>& vars_to_infer,
+                           const Array<Var>& ori_vars,
+                           const Map<Var, Range>& ori_ranges) {
+  // The resulting ranges
+  Map<Var, Range> new_ranges;
+
+  std::unordered_set<const VarNode*> ori_vset;
+  for (const Var& v : ori_vars) {
+    ori_vset.insert(v.get());
+  }
+
+  std::unordered_map<const VarNode*, IntSet> var_intsets;
+  for (const auto& p : ori_ranges) {
+    if (!ori_vset.count(p.first.get())) {
+      // First of all, fill the new ranges with outer variable ranges
+      new_ranges.Set(p.first, p.second);
+    }
+    // Convert original ranges to IntSets
+    var_intsets[p.first.get()] = IntSet::range(p.second);
+  }
+
+  // Infer ranges for the new variables and add them to the resulting ranges
+  for (const auto& p : vars_to_infer) {
+    const auto& var = p.first;
+    const auto& expr = p.second;
+    Range range = EvalSet(expr, var_intsets).cover_range(Range());
+    if (range.defined()) {
+      new_ranges.Set(var, range);
+    }
+  }
+  return new_ranges;
+}
+
+// pretty print matrix equation
+void DebugPrint(const std::vector<std::vector<int64_t>>& S,
+                const std::vector<std::vector<int64_t>>& V,
+                const std::vector<PrimExpr>& V_inv_x,
+                const std::vector<PrimExpr>& rhs) {
+  std::cout << "S:\n";
+  for (size_t i = 0; i < S.size(); ++i) {
+    for (auto e : S[i]) {
+      std::cout << e << "\t";
+    }
+    std::cout << "\t->\t" << rhs[i];
+    std::cout << "\n";
+  }
+  std::cout << "V:\n";
+  for (const auto& r : V) {
+    for (auto e : r) {
+      std::cout << e << "\t";
+    }
+    std::cout << "\n";
+  }
+  std::cout << "V_inv x:\n" << Array<PrimExpr>(V_inv_x);
+  std::cout << "\n" << std::endl;
+}
+
+IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve) {
+  // m: # of equations
+  // n: # of variables
+  // we first construct A_{mxn} x_{nx1} = y_{mx1}
+  // then get Smith normal form of matrix A,
+  // S_{mxn} = U_{mxm} A_{mxn} V_{nxn}
+  // => U^{-1} S V^{-1} x = y
+  // S V^{-1} x = U y
+  std::vector<PrimExpr> Uy;  // mx1
+  std::vector<std::vector<int64_t>> S;  // mxn
+  std::vector<std::vector<int64_t>> V;  // nxn
+  std::vector<PrimExpr> V_inv_x;  // V^{-1} x, nx1
+  // Conditions we don't know what to do with
+  std::vector<PrimExpr> rest;
+
+  Analyzer analyzer_problem;
+  analyzer_problem.Bind(system_to_solve->ranges);
+
+  size_t num_vars = system_to_solve->variables.size();
+
+  // initialize V_{nxn} with identity matrix,
+  // initialize V^{-1} x as x
+  for (size_t i = 0; i < num_vars; ++i) {
+    V.emplace_back(num_vars);
+    V.back()[i] = 1;
+    V_inv_x.push_back(system_to_solve->variables[i]);
+  }
+
+  // Transform formulas into rows of the matrix
+  // S_{mxn} V^{-1}_{nxn} x_{nx1} = U y, in which n is # of variables
+  // here we initialize S_{mxn} to be A, U to be identity matrix.
+  for (const PrimExpr& equation : system_to_solve->relations) {
+    if (const tir::EQNode* eq = equation.as<tir::EQNode>()) {
+      // a-b = sum_{i=0}^{n-1} variables[i] * coeff[i] + coeff[n]
+      Array<PrimExpr> coeffs = arith::DetectLinearEquation(
+          analyzer_problem.Simplify(eq->a - eq->b),
+          system_to_solve->variables);
+      if (!coeffs.empty()) {
+        std::vector<int64_t> row;
+        for (size_t j = 0; j < coeffs.size() - 1; ++j) {
+          PrimExpr c = coeffs[j];
+          if (const IntImmNode* ic = c.as<IntImmNode>()) {
+            row.push_back(ic->value);
+          } else {
+            // elements in matrix S V must be integers
+            // ignore equations that we cannot deal with.
+            LOG(WARNING) << "Cannot deal with non-integer coefficients, ignore equation "
+                         << equation;
+            row.clear();
+            break;
+          }
+        }
+
+        if (!row.empty()) {
+          // S V^{-1} (a-b) = Uy
+          // V is identity for now
+          S.push_back(row);
+          Uy.push_back(-coeffs[coeffs.size() - 1]);
+          continue;
+        }
+      }
+    }
+
+    // otherwise
+    rest.push_back(equation);
+  }
+
+  // After diagonalizing, we have
+  // S_{mxn} is the Smith normal form (diagonal matrix)
+  // V_{nxn} is invertible
+  // V_inv_x is V^{-1} \times x
+  // Uy is U \times y
+  SmithNormalFormDiag(&S, &V, &V_inv_x, &Uy);
+
+  Array<Var> new_vars;
+  Array<PrimExpr> new_relations;
+  Map<Var, PrimExpr> new_to_old_map;
+  Map<Var, PrimExpr> old_to_new_map;
+
+  // Simplify right hand sides
+  for (PrimExpr r : Uy) {
+    r = analyzer_problem.Simplify(r);
+  }
+
+  // Create the relations of the existence of a solution
+  for (size_t j = 0; j < S.size(); ++j) {
+    PrimExpr new_relation;
+    if (j >= num_vars || S[j][j] == 0) {
+      // The row of matrix is zero. A solution exists only if the Ub[j] is also zero
+      new_relation = (Uy[j] == 0);
+    } else {
+      // The diagonal element is non-zero. A solution exists only if the diagonal element
+      // is a divisor of the Ub[j]
+      new_relation = (floormod(Uy[j], std::abs(S[j][j])) == 0);
+    }
+    new_relation = analyzer_problem.Simplify(new_relation);
+    if (tir::is_const_int(new_relation, 0)) {
+      // unable to solve the system.
+      return IntConstraintsTransform(
+          system_to_solve,
+          IntConstraints(
+              /*variables=*/{},
+              /*ranges=*/{},
+              /*relations=*/{te::make_zero(DataType::Bool())}),
+          {}, {});
+    } else if (!tir::is_const_int(new_relation, 1)) {
+      new_relations.push_back(new_relation);
+    }
+  }
+
+  Array<PrimExpr> solution_for_V_inv_x;
+  // Now create new variables or directly solve the equations
+  // suppose the rank of A is r, aka r = # of non-zeros in S
+  // the solution of S_{mxn} V^{-1}_{nxn} x_{nx1} = U b
+  // is
+  // x = (pseudo-inverse of A) b + K_{(n)x(n-r)} z_{n-r}
+  //   = V_{nxn} S^{-1}_{nxm} (Ub)_{mxn} + K_{(n)x(n-r)} z_{n-r}
+  // in which K is the right n-r columns of V, z is variable vector
+  // thus,
+  // V^{-1} x = S^{-1}_{nxm} (Ub)_{mxn} +
+  //            [[0, ... 0]_{n-r}, ... [0, ..., 0], diag(1, ..., 1)_{(n-r)x(n-r)}] z_{n-r}
+  for (size_t j = 0; j < num_vars; ++j) {
+    if (j >= S.size() || S[j][j] == 0) {
+      // The j-th variable can take any integer value, create a tvm variable for it
+      PrimExpr to_old = analyzer_problem.Simplify(V_inv_x[j]);
+      std::string name_hint = "n" + std::to_string(new_vars.size());
+      if (const VarNode* v_old = to_old.as<VarNode>()) {
+        name_hint += "_" + v_old->name_hint;
+      }
+      Var v = Var(name_hint, V_inv_x[j].dtype());
+      solution_for_V_inv_x.push_back(v);
+      new_vars.push_back(v);
+      new_to_old_map.Set(v, to_old);
+    } else {
+      // The j-th variable is just a single value, don't create a tvm variable
+      // S^{-1}_{nxm} Uy_{mxn}
+      if (S[j][j] >= 0) {
+        PrimExpr a = te::make_const(Uy[j].dtype(), S[j][j]);
+        solution_for_V_inv_x.push_back(
+            analyzer_problem.Simplify(floordiv(Uy[j], a)));
+      } else {
+        // This is required because some simplifiers
+        // have problems with dividing by negative numbers
+        PrimExpr a = te::make_const(Uy[j].dtype(), -S[j][j]);
+        solution_for_V_inv_x.push_back(
+            analyzer_problem.Simplify(floordiv(-Uy[j], a)));
+      }
+    }
+  }
+
+  // V V^{-1} x = x
+  for (size_t i = 0; i < num_vars; ++i) {
+    PrimExpr e = te::make_zero(system_to_solve->variables[i].dtype());
+    for (size_t j = 0; j < num_vars; ++j) {
+      e = e + te::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j];
+    }
+    e = analyzer_problem.Simplify(e);
+    old_to_new_map.Set(system_to_solve->variables[i], e);
+  }
+
+  // The resulting ranges
+  Map<Var, Range> new_ranges = InferRange(
+      new_to_old_map, system_to_solve->variables, system_to_solve->ranges);
+  Analyzer analyzer_solution;
+  analyzer_solution.Bind(new_ranges);
+
+  // We have to transform ranges of the old variables into relations over new variables because
+  // new ranges are not enough usually.
+  for (const auto& p : system_to_solve->ranges) {
+    const Var& old_var = p.first;
+    const Range& old_range = p.second;
+    if (old_to_new_map.count(old_var)) {
+      PrimExpr express_by_new_vars = old_to_new_map[old_var];
+      PrimExpr lower_cond = analyzer_solution.Simplify(
+          old_range->min <= express_by_new_vars);
+      PrimExpr upper_cond = analyzer_solution.Simplify(
+          express_by_new_vars < old_range->min + old_range->extent);
+      if (!tir::is_const_int(lower_cond, 1)) {
+        new_relations.push_back(lower_cond);
+      }
+      if (!tir::is_const_int(upper_cond, 1)) {
+        new_relations.push_back(upper_cond);
+      }
+    }
+  }
+
+  // Add the rest conditions
+  for (const PrimExpr& cond : rest) {
+    new_relations.push_back(Substitute(cond, old_to_new_map));
+  }
+
+  IntConstraints solution(new_vars, new_ranges, new_relations);
+  IntConstraintsTransform transform(
+      system_to_solve, solution, old_to_new_map, new_to_old_map);
+
+  return transform;
+}
+
+TVM_REGISTER_GLOBAL("arith.SolveLinearEquations")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+    if (args.size() == 1) {
+      *ret = SolveLinearEquations(args[0]);
+    } else if (args.size() == 3) {
+      IntConstraints problem(args[0], args[1], args[2]);
+      *ret = SolveLinearEquations(problem);
+    } else {
+      LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size();
+    }
+  });
+
+}  // namespace arith
+}  // namespace tvm
diff --git a/src/arith/util.cc b/src/arith/util.cc
new file mode 100644 (file)
index 0000000..058c3e9
--- /dev/null
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file util.cc
+ * \brief The utils for arithmetic analysis.
+ */
+#include <tvm/arith/util.h>
+#include <dmlc/logging.h>
+
+namespace tvm {
+namespace arith {
+
+std::tuple<int64_t, int64_t, int64_t> xgcd(int64_t a, int64_t b) {
+  int64_t s = 0, old_s = 1;
+  int64_t t = 1, old_t = 0;
+  int64_t r = b, old_r = a;
+
+  while (r != 0) {
+    int64_t q = old_r / r;
+    std::swap(r, old_r);
+    r -= q * old_r;
+    std::swap(s, old_s);
+    s -= q * old_s;
+    std::swap(t, old_t);
+    t -= q * old_t;
+  }
+
+  CHECK_EQ(a % old_r, 0);
+  CHECK_EQ(b % old_r, 0);
+  CHECK(old_r == old_s*a + old_t*b);
+
+  return std::make_tuple(old_r, old_s, old_t);
+}
+
+}  // namespace arith
+}  // namespace tvm
diff --git a/tests/python/unittest/test_arith_solve_linear_system.py b/tests/python/unittest/test_arith_solve_linear_system.py
new file mode 100644 (file)
index 0000000..45f8fc1
--- /dev/null
@@ -0,0 +1,237 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import random
+import numpy as np
+import sys
+import pytest
+import tvm
+from tvm import te, arith, ir, tir
+
+
+def run_expr(expr, vranges):
+    """ Evaluate expr for every value of free variables
+    given by vranges and return the tensor of results.
+    TODO(yzhliu): move to utils
+    """
+    def _compute_body(*us):
+        vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)}
+        return tir.ir_pass.Substitute(expr, vmap)
+
+    A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body)
+    args = [tvm.nd.empty(A.shape, A.dtype)]
+    sch = te.create_schedule(A.op)
+    mod = tvm.build(sch, [A])
+    mod(*args)
+    return args[0].asnumpy()
+
+
+def check_bruteforce(bool_expr, vranges, cond=None):
+    """ Check that bool_expr holds given the condition cond
+    for every value of free variables from vranges.
+    TODO(yzhliu): move to utils
+    """
+    if cond is not None:
+        bool_expr = te.any(tir.Not(cond), bool_expr)
+
+    res = run_expr(bool_expr, vranges)
+    if not np.all(res):
+        indices = list(np.argwhere(res == 0)[0])
+        counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)]
+        counterex = sorted(counterex, key=lambda x: x[0])
+        counterex = ", ".join([v + " = " + str(i) for v, i in counterex])
+        raise AssertionError("Expression {}\nis not true on {}\n"
+                             "Counterexample: {}"
+                             .format(tir.ir_pass.CanonicalSimplify(bool_expr), vranges, counterex))
+
+
+def check_solution(solution, vranges={}):
+    """Check that solution is a bijective transformation"""
+    def _check_forward(constraints1, constraints2, varmap, backvarmap):
+        all_vranges = vranges.copy()
+        all_vranges.update({v: r for v, r in constraints1.ranges.items()})
+
+        # Check that the transformation is injective
+        cond_on_vars = tir.const(1, 'bool')
+        for v in constraints1.variables:
+            # variable mapping is consistent
+            v_back = tir.ir_pass.Simplify(tir.ir_pass.Substitute(varmap[v], backvarmap))
+            cond_on_vars = te.all(cond_on_vars, v == v_back)
+        # Also we have to check that the new relations are true when old relations are true
+        cond_subst = tir.ir_pass.Substitute(
+            te.all(tir.const(1, 'bool'), *constraints2.relations), backvarmap)
+        # We have to include relations from vranges too
+        for v in constraints2.variables:
+            if v in constraints2.ranges:
+                r = constraints2.ranges[v]
+                range_cond = te.all(v >= r.min, v < r.min + r.extent)
+                range_cond = tir.ir_pass.Substitute(range_cond, backvarmap)
+                cond_subst = te.all(cond_subst, range_cond)
+        cond_subst = tir.ir_pass.Simplify(cond_subst)
+        check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges,
+                         cond=te.all(tir.const(1, 'bool'), *constraints1.relations))
+
+    rels = solution.dst.relations
+    if len(rels) == 1 and ir.structural_equal(rels[0], False):
+        # not solvable, skip
+        return
+    _check_forward(solution.src, solution.dst,
+                   solution.src_to_dst, solution.dst_to_src)
+    _check_forward(solution.dst, solution.src,
+                   solution.dst_to_src, solution.src_to_dst)
+
+
+def test_solution_consistency():
+    seed = random.randrange(sys.maxsize)
+    print("\nThis test is intentionally non-deterministic, "
+          "if it fails please report it in github issue together with this seed {}\n".format(seed))
+    random.seed(seed)
+
+    def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)):
+        variables = [te.var("x" + str(i)) for i in range(num_vars)]
+
+        relations = []
+        for i in range(num_formulas):
+            s1 = sum([v*random.randint(coef[0], coef[1]) for v in variables])
+            s1 += random.randint(coef[0], coef[1])
+            s2 = sum([v*random.randint(coef[0], coef[1]) for v in variables])
+            s2 += random.randint(coef[0], coef[1])
+            if random.random() < 0.7:
+                op = tvm.tir.EQ
+            else:
+                # we also make sure it can correctly handle inequalities
+                op = random.choice([tvm.tir.LE, tvm.tir.LT, tvm.tir.GE, tvm.tir.GT])
+            relations.append(op(s1, s2))
+
+        vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in variables}
+        solution = arith.solve_linear_equations(relations, variables, vranges)
+
+        check_solution(solution)
+
+        # leaving some variables as parameters should also be ok
+        for k in [1, 2]:
+            if len(variables) > k:
+                solution = arith.solve_linear_equations(relations, variables[:-k], vranges)
+                param_ranges = {v: vranges[v] for v in variables[-k:]}
+                check_solution(solution, param_ranges)
+
+    for i in range(2):
+        _check(num_vars=1, num_formulas=1)
+    for i in range(2):
+        _check(num_vars=1, num_formulas=2)
+
+    for i in range(2):
+        _check(num_vars=2, num_formulas=1)
+    for i in range(2):
+        _check(num_vars=2, num_formulas=2)
+    for i in range(2):
+        _check(num_vars=2, num_formulas=3)
+
+    for i in range(3):
+        _check(num_vars=3, num_formulas=3, coef=(-2, 2))
+    for i in range(3):
+        _check(num_vars=3, num_formulas=4, coef=(-2, 2))
+
+    for i in range(3):
+        _check(num_vars=4, num_formulas=3, coef=(-1, 1))
+
+    for i in range(3):
+        _check(num_vars=10, num_formulas=2, coef=(-1, 1), bounds=(0, 4))
+    for i in range(3):
+        _check(num_vars=10, num_formulas=3, coef=(0, 1), bounds=(0, 4))
+
+
+def test_empty_var_to_solve():
+    x, y = te.var("x"), te.var("y")
+    equations = [
+        tvm.tir.EQ(x + y, 20),
+        tvm.tir.EQ(x - y, 10),
+    ]
+    solution = arith.solve_linear_equations(equations)
+    assert len(solution.src_to_dst) == 0
+    assert len(solution.dst_to_src) == 0
+    assert len(solution.src.variables) == 0
+    assert len(solution.src.ranges) == 0
+    assert ir.structural_equal(solution.src.relations, equations)
+    assert ir.structural_equal(solution.src, solution.dst)
+
+
+def test_unique_solution():
+    x, y = te.var("x"), te.var("y")
+
+    solution = arith.solve_linear_equations([
+        tvm.tir.EQ(x + y, 20),
+        tvm.tir.EQ(x - y, 10),
+    ], [x, y])
+    assert list(solution.dst.variables) == []
+    assert ir.structural_equal(solution.src_to_dst[x], 15)
+    assert ir.structural_equal(solution.src_to_dst[y], 5)
+
+
+def test_low_rank():
+    x, y, z = te.var("x"), te.var("y"), te.var("z")
+    ranges = {}
+
+    solution = arith.solve_linear_equations([
+        tvm.tir.EQ(x + y + z, 15),
+        tvm.tir.EQ(x + y, 10),
+    ], [x, y, z], ranges)
+    [n0] = solution.dst.variables
+    assert ir.structural_equal(solution.src_to_dst[x], n0 + 10)
+    assert ir.structural_equal(solution.src_to_dst[y], -n0)
+    assert ir.structural_equal(solution.src_to_dst[z], 5)
+
+
+def test_infer_range():
+    x, y = te.var("x"), te.var("y")
+    ranges = {
+        x: tvm.ir.Range.make_by_min_extent(-5, 10),
+        y: tvm.ir.Range.make_by_min_extent(0, 10),
+    }
+
+    solution = arith.solve_linear_equations([
+        tvm.tir.EQ(x + y, 0),
+    ], [x, y], ranges)
+    [n0] = solution.dst.variables
+    assert ir.structural_equal(solution.src_to_dst[x], n0)
+    assert ir.structural_equal(solution.src_to_dst[y], -n0)
+    # inferred from y's range
+    assert ir.structural_equal(solution.dst.ranges[n0].min, -9)
+    assert ir.structural_equal(solution.dst.ranges[n0].extent, 10)
+    # additional inequality is added into the system for x
+    [ineq] = solution.dst.relations
+    assert isinstance(ineq, tvm.tir.LE)
+    assert ir.structural_equal(ineq.a, -5)
+    assert ir.structural_equal(ineq.b, n0)
+
+
+def test_ill_formed():
+    x, y = te.var("x"), te.var("y")
+
+    solution = arith.solve_linear_equations([
+        tvm.tir.EQ(x + y, 0),
+        tvm.tir.EQ(x - y, 0),
+        tvm.tir.EQ(x, 5),
+    ], [x, y], {})
+    assert list(solution.dst.variables) == []
+    [rel] = solution.dst.relations
+    assert ir.structural_equal(rel, False)
+    assert len(solution.src_to_dst) == 0
+    assert len(solution.dst_to_src) == 0
+
+
+if __name__ == "__main__":
+    pytest.main([__file__])