From 8502691b5b7ca152da9eb626529070db53d479c8 Mon Sep 17 00:00:00 2001 From: maheshambule <15611578+maheshambule@users.noreply.github.com> Date: Tue, 3 Mar 2020 02:27:40 +0530 Subject: [PATCH] [Frontend] [Tensorflow] ReadVariableOp operator support (#4952) * tf frontend read variable op * pylint fix * tf frontend freezed graph pruned ops --- python/tvm/relay/frontend/tensorflow.py | 11 +++++ tests/python/frontend/tensorflow/test_forward.py | 60 ++++++++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 6f27d73..14d2418 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1500,6 +1500,12 @@ def _add_n(): # compatible operators that do NOT require any conversion. _identity_list = [] +# Operators that get pruned away when the complete graph is frozen. +# These operators are not needed for inference. +_freezed_graph_pruned_op_list = ['ReadVariableOp', 'ResourceGather', 'Variable', + 'VariableV2', 'VarHandleOp', 'Assign', 'AssignVariableOp'] + + # _convert_map defines maps of name to converter functor(callable) # for 1 to 1 mapping, use Renamer if nothing but name is different # use AttrCvt if attributes need to be converted @@ -2187,6 +2193,11 @@ class GraphProto(object): missing_operators = self._parse_import_prerequisites(graph) if missing_operators: + freezed_ops = [op for op in missing_operators if op in _freezed_graph_pruned_op_list] + if freezed_ops: + raise Exception("Graph is not frozen. Provide a frozen graph. " + "Found operators {}".format(freezed_ops)) + raise NotImplementedError( \ "The following operators are not implemented: {}".format(missing_operators)) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 9cd978e..42408b7 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -22,6 +22,7 @@ This article is a test script to test tensorflow operator with Relay. """ from __future__ import print_function import numpy as np +import pytest import tensorflow as tf from tensorflow.python.framework import constant_op from tensorflow.python.framework import graph_util @@ -1061,6 +1062,62 @@ def test_forward_variable(): _test_variable(np.random.uniform(size=(32, 100)).astype('float32')) +def test_read_variable_op(): + """ Read Variable op test """ + + tf.reset_default_graph() + data = np.random.uniform(size=(32, 100)).astype('float32') + input_tensor = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + + size = input_tensor.shape.dims[1] + var_data = np.random.uniform(-5, 5, size=[size, size]).astype(np.float32) + input_var = tf.Variable(var_data, name='var1', use_resource=True) + math_ops.matmul(input_tensor, input_var) + + out_name = ['MatMul:0'] + out_node = ['MatMul'] + in_name = ['Placeholder:0'] + in_node = ['Placeholder'] + in_data = [data] + + with tf.Session() as sess: + sess.run(variables.global_variables_initializer()) + + final_graph_def = sess.graph.as_graph_def(add_shapes=True) + tf_output = run_tf_graph(sess, in_data, in_name, out_name) + + shape_dict = {e: i.shape for e, i in zip(in_name, in_data)} + with pytest.raises(Exception) as exexcinfo: + mod, params = relay.frontend.from_tensorflow(final_graph_def, + layout=None, + shape=shape_dict, + outputs=None) + + assert exexcinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph.") + + # Now convert the variables to constant and run inference on the converted graph + final_graph_def = tf.graph_util.convert_variables_to_constants( + sess, + sess.graph.as_graph_def(add_shapes=True), + out_node, + ) + + for device in ["llvm", "cuda"]: + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + continue + + tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, + target=device, out_names=out_name, + num_output=len(out_name)) + for i in range(len(tf_output)): + tvm.testing.assert_allclose( + tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) + + sess.close() + + ####################################################################### # MatMul, BatchMatMul, BatchMatMulV2 # ---------------------------------- @@ -3038,3 +3095,6 @@ if __name__ == '__main__': test_forward_where() test_forward_matmul() test_forward_batch_matmul() + + # Internal misc. ops + test_read_variable_op() -- 2.7.4