From d33e7d12367db4a7ec8ed920563411b5584dd106 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 16 Jan 2019 12:15:12 -0800 Subject: [PATCH] multinomial: fix detection of zero probability (#16075) Summary: The cumsum over the probabilities can be not monotonically non-decreasing. Thus it is hard to detect zero probability classes using just the cumsum. This changes the binary search postprocessing to use the (non-cumulated) distribution instead. Thank you, jcjohnson, for the bug report with reproducing case. Fixes: #13867 Pull Request resolved: https://github.com/pytorch/pytorch/pull/16075 Differential Revision: D13695565 Pulled By: soumith fbshipit-source-id: 02c4d6f868f0050c1ae7d333f4317c5610e49cd9 --- aten/src/THC/THCTensorRandom.cuh | 18 +++++++++++------- aten/src/THC/generic/THCTensorRandom.cu | 3 ++- test/test_cuda.py | 6 ++++++ 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/aten/src/THC/THCTensorRandom.cuh b/aten/src/THC/THCTensorRandom.cuh index 4830340..2a8f7d0 100644 --- a/aten/src/THC/THCTensorRandom.cuh +++ b/aten/src/THC/THCTensorRandom.cuh @@ -121,18 +121,19 @@ __global__ void renormRowsL1(T* dist, long rows, long cols) { } template -__device__ int binarySearchForMultinomial(T* dist, +__device__ int binarySearchForMultinomial(T* cumdist, + T* dist, int size, T val) { int start = 0; int end = size; - // dist[size - 1] = 0 => all zero prob dist - assert(THCNumerics::gt(dist[size - 1], 0)); + // cumdist[size - 1] = 0 => all zero prob dist + assert(THCNumerics::gt(cumdist[size - 1], 0)); while (end - start > 0) { int mid = start + (end - start) / 2; - T midVal = dist[mid]; + T midVal = cumdist[mid]; if (THCNumerics::lt(midVal, val)) { start = mid + 1; } else { @@ -149,8 +150,8 @@ __device__ int binarySearchForMultinomial(T* dist, start = size - 1; } - T curVal = dist[start]; - while(start >= 1 && THCNumerics::eq(dist[start - 1], curVal)) start--; + T curVal = cumdist[start]; + while(start >= 1 && THCNumerics::eq(dist[start], 0)) start--; return start; } @@ -299,7 +300,8 @@ sampleMultinomialWithReplacement(curandStateMtgp32* state, int64_t* dest, int64_t distributions, int categories, - T* normDistPrefixSum) { + T* normDistPrefixSum, + T* normDist) { // At the moment, each warp computes one sample value in the binary // search due to divergence. It seems possible to compute multiple // values and limit divergence though later on. However, no matter @@ -322,6 +324,7 @@ sampleMultinomialWithReplacement(curandStateMtgp32* state, // Find the bucket that a uniform sample lies in int choice = binarySearchForMultinomial( normDistPrefixSum + curDist * categories, + normDist + curDist * categories, categories, r); @@ -363,6 +366,7 @@ sampleMultinomialWithoutReplacement(curandStateMtgp32* state, // Find the bucket that a uniform sample lies in int choice = binarySearchForMultinomial( normDistPrefixSum + curDist * categories, + origDist + curDist * categories, categories, r); diff --git a/aten/src/THC/generic/THCTensorRandom.cu b/aten/src/THC/generic/THCTensorRandom.cu index 9348188b..c18bbfe 100644 --- a/aten/src/THC/generic/THCTensorRandom.cu +++ b/aten/src/THC/generic/THCTensorRandom.cu @@ -248,7 +248,8 @@ void THCTensor_(multinomial)(struct THCState *state, n_sample, THCudaLongTensor_data(state, self), numDist, numCategories, - THCTensor_(data)(state, prefixSum)); + THCTensor_(data)(state, prefixSum), + THCTensor_(data)(state, normDist)); } else { // Sample without replacement diff --git a/test/test_cuda.py b/test/test_cuda.py index 03bd195..be53595 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1813,6 +1813,12 @@ class TestCuda(TestCase): r = torch.multinomial(p, 1) self.assertNotEqual(r.min().item(), 0) + # test corner case from Issue #13867 + torch.cuda.manual_seed(33) + probs = torch.randn(1000000, device='cuda').clamp(min=0) * 3e-5 + samples = probs.multinomial(1000000, replacement=True) + self.assertGreater(probs[samples].min().item(), 0) + @staticmethod def mute(): os.dup2(os.open(os.devnull, os.O_WRONLY), sys.stderr.fileno()) -- 2.7.4