save (#3033)
author雾雨魔理沙 <lolisa@marisa.moe>
Sat, 15 Jun 2019 22:08:46 +0000 (15:08 -0700)
committerJared Roesch <roeschinc@gmail.com>
Sat, 15 Jun 2019 22:08:46 +0000 (15:08 -0700)
save

save

save

upstream

lint

remove bad changes

fix build

save

save

please the ci god

Update src/relay/pass/partial_eval.cc

Co-Authored-By: Wei Chen <ipondering.weic@gmail.com>
save

fix test

ci is ANGRY

fix rebase problem

fix rebase

add test

save

save

comment

include/tvm/relay/pass.h
include/tvm/relay/transform.h
python/tvm/relay/ir_pass.py
src/relay/ir/expr.cc
src/relay/pass/dead_code.cc
src/relay/pass/partial_eval.cc
tests/python/relay/test_pass_partial_eval.py

index 977bb67..fff630f 100644 (file)
@@ -296,13 +296,15 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
  * For example, this pass should turn `let a = 1 in 2` into `2`,
  * as the value of the expression does not depend on a.
  *
- * As another example, `let a = 1 in a` will be optimized into 1.
+ * As another example, `let a = 1 in a` will be optimized into 1,
+ * if the flag is turned on.
  *
  * \param e the expression to optimize.
+ * \param inline_once whether or not to inline binding used one.
  *
  * \return the optimized expression.
  */
-TVM_DLL Expr DeadCodeElimination(const Expr& e);
+TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false);
 
 /*!
  * \brief Fold constant expressions.
@@ -435,11 +437,12 @@ TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
  * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
  * As a side effect, code size will explode.
  *
- * \param e the expression,
+ * \param e the expression
+ * \param mod the module
  *
  * \return the optimized expression.
  */
-TVM_DLL Expr PartialEval(const Expr& e);
+TVM_DLL Expr PartialEval(const Expr& e, const Module& mod);
 
 /*!
  * \brief Bind the free variables to a Relay expression.
index f579f1c..fb8ebbf 100644 (file)
@@ -356,9 +356,11 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
  *
  * As another example, `let a = 1 in a` will be optimized into 1.
  *
+ * \param inline_once whether or not to inline binding used one.
+ *
  * \return the pass.
  */
-TVM_DLL Pass DeadCodeElimination();
+TVM_DLL Pass DeadCodeElimination(bool inline_once = false);
 
 /*!
  * \brief Fold constant expressions.
index 8f1cede..dd0f54c 100644 (file)
@@ -129,7 +129,7 @@ def well_formed(expr):
 
     Parameters
     ----------
-    expr: tvm.relay.Expr
+    expr : tvm.relay.Expr
         The input expression
 
     Returns
@@ -175,7 +175,7 @@ def free_vars(expr):
 
     Parameters
     ----------
-    expr: tvm.relay.Expr
+    expr : tvm.relay.Expr
         The input expression
 
     Returns
@@ -197,7 +197,7 @@ def bound_vars(expr):
 
     Parameters
     ----------
-    expr: tvm.relay.Expr
+    expr : tvm.relay.Expr
         The input expression
 
     Returns
@@ -213,7 +213,7 @@ def all_vars(expr):
 
     Parameters
     ----------
-    expr: tvm.relay.Expr
+    expr : tvm.relay.Expr
         The input expression
 
     Returns
@@ -229,9 +229,10 @@ def free_type_vars(expr, mod=None):
 
     Parameters
     ----------
-    expr: Union[tvm.relay.Expr,tvm.relay.Type]
+    expr : Union[tvm.relay.Expr,tvm.relay.Type]
         The input expression/type
-    mod: tvm.relay.Module, optional
+
+    mod : Optional[tvm.relay.Module]
         The global module
 
     Returns
@@ -248,9 +249,10 @@ def bound_type_vars(expr, mod=None):
 
     Parameters
     ----------
-    expr: Union[tvm.relay.Expr,tvm.relay.Type]
+    expr : Union[tvm.relay.Expr,tvm.relay.Type]
         The input expression/type
-    mod: tvm.relay.Module, optional
+
+    mod : Optional[tvm.relay.Module]
         The global module
 
     Returns
@@ -267,9 +269,9 @@ def all_type_vars(expr, mod=None):
 
     Parameters
     ----------
-    expr: Union[tvm.relay.Expr,tvm.relay.Type]
+    expr : Union[tvm.relay.Expr,tvm.relay.Type]
         The input expression/type
-    mod: tvm.relay.Module, optional
+    mod : Optional[tvm.relay.Module]
         The global module
 
     Returns
@@ -286,12 +288,12 @@ def simplify_inference(expr):
 
     Parameters
     ----------
-    e: tvm.relay.Expr
+    expr : tvm.relay.Expr
         The input Expression
 
     Returns
     -------
-    result: tvm.relay.Expr
+    result : tvm.relay.Expr
         An expression which is semantically equal to the input expression,
         but with some simplification
     """
@@ -304,32 +306,34 @@ def canonicalize_ops(expr):
 
     Parameters
     ----------
-    e: tvm.relay.Expr
+    expr : tvm.relay.Expr
         The input Expression
 
     Returns
     -------
-    result: tvm.relay.Expr
+    result : tvm.relay.Expr
         An expression without bias_add
     """
     return _ir_pass.canonicalize_ops(expr)
 
 
-def dead_code_elimination(expr):
+def dead_code_elimination(expr, inline_once=False):
     """ Remove expressions which does not effect the program result (dead code).
 
     Parameters
     ----------
-    e: tvm.relay.Expr
+    expr : tvm.relay.Expr
         The input Expression
 
+    inline_once : Optional[Bool]
+        Whether to inline binding that occur only once.
     Returns
     -------
-    result: tvm.relay.Expr
+    result : tvm.relay.Expr
         An expression which is semantically equal to the input expression,
         but with dead code removed.
     """
-    return _ir_pass.dead_code_elimination(expr)
+    return _ir_pass.dead_code_elimination(expr, inline_once)
 
 
 def alpha_equal(lhs, rhs):
@@ -337,15 +341,15 @@ def alpha_equal(lhs, rhs):
 
     Parameters
     ----------
-    lhs: tvm.relay.Expr
+    lhs : tvm.relay.Expr
         One of the input Expression.
 
-    rhs: tvm.relay.Expr
+    rhs : tvm.relay.Expr
         One of the input Expression.
 
     Returns
     -------
-    result: bool
+    result : bool
         True iff lhs is alpha equal to rhs.
     """
     return bool(_make._alpha_equal(lhs, rhs))
@@ -359,15 +363,15 @@ def graph_equal(lhs, rhs):
 
     Parameters
     ----------
-    lhs: tvm.relay.Expr
+    lhs : tvm.relay.Expr
       One of the input Expression.
 
-    rhs: tvm.relay.Expr
+    rhs : tvm.relay.Expr
       One of the input Expression.
 
     Returns
     -------
-    result: bool
+    result : bool
       True iff lhs is data-flow equivalent to rhs.
     """
     return bool(_make._graph_equal(lhs, rhs))
@@ -378,12 +382,12 @@ def structural_hash(value):
 
     Parameters
     ----------
-    expr: tvm.relay.Expr or tvm.relay.Type
+    expr : Union[tvm.relay.Expr, tvm.relay.Type]
       The expression to hash.
 
     Returns
     -------
-    result: int
+    result : int
       The hash value
     """
     if isinstance(value, Expr):
@@ -544,12 +548,12 @@ def to_a_normal_form(expr, mod=None):
     expr : tvm.relay.Expr
         The input expression.
 
-    mod: Optional[tvm.relay.Module]
+    mod : Optional[tvm.relay.Module]
         The global module.
 
     Returns
     -------
-    expr: tvm.relay.Expr
+    result : tvm.relay.Expr
       The output expression.
     """
     return _ir_pass.to_a_normal_form(expr, mod)
@@ -563,7 +567,7 @@ def to_graph_normal_form(expr):
         The input expression
     Returns
     -------
-    expr : tvm.relay.Expr
+    result : tvm.relay.Expr
       The output expression
     """
     return _ir_pass.to_graph_normal_form(expr)
@@ -612,7 +616,7 @@ def get_total_mac_number(expr):
 
     Returns
     -------
-    ret : int64
+    result : int64
       The number of MACs (multiply-accumulate) of a model
     """
     return _ir_pass.GetTotalMacNumber(expr)
@@ -627,17 +631,17 @@ def eliminate_common_subexpr(expr, fskip=None):
     expr : tvm.relay.Expr
         The input expression.
 
-    fskip: function
+    fskip : function
         The callback function that decides whether an expression should be skipped.
 
     Returns
     -------
-    expr : tvm.relay.Expr
+    result : tvm.relay.Expr
       The output expression.
     """
     return _ir_pass.eliminate_common_subexpr(expr, fskip)
 
-def partial_evaluate(expr):
+def partial_evaluate(expr, mod=None):
     """
     Evaluate the static fragment of the code.
 
@@ -646,12 +650,15 @@ def partial_evaluate(expr):
     expr : tvm.relay.Expr
         The input expression.
 
+    mod : Optional[tvm.relay.Module]
+        The global module
+
     Returns
     -------
-    expr : tvm.relay.Expr
+    result : tvm.relay.Expr
       The output expression.
     """
-    return _ir_pass.partial_evaluate(expr)
+    return _ir_pass.partial_evaluate(expr, mod)
 
 def unmatched_cases(match, mod=None):
     """
index 6470693..e0ec10a 100644 (file)
@@ -220,8 +220,8 @@ TVM_REGISTER_API("relay._make.Call")
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
-  p->stream << "CallNode(" << node->op << ", " << node->args << ", "
-    << node->attrs << ", " << node->type_args << ")";
+    p->stream << "CallNode(" << node->op << ", " << node->args << ", "
+              << node->attrs << ", " << node->type_args << ")";
 });
 
 Let LetNode::make(Var var, Expr value, Expr body) {
@@ -324,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 
 TVM_REGISTER_API("relay._expr.TempExprRealize")
 .set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
-  return temp->Realize();
+    return temp->Realize();
 });
 
 }  // namespace relay
index be67745..7e186f8 100644 (file)
@@ -38,10 +38,10 @@ namespace relay {
 // calculate the dependency graph from expression
 class CalcDep : private ExprVisitor {
  public:
-  static Expr Eliminate(const Expr& e) {
+  static Expr Eliminate(const Expr& e, bool inline_once) {
     CalcDep cd;
     cd.Calculate(e);
-    Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_);
+    Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once);
     return el(e);
   }
 
@@ -117,15 +117,23 @@ class CalcDep : private ExprVisitor {
     VarMap<Expr> expr_map_;
     VarMap<size_t> use_map_;
     VarSet letrec_set_;
+    bool inline_once_;
     explicit Eliminator(const VarMap<Expr>& expr_map,
                         const VarMap<size_t>& use_map,
-                        const VarSet& letrec_set) :
-      expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { }
+                        const VarSet& letrec_set,
+                        bool inline_once) :
+      expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { }
     friend CalcDep;
 
     bool HasLet(const Var& v) {
-      // TODO(@jroesch): MK fix me
-      return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
+      switch (use_map_[v]) {
+      case 0:
+        return false;
+      case 1:
+        return letrec_set_.count(v) > 0 || !inline_once_;
+      default:
+        return true;
+      }
     }
 
     Expr VisitExpr_(const VarNode* op) final {
@@ -144,8 +152,8 @@ class CalcDep : private ExprVisitor {
   };
 };
 
-Expr DeadCodeElimination(const Expr& e) {
-  return CalcDep::Eliminate(e);
+Expr DeadCodeElimination(const Expr& e, bool inline_once) {
+  return CalcDep::Eliminate(e, inline_once);
 }
 
 TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
@@ -153,10 +161,10 @@ TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
 
 namespace transform {
 
-Pass DeadCodeElimination() {
+Pass DeadCodeElimination(bool inline_once) {
   runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
     [=](Function f, Module m, PassContext pc) {
-    return Downcast<Function>(DeadCodeElimination(f));
+    return Downcast<Function>(DeadCodeElimination(f, inline_once));
   };
   return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
 }
index 71ba7cd..07ec1b0 100644 (file)
  *
  * The partial evaluator makes several assumptions, so there is room for improvement:
  *
- * 0: The partial evaluator treats global variables as opaque.
- * Doing PartialEval on a module level will solve this.
- *
- * 1: The partial evaluator assume all functions as terminating.
- * We need to has a max_expand parameter that shrink on every compile time evaluation,
- * to make sure PE does not infinite loop.
- * Additionally, we might add a termination analysis pass that lift this requirement
- * for function that analysis found terminating.
- *
- * 2: Every time an unknown effect happened, we clear the whole store.
+ * 0: Every time an unknown effect happened, we clear the whole store.
  * It is too conservative: if a local reference is created (and do not get passed outside),
  * An unknown global function call/global reference write can not modify it.
  * We can pair PE with escape analysis/alias analysis.
  *
- * 3: We assume all unknown code has effect. Doing effect analysis can make the store more precise.
+ * 1: We assume all unknown code has effect. Doing effect analysis can make the store more precise.
  *
- * 4: When doing pattern matching, we can simplify the match even for dynamic case.
+ * 2: When doing pattern matching, we can simplify the match even for dynamic case.
  * Right now it is all or nothing: either a complete match, or the original dynamic code.
  * Instead, we can get a match tree, pair it with the data and evaluate it to a normal form.
  * We then can reify the result.
  *
- * 5: Every time a function is called, it's code will get expanded and partially evaluated.
+ * 3: Every time a function is called, its code will get expanded and partially evaluated.
  * We can do a binding time analysis to cache the result and avoid re-partial evaluation.
  *
  * These assumptions do not affect the correctness of the algorithm, however.
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/pattern_functor.h>
 #include <tvm/relay/interpreter.h>
+#include "../ir/type_functor.h"
 #include "pass_util.h"
 #include "let_list.h"
 
@@ -132,6 +124,8 @@ struct VarEqual {
   }
 };
 
+Expr PostProcess(const Expr&);
+
 /*! \brief The base container type of Relay values. */
 class StaticNode : public RelayNode {
  public:
@@ -150,10 +144,20 @@ class Static : public NodeRef {
   using ContainerType = StaticNode;
 };
 
+using Time = size_t;
+
 struct PStaticNode : Node {
+  static Time time() {
+    static Time time_ = 0;
+    Time ret = time_;
+    time_++;
+    return ret;
+  }
   Static pstatic;  // may be null
   Expr dynamic;
-  PStaticNode(const Static& pstatic, const Expr& dynamic) : pstatic(pstatic), dynamic(dynamic) { }
+  Time created_time;
+  PStaticNode(const Static& pstatic, const Expr& dynamic) :
+    pstatic(pstatic), dynamic(dynamic), created_time(time()) { }
   explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { }
   TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node);
 };
@@ -341,6 +345,7 @@ class Store {
 };
 
 PStatic HasStatic(const Static& stat, const Expr& dynamic) {
+  CHECK(stat.defined());
   return PStatic(make_node<PStaticNode>(stat, dynamic));
 }
 
@@ -383,15 +388,78 @@ FInterpreter CPUInterpreter() {
   return CreateInterpreter(Module(nullptr), CPUContext(), target);
 }
 
+bool IsAtomic(const Expr& e) {
+  return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
+}
+
+using FuncId = int;
+
+/*!
+ * \brief Annotate a function with a FuncId.
+ */
+struct WithFuncIdAttrs : public tvm::AttrsNode<WithFuncIdAttrs> {
+  FuncId fid;
+
+  TVM_DECLARE_ATTRS(WithFuncIdAttrs, "relay.attrs.WithFuncIdAttrs") {
+    TVM_ATTR_FIELD(fid)
+      .describe("The FuncId that an function is annotated with.")
+      .set_default(-1);
+  }
+};
+
+TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs);
+
+Op WithFuncIdOp() {
+  static const Op& op = Op::Get("annotation.with_funcid");
+  return op;
+}
+
+Expr MkWithFuncId(const Expr& expr, FuncId fid) {
+  auto attrs = make_node<WithFuncIdAttrs>();
+  attrs->fid = fid;
+  return CallNode::make(WithFuncIdOp(), {expr}, Attrs(attrs), {});
+}
+
+RELAY_REGISTER_OP("annotation.with_funcid")
+.describe(R"code(Annotate a function with a funcid.)code"
+TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.add_argument("func", "Function", "The input data.");
+
+Expr StripWithFuncId(const Expr& e);
+
+Expr DeDup(const Expr& e);
+
+Function AsFunc(const Expr& e) {
+  if (e.as<FunctionNode>()) {
+    return Downcast<Function>(e);
+  } else if (const CallNode* c = e.as<CallNode>()) {
+    CHECK(c->op.same_as(WithFuncIdOp()));
+    CHECK_EQ(c->args.size(), 1);
+    return AsFunc(c->args[0]);
+  } else {
+    LOG(FATAL) << "Unknown case";
+    throw;
+  }
+}
+
 class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>,
                          public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> {
  public:
-  PartialEvaluator(const tvm::Array<Var>& free_vars) {
+  PartialEvaluator(const tvm::Array<Var>& free_vars,
+                   const Module& mod) :
+    mod_(mod) {
     for (const Var& v : free_vars) {
       env_.Insert(v, NoStatic(v));
     }
   }
 
+  PStatic VisitExpr(const Expr& e, LetList* ll) final {
+    PStatic ret = ExprFunctor<PStatic(const Expr&, LetList*)>::VisitExpr(e, ll);
+    CHECK(IsAtomic(ret->dynamic)) << ret->dynamic;
+    return ret;
+  }
+
   PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final {
     return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef<Expr>(op)));
   }
@@ -421,7 +489,20 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
   }
 
   PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final {
-    return NoStatic(GetRef<Expr>(op));
+    GlobalVar gv = GetRef<GlobalVar>(op);
+    if (gv_map_.count(gv) == 0) {
+      if (mod_.defined()) {
+        Function func = mod_->Lookup(gv);
+        InitializeFuncId(func);
+        Func f = VisitFuncStatic(func, gv);
+        gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
+        func = AsFunc(PostProcess(VisitFuncDynamic(func, f)));
+        mod_->Update(gv, func);
+      } else {
+        gv_map_.insert({gv, NoStatic(gv)});
+      }
+    }
+    return gv_map_.at(gv);
   }
 
   PStatic VisitExpr_(const LetNode* op, LetList* ll) final {
@@ -485,6 +566,10 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
   }
 
   PStatic VisitExpr_(const CallNode* op, LetList* ll) final {
+    if (op->op.same_as(WithFuncIdOp())) {
+      CHECK_EQ(op->args.size(), 1);
+      return VisitExpr(op->args[0], ll);
+    }
     PStatic f = VisitExpr(op->op, ll);
     std::vector<PStatic> x;
     tvm::Array<Expr> x_dyn;
@@ -501,19 +586,40 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
     }
   }
 
