[Relay] fix exponential blowup in interpreter (#3559)
author雾雨魔理沙 <lolisa@marisa.moe>
Wed, 11 Sep 2019 03:30:46 +0000 (20:30 -0700)
committerWuwei Lin <wuwei@apache.org>
Wed, 11 Sep 2019 03:30:46 +0000 (23:30 -0400)
include/tvm/relay/feature.h
python/tvm/relay/backend/interpreter.py
src/relay/backend/interpreter.cc
src/relay/ir/alpha_equal.cc
src/relay/pass/feature.cc
src/relay/pass/type_infer.cc
src/relay/pass/type_solver.cc
tests/python/relay/test_feature.py
tests/python/relay/test_pass_to_cps.py
tests/python/relay/test_type_infer.py

index a8b60e7..d7b3b39 100644 (file)
@@ -81,13 +81,13 @@ class FeatureSet {
     return ret;
   }
   /*! \brief A set that contain all the Feature. */
-  static FeatureSet AllFeature() {
+  static FeatureSet All() {
     FeatureSet fs;
     fs.bs_.flip();
     return fs;
   }
   /*! \brief The empty set. Contain no Feature. */
-  static FeatureSet NoFeature() {
+  static FeatureSet No() {
     FeatureSet fs;
     return fs;
   }
index cf7516c..ae60b7a 100644 (file)
@@ -280,6 +280,7 @@ class Interpreter(Executor):
         """
         seq = transform.Sequential([transform.SimplifyInference(),
                                     transform.FuseOps(0),
+                                    transform.ToANormalForm(),
                                     transform.InferType()])
         return seq(self.mod)
 
index e77d6a8..86a4ebb 100644 (file)
@@ -29,6 +29,7 @@
 #include <tvm/relay/interpreter.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/debug.h>
+#include <tvm/relay/feature.h>
 #include "compile_engine.h"
 
 namespace tvm {
@@ -761,6 +762,8 @@ CreateInterpreter(
     Target target) {
   auto intrp = std::make_shared<Interpreter>(mod, context, target);
   auto packed = [intrp](Expr expr) {
+    auto f = DetectFeature(expr);
+    CHECK(f.is_subset_of(FeatureSet::All() - fGraph));
     return intrp->Eval(expr);
   };
   return TypedPackedFunc<Value(Expr)>(packed);
index f83e588..878795d 100644 (file)
@@ -120,7 +120,7 @@ class AlphaEqualHandler:
    * \return the comparison result.
    */
   bool TypeEqual(const Type& lhs, const Type& rhs) {
-    auto compute = [&](){
+    auto compute = [&]() {
       if (lhs.same_as(rhs)) return true;
       if (!lhs.defined() || !rhs.defined()) return false;
       return this->VisitType(lhs, rhs);
index df3a5d7..2c5e7ab 100644 (file)
@@ -34,13 +34,15 @@ namespace relay {
 
 FeatureSet DetectFeature(const Expr& expr) {
   if (!expr.defined()) {
-    return FeatureSet::NoFeature();
+    return FeatureSet::No();
   }
   struct FeatureDetector : ExprVisitor {
     std::unordered_set<Expr, NodeHash, NodeEqual> visited_;
-    FeatureSet fs = FeatureSet::NoFeature();
+    FeatureSet fs = FeatureSet::No();
+
     void VisitExpr(const Expr& expr) final {
       if (visited_.count(expr) == 0) {
+        visited_.insert(expr);
         ExprVisitor::VisitExpr(expr);
       } else {
         if (!IsAtomic(expr)) {
@@ -52,15 +54,20 @@ FeatureSet DetectFeature(const Expr& expr) {
   void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { \
     STMT                                                  \
     fs += f##CONSTRUCT_NAME;                              \
-    ExprVisitor::VisitExpr_(op);                          \
   }
-#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, {})
+#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, { \
+    ExprVisitor::VisitExpr_(op);                                                    \
+  })
     DETECT_DEFAULT_CONSTRUCT(Var)
     DETECT_DEFAULT_CONSTRUCT(GlobalVar)
     DETECT_DEFAULT_CONSTRUCT(Constant)
     DETECT_DEFAULT_CONSTRUCT(Tuple)
     DETECT_DEFAULT_CONSTRUCT(TupleGetItem)
-    DETECT_DEFAULT_CONSTRUCT(Function)
+    DETECT_CONSTRUCT(Function, {
+        if (!op->IsPrimitive()) {
+          ExprVisitor::VisitExpr_(op);
+        }
+      })
     DETECT_DEFAULT_CONSTRUCT(Op)
     DETECT_DEFAULT_CONSTRUCT(Call)
     DETECT_CONSTRUCT(Let, {
@@ -69,6 +76,7 @@ FeatureSet DetectFeature(const Expr& expr) {
             fs += fLetRec;
           }
         }
+        ExprVisitor::VisitExpr_(op);
       })
     DETECT_DEFAULT_CONSTRUCT(If)
     DETECT_DEFAULT_CONSTRUCT(RefCreate)
@@ -83,7 +91,7 @@ FeatureSet DetectFeature(const Expr& expr) {
 }
 
 FeatureSet DetectFeature(const Module& mod) {
-  FeatureSet fs = FeatureSet::NoFeature();
+  FeatureSet fs = FeatureSet::No();
   if (mod.defined()) {
     for (const auto& f : mod->functions) {
       fs += DetectFeature(f.second);
index f7de2a9..e8bdc09 100644 (file)
@@ -139,19 +139,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
   // Perform unification on two types and report the error at the expression
   // or the span of the expression.
   Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) {
-    // TODO(tqchen, jroesch): propagate span to solver
     try {
-      // instantiate higher-order func types when unifying because
-      // we only allow polymorphism at the top level
-      Type first = t1;
-      Type second = t2;
-      if (auto* ft1 = t1.as<FuncTypeNode>()) {
-        first = InstantiateFuncType(ft1);
-      }
-      if (auto* ft2 = t2.as<FuncTypeNode>()) {
-        second = InstantiateFuncType(ft2);
-      }
-      return solver_.Unify(first, second, expr);
+      return solver_.Unify(t1, t2, expr);
     } catch (const dmlc::Error &e) {
       this->ReportFatalError(
         expr,
index 3887076..743a4c7 100644 (file)
@@ -289,30 +289,44 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
     const auto* ftn = tn.as<FuncTypeNode>();
     if (!ftn
         || op->arg_types.size() != ftn->arg_types.size()
-        || op->type_params.size() != ftn->type_params.size()
         || op->type_constraints.size() != ftn->type_constraints.size()) {
       return Type(nullptr);
     }
 
+    // without loss of generality, suppose op->type_params.size() >= ftn->type_params.size().
+    if (op->type_params.size() < ftn->type_params.size()) {
+      return VisitType_(ftn, GetRef<FuncType>(op));
+    }
+
     // remap type vars so they match
     Map<TypeVar, Type> subst_map;
-    for (size_t i = 0; i < op->type_params.size(); i++) {
-      subst_map.Set(ftn->type_params[i], op->type_params[i]);
+    tvm::Array<TypeVar> ft_type_params;
+    for (size_t i = 0; i < ftn->type_params.size(); ++i) {
+      subst_map.Set(op->type_params[i], ftn->type_params[i]);
+      ft_type_params.push_back(op->type_params[i]);
+    }
+
+    for (size_t i = ftn->type_params.size(); i < op->type_params.size(); ++i) {
+      subst_map.Set(op->type_params[i], IncompleteTypeNode::make(kType));
     }
 
-    auto ft1 = GetRef<FuncType>(op);
-    auto ft2 = Downcast<FuncType>(Bind(GetRef<FuncType>(ftn), subst_map));
+    FuncType ft = FuncTypeNode::make(op->arg_types,
+                                     op->ret_type,
+                                     ft_type_params,
+                                     op->type_constraints);
+    auto ft1 = Downcast<FuncType>(Bind(ft, subst_map));
+    auto ft2 = GetRef<FuncType>(ftn);
 
     Type ret_type = Unify(ft1->ret_type, ft2->ret_type);
 
     std::vector<Type> arg_types;
-    for (size_t i = 0; i < ft1->arg_types.size(); i++) {
+    for (size_t i = 0; i < ft2->arg_types.size(); ++i) {
       Type arg_type = Unify(ft1->arg_types[i], ft2->arg_types[i]);
       arg_types.push_back(arg_type);
     }
 
     std::vector<TypeConstraint> type_constraints;
-    for (size_t i = 0; i < ft1->type_constraints.size(); i++) {
+    for (size_t i = 0; i < ft1->type_constraints.size(); ++i) {
       Type unified_constraint = Unify(ft1->type_constraints[i],
                                       ft2->type_constraints[i]);
       const auto* tcn = unified_constraint.as<TypeConstraintNode>();
@@ -321,7 +335,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
       type_constraints.push_back(GetRef<TypeConstraint>(tcn));
     }
 
-    return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints);
+    return FuncTypeNode::make(arg_types, ret_type, ft2->type_params, type_constraints);
   }
 
   Type VisitType_(const RefTypeNode* op, const Type& tn) final {
index 3e9e6a3..8f0e90d 100644 (file)
@@ -63,7 +63,8 @@ def test_ad():
         Feature.fLet,
         Feature.fRefCreate,
         Feature.fRefRead,
-        Feature.fRefWrite
+        Feature.fRefWrite,
+        Feature.fGraph
     ])
 
 
index 7bed13f..045c92c 100644 (file)
@@ -30,6 +30,20 @@ def rand(dtype='float32', *shape):
     return tvm.nd.array(np.random.rand(*shape).astype(dtype))
 
 
+def test_id():
+    x = relay.var("x", shape=[])
+    id = run_infer_type(relay.Function([x], x))
+    id_cps = run_infer_type(to_cps(id))
+
+
+def test_double():
+    t = relay.TypeVar("t")
+    x = relay.var("x", t)
+    f = relay.var("f", relay.FuncType([t], t))
+    double = run_infer_type(relay.Function([f, x], f(f(x)), t, [t]))
+    double_cps = run_infer_type(to_cps(double))
+
+
 # make sure cps work for recursion.
 def test_recursion():
     mod = relay.Module()
index e8dff7a..3f6b0d2 100644 (file)
@@ -19,6 +19,7 @@
 """
 from tvm import relay
 from tvm.relay import op, transform, analysis
+from tvm.relay.analysis import assert_alpha_equal
 
 
 def run_infer_type(expr, mod=None):
@@ -349,6 +350,17 @@ def test_adt_match_type_annotations():
     assert ft.checked_type == relay.FuncType([tt], relay.TupleType([]))
 
 
+def test_let_polymorphism():
+    id = relay.Var("id")
+    xt = relay.TypeVar("xt")
+    x = relay.Var("x", xt)
+    body = relay.Tuple([id(relay.const(1)), id(relay.Tuple([]))])
+    body = relay.Let(id, relay.Function([x], x, xt, [xt]), body)
+    body = run_infer_type(body)
+    int32 = relay.TensorType((), "int32")
+    assert_alpha_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))
+
+
 if __name__ == "__main__":
     test_free_expr()
     test_dual_op()
@@ -366,3 +378,4 @@ if __name__ == "__main__":
     test_constructor_type()
     test_constructor_call()
     test_adt_match()
+    test_let_polymorphism()