def _impl_v1(cls, inputs, attr, params):
return _op.erf(inputs[0])
+class Where(OnnxOpConverter):
+ """Operator converter for Where
+ """
+ @classmethod
+ def _impl_v9(cls, inputs, attr, params):
+ return _op.where(inputs[0], inputs[1], inputs[2])
+
# compatible operators that do NOT require any conversion.
_identity_list = []
'Not': Not.get_converter(opset),
'And': And.get_converter(opset),
'Tile': Tile.get_converter(opset),
- 'Erf': Erf.get_converter(opset)
+ 'Erf': Erf.get_converter(opset),
+ 'Where': Where.get_converter(opset)
}
z = scipy.special.erf(x)
verify_erf(x, z)
+def verify_where(condition, x, y, dtype, outdata):
+ node = helper.make_node('Where', inputs=['condition', 'x', 'y'], outputs=['out'])
+ graph = helper.make_graph([node],
+ 'where_test',
+ inputs=[helper.make_tensor_value_info('condition', TensorProto.BOOL, list(condition.shape)),
+ helper.make_tensor_value_info('x', dtype, list(x.shape)),
+ helper.make_tensor_value_info('y', dtype, list(y.shape))],
+ outputs=[helper.make_tensor_value_info('out', dtype, list(outdata.shape))])
+ model = helper.make_model(graph, producer_name='where_test')
+
+ for target, ctx in ctx_list():
+ tvm_out = get_tvm_output(model, [condition, x, y], target, ctx, outdata.shape)
+ tvm.testing.assert_allclose(outdata, tvm_out)
+
+def test_where():
+ condition = np.array([[1, 0], [1, 1]], dtype=np.bool)
+ x = np.array([[1, 2], [3, 4]], dtype=np.int64)
+ y = np.array([[9, 8], [7, 6]], dtype=np.int64)
+ outdata = np.where(condition, x, y)
+ verify_where(condition, x, y, TensorProto.INT64, outdata)
+
+ x = np.array([[1, 2], [3, 4]], dtype=np.float32)
+ y = np.array([[9, 8], [7, 6]], dtype=np.float32)
+ outdata = np.where(condition, x, y)
+ verify_where(condition, x, y, TensorProto.FLOAT, outdata)
+
if __name__ == '__main__':
test_flatten()
test_and()
test_tile()
test_erf()
+ test_where()