int64_t nbins = static_cast<int64_t>(*self.max().data<input_t>()) + 1L;
nbins = std::max(nbins, minlength); // at least minlength # of bins
- const input_t* self_p = self.contiguous().data<input_t>();
+ const input_t* self_p = self.data<input_t>();
if (has_weights) {
output = native::zeros({nbins}, weights.options());
weights_t* output_p = output.data<weights_t>();
- const weights_t* weights_p = weights.contiguous().data<weights_t>();
+ const weights_t* weights_p = weights.data<weights_t>();
for (int64_t i = 0; i < self.size(0); i++) {
output_p[self_p[i]] += weights_p[i];
}
return AT_DISPATCH_INTEGRAL_TYPES(self.type(), "bincount", [&] {
const auto scalar = weights.type().scalarType();
if (scalar == ScalarType::Undefined || scalar == ScalarType::Float)
- return _bincount_cpu_template<scalar_t, float>(self, weights, minlength);
+ return _bincount_cpu_template<scalar_t, float>(self.contiguous(), weights.contiguous(), minlength);
return _bincount_cpu_template<scalar_t, double>(
- self, weights.toType(CPU(kDouble)), minlength);
+ self.contiguous(), weights.contiguous().toType(CPU(kDouble)), minlength);
});
}
torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device))
self.assertEqual(
torch.tensor([1, 9, 0, 0, 5], device=device), byte_counts)
+ # test non-contiguous inputs and weights
+ inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device)
+ weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device)
+ print(inputs[:, 1])
+ print(weights[:, 1])
+ for i in [0, 1]:
+ assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous"
+ assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous"
+ # inputs are non-contiguous but weights are contiguous
+ self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2]))
+ # inputs and weights are non-contiguous
+ print(inputs[:, 1].bincount(weights[:, 1]))
+ self.assertEqual(inputs[:, 1].bincount(weights[:, 1]), torch.tensor([1, 9, 0, 0, 5]))
+ # weights are non-contiguous but inputs are contiguous
+ self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]),
+ torch.tensor([1, 9, 0, 0, 5]))
# test large number of bins - global memory use
big_exp = torch.zeros(10000000, device=device)
big_exp[-1] = 50.0