[Relay][Pass] Only allow Module -> Module for opts managed by pass infra (#3430)
authorZhi <5145158+zhiics@users.noreply.github.com>
Mon, 1 Jul 2019 19:50:39 +0000 (12:50 -0700)
committerJared Roesch <roeschinc@gmail.com>
Mon, 1 Jul 2019 19:50:39 +0000 (12:50 -0700)
* [Relay][Pass] Only allow Module -> Module for opts managed by pass infra

* revert gradient pass

13 files changed:
include/tvm/relay/pass.h
include/tvm/relay/transform.h
python/tvm/relay/ir_pass.py
python/tvm/relay/transform.py
src/relay/pass/dead_code.cc
src/relay/pass/partial_eval.cc
src/relay/pass/pass_manager.cc
src/relay/pass/to_a_normal_form.cc
src/relay/pass/to_graph_normal_form.cc
tests/python/relay/test_pass_dead_code_elimination.py
tests/python/relay/test_pass_partial_eval.py
tests/python/relay/test_pass_to_a_normal_form.py
tests/python/relay/test_pass_to_graph_normal_form.py

index 294d22b..79172c3 100644 (file)
@@ -141,23 +141,6 @@ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
 TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2);
 
 /*!
- * \brief Add abstraction over a function
- *
- * For example: `square` is transformed to
- * `fun x -> square x`.
- *
- * See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion
- * for more details.
- *
- * \param e The original function.
- * \param mod The module used for referencing global functions, can be
- * None.
- *
- * \return the new function with abstraction
- */
-TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod);
-
-/*!
  * \brief Check that each Var is only bound once.
  *
  * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
@@ -288,24 +271,6 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
  */
 TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
 
-/*! \brief Remove expressions which does not effect the program result.
- *
- * It will remove let bindings which are not referenced,
- * and inline let bindings that are only used once.
- *
- * For example, this pass should turn `let a = 1 in 2` into `2`,
- * as the value of the expression does not depend on a.
- *
- * As another example, `let a = 1 in a` will be optimized into 1,
- * if the flag is turned on.
- *
- * \param e the expression to optimize.
- * \param inline_once whether or not to inline binding used one.
- *
- * \return the optimized expression.
- */
-TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false);
-
 /*!
  * \brief Fold constant expressions.
  *
@@ -388,38 +353,6 @@ TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);
 TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
 
 /*!
- * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
- *
- * It will turn an expression that is in a graph form (with sharing implicit),
- * to an expression with explicit sharing (A-Normal Form).
- *
- * The scope of the root expression is the global scope.
- *
- * The scope of any non root expression is the least common ancestor of all it's scope.
- *
- * Values are ordered by post-DFS order in each scope.
- *
- * \param e the expression to observably share.
- * \param mod The module used for referencing global functions, can be
- * None.
- *
- * \return expression in A-Normal Form.
- */
-TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
-
-/*!
- * \brief Remove let binding and directly share via pointer instead.
- *
- * It will remove all let binding,
- * and turn all of the variable bound by let into direct pointer reference.
- *
- * \param e the expression.
- *
- * \return the expression in graph normal form.
- */
-TVM_DLL Expr ToGraphNormalForm(const Expr& e);
-
-/*!
  * \brief Finds cases that the given match expression does not catch, if any.
  *
  * \param match the match expression to test
@@ -432,28 +365,17 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e);
 TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
 
 /*!
- * \brief Aggressive constant propagation/constant folding/inlining.
- * It will do as much computation in compile time as possible.
- * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
- * As a side effect, code size will explode.
- *
- * \param e the expression
- * \param mod the module
- *
- * \return the optimized expression.
- */
-TVM_DLL Expr PartialEval(const Expr& e, const Module& mod);
-
-/*
- * \brief Bind function parameters or free variables.
+ * \brief Bind the free variables to a Relay expression.
  *
  * Parameter binding can only happen if expr is a Function.
  * binds cannot change internal arguments of internal functions.
  *
  * \param expr The function to be binded.
  * \param binds The map of arguments to
+ *
+ * \return The expression with all free vars bound.
  */
-TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& bind_map);
+TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
 
 /*! \brief A hashing structure in the style of std::hash. */
 struct StructuralHash {
index 04b4e64..9ae71d8 100644 (file)
@@ -541,6 +541,33 @@ TVM_DLL Pass AlterOpLayout();
  */
 TVM_DLL Pass CanonicalizeCast();
 
+/*!
+ * \brief Add abstraction over a function
+ *
+ * For example: `square` is transformed to
+ * `fun x -> square x`.
+ *
+ * See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion
+ * for more details.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass EtaExpand();
+
+/*!
+ * \brief This is a helper function that runs a some optimization passes on
+ * a certain expression and returns the optimized version. With the help of this
+ * function, users don't need to manually construct a module, then perform
+ * passes, and finally and extract the target function/expression from the
+ * returned module frequently.
+ *
+ * \param expr The expression to be optimized.
+ * \param passes The passses that will be applied on the given expression.
+ *
+ * \return The optimized expression.
+ */
+TVM_DLL Expr OptimizeOnExpr(const Expr& expr, const Array<Pass>& passes);
+
 }  // namespace transform
 }  // namespace relay
 }  // namespace tvm
