[Relay] Register abs gradient: grad * (select(x < 0, -1, 1)) (#3447)
authorAmy Wang <kai.ting.wang@huawei.com>
Fri, 28 Jun 2019 03:25:38 +0000 (23:25 -0400)
committerWuwei Lin <vincentl13x@gmail.com>
Fri, 28 Jun 2019 03:25:38 +0000 (11:25 +0800)
python/tvm/relay/op/_tensor_grad.py
tests/python/relay/test_op_grad_level1.py

index 158b2dc..4d0e6f8 100644 (file)
@@ -110,3 +110,11 @@ def collapse_sum_like_grad(orig, grad):
     """Returns [broadcast_to_like(grad, x), 0]"""
     x, y = orig.args
     return [broadcast_to_like(grad, x), zeros_like(y)]
+
+@register_gradient("abs")
+def abs_grad(orig, grad):
+    """Returns grad * (select(x < 0, -1, 1))."""
+    x = orig.args[0]
+    zeros = zeros_like(x)
+    ones = ones_like(x)
+    return [where(less(x, zeros), -ones * grad, ones * grad)]
index 12b3ae3..0722712 100644 (file)
@@ -53,6 +53,7 @@ def test_unary_op():
                         (tvm.relay.sigmoid, lambda x: sigmoid(x) * (1 - sigmoid(x))),
                         (tvm.relay.tanh, lambda x: 1 - np.tanh(x) * np.tanh(x)),
                         (tvm.relay.sqrt, lambda x: 0.5 * np.power(x, -0.5)),
+                        (tvm.relay.abs, lambda x: np.where(x < 0, -np.ones_like(x), np.ones_like(x))),
                         (relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x)))]:
         check_single_op(opfunc, ref)