Consistent result of DetectLinearEquation() when an empy vars is passed (#2860)
authorYutetsu TAKATSUKASA <y.takatsukasa@gmail.com>
Sat, 30 Mar 2019 18:13:54 +0000 (03:13 +0900)
committerTianqi Chen <tqchen@users.noreply.github.com>
Sat, 30 Mar 2019 18:13:54 +0000 (14:13 -0400)
src/arithmetic/detect_linear_equation.cc
src/pass/inject_copy_intrin.cc
tests/python/unittest/test_arith_detect_linear_equation.py

index 6f4d3cf..e7bc7e7 100644 (file)
@@ -127,25 +127,21 @@ Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
   Expr base = e;
   Array<Expr> coeff;
 
-  if (0 == vars.size()) {
-    coeff.push_back(make_const(Int(32), 1));
-  } else {
-    for (Var v : vars) {
-      LinearEqEntry ret;
-      if (!LinearEqDetector(v).Detect(base, &ret)) {
-        return Array<Expr>();
-      }
-      coeff.push_back(ret.coeff);
-      base = std::move(ret.base);
+  for (Var v : vars) {
+    LinearEqEntry ret;
+    if (!LinearEqDetector(v).Detect(base, &ret)) {
+      return Array<Expr>();
     }
+    coeff.push_back(ret.coeff);
+    base = std::move(ret.base);
+  }
 
-    std::unordered_set<const Variable*> vset;
-    for (size_t i = vars.size(); i != 1; --i) {
-      vset.insert(vars[i - 1].get());
-      // The previous coeff contains the variable
-      if (ExprUseVar(coeff[i - 2], vset)) {
-        return Array<Expr>();
-      }
+  std::unordered_set<const Variable*> vset;
+  for (size_t i = vars.size(); i > 1; --i) {
+    vset.insert(vars[i - 1].get());
+    // The previous coeff contains the variable
+    if (ExprUseVar(coeff[i - 2], vset)) {
+      return Array<Expr>();
     }
   }
   coeff.push_back(base);
index 7ca1d13..7dcfcfd 100644 (file)
@@ -39,7 +39,6 @@ class CopyIntrinInjector : public IRMutator {
   bool MatchCopyPattern(Stmt stmt, Stmt *out) {
     using namespace arith;
     Stmt body = stmt;
-    bool is_single_point_copy = false;
 
     // strip the loops
     std::vector<const For*> loops;
@@ -60,7 +59,6 @@ class CopyIntrinInjector : public IRMutator {
     const Cast* cast = store->value.as<Cast>();
     const Load* load = store->value.as<Load>();
     if (0 == loops.size()) {
-      is_single_point_copy = true;
       CHECK(!has_cond);
     }
     // for now only support true condition matching
@@ -83,9 +81,8 @@ class CopyIntrinInjector : public IRMutator {
         arith::DetectLinearEquation(load->index, loop_vars);
     if (load_strides.size()  == 0 || store_strides.size() == 0) return false;
     Array<Expr> dst_shape;
-    auto loop_var_size = loop_vars.size();
-    if (is_single_point_copy) {
-      loop_var_size = 1;
+    const size_t loop_var_size = loop_vars.size();
+    if (loop_var_size == 0) {
       dst_shape.push_back(make_const(Int(32), 1));
     } else {
       for (const For* op : loops) {
@@ -132,6 +129,10 @@ class CopyIntrinInjector : public IRMutator {
     CHECK_EQ(load_strides.size(), loop_var_size + 1);
     Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size);
     Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size);
+    if (loop_var_size == 0) {
+        src_strides.push_back(make_const(Int(32), 1));
+        dst_strides.push_back(make_const(Int(32), 1));
+    }
     Buffer dst = BufferNode::make(
         Var(store->buffer_var.node_),
         store->value.type(),
index 2b0f327..33e2666 100644 (file)
@@ -20,6 +20,10 @@ def test_basic():
     m = tvm.arith.DetectLinearEquation(b * 7, [a])
     assert m[0].value == 0
 
+    m = tvm.arith.DetectLinearEquation(b * 7, [])
+    assert len(m) == 1
+    assert tvm.ir_pass.Simplify(m[0] - b * 7).value == 0
+
 def test_multivariate():
     v = [tvm.var("v%d" % i) for i in range(4)]
     b = tvm.var("b")
@@ -42,6 +46,10 @@ def test_multivariate():
     assert(m[0].value == 0)
     assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0)
 
+    m = tvm.arith.DetectLinearEquation((v[0] - v[1]), [])
+    assert(len(m) == 1)
+    assert(tvm.ir_pass.Simplify(m[0] - (v[0] - v[1])).value == 0)
+
 if __name__ == "__main__":
     test_basic()
     test_multivariate()