// https://github.com/numpy/numpy/issues/15959 is resolved
Tensor isclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) {
TORCH_CHECK(self.scalar_type() == other.scalar_type(), self.scalar_type(), " did not match ", other.scalar_type());
- TORCH_CHECK(!(self.is_complex() && equal_nan),
- "isclose with equal_nan=True is not supported for complex inputs.");
TORCH_CHECK(!(self.is_quantized() || other.is_quantized()),
"isclose is not supported for quantized inputs.");
// Computes equality closeness
Tensor close = self == other;
- if (equal_nan && self.is_floating_point()) {
- close.__ior__((self != self).__iand__(other != other));
+ if (equal_nan && (self.is_floating_point() || self.is_complex())) {
+ close.__ior__(self.isnan().__iand__(other.isnan()));
}
// In case of zero tolerances the closeness inequality degenerates to an equality check.
self._comparetensors_helper(tests, device, dtype, True)
- # torch.close with equal_nan=True is not implemented for complex inputs
- # see https://github.com/numpy/numpy/issues/15959
# Note: compareTensor will compare the real and imaginary parts of a
# complex tensors separately, unlike isclose.
@dtypes(torch.complex64, torch.complex128)
# equal_nan = True tests
tests = (
(complex(1, 1), complex(1, float('nan')), False),
- (complex(float('nan'), 1), complex(1, float('nan')), False),
+ (complex(1, 1), complex(float('nan'), 1), False),
(complex(float('nan'), 1), complex(float('nan'), 1), True),
+ (complex(float('nan'), 1), complex(1, float('nan')), True),
+ (complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True),
)
+ self._isclose_helper(tests, device, dtype, True)
- with self.assertRaises(RuntimeError):
- self._isclose_helper(tests, device, dtype, True)
-
+ tests = (
+ (complex(1, 1), complex(1, float('nan')), False),
+ (complex(1, 1), complex(float('nan'), 1), False),
+ (complex(float('nan'), 1), complex(float('nan'), 1), True),
+ (complex(float('nan'), 1), complex(1, float('nan')), False),
+ (complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True),
+ )
self._comparetensors_helper(tests, device, dtype, True)
# Tests that isclose with rtol or atol values less than zero throws a
self.assertFalse(torch.isclose(a, b, rtol=0, atol=0))
+ @dtypes(torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128)
+ def test_isclose_nan_equality_shortcut(self, device, dtype):
+ if dtype.is_floating_point:
+ a = b = torch.nan
+ else:
+ a = complex(torch.nan, 0)
+ b = complex(0, torch.nan)
+
+ expected = True
+ tests = [(a, b, expected)]
+
+ self._isclose_helper(tests, device, dtype, equal_nan=True, rtol=0, atol=0)
+
@dtypes(torch.bool, torch.long, torch.float, torch.cfloat)
def test_make_tensor(self, device, dtype):
def check(size, low, high, requires_grad, noncontiguous):