ONNX frontend operator support: And (#3878)
authorNeo Chien <cchung100m@cs.ccu.edu.tw>
Tue, 3 Sep 2019 04:02:52 +0000 (12:02 +0800)
committerZhi <5145158+zhiics@users.noreply.github.com>
Tue, 3 Sep 2019 04:02:52 +0000 (21:02 -0700)
python/tvm/relay/frontend/onnx.py
tests/python/frontend/onnx/test_forward.py

index 4f6dd74..b7fe2cf 100644 (file)
@@ -877,6 +877,14 @@ class Not(Elemwise):
         return _op.logical_not(inputs[0])
 
 
+class And(Elemwise):
+    """ Operator converter for And.
+    """
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        return _op.logical_and(inputs[0], inputs[1])
+
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
@@ -993,7 +1001,8 @@ def _get_convert_map(opset):
         'Shape': Shape.get_converter(opset),
         'Sign': Sign.get_converter(opset),
         'Equal': Equal.get_converter(opset),
-        'Not': Not.get_converter(opset)
+        'Not': Not.get_converter(opset),
+        'And': And.get_converter(opset)
     }
 
 
index e4c161d..7e0e11f 100644 (file)
@@ -1158,6 +1158,53 @@ def test_not():
     verify_not(indata=(np.random.randn(3, 4, 5, 6) > 0), dtype=bool)
 
 
+def verify_and(indata, dtype):
+    x = indata[0].astype(dtype)
+    y = indata[1].astype(dtype)
+    outdata = np.logical_and(x, y)
+
+    node = helper.make_node('And', inputs=['in1', 'in2'], outputs=['out'], )
+
+    graph = helper.make_graph([node],
+                              'and_test',
+                              inputs=[helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)),
+                                      helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape))],
+                              outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))])
+
+    model = helper.make_model(graph, producer_name='and_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, [x, y], target, ctx, outdata.shape)
+        tvm.testing.assert_allclose(outdata, tvm_out)
+
+
+def test_and():
+    # 2d
+    x = (np.random.randn(3, 4) > 0)
+    y = (np.random.randn(3, 4) > 0)
+    verify_and(indata=[x, y], dtype=bool)
+
+    # 3d
+    x = (np.random.randn(3, 4, 5) > 0)
+    y = (np.random.randn(3, 4, 5) > 0)
+    verify_and(indata=[x, y], dtype=bool)
+
+    # 4d
+    x = (np.random.randn(3, 4, 5, 6) > 0)
+    y = (np.random.randn(3, 4, 5, 6) > 0)
+    verify_and(indata=[x, y], dtype=bool)
+
+    # 3d vs 1d
+    x = (np.random.randn(3, 4, 5) > 0)
+    y = (np.random.randn(5) > 0)
+    verify_and(indata=[x, y], dtype=bool)
+
+    # 3d vs 2d
+    x = (np.random.randn(3, 4, 5) > 0)
+    y = (np.random.randn(4, 5) > 0)
+    verify_and(indata=[x, y], dtype=bool)
+
+
 if __name__ == '__main__':
     test_flatten()
     test_reshape()
@@ -1202,3 +1249,4 @@ if __name__ == '__main__':
     test_densenet()
     test_sign()
     test_not()
+    test_and()