From 04475bc1b226dcf0bc219400ac42dd731a5787fe Mon Sep 17 00:00:00 2001 From: Ilya Churaev Date: Tue, 18 Aug 2020 13:49:30 +0300 Subject: [PATCH] Add arithmetic operators for bfloat16 (#1831) --- ngraph/core/include/ngraph/type/bfloat16.hpp | 122 +++++++++++++++++++++++++-- ngraph/core/src/type/bfloat16.cpp | 32 ------- ngraph/test/bfloat16.cpp | 12 +++ 3 files changed, 128 insertions(+), 38 deletions(-) diff --git a/ngraph/core/include/ngraph/type/bfloat16.hpp b/ngraph/core/include/ngraph/type/bfloat16.hpp index dbb73a0..72d47f5 100644 --- a/ngraph/core/include/ngraph/type/bfloat16.hpp +++ b/ngraph/core/include/ngraph/type/bfloat16.hpp @@ -55,12 +55,37 @@ namespace ngraph std::string to_string() const; size_t size() const; - bool operator==(const bfloat16& other) const; - bool operator!=(const bfloat16& other) const { return !(*this == other); } - bool operator<(const bfloat16& other) const; - bool operator<=(const bfloat16& other) const; - bool operator>(const bfloat16& other) const; - bool operator>=(const bfloat16& other) const; + template + bool operator==(const T& other) const; + template + bool operator!=(const T& other) const + { + return !(*this == other); + } + template + bool operator<(const T& other) const; + template + bool operator<=(const T& other) const; + template + bool operator>(const T& other) const; + template + bool operator>=(const T& other) const; + template + bfloat16 operator+(const T& other) const; + template + bfloat16 operator+=(const T& other); + template + bfloat16 operator-(const T& other) const; + template + bfloat16 operator-=(const T& other); + template + bfloat16 operator*(const T& other) const; + template + bfloat16 operator*=(const T& other); + template + bfloat16 operator/(const T& other) const; + template + bfloat16 operator/=(const T& other); operator float() const; static std::vector to_float_vector(const std::vector&); @@ -106,6 +131,91 @@ namespace ngraph uint16_t m_value; }; + + template + bool bfloat16::operator==(const T& other) const + { +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wfloat-equal" +#endif + return (static_cast(*this) == static_cast(other)); +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + } + + template + bool bfloat16::operator<(const T& other) const + { + return (static_cast(*this) < static_cast(other)); + } + + template + bool bfloat16::operator<=(const T& other) const + { + return (static_cast(*this) <= static_cast(other)); + } + + template + bool bfloat16::operator>(const T& other) const + { + return (static_cast(*this) > static_cast(other)); + } + + template + bool bfloat16::operator>=(const T& other) const + { + return (static_cast(*this) >= static_cast(other)); + } + + template + bfloat16 bfloat16::operator+(const T& other) const + { + return {static_cast(*this) + static_cast(other)}; + } + + template + bfloat16 bfloat16::operator+=(const T& other) + { + return *this = *this + other; + } + + template + bfloat16 bfloat16::operator-(const T& other) const + { + return {static_cast(*this) - static_cast(other)}; + } + + template + bfloat16 bfloat16::operator-=(const T& other) + { + return *this = *this - other; + } + + template + bfloat16 bfloat16::operator*(const T& other) const + { + return {static_cast(*this) * static_cast(other)}; + } + + template + bfloat16 bfloat16::operator*=(const T& other) + { + return *this = *this * other; + } + + template + bfloat16 bfloat16::operator/(const T& other) const + { + return {static_cast(*this) / static_cast(other)}; + } + + template + bfloat16 bfloat16::operator/=(const T& other) + { + return *this = *this / other; + } } namespace std diff --git a/ngraph/core/src/type/bfloat16.cpp b/ngraph/core/src/type/bfloat16.cpp index 8400417..0ebcf5e 100644 --- a/ngraph/core/src/type/bfloat16.cpp +++ b/ngraph/core/src/type/bfloat16.cpp @@ -74,38 +74,6 @@ size_t bfloat16::size() const return sizeof(m_value); } -bool bfloat16::operator==(const bfloat16& other) const -{ -#if defined(__GNUC__) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wfloat-equal" -#endif - return (static_cast(*this) == static_cast(other)); -#if defined(__GNUC__) -#pragma GCC diagnostic pop -#endif -} - -bool bfloat16::operator<(const bfloat16& other) const -{ - return (static_cast(*this) < static_cast(other)); -} - -bool bfloat16::operator<=(const bfloat16& other) const -{ - return (static_cast(*this) <= static_cast(other)); -} - -bool bfloat16::operator>(const bfloat16& other) const -{ - return (static_cast(*this) > static_cast(other)); -} - -bool bfloat16::operator>=(const bfloat16& other) const -{ - return (static_cast(*this) >= static_cast(other)); -} - bfloat16::operator float() const { uint32_t tmp = (static_cast(m_value) << 16); diff --git a/ngraph/test/bfloat16.cpp b/ngraph/test/bfloat16.cpp index 477e3da..580fd46 100644 --- a/ngraph/test/bfloat16.cpp +++ b/ngraph/test/bfloat16.cpp @@ -259,3 +259,15 @@ TEST(bfloat16, assigns) EXPECT_EQ(f32arr[i], bf16arr[i]); } } + +TEST(bfloat16, operators) +{ + bfloat16 a(2.0); + bfloat16 b(3.5); + bfloat16 c(5.5); + bfloat16 d(7.0); + ASSERT_TRUE(a + b == c); + ASSERT_TRUE(a == c - b); + ASSERT_TRUE(a * b == d); + ASSERT_TRUE(a == d / b); +} -- 2.7.4