From 73dc5ac379a248f1dd6da8847024ed4e97ed7c06 Mon Sep 17 00:00:00 2001 From: Neo Chien Date: Sun, 1 Sep 2019 08:50:36 +0800 Subject: [PATCH] Add not operator for the frontend/onnx.py (#3836) --- python/tvm/relay/frontend/onnx.py | 12 +++++++++++- tests/python/frontend/onnx/test_forward.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 07cda16..4f6dd74 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -868,6 +868,15 @@ class Equal(Elemwise): """ name = 'equal' + +class Not(Elemwise): + """ Operator converter for Not. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + return _op.logical_not(inputs[0]) + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -983,7 +992,8 @@ def _get_convert_map(opset): 'Pad': Pad.get_converter(opset), 'Shape': Shape.get_converter(opset), 'Sign': Sign.get_converter(opset), - 'Equal': Equal.get_converter(opset) + 'Equal': Equal.get_converter(opset), + 'Not': Not.get_converter(opset) } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 6173362..e4c161d 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1130,6 +1130,34 @@ def test_sign(): 'Sign', {}) + +def verify_not(indata, dtype): + x = indata.astype(dtype) + outdata = np.logical_not(x) + + node = helper.make_node('Not', inputs=['in'], outputs=['out'],) + + graph = helper.make_graph([node], + 'not_test', + inputs=[helper.make_tensor_value_info("in", TensorProto.BOOL, list(x.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))]) + + model = helper.make_model(graph, producer_name='not_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, [x], target, ctx, outdata.shape) + tvm.testing.assert_allclose(outdata, tvm_out) + + +def test_not(): + # 2d + verify_not(indata=(np.random.randn(3, 4) > 0), dtype=bool) + # 3d + verify_not(indata=(np.random.randn(3, 4, 5) > 0), dtype=bool) + # 4d + verify_not(indata=(np.random.randn(3, 4, 5, 6) > 0), dtype=bool) + + if __name__ == '__main__': test_flatten() test_reshape() @@ -1173,3 +1201,4 @@ if __name__ == '__main__': test_inception() test_densenet() test_sign() + test_not() -- 2.7.4