From: 雾雨魔理沙 Date: Fri, 6 Sep 2019 22:17:37 +0000 (-0700) Subject: save (#3901) X-Git-Tag: upstream/0.7.0~1942 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=02ddb5a9c340da5a3a20d1d2cc6bf90b9f63a143;p=platform%2Fupstream%2Ftvm.git save (#3901) --- diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 89f4ca8..08624e1 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -290,6 +290,12 @@ def dense_grad(orig, grad): collapse_sum_like(data * transpose(grad), weight)] +@register_gradient("reshape") +def reshape_grad(orig, grad): + """Gradient of reshape""" + return [reshape_like(grad, orig.args[0])] + + @register_gradient("nn.batch_flatten") def batch_flatten_grad(orig, grad): """Returns grad reshaped to data dims""" diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index 9324555..cc57361 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm from tvm import relay @@ -58,6 +59,4 @@ def test_negative_grad(): if __name__ == "__main__": - test_clip() - test_transpose_grad() - test_negative_grad() + pytest.main() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index f1d91a2..03fe7f7 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -21,7 +21,7 @@ from nose.tools import raises import tvm from tvm import relay from tvm.relay import create_executor, transform -from tvm.relay.testing import ctx_list +from tvm.relay.testing import ctx_list, check_grad def run_infer_type(expr): mod = relay.Module.from_expr(expr) @@ -247,6 +247,7 @@ def test_reshape(): assert zz.checked_type == relay.ty.TensorType(oshape, "float32") func = relay.Function([x], z) + check_grad(func) x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") ref_res = np.reshape(x_data, oshape) for target, ctx in ctx_list():