From ec0ef29835563b762ec9443a3c194c5c904fd6be Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 May 2018 13:55:20 -0700 Subject: [PATCH] When using static_state_saving_rnn(..) in the following manner _, state = tf.nn.static_state_saving_rnn(..) the runtime will be blocked after some time, because the save_state method of the state_saver object won't be executed as a part of the graph (that part depends only on output node in the current implementation). Now it should depend on state as well, so the above implementation won't be blocked. PiperOrigin-RevId: 196024050 --- .../rnn/python/kernel_tests/core_rnn_test.py | 137 ++++++++++++++++----- tensorflow/python/ops/rnn.py | 7 ++ 2 files changed, 116 insertions(+), 28 deletions(-) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index ba4933d..c75593e 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import state_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib @@ -142,6 +143,47 @@ class TestStateSaver(object): self.saved_state[name] = state return array_ops.identity(state) + @property + def batch_size(self): + return self._batch_size + + @property + def state_size(self): + return self._state_size + + +class TestStateSaverWithCounters(TestStateSaver): + """Class wrapper around TestStateSaver. + + A dummy class used for testing of static_state_saving_rnn. It helps test if + save_state and state functions got called same number of time when we + evaluate output of rnn cell and state or either of them separately. It + inherits from the TestStateSaver and adds the counters for calls of functions. + """ + + def __init__(self, batch_size, state_size): + super(TestStateSaverWithCounters, self).__init__(batch_size, state_size) + self._num_state_calls = variables_lib.Variable(0) + self._num_save_state_calls = variables_lib.Variable(0) + + def state(self, name): + with ops_lib.control_dependencies( + [state_ops.assign_add(self._num_state_calls, 1)]): + return super(TestStateSaverWithCounters, self).state(name) + + def save_state(self, name, state): + with ops_lib.control_dependencies([state_ops.assign_add( + self._num_save_state_calls, 1)]): + return super(TestStateSaverWithCounters, self).save_state(name, state) + + @property + def num_state_calls(self): + return self._num_state_calls + + @property + def num_save_state_calls(self): + return self._num_save_state_calls + class RNNTest(test.TestCase): @@ -1792,13 +1834,40 @@ class StateSaverRNNTest(test.TestCase): self._seed = 23489 np.random.seed(self._seed) - def _testScope(self, factory, prefix="prefix", use_outer_scope=True): + def _factory(self, scope, state_saver): + num_units = state_saver.state_size // 2 + batch_size = state_saver.batch_size + input_size = 5 + max_length = 8 + initializer = init_ops.random_uniform_initializer( + -0.01, 0.01, seed=self._seed) + cell = rnn_cell.LSTMCell( + num_units, + use_peepholes=False, + initializer=initializer, + state_is_tuple=False) + inputs = max_length * [ + array_ops.zeros(dtype=dtypes.float32, shape=(batch_size, input_size)) + ] + out, state = rnn.static_state_saving_rnn( + cell, + inputs, + state_saver=state_saver, + state_name="save_lstm", + scope=scope) + return out, state, state_saver + + def _testScope(self, prefix="prefix", use_outer_scope=True): + num_units = 3 + batch_size = 2 + state_saver = TestStateSaver(batch_size, 2 * num_units) + with self.test_session(use_gpu=True, graph=ops_lib.Graph()): if use_outer_scope: with variable_scope.variable_scope(prefix) as scope: - factory(scope) + self._factory(scope=scope, state_saver=state_saver) else: - factory(prefix) + self._factory(scope=prefix, state_saver=state_saver) variables_lib.global_variables_initializer() # check that all the variables names starts @@ -1813,34 +1882,46 @@ class StateSaverRNNTest(test.TestCase): self.assertEqual(len(scope_vars), len(all_vars)) def testStateSaverRNNScope(self): - num_units = 3 - input_size = 5 - batch_size = 2 - max_length = 8 + self._testScope(use_outer_scope=True) + self._testScope(use_outer_scope=False) + self._testScope(prefix=None, use_outer_scope=False) - def factory(scope): - initializer = init_ops.random_uniform_initializer( - -0.01, 0.01, seed=self._seed) - state_saver = TestStateSaver(batch_size, 2 * num_units) - cell = rnn_cell.LSTMCell( - num_units, - use_peepholes=False, - initializer=initializer, - state_is_tuple=False) - inputs = max_length * [ - array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) - ] - return rnn.static_state_saving_rnn( - cell, - inputs, - state_saver=state_saver, - state_name="save_lstm", - scope=scope) + def testStateSaverCallsSaveState(self): + """Test that number of calls to state and save_state is equal. - self._testScope(factory, use_outer_scope=True) - self._testScope(factory, use_outer_scope=False) - self._testScope(factory, prefix=None, use_outer_scope=False) + Test if the order of actual evaluating or skipping evaluation of out, + state tensors, which are the output tensors from static_state_saving_rnn, + have influence on number of calls to save_state and state methods of + state_saver object (the number of calls should be same.) + """ + num_units = 3 + batch_size = 2 + state_saver = TestStateSaverWithCounters(batch_size, 2 * num_units) + out, state, state_saver = self._factory(scope=None, state_saver=state_saver) + + with self.test_session() as sess: + sess.run(variables_lib.global_variables_initializer()) + sess.run(variables_lib.local_variables_initializer()) + + _, _, num_state_calls, num_save_state_calls = sess.run([ + out, + state, + state_saver.num_state_calls, + state_saver.num_save_state_calls]) + self.assertEqual(num_state_calls, num_save_state_calls) + + _, num_state_calls, num_save_state_calls = sess.run([ + out, + state_saver.num_state_calls, + state_saver.num_save_state_calls]) + self.assertEqual(num_state_calls, num_save_state_calls) + + _, num_state_calls, num_save_state_calls = sess.run([ + state, + state_saver.num_state_calls, + state_saver.num_save_state_calls]) + self.assertEqual(num_state_calls, num_save_state_calls) class GRUTest(test.TestCase): diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index e94ad90..c77a18d 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -1401,6 +1401,13 @@ def static_state_saving_rnn(cell, outputs[-1] = nest.pack_sequence_as( structure=last_output, flat_sequence=flat_last_output) + if state_is_tuple: + state = nest.pack_sequence_as( + structure=state, + flat_sequence=[array_ops.identity(s) for s in flat_state]) + else: + state = array_ops.identity(state) + return (outputs, state) -- 2.7.4