[angkor] Overload operator== (#8233)
author채성우/On-Device Lab(SR)/Engineer/삼성전자 <sw4670.chae@samsung.com>
Wed, 16 Oct 2019 10:40:57 +0000 (19:40 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 16 Oct 2019 10:40:57 +0000 (19:40 +0900)
* [angkor] Overload operator==

This commit overloads operator== to index class.

Signed-off-by: seongwoo <sw4670.chae@samsung.com>
* apply commnet.

compiler/angkor/include/nncc/core/ADT/tensor/Index.h
compiler/angkor/src/ADT/tensor/Index.cpp
compiler/angkor/src/ADT/tensor/Index.test.cpp

index 56d7391..19beafa 100644 (file)
@@ -55,6 +55,7 @@ private:
 
 // It throws an exception when rank of inputs does not match.
 Index operator+(const Index &lhs, const Index &rhs);
+bool operator==(const Index &lhs, const Index &rhs);
 
 } // namespace tensor
 } // namespace ADT
index e61fafb..61f0a71 100644 (file)
@@ -63,6 +63,18 @@ Index operator+(const Index &lhs, const Index &rhs)
   return ret;
 }
 
+bool operator==(const Index &lhs, const Index &rhs)
+{
+  if (lhs.rank() != rhs.rank())
+    return false;
+  for (uint32_t axis = 0; axis < lhs.rank(); axis++)
+  {
+    if (lhs.at(axis) != rhs.at(axis))
+      return false;
+  }
+  return true;
+}
+
 } // namespace tensor
 } // namespace ADT
 } // namespace core
index 412c85e..2306028 100644 (file)
@@ -49,6 +49,18 @@ TEST(ADT_TENSOR_INDEX, operator_add)
   ASSERT_EQ(result.at(3), 12);
 }
 
+TEST(ADT_TENSOR_INDEX, operator_eqaul)
+{
+  nncc::core::ADT::tensor::Index index1{1, 2, 3, 4};
+  nncc::core::ADT::tensor::Index index2{1, 2, 3, 4};
+  nncc::core::ADT::tensor::Index index3{5, 6, 7, 8};
+  nncc::core::ADT::tensor::Index index4{1, 2};
+
+  ASSERT_TRUE(index1 == index2);
+  ASSERT_FALSE(index1 == index3);
+  ASSERT_FALSE(index1 == index4);
+}
+
 TEST(ADT_TENSOR_INDEX, operator_add_different_size)
 {
   nncc::core::ADT::tensor::Index index1{1, 2, 3, 4};