[Relay] Fix Partial Evaluator, Add stricter checking for CheckWellFormed (#3749)
author雾雨魔理沙 <lolisa@marisa.moe>
Sun, 11 Aug 2019 01:23:23 +0000 (18:23 -0700)
committerThierry Moreau <moreau@uw.edu>
Sun, 11 Aug 2019 01:23:23 +0000 (18:23 -0700)
* aot

* save

* save

* fix test

* remove vta changes

* lint

src/relay/ir/expr_functor.cc
src/relay/pass/de_duplicate.cc
src/relay/pass/let_list.h
src/relay/pass/partial_eval.cc
src/relay/pass/util.cc
src/relay/pass/well_formed.cc
tests/python/relay/test_error_reporting.py
tests/python/relay/test_pass_partial_eval.py

index 1a1fc36..da9f7b8 100644 (file)
@@ -216,7 +216,8 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) {
 }
 
 Clause ExprMutator::VisitClause(const Clause& c) {
-  return ClauseNode::make(VisitPattern(c->lhs), VisitExpr(c->rhs));
+  Pattern p = VisitPattern(c->lhs);
+  return ClauseNode::make(p, VisitExpr(c->rhs));
 }
 
 Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }
@@ -395,7 +396,9 @@ class ExprBinder : public ExprMutator, PatternMutator {
   }
 
   Var VisitVar(const Var& v) final {
-    return Downcast<Var>(VisitExpr(v));
+    CHECK(!args_map_.count(v))
+      << "Cannnot bind an internal pattern variable";
+    return v;
   }
 
  private:
index d5d4f69..332803c 100644 (file)
@@ -44,6 +44,8 @@ Expr DeDup(const Expr& e) {
     }
 
     Var Fresh(const Var& v) {
+      CHECK_EQ(rename_.count(v), 0);
+      CHECK_EQ(memo_.count(v), 0) << v.as<VarNode>();
       Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation));
       rename_[v] = ret;
       return ret;
@@ -84,18 +86,13 @@ Expr DeDup(const Expr& e) {
     }
 
     Pattern VisitPattern(const Pattern& p) final {
-      return PatternMutator::VisitPattern(p);
+      return PatternFunctor::VisitPattern(p);
     }
 
     Pattern VisitPattern_(const PatternVarNode* op) final {
       return PatternVarNode::make(Fresh(op->var));
     }
 
-    Clause VisitClause(const Clause& c) final {
-      Pattern pat = VisitPattern(c->lhs);
-      return ClauseNode::make(pat, VisitExpr(c->rhs));
-    }
-
     Type VisitType_(const TypeVarNode* op) final {
       TypeVar v = GetRef<TypeVar>(op);
       return type_rename_.count(v) != 0 ? type_rename_.at(v) : v;
@@ -109,9 +106,10 @@ Expr DeDup(const Expr& e) {
     std::unordered_map<Var, Var, NodeHash, NodeEqual> rename_;
     std::unordered_map<TypeVar, TypeVar, NodeHash, NodeEqual> type_rename_;
   };
-
+  CHECK(WellFormed(e)) << AsText(e, false);
   Expr ret = DeDupMutator().VisitExpr(e);
-  CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size());
+  CHECK(WellFormed(ret));
+  CHECK_EQ(FreeVars(e).size(), FreeVars(ret).size());
   return ret;
 }
 
index e90ab12..94b5ea3 100644 (file)
@@ -31,6 +31,7 @@
 #define TVM_RELAY_PASS_LET_LIST_H_
 
 #include <tvm/relay/expr.h>
+#include <tvm/relay/analysis.h>
 #include <utility>
 #include <vector>
 #include <tuple>