-  PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final {
-    Function func = GetRef<Function>(op);
+  struct TimeFrame {
+    PartialEvaluator* pe_;
+    FuncId fid_;
+    std::vector<Time> old_time;
+    bool has_old_time;
+    TimeFrame(PartialEvaluator* pe,
+              FuncId fid,
+              const std::vector<Time>& args_time) : pe_(pe), fid_(fid) {
+      has_old_time = pe_->time_map_.count(fid_) > 0;
+      old_time = pe_->time_map_[fid_];
+      pe_->time_map_[fid_] = args_time;
+    }
+    ~TimeFrame() {
+      if (has_old_time) {
+        pe_->time_map_[fid_] = old_time;
+      } else {
+        pe_->time_map_.erase(fid_);
+      }
+    }
+  };
+
+  Func VisitFuncStatic(const Function& func, const Expr& var) {
+    CHECK(IsAtomic(var));
     if (func->IsPrimitive()) {
-      return HasStatic(MkSFunc(ConstEvaluateFunc(func, ll)), func);
+      return ConstEvaluateFunc(func);
     }
     std::vector<std::pair<Var, PStatic> > free_vars;
-    for (const auto& v : FreeVars(GetRef<Expr>(op))) {
+    for (const auto& v : FreeVars(func)) {
       free_vars.push_back(std::pair<Var, PStatic>(v, env_.Lookup(v)));
     }
-    Func f = [=](const std::vector<PStatic>& pv,
-                 const Attrs& attrs,
-                 const tvm::Array<Type>& type_args,
-                 LetList* ll) {
+    return [=](const std::vector<PStatic>& pv,
+               const Attrs& attrs,
+               const tvm::Array<Type>& type_args,
+               LetList* ll) {
       return env_.Extend<PStatic>([&]() {
           CHECK_EQ(pv.size(), func->params.size());
           for (size_t i = 0; i < pv.size(); ++i) {
@@ -529,10 +635,50 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
           for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
             subst.Set(func->type_params[i], Type());
           }
-          return VisitExpr(TypeSubst(func->body, subst), ll);
+          std::vector<Time> args_time;
+          for (const auto& v : pv) {
+            args_time.push_back(v->created_time);
+          }
+          CHECK_GT(func_map_.count(func), 0);
+          FuncId fid = func_map_.at(func);
+          auto recurse = [&]() {
+            TimeFrame tf(this, fid, args_time);
+            return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll);
+          };
+          if (time_map_.count(fid) == 0) {
+            return recurse();
+          } else {
+            /* We check to see that at least one argument decrease
+             * with respect to all previous invocation.
+             * The depth of the recursion is bounded by
+             * the sum of the time of all argument at the first call.
+             */
+            bool can_recurse = false;
+            std::vector<Time>& min_time = time_map_.at(fid);
+            CHECK_EQ(args_time.size(), min_time.size());
+            for (size_t i = 0; i < args_time.size(); ++i) {
+              if (args_time[i] < min_time[i]) {
+                can_recurse = true;
+              }
+              args_time[i] = std::min(args_time[i], min_time[i]);
+            }
+            if (can_recurse) {
+              return recurse();
+            } else {
+              std::vector<Expr> dyn;
+              for (const auto& v : pv) {
+                dyn.push_back(v->dynamic);
+              }
+              return NoStatic(ll->Push(CallNode::make(var, dyn, attrs, type_args)));
+            }
+          }
         });
     };
-    Expr dyn = store_.Extend<Expr>([&]() {
+  }
+
+
+  Expr VisitFuncDynamic(const Function& func, const Func& f) {
+    return store_.Extend<Expr>([&]() {
         store_.Invalidate();
         return FunctionNode::make(func->params, LetList::With([&](LetList* ll) {
               std::vector<PStatic> pv;
@@ -546,7 +692,20 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
               return f(pv, Attrs(), type_args, ll)->dynamic;
             }), func->ret_type, func->type_params, func->attrs);
       });
-    return HasStatic(MkSFunc(f), ll->Push(dyn));
+  }
+
+  PStatic VisitFunc(const Function& func, LetList* ll) {
+    Var v = VarNode::make("x", Type());
+    Func f = VisitFuncStatic(func, v);
+    Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func))));
+    // TODO(@M.K.): we seems to reduce landin knot into letrec.
+    // restore letrec support across whole relay.
+    return HasStatic(MkSFunc(f),
+                     ll->Push(v, VisitFuncDynamic(u_func, f)));
+  }
+
+  PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final {
+    return VisitFunc(GetRef<Function>(op), ll);
   }
 
   Expr Reflect(const PStatic& st) {
@@ -590,7 +749,8 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
     return Reify(executor_(fused_infered), ll);
   }
 
