self.assertTrue(sample['another_dict']['a_number'].is_pinned())
+class NamedTupleDataset(Dataset):
+ from collections import namedtuple
+ Batch = namedtuple('Batch', ['data', 'label'])
+ Data = namedtuple('Data', ['positive', 'negative'])
+
+ def __len__(self):
+ return 4
+
+ def __getitem__(self, ndx):
+ return self.Batch(data=self.Data(positive=ndx, negative=-ndx),
+ label=str(ndx))
+
+
+class TestNamedTupleDataLoader(TestCase):
+ def setUp(self):
+ self.dataset = NamedTupleDataset()
+
+ @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+ def test_collate_and_pin_memory_with_namedtuple(self):
+ loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
+ for batch in loader:
+ self.assertIsInstance(batch, NamedTupleDataset.Batch)
+ self.assertIsInstance(batch.data, NamedTupleDataset.Data)
+
+
class SimpleCustomBatch:
def __init__(self, data):
transposed_data = list(zip(*data))
return batch
elif isinstance(batch[0], container_abcs.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
+ elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'): # namedtuple
+ return type(batch[0])(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(batch[0], container_abcs.Sequence):
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
return batch
elif isinstance(batch, container_abcs.Mapping):
return {k: pin_memory_batch(sample) for k, sample in batch.items()}
+ elif isinstance(batch, tuple) and hasattr(batch, '_fields'): # namedtuple
+ return type(batch)(*(pin_memory_batch(sample) for sample in batch))
elif isinstance(batch, container_abcs.Sequence):
return [pin_memory_batch(sample) for sample in batch]
elif hasattr(batch, "pin_memory"):