[tf.data] Adding an experimental `group_by_reducer` transformation which groups eleme...
authorJiri Simsa <jsimsa@google.com>
Tue, 1 May 2018 00:38:38 +0000 (17:38 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 1 May 2018 00:40:46 +0000 (17:40 -0700)
PiperOrigin-RevId: 194874087

tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
tensorflow/contrib/data/python/ops/grouping.py
tensorflow/core/api_def/base_api/api_def_GroupByReducerDataset.pbtxt [new file with mode: 0644]
tensorflow/core/kernels/data/BUILD
tensorflow/core/kernels/data/captured_function.cc
tensorflow/core/kernels/data/captured_function.h
tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc [new file with mode: 0644]
tensorflow/core/kernels/data/group_by_window_dataset_op.cc
tensorflow/core/ops/dataset_ops.cc

index 55a56b8..bd3e034 100644 (file)
@@ -28,6 +28,7 @@ from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
@@ -35,6 +36,179 @@ from tensorflow.python.ops import string_ops
 from tensorflow.python.platform import test
 
 
+class GroupByReducerTest(test.TestCase):
+
+  def checkResults(self, dataset, shapes, values):
+    self.assertEqual(shapes, dataset.output_shapes)
+    get_next = dataset.make_one_shot_iterator().get_next()
+    with self.test_session() as sess:
+      for expected in values:
+        got = sess.run(get_next)
+        self.assertEqual(got, expected)
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(get_next)
+
+  def testSum(self):
+    reducer = grouping.Reducer(
+        init_func=lambda _: np.int64(0),
+        reduce_func=lambda x, y: x + y,
+        finalize_func=lambda x: x)
+    for i in range(1, 11):
+      dataset = dataset_ops.Dataset.range(2 * i).apply(
+          grouping.group_by_reducer(lambda x: x % 2, reducer))
+      self.checkResults(
+          dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
+
+  def testAverage(self):
+
+    def reduce_fn(x, y):
+      return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / (
+          x[1] + 1), x[1] + 1
+
+    reducer = grouping.Reducer(
+        init_func=lambda _: (0.0, 0.0),
+        reduce_func=reduce_fn,
+        finalize_func=lambda x: x[0])
+    for i in range(1, 11):
+      dataset = dataset_ops.Dataset.range(2 * i).apply(
+          grouping.group_by_reducer(
+              lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer))
+      self.checkResults(
+          dataset, shapes=tensor_shape.scalar(), values=[i - 1, i])
+
+  def testConcat(self):
+    components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
+    reducer = grouping.Reducer(
+        init_func=lambda x: "",
+        reduce_func=lambda x, y: x + y[0],
+        finalize_func=lambda x: x)
+    for i in range(1, 11):
+      dataset = dataset_ops.Dataset.zip(
+          (dataset_ops.Dataset.from_tensor_slices(components),
+           dataset_ops.Dataset.range(2 * i))).apply(
+               grouping.group_by_reducer(lambda x, y: y % 2, reducer))
+      self.checkResults(
+          dataset,
+          shapes=tensor_shape.scalar(),
+          values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]])
+
+  def testSparseSum(self):
+    def _sparse(i):
+      return sparse_tensor.SparseTensorValue(
+          indices=np.array([[0, 0]]),
+          values=(i * np.array([1], dtype=np.int64)),
+          dense_shape=np.array([1, 1]))
+
+    reducer = grouping.Reducer(
+        init_func=lambda _: _sparse(np.int64(0)),
+        reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]),
+        finalize_func=lambda x: x.values[0])
+    for i in range(1, 11):
+      dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply(
+          grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer))
+      self.checkResults(
+          dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
+
+  def testChangingStateShape(self):
+
+    def reduce_fn(x, _):
+      # Statically known rank, but dynamic length.
+      larger_dim = array_ops.concat([x[0], x[0]], 0)
+      # Statically unknown rank.
+      larger_rank = array_ops.expand_dims(x[1], 0)
+      return larger_dim, larger_rank
+
+    reducer = grouping.Reducer(
+        init_func=lambda x: ([0], 1),
+        reduce_func=reduce_fn,
+        finalize_func=lambda x: x)
+
+    for i in range(1, 11):
+      dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply(
+          grouping.group_by_reducer(lambda x: x, reducer))
+      self.assertEqual([None], dataset.output_shapes[0].as_list())
+      self.assertIs(None, dataset.output_shapes[1].ndims)
+      iterator = dataset.make_one_shot_iterator()
+      get_next = iterator.get_next()
+      with self.test_session() as sess:
+        x, y = sess.run(get_next)
+        self.assertAllEqual([0] * (2**i), x)
+        self.assertAllEqual(np.array(1, ndmin=i), y)
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next)
+
+  def testTypeMismatch(self):
+    reducer = grouping.Reducer(
+        init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32),
+        reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64),
+        finalize_func=lambda x: x)
+
+    dataset = dataset_ops.Dataset.range(10)
+    with self.assertRaisesRegexp(
+        TypeError,
+        "The element types for the new state must match the initial state."):
+      dataset.apply(
+          grouping.group_by_reducer(lambda _: np.int64(0), reducer))
+
+  # TODO(b/78665031): Remove once non-scalar keys are supported.
+  def testInvalidKeyShape(self):
+    reducer = grouping.Reducer(
+        init_func=lambda x: np.int64(0),
+        reduce_func=lambda x, y: x + y,
+        finalize_func=lambda x: x)
+
+    dataset = dataset_ops.Dataset.range(10)
+    with self.assertRaisesRegexp(
+        ValueError, "`key_func` must return a single tf.int64 tensor."):
+      dataset.apply(
+          grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
+
+  # TODO(b/78665031): Remove once non-int64 keys are supported.
+  def testInvalidKeyType(self):
+    reducer = grouping.Reducer(
+        init_func=lambda x: np.int64(0),
+        reduce_func=lambda x, y: x + y,
+        finalize_func=lambda x: x)
+
+    dataset = dataset_ops.Dataset.range(10)
+    with self.assertRaisesRegexp(
+        ValueError, "`key_func` must return a single tf.int64 tensor."):
+      dataset.apply(
+          grouping.group_by_reducer(lambda _: "wrong", reducer))
+
+
+class GroupByReducerSerializationTest(
+    dataset_serialization_test_base.DatasetSerializationTestBase):
+
+  def _build_dataset(self, components):
+    reducer = grouping.Reducer(
+        init_func=lambda _: np.int64(0),
+        reduce_func=lambda x, y: x + y,
+        finalize_func=lambda x: x)
+
+    return dataset_ops.Dataset.from_tensor_slices(components).apply(
+        grouping.group_by_reducer(lambda x: x % 5, reducer))
+
+  def testCoreGroupByReducer(self):
+    components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64)
+    self.verify_unused_iterator(
+        lambda: self._build_dataset(components), 5, verify_exhausted=True)
+    self.verify_init_before_restore(
+        lambda: self._build_dataset(components), 5, verify_exhausted=True)
+    self.verify_multiple_breaks(
+        lambda: self._build_dataset(components), 5, verify_exhausted=True)
+    self.verify_reset_restored_iterator(
+        lambda: self._build_dataset(components), 5, verify_exhausted=True)
+    self.verify_restore_in_empty_graph(
+        lambda: self._build_dataset(components), 5, verify_exhausted=True)
+    diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64)
+    self.verify_restore_in_modified_graph(
+        lambda: self._build_dataset(components),
+        lambda: self._build_dataset(diff_components),
+        5,
+        verify_exhausted=True)
+
+
 class GroupByWindowTest(test.TestCase):
 
   def testSimple(self):
