[Relay][TensorFlow] Support tf.math.reduce_prod (#3166)
authorlixiaoquan <radioheads@163.com>
Sat, 11 May 2019 04:56:01 +0000 (12:56 +0800)
committerSiva <sivar.b@huawei.com>
Sat, 11 May 2019 04:56:01 +0000 (10:26 +0530)
python/tvm/relay/frontend/tensorflow.py
tests/python/frontend/tensorflow/test_forward.py

index bbbb0a2..48f7883 100644 (file)
@@ -1080,6 +1080,15 @@ def _batch_to_space_nd():
 
     return _impl
 
+
+def _prod():
+    def _impl(inputs, attr, params):
+        axis = params.pop(inputs[1].name_hint).asnumpy()[0]
+        keepdims = attr['keep_dims']
+        return _op.prod(inputs[0], int(axis), keepdims=keepdims)
+    return _impl
+
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
@@ -1136,6 +1145,7 @@ _convert_map = {
     'Pad'                               : _pad('Pad'),
     'PadV2'                             : _pad('PadV2'),
     'Pow'                               : _elemwise('power'),
+    'Prod'                              : _prod(),
     'Range'                             : _range(),
     'Rank'                              : _rank(),
     'RealDiv'                           : _elemwise('div'),
index 1579769..58bbdab 100644 (file)
@@ -161,7 +161,6 @@ def is_gpu_available():
     else:
         return False
 
-
 #######################################################################
 # Pooling
 # -------
@@ -1509,6 +1508,25 @@ def test_forward_expand_dims():
     _test_forward_expand_dims(np.array([[1], [2]]), 1)
     _test_forward_expand_dims(np.array([[1], [2]]), -1)
 
+
+#######################################################################
+# Prod
+# ----
+def _test_forward_reduce_prod(shape, axis, keepdims):
+    inp_array1 = np.random.uniform(-5, 5, size=shape).astype(np.float32)
+    with tf.Graph().as_default():
+        in1 = tf.placeholder(shape=inp_array1.shape, dtype=inp_array1.dtype)
+        out = tf.math.reduce_prod(in1, axis, keepdims)
+        compare_tf_with_tvm(inp_array1, in1.name, out.name)
+
+def test_forward_reduce_prod():
+    _test_forward_reduce_prod((5,), 0, False)
+    _test_forward_reduce_prod((5, 5), 0, False)
+    _test_forward_reduce_prod((5, 5), 1, False)
+    _test_forward_reduce_prod((5,), 0, True)
+    _test_forward_reduce_prod((5, 5), 0, True)
+    _test_forward_reduce_prod((5, 5), 1, True)
+
 #######################################################################
 # Main
 # ----
@@ -1550,6 +1568,7 @@ if __name__ == '__main__':
     test_forward_argminmax()
     test_forward_reduce()
     test_forward_mean()
+    test_forward_reduce_prod()
 
     # General
     test_forward_multi_input()