Fix bincount for non-contiguous inputs on CPU (#15109)
authorvishwakftw <cs15btech11043@iith.ac.in>
Thu, 13 Dec 2018 17:38:40 +0000 (09:38 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 13 Dec 2018 17:44:20 +0000 (09:44 -0800)
Summary:
Fixes #15058.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15109

Differential Revision: D13447448

Pulled By: soumith

fbshipit-source-id: 56e8d42934538fb00465105a2c5ccfeb7c18a651

aten/src/ATen/native/SummaryOps.cpp
test/test_torch.py

index 79bb4d2..6a1d6bd 100644 (file)
@@ -34,11 +34,11 @@ Tensor _bincount_cpu_template(
   int64_t nbins = static_cast<int64_t>(*self.max().data<input_t>()) + 1L;
   nbins = std::max(nbins, minlength); // at least minlength # of bins
 
-  const input_t* self_p = self.contiguous().data<input_t>();
+  const input_t* self_p = self.data<input_t>();
   if (has_weights) {
     output = native::zeros({nbins}, weights.options());
     weights_t* output_p = output.data<weights_t>();
-    const weights_t* weights_p = weights.contiguous().data<weights_t>();
+    const weights_t* weights_p = weights.data<weights_t>();
     for (int64_t i = 0; i < self.size(0); i++) {
       output_p[self_p[i]] += weights_p[i];
     }
@@ -58,9 +58,9 @@ _bincount_cpu(const Tensor& self, const Tensor& weights, int64_t minlength) {
   return AT_DISPATCH_INTEGRAL_TYPES(self.type(), "bincount", [&] {
     const auto scalar = weights.type().scalarType();
     if (scalar == ScalarType::Undefined || scalar == ScalarType::Float)
-      return _bincount_cpu_template<scalar_t, float>(self, weights, minlength);
+      return _bincount_cpu_template<scalar_t, float>(self.contiguous(), weights.contiguous(), minlength);
     return _bincount_cpu_template<scalar_t, double>(
-        self, weights.toType(CPU(kDouble)), minlength);
+        self.contiguous(), weights.contiguous().toType(CPU(kDouble)), minlength);
   });
 }
 
index 47dcc34..c9c41ff 100644 (file)
@@ -9447,6 +9447,22 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device))
         self.assertEqual(
             torch.tensor([1, 9, 0, 0, 5], device=device), byte_counts)
+        # test non-contiguous inputs and weights
+        inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device)
+        weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device)
+        print(inputs[:, 1])
+        print(weights[:, 1])
+        for i in [0, 1]:
+            assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous"
+            assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous"
+        # inputs are non-contiguous but weights are contiguous
+        self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2]))
+        # inputs and weights are non-contiguous
+        print(inputs[:, 1].bincount(weights[:, 1]))
+        self.assertEqual(inputs[:, 1].bincount(weights[:, 1]), torch.tensor([1, 9, 0, 0, 5]))
+        # weights are non-contiguous but inputs are contiguous
+        self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]),
+                         torch.tensor([1, 9, 0, 0, 5]))
         # test large number of bins - global memory use
         big_exp = torch.zeros(10000000, device=device)
         big_exp[-1] = 50.0