[Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff (#6078)
authorYizhi Liu <liuyizhi@apache.org>
Tue, 18 Aug 2020 23:19:39 +0000 (16:19 -0700)
committerGitHub <noreply@github.com>
Tue, 18 Aug 2020 23:19:39 +0000 (16:19 -0700)
* [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Co-authored-by: Sergei Grechanik <sergei.grechanik+h@gmail.com>
* fix lint

* fix clang-format

* add comments and magic number

* clang-lint

* address some comments

* remove FreeVarsVisitor

* fix constexpr lint

* fix lint

* fix lint

* add Map.Merge

* lint

* change Array::Concat & Map::Merge to global functions

* fix lint

* move functions to global

* static -> inline

Co-authored-by: Sergei Grechanik <sergei.grechanik+h@gmail.com>
16 files changed:
include/tvm/arith/analyzer.h
include/tvm/arith/int_solver.h
include/tvm/node/container.h
include/tvm/runtime/container.h
include/tvm/tir/analysis.h
src/arith/canonical_simplify.cc
src/arith/int_constraints.cc
src/arith/solve_linear_inequality.cc
src/relay/op/tensor/transform.cc
src/te/autodiff/ad_simplify.cc [new file with mode: 0644]
src/te/autodiff/ad_util.cc
src/te/autodiff/ad_util.h
src/te/autodiff/adjoint.cc
src/te/autodiff/jacobian.cc
src/tir/transforms/split_host_device.cc
tests/python/unittest/test_te_autodiff.py

index cc64294..a9a0bed 100644 (file)
@@ -52,6 +52,13 @@ class Analyzer;
 
 using tir::Var;
 
+enum DivMode {
+  /*! \brief Truncated division. */
+  kTruncDiv,
+  /*! \brief Floor division. */
+  kFloorDiv
+};
+
 /*!
  * \brief Constant integer up and lower bound(inclusive).
  *  Useful for value bound analysis.
index ad044b2..0ef74ce 100644 (file)
@@ -41,6 +41,11 @@ using tir::IterVar;
 using tir::Var;
 using tir::VarNode;
 
+// According to experiments two best simplifications orders were can->rw and rw->can->rw,
+// but rw->can->rw is better for a couple of cases.
+// Also we should end with rw because it factors multipliers out.
+constexpr int kSimplifyRewriteCanonicalRewrite = 3;
+
 /*!
  * \brief Represent integer grouped bounds which are classified into
  *        lower bounds (inclusive), upper bounds (inclusive) and equalities.
@@ -251,6 +256,15 @@ class IntConstraintsTransform : public ObjectRef {
   TVM_DLL IntConstraintsTransform(IntConstraints src, IntConstraints dst,
                                   Map<Var, PrimExpr> src_to_dst, Map<Var, PrimExpr> dst_to_src);
 
+  /*!
+   * \brief Chain-compose two IntConstraintsTransform together.
+   *        this->dst must be the same as other->src.
+   * @param other another IntConstraintsTransform whose src is same as this->dst.
+   * @return composed IntConstraintsTransform(this->src, other->dst)
+   *         with its variables and ranges are properly modified.
+   */
+  IntConstraintsTransform operator+(const IntConstraintsTransform& other) const;
+
   TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode);
 };
 
@@ -307,6 +321,16 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol
 PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve);
 
 /*!
+ * \brief Combine the information into an array of (in)equalities.
+ * \param variables The variables in \p bounds.
+ *        It is used to determine the iteration order to avoid indeterministic results.
+ * \param bounds grouped boundary of the variables.
+ * \param relations other relations.
+ */
+Array<PrimExpr> AsConditions(const Array<Var>& variables, const Map<Var, IntGroupBounds>& bounds,
+                             const Array<PrimExpr>& relations);
+
+/*!
  * \brief Solve linear inequalities and infer the range of each variable.
  * \param system_to_solve the variables to solve, their ranges, and a list of inequalities.
  * \return The result ranges for each variables.
index 365eb60..c2a198d 100644 (file)
@@ -1427,6 +1427,22 @@ class Map : public ObjectRef {
   MapNode* GetMapNode() const { return static_cast<MapNode*>(data_.get()); }
 };
 
+/*!
+ * \brief Merge two Maps.
+ * \param lhs the first Map to merge.
+ * \param rhs the second Map to merge.
+ * @return The merged Array. Original Maps are kept unchanged.
+ */
+template <typename K, typename V,
+          typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value>::type,
+          typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
+inline Map<K, V> Merge(Map<K, V> lhs, const Map<K, V>& rhs) {
+  for (const auto& p : rhs) {
+    lhs.Set(p.first, p.second);
+  }
+  return std::move(lhs);
+}
+
 }  // namespace tvm
 
 namespace tvm {
index 423ea89..9972785 100644 (file)
@@ -996,6 +996,21 @@ class Array : public ObjectRef {
   }
 };
 
