From b1154cc7741fa7ad4f075272347ff587ebf168f7 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 25 Aug 2021 22:04:44 -0700 Subject: [PATCH] enable equal_nan for complex values in isclose (#63571) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63571 Test Plan: Imported from OSS Reviewed By: malfet, ngimel Differential Revision: D30560127 Pulled By: mruberry fbshipit-source-id: 8958121ca24e7c139d869607903aebbe87bc0740 --- aten/src/ATen/native/TensorCompare.cpp | 6 ++---- test/test_testing.py | 30 ++++++++++++++++++++++++------ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 90a57d1..3f69cab 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -108,8 +108,6 @@ bool allclose(const Tensor& self, const Tensor& other, double rtol, double atol, // https://github.com/numpy/numpy/issues/15959 is resolved Tensor isclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) { TORCH_CHECK(self.scalar_type() == other.scalar_type(), self.scalar_type(), " did not match ", other.scalar_type()); - TORCH_CHECK(!(self.is_complex() && equal_nan), - "isclose with equal_nan=True is not supported for complex inputs."); TORCH_CHECK(!(self.is_quantized() || other.is_quantized()), "isclose is not supported for quantized inputs."); @@ -121,8 +119,8 @@ Tensor isclose(const Tensor& self, const Tensor& other, double rtol, double atol // Computes equality closeness Tensor close = self == other; - if (equal_nan && self.is_floating_point()) { - close.__ior__((self != self).__iand__(other != other)); + if (equal_nan && (self.is_floating_point() || self.is_complex())) { + close.__ior__(self.isnan().__iand__(other.isnan())); } // In case of zero tolerances the closeness inequality degenerates to an equality check. diff --git a/test/test_testing.py b/test/test_testing.py index d59290b..7e67569 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -335,8 +335,6 @@ class TestTesting(TestCase): self._comparetensors_helper(tests, device, dtype, True) - # torch.close with equal_nan=True is not implemented for complex inputs - # see https://github.com/numpy/numpy/issues/15959 # Note: compareTensor will compare the real and imaginary parts of a # complex tensors separately, unlike isclose. @dtypes(torch.complex64, torch.complex128) @@ -416,13 +414,20 @@ class TestTesting(TestCase): # equal_nan = True tests tests = ( (complex(1, 1), complex(1, float('nan')), False), - (complex(float('nan'), 1), complex(1, float('nan')), False), + (complex(1, 1), complex(float('nan'), 1), False), (complex(float('nan'), 1), complex(float('nan'), 1), True), + (complex(float('nan'), 1), complex(1, float('nan')), True), + (complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True), ) + self._isclose_helper(tests, device, dtype, True) - with self.assertRaises(RuntimeError): - self._isclose_helper(tests, device, dtype, True) - + tests = ( + (complex(1, 1), complex(1, float('nan')), False), + (complex(1, 1), complex(float('nan'), 1), False), + (complex(float('nan'), 1), complex(float('nan'), 1), True), + (complex(float('nan'), 1), complex(1, float('nan')), False), + (complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True), + ) self._comparetensors_helper(tests, device, dtype, True) # Tests that isclose with rtol or atol values less than zero throws a @@ -449,6 +454,19 @@ class TestTesting(TestCase): self.assertFalse(torch.isclose(a, b, rtol=0, atol=0)) + @dtypes(torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_isclose_nan_equality_shortcut(self, device, dtype): + if dtype.is_floating_point: + a = b = torch.nan + else: + a = complex(torch.nan, 0) + b = complex(0, torch.nan) + + expected = True + tests = [(a, b, expected)] + + self._isclose_helper(tests, device, dtype, equal_nan=True, rtol=0, atol=0) + @dtypes(torch.bool, torch.long, torch.float, torch.cfloat) def test_make_tensor(self, device, dtype): def check(size, low, high, requires_grad, noncontiguous): -- 2.7.4