[Relay] Fix PE (#3482)
author雾雨魔理沙 <lolisa@marisa.moe>
Thu, 4 Jul 2019 04:58:51 +0000 (21:58 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Thu, 4 Jul 2019 04:58:51 +0000 (21:58 -0700)
include/tvm/relay/module.h
src/relay/ir/expr_functor.cc
src/relay/ir/type_functor.cc
src/relay/ir/type_functor.h
src/relay/pass/let_list.h
src/relay/pass/partial_eval.cc
src/relay/pass/type_infer.cc
src/relay/pass/util.cc
tests/python/relay/test_pass_partial_eval.py

index 638f759..4a3ff0b 100644 (file)
@@ -55,7 +55,7 @@ struct Module;
  *  The functional style allows users to construct custom
  *  environments easily, for example each thread can store
  *  a Module while auto-tuning.
- * */
+ */
 
 class ModuleNode : public RelayNode {
  public:
index 36692c5..0434e2a 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,7 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
+ *  Copyright (c) 2019 by Contributors
  * \file src/tvm/relay/expr_mutator.cc
  * \brief A wrapper around ExprFunctor which functionally updates the AST.
  *
@@ -26,6 +26,7 @@
  * the cost of using functional updates.
  */
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/pattern_functor.h>
 #include "type_functor.h"
 
 namespace tvm {
@@ -353,7 +354,7 @@ TVM_REGISTER_API("relay._analysis.post_order_visit")
   });
 
 // Implement bind.
-class ExprBinder : public ExprMutator {
+class ExprBinder : public ExprMutator, PatternMutator {
  public:
   explicit ExprBinder(const tvm::Map<Var, Expr>& args_map)
     : args_map_(args_map) {
@@ -383,13 +384,26 @@ class ExprBinder : public ExprMutator {
     }
   }
 
+  Pattern VisitPattern(const Pattern& p) final {
+    return PatternMutator::VisitPattern(p);
+  }
+
+  Clause VisitClause(const Clause& c) final {
+    Pattern pat = VisitPattern(c->lhs);
+    return ClauseNode::make(pat, VisitExpr(c->rhs));
+  }
+
+  Var VisitVar(const Var& v) final {
+    return Downcast<Var>(VisitExpr(v));
+  }
+
  private:
   const tvm::Map<Var, Expr>& args_map_;
 };
 
 Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
   if (const FunctionNode* func = expr.as<FunctionNode>()) {
-    Expr new_body = ExprBinder(args_map).Mutate(func->body);
+    Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
     Array<Var> new_params;
     for (Var param : func->params) {
       if (!args_map.count(param)) {
@@ -406,7 +420,7 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
                               func->type_params,
                               func->attrs);
   } else {
-    return ExprBinder(args_map).Mutate(expr);
+    return ExprBinder(args_map).VisitExpr(expr);
   }
 }
 
index 516f4c8..cde68c5 100644 (file)
@@ -92,6 +92,10 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) {
   }
 }
 
+Type TypeMutator::VisitType(const Type& t) {
+  return t.defined() ? TypeFunctor<Type(const Type&)>::VisitType(t) : t;
+}
+
 // Type Mutator.
 Array<Type> TypeMutator::MutateArray(Array<Type> arr) {
   // The array will do copy on write
@@ -221,7 +225,7 @@ class TypeBinder : public TypeMutator {
 };
 
 Type Bind(const Type& type, const tvm::Map<TypeVar, Type>& args_map) {
-  return type.defined() ? TypeBinder(args_map).VisitType(type) : type;
+  return TypeBinder(args_map).VisitType(type);
 }
 
 }  // namespace relay
index 27ac288..c3ee14e 100644 (file)
@@ -139,6 +139,7 @@ class TypeVisitor : public TypeFunctor<void(const Type& n)> {
 // Mutator that transform a type to another one.
 class TypeMutator : public TypeFunctor<Type(const Type& n)> {
  public:
+  Type VisitType(const Type& t) override;
   Type VisitType_(const TypeVarNode* op) override;
   Type VisitType_(const TensorTypeNode* op) override;
   Type VisitType_(const IncompleteTypeNode* op) override;
index 1b422d2..73c5fe3 100644 (file)
@@ -48,7 +48,7 @@ class LetList {
  public:
   ~LetList() {
     if (lets_.size() > 0 && !used_) {
-      std::cout << "Warning: letlist not used" << std::endl;
+      LOG(WARNING) << "letlist not used";
     }
   }
   /*!
index 6887c7a..b7f12b6 100644 (file)
@@ -64,7 +64,7 @@
  * 3: The generated code reuses bindings (although they are not shadowed),
  * so we have to deduplicate them.
  *
- * 4: In the generated code, multiple VarNode might have same Id.
+ * 4: In the generated code, as it call TypeSubst, multiple VarNode might have same Id.
  * While it is permitted, most pass use NodeHash for Var,
  * and having multiple VarNode for same Id break them.
  * Thus we remap them to a single Id for now.
@@ -216,9 +216,9 @@ Static MkSRef() {
 }
 
 using Func = std::function<PStatic(const std::vector<PStatic>&,
-                                      const Attrs&,
-                                      const Array<Type>&,
-                                      LetList*)>;
+                                   const Attrs&,
+                                   const Array<Type>&,
+                                   LetList*)>;
 
 struct SFuncNode : StaticNode {
   Func func;
@@ -256,6 +256,7 @@ class Environment {
 
   void Insert(const Var& v, const PStatic& ps) {
     CHECK(ps.defined());
+    CHECK_EQ(env_.back().locals.count(v), 0);
     env_.back().locals[v] = ps;
   }
 
@@ -287,12 +288,17 @@ class Environment {
 
 /*!
  * \brief As our store require rollback, we implement it as a frame.
- * every time we need to copy the store, a new frame is insert.
- * every time we roll back, a frame is popped.
+ *
+ * Every time we need to copy the store, a new frame is insert.
+ * Every time we roll back, a frame is popped.
  */
 struct StoreFrame {
   std::unordered_map<const SRefNode*, PStatic> store;
-  /*! \brief on unknown effect, history_valid is set to true to signal above frame is outdated */
+  /*!
+   * \brief On unknown effect, history_valid is set to true to signal above frame is outdated.
+   *
+   * It only outdate the frame above it, but not the current frame.
+   */
   bool history_valid = true;
   explicit StoreFrame(const std::unordered_map<const SRefNode*, PStatic>& store) : store(store) { }
   StoreFrame() = default;
@@ -310,6 +316,7 @@ class Store {
   }
 
   void Insert(const SRefNode* r, const PStatic& ps) {
+    CHECK(r);
     store_.back().store[r] = ps;
   }
 
@@ -317,19 +324,21 @@ class Store {
   PStatic Lookup(const SRefNode* r) {
     auto rit = store_.rbegin();
     while (rit != store_.rend()) {
-      if (!rit->history_valid) {
-        return PStatic();
-      }
       if (rit->store.find(r) != rit->store.end()) {
         return rit->store.find(r)->second;
       }
+      if (!rit->history_valid) {
+        return PStatic();
+      }
       ++rit;
     }
     return PStatic();
   }
 
   void Invalidate() {
-    store_.back().history_valid = false;
+    StoreFrame sf;
+    sf.history_valid = false;
+    store_.push_back(sf);
   }
 
  private:
@@ -341,6 +350,10 @@ class Store {
       store_->store_.push_back(StoreFrame());
     }
     ~StoreFrameContext() {
+      // push one history valid frame off.
+      while (!store_->store_.back().history_valid) {
+        store_->store_.pop_back();
+      }
       store_->store_.pop_back();
     }
   };
@@ -442,13 +455,7 @@ Function AsFunc(const Expr& e) {
 class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>,
                          public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> {
  public:
-  PartialEvaluator(const tvm::Array<Var>& free_vars,
-                   const Module& mod) :
-    mod_(mod) {
-    for (const Var& v : free_vars) {
-      env_.Insert(v, NoStatic(v));
-    }
-  }
+  PartialEvaluator(const Module& mod) : mod_(mod) { }
 
   PStatic VisitExpr(const Expr& e, LetList* ll) final {
     PStatic ret = ExprFunctor<PStatic(const Expr&, LetList*)>::VisitExpr(e, ll);
@@ -484,23 +491,23 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
     return env_.Lookup(GetRef<Var>(op));
   }
 
-  PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final {
-    GlobalVar gv = GetRef<GlobalVar>(op);
+  PStatic VisitGlobalVar(const GlobalVar& gv) {
+    CHECK(mod_.defined());
     if (gv_map_.count(gv) == 0) {
-      if (mod_.defined()) {
-        Function func = mod_->Lookup(gv);
-        InitializeFuncId(func);
-        Func f = VisitFuncStatic(func, gv);
-        gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
-        func = AsFunc(PostProcess(VisitFuncDynamic(func, f)));
-        mod_->Update(gv, func);
-      } else {
-        gv_map_.insert({gv, NoStatic(gv)});
-      }
+      Function func = mod_->Lookup(gv);
+      InitializeFuncId(func);
+      Func f = VisitFuncStatic(func, gv);
+      gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
+      func = AsFunc(PostProcess(VisitFuncDynamic(func, f)));
+      mod_->Update(gv, func);
     }
     return gv_map_.at(gv);
   }
 
+  PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final {
+    return VisitGlobalVar(GetRef<GlobalVar>(op));
+  }
+
   PStatic VisitExpr_(const LetNode* op, LetList* ll) final {
     env_.Insert(op->var, VisitExpr(op->value, ll));
     return VisitExpr(op->body, ll);
@@ -629,7 +636,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
             subst.Set(func->type_params[i], type_args[i]);
           }
           for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
-            subst.Set(func->type_params[i], Type());
+            subst.Set(func->type_params[i], IncompleteTypeNode::make(kType));
           }
           std::vector<Time> args_time;
           for (const auto& v : pv) {
@@ -672,22 +679,22 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
     };
   }
 
-
   Expr VisitFuncDynamic(const Function& func, const Func& f) {
     return store_.Extend<Expr>([&]() {
-        store_.Invalidate();
-        return FunctionNode::make(func->params, LetList::With([&](LetList* ll) {
-              std::vector<PStatic> pv;
-              for (const auto& v : func->params) {
-                pv.push_back(NoStatic(v));
-              }
-              tvm::Array<Type> type_args;
-              for (const auto& tp : func->type_params) {
-                type_args.push_back(tp);
-              }
-              return f(pv, Attrs(), type_args, ll)->dynamic;
-            }), func->ret_type, func->type_params, func->attrs);
-      });
+      store_.Invalidate();
+      return FunctionNode::make(func->params,
+                                LetList::With([&](LetList* ll) {
+        std::vector<PStatic> pv;
+        for (const auto& v : func->params) {
+          pv.push_back(NoStatic(v));
+        }
+        tvm::Array<Type> type_args;
+        for (const auto& tp : func->type_params) {
+          type_args.push_back(tp);
+        }
+        return f(pv, Attrs(), type_args, ll)->dynamic;
+      }), func->ret_type, func->type_params, func->attrs);
+    });
   }
 
   PStatic VisitFunc(const Function& func, LetList* ll) {
@@ -1012,17 +1019,14 @@ Expr PostProcess(const Expr& e) {
 
 Module PartialEval(const Module& m) {
   CHECK(m->entry_func.defined());
-  auto func = m->Lookup(m->entry_func);
-  Expr ret =
-    TransformF([&](const Expr& e) {
-      return LetList::With([&](LetList* ll) {
-        relay::partial_eval::PartialEvaluator pe(FreeVars(e), m);
-        pe.InitializeFuncId(e);
-        return relay::partial_eval::PostProcess(pe.VisitExpr(e, ll)->dynamic);
-      });
-    }, func);
-  CHECK(ret->is_type<FunctionNode>());
-  m->Update(m->entry_func, Downcast<Function>(ret));
+  relay::partial_eval::PartialEvaluator pe(m);
+  std::vector<GlobalVar> gvs;
+  for (const auto& p : m->functions) {
+    gvs.push_back(p.first);
+  }
+  for (const auto& gv : gvs) {
+    pe.VisitGlobalVar(gv);
+  }
   return m;
 }
 
index aa3cc02..5ae3908 100644 (file)
@@ -172,6 +172,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
       return it->second.checked_type;
     }
     Type ret = this->VisitExpr(expr);
+    CHECK(ret.defined());
     KindCheck(ret, mod_);
     ResolvedTypeInfo& rti = type_map_[expr];
     rti.checked_type = ret;
index 2497197..e2b7157 100644 (file)
@@ -425,6 +425,16 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
     Var VisitVar(const Var& v) final {
       return Downcast<Var>(VisitExpr(v));
     }
+
+    Pattern VisitPattern(const Pattern& p) final {
+      return PatternMutator::VisitPattern(p);
+    }
+
+    Clause VisitClause(const Clause& c) final {
+      Pattern pat = VisitPattern(c->lhs);
+      return ClauseNode::make(pat, VisitExpr(c->rhs));
+    }
+
    private:
     const tvm::Map<TypeVar, Type>& subst_map_;
   };
index 6a7f59c..8855b08 100644 (file)
@@ -307,10 +307,10 @@ def test_double():
 
 
 if __name__ == '__main__':
-    test_empty_ad()
+    test_ref()
     test_tuple()
+    test_empty_ad()
     test_const_inline()
-    test_ref()
     test_ad()
     test_if_ref()
     test_function_invalidate()