add MXNet converter for where operator for both NNVM and Relay (#2647)
authorHao Jin <haojin2@users.noreply.github.com>
Fri, 22 Feb 2019 06:14:08 +0000 (22:14 -0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 22 Feb 2019 06:14:08 +0000 (22:14 -0800)
nnvm/python/nnvm/frontend/mxnet.py
nnvm/tests/python/frontend/mxnet/test_forward.py
python/tvm/relay/frontend/mxnet.py
tests/python/frontend/mxnet/test_forward.py

index 0194886f53891be784e300ed43a9d0a7d5328738..85268de858d1789fde5bd557fe2499f7270e090a 100644 (file)
@@ -305,7 +305,7 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
                   'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
                   'ones_like', 'relu', 'sigmoid', 'slice_like', 'softmax',
                   'sum', 'tanh', 'transpose', 'zeros_like', 'gather_nd',
-                  'reshape_like']
+                  'reshape_like', 'where']
 
 _convert_map = {
     '_copy'         : _rename('copy'),
index 66ae9d6e9de4690fefdc2bf11cd16f29b1c70b51..e9225a4c7c50f0b864b5d55931452b52108216ba 100644 (file)
@@ -158,7 +158,7 @@ def test_forward_ones():
     ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32')
     mx_sym = mx.sym.elemwise_add(data, ones)
     verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
-    
+
 def test_forward_zeros():
     data = mx.sym.var('data')
     zeros = mx.sym.zeros(shape=(2, 3, 4), dtype='float32')
@@ -184,7 +184,42 @@ def test_forward_argmin():
     data = mx.sym.var('data')
     mx_sym = mx.sym.argmin(data, axis=0)
     verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))
-    
+
+def test_forward_where():
+    cond = mx.sym.var('cond')
+    x = mx.sym.var('x')
+    y = mx.sym.var('y')
+    dshape = (2, 2)
+    dtype = 'float32'
+    mx_sym = mx.sym.where(cond, x, y)
+    np_cond = np.array([[0, 1], [-1, 0]]).astype(dtype)
+    np_x = np.random.uniform(size=dshape).astype(dtype)
+    np_y = np.random.uniform(size=dshape).astype(dtype)
+    mx_cond = mx.nd.array(np_cond)
+    mx_x = mx.nd.array(np_x)
+    mx_y = mx.nd.array(np_y)
+    mod = mx.mod.Module(mx_sym, label_names=None, data_names=['cond', 'x', 'y'])
+    mod.bind(data_shapes=[('cond', dshape), ('x', dshape), ('y', dshape)], for_training=False)
+    mod.init_params()
+    args, auxs = mod.get_params()
+    mx_out = mx.nd.where(mx_cond, mx_x, mx_y).asnumpy()
+    out_shape = dshape
+    new_sym, params = frontend.from_mxnet(mx_sym, args, auxs)
+    shape_dict = {'cond': dshape, 'x': dshape, 'y': dshape}
+    for target, ctx in ctx_list():
+        with nnvm.compiler.build_config(opt_level=3):
+            graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
+        m = graph_runtime.create(graph, lib, ctx)
+        # set inputs
+        m.set_input("cond", tvm.nd.array(np_cond))
+        m.set_input("x", tvm.nd.array(np_x))
+        m.set_input("y", tvm.nd.array(np_y))
+        m.set_input(**params)
+        m.run()
+        # get outputs
+        tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
+        tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
+
 if __name__ == '__main__':
     test_forward_mlp()
     test_forward_vgg()
@@ -206,4 +241,5 @@ if __name__ == '__main__':
     test_forward_zeros_like()
     test_forward_argmax()
     test_forward_argmin()
-    
+    test_forward_where()
+
index 540e139ff49591eae78ca269c5d52112f37baac1..3a0885a3fcdfc6cd02adeb04135c8fc2140be841 100644 (file)
@@ -290,6 +290,7 @@ _identity_list = [
     "slice_like",
     "zeros_like",
     "ones_like",
+    "where",
 ]
 
 _convert_map = {
index 81a12b041ed7ded2a9c2915c0e2f0927838e5919..e1f7e55092303d3a7c09562a52895baa524c1c0d 100644 (file)
@@ -190,6 +190,44 @@ def test_forward_argmin():
     mx_sym = mx.sym.argmin(data, axis=0)
     verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))
 
+def test_forward_where():
+    cond = mx.sym.var('cond')
+    x = mx.sym.var('x')
+    y = mx.sym.var('y')
+    dshape = (2, 2)
+    dtype = 'float32'
+    mx_sym = mx.sym.where(cond, x, y)
+    np_cond = np.array([[0, 1], [-1, 0]]).astype(dtype)
+    np_x = np.random.uniform(size=dshape).astype(dtype)
+    np_y = np.random.uniform(size=dshape).astype(dtype)
+    mx_cond = mx.nd.array(np_cond)
+    mx_x = mx.nd.array(np_x)
+    mx_y = mx.nd.array(np_y)
+    mod = mx.mod.Module(mx_sym, label_names=None, data_names=['cond', 'x', 'y'])
+    mod.bind(data_shapes=[('cond', dshape), ('x', dshape), ('y', dshape)], for_training=False)
+    mod.init_params()
+    args, auxs = mod.get_params()
+    mx_out = mx.nd.where(mx_cond, mx_x, mx_y).asnumpy()
+    out_shape = dshape
+    shape_dict = {'cond': dshape, 'x': dshape, 'y': dshape}
+    new_sym, params = relay.frontend.from_mxnet(mx_sym,
+                                                shape_dict,
+                                                arg_params=args,
+                                                aux_params=auxs)
+    for target, ctx in ctx_list():
+        with relay.build_config(opt_level=3):
+            graph, lib, params = relay.build(new_sym, target, params=params)
+        m = graph_runtime.create(graph, lib, ctx)
+        # set inputs
+        m.set_input("cond", tvm.nd.array(np_cond))
+        m.set_input("x", tvm.nd.array(np_x))
+        m.set_input("y", tvm.nd.array(np_y))
+        m.set_input(**params)
+        m.run()
+        # get outputs
+        tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
+        tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
+
 
 if __name__ == '__main__':
     test_forward_mlp()
@@ -212,3 +250,4 @@ if __name__ == '__main__':
     test_forward_zeros_like()
     test_forward_argmax()
     test_forward_argmin()
+    test_forward_where()