[Relay][Frontend] Add slice axis op in mxnet converter (#2706)
authorHaichen Shen <shenhaichen@gmail.com>
Mon, 4 Mar 2019 05:14:14 +0000 (21:14 -0800)
committerGitHub <noreply@github.com>
Mon, 4 Mar 2019 05:14:14 +0000 (21:14 -0800)
* Add slice axis op in mxnet converter

* Fix lint

python/tvm/relay/frontend/mxnet.py
tests/python/frontend/mxnet/test_forward.py

index 1f1d18e240cd1795e1447b177a7c35faa78683d7..4d341c76043a83d748c18ba3c35ccff386affe5e 100644 (file)
@@ -194,6 +194,34 @@ def _mx_slice(inputs, attrs):
     return _op.strided_slice(inputs[0], **new_attrs)
 
 
+def _mx_slice_axis(inputs, attrs):
+    assert len(inputs) == 1
+    shape = ir_pass.infer_type(inputs[0]).checked_type.shape
+    axis = attrs.get_int("axis")
+    ax_beg = attrs.get_int("begin")
+    ax_end = attrs.get_str("end")
+    if ax_end == "None":
+        ax_end = int(shape[axis])
+    else:
+        ax_end = int(ax_end)
+    if ax_beg < 0:
+        ax_beg += int(shape[axis])
+    if ax_end < 0:
+        ax_end += int(shape[axis])
+    assert ax_beg >= 0 and ax_beg < int(shape[axis])
+    assert ax_end > ax_beg and ax_end <= int(shape[axis])
+    begin = []
+    end = []
+    for i, dim in enumerate(shape):
+        if i != axis:
+            begin.append(0)
+            end.append(dim)
+        else:
+            begin.append(ax_beg)
+            end.append(ax_end)
+    return _op.strided_slice(inputs[0], begin, end)
+
+
 def _mx_split(inputs, attrs):
     axis = attrs.get_int("axis", 1)
     new_attrs = {}
@@ -423,6 +451,7 @@ _convert_map = {
     "BatchNorm_v1"  : _mx_batch_norm,
     "LRN"           : _mx_lrn,
     "slice"         : _mx_slice,
+    "slice_axis"    : _mx_slice_axis,
     "SliceChannel"  : _mx_split,
     "split"         : _mx_split,
     "expand_dims"   : _mx_expand_dims,
index ee47d72046ed4792f1e5f9df9b84a74bda01ea59..7f53aa8a015573922c8d4e4c7a1de9b5466599ba 100644 (file)
@@ -337,6 +337,23 @@ def test_forward_scalar_ops():
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
 
 
+def test_forward_slice_axis():
+    def verify(shape, axis, begin, end):
+        data_np = np.random.uniform(size=shape).astype("float32")
+        ref_res = mx.nd.slice_axis(mx.nd.array(data_np), axis, begin, end)
+        mx_sym = mx.sym.slice_axis(mx.sym.var("data"), axis, begin, end)
+        new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape})
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(new_sym)(data_np)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+    verify((3, 4), 0, 1, 2)
+    verify((3, 4), 0, 1, None)
+    verify((3, 4), 1, 0, 2)
+    verify((3, 4), 1, -3, -1)
+
+
 if __name__ == '__main__':
     test_forward_mlp()
     test_forward_vgg()
@@ -363,3 +380,4 @@ if __name__ == '__main__':
     test_forward_broadcast_ops()
     test_forward_elemwise_ops()
     test_forward_scalar_ops()
+    test_forward_slice_axis()
\ No newline at end of file