@@ -63,6 +64,7 @@ class LetList {
    */
   Var Push(Var pv, Expr expr) {
     CHECK(!used_);
+    CHECK(WellFormed(expr));
     lets_.emplace_back(std::make_pair(pv, expr));
     return pv;
   }
index 1ea63e8..3f92d7a 100644 (file)
@@ -396,6 +396,7 @@ class Environment {
 
   void Insert(const Var& v, const PStatic& ps) {
     CHECK(ps.defined());
+    CHECK_GT(env_.size(), 0);
     CHECK_EQ(env_.back().locals.count(v), 0);
     env_.back().locals[v] = ps;
   }
@@ -604,10 +605,10 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
   }
 
   PStatic VisitExpr(const Expr& e, LetList* ll, const Var& name) {
-    if (auto* op = e.as<CallNode>()) {
-      if (op->op.same_as(WithFuncIdOp())) {
-        CHECK_EQ(op->args.size(), 1);
-        return VisitExpr(op->args[0], ll, name);
+    if (const CallNode* c = e.as<CallNode>()) {
+      if (c->op.same_as(WithFuncIdOp())) {
+        CHECK_EQ(c->args.size(), 1);
+        return VisitExpr(c->args[0], ll, name);
       }
     }
     PStatic ret = e.as<FunctionNode>() ?
@@ -801,34 +802,36 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
                LetList* ll) {
       return env_.Extend<PStatic>([&]() {
           CHECK_EQ(pv.size(), func->params.size());
-          if (var.as<VarNode>()) {
-            env_.Insert(Downcast<Var>(var), self);
-          }
-          for (size_t i = 0; i < pv.size(); ++i) {
-            env_.Insert(func->params[i], pv[i]);
-          }
-          for (const auto& p : free_vars) {
-            env_.Insert(p.first, p.second);
-          }
-          tvm::Map<TypeVar, Type> subst;
-          for (size_t i = 0; i < type_args.size(); ++i) {
-            subst.Set(func->type_params[i], type_args[i]);
-          }
-          for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
-            subst.Set(func->type_params[i], IncompleteTypeNode::make(kType));
-          }
-          std::vector<Fuel> args_fuel;
-          for (const auto& v : pv) {
-            args_fuel.push_back(GetFuel(v));
-          }
           CHECK_GT(func_map_.count(func), 0);
           FuncId fid = func_map_.at(func);
           if (fuel_map_.count(fid) == 0) {
             fuel_map_.insert({fid, MkFTop()});
           }
+          std::vector<Fuel> args_fuel;
+          for (const auto& v : pv) {
+            args_fuel.push_back(GetFuel(v));
+          }
           auto meet_res = fuel_map_[fid]->Meet(MkFSeq(args_fuel));
           if (std::get<1>(meet_res)) {
             FuelFrame tf(this, fid, std::get<0>(meet_res));
+            Expr dedup_func = RegisterFuncId(DeDup(AnnotateFuncId(func)));
+            Function func = AsFunc(dedup_func);
+            if (var.as<VarNode>()) {
+              env_.Insert(Downcast<Var>(var), self);
+            }
+            for (size_t i = 0; i < pv.size(); ++i) {
+              env_.Insert(func->params[i], pv[i]);
+            }
+            for (const auto& p : free_vars) {
+              env_.Insert(p.first, p.second);
+            }
+            tvm::Map<TypeVar, Type> subst;
+            for (size_t i = 0; i < type_args.size(); ++i) {
+              subst.Set(func->type_params[i], type_args[i]);
+            }
+            for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
+              subst.Set(func->type_params[i], IncompleteTypeNode::make(kType));
+            }
             return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll);
           } else {
             std::vector<Expr> dyn;
@@ -979,32 +982,37 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
   PStatic VisitExpr_(const MatchNode* op, LetList* ll) final {
     PStatic ps = VisitExpr(op->data, ll);
     return env_.Extend<PStatic>([&]() {
-        for (const Clause& c : op->clauses) {
-          switch (VisitPattern(c->lhs, ps)) {
-          case MatchStatus::Match:
-            return VisitExpr(c->rhs, ll);
-          case MatchStatus::NoMatch:
-            continue;
-          case MatchStatus::Unknown:
+      for (const Clause& c : op->clauses) {
+        switch (VisitPattern(c->lhs, ps)) {
+        case MatchStatus::Match:
+          return VisitExpr(c->rhs, ll);
+        case MatchStatus::NoMatch:
+          continue;
+        case MatchStatus::Unknown:
+          return [&]() {
             tvm::Array<Clause> clauses;
             for (const Clause& c : op->clauses) {
               Expr expr = store_.Extend<Expr>([&]() {
-                  return LetList::With([&](LetList* ll) {
-                      for (const Var& v : BoundVars(c->lhs)) {
-                        env_.Insert(v, NoStatic(v));
-                      }
-                      return VisitExpr(c->rhs, ll)->dynamic;
-                    });
+                return LetList::With([&](LetList* ll) {
+                  for (const Var& v : BoundVars(c->lhs)) {
+                    env_.Insert(v, NoStatic(v));
+                  }
+                  return VisitExpr(c->rhs, ll)->dynamic;
                 });
+              });
               clauses.push_back(ClauseNode::make(c->lhs, expr));
             }
             store_.Invalidate();
             return NoStatic(ll->Push(MatchNode::make(ps->dynamic, clauses, op->complete)));
-          }
+          }();
+        default:
+          LOG(FATAL) << "Unknown MatchStatus";
+          throw;
         }
-        LOG(FATAL) << "No case Match";
-        throw;
-      });
+      }
+      LOG(FATAL) << "No case Match";
+      throw;
+    });
   }
 
   MatchStatus VisitPattern_(const PatternWildcardNode* op, const PStatic& ps) final {
index e2b7157..90c3de8 100644 (file)
@@ -438,7 +438,11 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
    private:
     const tvm::Map<TypeVar, Type>& subst_map_;
   };
-  return TypeSubstMutator(subst_map).VisitExpr(expr);
+  CHECK(WellFormed(expr));
+  auto ret = TypeSubstMutator(subst_map).VisitExpr(expr);
+  CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
+  CHECK(WellFormed(ret));
+  return ret;
 }
 
 }  // namespace relay
index bfe8865..27b31de 100644 (file)
@@ -35,36 +35,84 @@ namespace relay {
 class WellFormedChecker : private ExprVisitor, PatternVisitor {
   bool well_formed = true;
 
-  std::unordered_set<Var, NodeHash, NodeEqual> s;
+  std::vector<std::unordered_set<Var, NodeHash, NodeEqual>> scope;
+  std::unordered_set<Var, NodeHash, NodeEqual> current_bound;
+  std::unordered_set<Var, NodeHash, NodeEqual> total_bound;
+  std::unordered_set<Var, NodeHash, NodeEqual> free;
 
-  void Check(const Var& v) {
-    if (s.count(v) != 0) {
+  struct Scope {
+    WellFormedChecker* wfc;
+    explicit Scope(WellFormedChecker* wfc) : wfc(wfc) {
+      wfc->scope.push_back({});
+    }
+    ~Scope() {
+      CHECK_GE(wfc->scope.size(), 0);
+      for (const Var& v : wfc->scope.back()) {
+        CHECK_GE(wfc->current_bound.count(v), 0);
+        wfc->current_bound.erase(v);
+      }
+      wfc->scope.pop_back();
+    }
+  };
+
+  void Bound(const Var& v) {
+    if (current_bound.count(v) != 0 || total_bound.count(v) != 0 || free.count(v) != 0) {
       well_formed = false;
     }
-    s.insert(v);
+    CHECK_GE(scope.size(), 0);
+    scope.back().insert(v);
+    current_bound.insert(v);
+    total_bound.insert(v);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    Var v = GetRef<Var>(op);
+    if (current_bound.count(v) == 0) {
+      if (total_bound.count(v) != 0) {
+        well_formed = false;
+      } else {
+        free.insert(v);
+      }
+    }
   }
 
   void VisitExpr_(const LetNode* l) final {
+    Scope s(this);
     // we do letrec only for FunctionNode,
     // but shadowing let in let binding is likely programming error, and we should forbidden it.
-    Check(l->var);
+    Bound(l->var);
     CheckWellFormed(l->value);
     CheckWellFormed(l->body);
   }
 
   void VisitExpr_(const FunctionNode* f) final {
+    Scope s(this);
     for (const Var& param : f->params) {
-      Check(param);
+      Bound(param);
     }
     CheckWellFormed(f->body);
   }
 
+  void VisitClause(const Clause& c) final {
+    Scope s(this);
+    VisitPattern(c->lhs);
+    VisitExpr(c->rhs);
+  }
+
   void VisitPattern(const Pattern& p) final {
     PatternVisitor::VisitPattern(p);
   }
 
   void VisitVar(const Var& v) final {
-    Check(v);
+    Bound(v);
+  }
+
+  void VisitExpr(const Expr& e) final {
+    if (auto v = e.as<VarNode>()) {
+      VisitExpr_(v);
+    } else {
+      ExprVisitor::VisitExpr(e);
+    }
   }
 
  public:
index c446f36..74e6518 100644 (file)
@@ -27,27 +27,36 @@ def check_type_err(expr, msg):
     except tvm.TVMError as err:
         assert msg in str(err)
 
+def test_wellformed():
+    x = relay.var('x', shape=(10, 10))
+    f = relay.Function([x], x)
+    check_type_err(
+        f(x),
+        "Check failed: WellFormed")
+
 def test_too_many_args():
     x = relay.var('x', shape=(10, 10))
     f = relay.Function([x], x)
     y = relay.var('y', shape=(10, 10))
     check_type_err(
-        f(x, y),
+        f(y, y),
         "the function is provided too many arguments expected 1, found 2;")
 
 def test_too_few_args():
     x = relay.var('x', shape=(10, 10))
     y = relay.var('y', shape=(10, 10))
+    z = relay.var('z', shape=(10, 10))
     f = relay.Function([x, y], x)
-    check_type_err(f(x), "the function is provided too few arguments expected 2, found 1;")
+    check_type_err(f(z), "the function is provided too few arguments expected 2, found 1;")
 
 def test_rel_fail():
     x = relay.var('x', shape=(10, 10))
     y = relay.var('y', shape=(11, 10))
     f = relay.Function([x, y], x + y)
-    check_type_err(f(x, y), "Incompatible broadcast type TensorType([10, 10], float32) and TensorType([11, 10], float32);")
+    check_type_err(f, "Incompatible broadcast type TensorType([10, 10], float32) and TensorType([11, 10], float32);")
 
 if __name__ == "__main__":
+    test_wellformed()
     test_too_many_args()
     test_too_few_args()
     test_rel_fail()
index 6349362..f914f18 100644 (file)
@@ -323,7 +323,16 @@ def test_triangle_number():
     assert_alpha_equal(dcpe(orig), const(55))
 
 
+def test_nat_update():
+    m = Module()
+    p = Prelude(m)
+    add_nat_definitions(p)
+    m = transform.ToANormalForm()(m)
+    transform.PartialEvaluate()(m)
+
+
 if __name__ == '__main__':
+    test_nat_update()
     test_ref()
     test_tuple()
     test_empty_ad()