[Relay] Ensure nested higher-order functions are treated correctly (#2676)
authorSteven S. Lyubomirsky <slyubomirsky@gmail.com>
Wed, 27 Feb 2019 04:38:50 +0000 (20:38 -0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Wed, 27 Feb 2019 04:38:50 +0000 (20:38 -0800)
python/tvm/relay/prelude.py
src/relay/pass/type_infer.cc
tests/python/relay/test_type_infer.py

index 034b58ef1c7e7e0010e59a68698494a9e393e10f..41d1be284f8ef3eeb47b58fd56faa4d6873527bf 100644 (file)
@@ -394,9 +394,7 @@ class Prelude:
         f = Var("f", FuncType([a], a))
         x = Var("x", self.nat())
         y = Var("y", self.nat())
-        z = Var("z")
-        z_case = Clause(PatternConstructor(self.z), Function([z], z))
-        # todo: fix typechecker so Function([z], z) can be replaced by self.id
+        z_case = Clause(PatternConstructor(self.z), self.id)
         s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
                         self.compose(f, self.iterate(f, y)))
         self.mod[self.iterate] = Function([f, x],
index b6bdedc044739dbd74173c6b21d1e6f0d8bb21f2..8dd02f39adcee986792af9eb9b4e720dfa1b5bc5 100644 (file)
@@ -121,7 +121,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
   Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) {
     // TODO(tqchen, jroesch): propagate span to solver
     try {
-      return solver_.Unify(t1, t2, expr);
+      // instantiate higher-order func types when unifying because
+      // we only allow polymorphism at the top level
+      Type first = t1;
+      Type second = t2;
+      if (auto* ft1 = t1.as<FuncTypeNode>()) {
+        first = InstantiateFuncType(ft1);
+      }
+      if (auto* ft2 = t2.as<FuncTypeNode>()) {
+        second = InstantiateFuncType(ft2);
+      }
+      return solver_.Unify(first, second, expr);
     } catch (const dmlc::Error &e) {
       this->ReportFatalError(
         expr,
@@ -351,6 +361,20 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     return Downcast<FuncType>(inst_ty);
   }
 
+  // instantiates starting from incompletes
+  FuncType InstantiateFuncType(const FuncTypeNode* fn_ty) {
+    if (fn_ty->type_params.size() == 0) {
+      return GetRef<FuncType>(fn_ty);
+    }
+
+    Array<Type> type_args;
+    for (size_t i = 0; i < fn_ty->type_params.size(); i++) {
+      type_args.push_back(IncompleteTypeNode::make(Kind::kType));
+    }
+    return InstantiateFuncType(fn_ty, type_args);
+  }
+
+
   void AddTypeArgs(const Expr& expr, Array<Type> type_args) {
     auto type_info = type_map_.find(expr);
     if (type_info == type_map_.end()) {
@@ -464,6 +488,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
       arg_types.push_back(GetType(param));
     }
     Type rtype = GetType(f->body);
+    if (auto* ft = rtype.as<FuncTypeNode>()) {
+      rtype = InstantiateFuncType(ft);
+    }
     if (f->ret_type.defined()) {
       rtype = this->Unify(f->ret_type, rtype, GetRef<Function>(f));
     }
index 05f8b8fd22f98cae3cb086ca15c45cadcd53a5a8..8c8e7dfd1fcc0d4d49d99e7c9ee0c76900565c1a 100644 (file)
@@ -133,6 +133,58 @@ def test_incomplete_call():
     assert ft.checked_type == relay.FuncType([tt, f_type], tt)
 
 
+def test_higher_order_argument():
+    a = relay.TypeVar('a')
+    x = relay.Var('x', a)
+    id_func = relay.Function([x], x, a, [a])
+
+    b = relay.TypeVar('b')
+    f = relay.Var('f', relay.FuncType([b], b))
+    y = relay.Var('y', b)
+    ho_func = relay.Function([f, y], f(y), b, [b])
+
+    # id func should be an acceptable argument to the higher-order
+    # function even though id_func takes a type parameter
+    ho_call = ho_func(id_func, relay.const(0, 'int32'))
+
+    hc = relay.ir_pass.infer_type(ho_call)
+    expected = relay.scalar_type('int32')
+    assert hc.checked_type == expected
+
+
+def test_higher_order_return():
+    a = relay.TypeVar('a')
+    x = relay.Var('x', a)
+    id_func = relay.Function([x], x, a, [a])
+
+    b = relay.TypeVar('b')
+    nested_id = relay.Function([], id_func, relay.FuncType([b], b), [b])
+
+    ft = relay.ir_pass.infer_type(nested_id)
+    assert ft.checked_type == relay.FuncType([], relay.FuncType([b], b), [b])
+
+
+def test_higher_order_nested():
+    a = relay.TypeVar('a')
+    x = relay.Var('x', a)
+    id_func = relay.Function([x], x, a, [a])
+
+    choice_t = relay.FuncType([], relay.scalar_type('bool'))
+    f = relay.Var('f', choice_t)
+
+    b = relay.TypeVar('b')
+    z = relay.Var('z')
+    top = relay.Function(
+        [f],
+        relay.If(f(), id_func, relay.Function([z], z)),
+        relay.FuncType([b], b),
+        [b])
+
+    expected = relay.FuncType([choice_t], relay.FuncType([b], b), [b])
+    ft = relay.ir_pass.infer_type(top)
+    assert ft.checked_type == expected
+
+
 def test_tuple():
     tp = relay.TensorType((10,))
     x = relay.var("x", tp)