[DataPipe] Improve Mapper to accept input/output index when apply fn (#64951)
authorErjia Guan <erjia@fb.com>
Tue, 14 Sep 2021 22:44:57 +0000 (15:44 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 14 Sep 2021 22:46:42 +0000 (15:46 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64951

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D30910035

Pulled By: ejguan

fbshipit-source-id: d687fe10939920a3617a60552fe743e8526438a0

test/test_datapipe.py
torch/utils/data/datapipes/iter/callable.py

index 4d5bd7c..f19a2e0 100644 (file)
@@ -38,7 +38,7 @@ import torch.utils.data.backward_compatibility
 import torch.utils.data.datapipes as dp
 import torch.utils.data.graph
 import torch.utils.data.sharding
-from torch.testing._internal.common_utils import TestCase, run_tests
+from torch.testing._internal.common_utils import TestCase, run_tests, suppress_warnings
 from torch.utils.data import (
     DataLoader,
     DataChunk,
@@ -902,7 +902,7 @@ class TestFunctionalIterDataPipe(TestCase):
         with self.assertRaises(TypeError):
             len(dp2)
 
-
+    @suppress_warnings  # Suppress warning for lambda fn
     def test_map_datapipe(self):
         input_dp = IDP(range(10))
 
@@ -927,12 +927,137 @@ class TestFunctionalIterDataPipe(TestCase):
             self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())
 
         input_dp_nl = IDP_NoLen(range(10))
-        map_dp_nl = input_dp_nl.map()
+        map_dp_nl = input_dp_nl.map(lambda x: x)
         with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
             len(map_dp_nl)
         for x, y in zip(map_dp_nl, input_dp_nl):
             self.assertEqual(x, torch.tensor(y, dtype=torch.float))
 
+    @suppress_warnings  # Suppress warning for lambda fn
+    def test_map_tuple_list_with_col_datapipe(self):
+        def fn_11(d):
+            return -d
+
+        def fn_1n(d):
+            return -d, d
+
+        def fn_n1(d0, d1):
+            return d0 + d1
+
+        def fn_nn(d0, d1):
+            return -d0, -d1, d0 + d1
+
+        def _helper(ref_fn, fn, input_col=None, output_col=None):
+            for constr in (list, tuple):
+                datapipe = IDP([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
+                res_dp = datapipe.map(fn, input_col, output_col)
+                ref_dp = datapipe.map(ref_fn)
+                self.assertEqual(list(res_dp), list(ref_dp))
+                # Reset
+                self.assertEqual(list(res_dp), list(ref_dp))
+
+        # Replacing with one input column and default output column
+        _helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
+        _helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1)
+        # The index of input column is out of range
+        with self.assertRaises(IndexError):
+            _helper(None, fn_1n, 3)
+        # Unmatched input columns with fn arguments
+        with self.assertRaises(TypeError):
+            _helper(None, fn_n1, 1)
+        # Replacing with multiple input columns and default output column (the left-most input column)
+        _helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
+        _helper(lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])), fn_nn, [2, 1])
+
+        # output_col can only be specified when input_col is not None
+        with self.assertRaises(ValueError):
+            _helper(None, fn_n1, None, 1)
+        # output_col can only be single-element list or tuple
+        with self.assertRaises(ValueError):
+            _helper(None, fn_n1, None, [0, 1])
+        # Single-element list as output_col
+        _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0])
+        # Replacing with one input column and single specified output column
+        _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0)
+        _helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2)
+        # The index of output column is out of range
+        with self.assertRaises(IndexError):
+            _helper(None, fn_1n, 1, 3)
+        _helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1)
+        _helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0)
+
+        # Appending the output at the end
+        _helper(lambda data: (*data, -data[1]), fn_11, 1, -1)
+        _helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1)
+        _helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1)
+        _helper(lambda data: (*data, (-data[1], -data[2], data[1] + data[2])), fn_nn, [1, 2], -1)
+
+    @suppress_warnings  # Suppress warning for lambda fn
+    def test_map_dict_with_col_datapipe(self):
+        def fn_11(d):
+            return -d
+
+        def fn_1n(d):
+            return -d, d
+
+        def fn_n1(d0, d1):
+            return d0 + d1
+
+        def fn_nn(d0, d1):
+            return -d0, -d1, d0 + d1
+
+        # Prevent modification in-place to support resetting
+        def _dict_update(data, newdata, remove_idx=None):
+            _data = dict(data)
+            _data.update(newdata)
+            if remove_idx:
+                for idx in remove_idx:
+                    del _data[idx]
+            return _data
+
+        def _helper(ref_fn, fn, input_col=None, output_col=None):
+            datapipe = IDP([{"x": 0, "y": 1, "z": 2},
+                            {"x": 3, "y": 4, "z": 5},
+                            {"x": 6, "y": 7, "z": 8}])
+            res_dp = datapipe.map(fn, input_col, output_col)
+            ref_dp = datapipe.map(ref_fn)
+            self.assertEqual(list(res_dp), list(ref_dp))
+            # Reset
+            self.assertEqual(list(res_dp), list(ref_dp))
+
+        # Replacing with one input column and default output column
+        _helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
+        _helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y")
+        # The key of input column is not in dict
+        with self.assertRaises(KeyError):
+            _helper(None, fn_1n, "a")
+        # Unmatched input columns with fn arguments
+        with self.assertRaises(TypeError):
+            _helper(None, fn_n1, "y")
+        # Replacing with multiple input columns and default output column (the left-most input column)
+        _helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"])
+        _helper(lambda data: _dict_update(data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]), fn_nn, ["z", "y"])
+
+        # output_col can only be specified when input_col is not None
+        with self.assertRaises(ValueError):
+            _helper(None, fn_n1, None, "x")
+        # output_col can only be single-element list or tuple
+        with self.assertRaises(ValueError):
+            _helper(None, fn_n1, None, ["x", "y"])
+        # Single-element list as output_col
+        _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"])
+        # Replacing with one input column and single specified output column
+        _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x")
+        _helper(lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}), fn_1n, "y", "z")
+        _helper(lambda data: _dict_update(data, {"y": data["x"] + data["z"]}), fn_n1, ["x", "z"], "y")
+        _helper(lambda data: _dict_update(data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])}), fn_nn, ["y", "z"], "x")
+
+        # Adding new key to dict for the output
+        _helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a")
+        _helper(lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}), fn_1n, "y", "a")
+        _helper(lambda data: _dict_update(data, {"a": data["x"] + data["z"]}), fn_n1, ["x", "z"], "a")
+        _helper(lambda data: _dict_update(data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])}), fn_nn, ["y", "z"], "a")
+
     # TODO(VitalyFedyunin): If dill installed this test fails
     def _test_map_datapipe_nested_level(self):
 
