Add support and testing for tf.assert (as no-op) and tf.no_op to TF Relay frontend...
authorBjarke Hammersholt Roune <bjarke.roune@gmail.com>
Wed, 23 Oct 2019 20:53:03 +0000 (13:53 -0700)
committerZhi <5145158+zhiics@users.noreply.github.com>
Wed, 23 Oct 2019 20:53:03 +0000 (13:53 -0700)
python/tvm/relay/frontend/tensorflow.py
tests/python/frontend/tensorflow/test_debugging.py [new file with mode: 0644]
tests/python/frontend/tensorflow/test_no_op.py [new file with mode: 0644]

index 95e0008..bfa3431 100644 (file)
@@ -436,6 +436,24 @@ def _check_numerics():
         return AttrCvt(op_name="copy", ignores=['message'])(inputs, attr)
     return _impl
 
+def _assert():
+    # ToDo: In general people want asserts to be gone from TensorFlow graphs
+    # when they are optimizing them, so converting it to a no-op is
+    # reasonable. However, it would be nice to have the option to keep them
+    # once Relay gets a Halt or Assert op.
+    return _no_op()
+
+def _no_op():
+    def _impl(inputs, attr, params):
+        # ToDo: This should really be an op that returns nothing, which could
+        # be represented as an empty tuple. It turns out that TVM
+        # infrastructure doesn't like running functions that return None and
+        # also don't like running functions that return an empty tuple. So it
+        # doesn't work, but it should be made to work and then this could be
+        # improved. In the mean time, it is hard to imagine a case where it
+        # matters in any real way that a no-op is converted to a constant 0.
+        return tvm.relay.const(0)
+    return _impl
 
 def _matmul():
     def _impl(inputs, attr, params):
@@ -1326,6 +1344,7 @@ _convert_map = {
     'All'                               : _reduce('all'),
     'ArgMax'                            : _argx(_op.argmax, 'argmax'),
     'ArgMin'                            : _argx(_op.argmin, 'argmin'),
+    'Assert'                            : _assert(),
     'AvgPool'                           : _pooling('avg_pool'),
     'BatchMatMul'                       : _batch_matmul(),
     'BatchMatMulV2'                     : _batch_matmul(),
@@ -1384,6 +1403,7 @@ _convert_map = {
     'Mod'                               : _elemwise('mod'),
     'Mul'                               : _elemwise('multiply'),
     'Neg'                               : AttrCvt('negative'),
+    'NoOp'                              : _no_op(),
     'NotEqual'                          : _broadcast('not_equal'),
     'OneHot'                            : _one_hot(),
     'Pack'                              : _pack(),
@@ -2196,8 +2216,11 @@ class GraphProto(object):
             if np_array.dtype == np.dtype(object):
                 # Object types are generally tensorflow DT_STRING (DecodeJpeg op).
                 # Just leave it as placeholder.
-                self._nodes[name] = [_expr.var(name, shape=shape[name], dtype='uint8')]
-
+                if shape:
+                    var_shape = shape[name]
+                else:
+                    var_shape = tensor_util.TensorShapeProtoToList(value.tensor.tensor_shape)
+                self._nodes[name] = [_expr.var(name, shape=var_shape, dtype='uint8')]
                 return
 
             array_ndim = len(np_array.shape)
diff --git a/tests/python/frontend/tensorflow/test_debugging.py b/tests/python/frontend/tensorflow/test_debugging.py
new file mode 100644 (file)
index 0000000..c7da636
--- /dev/null
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for converting TensorFlow debugging ops to Relay."""
+import tensorflow as tf
+import numpy as np
+from tvm import relay
+from tvm.relay.frontend.tensorflow import from_tensorflow
+
+def run_relay(graph, *vars):
+    mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
+    ex = relay.create_executor('debug', mod=mod)
+    return ex.evaluate()(*vars)
+
+def test_assert_true():
+    g = tf.Graph()
+    with g.as_default():
+        x = tf.placeholder(tf.float32, shape=())
+        assert_op = tf.Assert(tf.less_equal(x, x), ["it failed"])
+
+        with tf.Session() as sess:
+            x_value = np.random.rand()
+            assert sess.run(assert_op, feed_dict={x: x_value}) is None
+
+        # In TVM, tf.assert is converted to a no-op which is actually a 0,
+        # though it should probably be none or an empty tuple.
+        #
+        # ToDo: It appears that the frontend converter gets confused here and
+        # entirely eliminates all operands from main(). Likely because x <= x
+        # is always true, so the placeholder can be eliminated. But TF doesn't
+        # do that, it's happening in Relay, and that optimization shouldn't
+        # affect the arity of the main function. We should have to pass in
+        # x_value here.
+        np.testing.assert_allclose(0, run_relay(g).asnumpy())
+
+def test_assert_true_var_capture():
+    g = tf.Graph()
+    with g.as_default():
+        x = tf.placeholder(tf.float32, shape=())
+
+        # It turns out that tf.assert() creates a large and complex subgraph if
+        # you capture a variable as part of the error message. So we need to
+        # test that, too.
+        assert_op = tf.Assert(tf.less_equal(x, x), ["it failed", x])
+
+        with tf.Session() as sess:
+            x_value = np.random.rand()
+            assert sess.run(assert_op, feed_dict={x: x_value}) is None
+
+        # ToDo: The frontend converter gets confused here as well, thinking
+        # that it needs to be told what x is twice. It also notes the output of
+        # the graph as a boolean, which is not correct - as you can see above,
+        # TF believes that the value of this graph is None. In addition, the
+        # arity of the translated function should be 1, not 2.
+        np.testing.assert_allclose(True, run_relay(g, x_value, x_value).asnumpy())
+
+def test_assert_false():
+    g = tf.Graph()
+    with g.as_default():
+        assert_op = tf.Assert(tf.constant(False), ["it failed"])
+
+        with tf.Session() as sess:
+            try:
+                print(sess.run(assert_op))
+                assert False  # TF should have thrown an exception
+            except tf.errors.InvalidArgumentError as e:
+                assert "it failed" in e.message
+
+        # In TVM, tf.assert is converted to a no-op which is actually a 0,
+        # though it should probably be none or an empty tuple. For the same
+        # reason, there should not be an error here, even though the assertion
+        # argument is false.
+        np.testing.assert_allclose(0, run_relay(g).asnumpy())
+
+        
+if __name__ == "__main__":
+    test_assert_true()
+    test_assert_true_var_capture()
+    test_assert_false()
+    
diff --git a/tests/python/frontend/tensorflow/test_no_op.py b/tests/python/frontend/tensorflow/test_no_op.py
new file mode 100644 (file)
index 0000000..0d09cf4
--- /dev/null
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for converting TensorFlow debugging ops to Relay."""
+import tensorflow as tf
+import numpy as np
+from tvm import relay
+from tvm.relay.frontend.tensorflow import from_tensorflow
+
+def run_relay(graph):
+    mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
+    ex = relay.create_executor('debug', mod=mod)
+    return ex.evaluate()(**params)
+
+def test_no_op():
+    g = tf.Graph()
+    with g.as_default():
+        no_op = tf.no_op()
+        with tf.Session() as sess:
+            # In TF, the type of a no-op is None.
+            assert sess.run(no_op) is None
+
+        # In TVM, no-op is currently translated to 0, though it should
+        # probably be none or an empty tuple.
+        np.testing.assert_allclose(0, run_relay(g).asnumpy())
+
+
+if __name__ == "__main__":
+    test_no_op()
+