Summary:
Fixes https://github.com/pytorch/pytorch/issues/15681
Changelog:
- Add hard-coded implementation of log_prob
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15878
Differential Revision:
D13613716
Pulled By: soumith
fbshipit-source-id:
2ba74e52748b6213098b167940dcc068f0c056f4
new.scale = self.scale.expand(batch_shape)
return super(Gumbel, self).expand(batch_shape, _instance=new)
+ # Explicitly defining the log probability function for Gumbel due to precision issues
+ def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ y = (self.loc - value) / self.scale
+ return (y - y.exp()) - self.scale.log()
+
@property
def mean(self):
return self.loc + self.scale * euler_constant