1st version of sequential feature columns.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 26 Feb 2018 22:25:30 +0000 (14:25 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
PiperOrigin-RevId: 187080635

tensorflow/contrib/feature_column/BUILD
tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py
tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py [new file with mode: 0644]

index 6fc0537..a53e36c 100644 (file)
@@ -33,5 +33,34 @@ py_library(
     name = "sequential_feature_column",
     srcs = ["python/feature_column/sequential_feature_column.py"],
     srcs_version = "PY2AND3",
-    deps = [],
+    deps = [
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:check_ops",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:parsing_ops",
+        "//tensorflow/python:sparse_ops",
+        "//tensorflow/python:tensor_shape",
+        "//tensorflow/python:variable_scope",
+        "//tensorflow/python/feature_column",
+    ],
+)
+
+py_test(
+    name = "sequential_feature_column_test",
+    srcs = ["python/feature_column/sequential_feature_column_test.py"],
+    srcs_version = "PY2AND3",
+    tags = ["no_pip"],
+    deps = [
+        ":sequential_feature_column",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:errors",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:sparse_tensor",
+        "//tensorflow/python:training",
+        "//tensorflow/python/feature_column",
+        "//third_party/py/numpy",
+    ],
 )
index 690a44f..4ed7268 100644 (file)
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Experimental methods for tf.feature_column sequential input."""
+"""Experimental methods for tf.feature_column sequence input."""
 
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
+
+
+import abc
+import collections
+
+
+from tensorflow.python.feature_column import feature_column as fc
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops import variable_scope
+
+# TODO(b/73160931): Fix pydoc.
+# pylint: disable=g-doc-args,missing-docstring,protected-access
+# TODO(b/73827486): Support SequenceExample.
+
+
+def sequence_input_layer(
+    features,
+    feature_columns,
+    weight_collections=None,
+    trainable=True,
+    scope=None):
+  """"Builds input layer for sequence input.
+
+  All `feature_columns` must be sequence dense columns with the same
+  `sequence_length`. The output of this method can be fed into sequence
+  networks, such as RNN.
+
+  The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`.
+  `T` is the maximum sequence length for this batch, which could differ from
+  batch to batch.
+
+  If multiple `feature_columns` are given with `Di` `num_elements` each, their
+  outputs are concatenated. So, the final `Tensor` has shape
+  `[batch_size, T, D0 + D1 + ... + Dn]`.
+
+  Example:
+
+  ```python
+  rating = sequence_numeric_column('rating')
+  watches = sequence_categorical_column_with_identity(
+      'watches', num_buckets=1000)
+  watches_embedding = embedding_column(watches, dimension=10)
+  columns = [rating, watches]
+
+  features = tf.parse_example(..., features=make_parse_example_spec(columns))
+  input_layer, sequence_length = sequence_input_layer(features, columns)
+
+  rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
+  outputs, state = tf.nn.dynamic_rnn(
+      rnn_cell, inputs=input_layer, sequence_length=sequence_length)
+  ```
+
+  Returns:
+    An `(input_layer, sequence_length)` tuple where:
+    - input_layer: A float `Tensor` of shape `[batch_size, T, D]`.
+        `T` is the maximum sequence length for this batch, which could differ
+        from batch to batch. `D` is the sum of `num_elements` for all
+        `feature_columns`.
+    - sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence
+        length for each example.
+  Raises:
+    ValueError: If any of the `feature_columns` is the wrong type.
+  """
+  feature_columns = fc._clean_feature_columns(feature_columns)
+  for c in feature_columns:
+    if not isinstance(c, _SequenceDenseColumn):
+      raise ValueError(
+          'All feature_columns must be of type _SequenceDenseColumn. '
+          'Given (type {}): {}'.format(type(c), c))
+
+  with variable_scope.variable_scope(
+      scope, default_name='sequence_input_layer', values=features.values()):
+    builder = fc._LazyBuilder(features)
+    output_tensors = []
+    sequence_lengths = []
+    ordered_columns = []
+    for column in sorted(feature_columns, key=lambda x: x.name):
+      ordered_columns.append(column)
+      with variable_scope.variable_scope(
+          None, default_name=column._var_scope_name):
+        dense_tensor, sequence_length = column._get_sequence_dense_tensor(
+            builder,
+            weight_collections=weight_collections,
+            trainable=trainable)
+        # Flattens the final dimension to produce a 3D Tensor.
+        num_elements = column._variable_shape.num_elements()
+        shape = array_ops.shape(dense_tensor)
+        output_tensors.append(
+            array_ops.reshape(
+                dense_tensor,
+                shape=array_ops.concat([shape[:2], [num_elements]], axis=0)))
+        sequence_lengths.append(sequence_length)
+    fc._verify_static_batch_size_equality(output_tensors, ordered_columns)
+    # TODO(b/73160931): Verify sequence_length equality.
+    return array_ops.concat(output_tensors, -1), sequence_lengths[0]
+
+
+# TODO(b/73160931): Add remaining categorical columns.
+def sequence_categorical_column_with_identity(
+    key, num_buckets, default_value=None):
+  return _SequenceCategoricalColumn(
+      fc.categorical_column_with_identity(
+          key=key,
+          num_buckets=num_buckets,
+          default_value=default_value))
+
+
+# TODO(b/73160931): Merge with embedding_column
+def _sequence_embedding_column(
+    categorical_column, dimension, initializer=None, ckpt_to_load_from=None,
+    tensor_name_in_ckpt=None, max_norm=None, trainable=True):
+  if not isinstance(categorical_column, _SequenceCategoricalColumn):
+    raise ValueError(
+        'categorical_column must be of type _SequenceCategoricalColumn. '
+        'Given (type {}): {}'.format(
+            type(categorical_column), categorical_column))
+  return _SequenceEmbeddingColumn(
+      fc.embedding_column(
+          categorical_column,
+          dimension=dimension,
+          initializer=initializer,
+          ckpt_to_load_from=ckpt_to_load_from,
+          tensor_name_in_ckpt=tensor_name_in_ckpt,
+          max_norm=max_norm,
+          trainable=trainable))
+
+
+def sequence_numeric_column(
+    key,
+    shape=(1,),
+    default_value=0.,
+    dtype=dtypes.float32):
+  # TODO(b/73160931): Add validations.
+  return _SequenceNumericColumn(
+      key,
+      shape=shape,
+      default_value=default_value,
+      dtype=dtype)
+
+
+class _SequenceDenseColumn(fc._FeatureColumn):
+  """Represents dense sequence data."""
+
+  __metaclass__ = abc.ABCMeta
+
+  TensorSequenceLengthPair = collections.namedtuple(  # pylint: disable=invalid-name
+      'TensorSequenceLengthPair', ['dense_tensor', 'sequence_length'])
+
+  @abc.abstractproperty
+  def _variable_shape(self):
+    """`TensorShape` without batch and sequence dimensions."""
+    pass
+
+  @abc.abstractmethod
+  def _get_sequence_dense_tensor(
+      self, inputs, weight_collections=None, trainable=None):
+    """Returns a `TensorSequenceLengthPair`."""
+    pass
+
+
+def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1):
+  with ops.name_scope(None, 'sequence_length') as name_scope:
+    row_ids = sp_tensor.indices[:, 0]
+    column_ids = sp_tensor.indices[:, 1]
+    column_ids += array_ops.ones_like(column_ids)
+    seq_length = (
+        math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements)
+    # If the last n rows do not have ids, seq_length will have shape
+    # [batch_size - n]. Pad the remaining values with zeros.
+    n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1]
+    padding = array_ops.zeros(n_pad, dtype=seq_length.dtype)
+    return array_ops.concat([seq_length, padding], axis=0, name=name_scope)
+
+
+class _SequenceCategoricalColumn(
+    fc._CategoricalColumn,
+    collections.namedtuple(
+        '_SequenceCategoricalColumn', ['categorical_column'])):
+
+  @property
+  def name(self):
+    return self.categorical_column.name
+
+  @property
+  def _parse_example_spec(self):
+    return self.categorical_column._parse_example_spec
+
+  def _transform_feature(self, inputs):
+    return self.categorical_column._transform_feature(inputs)
+
+  @property
+  def _num_buckets(self):
+    return self.categorical_column._num_buckets
+
+  def _get_sparse_tensors(self, inputs, weight_collections=None,
+                          trainable=None):
+    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)
+    id_tensor = sparse_tensors.id_tensor
+    weight_tensor = sparse_tensors.weight_tensor
+    # Expands final dimension, so that embeddings are not combined during
+    # embedding lookup.
+    check_id_rank = check_ops.assert_equal(
+        array_ops.rank(id_tensor), 2,
+        data=[
+            'Column {} expected ID tensor of rank 2. '.format(self.name),
+            'id_tensor shape: ', array_ops.shape(id_tensor)])
+    with ops.control_dependencies([check_id_rank]):
+      id_tensor = sparse_ops.sparse_reshape(
+          id_tensor,
+          shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0))
+    if weight_tensor is not None:
+      check_weight_rank = check_ops.assert_equal(
+          array_ops.rank(weight_tensor), 2,
+          data=[
+              'Column {} expected weight tensor of rank 2.'.format(self.name),
+              'weight_tensor shape:', array_ops.shape(weight_tensor)])
+      with ops.control_dependencies([check_weight_rank]):
+        weight_tensor = sparse_ops.sparse_reshape(
+            weight_tensor,
+            shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
+    return fc._CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
+
+  def _sequence_length(self, inputs):
+    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)
+    return _sequence_length_from_sparse_tensor(sparse_tensors.id_tensor)
+
+
+class _SequenceEmbeddingColumn(
+    _SequenceDenseColumn,
+    collections.namedtuple('_SequenceEmbeddingColumn', ['embedding_column'])):
+
+  @property
+  def name(self):
+    return self.embedding_column.name
+
+  @property
+  def _parse_example_spec(self):
+    return self.embedding_column._parse_example_spec
+
+  def _transform_feature(self, inputs):
+    return self.embedding_column._transform_feature(inputs)
+
+  @property
+  def _variable_shape(self):
+    return self.embedding_column._variable_shape
+
+  def _get_sequence_dense_tensor(
+      self, inputs, weight_collections=None, trainable=None):
+    dense_tensor = self.embedding_column._get_dense_tensor(
+        inputs=inputs,
+        weight_collections=weight_collections,
+        trainable=trainable)
+    sequence_length = self.embedding_column.categorical_column._sequence_length(
+        inputs)
+    return _SequenceDenseColumn.TensorSequenceLengthPair(
+        dense_tensor=dense_tensor, sequence_length=sequence_length)
+
+
+class _SequenceNumericColumn(
+    _SequenceDenseColumn,
+    collections.namedtuple(
+        '_SequenceNumericColumn',
+        ['key', 'shape', 'default_value', 'dtype'])):
+
+  @property
+  def name(self):
+    return self.key
+
+  @property
+  def _parse_example_spec(self):
+    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
+
+  def _transform_feature(self, inputs):
+    return inputs.get(self.key)
+
+  @property
+  def _variable_shape(self):
+    return tensor_shape.TensorShape(self.shape)
+
+  def _get_sequence_dense_tensor(
+      self, inputs, weight_collections=None, trainable=None):
+    # Do nothing with weight_collections and trainable since no variables are
+    # created in this function.
+    del weight_collections
+    del trainable
+    sp_tensor = inputs.get(self)
+    dense_tensor = sparse_ops.sparse_tensor_to_dense(
+        sp_tensor, default_value=self.default_value)
+    # Reshape into [batch_size, T, variable_shape].
+    dense_shape = array_ops.concat(
+        [array_ops.shape(dense_tensor)[:1], [-1], self._variable_shape],
+        axis=0)
+    dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape)
+    sequence_length = _sequence_length_from_sparse_tensor(
+        sp_tensor, num_elements=self._variable_shape.num_elements())
+    return _SequenceDenseColumn.TensorSequenceLengthPair(
+        dense_tensor=dense_tensor, sequence_length=sequence_length)
+
+# pylint: enable=g-doc-args,missing-docstring,protected-access
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py
new file mode 100644 (file)
index 0000000..5967486
--- /dev/null
@@ -0,0 +1,471 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for sequential_feature_column."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.feature_column.python.feature_column import sequential_feature_column as sfc
+from tensorflow.python.feature_column.feature_column import _LazyBuilder
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.platform import test
+from tensorflow.python.training import monitored_session
+
+
+class SequenceInputLayerTest(test.TestCase):
+
+  def test_embedding_column(self):
+    vocabulary_size = 3
+    sparse_input_a = sparse_tensor.SparseTensorValue(
+        # example 0, ids [2]
+        # example 1, ids [0, 1]
+        indices=((0, 0), (1, 0), (1, 1)),
+        values=(2, 0, 1),
+        dense_shape=(2, 2))
+    sparse_input_b = sparse_tensor.SparseTensorValue(
+        # example 0, ids [1]
+        # example 1, ids [2, 0]
+        indices=((0, 0), (1, 0), (1, 1)),
+        values=(1, 2, 0),
+        dense_shape=(2, 2))
+
+    embedding_dimension_a = 2
+    embedding_values_a = (
+        (1., 2.),  # id 0
+        (3., 4.),  # id 1
+        (5., 6.)  # id 2
+    )
+    embedding_dimension_b = 3
+    embedding_values_b = (
+        (11., 12., 13.),  # id 0
+        (14., 15., 16.),  # id 1
+        (17., 18., 19.)  # id 2
+    )
+    def _get_initializer(embedding_dimension, embedding_values):
+      def _initializer(shape, dtype, partition_info):
+        self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+        self.assertEqual(dtypes.float32, dtype)
+        self.assertIsNone(partition_info)
+        return embedding_values
+      return _initializer
+
+    expected_input_layer = [
+        # example 0, ids_a [2], ids_b [1]
+        [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]],
+        # example 1, ids_a [0, 1], ids_b [2, 0]
+        [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]],
+    ]
+    expected_sequence_length = [1, 2]
+
+    categorical_column_a = sfc.sequence_categorical_column_with_identity(
+        key='aaa', num_buckets=vocabulary_size)
+    embedding_column_a = sfc._sequence_embedding_column(
+        categorical_column_a, dimension=embedding_dimension_a,
+        initializer=_get_initializer(embedding_dimension_a, embedding_values_a))
+    categorical_column_b = sfc.sequence_categorical_column_with_identity(
+        key='bbb', num_buckets=vocabulary_size)
+    embedding_column_b = sfc._sequence_embedding_column(
+        categorical_column_b, dimension=embedding_dimension_b,
+        initializer=_get_initializer(embedding_dimension_b, embedding_values_b))
+
+    input_layer, sequence_length = sfc.sequence_input_layer(
+        features={
+            'aaa': sparse_input_a,
+            'bbb': sparse_input_b,
+        },
+        # Test that columns are reordered alphabetically.
+        feature_columns=[embedding_column_b, embedding_column_a])
+
+    global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+    self.assertItemsEqual(
+        ('sequence_input_layer/aaa_embedding/embedding_weights:0',
+         'sequence_input_layer/bbb_embedding/embedding_weights:0'),
+        tuple([v.name for v in global_vars]))
+    with monitored_session.MonitoredSession() as sess:
+      self.assertAllEqual(embedding_values_a, global_vars[0].eval(session=sess))
+      self.assertAllEqual(embedding_values_b, global_vars[1].eval(session=sess))
+      self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess))
+      self.assertAllEqual(
+          expected_sequence_length, sequence_length.eval(session=sess))
+
+  def test_numeric_column(self):
+    sparse_input = sparse_tensor.SparseTensorValue(
+        # example 0, values [[0.], [1]]
+        # example 1, [[10.]]
+        indices=((0, 0), (0, 1), (1, 0)),
+        values=(0., 1., 10.),
+        dense_shape=(2, 2))
+    expected_input_layer = [
+        [[0.], [1.]],
+        [[10.], [0.]],
+    ]
+    expected_sequence_length = [2, 1]
+    numeric_column = sfc.sequence_numeric_column('aaa')
+
+    input_layer, sequence_length = sfc.sequence_input_layer(
+        features={'aaa': sparse_input},
+        feature_columns=[numeric_column])
+
+    with monitored_session.MonitoredSession() as sess:
+      self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess))
+      self.assertAllEqual(
+          expected_sequence_length, sequence_length.eval(session=sess))
+
+  def test_numeric_column_multi_dim(self):
+    """Tests sequence_input_layer for multi-dimensional numeric_column."""
+    sparse_input = sparse_tensor.SparseTensorValue(
+        # example 0, values [[[0., 1.],  [2., 3.]], [[4., 5.],  [6., 7.]]]
+        # example 1, [[[10., 11.],  [12., 13.]]]
+        indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7),
+                 (1, 0), (1, 1), (1, 2), (1, 3)),
+        values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
+        dense_shape=(2, 8))
+    # The output of numeric_column._get_dense_tensor should be flattened.
+    expected_input_layer = [
+        [[0., 1., 2., 3.], [4., 5., 6., 7.]],
+        [[10., 11., 12., 13.], [0., 0., 0., 0.]],
+    ]
+    expected_sequence_length = [2, 1]
+    numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2))
+
+    input_layer, sequence_length = sfc.sequence_input_layer(
+        features={'aaa': sparse_input},
+        feature_columns=[numeric_column])
+
+    with monitored_session.MonitoredSession() as sess:
+      self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess))
+      self.assertAllEqual(
+          expected_sequence_length, sequence_length.eval(session=sess))
+
+
+def _assert_sparse_tensor_value(test_case, expected, actual):
+  test_case.assertEqual(np.int64, np.array(actual.indices).dtype)
+  test_case.assertAllEqual(expected.indices, actual.indices)
+
+  test_case.assertEqual(
+      np.array(expected.values).dtype, np.array(actual.values).dtype)
+  test_case.assertAllEqual(expected.values, actual.values)
+
+  test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype)
+  test_case.assertAllEqual(expected.dense_shape, actual.dense_shape)
+
+
+class SequenceCategoricalColumnWithIdentityTest(test.TestCase):
+
+  def test_get_sparse_tensors(self):
+    column = sfc.sequence_categorical_column_with_identity(
+        'aaa', num_buckets=3)
+    inputs = sparse_tensor.SparseTensorValue(
+        indices=((0, 0), (1, 0), (1, 1)),
+        values=(1, 2, 0),
+        dense_shape=(2, 2))
+    expected_sparse_ids = sparse_tensor.SparseTensorValue(
+        indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
+        values=np.array((1, 2, 0), dtype=np.int64),
+        dense_shape=(2, 2, 1))
+
+    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
+
+    self.assertIsNone(id_weight_pair.weight_tensor)
+    with monitored_session.MonitoredSession() as sess:
+      _assert_sparse_tensor_value(
+          self,
+          expected_sparse_ids,
+          id_weight_pair.id_tensor.eval(session=sess))
+
+  def test_get_sparse_tensors_inputs3d(self):
+    """Tests _get_sparse_tensors when the input is already 3D Tensor."""
+    column = sfc.sequence_categorical_column_with_identity(
+        'aaa', num_buckets=3)
+    inputs = sparse_tensor.SparseTensorValue(
+        indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
+        values=(1, 2, 0),
+        dense_shape=(2, 2, 1))
+
+    with self.assertRaisesRegexp(
+        errors.InvalidArgumentError,
+        r'Column aaa expected ID tensor of rank 2\.\s*'
+        r'id_tensor shape:\s*\[2 2 1\]'):
+      id_weight_pair = column._get_sparse_tensors(
+          _LazyBuilder({'aaa': inputs}))
+      with monitored_session.MonitoredSession() as sess:
+        id_weight_pair.id_tensor.eval(session=sess)
+
+  def test_sequence_length(self):
+    column = sfc.sequence_categorical_column_with_identity(
+        'aaa', num_buckets=3)
+    inputs = sparse_tensor.SparseTensorValue(
+        indices=((0, 0), (1, 0), (1, 1)),
+        values=(1, 2, 0),
+        dense_shape=(2, 2))
+    expected_sequence_length = [1, 2]
+
+    sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs}))
+
+    with monitored_session.MonitoredSession() as sess:
+      self.assertAllEqual(
+          expected_sequence_length, sequence_length.eval(session=sess))
+
+  def test_sequence_length_with_zeros(self):
+    column = sfc.sequence_categorical_column_with_identity(
+        'aaa', num_buckets=3)
+    inputs = sparse_tensor.SparseTensorValue(
+        indices=((1, 0), (3, 0), (3, 1)),
+        values=(1, 2, 0),
+        dense_shape=(5, 2))
+    expected_sequence_length = [0, 1, 0, 2, 0]
+
+    sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs}))
+
+    with monitored_session.MonitoredSession() as sess:
+      self.assertAllEqual(
+          expected_sequence_length, sequence_length.eval(session=sess))
+
+
+class SequenceEmbeddingColumnTest(test.TestCase):
+
+  def test_get_sequence_dense_tensor(self):
+    vocabulary_size = 3
+    sparse_input = sparse_tensor.SparseTensorValue(
+        # example 0, ids [2]
+        # example 1, ids [0, 1]
+        # example 2, ids []
+        # example 3, ids [1]
+        indices=((0, 0), (1, 0), (1, 1), (3, 0)),
+        values=(2, 0, 1, 1),
+        dense_shape=(4, 2))
+
+    embedding_dimension = 2
+    embedding_values = (
+        (1., 2.),  # id 0
+        (3., 5.),  # id 1
+        (7., 11.)  # id 2
+    )
+    def _initializer(shape, dtype, partition_info):
+      self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+      self.assertEqual(dtypes.float32, dtype)
+      self.assertIsNone(partition_info)
+      return embedding_values
+
+    expected_lookups = [
+        # example 0, ids [2]
+        [[7., 11.], [0., 0.]],
+        # example 1, ids [0, 1]
+        [[1., 2.], [3., 5.]],
+        # example 2, ids []
+        [[0., 0.], [0., 0.]],
+        # example 3, ids [1]
+        [[3., 5.], [0., 0.]],
+    ]
+
+    categorical_column = sfc.sequence_categorical_column_with_identity(
+        key='aaa', num_buckets=vocabulary_size)
+    embedding_column = sfc._sequence_embedding_column(
+        categorical_column, dimension=embedding_dimension,
+        initializer=_initializer)
+
+    embedding_lookup, _ = embedding_column._get_sequence_dense_tensor(
+        _LazyBuilder({'aaa': sparse_input}))
+
+    global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+    self.assertItemsEqual(
+        ('embedding_weights:0',), tuple([v.name for v in global_vars]))
+    with monitored_session.MonitoredSession() as sess:
+      self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess))
+      self.assertAllEqual(expected_lookups, embedding_lookup.eval(session=sess))
+
+  def test_sequence_length(self):
+    vocabulary_size = 3
+    sparse_input = sparse_tensor.SparseTensorValue(
+        # example 0, ids [2]
+        # example 1, ids [0, 1]
+        indices=((0, 0), (1, 0), (1, 1)),
+        values=(2, 0, 1),
+        dense_shape=(2, 2))
+    expected_sequence_length = [1, 2]
+
+    categorical_column = sfc.sequence_categorical_column_with_identity(
+        key='aaa', num_buckets=vocabulary_size)
+    embedding_column = sfc._sequence_embedding_column(
+        categorical_column, dimension=2)
+
+    _, sequence_length = embedding_column._get_sequence_dense_tensor(
+        _LazyBuilder({'aaa': sparse_input}))
+
+    with monitored_session.MonitoredSession() as sess:
+      self.assertAllEqual(
+          expected_sequence_length, sequence_length.eval(session=sess))
+
+  def test_sequence_length_with_empty_rows(self):
+    """Tests _sequence_length when some examples do not have ids."""
+    vocabulary_size = 3
+    sparse_input = sparse_tensor.SparseTensorValue(
+        # example 0, ids []
+        # example 1, ids [2]
+        # example 2, ids [0, 1]
+        # example 3, ids []
+        # example 4, ids [1]
+        # example 5, ids []
+        indices=((1, 0), (2, 0), (2, 1), (4, 0)),
+        values=(2, 0, 1, 1),
+        dense_shape=(6, 2))
+    expected_sequence_length = [0, 1, 2, 0, 1, 0]
+
+    categorical_column = sfc.sequence_categorical_column_with_identity(
+        key='aaa', num_buckets=vocabulary_size)
+    embedding_column = sfc._sequence_embedding_column(
+        categorical_column, dimension=2)
+
+    _, sequence_length = embedding_column._get_sequence_dense_tensor(
+        _LazyBuilder({'aaa': sparse_input}))
+
+    with monitored_session.MonitoredSession() as sess:
+      self.assertAllEqual(
+          expected_sequence_length, sequence_length.eval(session=sess))
+
+
+class SequenceNumericColumnTest(test.TestCase):
+
+  def test_get_sequence_dense_tensor(self):
+    sparse_input = sparse_tensor.SparseTensorValue(
+        # example 0, values [[0.], [1]]
+        # example 1, [[10.]]
+        indices=((0, 0), (0, 1), (1, 0)),
+        values=(0., 1., 10.),
+        dense_shape=(2, 2))
+    expected_dense_tensor = [
+        [[0.], [1.]],
+        [[10.], [0.]],
+    ]
+    numeric_column = sfc.sequence_numeric_column('aaa')
+
+    dense_tensor, _ = numeric_column._get_sequence_dense_tensor(
+        _LazyBuilder({'aaa': sparse_input}))
+
+    with monitored_session.MonitoredSession() as sess:
+      self.assertAllEqual(
+          expected_dense_tensor, dense_tensor.eval(session=sess))
+
+  def test_get_sequence_dense_tensor_with_shape(self):
+    """Tests get_sequence_dense_tensor with shape !=(1,)."""
+    sparse_input = sparse_tensor.SparseTensorValue(
+        # example 0, values [[0., 1., 2.], [3., 4., 5.]]
+        # example 1, [[10., 11., 12.]]
+        indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5),
+                 (1, 0), (1, 1), (1, 2)),
+        values=(0., 1., 2., 3., 4., 5., 10., 11., 12.),
+        dense_shape=(2, 6))
+    expected_dense_tensor = [
+        [[0., 1., 2.], [3., 4., 5.]],
+        [[10., 11., 12.], [0., 0., 0.]],
+    ]
+    numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,))
+
+    dense_tensor, _ = numeric_column._get_sequence_dense_tensor(
+        _LazyBuilder({'aaa': sparse_input}))
+
+    with monitored_session.MonitoredSession() as sess:
+      self.assertAllEqual(
+          expected_dense_tensor, dense_tensor.eval(session=sess))
+
+  def test_get_dense_tensor_multi_dim(self):
+    """Tests get_sequence_dense_tensor for multi-dim numeric_column."""
+    sparse_input = sparse_tensor.SparseTensorValue(
+        # example 0, values [[[0., 1.],  [2., 3.]], [[4., 5.],  [6., 7.]]]
+        # example 1, [[[10., 11.],  [12., 13.]]]
+        indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7),
+                 (1, 0), (1, 1), (1, 2), (1, 3)),
+        values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
+        dense_shape=(2, 8))
+    expected_dense_tensor = [
+        [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]],
+        [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]],
+    ]
+    numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2))
+
+    dense_tensor, _ = numeric_column._get_sequence_dense_tensor(
+        _LazyBuilder({'aaa': sparse_input}))
+
+    with monitored_session.MonitoredSession() as sess:
+      self.assertAllEqual(
+          expected_dense_tensor, dense_tensor.eval(session=sess))
+
+  def test_sequence_length(self):
+    sparse_input = sparse_tensor.SparseTensorValue(
+        # example 0, values [[0., 1., 2.], [3., 4., 5.]]
+        # example 1, [[10., 11., 12.]]
+        indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5),
+                 (1, 0), (1, 1), (1, 2)),
+        values=(0., 1., 2., 3., 4., 5., 10., 11., 12.),
+        dense_shape=(2, 6))
+    expected_sequence_length = [2, 1]
+    numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,))
+
+    _, sequence_length = numeric_column._get_sequence_dense_tensor(
+        _LazyBuilder({'aaa': sparse_input}))
+
+    with monitored_session.MonitoredSession() as sess:
+      self.assertAllEqual(
+          expected_sequence_length, sequence_length.eval(session=sess))
+
+  def test_sequence_length_with_shape(self):
+    """Tests _sequence_length with shape !=(1,)."""
+    sparse_input = sparse_tensor.SparseTensorValue(
+        # example 0, values [[0.], [1]]
+        # example 1, [[10.]]
+        indices=((0, 0), (0, 1), (1, 0)),
+        values=(0., 1., 10.),
+        dense_shape=(2, 2))
+    expected_sequence_length = [2, 1]
+    numeric_column = sfc.sequence_numeric_column('aaa')
+
+    _, sequence_length = numeric_column._get_sequence_dense_tensor(
+        _LazyBuilder({'aaa': sparse_input}))
+
+    with monitored_session.MonitoredSession() as sess:
+      self.assertAllEqual(
+          expected_sequence_length, sequence_length.eval(session=sess))
+
+  def test_sequence_length_with_empty_rows(self):
+    """Tests _sequence_length when some examples do not have ids."""
+    sparse_input = sparse_tensor.SparseTensorValue(
+        # example 0, values []
+        # example 1, values [[0.], [1.]]
+        # example 2, [[2.]]
+        # example 3, values []
+        # example 4, [[3.]]
+        # example 5, values []
+        indices=((1, 0), (1, 1), (2, 0), (4, 0)),
+        values=(0., 1., 2., 3.),
+        dense_shape=(6, 2))
+    expected_sequence_length = [0, 2, 1, 0, 1, 0]
+    numeric_column = sfc.sequence_numeric_column('aaa')
+
+    _, sequence_length = numeric_column._get_sequence_dense_tensor(
+        _LazyBuilder({'aaa': sparse_input}))
+
+    with monitored_session.MonitoredSession() as sess:
+      self.assertAllEqual(
+          expected_sequence_length, sequence_length.eval(session=sess))
+
+
+if __name__ == '__main__':
+  test.main()