From: Mahesh Ambule <15611578+maheshambule@users.noreply.github.com> Date: Tue, 14 Apr 2020 06:09:21 +0000 (+0530) Subject: [Frontend|MXNet] SwapAxis operator support (#5246) X-Git-Tag: upstream/0.7.0~916 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b7545eb5ca87507ea04ccbe96c1a02040bef26be;p=platform%2Fupstream%2Ftvm.git [Frontend|MXNet] SwapAxis operator support (#5246) * MXNet swap axis * MXNet swap axis * swap axis review comment * swap axis review comment --- diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 5c8e726..e6aa5f1 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -127,6 +127,17 @@ def _mx_unravel_index(inputs, attrs): return _op.unravel_index(inputs[0], shape_expr) +def _mx_swap_axis(inputs, attrs): + assert len(inputs) == 1 + dim1 = attrs.get_int('dim1') + dim2 = attrs.get_int('dim2') + shape = _infer_type(inputs[0]).checked_type.shape + axes = list(range(len(shape))) + axes[dim1] = dim2 + axes[dim2] = dim1 + return _op.transpose(inputs[0], axes=axes) + + def _mx_zeros(inputs, attrs): assert len(inputs) == 0 shape = attrs.get_int_tuple("shape") @@ -1813,6 +1824,7 @@ _convert_map = { "slice_axis" : _mx_slice_axis, "SliceChannel" : _mx_split, "split" : _mx_split, + "SwapAxis" : _mx_swap_axis, "expand_dims" : _mx_expand_dims, "Concat" : _mx_concat, "concat" : _mx_concat, diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index eb308c5..4a9848e 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -983,6 +983,18 @@ def test_forward_unravel_index(): # verify([0, 1, 2, 5], [2, 2], dtype) +def test_forward_swap_axis(): + def _verify_swap_axis(in_shape, out_shape, dim1, dim2): + data = mx.sym.var('data') + mx_sym = mx.sym.swapaxes(data, dim1, dim2) + verify_mxnet_frontend_impl(mx_sym, in_shape, out_shape) + + _verify_swap_axis((4, 5), (5, 4), 0, 1) + _verify_swap_axis((2, 4, 4, 5), (2, 5, 4, 4), 1, 3) + # MXNet errors out when dim1 == dim2 + # _verify_swap_axis((4, 5), (5, 4), 0, 0) + + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -1040,3 +1052,4 @@ if __name__ == '__main__': test_forward_cond() test_forward_make_loss() test_forward_unravel_index() + test_forward_swap_axis() \ No newline at end of file