+/*!
+ * \brief Concat two Arrays.
+ * \param lhs first Array to be concatenated.
+ * \param rhs second Array to be concatenated.
+ * \return The concatenated Array. Original Arrays are kept unchanged.
+ */
+template <typename T,
+          typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
+inline Array<T> Concat(Array<T> lhs, const Array<T>& rhs) {
+  for (const auto& x : rhs) {
+    lhs.push_back(x);
+  }
+  return std::move(lhs);
+}
+
 // Specialize make_object<ArrayNode> to make sure it is correct.
 template <>
 inline ObjectPtr<ArrayNode> make_object() {
index cbc7a51..e5b2c2b 100644 (file)
@@ -57,7 +57,7 @@ struct ExprDeepEqual {
 };
 
 /*!
- * \brief Find undefined vars in the statment.
+ * \brief Find undefined vars in the statement.
  * \param stmt The function to be checked.
  * \param defs The vars that is defined.
  * \return Array of undefined vars.
@@ -65,6 +65,13 @@ struct ExprDeepEqual {
 TVM_DLL Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
 
 /*!
+ * \brief Find undefined vars in the expression.
+ * \param expr The expression to be checked.
+ * \return Array of undefined vars.
+ */
+TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
+
+/*!
  * \brief Analyze the side effect
  * \param expr The expression to be checked.
  *
index 726289c..a8ef6a1 100644 (file)
@@ -59,13 +59,6 @@ class CanonicalExprNode : public PrimExprNode {
   TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode);
 };
 
-enum DivMode {
-  /*! \brief Truncated division. */
-  kTruncDiv,
-  /*! \brief Floor division. */
-  kFloorDiv
-};
-
 inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
   if (mode == kTruncDiv) {
     return truncmod(a, b);
index c95f7f8..189869b 100644 (file)
 namespace tvm {
 namespace arith {
 
+Array<PrimExpr> AsConditions(const Array<Var>& variables, const Map<Var, IntGroupBounds>& bounds,
+                             const Array<PrimExpr>& relations) {
+  Array<PrimExpr> res;
+  // use variables to keep the order of iteration
+  // so as to get rid of any non-determinism.
+  CHECK_EQ(variables.size(), bounds.size());
+  for (const auto v : variables) {
+    CHECK(bounds.count(v));
+    const auto& bnds = bounds[v];
+    PrimExpr lhs = bnds->coef * v;
+    for (const PrimExpr& rhs : bnds->equal) {
+      res.push_back(tir::EQ(lhs, rhs));
+    }
+    for (const PrimExpr& rhs : bnds->lower) {
+      res.push_back(tir::GE(lhs, rhs));
+    }
+    for (const PrimExpr& rhs : bnds->upper) {
+      res.push_back(tir::LE(lhs, rhs));
+    }
+  }
+  for (const PrimExpr& e : relations) {
+    res.push_back(e);
+  }
+  return res;
+}
+
 IntGroupBounds::IntGroupBounds(PrimExpr coef, Array<PrimExpr> lower, Array<PrimExpr> equal,
                                Array<PrimExpr> upper) {
   CHECK(coef.dtype().is_int() || coef.dtype().is_uint())
@@ -231,6 +257,26 @@ IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstrai
   data_ = std::move(node);
 }
 
+IntConstraintsTransform IntConstraintsTransform::operator+(
+    const IntConstraintsTransform& other) const {
+  CHECK(other->src.same_as(operator->()->dst));
+  Map<Var, PrimExpr> dst_to_src;
+  Map<Var, PrimExpr> src_to_dst;
+
+  Analyzer ana_first;
+  ana_first.Bind(operator->()->src->ranges);
+  for (auto p : other->dst_to_src) {
+    dst_to_src.Set(p.first, ana_first.Simplify(Substitute(p.second, operator->()->dst_to_src)));
+  }
+
+  Analyzer ana_second;
+  ana_second.Bind(other->dst->ranges);
+  for (auto p : operator->()->src_to_dst) {
+    src_to_dst.Set(p.first, ana_second.Simplify(Substitute(p.second, other->src_to_dst)));
+  }
+  return IntConstraintsTransform(operator->()->src, other->dst, src_to_dst, dst_to_src);
+}
+
 TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode);
 
 TVM_REGISTER_GLOBAL("arith.IntConstraintsTransform")
index f489d04..5744cfd 100644 (file)
@@ -94,35 +94,6 @@ struct ExprLess {
   }
 };
 
-/*!
- * \brief Combine the information into an array of (in)equalities.
- */
-Array<PrimExpr> as_conditions(const Array<Var>& variables, const Map<Var, IntGroupBounds>& bounds,
-                              const Array<PrimExpr>& relations) {
-  Array<PrimExpr> res;
-  // use variables to keep the order of iteration
-  // so as to get rid of any non-determinism.
-  CHECK_EQ(variables.size(), bounds.size());
-  for (const auto v : variables) {
-    CHECK(bounds.count(v));
-    const auto& bnds = bounds[v];
-    PrimExpr lhs = bnds->coef * v;
-    for (const PrimExpr& rhs : bnds->equal) {
-      res.push_back(tir::EQ(lhs, rhs));
-    }
-    for (const PrimExpr& rhs : bnds->lower) {
-      res.push_back(tir::GE(lhs, rhs));
-    }
-    for (const PrimExpr& rhs : bnds->upper) {
-      res.push_back(tir::LE(lhs, rhs));
-    }
-  }
-  for (const PrimExpr& e : relations) {
-    res.push_back(e);
-  }
-  return res;
-}
-
 void DebugPrint(
     const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& current_ineq_set,
     const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& next_ineq_set,
@@ -290,7 +261,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
 
   // Simplify each inequality into the form `expr <= 0` and add to current formulas
   for (const PrimExpr& ineq : system_to_solve->relations) {
-    AddInequality(&current_ineq_set_to_solve, NormalizeComparisons()(analyzer.Simplify(ineq, 3)),
+    AddInequality(&current_ineq_set_to_solve,
+                  NormalizeComparisons()(analyzer.Simplify(ineq, kSimplifyRewriteCanonicalRewrite)),
                   &analyzer);
   }
 
@@ -307,8 +279,9 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
     // Add bounds from vranges
     if (system_to_solve->ranges.count(v)) {
       const Range& range = system_to_solve->ranges[v];
-      PrimExpr range_lbound = analyzer.Simplify(range->min, 3);
-      PrimExpr range_ubound = analyzer.Simplify(range->min + range->extent - 1, 3);
+      PrimExpr range_lbound = analyzer.Simplify(range->min, kSimplifyRewriteCanonicalRewrite);
+      PrimExpr range_ubound =
+          analyzer.Simplify(range->min + range->extent - 1, kSimplifyRewriteCanonicalRewrite);
       coef_neg.push_back({-1, range_lbound});
       coef_pos.push_back({1, -range_ubound});
     }
@@ -329,7 +302,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
         // we need rewrite_simplify -> canonical_simplify -> rewrite_simplify
         // to help simplify things like (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0
         // with steps = 2 it's (y*2) - 10 <= 0
-        new_ineq = NormalizeComparisons()(analyzer.Simplify(new_ineq, 3));
+        new_ineq =
+            NormalizeComparisons()(analyzer.Simplify(new_ineq, kSimplifyRewriteCanonicalRewrite));
         AddInequality(&next_ineq_set_to_solve, new_ineq, &analyzer);
       }
     }
@@ -354,7 +328,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
 
     for (const auto& pos : coef_pos) {
       PrimExpr bound = make_const(v.dtype(), -coef_lcm / pos.first) * pos.second;
-      bound = analyzer.Simplify(bound, 3);
+      bound = analyzer.Simplify(bound, kSimplifyRewriteCanonicalRewrite);
       // Don't add if any of the existing bounds is better
       if (std::any_of(upper_bounds.begin(), upper_bounds.end(),
                       [&bound, &analyzer](const PrimExpr& o) {
@@ -375,7 +349,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
     }
     for (const auto& neg : coef_neg) {
       PrimExpr bound = make_const(v.dtype(), -coef_lcm / neg.first) * neg.second;
-      bound = analyzer.Simplify(bound, 3);
+      bound = analyzer.Simplify(bound, kSimplifyRewriteCanonicalRewrite);
       // Don't add if any of the existing bounds is better
       if (std::any_of(lower_bounds.begin(), lower_bounds.end(),
                       [&bound, &analyzer](const PrimExpr& o) {
@@ -414,7 +388,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
   // Everything that is left goes to res.relations
   Array<PrimExpr> other_conditions;
   for (const PrimExpr& e : current_ineq_set_to_solve) {
-    PrimExpr e_simp = analyzer.Simplify(e, 3);
+    PrimExpr e_simp = analyzer.Simplify(e, kSimplifyRewriteCanonicalRewrite);
     if (is_const_int(e_simp, 0)) {
       // contradiction detected
       other_conditions = {const_false()};
@@ -465,7 +439,8 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) {
       // There is an equation of the form `v == expr`, so this variable can be completely removed.
       // Note that we use the 0-th expression because they are ordered by complexity,
       // so it must be the simplest one.
-      Range best_range(bnd->equal[0], analyzer.Simplify(bnd->equal[0] + 1, 3));
+      Range best_range(bnd->equal[0],
+                       analyzer.Simplify(bnd->equal[0] + 1, kSimplifyRewriteCanonicalRewrite));
       res_ranges.Set(var, best_range);
       vranges.Set(var, best_range);
     } else {
@@ -491,7 +466,7 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) {
   arith::Analyzer analyzer;
   analyzer.Bind(vranges);
   for (const PrimExpr& old_cond :
-       as_conditions(inequalities->variables, solved_bounds, solved_other_relations)) {
+       AsConditions(inequalities->variables, solved_bounds, solved_other_relations)) {
     if (!analyzer.CanProve(old_cond)) {
       // those not represented in vranges (res_ranges)
       res_relations.push_back(old_cond);
@@ -584,7 +559,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ
 
   // Add the original conditions (with variables substituted) to the resulting conditions
   for (const PrimExpr& old_cond :
-       as_conditions(inequalities->variables, solved_bounds, solved_other_relations)) {
+       AsConditions(inequalities->variables, solved_bounds, solved_other_relations)) {
     PrimExpr new_cond = analyzer.Simplify(Substitute(old_cond, res_src_to_dst));
     if (!is_const_int(new_cond, 1)) {
       // those not represented in vranges (res_ranges)
@@ -615,7 +590,7 @@ TVM_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition")
         LOG(FATAL) << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets "
                    << args.size();
       }
-      *ret = as_conditions(problem->variables, ret_ineq.first, ret_ineq.second);
+      *ret = AsConditions(problem->variables, ret_ineq.first, ret_ineq.second);
     });
 
 TVM_REGISTER_GLOBAL("arith.SolveInequalitiesToRange").set_body([](TVMArgs args, TVMRetValue* ret) {
index be7cab1..f1d5b7a 100644 (file)
@@ -886,7 +886,6 @@ RELAY_REGISTER_OP("scatter_add")
     .set_attr<TOpPattern>("TOpPattern", kOpaque)
     .set_support_level(10);
 
-////
 
 // Take
 TVM_REGISTER_NODE_TYPE(TakeAttrs);
diff --git a/src/te/autodiff/ad_simplify.cc b/src/te/autodiff/ad_simplify.cc
new file mode 100644 (file)
index 0000000..3860c00
--- /dev/null
@@ -0,0 +1,1231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kSimplifyRewriteCanonicalRewrite;
+using arith::kTruncDiv;
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(
+          analyzer.Simplify(combiner->identity_element[0], kSimplifyRewriteCanonicalRewrite), 0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], kSimplifyRewriteCanonicalRewrite);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        kSimplifyRewriteCanonicalRewrite),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, kSimplifyRewriteCanonicalRewrite);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+    return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond = analyzer_.Simplify(nz_a.cond && cond, kSimplifyRewriteCanonicalRewrite);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond = analyzer_.Simplify(nz_b.cond && !cond, kSimplifyRewriteCanonicalRewrite);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           kSimplifyRewriteCanonicalRewrite);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(op_if_then_else_)) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             kSimplifyRewriteCanonicalRewrite);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, kSimplifyRewriteCanonicalRewrite);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, kSimplifyRewriteCanonicalRewrite);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+  const Op& op_if_then_else_ = Op::Get("tir.if_then_else");
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {
+  std::vector<PrimExpr> atomic_formulas;
+  PrimExpr rest;
+
+  PrimExpr to_expr() const {
+    PrimExpr res = rest;
+    for (const PrimExpr& e : atomic_formulas) {
+      res = And(e, res);
+    }
+    return res;
+  }
+
+  Array<PrimExpr> to_array() const {
+    Array<PrimExpr> res = atomic_formulas;
+    res.push_back(rest);
+    return res;
+  }
+};
+
+// The implementation of FactorOutAtomicFormulas
+class FactorOutAtomicFormulasFunctor
+    : public ExprFunctor<FactorOutAtomicFormulasResult(const PrimExpr&)> {
+ public:
+  result_type Atomic_(const PrimExpr& e) {
+    // For atomic expressions the result is the expr itself with True as the residual
+    return {{e}, make_const(e.dtype(), 1)};
+  }
+
+  // This is basically the list of expression kinds that are considered atomic
+  result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    // Select can be rewritten through other logical ops
+    PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value);
+    return VisitExpr(expr);
+  }
+
+  result_type VisitExpr_(const NotNode* op) final {
+    // Not should be moved down
+    if (const OrNode* or_expr = op->a.as<OrNode>()) {
+      PrimExpr expr = !or_expr->a && !or_expr->b;
+      return VisitExpr(expr);
+    } else if (const AndNode* and_expr = op->a.as<AndNode>()) {
+      PrimExpr expr = !and_expr->a || !and_expr->b;
+      return VisitExpr(expr);
+    } else if (const SelectNode* sel_expr = op->a.as<SelectNode>()) {
+      PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) &&
+                       (sel_expr->condition || !sel_expr->false_value));
+      return VisitExpr(expr);
+    }
+    return Atomic_(GetRef<PrimExpr>(op));
+  }
+
+  result_type VisitExpr_(const AndNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    // For the And case we return the union of the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size());
+    std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+    std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    // And the residuals are combined with &&
+    return {res, res_a.rest && res_b.rest};
+  }
+
+  result_type VisitExpr_(const MulNode* op) final {
+    // Since we work with bools, for multiplication we do the same thing as for And
+    PrimExpr e_and = op->a && op->b;
+    return VisitExpr(e_and);
+  }
+
+  result_type VisitExpr_(const OrNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set{
+        res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()};
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_b_set{
+        res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()};
+
+    // For the Or case we intersect the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
+    for (const auto& res_b_formula : res_b_set) {
+      if (res_a_set.count(res_b_formula)) {
+        res_set.insert(res_b_formula);
+      }
+    }
+
+    // Computing the residual is more complex: we have to compute the sets of atomic formulas
+    // which are left behind, and then combine them with the residuals into the new residual.
+    std::vector<PrimExpr> new_cond_a;
+    new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_a_set) {
+      if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
+    }
+
+    std::vector<PrimExpr> new_cond_b;
+    new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_b_set) {
+      if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
+    }
+
+    res_a.atomic_formulas = std::move(new_cond_a);
+    res_b.atomic_formulas = std::move(new_cond_b);
+
+    PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    return {res, new_rest};
+  }
+};
+
+// Transform the given formula into a conjunction of atomic formulas (represented as an array)
+// and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b,
+// etc), i.e. formulas which are not logical operators (||, &&, !) on the top level.
+FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) {
+  CHECK(e.dtype().is_bool());
+  return FactorOutAtomicFormulasFunctor().VisitExpr(e);
+}
+
+struct EliminateDivModResult {
+  PrimExpr expr;
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+};
+
+inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncmod(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floormod(a, b);
+  }
+}
+
+inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncdiv(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floordiv(a, b);
+  }
+}
+
+class EliminateDivModMutator : public ExprMutator {
+ public:
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+
+  explicit EliminateDivModMutator(Map<Var, Range> ranges) : ranges(std::move(ranges)) {}
+
+  virtual PrimExpr VisitExpr_(const DivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == -(x/c) for truncated division
+        return make_zero(op->dtype) -
+               VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return truncdiv(mutated_a, op->b);
+      }
+    }
+
+    return truncdiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const ModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == x % c for truncated division
+        return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return truncmod(mutated_a, op->b);
+      }
+    }
+
+    return truncmod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorDivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == (-x) / c for flooring division
+        return VisitExpr(
+            floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return floordiv(mutated_a, op->b);
+      }
+    }
+
+    return floordiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == -(-x % c) for flooring division
+        return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a,
+                                                         make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return floormod(mutated_a, op->b);
+      }
+    }
+
+    return floormod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+ private:
+  dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
+                                                    int64_t val, DivMode mode) {
+    using tresult = dmlc::optional<std::pair<Var, Var>>;
+
+    // Try to find the variables using the mutated expressions
+    if (!e.same_as(mut)) {
+      auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val));
+      if (it != expr_to_vars_.end()) {
+        return tresult(it->second);
+      }
+    }
+
+    PrimExpr val_e = make_const(e.dtype(), val);
+    idx_ += 1;
+
+    // Convert `ranges` to IntSets
+    std::unordered_map<const VarNode*, IntSet> var_intsets;
+    for (const auto& p : ranges) {
+      var_intsets[p.first.get()] = IntSet::FromRange(p.second);
+    }
+
+    // Infer ranges for the expressions we want to replace with variables
+    Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+    Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+
+    // We don't want to add unbounded variables
+    if (!div_range.get() || !mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+    if (!mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+
+    // Create new variables for the expressions
+    auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype());
+    auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype());
+
+    new_variables.push_back(div);
+    new_variables.push_back(mod);
+
+    // Note that we have to perform substitution to mut because mut may contain new variables
+    substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode));
+    substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode));
+
+    ranges.Set(div, div_range);
+    ranges.Set(mod, mod_range);
+
+    // This additional condition works as a definition for the new variables
+    conditions.push_back(mut == div * val_e + mod);
+
+    if (!analyzer_.CanProve(mod_range->extent <= val_e)) {
+      // If we use the C/C++ definition of mod, there may be multiple values of `mod`
+      // satisfying the added condition if the expr `e` may change its sign, so we
+      // have to add another condition.
+      LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod because "
+                   << ModImpl(e, val_e, mode) << "  probably may change its sign";
+      conditions.push_back(Select(e >= 0, mod >= 0, mod <= 0));
+    }
+
+    auto p = std::make_pair(div, mod);
+    expr_to_vars_[std::make_tuple(mode, e, val)] = p;
+    if (!e.same_as(mut)) {
+      expr_to_vars_[std::make_tuple(mode, mut, val)] = p;
+    }
+    return tresult(p);
+  }
+
+  class TupleEqual_ {
+   public:
+    bool operator()(const std::tuple<DivMode, PrimExpr, int64_t>& lhs,
+                    const std::tuple<DivMode, PrimExpr, int64_t>& rhs) const {
+      return std::get<0>(lhs) == std::get<0>(rhs) &&
+             tir::ExprDeepEqual()(std::get<1>(lhs), std::get<1>(rhs)) &&
+             std::get<2>(lhs) == std::get<2>(rhs);
+    }
+  };
+
+  class TupleHasher_ {
+   public:
+    size_t operator()(const std::tuple<DivMode, PrimExpr, int64_t>& key) const {
+      return ((std::hash<int>()(std::get<0>(key)) ^ (StructuralHash()(std::get<1>(key)) << 1)) >>
+              1) ^
+             (std::hash<int64_t>()(std::get<2>(key)) << 1);
+    }
+  };
+
+  // A counter for naming new variables
+  int idx_{0};
+  // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod)
+  // such that `div = e / n` and `mod = e % n`
+  std::unordered_map<std::tuple<DivMode, PrimExpr, int64_t>, std::pair<Var, Var>, TupleHasher_,
+                     TupleEqual_>
+      expr_to_vars_;
+  arith::Analyzer analyzer_;
+};
+
+// Replace every subexpr of the form e/const and e % const with a new variable.
+// Syntactically equal expressions will be mapped to the same variable.
+EliminateDivModResult EliminateDivMod(const PrimExpr& expr, Map<Var, Range> ranges) {
+  EliminateDivModResult res;
+  EliminateDivModMutator mutator(ranges);
+  res.expr = mutator(expr);
+  res.conditions = std::move(mutator.conditions);
+  res.new_variables = std::move(mutator.new_variables);
+  res.substitution = std::move(mutator.substitution);
+  res.ranges = std::move(mutator.ranges);
+  return res;
+}
+
+arith::IntConstraintsTransform EliminateDivModFromDomainConditions(
+    const arith::IntConstraints& domain) {
+  auto elim_res = EliminateDivMod(All(domain->relations), domain->ranges);
+
+  Map<Var, Range> new_vranges = elim_res.ranges;
+  Array<Var> new_axis = Concat(domain->variables, elim_res.new_variables);
+  PrimExpr new_cond = elim_res.expr && All(elim_res.conditions);
+
+  arith::IntConstraints new_domain(new_axis, new_vranges,
+                                   FactorOutAtomicFormulas(new_cond).to_array());
+
+  Map<Var, PrimExpr> src_to_dst;
+  Map<Var, PrimExpr> dst_to_src = elim_res.substitution;
+  for (const Var& v : domain->variables) {
+    src_to_dst.Set(v, v);
+    dst_to_src.Set(v, v);
+  }
+
+  return arith::IntConstraintsTransform(domain, new_domain, src_to_dst, dst_to_src);
+}
+
+inline arith::IntConstraintsTransform IdentityTransformation(const arith::IntConstraints& domain) {
+  Map<Var, PrimExpr> identity_map;
+  for (const Var& v : domain->variables) {
+    identity_map.Set(v, v);
+  }
+  return arith::IntConstraintsTransform(domain, domain, identity_map, identity_map);
+}
+
+// Simplify an iteration domain.
+arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains,
+                                              bool eliminate_div_mod) {
+  arith::IntConstraintsTransform transf = IdentityTransformation(iter_domains);
+
+  if (eliminate_div_mod) {
+    transf = transf + EliminateDivModFromDomainConditions(transf->dst);
+  }
+
+  // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably
+  // should find a better terminating criterion (like stop when the domain volume stops decreasing)
+  // Also 2 steps seems to be slightly better than 3
+  for (size_t i = 0; i < 2; ++i) {
+    transf = transf + arith::SolveLinearEquations(transf->dst);
+    transf = transf + arith::SolveInequalitiesDeskewRange(transf->dst);
+  }
+
+  return transf;
+}
+
+// Use the condition of a reduction op to simplify its domain (axis)
+PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map<Var, Range>& outer_vranges) {
+  if (const ReduceNode* red = expr.as<ReduceNode>()) {
+    Array<Var> vars = IterVarsToVars(red->axis);
+    Map<Var, Range> vranges = Merge(outer_vranges, IterVarsToMap(red->axis));
+    Array<PrimExpr> relations = FactorOutAtomicFormulas(red->condition).to_array();
+
+    arith::IntConstraints domain(vars, vranges, relations);
+    auto res = SimplifyDomain(domain);
+
+    Array<PrimExpr> new_source;
+    for (const PrimExpr& src : red->source) {
+      new_source.push_back(Substitute(src, res->src_to_dst));
+    }
+
+    Array<IterVar> new_axis = IterVarsFromMap(res->dst->variables, res->dst->ranges, kCommReduce);
+
+    // Perform simplification mainly to remove a possibly empty reduction.
+    arith::Analyzer analyzer;
+    return analyzer.Simplify(
+        Reduce(red->combiner, new_source, new_axis, All(res->dst->relations), red->value_index),
+        kSimplifyRewriteCanonicalRewrite);
+  } else {
+    return expr;
+  }
+}
+
+// Extract from cond an implication of cond not containing vars
+std::pair<PrimExpr, PrimExpr> ImplicationNotContainingVars(
+    const PrimExpr& cond, const std::unordered_set<const VarNode*>& vars) {
+  CHECK(cond.dtype().is_bool()) << "The type of cond must be bool";
+  // TODO(sgrechanik-h): NOTs could be pushed down using De Morgan laws
+  // before running this function but this case didn't seem to be important enough.
+  if (const AndNode* op = cond.as<AndNode>()) {
+    auto pair_a = ImplicationNotContainingVars(op->a, vars);
+    auto pair_b = ImplicationNotContainingVars(op->b, vars);
+    return {pair_a.first && pair_b.first, pair_a.second && pair_b.second};
+  } else if (const OrNode* op = cond.as<OrNode>()) {
+    auto pair_a = ImplicationNotContainingVars(op->a, vars);
+    auto pair_b = ImplicationNotContainingVars(op->b, vars);
+    return {pair_a.first || pair_b.first, (pair_a.first || pair_b.second) &&
+                                              (pair_b.first || pair_a.second) &&
+                                              (pair_a.second || pair_b.second)};
+  } else if (!tir::ExprUseVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) {
+    return {cond, const_true()};
+  } else {
+    return {const_true(), cond};
+  }
+}
+
+// Factor conditions out of a reduction by applying Fourier-Motzkin elimination and moving out
+// (in)equalities which do not depend on the reduction variables.
+std::pair<PrimExpr, PrimExpr> LiftConditionsThroughReduction(const PrimExpr& cond,
+                                                             const Array<IterVar>& red_axis,
+                                                             const Array<IterVar>& outer_axis) {
+  // Factor out atomics so that we can consider this as a system of inequalities
+  auto factor_atomic_res = FactorOutAtomicFormulas(cond);
+  Array<PrimExpr> atomics = factor_atomic_res.atomic_formulas;
+  const PrimExpr& rest = factor_atomic_res.rest;
+
+  Array<Var> allvars;
+  for (const IterVar& v : red_axis) {
+    allvars.push_back(v->var);
+  }
+  for (const IterVar& v : outer_axis) {
+    allvars.push_back(v->var);
+  }
+
+  auto vranges = Merge(IterVarsToMap(red_axis), IterVarsToMap(outer_axis));
+  // start from reduction vars, so that input vars don't depend on them
+  arith::IntConstraints ineq_to_solve(allvars, vranges, atomics);
+  auto res_ineq = arith::SolveLinearInequalities(ineq_to_solve);
+  atomics = arith::AsConditions(allvars, res_ineq.first, res_ineq.second);
+
+  // Append the rest part
+  PrimExpr rewritten_cond = All(atomics) && rest;
+
+  std::unordered_set<const VarNode*> vset;
+  for (const IterVar& v : red_axis) {
+    vset.insert(v->var.get());
+  }
+
+  // The outer (first) condition does not contain reduction vars,
+  // the inner (second) condition is everything else
+  auto res = ImplicationNotContainingVars(rewritten_cond, vset);
+  return res;
+}
+
+// Convert an array of itervars to an array of inequalities
+Array<PrimExpr> IterVarsToInequalities(const Array<IterVar>& itervars) {
+  Array<PrimExpr> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(GE(v->var, v->dom->min));
+    res.push_back(LT(v->var, v->dom->min + v->dom->extent));
+  }
+  return res;
+}
+
+class RemoveRedundantInequalitiesMutator : public ExprMutator {
+ public:
+  explicit RemoveRedundantInequalitiesMutator(Array<PrimExpr> known) {
+    for (const PrimExpr& cond : known) {
+      known_.push_back(analyzer_.Simplify(cond, kSimplifyRewriteCanonicalRewrite));
+    }
+  }
+
+  virtual PrimExpr VisitExpr_(const SelectNode* op) {
+    bool has_side_effect = (SideEffect(GetRef<PrimExpr>(op)) > CallEffectKind::kReadState);
+    PrimExpr new_cond =
+        analyzer_.Simplify(VisitExpr(op->condition), kSimplifyRewriteCanonicalRewrite);
+    if (is_one(new_cond) && !has_side_effect) {
+      return VisitExpr(op->true_value);
+    } else if (is_zero(new_cond) && !has_side_effect) {
+      return VisitExpr(op->false_value);
+    } else {
+      Array<PrimExpr> new_known = known_;
+      for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) {
+        new_known.push_back(atomic);
+      }
+      RemoveRedundantInequalitiesMutator new_mutator(new_known);
+      // Note that we mutate only the true value with the new mutator
+      // TODO(sgrechanik-h): Update known conditions for the false value as well
+      return Select(new_cond, new_mutator(op->true_value), VisitExpr(op->false_value));
+    }
+  }
+
+  virtual PrimExpr VisitExpr_(const CallNode* op) {
+    if (op->op.same_as(op_if_then_else_)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(VisitExpr(op->args[0]), kSimplifyRewriteCanonicalRewrite);
+      if (is_one(new_cond)) {
+        return VisitExpr(op->args[1]);
+      } else if (is_zero(new_cond)) {
+        return VisitExpr(op->args[2]);
+      } else {
+        Array<PrimExpr> new_known = known_;
+        for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) {
+          new_known.push_back(atomic);
+        }
+        RemoveRedundantInequalitiesMutator new_mutator(new_known);
+        // Note that we mutate only the true value with the new mutator
+        // TODO(sgrechanik-h): Update known conditions for the false value as well
+        return if_then_else(new_cond, new_mutator(op->args[1]), VisitExpr(op->args[2]));
+      }
+    } else {
+      return ExprMutator::VisitExpr_(op);
+    }
+  }
+
+  virtual PrimExpr VisitExpr_(const ReduceNode* op) {
+    Array<PrimExpr> known_with_axes = known_;
+    for (const PrimExpr& axis_cond : IterVarsToInequalities(op->axis)) {
+      known_with_axes.push_back(axis_cond);
+    }
+    RemoveRedundantInequalitiesMutator mutator_with_axes(known_with_axes);
+
+    PrimExpr new_cond = mutator_with_axes(op->condition);
+
+    Array<PrimExpr> new_known = known_with_axes;
+    for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) {
+      new_known.push_back(atomic);
+    }
+    RemoveRedundantInequalitiesMutator new_mutator(new_known);
+
+    Array<PrimExpr> new_source;
+    for (const PrimExpr& src : op->source) {
+      new_source.push_back(new_mutator(src));
+    }
+
+    return Reduce(op->combiner, new_source, op->axis, new_cond, op->value_index);
+  }
+
+  virtual PrimExpr VisitExpr_(const EQNode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
+  virtual PrimExpr VisitExpr_(const NENode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
+  virtual PrimExpr VisitExpr_(const LTNode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
+  virtual PrimExpr VisitExpr_(const LENode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
+  virtual PrimExpr VisitExpr_(const GTNode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
+  virtual PrimExpr VisitExpr_(const GENode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
+
+  virtual PrimExpr VisitExpr_(const AndNode* op) { return VisitExpr(op->a) && VisitExpr(op->b); }
+
+ private:
+  PrimExpr MutateAtomic_(const PrimExpr& e) {
+    PrimExpr simplified = analyzer_.Simplify(e, kSimplifyRewriteCanonicalRewrite);
+    for (const PrimExpr& other : known_) {
+      if (ExprDeepEqual()(simplified, other)) {
+        return const_true();
+      }
+    }
+    return simplified;
+  }
+
+  Array<PrimExpr> known_;
+  arith::Analyzer analyzer_;
+  const Op& op_if_then_else_ = Op::Get("tir.if_then_else");
+};
+
+// Propagate information from conditions and remove redundant inequalities
+inline PrimExpr RemoveRedundantInequalities(const PrimExpr& expr, const Array<PrimExpr>& known) {
+  return RemoveRedundantInequalitiesMutator(known)(expr);
+}
+
+// Extract the given expr under the given condition as a separate tensor if the volume of the
+// extracted tensor will be less than the volume of the outer_axis
+PrimExpr TrySimplifyCompute(const PrimExpr& expr, const PrimExpr& cond,
+                            const Array<Var>& outer_axis, const Map<Var, Range>& vranges) {
+  // solve cond, e.g., (jac_i0 == i) && (jac_i1 == j)
+  arith::IntConstraints domain_to_solve(outer_axis, vranges,
+                                        FactorOutAtomicFormulas(cond).to_array());
+  auto res = SimplifyDomain(domain_to_solve);
+
+  arith::Analyzer analyzer;
+  analyzer.Bind(res->dst->ranges);
+  PrimExpr new_expr =
+      analyzer.Simplify(Substitute(expr, res->src_to_dst), kSimplifyRewriteCanonicalRewrite);
+  // TODO(yzhliu): This is mostly done to simplify if_then_else
+  // which is not realized by the canonical simplifier
+  new_expr = RemoveRedundantInequalities(new_expr, res->dst->relations);
+
+  // Keep only those variables of the new vars which are used in the new_expr
+  Array<Var> used_res_variables;
+  for (const Var& var : res->dst->variables) {
+    if (ExprUseVar(new_expr, var)) {
+      CHECK(res->dst->ranges.count(var)) << "Range of " << var << " cannot be inferred.";
+      used_res_variables.push_back(var);
+    }
+  }
+
+  // If the expression does not use vars then it is probably better to keep it inlined
+  if (used_res_variables.empty()) {
+    // We can return the new_expr here instead of the old expr because it doesn't use variables
+    // otherwise we would need to replace the new vars or create a let-expression
+    return new_expr;
+  }
+
+  // If it's already tensor[...] then it will probably be useless to further simplify it.
+  if (new_expr.as<ProducerLoadNode>()) {
+    return expr;
+  }
+
+  // Compute volumes before and after
+  PrimExpr old_volume = make_const(DataType::Int(64), 1);
+  for (const Var& var : outer_axis) {
+    CHECK(vranges.count(var)) << "Range of " << var << " was not provided.";
+    old_volume = old_volume * vranges[var]->extent;
+  }
+
+  PrimExpr new_volume = make_const(DataType::Int(64), 1);
+  for (const Var& var : used_res_variables) {
+    new_volume = new_volume * res->dst->ranges[var]->extent;
+  }
+
+  // if we can prove that the old volume is not greater than the new volume then
+  // prefer the old expression.
+  arith::Analyzer ana_vranges;
+  ana_vranges.Bind(vranges);
+  if (ana_vranges.CanProve(old_volume <= new_volume)) {
+    return expr;
+  }
+
+  Tensor tensor = TensorFromExpr(new_expr, IterVarsFromMap(used_res_variables, res->dst->ranges),
+                                 "extracted_tensor");
+
+  Array<PrimExpr> args;
+  for (const Var& var : used_res_variables) {
+    args.push_back(res->dst_to_src[var]);
+  }
+
+  return ProducerLoad(tensor, args);
+}
+
+class ReductionAsTensorAccessMutator : public ExprMutator {
+ public:
+  explicit ReductionAsTensorAccessMutator(const Array<Var>& outer_axis, Map<Var, Range> vranges,
+                                          std::string name = "extracted_reduction")
+      : outer_axis_(outer_axis), vranges_(std::move(vranges)), name_(std::move(name)) {}
+
+  PrimExpr VisitExpr_(const ReduceNode* op) final {
+    ReductionAsTensorAccessMutator new_mutator(Concat(IterVarsToVars(op->axis), outer_axis_),
+                                               Merge(vranges_, IterVarsToMap(op->axis)), name_);
+
+    Array<PrimExpr> new_source;
+    for (const PrimExpr& src : op->source) {
+      new_source.push_back(new_mutator(src));
+    }
+
+    PrimExpr new_reduce =
+        Reduce(op->combiner, new_source, op->axis, op->condition, op->value_index);
+
+    Array<Var> undefined_vars = UndefinedVars(new_reduce);
+    std::unordered_set<const VarNode*> undefined_var_set;
+    for (const Var& var : undefined_vars) {
+      undefined_var_set.insert(var.get());
+    }
+
+    // Vars of the tensor we are going to create for this reduction
+    Array<Var> vars;
+    for (const Var& v : outer_axis_) {
+      // We take variables from the outer_axis_ which are also present in the new reduction
+      if (undefined_var_set.count(v.get())) {
+        vars.push_back(v);
+      }
+    }
+
+    auto new_axis_vmap_pair = CloneIterVars(IterVarsFromMap(vars, vranges_));
+    Array<IterVar> new_axis = new_axis_vmap_pair.first;
+    arith::Analyzer analyzer;
+    analyzer.Bind(IterVarsToMap(new_axis));
+    new_reduce = analyzer.Simplify(Substitute(new_reduce, new_axis_vmap_pair.second),
+                                   kSimplifyRewriteCanonicalRewrite);
+
+    Tensor tensor = TensorFromExpr(new_reduce, new_axis, name_, tag_, attrs_);
+
+    Array<PrimExpr> args;
+    for (const Var& v : vars) {
+      args.push_back(v);
+    }
+
+    return ProducerLoad(tensor, args);
+  }
+
+ private:
+  Array<Var> outer_axis_;
+  Map<Var, Range> vranges_;
+  std::string name_;
+  std::string tag_;
+  Map<String, ObjectRef> attrs_;
+};
+
+// Extract reductions as separate tensors.
+inline PrimExpr ReductionAsTensorAccess(const PrimExpr& expr, const Array<Var>& outer_axis,
+                                        const Map<Var, Range>& vranges) {
+  return ReductionAsTensorAccessMutator(outer_axis, vranges)(expr);
+}
+
+PrimExpr LiftReductions(const PrimExpr& expr, const Array<Var>& outer_axis,
+                        const Map<Var, Range>& vranges) {
+  if (const ReduceNode* red = expr.as<ReduceNode>()) {
+    Array<Var> new_outer_axis = Concat(IterVarsToVars(red->axis), outer_axis);
+    Map<Var, Range> new_vranges = Merge(vranges, IterVarsToMap(red->axis));
+    Array<PrimExpr> new_source;
+    for (const PrimExpr& src : red->source) {
+      new_source.push_back(ReductionAsTensorAccess(src, new_outer_axis, new_vranges));
+    }
+    PrimExpr new_condition = ReductionAsTensorAccess(red->condition, new_outer_axis, new_vranges);
+
+    return Reduce(red->combiner, new_source, red->axis, new_condition, red->value_index);
+  } else {
+    return ReductionAsTensorAccess(expr, outer_axis, vranges);
+  }
+}
+
+PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const Array<IterVar>& axis,
+                                              const Map<Var, Range>& vranges) {
+  PrimExpr result;
+  Map<Var, Range> combined_vranges = Merge(vranges, IterVarsToMap(axis));
+  arith::Analyzer analyzer;
+  analyzer.Bind(combined_vranges);
+
+  // Simplify the original expression first, mostly to simplify combiners
+  PrimExpr expr = analyzer.Simplify(expr_orig, kSimplifyRewriteCanonicalRewrite);
+
+  if (const ReduceNode* red = expr.as<ReduceNode>()) {
+    // TODO(sgrechanik-h): There are some other operations which behave like sum
+    bool is_sum = IsSumCombiner(red->combiner, vranges);
+    if (is_sum || CanFactorZeroFromCombiner(red->combiner, red->value_index, vranges)) {
+      PrimExpr new_red = expr;
+
+      // Here we simplify the reduction
+      PrimExpr cond = red->condition;
+      Array<PrimExpr> source = red->source;
+
+      // If it is a summation then we can lift nonzeroness conditions from the source
+      // and add them to the reduction conditions
+      if (is_sum) {
+        auto nz = NonzeronessCondition(red->source[red->value_index]);
+        cond = nz.cond && cond;
+        source.Set(0, nz.value);
+      }
+
+      new_red = Reduce(red->combiner, source, red->axis, cond, red->value_index);
+      new_red = SimplifyReductionDomain(new_red, combined_vranges);
+      // If the reduction disappears completely then transform the result as a non-reduction
+      if (!new_red.as<ReduceNode>()) {
+        return RemoveJacobianAndLiftNonzeroCondImpl(new_red, axis, vranges);
+      }
+
+      PrimExpr new_outer_cond, new_reduce_cond;
+      Array<PrimExpr> new_source = red->source;
+
+      // Partially lift conditions from the reduce condition
+      std::tie(new_outer_cond, new_reduce_cond) =
+          LiftConditionsThroughReduction(red->condition, red->axis, axis);
+
+      // If it's not sum then we haven't yet lifted nonzeroness cond from the source
+      if (!is_sum) {
+        PrimExpr outer_nz_cond, nz_cond, nz_source;
+        auto nz = NonzeronessCondition(red->source[red->value_index]);
+        // Append conditions from the reduction
+        nz_cond = new_reduce_cond && nz.cond;
+        nz_source = nz.value;
+        std::tie(outer_nz_cond, nz_cond) = LiftConditionsThroughReduction(nz_cond, red->axis, axis);
+        new_outer_cond = new_outer_cond && outer_nz_cond;
+        new_source.Set(red->value_index, Select(nz_cond, nz_source, make_zero(nz_source.dtype())));
+      }
+
+      PrimExpr new_reduce =
+          Reduce(red->combiner, new_source, red->axis, new_reduce_cond, red->value_index);
+      new_reduce =
+          TrySimplifyCompute(new_reduce, new_outer_cond, IterVarsToVars(axis), combined_vranges);
+      result = Select(new_outer_cond, new_reduce, make_zero(new_reduce.dtype()));
+    } else {
+      return SimplifyReductionDomain(expr, combined_vranges);
+    }
+  } else {
+    auto nz = NonzeronessCondition(expr);
+    PrimExpr new_expr =
+        TrySimplifyCompute(nz.value, nz.cond, IterVarsToVars(axis), combined_vranges);
+    result = Select(nz.cond, new_expr, make_zero(new_expr.dtype()));
+  }
+
+  // Note that RemoveRedundantInequalities can sometimes propagate equalities which
+  // other simplifiers cannot, like (i % 3) == 0.
+  Array<PrimExpr> axis_conds = IterVarsToInequalities(axis);
+  result = RemoveRedundantInequalities(result, axis_conds);
+
+  // Currently in TVM reductions are only allowed at the top level of compute,
+  // we need to extract intermediate inlined reduction as a separate stage (tensor).
+  // Sometimes TrySimplifyCompute doesn't perform lift / extraction,
+  // so there may be some non-top reductions left, take care of them.
+  result = LiftReductions(result, IterVarsToVars(axis), combined_vranges);
+  return analyzer.Simplify(result, kSimplifyRewriteCanonicalRewrite);
+}
+
+Tensor RemoveJacobianAndLiftNonzeroCond(const Tensor& tensor, const Map<Var, Range>& vranges) {
+  auto transform_func = [&vranges](const PrimExpr& expr, const Array<IterVar>& axis) {
+    return RemoveJacobianAndLiftNonzeroCondImpl(expr, axis, vranges);
+  };
+  return TransformTensorBody(tensor, transform_func);
+}
+
+}  // namespace te
+}  // namespace tvm
index 89ff96d..995c8e0 100644 (file)
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt_functor.h>
 
+#include <set>
 #include <string>
 
+#include "../schedule/operation_inline.h"
+
 namespace tvm {
 namespace te {
 
@@ -60,5 +63,134 @@ PrimExpr CloneReduction(const PrimExpr& expr) {
   }
 }
 
+Operation ComputeOpFromExprs(const Array<PrimExpr>& exprs, const Array<IterVar>& axis,
+                             const std::string& name, const std::string& tag,
+                             const Map<String, ObjectRef>& attrs, bool clone_axis) {
+  if (clone_axis) {
+    Array<IterVar> new_axis = axis;
+    Map<Var, PrimExpr> vmap;
+    std::tie(new_axis, vmap) = CloneIterVars(axis);
+    Array<PrimExpr> new_exprs;
+    for (const PrimExpr& e : exprs) {
+      new_exprs.push_back(Substitute(CloneReduction(e), vmap));
+    }
+    return ComputeOpFromExprs(new_exprs, new_axis, name, tag, attrs, false);
+  }
+
+  Array<PrimExpr> new_exprs;
+
+  // If this is a reduction then we have to replicate it
+  if (const ReduceNode* red = exprs[0].as<ReduceNode>()) {
+    for (size_t i = 0; i < red->source.size(); ++i) {
+      PrimExpr ith_red = Reduce(red->combiner, red->source, red->axis, red->condition, i);
+      new_exprs.push_back(ith_red);
+    }
+  } else {
+    new_exprs = exprs;
+  }
+
+  return ComputeOp(name, tag, attrs, axis, new_exprs);
+}
+
+Tensor TensorFromExpr(const PrimExpr& expr, const Array<IterVar>& axis, const std::string& name,
+                      const std::string& tag, const Map<String, ObjectRef>& attrs,
+                      bool clone_axis) {
+  int new_value_index = 0;
+  if (const ReduceNode* red = expr.as<ReduceNode>()) {
+    new_value_index = red->value_index;
+  }
+  return ComputeOpFromExprs({expr}, axis, name, tag, attrs, clone_axis).output(new_value_index);
+}
+
+Tensor TransformTensorBody(
+    const Tensor& tensor,
+    const std::function<PrimExpr(const PrimExpr&, const Array<IterVar>&)>& func) {
+  if (const ComputeOpNode* op = tensor->op.as<ComputeOpNode>()) {
+    // Transform only one body
+    PrimExpr new_body = func(op->body[tensor->value_index], op->axis);
+
+    // If the body didn't change then we can return the same tensor
+    if (new_body.same_as(op->body[tensor->value_index])) {
+      return tensor;
+    }
+
+    return TensorFromExpr(new_body, op->axis, op->name, op->tag, op->attrs);
+  } else {
+    return tensor;
+  }
+}
+
+Tensor TransformTensorBody(const Tensor& tensor,
+                           const std::function<PrimExpr(const PrimExpr&)>& func) {
+  return TransformTensorBody(tensor,
+                             [func](const PrimExpr& e, const Array<IterVar>&) { return func(e); });
+}
+
+// If expr is a Tensor Access node, perform inlining, otherwise do nothing
+PrimExpr InlineImmediateTensorAccess(const PrimExpr& expr) {
+  if (const ProducerLoadNode* op = expr.as<ProducerLoadNode>()) {
+    auto tensor = Downcast<te::Tensor>(op->producer);
+    if (const ComputeOpNode* op_comp = tensor->op.as<ComputeOpNode>()) {
+      Array<Var> tensor_axes;
+      for (const auto& var : op_comp->axis) {
+        tensor_axes.push_back(var->var);
+      }
+
+      Stmt inlined =
+          Inline(Evaluate(expr), tensor->op, tensor_axes, op_comp->body[tensor->value_index]);
+      if (const EvaluateNode* ev = inlined.as<EvaluateNode>()) {
+        // If it is a reduction, clone it
+        return CloneReduction(ev->value);
+      }
+    }
+  }
+  return expr;
+}
+
+// Implements InlineTensors by trying to inline every Call of the given Expr
+class InlineTensorsMutator : public ExprMutator {
+ public:
+  explicit InlineTensorsMutator(const Array<Tensor>& inlineable, bool inline_reductions = false)
+      : inline_reductions_(inline_reductions) {
+    for (const Tensor& tensor : inlineable) {
+      inlineable_.emplace(tensor->op.operator->(), tensor->value_index);
+    }
+  }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    auto tensor = Downcast<te::Tensor>(op->producer);
+    if (const ComputeOpNode* op_comp = tensor->op.as<ComputeOpNode>()) {
+      // Inline only if the array of inlineable tensors is empty or contains this tensor
+      if (inlineable_.empty() || inlineable_.count({op_comp, tensor->value_index})) {
+        // Inline only compute nodes that are not reductions (unless inline reductions is allowed)
+        if (inline_reductions_ || !op_comp->body[0].as<ReduceNode>()) {
+          PrimExpr expr = GetRef<PrimExpr>(op);
+          // Inline this tensor access and then try to perform further inlining
+          return VisitExpr(InlineImmediateTensorAccess(expr));
+        }
+      }
+    }
+    // If we cannot inline this call, we should try to do inlining in its arguments
+    return ExprMutator::VisitExpr_(op);
+  }
+
+ private:
+  // Tensors which are allowed to be inlined, represented as pairs (op_node, value_index)
+  std::set<std::pair<const OperationNode*, int>> inlineable_;
+  bool inline_reductions_;
+};
+
+Tensor InlineTensorAccess(const Tensor& tensor, const Array<Tensor>& inlineable,
+                          bool inline_reductions) {
+  auto transformation = [inlineable, inline_reductions](const PrimExpr& e) {
+    return InlineTensorsMutator(inlineable, inline_reductions)(e);
+  };
+  return TransformTensorBody(tensor, transformation);
+}
+
+Tensor InlineTailTensorAccess(const Tensor& tensor) {
+  return TransformTensorBody(tensor, InlineImmediateTensorAccess);
+}
+
 }  // namespace te
 }  // namespace tvm
