Cast sequence_length to an integer.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 1 Mar 2018 21:00:40 +0000 (13:00 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 1 Mar 2018 21:04:20 +0000 (13:04 -0800)
PiperOrigin-RevId: 187520920

tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py

index e99033b..e446043 100644 (file)
@@ -295,7 +295,7 @@ def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1):
     row_ids = sp_tensor.indices[:, 0]
     column_ids = sp_tensor.indices[:, 1]
     column_ids += array_ops.ones_like(column_ids)
-    seq_length = (
+    seq_length = math_ops.to_int64(
         math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements)
     # If the last n rows do not have ids, seq_length will have shape
     # [batch_size - n]. Pad the remaining values with zeros.
index 8c37ccf..1052136 100644 (file)
@@ -221,8 +221,9 @@ class SequenceCategoricalColumnWithIdentityTest(test.TestCase):
     sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs}))
 
     with monitored_session.MonitoredSession() as sess:
-      self.assertAllEqual(
-          expected_sequence_length, sequence_length.eval(session=sess))
+      sequence_length = sess.run(sequence_length)
+      self.assertAllEqual(expected_sequence_length, sequence_length)
+      self.assertEqual(np.int64, sequence_length.dtype)
 
   def test_sequence_length_with_zeros(self):
     column = sfc.sequence_categorical_column_with_identity(
@@ -311,8 +312,9 @@ class SequenceEmbeddingColumnTest(test.TestCase):
         _LazyBuilder({'aaa': sparse_input}))
 
     with monitored_session.MonitoredSession() as sess:
-      self.assertAllEqual(
-          expected_sequence_length, sequence_length.eval(session=sess))
+      sequence_length = sess.run(sequence_length)
+      self.assertAllEqual(expected_sequence_length, sequence_length)
+      self.assertEqual(np.int64, sequence_length.dtype)
 
   def test_sequence_length_with_empty_rows(self):
     """Tests _sequence_length when some examples do not have ids."""
@@ -423,8 +425,9 @@ class SequenceNumericColumnTest(test.TestCase):
         _LazyBuilder({'aaa': sparse_input}))
 
     with monitored_session.MonitoredSession() as sess:
-      self.assertAllEqual(
-          expected_sequence_length, sequence_length.eval(session=sess))
+      sequence_length = sess.run(sequence_length)
+      self.assertAllEqual(expected_sequence_length, sequence_length)
+      self.assertEqual(np.int64, sequence_length.dtype)
 
   def test_sequence_length_with_shape(self):
     """Tests _sequence_length with shape !=(1,)."""