// Build a subsitituion map up from the function type and type arguments.
// Eventually allow the type vars to be passed in.
+ CHECK(fn_ty->type_params.size() == ty_args.size())
+ << "number of type parameters does not match expected";
for (size_t i = 0; i < ty_args.size(); ++i) {
subst_map.Set(fn_ty->type_params[i], ty_args[i]);
}
- for (size_t i = ty_args.size(); i < fn_ty->type_params.size(); ++i) {
- subst_map.Set(fn_ty->type_params[i], IncompleteType(Kind::kType));
- }
-
Type ret_type = fn_ty->ret_type;
// If the function type is incomplete, place a new IncompleteType
<< "Expected " << fn_ty_node->type_params.size() << "but got "
<< type_args.size());
}
+ for (size_t i = type_args.size(); i < fn_ty_node->type_params.size(); i++) {
+ type_args.push_back(IncompleteType(TypeKind::kType));
+ }
FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args);
ft = run_infer_type(top)
tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype='float32'))
+def test_type_arg_infer():
+ code = """
+#[version = "0.0.5"]
+def @id[A](%x: A) -> A {
+ %x
+}
+def @main(%f: float32) -> float32 {
+ @id(%f)
+}
+"""
+ mod = tvm.parser.fromtext(code)
+ mod = transform.InferType()(mod)
+ tvm.ir.assert_structural_equal(mod['main'].body.type_args, [relay.TensorType((), 'float32')])
+
if __name__ == "__main__":
pytest.main([__file__])