import torch.utils.data.datapipes as dp
import torch.utils.data.graph
import torch.utils.data.sharding
-from torch.testing._internal.common_utils import TestCase, run_tests
+from torch.testing._internal.common_utils import TestCase, run_tests, suppress_warnings
from torch.utils.data import (
DataLoader,
DataChunk,
with self.assertRaises(TypeError):
len(dp2)
-
+ @suppress_warnings # Suppress warning for lambda fn
def test_map_datapipe(self):
input_dp = IDP(range(10))
self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())
input_dp_nl = IDP_NoLen(range(10))
- map_dp_nl = input_dp_nl.map()
+ map_dp_nl = input_dp_nl.map(lambda x: x)
with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
len(map_dp_nl)
for x, y in zip(map_dp_nl, input_dp_nl):
self.assertEqual(x, torch.tensor(y, dtype=torch.float))
+ @suppress_warnings # Suppress warning for lambda fn
+ def test_map_tuple_list_with_col_datapipe(self):
+ def fn_11(d):
+ return -d
+
+ def fn_1n(d):
+ return -d, d
+
+ def fn_n1(d0, d1):
+ return d0 + d1
+
+ def fn_nn(d0, d1):
+ return -d0, -d1, d0 + d1
+
+ def _helper(ref_fn, fn, input_col=None, output_col=None):
+ for constr in (list, tuple):
+ datapipe = IDP([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
+ res_dp = datapipe.map(fn, input_col, output_col)
+ ref_dp = datapipe.map(ref_fn)
+ self.assertEqual(list(res_dp), list(ref_dp))
+ # Reset
+ self.assertEqual(list(res_dp), list(ref_dp))
+
+ # Replacing with one input column and default output column
+ _helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
+ _helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1)
+ # The index of input column is out of range
+ with self.assertRaises(IndexError):
+ _helper(None, fn_1n, 3)
+ # Unmatched input columns with fn arguments
+ with self.assertRaises(TypeError):
+ _helper(None, fn_n1, 1)
+ # Replacing with multiple input columns and default output column (the left-most input column)
+ _helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
+ _helper(lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])), fn_nn, [2, 1])
+
+ # output_col can only be specified when input_col is not None
+ with self.assertRaises(ValueError):
+ _helper(None, fn_n1, None, 1)
+ # output_col can only be single-element list or tuple
+ with self.assertRaises(ValueError):
+ _helper(None, fn_n1, None, [0, 1])
+ # Single-element list as output_col
+ _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0])
+ # Replacing with one input column and single specified output column
+ _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0)
+ _helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2)
+ # The index of output column is out of range
+ with self.assertRaises(IndexError):
+ _helper(None, fn_1n, 1, 3)
+ _helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1)
+ _helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0)
+
+ # Appending the output at the end
+ _helper(lambda data: (*data, -data[1]), fn_11, 1, -1)
+ _helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1)
+ _helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1)
+ _helper(lambda data: (*data, (-data[1], -data[2], data[1] + data[2])), fn_nn, [1, 2], -1)
+
+ @suppress_warnings # Suppress warning for lambda fn
+ def test_map_dict_with_col_datapipe(self):
+ def fn_11(d):
+ return -d
+
+ def fn_1n(d):
+ return -d, d
+
+ def fn_n1(d0, d1):
+ return d0 + d1
+
+ def fn_nn(d0, d1):
+ return -d0, -d1, d0 + d1
+
+ # Prevent modification in-place to support resetting
+ def _dict_update(data, newdata, remove_idx=None):
+ _data = dict(data)
+ _data.update(newdata)
+ if remove_idx:
+ for idx in remove_idx:
+ del _data[idx]
+ return _data
+
+ def _helper(ref_fn, fn, input_col=None, output_col=None):
+ datapipe = IDP([{"x": 0, "y": 1, "z": 2},
+ {"x": 3, "y": 4, "z": 5},
+ {"x": 6, "y": 7, "z": 8}])
+ res_dp = datapipe.map(fn, input_col, output_col)
+ ref_dp = datapipe.map(ref_fn)
+ self.assertEqual(list(res_dp), list(ref_dp))
+ # Reset
+ self.assertEqual(list(res_dp), list(ref_dp))
+
+ # Replacing with one input column and default output column
+ _helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
+ _helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y")
+ # The key of input column is not in dict
+ with self.assertRaises(KeyError):
+ _helper(None, fn_1n, "a")
+ # Unmatched input columns with fn arguments
+ with self.assertRaises(TypeError):
+ _helper(None, fn_n1, "y")
+ # Replacing with multiple input columns and default output column (the left-most input column)
+ _helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"])
+ _helper(lambda data: _dict_update(data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]), fn_nn, ["z", "y"])
+
+ # output_col can only be specified when input_col is not None
+ with self.assertRaises(ValueError):
+ _helper(None, fn_n1, None, "x")
+ # output_col can only be single-element list or tuple
+ with self.assertRaises(ValueError):
+ _helper(None, fn_n1, None, ["x", "y"])
+ # Single-element list as output_col
+ _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"])
+ # Replacing with one input column and single specified output column
+ _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x")
+ _helper(lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}), fn_1n, "y", "z")
+ _helper(lambda data: _dict_update(data, {"y": data["x"] + data["z"]}), fn_n1, ["x", "z"], "y")
+ _helper(lambda data: _dict_update(data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])}), fn_nn, ["y", "z"], "x")
+
+ # Adding new key to dict for the output
+ _helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a")
+ _helper(lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}), fn_1n, "y", "a")
+ _helper(lambda data: _dict_update(data, {"a": data["x"] + data["z"]}), fn_n1, ["x", "z"], "a")
+ _helper(lambda data: _dict_update(data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])}), fn_nn, ["y", "z"], "a")
+
# TODO(VitalyFedyunin): If dill installed this test fails
def _test_map_datapipe_nested_level(self):
+import copy
import warnings
from torch.utils.data import IterDataPipe, _utils, functional_datapipe, DataChunk
from typing import Callable, Dict, Iterator, Optional, Sized, Tuple, TypeVar
except ImportError:
DILL_AVAILABLE = False
-T_co = TypeVar('T_co', covariant=True)
+T_co = TypeVar("T_co", covariant=True)
-# Default function to return each item directly
-# In order to keep datapipe picklable, eliminates the usage
-# of python lambda function
-def default_fn(data):
- return data
-
-
-@functional_datapipe('map')
+@functional_datapipe("map")
class MapperIterDataPipe(IterDataPipe[T_co]):
- r""" :class:`MapperIterDataPipe`.
+ r""":class:`MapperIterDataPipe`.
Iterable DataPipe to run a function over each item from the source DataPipe.
The function can be any regular python function or partial object. Lambda
Args:
datapipe: Source Iterable DataPipe
fn: Function called over each item
+ input_col: Index or indices of data which `fn` is applied
+ - None as default to apply `fn` to the data directly.
+ - Integer(s) is used for list/tuple.
+ - Key(s) is used for dict.
+ output_col: Index of data where result of `fn` is placed. `output_col` can be specified only when `input_col` is not None
+ - None as default to replace the index that `input_col` specified;
+ For `input_col` with multiple indices, the left-most one is used, and other indices will be removed.
+ - Integer is used for list/tuple. -1 represents to append result at the end.
+ - Key is used for dict. New key is acceptable.
fn_args: Positional arguments for `fn`
fn_kwargs: Keyword arguments for `fn`
nesting_level: Determines which level the fn gets applied to, by default it applies to the top level (= 0).
datapipe: IterDataPipe
fn: Callable
- def __init__(self,
- datapipe: IterDataPipe,
- fn: Callable = default_fn,
- fn_args: Optional[Tuple] = None,
- fn_kwargs: Optional[Dict] = None,
- nesting_level: int = 0,
- ) -> None:
+ def __init__(
+ self,
+ datapipe: IterDataPipe,
+ fn: Callable,
+ input_col=None,
+ output_col=None,
+ *,
+ fn_args: Optional[Tuple] = None,
+ fn_kwargs: Optional[Dict] = None,
+ nesting_level: int = 0,
+ ) -> None:
super().__init__()
self.datapipe = datapipe
# Partial object has no attribute '__name__', but can be pickled
- if hasattr(fn, '__name__') and fn.__name__ == '<lambda>' and not DILL_AVAILABLE:
- warnings.warn("Lambda function is not supported for pickle, please use "
- "regular python function or functools.partial instead.")
+ if hasattr(fn, "__name__") and fn.__name__ == "<lambda>" and not DILL_AVAILABLE:
+ warnings.warn(
+ "Lambda function is not supported for pickle, please use "
+ "regular python function or functools.partial instead."
+ )
self.fn = fn # type: ignore[assignment]
+ self.input_col = input_col
+ if input_col is None and output_col is not None:
+ raise ValueError("`output_col` must be None when `input_col` is None.")
+ if isinstance(output_col, (list, tuple)):
+ if len(output_col) > 1:
+ raise ValueError("`output_col` must be a single-element list or tuple")
+ output_col = output_col[0]
+ self.output_col = output_col
self.args = () if fn_args is None else fn_args
self.kwargs = {} if fn_kwargs is None else fn_kwargs
if nesting_level < -1:
raise ValueError("nesting_level must be -1 or >= 0")
self.nesting_level = nesting_level
+ def _apply_fn(self, data):
+ if self.input_col is None and self.output_col is None:
+ return self.fn(data, *self.args, **self.kwargs)
+
+ if self.input_col is None:
+ res = self.fn(data, *self.args, **self.kwargs)
+ elif isinstance(self.input_col, (list, tuple)):
+ args = tuple(data[col] for col in self.input_col)
+ res = self.fn(*args, *self.args, **self.kwargs)
+ else:
+ res = self.fn(data[self.input_col], *self.args, **self.kwargs)
+
+ # Copy tuple to list and run in-place modification because tuple is immutable.
+ if isinstance(data, tuple):
+ t_flag = True
+ data = list(data)
+ else:
+ t_flag = False
+ # Deepcopy data to prevent the original data modified. E.g. list, dict
+ data = copy.deepcopy(data)
+
+ if self.output_col is None:
+ if isinstance(self.input_col, (list, tuple)):
+ data[self.input_col[0]] = res
+ for idx in sorted(self.input_col[1:], reverse=True):
+ del data[idx]
+ else:
+ data[self.input_col] = res
+ else:
+ if self.output_col == -1:
+ data.append(res)
+ else:
+ data[self.output_col] = res
+
+ # Convert list back to tuple
+ return tuple(data) if t_flag else data
+
def _apply(self, data, nesting_level):
if nesting_level == 0:
- return self.fn(data, *self.args, **self.kwargs)
+ return self._apply_fn(data)
elif nesting_level > 0:
if isinstance(data, DataChunk):
- return type(data)([self._apply(i, nesting_level - 1) for i in data.raw_iterator()])
+ return type(data)(
+ [self._apply(i, nesting_level - 1) for i in data.raw_iterator()]
+ )
elif isinstance(data, list):
return [self._apply(i, nesting_level - 1) for i in data]
else:
- raise IndexError(f"nesting_level {self.nesting_level} out of range (exceeds data pipe depth)")
+ raise IndexError(
+ f"nesting_level {self.nesting_level} out of range (exceeds data pipe depth)"
+ )
else:
if isinstance(data, DataChunk):
- return type(data)([self._apply(i, nesting_level) for i in data.raw_iterator()])
+ return type(data)(
+ [self._apply(i, nesting_level) for i in data.raw_iterator()]
+ )
elif isinstance(data, list):
return [self._apply(i, nesting_level) for i in data]
else:
- return self.fn(data, *self.args, **self.kwargs)
+ return self._apply_fn(data)
def __iter__(self) -> Iterator[T_co]:
for data in self.datapipe:
def __len__(self) -> int:
if isinstance(self.datapipe, Sized):
return len(self.datapipe)
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+ raise TypeError(
+ "{} instance doesn't have valid length".format(type(self).__name__)
+ )
def __getstate__(self):
if DILL_AVAILABLE:
dill_function = dill.dumps(self.fn)
else:
dill_function = self.fn
- state = (self.datapipe, dill_function, self.args, self.kwargs, self.nesting_level)
+ state = (
+ self.datapipe,
+ dill_function,
+ self.input_col,
+ self.output_col,
+ self.args,
+ self.kwargs,
+ self.nesting_level,
+ )
return state
def __setstate__(self, state):
- (self.datapipe, dill_function, self.args, self.kwargs, self.nesting_level) = state
+ (
+ self.datapipe,
+ dill_function,
+ self.input_col,
+ self.output_col,
+ self.args,
+ self.kwargs,
+ self.nesting_level,
+ ) = state
if DILL_AVAILABLE:
self.fn = dill.loads(dill_function) # type: ignore[assignment]
else:
self.fn = dill_function # type: ignore[assignment]
-@functional_datapipe('collate')
+@functional_datapipe("collate")
class CollatorIterDataPipe(MapperIterDataPipe):
- r""" :class:`CollatorIterDataPipe`.
+ r""":class:`CollatorIterDataPipe`.
Iterable DataPipe to collate samples from datapipe to Tensor(s) by `util_.collate.default_collate`,
or customized Data Structure by collate_fn.
[tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
"""
- def __init__(self,
- datapipe: IterDataPipe,
- collate_fn: Callable = _utils.collate.default_collate,
- fn_args: Optional[Tuple] = None,
- fn_kwargs: Optional[Dict] = None,
- ) -> None:
+ def __init__(
+ self,
+ datapipe: IterDataPipe,
+ collate_fn: Callable = _utils.collate.default_collate,
+ fn_args: Optional[Tuple] = None,
+ fn_kwargs: Optional[Dict] = None,
+ ) -> None:
super().__init__(datapipe, fn=collate_fn, fn_args=fn_args, fn_kwargs=fn_kwargs)