index 56ab6c1..21de61c 100644 (file)
 #ifndef TVM_TE_AUTODIFF_AD_UTIL_H_
 #define TVM_TE_AUTODIFF_AD_UTIL_H_
 
+#include <tvm/arith/int_solver.h>
 #include <tvm/te/operation.h>
 #include <tvm/tir/expr.h>
 
+#include <string>
 #include <unordered_map>
 #include <utility>
 #include <vector>
@@ -48,6 +50,86 @@ std::pair<Array<IterVar>, Map<Var, PrimExpr>> CloneIterVars(const Array<IterVar>
  */
 PrimExpr CloneReduction(const PrimExpr& expr);
 
+/*!
+ * \brief Create a tensor from an expression. The expression may be a reduction, in which
+ *  case its body will be correctly duplicated if it is a multi-valued reduction.
+ *
+ * \param expr The expr which will be the tensor's body.
+ * \param axis The input variables with ranges.
+ * \param name The tensor's name.
+ * \param tag The tensor's tag.
+ * \param attrs The tensor's attrs.
+ * \param clone_axis Whether to clone the given axis and perform substitution.
+ * \return A tensor.
+ */
+Tensor TensorFromExpr(const PrimExpr& expr, const Array<IterVar>& axis,
+                      const std::string& name = "tensor", const std::string& tag = "",
+                      const Map<String, ObjectRef>& attrs = {}, bool clone_axis = true);
+
+Tensor TransformTensorBody(
+    const Tensor& tensor,
+    const std::function<PrimExpr(const PrimExpr&, const Array<IterVar>&)>& func);
+
+Tensor TransformTensorBody(const Tensor& tensor,
+                           const std::function<PrimExpr(const PrimExpr&)>& func);
+
+/*!
+ * \brief Inline tensors access recursively.
+ *
+ *  This function will inline tensors recursively until it reaches a tensor which is impossible to
+ *  inline (a reduction if \p inline_reductions is false, a non-compute tensor, a tensor which is
+ *  not from \p inlineable). It won't descend into non-inlinable tensors' bodies.
+ *
+ * \param tensor The tensor whose body to transform.
+ * \param inlineable A list of tensors which are allowed to be inlined. If empty, try
+ *  to inline all tensors.
+ * \param inline_reductions Whether to inline reductions (this may result in top-level reduction
+ *  nodes).
+ *
+ * \return An inlined tensor
+ */
+TVM_DLL Tensor InlineTensorAccess(const Tensor& tensor,
+                                  const Array<Tensor>& inlineable = Array<Tensor>(),
+                                  bool inline_reductions = false);
+
+/*!
+ * \brief Inline tensors access at the tail.
+ * \param tensor The tensor whose body to transform.
+ * \return An inlined tensor
+ */
+TVM_DLL Tensor InlineTailTensorAccess(const Tensor& tensor);
+
+/*!
+ * \brief Simplify an iteration domain.
+ *
+ *  An iteration domain is basically an array of variables and a condition. The function will do the
+ *  following:
+ *  - Replace div and mod operations with new variables (optional).
+ *  - Extract (in)equalities from the condition.
+ *  - Perform Fourier-Motzkin elimination.
+ *  - Shear the domain of iteration (e.g. if `y <= x <= y + 2` then x will be replaced with `y + d`
+ *    where `d` is a new variable such that `0 <= d <= 2`).
+ *  - Remove redundant variables.
+ *  - Infer new variable ranges (hopefully more precise).
+ *
+ * \param iter_domains The original domain.
+ * \param eliminate_div_mod Whether to eliminate div and mod by introducing new variables.
+ */
+TVM_DLL arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains,
+                                                      bool eliminate_div_mod = true);
+
+/*!
+ * \brief Perform lifting of conditions of being possible to be non-zero together with
+ *  applying some transformations like simplifying the reduction domain. Works only with
+ *  this particular tensor's body, i.e. doesn't perform inlining.
+ *
+ * \param tensor The original tensor;
+ * \param vranges Optional map from free variables to their value ranges.
+ * \return An optimized tensor.
+ */
+TVM_DLL Tensor RemoveJacobianAndLiftNonzeroCond(const Tensor& tensor,
+                                                const Map<Var, Range>& vranges = Map<Var, Range>());
+
 }  // namespace te
 }  // namespace tvm
 #endif  // TVM_TE_AUTODIFF_AD_UTIL_H_
