all_values.sort()
self.assertListEqual(data, all_values)
+ def test_splits_indexing_type(self):
+ r"""Indices generated by random_split
+ should be of integer type
+ """
+ class CustomDataset():
+ def __init__(self, test_object, custom_list):
+ self.data = custom_list
+ self.test_object = test_object
+
+ def __getitem__(self, key):
+ self.test_object.assertEqual(type(key), type(0))
+ return self.data[key]
+
+ def __len__(self):
+ return len(self.data)
+
+ x = [1, 2, 3, 4, 5]
+ dataset = CustomDataset(self, x)
+ dataset = random_split(dataset, [5])[0]
+ data_loader = DataLoader(dataset)
+ for batch in data_loader:
+ pass
+
class TestTensorDataset(TestCase):
if sum(lengths) != len(dataset):
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
- indices = randperm(sum(lengths))
+ indices = randperm(sum(lengths)).tolist()
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]