From: A. Unique TensorFlower Date: Mon, 26 Feb 2018 22:25:30 +0000 (-0800) Subject: 1st version of sequential feature columns. X-Git-Tag: upstream/v1.7.0~101 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c3ad72500cd714a39af5ab530ab14f477cc717c6;p=platform%2Fupstream%2Ftensorflow.git 1st version of sequential feature columns. PiperOrigin-RevId: 187080635 --- diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index 6fc0537..a53e36c 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -33,5 +33,34 @@ py_library( name = "sequential_feature_column", srcs = ["python/feature_column/sequential_feature_column.py"], srcs_version = "PY2AND3", - deps = [], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:variable_scope", + "//tensorflow/python/feature_column", + ], +) + +py_test( + name = "sequential_feature_column_test", + srcs = ["python/feature_column/sequential_feature_column_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":sequential_feature_column", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:training", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + ], ) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py index 690a44f..4ed7268 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py @@ -12,8 +12,314 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Experimental methods for tf.feature_column sequential input.""" +"""Experimental methods for tf.feature_column sequence input.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function + + +import abc +import collections + + +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import variable_scope + +# TODO(b/73160931): Fix pydoc. +# pylint: disable=g-doc-args,missing-docstring,protected-access +# TODO(b/73827486): Support SequenceExample. + + +def sequence_input_layer( + features, + feature_columns, + weight_collections=None, + trainable=True, + scope=None): + """"Builds input layer for sequence input. + + All `feature_columns` must be sequence dense columns with the same + `sequence_length`. The output of this method can be fed into sequence + networks, such as RNN. + + The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ from + batch to batch. + + If multiple `feature_columns` are given with `Di` `num_elements` each, their + outputs are concatenated. So, the final `Tensor` has shape + `[batch_size, T, D0 + D1 + ... + Dn]`. + + Example: + + ```python + rating = sequence_numeric_column('rating') + watches = sequence_categorical_column_with_identity( + 'watches', num_buckets=1000) + watches_embedding = embedding_column(watches, dimension=10) + columns = [rating, watches] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Returns: + An `(input_layer, sequence_length)` tuple where: + - input_layer: A float `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ + from batch to batch. `D` is the sum of `num_elements` for all + `feature_columns`. + - sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence + length for each example. + Raises: + ValueError: If any of the `feature_columns` is the wrong type. + """ + feature_columns = fc._clean_feature_columns(feature_columns) + for c in feature_columns: + if not isinstance(c, _SequenceDenseColumn): + raise ValueError( + 'All feature_columns must be of type _SequenceDenseColumn. ' + 'Given (type {}): {}'.format(type(c), c)) + + with variable_scope.variable_scope( + scope, default_name='sequence_input_layer', values=features.values()): + builder = fc._LazyBuilder(features) + output_tensors = [] + sequence_lengths = [] + ordered_columns = [] + for column in sorted(feature_columns, key=lambda x: x.name): + ordered_columns.append(column) + with variable_scope.variable_scope( + None, default_name=column._var_scope_name): + dense_tensor, sequence_length = column._get_sequence_dense_tensor( + builder, + weight_collections=weight_collections, + trainable=trainable) + # Flattens the final dimension to produce a 3D Tensor. + num_elements = column._variable_shape.num_elements() + shape = array_ops.shape(dense_tensor) + output_tensors.append( + array_ops.reshape( + dense_tensor, + shape=array_ops.concat([shape[:2], [num_elements]], axis=0))) + sequence_lengths.append(sequence_length) + fc._verify_static_batch_size_equality(output_tensors, ordered_columns) + # TODO(b/73160931): Verify sequence_length equality. + return array_ops.concat(output_tensors, -1), sequence_lengths[0] + + +# TODO(b/73160931): Add remaining categorical columns. +def sequence_categorical_column_with_identity( + key, num_buckets, default_value=None): + return _SequenceCategoricalColumn( + fc.categorical_column_with_identity( + key=key, + num_buckets=num_buckets, + default_value=default_value)) + + +# TODO(b/73160931): Merge with embedding_column +def _sequence_embedding_column( + categorical_column, dimension, initializer=None, ckpt_to_load_from=None, + tensor_name_in_ckpt=None, max_norm=None, trainable=True): + if not isinstance(categorical_column, _SequenceCategoricalColumn): + raise ValueError( + 'categorical_column must be of type _SequenceCategoricalColumn. ' + 'Given (type {}): {}'.format( + type(categorical_column), categorical_column)) + return _SequenceEmbeddingColumn( + fc.embedding_column( + categorical_column, + dimension=dimension, + initializer=initializer, + ckpt_to_load_from=ckpt_to_load_from, + tensor_name_in_ckpt=tensor_name_in_ckpt, + max_norm=max_norm, + trainable=trainable)) + + +def sequence_numeric_column( + key, + shape=(1,), + default_value=0., + dtype=dtypes.float32): + # TODO(b/73160931): Add validations. + return _SequenceNumericColumn( + key, + shape=shape, + default_value=default_value, + dtype=dtype) + + +class _SequenceDenseColumn(fc._FeatureColumn): + """Represents dense sequence data.""" + + __metaclass__ = abc.ABCMeta + + TensorSequenceLengthPair = collections.namedtuple( # pylint: disable=invalid-name + 'TensorSequenceLengthPair', ['dense_tensor', 'sequence_length']) + + @abc.abstractproperty + def _variable_shape(self): + """`TensorShape` without batch and sequence dimensions.""" + pass + + @abc.abstractmethod + def _get_sequence_dense_tensor( + self, inputs, weight_collections=None, trainable=None): + """Returns a `TensorSequenceLengthPair`.""" + pass + + +def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1): + with ops.name_scope(None, 'sequence_length') as name_scope: + row_ids = sp_tensor.indices[:, 0] + column_ids = sp_tensor.indices[:, 1] + column_ids += array_ops.ones_like(column_ids) + seq_length = ( + math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements) + # If the last n rows do not have ids, seq_length will have shape + # [batch_size - n]. Pad the remaining values with zeros. + n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1] + padding = array_ops.zeros(n_pad, dtype=seq_length.dtype) + return array_ops.concat([seq_length, padding], axis=0, name=name_scope) + + +class _SequenceCategoricalColumn( + fc._CategoricalColumn, + collections.namedtuple( + '_SequenceCategoricalColumn', ['categorical_column'])): + + @property + def name(self): + return self.categorical_column.name + + @property + def _parse_example_spec(self): + return self.categorical_column._parse_example_spec + + def _transform_feature(self, inputs): + return self.categorical_column._transform_feature(inputs) + + @property + def _num_buckets(self): + return self.categorical_column._num_buckets + + def _get_sparse_tensors(self, inputs, weight_collections=None, + trainable=None): + sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) + id_tensor = sparse_tensors.id_tensor + weight_tensor = sparse_tensors.weight_tensor + # Expands final dimension, so that embeddings are not combined during + # embedding lookup. + check_id_rank = check_ops.assert_equal( + array_ops.rank(id_tensor), 2, + data=[ + 'Column {} expected ID tensor of rank 2. '.format(self.name), + 'id_tensor shape: ', array_ops.shape(id_tensor)]) + with ops.control_dependencies([check_id_rank]): + id_tensor = sparse_ops.sparse_reshape( + id_tensor, + shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0)) + if weight_tensor is not None: + check_weight_rank = check_ops.assert_equal( + array_ops.rank(weight_tensor), 2, + data=[ + 'Column {} expected weight tensor of rank 2.'.format(self.name), + 'weight_tensor shape:', array_ops.shape(weight_tensor)]) + with ops.control_dependencies([check_weight_rank]): + weight_tensor = sparse_ops.sparse_reshape( + weight_tensor, + shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0)) + return fc._CategoricalColumn.IdWeightPair(id_tensor, weight_tensor) + + def _sequence_length(self, inputs): + sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) + return _sequence_length_from_sparse_tensor(sparse_tensors.id_tensor) + + +class _SequenceEmbeddingColumn( + _SequenceDenseColumn, + collections.namedtuple('_SequenceEmbeddingColumn', ['embedding_column'])): + + @property + def name(self): + return self.embedding_column.name + + @property + def _parse_example_spec(self): + return self.embedding_column._parse_example_spec + + def _transform_feature(self, inputs): + return self.embedding_column._transform_feature(inputs) + + @property + def _variable_shape(self): + return self.embedding_column._variable_shape + + def _get_sequence_dense_tensor( + self, inputs, weight_collections=None, trainable=None): + dense_tensor = self.embedding_column._get_dense_tensor( + inputs=inputs, + weight_collections=weight_collections, + trainable=trainable) + sequence_length = self.embedding_column.categorical_column._sequence_length( + inputs) + return _SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) + + +class _SequenceNumericColumn( + _SequenceDenseColumn, + collections.namedtuple( + '_SequenceNumericColumn', + ['key', 'shape', 'default_value', 'dtype'])): + + @property + def name(self): + return self.key + + @property + def _parse_example_spec(self): + return {self.key: parsing_ops.VarLenFeature(self.dtype)} + + def _transform_feature(self, inputs): + return inputs.get(self.key) + + @property + def _variable_shape(self): + return tensor_shape.TensorShape(self.shape) + + def _get_sequence_dense_tensor( + self, inputs, weight_collections=None, trainable=None): + # Do nothing with weight_collections and trainable since no variables are + # created in this function. + del weight_collections + del trainable + sp_tensor = inputs.get(self) + dense_tensor = sparse_ops.sparse_tensor_to_dense( + sp_tensor, default_value=self.default_value) + # Reshape into [batch_size, T, variable_shape]. + dense_shape = array_ops.concat( + [array_ops.shape(dense_tensor)[:1], [-1], self._variable_shape], + axis=0) + dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape) + sequence_length = _sequence_length_from_sparse_tensor( + sp_tensor, num_elements=self._variable_shape.num_elements()) + return _SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) + +# pylint: enable=g-doc-args,missing-docstring,protected-access diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py new file mode 100644 index 0000000..5967486 --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py @@ -0,0 +1,471 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sequential_feature_column.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.feature_column.python.feature_column import sequential_feature_column as sfc +from tensorflow.python.feature_column.feature_column import _LazyBuilder +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.platform import test +from tensorflow.python.training import monitored_session + + +class SequenceInputLayerTest(test.TestCase): + + def test_embedding_column(self): + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [2, 0] + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + + embedding_dimension_a = 2 + embedding_values_a = ( + (1., 2.), # id 0 + (3., 4.), # id 1 + (5., 6.) # id 2 + ) + embedding_dimension_b = 3 + embedding_values_b = ( + (11., 12., 13.), # id 0 + (14., 15., 16.), # id 1 + (17., 18., 19.) # id 2 + ) + def _get_initializer(embedding_dimension, embedding_values): + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + return _initializer + + expected_input_layer = [ + # example 0, ids_a [2], ids_b [1] + [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [2, 0] + [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]], + ] + expected_sequence_length = [1, 2] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = sfc._sequence_embedding_column( + categorical_column_a, dimension=embedding_dimension_a, + initializer=_get_initializer(embedding_dimension_a, embedding_values_a)) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_b = sfc._sequence_embedding_column( + categorical_column_b, dimension=embedding_dimension_b, + initializer=_get_initializer(embedding_dimension_b, embedding_values_b)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + # Test that columns are reordered alphabetically. + feature_columns=[embedding_column_b, embedding_column_a]) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('sequence_input_layer/aaa_embedding/embedding_weights:0', + 'sequence_input_layer/bbb_embedding/embedding_weights:0'), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values_a, global_vars[0].eval(session=sess)) + self.assertAllEqual(embedding_values_b, global_vars[1].eval(session=sess)) + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_numeric_column(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + expected_input_layer = [ + [[0.], [1.]], + [[10.], [0.]], + ] + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa') + + input_layer, sequence_length = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_numeric_column_multi_dim(self): + """Tests sequence_input_layer for multi-dimensional numeric_column.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), + (1, 0), (1, 1), (1, 2), (1, 3)), + values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + dense_shape=(2, 8)) + # The output of numeric_column._get_dense_tensor should be flattened. + expected_input_layer = [ + [[0., 1., 2., 3.], [4., 5., 6., 7.]], + [[10., 11., 12., 13.], [0., 0., 0., 0.]], + ] + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +def _assert_sparse_tensor_value(test_case, expected, actual): + test_case.assertEqual(np.int64, np.array(actual.indices).dtype) + test_case.assertAllEqual(expected.indices, actual.indices) + + test_case.assertEqual( + np.array(expected.values).dtype, np.array(actual.values).dtype) + test_case.assertAllEqual(expected.values, actual.values) + + test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype) + test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) + + +class SequenceCategoricalColumnWithIdentityTest(test.TestCase): + + def test_get_sparse_tensors(self): + column = sfc.sequence_categorical_column_with_identity( + 'aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + expected_sparse_ids = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=np.array((1, 2, 0), dtype=np.int64), + dense_shape=(2, 2, 1)) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_value( + self, + expected_sparse_ids, + id_weight_pair.id_tensor.eval(session=sess)) + + def test_get_sparse_tensors_inputs3d(self): + """Tests _get_sparse_tensors when the input is already 3D Tensor.""" + column = sfc.sequence_categorical_column_with_identity( + 'aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=(1, 2, 0), + dense_shape=(2, 2, 1)) + + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'Column aaa expected ID tensor of rank 2\.\s*' + r'id_tensor shape:\s*\[2 2 1\]'): + id_weight_pair = column._get_sparse_tensors( + _LazyBuilder({'aaa': inputs})) + with monitored_session.MonitoredSession() as sess: + id_weight_pair.id_tensor.eval(session=sess) + + def test_sequence_length(self): + column = sfc.sequence_categorical_column_with_identity( + 'aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + expected_sequence_length = [1, 2] + + sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_with_zeros(self): + column = sfc.sequence_categorical_column_with_identity( + 'aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((1, 0), (3, 0), (3, 1)), + values=(1, 2, 0), + dense_shape=(5, 2)) + expected_sequence_length = [0, 1, 0, 2, 0] + + sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +class SequenceEmbeddingColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)) + + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + expected_lookups = [ + # example 0, ids [2] + [[7., 11.], [0., 0.]], + # example 1, ids [0, 1] + [[1., 2.], [3., 5.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [1] + [[3., 5.], [0., 0.]], + ] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = sfc._sequence_embedding_column( + categorical_column, dimension=embedding_dimension, + initializer=_initializer) + + embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('embedding_weights:0',), tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual(expected_lookups, embedding_lookup.eval(session=sess)) + + def test_sequence_length(self): + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + expected_sequence_length = [1, 2] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = sfc._sequence_embedding_column( + categorical_column, dimension=2) + + _, sequence_length = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length = [0, 1, 2, 0, 1, 0] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = sfc._sequence_embedding_column( + categorical_column, dimension=2) + + _, sequence_length = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +class SequenceNumericColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + expected_dense_tensor = [ + [[0.], [1.]], + [[10.], [0.]], + ] + numeric_column = sfc.sequence_numeric_column('aaa') + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_get_sequence_dense_tensor_with_shape(self): + """Tests get_sequence_dense_tensor with shape !=(1,).""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0., 1., 2.], [3., 4., 5.]] + # example 1, [[10., 11., 12.]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), + (1, 0), (1, 1), (1, 2)), + values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), + dense_shape=(2, 6)) + expected_dense_tensor = [ + [[0., 1., 2.], [3., 4., 5.]], + [[10., 11., 12.], [0., 0., 0.]], + ] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_get_dense_tensor_multi_dim(self): + """Tests get_sequence_dense_tensor for multi-dim numeric_column.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), + (1, 0), (1, 1), (1, 2), (1, 3)), + values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + dense_shape=(2, 8)) + expected_dense_tensor = [ + [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]], + [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]], + ] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_sequence_length(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0., 1., 2.], [3., 4., 5.]] + # example 1, [[10., 11., 12.]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), + (1, 0), (1, 1), (1, 2)), + values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), + dense_shape=(2, 6)) + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) + + _, sequence_length = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_with_shape(self): + """Tests _sequence_length with shape !=(1,).""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa') + + _, sequence_length = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [] + # example 1, values [[0.], [1.]] + # example 2, [[2.]] + # example 3, values [] + # example 4, [[3.]] + # example 5, values [] + indices=((1, 0), (1, 1), (2, 0), (4, 0)), + values=(0., 1., 2., 3.), + dense_shape=(6, 2)) + expected_sequence_length = [0, 2, 1, 0, 1, 0] + numeric_column = sfc.sequence_numeric_column('aaa') + + _, sequence_length = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +if __name__ == '__main__': + test.main()