index f544b1c..eb2ceff 100644 (file)
@@ -168,7 +168,7 @@ class ScanDatasetTest(test.TestCase):
           scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
 
 
-class ScanDatasetSerialzationTest(
+class ScanDatasetSerializationTest(
     dataset_serialization_test_base.DatasetSerializationTestBase):
 
   def _build_dataset(self, num_elements):
index 0531f9c..ea229b5 100644 (file)
@@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import function
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
@@ -33,6 +34,35 @@ from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import math_ops
 
 
+def group_by_reducer(key_func, reducer):
+  """A transformation that groups elements and performs a reduction.
+
+  This transformation maps element of a dataset to a key using `key_func` and
+  groups the elements by key. The `reducer` is used to process each group; its
+  `init_func` is used to initialize state for each group when it is created, the
+  `reduce_func` is used to update the state every time an element is mapped to
+  the matching group, and the `finalize_func` is used to map the final state to
+  an output value.
+
+  Args:
+    key_func: A function mapping a nested structure of tensors
+      (having shapes and types defined by `self.output_shapes` and
+      `self.output_types`) to a scalar `tf.int64` tensor.
+    reducer: An instance of `Reducer`, which captures the reduction logic using
+      the `init_func`, `reduce_func`, and `finalize_func` functions.
+
+  Returns:
+    A `Dataset` transformation function, which can be passed to
+    @{tf.data.Dataset.apply}.
+  """
+
+  def _apply_fn(dataset):
+    """Function from `Dataset` to `Dataset` that applies the transformation."""
+    return GroupByReducerDataset(dataset, key_func, reducer)
+
+  return _apply_fn
+
+
 def group_by_window(key_func,
                     reduce_func,
                     window_size=None,
@@ -227,6 +257,250 @@ class _VariantDataset(dataset_ops.Dataset):
     return self._output_types
 
 
+class GroupByReducerDataset(dataset_ops.Dataset):
+  """A `Dataset` that groups its input and performs a reduction."""
+
+  def __init__(self, input_dataset, key_func, reducer):
+    """See `group_by_reducer()` for details."""
+    super(GroupByReducerDataset, self).__init__()
+
+    self._input_dataset = input_dataset
+
+    self._make_key_func(key_func, input_dataset)
+    self._make_init_func(reducer.init_func)
+    self._make_reduce_func(reducer.reduce_func, input_dataset)
+    self._make_finalize_func(reducer.finalize_func)
+
+  def _make_key_func(self, key_func, input_dataset):
+    """Make wrapping Defun for key_func."""
+
+    @function.Defun(*nest.flatten(
+        sparse.as_dense_types(input_dataset.output_types,
+                              input_dataset.output_classes)))
+    def tf_key_func(*args):
+      """A wrapper for Defun that facilitates shape inference."""
+      # Pass in shape information from the input_dataset.
+      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
+                                            input_dataset.output_classes)
+      for arg, shape in zip(args, nest.flatten(dense_shapes)):
+        arg.set_shape(shape)
+
+      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
+      nested_args = sparse.deserialize_sparse_tensors(
+          nested_args, input_dataset.output_types, input_dataset.output_shapes,
+          input_dataset.output_classes)
+      # pylint: disable=protected-access
+      if dataset_ops._should_unpack_args(nested_args):
+        ret = key_func(*nested_args)
+      # pylint: enable=protected-access
+      else:
+        ret = key_func(nested_args)
+      ret = ops.convert_to_tensor(ret)
+      if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar():
+        raise ValueError(
+            "`key_func` must return a single tf.int64 tensor. "
+            "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape()))
+      return ret
+
+    self._key_func = tf_key_func
+    self._key_func.add_to_graph(ops.get_default_graph())
+
+  def _make_init_func(self, init_func):
+    """Make wrapping Defun for init_func."""
+
+    @function.Defun(dtypes.int64)
+    def tf_init_func(key):
+      """A wrapper for Defun that facilitates shape inference."""
+      key.set_shape([])
+      ret = init_func(key)
+      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+      # values to tensors.
+      ret = nest.pack_sequence_as(ret, [
+          sparse_tensor.SparseTensor.from_value(t)
+          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
+          for t in nest.flatten(ret)
+      ])
+
+      self._state_classes = sparse.get_classes(ret)
+      self._state_shapes = nest.pack_sequence_as(
+          ret, [t.get_shape() for t in nest.flatten(ret)])
+      self._state_types = nest.pack_sequence_as(
+          ret, [t.dtype for t in nest.flatten(ret)])
+
+      # Serialize any sparse tensors.
+      ret = nest.pack_sequence_as(
+          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
+      return nest.flatten(ret)
+
+    self._init_func = tf_init_func
+    self._init_func.add_to_graph(ops.get_default_graph())
+
+  def _make_reduce_func(self, reduce_func, input_dataset):
+    """Make wrapping Defun for reduce_func."""
+
+    # Iteratively rerun the reduce function until reaching a fixed point on
+    # `self._state_shapes`.
+    need_to_rerun = True
+    while need_to_rerun:
+
+      # Create a list in which `tf_reduce_func` will store the new shapes.
+      flat_new_state_shapes = []
+
+      @function.Defun(*(nest.flatten(
+          sparse.as_dense_types(
+              self._state_types, self._state_classes)) + nest.flatten(
+                  sparse.as_dense_types(input_dataset.output_types,
+                                        input_dataset.output_classes))))
+      def tf_reduce_func(*args):
+        """A wrapper for Defun that facilitates shape inference."""
+        for arg, shape in zip(
+            args,
+            nest.flatten(
+                sparse.as_dense_shapes(self._state_shapes, self._state_classes))
+            + nest.flatten(
+                sparse.as_dense_shapes(input_dataset.output_shapes,
+                                       input_dataset.output_classes))):
+          arg.set_shape(shape)
+
+        pivot = len(nest.flatten(self._state_shapes))
+        nested_state_args = nest.pack_sequence_as(self._state_types,
+                                                  args[:pivot])
+        nested_state_args = sparse.deserialize_sparse_tensors(
+            nested_state_args, self._state_types, self._state_shapes,
+            self._state_classes)
+        nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
+                                                  args[pivot:])
+        nested_input_args = sparse.deserialize_sparse_tensors(
+            nested_input_args, input_dataset.output_types,
+            input_dataset.output_shapes, input_dataset.output_classes)
+
+        ret = reduce_func(nested_state_args, nested_input_args)
+
+        # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+        # values to tensors.
+        ret = nest.pack_sequence_as(ret, [
+            sparse_tensor.SparseTensor.from_value(t)
+            if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
+            for t in nest.flatten(ret)
+        ])
+
+        # Extract shape information from the returned values.
+        flat_new_state = nest.flatten(ret)
+        flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state])
+
+        # Extract and validate type information from the returned values.
+        for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)):
+          if t.dtype != dtype:
+            raise TypeError(
+                "The element types for the new state must match the initial "
+                "state. Expected %s; got %s." %
+                (self._state_types,
+                 nest.pack_sequence_as(self._state_types,
+                                       [t.dtype for t in flat_new_state])))
+
+        # Serialize any sparse tensors.
+        ret = nest.pack_sequence_as(
+            ret,
+            [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
+        return nest.flatten(ret)
+
+      # Use the private method that will execute `tf_reduce_func` but delay
+      # adding it to the graph in case we need to rerun the function.
+      tf_reduce_func._create_definition_if_needed()  # pylint: disable=protected-access
+
+      flat_state_shapes = nest.flatten(self._state_shapes)
+      weakened_state_shapes = [
+          old.most_specific_compatible_shape(new)
+          for old, new in zip(flat_state_shapes, flat_new_state_shapes)
+      ]
+
+      need_to_rerun = False
+      for old_shape, weakened_shape in zip(flat_state_shapes,
+                                           weakened_state_shapes):
+        if old_shape.ndims is not None and (
+            weakened_shape.ndims is None or
+            old_shape.as_list() != weakened_shape.as_list()):
+          need_to_rerun = True
+          break
+
+      if need_to_rerun:
+        self._state_shapes = nest.pack_sequence_as(self._state_shapes,
+                                                   weakened_state_shapes)
+
+    self._reduce_func = tf_reduce_func
+    self._reduce_func.add_to_graph(ops.get_default_graph())
+
+  def _make_finalize_func(self, finalize_func):
+    """Make wrapping Defun for finalize_func."""
+
+    @function.Defun(*(nest.flatten(
+        sparse.as_dense_types(self._state_types, self._state_classes))))
+    def tf_finalize_func(*args):
+      """A wrapper for Defun that facilitates shape inference."""
+      for arg, shape in zip(
+          args,
+          nest.flatten(
+              sparse.as_dense_shapes(self._state_shapes, self._state_classes))):
+        arg.set_shape(shape)
+
+      nested_args = nest.pack_sequence_as(self._state_types, args)
+      nested_args = sparse.deserialize_sparse_tensors(
+          nested_args, self._state_types, self._state_shapes,
+          self._state_classes)
+
+      ret = finalize_func(nested_args)
+
+      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+      # values to tensors.
+      ret = nest.pack_sequence_as(ret, [
+          sparse_tensor.SparseTensor.from_value(t)
+          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
+          for t in nest.flatten(ret)
+      ])
+
+      self._output_classes = sparse.get_classes(ret)
+      self._output_shapes = nest.pack_sequence_as(
+          ret, [t.get_shape() for t in nest.flatten(ret)])
+      self._output_types = nest.pack_sequence_as(
+          ret, [t.dtype for t in nest.flatten(ret)])
+
+      # Serialize any sparse tensors.
+      ret = nest.pack_sequence_as(
+          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
+      return nest.flatten(ret)
+
+    self._finalize_func = tf_finalize_func
+    self._finalize_func.add_to_graph(ops.get_default_graph())
+
+  @property
+  def output_classes(self):
+    return self._output_classes
+
+  @property
+  def output_shapes(self):
+    return self._output_shapes
+
+  @property
+  def output_types(self):
+    return self._output_types
+
+  def _as_variant_tensor(self):
+    return gen_dataset_ops.group_by_reducer_dataset(
+        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
+        self._key_func.captured_inputs,
+        self._init_func.captured_inputs,
+        self._reduce_func.captured_inputs,
+        self._finalize_func.captured_inputs,
+        key_func=self._key_func,
+        init_func=self._init_func,
+        reduce_func=self._reduce_func,
+        finalize_func=self._finalize_func,
+        output_types=nest.flatten(
+            sparse.as_dense_types(self.output_types, self.output_classes)),
+        output_shapes=nest.flatten(
+            sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+
+
 class GroupByWindowDataset(dataset_ops.Dataset):
   """A `Dataset` that groups its input and performs a windowed reduction."""
 
@@ -336,3 +610,30 @@ class GroupByWindowDataset(dataset_ops.Dataset):
             sparse.as_dense_types(self.output_types, self.output_classes)),
         output_shapes=nest.flatten(
             sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+
+
+class Reducer(object):
+  """A reducer is used for reducing a set of elements.
+
+  A reducer is represented as a tuple of the three functions:
+    1) initialization function: key => initial state
+    2) reduce function: (old state, input) => new state
+    3) finalization function: state => result
+  """
+
+  def __init__(self, init_func, reduce_func, finalize_func):
+    self._init_func = init_func
+    self._reduce_func = reduce_func
+    self._finalize_func = finalize_func
+
+  @property
+  def init_func(self):
+    return self._init_func
+
+  @property
+  def reduce_func(self):
+    return self._reduce_func
+
+  @property
+  def finalize_func(self):
+    return self._finalize_func
diff --git a/tensorflow/core/api_def/base_api/api_def_GroupByReducerDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_GroupByReducerDataset.pbtxt
new file mode 100644 (file)
index 0000000..067ad40
--- /dev/null
@@ -0,0 +1,69 @@
+op {
+  graph_op_name: "GroupByReducerDataset"
+  visibility: HIDDEN
+  in_arg {
+    name: "input_dataset"
+    description: <<END
+A variant tensor representing the input dataset.
+END
+  }
+  in_arg {
+    name: "key_func_other_arguments"
+    description: <<END
+A list of tensors, typically values that were captured when
+building a closure for `key_func`.
+END
+  }
+  attr {
+    name: "key_func"
+    description: <<END
+A function mapping an element of `input_dataset`, concatenated
+with `key_func_other_arguments` to a scalar value of type DT_INT64.
+END
+  }
+  in_arg {
+    name: "init_func_other_arguments"
+    description: <<END
+A list of tensors, typically values that were captured when
+building a closure for `init_func`.
+END
+  }
+  attr {
+    name: "init_func"
+    description: <<END
+A function mapping a key of type DT_INT64, concatenated with
+`init_func_other_arguments` to the initial reducer state.
+END
+  }
+  in_arg {
+    name: "reduce_func_other_arguments"
+    description: <<END
+A list of tensors, typically values that were captured when
+building a closure for `reduce_func`.
+END
+  }
+  attr {
+    name: "reduce_func"
+    description: <<END
+A function mapping the current reducer state and an element of `input_dataset`,
+concatenated with `reduce_func_other_arguments` to a new reducer state.
+END
+  }
+  in_arg {
+    name: "finalize_func_other_arguments"
+    description: <<END
+A list of tensors, typically values that were captured when
+building a closure for `finalize_func`.
+END
+  }
+  attr {
+    name: "finalize_func"
+    description: <<END
+A function mapping the final reducer state to an output element.
+END
+  }
+  summary: "Creates a dataset that computes a group-by on `input_dataset`."
+  description: <<END
+Creates a dataset that computes a group-by on `input_dataset`.
+END
+}
index c78e0af..9ded266 100644 (file)
@@ -124,6 +124,20 @@ tf_kernel_library(
 )
 
 tf_kernel_library(
+    name = "group_by_reducer_dataset_op",
+    srcs = ["group_by_reducer_dataset_op.cc"],
+    deps = [
+        ":captured_function",
+        ":dataset",
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:dataset_ops_op_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
     name = "group_by_window_dataset_op",
     srcs = ["group_by_window_dataset_op.cc"],
     deps = [
@@ -550,6 +564,7 @@ tf_kernel_library(
         ":filter_dataset_op",
         ":flat_map_dataset_op",
         ":generator_dataset_op",
+        ":group_by_reducer_dataset_op",
         ":group_by_window_dataset_op",
         ":interleave_dataset_op",
         ":iterator_ops",
index dd61b7d..ee58341 100644 (file)
@@ -32,6 +32,20 @@ Status CapturedFunction::Create(
   return Status::OK();
 }
 
+/* static */
+Status CapturedFunction::Create(
+    const NameAttrList& func, OpKernelContext* ctx, const string& argument,
+    std::unique_ptr<CapturedFunction>* out_function) {
+  OpInputList argument_inputs;
+  TF_RETURN_IF_ERROR(ctx->input_list(argument, &argument_inputs));
+  std::vector<Tensor> arguments_t;
+  arguments_t.reserve(argument_inputs.size());
+  for (const Tensor& t : argument_inputs) {
+    arguments_t.push_back(t);
+  }
+  return CapturedFunction::Create(func, std::move(arguments_t), out_function);
+}
+
 CapturedFunction::~CapturedFunction() {
   if (lib_ != nullptr && f_handle_ != kInvalidHandle) {
     lib_->ReleaseHandle(f_handle_).IgnoreError();
index 490f5cd..e9ad3e3 100644 (file)
@@ -40,12 +40,20 @@ class ResourceMgr;
 // context.
 class CapturedFunction {
  public:
+  // Creates a new instance from a list of named attributes and captured inputs.
+  //
   // NOTE(mrry): The `captured_inputs` are passed by value. For
   // efficiency, you are recommended to move this argument into the call.
   static Status Create(const NameAttrList& func,
                        std::vector<Tensor> captured_inputs,
                        std::unique_ptr<CapturedFunction>* out_function);
 
+  // Creates a new instance using a list of named attributes, fetching captured
+  // inputs from a context argument.
+  static Status Create(const NameAttrList& func, OpKernelContext* ctx,
+                       const string& argument,
+                       std::unique_ptr<CapturedFunction>* out_function);
+
   ~CapturedFunction();
 
   // Runs the "Captured function" using the given FLR and caches the lib and
@@ -87,6 +95,9 @@ class CapturedFunction {
                 std::vector<Tensor>* rets,
                 FunctionLibraryRuntime::DoneCallback done);
 
+  // Returns the named list of function arguments.
+  const NameAttrList& func() { return func_; }
+
   // Returns that additional captured inputs that will be passed to the function
   // when `Run*()` is called.
   const std::vector<Tensor>& captured_inputs() { return captured_inputs_; }
diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
new file mode 100644 (file)
index 0000000..c8aeaab
--- /dev/null
@@ -0,0 +1,422 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <map>
+
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/data/captured_function.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/lib/random/random.h"
+
+namespace tensorflow {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
+ public:
+  explicit GroupByReducerDatasetOp(OpKernelConstruction* ctx)
+      : UnaryDatasetOpKernel(ctx),
+        graph_def_version_(ctx->graph_def_version()) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+  }
+
+  void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+                   DatasetBase** output) override {
+    std::unique_ptr<CapturedFunction> captured_key_func;
+    OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx,
+                                                 "key_func_other_arguments",
+                                                 &captured_key_func));
+    std::unique_ptr<CapturedFunction> captured_init_func;
+    OP_REQUIRES_OK(ctx, CapturedFunction::Create(init_func_, ctx,
+                                                 "init_func_other_arguments",
+                                                 &captured_init_func));
+    std::unique_ptr<CapturedFunction> captured_reduce_func;
+    OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx,
+                                                 "reduce_func_other_arguments",
+                                                 &captured_reduce_func));
+    std::unique_ptr<CapturedFunction> captured_finalize_func;
+    OP_REQUIRES_OK(ctx,
+                   CapturedFunction::Create(finalize_func_, ctx,
+                                            "finalize_func_other_arguments",
+                                            &captured_finalize_func));
+
+    *output = new Dataset(
+        ctx, input, std::move(captured_key_func), std::move(captured_init_func),
+        std::move(captured_reduce_func), std::move(captured_finalize_func),
+        output_types_, output_shapes_);
+  }
+
+ private:
+  class Dataset : public GraphDatasetBase {
+   public:
+    Dataset(OpKernelContext* ctx, const DatasetBase* input,
+            std::unique_ptr<CapturedFunction> captured_key_func,
+            std::unique_ptr<CapturedFunction> captured_init_func,
+            std::unique_ptr<CapturedFunction> captured_reduce_func,
+            std::unique_ptr<CapturedFunction> captured_finalize_func,
+            const DataTypeVector& output_types,
+            const std::vector<PartialTensorShape>& output_shapes)
+        : GraphDatasetBase(ctx),
+          input_(input),
+          captured_key_func_(std::move(captured_key_func)),
+          captured_init_func_(std::move(captured_init_func)),
+          captured_reduce_func_(std::move(captured_reduce_func)),
+          captured_finalize_func_(std::move(captured_finalize_func)),
+          output_types_(output_types),
+          output_shapes_(output_shapes) {
+      input_->Ref();
+    }
+
+    ~Dataset() override { input_->Unref(); }
+
+    std::unique_ptr<IteratorBase> MakeIterator(
+        const string& prefix) const override {
+      return std::unique_ptr<IteratorBase>(
+          new Iterator({this, strings::StrCat(prefix, "::GroupByReducer")}));
+    }
+
+    const DataTypeVector& output_dtypes() const override {
+      return output_types_;
+    }
+    const std::vector<PartialTensorShape>& output_shapes() const override {
+      return output_shapes_;
+    }
+
+    string DebugString() override { return "GroupByReducerDatasetOp::Dataset"; }
+
+   protected:
+    Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+                              Node** output) const override {
+      TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func().name()));
+      TF_RETURN_IF_ERROR(b->AddFunction(ctx, init_func().name()));
+      TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func().name()));
+      TF_RETURN_IF_ERROR(b->AddFunction(ctx, finalize_func().name()));
+      Node* input_graph_node = nullptr;
+      TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
+
+      std::vector<Node*> key_func_other_arguments_node;
+      DataTypeVector key_func_other_arguments_types;
+      TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
+          b, captured_key_func_, &key_func_other_arguments_node,
+          &key_func_other_arguments_types));
+
+      std::vector<Node*> init_func_other_arguments_node;
+      DataTypeVector init_func_other_arguments_types;
+      TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
+          b, captured_init_func_, &init_func_other_arguments_node,
+          &init_func_other_arguments_types));
+
+      std::vector<Node*> reduce_func_other_arguments_node;
+      DataTypeVector reduce_func_other_arguments_types;
+      TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
+          b, captured_reduce_func_, &reduce_func_other_arguments_node,
+          &reduce_func_other_arguments_types));
+
+      std::vector<Node*> finalize_func_other_arguments_node;
+      DataTypeVector finalize_func_other_arguments_types;
+      TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
+          b, captured_finalize_func_, &finalize_func_other_arguments_node,
+          &finalize_func_other_arguments_types));
+
+      AttrValue key_func;
+      b->BuildAttrValue(this->key_func(), &key_func);
+      AttrValue init_func;
+      b->BuildAttrValue(this->init_func(), &init_func);
+      AttrValue reduce_func;
+      b->BuildAttrValue(this->reduce_func(), &reduce_func);
+      AttrValue finalize_func;
+      b->BuildAttrValue(this->finalize_func(), &finalize_func);
+
+      AttrValue key_func_other_arguments_types_attr;
+      b->BuildAttrValue(key_func_other_arguments_types,
+                        &key_func_other_arguments_types_attr);
+      AttrValue init_func_other_arguments_types_attr;
+      b->BuildAttrValue(init_func_other_arguments_types,
+                        &init_func_other_arguments_types_attr);
+      AttrValue reduce_func_other_arguments_types_attr;
+      b->BuildAttrValue(reduce_func_other_arguments_types,
+                        &reduce_func_other_arguments_types_attr);
+      AttrValue finalize_func_other_arguments_types_attr;
+      b->BuildAttrValue(finalize_func_other_arguments_types,
+                        &finalize_func_other_arguments_types_attr);
+
+      TF_RETURN_IF_ERROR(b->AddDataset(
+          this, {{0, input_graph_node}},
+          {{1, key_func_other_arguments_node},
+           {2, init_func_other_arguments_node},
+           {3, reduce_func_other_arguments_node},
+           {4, finalize_func_other_arguments_node}},
+          {{"key_func", key_func},
+           {"init_func", init_func},
+           {"reduce_func", reduce_func},
+           {"finalize_func", finalize_func},
+           {"Tkey_func_other_arguments", key_func_other_arguments_types_attr},
+           {"Tinit_func_other_arguments", init_func_other_arguments_types_attr},
+           {"Treduce_func_other_arguments",
+            reduce_func_other_arguments_types_attr},
+           {"Tfinalize_func_other_arguments",
+            finalize_func_other_arguments_types_attr}},
+          output));
+      return Status::OK();
+    }
+
+   private:
+    class Iterator : public DatasetIterator<Dataset> {
+     public:
+      explicit Iterator(const Params& params)
+          : DatasetIterator<Dataset>(params),
+            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+
+      Status GetNextInternal(IteratorContext* ctx,
+                             std::vector<Tensor>* out_tensors,
+                             bool* end_of_sequence) override {
+        mutex_lock l(mu_);
+
+        // Iterate through the input dataset, keying input elements to reducers.
+        while (!end_of_input_) {
+          std::vector<Tensor> next_input_element;
+          TF_RETURN_IF_ERROR(
+              input_impl_->GetNext(ctx, &next_input_element, &end_of_input_));
+
+          if (!end_of_input_) {
+            // Run the key function on the input element.
+            std::vector<Tensor> key_func_output;
+            TF_RETURN_IF_ERROR(
+                dataset()->captured_key_func_->RunWithBorrowedArgs(
+                    ctx, next_input_element, &key_func_output));
+
+            if (key_func_output.size() != 1 ||
+                key_func_output[0].dtype() != DT_INT64 ||
+                key_func_output[0].NumElements() != 1) {
+              // TODO(b/78665031): Support non-int64 keys.
+              return errors::InvalidArgument(
+                  "`key_func` must return a scalar int64.");
+            }
+            const int64 key = key_func_output[0].scalar<int64>()();
+
+            if (states_.find(key) == states_.end()) {
+              // Run the init function to create the initial state.
+              std::vector<Tensor> init_func_output;
+              TF_RETURN_IF_ERROR(dataset()->captured_init_func_->Run(
+                  ctx, std::move(key_func_output), &init_func_output));
+              states_[key] = init_func_output;
+            }
+
+            // Run the reduce function to update the current state.
+            std::vector<Tensor> args;
+            args.reserve(states_[key].size() + next_input_element.size());
+            std::copy(states_[key].begin(), states_[key].end(),
+                      std::back_inserter(args));
+            std::copy(next_input_element.begin(), next_input_element.end(),
+                      std::back_inserter(args));
+
+            std::vector<Tensor> reduce_func_output;
+            TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Run(
+                ctx, std::move(args), &reduce_func_output));
+            states_[key] = reduce_func_output;
+          } else {
+            keys_.resize(states_.size());
+            int idx = 0;
+            for (auto it = states_.begin(); it != states_.end(); ++idx, ++it) {
+              keys_[idx] = it->first;
+            }
+          }
+        }
+
+        if (keys_index_ == keys_.size()) {
+          *end_of_sequence = true;
+          return Status::OK();
+        }
+        TF_RETURN_IF_ERROR(
+            dataset()->captured_finalize_func_->RunWithBorrowedArgs(
+                ctx, states_[keys_[keys_index_++]], out_tensors));
+        return Status::OK();
+      }
+
+     protected:
+      Status SaveInternal(IteratorStateWriter* writer) override {
+        mutex_lock l(mu_);
+        TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+
+        if (end_of_input_) {
+          TF_RETURN_IF_ERROR(
+              writer->WriteScalar(full_name("end_of_input"), ""));
+        }
+
+        // Saving states_.
+        if (!states_.empty()) {
+          TF_RETURN_IF_ERROR(
+              writer->WriteScalar(full_name("states_size"), states_.size()));
+          int idx = 0;
+          for (auto it = states_.begin(); it != states_.end(); ++idx, ++it) {
+            int64 key = it->first;
+            TF_RETURN_IF_ERROR(writer->WriteScalar(
+                full_name(strings::StrCat("states[", idx, "]->key")), key));
+            if (!it->second.empty()) {
+              TF_RETURN_IF_ERROR(writer->WriteScalar(
+                  full_name(strings::StrCat("states[", idx, "]->state_size")),
+                  it->second.size()));
+              for (int j = 0; j < it->second.size(); ++j) {
+                TF_RETURN_IF_ERROR(writer->WriteTensor(
+                    full_name(
+                        strings::StrCat("states[", idx, "]->state[", j, "]")),
+                    it->second[j]));
+              }
+            }
+          }
+        }
+
+        // Saving keys_index_ and keys_.
+        if (end_of_input_) {
+          TF_RETURN_IF_ERROR(
+              writer->WriteScalar(full_name("keys_index"), keys_index_));
+          if (!keys_.empty()) {
+            TF_RETURN_IF_ERROR(
+                writer->WriteScalar(full_name("keys_size"), keys_.size()));
+            for (int idx = 0; idx < keys_.size(); ++idx) {
+              TF_RETURN_IF_ERROR(writer->WriteScalar(
+                  full_name(strings::StrCat("keys[", idx, "]")), keys_[idx]));
+            }
+          }
+        }
+
+        return Status::OK();
+      }
+
+      Status RestoreInternal(IteratorContext* ctx,
+                             IteratorStateReader* reader) override {
+        mutex_lock l(mu_);
+        TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+
+        if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
+
+        // Restoring states_.
+        if (reader->Contains(full_name("states_size"))) {
+          int64 size;
+          TF_RETURN_IF_ERROR(
+              reader->ReadScalar(full_name("states_size"), &size));
+          for (int idx = 0; idx < size; ++idx) {
+            int64 key;
+            TF_RETURN_IF_ERROR(reader->ReadScalar(
+                full_name(strings::StrCat("states[", idx, "]->key")), &key));
+            std::vector<Tensor> state;
+            if (reader->Contains(full_name(
+                    strings::StrCat("states[", idx, "]->state_size")))) {
+              int64 state_size;
+              TF_RETURN_IF_ERROR(reader->ReadScalar(
+                  full_name(strings::StrCat("states[", idx, "]->state_size")),
+                  &state_size));
+              state.resize(state_size);
+              for (int j = 0; j < state_size; ++j) {
+                TF_RETURN_IF_ERROR(reader->ReadTensor(
+                    full_name(
+                        strings::StrCat("states[", idx, "]->state[", j, "]")),
+                    &state[j]));
+              }
+            }
+            states_[key] = state;
+          }
+        }
+
+        // Restoring keys_index_ and keys_.
+        if (end_of_input_) {
+          TF_RETURN_IF_ERROR(
+              reader->ReadScalar(full_name("keys_index"), &keys_index_));
+          if (reader->Contains(full_name("keys_size"))) {
+            int64 size;
+            TF_RETURN_IF_ERROR(
+                reader->ReadScalar(full_name("keys_size"), &size));
+            keys_.resize(size);
+            for (int idx = 0; idx < size; ++idx) {
+              int64 key;
+              TF_RETURN_IF_ERROR(reader->ReadScalar(
+                  full_name(strings::StrCat("keys[", idx, "]")), &key));
+              keys_[idx] = key;
+            }
+          }
+        }
+
+        return Status::OK();
+      }
+
+     private:
+      mutex mu_;
+      std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+      bool end_of_input_ GUARDED_BY(mu_) = false;
+      std::map<int64, std::vector<Tensor>> states_ GUARDED_BY(mu_);
+      std::vector<int64> keys_ GUARDED_BY(mu_);
+      int64 keys_index_ GUARDED_BY(mu_) = 0;
+    };
+
+    const NameAttrList& key_func() const { return captured_key_func_->func(); }
+
+    const NameAttrList& init_func() const {
+      return captured_init_func_->func();
+    }
+
+    const NameAttrList& reduce_func() const {
+      return captured_reduce_func_->func();
+    }
+
+    const NameAttrList& finalize_func() const {
+      return captured_finalize_func_->func();
+    }
+
+    Status OtherArgumentsNodeAndType(
+        DatasetGraphDefBuilder* b,
+        const std::unique_ptr<CapturedFunction>& captured_func,
+        std::vector<Node*>* other_arguments_node,
+        DataTypeVector* other_arguments_types) const {
+      other_arguments_node->reserve(captured_func->captured_inputs().size());
+      other_arguments_types->reserve(captured_func->captured_inputs().size());
+      for (const Tensor& t : captured_func->captured_inputs()) {
+        Node* node;
+        TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+        other_arguments_node->emplace_back(node);
+        other_arguments_types->emplace_back(t.dtype());
+      }
+      return Status::OK();
+    }
+
+    const DatasetBase* const input_;
+    const std::unique_ptr<CapturedFunction> captured_key_func_;
+    const std::unique_ptr<CapturedFunction> captured_init_func_;
+    const std::unique_ptr<CapturedFunction> captured_reduce_func_;
+    const std::unique_ptr<CapturedFunction> captured_finalize_func_;
+    const DataTypeVector output_types_;
+    const std::vector<PartialTensorShape> output_shapes_;
+  };
+
+  const int graph_def_version_;
+  DataTypeVector output_types_;
+  std::vector<PartialTensorShape> output_shapes_;
+  NameAttrList key_func_;
+  NameAttrList init_func_;
+  NameAttrList reduce_func_;
+  NameAttrList finalize_func_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("GroupByReducerDataset").Device(DEVICE_CPU),
+                        GroupByReducerDatasetOp);
+
+}  // namespace
+}  // namespace tensorflow
index 46f43dd..03f847c 100644 (file)
@@ -241,7 +241,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
               if (key_func_output.size() != 1 ||
                   key_func_output[0].dtype() != DT_INT64 ||
                   key_func_output[0].NumElements() != 1) {
-                // TODO(mrry): Support non-int64 keys.
+                // TODO(b/78665031): Support non-int64 keys.
                 return errors::InvalidArgument(
                     "`key_func` must return a scalar int64.");
               }
