Add a subclassed Model's attribute-assigned variables to Model.weights et al
authorAllen Lavoie <allenl@google.com>
Thu, 31 May 2018 02:01:58 +0000 (19:01 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 31 May 2018 02:04:42 +0000 (19:04 -0700)
Makes the Variable.trainable property public, which is sensible if we're discouraging use of the global collection (currently eager execution is using ResourceVariable._trainable in a bunch of places anyway). I'm leaving it read-only for now, since we should toggle in and out of the global collection when it changes.

Same change for checkpointable data structures with respect to gathering extra variables. They'll behave like subclassed Models.

I think this makes more sense than trying to have a distinction between "variables" and "weights". It's also more sensible than collecting everything that would get checkpointed, since that will include Optimizer slot variables and metrics. Collecting those is generally pointless, and accidentally adding them to gradient tapes would be horribly confusing.

PiperOrigin-RevId: 198656079

15 files changed:
tensorflow/core/framework/variable.proto
tensorflow/python/eager/function.py
tensorflow/python/eager/graph_callable.py
tensorflow/python/eager/pywrap_tfe_src.cc
tensorflow/python/keras/engine/network.py
tensorflow/python/keras/model_subclassing_test.py
tensorflow/python/keras/utils/layer_utils.py
tensorflow/python/kernel_tests/resource_variable_ops_test.py
tensorflow/python/kernel_tests/variables_test.py
tensorflow/python/ops/resource_variable_ops.py
tensorflow/python/ops/variable_scope.py
tensorflow/python/ops/variables.py
tensorflow/python/training/checkpointable/data_structures.py
tensorflow/python/training/checkpointable/data_structures_test.py
tensorflow/tools/api/golden/tensorflow.-variable.pbtxt

index 93ae423..66ba4cb 100644 (file)
@@ -26,6 +26,9 @@ message VariableDef {
 
   // Whether to represent this as a ResourceVariable.
   bool is_resource = 5;
+
+  // Whether this variable should be trained.
+  bool trainable = 7;
 }
 
 message SaveSliceInfoDef {
index 23d87fb..559063d 100644 (file)
@@ -494,7 +494,7 @@ class GraphModeFunction(object):
   def __call__(self, *args):
     """Executes the passed function in eager mode."""
     for v in self._variables:
-      if v._trainable:  # pylint: disable=protected-access
+      if v.trainable:
         tape.watch_variable(v)
 
     tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
index d9ffcbd..760a148 100644 (file)
@@ -202,7 +202,7 @@ class _InitializingFunctionObject(object):
         v.handle).numpy() for v in self._call_fn.variables]
     if all(x for x in initialized):
       for v in self._call_fn.variables:
-        if v._trainable:  # pylint: disable=protected-access
+        if v.trainable:
           tape.watch_variable(v)
       return self._call_fn(*args)
     elif all(not x for x in initialized):
index 52b9050..e3ce0ef 100644 (file)
@@ -1874,10 +1874,10 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
 
 void MaybeWatchVariable(PyObject* input) {
   DCHECK(CheckResourceVariable(input));
-  DCHECK(PyObject_HasAttrString(input, "_trainable"));
+  DCHECK(PyObject_HasAttrString(input, "trainable"));
 
   tensorflow::Safe_PyObjectPtr trainable(
-      PyObject_GetAttrString(input, "_trainable"));
+      PyObject_GetAttrString(input, "trainable"));
   if (trainable.get() == Py_False) return;
   TFE_Py_TapeSetWatchVariable(input);
 }
index 6db4147..f63ca1a 100644 (file)
@@ -36,9 +36,10 @@ from tensorflow.python.keras import backend
 from tensorflow.python.keras.engine import base_layer
 from tensorflow.python.keras.engine import saving
 from tensorflow.python.keras.utils import generic_utils
+from tensorflow.python.keras.utils import layer_utils
 from tensorflow.python.keras.utils import tf_utils
 from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
-from tensorflow.python.keras.utils.layer_utils import print_summary as print_layer_summary
+from tensorflow.python.ops import variables
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training.checkpointable import base as checkpointable
 from tensorflow.python.training.checkpointable import data_structures_base
@@ -94,6 +95,11 @@ class Network(base_layer.Layer):
     self.trainable = True
     self._is_compiled = False
     self._expects_training_arg = False
