fix behavior of ConcatDataset w/ negative indices (#15756)
authorjayleverett <jay@aireverie.com>
Thu, 14 Feb 2019 19:46:55 +0000 (11:46 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 14 Feb 2019 21:02:54 +0000 (13:02 -0800)
Summary:
Currently, when you pass a negative index to a `Dataset` created with `ConcatDataset`, it simply passes that index to the first dataset in the list. So if, for example, we took `concatenated_dataset[-1]`, this will give us the last entry of the *first* dataset, rather than the last entry of the *last* dataset, as we would expect.

This is a simple fix to support the expected behavior for negative indices.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15756

Reviewed By: ezyang

Differential Revision: D14081811

Pulled By: fmassa

fbshipit-source-id: a7783fd3fd9e1a8c00fd076c4978ca39ad5a8a2a

torch/utils/data/dataset.py

index 0202ca2..bb688ce 100644 (file)
@@ -73,6 +73,10 @@ class ConcatDataset(Dataset):
         return self.cumulative_sizes[-1]
 
     def __getitem__(self, idx):
+        if idx < 0:
+            if -idx > len(self):
+                raise ValueError("absolute value of index should not exceed dataset length")
+            idx = len(self) + idx
         dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
         if dataset_idx == 0:
             sample_idx = idx