From d3e3b246ea02e5d5ad76cb8a66a95e544a08a40c Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 14 Mar 2019 15:25:35 -0700 Subject: [PATCH] Use std::isnan instead of self-comparison. (#18021) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18021 ghimport-source-id: 03423ba47ba5900c2b400c4457b148147ce8b35e Stack: * **#18021 Use std::isnan instead of self-comparison.** Signed-off-by: Edward Z. Yang Reviewed By: soumith Differential Revision: D14460699 fbshipit-source-id: d8feb7f3f0e93996bd1b4f4aea163548b1d12437 --- aten/src/ATen/NumericUtils.h | 22 ++++++++++++++++++++++ aten/src/ATen/cpu/vec256/vec256_base.h | 9 +++++---- aten/src/ATen/native/Sorting.cpp | 3 ++- aten/src/ATen/native/cpu/TensorCompareKernel.cpp | 16 +--------------- 4 files changed, 30 insertions(+), 20 deletions(-) create mode 100644 aten/src/ATen/NumericUtils.h diff --git a/aten/src/ATen/NumericUtils.h b/aten/src/ATen/NumericUtils.h new file mode 100644 index 0000000..95274f7 --- /dev/null +++ b/aten/src/ATen/NumericUtils.h @@ -0,0 +1,22 @@ +#include +#include + +namespace at { + +// std::isnan isn't performant to use on integral types; it will +// (uselessly) convert to floating point and then do the test. +// This function is. + +template ::value, int>::type = 0> +inline bool _isnan(T val) { + return false; +} + +template ::value, int>::type = 0> +inline bool _isnan(T val) { + return std::isnan(val); +} + +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vec256_base.h b/aten/src/ATen/cpu/vec256/vec256_base.h index dffbbaa..b7a0df2 100644 --- a/aten/src/ATen/cpu/vec256/vec256_base.h +++ b/aten/src/ATen/cpu/vec256/vec256_base.h @@ -8,6 +8,7 @@ #include #include +#include #include #if defined(__GNUC__) @@ -323,7 +324,7 @@ template Vec256 inline maximum(const Vec256 &a, const Vec256 Vec256 c = Vec256(); for (int i = 0; i != Vec256::size(); i++) { c[i] = (a[i] > b[i]) ? a[i] : b[i]; - if (std::is_floating_point::value && std::isnan(a[i])) { + if (_isnan(a[i])) { // If either input is NaN, propagate a NaN. // NOTE: The case where b[i] was NaN is handled correctly by the naive // ternary operator above. @@ -336,7 +337,7 @@ template Vec256 inline maximum(const Vec256 &a, const Vec256 template inline T maximum(const T& a, const T& b) { T c = (a > b) ? a : b; - if (std::is_floating_point::value && std::isnan(a)) { + if (_isnan(a)) { c = a; } return c; @@ -348,7 +349,7 @@ template Vec256 inline minimum(const Vec256 &a, const Vec256 Vec256 c = Vec256(); for (int i = 0; i != Vec256::size(); i++) { c[i] = (a[i] < b[i]) ? a[i] : b[i]; - if (std::is_floating_point::value && std::isnan(a[i])) { + if (_isnan(a[i])) { // If either input is NaN, propagate a NaN. // NOTE: The case where b[i] was NaN is handled correctly by the naive // ternary operator above. @@ -361,7 +362,7 @@ template Vec256 inline minimum(const Vec256 &a, const Vec256 template inline T minimum(const T& a, const T& b) { T c = (a < b) ? a : b; - if (std::is_floating_point::value && std::isnan(a)) { + if (_isnan(a)) { c = a; } return c; diff --git a/aten/src/ATen/native/Sorting.cpp b/aten/src/ATen/native/Sorting.cpp index 01adda7..2a8af89 100644 --- a/aten/src/ATen/native/Sorting.cpp +++ b/aten/src/ATen/native/Sorting.cpp @@ -2,6 +2,7 @@ #include #include #include +#include namespace at { namespace native { @@ -171,7 +172,7 @@ std::tuple kthvalue_out_cpu( tmp_values, k - 1, [](scalar_t x, scalar_t y) -> bool { - return ((x != x && y == y) || (x > y)); + return ((_isnan(x) && !_isnan(y)) || (x > y)); }, [&](int64_t i, int64_t j) { std::swap(tmp_values[i], tmp_values[j]); diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index d118e59..e988b69 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -6,25 +6,11 @@ #include #include +#include #include namespace at { namespace native { namespace { -template -bool _isnan(scalar_t val) { - return false; -} - -template <> -bool _isnan(float val) { - return std::isnan(val); -} - -template <> -bool _isnan(double val) { - return std::isnan(val); -} - template struct Reduction { static void apply( -- 2.7.4