[Frontend|MXNet] SwapAxis operator support (#5246)
authorMahesh Ambule <15611578+maheshambule@users.noreply.github.com>
Tue, 14 Apr 2020 06:09:21 +0000 (11:39 +0530)
committerGitHub <noreply@github.com>
Tue, 14 Apr 2020 06:09:21 +0000 (23:09 -0700)
* MXNet swap axis

* MXNet swap axis

* swap axis review comment

* swap axis review comment

python/tvm/relay/frontend/mxnet.py
tests/python/frontend/mxnet/test_forward.py

index 5c8e726..e6aa5f1 100644 (file)
@@ -127,6 +127,17 @@ def _mx_unravel_index(inputs, attrs):
     return _op.unravel_index(inputs[0], shape_expr)
 
 
+def _mx_swap_axis(inputs, attrs):
+    assert len(inputs) == 1
+    dim1 = attrs.get_int('dim1')
+    dim2 = attrs.get_int('dim2')
+    shape = _infer_type(inputs[0]).checked_type.shape
+    axes = list(range(len(shape)))
+    axes[dim1] = dim2
+    axes[dim2] = dim1
+    return _op.transpose(inputs[0], axes=axes)
+
+
 def _mx_zeros(inputs, attrs):
     assert len(inputs) == 0
     shape = attrs.get_int_tuple("shape")
@@ -1813,6 +1824,7 @@ _convert_map = {
     "slice_axis"    : _mx_slice_axis,
     "SliceChannel"  : _mx_split,
     "split"         : _mx_split,
+    "SwapAxis"      : _mx_swap_axis,
     "expand_dims"   : _mx_expand_dims,
     "Concat"        : _mx_concat,
     "concat"        : _mx_concat,
index eb308c5..4a9848e 100644 (file)
@@ -983,6 +983,18 @@ def test_forward_unravel_index():
     # verify([0, 1, 2, 5], [2, 2], dtype)
 
 
+def test_forward_swap_axis():
+    def _verify_swap_axis(in_shape, out_shape, dim1, dim2):
+        data = mx.sym.var('data')
+        mx_sym = mx.sym.swapaxes(data, dim1, dim2)
+        verify_mxnet_frontend_impl(mx_sym, in_shape, out_shape)
+
+    _verify_swap_axis((4, 5), (5, 4), 0, 1)
+    _verify_swap_axis((2, 4, 4, 5), (2, 5, 4, 4), 1, 3)
+    # MXNet errors out when dim1 == dim2
+    # _verify_swap_axis((4, 5), (5, 4), 0, 0)
+
+
 if __name__ == '__main__':
     test_forward_mlp()
     test_forward_vgg()
@@ -1040,3 +1052,4 @@ if __name__ == '__main__':
     test_forward_cond()
     test_forward_make_loss()
     test_forward_unravel_index()
+    test_forward_swap_axis()
\ No newline at end of file