Fix log_prob for Gumbel distribution (#15878)
authorvishwakftw <cs15btech11043@iith.ac.in>
Thu, 10 Jan 2019 04:04:18 +0000 (20:04 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 10 Jan 2019 04:09:34 +0000 (20:09 -0800)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/15681

Changelog:
- Add hard-coded implementation of log_prob
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15878

Differential Revision: D13613716

Pulled By: soumith

fbshipit-source-id: 2ba74e52748b6213098b167940dcc068f0c056f4

torch/distributions/gumbel.py

index 68aa29a..5bd3a2d 100644 (file)
@@ -45,6 +45,13 @@ class Gumbel(TransformedDistribution):
         new.scale = self.scale.expand(batch_shape)
         return super(Gumbel, self).expand(batch_shape, _instance=new)
 
+    # Explicitly defining the log probability function for Gumbel due to precision issues
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        y = (self.loc - value) / self.scale
+        return (y - y.exp()) - self.scale.log()
+
     @property
     def mean(self):
         return self.loc + self.scale * euler_constant