[Relay][Frontend] Add ops in mxnet converter (#2844)
authorHaichen Shen <shenhaichen@gmail.com>
Wed, 20 Mar 2019 08:09:24 +0000 (01:09 -0700)
committerLeyuan Wang <laurawly@gmail.com>
Wed, 20 Mar 2019 08:09:24 +0000 (01:09 -0700)
* Add ops in mxnet converter

* trigger ci

python/tvm/relay/frontend/mxnet.py
tests/python/frontend/mxnet/test_forward.py

index dae9729597757fc454222d5266dfc9a2e8833b00..758793c980d68175b46ee1ca658a6c53b7c27781 100644 (file)
@@ -213,7 +213,7 @@ def _mx_slice_axis(inputs, attrs):
     ax_end = attrs.get_str("end")
     if axis < 0:
         axis += len(shape)
-    assert axis >= 0 and axis < len(shape)
+    assert 0 <= axis < len(shape)
     if ax_end == "None":
         ax_end = int(shape[axis])
     else:
@@ -222,8 +222,8 @@ def _mx_slice_axis(inputs, attrs):
         ax_beg += int(shape[axis])
     if ax_end < 0:
         ax_end += int(shape[axis])
-    assert ax_beg >= 0 and ax_beg < int(shape[axis])
-    assert ax_end > ax_beg and ax_end <= int(shape[axis])
+    assert 0 <= ax_beg < int(shape[axis])
+    assert ax_beg < ax_end <= int(shape[axis])
     begin = []
     end = []
     for i, dim in enumerate(shape):
@@ -527,11 +527,53 @@ def _mx_shape_array(inputs, attrs):
     return _op.shape_of(inputs[0], dtype='int64')
 
 
+def _mx_full(inputs, attrs):
+    assert len(inputs) == 0
+    val = attrs.get_float("value")
+    shape = attrs.get_int_tuple("shape")
+    dtype = attrs.get_str("dtype", "float32")
+    return _op.full(_expr.const(val, dtype), shape, dtype)
+
+
+def _mx_squeeze(inputs, attrs):
+    assert len(inputs) == 1
+    axis = attrs.get_int_tuple("axis", None)
+    return _op.squeeze(inputs[0], axis)
+
+
+def _mx_broadcast_axis(inputs, attrs):
+    assert len(inputs) == 1
+    axis = attrs.get_int_tuple("axis", [])
+    size = attrs.get_int_tuple("size", [])
+    assert len(axis) == len(size)
+    if len(axis) == 0:
+        return inputs[0]
+    src_shape = ir_pass.infer_type(inputs[0])._checked_type_.shape
+    tgt_shape = []
+    for i, dim in enumerate(src_shape):
+        if i not in axis:
+            tgt_shape.append(dim)
+        else:
+            assert int(dim) == 1
+            idx = axis.index(i)
+            tgt_shape.append(size[idx])
+    return _op.broadcast_to(inputs[0], tgt_shape)
+
+
+def _mx_embedding(inputs, _):
+    assert len(inputs) == 2
+    indices, weight = inputs
+    return _op.take(weight, indices.astype('int32'), axis=0)
+
+
 # Note: due to attribute conversion constraint
 # ops in the identity set must be attribute free
 _identity_list = [
     "log",
     "exp",
+    "sqrt",
+    "floor",
+    "ceil",
     "sigmoid",
     "tanh",
     "negative",
@@ -567,7 +609,6 @@ _convert_map = {
     "Flatten"                : _rename(_op.nn.batch_flatten),
     # scalar power
     "square"                 : _mx_make_power(2),
-    "sqrt"                   : _mx_make_power(1/2),
     "rsqrt"                  : _mx_make_power(-1/2),
     "cbrt"                   : _mx_make_power(1/3),
     "rcbrt"                  : _mx_make_power(-1/3),
@@ -649,11 +690,15 @@ _convert_map = {
     "batch_dot"     : _mx_batch_dot,
     "LeakyReLU"     : _mx_leaky_relu,
     "_arange"       : _mx_arange,
+    "_full"         : _mx_full,
     "repeat"        : _mx_repeat,
     "tile"          : _mx_tile,
     "reverse"       : _mx_reverse,
+    "squeeze"       : _mx_squeeze,
+    "broadcast_axis": _mx_broadcast_axis,
     "BlockGrad"     : _mx_BlockGrad,
     "shape_array"   : _mx_shape_array,
+    "Embedding"     : _mx_embedding,
     "SoftmaxOutput" : _mx_softmax_output,
     "SoftmaxActivation" : _mx_softmax_activation,
     # vision
index e83f1e569545ee29499541bd2722bf0829327251..aad666ca75b4afd7fc56625e26eb26868add0160 100644 (file)
@@ -379,7 +379,6 @@ def test_forward_l2_normalize():
     mx_sym = mx.sym.L2Normalization(data, mode="channel")
     verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5))
 
-
 def test_forward_shape_array():
     def verify(shape):
         x_np = np.random.uniform(size=shape).astype("float32")
@@ -395,6 +394,75 @@ def test_forward_shape_array():
     verify((3, 4, 5))
     verify((3, 4, 5, 6))
 
+def test_forward_squeeze():
+    def verify(shape, axis):
+        x_np = np.random.uniform(size=shape).astype("float32")
+        if axis is None:
+            ref_res = mx.nd.squeeze(mx.nd.array(x_np))
+            mx_sym = mx.sym.squeeze(mx.sym.var("x"))
+        else:
+            ref_res = mx.nd.squeeze(mx.nd.array(x_np), axis=axis)
+            mx_sym = mx.sym.squeeze(mx.sym.var("x"), axis=axis)
+        new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(new_sym)(x_np)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+    verify((1, 3, 1), None)
+    verify((1, 3, 1), 0)
+    verify((1, 3, 1), 2)
+    verify((1, 3, 1), (0, 2))
+
+def test_forward_broadcast_axis():
+    def verify(shape, axis, size):
+        x_np = np.random.uniform(size=shape).astype("float32")
+        ref_res = mx.nd.broadcast_axis(mx.nd.array(x_np), axis=axis, size=size)
+        mx_sym = mx.sym.broadcast_axis(mx.sym.var("x"), axis=axis, size=size)
+        new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(new_sym)(x_np)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+    verify((1, 2, 1), 2, 3)
+    verify((1, 2, 1), (0, 2), (2, 3))
+
+def test_forward_full():
+    def verify(val, shape, dtype):
+        ctx = mx.cpu()
+        ref_res = mx.nd.full(shape, val, dtype=dtype)
+        mx_sym = mx.sym.full(shape, val, dtype=dtype)
+        new_sym, _ = relay.frontend.from_mxnet(mx_sym, {})
+        for target, ctx in ctx_list():
+            # Skip testing graph runtime because this op will be optimized out
+            # by constant folding.
+            for kind in ["debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(new_sym)()
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+    verify(2, (3, 4), "float32")
+    verify(2, (3, 4), "int32")
+    verify(3.5, (1, 3, 4), "float32")
+
+def test_forward_embedding():
+    def verify(data_shape, weight_shape):
+        in_dim, out_dim = weight_shape
+        x_np = np.random.randint(0, weight_shape[0], size=data_shape).astype("float32")
+        w_np = np.random.uniform(size=weight_shape).astype("float32")
+        ref_res = mx.nd.Embedding(mx.nd.array(x_np), mx.nd.array(w_np),
+                                  input_dim=in_dim, output_dim=out_dim)
+        mx_sym = mx.sym.Embedding(mx.sym.var("x"), mx.sym.var("w"),
+                                  input_dim=in_dim, output_dim=out_dim)
+        new_sym, _ = relay.frontend.from_mxnet(
+            mx_sym, {"x": data_shape, "w": weight_shape})
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(new_sym)(x=x_np, w=w_np)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+    verify((2, 2), (4, 5))
+    verify((2, 3, 4), (4, 5))
 
 if __name__ == '__main__':
     test_forward_mlp()
@@ -426,3 +494,7 @@ if __name__ == '__main__':
     test_forward_slice_axis()
     test_forward_l2_normalize()
     test_forward_shape_array()
+    test_forward_squeeze()
+    test_forward_broadcast_axis()
+    test_forward_full()
+    test_forward_embedding()