[Relay] Keep fixed dim when unifying dynamic shape (#5795)
authorlixiaoquan <radioheads@163.com>
Fri, 24 Jul 2020 14:42:13 +0000 (22:42 +0800)
committerGitHub <noreply@github.com>
Fri, 24 Jul 2020 14:42:13 +0000 (07:42 -0700)
src/relay/analysis/type_solver.cc
tests/python/relay/test_any.py
tests/python/relay/test_type_infer.py

index a192002..a674265 100644 (file)
@@ -175,6 +175,17 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
     if (ulhs.same_as(urhs)) {
       return ulhs;
     }
+
+    if (ulhs.as<AnyNode>() && urhs.as<tvm::IntImmNode>()) {
+      solver_->shape_uf_.Set(urhs, ulhs);
+      return urhs;
+    }
+
+    if (ulhs.as<tvm::IntImmNode>() && urhs.as<AnyNode>()) {
+      solver_->shape_uf_.Set(ulhs, urhs);
+      return ulhs;
+    }
+
     if (ulhs.as<AnyNode>() || urhs.as<AnyNode>()) {
       return Any();
     }
index 6810d0b..bf28ee1 100644 (file)
@@ -721,9 +721,7 @@ def test_recursive_concat():
     mod["main"] = func
     data = np.array(0.0, dtype='int32')
     ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32")
-    # TODO(@jroesch): After LambdaLift pass, TypeInfer pass will fail
-    # so currently we cannot run this test case on VM
-    for kind in ["debug"]:
+    for kind in ["debug", "vm"]:
         ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
         result = ex.evaluate()(data)
         np.testing.assert_allclose(result.asnumpy(), ref)
index 4591618..e5082db 100644 (file)
@@ -21,6 +21,7 @@ import tvm
 from tvm import te
 from tvm import relay
 from tvm.relay import op, transform, analysis
+from tvm.relay import Any
 
 
 def run_infer_type(expr, mod=None):
@@ -362,6 +363,15 @@ def test_let_polymorphism():
     tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))
 
 
+def test_if():
+    choice_t = relay.FuncType([], relay.scalar_type('bool'))
+    f = relay.Var('f', choice_t)
+    true_branch = relay.Var('True', relay.TensorType([Any(), 1], dtype='float32'))
+    false_branch = relay.Var('False', relay.TensorType([Any(), Any()], dtype='float32'))
+    top = relay.Function([true_branch, false_branch], relay.If(f(), true_branch, false_branch))
+    ft = run_infer_type(top)
+    tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype='float32'))
+
 if __name__ == "__main__":
     test_free_expr()
     test_dual_op()
@@ -380,3 +390,4 @@ if __name__ == "__main__":
     test_constructor_call()
     test_adt_match()
     test_let_polymorphism()
+    test_if()