-  Func ConstEvaluateFunc(const Expr& expr, LetList* ll) {
+  Func ConstEvaluateFunc(const Expr& expr) {
+    CHECK_EQ(FreeVars(expr).size(), 0);
     return [=](const std::vector<PStatic>& pv,
                const Attrs& attrs,
                const tvm::Array<Type>& type_args,
@@ -599,7 +759,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
       for (const PStatic& ps : pv) {
         ns_args.push_back(ps->dynamic);
       }
-      PStatic ns = NoStatic(CallNode::make(expr, ns_args, attrs, type_args));
+      PStatic ns = NoStatic(ll->Push(CallNode::make(expr, ns_args, attrs, type_args)));
       if (StatefulOp(expr)) {
         return ns;
       }
@@ -616,7 +776,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
   }
 
   PStatic VisitExpr_(const OpNode* op, LetList* ll) final {
-    return HasStatic(MkSFunc(ConstEvaluateFunc(GetRef<Expr>(op), ll)), GetRef<Expr>(op));
+    return HasStatic(MkSFunc(ConstEvaluateFunc(GetRef<Expr>(op))), GetRef<Expr>(op));
   }
 
   PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final {
@@ -680,7 +840,6 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
       CHECK_NE(op->constructor->tag, -1);
       CHECK_NE(scn->constructor->tag, -1);
       if (op->constructor->tag == scn->constructor->tag) {
-        // todo(M.K.): should use ptr equality but it is broken
         CHECK_EQ(op->patterns.size(), scn->fields.size());
         MatchStatus current_match_status = MatchStatus::Match;
         for (size_t i = 0; i < op->patterns.size(); ++i) {
@@ -702,27 +861,119 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
     }
   }
 
+  void InitializeFuncId(const Expr& e) {
+    struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor {
+      PartialEvaluator* pe;
+      explicit InitializeFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { }
+
+      void VisitExpr_(const FunctionNode* op) final {
+        Function f = GetRef<Function>(op);
+        CHECK_EQ(pe->func_map_.count(f), 0);
+        pe->func_map_.insert({f, pe->func_map_.size()});
+        VisitExpr(f->body);
+      }
+
+      void VisitPattern(const Pattern& p) final {
+        PatternVisitor::VisitPattern(p);
+      }
+    };
+    InitializeFuncIdVisitor(this).VisitExpr(e);
+  }
+
+  Expr RegisterFuncId(const Expr& e) {
+    struct RegisterFuncIdVisitor : ExprVisitor, PatternVisitor {
+      PartialEvaluator* pe;
+      explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { }
+
+      void VisitExpr_(const CallNode* op) final {
+        if (op->op.same_as(WithFuncIdOp())) {
+          CHECK_EQ(op->args.size(), 1);
+          CHECK(op->attrs.defined());
+          CHECK(op->attrs.as<WithFuncIdAttrs>());
+          Function f = AsFunc(op->args[0]);
+          FuncId fid = op->attrs.as<WithFuncIdAttrs>()->fid;
+          if (pe->func_map_.count(f) != 0) {
+            CHECK_EQ(pe->func_map_.at(f), fid);
+          }
+          pe->func_map_.insert({f, fid});
+        }
+        ExprVisitor::VisitExpr_(op);
+      }
+
+      void VisitExpr_(const FunctionNode* op) final {
+        Function f = GetRef<Function>(op);
+        CHECK_GT(pe->func_map_.count(f), 0);
+        ExprVisitor::VisitExpr_(op);
+      }
+
+      void VisitPattern(const Pattern& p) final {
+        PatternVisitor::VisitPattern(p);
+      }
+    };
+    RegisterFuncIdVisitor(this).VisitExpr(e);
+    return e;
+  }
+
+  Expr AnnotateFuncId(const Expr& e) {
+    struct AnnotateFuncIdMutator : ExprMutator, PatternMutator {
+      PartialEvaluator* pe;
+      explicit AnnotateFuncIdMutator(PartialEvaluator* pe) : pe(pe) { }
+
+      Expr VisitExpr_(const FunctionNode* op) final {
+        Function f = GetRef<Function>(op);
+        CHECK_GT(pe->func_map_.count(f), 0);
+        return MkWithFuncId(ExprMutator::VisitExpr_(op), pe->func_map_.at(f));
+      }
+
+      Pattern VisitPattern(const Pattern& p) final {
+        return PatternMutator::VisitPattern(p);
+      }
+
+      Var VisitVar(const Var& v) final {
+        return v;
+      }
+    };
+    return AnnotateFuncIdMutator(this).VisitExpr(e);
+  }
+
  private:
   Environment env_;
+  Module mod_;
+  std::unordered_map<GlobalVar, PStatic, NodeHash, NodeEqual> gv_map_;
+  /*! Termination checking is done as follows:
+   *  We have finitely many FunctionIds.
+   *  Each FunctionId maps to a class of semantically equivalent function (ignoring type),
+   *  as both TypeSubst and DeDup create semantically equivalent function.
+   *  We partially map each FunctionId to a std::vector<Time>,
+   *  denoting the minimal TimeFrame of each argument of the function.
+   *  Every time we try to inline a Function,
+   *  we make sure it either does not have a vector<Time>, which means this is the initial call,
+   *  or some argument has a lesser time, which means some earlier argument is passed in.
+   *  In any case, we remap the mapping to a minimal vector<Time> across all previous invocations
+   *  when we PE inside the Function body.
+   *  Termination is guaranteed because the creation time of at least one argument will decrease every call.
+   */
+  std::unordered_map<Function, FuncId, NodeHash, NodeEqual> func_map_;
+  std::unordered_map<FuncId, std::vector<Time> > time_map_;
   Store store_;
   DLContext context_ = CPUContext();
   FInterpreter executor_ = CPUInterpreter();
 };
 
-Var DeDupVar(const Var& v) {
-  return VarNode::make(v->name_hint(), v->type_annotation);
-}
-
-TypeVar DeDupTypeVar(const TypeVar& tv) {
-  return TypeVarNode::make(tv->var->name_hint, tv->kind);
-}
-
 /*! \brief Use a fresh Id for every Var to make the result well-formed. */
 Expr DeDup(const Expr& e) {
-  class DeDupMutator : public ExprMutator, public PatternMutator {
+  class DeDupMutator : public TypeMutator,
+                       public ExprMutator,
+                       public PatternMutator {
    public:
+    TypeVar Fresh(const TypeVar& tv) {
+      TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind);
+      type_rename_[tv] = ret;
+      return ret;
+    }
+
     Var Fresh(const Var& v) {
-      Var ret = DeDupVar(v);
+      Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation));
       rename_[v] = ret;
       return ret;
     }
@@ -737,18 +988,27 @@ Expr DeDup(const Expr& e) {
     }
 
     Expr VisitExpr_(const LetNode* op) final {
-      return LetNode::make(Fresh(op->var), VisitExpr(op->value), VisitExpr(op->body));
+      Var v = Fresh(op->var);
+      return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
+    }
+
+    Type VisitType(const Type& t) final {
+      return t.defined() ? TypeMutator::VisitType(t) : t;
     }
 
     Expr VisitExpr_(const FunctionNode* op) final {
+      tvm::Array<TypeVar> type_params;
+      for (const TypeVar& type_param : op->type_params) {
+        type_params.push_back(Fresh(type_param));
+      }
       tvm::Array<Var> params;
       for (const Var& param : op->params) {
         params.push_back(Fresh(param));
       }
       return FunctionNode::make(params,
                                 VisitExpr(op->body),
-                                op->ret_type,
-                                op->type_params,
+                                VisitType(op->ret_type),
+                                type_params,
                                 op->attrs);
     }
 
@@ -756,14 +1016,28 @@ Expr DeDup(const Expr& e) {
       return PatternMutator::VisitPattern(p);
     }
 
+    Clause VisitClause(const Clause& c) final {
+      Pattern pat = VisitPattern(c->lhs);
+      return ClauseNode::make(pat, VisitExpr(c->rhs));
+    }
+
+    Type VisitType_(const TypeVarNode* op) final {
+      TypeVar v = GetRef<TypeVar>(op);
+      return type_rename_.count(v) != 0 ? type_rename_.at(v) : v;
+    }
+
     Var VisitVar(const Var& v) final {
       return Fresh(v);
     }
 
    private:
     std::unordered_map<Var, Var, NodeHash, NodeEqual> rename_;
+    std::unordered_map<TypeVar, TypeVar, NodeHash, NodeEqual> type_rename_;
   };
-  return DeDupMutator().VisitExpr(e);
+
+  Expr ret = DeDupMutator().VisitExpr(e);
+  CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size());
+  return ret;
 }
 
 /*! \brief Remap multiple Var sharing the same Id into the same Var. */
