from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class ScanDatasetTest(test.TestCase):
- def _count(self, start, step):
- return dataset_ops.Dataset.from_tensors(0).repeat(None).apply(
- scan_ops.scan(start, lambda state, _: (state + step, state)))
+ def _counting_dataset(self, start, scan_fn):
+ return dataset_ops.Dataset.from_tensors(0).repeat().apply(
+ scan_ops.scan(start, scan_fn))
def testCount(self):
+ def make_scan_fn(step):
+ return lambda state, _: (state + step, state)
+
start = array_ops.placeholder(dtypes.int32, shape=[])
step = array_ops.placeholder(dtypes.int32, shape=[])
take = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = self._count(start, step).take(take).make_initializable_iterator()
+ iterator = self._counting_dataset(
+ start, make_scan_fn(step)).take(take).make_initializable_iterator()
next_element = iterator.get_next()
with self.test_session() as sess:
self.assertEqual(5, self.evaluate(next_element()))
self.assertEqual(8, self.evaluate(next_element()))
+ def testSparseCount(self):
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1]))
+
+ def make_scan_fn(step):
+ return lambda state, _: (_sparse(state.values[0] + step), state)
+
+ start = array_ops.placeholder(dtypes.int32, shape=[])
+ step = array_ops.placeholder(dtypes.int32, shape=[])
+ take = array_ops.placeholder(dtypes.int64, shape=[])
+ iterator = self._counting_dataset(
+ _sparse(start),
+ make_scan_fn(step)).take(take).make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+
+ for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
+ (10, 2, 10), (10, -1, 10),
+ (10, -2, 10)]:
+ sess.run(iterator.initializer,
+ feed_dict={start: start_val, step: step_val, take: take_val})
+ for expected, _ in zip(
+ itertools.count(start_val, step_val), range(take_val)):
+ self.assertEqual(expected, sess.run(next_element).values[0])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
def testChangingStateShape(self):
# Test the fixed-point shape invariant calculations: start with
# initial values with known shapes, and use a scan function that
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gen_dataset_ops
self._input_dataset = input_dataset
with ops.name_scope("initial_state"):
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
self._initial_state = nest.pack_sequence_as(initial_state, [
- ops.convert_to_tensor(t, name="component_%d" % i)
+ sparse_tensor.SparseTensor.from_value(t)
+ if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(
+ t, name="component_%d" % i)
for i, t in enumerate(nest.flatten(initial_state))
])
- # Compute initial values for the state shapes and types based on
- # the initial state. These will be refined by running
- # `tf_scan_func` one or more times below.
- # TODO(b/68937811): Allow the initial state to be a tf.SparseTensor.
+ # Compute initial values for the state classes, shapes and types based on
+ # the initial state. The shapes may be refined by running `tf_scan_func` one
+ # or more times below.
+ self._state_classes = sparse.get_classes(self._initial_state)
self._state_shapes = nest.pack_sequence_as(
self._initial_state,
- [t.shape for t in nest.flatten(self._initial_state)])
+ [t.get_shape() for t in nest.flatten(self._initial_state)])
self._state_types = nest.pack_sequence_as(
self._initial_state,
[t.dtype for t in nest.flatten(self._initial_state)])
need_to_rerun = True
while need_to_rerun:
- flat_state_shapes = nest.flatten(self._state_shapes)
- flat_state_types = nest.flatten(self._state_types)
-
- # Create a list in which `tf_scan_func` will store the s
+ # Create a list in which `tf_scan_func` will store the new shapes.
flat_new_state_shapes = []
- @function.Defun(*(flat_state_types + nest.flatten(
- sparse.as_dense_types(input_dataset.output_types,
- input_dataset.output_classes))))
+ @function.Defun(*(nest.flatten(
+ sparse.as_dense_types(
+ self._state_types, self._state_classes)) + nest.flatten(
+ sparse.as_dense_types(input_dataset.output_types,
+ input_dataset.output_classes))))
def tf_scan_func(*args):
"""A wrapper for Defun that facilitates shape inference."""
# Pass in shape information from the state and input_dataset.
- # TODO(b/69424092): Check that neither inputs nor outputs are sparse.
- dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
- input_dataset.output_classes)
- for arg, shape in zip(args,
- flat_state_shapes + nest.flatten(dense_shapes)):
+ for arg, shape in zip(
+ args,
+ nest.flatten(
+ sparse.as_dense_shapes(self._state_shapes, self._state_classes))
+ + nest.flatten(
+ sparse.as_dense_shapes(input_dataset.output_shapes,
+ input_dataset.output_classes))):
arg.set_shape(shape)
- pivot = len(flat_state_shapes)
- old_state = nest.pack_sequence_as(self._initial_state, args[:pivot])
- input_value = nest.pack_sequence_as(input_dataset.output_types,
- args[pivot:])
-
- ret = scan_func(old_state, input_value)
+ pivot = len(nest.flatten(self._state_shapes))
+ print(self._state_classes)
+ nested_state_args = nest.pack_sequence_as(self._state_types,
+ args[:pivot])
+ nested_state_args = sparse.deserialize_sparse_tensors(
+ nested_state_args, self._state_types, self._state_shapes,
+ self._state_classes)
+ print(input_dataset.output_classes)
+ nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
+ args[pivot:])
+ nested_input_args = sparse.deserialize_sparse_tensors(
+ nested_input_args, input_dataset.output_types,
+ input_dataset.output_shapes, input_dataset.output_classes)
+
+ ret = scan_func(nested_state_args, nested_input_args)
if not isinstance(ret, collections.Sequence) or len(ret) != 2:
raise TypeError("The scan function must return a pair comprising the "
"new state and the output value.")
+
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
+ ret = nest.pack_sequence_as(ret, [
+ sparse_tensor.SparseTensor.from_value(t)
+ if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
+ for t in nest.flatten(ret)
+ ])
new_state, output_value = ret
- flat_new_state = [
- ops.convert_to_tensor(t) for t in nest.flatten(new_state)
- ]
- flat_output_value = [
- ops.convert_to_tensor(t) for t in nest.flatten(output_value)
- ]
+ # Extract and validate class information from the returned values.
+ for t, clazz in zip(
+ nest.flatten(new_state), nest.flatten(self._state_classes)):
+ if not isinstance(t, clazz):
+ raise TypeError(
+ "The element classes for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_classes,
+ nest.pack_sequence_as(
+ self._state_types,
+ [type(t) for t in nest.flatten(new_state)])))
+ self._output_classes = sparse.get_classes(output_value)
# Extract shape information from the returned values.
- flat_new_state_shapes.extend([t.shape for t in flat_new_state])
+ flat_new_state_shapes.extend(
+ [t.get_shape() for t in nest.flatten(new_state)])
self._output_shapes = nest.pack_sequence_as(
- output_value, [t.shape for t in flat_output_value])
+ output_value, [t.get_shape() for t in nest.flatten(output_value)])
# Extract and validate type information from the returned values.
- for t, dtype in zip(flat_new_state, flat_state_types):
+ for t, dtype in zip(
+ nest.flatten(new_state), nest.flatten(self._state_types)):
if t.dtype != dtype:
raise TypeError(
"The element types for the new state must match the initial "
"state. Expected %s; got %s." %
- (self._state_types, nest.pack_sequence_as(
- self._state_types, [t.dtype for t in flat_new_state])))
- self._output_classes = nest.pack_sequence_as(
- output_value, [ops.Tensor for _ in flat_output_value])
+ (self._state_types,
+ nest.pack_sequence_as(
+ self._state_types,
+ [t.dtype for t in nest.flatten(new_state)])))
self._output_types = nest.pack_sequence_as(
- output_value, [t.dtype for t in flat_output_value])
-
- return flat_new_state + flat_output_value
+ output_value, [t.dtype for t in nest.flatten(output_value)])
+
+ # Serialize any sparse tensors.
+ new_state = nest.pack_sequence_as(new_state, [
+ t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state))
+ ])
+ output_value = nest.pack_sequence_as(output_value, [
+ t for t in nest.flatten(
+ sparse.serialize_sparse_tensors(output_value))
+ ])
+ return nest.flatten(new_state) + nest.flatten(output_value)
# Use the private method that will execute `tf_scan_func` but delay
# adding it to the graph in case we need to rerun the function.
tf_scan_func._create_definition_if_needed() # pylint: disable=protected-access
+ flat_state_shapes = nest.flatten(self._state_shapes)
weakened_state_shapes = [
original.most_specific_compatible_shape(new)
for original, new in zip(flat_state_shapes, flat_new_state_shapes)
input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
return gen_dataset_ops.scan_dataset(
input_t,
- nest.flatten(self._initial_state),
+ nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
self._scan_func.captured_inputs,
f=self._scan_func,
output_types=nest.flatten(