index 1748571..52dc34d 100644 (file)
@@ -84,23 +84,6 @@ def backward_fold_scale_axis(expr):
     """
     return _ir_pass.backward_fold_scale_axis(expr)
 
-def eta_expand(expr, mod):
-    """Add abstraction over a function.
-
-    Parameters
-    ----------
-    expr : tvm.relay.Expr
-        The input expression, we expect that expr's types
-        should be fully inferred by infer_type.
-    mod : tvm.relay.Module
-         The global module.
-
-    Returns
-    -------
-    expanded_expr : tvm.relay.Expr
-        The expression after eta expansion.
-    """
-    return _ir_pass.eta_expand(expr, mod)
 
 def forward_fold_scale_axis(expr):
     """Fold the scaling of axis into weights of conv2d/dense.
@@ -318,25 +301,6 @@ def canonicalize_ops(expr):
     return _ir_pass.canonicalize_ops(expr)
 
 
-def dead_code_elimination(expr, inline_once=False):
-    """ Remove expressions which does not effect the program result (dead code).
-
-    Parameters
-    ----------
-    expr : tvm.relay.Expr
-        The input Expression
-
-    inline_once : Optional[Bool]
-        Whether to inline binding that occur only once.
-    Returns
-    -------
-    result : tvm.relay.Expr
-        An expression which is semantically equal to the input expression,
-        but with dead code removed.
-    """
-    return _ir_pass.dead_code_elimination(expr, inline_once)
-
-
 def alpha_equal(lhs, rhs):
     """Compare two Relay expr for structural equivalence (alpha equivalence).
 
@@ -534,46 +498,6 @@ def collect_device_annotation_ops(expr):
     return _ir_pass.CollectDeviceAnnotationOps(expr)
 
 
-def to_a_normal_form(expr, mod=None):
-    """
-    Turn Graph Normal Form expression into A Normal Form Expression.
-
-    The scope of the root expression is the global scope.
-
-    The scope of any non root expression is the least common ancestor of all it's scope.
-
-    Values are ordered by post-DFS order in each scope.
-
-    Parameters
-    ----------
-    expr : tvm.relay.Expr
-        The input expression.
-
-    mod : Optional[tvm.relay.Module]
-        The global module.
-
-    Returns
-    -------
-    result : tvm.relay.Expr
-      The output expression.
-    """
-    return _ir_pass.to_a_normal_form(expr, mod)
-
-
-def to_graph_normal_form(expr):
-    """Turn A Normal Form expression into Graph Normal Form expression
-    Parameters
-    ----------
-    expr : tvm.relay.Expr
-        The input expression
-    Returns
-    -------
-    result : tvm.relay.Expr
-      The output expression
-    """
-    return _ir_pass.to_graph_normal_form(expr)
-
-
 def gradient(expr, mod=None, mode='higher_order'):
     """
     Transform the input function,
@@ -642,26 +566,6 @@ def eliminate_common_subexpr(expr, fskip=None):
     return _ir_pass.eliminate_common_subexpr(expr, fskip)
 
 
-def partial_evaluate(expr, mod=None):
-    """
-    Evaluate the static fragment of the code.
-
-    Parameters
-    ----------
-    expr : tvm.relay.Expr
-        The input expression.
-
-    mod : Optional[tvm.relay.Module]
-        The global module
-
-    Returns
-    -------
-    result : tvm.relay.Expr
-      The output expression.
-    """
-    return _ir_pass.partial_evaluate(expr, mod)
-
-
 def unmatched_cases(match, mod=None):
     """
     Finds cases that the match expression does not catch, if any.
index 5f47e5b..ba4857d 100644 (file)
@@ -302,15 +302,20 @@ def CanonicalizeOps():
     return _transform.CanonicalizeOps()
 
 
-def DeadCodeElimination():
-    """ Remove expressions which does not effect the program result (dead code).
+def DeadCodeElimination(inline_once=False):
+    """Remove expressions which does not effect the program result (dead code).
+
+    Parameters
+    ----------
+    inline_once: Optional[Bool]
+        Whether to inline binding that occurs only once.
 
     Returns
     -------
     ret: tvm.relay.Pass
         The registered pass that eliminates the dead code in a Relay program.
     """
-    return _transform.DeadCodeElimination()
+    return _transform.DeadCodeElimination(inline_once)
 
 
 def FoldConstant():
@@ -406,6 +411,7 @@ def ToANormalForm():
     """
     return _transform.ToANormalForm()
 
+
 def EtaExpand():
     """Add abstraction over a function
 
@@ -416,6 +422,7 @@ def EtaExpand():
     """
     return _transform.EtaExpand()
 
+
 def ToGraphNormalForm():
     """Turn A Normal Form expression into Graph Normal Form expression
 
@@ -449,7 +456,7 @@ def PartialEvaluate():
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret: tvm.relay.Pass
         The registered pass that performs partial evaluation on an expression.
     """
     return _transform.PartialEvaluate()
@@ -465,6 +472,31 @@ def CanonicalizeCast():
     """
     return _transform.CanonicalizeCast()
 
+
+def OptimizeOnExpr(expr, passes):
+    """Perform optimization passes on an expressioin.
+
+    Parameters
+    ----------
+    expr: tvm.relay.Expr
+        The expression for optimization.
+
+    passes: Union[Pass, List[Pass]]
+        The list of optimizations to be applied.
+
+    Returns
+    -------
+    ret: tvm.relay.Expr
+        The optimized expression.
+    """
+    if isinstance(passes, Pass):
+        passes = [passes]
+    if not isinstance(passes, (list, tuple)):
+        raise TypeError("passes must be a pass or a list of pass objects.")
+
+    return _transform.OptimizeOnExpr(expr, passes)
+
+
 def _wrap_class_module_pass(pass_cls, pass_info):
     """Wrap a python class as function pass"""
     class PyModulePass(ModulePass):
index 7e186f8..8799bf4 100644 (file)
@@ -156,9 +156,6 @@ Expr DeadCodeElimination(const Expr& e, bool inline_once) {
   return CalcDep::Eliminate(e, inline_once);
 }
 
-TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
-.set_body_typed(DeadCodeElimination);
-
 namespace transform {
 
 Pass DeadCodeElimination(bool inline_once) {
index b95c584..e7edbb3 100644 (file)
@@ -1086,27 +1086,30 @@ Expr PostProcess(const Expr& e) {
 
 }  // namespace partial_eval
 
-Expr PartialEval(const Expr& e, const Module& m) {
-  return TransformF([&](const Expr& e) {
+Module PartialEval(const Module& m) {
+  CHECK(m->entry_func.defined());
+  auto func = m->Lookup(m->entry_func);
+  Expr ret =
+    TransformF([&](const Expr& e) {
       return LetList::With([&](LetList* ll) {
-          relay::partial_eval::PartialEvaluator pe(FreeVars(e), m);
-          pe.InitializeFuncId(e);
-          return relay::partial_eval::PostProcess(pe.VisitExpr(e, ll)->dynamic);
-        });
-    }, e);
+        relay::partial_eval::PartialEvaluator pe(FreeVars(e), m);
+        pe.InitializeFuncId(e);
+        return relay::partial_eval::PostProcess(pe.VisitExpr(e, ll)->dynamic);
+      });
+    }, func);
+  CHECK(ret->is_type<FunctionNode>());
+  m->Update(m->entry_func, Downcast<Function>(ret));
+  return m;
 }
 
-TVM_REGISTER_API("relay._ir_pass.partial_evaluate")
-.set_body_typed(PartialEval);
-
 namespace transform {
 
 Pass PartialEval() {
-  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
-    [=](Function f, Module m, PassContext pc) {
-    return Downcast<Function>(PartialEval(f, m));
+  runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
+    [=](Module m, PassContext pc) {
+    return PartialEval(m);
   };
-  return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {});
+  return CreateModulePass(pass_func, 1, "PartialEvaluate", {});
 }
 
 TVM_REGISTER_API("relay._transform.PartialEvaluate")
index d63d912..a620316 100644 (file)
@@ -573,6 +573,18 @@ class PassContext::Internal {
   }
 };
 
+Expr OptimizeOnExpr(const Expr& expr, const Array<Pass>& passes) {
+  auto mod = ModuleNode::FromExpr(expr);
+  Sequential seq(passes);
+  auto pass_ctx = PassContext::Create();
+  pass_ctx->opt_level = 3;
+  tvm::With<PassContext> ctx_scope(pass_ctx);
+  mod = seq(mod);
+  CHECK(mod.defined());
+  auto entry_func = mod->Lookup(mod->entry_func);
+  return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
+}
+
 TVM_REGISTER_API("relay._transform.GetCurrentPassContext")
 .set_body_typed(PassContext::Current);
 
@@ -582,6 +594,9 @@ TVM_REGISTER_API("relay._transform.EnterPassContext")
 TVM_REGISTER_API("relay._transform.ExitPassContext")
 .set_body_typed(PassContext::Internal::ExitScope);
 
+TVM_REGISTER_API("relay._transform.OptimizeOnExpr")
+.set_body_typed(OptimizeOnExpr);
+
 }  // namespace transform
 }  // namespace relay
 }  // namespace tvm
index 324eddd..b5a3f85 100644 (file)
@@ -26,6 +26,8 @@
  */
 #include <tvm/relay/pass.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/expr_functor.h>
 #include <tvm/logging.h>
 #include "let_list.h"
 #include "../../common/arena.h"
 namespace tvm {
 namespace relay {
 
-Expr ToANormalForm(const Expr& e,
-                   const Module& m,
-                   std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv);
-
 struct ScopeNode;
 using Scope = std::shared_ptr<ScopeNode>;
 
@@ -104,29 +102,21 @@ bool IsPrimitiveFunction(const Expr& e) {
 class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
  public:
   static Expr ToANormalForm(const Expr& e,
-                            const Module& m,
                             const DependencyGraph& dg,
-                            std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
-                            std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
-    Fill fi(m, dg, node_scope, gv);
+                            std::unordered_map<DependencyGraph::Node*, Scope>* node_scope) {
+    Fill fi(dg, node_scope);
     return fi.GetScope(e)->ll->Get(fi.VisitExpr(e));
   }
 
  private:
-  Module mod_;
   const DependencyGraph& dg_;
   std::unordered_map<DependencyGraph::Node*, Scope>* node_scope_;
-  std::unordered_set<GlobalVar, NodeHash, NodeEqual>* visited_;
   std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo;
 
-  Fill(Module mod,
-       const DependencyGraph& dg,
-       std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
-       std::unordered_set<GlobalVar, NodeHash, NodeEqual>* visited) :
-    mod_(mod),
+  Fill(const DependencyGraph& dg,
+       std::unordered_map<DependencyGraph::Node*, Scope>* node_scope) :
     dg_(dg),
-    node_scope_(node_scope),
-    visited_(visited) { }
+    node_scope_(node_scope) { }
 
   Scope GetScope(const Expr& e) {
     return node_scope_->at(dg_.expr_node.at(e));
@@ -246,10 +236,6 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
 
   Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final {
     GlobalVar gv = GetRef<GlobalVar>(gvn);
-    if (visited_->count(gv) == 0) {
-      visited_->insert(gv);
-      mod_->Update(gv, Downcast<Function>(relay::ToANormalForm(mod_->Lookup(gv), mod_, visited_)));
-    }
     return Atomic(gv, gv, v);
   }
 
@@ -276,9 +262,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
   }
 };
 
-Expr ToANormalFormAux(const Expr& e,
-                      const Module& m,
-                      std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
+Expr ToANormalFormAux(const Expr& e) {
   /* When you lift a lambda, what is inside is also being lift.
    *
    * So we must determine the scope of the lambda before determining the scope of it's body.
@@ -301,46 +285,40 @@ Expr ToANormalFormAux(const Expr& e,
    * We do an additional pass to fill all the LetList and we are done.
    */
   std::unordered_map<DependencyGraph::Node*, Scope> node_scope = CalcScope(dg);
-  return Fill::ToANormalForm(e, m, dg, &node_scope, gv);
+  return Fill::ToANormalForm(e, dg, &node_scope);
 }
 
-Expr ToANormalForm(const Expr& e,
-                   const Module& m,
-                   std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
-  DLOG(INFO)
-  << "ToANF:" << std::endl
-  << AsText(e, false);
-
-  Expr ret =
-    TransformF([&](const Expr& e) {
-      return ToANormalFormAux(e, m, gv);
-    }, e);
-
-  CHECK_EQ(FreeVars(ret).size(), 0);
+Module ToANormalForm(const Module& m) {
+  DLOG(INFO) << "ToANF:" << std::endl << m;
+
+  tvm::Map<GlobalVar, Function> updates;
+  auto funcs = m->functions;
+  for (const auto& it : funcs) {
+    Expr ret =
+      TransformF([&](const Expr& e) {
+        return ToANormalFormAux(e);
+      }, it.second);
+    CHECK_EQ(FreeVars(ret).size(), 0);
+    updates.Set(it.first, Downcast<Function>(ret));
+  }
 
-  DLOG(INFO)
-    << "ToANF: transformed" << std::endl
-    << AsText(ret, false);
+  for (auto pair : updates) {
+    m->Add(pair.first, pair.second, true);
+  }
 
-  return ret;
-}
+  DLOG(INFO) << "ToANF: transformed" << std::endl << m;
 
-Expr ToANormalForm(const Expr& e, const Module& m) {
-  std::unordered_set<GlobalVar, NodeHash, NodeEqual> gv;
-  return ToANormalForm(e, m, &gv);
+  return m;
 }
 
-TVM_REGISTER_API("relay._ir_pass.to_a_normal_form")
-.set_body_typed(static_cast<Expr (*)(const Expr&, const Module&)>(ToANormalForm));
-
 namespace transform {
 
 Pass ToANormalForm() {
-  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
-    [=](Function f, Module m, PassContext pc) {
-    return Downcast<Function>(ToANormalForm(f, m));
+  runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
+    [=](Module m, PassContext pc) {
+    return ToANormalForm(m);
   };
-  return CreateFunctionPass(pass_func, 1, "ToANormalForm", {});
+  return CreateModulePass(pass_func, 1, "ToANormalForm", {});
 }
 
 TVM_REGISTER_API("relay._transform.ToANormalForm")
index 9c166f9..c1ae19e 100644 (file)
@@ -24,8 +24,8 @@
  *
  * \brief Turn A normal form into graph normal form.
  */
-#include <tvm/relay/pass.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
 #include "let_list.h"
 
 namespace tvm {
@@ -76,9 +76,6 @@ Expr ToGraphNormalForm(const Expr& e) {
   return GNF()(e);
 }
 
-TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form")
-.set_body_typed(ToGraphNormalForm);
-
 namespace transform {
 
 Pass ToGraphNormalForm() {
index 9158f07..c3b12fe 100644 (file)
@@ -18,20 +18,13 @@ from nose.tools import nottest
 
 import tvm
 from tvm import relay
-from tvm.relay.ir_pass import dead_code_elimination, alpha_equal
+from tvm.relay import Function, transform
+from tvm.relay.ir_pass import alpha_equal, graph_equal, free_vars
 from tvm.relay.op import log, add, equal, subtract
 
 
 class env:
     def __init__(self):
-        self.a = relay.Var("a")
-        self.b = relay.Var("b")
-        self.c = relay.Var("c")
-        self.d = relay.Var("d")
-        self.e = relay.Var("e")
-        self.x = relay.Var("x")
-        self.y = relay.Var("y")
-        self.z = relay.Var("z")
         self.shape = tvm.convert([1, 2, 3])
         self.tt = relay.TensorType(self.shape, "float32")
         self.int32 = relay.TensorType([], "int32")
@@ -39,6 +32,14 @@ class env:
         self.one = relay.const(1.0)
         self.two = relay.const(2.0)
         self.three = relay.const(3.0)
+        self.a = relay.Var("a", self.float32)
+        self.b = relay.Var("b", self.float32)
+        self.c = relay.Var("c", self.float32)
+        self.d = relay.Var("d", self.float32)
+        self.e = relay.Var("e", self.float32)
+        self.x = relay.Var("x", self.int32)
+        self.y = relay.Var("y", self.int32)
+        self.z = relay.Var("z", self.int32)
 
 
 e = env()
@@ -46,22 +47,27 @@ e = env()
 
 def test_let():
     orig = relay.Let(e.x, e.y, e.z)
-    assert alpha_equal(dead_code_elimination(orig), e.z)
+    orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
+    assert alpha_equal(Function(free_vars(orig), orig), Function([e.z], e.z))
 
 
 def test_used_let():
     orig = relay.Let(e.c, e.one, e.c + e.c)
-    assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.one, e.c + e.c))
+    orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
+    expected = relay.Let(e.c, e.one, e.c + e.c)
+    assert alpha_equal(Function([e.c], orig), Function([e.c], expected))
 
 @nottest
 def test_inline():
     orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
-    assert alpha_equal(dead_code_elimination(orig), e.d)
+    orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
+    assert alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d))
 
 
 def test_chain_unused_let():
     orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e))
-    assert alpha_equal(dead_code_elimination(orig), e.e)
+    orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
+    assert alpha_equal(Function(free_vars(orig), orig), Function([e.e], e.e))
 
 
 # make sure we dont infinite loop
@@ -78,27 +84,39 @@ def test_recursion():
        f(2, 10000);
     """
     f = relay.Var("f")
+    f1 = relay.Var("f1")
     n = relay.Var("n", e.int32)
     data = relay.Var("data", e.float32)
     funcbody = relay.If(equal(n, relay.const(0)),
                         data,
-                        relay.Call(f, [subtract(n, relay.const(1.0)),
+                        relay.Call(f1, [subtract(n, relay.const(1)),
                                        log(data)]))
     value = relay.Function([n, data], funcbody, e.float32, [])
-    orig = relay.Let(f, value, relay.Call(f, [relay.const(2.0), relay.const(10000.0)]))
-    assert alpha_equal(dead_code_elimination(orig), orig)
-    assert alpha_equal(dead_code_elimination(relay.Let(f, value, e.three)), e.three)
+    orig = relay.Let(f, value, relay.Call(f, [relay.const(2), relay.const(10000.0)]))
+    dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
+    orig = transform.OptimizeOnExpr(orig, transform.InferType())
+    assert graph_equal(dced, orig)
+    dced = transform.OptimizeOnExpr(relay.Let(f, value, e.three),
+                                    transform.DeadCodeElimination())
+    assert alpha_equal(dced, e.three)
 
 
 def test_op_let():
-    assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two))
+    dced = transform.OptimizeOnExpr(add(relay.Let(e.a, e.one, e.three), e.two),
+                                   transform.DeadCodeElimination())
+    assert alpha_equal(dced, add(e.three, e.two))
 
 
 def test_tuple_get_item():
-    t = relay.Var('t')
+    tt = relay.TupleType([e.float32, e.float32])
+    t = relay.Var('t', tt)
+    a = relay.Var('a')
     g = relay.TupleGetItem(t, 0)
-    assert alpha_equal(dead_code_elimination(g), g)
-    assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t), 0)), g)
+    dced = transform.OptimizeOnExpr(g, transform.DeadCodeElimination())
+    assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
+    orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0)
+    dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
+    assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
 
 
 if __name__ == "__main__":
index b3c0c28..f2aedd1 100644 (file)
 import numpy as np
 import tvm
 from tvm import relay
-from tvm.relay.ir_pass import partial_evaluate, alpha_equal, infer_type, dead_code_elimination
-from tvm.relay.ir_pass import gradient
-from tvm.relay import op, create_executor
-from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
+from tvm.relay.ir_pass import alpha_equal, gradient
 from tvm.relay.prelude import Prelude
-from tvm.relay import create_executor
-from nose.tools import nottest
+from tvm.relay import op, create_executor, transform
 from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate
 from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match
-from tvm.relay import GlobalVar, Call, Type
-from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr
+from tvm.relay import GlobalVar, Call
+from tvm.relay.testing import add_nat_definitions, make_nat_expr
 
 def check_eval(expr, expected_result, mod=None, rtol=1e-07):
     ctx = tvm.context("llvm", 0)
@@ -38,8 +34,25 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07):
     np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
 
 
-def dcpe(expr, mod=None):
-    return dead_code_elimination(partial_evaluate(expr, mod=mod), inline_once=True)
+def tipe(expr):
+    return transform.OptimizeOnExpr(expr,
+                                    [transform.InferType(),
+                                     transform.PartialEvaluate(),
+                                     transform.InferType()])
+
+
+def dcpe(expr, mod=None, grad=False):
+    passes = [transform.PartialEvaluate(),
+              transform.DeadCodeElimination(inline_once=True)]
+    if grad:
+        expr = gradient(expr)
+    if mod:
+        assert isinstance(expr, Function)
+        mod[mod.entry_func] = expr
+        seq = transform.Sequential(passes)
+        mod = seq(mod)
+        return mod[mod.entry_func]
+    return transform.OptimizeOnExpr(expr, passes)
 
 
 def test_tuple():
@@ -47,24 +60,31 @@ def test_tuple():
     x = Var("x", t)
     body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
     f = Function([x], body, None, [t])
-    assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t]))
+    expected = relay.Function([x], x, None, [t])
+    expected = transform.OptimizeOnExpr(expected, transform.InferType())
+    assert alpha_equal(dcpe(f), expected)
+
 
 def test_const_inline():
-    d = Var("d")
+    t = relay.TensorType([], "float32")
+    d = Var("d", t)
     double = Function([d], d + d)
     orig = double(const(4.0))
     assert alpha_equal(dcpe(orig), const(8.0))
 
 
 def test_ref():
-    d = relay.Var("d")
-    r = relay.Var("r")
+    t = relay.TensorType([], "float32")
+    d = relay.Var("d", t)
+    r = relay.Var("r", relay.RefType(t))
     x = relay.Var("x")
     body = relay.RefRead(r)
     body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body)
     body = Let(r, RefCreate(d), body)
     square = Function([d], body)
-    assert alpha_equal(dcpe(square), Function([d], d * d))
+    expected = transform.OptimizeOnExpr(Function([d], d * d),
+                                        transform.InferType())
+    assert alpha_equal(dcpe(square), expected)
 
 
 def test_empty_ad():
@@ -73,17 +93,19 @@ def test_empty_ad():
     t = TensorType(shape, dtype)
     d = Var("d", t)
     f = Function([d], d)
-    g = dcpe(gradient(f))
+    g = dcpe(f, grad=True)
     expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])]))
+    expected = transform.OptimizeOnExpr(expected, transform.InferType())
     assert alpha_equal(g, expected)
 
+
 def test_ad():
     shape = (10, 10)
     dtype = "float32"
     t = TensorType(shape, dtype)
     d = Var("d", t)
     f = Function([d], d * d)
-    g = dcpe(gradient(f))
+    g = dcpe(f, grad=True)
     m = d * d
     x = relay.Var("x")
     o = op.ones_like(x)
@@ -92,6 +114,7 @@ def test_ad():
     body = Tuple([x, Tuple([grad])])
     body = relay.Let(x1, o, body)
     expected = Function([d], relay.Let(x, m, body))
+    expected = transform.OptimizeOnExpr(expected, transform.InferType())
     assert alpha_equal(g, expected)
 
 
@@ -107,8 +130,7 @@ def test_if_ref():
     eff = Var("eff")
     body = Let(eff, body, RefRead(r))
     f = Function([d], Let(r, RefCreate(const(1)), Let(u, update, body)))
-    f = infer_type(f)
-    pe_f = infer_type(partial_evaluate(f))
+    pe_f = tipe(f)
     ex = create_executor()
     f_res = ex.evaluate(f)(const(True))
     pe_f_res = ex.evaluate(pe_f)(const(True))
@@ -132,8 +154,7 @@ def test_function_invalidate():
     body = Let(fet, fetch, body)
     body = Let(r, RefCreate(const(0)), body)
     f = Function([d], body)
-    f = infer_type(f)
-    pe_f = infer_type(partial_evaluate(f))
+    pe_f = tipe(f)
     ex = create_executor()
     f_res = ex.evaluate(f)(const(True))
     pe_f_res = ex.evaluate(pe_f)(const(True))
@@ -144,35 +165,30 @@ def test_function_invalidate():
 def test_head_cons():
     mod = Module()
     p = Prelude(mod)
-    def hd_impl():
-        a = TypeVar("a")
-        x = Var("x", p.l(a))
-        y = Var("y")
-        z = Var("z")
-        cons_case = Clause(PatternConstructor(p.cons,
-                                              [PatternVar(y),
-                                               PatternVar(z)]),
-                           y)
-        y = Var("y")
-        z = Var("z")
-        return Function([x], Match(x, [cons_case]), a, [a])
+    hd = p.hd
     t = TypeVar("t")
     x = Var("x", t)
-    hd = Var("hd")
-    body = Let(hd, hd_impl(), hd(p.cons(x, p.nil())))
+    body = hd(p.cons(x, p.nil()))
     f = Function([x], body, None, [t])
-    f = infer_type(f, mod=mod)
-    res = dcpe(f)
+    res = dcpe(f, mod)
     assert alpha_equal(res, Function([x], x, t, [t]))
 
 
 def test_map():
     mod = Module()
     p = Prelude(mod)
-    f = Var("f")
+    f = GlobalVar("f")
+    t = TypeVar("t")
+    a = Var("a", t)
+    mod[f] = Function([a], a, t, [t])
     orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil()))))
-    expected = p.cons(f(const(1)), p.cons(f(const(2)), p.cons(f(const(3)), p.nil())))
-    assert alpha_equal(dcpe(orig, mod=mod), expected)
+    expected = p.cons((const(1)), p.cons((const(2)), p.cons((const(3)), p.nil())))
+    expected = Function([], expected)
+    mod[mod.entry_func] = expected
+    expected = mod[mod.entry_func]
+    orig = Function([], orig)
+    res = dcpe(orig, mod=mod)
+    assert alpha_equal(res.body, expected.body)
 
 
 def test_loop():
@@ -181,9 +197,12 @@ def test_loop():
     x = Var("x", t)
     loop = GlobalVar("loop")
     mod[loop] = Function([x], loop(x), t, [t])
-    res = dcpe(loop(const(1)), mod=mod)
-    expected = Call(loop, [const(1)], None, [None])
-    assert alpha_equal(res, expected)
+    expected = Call(loop, [const(1)])
+    mod[mod.entry_func] = Function([], expected)
+    expected = mod[mod.entry_func].body
+    call = Function([], loop(const(1)))
+    res = dcpe(call, mod=mod)
+    assert alpha_equal(res.body, expected)
 
 
 def test_swap_loop():
@@ -196,8 +215,9 @@ def test_swap_loop():
     loop = GlobalVar("loop")
     mod[loop] = Function([x, y], loop(y, x), nat)
     prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2))
-    res = dcpe(prog, mod=mod)
-    assert alpha_equal(prog, res)
+    res = Function([], prog)
+    res = dcpe(res, mod=mod)
+    assert alpha_equal(prog, res.body)
 
 
 def test_abs_diff():
@@ -217,8 +237,9 @@ def test_abs_diff():
     x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case]))
     mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case]))
     orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
+    orig = Function([], orig)
     res = dcpe(orig, mod=mod)
-    assert alpha_equal(res, make_nat_expr(p, 4))
+    assert alpha_equal(res.body, make_nat_expr(p, 4))
 
 
 def test_match_nat_id():
@@ -233,8 +254,9 @@ def test_match_nat_id():
     s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y))
     mod[nat_id] = Function([x], Match(x, [z_case, s_case]))
     orig = nat_id(make_nat_expr(p, 3))
+    orig = Function([], orig)
     res = dcpe(orig, mod=mod)
-    assert alpha_equal(res, make_nat_expr(p, 3))
+    assert alpha_equal(res.body, make_nat_expr(p, 3))
 
 
 def test_nat_id():
@@ -247,8 +269,9 @@ def test_nat_id():
     nat_id = GlobalVar("nat_id")
     mod[nat_id] = Function([x], x)
     orig = nat_id(make_nat_expr(p, 3))
+    orig = Function([], orig)
     res = dcpe(orig, mod=mod)
-    assert alpha_equal(res, make_nat_expr(p, 3))
+    assert alpha_equal(res.body, make_nat_expr(p, 3))
 
 
 def test_global_match_nat_id():
@@ -260,8 +283,9 @@ def test_global_match_nat_id():
     z_case = Clause(PatternConstructor(p.z, []), p.z())
     s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x))
     orig = Match(make_nat_expr(p, 3), [z_case, s_case])
+    orig = Function([], orig)
     res = dcpe(orig, mod=mod)
-    assert alpha_equal(res, make_nat_expr(p, 3))
+    assert alpha_equal(res.body, make_nat_expr(p, 3))
 
 
 def test_double():
@@ -269,8 +293,9 @@ def test_double():
     p = Prelude(mod)
     add_nat_definitions(p)
     orig = p.double(make_nat_expr(p, 3))
+    orig = Function([], orig)
     res = dcpe(orig, mod=mod)
-    assert alpha_equal(res, make_nat_expr(p, 6))
+    assert alpha_equal(res.body, make_nat_expr(p, 6))
 
 
 if __name__ == '__main__':
index 9a2570e..e741681 100644 (file)
@@ -17,9 +17,8 @@
 import numpy as np
 import tvm
 from tvm import relay
-from tvm.relay.ir_pass import to_a_normal_form, alpha_equal, infer_type, detect_feature
-from tvm.relay import op, create_executor
-from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
+from tvm.relay.ir_pass import alpha_equal, detect_feature
+from tvm.relay import op, create_executor, transform
 from tvm.relay.prelude import Prelude
 from tvm.relay.testing import add_nat_definitions, count
 from tvm.relay.feature import Feature
@@ -39,7 +38,7 @@ def test_explicit_bound():
     z = op.add(y, y)
     f = relay.Function([], op.add(z, z))
     assert not Feature.fLet in detect_feature(f)
-    anf = to_a_normal_form(f)
+    anf = transform.OptimizeOnExpr(f, transform.ToANormalForm())
     assert Feature.fLet in detect_feature(anf)
     check_eval(f(), 8.0)
     check_eval(anf(), 8.0)
@@ -53,7 +52,8 @@ def test_order():
     x = relay.const(1)
     val = x + y * z
     check_eval(val, 7.0)
-    anf = infer_type(to_a_normal_form(val))
+    anf = transform.OptimizeOnExpr(val, [transform.ToANormalForm(),
+                                         transform.InferType()])
     a = relay.Var('a', relay.IncompleteType())
     b = relay.Var('b', relay.IncompleteType())
     c = relay.Var('c', relay.IncompleteType())
@@ -65,14 +65,16 @@ def test_order():
     expected_output = relay.Let(c, z, expected_output)
     expected_output = relay.Let(b, y, expected_output)
     expected_output = relay.Let(a, x, expected_output)
-    expected_output = infer_type(expected_output)
+    expected_output = transform.OptimizeOnExpr(expected_output,
+                                               transform.InferType())
     assert alpha_equal(anf, expected_output)
 
 
 def test_if():
     cond = relay.const(True)
     x = relay.If(cond, relay.const(2), relay.const(3))
-    anf = infer_type(to_a_normal_form(x))
+    anf = transform.OptimizeOnExpr(x, [transform.ToANormalForm(),
+                                       transform.InferType()])
     a = relay.Var('a', relay.IncompleteType())
     b = relay.Var('b', relay.IncompleteType())
     c = relay.Var('c', relay.IncompleteType())
@@ -82,7 +84,8 @@ def test_if():
     expected_output = relay.If(c, true_branch, false_branch)
     expected_output = relay.Let(d, expected_output, d)
     expected_output = relay.Let(c, cond, expected_output)
-    expected_output = infer_type(expected_output)
+    expected_output = transform.OptimizeOnExpr(expected_output,
+                                               transform.InferType())
     assert alpha_equal(anf, expected_output)
 
 
@@ -114,7 +117,8 @@ def test_recursion():
     mod[f] = value
     check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
     old_f = mod[f]
-    f = to_a_normal_form(f, mod=mod)
+    mod = transform.ToANormalForm()(mod)
+    f = mod[f]
     check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
 
 
@@ -129,7 +133,8 @@ def test_ref():
     body = relay.Let(iv, relay.RefRead(i), body)
     body = relay.Let(i, relay.RefCreate(relay.const(1)), body)
     check_eval(body, 3)
-    check_eval(to_a_normal_form(body), 3)
+    opt_body = transform.OptimizeOnExpr(body, transform.ToANormalForm())
+    check_eval(opt_body, 3)
 
 
 def test_nat_add():
@@ -144,7 +149,12 @@ def test_nat_add():
     intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
     assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
     assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
-    assert count(p, intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
+    expr = add(s(z()), s(z()))
+    f = relay.GlobalVar("f")
+    mod[f] = relay.Function([], expr)
+    mod = transform.ToANormalForm()(mod)
+    expr = mod["f"]
+    assert count(p, intrp.evaluate(expr.body)) == 2
     assert Feature.fLet in detect_feature(mod[add])
 
 
@@ -155,14 +165,16 @@ def test_let():
     body = relay.Let(y, x, x + y)
     body = relay.Let(x, d, body)
     check_eval(body, 8)
-    check_eval(to_a_normal_form(body), 8)
+    opt_body = transform.OptimizeOnExpr(body, transform.ToANormalForm())
+    check_eval(opt_body, 8)
 
 
 def test_function():
-    x = relay.Var("x")
+    t = relay.TensorType((), 'float32')
+    x = relay.Var("x", t)
     f = relay.Function([x], x + x)
     d = relay.const(4.0, 'float32')
-    anf_f = to_a_normal_form(f)
+    anf_f = transform.OptimizeOnExpr(f, transform.ToANormalForm())
     assert isinstance(anf_f, relay.Function)
     check_eval(f(d), 8)
     check_eval(anf_f(d), 8)
index 6d9bd6a..09db48f 100644 (file)
 import numpy as np
 import tvm
 from tvm import relay
-from tvm.relay.ir_pass import to_graph_normal_form, to_a_normal_form, alpha_equal, detect_feature
-from tvm.relay import op, create_executor
+from tvm.relay import op, create_executor, transform
+from tvm.relay.ir_pass import detect_feature
 from tvm.relay.feature import Feature
-from tvm.relay.backend.interpreter import Value, TupleValue
 
 
 def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
@@ -41,9 +40,9 @@ def test_implicit_share():
     body = relay.Let(z, op.add(y, y), op.add(z, z))
     body = relay.Let(y, op.add(x, x), body)
     f = relay.Function([], relay.Let(x, relay.const(1), body))
-    g = to_graph_normal_form(f)
-    assert "let" in f.astext()
-    assert not "let" in g.astext()
+    g = transform.OptimizeOnExpr(f, transform.ToGraphNormalForm())
+    assert Feature.fLet in detect_feature(f)
+    assert not Feature.fLet in detect_feature(g)
     check_eval(f, [], 8.0)
     check_eval(g, [], 8.0)
 
@@ -55,8 +54,8 @@ def test_round_trip():
     body = relay.Let(z, op.add(y, y), op.add(z, z))
     body = relay.Let(y, op.add(x, x), body)
     f = relay.Function([], relay.Let(x, relay.const(1), body))
-    g = to_graph_normal_form(f)
-    h = to_a_normal_form(g)
+    g = transform.OptimizeOnExpr(f, transform.ToGraphNormalForm())
+    h = transform.OptimizeOnExpr(g, transform.ToANormalForm())
     assert Feature.fLet in detect_feature(f)
     assert not Feature.fLet in detect_feature(g)
     check_eval(f, [], 8.0)