[Relay] Make check stricter by using Feature. Fixed multiple bugs. (#6326)
author雾雨魔理沙 <lolisa@marisa.moe>
Tue, 25 Aug 2020 18:40:34 +0000 (11:40 -0700)
committerGitHub <noreply@github.com>
Tue, 25 Aug 2020 18:40:34 +0000 (11:40 -0700)
* save

lint

lint

lint

fix lint

lint

update

lint

save

save

save

lint

format

format

save

save

fix

use a form more suitable for numeric check

save

* save

* save

* lint

* save

* lint

* fix

* fix

17 files changed:
include/tvm/relay/feature.h
include/tvm/relay/transform.h
python/tvm/relay/prelude.py
python/tvm/relay/transform/transform.py
src/relay/analysis/feature.cc
src/relay/transforms/gradient.cc
src/relay/transforms/lazy_gradient_init.cc
src/relay/transforms/partial_eval.cc
src/relay/transforms/pass_util.h
src/relay/transforms/to_a_normal_form.cc
src/relay/transforms/to_cps.cc
tests/python/relay/test_analysis_feature.py
tests/python/relay/test_op_grad_level10.py
tests/python/relay/test_pass_gradient.py
tests/python/relay/test_pass_lazy_gradient_init.py
tests/python/relay/test_pass_merge_composite.py
tests/python/relay/test_pass_partial_eval.py

index 3783e32..7df8819 100644 (file)
@@ -29,6 +29,7 @@
 #include <tvm/relay/expr.h>
 
 #include <bitset>
+#include <string>
 
 namespace tvm {
 namespace relay {
@@ -124,6 +125,11 @@ class FeatureSet {
    */
   bool is_subset_of(const FeatureSet& rhs) const { return ((*this) - rhs).bs_.none(); }
 
+  /*!
+   * \brief return a string representation.
+   */
+  std::string ToString() const;
+
  private:
   std::bitset<feature_count> bs_;
   FeatureSet() = default;
@@ -160,6 +166,34 @@ inline FeatureSet DetectFeature(const Expr& expr, const IRModule& mod) {
   return DetectFeature(expr) + DetectFeature(mod);
 }
 
+/*!
+ * \brief Check the feature of the program.
+ *
+ * \param expr The expression.
+ * \param fs The feature set of the program.
+ */
+void CheckFeature(const RelayExpr& expr, const FeatureSet& fs);
+
+/*!
+ * \brief Check the feature of the program.
+ *
+ * \param mod The module.
+ * \param fs The feature set of the program.
+ */
+void CheckFeature(const IRModule& mod, const FeatureSet& fs);
+
+/*!
+ * \brief Check the feature of the program.
+ *
+ * \param expr The expression.
+ * \param mod The module.
+ * \param fs The feature set of the program.
+ */
+inline void CheckFeature(const RelayExpr& expr, const IRModule& mod, const FeatureSet& fs) {
+  CheckFeature(expr, fs);
+  CheckFeature(mod, fs);
+}
+
 }  // namespace relay
 }  // namespace tvm
 
index d322710..de2bcc4 100644 (file)
@@ -148,6 +148,15 @@ TVM_DLL Pass ToBasicBlockNormalForm();
 TVM_DLL Pass ToANormalForm();
 
 /*!
+ * \brief ToANormalForm but on incomplete graph.
+ *
+ * \param expr the graph.
+ *
+ * \return The transformed program.
+ */
+TVM_DLL Expr ToANormalForm(const Expr& expr);
+
+/*!
  * \brief Turn an expression into continuation passing style(CPS).
  *
  * CPS mean that every function will, instead of returning the result directly,
index 1b7ed77..893c855 100644 (file)
@@ -17,6 +17,7 @@
 # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
 """A prelude containing useful global functions and ADT definitions."""
 from tvm.ir import IRModule, TypeCall
+from tvm.relay.transform import ToANormalFormExpr
 
 from .ty import GlobalTypeVar, TensorType, Any, scalar_type
 from .expr import Var, GlobalVar, If, const
@@ -204,7 +205,6 @@ class StaticTensorArrayOps(object):
         self.prelude.mod[concat_var] = \
             Function([x, y], Match(x, [case], False), tensor_type_var(), [])
 
-
     def define_tensor_expand_dims(self):
         """Defines a function to grow a tensor_t's rank by adding one dimension in front
         of the original tensor_t.
@@ -511,8 +511,9 @@ class StaticTensorArrayOps(object):
                                      self.prelude.hd(tensor_array_expand_dims),
                                      self.prelude.tl(tensor_array_expand_dims))
         output_tensor_type_var, _ = self._get_adt_by_shape(output_shape)
-        self.prelude.mod[stack_var] = Function([tensor_array], tensors,
-                                               output_tensor_type_var(), [])
+        self.prelude.mod[stack_var] = \
+            Function([tensor_array], tensors,
+                     output_tensor_type_var(), [])
 
     def define_tensor_array_gather(self):
         """Defines a function to return the selected values in a tensor array as tensor_t.
@@ -809,7 +810,7 @@ class TensorArrayOps(object):
                                                tensor4_var(op.concatenate([t41, t42], axis=0)))],
                                     False))
         # op.concatenate does not support tensor with rank higher than 4
-        self.prelude.mod[concat_var] =\
+        self.prelude.mod[concat_var] = \
             Function([x, y], Match(x, [tensor1_case,
                                        tensor2_case,
                                        tensor3_case,
@@ -1167,7 +1168,7 @@ class TensorArrayOps(object):
         current = Var("current", scalar_type('int32'))
         limit = Var("limit", scalar_type('int32'))
         indices_ = Var('indices_', TensorType([Any()], 'int32'))
-        helper_body =\
+        helper_body = \
             If(equal(current, const(0)),
                stack_var(accu),
                helper_var(
@@ -1187,7 +1188,7 @@ class TensorArrayOps(object):
         indices_shape = op.shape_of(indices)
         limit = op.take(indices_shape, const(0))
         body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices)
-        self.prelude.mod[gather_var] =\
+        self.prelude.mod[gather_var] = \
             Function([tensor_array, indices], body, tensor_type_var(), [])
 
     def define_tensor_array_stack(self):
@@ -1205,7 +1206,8 @@ class TensorArrayOps(object):
         tensors = self.prelude.foldl(concat_var,
                                      self.prelude.hd(tensor_array_expand_dims),
                                      self.prelude.tl(tensor_array_expand_dims))
-        self.prelude.mod[stack_var] = Function([tensor_array], tensors, tensor_type_var(), [])
+        self.prelude.mod[stack_var] = \
+            Function([tensor_array], ToANormalFormExpr(tensors), tensor_type_var(), [])
 
     def register(self):
         """Register all tensor array ops in Prelude"""
index cc92141..de3f986 100644 (file)
@@ -510,11 +510,26 @@ def ToANormalForm():
 
     Returns
     -------
-    ret: Union[tvm.transform.Pass, tvm.relay.Expr]
+    ret : Union[tvm.transform.Pass, tvm.relay.Expr]
         The registered pass that transforms an expression into A Normal Form.
     """
     return _ffi_api.ToANormalForm()
 
+def ToANormalFormExpr(e):
+    """ToANormalForm, but on expression level.
+
+    Parameters
+    ----------
+    e : Expr
+        The graph expression.
+
+    Returns
+    -------
+    ret : Expr
+        The transformed expresion.
+    """
+    return _ffi_api.ToANormalFormExpr(e)
+
 def ToBasicBlockNormalForm():
     """Turn an expression to Basic Block Normal Form.
     We define a block as a group of expressions implied by the scope structure.
index a145b28..63f5e71 100644 (file)
@@ -86,12 +86,43 @@ FeatureSet DetectFeature(const Expr& expr) {
   return fd.fs;
 }
 
+std::string FeatureSet::ToString() const {
+  std::string ret;
+  ret += "[";
+  size_t detected = 0;
+#define DETECT_FEATURE(FEATURE_NAME) \
+  ++detected;                        \
+  if (bs_[FEATURE_NAME]) {           \
+    ret += #FEATURE_NAME;            \
+    ret += ", ";                     \
+  }
+  DETECT_FEATURE(fVar);
+  DETECT_FEATURE(fGlobalVar);
+  DETECT_FEATURE(fConstant);
+  DETECT_FEATURE(fTuple);
+  DETECT_FEATURE(fTupleGetItem);
+  DETECT_FEATURE(fFunction);
+  DETECT_FEATURE(fOp);
+  DETECT_FEATURE(fCall);
+  DETECT_FEATURE(fLet);
+  DETECT_FEATURE(fIf);
+  DETECT_FEATURE(fRefCreate);
+  DETECT_FEATURE(fRefRead);
+  DETECT_FEATURE(fRefWrite);
+  DETECT_FEATURE(fConstructor);
+  DETECT_FEATURE(fMatch);
+  DETECT_FEATURE(fGraph);
+  DETECT_FEATURE(fLetRec);
+#undef DETECT_FEATURE
+  CHECK(detected == feature_count) << "some feature not printed";
+  ret += "]";
+  return ret;
+}
+
 FeatureSet DetectFeature(const IRModule& mod) {
   FeatureSet fs = FeatureSet::No();
-  if (mod.defined()) {
-    for (const auto& f : mod->functions) {
-      fs += DetectFeature(f.second);
-    }
+  for (const auto& f : mod->functions) {
+    fs += DetectFeature(f.second);
   }
   return fs;
 }
@@ -106,5 +137,17 @@ Array<Integer> PyDetectFeature(const Expr& expr, const Optional<IRModule>& mod)
 
 TVM_REGISTER_GLOBAL("relay.analysis.detect_feature").set_body_typed(PyDetectFeature);
 
+void CheckFeature(const Expr& expr, const FeatureSet& fs) {
+  auto dfs = DetectFeature(expr);
+  CHECK(dfs.is_subset_of(fs)) << AsText(expr, false)
+                              << "\nhas unsupported feature: " << (dfs - fs).ToString();
+}
+
+void CheckFeature(const IRModule& mod, const FeatureSet& fs) {
+  for (const auto& f : mod->functions) {
+    CheckFeature(f.second, fs);
+  }
+}
+
 }  // namespace relay
 }  // namespace tvm
index 0cebba7..7894c34 100644 (file)
@@ -24,6 +24,7 @@
 #include <tvm/ir/type_functor.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/feature.h>
 #include <tvm/relay/transform.h>
 #include <tvm/te/operation.h>
 
@@ -81,7 +82,7 @@ Type WithGradientType(const Type& t) {
 Expr DeGlobal(const Optional<IRModule>& mod, const Expr& e) {
   const auto* x = e.as<GlobalVarNode>();
 
-  if (mod.defined() && (x)) {
+  if (mod.defined() && x) {
     BaseFunc base_func = mod.value()->Lookup(GetRef<GlobalVar>(x));
     if (auto* n = base_func.as<FunctionNode>()) {
       return n->body;
@@ -354,9 +355,9 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
                 LetList* ll) {
   CHECK(IsAtomic(e)) << e;
   if (forward_type.as<TensorTypeNode>()) {
-    auto ret = f(e);
+    auto ret = ll->Push(f(e));
     ret->checked_type_ = tf(forward_type);
-    return ret;
+    return std::move(ret);
   } else if (auto* tt = forward_type.as<TupleTypeNode>()) {
     tvm::Array<Expr> fields;
     tvm::Array<Type> types;
@@ -365,7 +366,7 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
       fields.push_back(field);
       types.push_back(field->checked_type_);
     }
-    auto ret = Tuple(fields);
+    auto ret = ll->Push(Tuple(fields));
     ret->checked_type_ = TupleType(types);
     return std::move(ret);
   } else {
@@ -395,9 +396,10 @@ void TransferGrads(const Type& forward_type, const Expr& from, const Expr& to, L
   }
 }
 
+// TODO(@M.K.): why take Expr?
 /*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
 Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) {
-  auto rev = [&](const Expr& e) { return Pair(e, ll->Push(RefCreate(ZerosLike(e)))); };
+  auto rev = [&](const Expr& e) { return Pair(e, RefCreate(ZerosLike(e))); };
   auto rev_type = [&](const Type& forward_type) { return ReverseType(forward_type); };
   return LiftTensor(rev, rev_type, forward_type, e, ll);
 }
@@ -411,14 +413,14 @@ Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) {
 
 /*! \brief ReverseType(t) -> t. Get the gradient. */
 Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) {
-  auto grad = [&](const Expr& e) { return ll->Push(RefRead(GetField(e, 1))); };
+  auto grad = [&](const Expr& e) { return RefRead(GetField(e, 1)); };
   auto grad_type = [&](const Type& forward_type) { return forward_type; };
   return LiftTensor(grad, grad_type, forward_type, e, ll);
 }
 
 void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
   if (t.as<TensorTypeNode>()) {
-    ll->Push(RefWrite(GetField(arg, 1), Add(ll->Push(RefRead(GetField(arg, 1))), grad)));
+    ll->Push(RefWrite(GetField(arg, 1), Add(RefRead(GetField(arg, 1)), grad)));
   } else if (auto* tt = t.as<TupleTypeNode>()) {
     for (size_t i = 0; i < tt->fields.size(); ++i) {
       UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)), ll->Push(GetField(grad, i)), ll);
