tile,
transpose,
where,
+ repeat,
+ expand_dims,
+ full_like
)
@register_gradient("nn.max_pool2d")
def max_pool2d_grad(orig, grad):
+ """Returns the gradient of max_pool2d."""
attrs = orig.attrs
pool_grad = _nn.max_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size,
strides=attrs.strides, padding=attrs.padding,
@register_gradient("nn.avg_pool2d")
def avg_pool2d_grad(orig, grad):
+ """Returns the gradient of avg_pool2d."""
attrs = orig.attrs
pool_grad = _nn.avg_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size,
strides=attrs.strides, padding=attrs.padding,
return [pool_grad]
+@register_gradient("nn.global_avg_pool2d")
+def global_avg_pool2d_grad(orig, grad):
+ """Returns the gradient of global_avg_pool2d."""
+ data = orig.args[0]
+ shape = data.checked_type.shape
+ layout = orig.attrs.layout
+
+ # we assume NCHW or NHWC layout for now, but easy to add more
+ assert layout in ["NCHW", "NHWC"]
+ if layout == "NCHW":
+ pool_size = shape[2], shape[3]
+ elif layout == "NHWC":
+ pool_size = shape[1], shape[2]
+
+ pool_grad = _nn.avg_pool2d_grad(grad, data, pool_size=pool_size,
+ strides=(1, 1), padding=(0, 0),
+ layout=layout)
+ return [pool_grad]
+
+
# not implemented, this is only for testing.
@register_gradient("concatenate")
def concatenate_grad(orig, grad):
return [backward_data, backward_weight]
+def _get_reduce_axis(call):
+ """Helper function that returns the reduce axis of the call as plain python ints."""
+ x, axis = call.args[0], call.attrs.axis
+ shape = x.checked_type.concrete_shape
+
+ # should never exclude when axis is None
+ assert not (axis is None and call.attrs.exclude)
+
+ if axis is None:
+ return None
+
+ # convert to nonnegative integers and sort
+ axis = sorted([ax if ax >= 0 else len(shape) + ax for ax in map(int, axis)])
+ if call.attrs.exclude:
+ axis = [ax for ax in range(len(shape)) if ax not in axis]
+ return axis
+
+
+def _unreduce_expand(x, axis):
+ """Helper function that returns x expanded on the reduced dimensions in axis."""
+ # assume axis is sorted nonnegative ints
+ for ax in axis:
+ x = expand_dims(x, ax)
+ return x
+
+
@register_gradient("max")
def max_grad(orig, grad):
"""Returns the gradient of max"""
- # Only support axis=0, since broadcasting orig to x behaves incorrectly
- x, axis = orig.args[0], orig.attrs.axis
- assert(axis is not None and len(axis) == 1 and int(axis[0]) == 0)
- orig = broadcast_to_like(orig, x)
- grad = broadcast_to_like(grad, x)
- indicators = cast_like(equal(orig, x), grad)
- return [indicators * grad]
+ x, axis = orig.args[0], _get_reduce_axis(orig)
+ shape = x.checked_type.concrete_shape
+
+ repeated = orig
+ if axis is None:
+ repeated = full_like(x, repeated)
+ else:
+ # expand dims (if necessary) and repeat along each axis
+ if not orig.attrs.keepdims:
+ repeated = _unreduce_expand(repeated, axis)
+ grad = _unreduce_expand(grad, axis)
+ for ax in axis:
+ repeated = repeat(repeated, shape[ax], ax)
+
+ indicators = cast_like(equal(repeated, x), grad)
+ num_selected = _sum(indicators, axis, keepdims=True)
+ # spread error across all max weights
+ return [indicators * grad / num_selected]
@register_gradient("nn.softmax")
@register_gradient("sum")
def sum_grad(orig, grad):
"""Returns grad broadcasted to data dims"""
- data = orig.args[0]
+ data, axis = orig.args[0], _get_reduce_axis(orig)
+ if not orig.attrs.keepdims:
+ if axis is None:
+ axis = list(range(len(data.checked_type.concrete_shape)))
+ grad = _unreduce_expand(grad, axis)
return [broadcast_to_like(grad, data)]
def test_max_pool2d_grad():
- verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0),
- ceil_mode=False)
+ verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False)
verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1), ceil_mode=False)
op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
-
def test_avg_pool2d_grad():
verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0),
ceil_mode=False, count_include_pad=True)
ceil_mode=False, count_include_pad=False)
+def verify_global_avg_pool2d_grad(x_shape):
+ x = relay.var("x", relay.TensorType(x_shape, "float32"))
+ y = tvm.relay.nn.global_avg_pool2d(x)
+
+ fwd_func = relay.Function([x], y)
+ fwd_func = run_infer_type(fwd_func)
+ bwd_func = run_infer_type(gradient(fwd_func))
+
+ data = np.random.rand(*x_shape).astype("float32")
+ y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
+ out_grad = np.ones(shape=y_shape)
+ ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=(x_shape[2], x_shape[3]),
+ strides=(1, 1), padding=[0, 0, 0, 0], pool_type='avg',
+ ceil_mode=False)
+
+ for target, ctx in ctx_list():
+ intrp = relay.create_executor(ctx=ctx, target=target)
+ op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
+ np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
+
+def test_global_avg_pool2d_grad():
+ verify_global_avg_pool2d_grad((1, 4, 16, 16))
+ verify_global_avg_pool2d_grad((1, 8, 8, 24))
+
def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mode='higher_order'):
try:
import torch
if __name__ == "__main__":
test_max_pool2d_grad()
test_avg_pool2d_grad()
+ test_global_avg_pool2d_grad()
test_conv2d_grad()
test_dense_grad()
test_batch_flatten_grad()
verify_sum_grad((4, 2))
verify_sum_grad((4, 2), axis=-1, keepdims=True)
verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True)
+ verify_sum_grad((4, 2, 1), axis=1)
-def test_max_grad():
- s = (10, 10)
- t = relay.TensorType(s)
- x = relay.var("x", t)
- axis = 0
- z = relay.max(x, axis)
-
- fwd_func = relay.Function([x], z)
+def verify_max_grad(d_shape, axis=None, keepdims=False, exclude=False):
+ data = relay.var("data", relay.TensorType(d_shape, "float32"))
+ fwd_func = relay.Function([data], relay.max(data, axis=axis, keepdims=keepdims, exclude=exclude))
check_grad(fwd_func, scale=1e-3)
+def test_max_grad():
+ verify_max_grad((10, 10), axis=None)
+ verify_max_grad((10, 10), axis=-1)
+ verify_max_grad((6, 3, 2), axis=(1, 2), keepdims=True)
+ verify_max_grad((5, 4, 3), axis=(0, 2), exclude=True)
+
+
if __name__ == "__main__":
pytest.main()