From 5f1dd9e7436366d2ea4580ebbaaba1a70b22b4c9 Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Wed, 9 Jan 2019 20:04:18 -0800 Subject: [PATCH] Fix log_prob for Gumbel distribution (#15878) Summary: Fixes https://github.com/pytorch/pytorch/issues/15681 Changelog: - Add hard-coded implementation of log_prob Pull Request resolved: https://github.com/pytorch/pytorch/pull/15878 Differential Revision: D13613716 Pulled By: soumith fbshipit-source-id: 2ba74e52748b6213098b167940dcc068f0c056f4 --- torch/distributions/gumbel.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torch/distributions/gumbel.py b/torch/distributions/gumbel.py index 68aa29a..5bd3a2d 100644 --- a/torch/distributions/gumbel.py +++ b/torch/distributions/gumbel.py @@ -45,6 +45,13 @@ class Gumbel(TransformedDistribution): new.scale = self.scale.expand(batch_shape) return super(Gumbel, self).expand(batch_shape, _instance=new) + # Explicitly defining the log probability function for Gumbel due to precision issues + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + y = (self.loc - value) / self.scale + return (y - y.exp()) - self.scale.log() + @property def mean(self): return self.loc + self.scale * euler_constant -- 2.7.4