from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
self.saved_state[name] = state
return array_ops.identity(state)
+ @property
+ def batch_size(self):
+ return self._batch_size
+
+ @property
+ def state_size(self):
+ return self._state_size
+
+
+class TestStateSaverWithCounters(TestStateSaver):
+ """Class wrapper around TestStateSaver.
+
+ A dummy class used for testing of static_state_saving_rnn. It helps test if
+ save_state and state functions got called same number of time when we
+ evaluate output of rnn cell and state or either of them separately. It
+ inherits from the TestStateSaver and adds the counters for calls of functions.
+ """
+
+ def __init__(self, batch_size, state_size):
+ super(TestStateSaverWithCounters, self).__init__(batch_size, state_size)
+ self._num_state_calls = variables_lib.Variable(0)
+ self._num_save_state_calls = variables_lib.Variable(0)
+
+ def state(self, name):
+ with ops_lib.control_dependencies(
+ [state_ops.assign_add(self._num_state_calls, 1)]):
+ return super(TestStateSaverWithCounters, self).state(name)
+
+ def save_state(self, name, state):
+ with ops_lib.control_dependencies([state_ops.assign_add(
+ self._num_save_state_calls, 1)]):
+ return super(TestStateSaverWithCounters, self).save_state(name, state)
+
+ @property
+ def num_state_calls(self):
+ return self._num_state_calls
+
+ @property
+ def num_save_state_calls(self):
+ return self._num_save_state_calls
+
class RNNTest(test.TestCase):
self._seed = 23489
np.random.seed(self._seed)
- def _testScope(self, factory, prefix="prefix", use_outer_scope=True):
+ def _factory(self, scope, state_saver):
+ num_units = state_saver.state_size // 2
+ batch_size = state_saver.batch_size
+ input_size = 5
+ max_length = 8
+ initializer = init_ops.random_uniform_initializer(
+ -0.01, 0.01, seed=self._seed)
+ cell = rnn_cell.LSTMCell(
+ num_units,
+ use_peepholes=False,
+ initializer=initializer,
+ state_is_tuple=False)
+ inputs = max_length * [
+ array_ops.zeros(dtype=dtypes.float32, shape=(batch_size, input_size))
+ ]
+ out, state = rnn.static_state_saving_rnn(
+ cell,
+ inputs,
+ state_saver=state_saver,
+ state_name="save_lstm",
+ scope=scope)
+ return out, state, state_saver
+
+ def _testScope(self, prefix="prefix", use_outer_scope=True):
+ num_units = 3
+ batch_size = 2
+ state_saver = TestStateSaver(batch_size, 2 * num_units)
+
with self.test_session(use_gpu=True, graph=ops_lib.Graph()):
if use_outer_scope:
with variable_scope.variable_scope(prefix) as scope:
- factory(scope)
+ self._factory(scope=scope, state_saver=state_saver)
else:
- factory(prefix)
+ self._factory(scope=prefix, state_saver=state_saver)
variables_lib.global_variables_initializer()
# check that all the variables names starts
self.assertEqual(len(scope_vars), len(all_vars))
def testStateSaverRNNScope(self):
- num_units = 3
- input_size = 5
- batch_size = 2
- max_length = 8
+ self._testScope(use_outer_scope=True)
+ self._testScope(use_outer_scope=False)
+ self._testScope(prefix=None, use_outer_scope=False)
- def factory(scope):
- initializer = init_ops.random_uniform_initializer(
- -0.01, 0.01, seed=self._seed)
- state_saver = TestStateSaver(batch_size, 2 * num_units)
- cell = rnn_cell.LSTMCell(
- num_units,
- use_peepholes=False,
- initializer=initializer,
- state_is_tuple=False)
- inputs = max_length * [
- array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
- ]
- return rnn.static_state_saving_rnn(
- cell,
- inputs,
- state_saver=state_saver,
- state_name="save_lstm",
- scope=scope)
+ def testStateSaverCallsSaveState(self):
+ """Test that number of calls to state and save_state is equal.
- self._testScope(factory, use_outer_scope=True)
- self._testScope(factory, use_outer_scope=False)
- self._testScope(factory, prefix=None, use_outer_scope=False)
+ Test if the order of actual evaluating or skipping evaluation of out,
+ state tensors, which are the output tensors from static_state_saving_rnn,
+ have influence on number of calls to save_state and state methods of
+ state_saver object (the number of calls should be same.)
+ """
+ num_units = 3
+ batch_size = 2
+ state_saver = TestStateSaverWithCounters(batch_size, 2 * num_units)
+ out, state, state_saver = self._factory(scope=None, state_saver=state_saver)
+
+ with self.test_session() as sess:
+ sess.run(variables_lib.global_variables_initializer())
+ sess.run(variables_lib.local_variables_initializer())
+
+ _, _, num_state_calls, num_save_state_calls = sess.run([
+ out,
+ state,
+ state_saver.num_state_calls,
+ state_saver.num_save_state_calls])
+ self.assertEqual(num_state_calls, num_save_state_calls)
+
+ _, num_state_calls, num_save_state_calls = sess.run([
+ out,
+ state_saver.num_state_calls,
+ state_saver.num_save_state_calls])
+ self.assertEqual(num_state_calls, num_save_state_calls)
+
+ _, num_state_calls, num_save_state_calls = sess.run([
+ state,
+ state_saver.num_state_calls,
+ state_saver.num_save_state_calls])
+ self.assertEqual(num_state_calls, num_save_state_calls)
class GRUTest(test.TestCase):