index 8b7c428..d027b39 100644 (file)
@@ -39,6 +39,8 @@
 #include <memory>
 #include <vector>
 
+#include "ad_util.h"
+
 namespace tvm {
 namespace te {
 
@@ -63,6 +65,10 @@ Tensor VectorJacobianProduct(const Tensor& output, const Tensor& input, const Te
   Tensor jac = Jacobian(output, input);
   Tensor result = topi::tensordot(head, jac, /*axes=*/output->shape.size(),
                                   output->op->name + "." + input->op->name + ".grad");
+  result = InlineTensorAccess(result, {jac}, false);
+  result = RemoveJacobianAndLiftNonzeroCond(result);
+  // inline tail call
+  result = InlineTailTensorAccess(result);
   return result;
 }
 
index e2479d8..e769e54 100644 (file)
@@ -356,7 +356,9 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) {
     new_shape.push_back(e);
   }
 
-  return Tensor(new_shape, output->dtype, new_op, value_index);
+  Tensor ret = Tensor(new_shape, output->dtype, new_op, value_index);
+  ret = RemoveJacobianAndLiftNonzeroCond(ret);
+  return ret;
 }
 
 }  // namespace te
index d5b51cb..5298add 100644 (file)
@@ -134,6 +134,13 @@ class VarUseDefAnalysis : public StmtExprMutator {
     return StmtExprMutator::VisitExpr_(op);
   }
 
