From 2939af7253339963d0c631e46468bdc26930897a Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Thu, 8 Feb 2018 15:18:31 -0800 Subject: [PATCH] Object based saving prototype: create ResourceVariables directly by default. This avoids variable reuse errors when building a graph. Where necessary for compatibility, we can still use get_variable. PiperOrigin-RevId: 185060891 --- tensorflow/contrib/eager/python/BUILD | 6 ++ tensorflow/contrib/eager/python/checkpointable.py | 62 +++++++++-- .../contrib/eager/python/checkpointable_test.py | 113 ++++++++++++++++----- 3 files changed, 147 insertions(+), 34 deletions(-) diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 46406e3..cfb38a1 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -226,9 +226,13 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/contrib/eager/proto:checkpointable_object_graph_proto_py", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", "//tensorflow/python:io_ops", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:state_ops", + "//tensorflow/python:tensor_shape", "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", @@ -246,7 +250,9 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", "//tensorflow/python:layers", + "//tensorflow/python:layers_base", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:state_ops", "//tensorflow/python:training", diff --git a/tensorflow/contrib/eager/python/checkpointable.py b/tensorflow/contrib/eager/python/checkpointable.py index 09d4054..896b38a 100644 --- a/tensorflow/contrib/eager/python/checkpointable.py +++ b/tensorflow/contrib/eager/python/checkpointable.py @@ -23,8 +23,12 @@ import weakref from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2 from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import init_ops from tensorflow.python.ops import io_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import optimizer as optimizer_lib @@ -75,6 +79,40 @@ def _assign_existing_variable(variable_to_restore, value_pointer): value_pointer.session.run(initializer_op) +def _default_getter(name, shape, dtype, initializer=None, + partition_info=None, **kwargs): + """A pared-down version of get_variable which does not reuse variables.""" + dtype = dtypes.as_dtype(dtype) + shape_object = tensor_shape.as_shape(shape) + with ops.init_scope(): + if initializer is None: + initializer, initializing_from_value = ( + variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access + name=name, shape=shape_object, dtype=dtype)) + else: + initializing_from_value = not callable(initializer) + # Same logic as get_variable + if initializing_from_value: + if shape is not None: + raise ValueError("If initializer is a constant, do not specify shape.") + initial_value = initializer + variable_dtype = None + else: + # Instantiate initializer if provided initializer is a type object. + if isinstance(initializer, type(init_ops.Initializer)): + initializer = initializer(dtype=dtype) + def initial_value(): + return initializer( + shape_object.as_list(), dtype=dtype, partition_info=partition_info) + variable_dtype = dtype.base_dtype + return resource_variable_ops.ResourceVariable( + initial_value=initial_value, + name=name, + dtype=variable_dtype, + **kwargs + ) + + class Checkpointable(object): """Manages variables and dependencies on other objects. @@ -117,7 +155,8 @@ class Checkpointable(object): and value not in self._already_tracked): self.track_checkpointable(value, name=name) - def add_variable(self, name, shape, dtype=None, initializer=None, **kwargs): + def add_variable(self, name, shape=None, dtype=dtypes.float32, + initializer=None, **kwargs): """Create a new variable object to be saved with this `Checkpointable`. If the user has requested that this object or another `Checkpointable` which @@ -131,14 +170,18 @@ class Checkpointable(object): dtype: The data type of the variable. initializer: The initializer to use. Ignored if deferred loading has been requested. - **kwargs: Passed to get_variable. + **kwargs: Passed to the ResourceVariable constructor. Returns: The new variable object. Raises: ValueError: If the variable name is not unique. + RuntimeError: If __init__ has not been called. """ + if not hasattr(self, "_owned_variables"): + raise RuntimeError("Need to call Checkpointable.__init__ before adding " + "variables.") if name in self._owned_variables: raise ValueError( ("A variable named '%s' already exists in this Checkpointable, but " @@ -151,18 +194,19 @@ class Checkpointable(object): # be relatively uncommon in user code. getter = kwargs.pop("getter") else: - getter = variable_scope.get_variable + getter = _default_getter deferred_restoration = self._deferred_restorations.pop(name, None) if deferred_restoration is not None: dtype = deferred_restoration.value_pointer.dtype base_type = dtype.base_dtype # TODO(allenl): Handle partitioned variables here too - initializer, = io_ops.restore_v2( - prefix=deferred_restoration.value_pointer.save_path, - tensor_names=[deferred_restoration.value_pointer.checkpoint_key], - shape_and_slices=[""], - dtypes=[base_type], - name="checkpoint_initializer") + with ops.init_scope(): + initializer, = io_ops.restore_v2( + prefix=deferred_restoration.value_pointer.save_path, + tensor_names=[deferred_restoration.value_pointer.checkpoint_key], + shape_and_slices=[""], + dtypes=[base_type], + name="checkpoint_initializer") # We need to un-set the shape so get_variable doesn't complain, but we # also need to set the static shape information on the initializer if # possible so we don't get a variable with an unknown shape. diff --git a/tensorflow/contrib/eager/python/checkpointable_test.py b/tensorflow/contrib/eager/python/checkpointable_test.py index f0c3df5..f7bc155 100644 --- a/tensorflow/contrib/eager/python/checkpointable_test.py +++ b/tensorflow/contrib/eager/python/checkpointable_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.layers import base from tensorflow.python.layers import core +from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope @@ -119,15 +120,7 @@ class NonLayerCheckpointable(checkpointable.Checkpointable): def __init__(self): super(NonLayerCheckpointable, self).__init__() - with variable_scope.variable_scope(None, default_name="non_layer"): - # Unfortunately using tf.get_variable to implement self.add_variable - # (necessary for backwards compatibile naming with Layers) we can still - # run into duplicate variable errors (when building a graph, not when - # executing eagerly), thus the variable scope. - # - # TODO(allenl): Consider creating a ResourceVariable directly by - # default so that variable reuse isn't an issue. - self._a_variable = self.add_variable("a_variable", shape=[]) + self.a_variable = self.add_variable(name="a_variable", shape=[]) class MyNetwork(CheckpointableNetwork): @@ -158,17 +151,92 @@ class Root(checkpointable.Checkpointable): def global_step(self): if self._global_step is None: # Get the default create_global_step utility to actually call - # self.add_variable, by setting a custom getter. - def _owned_variable_as_custom_getter(getter, *args, **kwargs): - return self.add_variable(*args, getter=getter, **kwargs) - - with variable_scope.variable_scope( - "", custom_getter=_owned_variable_as_custom_getter): + # self.add_variable, by setting a custom creator. + def _owned_variable_as_creator( + next_creator, initial_value, **kwargs): + def _creator_as_getter(initializer, **kwargs): + return next_creator(initial_value=initializer, **kwargs) + return self.add_variable( + getter=_creator_as_getter, initializer=initial_value, shape=[], + **kwargs) + + with variable_scope.variable_creator_scope( + _owned_variable_as_creator): self._global_step = training_util.create_global_step() return self._global_step -class CheckpointNamingTests(test.TestCase): +class InterfaceTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testAddVariable(self): + obj = NonLayerCheckpointable() + with self.assertRaisesRegexp(ValueError, "do not specify shape"): + obj.add_variable( + name="shape_specified_twice", shape=[], initializer=1) + constant_initializer = obj.add_variable( + name="constant_initializer", initializer=1) + with variable_scope.variable_scope("some_variable_scope"): + ones_initializer = obj.add_variable( + name="ones_initializer", + shape=[2], + initializer=init_ops.ones_initializer(dtype=dtypes.float32)) + bare_initializer = obj.add_variable( + name="bare_initializer", + shape=[2, 2], + dtype=dtypes.float64, + initializer=init_ops.zeros_initializer) + + # Even in graph mode, there are no naming conflicts between objects, only + # naming conflicts within an object. + other_duplicate = resource_variable_ops.ResourceVariable( + name="duplicate", initial_value=1.) + duplicate = obj.add_variable(name="duplicate", shape=[]) + with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"): + obj.add_variable(name="duplicate", shape=[]) + + if context.in_graph_mode(): + self.evaluate(variables.global_variables_initializer()) + self.assertEqual("constant_initializer:0", constant_initializer.name) + self.assertEqual(1, self.evaluate(constant_initializer)) + self.assertEqual("some_variable_scope/ones_initializer:0", + ones_initializer.name) + self.assertAllEqual([1, 1], self.evaluate(ones_initializer)) + self.assertAllEqual([[0., 0.], + [0., 0.]], self.evaluate(bare_initializer)) + self.assertEqual("a_variable:0", obj.a_variable.name) + self.assertEqual("duplicate:0", other_duplicate.name) + if context.in_graph_mode(): + # The .name attribute may be globally influenced, but the checkpoint name + # won't be (tested below). + self.assertEqual("duplicate_1:0", duplicate.name) + else: + # When executing eagerly, there's no uniquification of variable names. The + # checkpoint name will be the same. + self.assertEqual("duplicate:0", duplicate.name) + named_variables, _ = checkpointable._serialize_object_graph(obj) + expected_checkpoint_names = ( + "a_variable", + "bare_initializer", + "constant_initializer", + "duplicate", + "ones_initializer", + ) + six.assertCountEqual( + self, expected_checkpoint_names, named_variables.keys()) + + def testInitNotCalled(self): + + class NoInit(checkpointable.Checkpointable): + + def __init__(self): + pass + + with self.assertRaisesRegexp(RuntimeError, "__init__"): + NoInit().add_variable("var", shape=[]) + + +class CheckpointingTests(test.TestCase): @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNamingWithOptimizer(self): @@ -391,12 +459,7 @@ class CheckpointNamingTests(test.TestCase): def _get_checkpoint_name(self, name): root = checkpointable.Checkpointable() - with variable_scope.variable_scope("get_checkpoint_name"): - # Create the variable in a variable scope so that we get more relaxed - # naming rules (variables outside a scope may not start with "_", "/" or - # "-"). Since we don't use the scope part of the name, these cases are - # somewhat annoying. - root.add_variable(name=name, shape=[1, 2], dtype=dtypes.float64) + root.add_variable(name=name, shape=[1, 2], dtype=dtypes.float64) named_variables, _ = checkpointable._serialize_object_graph(root) checkpoint_name, = named_variables.keys() with ops.name_scope("root/" + checkpoint_name): @@ -406,9 +469,9 @@ class CheckpointNamingTests(test.TestCase): @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testVariableNameEscaping(self): self.assertEqual(r"a_S__b_S__c", self._get_checkpoint_name(r"a/b/c")) - self.assertEqual(r"", self._get_checkpoint_name(r"")) - self.assertEqual(r"_S__", self._get_checkpoint_name(r"/")) - self.assertEqual(r"_S___S_._", self._get_checkpoint_name(r"/_S__")) + self.assertEqual(r"b", self._get_checkpoint_name(r"b")) + self.assertEqual(r"c_S__", self._get_checkpoint_name(r"c/")) + self.assertEqual(r"d_S___S_._", self._get_checkpoint_name(r"d/_S__")) @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNumberedPath(self): -- 2.7.4