[Relay][Training] Add gradient for max. (#3915)
author雾雨魔理沙 <lolisa@marisa.moe>
Mon, 9 Sep 2019 19:48:04 +0000 (12:48 -0700)
committerThierry Moreau <moreau@uw.edu>
Mon, 9 Sep 2019 19:48:04 +0000 (12:48 -0700)
* save

* save

python/tvm/relay/op/_tensor_grad.py
tests/python/relay/test_op_grad_level4.py

index 0cd2efb..d3d707b 100644 (file)
@@ -25,7 +25,7 @@ from ..expr import Tuple, TupleGetItem, const
 from . import nn as _nn
 from .op import register_gradient
 from .reduce import sum as _sum
-from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like
+from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like, equal
 from .transform import (
     broadcast_to_like,
     collapse_sum_like,
@@ -269,6 +269,18 @@ def conv2d_grad(orig, grad):
     return [backward_data, backward_weight]
 
 
+@register_gradient("max")
+def max_grad(orig, grad):
+    """Returns the gradient of max"""
+    # Only support axis=0, since broadcasting orig to x behaves incorrectly
+    x, axis = orig.args[0], orig.attrs.axis
+    assert(axis is not None and len(axis) == 1 and int(axis[0]) == 0)
+    orig = broadcast_to_like(orig, x)
+    grad = broadcast_to_like(grad, x)
+    indicators = cast_like(equal(orig, x), grad)
+    return [indicators * grad]
+
+
 @register_gradient("nn.softmax")
 def softmax_grad(orig, grad):
     """Gradient of softmax"""
index 5db1d93..3c799b8 100644 (file)
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import pytest
 from tvm import relay
 from tvm.relay.testing import check_grad
 
@@ -30,6 +31,16 @@ def test_sum_grad():
     verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True)
 
 
+def test_max_grad():
+    s = (5, 10)
+    t = relay.TensorType(s)
+    x = relay.var("x", t)
+    axis = 0
+    z = relay.max(x, axis)
+
+    fwd_func = relay.Function([x], z)
+    check_grad(fwd_func, eps=1e-7, rtol=1)
+
 
 if __name__ == "__main__":
-    test_sum_grad()
+    pytest.main()