row_ids = sp_tensor.indices[:, 0]
column_ids = sp_tensor.indices[:, 1]
column_ids += array_ops.ones_like(column_ids)
- seq_length = (
+ seq_length = math_ops.to_int64(
math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements)
# If the last n rows do not have ids, seq_length will have shape
# [batch_size - n]. Pad the remaining values with zeros.
sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs}))
with monitored_session.MonitoredSession() as sess:
- self.assertAllEqual(
- expected_sequence_length, sequence_length.eval(session=sess))
+ sequence_length = sess.run(sequence_length)
+ self.assertAllEqual(expected_sequence_length, sequence_length)
+ self.assertEqual(np.int64, sequence_length.dtype)
def test_sequence_length_with_zeros(self):
column = sfc.sequence_categorical_column_with_identity(
_LazyBuilder({'aaa': sparse_input}))
with monitored_session.MonitoredSession() as sess:
- self.assertAllEqual(
- expected_sequence_length, sequence_length.eval(session=sess))
+ sequence_length = sess.run(sequence_length)
+ self.assertAllEqual(expected_sequence_length, sequence_length)
+ self.assertEqual(np.int64, sequence_length.dtype)
def test_sequence_length_with_empty_rows(self):
"""Tests _sequence_length when some examples do not have ids."""
_LazyBuilder({'aaa': sparse_input}))
with monitored_session.MonitoredSession() as sess:
- self.assertAllEqual(
- expected_sequence_length, sequence_length.eval(session=sess))
+ sequence_length = sess.run(sequence_length)
+ self.assertAllEqual(expected_sequence_length, sequence_length)
+ self.assertEqual(np.int64, sequence_length.dtype)
def test_sequence_length_with_shape(self):
"""Tests _sequence_length with shape !=(1,)."""