From 8e1e29124de99c01d08a2e2c02455c72335a971d Mon Sep 17 00:00:00 2001 From: Ahmad Salim Al-Sibahi Date: Fri, 5 Apr 2019 12:45:37 -0700 Subject: [PATCH] Fix precision issue with expansion that prefers 'probs' over 'logits' (#18614) Summary: I have experienced that sometimes both were in `__dict__`, but it chose to copy `probs` which loses precision over `logits`. This is especially important when training (bayesian) neural networks or doing other type of optimization, since the loss is heavily affected. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18614 Differential Revision: D14793486 Pulled By: ezyang fbshipit-source-id: d4ff5e34fbb4021ea9de9f58af09a7de00d80a63 --- torch/distributions/bernoulli.py | 2 +- torch/distributions/binomial.py | 2 +- torch/distributions/categorical.py | 2 +- torch/distributions/geometric.py | 2 +- torch/distributions/negative_binomial.py | 2 +- torch/distributions/relaxed_bernoulli.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py index 74fb0d0..b9ec24c 100644 --- a/torch/distributions/bernoulli.py +++ b/torch/distributions/bernoulli.py @@ -53,7 +53,7 @@ class Bernoulli(ExponentialFamily): if 'probs' in self.__dict__: new.probs = self.probs.expand(batch_shape) new._param = new.probs - else: + if 'logits' in self.__dict__: new.logits = self.logits.expand(batch_shape) new._param = new.logits super(Bernoulli, new).__init__(batch_shape, validate_args=False) diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index 26fa54d..2d6f866 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -58,7 +58,7 @@ class Binomial(Distribution): if 'probs' in self.__dict__: new.probs = self.probs.expand(batch_shape) new._param = new.probs - else: + if 'logits' in self.__dict__: new.logits = self.logits.expand(batch_shape) new._param = new.logits super(Binomial, new).__init__(batch_shape, validate_args=False) diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py index bde112c..5e1d651 100644 --- a/torch/distributions/categorical.py +++ b/torch/distributions/categorical.py @@ -64,7 +64,7 @@ class Categorical(Distribution): if 'probs' in self.__dict__: new.probs = self.probs.expand(param_shape) new._param = new.probs - else: + if 'logits' in self.__dict__: new.logits = self.logits.expand(param_shape) new._param = new.logits new._num_events = self._num_events diff --git a/torch/distributions/geometric.py b/torch/distributions/geometric.py index e409abe..eec2b1d 100644 --- a/torch/distributions/geometric.py +++ b/torch/distributions/geometric.py @@ -51,7 +51,7 @@ class Geometric(Distribution): batch_shape = torch.Size(batch_shape) if 'probs' in self.__dict__: new.probs = self.probs.expand(batch_shape) - else: + if 'logits' in self.__dict__: new.logits = self.logits.expand(batch_shape) super(Geometric, new).__init__(batch_shape, validate_args=False) new._validate_args = self._validate_args diff --git a/torch/distributions/negative_binomial.py b/torch/distributions/negative_binomial.py index 01a009b..9e4410a 100644 --- a/torch/distributions/negative_binomial.py +++ b/torch/distributions/negative_binomial.py @@ -45,7 +45,7 @@ class NegativeBinomial(Distribution): if 'probs' in self.__dict__: new.probs = self.probs.expand(batch_shape) new._param = new.probs - else: + if 'logits' in self.__dict__: new.logits = self.logits.expand(batch_shape) new._param = new.logits super(NegativeBinomial, new).__init__(batch_shape, validate_args=False) diff --git a/torch/distributions/relaxed_bernoulli.py b/torch/distributions/relaxed_bernoulli.py index 4df4b53..cea8a7c 100644 --- a/torch/distributions/relaxed_bernoulli.py +++ b/torch/distributions/relaxed_bernoulli.py @@ -54,7 +54,7 @@ class LogitRelaxedBernoulli(Distribution): if 'probs' in self.__dict__: new.probs = self.probs.expand(batch_shape) new._param = new.probs - else: + if 'logits' in self.__dict__: new.logits = self.logits.expand(batch_shape) new._param = new.logits super(LogitRelaxedBernoulli, new).__init__(batch_shape, validate_args=False) -- 2.7.4