add a few gradients (#5899)
authorThomas Viehmann <tv.code@beamnet.de>
Tue, 23 Jun 2020 21:01:46 +0000 (23:01 +0200)
committerGitHub <noreply@github.com>
Tue, 23 Jun 2020 21:01:46 +0000 (14:01 -0700)
python/tvm/relay/op/_tensor_grad.py
tests/python/relay/test_op_grad_level1.py
tests/python/relay/test_op_grad_level10.py
tests/python/relay/test_op_grad_level3.py
tests/python/relay/test_op_grad_level4.py

index 0deb87a..00ea097 100644 (file)
@@ -270,6 +270,14 @@ def abs_grad(orig, grad):
     return [where(less(x, zeros), -ones * grad, ones * grad)]
 
 
+@register_gradient("erf")
+def erf_grad(orig, grad):
+    # c_2_div_sqrt_pi = 2.0 / math.sqrt(math.pi)
+    inp, = orig.args
+    c_2_div_sqrt_pi = const(1.1283791670955126, dtype=inp.checked_type.dtype)
+    return [c_2_div_sqrt_pi * exp(- inp * inp) * grad]
+
+
 @register_gradient("clip")
 def clip_grad(orig, grad):
     """Returns grad * (select(x < min || max < x , 0, 1))."""
@@ -479,6 +487,19 @@ def dense_grad(orig, grad):
             collapse_sum_like(_nn.dense(transpose(grad), transpose(data),
                                         units=data.checked_type.shape[1]), weight)]
 
+
+@register_gradient("nn.batch_matmul")
+def batch_matmul_grad(orig, grad):
+    """gradient for nn.batch_matmul: in einsum LHS_bik,RHS_bjk->RES_bij
+       grads: GRAD_OUT_bij,RHS_bjk->GRAD_IN_LHS_bik
+              GRAD_OUT_bij,LHS_bik->GRAD_IN_RHS_bjk
+    """
+    lhs, rhs = orig.args
+    return [collapse_sum_like(_nn.batch_matmul(grad, transpose(rhs, [0, 2, 1])), lhs),
+            collapse_sum_like(_nn.batch_matmul(transpose(grad, [0, 2, 1]),
+                                               transpose(lhs, [0, 2, 1])), rhs)]
+
+
 @register_gradient("reshape")
 def reshape_grad(orig, grad):
     """Gradient of reshape"""
@@ -529,6 +550,42 @@ def sum_grad(orig, grad):
     return [broadcast_to_like(grad, data)]
 
 
+@register_gradient("mean")
+def mean_grad(orig, grad):
+    """Returns grad broadcasted to data dims"""
+    data, axis = orig.args[0], _get_reduce_axis(orig)
+    shape = data.checked_type.concrete_shape
+    if axis is None:
+        axis = list(range(len(data.checked_type.concrete_shape)))
+    if not orig.attrs.keepdims:
+        grad = _unreduce_expand(grad, axis)
+    mult = 1.0
+    for a in axis:
+        mult /= shape[a]
+    return [broadcast_to_like(grad * const(mult, dtype=data.checked_type.dtype), data)]
+
+
+@register_gradient("variance")
+def variance_grad(orig, grad):
+    """Note that we take mean as an argument in the variance node"""
+    data, data_mean, axis = orig.args[0], orig.args[1], _get_reduce_axis(orig)
+    shape = data.checked_type.concrete_shape
+    if axis is None:
+        axis = list(range(len(data.checked_type.concrete_shape)))
+    if not orig.attrs.keepdims:
+        grad = _unreduce_expand(grad, axis)
+    mult = 2.0
+    for a in axis:
+        mult /= shape[a]
+    return [(grad * const(mult, dtype=data.checked_type.dtype)) * data,
+            const(-2, dtype=data.checked_type.dtype) * grad * data_mean]
+
+
+@register_gradient("copy")
+def copy_grad(orig, grad):
+    return [grad]
+
+
 @register_gradient("nn.cross_entropy")
 def cross_entropy_grad(orig, grad):
     x, y = orig.args
index 9faf6d9..85506e0 100644 (file)
@@ -62,6 +62,7 @@ def test_unary_op():
                         (tvm.relay.sqrt, lambda x: 0.5 * np.power(x, -0.5)),
                         (tvm.relay.abs, lambda x: np.where(x < 0, -np.ones_like(x), np.ones_like(x))),
                         (relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x))),
+                        (tvm.relay.erf, lambda x: 2.0 / (np.pi**(0.5)) * np.exp(-x * x)),
                         (tvm.relay.cos, lambda x: -1.0 * np.sin(x)),
                         (tvm.relay.sin, lambda x: np.cos(x)),
                         (tvm.relay.tan, lambda x: 1.0 / (np.cos(x) ** 2)),
index acf3b75..6e64999 100644 (file)
@@ -44,5 +44,11 @@ def test_checkpoint():
     check_grad(relay.Function(inputs, out_single))
 
 
+def test_batch_matmul_grad():
+    x = relay.var("x", shape=(2, 3, 5), dtype="float64")
+    y = relay.var("y", shape=(2, 4, 5), dtype="float64")
+    check_grad(relay.Function([x, y], relay.op.nn.batch_matmul(x, y)))
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
index d13687f..b1d0e25 100644 (file)
@@ -64,5 +64,12 @@ def test_cast_grad():
     fwd_func = relay.Function([data], relay.cast(data, "float64"))
     check_grad(fwd_func)
 
+
+def test_copy_grad():
+    data = relay.var("data", relay.TensorType((10, 4), "float64"))
+    fwd_func = relay.Function([data], relay.copy(data))
+    check_grad(fwd_func)
+
+
 if __name__ == "__main__":
     pytest.main()
index f690a18..956c6af 100644 (file)
@@ -19,17 +19,18 @@ from tvm import relay
 from tvm.relay.testing import check_grad
 
 
-def verify_sum_grad(d_shape, axis=None, keepdims=False, exclude=False):
+def verify_reduction_grad(red_fn, d_shape, axis=None, keepdims=False, exclude=False):
     data = relay.var("data", relay.TensorType(d_shape, "float32"))
-    fwd_func = relay.Function([data], relay.sum(data, axis=axis, keepdims=keepdims, exclude=exclude))
+    fwd_func = relay.Function([data], red_fn(data, axis=axis, keepdims=keepdims, exclude=exclude))
     check_grad(fwd_func)
 
 
-def test_sum_grad():
-    verify_sum_grad((4, 2))
-    verify_sum_grad((4, 2), axis=-1, keepdims=True)
-    verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True)
-    verify_sum_grad((4, 2, 1), axis=1)
+def test_reduction_grad():
+    for op in (relay.sum, relay.variance, relay.mean):
+        verify_reduction_grad(op, (4, 2))
+        verify_reduction_grad(op, (4, 2), axis=-1, keepdims=True)
+        verify_reduction_grad(op, (4, 2, 1), axis=(1, 2), exclude=True)
+        verify_reduction_grad(op, (4, 2, 1), axis=1)
 
 
 def verify_max_grad(d_shape, axis=None, keepdims=False, exclude=False):