From: vishwakftw Date: Thu, 13 Dec 2018 17:38:40 +0000 (-0800) Subject: Fix bincount for non-contiguous inputs on CPU (#15109) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2270 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=214f46faf5c5dddb073e09205d047aead814b0c6;p=platform%2Fupstream%2Fpytorch.git Fix bincount for non-contiguous inputs on CPU (#15109) Summary: Fixes #15058. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15109 Differential Revision: D13447448 Pulled By: soumith fbshipit-source-id: 56e8d42934538fb00465105a2c5ccfeb7c18a651 --- diff --git a/aten/src/ATen/native/SummaryOps.cpp b/aten/src/ATen/native/SummaryOps.cpp index 79bb4d2..6a1d6bd 100644 --- a/aten/src/ATen/native/SummaryOps.cpp +++ b/aten/src/ATen/native/SummaryOps.cpp @@ -34,11 +34,11 @@ Tensor _bincount_cpu_template( int64_t nbins = static_cast(*self.max().data()) + 1L; nbins = std::max(nbins, minlength); // at least minlength # of bins - const input_t* self_p = self.contiguous().data(); + const input_t* self_p = self.data(); if (has_weights) { output = native::zeros({nbins}, weights.options()); weights_t* output_p = output.data(); - const weights_t* weights_p = weights.contiguous().data(); + const weights_t* weights_p = weights.data(); for (int64_t i = 0; i < self.size(0); i++) { output_p[self_p[i]] += weights_p[i]; } @@ -58,9 +58,9 @@ _bincount_cpu(const Tensor& self, const Tensor& weights, int64_t minlength) { return AT_DISPATCH_INTEGRAL_TYPES(self.type(), "bincount", [&] { const auto scalar = weights.type().scalarType(); if (scalar == ScalarType::Undefined || scalar == ScalarType::Float) - return _bincount_cpu_template(self, weights, minlength); + return _bincount_cpu_template(self.contiguous(), weights.contiguous(), minlength); return _bincount_cpu_template( - self, weights.toType(CPU(kDouble)), minlength); + self.contiguous(), weights.contiguous().toType(CPU(kDouble)), minlength); }); } diff --git a/test/test_torch.py b/test/test_torch.py index 47dcc34..c9c41ff 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -9447,6 +9447,22 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device)) self.assertEqual( torch.tensor([1, 9, 0, 0, 5], device=device), byte_counts) + # test non-contiguous inputs and weights + inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device) + weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device) + print(inputs[:, 1]) + print(weights[:, 1]) + for i in [0, 1]: + assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous" + assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous" + # inputs are non-contiguous but weights are contiguous + self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2])) + # inputs and weights are non-contiguous + print(inputs[:, 1].bincount(weights[:, 1])) + self.assertEqual(inputs[:, 1].bincount(weights[:, 1]), torch.tensor([1, 9, 0, 0, 5])) + # weights are non-contiguous but inputs are contiguous + self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]), + torch.tensor([1, 9, 0, 0, 5])) # test large number of bins - global memory use big_exp = torch.zeros(10000000, device=device) big_exp[-1] = 50.0