From a6c4ea66dd78b28c9c0aecec64d4cf0b14ba55ff Mon Sep 17 00:00:00 2001 From: bhushan Date: Sun, 10 Mar 2019 09:20:30 -0700 Subject: [PATCH] Passing indices as a list to Subset instead of Tensor (#17649) Summary: Indices in Subset were stored as tensors earlier passing as list in random_split to ensure integer indexing fixes: #17466 Pull Request resolved: https://github.com/pytorch/pytorch/pull/17649 Differential Revision: D14400250 Pulled By: soumith fbshipit-source-id: cd20a959f33773c4babf8e861ea37ec61c2713a0 --- test/test_dataloader.py | 23 +++++++++++++++++++++++ torch/utils/data/dataset.py | 2 +- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/test/test_dataloader.py b/test/test_dataloader.py index c5dfb4c..02a7eb2 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -73,6 +73,29 @@ class TestDatasetRandomSplit(TestCase): all_values.sort() self.assertListEqual(data, all_values) + def test_splits_indexing_type(self): + r"""Indices generated by random_split + should be of integer type + """ + class CustomDataset(): + def __init__(self, test_object, custom_list): + self.data = custom_list + self.test_object = test_object + + def __getitem__(self, key): + self.test_object.assertEqual(type(key), type(0)) + return self.data[key] + + def __len__(self): + return len(self.data) + + x = [1, 2, 3, 4, 5] + dataset = CustomDataset(self, x) + dataset = random_split(dataset, [5])[0] + data_loader = DataLoader(dataset) + for batch in data_loader: + pass + class TestTensorDataset(TestCase): diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index bb688ce..cd8b455 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -121,5 +121,5 @@ def random_split(dataset, lengths): if sum(lengths) != len(dataset): raise ValueError("Sum of input lengths does not equal the length of the input dataset!") - indices = randperm(sum(lengths)) + indices = randperm(sum(lengths)).tolist() return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] -- 2.7.4