[Relay] Fix Type Arguments not Attached (#6385)
authorAndrew Liu <andrewlliu@gmail.com>
Thu, 3 Sep 2020 17:31:19 +0000 (10:31 -0700)
committerGitHub <noreply@github.com>
Thu, 3 Sep 2020 17:31:19 +0000 (10:31 -0700)
src/relay/transforms/type_infer.cc
tests/python/relay/test_type_infer.py

index 7182f0e..e110737 100644 (file)
@@ -369,14 +369,12 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
 
     // Build a subsitituion map up from the function type and type arguments.
     // Eventually allow the type vars to be passed in.
+    CHECK(fn_ty->type_params.size() == ty_args.size())
+        << "number of type parameters does not match expected";
     for (size_t i = 0; i < ty_args.size(); ++i) {
       subst_map.Set(fn_ty->type_params[i], ty_args[i]);
     }
 
-    for (size_t i = ty_args.size(); i < fn_ty->type_params.size(); ++i) {
-      subst_map.Set(fn_ty->type_params[i], IncompleteType(Kind::kType));
-    }
-
     Type ret_type = fn_ty->ret_type;
 
     // If the function type is incomplete, place a new IncompleteType
@@ -445,6 +443,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
                                  << "Expected " << fn_ty_node->type_params.size() << "but got "
                                  << type_args.size());
     }
+    for (size_t i = type_args.size(); i < fn_ty_node->type_params.size(); i++) {
+      type_args.push_back(IncompleteType(TypeKind::kType));
+    }
 
     FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args);
 
index cc4748c..70e0c3f 100644 (file)
@@ -372,5 +372,19 @@ def test_if():
     ft = run_infer_type(top)
     tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype='float32'))
 
+def test_type_arg_infer():
+  code = """
+#[version = "0.0.5"]
+def @id[A](%x: A) -> A {
+  %x
+}
+def @main(%f: float32) -> float32 {
+  @id(%f)
+}
+"""
+  mod = tvm.parser.fromtext(code)
+  mod = transform.InferType()(mod)
+  tvm.ir.assert_structural_equal(mod['main'].body.type_args, [relay.TensorType((), 'float32')])
+
 if __name__ == "__main__":
     pytest.main([__file__])