+    # A list of "extra" variables assigned to attributes of this class, included
+    # in self.weights and self.variables. Always empty for graph networks (but
+    # included in base_init to avoid excessive special casing when retrieving
+    # the value).
+    self._extra_variables = []
 
     self.supports_masking = False
     if not hasattr(self, 'optimizer'):
@@ -347,11 +353,22 @@ class Network(base_layer.Layer):
       # layers). Therefore Model tracks Checkpointable objects itself.
       self._track_checkpointable(
           checkpointable=value, name=name, overwrite=True)
+      if (  # For subclassed models only, users may add extra weights/variables
+            # simply by assigning them to attributes.
+          not self._is_graph_network
+          and isinstance(value, variables.Variable)):
+        self._extra_variables.append(value)
     super(Network, self).__setattr__(name, value)
 
   def add_variable(self, name, shape, dtype=None, initializer=None,
                    regularizer=None, trainable=True, constraint=None):
-    raise NotImplementedError('`add_variable` is not supported on Networks.')
+    if self._is_graph_network:
+      raise NotImplementedError('`add_variable` is not supported on Networks.')
+    else:
+      raise NotImplementedError(
+          '`add_variable` is not supported on Networks. However, you may '
+          'assign variables to attributes and they will show up in the weights '
+          'and variables properties.')
 
   def add_loss(self, *args, **kwargs):
     if context.executing_eagerly():
@@ -589,24 +606,17 @@ class Network(base_layer.Layer):
 
   @property
   def trainable_weights(self):
-    if not self.trainable:
-      return []
-    weights = []
-    for layer in self.layers:
-      weights += layer.trainable_weights
-    return weights
+    return layer_utils.gather_trainable_weights(
+        trainable=self.trainable,
+        sub_layers=self.layers,
+        extra_variables=self._extra_variables)
 
   @property
   def non_trainable_weights(self):
-    weights = []
-    for layer in self.layers:
-      weights += layer.non_trainable_weights
-    if not self.trainable:
-      trainable_weights = []
-      for layer in self.layers:
-        trainable_weights += layer.trainable_weights
-      return trainable_weights + weights
-    return weights
+    return layer_utils.gather_non_trainable_weights(
+        trainable=self.trainable,
+        sub_layers=self.layers,
+        extra_variables=self._extra_variables)
 
   @property
   def input_spec(self):
@@ -1437,10 +1447,10 @@ class Network(base_layer.Layer):
                        'have not yet been created, so no summary can be '
                        'displayed. Build the model first '
                        '(e.g. by calling it on some data).')
-    print_layer_summary(self,
-                        line_length=line_length,
-                        positions=positions,
-                        print_fn=print_fn)
+    layer_utils.print_summary(self,
+                              line_length=line_length,
+                              positions=positions,
+                              print_fn=print_fn)
 
 
 def get_source_inputs(tensor, layer=None, node_index=None):
index 558854a..86f7e20 100644 (file)
@@ -622,6 +622,51 @@ class ModelSubclassingTest(test.TestCase):
     self.assertIs(m.isdep, m._checkpoint_dependencies[0].ref)
     self.assertEqual('notdep_var:0', m.notdep_var.name)
 
+  def test_extra_variable(self):
+
+    class ExtraVar(keras.Model):
+
+      def __init__(self):
+        super(ExtraVar, self).__init__()
+        self.dense = keras.layers.Dense(1)
+        self.var = resource_variable_ops.ResourceVariable(1.)
+        self.not_trainable_var = resource_variable_ops.ResourceVariable(
+            2., trainable=False)
+
+      def call(self, inputs):
+        return self.dense(inputs + self.var)
+
+    m = ExtraVar()
+    self.assertTrue(m.trainable)
+    self.assertEqual([m.dense], m.layers)
+    self.assertEqual([m.var, m.not_trainable_var], m.variables)
+    self.assertEqual([m.var], m.trainable_variables)
+    self.assertEqual([m.not_trainable_var], m.non_trainable_variables)
+    m.trainable = False
+    self.assertEqual([m.var, m.not_trainable_var], m.variables)
+    self.assertEqual([], m.trainable_variables)
+    self.assertEqual([m.var, m.not_trainable_var], m.non_trainable_variables)
+    m.trainable = True
+
+    m(array_ops.ones([1, 1]))
+
+    self.assertEqual([m.dense.kernel, m.dense.bias], m.dense.variables)
+    self.assertEqual([m.dense.kernel, m.dense.bias], m.dense.weights)
+
+    self.assertEqual([m.dense.kernel, m.dense.bias, m.var, m.not_trainable_var],
+                     m.variables)
+    self.assertEqual([m.dense.kernel, m.dense.bias, m.var],
+                     m.trainable_variables)
+    self.assertEqual([m.not_trainable_var], m.non_trainable_variables)
+
+    m.dense.trainable = False
+    self.assertEqual(
+        [m.var, m.dense.kernel, m.dense.bias, m.not_trainable_var],
+        m.variables)
+    self.assertEqual([m.var], m.trainable_variables)
+    self.assertEqual([m.dense.kernel, m.dense.bias, m.not_trainable_var],
+                     m.non_trainable_variables)
+
 
 class CustomCallModel(keras.Model):
 
