if axes is None:
axes = list(range(inputs[0].shape[0].value))
axes.reverse()
+ axes = list(axes)
for i, axis in enumerate(axes):
if axis < 0:
- axes[i] = inputs[0].shape[0] - axis
+ axes[i] = inputs[0].shape[0] + axis
return [_transpose_shape_func(inputs[0], convert(axes))]
@script
verify_any_transpose(any_dims(3), (1, 0, 2), (10, 3, 2))
verify_any_transpose(any_dims(3), None, (2, 3, 4))
verify_any_transpose(any_dims(6), (0, 1, 3, 2, 5, 4), (11, 12, 2, 1, 9, 17))
+ verify_any_transpose(any_dims(2), (-1, 0), (3, 2))
def verify_any_squeeze(data_shape, axis, static_data_shape):
mod = tvm.IRModule()