@@ -787,11 +1061,38 @@ Expr Remap(const Expr& e) {
   return RemapMutator().VisitExpr(e);
 }
 
-Expr PartialEval(const Expr& e) {
+Expr StripWithFuncId(const Expr& e) {
+  struct StripWithFuncIdMutator : ExprMutator, PatternMutator {
+    Expr VisitExpr_(const CallNode* op) final {
+      if (op->op.same_as(WithFuncIdOp())) {
+        CHECK_EQ(op->args.size(), 1);
+        return VisitExpr(op->args[0]);
+      } else {
+        return ExprMutator::VisitExpr_(op);
+      }
+    }
+
+    Pattern VisitPattern(const Pattern& p) final {
+      return PatternMutator::VisitPattern(p);
+    }
+
+    Var VisitVar(const Var& v) final {
+      return v;
+    }
+  };
+  return StripWithFuncIdMutator().VisitExpr(e);
+}
+
+Expr PostProcess(const Expr& e) {
+  return StripWithFuncId(DeDup(Remap(e)));
+}
+
+Expr PartialEval(const Expr& e, const Module& m) {
   return TransformF([&](const Expr& e) {
       return LetList::With([&](LetList* ll) {
-          PartialEvaluator pe(FreeVars(e));
-          return Remap(DeDup(pe.VisitExpr(e, ll)->dynamic));
+          PartialEvaluator pe(FreeVars(e), m);
+          pe.InitializeFuncId(e);
+          return PostProcess(pe.VisitExpr(e, ll)->dynamic);
         });
     }, e);
 }
@@ -804,7 +1105,7 @@ namespace transform {
 Pass PartialEval() {
   runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
     [=](Function f, Module m, PassContext pc) {
-    return Downcast<Function>(PartialEval(f));
+    return Downcast<Function>(PartialEval(f, m));
   };
   return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {});
 }
index 78fa63b..b3c0c28 100644 (file)
 import numpy as np
 import tvm
 from tvm import relay
-from tvm.relay.ir_pass import partial_evaluate, dead_code_elimination
-from tvm.relay.ir_pass import gradient, alpha_equal, infer_type
+from tvm.relay.ir_pass import partial_evaluate, alpha_equal, infer_type, dead_code_elimination
+from tvm.relay.ir_pass import gradient
 from tvm.relay import op, create_executor
 from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
 from tvm.relay.prelude import Prelude
 from tvm.relay import create_executor
-
 from nose.tools import nottest
+from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate
+from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match
+from tvm.relay import GlobalVar, Call, Type
+from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr
 
 def check_eval(expr, expected_result, mod=None, rtol=1e-07):
     ctx = tvm.context("llvm", 0)
@@ -35,24 +38,22 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07):
     np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
 
 
-def dcpe(expr):
-    return dead_code_elimination(partial_evaluate(expr))
+def dcpe(expr, mod=None):
+    return dead_code_elimination(partial_evaluate(expr, mod=mod), inline_once=True)
 
 
 def test_tuple():
-    t = relay.TypeVar("t")
-    x = relay.Var("x", t)
-    body = relay.TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
-    f = relay.Function([x], body, None, [t])
+    t = TypeVar("t")
+    x = Var("x", t)
+    body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
+    f = Function([x], body, None, [t])
     assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t]))
 
-@nottest
 def test_const_inline():