index bd61f8e..88daff0 100644 (file)
@@ -201,6 +201,61 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
   print_fn('_' * line_length)
 
 
+def gather_trainable_weights(trainable, sub_layers, extra_variables):
+  """Lists the trainable weights for an object with sub-layers.
+
+  Args:
+    trainable: Whether the object collecting the variables is trainable.
+    sub_layers: A flat list of Layer objects owned by this object, to collect
+      variables from.
+    extra_variables: Any extra variables to include. Their `.trainable` property
+      is used to categorize them.
+
+  Returns:
+    A list of collected trainable weights/variables.
+  """
+  if not trainable:
+    return []
+  weights = []
+  for layer in sub_layers:
+    weights += layer.trainable_weights
+  trainable_extra_variables = [
+      v for v in extra_variables if v.trainable]
+  return weights + trainable_extra_variables
+
+
+def gather_non_trainable_weights(trainable, sub_layers, extra_variables):
+  """Lists the non-trainable weights for an object with sub-layers.
+
+  Args:
+    trainable: Whether the object collecting the variables is trainable.
+    sub_layers: A flat list of Layer objects owned by this object, to collect
+      variables from.
+    extra_variables: Any extra variables to include. Their `.trainable` property
+      is used to categorize them.
+
+  Returns:
+    A list of collected non-trainable weights/variables.
+  """
+  trainable_extra_variables = []
+  non_trainable_extra_variables = []
+  for v in extra_variables:
+    if v.trainable:
+      trainable_extra_variables.append(v)
+    else:
+      non_trainable_extra_variables.append(v)
+  weights = []
+  for layer in sub_layers:
+    weights += layer.non_trainable_weights
+  if not trainable:
+    trainable_weights = []
+    for layer in sub_layers:
+      trainable_weights += layer.trainable_weights
+    return (trainable_weights + trainable_extra_variables
+            + weights + non_trainable_extra_variables)
+  return weights + non_trainable_extra_variables
+
+
 @tf_export('keras.utils.convert_all_kernels_in_model')
 def convert_all_kernels_in_model(model):
   """Converts all convolution kernels in a model from Theano to TensorFlow.
index 972fbdb..00d517e 100644 (file)
@@ -538,6 +538,25 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
       with self.assertRaises(ValueError):
         sess.run(v.initialized_value())
 
+  def testTrainableInProto(self):
+    with ops.Graph().as_default():
+      non_trainable_variable = resource_variable_ops.ResourceVariable(
+          trainable=False,
+          initial_value=constant_op.constant(10.0))
+      self.assertEqual(
+          False,
+          resource_variable_ops.ResourceVariable(
+              variable_def=non_trainable_variable.to_proto())
+          .trainable)
+      trainable_variable = resource_variable_ops.ResourceVariable(
+          trainable=True,
+          initial_value=constant_op.constant(10.0))
+      self.assertEqual(
+          True,
+          resource_variable_ops.ResourceVariable(
+              variable_def=trainable_variable.to_proto())
+          .trainable)
+
   @test_util.run_in_graph_and_eager_modes()
   def testSparseRead(self):
     with self.test_session():
index 2759986..62d596d 100644 (file)
@@ -496,6 +496,23 @@ class VariablesTestCase(test.TestCase):
       with self.assertRaises(ValueError):
         sess.run(v.initialized_value())
 
+  def testTrainableInProto(self):
+    with ops.Graph().as_default():
+      non_trainable_variable = variables.Variable(
+          trainable=False,
+          initial_value=constant_op.constant(10.0))
+      self.assertEqual(
+          False,
+          variables.Variable(variable_def=non_trainable_variable.to_proto())
+          .trainable)
+      trainable_variable = variables.Variable(
+          trainable=True,
+          initial_value=constant_op.constant(10.0))
+      self.assertEqual(
+          True,
+          variables.Variable(variable_def=trainable_variable.to_proto())
+          .trainable)
+
   def testLoad(self):
     with self.test_session():
       var = variables.Variable(np.zeros((5, 5), np.float32))
index e37e93e..7061b32 100644 (file)
@@ -551,6 +551,7 @@ class ResourceVariable(variables.Variable):
                                  import_scope=import_scope))
     else:
       self._initial_value = None
+    self._trainable = getattr(variable_def, "trainable", True)
     if variable_def.snapshot_name:
       snapshot = g.as_graph_element(
           ops.prepend_name_scope(
@@ -735,7 +736,7 @@ class ResourceVariable(variables.Variable):
     return self._save_slice_info
 
   def _read_variable_op(self):
-    if hasattr(self, "_trainable") and self._trainable:
+    if self.trainable:
       tape.watch_variable(self)
     return gen_resource_variable_ops.read_variable_op(self._handle,
                                                       self._dtype)
@@ -760,7 +761,7 @@ class ResourceVariable(variables.Variable):
   def sparse_read(self, indices, name=None):
     """Reads the value of this variable sparsely, using `gather`."""
     with ops.name_scope("Gather" if name is None else name) as name:
-      if self._trainable:
+      if self.trainable:
         tape.watch_variable(self)
       value = gen_resource_variable_ops.resource_gather(
           self._handle, indices, dtype=self._dtype, name=name)
@@ -801,6 +802,7 @@ class ResourceVariable(variables.Variable):
         var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name,
                                                      export_scope)
       var_def.is_resource = True
+      var_def.trainable = self.trainable
       if self._save_slice_info:
         var_def.save_slice_info_def.MergeFrom(
             self._save_slice_info.to_proto(export_scope=export_scope))
@@ -913,7 +915,7 @@ class ResourceVariable(variables.Variable):
     return assign_add_op
 
   def _lazy_read(self, op):
-    if hasattr(self, "_trainable") and self._trainable:
+    if self.trainable:
       tape.watch_variable(self)
     return _UnreadVariable(
         self._handle, self.dtype, self._shape, self._in_graph_mode,
index 8d93d24..fa34774 100644 (file)
@@ -1261,13 +1261,13 @@ class EagerVariableStore(object):
 
   def trainable_variables(self):
     # pylint: disable=protected-access
-    return sorted([x for x in self._store._vars.values() if x._trainable],
+    return sorted([x for x in self._store._vars.values() if x.trainable],
                   key=lambda x: x.name)
     # pylint: enable=protected-access
 
   def non_trainable_variables(self):
     # pylint: disable=protected-access
-    return sorted([x for x in self._store._vars.values() if not x._trainable],
+    return sorted([x for x in self._store._vars.values() if not x.trainable],
                   key=lambda x: x.name)
     # pylint: enable=protected-access
 
@@ -1296,7 +1296,7 @@ class EagerVariableStore(object):
       new_var = resource_variable_ops.ResourceVariable(
           var.read_value(),
           name=stripped_var_name,
-          trainable=var._trainable)
+          trainable=var.trainable)
       new_store._store._vars[key] = new_var
     return new_store
     # pylint: enable=protected-access
index d88fd83..4be9f5e 100644 (file)
@@ -341,6 +341,7 @@ class Variable(checkpointable.CheckpointableBase):
       self._update_uid = initial_value.checkpoint_position.restore_uid
       initial_value = initial_value.wrapped_value
 
+    self._trainable = trainable
     if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
       collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
     with ops.init_scope():
@@ -450,6 +451,7 @@ class Variable(checkpointable.CheckpointableBase):
                                  import_scope=import_scope))
     else:
       self._initial_value = None
+    self._trainable = getattr(variable_def, "trainable", True)
     self._snapshot = g.as_graph_element(
         ops.prepend_name_scope(variable_def.snapshot_name,
                                import_scope=import_scope))
@@ -543,6 +545,10 @@ class Variable(checkpointable.CheckpointableBase):
     self._ref().set_shape(shape)
     self.value().set_shape(shape)
 
+  @property
+  def trainable(self):
+    return self._trainable
+
   def eval(self, session=None):
     """In a session, computes and returns the value of this variable.
 