index 4ba3f15..5f10ad2 100644 (file)
@@ -270,6 +270,26 @@ REGISTER_OP("ParallelInterleaveDataset")
     .Attr("output_shapes: list(shape) >= 1")
     .SetShapeFn(shape_inference::ScalarShape);
 
+REGISTER_OP("GroupByReducerDataset")
+    .Input("input_dataset: variant")
+    .Input("key_func_other_arguments: Tkey_func_other_arguments")
+    .Input("init_func_other_arguments: Tinit_func_other_arguments")
+    .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
+    .Input("finalize_func_other_arguments: Tfinalize_func_other_arguments")
+    .Output("handle: variant")
+    .Attr("key_func: func")
+    .Attr("init_func: func")
+    .Attr("reduce_func: func")
+    .Attr("finalize_func: func")
+    .Attr("Tkey_func_other_arguments: list(type) >= 0")
+    .Attr("Tinit_func_other_arguments: list(type) >= 0")
+    .Attr("Treduce_func_other_arguments: list(type) >= 0")
+    .Attr("Tfinalize_func_other_arguments: list(type) >= 0")
+    .Attr("output_types: list(type) >= 1")
+    .Attr("output_shapes: list(shape) >= 1")
+    .SetIsStateful()
+    .SetShapeFn(shape_inference::ScalarShape);
+
 REGISTER_OP("GroupByWindowDataset")
     .Input("input_dataset: variant")
     .Input("key_func_other_arguments: Tkey_func_other_arguments")