Sequence versions of remaining categorical columns
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 6 Mar 2018 19:23:41 +0000 (11:23 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Mar 2018 19:30:53 +0000 (11:30 -0800)
PiperOrigin-RevId: 188051821

tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py

index b25d7e5..f57557c 100644 (file)
@@ -132,7 +132,6 @@ def sequence_input_layer(
     return array_ops.concat(output_tensors, -1), sequence_length
 
 
-# TODO(b/73160931): Add remaining categorical columns.
 def sequence_categorical_column_with_identity(
     key, num_buckets, default_value=None):
   """Returns a feature column that represents sequences of integers.
@@ -143,7 +142,7 @@ def sequence_categorical_column_with_identity(
   watches = sequence_categorical_column_with_identity(
       'watches', num_buckets=1000)
   watches_embedding = embedding_column(watches, dimension=10)
-  columns = [watches]
+  columns = [watches_embedding]
 
   features = tf.parse_example(..., features=make_parse_example_spec(columns))
   input_layer, sequence_length = sequence_input_layer(features, columns)
@@ -171,6 +170,141 @@ def sequence_categorical_column_with_identity(
           default_value=default_value))
 
 
+def sequence_categorical_column_with_hash_bucket(
+    key, hash_bucket_size, dtype=dtypes.string):
+  """A sequence of categorical terms where ids are set by hashing.
+
+  Example:
+
+  ```python
+  tokens = sequence_categorical_column_with_hash_bucket(
+      'tokens', hash_bucket_size=1000)
+  tokens_embedding = embedding_column(tokens, dimension=10)
+  columns = [tokens_embedding]
+
+  features = tf.parse_example(..., features=make_parse_example_spec(columns))
+  input_layer, sequence_length = sequence_input_layer(features, columns)
+
+  rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
+  outputs, state = tf.nn.dynamic_rnn(
+      rnn_cell, inputs=input_layer, sequence_length=sequence_length)
+  ```
+
+  Args:
+    key: A unique string identifying the input feature.
+    hash_bucket_size: An int > 1. The number of buckets.
+    dtype: The type of features. Only string and integer types are supported.
+
+  Returns:
+    A `_SequenceCategoricalColumn`.
+  """
+  return _SequenceCategoricalColumn(
+      fc.categorical_column_with_hash_bucket(
+          key=key,
+          hash_bucket_size=hash_bucket_size,
+          dtype=dtype))
+
+
+def sequence_categorical_column_with_vocabulary_file(
+    key, vocabulary_file, vocabulary_size=None, num_oov_buckets=0,
+    default_value=None, dtype=dtypes.string):
+  """A sequence of categorical terms where ids use a vocabulary file.
+
+  Example:
+
+  ```python
+  states = sequence_categorical_column_with_vocabulary_file(
+      key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
+      num_oov_buckets=5)
+  states_embedding = embedding_column(states, dimension=10)
+  columns = [states_embedding]
+
+  features = tf.parse_example(..., features=make_parse_example_spec(columns))
+  input_layer, sequence_length = sequence_input_layer(features, columns)
+
+  rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
+  outputs, state = tf.nn.dynamic_rnn(
+      rnn_cell, inputs=input_layer, sequence_length=sequence_length)
+  ```
+
+  Args:
+    key: A unique string identifying the input feature.
+    vocabulary_file: The vocabulary file name.
+    vocabulary_size: Number of the elements in the vocabulary. This must be no
+      greater than length of `vocabulary_file`, if less than length, later
+      values are ignored. If None, it is set to the length of `vocabulary_file`.
+    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
+      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
+      `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
+      the input value. A positive `num_oov_buckets` can not be specified with
+      `default_value`.
+    default_value: The integer ID value to return for out-of-vocabulary feature
+      values, defaults to `-1`. This can not be specified with a positive
+      `num_oov_buckets`.
+    dtype: The type of features. Only string and integer types are supported.
+
+  Returns:
+    A `_SequenceCategoricalColumn`.
+  """
+  return _SequenceCategoricalColumn(
+      fc.categorical_column_with_vocabulary_file(
+          key=key,
+          vocabulary_file=vocabulary_file,
+          vocabulary_size=vocabulary_size,
+          num_oov_buckets=num_oov_buckets,
+          default_value=default_value,
+          dtype=dtype))
+
+
+def sequence_categorical_column_with_vocabulary_list(
+    key, vocabulary_list, dtype=None, default_value=-1, num_oov_buckets=0):
+  """A sequence of categorical terms where ids use an in-memory list.
+
+  Example:
+
+  ```python
+  colors = sequence_categorical_column_with_vocabulary_list(
+      key='colors', vocabulary_list=('R', 'G', 'B', 'Y'),
+      num_oov_buckets=2)
+  colors_embedding = embedding_column(colors, dimension=3)
+  columns = [colors_embedding]
+
+  features = tf.parse_example(..., features=make_parse_example_spec(columns))
+  input_layer, sequence_length = sequence_input_layer(features, columns)
+
+  rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
+  outputs, state = tf.nn.dynamic_rnn(
+      rnn_cell, inputs=input_layer, sequence_length=sequence_length)
+  ```
+
+  Args:
+    key: A unique string identifying the input feature.
+    vocabulary_list: An ordered iterable defining the vocabulary. Each feature
+      is mapped to the index of its value (if present) in `vocabulary_list`.
+      Must be castable to `dtype`.
+    dtype: The type of features. Only string and integer types are supported.
+      If `None`, it will be inferred from `vocabulary_list`.
+    default_value: The integer ID value to return for out-of-vocabulary feature
+      values, defaults to `-1`. This can not be specified with a positive
+      `num_oov_buckets`.
+    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
+      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
+      `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a
+      hash of the input value. A positive `num_oov_buckets` can not be specified
+      with `default_value`.
+
+  Returns:
+    A `_SequenceCategoricalColumn`.
+  """
+  return _SequenceCategoricalColumn(
+      fc.categorical_column_with_vocabulary_list(
+          key=key,
+          vocabulary_list=vocabulary_list,
+          dtype=dtype,
+          default_value=default_value,
+          num_oov_buckets=num_oov_buckets))
+
+
 # TODO(b/73160931): Merge with embedding_column
 def _sequence_embedding_column(
     categorical_column, dimension, initializer=None, ckpt_to_load_from=None,
index 5c1e76f..c077f03 100644 (file)
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import os
 import numpy as np
 
 from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc
@@ -230,13 +231,17 @@ class SequenceInputLayerTest(test.TestCase):
 
 
 def _assert_sparse_tensor_value(test_case, expected, actual):
-  test_case.assertEqual(np.int64, np.array(actual.indices).dtype)
-  test_case.assertAllEqual(expected.indices, actual.indices)
+  _assert_sparse_tensor_indices_shape(test_case, expected, actual)
 
   test_case.assertEqual(
       np.array(expected.values).dtype, np.array(actual.values).dtype)
   test_case.assertAllEqual(expected.values, actual.values)
 
+
+def _assert_sparse_tensor_indices_shape(test_case, expected, actual):
+  test_case.assertEqual(np.int64, np.array(actual.indices).dtype)
+  test_case.assertAllEqual(expected.indices, actual.indices)
+
   test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype)
   test_case.assertAllEqual(expected.dense_shape, actual.dense_shape)
 
@@ -314,6 +319,145 @@ class SequenceCategoricalColumnWithIdentityTest(test.TestCase):
           expected_sequence_length, sequence_length.eval(session=sess))
 
 
+class SequenceCategoricalColumnWithHashBucketTest(test.TestCase):
+
+  def test_get_sparse_tensors(self):
+    column = sfc.sequence_categorical_column_with_hash_bucket(
+        'aaa', hash_bucket_size=10)
+    inputs = sparse_tensor.SparseTensorValue(
+        indices=((0, 0), (1, 0), (1, 1)),
+        values=('omar', 'stringer', 'marlo'),
+        dense_shape=(2, 2))
+
+    expected_sparse_ids = sparse_tensor.SparseTensorValue(
+        indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
+        # Ignored to avoid hash dependence in test.
+        values=np.array((0, 0, 0), dtype=np.int64),
+        dense_shape=(2, 2, 1))
+
+    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
+
+    self.assertIsNone(id_weight_pair.weight_tensor)
+    with monitored_session.MonitoredSession() as sess:
+      _assert_sparse_tensor_indices_shape(
+          self,
+          expected_sparse_ids,
+          id_weight_pair.id_tensor.eval(session=sess))
+
+  def test_sequence_length(self):
+    column = sfc.sequence_categorical_column_with_hash_bucket(
+        'aaa', hash_bucket_size=10)
+    inputs = sparse_tensor.SparseTensorValue(
+        indices=((0, 0), (1, 0), (1, 1)),
+        values=('omar', 'stringer', 'marlo'),
+        dense_shape=(2, 2))
+    expected_sequence_length = [1, 2]
+
+    sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs}))
+
+    with monitored_session.MonitoredSession() as sess:
+      self.assertAllEqual(
+          expected_sequence_length, sequence_length.eval(session=sess))
+
+
+class SequenceCategoricalColumnWithVocabularyFileTest(test.TestCase):
+
+  def _write_vocab(self, vocab_strings, file_name):
+    vocab_file = os.path.join(self.get_temp_dir(), file_name)
+    with open(vocab_file, 'w') as f:
+      f.write('\n'.join(vocab_strings))
+    return vocab_file
+
+  def setUp(self):
+    super(SequenceCategoricalColumnWithVocabularyFileTest, self).setUp()
+
+    vocab_strings = ['omar', 'stringer', 'marlo']
+    self._wire_vocabulary_file_name = self._write_vocab(vocab_strings,
+                                                        'wire_vocabulary.txt')
+    self._wire_vocabulary_size = 3
+
+  def test_get_sparse_tensors(self):
+    column = sfc.sequence_categorical_column_with_vocabulary_file(
+        key='aaa',
+        vocabulary_file=self._wire_vocabulary_file_name,
+        vocabulary_size=self._wire_vocabulary_size)
+    inputs = sparse_tensor.SparseTensorValue(
+        indices=((0, 0), (1, 0), (1, 1)),
+        values=('marlo', 'skywalker', 'omar'),
+        dense_shape=(2, 2))
+    expected_sparse_ids = sparse_tensor.SparseTensorValue(
+        indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
+        values=np.array((2, -1, 0), dtype=np.int64),
+        dense_shape=(2, 2, 1))
+
+    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
+
+    self.assertIsNone(id_weight_pair.weight_tensor)
+    with monitored_session.MonitoredSession() as sess:
+      _assert_sparse_tensor_value(
+          self,
+          expected_sparse_ids,
+          id_weight_pair.id_tensor.eval(session=sess))
+
+  def test_sequence_length(self):
+    column = sfc.sequence_categorical_column_with_vocabulary_file(
+        key='aaa',
+        vocabulary_file=self._wire_vocabulary_file_name,
+        vocabulary_size=self._wire_vocabulary_size)
+    inputs = sparse_tensor.SparseTensorValue(
+        indices=((0, 0), (1, 0), (1, 1)),
+        values=('marlo', 'skywalker', 'omar'),
+        dense_shape=(2, 2))
+    expected_sequence_length = [1, 2]
+
+    sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs}))
+
+    with monitored_session.MonitoredSession() as sess:
+      self.assertAllEqual(
+          expected_sequence_length, sequence_length.eval(session=sess))
+
+
+class SequenceCategoricalColumnWithVocabularyListTest(test.TestCase):
+
+  def test_get_sparse_tensors(self):
+    column = sfc.sequence_categorical_column_with_vocabulary_list(
+        key='aaa',
+        vocabulary_list=('omar', 'stringer', 'marlo'))
+    inputs = sparse_tensor.SparseTensorValue(
+        indices=((0, 0), (1, 0), (1, 1)),
+        values=('marlo', 'skywalker', 'omar'),
+        dense_shape=(2, 2))
+    expected_sparse_ids = sparse_tensor.SparseTensorValue(
+        indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
+        values=np.array((2, -1, 0), dtype=np.int64),
+        dense_shape=(2, 2, 1))
+
+    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
+
+    self.assertIsNone(id_weight_pair.weight_tensor)
+    with monitored_session.MonitoredSession() as sess:
+      _assert_sparse_tensor_value(
+          self,
+          expected_sparse_ids,
+          id_weight_pair.id_tensor.eval(session=sess))
+
+  def test_sequence_length(self):
+    column = sfc.sequence_categorical_column_with_vocabulary_list(
+        key='aaa',
+        vocabulary_list=('omar', 'stringer', 'marlo'))
+    inputs = sparse_tensor.SparseTensorValue(
+        indices=((0, 0), (1, 0), (1, 1)),
+        values=('marlo', 'skywalker', 'omar'),
+        dense_shape=(2, 2))
+    expected_sequence_length = [1, 2]
+
+    sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs}))
+
+    with monitored_session.MonitoredSession() as sess:
+      self.assertAllEqual(
+          expected_sequence_length, sequence_length.eval(session=sess))
+
+
 class SequenceEmbeddingColumnTest(test.TestCase):
 
   def test_get_sequence_dense_tensor(self):