*/
void Bind(const Var& var, const Range& range);
/*!
+ * \brief Bind all the vars in the Map
+ *
+ * \param variables The {variable -> range} map.
+ */
+ void Bind(const Map<Var, Range>& variables);
+ /*!
* \brief Whether can we prove expr >= val.
* Non-negative proof is very useful in integer analysis
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/arith/int_solver.h
+ * \brief integer constraints data structures and solvers
+ */
+#ifndef TVM_ARITH_INT_SOLVER_H_
+#define TVM_ARITH_INT_SOLVER_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/tir/expr.h>
+#include <unordered_map>
+#include <vector>
+
+namespace tvm {
+namespace arith {
+
+using tir::Var;
+using tir::VarNode;
+using tir::IterVar;
+
+/*!
+ * \brief Represent integer constrains including (integer) variables, their ranges and
+ * the relations between them (either equations or inequalities).
+ * \sa LinearSystem
+ */
+class IntConstraintsNode : public Object {
+ public:
+ // e.g., \alpha, \beta, must be integers
+ Array<Var> variables;
+ // e.g., 1 <= \alpha <= N, etc.
+ // it is absolutely ok to include ranges for parameters
+ // (variables that are not in this->variables) in this map
+ Map<Var, Range> ranges;
+ // linear equalities or inequalities
+ // e.g., A \alpha = \beta or A \alpha <= \beta
+ Array<PrimExpr> relations;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("variables", &variables);
+ v->Visit("ranges", &ranges);
+ v->Visit("relations", &relations);
+ }
+
+ bool SEqualReduce(const IntConstraintsNode* other, SEqualReducer equal) const {
+ return
+ equal(variables, other->variables) &&
+ equal(ranges, other->ranges) &&
+ equal(relations, other->relations);
+ }
+
+ void SHashReduce(SHashReducer hash_reduce) const {
+ hash_reduce(variables);
+ hash_reduce(ranges);
+ hash_reduce(relations);
+ }
+
+ static constexpr const bool _type_has_method_sequal_reduce = true;
+ static constexpr const char* _type_key = "arith.IntConstraints";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object);
+};
+
+/*!
+ * \brief Managed reference to IntConstraintsNode.
+ * \sa IntConstraintsNode
+ */
+class IntConstraints : public ObjectRef {
+ public:
+ /*!
+ * \brief Constructor by fields
+ * \param variables The variables in the constraints, must be integers.
+ * \param ranges The ranges of the variables.
+ * \param relations The linear relations between the variables
+ * (either equations or inequalities)
+ */
+ TVM_DLL IntConstraints(Array<Var> variables,
+ Map<Var, Range> ranges,
+ Array<PrimExpr> relations);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode);
+};
+
+/*!
+ * \brief We can have different set of variables to represent the same constraints.
+ * For example, the following two systems are equivalent,
+ * {a + b = 0 | a >= 0, b >= 0} and
+ * {m - n = 0 | m >= 0, n <= 0}
+ * This data structure represents the transformation
+ * between two equivalent linear systems.
+ * In the above example,
+ * src : {a + b = 0 | a >= 0, b >= 0}
+ * dst : {m - n = 0 | m >= 0, n <= 0}
+ * src_to_dst : {a -> m, b -> -n}
+ * dst_to_src : {m -> a, n -> -b}
+ * \sa IntConstraintsTransform
+ */
+class IntConstraintsTransformNode : public Object {
+ public:
+ IntConstraints src;
+ IntConstraints dst;
+ Map<Var, PrimExpr> src_to_dst;
+ Map<Var, PrimExpr> dst_to_src;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("src", &src);
+ v->Visit("dst", &dst);
+ v->Visit("src_to_dst", &src_to_dst);
+ v->Visit("dst_to_src", &dst_to_src);
+ }
+
+ bool SEqualReduce(const IntConstraintsTransformNode* other, SEqualReducer equal) const {
+ return
+ equal(src, other->src) &&
+ equal(dst, other->dst) &&
+ equal(src_to_dst, other->src_to_dst) &&
+ equal(dst_to_src, other->dst_to_src);
+ }
+
+ void SHashReduce(SHashReducer hash_reduce) const {
+ hash_reduce(src);
+ hash_reduce(dst);
+ hash_reduce(src_to_dst);
+ hash_reduce(dst_to_src);
+ }
+
+ static constexpr const bool _type_has_method_sequal_reduce = true;
+ static constexpr const char* _type_key = "arith.IntConstraintsTransform";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object);
+};
+
+/*!
+ * \brief Managed reference to IntConstraintsTransformNode.
+ * \sa IntConstraintsTransformNode
+ */
+class IntConstraintsTransform : public ObjectRef {
+ public:
+ /*!
+ * \brief Constructor by fields
+ * \param src source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0}
+ * \param dst integer constraints equivalent to the source,
+ * e.g., {m - n = 0 | m >= 0, n <= 0}
+ * \param src_to_dst mapping from variables in the \p src to the variables in the \p dst,
+ * e.g., {a -> m, b -> -n}
+ * \param dst_to_src mapping from variables in the \p dst to the variables in the \p src,
+ * e.g., {m -> a, n -> -b}
+ */
+ TVM_DLL IntConstraintsTransform(IntConstraints src,
+ IntConstraints dst,
+ Map<Var, PrimExpr> src_to_dst,
+ Map<Var, PrimExpr> dst_to_src);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode);
+};
+
+/*!
+ * \brief Obtain Smith Normal Form of linear equation A x = y.
+ * Smith Normal Form of matrix A_{mxn} is S_{mxn} = U_{mxm} A_{mxn} V_{nxn},
+ * in which S_{mxn} is diag(s1, s2, ..., sr, 0, ..., 0) and r is the rank of A.
+ * NOTE: Although in standard Smith Normal Form the diagonal elements satisfy
+ * s_i | s_{i+1} (| means divides), the implement here does not guarantee it.
+ * TODO(yzhliu): From sergei-grechanik:
+ * computing the proper Smith normal form may improve stability of automatic differentiation
+ * (generating the same gradient code for slightly different but equivalent input code
+ * U_{mxm} and V_{nxn} are invertible matrices.
+ * This function modifies \p S to be S_{mxn}, \p V to be V_{nxn},
+ * \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x.
+ * \param S the original A_{mxn}, it will be modified to S_{mxn}
+ * \param V an identity matrix, it will be modified to V_{nxn}
+ * \param x the x in A x = y. it will be modified to V^{-1}_{nxn} x_{nx1}
+ * \param y the y in A x = y. it will be modified to U_{mxm} y_{mx1}
+ */
+void SmithNormalFormDiag(std::vector<std::vector<int64_t>> *S,
+ std::vector<std::vector<int64_t>> *V,
+ std::vector<PrimExpr>* x,
+ std::vector<PrimExpr> *y);
+
+/*!
+ * \brief Solve linear equations.
+ * \param system_to_solve the variables to solve, their ranges, and a list of equations.
+ * \return A new linear system, with less variables (if \p system_to_solve is NOT of full rank),
+ * or no variable (if \p system_to_solve is of full rank),
+ * or an empty linear system (if \p system_to_solve is unsolvable).
+ * It also provides the ranges of the variables in the new system,
+ * as well as inequalities inferred from the \p system_to_solve.
+ * You can get the mapping from the original variables to the solution via ret->src_to_dst.
+ */
+IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve);
+
+} // namespace arith
+} // namespace tvm
+#endif // TVM_ARITH_INT_SOLVER_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/arith/util.h
+ * \brief Utils for arithmetic analysis.
+ */
+#ifndef TVM_ARITH_UTIL_H_
+#define TVM_ARITH_UTIL_H_
+
+#include <cstdint>
+#include <tuple>
+
+namespace tvm {
+/*! \brief namespace of arithmetic analysis. */
+namespace arith {
+
+/*!
+ * \brief Calculate the extended greatest common divisor for two values.
+ * See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm.
+ * \param a an integer number
+ * \param b an integer number
+ * \return 3 integers (div, m, n) where div = gcd(a, b) and a*m + b*n = div
+ */
+std::tuple<int64_t, int64_t, int64_t> xgcd(int64_t a, int64_t b);
+
+} // namespace arith
+} // namespace tvm
+#endif // TVM_ARITH_UTIL_H_
from .analyzer import ModularSet, ConstIntBound, Analyzer
from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound
+from .int_solver import solve_linear_equations
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""integer constraints data structures and solvers"""
+import tvm._ffi
+from tvm.runtime import Object
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("arith.IntConstraints")
+class IntConstraints(Object):
+ """Represent a set of integer constraints including variables, their ranges and
+ the relations between them (either equations or inequalities)
+
+ Parameters
+ ----------
+ variables : List[tvm.tir.Var]
+ The variables in the constraints. Must be integers
+ ranges : Map[tvm.tir.Var, tvm.ir.Range]
+ The ranges of the variables.
+ relations : List[tvm.ir.PrimExpr]
+ The relations between the variables (either equations or inequalities)
+ """
+ def __init__(self, variables, ranges, relations):
+ self.__init_handle_by_constructor__(
+ _ffi_api.IntConstraints, variables, ranges, relations)
+
+
+@tvm._ffi.register_object("arith.IntConstraintsTransform")
+class IntConstraintsTransform(Object):
+ """We can have different set of variables to represent the same integer constraints.
+ For example, the following two constrains are equivalent,
+ {a + b = 0 | a >= 0, b >= 0} and
+ {m - n = 0 | m >= 0, n <= 0}
+ This data structure represents the transformation
+ between two equivalent integer constraints.
+ In the above example,
+ src : {a + b = 0 | a >= 0, b >= 0}
+ dst : {m - n = 0 | m >= 0, n <= 0}
+ src_to_dst : {a -> m, b -> -n}
+ dst_to_src : {m -> a, n -> -b}
+
+ Parameters
+ ----------
+ src : arith.IntConstraints
+ source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0}
+ dst : arith.IntConstraints
+ integer constraints equivalent to the source, e.g., {m - n = 0 | m >= 0, n <= 0}
+ src_to_dst : Map[tvm.tir.Var, tvm.ir.PrimExpr]
+ mapping from variables in the src to the variables in the dst,
+ e.g., {a -> m, b -> -n}
+ dst_to_src : Map[tvm.tir.Var, tvm.ir.PrimExpr]
+ mapping from variables in the dst to the variables in the src,
+ e.g., {m -> a, n -> -b}
+ """
+ def __init__(self, src, dst, src_to_dst, dst_to_src):
+ self.__init_handle_by_constructor__(
+ _ffi_api.IntConstraintsTransform, src, dst, src_to_dst, dst_to_src)
+
+
+def solve_linear_equations(equations, variables=None, ranges=None):
+ """Solve linear equations.
+
+ Parameters
+ ----------
+ equations: List[tvm.ir.PrimExpr] or IntConstraints
+ The equations of the variables
+ variables : Optional[List[tvm.tir.Var]]
+ The variables in the system.
+ ranges : Optional[Map[tvm.tir.Var, tvm.ir.Range]]
+ The ranges of the variables.
+
+ Returns
+ -------
+ int_constraints_transform : IntConstraintsTransform
+ New integer constraints, with less variables (if the problem is NOT of full rank),
+ or no variable (if the problem is of full rank),
+ or an empty integer constraints (if the problem is unsolvable).
+ It also provides the ranges of the variables in the new system,
+ as well as inequalities inferred from the problem.
+ You can get the mapping from the original variables to the solution via
+ int_constraints_transform.src_to_dst.
+ """
+ if isinstance(equations, IntConstraints):
+ return _ffi_api.SolveLinearEquations(equations)
+ return _ffi_api.SolveLinearEquations(variables, ranges, equations)
// skip rewrite simplify
}
+void Analyzer::Bind(const Map<Var, Range>& variables) {
+ for (const auto& iter : variables) {
+ this->Bind(iter.first, iter.second);
+ }
+}
void ConstraintContext::EnterWithScope() {
CHECK(exit_ == nullptr);
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file int_constraints.cc
+ * \brief The integer constraints data structures.
+ */
+#include <tvm/arith/int_solver.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/runtime/registry.h>
+
+#include <utility>
+#include <algorithm>
+#include <unordered_map>
+
+namespace tvm {
+namespace arith {
+
+IntConstraints::IntConstraints(Array<Var> variables,
+ Map<Var, Range> ranges,
+ Array<PrimExpr> relations) {
+ ObjectPtr<IntConstraintsNode> node = make_object<IntConstraintsNode>();
+ if (!variables.defined()) {
+ variables = Array<Var>();
+ }
+ if (!ranges.defined()) {
+ ranges = Map<Var, Range>();
+ }
+ CHECK(relations.defined());
+ for (const auto& var : variables) {
+ CHECK(var.dtype().is_int() || var.dtype().is_uint())
+ << "Variables in IntConstraints must be integers";
+ }
+ node->variables = std::move(variables);
+ node->ranges = std::move(ranges);
+ node->relations = std::move(relations);
+ data_ = std::move(node);
+}
+
+TVM_REGISTER_NODE_TYPE(IntConstraintsNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+.set_dispatch<IntConstraintsNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const IntConstraintsNode*>(node.get());
+ p->stream << "IntConstraints("
+ << op->variables
+ << ", " << op->ranges
+ << ", " << op->relations
+ << ")";
+ });
+
+
+IntConstraintsTransform::IntConstraintsTransform(IntConstraints src,
+ IntConstraints dst,
+ Map<Var, PrimExpr> src_to_dst,
+ Map<Var, PrimExpr> dst_to_src) {
+ ObjectPtr<IntConstraintsTransformNode> node = make_object<IntConstraintsTransformNode>();
+ node->src = std::move(src);
+ node->dst = std::move(dst);
+ node->src_to_dst = std::move(src_to_dst);
+ node->dst_to_src = std::move(dst_to_src);
+ data_ = std::move(node);
+}
+
+TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+.set_dispatch<IntConstraintsTransformNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const IntConstraintsTransformNode*>(node.get());
+ p->stream << "IntConstraintsTransform("
+ << "\n\t" << op->src
+ << "\n\t" << op->dst
+ << "\n\t" << op->src_to_dst
+ << "\n\t" << op->dst_to_src
+ << "\n)";
+ });
+
+} // namespace arith
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/arith/solve_linear_equation.cc
+ * \brief Solve linear equations.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/arith/util.h>
+#include <tvm/tir/op.h>
+#include <tvm/arith/pattern.h>
+#include <tvm/tir/ir_pass.h>
+#include <tvm/runtime/data_type.h>
+
+namespace tvm {
+namespace arith {
+
+using namespace tvm::runtime;
+
+void SmithNormalFormDiag(std::vector<std::vector<int64_t> >* S,
+ std::vector<std::vector<int64_t> >* V,
+ std::vector<PrimExpr>* x,
+ std::vector<PrimExpr>* y) {
+ if (S->empty() || V->empty()) return;
+ size_t m = S->size();
+ size_t n = (*S)[0].size(); // n is # of variables
+ CHECK_EQ(V->size(), n);
+ CHECK_EQ((*V)[0].size(), n);
+
+ for (size_t index = 0; index < std::min(m, n); ++index) {
+ // Here A is partially diagonalized, that is A[i, j] is zero for all i, j
+ // such that (i < index) or (j < index), unless (i == j).
+ // That is, now we are diagonalizing the submatrix with i >= index and j >= index
+
+ // Find a row with a nonzero element in the index-th column
+ // (We also prefer rows where this element has minimal abs value)
+ size_t best_i = index;
+ for (size_t i = best_i; i < m; ++i) {
+ int64_t s_old = (*S)[best_i][index];
+ int64_t s_new = (*S)[i][index];
+ if (s_new != 0) {
+ if (s_old == 0 || std::abs(s_new) < std::abs(s_old)) {
+ best_i = i;
+ }
+ }
+ }
+ // Move the row we found to the index-th position
+ std::swap((*S)[index], (*S)[best_i]);
+ std::swap((*y)[index], (*y)[best_i]);
+
+ // If the index-th diagonal element is still zero, try to find a column with nonzero index-th
+ // element and move it to the index-th position
+ if ((*S)[index][index] == 0) {
+ for (size_t j = index + 1; j < n; ++j) {
+ if ((*S)[index][j] != 0) {
+ for (size_t i = index; i < m; ++i) {
+ std::swap((*S)[i][index], (*S)[i][j]);
+ }
+ // swapping columns corresponds to swapping the corresponding x
+ std::swap((*x)[index], (*x)[j]);
+ for (size_t i = 0; i < n; ++i) {
+ std::swap((*V)[i][index], (*V)[i][j]);
+ }
+ break;
+ }
+ }
+ }
+
+ // If the index-th diagonal element is still zero, then both the index-th row and the index-th
+ // column are completely zero, and we don't need to do anything; just go to the next index
+ if ((*S)[index][index] == 0) {
+ continue;
+ }
+
+ // Now the index-th diagonal element is non-zero and we can zero all the index-th column
+ // below it by subtracting rows from each other
+ for (auto i = index + 1; i < m; ++i) {
+ if ((*S)[i][index] != 0) {
+ int64_t g, a, b;
+ // g = a*matrix[index][index] + b*matrix[i][index]
+ if ((*S)[i][index] % (*S)[index][index] != 0) {
+ std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[i][index]);
+ } else {
+ // Explicitly avoid changing the index-th row. This is important to avoid infinite loop.
+ g = (*S)[index][index];
+ a = 1;
+ b = 0;
+ }
+
+ // Let m = S[index][index], n = S[i][index], then the following is true:
+ //
+ // [ a n/g ][ m/g n/g ] = [ 1 0 ]
+ // [ b -m/g ][ b -a ] = [ 0 1 ]
+ //
+ // Note that the two matrices are integer (since g = gcd(m, n)).
+ // We will essentially multiply our matrix on the left by a dilated and transposed version
+ // of the first of these two matrices. The second matrix is not needed here, however we will
+ // use it while zeroing the index-th row.
+
+ int64_t m_g = (*S)[index][index] / g;
+ int64_t n_g = (*S)[i][index] / g;
+
+ // Note that j is the index of the column, not the row
+ for (size_t j = index; j < (*S)[i].size(); ++j) {
+ // Multiply index-th row by a and add the i-th row multiplied by b
+ // This will make the index-th diagonal element equal to the gcd
+ int64_t new_index_j = a*(*S)[index][j] + b*(*S)[i][j];
+ // This transformation performs zeroing of matrix[i][index]
+ int64_t new_i_j = n_g*(*S)[index][j] - m_g*(*S)[i][j];
+ (*S)[index][j] = new_index_j;
+ (*S)[i][j] = new_i_j;
+ }
+ // We have to do the same with rhs
+ PrimExpr ea = te::make_const((*y)[index].dtype(), a);
+ PrimExpr eb = te::make_const((*y)[i].dtype(), b);
+ PrimExpr e_m_g = te::make_const((*y)[i].dtype(), m_g);
+ PrimExpr e_n_g = te::make_const((*y)[index].dtype(), n_g);
+ PrimExpr new_index_rhs = ea*(*y)[index] + eb*(*y)[i];
+ PrimExpr new_i_rhs = e_n_g*(*y)[index] - e_m_g*(*y)[i];
+ (*y)[index] = new_index_rhs;
+ (*y)[i] = new_i_rhs;
+ }
+ }
+
+ bool changed = false;
+
+ // Now we have to zero the elements of the index-th row by manipulating columns.
+ // This is more difficult because column manipulation corresponds to variable manipulation,
+ // but the algorithm is essentially the same as before.
+ for (size_t j = index + 1; j < n; ++j) {
+ if ((*S)[index][j] != 0) {
+ int64_t g, a, b;
+ // g = a*matrix[index][index] + b*matrix[index][j]
+ if ((*S)[index][j] % (*S)[index][index] != 0) {
+ std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[index][j]);
+ // During this phase we may disrupt the zeroness of the index-th column, so we will
+ // have to take some action if this might have happened.
+ changed = true;
+ } else {
+ // Explicitly avoid changing the index-th column. This is important to avoid infinite
+ // loop. Note that here we don't have to set `changed` to true since we don't change the
+ // index-th column.
+ g = (*S)[index][index];
+ a = 1;
+ b = 0;
+ }
+
+ // Let m = S[index][index], n = S[index][j], then the following is true:
+ //
+ // [ a n/g ][ m/g n/g ] = [ 1 0 ]
+ // [ b -m/g ][ b -a ] = [ 0 1 ]
+ //
+ // Now we are going to multiply our matrix on the right (to manipulate columns instead of
+ // rows), we will also transform the old_to_new matrix the same way, and we will use the
+ // second matrix to transform new_to_old.
+
+ int64_t m_g = (*S)[index][index] / g;
+ int64_t n_g = (*S)[index][j] / g;
+
+ for (size_t i = index; i < m; ++i) {
+ int64_t new_i_index = a*(*S)[i][index] + b*(*S)[i][j];
+ int64_t new_i_j = n_g*(*S)[i][index] - m_g*(*S)[i][j];
+ (*S)[i][index] = new_i_index;
+ (*S)[i][j] = new_i_j;
+ }
+ // We do exactly the same transformations with V
+ for (size_t i = 0; i < n; ++i) {
+ int64_t new_i_index = a*(*V)[i][index] + b*(*V)[i][j];
+ int64_t new_i_j = n_g*(*V)[i][index] - m_g*(*V)[i][j];
+ (*V)[i][index] = new_i_index;
+ (*V)[i][j] = new_i_j;
+ }
+ // And apply reverse transformations to new_to_old.
+ PrimExpr ea = te::make_const((*x)[j].dtype(), a);
+ PrimExpr eb = te::make_const((*x)[index].dtype(), b);
+ PrimExpr e_m_g = te::make_const((*x)[index].dtype(), m_g);
+ PrimExpr e_n_g = te::make_const((*x)[j].dtype(), n_g);
+ PrimExpr new_index = e_m_g*(*x)[index] + e_n_g*(*x)[j];
+ PrimExpr new_j = eb*(*x)[index] - ea*(*x)[j];
+ (*x)[index] = new_index;
+ (*x)[j] = new_j;
+ }
+ }
+
+ if (changed) {
+ // We might have changed the first column, so we have to zero it once more
+ // (or at least check if it's zero), so just perform this iteration once more.
+ index -= 1;
+ }
+ }
+}
+
+Map<Var, Range> InferRange(const Map<Var, PrimExpr>& vars_to_infer,
+ const Array<Var>& ori_vars,
+ const Map<Var, Range>& ori_ranges) {
+ // The resulting ranges
+ Map<Var, Range> new_ranges;
+
+ std::unordered_set<const VarNode*> ori_vset;
+ for (const Var& v : ori_vars) {
+ ori_vset.insert(v.get());
+ }
+
+ std::unordered_map<const VarNode*, IntSet> var_intsets;
+ for (const auto& p : ori_ranges) {
+ if (!ori_vset.count(p.first.get())) {
+ // First of all, fill the new ranges with outer variable ranges
+ new_ranges.Set(p.first, p.second);
+ }
+ // Convert original ranges to IntSets
+ var_intsets[p.first.get()] = IntSet::range(p.second);
+ }
+
+ // Infer ranges for the new variables and add them to the resulting ranges
+ for (const auto& p : vars_to_infer) {
+ const auto& var = p.first;
+ const auto& expr = p.second;
+ Range range = EvalSet(expr, var_intsets).cover_range(Range());
+ if (range.defined()) {
+ new_ranges.Set(var, range);
+ }
+ }
+ return new_ranges;
+}
+
+// pretty print matrix equation
+void DebugPrint(const std::vector<std::vector<int64_t>>& S,
+ const std::vector<std::vector<int64_t>>& V,
+ const std::vector<PrimExpr>& V_inv_x,
+ const std::vector<PrimExpr>& rhs) {
+ std::cout << "S:\n";
+ for (size_t i = 0; i < S.size(); ++i) {
+ for (auto e : S[i]) {
+ std::cout << e << "\t";
+ }
+ std::cout << "\t->\t" << rhs[i];
+ std::cout << "\n";
+ }
+ std::cout << "V:\n";
+ for (const auto& r : V) {
+ for (auto e : r) {
+ std::cout << e << "\t";
+ }
+ std::cout << "\n";
+ }
+ std::cout << "V_inv x:\n" << Array<PrimExpr>(V_inv_x);
+ std::cout << "\n" << std::endl;
+}
+
+IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve) {
+ // m: # of equations
+ // n: # of variables
+ // we first construct A_{mxn} x_{nx1} = y_{mx1}
+ // then get Smith normal form of matrix A,
+ // S_{mxn} = U_{mxm} A_{mxn} V_{nxn}
+ // => U^{-1} S V^{-1} x = y
+ // S V^{-1} x = U y
+ std::vector<PrimExpr> Uy; // mx1
+ std::vector<std::vector<int64_t>> S; // mxn
+ std::vector<std::vector<int64_t>> V; // nxn
+ std::vector<PrimExpr> V_inv_x; // V^{-1} x, nx1
+ // Conditions we don't know what to do with
+ std::vector<PrimExpr> rest;
+
+ Analyzer analyzer_problem;
+ analyzer_problem.Bind(system_to_solve->ranges);
+
+ size_t num_vars = system_to_solve->variables.size();
+
+ // initialize V_{nxn} with identity matrix,
+ // initialize V^{-1} x as x
+ for (size_t i = 0; i < num_vars; ++i) {
+ V.emplace_back(num_vars);
+ V.back()[i] = 1;
+ V_inv_x.push_back(system_to_solve->variables[i]);
+ }
+
+ // Transform formulas into rows of the matrix
+ // S_{mxn} V^{-1}_{nxn} x_{nx1} = U y, in which n is # of variables
+ // here we initialize S_{mxn} to be A, U to be identity matrix.
+ for (const PrimExpr& equation : system_to_solve->relations) {
+ if (const tir::EQNode* eq = equation.as<tir::EQNode>()) {
+ // a-b = sum_{i=0}^{n-1} variables[i] * coeff[i] + coeff[n]
+ Array<PrimExpr> coeffs = arith::DetectLinearEquation(
+ analyzer_problem.Simplify(eq->a - eq->b),
+ system_to_solve->variables);
+ if (!coeffs.empty()) {
+ std::vector<int64_t> row;
+ for (size_t j = 0; j < coeffs.size() - 1; ++j) {
+ PrimExpr c = coeffs[j];
+ if (const IntImmNode* ic = c.as<IntImmNode>()) {
+ row.push_back(ic->value);
+ } else {
+ // elements in matrix S V must be integers
+ // ignore equations that we cannot deal with.
+ LOG(WARNING) << "Cannot deal with non-integer coefficients, ignore equation "
+ << equation;
+ row.clear();
+ break;
+ }
+ }
+
+ if (!row.empty()) {
+ // S V^{-1} (a-b) = Uy
+ // V is identity for now
+ S.push_back(row);
+ Uy.push_back(-coeffs[coeffs.size() - 1]);
+ continue;
+ }
+ }
+ }
+
+ // otherwise
+ rest.push_back(equation);
+ }
+
+ // After diagonalizing, we have
+ // S_{mxn} is the Smith normal form (diagonal matrix)
+ // V_{nxn} is invertible
+ // V_inv_x is V^{-1} \times x
+ // Uy is U \times y
+ SmithNormalFormDiag(&S, &V, &V_inv_x, &Uy);
+
+ Array<Var> new_vars;
+ Array<PrimExpr> new_relations;
+ Map<Var, PrimExpr> new_to_old_map;
+ Map<Var, PrimExpr> old_to_new_map;
+
+ // Simplify right hand sides
+ for (PrimExpr r : Uy) {
+ r = analyzer_problem.Simplify(r);
+ }
+
+ // Create the relations of the existence of a solution
+ for (size_t j = 0; j < S.size(); ++j) {
+ PrimExpr new_relation;
+ if (j >= num_vars || S[j][j] == 0) {
+ // The row of matrix is zero. A solution exists only if the Ub[j] is also zero
+ new_relation = (Uy[j] == 0);
+ } else {
+ // The diagonal element is non-zero. A solution exists only if the diagonal element
+ // is a divisor of the Ub[j]
+ new_relation = (floormod(Uy[j], std::abs(S[j][j])) == 0);
+ }
+ new_relation = analyzer_problem.Simplify(new_relation);
+ if (tir::is_const_int(new_relation, 0)) {
+ // unable to solve the system.
+ return IntConstraintsTransform(
+ system_to_solve,
+ IntConstraints(
+ /*variables=*/{},
+ /*ranges=*/{},
+ /*relations=*/{te::make_zero(DataType::Bool())}),
+ {}, {});
+ } else if (!tir::is_const_int(new_relation, 1)) {
+ new_relations.push_back(new_relation);
+ }
+ }
+
+ Array<PrimExpr> solution_for_V_inv_x;
+ // Now create new variables or directly solve the equations
+ // suppose the rank of A is r, aka r = # of non-zeros in S
+ // the solution of S_{mxn} V^{-1}_{nxn} x_{nx1} = U b
+ // is
+ // x = (pseudo-inverse of A) b + K_{(n)x(n-r)} z_{n-r}
+ // = V_{nxn} S^{-1}_{nxm} (Ub)_{mxn} + K_{(n)x(n-r)} z_{n-r}
+ // in which K is the right n-r columns of V, z is variable vector
+ // thus,
+ // V^{-1} x = S^{-1}_{nxm} (Ub)_{mxn} +
+ // [[0, ... 0]_{n-r}, ... [0, ..., 0], diag(1, ..., 1)_{(n-r)x(n-r)}] z_{n-r}
+ for (size_t j = 0; j < num_vars; ++j) {
+ if (j >= S.size() || S[j][j] == 0) {
+ // The j-th variable can take any integer value, create a tvm variable for it
+ PrimExpr to_old = analyzer_problem.Simplify(V_inv_x[j]);
+ std::string name_hint = "n" + std::to_string(new_vars.size());
+ if (const VarNode* v_old = to_old.as<VarNode>()) {
+ name_hint += "_" + v_old->name_hint;
+ }
+ Var v = Var(name_hint, V_inv_x[j].dtype());
+ solution_for_V_inv_x.push_back(v);
+ new_vars.push_back(v);
+ new_to_old_map.Set(v, to_old);
+ } else {
+ // The j-th variable is just a single value, don't create a tvm variable
+ // S^{-1}_{nxm} Uy_{mxn}
+ if (S[j][j] >= 0) {
+ PrimExpr a = te::make_const(Uy[j].dtype(), S[j][j]);
+ solution_for_V_inv_x.push_back(
+ analyzer_problem.Simplify(floordiv(Uy[j], a)));
+ } else {
+ // This is required because some simplifiers
+ // have problems with dividing by negative numbers
+ PrimExpr a = te::make_const(Uy[j].dtype(), -S[j][j]);
+ solution_for_V_inv_x.push_back(
+ analyzer_problem.Simplify(floordiv(-Uy[j], a)));
+ }
+ }
+ }
+
+ // V V^{-1} x = x
+ for (size_t i = 0; i < num_vars; ++i) {
+ PrimExpr e = te::make_zero(system_to_solve->variables[i].dtype());
+ for (size_t j = 0; j < num_vars; ++j) {
+ e = e + te::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j];
+ }
+ e = analyzer_problem.Simplify(e);
+ old_to_new_map.Set(system_to_solve->variables[i], e);
+ }
+
+ // The resulting ranges
+ Map<Var, Range> new_ranges = InferRange(
+ new_to_old_map, system_to_solve->variables, system_to_solve->ranges);
+ Analyzer analyzer_solution;
+ analyzer_solution.Bind(new_ranges);
+
+ // We have to transform ranges of the old variables into relations over new variables because
+ // new ranges are not enough usually.
+ for (const auto& p : system_to_solve->ranges) {
+ const Var& old_var = p.first;
+ const Range& old_range = p.second;
+ if (old_to_new_map.count(old_var)) {
+ PrimExpr express_by_new_vars = old_to_new_map[old_var];
+ PrimExpr lower_cond = analyzer_solution.Simplify(
+ old_range->min <= express_by_new_vars);
+ PrimExpr upper_cond = analyzer_solution.Simplify(
+ express_by_new_vars < old_range->min + old_range->extent);
+ if (!tir::is_const_int(lower_cond, 1)) {
+ new_relations.push_back(lower_cond);
+ }
+ if (!tir::is_const_int(upper_cond, 1)) {
+ new_relations.push_back(upper_cond);
+ }
+ }
+ }
+
+ // Add the rest conditions
+ for (const PrimExpr& cond : rest) {
+ new_relations.push_back(Substitute(cond, old_to_new_map));
+ }
+
+ IntConstraints solution(new_vars, new_ranges, new_relations);
+ IntConstraintsTransform transform(
+ system_to_solve, solution, old_to_new_map, new_to_old_map);
+
+ return transform;
+}
+
+TVM_REGISTER_GLOBAL("arith.SolveLinearEquations")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+ if (args.size() == 1) {
+ *ret = SolveLinearEquations(args[0]);
+ } else if (args.size() == 3) {
+ IntConstraints problem(args[0], args[1], args[2]);
+ *ret = SolveLinearEquations(problem);
+ } else {
+ LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size();
+ }
+ });
+
+} // namespace arith
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file util.cc
+ * \brief The utils for arithmetic analysis.
+ */
+#include <tvm/arith/util.h>
+#include <dmlc/logging.h>
+
+namespace tvm {
+namespace arith {
+
+std::tuple<int64_t, int64_t, int64_t> xgcd(int64_t a, int64_t b) {
+ int64_t s = 0, old_s = 1;
+ int64_t t = 1, old_t = 0;
+ int64_t r = b, old_r = a;
+
+ while (r != 0) {
+ int64_t q = old_r / r;
+ std::swap(r, old_r);
+ r -= q * old_r;
+ std::swap(s, old_s);
+ s -= q * old_s;
+ std::swap(t, old_t);
+ t -= q * old_t;
+ }
+
+ CHECK_EQ(a % old_r, 0);
+ CHECK_EQ(b % old_r, 0);
+ CHECK(old_r == old_s*a + old_t*b);
+
+ return std::make_tuple(old_r, old_s, old_t);
+}
+
+} // namespace arith
+} // namespace tvm
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import random
+import numpy as np
+import sys
+import pytest
+import tvm
+from tvm import te, arith, ir, tir
+
+
+def run_expr(expr, vranges):
+ """ Evaluate expr for every value of free variables
+ given by vranges and return the tensor of results.
+ TODO(yzhliu): move to utils
+ """
+ def _compute_body(*us):
+ vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)}
+ return tir.ir_pass.Substitute(expr, vmap)
+
+ A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body)
+ args = [tvm.nd.empty(A.shape, A.dtype)]
+ sch = te.create_schedule(A.op)
+ mod = tvm.build(sch, [A])
+ mod(*args)
+ return args[0].asnumpy()
+
+
+def check_bruteforce(bool_expr, vranges, cond=None):
+ """ Check that bool_expr holds given the condition cond
+ for every value of free variables from vranges.
+ TODO(yzhliu): move to utils
+ """
+ if cond is not None:
+ bool_expr = te.any(tir.Not(cond), bool_expr)
+
+ res = run_expr(bool_expr, vranges)
+ if not np.all(res):
+ indices = list(np.argwhere(res == 0)[0])
+ counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)]
+ counterex = sorted(counterex, key=lambda x: x[0])
+ counterex = ", ".join([v + " = " + str(i) for v, i in counterex])
+ raise AssertionError("Expression {}\nis not true on {}\n"
+ "Counterexample: {}"
+ .format(tir.ir_pass.CanonicalSimplify(bool_expr), vranges, counterex))
+
+
+def check_solution(solution, vranges={}):
+ """Check that solution is a bijective transformation"""
+ def _check_forward(constraints1, constraints2, varmap, backvarmap):
+ all_vranges = vranges.copy()
+ all_vranges.update({v: r for v, r in constraints1.ranges.items()})
+
+ # Check that the transformation is injective
+ cond_on_vars = tir.const(1, 'bool')
+ for v in constraints1.variables:
+ # variable mapping is consistent
+ v_back = tir.ir_pass.Simplify(tir.ir_pass.Substitute(varmap[v], backvarmap))
+ cond_on_vars = te.all(cond_on_vars, v == v_back)
+ # Also we have to check that the new relations are true when old relations are true
+ cond_subst = tir.ir_pass.Substitute(
+ te.all(tir.const(1, 'bool'), *constraints2.relations), backvarmap)
+ # We have to include relations from vranges too
+ for v in constraints2.variables:
+ if v in constraints2.ranges:
+ r = constraints2.ranges[v]
+ range_cond = te.all(v >= r.min, v < r.min + r.extent)
+ range_cond = tir.ir_pass.Substitute(range_cond, backvarmap)
+ cond_subst = te.all(cond_subst, range_cond)
+ cond_subst = tir.ir_pass.Simplify(cond_subst)
+ check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges,
+ cond=te.all(tir.const(1, 'bool'), *constraints1.relations))
+
+ rels = solution.dst.relations
+ if len(rels) == 1 and ir.structural_equal(rels[0], False):
+ # not solvable, skip
+ return
+ _check_forward(solution.src, solution.dst,
+ solution.src_to_dst, solution.dst_to_src)
+ _check_forward(solution.dst, solution.src,
+ solution.dst_to_src, solution.src_to_dst)
+
+
+def test_solution_consistency():
+ seed = random.randrange(sys.maxsize)
+ print("\nThis test is intentionally non-deterministic, "
+ "if it fails please report it in github issue together with this seed {}\n".format(seed))
+ random.seed(seed)
+
+ def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)):
+ variables = [te.var("x" + str(i)) for i in range(num_vars)]
+
+ relations = []
+ for i in range(num_formulas):
+ s1 = sum([v*random.randint(coef[0], coef[1]) for v in variables])
+ s1 += random.randint(coef[0], coef[1])
+ s2 = sum([v*random.randint(coef[0], coef[1]) for v in variables])
+ s2 += random.randint(coef[0], coef[1])
+ if random.random() < 0.7:
+ op = tvm.tir.EQ
+ else:
+ # we also make sure it can correctly handle inequalities
+ op = random.choice([tvm.tir.LE, tvm.tir.LT, tvm.tir.GE, tvm.tir.GT])
+ relations.append(op(s1, s2))
+
+ vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in variables}
+ solution = arith.solve_linear_equations(relations, variables, vranges)
+
+ check_solution(solution)
+
+ # leaving some variables as parameters should also be ok
+ for k in [1, 2]:
+ if len(variables) > k:
+ solution = arith.solve_linear_equations(relations, variables[:-k], vranges)
+ param_ranges = {v: vranges[v] for v in variables[-k:]}
+ check_solution(solution, param_ranges)
+
+ for i in range(2):
+ _check(num_vars=1, num_formulas=1)
+ for i in range(2):
+ _check(num_vars=1, num_formulas=2)
+
+ for i in range(2):
+ _check(num_vars=2, num_formulas=1)
+ for i in range(2):
+ _check(num_vars=2, num_formulas=2)
+ for i in range(2):
+ _check(num_vars=2, num_formulas=3)
+
+ for i in range(3):
+ _check(num_vars=3, num_formulas=3, coef=(-2, 2))
+ for i in range(3):
+ _check(num_vars=3, num_formulas=4, coef=(-2, 2))
+
+ for i in range(3):
+ _check(num_vars=4, num_formulas=3, coef=(-1, 1))
+
+ for i in range(3):
+ _check(num_vars=10, num_formulas=2, coef=(-1, 1), bounds=(0, 4))
+ for i in range(3):
+ _check(num_vars=10, num_formulas=3, coef=(0, 1), bounds=(0, 4))
+
+
+def test_empty_var_to_solve():
+ x, y = te.var("x"), te.var("y")
+ equations = [
+ tvm.tir.EQ(x + y, 20),
+ tvm.tir.EQ(x - y, 10),
+ ]
+ solution = arith.solve_linear_equations(equations)
+ assert len(solution.src_to_dst) == 0
+ assert len(solution.dst_to_src) == 0
+ assert len(solution.src.variables) == 0
+ assert len(solution.src.ranges) == 0
+ assert ir.structural_equal(solution.src.relations, equations)
+ assert ir.structural_equal(solution.src, solution.dst)
+
+
+def test_unique_solution():
+ x, y = te.var("x"), te.var("y")
+
+ solution = arith.solve_linear_equations([
+ tvm.tir.EQ(x + y, 20),
+ tvm.tir.EQ(x - y, 10),
+ ], [x, y])
+ assert list(solution.dst.variables) == []
+ assert ir.structural_equal(solution.src_to_dst[x], 15)
+ assert ir.structural_equal(solution.src_to_dst[y], 5)
+
+
+def test_low_rank():
+ x, y, z = te.var("x"), te.var("y"), te.var("z")
+ ranges = {}
+
+ solution = arith.solve_linear_equations([
+ tvm.tir.EQ(x + y + z, 15),
+ tvm.tir.EQ(x + y, 10),
+ ], [x, y, z], ranges)
+ [n0] = solution.dst.variables
+ assert ir.structural_equal(solution.src_to_dst[x], n0 + 10)
+ assert ir.structural_equal(solution.src_to_dst[y], -n0)
+ assert ir.structural_equal(solution.src_to_dst[z], 5)
+
+
+def test_infer_range():
+ x, y = te.var("x"), te.var("y")
+ ranges = {
+ x: tvm.ir.Range.make_by_min_extent(-5, 10),
+ y: tvm.ir.Range.make_by_min_extent(0, 10),
+ }
+
+ solution = arith.solve_linear_equations([
+ tvm.tir.EQ(x + y, 0),
+ ], [x, y], ranges)
+ [n0] = solution.dst.variables
+ assert ir.structural_equal(solution.src_to_dst[x], n0)
+ assert ir.structural_equal(solution.src_to_dst[y], -n0)
+ # inferred from y's range
+ assert ir.structural_equal(solution.dst.ranges[n0].min, -9)
+ assert ir.structural_equal(solution.dst.ranges[n0].extent, 10)
+ # additional inequality is added into the system for x
+ [ineq] = solution.dst.relations
+ assert isinstance(ineq, tvm.tir.LE)
+ assert ir.structural_equal(ineq.a, -5)
+ assert ir.structural_equal(ineq.b, n0)
+
+
+def test_ill_formed():
+ x, y = te.var("x"), te.var("y")
+
+ solution = arith.solve_linear_equations([
+ tvm.tir.EQ(x + y, 0),
+ tvm.tir.EQ(x - y, 0),
+ tvm.tir.EQ(x, 5),
+ ], [x, y], {})
+ assert list(solution.dst.variables) == []
+ [rel] = solution.dst.relations
+ assert ir.structural_equal(rel, False)
+ assert len(solution.src_to_dst) == 0
+ assert len(solution.dst_to_src) == 0
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])