From 0f4c151f2a7ef7f026cb23d7d681fe0342fe13b6 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Mon, 9 Sep 2019 12:48:04 -0700 Subject: [PATCH] [Relay][Training] Add gradient for max. (#3915) * save * save --- python/tvm/relay/op/_tensor_grad.py | 14 +++++++++++++- tests/python/relay/test_op_grad_level4.py | 13 ++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 0cd2efb..d3d707b 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -25,7 +25,7 @@ from ..expr import Tuple, TupleGetItem, const from . import nn as _nn from .op import register_gradient from .reduce import sum as _sum -from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like +from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like, equal from .transform import ( broadcast_to_like, collapse_sum_like, @@ -269,6 +269,18 @@ def conv2d_grad(orig, grad): return [backward_data, backward_weight] +@register_gradient("max") +def max_grad(orig, grad): + """Returns the gradient of max""" + # Only support axis=0, since broadcasting orig to x behaves incorrectly + x, axis = orig.args[0], orig.attrs.axis + assert(axis is not None and len(axis) == 1 and int(axis[0]) == 0) + orig = broadcast_to_like(orig, x) + grad = broadcast_to_like(grad, x) + indicators = cast_like(equal(orig, x), grad) + return [indicators * grad] + + @register_gradient("nn.softmax") def softmax_grad(orig, grad): """Gradient of softmax""" diff --git a/tests/python/relay/test_op_grad_level4.py b/tests/python/relay/test_op_grad_level4.py index 5db1d93..3c799b8 100644 --- a/tests/python/relay/test_op_grad_level4.py +++ b/tests/python/relay/test_op_grad_level4.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest from tvm import relay from tvm.relay.testing import check_grad @@ -30,6 +31,16 @@ def test_sum_grad(): verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True) +def test_max_grad(): + s = (5, 10) + t = relay.TensorType(s) + x = relay.var("x", t) + axis = 0 + z = relay.max(x, axis) + + fwd_func = relay.Function([x], z) + check_grad(fwd_func, eps=1e-7, rtol=1) + if __name__ == "__main__": - test_sum_grad() + pytest.main() -- 2.7.4