f = Var("f", FuncType([a], a))
x = Var("x", self.nat())
y = Var("y", self.nat())
- z = Var("z")
- z_case = Clause(PatternConstructor(self.z), Function([z], z))
- # todo: fix typechecker so Function([z], z) can be replaced by self.id
+ z_case = Clause(PatternConstructor(self.z), self.id)
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
self.compose(f, self.iterate(f, y)))
self.mod[self.iterate] = Function([f, x],
Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) {
// TODO(tqchen, jroesch): propagate span to solver
try {
- return solver_.Unify(t1, t2, expr);
+ // instantiate higher-order func types when unifying because
+ // we only allow polymorphism at the top level
+ Type first = t1;
+ Type second = t2;
+ if (auto* ft1 = t1.as<FuncTypeNode>()) {
+ first = InstantiateFuncType(ft1);
+ }
+ if (auto* ft2 = t2.as<FuncTypeNode>()) {
+ second = InstantiateFuncType(ft2);
+ }
+ return solver_.Unify(first, second, expr);
} catch (const dmlc::Error &e) {
this->ReportFatalError(
expr,
return Downcast<FuncType>(inst_ty);
}
+ // instantiates starting from incompletes
+ FuncType InstantiateFuncType(const FuncTypeNode* fn_ty) {
+ if (fn_ty->type_params.size() == 0) {
+ return GetRef<FuncType>(fn_ty);
+ }
+
+ Array<Type> type_args;
+ for (size_t i = 0; i < fn_ty->type_params.size(); i++) {
+ type_args.push_back(IncompleteTypeNode::make(Kind::kType));
+ }
+ return InstantiateFuncType(fn_ty, type_args);
+ }
+
+
void AddTypeArgs(const Expr& expr, Array<Type> type_args) {
auto type_info = type_map_.find(expr);
if (type_info == type_map_.end()) {
arg_types.push_back(GetType(param));
}
Type rtype = GetType(f->body);
+ if (auto* ft = rtype.as<FuncTypeNode>()) {
+ rtype = InstantiateFuncType(ft);
+ }
if (f->ret_type.defined()) {
rtype = this->Unify(f->ret_type, rtype, GetRef<Function>(f));
}
assert ft.checked_type == relay.FuncType([tt, f_type], tt)
+def test_higher_order_argument():
+ a = relay.TypeVar('a')
+ x = relay.Var('x', a)
+ id_func = relay.Function([x], x, a, [a])
+
+ b = relay.TypeVar('b')
+ f = relay.Var('f', relay.FuncType([b], b))
+ y = relay.Var('y', b)
+ ho_func = relay.Function([f, y], f(y), b, [b])
+
+ # id func should be an acceptable argument to the higher-order
+ # function even though id_func takes a type parameter
+ ho_call = ho_func(id_func, relay.const(0, 'int32'))
+
+ hc = relay.ir_pass.infer_type(ho_call)
+ expected = relay.scalar_type('int32')
+ assert hc.checked_type == expected
+
+
+def test_higher_order_return():
+ a = relay.TypeVar('a')
+ x = relay.Var('x', a)
+ id_func = relay.Function([x], x, a, [a])
+
+ b = relay.TypeVar('b')
+ nested_id = relay.Function([], id_func, relay.FuncType([b], b), [b])
+
+ ft = relay.ir_pass.infer_type(nested_id)
+ assert ft.checked_type == relay.FuncType([], relay.FuncType([b], b), [b])
+
+
+def test_higher_order_nested():
+ a = relay.TypeVar('a')
+ x = relay.Var('x', a)
+ id_func = relay.Function([x], x, a, [a])
+
+ choice_t = relay.FuncType([], relay.scalar_type('bool'))
+ f = relay.Var('f', choice_t)
+
+ b = relay.TypeVar('b')
+ z = relay.Var('z')
+ top = relay.Function(
+ [f],
+ relay.If(f(), id_func, relay.Function([z], z)),
+ relay.FuncType([b], b),
+ [b])
+
+ expected = relay.FuncType([choice_t], relay.FuncType([b], b), [b])
+ ft = relay.ir_pass.infer_type(top)
+ assert ft.checked_type == expected
+
+
def test_tuple():
tp = relay.TensorType((10,))
x = relay.var("x", tp)