From 52fde8f773a9e13b905c014ed34abf9efc10fbc3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Fri, 9 Aug 2019 12:40:16 -0700 Subject: [PATCH] [Relay] [Training] Fix ad for concatenate (#3729) * reproduce error * fix * lint * lint --- python/tvm/relay/op/_tensor_grad.py | 13 ++++- src/relay/ir/alpha_equal.cc | 9 ++- src/relay/pass/gradient.cc | 96 +++++++++++++++++++++++++++----- tests/python/relay/test_pass_gradient.py | 17 +++++- 4 files changed, 115 insertions(+), 20 deletions(-) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 3e64a97..4370863 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -17,7 +17,7 @@ #pylint: disable=invalid-name, unused-argument """Backend compiler related feature registration""" from __future__ import absolute_import -from ..expr import const +from ..expr import const, Tuple, TupleGetItem from .op import register_gradient from .transform import collapse_sum_like, broadcast_to_like, where from .tensor import exp, negative, power, less, cos, sin @@ -176,3 +176,14 @@ def avg_pool2d_grad(orig, grad): layout=attrs.layout, ceil_mode=attrs.ceil_mode, count_include_pad=attrs.count_include_pad) return [pool_grad] + +# not implemented, this is only for testing. +@register_gradient("concatenate") +def concatenate_grad(orig, grad): + assert len(orig.args) == 1 + t = orig.args[0] + x = TupleGetItem(t, 0) + y = TupleGetItem(t, 1) + # Assume only two element in tuple rn. + # In the real implementation, concatenate_grad probably need to be implemented by an operator. + return [Tuple([zeros_like(x), zeros_like(y)])] diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index ea27027..2c23f0f 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -117,9 +117,12 @@ class AlphaEqualHandler: * \return the comparison result. */ bool TypeEqual(const Type& lhs, const Type& rhs) { - if (lhs.same_as(rhs)) return true; - if (!lhs.defined() || !rhs.defined()) return false; - return this->VisitType(lhs, rhs); + auto compute = [&](){ + if (lhs.same_as(rhs)) return true; + if (!lhs.defined() || !rhs.defined()) return false; + return this->VisitType(lhs, rhs); + }; + return Compare(compute(), lhs, rhs); } bool Compare(bool result, const NodeRef& lhs, const NodeRef& rhs) { diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 12cf4a1..dbef374 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -29,6 +29,7 @@ #include #include #include "pattern_util.h" +#include "pass_util.h" #include "let_list.h" #include "../ir/type_functor.h" @@ -257,11 +258,79 @@ struct ReverseADType : TypeMutator { } }; +Type ReverseType(const Type& t) { + return ReverseADType()(t); +} + +/*! \brief Lift a function that transform Tensor to a function that also transform more type + * by doing a structure preserving map. + */ +Expr LiftTensor(const std::function& f, + const Type& t, + const Expr& e, + LetList* ll) { + CHECK(IsAtomic(e)) << e; + if (t.as()) { + return f(e); + } else if (auto* tt = t.as()) { + tvm::Array fields; + for (size_t i = 0; i < tt->fields.size(); ++i) { + fields.push_back(LiftTensor(f, + tt->fields[i], + ll->Push(GetField(e, i)), + ll)); + } + return TupleNode::make(fields); + } else { + LOG(FATAL) << "unsupported input/output type: " << tt; + throw; + } +} + +/*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */ +Expr GetRev(const Type& t, const Expr& e, LetList* ll) { + auto rev = [&](const Expr& e) { + return Pair(e, ll->Push(RefCreateNode::make(ZerosLike(e)))); + }; + return LiftTensor(rev, t, e, ll); +} + +/*! \brief ReverseType(t) -> t. Get the original value. */ +Expr GetValue(const Type& t, const Expr& e, LetList* ll) { + return LiftTensor([&](const Expr& e) { return GetField(e, 0); }, t, e, ll); +} + +/*! \brief ReverseType(t) -> t. Get the gradient. */ +Expr GetGrad(const Type& t, const Expr& e, LetList* ll) { + auto grad = [&](const Expr& e) { + return ll->Push(RefReadNode::make(GetField(e, 1))); + }; + return LiftTensor(grad, t, e, ll); +} + +void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { + if (t.as()) { + ll->Push(RefWriteNode::make(GetField(arg, 1), + Add(ll->Push(RefReadNode::make(GetField(arg, 1))), + grad))); + } else if (auto* tt = t.as()) { + for (size_t i = 0; i < tt->fields.size(); ++i) { + UpdateGrad(tt->fields[i], + ll->Push(GetField(arg, i)), + ll->Push(GetField(grad, i)), + ll); + } + } else { + LOG(FATAL) << "unsupported arg type of operator: " << t; + throw; + } +} + struct ReverseAD : ExprMutator { Var bp; const OpMap rev_map = Op::GetAttr("FPrimalGradient"); - ReverseAD(const Var& bp) : bp(bp) { } /// NOLINT(*) + explicit ReverseAD(const Var& bp) : bp(bp) { } Expr VisitExpr_(const OpNode* op) final { LOG(FATAL) << "op should only be inside call"; @@ -279,29 +348,26 @@ struct ReverseAD : ExprMutator { args.push_back(ll->Push(VisitExpr(arg))); } std::vector orig_args; - for (const auto& arg : args) { - orig_args.push_back(GetField(arg, 0)); + for (size_t i = 0; i < args.size(); ++i) { + orig_args.push_back(GetValue(op->args[i]->checked_type(), args[i], ll)); } Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args); - Var orig_var = ll->Push(orig); - auto ref = ll->Push(RefCreateNode::make(ZerosLike(orig_var))); + auto ret = ll->Push(GetRev(op->checked_type(), ll->Push(orig), ll)); auto bpv = ll->Push(RefReadNode::make(bp)); Expr nbp = FunctionNode::make( {}, LetList::With([&](LetList* ll) { - tvm::Array rev = rev_map[op_ref](orig, ll->Push(RefReadNode::make(ref))); - CHECK(args.size() == rev.size()); - for (size_t i = 0; i < args.size(); ++i) { - ll->Push(RefWriteNode::make(GetField(args[i], 1), - Add(ll->Push(RefReadNode::make(GetField(args[i], 1))), - rev[i]))); - } + tvm::Array rev = rev_map[op_ref](orig, GetGrad(op->checked_type(), ret, ll)); + CHECK(args.size() == rev.size()); + for (size_t i = 0; i < args.size(); ++i) { + UpdateGrad(op->args[i]->checked_type(), args[i], rev[i], ll); + } return CallNode::make(bpv, {}); - }), + }), TupleTypeNode::make({}), {}); ll->Push(RefWriteNode::make(bp, nbp)); - return Pair(orig_var, ref); + return ret; }); } return ExprMutator::VisitExpr_(op); @@ -319,7 +385,7 @@ struct ReverseAD : ExprMutator { } Type VisitType(const Type& t) final { - return t.defined() ? ReverseADType()(t) : t; + return t.defined() ? ReverseType(t) : t; } }; diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 8e901e7..8e4b701 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -18,11 +18,12 @@ import numpy as np import tvm from tvm import relay -from tvm.relay.analysis import free_vars, free_type_vars +from tvm.relay.analysis import free_vars, free_type_vars, assert_alpha_equal from tvm.relay import create_executor, transform from tvm.relay.transform import gradient from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand +import tvm.relay.op as op def test_id(): @@ -280,6 +281,20 @@ def test_grad_tuple(): tvm.testing.assert_allclose(grad.asnumpy(), 4 * np.ones_like(x.asnumpy())) +def test_concat(): + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + rt = relay.TensorType((10, 20), dtype) + x = relay.var("x", t) + y = op.concatenate([x, x], axis=1) + func = relay.Function([x], y) + func = run_infer_type(func) + back_func = run_infer_type(gradient(func)) + assert_alpha_equal(back_func.checked_type, relay.FuncType([t], relay.TupleType([rt, relay.TupleType([t])]))) + # no value validation as concatenate has dummy gradient right now. + + if __name__ == "__main__": test_id() test_add() -- 2.7.4