'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
'ones_like', 'relu', 'sigmoid', 'slice_like', 'softmax',
'sum', 'tanh', 'transpose', 'zeros_like', 'gather_nd',
- 'reshape_like']
+ 'reshape_like', 'where']
_convert_map = {
'_copy' : _rename('copy'),
ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32')
mx_sym = mx.sym.elemwise_add(data, ones)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
-
+
def test_forward_zeros():
data = mx.sym.var('data')
zeros = mx.sym.zeros(shape=(2, 3, 4), dtype='float32')
data = mx.sym.var('data')
mx_sym = mx.sym.argmin(data, axis=0)
verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))
-
+
+def test_forward_where():
+ cond = mx.sym.var('cond')
+ x = mx.sym.var('x')
+ y = mx.sym.var('y')
+ dshape = (2, 2)
+ dtype = 'float32'
+ mx_sym = mx.sym.where(cond, x, y)
+ np_cond = np.array([[0, 1], [-1, 0]]).astype(dtype)
+ np_x = np.random.uniform(size=dshape).astype(dtype)
+ np_y = np.random.uniform(size=dshape).astype(dtype)
+ mx_cond = mx.nd.array(np_cond)
+ mx_x = mx.nd.array(np_x)
+ mx_y = mx.nd.array(np_y)
+ mod = mx.mod.Module(mx_sym, label_names=None, data_names=['cond', 'x', 'y'])
+ mod.bind(data_shapes=[('cond', dshape), ('x', dshape), ('y', dshape)], for_training=False)
+ mod.init_params()
+ args, auxs = mod.get_params()
+ mx_out = mx.nd.where(mx_cond, mx_x, mx_y).asnumpy()
+ out_shape = dshape
+ new_sym, params = frontend.from_mxnet(mx_sym, args, auxs)
+ shape_dict = {'cond': dshape, 'x': dshape, 'y': dshape}
+ for target, ctx in ctx_list():
+ with nnvm.compiler.build_config(opt_level=3):
+ graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
+ m = graph_runtime.create(graph, lib, ctx)
+ # set inputs
+ m.set_input("cond", tvm.nd.array(np_cond))
+ m.set_input("x", tvm.nd.array(np_x))
+ m.set_input("y", tvm.nd.array(np_y))
+ m.set_input(**params)
+ m.run()
+ # get outputs
+ tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
+ tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
+
if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
test_forward_zeros_like()
test_forward_argmax()
test_forward_argmin()
-
+ test_forward_where()
+
mx_sym = mx.sym.argmin(data, axis=0)
verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))
+def test_forward_where():
+ cond = mx.sym.var('cond')
+ x = mx.sym.var('x')
+ y = mx.sym.var('y')
+ dshape = (2, 2)
+ dtype = 'float32'
+ mx_sym = mx.sym.where(cond, x, y)
+ np_cond = np.array([[0, 1], [-1, 0]]).astype(dtype)
+ np_x = np.random.uniform(size=dshape).astype(dtype)
+ np_y = np.random.uniform(size=dshape).astype(dtype)
+ mx_cond = mx.nd.array(np_cond)
+ mx_x = mx.nd.array(np_x)
+ mx_y = mx.nd.array(np_y)
+ mod = mx.mod.Module(mx_sym, label_names=None, data_names=['cond', 'x', 'y'])
+ mod.bind(data_shapes=[('cond', dshape), ('x', dshape), ('y', dshape)], for_training=False)
+ mod.init_params()
+ args, auxs = mod.get_params()
+ mx_out = mx.nd.where(mx_cond, mx_x, mx_y).asnumpy()
+ out_shape = dshape
+ shape_dict = {'cond': dshape, 'x': dshape, 'y': dshape}
+ new_sym, params = relay.frontend.from_mxnet(mx_sym,
+ shape_dict,
+ arg_params=args,
+ aux_params=auxs)
+ for target, ctx in ctx_list():
+ with relay.build_config(opt_level=3):
+ graph, lib, params = relay.build(new_sym, target, params=params)
+ m = graph_runtime.create(graph, lib, ctx)
+ # set inputs
+ m.set_input("cond", tvm.nd.array(np_cond))
+ m.set_input("x", tvm.nd.array(np_x))
+ m.set_input("y", tvm.nd.array(np_y))
+ m.set_input(**params)
+ m.run()
+ # get outputs
+ tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
+ tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
+
if __name__ == '__main__':
test_forward_mlp()
test_forward_zeros_like()
test_forward_argmax()
test_forward_argmin()
+ test_forward_where()