[Relay][Training] Add and fix gradients (#4126)
authorAltan Haan <altanh@cs.washington.edu>
Wed, 16 Oct 2019 14:32:29 +0000 (07:32 -0700)
committerWuwei Lin <wuwei@apache.org>
Wed, 16 Oct 2019 14:32:29 +0000 (10:32 -0400)
* add and fix gradients

* fix linter issues

python/tvm/relay/op/_tensor_grad.py
tests/python/relay/test_op_grad_level2.py
tests/python/relay/test_op_grad_level4.py

index 3a82e46..1c94162 100644 (file)
@@ -48,6 +48,9 @@ from .transform import (
     tile,
     transpose,
     where,
+    repeat,
+    expand_dims,
+    full_like
 )
 
 
@@ -198,6 +201,7 @@ def clip_grad(orig, grad):
 
 @register_gradient("nn.max_pool2d")
 def max_pool2d_grad(orig, grad):
+    """Returns the gradient of max_pool2d."""
     attrs = orig.attrs
     pool_grad = _nn.max_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size,
                                     strides=attrs.strides, padding=attrs.padding,
@@ -207,6 +211,7 @@ def max_pool2d_grad(orig, grad):
 
 @register_gradient("nn.avg_pool2d")
 def avg_pool2d_grad(orig, grad):
+    """Returns the gradient of avg_pool2d."""
     attrs = orig.attrs
     pool_grad = _nn.avg_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size,
                                     strides=attrs.strides, padding=attrs.padding,
@@ -215,6 +220,26 @@ def avg_pool2d_grad(orig, grad):
     return [pool_grad]
 
 
+@register_gradient("nn.global_avg_pool2d")
+def global_avg_pool2d_grad(orig, grad):
+    """Returns the gradient of global_avg_pool2d."""
+    data = orig.args[0]
+    shape = data.checked_type.shape
+    layout = orig.attrs.layout
+
+    # we assume NCHW or NHWC layout for now, but easy to add more
+    assert layout in ["NCHW", "NHWC"]
+    if layout == "NCHW":
+        pool_size = shape[2], shape[3]
+    elif layout == "NHWC":
+        pool_size = shape[1], shape[2]
+
+    pool_grad = _nn.avg_pool2d_grad(grad, data, pool_size=pool_size,
+                                    strides=(1, 1), padding=(0, 0),
+                                    layout=layout)
+    return [pool_grad]
+
+
 # not implemented, this is only for testing.
 @register_gradient("concatenate")
 def concatenate_grad(orig, grad):
@@ -287,16 +312,53 @@ def conv2d_grad(orig, grad):
     return [backward_data, backward_weight]
 
 
+def _get_reduce_axis(call):
+    """Helper function that returns the reduce axis of the call as plain python ints."""
+    x, axis = call.args[0], call.attrs.axis
+    shape = x.checked_type.concrete_shape
+
+    # should never exclude when axis is None
+    assert not (axis is None and call.attrs.exclude)
+
+    if axis is None:
+        return None
+
+    # convert to nonnegative integers and sort
+    axis = sorted([ax if ax >= 0 else len(shape) + ax for ax in map(int, axis)])
+    if call.attrs.exclude:
+        axis = [ax for ax in range(len(shape)) if ax not in axis]
+    return axis
+
+
+def _unreduce_expand(x, axis):
+    """Helper function that returns x expanded on the reduced dimensions in axis."""
+    # assume axis is sorted nonnegative ints
+    for ax in axis:
+        x = expand_dims(x, ax)
+    return x
+
+
 @register_gradient("max")
 def max_grad(orig, grad):
     """Returns the gradient of max"""
-    # Only support axis=0, since broadcasting orig to x behaves incorrectly
-    x, axis = orig.args[0], orig.attrs.axis
-    assert(axis is not None and len(axis) == 1 and int(axis[0]) == 0)
-    orig = broadcast_to_like(orig, x)
-    grad = broadcast_to_like(grad, x)
-    indicators = cast_like(equal(orig, x), grad)
-    return [indicators * grad]
+    x, axis = orig.args[0], _get_reduce_axis(orig)
+    shape = x.checked_type.concrete_shape
+
+    repeated = orig
+    if axis is None:
+        repeated = full_like(x, repeated)
+    else:
+        # expand dims (if necessary) and repeat along each axis
+        if not orig.attrs.keepdims:
+            repeated = _unreduce_expand(repeated, axis)
+            grad = _unreduce_expand(grad, axis)
+        for ax in axis:
+            repeated = repeat(repeated, shape[ax], ax)
+
+    indicators = cast_like(equal(repeated, x), grad)
+    num_selected = _sum(indicators, axis, keepdims=True)
+    # spread error across all max weights
+    return [indicators * grad / num_selected]
 
 
 @register_gradient("nn.softmax")
@@ -372,7 +434,11 @@ def negative_grad(orig, grad):
 @register_gradient("sum")
 def sum_grad(orig, grad):
     """Returns grad broadcasted to data dims"""
-    data = orig.args[0]
+    data, axis = orig.args[0], _get_reduce_axis(orig)
+    if not orig.attrs.keepdims:
+        if axis is None:
+            axis = list(range(len(data.checked_type.concrete_shape)))
+        grad = _unreduce_expand(grad, axis)
     return [broadcast_to_like(grad, data)]
 
 
index 8e80925..57b1e2c 100644 (file)
@@ -48,8 +48,7 @@ def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode):
 
 
 def test_max_pool2d_grad():