index 2c5ca3d..e7f8184 100644 (file)
@@ -1,3 +1,4 @@
+import copy
 import warnings
 from torch.utils.data import IterDataPipe, _utils, functional_datapipe, DataChunk
 from typing import Callable, Dict, Iterator, Optional, Sized, Tuple, TypeVar
@@ -14,19 +15,12 @@ try:
 except ImportError:
     DILL_AVAILABLE = False
 
-T_co = TypeVar('T_co', covariant=True)
+T_co = TypeVar("T_co", covariant=True)
 
 
-# Default function to return each item directly
-# In order to keep datapipe picklable, eliminates the usage
-# of python lambda function
-def default_fn(data):
-    return data
-
-
-@functional_datapipe('map')
+@functional_datapipe("map")
 class MapperIterDataPipe(IterDataPipe[T_co]):
-    r""" :class:`MapperIterDataPipe`.
+    r""":class:`MapperIterDataPipe`.
 
     Iterable DataPipe to run a function over each item from the source DataPipe.
     The function can be any regular python function or partial object. Lambda
@@ -35,6 +29,15 @@ class MapperIterDataPipe(IterDataPipe[T_co]):
     Args:
         datapipe: Source Iterable DataPipe
         fn: Function called over each item
+        input_col: Index or indices of data which `fn` is applied
+            - None as default to apply `fn` to the data directly.
+            - Integer(s) is used for list/tuple.
+            - Key(s) is used for dict.
+        output_col: Index of data where result of `fn` is placed. `output_col` can be specified only when `input_col` is not None
+            - None as default to replace the index that `input_col` specified;
+              For `input_col` with multiple indices, the left-most one is used, and other indices will be removed.
+            - Integer is used for list/tuple. -1 represents to append result at the end.
+            - Key is used for dict. New key is acceptable.
         fn_args: Positional arguments for `fn`
         fn_kwargs: Keyword arguments for `fn`
         nesting_level: Determines which level the fn gets applied to, by default it applies to the top level (= 0).
