remove componentwise comparison of complex values in TestCase.assertEqual (#63572)
authorPhilip Meier <github.pmeier@posteo.de>
Mon, 30 Aug 2021 19:28:39 +0000 (12:28 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 19:36:45 +0000 (12:36 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63572

Addresses #61906. Issue will be fixed later in the stack when `torch.testing.assert_close` got the same treatment.

cc ezyang gchanan

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D30633527

Pulled By: mruberry

fbshipit-source-id: c2002a4998a7a75cb2ab83f87190bde43a9d4f7c

test/test_tensor_creation_ops.py
test/test_testing.py
test/test_torch.py
test/test_unary_ufuncs.py
torch/testing/_core.py

index 9ef3742..dcb4938 100644 (file)
@@ -3258,7 +3258,7 @@ class TestRandomTensorCreation(TestCase):
             self.assertTrue((res1 >= 0).all().item())
 
     @dtypes(torch.half, torch.float, torch.bfloat16, torch.double,
-            torch.complex32, torch.complex64, torch.complex128)
+            torch.complex64, torch.complex128)
     def test_randn(self, device, dtype):
         SIZE = 100
         for size in [0, SIZE]:
index f38183d..fdc3463 100644 (file)
@@ -88,25 +88,19 @@ class TestTesting(TestCase):
                         "atol=1e-05 is only 1.9100000000000003e-05!")
         self.assertEqual(debug_msg, expected_msg)
 
-        # complex x complex, real difference
+        # complex x complex
         result, debug_msg = self._compareScalars(complex(1, 3), complex(3, 1))
-        expected_msg = ("Comparing the real part 1.0 and 3.0 gives a difference "
-                        "of 2.0, but the allowed difference with rtol=1.3e-06 "
-                        "and atol=1e-05 is only 1.39e-05!")
-        self.assertEqual(debug_msg, expected_msg)
-
-        # complex x complex, imaginary difference
-        result, debug_msg = self._compareScalars(complex(1, 3), complex(1, 5.5))
-        expected_msg = ("Comparing the imaginary part 3.0 and 5.5 gives a "
-                        "difference of 2.5, but the allowed difference with "
-                        "rtol=1.3e-06 and atol=1e-05 is only 1.715e-05!")
+        expected_msg = ("Comparing (1+3j) and (3+1j) gives a difference "
+                        "of 2.8284271247461903, but the allowed difference "
+                        "with rtol=1.3e-06 and atol=1e-05 is only "
+                        "1.4110960958218895e-05!")
         self.assertEqual(debug_msg, expected_msg)
 
         # complex x int
         result, debug_msg = self._compareScalars(complex(1, -2), 1)
-        expected_msg = ("Comparing the imaginary part -2.0 and 0.0 gives a "
-                        "difference of 2.0, but the allowed difference with "
-                        "rtol=1.3e-06 and atol=1e-05 is only 1e-05!")
+        expected_msg = ("Comparing (1-2j) and 1 gives a difference of 2.0, "
+                        "but the allowed difference with rtol=1.3e-06 and "
+                        "atol=1e-05 is only 1.13e-05!")
         self.assertEqual(debug_msg, expected_msg)
 
         # NaN x NaN, equal_nan=False
@@ -170,28 +164,6 @@ class TestTesting(TestCase):
                         "occuring at index 0.")
         self.assertEqual(debug_msg, expected_msg)
 
-        # Checks complex tensor comparisons (real part)
-        a = torch.tensor((1 - 1j, 4 + 3j), device=device)
-        b = torch.tensor((1 - 1j, 1 + 3j), device=device)
-        result, debug_msg = self._compareTensors(a, b)
-        expected_msg = ("Real parts failed to compare as equal! "
-                        "With rtol=1.3e-06 and atol={0}, "
-                        "found 1 element(s) (out of 2) whose difference(s) exceeded the "
-                        "margin of error (including 0 nan comparisons). The greatest difference was "
-                        "3.0 (4.0 vs. 1.0), which occurred at index 1.").format(atol)
-        self.assertEqual(debug_msg, expected_msg)
-
-        # Checks complex tensor comparisons (imaginary part)
-        a = torch.tensor((1 - 1j, 4 + 3j), device=device)
-        b = torch.tensor((1 - 1j, 4 - 21j), device=device)
-        result, debug_msg = self._compareTensors(a, b)
-        expected_msg = ("Imaginary parts failed to compare as equal! "
-                        "With rtol=1.3e-06 and atol={0}, "
-                        "found 1 element(s) (out of 2) whose difference(s) exceeded the "
-                        "margin of error (including 0 nan comparisons). The greatest difference was "
-                        "24.0 (3.0 vs. -21.0), which occurred at index 1.").format(atol)
-        self.assertEqual(debug_msg, expected_msg)
-
         # Checks size mismatch
         a = torch.tensor((1, 2), device=device)
         b = torch.tensor((3), device=device)
@@ -407,7 +379,7 @@ class TestTesting(TestCase):
         tests = (
             (complex(1, -1), complex(-1, 1), False),
             (complex(1, -1), complex(2, -2), True),
-            (complex(1, 99), complex(4, 100), False),
+            (complex(1, 99), complex(4, 100), True),
         )
 
         self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5)
@@ -421,14 +393,6 @@ class TestTesting(TestCase):
             (complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True),
         )
         self._isclose_helper(tests, device, dtype, True)
-
-        tests = (
-            (complex(1, 1), complex(1, float('nan')), False),
-            (complex(1, 1), complex(float('nan'), 1), False),
-            (complex(float('nan'), 1), complex(float('nan'), 1), True),
-            (complex(float('nan'), 1), complex(1, float('nan')), False),
-            (complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True),
-        )
         self._comparetensors_helper(tests, device, dtype, True)
 
     # Tests that isclose with rtol or atol values less than zero throws a
index c50b7ca..b267b9c 100644 (file)
@@ -5121,7 +5121,7 @@ else:
                 spacing = [space.cpu().detach().numpy() for space in spacing]
             expected = np.gradient(t_np, *self._wrap_to_list(spacing), axis=dims, edge_order=edge_order)
             actual, expected = self._inf_nan_preprocess(list(actual), self._wrap_to_list(expected))
-            self.assertEqual(actual, expected, equal_nan="relaxed", atol=1e-4, rtol=0, exact_dtype=False)
+            self.assertEqual(actual, expected, equal_nan=True, atol=1e-4, rtol=0, exact_dtype=False)
 
     @onlyOnCPUAndCUDA
     @dtypes(torch.long, torch.float32, torch.complex64)
@@ -5188,7 +5188,7 @@ else:
                     self.assertEqual(expected[i].imag, torch.zeros(actual[i].shape), exact_dtype=False)
             else:
                 actual, expected = self._inf_nan_preprocess(list(actual), expected)
-                self.assertEqual(actual, expected, equal_nan="relaxed", exact_dtype=False)
+                self.assertEqual(actual, expected, equal_nan=True, exact_dtype=False)
 
     @onlyOnCPUAndCUDA
     @dtypes(torch.long, torch.float32, torch.complex64)
index 22f6151..526b67a 100644 (file)
@@ -359,10 +359,7 @@ class TestUnaryUfuncs(TestCase):
         tensors = generate_numeric_tensors_extremal(device, dtype,
                                                     domain=op.domain)
 
-        # https://github.com/pytorch/pytorch/issues/50749
-        equal_nan = "relaxed" if device.startswith('cuda') else True
-
-        self._test_reference_numerics(dtype, op, tensors, equal_nan)
+        self._test_reference_numerics(dtype, op, tensors)
 
     # Tests for testing (non)contiguity consistency
 
index d980615..66060f8 100644 (file)
@@ -6,7 +6,7 @@ import torch
 import random
 import math
 import cmath
-from typing import cast, List, Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union
 import operator
 
 FileCheck = torch._C.FileCheck
@@ -78,27 +78,12 @@ _compare_return_type = Tuple[bool, Optional[str]]
 #   Two tensors are "equal" if they are "close", in the sense of torch.allclose.
 #   The only exceptions are complex tensors and bool tensors.
 #
-#   Complex tensors are "equal" if both the
-#   real and complex parts (separately) are close. This is divergent from
-#   torch.allclose's behavior, which compares the absolute values of the
-#   complex numbers instead.
-#
-#   Using torch.allclose would be a less strict
-#   comparison that would allow large complex values with
-#   significant real or imaginary differences to be considered "equal,"
-#   and would make setting rtol and atol for complex tensors distinct from
-#   other tensor types.
-#
 #   Bool tensors are equal only if they are identical, regardless of
 #   the rtol and atol values.
 #
 #   The `equal_nan` can be True or False, which maps to the True or False
-#   in `torch.allclose`. `equal_nan` can also be "relaxed", which means
-#   the complex will be compared in the relaxed mode:
-#       2 + nan j == 3 + nan j ---> False when equal_nan=True
-#                                   True when equal_nan="relaxed"
-def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, equal_nan: Union[str, bool]) -> _compare_return_type:
-    assert equal_nan in {True, False, "relaxed"}
+#   in `torch.allclose`.
+def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, equal_nan) -> _compare_return_type:
     debug_msg : Optional[str]
     # Integer (including bool) comparisons are identity comparisons
     # when rtol is zero and atol is less than one
@@ -129,48 +114,19 @@ def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, e
                                    _unravel_index(greatest_diff_index, a.shape)))
         return (False, debug_msg)
 
-    # Compares complex tensors' real and imaginary parts separately.
-    # (see NOTE Test Framework Tensor "Equality")
-    if a.is_complex():
-        if equal_nan == "relaxed":
-            a = a.clone()
-            b = b.clone()
-            a.real[a.imag.isnan()] = math.nan
-            a.imag[a.real.isnan()] = math.nan
-            b.real[b.imag.isnan()] = math.nan
-            b.imag[b.real.isnan()] = math.nan
-
-        real_result, debug_msg = _compare_tensors_internal(a.real, b.real,
-                                                           rtol=rtol, atol=atol,
-                                                           equal_nan=equal_nan)
-
-        if not real_result:
-            debug_msg = "Real parts failed to compare as equal! " + cast(str, debug_msg)
-            return (real_result, debug_msg)
-
-        imag_result, debug_msg = _compare_tensors_internal(a.imag, b.imag,
-                                                           rtol=rtol, atol=atol,
-                                                           equal_nan=equal_nan)
-
-        if not imag_result:
-            debug_msg = "Imaginary parts failed to compare as equal! " + cast(str, debug_msg)
-            return (imag_result, debug_msg)
-
-        return (True, None)
-
     # All other comparisons use torch.allclose directly
-    if torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=(equal_nan in {"relaxed", True})):
+    if torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan):
         return (True, None)
 
     # Gathers debug info for failed float tensor comparison
     # NOTE: converts to float64 to best represent differences
-    a_flat = a.to(torch.float64).flatten()
-    b_flat = b.to(torch.float64).flatten()
+    a_flat = a.to(torch.float64 if not a.dtype.is_complex else torch.complex128).flatten()
+    b_flat = b.to(torch.float64 if not a.dtype.is_complex else torch.complex128).flatten()
     diff = torch.abs(a_flat - b_flat)
 
     # Masks close values
     # NOTE: this avoids (inf - inf) oddities when computing the difference
-    close = torch.isclose(a_flat, b_flat, rtol, atol, (equal_nan in {"relaxed", True}))
+    close = torch.isclose(a_flat, b_flat, rtol, atol, equal_nan)
     diff[close] = 0
     nans = torch.isnan(diff)
     num_nans = nans.sum()
@@ -212,7 +168,7 @@ def _compare_scalars_internal(a, b, *, rtol: float, atol: float, equal_nan: Unio
 
         # Special-case for infinity comparisons
         # NOTE: if b is inf then allowed_diff will be inf when rtol is not 0
-        if ((math.isinf(a) or math.isinf(b)) and a != b):
+        if ((cmath.isinf(a) or cmath.isinf(b)) and a != b):
             result = False
 
         msg = None
@@ -228,21 +184,6 @@ def _compare_scalars_internal(a, b, *, rtol: float, atol: float, equal_nan: Unio
                 )
         return result, msg
 
-    if isinstance(a, complex) or isinstance(b, complex):
-        a = complex(a)
-        b = complex(b)
-
-        if equal_nan == "relaxed":
-            if cmath.isnan(a) and cmath.isnan(b):
-                return (True, None)
-
-        result, msg = _helper(a.real, b.real, " the real part ")
-
-        if not result:
-            return (False, msg)
-
-        return _helper(a.imag, b.imag, " the imaginary part ")
-
     return _helper(a, b, " ")