[Frontend] [Tensorflow] ReadVariableOp operator support (#4952)
authormaheshambule <15611578+maheshambule@users.noreply.github.com>
Mon, 2 Mar 2020 20:57:40 +0000 (02:27 +0530)
committerGitHub <noreply@github.com>
Mon, 2 Mar 2020 20:57:40 +0000 (12:57 -0800)
* tf frontend read variable op

* pylint fix

* tf frontend freezed graph pruned ops

python/tvm/relay/frontend/tensorflow.py
tests/python/frontend/tensorflow/test_forward.py

index 6f27d73..14d2418 100644 (file)
@@ -1500,6 +1500,12 @@ def _add_n():
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
+# Operators that get pruned away when the complete graph is frozen.
+# These operators are not needed for inference.
+_freezed_graph_pruned_op_list = ['ReadVariableOp', 'ResourceGather', 'Variable',
+                                 'VariableV2', 'VarHandleOp', 'Assign', 'AssignVariableOp']
+
+
 # _convert_map defines maps of name to converter functor(callable)
 # for 1 to 1 mapping, use Renamer if nothing but name is different
 # use AttrCvt if attributes need to be converted
@@ -2187,6 +2193,11 @@ class GraphProto(object):
         missing_operators = self._parse_import_prerequisites(graph)
 
         if missing_operators:
+            freezed_ops = [op for op in missing_operators if op in _freezed_graph_pruned_op_list]
+            if freezed_ops:
+                raise Exception("Graph is not frozen. Provide a frozen graph. "
+                                "Found operators {}".format(freezed_ops))
+
             raise NotImplementedError( \
                 "The following operators are not implemented: {}".format(missing_operators))
 
index 9cd978e..42408b7 100644 (file)
@@ -22,6 +22,7 @@ This article is a test script to test tensorflow operator with Relay.
 """
 from __future__ import print_function
 import numpy as np
+import pytest
 import tensorflow as tf
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import graph_util
@@ -1061,6 +1062,62 @@ def test_forward_variable():
     _test_variable(np.random.uniform(size=(32, 100)).astype('float32'))
 
 
+def test_read_variable_op():
+    """ Read Variable op test """
+
+    tf.reset_default_graph()
+    data = np.random.uniform(size=(32, 100)).astype('float32')
+    input_tensor = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+
+    size = input_tensor.shape.dims[1]
+    var_data = np.random.uniform(-5, 5, size=[size, size]).astype(np.float32)
+    input_var = tf.Variable(var_data, name='var1', use_resource=True)
+    math_ops.matmul(input_tensor, input_var)
+
+    out_name = ['MatMul:0']
+    out_node = ['MatMul']
+    in_name = ['Placeholder:0']
+    in_node = ['Placeholder']
+    in_data = [data]
+
+    with tf.Session() as sess:
+        sess.run(variables.global_variables_initializer())
+
+        final_graph_def = sess.graph.as_graph_def(add_shapes=True)
+        tf_output = run_tf_graph(sess, in_data, in_name, out_name)
+
+        shape_dict = {e: i.shape for e, i in zip(in_name, in_data)}
+        with pytest.raises(Exception) as exexcinfo:
+            mod, params = relay.frontend.from_tensorflow(final_graph_def,
+                                                         layout=None,
+                                                         shape=shape_dict,
+                                                         outputs=None)
+
+        assert exexcinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph.")
+
+        # Now convert the variables to constant and run inference on the converted graph
+        final_graph_def = tf.graph_util.convert_variables_to_constants(
+            sess,
+            sess.graph.as_graph_def(add_shapes=True),
+            out_node,
+        )
+
+        for device in ["llvm", "cuda"]:
+            ctx = tvm.context(device, 0)
+            if not ctx.exist:
+                print("Skip because %s is not enabled" % device)
+                continue
+
+            tvm_output = run_tvm_graph(final_graph_def, in_data, in_node,
+                                       target=device, out_names=out_name,
+                                       num_output=len(out_name))
+            for i in range(len(tf_output)):
+                tvm.testing.assert_allclose(
+                    tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
+
+        sess.close()
+
+
 #######################################################################
 # MatMul, BatchMatMul, BatchMatMulV2
 # ----------------------------------
@@ -3038,3 +3095,6 @@ if __name__ == '__main__':
     test_forward_where()
     test_forward_matmul()
     test_forward_batch_matmul()
+
+    # Internal misc. ops
+    test_read_variable_op()