From fcc1f87b6aacdfba9193691568bc405be42ac77e Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Fri, 13 Aug 2021 10:12:01 -0700 Subject: [PATCH] Fixing user inputs for low, high in `make_tensor` (#61108) Summary: **TODOs:** * [x] Do not clamp inputs for low and high when given and valid. * [x] Devise rules for modifying `low` and `high` when extremals/invalid values passed. * [x] Testing with `test_references_numerics_hard` with the revised changes. _(I've tested locally, the changes will take place in a separate PR though after offline discussion with mruberry)_ * [x] Revise comments/documentation for `make_tensor` See https://github.com/pytorch/pytorch/issues/61758 for tracker issue. cc: mruberry pmeier Pull Request resolved: https://github.com/pytorch/pytorch/pull/61108 Reviewed By: VitalyFedyunin Differential Revision: D30296167 Pulled By: mruberry fbshipit-source-id: 67e8d15b173209a9c97ca013231494a5fa99f8c7 --- .../_internal/common_methods_invocations.py | 32 +++++---- torch/testing/_internal/common_utils.py | 84 ++++++++++++++-------- 2 files changed, 73 insertions(+), 43 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 7e06e67..36e9b9a 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -5223,7 +5223,7 @@ op_db: List[OpInfo] = [ UnaryUfuncInfo('acosh', aliases=('arccosh', ), ref=np.arccosh, - domain=(1, float('inf')), + domain=(1, None), dtypes=all_types_and_complex_and(torch.bool), dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), @@ -5326,7 +5326,7 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, decorators=[ DecorateInfo( - toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1.3e-05), + toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=1.3e-05), torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), 'TestCommon', 'test_reference_testing')], skips=( @@ -5446,11 +5446,16 @@ op_db: List[OpInfo] = [ domain=(-1, 1), supports_sparse=True, supports_forward_ad=True, - decorators=(precisionOverride({torch.bfloat16: 1e-2}),), safe_casts_outputs=True, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), assert_autodiffed=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-05, rtol=1e-03)}), + 'TestUnaryUfuncs', device_type='cuda'), + precisionOverride({torch.bfloat16: 1e-2}), + ], skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), @@ -6227,6 +6232,9 @@ op_db: List[OpInfo] = [ decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, skipCUDAIfRocm], supports_inplace_autograd=False, skips=( + # Will be removed once https://github.com/pytorch/pytorch/issues/62328 is fixed + # Probable fix (open PR): https://github.com/pytorch/pytorch/pull/62570 + SkipInfo('TestGradients', 'test_fn_grad', device_type='cuda', dtypes=(torch.complex128,)), SkipInfo('TestCommon', 'test_dtypes'), SkipInfo('TestGradients', 'test_fn_gradgrad'), # This test fails because singular inputs cannot be reliably @@ -6384,7 +6392,7 @@ op_db: List[OpInfo] = [ )), UnaryUfuncInfo('log', ref=np.log, - domain=(0, float('inf')), + domain=(0, None), dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), assert_autodiffed=True, @@ -6398,7 +6406,7 @@ op_db: List[OpInfo] = [ )), UnaryUfuncInfo('log10', ref=np.log10, - domain=(0, float('inf')), + domain=(0, None), decorators=(precisionOverride({torch.bfloat16: 5e-2}),), dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), assert_autodiffed=True, @@ -6413,7 +6421,7 @@ op_db: List[OpInfo] = [ UnaryUfuncInfo('log1p', ref=np.log1p, aliases=('special.log1p',), - domain=(-1, float('inf')), + domain=(-1, None), dtypes=all_types_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), decorators=(precisionOverride({torch.bfloat16: 1e-1}),), @@ -6422,7 +6430,7 @@ op_db: List[OpInfo] = [ assert_autodiffed=True), UnaryUfuncInfo('log2', ref=np.log2, - domain=(0, float('inf')), + domain=(0, None), dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), assert_autodiffed=True, @@ -6876,17 +6884,17 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, sample_inputs_func=sample_inputs_mode,), MvlGammaInfo(variant_test_name='mvlgamma_p_1', - domain=(1, float('inf')), + domain=(1, None), skips=skips_mvlgamma(), sample_kwargs=lambda device, dtype, input: ({'p': 1}, {'d': 1})), MvlGammaInfo(variant_test_name='mvlgamma_p_3', - domain=(2, float('inf')), + domain=(2, None), skips=skips_mvlgamma(skip_redundant=True) + ( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard', dtypes=(torch.float16,)), ), sample_kwargs=lambda device, dtype, input: ({'p': 3}, {'d': 3})), MvlGammaInfo(variant_test_name='mvlgamma_p_5', - domain=(3, float('inf')), + domain=(3, None), skips=skips_mvlgamma(skip_redundant=True) + ( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard', dtypes=(torch.float16,)), ), @@ -7389,7 +7397,7 @@ op_db: List[OpInfo] = [ )), UnaryUfuncInfo('rsqrt', ref=lambda x: np.reciprocal(np.sqrt(x)), - domain=(0, float('inf')), + domain=(0, None), dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), decorators=(precisionOverride({torch.half: 5e-2}),), @@ -7400,7 +7408,7 @@ op_db: List[OpInfo] = [ UnaryUfuncInfo('sqrt', ref=np.sqrt, supports_sparse=True, - domain=(0, float('inf')), + domain=(0, None), dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), assert_autodiffed=True, diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index fb2482b..fed9a00 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -43,7 +43,7 @@ from unittest.mock import MagicMock import numpy as np -from torch.testing import floating_types_and, integral_types, complex_types +from torch.testing import floating_types_and, integral_types, complex_types, get_all_dtypes import expecttest from .._core import \ (_compare_tensors_internal, _compare_scalars_internal, _compare_return_type) @@ -1944,51 +1944,72 @@ def make_tensor(size, device: torch.device, dtype: torch.dtype, *, low=None, hig exclude_zero: bool = False) -> torch.Tensor: """ Creates a random tensor with the given size, device and dtype. - By default, the tensor's values are in the range [-9, 9] for most dtypes. If low - and/or high are specified then the values will be in the range [max(-9, low), min(9, high)]. - - For unsigned types the values are in the range[0, 9] and for complex types the real and imaginary - parts are each in the range [-9, 9]. + Default values for low and high: + * boolean type: low = 0, high = 2 + * uint8 type: low = 0, high = 9 + * floating and integral types: low = -9 and high = 9 + * complex types, for each real and imaginary part: low = -9, high = 9 + If low/high are specified and within dtype limits: low = low, high = high + If low/high are specified but exceed the limits: low = dtype_min, high = dtype_max + If low is -inf and/or high is inf: low = dtype_min, high = dtype_max + If low is inf or nan and/or high is -inf or nan: ValueError raised If noncontiguous=True, a noncontiguous tensor with the given size will be returned unless the size specifies a tensor with a 1 or 0 elements in which case the noncontiguous parameter is ignored because it is not possible to create a noncontiguous Tensor with a single element. If exclude_zero is passed with True (default is False), all the matching values (with zero) in - created tensor are replaced with an epsilon value if floating type, [`eps + `eps`.j] if - complex type and 1 if integer/boolean type. + created tensor are replaced with a tiny (smallest positive representable number) value if floating type, + [`tiny` + `tiny`.j] if complex type and 1 if integer/boolean type. """ + def _modify_low_high(low, high, lowest, highest, default_low, default_high, dtype): + """ + Modifies (and raises ValueError when appropriate) low and high values given by the user (input_low, input_high) if required. + """ + def clamp(a, l, h): + return min(max(a, l), h) + + low = low if low is not None else default_low + high = high if high is not None else default_high + + # Checks for error cases + if low != low or high != high: + raise ValueError("make_tensor: one of low or high was NaN!") + if low > high: + raise ValueError("make_tensor: low must be weakly less than high!") - assert low is None or low < 9, "low value too high!" - assert high is None or high > -9, "high value too low!" + low = clamp(low, lowest, highest) + high = clamp(high, lowest, highest) + + if dtype in integral_types(): + return math.floor(low), math.ceil(high) + + return low, high if dtype is torch.bool: result = torch.randint(0, 2, size, device=device, dtype=dtype) elif dtype is torch.uint8: - low = math.floor(0 if low is None else max(low, 0)) - high = math.ceil(10 if high is None else min(high, 10)) + ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max) + low, high = _modify_low_high(low, high, ranges[0], ranges[1], 0, 9, dtype) result = torch.randint(low, high, size, device=device, dtype=dtype) elif dtype in integral_types(): - low = math.floor(-9 if low is None else max(low, -9)) - high = math.ceil(10 if high is None else min(high, 10)) + ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max) + low, high = _modify_low_high(low, high, ranges[0], ranges[1], -9, 9, dtype) result = torch.randint(low, high, size, device=device, dtype=dtype) elif dtype in floating_types_and(torch.half, torch.bfloat16): - low = -9 if low is None else max(low, -9) - high = 9 if high is None else min(high, 10) - span = high - low - # Windows doesn't support torch.rand(bfloat16) on CUDA - if IS_WINDOWS and torch.device(device).type == 'cuda' and dtype is torch.bfloat16: - result = (torch.rand(size, device=device, dtype=torch.float32) * span + low).to(torch.bfloat16) - else: - result = torch.rand(size, device=device, dtype=dtype) * span + low + ranges_floats = (torch.finfo(dtype).min, torch.finfo(dtype).max) + low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype) + rand_val = torch.rand(size, device=device, dtype=dtype) + result = high * rand_val + low * (1 - rand_val) else: assert dtype in complex_types() - low = -9 if low is None else max(low, -9) - high = 9 if high is None else min(high, 10) - span = high - low float_dtype = torch.float if dtype is torch.cfloat else torch.double - real = torch.rand(size, device=device, dtype=float_dtype) * span + low - imag = torch.rand(size, device=device, dtype=float_dtype) * span + low + ranges_floats = (torch.finfo(float_dtype).min, torch.finfo(float_dtype).max) + low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype) + real_rand_val = torch.rand(size, device=device, dtype=float_dtype) + imag_rand_val = torch.rand(size, device=device, dtype=float_dtype) + real = high * real_rand_val + low * (1 - real_rand_val) + imag = high * imag_rand_val + low * (1 - imag_rand_val) result = torch.complex(real, imag) if noncontiguous and result.numel() > 1: @@ -1999,12 +2020,13 @@ def make_tensor(size, device: torch.device, dtype: torch.dtype, *, low=None, hig if dtype in integral_types() or dtype is torch.bool: replace_with = torch.tensor(1, device=device, dtype=dtype) elif dtype in floating_types_and(torch.half, torch.bfloat16): - replace_with = torch.tensor(torch.finfo(dtype).eps, device=device, dtype=dtype) - else: - assert dtype in complex_types() + replace_with = torch.tensor(torch.finfo(dtype).tiny, device=device, dtype=dtype) + elif dtype in complex_types(): float_dtype = torch.float if dtype is torch.cfloat else torch.double - float_eps = torch.tensor(torch.finfo(float_dtype).eps, device=device, dtype=float_dtype) + float_eps = torch.tensor(torch.finfo(float_dtype).tiny, device=device, dtype=float_dtype) replace_with = torch.complex(float_eps, float_eps) + else: + raise ValueError(f"Invalid dtype passed, supported dtypes are: {get_all_dtypes()}") result[result == 0] = replace_with if dtype in floating_types_and(torch.half, torch.bfloat16) or\ -- 2.7.4