Fixing user inputs for low, high in `make_tensor` (#61108)
authorKushashwa Ravi Shrimali <kushashwaravishrimali@gmail.com>
Fri, 13 Aug 2021 17:12:01 +0000 (10:12 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 13 Aug 2021 17:13:12 +0000 (10:13 -0700)
Summary:
**TODOs:**

* [x] Do not clamp inputs for low and high when given and valid.
* [x] Devise rules for modifying `low` and `high` when extremals/invalid values passed.
* [x] Testing with `test_references_numerics_hard` with the revised changes. _(I've tested locally, the changes will take place in a separate PR though after offline discussion with mruberry)_
* [x] Revise comments/documentation for `make_tensor`

See https://github.com/pytorch/pytorch/issues/61758 for tracker issue.

cc: mruberry pmeier

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61108

Reviewed By: VitalyFedyunin

Differential Revision: D30296167

Pulled By: mruberry

fbshipit-source-id: 67e8d15b173209a9c97ca013231494a5fa99f8c7

torch/testing/_internal/common_methods_invocations.py
torch/testing/_internal/common_utils.py

index 7e06e67..36e9b9a 100644 (file)
@@ -5223,7 +5223,7 @@ op_db: List[OpInfo] = [
     UnaryUfuncInfo('acosh',
                    aliases=('arccosh', ),
                    ref=np.arccosh,
-                   domain=(1, float('inf')),
+                   domain=(1, None),
                    dtypes=all_types_and_complex_and(torch.bool),
                    dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
@@ -5326,7 +5326,7 @@ op_db: List[OpInfo] = [
            supports_forward_ad=True,
            decorators=[
                DecorateInfo(
-                   toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1.3e-05),
+                   toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=1.3e-05),
                                       torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
                    'TestCommon', 'test_reference_testing')],
            skips=(
@@ -5446,11 +5446,16 @@ op_db: List[OpInfo] = [
                    domain=(-1, 1),
                    supports_sparse=True,
                    supports_forward_ad=True,
-                   decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
                    safe_casts_outputs=True,
                    dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    assert_autodiffed=True,
+                   decorators=[
+                       DecorateInfo(
+                           toleranceOverride({torch.float16: tol(atol=1e-05, rtol=1e-03)}),
+                           'TestUnaryUfuncs', device_type='cuda'),
+                       precisionOverride({torch.bfloat16: 1e-2}),
+                   ],
                    skips=(
                        SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_extremal',
                                 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
@@ -6227,6 +6232,9 @@ op_db: List[OpInfo] = [
            decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, skipCUDAIfRocm],
            supports_inplace_autograd=False,
            skips=(
+               # Will be removed once https://github.com/pytorch/pytorch/issues/62328 is fixed
+               # Probable fix (open PR): https://github.com/pytorch/pytorch/pull/62570
+               SkipInfo('TestGradients', 'test_fn_grad', device_type='cuda', dtypes=(torch.complex128,)),
                SkipInfo('TestCommon', 'test_dtypes'),
                SkipInfo('TestGradients', 'test_fn_gradgrad'),
                # This test fails because singular inputs cannot be reliably
@@ -6384,7 +6392,7 @@ op_db: List[OpInfo] = [
            )),
     UnaryUfuncInfo('log',
                    ref=np.log,
-                   domain=(0, float('inf')),
+                   domain=(0, None),
                    dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    assert_autodiffed=True,
@@ -6398,7 +6406,7 @@ op_db: List[OpInfo] = [
                    )),
     UnaryUfuncInfo('log10',
                    ref=np.log10,
-                   domain=(0, float('inf')),
+                   domain=(0, None),
                    decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
                    dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
                    assert_autodiffed=True,
@@ -6413,7 +6421,7 @@ op_db: List[OpInfo] = [
     UnaryUfuncInfo('log1p',
                    ref=np.log1p,
                    aliases=('special.log1p',),
-                   domain=(-1, float('inf')),
+                   domain=(-1, None),
                    dtypes=all_types_and(torch.bool, torch.bfloat16),
                    dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
                    decorators=(precisionOverride({torch.bfloat16: 1e-1}),),
@@ -6422,7 +6430,7 @@ op_db: List[OpInfo] = [
                    assert_autodiffed=True),
     UnaryUfuncInfo('log2',
                    ref=np.log2,
-                   domain=(0, float('inf')),
+                   domain=(0, None),
                    dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    assert_autodiffed=True,
@@ -6876,17 +6884,17 @@ op_db: List[OpInfo] = [
            supports_forward_ad=True,
            sample_inputs_func=sample_inputs_mode,),
     MvlGammaInfo(variant_test_name='mvlgamma_p_1',
-                 domain=(1, float('inf')),
+                 domain=(1, None),
                  skips=skips_mvlgamma(),
                  sample_kwargs=lambda device, dtype, input: ({'p': 1}, {'d': 1})),
     MvlGammaInfo(variant_test_name='mvlgamma_p_3',
-                 domain=(2, float('inf')),
+                 domain=(2, None),
                  skips=skips_mvlgamma(skip_redundant=True) + (
                      SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard', dtypes=(torch.float16,)),
                  ),
                  sample_kwargs=lambda device, dtype, input: ({'p': 3}, {'d': 3})),
     MvlGammaInfo(variant_test_name='mvlgamma_p_5',
-                 domain=(3, float('inf')),
+                 domain=(3, None),
                  skips=skips_mvlgamma(skip_redundant=True) + (
                      SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard', dtypes=(torch.float16,)),
                  ),
@@ -7389,7 +7397,7 @@ op_db: List[OpInfo] = [
                    )),
     UnaryUfuncInfo('rsqrt',
                    ref=lambda x: np.reciprocal(np.sqrt(x)),
-                   domain=(0, float('inf')),
+                   domain=(0, None),
                    dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    decorators=(precisionOverride({torch.half: 5e-2}),),
@@ -7400,7 +7408,7 @@ op_db: List[OpInfo] = [
     UnaryUfuncInfo('sqrt',
                    ref=np.sqrt,
                    supports_sparse=True,
-                   domain=(0, float('inf')),
+                   domain=(0, None),
                    dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    assert_autodiffed=True,
index fb2482b..fed9a00 100644 (file)
@@ -43,7 +43,7 @@ from unittest.mock import MagicMock
 
 import numpy as np
 
-from torch.testing import floating_types_and, integral_types, complex_types
+from torch.testing import floating_types_and, integral_types, complex_types, get_all_dtypes
 import expecttest
 from .._core import \
     (_compare_tensors_internal, _compare_scalars_internal, _compare_return_type)
@@ -1944,51 +1944,72 @@ def make_tensor(size, device: torch.device, dtype: torch.dtype, *, low=None, hig
                 exclude_zero: bool = False) -> torch.Tensor:
     """ Creates a random tensor with the given size, device and dtype.
 
-        By default, the tensor's values are in the range [-9, 9] for most dtypes. If low
-        and/or high are specified then the values will be in the range [max(-9, low), min(9, high)].
-
-        For unsigned types the values are in the range[0, 9] and for complex types the real and imaginary
-        parts are each in the range [-9, 9].
+        Default values for low and high:
+            * boolean type: low = 0, high = 2
+            * uint8 type: low = 0, high = 9
+            * floating and integral types: low = -9 and high = 9
+            * complex types, for each real and imaginary part: low = -9, high = 9
+        If low/high are specified and within dtype limits: low = low, high = high
+        If low/high are specified but exceed the limits: low = dtype_min, high = dtype_max
+        If low is -inf and/or high is inf: low = dtype_min, high = dtype_max
+        If low is inf or nan and/or high is -inf or nan: ValueError raised
 
         If noncontiguous=True, a noncontiguous tensor with the given size will be returned unless the size
         specifies a tensor with a 1 or 0 elements in which case the noncontiguous parameter is ignored because
         it is not possible to create a noncontiguous Tensor with a single element.
 
         If exclude_zero is passed with True (default is False), all the matching values (with zero) in
-        created tensor are replaced with an epsilon value if floating type, [`eps + `eps`.j] if
-        complex type and 1 if integer/boolean type.
+        created tensor are replaced with a tiny (smallest positive representable number) value if floating type,
+        [`tiny` + `tiny`.j] if complex type and 1 if integer/boolean type.
     """
+    def _modify_low_high(low, high, lowest, highest, default_low, default_high, dtype):
+        """
+        Modifies (and raises ValueError when appropriate) low and high values given by the user (input_low, input_high) if required.
+        """
+        def clamp(a, l, h):
+            return min(max(a, l), h)
+
+        low = low if low is not None else default_low
+        high = high if high is not None else default_high
+
+        # Checks for error cases
+        if low != low or high != high:
+            raise ValueError("make_tensor: one of low or high was NaN!")
+        if low > high:
+            raise ValueError("make_tensor: low must be weakly less than high!")
 
-    assert low is None or low < 9, "low value too high!"
-    assert high is None or high > -9, "high value too low!"
+        low = clamp(low, lowest, highest)
+        high = clamp(high, lowest, highest)
+
+        if dtype in integral_types():
+            return math.floor(low), math.ceil(high)
+
+        return low, high
 
     if dtype is torch.bool:
         result = torch.randint(0, 2, size, device=device, dtype=dtype)
     elif dtype is torch.uint8:
-        low = math.floor(0 if low is None else max(low, 0))
-        high = math.ceil(10 if high is None else min(high, 10))
+        ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max)
+        low, high = _modify_low_high(low, high, ranges[0], ranges[1], 0, 9, dtype)
         result = torch.randint(low, high, size, device=device, dtype=dtype)
     elif dtype in integral_types():
-        low = math.floor(-9 if low is None else max(low, -9))
-        high = math.ceil(10 if high is None else min(high, 10))
+        ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max)
+        low, high = _modify_low_high(low, high, ranges[0], ranges[1], -9, 9, dtype)
         result = torch.randint(low, high, size, device=device, dtype=dtype)
     elif dtype in floating_types_and(torch.half, torch.bfloat16):
-        low = -9 if low is None else max(low, -9)
-        high = 9 if high is None else min(high, 10)
-        span = high - low
-        # Windows doesn't support torch.rand(bfloat16) on CUDA
-        if IS_WINDOWS and torch.device(device).type == 'cuda' and dtype is torch.bfloat16:
-            result = (torch.rand(size, device=device, dtype=torch.float32) * span + low).to(torch.bfloat16)
-        else:
-            result = torch.rand(size, device=device, dtype=dtype) * span + low
+        ranges_floats = (torch.finfo(dtype).min, torch.finfo(dtype).max)
+        low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype)
+        rand_val = torch.rand(size, device=device, dtype=dtype)
+        result = high * rand_val + low * (1 - rand_val)
     else:
         assert dtype in complex_types()
-        low = -9 if low is None else max(low, -9)
-        high = 9 if high is None else min(high, 10)
-        span = high - low
         float_dtype = torch.float if dtype is torch.cfloat else torch.double
-        real = torch.rand(size, device=device, dtype=float_dtype) * span + low
-        imag = torch.rand(size, device=device, dtype=float_dtype) * span + low
+        ranges_floats = (torch.finfo(float_dtype).min, torch.finfo(float_dtype).max)
+        low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype)
+        real_rand_val = torch.rand(size, device=device, dtype=float_dtype)
+        imag_rand_val = torch.rand(size, device=device, dtype=float_dtype)
+        real = high * real_rand_val + low * (1 - real_rand_val)
+        imag = high * imag_rand_val + low * (1 - imag_rand_val)
         result = torch.complex(real, imag)
 
     if noncontiguous and result.numel() > 1:
@@ -1999,12 +2020,13 @@ def make_tensor(size, device: torch.device, dtype: torch.dtype, *, low=None, hig
         if dtype in integral_types() or dtype is torch.bool:
             replace_with = torch.tensor(1, device=device, dtype=dtype)
         elif dtype in floating_types_and(torch.half, torch.bfloat16):
-            replace_with = torch.tensor(torch.finfo(dtype).eps, device=device, dtype=dtype)
-        else:
-            assert dtype in complex_types()
+            replace_with = torch.tensor(torch.finfo(dtype).tiny, device=device, dtype=dtype)
+        elif dtype in complex_types():
             float_dtype = torch.float if dtype is torch.cfloat else torch.double
-            float_eps = torch.tensor(torch.finfo(float_dtype).eps, device=device, dtype=float_dtype)
+            float_eps = torch.tensor(torch.finfo(float_dtype).tiny, device=device, dtype=float_dtype)
             replace_with = torch.complex(float_eps, float_eps)
+        else:
+            raise ValueError(f"Invalid dtype passed, supported dtypes are: {get_all_dtypes()}")
         result[result == 0] = replace_with
 
     if dtype in floating_types_and(torch.half, torch.bfloat16) or\