multinomial: fix detection of zero probability (#16075)
authorThomas Viehmann <tv.code@beamnet.de>
Wed, 16 Jan 2019 20:15:12 +0000 (12:15 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 16 Jan 2019 20:50:49 +0000 (12:50 -0800)
Summary:
The cumsum over the probabilities can be not monotonically
non-decreasing. Thus it is hard to detect zero probability
classes using just the cumsum.
This changes the binary search postprocessing to use the
(non-cumulated) distribution instead.

Thank you, jcjohnson, for the bug report with
reproducing case.

Fixes: #13867
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16075

Differential Revision: D13695565

Pulled By: soumith

fbshipit-source-id: 02c4d6f868f0050c1ae7d333f4317c5610e49cd9

aten/src/THC/THCTensorRandom.cuh
aten/src/THC/generic/THCTensorRandom.cu
test/test_cuda.py

index 4830340..2a8f7d0 100644 (file)
@@ -121,18 +121,19 @@ __global__ void renormRowsL1(T* dist, long rows, long cols) {
 }
 
 template <typename T>
-__device__ int binarySearchForMultinomial(T* dist,
+__device__ int binarySearchForMultinomial(T* cumdist,
+                                          T* dist,
                                           int size,
                                           T val) {
   int start = 0;
   int end = size;
-  // dist[size - 1] = 0 => all zero prob dist
-  assert(THCNumerics<T>::gt(dist[size - 1], 0));
+  // cumdist[size - 1] = 0 => all zero prob dist
+  assert(THCNumerics<T>::gt(cumdist[size - 1], 0));
 
   while (end - start > 0) {
     int mid = start + (end - start) / 2;
 
-    T midVal = dist[mid];
+    T midVal = cumdist[mid];
     if (THCNumerics<T>::lt(midVal, val)) {
       start = mid + 1;
     } else {
@@ -149,8 +150,8 @@ __device__ int binarySearchForMultinomial(T* dist,
     start = size - 1;
   }
 
-  T curVal = dist[start];
-  while(start >= 1 && THCNumerics<T>::eq(dist[start - 1], curVal)) start--;
+  T curVal = cumdist[start];
+  while(start >= 1 && THCNumerics<T>::eq(dist[start], 0)) start--;
 
   return start;
 }
@@ -299,7 +300,8 @@ sampleMultinomialWithReplacement(curandStateMtgp32* state,
                                  int64_t* dest,
                                  int64_t distributions,
                                  int categories,
-                                 T* normDistPrefixSum) {
+                                 T* normDistPrefixSum,
+                                 T* normDist) {
   // At the moment, each warp computes one sample value in the binary
   // search due to divergence. It seems possible to compute multiple
   // values and limit divergence though later on. However, no matter
@@ -322,6 +324,7 @@ sampleMultinomialWithReplacement(curandStateMtgp32* state,
         // Find the bucket that a uniform sample lies in
         int choice = binarySearchForMultinomial<T>(
           normDistPrefixSum + curDist * categories,
+          normDist + curDist * categories,
           categories,
           r);
 
@@ -363,6 +366,7 @@ sampleMultinomialWithoutReplacement(curandStateMtgp32* state,
       // Find the bucket that a uniform sample lies in
       int choice = binarySearchForMultinomial<T>(
         normDistPrefixSum + curDist * categories,
+        origDist + curDist * categories,
         categories,
         r);
 
index 9348188..c18bbfe 100644 (file)
@@ -248,7 +248,8 @@ void THCTensor_(multinomial)(struct THCState *state,
           n_sample,
           THCudaLongTensor_data(state, self),
           numDist, numCategories,
-          THCTensor_(data)(state, prefixSum));
+          THCTensor_(data)(state, prefixSum),
+         THCTensor_(data)(state, normDist));
     } else {
       // Sample without replacement
 
index 03bd195..be53595 100644 (file)
@@ -1813,6 +1813,12 @@ class TestCuda(TestCase):
         r = torch.multinomial(p, 1)
         self.assertNotEqual(r.min().item(), 0)
 
+        # test corner case from Issue #13867
+        torch.cuda.manual_seed(33)
+        probs = torch.randn(1000000, device='cuda').clamp(min=0) * 3e-5
+        samples = probs.multinomial(1000000, replacement=True)
+        self.assertGreater(probs[samples].min().item(), 0)
+
     @staticmethod
     def mute():
         os.dup2(os.open(os.devnull, os.O_WRONLY), sys.stderr.fileno())