[Relay] Fix bug in transpose_shape_func (#6180)
authorlixiaoquan <radioheads@163.com>
Fri, 31 Jul 2020 15:06:37 +0000 (23:06 +0800)
committerGitHub <noreply@github.com>
Fri, 31 Jul 2020 15:06:37 +0000 (08:06 -0700)
python/tvm/relay/op/_transform.py
tests/python/relay/test_any.py

index a2c374d..4e113f7 100644 (file)
@@ -540,9 +540,10 @@ def transpose_shape_func(attrs, inputs, _):
     if axes is None:
         axes = list(range(inputs[0].shape[0].value))
         axes.reverse()
+    axes = list(axes)
     for i, axis in enumerate(axes):
         if axis < 0:
-            axes[i] = inputs[0].shape[0] - axis
+            axes[i] = inputs[0].shape[0] + axis
     return [_transpose_shape_func(inputs[0], convert(axes))]
 
 @script
index 0e8a328..f9d9c93 100644 (file)
@@ -363,6 +363,7 @@ def test_any_transpose():
     verify_any_transpose(any_dims(3), (1, 0, 2), (10, 3, 2))
     verify_any_transpose(any_dims(3), None, (2, 3, 4))
     verify_any_transpose(any_dims(6), (0, 1, 3, 2, 5, 4), (11, 12, 2, 1, 9, 17))
+    verify_any_transpose(any_dims(2), (-1, 0), (3, 2))
 
 def verify_any_squeeze(data_shape, axis, static_data_shape):
     mod = tvm.IRModule()