Split out SaveableObjects into their own file
authorAllen Lavoie <allenl@google.com>
Thu, 26 Apr 2018 23:40:16 +0000 (16:40 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 26 Apr 2018 23:42:50 +0000 (16:42 -0700)
Pulls a couple build rules out of tensorflow/python:training. I'd like to use a SaveableObject in :checkpointable (for saving some Python state by default), which means the file with SaveableObject has to be essientially dependency-free.

PiperOrigin-RevId: 194473987

tensorflow/python/BUILD
tensorflow/python/training/saveable_object.py [new file with mode: 0644]
tensorflow/python/training/saver.py

index e2d86fa..105fcba 100644 (file)
@@ -2967,7 +2967,11 @@ py_library(
         ["training/**/*.py"],
         exclude = [
             "**/*test*",
-            "training/training_util.py",  # See :training_util
+            # The following targets have their own build rules (same name as the
+            # file):
+            "training/checkpointable.py",
+            "training/saveable_object.py",
+            "training/training_util.py",
         ],
     ),
     srcs_version = "PY2AND3",
@@ -2975,6 +2979,7 @@ py_library(
         ":array_ops",
         ":array_ops_gen",
         ":checkpoint_ops_gen",
+        ":checkpointable",
         ":client",
         ":control_flow_ops",
         ":data_flow_ops",
@@ -2998,6 +3003,7 @@ py_library(
         ":random_ops",
         ":resource_variable_ops",
         ":resources",
+        ":saveable_object",
         ":sdca_ops",
         ":sparse_ops",
         ":state_ops",
@@ -3044,6 +3050,12 @@ py_test(
 )
 
 py_library(
+    name = "saveable_object",
+    srcs = ["training/saveable_object.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_library(
     name = "device_util",
     srcs = ["training/device_util.py"],
     srcs_version = "PY2AND3",
diff --git a/tensorflow/python/training/saveable_object.py b/tensorflow/python/training/saveable_object.py
new file mode 100644 (file)
index 0000000..4b19294
--- /dev/null
@@ -0,0 +1,99 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Types for specifying saving and loading behavior."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+class SaveSpec(object):
+  """Class used to describe tensor slices that need to be saved."""
+
+  def __init__(self, tensor, slice_spec, name, dtype=None):
+    """Creates a `SaveSpec` object.
+
+    Args:
+      tensor: the tensor to save or callable that produces a tensor to save.
+      slice_spec: the slice to be saved. See `Variable.SaveSliceInfo`.
+      name: the name to save the tensor under.
+      dtype: The data type of the Tensor. Required if `tensor` is callable.
+        Used for error checking in the restore op.
+    """
+    self._tensor = tensor
+    self.slice_spec = slice_spec
+    self.name = name
+    if callable(self._tensor):
+      if dtype is None:
+        raise AssertionError(
+            "When passing a callable `tensor` to a SaveSpec, an explicit "
+            "dtype must be provided.")
+      self.dtype = dtype
+    else:
+      self.dtype = tensor.dtype
+
+  @property
+  def tensor(self):
+    return self._tensor() if callable(self._tensor) else self._tensor
+
+
+class SaveableObject(object):
+  """Base class for saving and restoring saveable objects."""
+
+  def __init__(self, op, specs, name):
+    """Creates a `SaveableObject` object.
+
+    Args:
+      op: the "producer" object that this class wraps; it produces a list of
+        tensors to save.  E.g., a "Variable" object saving its backing tensor.
+      specs: a list of SaveSpec, each element of which describes one tensor to
+        save under this object. All Tensors must be on the same device.
+      name: the name to save the object under.
+    """
+    self.op = op
+    self.specs = specs
+    self.name = name
+    self._device = None
+
+  @property
+  def device(self):
+    """The device for SaveSpec Tensors."""
+    # Note that SaveSpec.tensor runs Tensor-gathering ops when executing
+    # eagerly, making this call potentially very expensive.
+    #
+    # TODO(allenl): Consider another way to gather device information. Lower
+    # priority since this property isn't part of the normal save()/restore()
+    # workflow, but does come up when some alternative builders are passed to
+    # the Saver.
+    if self._device is None:
+      self._device = self.specs[0].tensor.device
+    return self._device
+
+  def restore(self, restored_tensors, restored_shapes):
+    """Restores this object from 'restored_tensors'.
+
+    Args:
+      restored_tensors: the tensors that were loaded from a checkpoint
+      restored_shapes: the shapes this object should conform to after
+        restore, or None.
+
+    Returns:
+      An operation that restores the state of the object.
+
+    Raises:
+      ValueError: If the object cannot be restored using the provided
+        parameters.
+    """
+    # pylint: disable=unused-argument
+    raise ValueError("Calling an abstract method.")
index a74d629..53e821c 100644 (file)
@@ -54,6 +54,7 @@ from tensorflow.python.ops import variables
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import checkpointable
+from tensorflow.python.training import saveable_object
 from tensorflow.python.training import training_util
 from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
 from tensorflow.python.util import compat
@@ -91,84 +92,8 @@ class BaseSaverBuilder(object):
   Can be extended to create different Ops.
   """
 
-  class SaveSpec(object):
-    """Class used to describe tensor slices that need to be saved."""
-
-    def __init__(self, tensor, slice_spec, name, dtype=None):
-      """Creates a `SaveSpec` object.
-
-      Args:
-        tensor: the tensor to save or callable that produces a tensor to save.
-        slice_spec: the slice to be saved. See `Variable.SaveSliceInfo`.
-        name: the name to save the tensor under.
-        dtype: The data type of the Tensor. Required if `tensor` is callable.
-          Used for error checking in the restore op.
-      """
-      self._tensor = tensor
-      self.slice_spec = slice_spec
-      self.name = name
-      if callable(self._tensor):
-        if dtype is None:
-          raise AssertionError(
-              "When passing a callable `tensor` to a SaveSpec, an explicit "
-              "dtype must be provided.")
-        self.dtype = dtype
-      else:
-        self.dtype = tensor.dtype
-
-    @property
-    def tensor(self):
-      return self._tensor() if callable(self._tensor) else self._tensor
-
-  class SaveableObject(object):
-    """Base class for saving and restoring saveable objects."""
-
-    def __init__(self, op, specs, name):
-      """Creates a `SaveableObject` object.
-
-      Args:
-        op: the "producer" object that this class wraps; it produces a list of
-          tensors to save.  E.g., a "Variable" object saving its backing tensor.
-        specs: a list of SaveSpec, each element of which describes one tensor to
-          save under this object. All Tensors must be on the same device.
-        name: the name to save the object under.
-      """
-      self.op = op
-      self.specs = specs
-      self.name = name
-      self._device = None
-
-    @property
-    def device(self):
-      """The device for SaveSpec Tensors."""
-      # Note that SaveSpec.tensor runs Tensor-gathering ops when executing
-      # eagerly, making this call potentially very expensive.
-      #
-      # TODO(allenl): Consider another way to gather device information. Lower
-      # priority since this property isn't part of the normal save()/restore()
-      # workflow, but does come up when some alternative builders are passed to
-      # the Saver.
-      if self._device is None:
-        self._device = self.specs[0].tensor.device
-      return self._device
-
-    def restore(self, restored_tensors, restored_shapes):
-      """Restores this object from 'restored_tensors'.
-
-      Args:
-        restored_tensors: the tensors that were loaded from a checkpoint
-        restored_shapes: the shapes this object should conform to after
-          restore, or None.
-
-      Returns:
-        An operation that restores the state of the object.
-
-      Raises:
-        ValueError: If the object cannot be restored using the provided
-          parameters.
-      """
-      # pylint: disable=unused-argument
-      raise ValueError("Calling an abstract method.")
+  SaveSpec = saveable_object.SaveSpec
+  SaveableObject = saveable_object.SaveableObject
 
   class VariableSaveable(SaveableObject):
     """SaveableObject implementation that handles Variables."""