@@ -1050,6 +1056,7 @@ class Variable(checkpointable.CheckpointableBase):
         # For backwards compatibility.
         var_def.initial_value_name = ops.strip_name_scope(
             self._initial_value.name, export_scope)
+      var_def.trainable = self.trainable
       var_def.initializer_name = ops.strip_name_scope(
           self.initializer.name, export_scope)
       var_def.snapshot_name = ops.strip_name_scope(
index 62cefa4..69ed253 100644 (file)
@@ -22,6 +22,8 @@ import collections
 import six
 
 from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.keras.utils import layer_utils
+from tensorflow.python.ops import variables
 from tensorflow.python.training.checkpointable import base as checkpointable_lib
 from tensorflow.python.training.checkpointable import data_structures_base
 
@@ -41,11 +43,14 @@ class CheckpointableDataStructure(
   def __init__(self):
     self._layers = []
     self.trainable = True
+    self._extra_variables = []
 
   def _track_value(self, value, name):
     """Add a dependency on `value`."""
     if isinstance(value, checkpointable_lib.CheckpointableBase):
       self._track_checkpointable(value, name=name)
+      if isinstance(value, variables.Variable):
+        self._extra_variables.append(value)
     else:
       raise ValueError(
           ("Only checkpointable objects (such as Layers or Optimizers) may be "
@@ -67,30 +72,31 @@ class CheckpointableDataStructure(
 
   @property
   def trainable_weights(self):
-    if not self.trainable:
-      return []
-    weights = []
-    for layer in self.layers:
-      weights += layer.trainable_weights
-    return weights
+    return layer_utils.gather_trainable_weights(
+        trainable=self.trainable,
+        sub_layers=self.layers,
+        extra_variables=self._extra_variables)
 
   @property
   def non_trainable_weights(self):
-    weights = []
-    for layer in self.layers:
-      weights += layer.non_trainable_weights
-    if not self.trainable:
-      trainable_weights = []
-      for layer in self.layers:
-        trainable_weights += layer.trainable_weights
-      return trainable_weights + weights
-    return weights
+    return layer_utils.gather_non_trainable_weights(
+        trainable=self.trainable,
+        sub_layers=self.layers,
+        extra_variables=self._extra_variables)
 
   @property
   def weights(self):
     return self.trainable_weights + self.non_trainable_weights
 
   @property
+  def trainable_variables(self):
+    return self.trainable_weights
+
+  @property
+  def non_trainable_variables(self):
+    return self.non_trainable_weights
+
+  @property
   def variables(self):
     return self.weights
 
index 31a0e8b..b05b3a8 100644 (file)
@@ -139,6 +139,25 @@ class ListTests(test.TestCase):
           outer.variables[0],
           resource_variable_ops.ResourceVariable)
 
+  def testNonLayerVariables(self):
+    v = resource_variable_ops.ResourceVariable([1.])
+    l = data_structures.List([v])
+    self.assertTrue(l.trainable)
+    self.assertEqual([], l.layers)
+    self.assertEqual([v], l.variables)
+    self.assertEqual([v], l.trainable_weights)
+    self.assertEqual([], l.non_trainable_variables)
+    l.trainable = False
+    self.assertEqual([v], l.variables)
+    self.assertEqual([], l.trainable_variables)
+    self.assertEqual([v], l.non_trainable_variables)
+    l.trainable = True
+    v2 = resource_variable_ops.ResourceVariable(1., trainable=False)
+    l.append(v2)
+    self.assertEqual([v, v2], l.weights)
+    self.assertEqual([v], l.trainable_weights)
+    self.assertEqual([v2], l.non_trainable_weights)
+
   def testHashing(self):
     has_sequences = set([data_structures.List(),
                          data_structures.List()])
index 8c8912d..23b552c 100644 (file)
@@ -43,6 +43,10 @@ tf_class {
     name: "shape"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "trainable"
+    mtype: "<type \'property\'>"
+  }
   member_method {
     name: "__init__"
     argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "