Passing indices as a list to Subset instead of Tensor (#17649)
authorbhushan <bhushan.s.94@gmail.com>
Sun, 10 Mar 2019 16:20:30 +0000 (09:20 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sun, 10 Mar 2019 16:23:53 +0000 (09:23 -0700)
Summary:
Indices in Subset were stored as tensors earlier
passing as list in random_split to ensure integer indexing

fixes: #17466
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17649

Differential Revision: D14400250

Pulled By: soumith

fbshipit-source-id: cd20a959f33773c4babf8e861ea37ec61c2713a0

test/test_dataloader.py
torch/utils/data/dataset.py

index c5dfb4c..02a7eb2 100644 (file)
@@ -73,6 +73,29 @@ class TestDatasetRandomSplit(TestCase):
         all_values.sort()
         self.assertListEqual(data, all_values)
 
+    def test_splits_indexing_type(self):
+        r"""Indices generated by random_split
+          should be of integer type
+        """
+        class CustomDataset():
+            def __init__(self, test_object, custom_list):
+                self.data = custom_list
+                self.test_object = test_object
+
+            def __getitem__(self, key):
+                self.test_object.assertEqual(type(key), type(0))
+                return self.data[key]
+
+            def __len__(self):
+                return len(self.data)
+
+        x = [1, 2, 3, 4, 5]
+        dataset = CustomDataset(self, x)
+        dataset = random_split(dataset, [5])[0]
+        data_loader = DataLoader(dataset)
+        for batch in data_loader:
+            pass
+
 
 class TestTensorDataset(TestCase):
 
index bb688ce..cd8b455 100644 (file)
@@ -121,5 +121,5 @@ def random_split(dataset, lengths):
     if sum(lengths) != len(dataset):
         raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
 
-    indices = randperm(sum(lengths))
+    indices = randperm(sum(lengths)).tolist()
     return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]