From dee11b4198574455bc5ec2b7dd959cb182cd9e86 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Thu, 5 Sep 2019 11:13:07 -0700 Subject: [PATCH] [Relay][Training] Small refactoring (#3893) * init * fix --- python/tvm/relay/op/_tensor_grad.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index b55f517..89f4ca8 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -44,6 +44,7 @@ def log_grad(orig, grad): x = orig.args[0] return [grad * ones_like(x) / x] + @register_gradient("cos") def cos_grad(orig, grad): """Returns [grad * (-sin(x))]""" @@ -51,12 +52,14 @@ def cos_grad(orig, grad): ones = ones_like(x) return [grad * (-ones * sin(x))] + @register_gradient("sin") def sin_grad(orig, grad): """Returns [grad * cos(x)]""" x = orig.args[0] return [grad * cos(x)] + @register_gradient("exp") def exp_grad(orig, grad): """Returns [grad * exp(x)]""" @@ -173,6 +176,7 @@ def clip_grad(orig, grad): ones = ones_like(x) return [where(less(x, a_mins), zeros, where(less(a_maxs, x), zeros, ones * grad))] + @register_gradient("nn.max_pool2d") def max_pool2d_grad(orig, grad): attrs = orig.attrs @@ -181,6 +185,7 @@ def max_pool2d_grad(orig, grad): layout=attrs.layout, ceil_mode=attrs.ceil_mode) return [pool_grad] + @register_gradient("nn.avg_pool2d") def avg_pool2d_grad(orig, grad): attrs = orig.attrs @@ -190,6 +195,7 @@ def avg_pool2d_grad(orig, grad): count_include_pad=attrs.count_include_pad) return [pool_grad] + # not implemented, this is only for testing. @register_gradient("concatenate") def concatenate_grad(orig, grad): @@ -201,6 +207,7 @@ def concatenate_grad(orig, grad): # In the real implementation, concatenate_grad probably need to be implemented by an operator. return [Tuple([zeros_like(x), zeros_like(y)])] + @register_gradient("nn.conv2d") def conv2d_grad(orig, grad): """Gradient of conv2d""" @@ -268,8 +275,8 @@ def softmax_grad(orig, grad): @register_gradient("nn.bias_add") -def bias_grad(orig, grad): - """Returns grad""" +def bias_add_grad(orig, grad): + """Returns gradient of bias_add""" data, bias = orig.args return [collapse_sum_like(grad, data), collapse_sum_like(grad, bias)] -- 2.7.4