@register_gradient("nn.bias_add")
def bias_add_grad(orig, grad):
"""Returns gradient of bias_add"""
- data, bias = orig.args
+ data = orig.args[0]
return [collapse_sum_like(grad, data),
- collapse_sum_like(grad, bias)]
+ _sum(grad, orig.attrs.axis, keepdims=False, exclude=True)]
@register_gradient("nn.dense")
check_grad(fwd_func, scale=1)
-def test_bias_add_grad():
- data = relay.var("data", relay.TensorType((1, 16), "float32"))
- bias = relay.var("bias", relay.TensorType((16,), "float32"))
- fwd_func = relay.Function([data, bias], relay.nn.bias_add(data, bias))
+def verify_bias_add(d_shape, b_shape, axis=1):
+ data = relay.var("data", relay.TensorType(d_shape, "float32"))
+ bias = relay.var("bias", relay.TensorType(b_shape, "float32"))
+ fwd_func = relay.Function([data, bias], relay.nn.bias_add(data, bias, axis=axis))
check_grad(fwd_func)
+def test_bias_add_grad():
+ verify_bias_add((1, 16), (16,))
+ verify_bias_add((1, 8, 2, 2), (8,))
+ verify_bias_add((1, 2, 2, 8), (8,), 3)
+ verify_bias_add((4, 8), (8,))
+
+
if __name__ == "__main__":
pytest.main([__file__])