[Relay][Training] Small refactoring (#3893)
author雾雨魔理沙 <lolisa@marisa.moe>
Thu, 5 Sep 2019 18:13:07 +0000 (11:13 -0700)
committerJared Roesch <roeschinc@gmail.com>
Thu, 5 Sep 2019 18:13:07 +0000 (11:13 -0700)
* init

* fix

python/tvm/relay/op/_tensor_grad.py

index b55f517..89f4ca8 100644 (file)
@@ -44,6 +44,7 @@ def log_grad(orig, grad):
     x = orig.args[0]
     return [grad * ones_like(x) / x]
 
+
 @register_gradient("cos")
 def cos_grad(orig, grad):
     """Returns [grad * (-sin(x))]"""
@@ -51,12 +52,14 @@ def cos_grad(orig, grad):
     ones = ones_like(x)
     return [grad * (-ones * sin(x))]
 
+
 @register_gradient("sin")
 def sin_grad(orig, grad):
     """Returns [grad * cos(x)]"""
     x = orig.args[0]
     return [grad * cos(x)]
 
+
 @register_gradient("exp")
 def exp_grad(orig, grad):
     """Returns [grad * exp(x)]"""
@@ -173,6 +176,7 @@ def clip_grad(orig, grad):
     ones = ones_like(x)
     return [where(less(x, a_mins), zeros, where(less(a_maxs, x), zeros, ones * grad))]
 
+
 @register_gradient("nn.max_pool2d")
 def max_pool2d_grad(orig, grad):
     attrs = orig.attrs
@@ -181,6 +185,7 @@ def max_pool2d_grad(orig, grad):
                                     layout=attrs.layout, ceil_mode=attrs.ceil_mode)
     return [pool_grad]
 
+
 @register_gradient("nn.avg_pool2d")
 def avg_pool2d_grad(orig, grad):
     attrs = orig.attrs
@@ -190,6 +195,7 @@ def avg_pool2d_grad(orig, grad):
                                     count_include_pad=attrs.count_include_pad)
     return [pool_grad]
 
+
 # not implemented, this is only for testing.
 @register_gradient("concatenate")
 def concatenate_grad(orig, grad):
@@ -201,6 +207,7 @@ def concatenate_grad(orig, grad):
     # In the real implementation, concatenate_grad probably need to be implemented by an operator.
     return [Tuple([zeros_like(x), zeros_like(y)])]
 
+
 @register_gradient("nn.conv2d")
 def conv2d_grad(orig, grad):
     """Gradient of conv2d"""
@@ -268,8 +275,8 @@ def softmax_grad(orig, grad):
 
 
 @register_gradient("nn.bias_add")
-def bias_grad(orig, grad):
-    """Returns grad"""
+def bias_add_grad(orig, grad):
+    """Returns gradient of bias_add"""
     data, bias = orig.args
     return [collapse_sum_like(grad, data),
             collapse_sum_like(grad, bias)]