From eafe33c995d47d45fceaf42801717f3120d799b9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 30 Aug 2021 12:28:39 -0700 Subject: [PATCH] remove componentwise comparison of complex values in torch.testing.assert_close (#63841) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63841 Closes #61906. cc ezyang gchanan Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D30633526 Pulled By: mruberry fbshipit-source-id: ddb5d61838cd1e12d19d0093799e827344382cdc --- test/test_testing.py | 65 ++++++++++++++++++----------------- torch/testing/_asserts.py | 86 ++++++++++------------------------------------- 2 files changed, 50 insertions(+), 101 deletions(-) diff --git a/test/test_testing.py b/test/test_testing.py index fdc3463..a5ea232 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -863,20 +863,43 @@ class TestAssertClose(TestCase): for fn in assert_close_with_inputs(actual, expected): fn(rtol=0.0, atol=eps * 2) - def test_matching_nan(self): - actual = torch.tensor(float("NaN")) - expected = actual.clone() + # TODO: the code that this test was designed for was removed in https://github.com/pytorch/pytorch/pull/56058 + # We need to check if this test is still needed or if this behavior is now enabled by default. + def test_matching_conjugate_bit(self): + actual = torch.tensor(complex(1, 1)).conj() + expected = torch.tensor(complex(1, -1)) for fn in assert_close_with_inputs(actual, expected): - with self.assertRaises(AssertionError): - fn() + fn() + + def test_matching_nan(self): + nan = float("NaN") + + tests = ( + (nan, nan), + (complex(nan, 0), complex(0, nan)), + (complex(nan, nan), complex(nan, 0)), + (complex(nan, nan), complex(nan, nan)), + ) + + for actual, expected in tests: + for fn in assert_close_with_inputs(actual, expected): + with self.assertRaises(AssertionError): + fn() def test_matching_nan_with_equal_nan(self): - actual = torch.tensor(float("NaN")) - expected = actual.clone() + nan = float("NaN") - for fn in assert_close_with_inputs(actual, expected): - fn(equal_nan=True) + tests = ( + (nan, nan), + (complex(nan, 0), complex(0, nan)), + (complex(nan, nan), complex(nan, 0)), + (complex(nan, nan), complex(nan, nan)), + ) + + for actual, expected in tests: + for fn in assert_close_with_inputs(actual, expected): + fn(equal_nan=True) def test_numpy(self): tensor = torch.rand(2, 2, dtype=torch.float32) @@ -1181,30 +1204,6 @@ class TestAssertCloseContainer(TestCase): torch.testing.assert_close(actual, expected) -class TestAssertCloseComplex(TestCase): - def test_mismatching_nan_with_equal_nan(self): - actual = torch.tensor(complex(1, float("NaN"))) - expected = torch.tensor(complex(float("NaN"), 1)) - - for fn in assert_close_with_inputs(actual, expected): - with self.assertRaises(AssertionError): - fn(equal_nan=True) - - def test_mismatching_nan_with_equal_nan_relaxed(self): - actual = torch.tensor(complex(1, float("NaN"))) - expected = torch.tensor(complex(float("NaN"), 1)) - - for fn in assert_close_with_inputs(actual, expected): - fn(equal_nan="relaxed") - - def test_matching_conjugate_bit(self): - actual = torch.tensor(complex(1, 1)).conj() - expected = torch.tensor(complex(1, -1)) - - for fn in assert_close_with_inputs(actual, expected): - fn() - - class TestAssertCloseSparseCOO(TestCase): def test_matching_coalesced(self): indices = ( diff --git a/torch/testing/_asserts.py b/torch/testing/_asserts.py index 2de2cc0..073e2e2 100644 --- a/torch/testing/_asserts.py +++ b/torch/testing/_asserts.py @@ -44,52 +44,6 @@ def _get_default_rtol_and_atol(actual: Tensor, expected: Tensor) -> Tuple[float, return max(actual_rtol, expected_rtol), max(actual_atol, expected_atol) -def _check_complex_components_individually( - check_tensors: Callable[..., Optional[_TestingErrorMeta]] -) -> Callable[..., Optional[_TestingErrorMeta]]: - """Decorates real-valued tensor check functions to handle complex components individually. - - If the inputs are not complex, this decorator is a no-op. - - Args: - check_tensors (Callable[[Tensor, Tensor], Optional[_TestingErrorMeta]]): Tensor check function for real-valued - tensors. - """ - - @functools.wraps(check_tensors) - def wrapper( - actual: Tensor, expected: Tensor, *, equal_nan: Union[str, bool], **kwargs: Any - ) -> Optional[_TestingErrorMeta]: - if equal_nan == "relaxed": - relaxed_complex_nan = True - equal_nan = True - else: - relaxed_complex_nan = False - - if actual.dtype not in (torch.complex32, torch.complex64, torch.complex128): - return check_tensors(actual, expected, equal_nan=equal_nan, **kwargs) - - if relaxed_complex_nan: - actual, expected = [ - t.clone().masked_fill( - t.real.isnan() | t.imag.isnan(), complex(float("NaN"), float("NaN")) # type: ignore[call-overload] - ) - for t in (actual, expected) - ] - - error_meta = check_tensors(actual.real, expected.real, equal_nan=equal_nan, **kwargs) - if error_meta: - return error_meta - - error_meta = check_tensors(actual.imag, expected.imag, equal_nan=equal_nan, **kwargs) - if error_meta: - return error_meta - - return None - - return wrapper - - def _check_sparse_coo_members_individually( check_tensors: Callable[..., Optional[_TestingErrorMeta]] ) -> Callable[..., Optional[_TestingErrorMeta]]: @@ -430,10 +384,24 @@ def _make_mismatch_msg( return msg.strip() +def _get_comparison_dtype(dtype: torch.dtype) -> torch.dtype: + """Selects the comparison dtype based on the input dtype. + + Returns: + Highest precision dtype of the same dtype category as the input. :class:`torch.bool` is treated as integral + dtype. + """ + if dtype.is_complex: + return torch.complex128 + elif dtype.is_floating_point: + return torch.float64 + else: + return torch.int64 + + @_check_quantized @_check_sparse_coo_members_individually @_check_sparse_csr_members_individually -@_check_complex_components_individually def _check_values_close( actual: Tensor, expected: Tensor, @@ -457,7 +425,7 @@ def _check_values_close( Returns: (Optional[AssertionError]): If check did not pass. """ - dtype = torch.float64 if actual.dtype.is_floating_point else torch.int64 + dtype = _get_comparison_dtype(actual.dtype) actual = actual.to(dtype) expected = expected.to(dtype) mismatches = ~torch.isclose(actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan) @@ -740,7 +708,7 @@ def assert_close( allow_subclasses: bool = True, rtol: Optional[float] = None, atol: Optional[float] = None, - equal_nan: Union[bool, str] = False, + equal_nan: bool = False, check_device: bool = True, check_dtype: bool = True, check_stride: bool = False, @@ -761,9 +729,6 @@ def assert_close( (``-inf`` and ``inf``) are only considered close if and only if they are equal. ``NaN``'s are only considered equal to each other if :attr:`equal_nan` is ``True``. - If :attr:`actual` and :attr:`expected` are complex-valued, they are considered close if both their real and - imaginary components are considered close according to the definition above. - If :attr:`actual` and :attr:`expected` are sparse (either having COO or CSR layout), their strided members are checked individually. Indices, namely ``indices`` for COO or ``crow_indices`` and ``col_indices`` for CSR layout, are always checked for equality whereas the values are checked for closeness according to the definition above. @@ -795,8 +760,7 @@ def assert_close( default values based on the :attr:`~torch.Tensor.dtype` are selected with the below table. atol (Optional[float]): Absolute tolerance. If specified :attr:`rtol` must also be specified. If omitted, default values based on the :attr:`~torch.Tensor.dtype` are selected with the below table. - equal_nan (Union[bool, str]): If ``True``, two ``NaN`` values will be considered equal. If ``"relaxed"``, - complex values are considered as ``NaN`` if either the real **or** imaginary component is ``NaN``. + equal_nan (Union[bool, str]): If ``True``, two ``NaN`` values will be considered equal. check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same :attr:`~torch.Tensor.device`. If this check is disabled, tensors on different :attr:`~torch.Tensor.device`'s are moved to the CPU before being compared. @@ -956,20 +920,6 @@ def assert_close( Relative difference: nan (up to 1.3e-06 allowed) >>> torch.testing.assert_close(actual, expected, equal_nan=True) - >>> # If equal_nan=True, the real and imaginary NaN's of complex inputs have to match. - >>> expected = torch.tensor(complex(float("NaN"), 0)) - >>> actual = torch.tensor(complex(0, float("NaN"))) - >>> torch.testing.assert_close(actual, expected, equal_nan=True) - Traceback (most recent call last): - ... - AssertionError: Scalars are not close! - - Absolute difference: nan (up to 1e-05 allowed) - Relative difference: nan (up to 1.3e-06 allowed) - >>> # If equal_nan="relaxed", however, then complex numbers are treated as NaN if any - >>> # of the real or imaginary components is NaN. - >>> torch.testing.assert_close(actual, expected, equal_nan="relaxed") - >>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default mismatch message can be overwritten. -- 2.7.4