save (#3901)
author雾雨魔理沙 <lolisa@marisa.moe>
Fri, 6 Sep 2019 22:17:37 +0000 (15:17 -0700)
committerJared Roesch <roeschinc@gmail.com>
Fri, 6 Sep 2019 22:17:37 +0000 (15:17 -0700)
python/tvm/relay/op/_tensor_grad.py
tests/python/relay/test_op_grad_level3.py
tests/python/relay/test_op_level3.py

index 89f4ca8..08624e1 100644 (file)
@@ -290,6 +290,12 @@ def dense_grad(orig, grad):
             collapse_sum_like(data * transpose(grad), weight)]
 
 
+@register_gradient("reshape")
+def reshape_grad(orig, grad):
+    """Gradient of reshape"""
+    return [reshape_like(grad, orig.args[0])]
+
+
 @register_gradient("nn.batch_flatten")
 def batch_flatten_grad(orig, grad):
     """Returns grad reshaped to data dims"""
index 9324555..cc57361 100644 (file)
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import numpy as np
+import pytest
 
 import tvm
 from tvm import relay
@@ -58,6 +59,4 @@ def test_negative_grad():
 
 
 if __name__ == "__main__":
-    test_clip()
-    test_transpose_grad()
-    test_negative_grad()
+    pytest.main()
index f1d91a2..03fe7f7 100644 (file)
@@ -21,7 +21,7 @@ from nose.tools import raises
 import tvm
 from tvm import relay
 from tvm.relay import create_executor, transform
-from tvm.relay.testing import ctx_list
+from tvm.relay.testing import ctx_list, check_grad
 
 def run_infer_type(expr):
     mod = relay.Module.from_expr(expr)
@@ -247,6 +247,7 @@ def test_reshape():
         assert zz.checked_type == relay.ty.TensorType(oshape, "float32")
 
         func = relay.Function([x], z)
+        check_grad(func)
         x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
         ref_res = np.reshape(x_data, oshape)
         for target, ctx in ctx_list():