@@ -44,43 +47,100 @@ class MapperIterDataPipe(IterDataPipe[T_co]):
     datapipe: IterDataPipe
     fn: Callable
 
-    def __init__(self,
-                 datapipe: IterDataPipe,
-                 fn: Callable = default_fn,
-                 fn_args: Optional[Tuple] = None,
-                 fn_kwargs: Optional[Dict] = None,
-                 nesting_level: int = 0,
-                 ) -> None:
+    def __init__(
+        self,
+        datapipe: IterDataPipe,
+        fn: Callable,
+        input_col=None,
+        output_col=None,
+        *,
+        fn_args: Optional[Tuple] = None,
+        fn_kwargs: Optional[Dict] = None,
+        nesting_level: int = 0,
+    ) -> None:
         super().__init__()
         self.datapipe = datapipe
         # Partial object has no attribute '__name__', but can be pickled
-        if hasattr(fn, '__name__') and fn.__name__ == '<lambda>' and not DILL_AVAILABLE:
-            warnings.warn("Lambda function is not supported for pickle, please use "
-                          "regular python function or functools.partial instead.")
+        if hasattr(fn, "__name__") and fn.__name__ == "<lambda>" and not DILL_AVAILABLE:
+            warnings.warn(
+                "Lambda function is not supported for pickle, please use "
+                "regular python function or functools.partial instead."
+            )
         self.fn = fn  # type: ignore[assignment]
+        self.input_col = input_col
+        if input_col is None and output_col is not None:
+            raise ValueError("`output_col` must be None when `input_col` is None.")
+        if isinstance(output_col, (list, tuple)):
+            if len(output_col) > 1:
+                raise ValueError("`output_col` must be a single-element list or tuple")
+            output_col = output_col[0]
+        self.output_col = output_col
         self.args = () if fn_args is None else fn_args
         self.kwargs = {} if fn_kwargs is None else fn_kwargs
         if nesting_level < -1:
             raise ValueError("nesting_level must be -1 or >= 0")
         self.nesting_level = nesting_level
 
+    def _apply_fn(self, data):
+        if self.input_col is None and self.output_col is None:
+            return self.fn(data, *self.args, **self.kwargs)
+
+        if self.input_col is None:
+            res = self.fn(data, *self.args, **self.kwargs)
+        elif isinstance(self.input_col, (list, tuple)):
+            args = tuple(data[col] for col in self.input_col)
+            res = self.fn(*args, *self.args, **self.kwargs)
+        else:
+            res = self.fn(data[self.input_col], *self.args, **self.kwargs)
+
+        # Copy tuple to list and run in-place modification because tuple is immutable.
+        if isinstance(data, tuple):
+            t_flag = True
+            data = list(data)
+        else:
+            t_flag = False
+            # Deepcopy data to prevent the original data modified. E.g. list, dict
+            data = copy.deepcopy(data)
+
+        if self.output_col is None:
+            if isinstance(self.input_col, (list, tuple)):
+                data[self.input_col[0]] = res
+                for idx in sorted(self.input_col[1:], reverse=True):
+                    del data[idx]
+            else:
+                data[self.input_col] = res
+        else:
+            if self.output_col == -1:
+                data.append(res)
+            else:
+                data[self.output_col] = res
+
+        # Convert list back to tuple
+        return tuple(data) if t_flag else data
+
     def _apply(self, data, nesting_level):
         if nesting_level == 0:
-            return self.fn(data, *self.args, **self.kwargs)
+            return self._apply_fn(data)
         elif nesting_level > 0:
             if isinstance(data, DataChunk):
-                return type(data)([self._apply(i, nesting_level - 1) for i in data.raw_iterator()])
+                return type(data)(
+                    [self._apply(i, nesting_level - 1) for i in data.raw_iterator()]
+                )
             elif isinstance(data, list):
                 return [self._apply(i, nesting_level - 1) for i in data]
             else:
-                raise IndexError(f"nesting_level {self.nesting_level} out of range (exceeds data pipe depth)")
+                raise IndexError(
+                    f"nesting_level {self.nesting_level} out of range (exceeds data pipe depth)"
+                )
         else:
             if isinstance(data, DataChunk):
-                return type(data)([self._apply(i, nesting_level) for i in data.raw_iterator()])
+                return type(data)(
+                    [self._apply(i, nesting_level) for i in data.raw_iterator()]
+                )
             elif isinstance(data, list):
                 return [self._apply(i, nesting_level) for i in data]
             else:
-                return self.fn(data, *self.args, **self.kwargs)
+                return self._apply_fn(data)
 
     def __iter__(self) -> Iterator[T_co]:
         for data in self.datapipe:
@@ -89,27 +149,45 @@ class MapperIterDataPipe(IterDataPipe[T_co]):
     def __len__(self) -> int:
         if isinstance(self.datapipe, Sized):
             return len(self.datapipe)
-        raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+        raise TypeError(
+            "{} instance doesn't have valid length".format(type(self).__name__)
+        )
 
     def __getstate__(self):
         if DILL_AVAILABLE:
             dill_function = dill.dumps(self.fn)
         else:
             dill_function = self.fn
-        state = (self.datapipe, dill_function, self.args, self.kwargs, self.nesting_level)
+        state = (
+            self.datapipe,
+            dill_function,
+            self.input_col,
+            self.output_col,
+            self.args,
+            self.kwargs,
+            self.nesting_level,
+        )
         return state
 
     def __setstate__(self, state):
-        (self.datapipe, dill_function, self.args, self.kwargs, self.nesting_level) = state
+        (
+            self.datapipe,
+            dill_function,
+            self.input_col,
+            self.output_col,
+            self.args,
+            self.kwargs,
+            self.nesting_level,
+        ) = state
         if DILL_AVAILABLE:
             self.fn = dill.loads(dill_function)  # type: ignore[assignment]
         else:
             self.fn = dill_function  # type: ignore[assignment]
 
 
-@functional_datapipe('collate')
+@functional_datapipe("collate")
 class CollatorIterDataPipe(MapperIterDataPipe):
-    r""" :class:`CollatorIterDataPipe`.
+    r""":class:`CollatorIterDataPipe`.
 
     Iterable DataPipe to collate samples from datapipe to Tensor(s) by `util_.collate.default_collate`,
     or customized Data Structure by collate_fn.
@@ -147,10 +225,11 @@ class CollatorIterDataPipe(MapperIterDataPipe):
         [tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
     """
 
-    def __init__(self,
-                 datapipe: IterDataPipe,
-                 collate_fn: Callable = _utils.collate.default_collate,
-                 fn_args: Optional[Tuple] = None,
-                 fn_kwargs: Optional[Dict] = None,
-                 ) -> None:
+    def __init__(
+        self,
+        datapipe: IterDataPipe,
+        collate_fn: Callable = _utils.collate.default_collate,
+        fn_args: Optional[Tuple] = None,
+        fn_kwargs: Optional[Dict] = None,
+    ) -> None:
         super().__init__(datapipe, fn=collate_fn, fn_args=fn_args, fn_kwargs=fn_kwargs)