Summary:
Currently, when you pass a negative index to a `Dataset` created with `ConcatDataset`, it simply passes that index to the first dataset in the list. So if, for example, we took `concatenated_dataset[-1]`, this will give us the last entry of the *first* dataset, rather than the last entry of the *last* dataset, as we would expect.
This is a simple fix to support the expected behavior for negative indices.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15756
Reviewed By: ezyang
Differential Revision:
D14081811
Pulled By: fmassa
fbshipit-source-id:
a7783fd3fd9e1a8c00fd076c4978ca39ad5a8a2a
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
+ if idx < 0:
+ if -idx > len(self):
+ raise ValueError("absolute value of index should not exceed dataset length")
+ idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx