Remove _finfo; replace _finfo usage with torch.finfo (#15165)
authorvishwakftw <cs15btech11043@iith.ac.in>
Thu, 13 Dec 2018 22:28:09 +0000 (14:28 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 13 Dec 2018 22:30:27 +0000 (14:30 -0800)
Summary:
This PR removes the usage of _finfo defined in torch.distributions.utils and changes the call sites
to use torch.finfo instead

Differential Revision: D13451936

Pulled By: soumith

fbshipit-source-id: 6dbda3a6179d9407bc3396bf1a2baf3e85bc4cf2

test/test_distributions.py
torch/distributions/dirichlet.py
torch/distributions/fishersnedecor.py
torch/distributions/gamma.py
torch/distributions/geometric.py
torch/distributions/gumbel.py
torch/distributions/laplace.py
torch/distributions/utils.py

index 4558f2b..de1c43e 100644 (file)
@@ -57,7 +57,7 @@ from torch.distributions.transforms import (AbsTransform, AffineTransform,
                                             SoftmaxTransform,
                                             StickBreakingTransform,
                                             identity_transform)
-from torch.distributions.utils import _finfo, probs_to_logits, lazy_property
+from torch.distributions.utils import probs_to_logits, lazy_property
 from torch.nn.functional import softmax
 
 # load_tests from common_utils is used to automatically filter tests for
@@ -3480,7 +3480,7 @@ class TestNumericalStability(TestCase):
             self._test_pdf_score(dist_class=Bernoulli,
                                  probs=tensor_type([0]),
                                  x=tensor_type([1]),
-                                 expected_value=tensor_type([_finfo(tensor_type([])).eps]).log(),
+                                 expected_value=tensor_type([torch.finfo(tensor_type([]).dtype).eps]).log(),
                                  expected_gradient=tensor_type([0]))
 
             self._test_pdf_score(dist_class=Bernoulli,
index f618628..5189266 100644 (file)
@@ -5,7 +5,7 @@ from torch.autograd import Function
 from torch.autograd.function import once_differentiable
 from torch.distributions import constraints
 from torch.distributions.exp_family import ExponentialFamily
-from torch.distributions.utils import _finfo, broadcast_all, clamp_probs
+from torch.distributions.utils import broadcast_all, clamp_probs
 
 
 def _dirichlet_sample_nograd(concentration):
index 6fe09a7..1071d53 100644 (file)
@@ -5,7 +5,7 @@ from torch._six import nan
 from torch.distributions import constraints
 from torch.distributions.distribution import Distribution
 from torch.distributions.gamma import Gamma
-from torch.distributions.utils import broadcast_all, _finfo
+from torch.distributions.utils import broadcast_all
 
 
 class FisherSnedecor(Distribution):
@@ -66,9 +66,10 @@ class FisherSnedecor(Distribution):
         #   Y = df2 * df1 * X1 / (df1 * df2 * X2) = X1 / X2 ~ F(df1, df2)
         X1 = self._gamma1.rsample(sample_shape).view(shape)
         X2 = self._gamma2.rsample(sample_shape).view(shape)
-        X2.clamp_(min=_finfo(X2).tiny)
+        tiny = torch.finfo(X2.dtype).tiny
+        X2.clamp_(min=tiny)
         Y = X1 / X2
-        Y.clamp_(min=_finfo(X2).tiny)
+        Y.clamp_(min=tiny)
         return Y
 
     def log_prob(self, value):
index 655ddac..d4641db 100644 (file)
@@ -3,7 +3,7 @@ from numbers import Number
 import torch
 from torch.distributions import constraints
 from torch.distributions.exp_family import ExponentialFamily
-from torch.distributions.utils import _finfo, broadcast_all, lazy_property
+from torch.distributions.utils import broadcast_all, lazy_property
 
 
 def _standard_gamma(concentration):
@@ -59,7 +59,7 @@ class Gamma(ExponentialFamily):
     def rsample(self, sample_shape=torch.Size()):
         shape = self._extended_shape(sample_shape)
         value = _standard_gamma(self.concentration.expand(shape)) / self.rate.expand(shape)
-        value.detach().clamp_(min=_finfo(value).tiny)  # do not record in autograd graph
+        value.detach().clamp_(min=torch.finfo(value.dtype).tiny)  # do not record in autograd graph
         return value
 
     def log_prob(self, value):
index 7b9e796..4b8dc4c 100644 (file)
@@ -3,7 +3,7 @@ from numbers import Number
 import torch
 from torch.distributions import constraints
 from torch.distributions.distribution import Distribution
-from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property, _finfo
+from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property
 from torch.nn.functional import binary_cross_entropy_with_logits
 
 
@@ -74,13 +74,14 @@ class Geometric(Distribution):
 
     def sample(self, sample_shape=torch.Size()):
         shape = self._extended_shape(sample_shape)
+        tiny = torch.finfo(self.probs.dtype).tiny
         with torch.no_grad():
             if torch._C._get_tracing_state():
                 # [JIT WORKAROUND] lack of support for .uniform_()
                 u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
-                u = u.clamp(min=_finfo(self.probs).tiny)
+                u = u.clamp(min=tiny)
             else:
-                u = self.probs.new(shape).uniform_(_finfo(self.probs).tiny, 1)
+                u = self.probs.new(shape).uniform_(tiny, 1)
             return (u.log() / (-self.probs).log1p()).floor()
 
     def log_prob(self, value):
index 7b9deaa..68aa29a 100644 (file)
@@ -5,7 +5,7 @@ from torch.distributions import constraints
 from torch.distributions.uniform import Uniform
 from torch.distributions.transformed_distribution import TransformedDistribution
 from torch.distributions.transforms import AffineTransform, ExpTransform
-from torch.distributions.utils import _finfo, broadcast_all
+from torch.distributions.utils import broadcast_all
 
 euler_constant = 0.57721566490153286060  # Euler Mascheroni Constant
 
@@ -29,7 +29,7 @@ class Gumbel(TransformedDistribution):
 
     def __init__(self, loc, scale, validate_args=None):
         self.loc, self.scale = broadcast_all(loc, scale)
-        finfo = _finfo(self.loc)
+        finfo = torch.finfo(self.loc.dtype)
         if isinstance(loc, Number) and isinstance(scale, Number):
             base_dist = Uniform(finfo.tiny, 1 - finfo.eps)
         else:
index c15a35c..748ab77 100644 (file)
@@ -2,7 +2,7 @@ from numbers import Number
 import torch
 from torch.distributions import constraints
 from torch.distributions.distribution import Distribution
-from torch.distributions.utils import _finfo, broadcast_all
+from torch.distributions.utils import broadcast_all
 
 
 class Laplace(Distribution):
@@ -54,11 +54,12 @@ class Laplace(Distribution):
 
     def rsample(self, sample_shape=torch.Size()):
         shape = self._extended_shape(sample_shape)
+        finfo = torch.finfo(self.loc.dtype)
         if torch._C._get_tracing_state():
             # [JIT WORKAROUND] lack of support for .uniform_()
             u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device) * 2 - 1
-            return self.loc - self.scale * u.sign() * torch.log1p(-u.abs().clamp(min=_finfo(self.loc).tiny))
-        u = self.loc.new(shape).uniform_(_finfo(self.loc).eps - 1, 1)
+            return self.loc - self.scale * u.sign() * torch.log1p(-u.abs().clamp(min=finfo.tiny))
+        u = self.loc.new(shape).uniform_(finfo.eps - 1, 1)
         # TODO: If we ever implement tensor.nextafter, below is what we want ideally.
         # u = self.loc.new(shape).uniform_(self.loc.nextafter(-.5, 0), .5)
         return self.loc - self.scale * u.sign() * torch.log1p(-u.abs())
index 1fcc727..698c5a2 100644 (file)
@@ -1,36 +1,9 @@
-from collections import namedtuple
 from functools import update_wrapper
 from numbers import Number
 import math
 import torch
 import torch.nn.functional as F
 
-# This follows semantics of numpy.finfo.
-_Finfo = namedtuple('_Finfo', ['eps', 'tiny'])
-_FINFO = {
-    torch.HalfStorage: _Finfo(eps=0.00097656, tiny=6.1035e-05),
-    torch.FloatStorage: _Finfo(eps=1.19209e-07, tiny=1.17549e-38),
-    torch.DoubleStorage: _Finfo(eps=2.22044604925e-16, tiny=2.22507385851e-308),
-    torch.cuda.HalfStorage: _Finfo(eps=0.00097656, tiny=6.1035e-05),
-    torch.cuda.FloatStorage: _Finfo(eps=1.19209e-07, tiny=1.17549e-38),
-    torch.cuda.DoubleStorage: _Finfo(eps=2.22044604925e-16, tiny=2.22507385851e-308),
-}
-
-
-def _finfo(tensor):
-    r"""
-    Return floating point info about a `Tensor`:
-    - `.eps` is the smallest number that can be added to 1 without being lost.
-    - `.tiny` is the smallest positive number greater than zero
-      (much smaller than `.eps`).
-
-    Args:
-        tensor (Tensor): tensor of floating point data.
-    Returns:
-        _Finfo: a `namedtuple` with fields `.eps` and `.tiny`.
-    """
-    return _FINFO[tensor.storage_type()]
-
 
 # promote numbers to tensors of dtype torch.get_default_dtype()
 def _default_promotion(v):
@@ -100,7 +73,7 @@ def logits_to_probs(logits, is_binary=False):
 
 
 def clamp_probs(probs):
-    eps = _finfo(probs).eps
+    eps = torch.finfo(probs.dtype).eps
     return probs.clamp(min=eps, max=1 - eps)