Make pin_memory and default_collate preserve namedtuples (#16440)
authorEskil Jörgensen <eskil.jorgensen@gmail.com>
Mon, 11 Feb 2019 16:22:15 +0000 (08:22 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 11 Feb 2019 16:47:33 +0000 (08:47 -0800)
Summary:
Open issue: https://github.com/pytorch/pytorch/issues/3281
Corresponding PR (conflict): https://github.com/pytorch/pytorch/pull/4577

Another open issue: https://github.com/pytorch/pytorch/issues/14613
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16440

Differential Revision: D14020901

Pulled By: ezyang

fbshipit-source-id: 4abe817fc43c281a510715d311bad544511995d3

test/test_dataloader.py
torch/utils/data/_utils/collate.py
torch/utils/data/_utils/pin_memory.py

index ad7cfdc..50f5e60 100644 (file)
@@ -956,6 +956,31 @@ class TestDictDataLoader(TestCase):
             self.assertTrue(sample['another_dict']['a_number'].is_pinned())
 
 
+class NamedTupleDataset(Dataset):
+    from collections import namedtuple
+    Batch = namedtuple('Batch', ['data', 'label'])
+    Data = namedtuple('Data', ['positive', 'negative'])
+
+    def __len__(self):
+        return 4
+
+    def __getitem__(self, ndx):
+        return self.Batch(data=self.Data(positive=ndx, negative=-ndx),
+                          label=str(ndx))
+
+
+class TestNamedTupleDataLoader(TestCase):
+    def setUp(self):
+        self.dataset = NamedTupleDataset()
+
+    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+    def test_collate_and_pin_memory_with_namedtuple(self):
+        loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
+        for batch in loader:
+            self.assertIsInstance(batch, NamedTupleDataset.Batch)
+            self.assertIsInstance(batch.data, NamedTupleDataset.Data)
+
+
 class SimpleCustomBatch:
     def __init__(self, data):
         transposed_data = list(zip(*data))
index e46b2c9..bffaa1a 100644 (file)
@@ -61,6 +61,8 @@ def default_collate(batch):
         return batch
     elif isinstance(batch[0], container_abcs.Mapping):
         return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
+    elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'):  # namedtuple
+        return type(batch[0])(*(default_collate(samples) for samples in zip(*batch)))
     elif isinstance(batch[0], container_abcs.Sequence):
         transposed = zip(*batch)
         return [default_collate(samples) for samples in transposed]
index 07022d1..f762aff 100644 (file)
@@ -51,6 +51,8 @@ def pin_memory_batch(batch):
         return batch
     elif isinstance(batch, container_abcs.Mapping):
         return {k: pin_memory_batch(sample) for k, sample in batch.items()}
+    elif isinstance(batch, tuple) and hasattr(batch, '_fields'):  # namedtuple
+        return type(batch)(*(pin_memory_batch(sample) for sample in batch))
     elif isinstance(batch, container_abcs.Sequence):
         return [pin_memory_batch(sample) for sample in batch]
     elif hasattr(batch, "pin_memory"):