From d79f8f9dffa326ec966212030dffe9ada1abe74f Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 12 Jun 2019 19:52:19 +0900 Subject: [PATCH] [loco] Helper function to check equality of two dims (#3727) * [loco] Helper function to check equality of two dims This adds a helper operator `==` to check equality of two Dimensions. Signed-off-by: Hyun Sik Yoon * modified per comments --- contrib/loco/include/loco/IR/Dimension.h | 7 +++++++ contrib/loco/src/IR/Dimension.cpp | 5 +++++ contrib/loco/src/IR/Dimension.test.cpp | 9 +++++++++ 3 files changed, 21 insertions(+) diff --git a/contrib/loco/include/loco/IR/Dimension.h b/contrib/loco/include/loco/IR/Dimension.h index f489d10..2657f4c 100644 --- a/contrib/loco/include/loco/IR/Dimension.h +++ b/contrib/loco/include/loco/IR/Dimension.h @@ -67,6 +67,13 @@ private: uint32_t _value{0}; }; +/** + * @brief Equality operator between two Dimensions + * + * @note Refer to the definition of equality of dimemsion at + * https://www.tensorflow.org/api_docs/python/tf/Dimension#__eq__ + */ +bool operator==(const Dimension &, const Dimension &); bool operator==(const Dimension &, uint32_t); bool operator==(uint32_t, const Dimension &); diff --git a/contrib/loco/src/IR/Dimension.cpp b/contrib/loco/src/IR/Dimension.cpp index 11c6cfd..00d4ae6 100644 --- a/contrib/loco/src/IR/Dimension.cpp +++ b/contrib/loco/src/IR/Dimension.cpp @@ -19,6 +19,11 @@ namespace loco { +bool operator==(const Dimension &lhs, const Dimension &rhs) +{ + return lhs.known() && rhs.known() && lhs.value() == rhs.value(); +} + bool operator==(const Dimension &lhs, uint32_t rhs) { return lhs.known() && lhs.value() == rhs; } bool operator==(uint32_t lhs, const Dimension &rhs) { return rhs.known() && lhs == rhs.value(); } diff --git a/contrib/loco/src/IR/Dimension.test.cpp b/contrib/loco/src/IR/Dimension.test.cpp index 8c84bd6..e9075f2 100644 --- a/contrib/loco/src/IR/Dimension.test.cpp +++ b/contrib/loco/src/IR/Dimension.test.cpp @@ -81,6 +81,15 @@ TEST_F(DimensionTest, operator_eq) ASSERT_FALSE(known == 4); ASSERT_FALSE(4 == known); + + // Compare two known dimensions + loco::Dimension another_known{3}; + ASSERT_TRUE(known == another_known); + + // Compare two unknown dimensions + loco::Dimension unknown_a, unknown_b; + ASSERT_TRUE(unknown_a.known() == false && unknown_b.known() == false); + ASSERT_FALSE(unknown_a == unknown_b); } TEST_F(DimensionTest, make_unknown_dimension) -- 2.7.4