[RELAY][Fix] i64 indices (#5235)
authorHaozheng Fan <fanhaozh@amazon.com>
Thu, 23 Jul 2020 21:04:10 +0000 (05:04 +0800)
committerGitHub <noreply@github.com>
Thu, 23 Jul 2020 21:04:10 +0000 (14:04 -0700)
* fix

* resolve comments

src/te/schedule/operation_inline.cc
tests/python/relay/test_pass_fuse_ops.py

index aab30ed..7399af9 100644 (file)
@@ -63,7 +63,8 @@ class OperationInliner final : public StmtExprMutator {
       } else {
         Map<Var, PrimExpr> vmap;
         for (size_t i = 0; i < args_.size(); ++i) {
-          vmap.Set(args_[i], op->indices[i]);
+          // cast indices to the type of the original indexing variable
+          vmap.Set(args_[i], cast(args_[i].dtype(), op->indices[i]));
         }
         expr = Substitute(Evaluate(expr), vmap).as<EvaluateNode>()->value;
       }
index 6b7d297..f4369c1 100644 (file)
@@ -621,6 +621,79 @@ def test_fuse_max():
     after = run_opt_pass(expected(), transform.InferType())
     assert tvm.ir.structural_equal(zz, after)
 
+
+def test_fuse_take():
+    """Test fusion case involving concat and take"""
+
+    def before():
+        shape = (tvm.tir.const(10, "int64"),
+                 tvm.tir.const(1, "int64"))
+        x = relay.var("x", shape=shape)
+        concat = relay.concatenate([x,x], axis=-1)
+        out = relay.op.take(concat, indices=relay.const([0], dtype="int64"))
+        return relay.Function(relay.analysis.free_vars(out), out)
+
+    def expected():
+        shape1 = (tvm.tir.const(10, "int64"),
+                  tvm.tir.const(1, "int64"))
+        shape2 = (tvm.tir.const(1, "int64"),)
+        x = relay.var("x", shape=shape1)
+        p0 = relay.var("p0", shape=shape1)
+        p1 = relay.var("p1", shape=shape2,
+                             dtype="int64")
+        c = relay.const([0], dtype="int64")
+        concat = relay.concatenate([p0,p0], axis=-1)
+        out = relay.op.take(concat, indices=p1)
+
+        f0 = relay.Function([p0, p1], out)
+        f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+
+        y = relay.Call(f0, [x, c])
+        return relay.Function([x], y)
+
+    orig = before()
+    m = fuse2(tvm.IRModule.from_expr(orig))
+    relay.build(m, 'llvm')
+    after = run_opt_pass(expected(), transform.InferType())
+    assert tvm.ir.structural_equal(m["main"], after)
+
+
+def test_fuse_gather_nd():
+    """Test fusion case involving concat and gather_nd"""
+
+    def before():
+        shape = (tvm.tir.const(10, "int64"),
+                 tvm.tir.const(1, "int64"))
+        x = relay.var("x", shape=shape)
+        concat = relay.concatenate([x,x], axis=-1)
+        out = relay.gather_nd(concat, indices=relay.expr.const([[0,1],[1,0]], dtype="int64"))
+        return relay.Function(relay.analysis.free_vars(out), out)
+
+    def expected():
+        shape1 = (tvm.tir.const(10, "int64"),
+                  tvm.tir.const(1, "int64"))
+        shape2 = (tvm.tir.const(2, "int64"),
+                  tvm.tir.const(2, "int64"))
+        x = relay.var("x", shape=shape1)
+        p0 = relay.var("p0", shape=shape1)
+        p1 = relay.var("p1", shape=shape2, dtype="int64")
+        c = relay.const([[0,1],[1,0]], dtype="int64")
+        concat = relay.concatenate([p0,p0], axis=-1)
+        out = relay.gather_nd(concat, indices=p1)
+
+        f0 = relay.Function([p0, p1], out)
+        f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+
+        y = relay.Call(f0, [x, c])
+        return relay.Function([x], y)
+
+    orig = before()
+    m = fuse2(tvm.IRModule.from_expr(orig))
+    relay.build(m, 'llvm')
+    after = run_opt_pass(expected(), transform.InferType())
+    assert tvm.ir.structural_equal(m["main"], after)
+
+
 if __name__ == "__main__":
     test_fuse_simple()
     test_conv2d_fuse()
@@ -637,3 +710,5 @@ if __name__ == "__main__":
     test_immutable()
     test_split()
     test_fuse_max()
+    test_fuse_take()
+    test_fuse_gather_nd()