"""
pass
+ def _create_state(self, weight_collections=None, creator=None):
+ """Returns an object that captures the state of the column.
+
+ Args:
+ weight_collections: Collections to add the variable to
+ creator: Variable creator method called, if provided.
+
+ Returns:
+ An object that encapsulates the state of the column. Can return None.
+ """
+ del weight_collections, creator # Unused
+ return None
+
class _DenseColumn(_FeatureColumn):
"""Represents a column which can be represented as `Tensor`.
pass
@abc.abstractmethod
- def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ def _get_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None,
+ state=None):
"""Returns a `Tensor`.
The output of this function will be used by model-builder-functions. For
will be created) are added.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see @{tf.Variable}).
+ state: An object encapsulating the state of the column. Columns that
+ create state using the _create_state method would have that state
+ passed in to this method.
Returns:
`Tensor` of shape [batch_size] + `_variable_shape`.
pass
-def _create_weighted_sum(
- column,
- builder,
- units,
- sparse_combiner,
- weight_collections,
- trainable):
+def _create_weighted_sum(column,
+ builder,
+ units,
+ sparse_combiner,
+ weight_collections,
+ trainable,
+ state=None):
"""Creates a weighted sum for a dense or sparse column for linear_model."""
if isinstance(column, _CategoricalColumn):
return _create_categorical_column_weighted_sum(
builder=builder,
units=units,
weight_collections=weight_collections,
- trainable=trainable)
+ trainable=trainable,
+ state=state)
-def _create_dense_column_weighted_sum(
- column, builder, units, weight_collections, trainable):
+def _create_dense_column_weighted_sum(column,
+ builder,
+ units,
+ weight_collections,
+ trainable,
+ state=None):
"""Create a weighted sum of a dense column for linear_model."""
- tensor = column._get_dense_tensor( # pylint: disable=protected-access
- builder,
- weight_collections=weight_collections,
- trainable=trainable)
+ if state is not None:
+ tensor = column._get_dense_tensor( # pylint: disable=protected-access
+ builder,
+ weight_collections=weight_collections,
+ trainable=trainable,
+ state=state)
+ else:
+ tensor = column._get_dense_tensor( # pylint: disable=protected-access
+ builder,
+ weight_collections=weight_collections,
+ trainable=trainable)
num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access
batch_size = array_ops.shape(tensor)[0]
tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
self._shape = tensor_shape.vector(self.dimension)
return self._shape
- def _get_dense_tensor_internal(
- self, inputs, weight_collections=None, trainable=None):
+ def _create_state(self, weight_collections=None, creator=None):
+ variables_map = {}
+ embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
+ if creator is not None:
+ embedding_weights = creator(
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self.initializer,
+ trainable=self.trainable)
+ ops.add_to_collections(weight_collections, embedding_weights)
+ else:
+ embedding_weights = variable_scope.get_variable(
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self.initializer,
+ trainable=self.trainable,
+ collections=weight_collections)
+ variables_map['embedding_weights'] = embedding_weights
+ return variables_map
+
+ def _get_dense_tensor_internal(self,
+ inputs,
+ weight_collections=None,
+ trainable=None,
+ state=None):
"""Private method that follows the signature of _get_dense_tensor."""
# Get sparse IDs and weights.
sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
- embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
- embedding_weights = variable_scope.get_variable(
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer,
- trainable=self.trainable and trainable,
- collections=weight_collections)
+ if state is None:
+ state = self._create_state(weight_collections)
+ embedding_weights = state['embedding_weights']
+
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
if isinstance(to_restore, variables.PartitionedVariable):
name='%s_weights' % self.name,
max_norm=self.max_norm)
- def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ def _get_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None,
+ state=None):
if isinstance(self.categorical_column, _SequenceCategoricalColumn):
raise ValueError(
'In embedding_column: {}. '
self.name, type(self.categorical_column),
self.categorical_column))
return self._get_dense_tensor_internal(
- inputs=inputs, weight_collections=weight_collections,
- trainable=trainable)
+ inputs=inputs,
+ weight_collections=weight_collections,
+ trainable=trainable,
+ state=state)
def _get_sequence_dense_tensor(
self, inputs, weight_collections=None, trainable=None):
self._shape = tensor_shape.vector(self.dimension)
return self._shape
- def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ def _create_state(self, weight_collections=None, creator=None):
+ variables_map = {}
+ shared_embedding_collection = ops.get_collection(
+ self.shared_embedding_collection_name)
+ if not shared_embedding_collection:
+ embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
+ if creator is not None:
+ embedding_weights = creator(
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self.initializer,
+ trainable=self.trainable)
+ ops.add_to_collections(weight_collections, embedding_weights)
+ else:
+ embedding_weights = variable_scope.get_variable(
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self.initializer,
+ trainable=self.trainable,
+ collections=weight_collections)
+ ops.add_to_collection(self.shared_embedding_collection_name,
+ embedding_weights)
+ variables_map['embedding_weights'] = embedding_weights
+
+ return variables_map
+
+ def _get_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None,
+ state=None):
# This method is called from a variable_scope with name _var_scope_name,
# which is shared among all shared embeddings. Open a name_scope here, so
# that the ops for different columns have distinct names.
self.assertAllEqual(embedding_values, global_vars[0].eval())
self.assertAllEqual(expected_lookups, embedding_lookup.eval())
+ def test_get_dense_tensor_with_state(self):
+ # Inputs.
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups = (
+ # example 0, ids [2], embedding = [7, 11]
+ (7., 11.),
+ # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ (2., 3.5),
+ # example 2, ids [], embedding = [0, 0]
+ (0., 0.),
+ # example 3, ids [1], embedding = [3, 5]
+ (3., 5.),
+ )
+
+ # Build columns.
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc.embedding_column(
+ categorical_column, dimension=embedding_dimension,
+ initializer=_initializer)
+
+ # Create embedding_weights variable.
+ weight_collections = [ops.GraphKeys.GLOBAL_VARIABLES,
+ ops.GraphKeys.MODEL_VARIABLES]
+ state = embedding_column._create_state(weight_collections)
+
+ # Provide sparse input and get dense result.
+ embedding_lookup = embedding_column._get_dense_tensor(
+ _LazyBuilder({
+ 'aaa': sparse_input
+ }),
+ state=state)
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ('embedding_weights:0',), tuple([v.name for v in global_vars]))
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, global_vars[0].eval())
+ self.assertAllEqual(expected_lookups, embedding_lookup.eval())
+
def test_get_dense_tensor_3d(self):
# Inputs.
vocabulary_size = 4
self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval())
self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval())
+ def test_get_dense_tensor_with_state(self):
+ # Inputs.
+ vocabulary_size = 3
+ # -1 values are ignored.
+ input_a = np.array(
+ [[2, -1, -1], # example 0, ids [2]
+ [0, 1, -1]]) # example 1, ids [0, 1]
+ input_b = np.array(
+ [[0, -1, -1], # example 0, ids [0]
+ [-1, -1, -1]]) # example 1, ids []
+ input_features = {
+ 'aaa': input_a,
+ 'bbb': input_b
+ }
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups_a = (
+ # example 0:
+ (7., 11.), # ids [2], embedding = [7, 11]
+ # example 1:
+ (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ )
+ expected_lookups_b = (
+ # example 0:
+ (1., 2.), # ids [0], embedding = [1, 2]
+ # example 1:
+ (0., 0.), # ids [], embedding = [0, 0]
+ )
+
+ # Build columns.
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension, initializer=_initializer)
+
+ # Create state.
+ weight_collections = [ops.GraphKeys.GLOBAL_VARIABLES,
+ ops.GraphKeys.MODEL_VARIABLES]
+ state = embedding_column_a._create_state(weight_collections)
+
+ # Provide sparse input and get dense result.
+ embedding_lookup_a = embedding_column_a._get_dense_tensor(
+ _LazyBuilder(input_features), state=state)
+ embedding_lookup_b = embedding_column_b._get_dense_tensor(
+ _LazyBuilder(input_features), state=state)
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ('embedding_weights:0',), tuple([v.name for v in global_vars]))
+ embedding_var = global_vars[0]
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, embedding_var.eval())
+ self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval())
+ self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval())
+
def test_get_dense_tensor_placeholder_inputs(self):
# Inputs.
vocabulary_size = 3