From: Eskil Jörgensen Date: Mon, 11 Feb 2019 16:22:15 +0000 (-0800) Subject: Make pin_memory and default_collate preserve namedtuples (#16440) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1365 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=8042edcdb10e75b49a15ed7bf8c807ecde41d12b;p=platform%2Fupstream%2Fpytorch.git Make pin_memory and default_collate preserve namedtuples (#16440) Summary: Open issue: https://github.com/pytorch/pytorch/issues/3281 Corresponding PR (conflict): https://github.com/pytorch/pytorch/pull/4577 Another open issue: https://github.com/pytorch/pytorch/issues/14613 Pull Request resolved: https://github.com/pytorch/pytorch/pull/16440 Differential Revision: D14020901 Pulled By: ezyang fbshipit-source-id: 4abe817fc43c281a510715d311bad544511995d3 --- diff --git a/test/test_dataloader.py b/test/test_dataloader.py index ad7cfdc..50f5e60 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -956,6 +956,31 @@ class TestDictDataLoader(TestCase): self.assertTrue(sample['another_dict']['a_number'].is_pinned()) +class NamedTupleDataset(Dataset): + from collections import namedtuple + Batch = namedtuple('Batch', ['data', 'label']) + Data = namedtuple('Data', ['positive', 'negative']) + + def __len__(self): + return 4 + + def __getitem__(self, ndx): + return self.Batch(data=self.Data(positive=ndx, negative=-ndx), + label=str(ndx)) + + +class TestNamedTupleDataLoader(TestCase): + def setUp(self): + self.dataset = NamedTupleDataset() + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_collate_and_pin_memory_with_namedtuple(self): + loader = DataLoader(self.dataset, batch_size=2, pin_memory=True) + for batch in loader: + self.assertIsInstance(batch, NamedTupleDataset.Batch) + self.assertIsInstance(batch.data, NamedTupleDataset.Data) + + class SimpleCustomBatch: def __init__(self, data): transposed_data = list(zip(*data)) diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py index e46b2c9..bffaa1a 100644 --- a/torch/utils/data/_utils/collate.py +++ b/torch/utils/data/_utils/collate.py @@ -61,6 +61,8 @@ def default_collate(batch): return batch elif isinstance(batch[0], container_abcs.Mapping): return {key: default_collate([d[key] for d in batch]) for key in batch[0]} + elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'): # namedtuple + return type(batch[0])(*(default_collate(samples) for samples in zip(*batch))) elif isinstance(batch[0], container_abcs.Sequence): transposed = zip(*batch) return [default_collate(samples) for samples in transposed] diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py index 07022d1..f762aff 100644 --- a/torch/utils/data/_utils/pin_memory.py +++ b/torch/utils/data/_utils/pin_memory.py @@ -51,6 +51,8 @@ def pin_memory_batch(batch): return batch elif isinstance(batch, container_abcs.Mapping): return {k: pin_memory_batch(sample) for k, sample in batch.items()} + elif isinstance(batch, tuple) and hasattr(batch, '_fields'): # namedtuple + return type(batch)(*(pin_memory_batch(sample) for sample in batch)) elif isinstance(batch, container_abcs.Sequence): return [pin_memory_batch(sample) for sample in batch] elif hasattr(batch, "pin_memory"):