From 967d7318a38f671509c6dc02fd298efa7f38b099 Mon Sep 17 00:00:00 2001 From: Samuel Date: Fri, 1 May 2020 18:33:45 +0530 Subject: [PATCH] [MXNET]broadcast and logical op support (#5461) * [MXNET]broadcast and logical op support * Review comment fixed --- python/tvm/relay/frontend/mxnet.py | 37 ++++++++++++++- tests/python/frontend/mxnet/test_forward.py | 71 +++++++++++++++++++++++++---- 2 files changed, 97 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 775eb53..7dbc788 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -1712,6 +1712,33 @@ def _qnn_fully_connected(inputs, attrs, subgraphs, params): res = _op.nn.relu(res) return res + +def _mx_broadcast_to(inputs, attrs): + data = inputs[0] + tgt_shape = attrs.get_int_tuple("shape", []) + + return _op.broadcast_to(data, tgt_shape) + + +def _mx_logical_not(inputs, input_types): + data = inputs[0] + dtype = _infer_type(data).checked_type.dtype + data = _op.cast(data, "bool") if dtype != "bool" else data + + return _op.cast(_op.logical_not(data), dtype) + + +def _mx_broadcast_logical(logical_op): + def impl(inputs, input_types): + lhs_type = _infer_type(inputs[0]).checked_type.dtype + rhs_type = _infer_type(inputs[1]).checked_type.dtype + lhs = _op.cast(inputs[0], "bool") if lhs_type != "bool" else inputs[0] + rhs = _op.cast(inputs[1], "bool") if rhs_type != "bool" else inputs[1] + + return _op.cast(logical_op(lhs, rhs), lhs_type) + return impl + + # Note: due to attribute conversion constraint # ops in the identity set must be attribute free _identity_list = [ @@ -1738,12 +1765,15 @@ _convert_map = { "_copy" : _rename(_op.copy), "relu" : _rename(_op.nn.relu), "broadcast_add" : _rename(_op.add), + "broadcast_plus" : _rename(_op.add), "broadcast_sub" : _rename(_op.subtract), + "broadcast_minus" : _rename(_op.subtract), "broadcast_mul" : _rename(_op.multiply), "broadcast_div" : _rename(_op.divide), "broadcast_mod" : _rename(_op.mod), "broadcast_maximum" : _rename(_op.maximum), "broadcast_minimum" : _rename(_op.minimum), + "broadcast_power" : _rename(_op.power), "arctan" : _rename(_op.atan), "broadcast_equal" : _mx_compare(_op.equal, _rename), "broadcast_not_equal" : _mx_compare(_op.not_equal, _rename), @@ -1751,6 +1781,11 @@ _convert_map = { "broadcast_greater_equal": _mx_compare(_op.greater_equal, _rename), "broadcast_lesser" : _mx_compare(_op.less, _rename), "broadcast_lesser_equal" : _mx_compare(_op.less_equal, _rename), + "broadcast_logical_or" : _mx_broadcast_logical(_op.logical_or), + "broadcast_logical_and" : _mx_broadcast_logical(_op.logical_and), + "broadcast_logical_xor" : _mx_broadcast_logical(_op.logical_xor), + "broadcast_to" : _mx_broadcast_to, + "logical_not" : _mx_logical_not, "_equal" : _mx_compare(_op.equal, _rename), "_not_equal" : _mx_compare(_op.not_equal, _rename), "_greater" : _mx_compare(_op.greater, _rename), @@ -1860,6 +1895,7 @@ _convert_map = { "reverse" : _mx_reverse, "squeeze" : _mx_squeeze, "broadcast_axis": _mx_broadcast_axis, + "broadcast_axes": _mx_broadcast_axis, "BlockGrad" : _mx_BlockGrad, "shape_array" : _mx_shape_array, "Embedding" : _mx_embedding, @@ -1897,7 +1933,6 @@ _convert_map = { # List of missing operators that are present in NNVMv1 # TODO(tvm-tvm): support all operators. # - # "broadcast_to", # "contrib_fifo_buffer": _mx_contrib_fifo_buffer, "ring_buffer": _mx_contrib_fifo_buffer, # Qnn ops diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 5e4c137..84c8acf 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -301,11 +301,25 @@ def _mx_symbol(F, op_name, inputs): return op(*inputs) def test_forward_broadcast_ops(): - for op in ["broadcast_add", "broadcast_sub", "broadcast_mul", - "broadcast_div", "broadcast_mod", "broadcast_maximum", - "broadcast_minimum", "broadcast_equal", "broadcast_not_equal", - "broadcast_greater", "broadcast_greater_equal", - "broadcast_lesser", "broadcast_lesser_equal"]: + for op in ["broadcast_add", + "broadcast_plus", + "broadcast_sub", + "broadcast_minus", + "broadcast_mul", + "broadcast_div", + "broadcast_mod", + "broadcast_maximum", + "broadcast_minimum", + "broadcast_equal", + "broadcast_not_equal", + "broadcast_greater", + "broadcast_greater_equal", + "broadcast_lesser", + "broadcast_lesser_equal", + "broadcast_power", + "broadcast_logical_or", + "broadcast_logical_and", + "broadcast_logical_xor"]: a_shape = (3, 4, 5) b_shape = (4, 5) if op == "broadcast_mod": @@ -462,16 +476,51 @@ def test_forward_squeeze(): def test_forward_broadcast_axis(): def verify(shape, axis, size): x_np = np.random.uniform(size=shape).astype("float32") - ref_res = mx.nd.broadcast_axis(mx.nd.array(x_np), axis=axis, size=size) - mx_sym = mx.sym.broadcast_axis(mx.sym.var("x"), axis=axis, size=size) - mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) + for op in ["broadcast_axis", + "broadcast_axes"]: + mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('x'),axis,size]) + ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(x_np),axis,size]) + mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + + verify((1, 2, 1), 2, 3) + verify((1, 2, 1), (0, 2), (2, 3)) + + +def test_forward_broadcast_to(): + def verify(input_shape, shape): + x_np = np.random.uniform(size=input_shape).astype("float32") + ref_res = mx.nd.broadcast_to(mx.nd.array(x_np), shape=shape) + mx_sym = mx.sym.broadcast_to(mx.sym.var("x"), shape=shape) + mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": input_shape}) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) - verify((1, 2, 1), 2, 3) - verify((1, 2, 1), (0, 2), (2, 3)) + + verify((1, 2, 3), (3, 2, 3)) + verify((4, 1, 32, 32), (4, 8, 32, 32)) + + +def test_forward_logical_not(): + a_shape = (3, 4, 5) + dtype = 'float32' + a_np = np.random.uniform(size=a_shape).astype(dtype) + mx_sym = mx.sym.logical_not(mx.sym.var('a')) + ref_res = mx.nd.logical_not(mx.nd.array(a_np)) + shapes = {'a': a_shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(a_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + def test_forward_full(): def verify(val, shape, dtype): @@ -1061,6 +1110,8 @@ if __name__ == '__main__': test_forward_where() test_forward_arange() test_forward_broadcast_ops() + test_forward_broadcast_to() + test_forward_logical_not() test_forward_elemwise_ops() test_forward_scalar_ops() test_forward_slice_like() -- 2.7.4