[Relay] [Training] Fix ad for concatenate (#3729)
author雾雨魔理沙 <lolisa@marisa.moe>
Fri, 9 Aug 2019 19:40:16 +0000 (12:40 -0700)
committerThierry Moreau <moreau@uw.edu>
Fri, 9 Aug 2019 19:40:16 +0000 (12:40 -0700)
* reproduce error

* fix

* lint

* lint

python/tvm/relay/op/_tensor_grad.py
src/relay/ir/alpha_equal.cc
src/relay/pass/gradient.cc
tests/python/relay/test_pass_gradient.py

index 3e64a97..4370863 100644 (file)
@@ -17,7 +17,7 @@
 #pylint: disable=invalid-name, unused-argument
 """Backend compiler related feature registration"""
 from __future__ import absolute_import
-from ..expr import const
+from ..expr import const, Tuple, TupleGetItem
 from .op import register_gradient
 from .transform import collapse_sum_like, broadcast_to_like, where
 from .tensor import exp, negative, power, less, cos, sin
@@ -176,3 +176,14 @@ def avg_pool2d_grad(orig, grad):
                                     layout=attrs.layout, ceil_mode=attrs.ceil_mode,
                                     count_include_pad=attrs.count_include_pad)
     return [pool_grad]
+
+# not implemented, this is only for testing.
+@register_gradient("concatenate")
+def concatenate_grad(orig, grad):
+    assert len(orig.args) == 1
+    t = orig.args[0]
+    x = TupleGetItem(t, 0)
+    y = TupleGetItem(t, 1)
+    # Assume only two element in tuple rn.
+    # In the real implementation, concatenate_grad probably need to be implemented by an operator.
+    return [Tuple([zeros_like(x), zeros_like(y)])]
index ea27027..2c23f0f 100644 (file)
@@ -117,9 +117,12 @@ class AlphaEqualHandler:
    * \return the comparison result.
    */
   bool TypeEqual(const Type& lhs, const Type& rhs) {
-    if (lhs.same_as(rhs)) return true;
-    if (!lhs.defined() || !rhs.defined()) return false;
-    return this->VisitType(lhs, rhs);
+    auto compute = [&](){
+      if (lhs.same_as(rhs)) return true;
+      if (!lhs.defined() || !rhs.defined()) return false;
+      return this->VisitType(lhs, rhs);
+    };
+    return Compare(compute(), lhs, rhs);
   }
 
   bool Compare(bool result, const NodeRef& lhs, const NodeRef& rhs) {
index 12cf4a1..dbef374 100644 (file)
@@ -29,6 +29,7 @@
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/transform.h>
 #include "pattern_util.h"
+#include "pass_util.h"
 #include "let_list.h"
 #include "../ir/type_functor.h"
 
@@ -257,11 +258,79 @@ struct ReverseADType : TypeMutator {
   }
 };
 
