[Relay][Frontend][ONNX] Add support for op Where (#4184)
authorJon Soifer <soiferj@gmail.com>
Sun, 27 Oct 2019 00:05:22 +0000 (17:05 -0700)
committerJared Roesch <roeschinc@gmail.com>
Sun, 27 Oct 2019 00:05:22 +0000 (17:05 -0700)
* Add support for op Where

* Update impl version

python/tvm/relay/frontend/onnx.py
tests/python/frontend/onnx/test_forward.py

index a7f7874..b007b41 100644 (file)
@@ -922,6 +922,13 @@ class Erf(OnnxOpConverter):
     def _impl_v1(cls, inputs, attr, params):
         return _op.erf(inputs[0])
 
+class Where(OnnxOpConverter):
+    """Operator converter for Where
+    """
+    @classmethod
+    def _impl_v9(cls, inputs, attr, params):
+        return _op.where(inputs[0], inputs[1], inputs[2])
+
 
 # compatible operators that do NOT require any conversion.
 _identity_list = []
@@ -1042,7 +1049,8 @@ def _get_convert_map(opset):
         'Not': Not.get_converter(opset),
         'And': And.get_converter(opset),
         'Tile': Tile.get_converter(opset),
-        'Erf': Erf.get_converter(opset)
+        'Erf': Erf.get_converter(opset),
+        'Where': Where.get_converter(opset)
     }
 
 
index 16e7174..3d1262f 100644 (file)
@@ -1299,6 +1299,32 @@ def test_erf():
     z = scipy.special.erf(x)
     verify_erf(x, z)
 
+def verify_where(condition, x, y, dtype, outdata):
+    node = helper.make_node('Where', inputs=['condition', 'x', 'y'], outputs=['out'])
+    graph = helper.make_graph([node],
+                              'where_test',
+                              inputs=[helper.make_tensor_value_info('condition', TensorProto.BOOL, list(condition.shape)),
+                                      helper.make_tensor_value_info('x', dtype, list(x.shape)),
+                                      helper.make_tensor_value_info('y', dtype, list(y.shape))],
+                              outputs=[helper.make_tensor_value_info('out', dtype, list(outdata.shape))])
+    model = helper.make_model(graph, producer_name='where_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, [condition, x, y], target, ctx, outdata.shape)
+        tvm.testing.assert_allclose(outdata, tvm_out)
+
+def test_where():
+    condition = np.array([[1, 0], [1, 1]], dtype=np.bool)
+    x = np.array([[1, 2], [3, 4]], dtype=np.int64)
+    y = np.array([[9, 8], [7, 6]], dtype=np.int64)
+    outdata = np.where(condition, x, y)
+    verify_where(condition, x, y, TensorProto.INT64, outdata)
+
+    x = np.array([[1, 2], [3, 4]], dtype=np.float32)
+    y = np.array([[9, 8], [7, 6]], dtype=np.float32)
+    outdata = np.where(condition, x, y)
+    verify_where(condition, x, y, TensorProto.FLOAT, outdata)
+
 
 if __name__ == '__main__':
     test_flatten()
@@ -1347,3 +1373,4 @@ if __name__ == '__main__':
     test_and()
     test_tile()
     test_erf()
+    test_where()