+  PrimExpr VisitExpr_(const ReduceNode* op) final {
+    for (const auto& iv : op->axis) {
+      this->HandleDef(iv->var.get());
+    }
+    return StmtExprMutator::VisitExpr_(op);
+  }
+
   PrimExpr VisitExpr_(const LoadNode* op) final {
     this->HandleUse(op->buffer_var);
     return StmtExprMutator::VisitExpr_(op);
@@ -187,6 +194,13 @@ Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
   return m.undefined_;
 }
 
+Array<Var> UndefinedVars(const PrimExpr& expr) {
+  VarUseDefAnalysis m;
+  m.simplify_let_ = false;
+  m(expr);
+  return m.undefined_;
+}
+
 class HostDeviceSplitter : public StmtMutator {
  public:
   explicit HostDeviceSplitter(IRModule* device_mod, Target device_target, std::string name_prefix)
index 5415874..25accde 100644 (file)
@@ -24,7 +24,7 @@ from tvm.topi.util import get_const_tuple
 import numpy as np
 
 
-def check_grad(out, inputs, data_range=(-10, 10), desired_grads=None):
+def check_grad(out, inputs, args=[], data_range=(-10, 10), desired_grads=None, assert_no_jacobian=True):
     inputs = inputs if isinstance(inputs, list) else [inputs]
 
     def check_device(device, host="llvm"):
@@ -36,26 +36,32 @@ def check_grad(out, inputs, data_range=(-10, 10), desired_grads=None):
             return
 
         sout = te.create_schedule(out.op)
-        mout = tvm.build(sout, [out] + inputs)
+        mout = tvm.build(sout, [out] + inputs + args)
         out_shape = get_const_tuple(out.shape)
 
         l, h = data_range
         input_data = [tvm.nd.array(
             np.random.uniform(l, h, size=get_const_tuple(input.shape)).astype(input.dtype))
             for input in inputs]
+        arg_vals = [tvm.nd.array(
+            np.random.uniform(l, h, size=get_const_tuple(arg.shape)).astype(arg.dtype))
+            for arg in args]
 
         ones = topi.full_like(out, 1.0)
         # we provide head to sum and reduce the output dimension,
         # which equals to grad(out.sum(), inputs)
         grads = te.gradient(out, inputs, head=ones)
         grad_sched = te.create_schedule([grad.op for grad in grads])
-        mgrad = tvm.build(grad_sched, list(grads) + inputs)
-        # print(tvm.lower(grad_sched, list(grads) + inputs, simple_mode=True))
+        mgrad = tvm.build(grad_sched, list(grads) + inputs + args)
+        if assert_no_jacobian:
+            # TODO(yzhliu): it is better to visit the expression and do assertion
+            lowered_ir = str(tvm.lower(grad_sched, list(grads) + inputs + args, simple_mode=True))
+            assert "jacobian" not in lowered_ir, lowered_ir
 
         grad_data = [tvm.nd.empty(get_const_tuple(i.shape), g.dtype)
                      for i, g in zip(inputs, grads)]
 
-        mgrad(*grad_data, *input_data)
+        mgrad(*grad_data, *input_data, *arg_vals)
         g_res = [g.asnumpy() for g in grad_data]
 
         if desired_grads:
@@ -67,7 +73,7 @@ def check_grad(out, inputs, data_range=(-10, 10), desired_grads=None):
                 out_data = tvm.nd.empty(out_shape, out.dtype)
                 mout(out_data, *[tvm.nd.array(d) for d in list(in_data)])
                 return out_data.asnumpy().sum()
-            check_numerical_grads(forward, [d.asnumpy() for d in input_data], g_res)
+            check_numerical_grads(forward, [d.asnumpy() for d in input_data + arg_vals], g_res)
 
     check_device("cpu")
 
@@ -158,15 +164,168 @@ def test_basic_operation():
     check_grad(Y, X)
 
 
-def test_conv2d():
-    np.random.seed(0)
+def test_topi():
     X = te.placeholder((1, 2, 4, 4), name='X')
     W = te.placeholder((5, 2, 3, 3), name='W')
+    W1 = te.placeholder((2, 5, 3, 3), name='W1')
+    W2 = te.placeholder((1,), name='W2')
 
     R = topi.nn.conv2d(X, W, 1, 1, 1)
     check_grad(R, [X, W])
 
+    R1 = topi.nn.conv2d(topi.nn.relu(R), W1, 1, 0, 1)
+    check_grad(R1, [X, W, W1])
+
+    R = topi.broadcast_to(W2, (5, 2, 3, 3))
+    check_grad(R, [W2])
+
+    R = topi.nn.conv2d(X, topi.broadcast_to(W2, (5, 2, 3, 3)), 1, 1, 1)
+    check_grad(R, [X, W2])
+
+    R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'avg')
+    check_grad(R, X)
+
+    R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'max')
+    check_grad(R, X)
+
+    X = te.placeholder((1, 2, 5, 5), name='X')
+    R = topi.reshape(X, (1, 32))
+    check_grad(R, [X])
+
+    X = te.placeholder((1, 2, 5, 5), name='X')
+    W = te.placeholder((2, 2, 3, 3), name='W')
+
+    S = topi.reshape(X, (1, 50))
+    check_grad(S, [X])
+
+    R = X + topi.nn.conv2d(X + topi.nn.conv2d(X, W, 1, 1, 1), W, 1, 1, 1)
+    check_grad(R, [X, W])
+
+    S = topi.nn.softmax(topi.reshape(R, (1, 50)))
+    check_grad(S, [X, W])
+
+    S = topi.sigmoid(topi.reshape(R, (1, 50)))
+    check_grad(S, [X, W])
+
+    S = topi.tanh(topi.reshape(R, (1, 50)))
+    check_grad(S, [X, W])
+
+    S = topi.nn.log_softmax(topi.reshape(R, (1, 50)))
+    check_grad(S, [X, W])
+    check_grad(S, [W], [X])
+
+    X = te.placeholder((1, 2, 3, 5), name='X')
+    Y = te.placeholder((1, 2, 7, 5), name='Y')
+    S = topi.concatenate((X, Y), 2)
+    check_grad(S, [X, Y])
+
+    X = te.placeholder((1, 2, 6, 5), name='X')
+    (S, R) = topi.split(X, 2, 2)
+    check_grad(S, [X])
+    check_grad(R, [X])
+    R1 = topi.concatenate((S, R), 2)
+    check_grad(R1, [X])
+    R2 = topi.concatenate((R, S), 2)
+    check_grad(R2, [X])
+
+    X = te.placeholder((4, 5), name='X')
+    I = te.placeholder((100,), name='I', dtype='int32')
+    R = topi.take(X, topi.abs(I))
+    check_grad(R, [X], [I])
+
+    W = te.placeholder((5, 5), name='W')
+    exps = topi.exp(topi.nn.dense(X, W))
+    sumexps = topi.sum(exps, axis=-1, keepdims=True)
+    R = exps/sumexps
+    check_grad(R, [X, W], data_range=(-1, 1))
+
+
+def test_stride_dilation():
+    X = te.placeholder((1, 2, 10, 10), name='X')
+    W = te.placeholder((2, 2, 1, 1), name='W')
+
+    Y = topi.nn.conv2d(X, W, 1, 0, 1)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 2, 0, 1)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 3, 0, 1)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 1, 0, 2)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 2, 0, 2)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 3, 0, 2)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 1, 0, 3)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 2, 0, 3)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 3, 0, 3)
+    check_grad(Y, [X, W])
+
+    W = te.placeholder((2, 2, 2, 2), name='W')
+
+    Y = topi.nn.conv2d(X, W, 1, 0, 1)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 2, 0, 1)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 3, 0, 1)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 1, 0, 2)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 2, 0, 2)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 3, 0, 2)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 1, 0, 3)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 2, 0, 3)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 3, 0, 3)
+    check_grad(Y, [X, W])
+
+    W = te.placeholder((2, 2, 3, 3), name='W')
+
+    Y = topi.nn.conv2d(X, W, 1, 0, 1)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 2, 0, 1)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 3, 0, 1)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 1, 0, 2)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 2, 0, 2)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 3, 0, 2)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 1, 0, 3)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 2, 0, 3)
+    check_grad(Y, [X, W])
+    Y = topi.nn.conv2d(X, W, 3, 0, 3)
+    check_grad(Y, [X, W])
+
+    Y = topi.nn.pool(X, [1, 1], [1, 1], [0, 0, 0, 0], 'max')
+    check_grad(Y, [X])
+    Y = topi.nn.pool(X, [1, 1], [2, 2], [0, 0, 0, 0], 'max')
+    check_grad(Y, [X])
+    Y = topi.nn.pool(X, [1, 1], [3, 3], [0, 0, 0, 0], 'max')
+    check_grad(Y, [X])
+    Y = topi.nn.pool(X, [2, 2], [1, 1], [0, 0, 0, 0], 'max')
+    check_grad(Y, [X])
+    Y = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'max')
+    check_grad(Y, [X])
+    Y = topi.nn.pool(X, [2, 2], [3, 3], [0, 0, 0, 0], 'max')
+    check_grad(Y, [X])
+    Y = topi.nn.pool(X, [3, 3], [1, 1], [0, 0, 0, 0], 'max')
+    check_grad(Y, [X])
+    Y = topi.nn.pool(X, [3, 3], [2, 2], [0, 0, 0, 0], 'max')
+    check_grad(Y, [X])
+    Y = topi.nn.pool(X, [3, 3], [3, 3], [0, 0, 0, 0], 'max')
+    check_grad(Y, [X])
+
 
 if __name__ == "__main__":
     test_basic_operation()
-    test_conv2d()
+    test_topi()
+    test_stride_dilation()