Adding a _create_state method to FeatureColumn so that we can decouple variable ...
authorRohan Jain <rohanj@google.com>
Wed, 4 Apr 2018 21:29:03 +0000 (14:29 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 4 Apr 2018 21:34:38 +0000 (14:34 -0700)
PiperOrigin-RevId: 191647517

tensorflow/python/feature_column/feature_column.py
tensorflow/python/feature_column/feature_column_test.py

index 92c6ff2..e116739 100644 (file)
@@ -1643,6 +1643,19 @@ class _FeatureColumn(object):
     """
     pass
 
+  def _create_state(self, weight_collections=None, creator=None):
+    """Returns an object that captures the state of the column.
+
+    Args:
+      weight_collections: Collections to add the variable to
+      creator: Variable creator method called, if provided.
+
+    Returns:
+      An object that encapsulates the state of the column. Can return None.
+    """
+    del weight_collections, creator  # Unused
+    return None
+
 
 class _DenseColumn(_FeatureColumn):
   """Represents a column which can be represented as `Tensor`.
@@ -1662,7 +1675,11 @@ class _DenseColumn(_FeatureColumn):
     pass
 
   @abc.abstractmethod
-  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+  def _get_dense_tensor(self,
+                        inputs,
+                        weight_collections=None,
+                        trainable=None,
+                        state=None):
     """Returns a `Tensor`.
 
     The output of this function will be used by model-builder-functions. For
@@ -1680,6 +1697,9 @@ class _DenseColumn(_FeatureColumn):
         will be created) are added.
       trainable: If `True` also add variables to the graph collection
         `GraphKeys.TRAINABLE_VARIABLES` (see @{tf.Variable}).
+      state: An object encapsulating the state of the column. Columns that
+        create state using the _create_state method would have that state
+        passed in to this method.
 
     Returns:
       `Tensor` of shape [batch_size] + `_variable_shape`.
@@ -1687,13 +1707,13 @@ class _DenseColumn(_FeatureColumn):
     pass
 
 
-def _create_weighted_sum(
-    column,
-    builder,
-    units,
-    sparse_combiner,
-    weight_collections,
-    trainable):
+def _create_weighted_sum(column,
+                         builder,
+                         units,
+                         sparse_combiner,
+                         weight_collections,
+                         trainable,
+                         state=None):
   """Creates a weighted sum for a dense or sparse column for linear_model."""
   if isinstance(column, _CategoricalColumn):
     return _create_categorical_column_weighted_sum(
@@ -1709,16 +1729,28 @@ def _create_weighted_sum(
         builder=builder,
         units=units,
         weight_collections=weight_collections,
-        trainable=trainable)
+        trainable=trainable,
+        state=state)
 
 
-def _create_dense_column_weighted_sum(
-    column, builder, units, weight_collections, trainable):
+def _create_dense_column_weighted_sum(column,
+                                      builder,
+                                      units,
+                                      weight_collections,
+                                      trainable,
+                                      state=None):
   """Create a weighted sum of a dense column for linear_model."""
-  tensor = column._get_dense_tensor(  # pylint: disable=protected-access
-      builder,
-      weight_collections=weight_collections,
-      trainable=trainable)
+  if state is not None:
+    tensor = column._get_dense_tensor(  # pylint: disable=protected-access
+        builder,
+        weight_collections=weight_collections,
+        trainable=trainable,
+        state=state)
+  else:
+    tensor = column._get_dense_tensor(  # pylint: disable=protected-access
+        builder,
+        weight_collections=weight_collections,
+        trainable=trainable)
   num_elements = column._variable_shape.num_elements()  # pylint: disable=protected-access
   batch_size = array_ops.shape(tensor)[0]
   tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
@@ -2195,8 +2227,33 @@ class _EmbeddingColumn(
       self._shape = tensor_shape.vector(self.dimension)
     return self._shape
 
-  def _get_dense_tensor_internal(
-      self, inputs, weight_collections=None, trainable=None):
+  def _create_state(self, weight_collections=None, creator=None):
+    variables_map = {}
+    embedding_shape = (self.categorical_column._num_buckets, self.dimension)  # pylint: disable=protected-access
+    if creator is not None:
+      embedding_weights = creator(
+          name='embedding_weights',
+          shape=embedding_shape,
+          dtype=dtypes.float32,
+          initializer=self.initializer,
+          trainable=self.trainable)
+      ops.add_to_collections(weight_collections, embedding_weights)
+    else:
+      embedding_weights = variable_scope.get_variable(
+          name='embedding_weights',
+          shape=embedding_shape,
+          dtype=dtypes.float32,
+          initializer=self.initializer,
+          trainable=self.trainable,
+          collections=weight_collections)
+    variables_map['embedding_weights'] = embedding_weights
+    return variables_map
+
+  def _get_dense_tensor_internal(self,
+                                 inputs,
+                                 weight_collections=None,
+                                 trainable=None,
+                                 state=None):
     """Private method that follows the signature of _get_dense_tensor."""
     # Get sparse IDs and weights.
     sparse_tensors = self.categorical_column._get_sparse_tensors(  # pylint: disable=protected-access
@@ -2204,14 +2261,10 @@ class _EmbeddingColumn(
     sparse_ids = sparse_tensors.id_tensor
     sparse_weights = sparse_tensors.weight_tensor
 
-    embedding_shape = (self.categorical_column._num_buckets, self.dimension)  # pylint: disable=protected-access
-    embedding_weights = variable_scope.get_variable(
-        name='embedding_weights',
-        shape=embedding_shape,
-        dtype=dtypes.float32,
-        initializer=self.initializer,
-        trainable=self.trainable and trainable,
-        collections=weight_collections)
+    if state is None:
+      state = self._create_state(weight_collections)
+    embedding_weights = state['embedding_weights']
+
     if self.ckpt_to_load_from is not None:
       to_restore = embedding_weights
       if isinstance(to_restore, variables.PartitionedVariable):
@@ -2229,7 +2282,11 @@ class _EmbeddingColumn(
         name='%s_weights' % self.name,
         max_norm=self.max_norm)
 
-  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+  def _get_dense_tensor(self,
+                        inputs,
+                        weight_collections=None,
+                        trainable=None,
+                        state=None):
     if isinstance(self.categorical_column, _SequenceCategoricalColumn):
       raise ValueError(
           'In embedding_column: {}. '
@@ -2242,8 +2299,10 @@ class _EmbeddingColumn(
               self.name, type(self.categorical_column),
               self.categorical_column))
     return self._get_dense_tensor_internal(
-        inputs=inputs, weight_collections=weight_collections,
-        trainable=trainable)
+        inputs=inputs,
+        weight_collections=weight_collections,
+        trainable=trainable,
+        state=state)
 
   def _get_sequence_dense_tensor(
       self, inputs, weight_collections=None, trainable=None):
@@ -2299,7 +2358,39 @@ class _SharedEmbeddingColumn(
       self._shape = tensor_shape.vector(self.dimension)
     return self._shape
 
-  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+  def _create_state(self, weight_collections=None, creator=None):
+    variables_map = {}
+    shared_embedding_collection = ops.get_collection(
+        self.shared_embedding_collection_name)
+    if not shared_embedding_collection:
+      embedding_shape = (self.categorical_column._num_buckets, self.dimension)  # pylint: disable=protected-access
+      if creator is not None:
+        embedding_weights = creator(
+            name='embedding_weights',
+            shape=embedding_shape,
+            dtype=dtypes.float32,
+            initializer=self.initializer,
+            trainable=self.trainable)
+        ops.add_to_collections(weight_collections, embedding_weights)
+      else:
+        embedding_weights = variable_scope.get_variable(
+            name='embedding_weights',
+            shape=embedding_shape,
+            dtype=dtypes.float32,
+            initializer=self.initializer,
+            trainable=self.trainable,
+            collections=weight_collections)
+      ops.add_to_collection(self.shared_embedding_collection_name,
+                            embedding_weights)
+      variables_map['embedding_weights'] = embedding_weights
+
+    return variables_map
+
+  def _get_dense_tensor(self,
+                        inputs,
+                        weight_collections=None,
+                        trainable=None,
+                        state=None):
     # This method is called from a variable_scope with name _var_scope_name,
     # which is shared among all shared embeddings. Open a name_scope here, so
     # that the ops for different columns have distinct names.
index 6f366e7..4006a76 100644 (file)
@@ -3733,6 +3733,70 @@ class EmbeddingColumnTest(test.TestCase):
       self.assertAllEqual(embedding_values, global_vars[0].eval())
       self.assertAllEqual(expected_lookups, embedding_lookup.eval())
 
+  def test_get_dense_tensor_with_state(self):
+    # Inputs.
+    vocabulary_size = 3
+    sparse_input = sparse_tensor.SparseTensorValue(
+        # example 0, ids [2]
+        # example 1, ids [0, 1]
+        # example 2, ids []
+        # example 3, ids [1]
+        indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+        values=(2, 0, 1, 1),
+        dense_shape=(4, 5))
+
+    # Embedding variable.
+    embedding_dimension = 2
+    embedding_values = (
+        (1., 2.),  # id 0
+        (3., 5.),  # id 1
+        (7., 11.)  # id 2
+    )
+    def _initializer(shape, dtype, partition_info):
+      self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+      self.assertEqual(dtypes.float32, dtype)
+      self.assertIsNone(partition_info)
+      return embedding_values
+
+    # Expected lookup result, using combiner='mean'.
+    expected_lookups = (
+        # example 0, ids [2], embedding = [7, 11]
+        (7., 11.),
+        # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+        (2., 3.5),
+        # example 2, ids [], embedding = [0, 0]
+        (0., 0.),
+        # example 3, ids [1], embedding = [3, 5]
+        (3., 5.),
+    )
+
+    # Build columns.
+    categorical_column = fc.categorical_column_with_identity(
+        key='aaa', num_buckets=vocabulary_size)
+    embedding_column = fc.embedding_column(
+        categorical_column, dimension=embedding_dimension,
+        initializer=_initializer)
+
+    # Create embedding_weights variable.
+    weight_collections = [ops.GraphKeys.GLOBAL_VARIABLES,
+                          ops.GraphKeys.MODEL_VARIABLES]
+    state = embedding_column._create_state(weight_collections)
+
+    # Provide sparse input and get dense result.
+    embedding_lookup = embedding_column._get_dense_tensor(
+        _LazyBuilder({
+            'aaa': sparse_input
+        }),
+        state=state)
+
+    # Assert expected embedding variable and lookups.
+    global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+    self.assertItemsEqual(
+        ('embedding_weights:0',), tuple([v.name for v in global_vars]))
+    with _initialized_session():
+      self.assertAllEqual(embedding_values, global_vars[0].eval())
+      self.assertAllEqual(expected_lookups, embedding_lookup.eval())
+
   def test_get_dense_tensor_3d(self):
     # Inputs.
     vocabulary_size = 4
@@ -4453,6 +4517,78 @@ class SharedEmbeddingColumnTest(test.TestCase):
       self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval())
       self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval())
 
+  def test_get_dense_tensor_with_state(self):
+    # Inputs.
+    vocabulary_size = 3
+    # -1 values are ignored.
+    input_a = np.array(
+        [[2, -1, -1],  # example 0, ids [2]
+         [0, 1, -1]])  # example 1, ids [0, 1]
+    input_b = np.array(
+        [[0, -1, -1],  # example 0, ids [0]
+         [-1, -1, -1]])  # example 1, ids []
+    input_features = {
+        'aaa': input_a,
+        'bbb': input_b
+    }
+
+    # Embedding variable.
+    embedding_dimension = 2
+    embedding_values = (
+        (1., 2.),  # id 0
+        (3., 5.),  # id 1
+        (7., 11.)  # id 2
+    )
+    def _initializer(shape, dtype, partition_info):
+      self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+      self.assertEqual(dtypes.float32, dtype)
+      self.assertIsNone(partition_info)
+      return embedding_values
+
+    # Expected lookup result, using combiner='mean'.
+    expected_lookups_a = (
+        # example 0:
+        (7., 11.),  # ids [2], embedding = [7, 11]
+        # example 1:
+        (2., 3.5),  # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+    )
+    expected_lookups_b = (
+        # example 0:
+        (1., 2.),  # ids [0], embedding = [1, 2]
+        # example 1:
+        (0., 0.),  # ids [], embedding = [0, 0]
+    )
+
+    # Build columns.
+    categorical_column_a = fc.categorical_column_with_identity(
+        key='aaa', num_buckets=vocabulary_size)
+    categorical_column_b = fc.categorical_column_with_identity(
+        key='bbb', num_buckets=vocabulary_size)
+    embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+        [categorical_column_a, categorical_column_b],
+        dimension=embedding_dimension, initializer=_initializer)
+
+    # Create state.
+    weight_collections = [ops.GraphKeys.GLOBAL_VARIABLES,
+                          ops.GraphKeys.MODEL_VARIABLES]
+    state = embedding_column_a._create_state(weight_collections)
+
+    # Provide sparse input and get dense result.
+    embedding_lookup_a = embedding_column_a._get_dense_tensor(
+        _LazyBuilder(input_features), state=state)
+    embedding_lookup_b = embedding_column_b._get_dense_tensor(
+        _LazyBuilder(input_features), state=state)
+
+    # Assert expected embedding variable and lookups.
+    global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+    self.assertItemsEqual(
+        ('embedding_weights:0',), tuple([v.name for v in global_vars]))
+    embedding_var = global_vars[0]
+    with _initialized_session():
+      self.assertAllEqual(embedding_values, embedding_var.eval())
+      self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval())
+      self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval())
+
   def test_get_dense_tensor_placeholder_inputs(self):
     # Inputs.
     vocabulary_size = 3