-    # TODO(MK): fix me
-    d = relay.Var("d")
-    double = relay.Function([d], d + d)
-    orig = double(relay.const(4.0))
-    assert alpha_equal(dcpe(double(relay.const(4.0))), relay.const(8.0))
+    d = Var("d")
+    double = Function([d], d + d)
+    orig = double(const(4.0))
+    assert alpha_equal(dcpe(orig), const(8.0))
 
 
 def test_ref():
@@ -60,44 +61,57 @@ def test_ref():
     r = relay.Var("r")
     x = relay.Var("x")
     body = relay.RefRead(r)
-    body = relay.Let(x, relay.RefWrite(r, relay.RefRead(r) * relay.RefRead(r)), body)
-    body = relay.Let(r, relay.RefCreate(d), body)
-    square = relay.Function([d], body)
-    assert alpha_equal(dcpe(square), relay.Function([d], d * d))
+    body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body)
+    body = Let(r, RefCreate(d), body)
+    square = Function([d], body)
+    assert alpha_equal(dcpe(square), Function([d], d * d))
+
+
+def test_empty_ad():
+    shape = (10, 10)
+    dtype = "float32"
+    t = TensorType(shape, dtype)
+    d = Var("d", t)
+    f = Function([d], d)
+    g = dcpe(gradient(f))
+    expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])]))
+    assert alpha_equal(g, expected)
 
-@nottest
 def test_ad():
-    # TODO(MK): fix me
     shape = (10, 10)
     dtype = "float32"
-    t = relay.TensorType(shape, dtype)
-    d = relay.Var("d", t)
-    f = relay.Function([d], d * d)
+    t = TensorType(shape, dtype)
+    d = Var("d", t)
+    f = Function([d], d * d)
     g = dcpe(gradient(f))
     m = d * d
-    o = relay.op.ones_like(m)
-    grad = relay.op.zeros_like(d) + relay.op.collapse_sum_like(o * d, d) + relay.op.collapse_sum_like(o * d, d)
-    expected = relay.Function([d], relay.Tuple([m, relay.Tuple([grad])]))
+    x = relay.Var("x")
+    o = op.ones_like(x)
+    x1 = relay.Var("x1")
+    grad = op.zeros_like(d) + op.collapse_sum_like(x1 * d, d) + op.collapse_sum_like(x1 * d, d)
+    body = Tuple([x, Tuple([grad])])
+    body = relay.Let(x1, o, body)
+    expected = Function([d], relay.Let(x, m, body))
     assert alpha_equal(g, expected)
 
 
 def test_if_ref():
     shape = ()
     dtype = "bool"
-    t = relay.TensorType(shape, dtype)
-    d = relay.Var("d", t)
-    r = relay.Var("r")
-    update = relay.Function([], relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r)))
-    u = relay.Var("u")
-    body = relay.If(d, u(), u())
-    eff = relay.Var("eff")
-    body = relay.Let(eff, body, relay.RefRead(r))
-    f = relay.Function([d], relay.Let(r, relay.RefCreate(relay.const(1)), relay.Let(u, update, body)))
+    t = TensorType(shape, dtype)
+    d = Var("d", t)
+    r = Var("r")
+    update = Function([], RefWrite(r, RefRead(r) + RefRead(r)))
+    u = Var("u")
+    body = If(d, u(), u())
+    eff = Var("eff")
+    body = Let(eff, body, RefRead(r))
+    f = Function([d], Let(r, RefCreate(const(1)), Let(u, update, body)))
     f = infer_type(f)
     pe_f = infer_type(partial_evaluate(f))
     ex = create_executor()
-    f_res = ex.evaluate(f)(relay.const(True))
-    pe_f_res = ex.evaluate(pe_f)(relay.const(True))
+    f_res = ex.evaluate(f)(const(True))
+    pe_f_res = ex.evaluate(pe_f)(const(True))
     np.testing.assert_allclose(f_res.asnumpy(), 2 * np.ones_like(f_res.asnumpy()))
     np.testing.assert_allclose(pe_f_res.asnumpy(), 2 * np.ones_like(pe_f_res.asnumpy()))
 
@@ -105,52 +119,162 @@ def test_if_ref():
 def test_function_invalidate():
     shape = ()
     dtype = "bool"
-    t = relay.TensorType(shape, dtype)
-    d = relay.Var("d", t)
-    r = relay.Var("r")
-    fetch = relay.Function([], relay.RefRead(r))
-    fet = relay.Var("fetch")
-    fet_obscured = relay.Var("fetch_obscured")
-    u = relay.Var("u")
-    body = relay.If(d, fet_obscured(), fet_obscured())
-    body = relay.Let(u, relay.RefWrite(r, relay.const(1)), body)
-    body = relay.Let(fet_obscured, relay.If(d, fet, fet), body)
-    body = relay.Let(fet, fetch, body)
-    body = relay.Let(r, relay.RefCreate(relay.const(0)), body)
-    f = relay.Function([d], body)
+    t = TensorType(shape, dtype)
+    d = Var("d", t)
+    r = Var("r")
+    fetch = Function([], RefRead(r))
+    fet = Var("fetch")
+    fet_obscured = Var("fetch_obscured")
+    u = Var("u")
+    body = If(d, fet_obscured(), fet_obscured())
+    body = Let(u, RefWrite(r, const(1)), body)
+    body = Let(fet_obscured, If(d, fet, fet), body)
+    body = Let(fet, fetch, body)
+    body = Let(r, RefCreate(const(0)), body)
+    f = Function([d], body)
     f = infer_type(f)
     pe_f = infer_type(partial_evaluate(f))
     ex = create_executor()
-    f_res = ex.evaluate(f)(relay.const(True))
-    pe_f_res = ex.evaluate(pe_f)(relay.const(True))
+    f_res = ex.evaluate(f)(const(True))
+    pe_f_res = ex.evaluate(pe_f)(const(True))
     np.testing.assert_allclose(f_res.asnumpy(), np.ones_like(f_res.asnumpy()))
     np.testing.assert_allclose(pe_f_res.asnumpy(), np.ones_like(pe_f_res.asnumpy()))
 
 
 def test_head_cons():
