From d201978415f7e2a13565c12bdd4fd38d2557708f Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 29 Aug 2019 13:36:06 -0400 Subject: [PATCH] [Relay] Conv2d grad (#3636) * [Relay] Conv2d grad * Fix test * Fix first order gradient --- python/tvm/relay/op/_tensor_grad.py | 65 ++++++++++++++++++++++++++++++- python/tvm/relay/op/nn/nn.py | 6 ++- src/relay/op/nn/convolution.cc | 2 + src/relay/pass/gradient.cc | 36 +++++++++++------ tests/python/relay/test_op_grad_level2.py | 49 ++++++++++++++++++++++- 5 files changed, 144 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 4370863..41808b5 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -17,9 +17,13 @@ #pylint: disable=invalid-name, unused-argument """Backend compiler related feature registration""" from __future__ import absolute_import +from topi.util import get_const_tuple +from topi.nn.util import get_pad_tuple from ..expr import const, Tuple, TupleGetItem from .op import register_gradient -from .transform import collapse_sum_like, broadcast_to_like, where +from .reduce import sum as _sum +from .transform import collapse_sum_like, broadcast_to_like, where, transpose, reshape, tile, \ + strided_slice from .tensor import exp, negative, power, less, cos, sin from .tensor import zeros_like, ones_like from . import nn as _nn @@ -187,3 +191,62 @@ def concatenate_grad(orig, grad): # Assume only two element in tuple rn. # In the real implementation, concatenate_grad probably need to be implemented by an operator. return [Tuple([zeros_like(x), zeros_like(y)])] + +@register_gradient("nn.conv2d") +def conv2d_grad(orig, grad): + """Gradient of conv2d""" + attrs = orig.attrs + data, weight = orig.args + data_shape = get_const_tuple(data.checked_type.shape) + weight_shape = get_const_tuple(weight.checked_type.shape) + _, _, grad_h, grad_w = get_const_tuple(orig.checked_type.shape) + batch, in_channel, in_h, in_w = data_shape + out_channel, _, filter_h, filter_w = weight_shape + + # infer output_padding + fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(get_const_tuple(attrs.padding), + (filter_h, filter_w)) + stride_h, stride_w = get_const_tuple(attrs.strides) + dilation_h, dilation_w = get_const_tuple(attrs.dilation) + out_h = (grad_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + out_w = (grad_w - 1) * stride_w - fpad_left - fpad_right + filter_w + output_padding = (in_h - out_h, in_w - out_w) + + assert attrs.data_layout == 'NCHW', 'only support NCHW data layout' + assert attrs.kernel_layout == 'OIHW', 'only support OIHW kernel layout' + assert attrs.out_layout in ['', 'NCHW'], 'only support NCHW output layout' + + + backward_data = _nn.conv2d_transpose(grad, weight, + strides=attrs.strides, + padding=attrs.padding, + dilation=attrs.dilation, + groups=attrs.groups, + output_padding=output_padding) + grad = tile(grad, [1, in_channel // attrs.groups, 1, 1]) + grad = reshape(grad, [-1, 1, 0, 0]) # batch * oc * ic // groups, 1, oh, ow + data = reshape(data, [1, -1, 0, 0]) # 1, batch * ic, ih, iw + + backward_weight = _nn.conv2d(data, grad, + strides=attrs.dilation, + padding=attrs.padding, + dilation=attrs.strides, + groups=in_channel * batch) + # infer shape of backward_weight + padded_weight_grad_h = (in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom) \ + // dilation_h + 1 + padded_weight_grad_w = (in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right) \ + // dilation_w + 1 + backward_weight = reshape(backward_weight, + [batch, in_channel // attrs.groups, out_channel, + padded_weight_grad_h, padded_weight_grad_w]) + backward_weight = _sum(backward_weight, axis=0) + backward_weight = transpose(backward_weight, [1, 0, 2, 3]) + + assert padded_weight_grad_h >= filter_h + assert padded_weight_grad_w >= filter_w + if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w: + backward_weight = strided_slice(backward_weight, begin=[0, 0, 0, 0], + end=[None, None, filter_h, filter_w]) + + return [backward_data, backward_weight] diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 4b7f52e..946ea33 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -116,6 +116,7 @@ def conv2d_transpose(data, kernel_size=None, data_layout="NCHW", kernel_layout="OIHW", + out_layout="", output_padding=(0, 0), out_dtype=""): """Two dimensional transposed convolution operator. @@ -152,6 +153,9 @@ def conv2d_transpose(data, kernel_layout : str, optional Layout of the weight. + out_layout : Optional[str] + Layout of the output, by default, out_layout is the same as data_layout + output_padding : Tuple[int], optional Additional zero-padding to be added to one side of the output. @@ -165,7 +169,7 @@ def conv2d_transpose(data, """ return _make.conv2d_transpose(data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, - kernel_layout, output_padding, out_dtype) + kernel_layout, out_layout, output_padding, out_dtype) def softmax(data, axis=-1): diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 79203d8..5eb54a1 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -320,6 +320,7 @@ Expr MakeConv2DTranspose(Expr data, Array kernel_size, std::string data_layout, std::string kernel_layout, + std::string out_layout, Array output_padding, DataType out_dtype) { auto attrs = make_node(); @@ -332,6 +333,7 @@ Expr MakeConv2DTranspose(Expr data, attrs->groups = groups; attrs->data_layout = std::move(data_layout); attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("nn.conv2d_transpose"); return CallNode::make(op, {data, weight}, Attrs(attrs), {}); diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index dbef374..2606910 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -109,7 +109,9 @@ struct ADTensor : ADValueNode { Expr forward; mutable Expr reverse; // must be a variable to avoid duplication ADTensor(LetList* ll, const Expr& forward) : - forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { } + forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { + this->forward->checked_type_ = forward->checked_type(); + } }; /*! \brief A staged representation of the program, we reflect @@ -117,10 +119,12 @@ struct ADTensor : ADValueNode { * can compute away this function to obtain a reverse mode program. */ struct ADFunction : ADValueNode { - std::function&, + std::function&, const Attrs&, const tvm::Array&)> func; - explicit ADFunction(const std::function&, + explicit ADFunction(const std::function&, const Attrs&, const tvm::Array&)>& func) : func(func) { } @@ -139,7 +143,8 @@ struct FirstOrderReverseAD : ExprFunctor { Op op_ref = GetRef(op); CHECK(rev_map.count(op_ref)) << op->name << " does not have reverse mode defined"; - return std::make_shared([this, op_ref](const std::vector& args, + return std::make_shared([this, op_ref](const Type& orig_type, + const std::vector& args, const Attrs& attrs, const tvm::Array& type_args) { std::vector call_args; @@ -147,6 +152,7 @@ struct FirstOrderReverseAD : ExprFunctor { call_args.push_back(adval->get().forward); } auto orig = CallNode::make(op_ref, call_args, attrs, type_args); + orig->checked_type_ = orig_type; auto ret = std::make_shared(ll, orig); backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) { tvm::Array rev = rev_map[op_ref](orig, ret->reverse); @@ -171,13 +177,14 @@ struct FirstOrderReverseAD : ExprFunctor { for (const auto& arg : op->args) { args.push_back(VisitExpr(arg)); } - return f->get().func(args, op->attrs, op->type_args); + return f->get().func(op->checked_type(), args, op->attrs, op->type_args); } ADValue VisitExpr_(const FunctionNode* op) final { Function f = GetRef(op); // todo: assert no closure - return std::make_shared([this, f](const std::vector& args, + return std::make_shared([this, f](const Type& orig_type, + const std::vector& args, const Attrs& attrs, const tvm::Array& type_args) { CHECK_EQ(f->params.size(), args.size()); @@ -227,7 +234,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { for (const auto& p : f->params) { args.push_back(std::make_shared(ll, p)); } - auto c = rev->get().func(args, Attrs(), {}); + auto c = rev->get().func(f->checked_type(), args, Attrs(), {}); const auto& res = c->get(); Expr grad = LetList::With([&](LetList* ll) { res.reverse = OnesLike(res.forward); @@ -271,7 +278,9 @@ Expr LiftTensor(const std::function& f, LetList* ll) { CHECK(IsAtomic(e)) << e; if (t.as()) { - return f(e); + auto ret = f(e); + ret->checked_type_ = t; + return ret; } else if (auto* tt = t.as()) { tvm::Array fields; for (size_t i = 0; i < tt->fields.size(); ++i) { @@ -280,7 +289,9 @@ Expr LiftTensor(const std::function& f, ll->Push(GetField(e, i)), ll)); } - return TupleNode::make(fields); + auto ret = TupleNode::make(fields); + ret->checked_type_ = t; + return std::move(ret); } else { LOG(FATAL) << "unsupported input/output type: " << tt; throw; @@ -348,11 +359,14 @@ struct ReverseAD : ExprMutator { args.push_back(ll->Push(VisitExpr(arg))); } std::vector orig_args; - for (size_t i = 0; i < args.size(); ++i) { + for (size_t i = 0; i < args.size(); i++) { orig_args.push_back(GetValue(op->args[i]->checked_type(), args[i], ll)); } Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args); - auto ret = ll->Push(GetRev(op->checked_type(), ll->Push(orig), ll)); + orig->checked_type_ = op->checked_type(); + Var orig_var = ll->Push(orig); + orig_var->checked_type_ = op->checked_type(); + auto ret = ll->Push(GetRev(op->checked_type(), orig_var, ll)); auto bpv = ll->Push(RefReadNode::make(bp)); Expr nbp = FunctionNode::make( {}, diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index a0b52e5..33d7ecf 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -20,7 +20,7 @@ import topi import topi.testing from tvm import relay from tvm.relay.transform import gradient -from tvm.relay.testing import ctx_list +from tvm.relay.testing import ctx_list, check_grad from tvm.relay.testing import run_infer_type @@ -83,6 +83,53 @@ def test_avg_pool2d_grad(): ceil_mode=False, count_include_pad=False) +def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mode='higher_order'): + try: + import torch + import torch.nn.functional as F + except ImportError: + print('Skip because pytorch is not installed') + return + + dtype = 'float32' + data = relay.var('data', shape=dshape, dtype=dtype) + weight = relay.var('weight', shape=wshape, dtype=dtype) + conv = relay.nn.conv2d(data, weight, strides=strides, padding=padding, dilation=dilation, + groups=groups) + fwd_func = relay.Function([data, weight], conv) + fwd_func = run_infer_type(fwd_func) + bwd_func = run_infer_type(gradient(fwd_func, mode=mode)) + + data_pt = torch.randn(*dshape, dtype=torch.float32, requires_grad=True) + weight_pt = torch.randn(*wshape, dtype=torch.float32, requires_grad=True) + out_pt = F.conv2d(data_pt, weight_pt, stride=strides, padding=padding, dilation=dilation, + groups=groups) + grad_output_pt = torch.ones(out_pt.shape) + grad_input_pt = F.grad.conv2d_input(dshape, weight_pt, grad_output_pt, stride=strides, + padding=padding, dilation=dilation, groups=groups) \ + .detach().numpy() + grad_weight_pt = F.grad.conv2d_weight(data_pt, wshape, grad_output_pt, stride=strides, + padding=padding, dilation=dilation, groups=groups) \ + .detach().numpy() + + + for target, ctx in ctx_list(): + data = tvm.nd.array(data_pt.detach().numpy(), ctx) + weight = tvm.nd.array(weight_pt.detach().numpy(), ctx) + intrp = relay.create_executor(ctx=ctx, target=target) + op_res, (grad_input, grad_weight) = intrp.evaluate(bwd_func)(data, weight) + np.testing.assert_allclose(grad_input.asnumpy(), grad_input_pt, rtol=1e-4, atol=1e-4) + np.testing.assert_allclose(grad_weight.asnumpy(), grad_weight_pt, rtol=1e-4, atol=1e-4) + + +def test_conv2d_grad(): + verify_conv2d_grad((1, 4, 16, 16), (16, 4, 3, 3), [1, 1], [1, 1], [1, 1]) + verify_conv2d_grad((1, 4, 16, 16), (16, 4, 1, 1), [1, 1], [0, 0], [1, 1]) + verify_conv2d_grad((1, 4, 16, 16), (16, 4, 1, 1), [2, 2], [0, 0], [1, 1]) + verify_conv2d_grad((1, 4, 16, 16), (16, 4, 3, 3), [1, 1], [1, 1], [1, 1], mode='first_order') + + if __name__ == "__main__": test_max_pool2d_grad() test_avg_pool2d_grad() + test_conv2d_grad() -- 2.7.4