From 620bcf01283abc434b1971106863269168cb8a5a Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Thu, 17 May 2018 14:32:47 -0700 Subject: [PATCH] Basic usability fixes for RNNCell wrappers They weren't calling their parent constructors (for the Keras base Layer), so a bunch of their methods threw odd errors. There may still be issues, but hopefully not so blatent. Fixes #19208. For real this time. PiperOrigin-RevId: 197052962 --- .../rnn/python/kernel_tests/core_rnn_cell_test.py | 26 ++++++++++++++++++++++ tensorflow/python/ops/rnn_cell_impl.py | 3 +++ 2 files changed, 29 insertions(+) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index e512e8d..b8840a8 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import functools +import os import numpy as np @@ -30,6 +31,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -39,6 +41,7 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: enable=protected-access Linear = core_rnn_cell._Linear # pylint: disable=invalid-name @@ -189,6 +192,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(cell.dtype, None) self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name) self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name) + cell.get_config() # Should not throw an error g, out_m = cell(x, m) # Layer infers the input type. self.assertEqual(cell.dtype, dtype.name) @@ -439,6 +443,26 @@ class RNNCellTest(test.TestCase): self.assertTrue( float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6) + @test_util.run_in_graph_and_eager_modes() + def testWrapperCheckpointing(self): + for wrapper_type in [ + rnn_cell_impl.DropoutWrapper, + rnn_cell_impl.ResidualWrapper, + lambda cell: rnn_cell_impl.MultiRNNCell([cell])]: + with self.test_session(): + cell = rnn_cell_impl.BasicRNNCell(1) + wrapper = wrapper_type(cell) + wrapper(array_ops.ones([1, 1]), + state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32)) + self.evaluate([v.initializer for v in cell.variables]) + checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper) + prefix = os.path.join(self.get_temp_dir(), "ckpt") + self.evaluate(cell._bias.assign([40.])) + save_path = checkpoint.save(prefix) + self.evaluate(cell._bias.assign([0.])) + checkpoint.restore(save_path).assert_consumed().run_restore_ops() + self.assertAllEqual([40.], self.evaluate(cell._bias)) + def testOutputProjectionWrapper(self): with self.test_session() as sess: with variable_scope.variable_scope( @@ -485,6 +509,7 @@ class RNNCellTest(test.TestCase): variable_scope.get_variable_scope().reuse_variables() wrapper_object = rnn_cell_impl.ResidualWrapper(base_cell) (name, dep), = wrapper_object._checkpoint_dependencies + wrapper_object.get_config() # Should not throw an error self.assertIs(dep, base_cell) self.assertEqual("cell", name) @@ -534,6 +559,7 @@ class RNNCellTest(test.TestCase): wrapped = rnn_cell_impl.GRUCell(3) cell = rnn_cell_impl.DeviceWrapper(wrapped, "/cpu:14159") (name, dep), = cell._checkpoint_dependencies + cell.get_config() # Should not throw an error self.assertIs(dep, wrapped) self.assertEqual("cell", name) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index e9a2d2d..05723c6 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -979,6 +979,7 @@ class DropoutWrapper(RNNCell): but not `callable`. ValueError: if any of the keep_probs are not between 0 and 1. """ + super(DropoutWrapper, self).__init__() assert_like_rnncell("cell", cell) if (dropout_state_filter_visitor is not None @@ -1153,6 +1154,7 @@ class ResidualWrapper(RNNCell): Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs and outputs. """ + super(ResidualWrapper, self).__init__() self._cell = cell if isinstance(cell, checkpointable.CheckpointableBase): self._track_checkpointable(self._cell, name="cell") @@ -1210,6 +1212,7 @@ class DeviceWrapper(RNNCell): cell: An instance of `RNNCell`. device: A device string or function, for passing to `tf.device`. """ + super(DeviceWrapper, self).__init__() self._cell = cell if isinstance(cell, checkpointable.CheckpointableBase): self._track_checkpointable(self._cell, name="cell") -- 2.7.4