From 02ddb5a9c340da5a3a20d1d2cc6bf90b9f63a143 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Fri, 6 Sep 2019 15:17:37 -0700 Subject: [PATCH] save (#3901) --- python/tvm/relay/op/_tensor_grad.py | 6 ++++++ tests/python/relay/test_op_grad_level3.py | 5 ++--- tests/python/relay/test_op_level3.py | 3 ++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 89f4ca8..08624e1 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -290,6 +290,12 @@ def dense_grad(orig, grad): collapse_sum_like(data * transpose(grad), weight)] +@register_gradient("reshape") +def reshape_grad(orig, grad): + """Gradient of reshape""" + return [reshape_like(grad, orig.args[0])] + + @register_gradient("nn.batch_flatten") def batch_flatten_grad(orig, grad): """Returns grad reshaped to data dims""" diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index 9324555..cc57361 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm from tvm import relay @@ -58,6 +59,4 @@ def test_negative_grad(): if __name__ == "__main__": - test_clip() - test_transpose_grad() - test_negative_grad() + pytest.main() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index f1d91a2..03fe7f7 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -21,7 +21,7 @@ from nose.tools import raises import tvm from tvm import relay from tvm.relay import create_executor, transform -from tvm.relay.testing import ctx_list +from tvm.relay.testing import ctx_list, check_grad def run_infer_type(expr): mod = relay.Module.from_expr(expr) @@ -247,6 +247,7 @@ def test_reshape(): assert zz.checked_type == relay.ty.TensorType(oshape, "float32") func = relay.Function([x], z) + check_grad(func) x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") ref_res = np.reshape(x_data, oshape) for target, ctx in ctx_list(): -- 2.7.4