From 401bbb2aa0a183ddfb309740c020fb4962367ac9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 30 Aug 2021 12:28:39 -0700 Subject: [PATCH] remove componentwise comparison of complex values in TestCase.assertEqual (#63572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63572 Addresses #61906. Issue will be fixed later in the stack when `torch.testing.assert_close` got the same treatment. cc ezyang gchanan Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D30633527 Pulled By: mruberry fbshipit-source-id: c2002a4998a7a75cb2ab83f87190bde43a9d4f7c --- test/test_tensor_creation_ops.py | 2 +- test/test_testing.py | 54 +++++------------------------ test/test_torch.py | 4 +-- test/test_unary_ufuncs.py | 5 +-- torch/testing/_core.py | 75 +++++----------------------------------- 5 files changed, 21 insertions(+), 119 deletions(-) diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 9ef3742..dcb4938 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -3258,7 +3258,7 @@ class TestRandomTensorCreation(TestCase): self.assertTrue((res1 >= 0).all().item()) @dtypes(torch.half, torch.float, torch.bfloat16, torch.double, - torch.complex32, torch.complex64, torch.complex128) + torch.complex64, torch.complex128) def test_randn(self, device, dtype): SIZE = 100 for size in [0, SIZE]: diff --git a/test/test_testing.py b/test/test_testing.py index f38183d..fdc3463 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -88,25 +88,19 @@ class TestTesting(TestCase): "atol=1e-05 is only 1.9100000000000003e-05!") self.assertEqual(debug_msg, expected_msg) - # complex x complex, real difference + # complex x complex result, debug_msg = self._compareScalars(complex(1, 3), complex(3, 1)) - expected_msg = ("Comparing the real part 1.0 and 3.0 gives a difference " - "of 2.0, but the allowed difference with rtol=1.3e-06 " - "and atol=1e-05 is only 1.39e-05!") - self.assertEqual(debug_msg, expected_msg) - - # complex x complex, imaginary difference - result, debug_msg = self._compareScalars(complex(1, 3), complex(1, 5.5)) - expected_msg = ("Comparing the imaginary part 3.0 and 5.5 gives a " - "difference of 2.5, but the allowed difference with " - "rtol=1.3e-06 and atol=1e-05 is only 1.715e-05!") + expected_msg = ("Comparing (1+3j) and (3+1j) gives a difference " + "of 2.8284271247461903, but the allowed difference " + "with rtol=1.3e-06 and atol=1e-05 is only " + "1.4110960958218895e-05!") self.assertEqual(debug_msg, expected_msg) # complex x int result, debug_msg = self._compareScalars(complex(1, -2), 1) - expected_msg = ("Comparing the imaginary part -2.0 and 0.0 gives a " - "difference of 2.0, but the allowed difference with " - "rtol=1.3e-06 and atol=1e-05 is only 1e-05!") + expected_msg = ("Comparing (1-2j) and 1 gives a difference of 2.0, " + "but the allowed difference with rtol=1.3e-06 and " + "atol=1e-05 is only 1.13e-05!") self.assertEqual(debug_msg, expected_msg) # NaN x NaN, equal_nan=False @@ -170,28 +164,6 @@ class TestTesting(TestCase): "occuring at index 0.") self.assertEqual(debug_msg, expected_msg) - # Checks complex tensor comparisons (real part) - a = torch.tensor((1 - 1j, 4 + 3j), device=device) - b = torch.tensor((1 - 1j, 1 + 3j), device=device) - result, debug_msg = self._compareTensors(a, b) - expected_msg = ("Real parts failed to compare as equal! " - "With rtol=1.3e-06 and atol={0}, " - "found 1 element(s) (out of 2) whose difference(s) exceeded the " - "margin of error (including 0 nan comparisons). The greatest difference was " - "3.0 (4.0 vs. 1.0), which occurred at index 1.").format(atol) - self.assertEqual(debug_msg, expected_msg) - - # Checks complex tensor comparisons (imaginary part) - a = torch.tensor((1 - 1j, 4 + 3j), device=device) - b = torch.tensor((1 - 1j, 4 - 21j), device=device) - result, debug_msg = self._compareTensors(a, b) - expected_msg = ("Imaginary parts failed to compare as equal! " - "With rtol=1.3e-06 and atol={0}, " - "found 1 element(s) (out of 2) whose difference(s) exceeded the " - "margin of error (including 0 nan comparisons). The greatest difference was " - "24.0 (3.0 vs. -21.0), which occurred at index 1.").format(atol) - self.assertEqual(debug_msg, expected_msg) - # Checks size mismatch a = torch.tensor((1, 2), device=device) b = torch.tensor((3), device=device) @@ -407,7 +379,7 @@ class TestTesting(TestCase): tests = ( (complex(1, -1), complex(-1, 1), False), (complex(1, -1), complex(2, -2), True), - (complex(1, 99), complex(4, 100), False), + (complex(1, 99), complex(4, 100), True), ) self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5) @@ -421,14 +393,6 @@ class TestTesting(TestCase): (complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True), ) self._isclose_helper(tests, device, dtype, True) - - tests = ( - (complex(1, 1), complex(1, float('nan')), False), - (complex(1, 1), complex(float('nan'), 1), False), - (complex(float('nan'), 1), complex(float('nan'), 1), True), - (complex(float('nan'), 1), complex(1, float('nan')), False), - (complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True), - ) self._comparetensors_helper(tests, device, dtype, True) # Tests that isclose with rtol or atol values less than zero throws a diff --git a/test/test_torch.py b/test/test_torch.py index c50b7ca..b267b9c 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -5121,7 +5121,7 @@ else: spacing = [space.cpu().detach().numpy() for space in spacing] expected = np.gradient(t_np, *self._wrap_to_list(spacing), axis=dims, edge_order=edge_order) actual, expected = self._inf_nan_preprocess(list(actual), self._wrap_to_list(expected)) - self.assertEqual(actual, expected, equal_nan="relaxed", atol=1e-4, rtol=0, exact_dtype=False) + self.assertEqual(actual, expected, equal_nan=True, atol=1e-4, rtol=0, exact_dtype=False) @onlyOnCPUAndCUDA @dtypes(torch.long, torch.float32, torch.complex64) @@ -5188,7 +5188,7 @@ else: self.assertEqual(expected[i].imag, torch.zeros(actual[i].shape), exact_dtype=False) else: actual, expected = self._inf_nan_preprocess(list(actual), expected) - self.assertEqual(actual, expected, equal_nan="relaxed", exact_dtype=False) + self.assertEqual(actual, expected, equal_nan=True, exact_dtype=False) @onlyOnCPUAndCUDA @dtypes(torch.long, torch.float32, torch.complex64) diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 22f6151..526b67a 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -359,10 +359,7 @@ class TestUnaryUfuncs(TestCase): tensors = generate_numeric_tensors_extremal(device, dtype, domain=op.domain) - # https://github.com/pytorch/pytorch/issues/50749 - equal_nan = "relaxed" if device.startswith('cuda') else True - - self._test_reference_numerics(dtype, op, tensors, equal_nan) + self._test_reference_numerics(dtype, op, tensors) # Tests for testing (non)contiguity consistency diff --git a/torch/testing/_core.py b/torch/testing/_core.py index d980615..66060f8 100644 --- a/torch/testing/_core.py +++ b/torch/testing/_core.py @@ -6,7 +6,7 @@ import torch import random import math import cmath -from typing import cast, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import operator FileCheck = torch._C.FileCheck @@ -78,27 +78,12 @@ _compare_return_type = Tuple[bool, Optional[str]] # Two tensors are "equal" if they are "close", in the sense of torch.allclose. # The only exceptions are complex tensors and bool tensors. # -# Complex tensors are "equal" if both the -# real and complex parts (separately) are close. This is divergent from -# torch.allclose's behavior, which compares the absolute values of the -# complex numbers instead. -# -# Using torch.allclose would be a less strict -# comparison that would allow large complex values with -# significant real or imaginary differences to be considered "equal," -# and would make setting rtol and atol for complex tensors distinct from -# other tensor types. -# # Bool tensors are equal only if they are identical, regardless of # the rtol and atol values. # # The `equal_nan` can be True or False, which maps to the True or False -# in `torch.allclose`. `equal_nan` can also be "relaxed", which means -# the complex will be compared in the relaxed mode: -# 2 + nan j == 3 + nan j ---> False when equal_nan=True -# True when equal_nan="relaxed" -def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, equal_nan: Union[str, bool]) -> _compare_return_type: - assert equal_nan in {True, False, "relaxed"} +# in `torch.allclose`. +def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, equal_nan) -> _compare_return_type: debug_msg : Optional[str] # Integer (including bool) comparisons are identity comparisons # when rtol is zero and atol is less than one @@ -129,48 +114,19 @@ def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, e _unravel_index(greatest_diff_index, a.shape))) return (False, debug_msg) - # Compares complex tensors' real and imaginary parts separately. - # (see NOTE Test Framework Tensor "Equality") - if a.is_complex(): - if equal_nan == "relaxed": - a = a.clone() - b = b.clone() - a.real[a.imag.isnan()] = math.nan - a.imag[a.real.isnan()] = math.nan - b.real[b.imag.isnan()] = math.nan - b.imag[b.real.isnan()] = math.nan - - real_result, debug_msg = _compare_tensors_internal(a.real, b.real, - rtol=rtol, atol=atol, - equal_nan=equal_nan) - - if not real_result: - debug_msg = "Real parts failed to compare as equal! " + cast(str, debug_msg) - return (real_result, debug_msg) - - imag_result, debug_msg = _compare_tensors_internal(a.imag, b.imag, - rtol=rtol, atol=atol, - equal_nan=equal_nan) - - if not imag_result: - debug_msg = "Imaginary parts failed to compare as equal! " + cast(str, debug_msg) - return (imag_result, debug_msg) - - return (True, None) - # All other comparisons use torch.allclose directly - if torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=(equal_nan in {"relaxed", True})): + if torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan): return (True, None) # Gathers debug info for failed float tensor comparison # NOTE: converts to float64 to best represent differences - a_flat = a.to(torch.float64).flatten() - b_flat = b.to(torch.float64).flatten() + a_flat = a.to(torch.float64 if not a.dtype.is_complex else torch.complex128).flatten() + b_flat = b.to(torch.float64 if not a.dtype.is_complex else torch.complex128).flatten() diff = torch.abs(a_flat - b_flat) # Masks close values # NOTE: this avoids (inf - inf) oddities when computing the difference - close = torch.isclose(a_flat, b_flat, rtol, atol, (equal_nan in {"relaxed", True})) + close = torch.isclose(a_flat, b_flat, rtol, atol, equal_nan) diff[close] = 0 nans = torch.isnan(diff) num_nans = nans.sum() @@ -212,7 +168,7 @@ def _compare_scalars_internal(a, b, *, rtol: float, atol: float, equal_nan: Unio # Special-case for infinity comparisons # NOTE: if b is inf then allowed_diff will be inf when rtol is not 0 - if ((math.isinf(a) or math.isinf(b)) and a != b): + if ((cmath.isinf(a) or cmath.isinf(b)) and a != b): result = False msg = None @@ -228,21 +184,6 @@ def _compare_scalars_internal(a, b, *, rtol: float, atol: float, equal_nan: Unio ) return result, msg - if isinstance(a, complex) or isinstance(b, complex): - a = complex(a) - b = complex(b) - - if equal_nan == "relaxed": - if cmath.isnan(a) and cmath.isnan(b): - return (True, None) - - result, msg = _helper(a.real, b.real, " the real part ") - - if not result: - return (False, msg) - - return _helper(a.imag, b.imag, " the imaginary part ") - return _helper(a, b, " ") -- 2.7.4