From: lixiaoquan Date: Fri, 31 Jul 2020 15:06:37 +0000 (+0800) Subject: [Relay] Fix bug in transpose_shape_func (#6180) X-Git-Tag: upstream/0.7.0~327 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=33a3d7aee7534f38f4a5533f0dab03a0bcfec7a0;p=platform%2Fupstream%2Ftvm.git [Relay] Fix bug in transpose_shape_func (#6180) --- diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index a2c374d..4e113f7 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -540,9 +540,10 @@ def transpose_shape_func(attrs, inputs, _): if axes is None: axes = list(range(inputs[0].shape[0].value)) axes.reverse() + axes = list(axes) for i, axis in enumerate(axes): if axis < 0: - axes[i] = inputs[0].shape[0] - axis + axes[i] = inputs[0].shape[0] + axis return [_transpose_shape_func(inputs[0], convert(axes))] @script diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 0e8a328..f9d9c93 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -363,6 +363,7 @@ def test_any_transpose(): verify_any_transpose(any_dims(3), (1, 0, 2), (10, 3, 2)) verify_any_transpose(any_dims(3), None, (2, 3, 4)) verify_any_transpose(any_dims(6), (0, 1, 3, 2, 5, 4), (11, 12, 2, 1, 9, 17)) + verify_any_transpose(any_dims(2), (-1, 0), (3, 2)) def verify_any_squeeze(data_shape, axis, static_data_shape): mod = tvm.IRModule()