From: vishwakftw Date: Thu, 10 Jan 2019 04:04:18 +0000 (-0800) Subject: Fix log_prob for Gumbel distribution (#15878) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1934 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5f1dd9e7436366d2ea4580ebbaaba1a70b22b4c9;p=platform%2Fupstream%2Fpytorch.git Fix log_prob for Gumbel distribution (#15878) Summary: Fixes https://github.com/pytorch/pytorch/issues/15681 Changelog: - Add hard-coded implementation of log_prob Pull Request resolved: https://github.com/pytorch/pytorch/pull/15878 Differential Revision: D13613716 Pulled By: soumith fbshipit-source-id: 2ba74e52748b6213098b167940dcc068f0c056f4 --- diff --git a/torch/distributions/gumbel.py b/torch/distributions/gumbel.py index 68aa29a..5bd3a2d 100644 --- a/torch/distributions/gumbel.py +++ b/torch/distributions/gumbel.py @@ -45,6 +45,13 @@ class Gumbel(TransformedDistribution): new.scale = self.scale.expand(batch_shape) return super(Gumbel, self).expand(batch_shape, _instance=new) + # Explicitly defining the log probability function for Gumbel due to precision issues + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + y = (self.loc - value) / self.scale + return (y - y.exp()) - self.scale.log() + @property def mean(self): return self.loc + self.scale * euler_constant