@@ -448,6 +450,24 @@ struct ReverseAD : ExprMutator {
     throw;
   }
 
+  Expr Remap(const Expr& e) {
+    struct Remapper : ExprMutator {
+      std::shared_ptr<ADVarMap> ad_vars;
+      LetList* ll;
+      Remapper(const std::shared_ptr<ADVarMap>& ad_vars, LetList* ll) : ad_vars(ad_vars), ll(ll) {}
+      Expr VisitExpr_(const VarNode* var) final {
+        // memoize Var -> ADVar so we don't end up with free Vars when checkpointing
+        auto var_ref = GetRef<Var>(var);
+        if (ad_vars->count(var_ref) == 0) {
+          return std::move(var_ref);
+        } else {
+          return GetValue(var_ref->checked_type(), ad_vars->at(var_ref), ll);
+        }
+      }
+    };
+    return LetList::With([&](LetList* ll) { return Remapper(ad_vars, ll)(e); });
+  }
+
   Expr VisitCheckpoint(const CallNode* call) {
     const OpNode* op_node = call->op.as<OpNode>();
     CHECK(op_node) << "expected op in call";
@@ -455,7 +475,7 @@ struct ReverseAD : ExprMutator {
     CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation";
     auto x = call->args[0];
     return LetList::With([&](LetList* ll) {
-      auto x_var = ll->Push(x);
+      auto x_var = ll->Push(Remap(x));
       auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
       auto bpv = ll->Push(RefRead(bp));
       Expr nbp = Function({}, LetList::With([&](LetList* ll) {
@@ -508,7 +528,8 @@ struct ReverseAD : ExprMutator {
                               return Call(bpv, {});
                             }),
                             TupleType::Empty(), {});
-        ll->Push(RefWrite(bp, nbp));
+        ll->Push(RefWrite(bp, transform::ToANormalForm(nbp)));
+        // TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that.
         return ret;
       });
     }
@@ -516,8 +537,10 @@ struct ReverseAD : ExprMutator {
   }
 
   Expr VisitExpr_(const ConstantNode* op) final {
-    Expr e = GetRef<Expr>(op);
-    return Pair(e, RefCreate(ZerosLike(e)));
+    return LetList::With([&](LetList* ll) {
+      Expr e = ll->Push(GetRef<Expr>(op));
+      return Pair(e, RefCreate(ZerosLike(e)));
+    });
   }
 
   Expr VisitExpr_(const IfNode* op) final {
@@ -528,7 +551,7 @@ struct ReverseAD : ExprMutator {
   Expr VisitExpr_(const VarNode* var) final {
     // memoize Var -> ADVar so we don't end up with free Vars when checkpointing
     auto var_ref = GetRef<Var>(var);
-    if (!ad_vars->count(var_ref)) {
+    if (ad_vars->count(var_ref) == 0) {
       auto res = Downcast<Var>(ExprMutator::VisitExpr_(var));
       (*ad_vars)[var_ref] = res;
     }
@@ -568,6 +591,10 @@ bool MissingGrad(const Expr& e) {
 }
 
 Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
+  CheckFeature(re, FeatureSet::All() - fGraph);
+  if (mod.defined()) {
+    CheckFeature(mod.value(), FeatureSet::All() - fGraph);
+  }
   auto e = DeGlobal(mod, re);
   auto f = e.as<FunctionNode>();
   CHECK(f) << "input need to be a function";
@@ -619,7 +646,9 @@ Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
     };
     return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret));
   });
-  return Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
+  auto ret = Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
+  CheckFeature(ret, FeatureSet::All() - fGraph);
+  return std::move(ret);
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.gradient").set_body_typed(Gradient);
index f062466..de9406e 100644 (file)
@@ -63,6 +63,7 @@
 #include <tvm/node/structural_equal.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/feature.h>
 #include <tvm/relay/transform.h>
 
 #include "let_list.h"
 namespace tvm {
 namespace relay {
 
-/*!
- * \brief Visitor appropriately wraps tensors with Raw constructor
- *
- * Recursively looks at the type of the expression (TensorType or TupleType are only supported for
- * now) and either call the GradCell constructor if TensorType or unfold and recursively visit if
- * TupleType
- */
-class InputVisitor : public ExprFunctor<Expr(const Expr&, const Type&)> {
+class LazyGradientInitializer : public ExprMutator, public TypeMutator {
  public:
-  explicit InputVisitor(IRModule module) : module_(module) {}
-
-  Expr VisitExpr_(const VarNode* op, const Type& t) final {
-    std::cout << op->type_annotation << std::endl;
-    return WrapExpr(GetRef<Var>(op), op->type_annotation);
-  }
-
-  Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final {
-    return WrapExpr(GetRef<TupleGetItem>(op), t);
+  explicit LazyGradientInitializer(IRModule module) : module_(module) {
+    module_->ImportFromStd("gradient.rly");
   }
 
- private:
-  IRModule module_;
-
-  Expr WrapExpr(const Expr expr, const Type& type) {
+  Expr WrapExpr(const Var& var, const Type& type, LetList* ll) {
     if (type.as<TensorTypeNode>()) {
-      return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type});
+      return Call(module_->GetConstructor("GradCell", "Raw"), {var}, Attrs(), {type});
     } else if (auto* type_anno = type.as<TupleTypeNode>()) {
       tvm::Array<Expr> fields;
       for (size_t i = 0; i < type_anno->fields.size(); i++) {
         const Type& t = type_anno->fields[i];
-        fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t));
+        fields.push_back(WrapExpr(ll->Push(TupleGetItem(var, i)), t, ll));
       }
       Expr tuple = Tuple(fields);
       return tuple;
     }
 
-    return expr;
-  }
-};
-
-/*!
- * \brief Visitor appropriately unwraps expressions with GradCell type into Tensors
- *
- * Recursively looks at the type of the expression
- * and either use the FromGradCell function if TypeCall to GradCell
- * or unfold and recursively visit if TupleType
- */
-class OutputVisitor : public ExprFunctor<Expr(const Expr&, const Type&)> {
- public:
-  explicit OutputVisitor(IRModule module) : module_(module) {}
-
-  Expr VisitExpr_(const CallNode* op, const Type& t) final {
-    return UnwrapExpr(GetRef<Call>(op), t);
+    return var;
   }
 
-  Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final {
-    return UnwrapExpr(GetRef<TupleGetItem>(op), t);
-  }
-
- private:
-  IRModule module_;
-
-  Expr UnwrapExpr(const Expr expr, const Type& type) {
+  Expr UnwrapExpr(const Var& var, const Type& type, LetList* ll) {
     if (auto* type_call = type.as<TypeCallNode>()) {
       if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell"))) {
-        return Call(module_->GetGlobalVar("FromGradCell"), {expr});
+        return Call(module_->GetGlobalVar("FromGradCell"), {var});
       }
-      return expr;
+      return var;
     } else if (auto* type_anno = type.as<TupleTypeNode>()) {
       tvm::Array<Expr> fields;
       for (size_t i = 0; i < type_anno->fields.size(); i++) {
         const Type& t = type_anno->fields[i];
-        fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t));
+        fields.push_back(UnwrapExpr(ll->Push(TupleGetItem(var, i)), t, ll));
       }
       Expr tuple = Tuple(fields);
       return tuple;
     }
 
-    return expr;
+    return var;
   }
-};
 
