From: Andrew Liu Date: Thu, 3 Sep 2020 17:31:19 +0000 (-0700) Subject: [Relay] Fix Type Arguments not Attached (#6385) X-Git-Tag: upstream/0.7.0~164 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=17d39fbd2a99c4308bd089f10eb15699659c5052;p=platform%2Fupstream%2Ftvm.git [Relay] Fix Type Arguments not Attached (#6385) --- diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 7182f0e..e110737 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -369,14 +369,12 @@ class TypeInferencer : private ExprFunctor, // Build a subsitituion map up from the function type and type arguments. // Eventually allow the type vars to be passed in. + CHECK(fn_ty->type_params.size() == ty_args.size()) + << "number of type parameters does not match expected"; for (size_t i = 0; i < ty_args.size(); ++i) { subst_map.Set(fn_ty->type_params[i], ty_args[i]); } - for (size_t i = ty_args.size(); i < fn_ty->type_params.size(); ++i) { - subst_map.Set(fn_ty->type_params[i], IncompleteType(Kind::kType)); - } - Type ret_type = fn_ty->ret_type; // If the function type is incomplete, place a new IncompleteType @@ -445,6 +443,9 @@ class TypeInferencer : private ExprFunctor, << "Expected " << fn_ty_node->type_params.size() << "but got " << type_args.size()); } + for (size_t i = type_args.size(); i < fn_ty_node->type_params.size(); i++) { + type_args.push_back(IncompleteType(TypeKind::kType)); + } FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args); diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index cc4748c..70e0c3f 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -372,5 +372,19 @@ def test_if(): ft = run_infer_type(top) tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype='float32')) +def test_type_arg_infer(): + code = """ +#[version = "0.0.5"] +def @id[A](%x: A) -> A { + %x +} +def @main(%f: float32) -> float32 { + @id(%f) +} +""" + mod = tvm.parser.fromtext(code) + mod = transform.InferType()(mod) + tvm.ir.assert_structural_equal(mod['main'].body.type_args, [relay.TensorType((), 'float32')]) + if __name__ == "__main__": pytest.main([__file__])