Summary:
I have experienced that sometimes both were in `__dict__`, but it chose to copy `probs` which loses precision over `logits`. This is especially important when training (bayesian) neural networks or doing other type of optimization, since the loss is heavily affected.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18614
Differential Revision:
D14793486
Pulled By: ezyang
fbshipit-source-id:
d4ff5e34fbb4021ea9de9f58af09a7de00d80a63
if 'probs' in self.__dict__:
new.probs = self.probs.expand(batch_shape)
new._param = new.probs
- else:
+ if 'logits' in self.__dict__:
new.logits = self.logits.expand(batch_shape)
new._param = new.logits
super(Bernoulli, new).__init__(batch_shape, validate_args=False)
if 'probs' in self.__dict__:
new.probs = self.probs.expand(batch_shape)
new._param = new.probs
- else:
+ if 'logits' in self.__dict__:
new.logits = self.logits.expand(batch_shape)
new._param = new.logits
super(Binomial, new).__init__(batch_shape, validate_args=False)
if 'probs' in self.__dict__:
new.probs = self.probs.expand(param_shape)
new._param = new.probs
- else:
+ if 'logits' in self.__dict__:
new.logits = self.logits.expand(param_shape)
new._param = new.logits
new._num_events = self._num_events
batch_shape = torch.Size(batch_shape)
if 'probs' in self.__dict__:
new.probs = self.probs.expand(batch_shape)
- else:
+ if 'logits' in self.__dict__:
new.logits = self.logits.expand(batch_shape)
super(Geometric, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
if 'probs' in self.__dict__:
new.probs = self.probs.expand(batch_shape)
new._param = new.probs
- else:
+ if 'logits' in self.__dict__:
new.logits = self.logits.expand(batch_shape)
new._param = new.logits
super(NegativeBinomial, new).__init__(batch_shape, validate_args=False)
if 'probs' in self.__dict__:
new.probs = self.probs.expand(batch_shape)
new._param = new.probs
- else:
+ if 'logits' in self.__dict__:
new.logits = self.logits.expand(batch_shape)
new._param = new.logits
super(LogitRelaxedBernoulli, new).__init__(batch_shape, validate_args=False)