self.session = session
-class Checkpointable(object):
- """Manages dependencies on other objects.
+class CheckpointableBase(object):
+ """Base class for `Checkpointable` objects without automatic dependencies.
- `Checkpointable` objects may have dependencies: other `Checkpointable` objects
- which should be saved if the object declaring the dependency is saved. A
- correctly saveable program has a dependency graph such that if changing a
- global variable affects an object (e.g. changes the behavior of any of its
- methods) then there is a chain of dependencies from the influenced object to
- the variable.
-
- Dependency edges have names, and are created implicitly when a
- `Checkpointable` object is assigned to an attribute of another
- `Checkpointable` object. For example:
-
- ```
- obj = Checkpointable()
- obj.v = ResourceVariable(0.)
- ```
-
- The `Checkpointable` object `obj` now has a dependency named "v" on a
- variable.
-
- `Checkpointable` objects may specify `Tensor`s to be saved and restored
- directly (e.g. a `Variable` indicating how to save itself) rather than through
- dependencies on other objects. See
- `Checkpointable._scatter_tensors_from_checkpoint` and
- `Checkpointable._gather_tensors_for_checkpoint` for details.
+ This class has no __setattr__ override for performance reasons. Dependencies
+ must be added explicitly. Unless attribute assignment is performance-critical,
+ use `Checkpointable` instead. Use `CheckpointableBase` for `isinstance`
+ checks.
"""
def _maybe_initialize_checkpointable(self):
"initialization code was run.")
self._update_uid = -1
- def __setattr__(self, name, value):
- """Support self.foo = checkpointable syntax."""
- # Perform the attribute assignment, and potentially call other __setattr__
- # overrides such as that for tf.keras.Model.
- super(Checkpointable, self).__setattr__(name, value)
- if isinstance(value, Checkpointable):
- self._track_checkpointable(
- value, name=name,
- # Allow the user to switch the Checkpointable which is tracked by this
- # name, since assigning a new variable to an attribute has
- # historically been fine (e.g. Adam did this).
- # TODO(allenl): Should this be a warning once Checkpointable save/load
- # is usable?
- overwrite=True)
-
def _add_variable_with_custom_getter(
self, name, shape=None, dtype=dtypes.float32,
initializer=None, getter=None, **kwargs_for_getter):
ValueError: If another object is already tracked by this name.
"""
self._maybe_initialize_checkpointable()
- if not isinstance(checkpointable, Checkpointable):
+ if not isinstance(checkpointable, CheckpointableBase):
raise TypeError(
("Checkpointable._track_checkpointable() passed type %s, not a "
"Checkpointable.") % (type(checkpointable),))
def _gather_tensors_for_checkpoint(self):
"""Returns a dictionary of Tensors to save with this object."""
return {}
+
+
+class Checkpointable(CheckpointableBase):
+ """Manages dependencies on other objects.
+
+ `Checkpointable` objects may have dependencies: other `Checkpointable` objects
+ which should be saved if the object declaring the dependency is saved. A
+ correctly saveable program has a dependency graph such that if changing a
+ global variable affects an object (e.g. changes the behavior of any of its
+ methods) then there is a chain of dependencies from the influenced object to
+ the variable.
+
+ Dependency edges have names, and are created implicitly when a
+ `Checkpointable` object is assigned to an attribute of another
+ `Checkpointable` object. For example:
+
+ ```
+ obj = Checkpointable()
+ obj.v = ResourceVariable(0.)
+ ```
+
+ The `Checkpointable` object `obj` now has a dependency named "v" on a
+ variable.
+
+ `Checkpointable` objects may specify `Tensor`s to be saved and restored
+ directly (e.g. a `Variable` indicating how to save itself) rather than through
+ dependencies on other objects. See
+ `Checkpointable._scatter_tensors_from_checkpoint` and
+ `Checkpointable._gather_tensors_for_checkpoint` for details.
+ """
+
+ def __setattr__(self, name, value):
+ """Support self.foo = checkpointable syntax."""
+ # Perform the attribute assignment, and potentially call other __setattr__
+ # overrides such as that for tf.keras.Model.
+ super(Checkpointable, self).__setattr__(name, value)
+ if isinstance(value, CheckpointableBase):
+ self._track_checkpointable(
+ value, name=name,
+ # Allow the user to switch the Checkpointable which is tracked by this
+ # name, since assigning a new variable to an attribute has
+ # historically been fine (e.g. Adam did this).
+ # TODO(allenl): Should this be a warning once Checkpointable save/load
+ # is usable?
+ overwrite=True)