From edc30df260fc5623efe8ac8e986226d348bedf4c Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Senior=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Thu, 26 Apr 2018 14:00:37 +0900 Subject: [PATCH] [nncc.core] Equality over tensor shape (#153) This commit introduces == operator between two tensor::Shape class. Signed-off-by: Jonghyun Park --- libs/core/include/nncc/core/ADT/tensor/Shape.h | 2 ++ libs/core/src/nncc/core/ADT/tensor/Shape.cpp | 18 ++++++++++ libs/core/src/nncc/core/ADT/tensor/Shape.test.cpp | 43 +++++++++++++++++++++++ 3 files changed, 63 insertions(+) diff --git a/libs/core/include/nncc/core/ADT/tensor/Shape.h b/libs/core/include/nncc/core/ADT/tensor/Shape.h index 7e707cf..288f046 100644 --- a/libs/core/include/nncc/core/ADT/tensor/Shape.h +++ b/libs/core/include/nncc/core/ADT/tensor/Shape.h @@ -29,6 +29,8 @@ private: std::vector _dims; }; +bool operator==(const Shape &, const Shape &); + } // namespace tensor } // namespace ADT } // namespace core diff --git a/libs/core/src/nncc/core/ADT/tensor/Shape.cpp b/libs/core/src/nncc/core/ADT/tensor/Shape.cpp index fc94d12..38610a8 100644 --- a/libs/core/src/nncc/core/ADT/tensor/Shape.cpp +++ b/libs/core/src/nncc/core/ADT/tensor/Shape.cpp @@ -15,6 +15,24 @@ Shape &Shape::resize(uint32_t size) { _dims.resize(size); } uint32_t &Shape::dim(uint32_t axis) { return _dims.at(axis); } uint32_t Shape::dim(uint32_t axis) const { return _dims.at(axis); } +bool operator==(const Shape &lhs, const Shape &rhs) +{ + if (lhs.rank() != rhs.rank()) + { + return false; + } + + for (uint32_t axis = 0; axis < lhs.rank(); ++axis) + { + if (lhs.dim(axis) != rhs.dim(axis)) + { + return false; + } + } + + return true; +} + } // namespace tensor } // namespace ADT } // namespace core diff --git a/libs/core/src/nncc/core/ADT/tensor/Shape.test.cpp b/libs/core/src/nncc/core/ADT/tensor/Shape.test.cpp index 5a7b055..dfe204c 100644 --- a/libs/core/src/nncc/core/ADT/tensor/Shape.test.cpp +++ b/libs/core/src/nncc/core/ADT/tensor/Shape.test.cpp @@ -55,3 +55,46 @@ TEST(ADT_TENSOR_SHAPE, copy) ASSERT_EQ(original.dim(axis), copied.dim(axis)); } } + +TEST(ADT_TENSOR_SHAPE, eq_negative_on_unmatched_rank) +{ + nncc::core::ADT::tensor::Shape left; + nncc::core::ADT::tensor::Shape right; + + left.resize(3); + right.resize(4); + + ASSERT_FALSE(left == right); +} + +TEST(ADT_TENSOR_SHAPE, eq_negative_on_unmatched_dim) +{ + nncc::core::ADT::tensor::Shape left; + nncc::core::ADT::tensor::Shape right; + + left.resize(2); + left.dim(0) = 2; + left.dim(1) = 3; + + right.resize(2); + right.dim(0) = 2; + right.dim(1) = 4; + + ASSERT_FALSE(left == right); +} + +TEST(ADT_TENSOR_SHAPE, eq_positive) +{ + nncc::core::ADT::tensor::Shape left; + nncc::core::ADT::tensor::Shape right; + + left.resize(2); + left.dim(0) = 2; + left.dim(1) = 3; + + right.resize(2); + right.dim(0) = 2; + right.dim(1) = 3; + + ASSERT_TRUE(left == right); +} -- 2.7.4