-    mod = relay.Module()
+    mod = Module()
     p = Prelude(mod)
     def hd_impl():
-        a = relay.TypeVar("a")
-        x = relay.Var("x", p.l(a))
-        y = relay.Var("y")
-        z = relay.Var("z")
-        cons_case = relay.Clause(relay.PatternConstructor(p.cons,
-                                                          [relay.PatternVar(y),
-                                                           relay.PatternVar(z)]),
-                                 y)
-        return relay.Function([x], relay.Match(x, [cons_case]), a, [a])
-    t = relay.TypeVar("t")
-    x = relay.Var("x", t)
-    hd = relay.Var("hd")
-    body = relay.Let(hd, hd_impl(), hd(p.cons(x, p.nil())))
-    f = relay.Function([x], body, None, [t])
+        a = TypeVar("a")
+        x = Var("x", p.l(a))
+        y = Var("y")
+        z = Var("z")
+        cons_case = Clause(PatternConstructor(p.cons,
+                                              [PatternVar(y),
+                                               PatternVar(z)]),
+                           y)
+        y = Var("y")
+        z = Var("z")
+        return Function([x], Match(x, [cons_case]), a, [a])
+    t = TypeVar("t")
+    x = Var("x", t)
+    hd = Var("hd")
+    body = Let(hd, hd_impl(), hd(p.cons(x, p.nil())))
+    f = Function([x], body, None, [t])
     f = infer_type(f, mod=mod)
     res = dcpe(f)
-    assert alpha_equal(res, relay.Function([x], x, t, [t]))
+    assert alpha_equal(res, Function([x], x, t, [t]))
+
+
+def test_map():
+    mod = Module()
+    p = Prelude(mod)
+    f = Var("f")
+    orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil()))))
+    expected = p.cons(f(const(1)), p.cons(f(const(2)), p.cons(f(const(3)), p.nil())))
+    assert alpha_equal(dcpe(orig, mod=mod), expected)
+
+
+def test_loop():
+    mod = Module()
+    t = TypeVar("t")
+    x = Var("x", t)
+    loop = GlobalVar("loop")
+    mod[loop] = Function([x], loop(x), t, [t])
+    res = dcpe(loop(const(1)), mod=mod)
+    expected = Call(loop, [const(1)], None, [None])
+    assert alpha_equal(res, expected)
+
+
+def test_swap_loop():
+    mod = Module()
+    p = Prelude(mod)
+    add_nat_definitions(p)
+    nat = p.nat()
+    x = Var("x", nat)
+    y = Var("y", nat)
+    loop = GlobalVar("loop")
+    mod[loop] = Function([x, y], loop(y, x), nat)
+    prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2))
+    res = dcpe(prog, mod=mod)
+    assert alpha_equal(prog, res)
+
+
+def test_abs_diff():
+    # TODO(@M.K.): refactor using tuple pattern (not yet implemented)
+    mod = Module()
+    p = Prelude(mod)
+    add_nat_definitions(p)
+    nat = p.nat()
+    x = Var("x", nat)
+    y = Var("y", nat)
+    xp = Var("x'", nat)
+    yp = Var("y'", nat)
+    diff = GlobalVar("diff")
+    y_z_case = Clause(PatternConstructor(p.z, []), x)
+    y_s_case = Clause(PatternConstructor(p.s, [PatternVar(yp)]), diff(yp, xp))
+    x_z_case = Clause(PatternConstructor(p.z, []), y)
+    x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case]))
+    mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case]))
+    orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
+    res = dcpe(orig, mod=mod)
+    assert alpha_equal(res, make_nat_expr(p, 4))
+
+
+def test_match_nat_id():
+    mod = Module()
+    p = Prelude(mod)
+    add_nat_definitions(p)
+    nat = p.nat()
+    x = Var("x", nat)
+    y = Var("y", nat)
+    nat_id = GlobalVar("nat_id")
+    z_case = Clause(PatternConstructor(p.z, []), p.z())
+    s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y))
+    mod[nat_id] = Function([x], Match(x, [z_case, s_case]))
+    orig = nat_id(make_nat_expr(p, 3))
+    res = dcpe(orig, mod=mod)
+    assert alpha_equal(res, make_nat_expr(p, 3))
+
+
+def test_nat_id():
+    mod = Module()
+    p = Prelude(mod)
+    add_nat_definitions(p)
+    nat = p.nat()
+    x = Var("x", nat)
+    y = Var("y", nat)
+    nat_id = GlobalVar("nat_id")
+    mod[nat_id] = Function([x], x)
+    orig = nat_id(make_nat_expr(p, 3))
+    res = dcpe(orig, mod=mod)
+    assert alpha_equal(res, make_nat_expr(p, 3))
+
+
+def test_global_match_nat_id():
+    mod = Module()
+    p = Prelude(mod)
+    add_nat_definitions(p)
+    nat = p.nat()
+    x = Var("x", nat)
+    z_case = Clause(PatternConstructor(p.z, []), p.z())
+    s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x))
+    orig = Match(make_nat_expr(p, 3), [z_case, s_case])
+    res = dcpe(orig, mod=mod)
+    assert alpha_equal(res, make_nat_expr(p, 3))
+
+
+def test_double():
+    mod = Module()
+    p = Prelude(mod)
+    add_nat_definitions(p)
+    orig = p.double(make_nat_expr(p, 3))
+    res = dcpe(orig, mod=mod)
+    assert alpha_equal(res, make_nat_expr(p, 6))
 
 
 if __name__ == '__main__':
+    test_empty_ad()
     test_tuple()
     test_const_inline()
     test_ref()
@@ -158,3 +282,11 @@ if __name__ == '__main__':
     test_if_ref()
     test_function_invalidate()
     test_head_cons()
+    test_map()
+    test_loop()
+    test_swap_loop()
+    test_abs_diff()
+    test_double()
+    test_nat_id()
+    test_global_match_nat_id()
+    test_match_nat_id()