-    verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0),
-                           ceil_mode=False)
+    verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False)
     verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1), ceil_mode=False)
 
 
@@ -75,7 +74,6 @@ def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, coun
         op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
         np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
 
-
 def test_avg_pool2d_grad():
     verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0),
                            ceil_mode=False, count_include_pad=True)
@@ -83,6 +81,30 @@ def test_avg_pool2d_grad():
                            ceil_mode=False, count_include_pad=False)
 
 
+def verify_global_avg_pool2d_grad(x_shape):
+    x = relay.var("x", relay.TensorType(x_shape, "float32"))
+    y = tvm.relay.nn.global_avg_pool2d(x)
+
+    fwd_func = relay.Function([x], y)
+    fwd_func = run_infer_type(fwd_func)
+    bwd_func = run_infer_type(gradient(fwd_func))
+
+    data = np.random.rand(*x_shape).astype("float32")
+    y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
+    out_grad = np.ones(shape=y_shape)
+    ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=(x_shape[2], x_shape[3]), 
+                                            strides=(1, 1), padding=[0, 0, 0, 0], pool_type='avg', 
+                                            ceil_mode=False)
+
+    for target, ctx in ctx_list():
+        intrp = relay.create_executor(ctx=ctx, target=target)
+        op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
+        np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
+
+def test_global_avg_pool2d_grad():
+    verify_global_avg_pool2d_grad((1, 4, 16, 16))
+    verify_global_avg_pool2d_grad((1, 8, 8, 24))
+
 def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mode='higher_order'):
     try:
         import torch
@@ -155,6 +177,7 @@ def test_batch_flatten_grad():
 if __name__ == "__main__":
     test_max_pool2d_grad()
     test_avg_pool2d_grad()
+    test_global_avg_pool2d_grad()
     test_conv2d_grad()
     test_dense_grad()
     test_batch_flatten_grad()
index f8d6c3a..f690a18 100644 (file)
@@ -29,18 +29,21 @@ def test_sum_grad():
     verify_sum_grad((4, 2))
     verify_sum_grad((4, 2), axis=-1, keepdims=True)
     verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True)
+    verify_sum_grad((4, 2, 1), axis=1)
 
 
-def test_max_grad():
-    s = (10, 10)
-    t = relay.TensorType(s)
-    x = relay.var("x", t)
-    axis = 0
-    z = relay.max(x, axis)
-
-    fwd_func = relay.Function([x], z)
+def verify_max_grad(d_shape, axis=None, keepdims=False, exclude=False):
+    data = relay.var("data", relay.TensorType(d_shape, "float32"))
+    fwd_func = relay.Function([data], relay.max(data, axis=axis, keepdims=keepdims, exclude=exclude))
     check_grad(fwd_func, scale=1e-3)
 
 
+def test_max_grad():
+    verify_max_grad((10, 10), axis=None)
+    verify_max_grad((10, 10), axis=-1)
+    verify_max_grad((6, 3, 2), axis=(1, 2), keepdims=True)
+    verify_max_grad((5, 4, 3), axis=(0, 2), exclude=True)
+
+
 if __name__ == "__main__":
     pytest.main()