}
template <typename T>
-__device__ int binarySearchForMultinomial(T* dist,
+__device__ int binarySearchForMultinomial(T* cumdist,
+ T* dist,
int size,
T val) {
int start = 0;
int end = size;
- // dist[size - 1] = 0 => all zero prob dist
- assert(THCNumerics<T>::gt(dist[size - 1], 0));
+ // cumdist[size - 1] = 0 => all zero prob dist
+ assert(THCNumerics<T>::gt(cumdist[size - 1], 0));
while (end - start > 0) {
int mid = start + (end - start) / 2;
- T midVal = dist[mid];
+ T midVal = cumdist[mid];
if (THCNumerics<T>::lt(midVal, val)) {
start = mid + 1;
} else {
start = size - 1;
}
- T curVal = dist[start];
- while(start >= 1 && THCNumerics<T>::eq(dist[start - 1], curVal)) start--;
+ T curVal = cumdist[start];
+ while(start >= 1 && THCNumerics<T>::eq(dist[start], 0)) start--;
return start;
}
int64_t* dest,
int64_t distributions,
int categories,
- T* normDistPrefixSum) {
+ T* normDistPrefixSum,
+ T* normDist) {
// At the moment, each warp computes one sample value in the binary
// search due to divergence. It seems possible to compute multiple
// values and limit divergence though later on. However, no matter
// Find the bucket that a uniform sample lies in
int choice = binarySearchForMultinomial<T>(
normDistPrefixSum + curDist * categories,
+ normDist + curDist * categories,
categories,
r);
// Find the bucket that a uniform sample lies in
int choice = binarySearchForMultinomial<T>(
normDistPrefixSum + curDist * categories,
+ origDist + curDist * categories,
categories,
r);
n_sample,
THCudaLongTensor_data(state, self),
numDist, numCategories,
- THCTensor_(data)(state, prefixSum));
+ THCTensor_(data)(state, prefixSum),
+ THCTensor_(data)(state, normDist));
} else {
// Sample without replacement
r = torch.multinomial(p, 1)
self.assertNotEqual(r.min().item(), 0)
+ # test corner case from Issue #13867
+ torch.cuda.manual_seed(33)
+ probs = torch.randn(1000000, device='cuda').clamp(min=0) * 3e-5
+ samples = probs.multinomial(1000000, replacement=True)
+ self.assertGreater(probs[samples].min().item(), 0)
+
@staticmethod
def mute():
os.dup2(os.open(os.devnull, os.O_WRONLY), sys.stderr.fileno())