# weights are non-contiguous but inputs are contiguous
self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]),
torch.tensor([1, 9, 0, 0, 5]))
+
+ # test bincount on non-contiguous slices
+ all0s = torch.zeros((32, 2), dtype=torch.int64, device=device)
+ self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32]))
+
+ all1s = torch.ones((32, 2), dtype=torch.int64, device=device)
+ self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32]))
+
# test large number of bins - global memory use
big_exp = torch.zeros(10000000, device=device)
big_exp[-1] = 50.0