[Relay][Frontend][TF] Fix Size operator (#4175)
authorJon Soifer <soiferj@gmail.com>
Tue, 22 Oct 2019 23:37:53 +0000 (16:37 -0700)
committerZhi <5145158+zhiics@users.noreply.github.com>
Tue, 22 Oct 2019 23:37:53 +0000 (16:37 -0700)
* [Relay][Frontend][TF] Fix Size operator

* Uncomment tests

python/tvm/relay/frontend/common.py
python/tvm/relay/frontend/tensorflow.py
tests/python/frontend/tensorflow/test_forward.py

index 637e1f0..d4b9162 100644 (file)
@@ -259,7 +259,7 @@ def get_relay_op(op_name):
             op = None
     else:
         # try search op in various modules
-        for candidate in (_op, _op.nn, _op.image, _op.vision):
+        for candidate in (_op, _op.nn, _op.image, _op.vision, _op.contrib):
             op = getattr(candidate, op_name, None)
             if op is not None:
                 break
index eb67cf2..95e0008 100644 (file)
@@ -1305,6 +1305,13 @@ def _squared_difference():
         return _op.multiply(difference, difference)
     return _impl
 
+def _size():
+    def _impl(inputs, attr, params):
+        new_attr = attr
+        new_attr['out_type'] = attr['out_type'].name
+        return AttrCvt('ndarray_size', transforms={'out_type' : 'dtype'})(inputs, new_attr)
+    return _impl
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
@@ -1410,7 +1417,7 @@ _convert_map = {
     'Shape'                             : _shape(),
     'Sigmoid'                           : AttrCvt('sigmoid'),
     'Sign'                              : AttrCvt('sign'),
-    'Size'                              : AttrCvt('ndarray_size'),
+    'Size'                              : _size(),
     'Slice'                             : _slice(),
     'Softmax'                           : _softmax(),
     'Softplus'                          : _softplus(),
index 420bcb7..11c6a7b 100644 (file)
@@ -2184,15 +2184,18 @@ def test_forward_mean():
 def test_forward_size():
     def check_size(ishape):
         np_input = np.random.uniform(size=ishape).astype(np.float32)
+
+        # if all dimensions are constant, TF will optimize away size operator into constant
+        tf_input_shape = list(np_input.shape)
+        tf_input_shape[0] = None
+
         with tf.Graph().as_default():
-            input = tf.placeholder(shape=np_input.shape, dtype=np_input.dtype, name='input')
+            input = tf.placeholder(shape=tf_input_shape, dtype=np_input.dtype, name='input')
             tf.size(input, name='size')
             compare_tf_with_tvm([np_input], ['input:0'], 'size:0')
 
-    if tf.__version__ < LooseVersion('1.1'):
-        check_size((10, 8, 16, 32))
-        check_size((10,))
-        check_size(())
+    check_size((10, 8, 16, 32))
+    check_size((10,))
 
 #######################################################################
 # All, Max, Min