-class LazyGradientInitializer : public ExprMutator, public TypeMutator {
- public:
-  explicit LazyGradientInitializer(IRModule module) : module_(module) {
-    module_->ImportFromStd("gradient.rly");
+  // Turn off memo for constant node.
+  Expr VisitExpr(const Expr& e) final {
+    if (e.as<ConstantNode>()) {
+      return ExprFunctor::VisitExpr(e);
+    } else {
+      return ExprMutator::VisitExpr(e);
+    }
   }
 
   /*!
@@ -165,23 +128,26 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator {
    * input/output types should only be a combination of TupleTypes and TensorTypes
    */
   Expr Transform(const Expr& e) {
-    auto* f = (e).as<FunctionNode>();
+    auto* f = e.as<FunctionNode>();
     auto* transformed = this->Mutate(e).as<FunctionNode>();
 
+    CHECK(f);
+    CHECK(transformed);
+
     if (e.same_as(GetRef<Function>(transformed))) {
       return GetRef<Function>(transformed);
     }
 
-    // wrap inputs of Tensor type using InputVisitor class
-    tvm::Array<Expr> args;
-    for (Var var : f->params) {
-      Expr wrappedInput = InputVisitor(module_).VisitExpr(var, var->checked_type());
-      args.push_back(wrappedInput);
-    }
-    Expr transformedExpr = Call(GetRef<Function>(transformed), args);
-
-    // unwrap outputs of GradCell type into Tensor type using OutputVisitor class
-    Expr tensorOutput = OutputVisitor(module_).VisitExpr(transformedExpr, transformed->ret_type);
+    auto tensorOutput = LetList::With([&](LetList* ll) {
+      // wrap inputs of Tensor type using InputVisitor class
+      tvm::Array<Expr> args;
+      for (const Var& var : f->params) {
+        args.push_back(WrapExpr(var, var->checked_type(), ll));
+      }
+      Expr transformedExpr = Call(GetRef<Function>(transformed), args);
+      // unwrap outputs of GradCell type into Tensor type using OutputVisitor class
+      return UnwrapExpr(ll->Push(transformedExpr), transformed->ret_type, ll);
+    });
     return Function(f->params, tensorOutput, f->ret_type, Array<TypeVar>());
   }
 
@@ -293,7 +259,10 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator {
 };
 
 Expr LazyGradientInit(const Expr& e, IRModule mod) {
-  return LazyGradientInitializer(mod).Transform(e);
+  CheckFeature(e, mod, FeatureSet::All() - fGraph);
+  auto ret = LazyGradientInitializer(mod).Transform(e);
+  CheckFeature(ret, mod, FeatureSet::All() - fGraph);
+  return ret;
 }
 
 namespace transform {
index 63bd04d..e07dbea 100644 (file)
@@ -92,6 +92,7 @@
 #include <tvm/ir/type_functor.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/feature.h>
 #include <tvm/relay/interpreter.h>
 #include <tvm/relay/pattern_functor.h>
 #include <tvm/relay/transform.h>
@@ -1181,6 +1182,7 @@ Expr PostProcess(const Expr& e) { return StripWithFuncId(DeDup(Remap(e))); }
 }  // namespace partial_eval
 
 IRModule PartialEval(const IRModule& m) {
+  CheckFeature(m, FeatureSet::All() - fGraph);
   relay::partial_eval::PartialEvaluator pe(m);
   std::vector<GlobalVar> gvs;
   for (const auto& p : m->functions) {
@@ -1189,6 +1191,7 @@ IRModule PartialEval(const IRModule& m) {
   for (const auto& gv : gvs) {
     pe.VisitGlobalVar(gv);
   }
+  CheckFeature(m, FeatureSet::All() - fGraph);
   return m;
 }
 
@@ -1197,7 +1200,7 @@ namespace transform {
 Pass PartialEval() {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
       [=](IRModule m, PassContext pc) { return relay::PartialEval(m); };
-  return CreateModulePass(pass_func, 1, "PartialEvaluate", {});
+  return CreateModulePass(pass_func, 1, "PartialEval", {});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate").set_body_typed(PartialEval);
index 50d0fbb..63708c4 100644 (file)
@@ -117,7 +117,8 @@ inline Expr TransformF(const std::function<Expr(const Expr&)>& func, const Expr&
  *   if so, the compute cost of the expression is bounded so it can be copy without graph mode.
  */
 inline bool IsAtomic(const Expr& e) {
-  return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
+  return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>() ||
+         e.as<ConstantNode>();  // Constant is always by reference.
 }
 
 /*!
index 06e0d56..adb757b 100644 (file)
@@ -252,32 +252,6 @@ Expr Fill::VisitExpr_(const MatchNode* m, const Var& v) {
   return Compound(e, Match(data, clauses, m->complete), v);
 }
 
-Expr ToANormalFormAux(const Expr& e) {
-  /* When you lift a lambda, what is inside is also being lift.
-   *
-   * So we must determine the scope of the lambda before determining the scope of it's body.
-   *
-   * To make this more principled,
-   * we always determine the scope of parent before determining the scope of children.
-   *
-   * So we calculate all the dependency between nodes.
-   */
-  support::Arena arena;
-  DependencyGraph dg = DependencyGraph::Create(&arena, e);
-  /* In order to model new subscopes created by lambda, if else and pattern matching,
-   * we also assign scope to edge as well.
-   * The scope of an edge is either the parent's scope, or a new subscope of the parent's scope.
-   *
-   * So, the scope of the whole expr is global.
-   * The scope of any subexpr, is the lowest common ancestor of all incoming edge.
-   *
-   * Every scope additionally contain a LetList which collect all value of that scope.
-   * We do an additional pass to fill all the LetList and we are done.
-   */
-  std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
-  return Fill::ToANormalForm(e, dg, &scopes.first);
-}
-
 IRModule ToANormalForm(const IRModule& m) {
   DLOG(INFO) << "ToANF:" << std::endl << m;
 
@@ -288,7 +262,7 @@ IRModule ToANormalForm(const IRModule& m) {
     if (const auto* n = it.second.as<FunctionNode>()) {
       if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
     }
-    Expr ret = TransformF([&](const Expr& e) { return ToANormalFormAux(e); }, it.second);
+    Expr ret = TransformF([&](const Expr& e) { return transform::ToANormalForm(e); }, it.second);
     CHECK_EQ(FreeVars(ret).size(), 0)
         << AsText(ret) << "should not has free vars: " << FreeVars(ret);
     updates.Set(it.first, Downcast<Function>(ret));
@@ -305,13 +279,45 @@ IRModule ToANormalForm(const IRModule& m) {
 
 namespace transform {
 
+Expr ToANormalForm(const Expr& e) {
+  /* When you lift a lambda, what is inside is also being lift.
+   *
+   * So we must determine the scope of the lambda before determining the scope of it's body.
+   *
+   * To make this more principled,
+   * we always determine the scope of parent before determining the scope of children.
+   *
+   * So we calculate all the dependency between nodes.
+   */
+  support::Arena arena;
+  DependencyGraph dg = DependencyGraph::Create(&arena, e);
+  /* In order to model new subscopes created by lambda, if else and pattern matching,
+   * we also assign scope to edge as well.
+   * The scope of an edge is either the parent's scope, or a new subscope of the parent's scope.
+   *
+   * So, the scope of the whole expr is global.
+   * The scope of any subexpr, is the lowest common ancestor of all incoming edge.
+   *
+   * Every scope additionally contain a LetList which collect all value of that scope.
+   * We do an additional pass to fill all the LetList and we are done.
+   */
+  std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
+  return Fill::ToANormalForm(e, dg, &scopes.first);
+}
+
 Pass ToANormalForm() {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
       [=](IRModule m, PassContext pc) { return relay::ToANormalForm(m); };
   return CreateModulePass(pass_func, 1, "ToANormalForm", {});
 }
 
-TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed(ToANormalForm);
+TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed([]() {
+  return ToANormalForm();
+});
+
+TVM_REGISTER_GLOBAL("relay._transform.ToANormalFormExpr").set_body_typed([](const Expr& e) {
+  return ToANormalForm(e);
+});
 
 }  // namespace transform
 
index 6972d5a..7c11ce5 100644 (file)
@@ -52,6 +52,7 @@
  */
 #include <tvm/ir/type_functor.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/feature.h>
 #include <tvm/relay/pattern_functor.h>
 #include <tvm/relay/transform.h>
 
@@ -301,11 +302,13 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) {
 }
 
 Function ToCPS(const Function& f, const IRModule& m) {
+  CheckFeature(f, m, FeatureSet::All() - fGraph);
   CPSMap cps;
   return ToCPS(f, m, &cps);
 }
 
 Function UnCPS(const Function& f) {
+  CheckFeature(f, FeatureSet::All() - fGraph);
   CHECK_GT(f->params.size(), 0);
   std::vector<Var> new_params;
   for (const auto& p : f->params) {
index ec5deb3..2b32376 100644 (file)
@@ -39,7 +39,6 @@ def test_prelude():
         Feature.fIf,
         Feature.fConstructor,
         Feature.fMatch,
-        Feature.fGraph
     ])
 
 
@@ -65,7 +64,6 @@ def test_ad():
         Feature.fRefCreate,
         Feature.fRefRead,
         Feature.fRefWrite,
-        Feature.fGraph
     ])
 
 
index 2c749c9..b8624b4 100644 (file)
@@ -32,17 +32,21 @@ def test_cross_entropy_with_logits_grad():
         x = relay.var("x", shape=(2, 5), dtype=dtype)
         y = relay.var("y", shape=(2, 5), dtype=dtype)
         check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1)
-    
+
+
 def test_checkpoint():
     inputs = [relay.var("x{}".format(i), shape=(1,)) for i in range(4)]
     output = relay.multiply(relay.add(inputs[0], inputs[1]),
                             relay.add(inputs[2], inputs[3]))
     check_grad(relay.Function(inputs, relay.annotation.checkpoint(output)))
 
-    out_tuple = relay.Tuple([relay.add(inputs[0], inputs[1]),
-                             relay.multiply(inputs[2], inputs[3])])
-    out_single = relay.subtract(relay.TupleGetItem(relay.annotation.checkpoint(out_tuple), 0),
-                                relay.TupleGetItem(out_tuple, 1))
+    scope = relay.ScopeBuilder()
+    out_tuple = scope.let("out_tuple",
+                          relay.Tuple([relay.add(inputs[0], inputs[1]),
+                                       relay.multiply(inputs[2], inputs[3])]))
+    scope.ret(relay.subtract(relay.annotation.checkpoint(relay.TupleGetItem(out_tuple, 0)),
+                                relay.TupleGetItem(out_tuple, 1)))
+    out_single = scope.get()
     check_grad(relay.Function(inputs, out_single))
 
 
index 4838c6a..296d3e5 100644 (file)
@@ -45,6 +45,18 @@ def test_id():
     tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy()))
 
 
+def test_relu():
+    shape = (10, 10)
+    dtype = 'float32'
+    t = relay.TensorType(shape, dtype)
+    x = relay.var("x", t)
+    func = relay.Function([x], op.nn.relu(x))
+    func = run_infer_type(func)
+    back_func = run_infer_type(gradient(func))
+    assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
+    # gradient will implicitly check that no graph appear in result
+
+
 def test_add():
     shape = (10, 10)
     dtype = 'float32'
@@ -72,12 +84,14 @@ def test_check_grad():
 
 
 def test_temp_add():
+    scope = relay.ScopeBuilder()
     shape = (10, 10)
     dtype = 'float32'
     t = relay.TensorType(shape, dtype)
     x = relay.var("x", t)
-    y = x + x
-    func = relay.Function([x], y + y)
+    y = scope.let("y", x + x)
+    scope.ret(y + y)
+    func = relay.Function([x], scope.get())
     func = run_infer_type(func)
     back_func = run_infer_type(gradient(func))
     assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
@@ -280,12 +294,14 @@ def test_if():
 
 
 def test_grad_tuple():
+    scope = relay.ScopeBuilder()
     shape = (10, 10)
     dtype = 'float32'
     t = relay.TensorType(shape, dtype)
     x = relay.var("x", t)
-    y = x + x
-    func = relay.Function([x], relay.Tuple([y + y, y]))
+    y = scope.let("y", x + x)
+    scope.ret(relay.Tuple([y + y, y]))
+    func = relay.Function([x], scope.get())
     func = run_infer_type(func)
     back_func = run_infer_type(gradient(func))
     assert back_func.checked_type == relay.FuncType([t], relay.TupleType([relay.TupleType([t, t]), relay.TupleType([t])]))
index 4149268..377164e 100644 (file)
@@ -229,6 +229,24 @@ def test_multivar_reverse_ad():
   assert_allclose(grad_x.asnumpy(), y.asnumpy())
   assert_allclose(grad_y.asnumpy(), x.asnumpy())
 
+def test_partial_eval():
+  """Test transformation following reverse mode ad and PartialEval"""
+  mod = tvm.IRModule()
+
+  shape = (10, 10)
+  dtype = 'float32'
+  t = relay.TensorType(shape, dtype)
+
+  func = relay.Function([], relay.const(np.ones(shape, dtype)))
+  func = run_infer_type(func)
+  back_func = transform.gradient(func)
+  back_func = run_infer_type(back_func)
+
+  mod["main"] = back_func
+  back_func = mod["main"]
+
+  transform.PartialEvaluate()(mod)
+
 def test_after_partial_eval():
   """Test transformation following reverse mode ad and PartialEval"""
   mod = tvm.IRModule()
index ddb5b5d..aef6ab5 100644 (file)
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Unit tests for merge composite."""
+import pytest
 import tvm
 from tvm import relay, tir
 from tvm.relay.dataflow_pattern import TupleGetItemPattern, is_op, wildcard
@@ -213,7 +214,7 @@ def test_simple_merge():
         r = relay.Call(add_relu, [a, b])
         return relay.Function([a, b], r)
 
-    check_result(pattern_table, before(), expected(), import_prelude=True)
+    check_result(pattern_table, before(), expected())
 
 
 def test_branch_merge():
@@ -998,15 +999,4 @@ def test_type_check():
 
 
 if __name__ == "__main__":
-    test_simple_merge()
-    test_branch_merge()
-    test_multiple_patterns()
-    test_optional_pattern()
-    test_merge_order()
-    test_parallel_merge()
-    test_multiple_input_subgraphs()
-    test_reuse_call_merge()
-    test_tuple_get_item_merge()
-    test_pattern_with_check()
-    test_diamond_not_merge()
-    test_type_check()
+    pytest.main([__file__])
index 45593b4..95805d2 100644 (file)
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import pytest
 import numpy as np
 import tvm
 from tvm import te
@@ -173,10 +174,9 @@ def test_function_invalidate():
 def test_head_cons():
     mod = tvm.IRModule()
     p = Prelude(mod)
-    hd = p.hd
     t = TypeVar("t")
     x = Var("x", t)
-    body = hd(p.cons(x, p.nil()))
+    body = p.hd(p.cons(x, p.nil()))
     f = Function([x], body, None, [t])
     res = dcpe(f, mod)
     assert tvm.ir.structural_equal(res, Function([x], x, t, [t]))
@@ -340,23 +340,4 @@ def test_tuple_match():
 
 
 if __name__ == '__main__':
-    test_nat_update()
-    test_ref()
-    test_tuple()
-    test_empty_ad()
-    test_const_inline()
-    test_ad()
-    test_if_ref()
-    test_function_invalidate()
-    test_head_cons()
-    test_map()
-    test_loop()
-    test_swap_loop()
-    test_abs_diff()
-    test_double()
-    test_nat_id()
-    test_global_match_nat_id()
-    test_match_nat_id()
-    test_concat()
-    test_triangle_number()
-    test_tuple_match()
+    pytest.main([__file__])