# test non-contiguous inputs and weights
inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device)
weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device)
- print(inputs[:, 1])
- print(weights[:, 1])
for i in [0, 1]:
assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous"
assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous"
# inputs are non-contiguous but weights are contiguous
self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2]))
# inputs and weights are non-contiguous
- print(inputs[:, 1].bincount(weights[:, 1]))
self.assertEqual(inputs[:, 1].bincount(weights[:, 1]), torch.tensor([1, 9, 0, 0, 5]))
# weights are non-contiguous but inputs are contiguous
self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]),