+Type ReverseType(const Type& t) {
+  return ReverseADType()(t);
+}
+
+/*! \brief Lift a function that transform Tensor to a function that also transform more type
+ * by doing a structure preserving map.
+ */
+Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
+                const Type& t,
+                const Expr& e,
+                LetList* ll) {
+  CHECK(IsAtomic(e)) << e;
+  if (t.as<TensorTypeNode>()) {
+    return f(e);
+  } else if (auto* tt = t.as<TupleTypeNode>()) {
+    tvm::Array<Expr> fields;
+    for (size_t i = 0; i < tt->fields.size(); ++i) {
+      fields.push_back(LiftTensor(f,
+                                  tt->fields[i],
+                                  ll->Push(GetField(e, i)),
+                                  ll));
+    }
+    return TupleNode::make(fields);
+  } else {
+    LOG(FATAL) << "unsupported input/output type: " << tt;
+    throw;
+  }
+}
+
+/*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
+Expr GetRev(const Type& t, const Expr& e, LetList* ll) {
+  auto rev = [&](const Expr& e) {
+    return Pair(e, ll->Push(RefCreateNode::make(ZerosLike(e))));
+  };
+  return LiftTensor(rev, t, e, ll);
+}
+
+/*! \brief ReverseType(t) -> t. Get the original value. */
+Expr GetValue(const Type& t, const Expr& e, LetList* ll) {
+  return LiftTensor([&](const Expr& e) { return GetField(e, 0); }, t, e, ll);
+}
+
+/*! \brief ReverseType(t) -> t. Get the gradient. */
+Expr GetGrad(const Type& t, const Expr& e, LetList* ll) {
+  auto grad = [&](const Expr& e) {
+    return ll->Push(RefReadNode::make(GetField(e, 1)));
+  };
+  return LiftTensor(grad, t, e, ll);
+}
+
+void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
+  if (t.as<TensorTypeNode>()) {
+    ll->Push(RefWriteNode::make(GetField(arg, 1),
+                                Add(ll->Push(RefReadNode::make(GetField(arg, 1))),
+                                    grad)));
+  } else if (auto* tt = t.as<TupleTypeNode>()) {
+    for (size_t i = 0; i < tt->fields.size(); ++i) {
+      UpdateGrad(tt->fields[i],
+                 ll->Push(GetField(arg, i)),
+                 ll->Push(GetField(grad, i)),
+                 ll);
+    }
+  } else {
+    LOG(FATAL) << "unsupported arg type of operator: " << t;
+    throw;
+  }
+}
+
 struct ReverseAD : ExprMutator {
   Var bp;
   const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
 
-  ReverseAD(const Var& bp) : bp(bp) { } /// NOLINT(*)
+  explicit ReverseAD(const Var& bp) : bp(bp) { }
 
   Expr VisitExpr_(const OpNode* op) final {
     LOG(FATAL) << "op should only be inside call";
@@ -279,29 +348,26 @@ struct ReverseAD : ExprMutator {
           args.push_back(ll->Push(VisitExpr(arg)));
         }
         std::vector<Expr> orig_args;
-        for (const auto& arg : args) {
-          orig_args.push_back(GetField(arg, 0));
+        for (size_t i = 0; i < args.size(); ++i) {
+          orig_args.push_back(GetValue(op->args[i]->checked_type(), args[i], ll));
         }
         Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args);
-        Var orig_var = ll->Push(orig);
-        auto ref = ll->Push(RefCreateNode::make(ZerosLike(orig_var)));
+        auto ret = ll->Push(GetRev(op->checked_type(), ll->Push(orig), ll));
         auto bpv = ll->Push(RefReadNode::make(bp));
         Expr nbp = FunctionNode::make(
           {},
           LetList::With([&](LetList* ll) {
-              tvm::Array<Expr> rev = rev_map[op_ref](orig, ll->Push(RefReadNode::make(ref)));
-              CHECK(args.size() == rev.size());
-              for (size_t i = 0; i < args.size(); ++i) {
-                ll->Push(RefWriteNode::make(GetField(args[i], 1),
-                                            Add(ll->Push(RefReadNode::make(GetField(args[i], 1))),
-                                                rev[i])));
-              }
+            tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(op->checked_type(), ret, ll));
+            CHECK(args.size() == rev.size());
+            for (size_t i = 0; i < args.size(); ++i) {
+              UpdateGrad(op->args[i]->checked_type(), args[i], rev[i], ll);
+            }
             return CallNode::make(bpv, {});
-            }),
+          }),
           TupleTypeNode::make({}),
           {});
         ll->Push(RefWriteNode::make(bp, nbp));
-        return Pair(orig_var, ref);
+        return ret;
       });
     }
     return ExprMutator::VisitExpr_(op);
@@ -319,7 +385,7 @@ struct ReverseAD : ExprMutator {
   }
 
   Type VisitType(const Type& t) final {
-    return t.defined() ? ReverseADType()(t) : t;
+    return t.defined() ? ReverseType(t) : t;
   }
 };
 
index 8e901e7..8e4b701 100644 (file)
@@ -18,11 +18,12 @@ import numpy as np
 
 import tvm
 from tvm import relay
-from tvm.relay.analysis import free_vars, free_type_vars
+from tvm.relay.analysis import free_vars, free_type_vars, assert_alpha_equal
 from tvm.relay import create_executor, transform
 from tvm.relay.transform import gradient
 from tvm.relay.prelude import Prelude
 from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand
+import tvm.relay.op as op
 
 
 def test_id():
@@ -280,6 +281,20 @@ def test_grad_tuple():
     tvm.testing.assert_allclose(grad.asnumpy(), 4 * np.ones_like(x.asnumpy()))
 
 
+def test_concat():
+    shape = (10, 10)
+    dtype = 'float32'
+    t = relay.TensorType(shape, dtype)
+    rt = relay.TensorType((10, 20), dtype)
+    x = relay.var("x", t)
+    y = op.concatenate([x, x], axis=1)
+    func = relay.Function([x], y)
+    func = run_infer_type(func)
+    back_func = run_infer_type(gradient(func))
+    assert_alpha_equal(back_func.checked_type, relay.FuncType([t], relay.TupleType([rt, relay.TupleType([t])])))
+    # no value validation as concatenate has dummy gradient right now.
+
+
 if __name__ == "__main__":
     test_id()
     test_add()