Rely on numel() == 1 to check if distribution parameters are scalar. (#17503)
authorMorgan Funtowicz <morgan.funtowicz@naverlabs.com>
Thu, 28 Feb 2019 21:27:27 +0000 (13:27 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 28 Feb 2019 21:36:17 +0000 (13:36 -0800)
Summary:
As discussed here #16952, this PR aims at improving the __repr__ for distribution when the provided parameters are torch.Tensor with only one element.

Currently, __repr__() relies on dim() == 0 leading to the following behaviour :

```
>>> torch.distributions.Normal(torch.tensor([1.0]), torch.tensor([0.1]))
Normal(loc: torch.Size([1]), scale: torch.Size([1]))
```

With this PR, the output looks like the following:
```
>>> torch.distributions.Normal(torch.tensor([1.0]), torch.tensor([0.1]))
Normal(loc: 1.0, scale: 0.10000000149011612)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17503

Differential Revision: D14245439

Pulled By: soumith

fbshipit-source-id: a440998905fd60cf2ac9a94f75706021dd9ce5bf

torch/distributions/distribution.py

index 2c2733a..d1e3a39 100644 (file)
@@ -262,6 +262,6 @@ class Distribution(object):
     def __repr__(self):
         param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__]
         args_string = ', '.join(['{}: {}'.format(p, self.__dict__[p]
-                                if self.__dict__[p].dim() == 0
+                                if self.__dict__[p].numel() == 1
                                 else self.__dict__[p].size()) for p in param_names])
         return self.__class__.__name__ + '(' + args_string + ')'