enable equal_nan for complex values in isclose (#63571)
authorPhilip Meier <github.pmeier@posteo.de>
Thu, 26 Aug 2021 05:04:44 +0000 (22:04 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 05:05:49 +0000 (22:05 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63571

Test Plan: Imported from OSS

Reviewed By: malfet, ngimel

Differential Revision: D30560127

Pulled By: mruberry

fbshipit-source-id: 8958121ca24e7c139d869607903aebbe87bc0740

aten/src/ATen/native/TensorCompare.cpp
test/test_testing.py

index 90a57d1..3f69cab 100644 (file)
@@ -108,8 +108,6 @@ bool allclose(const Tensor& self, const Tensor& other, double rtol, double atol,
 //  https://github.com/numpy/numpy/issues/15959 is resolved
 Tensor isclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) {
   TORCH_CHECK(self.scalar_type() == other.scalar_type(), self.scalar_type(), " did not match ", other.scalar_type());
-  TORCH_CHECK(!(self.is_complex() && equal_nan),
-    "isclose with equal_nan=True is not supported for complex inputs.");
   TORCH_CHECK(!(self.is_quantized() || other.is_quantized()),
     "isclose is not supported for quantized inputs.");
 
@@ -121,8 +119,8 @@ Tensor isclose(const Tensor& self, const Tensor& other, double rtol, double atol
 
   // Computes equality closeness
   Tensor close = self == other;
-  if (equal_nan && self.is_floating_point()) {
-      close.__ior__((self != self).__iand__(other != other));
+  if (equal_nan && (self.is_floating_point() || self.is_complex())) {
+      close.__ior__(self.isnan().__iand__(other.isnan()));
   }
 
   // In case of zero tolerances the closeness inequality degenerates to an equality check.
index d59290b..7e67569 100644 (file)
@@ -335,8 +335,6 @@ class TestTesting(TestCase):
 
         self._comparetensors_helper(tests, device, dtype, True)
 
-    # torch.close with equal_nan=True is not implemented for complex inputs
-    # see https://github.com/numpy/numpy/issues/15959
     # Note: compareTensor will compare the real and imaginary parts of a
     # complex tensors separately, unlike isclose.
     @dtypes(torch.complex64, torch.complex128)
@@ -416,13 +414,20 @@ class TestTesting(TestCase):
         # equal_nan = True tests
         tests = (
             (complex(1, 1), complex(1, float('nan')), False),
-            (complex(float('nan'), 1), complex(1, float('nan')), False),
+            (complex(1, 1), complex(float('nan'), 1), False),
             (complex(float('nan'), 1), complex(float('nan'), 1), True),
+            (complex(float('nan'), 1), complex(1, float('nan')), True),
+            (complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True),
         )
+        self._isclose_helper(tests, device, dtype, True)
 
-        with self.assertRaises(RuntimeError):
-            self._isclose_helper(tests, device, dtype, True)
-
+        tests = (
+            (complex(1, 1), complex(1, float('nan')), False),
+            (complex(1, 1), complex(float('nan'), 1), False),
+            (complex(float('nan'), 1), complex(float('nan'), 1), True),
+            (complex(float('nan'), 1), complex(1, float('nan')), False),
+            (complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True),
+        )
         self._comparetensors_helper(tests, device, dtype, True)
 
     # Tests that isclose with rtol or atol values less than zero throws a
@@ -449,6 +454,19 @@ class TestTesting(TestCase):
 
         self.assertFalse(torch.isclose(a, b, rtol=0, atol=0))
 
+    @dtypes(torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128)
+    def test_isclose_nan_equality_shortcut(self, device, dtype):
+        if dtype.is_floating_point:
+            a = b = torch.nan
+        else:
+            a = complex(torch.nan, 0)
+            b = complex(0, torch.nan)
+
+        expected = True
+        tests = [(a, b, expected)]
+
+        self._isclose_helper(tests, device, dtype, equal_nan=True, rtol=0, atol=0)
+
     @dtypes(torch.bool, torch.long, torch.float, torch.cfloat)
     def test_make_tensor(self, device, dtype):
         def check(size, low, high, requires_grad, noncontiguous):