From f4ea67a123157b7611e8bb1639fde521fc3f3d52 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Senior=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Wed, 9 May 2018 10:47:44 +0900 Subject: [PATCH] [core.ADT.tensor] Introduce 'squeeze' operation (#207) * [core.ADT.tensor] Introduce 'squeeze' operation This commit introduces 'squeeze' operation on tensor shape which eliminates axies with null dimensionality. This commit also introduces several related unittests which show how to use this operation and describes the expected operation. Signed-off-by: Jonghyun Park * Return reference correctly --- libs/core/include/nncc/core/ADT/tensor/Shape.h | 5 +++ libs/core/src/nncc/core/ADT/tensor/Shape.cpp | 15 +++++++++ libs/core/src/nncc/core/ADT/tensor/Shape.test.cpp | 39 +++++++++++++++++++++++ 3 files changed, 59 insertions(+) diff --git a/libs/core/include/nncc/core/ADT/tensor/Shape.h b/libs/core/include/nncc/core/ADT/tensor/Shape.h index 230a8a6..75cf47e 100644 --- a/libs/core/include/nncc/core/ADT/tensor/Shape.h +++ b/libs/core/include/nncc/core/ADT/tensor/Shape.h @@ -30,10 +30,15 @@ public: uint32_t &dim(uint32_t axis); uint32_t dim(uint32_t axis) const; +public: + Shape &squeeze(void); + private: std::vector _dims; }; +Shape squeeze(const Shape &); + bool operator==(const Shape &, const Shape &); } // namespace tensor diff --git a/libs/core/src/nncc/core/ADT/tensor/Shape.cpp b/libs/core/src/nncc/core/ADT/tensor/Shape.cpp index d48eca1..c8d1caa 100644 --- a/libs/core/src/nncc/core/ADT/tensor/Shape.cpp +++ b/libs/core/src/nncc/core/ADT/tensor/Shape.cpp @@ -1,5 +1,7 @@ #include "nncc/core/ADT/tensor/Shape.h" +#include + namespace nncc { namespace core @@ -20,6 +22,19 @@ Shape &Shape::resize(uint32_t size) { _dims.resize(size); } uint32_t &Shape::dim(uint32_t axis) { return _dims.at(axis); } uint32_t Shape::dim(uint32_t axis) const { return _dims.at(axis); } +Shape &Shape::squeeze(void) +{ + _dims.erase(std::remove(_dims.begin(), _dims.end(), 0), _dims.end()); + return *this; +} + +Shape squeeze(const Shape &shape) +{ + Shape res{shape}; + res.squeeze(); + return res; +} + bool operator==(const Shape &lhs, const Shape &rhs) { if (lhs.rank() != rhs.rank()) diff --git a/libs/core/src/nncc/core/ADT/tensor/Shape.test.cpp b/libs/core/src/nncc/core/ADT/tensor/Shape.test.cpp index 3fd5590..7f17441 100644 --- a/libs/core/src/nncc/core/ADT/tensor/Shape.test.cpp +++ b/libs/core/src/nncc/core/ADT/tensor/Shape.test.cpp @@ -58,6 +58,45 @@ TEST(ADT_TENSOR_SHAPE, copy) } } +TEST(ADT_TENSOR_SHAPE, squeeze_neg) +{ + using nncc::core::ADT::tensor::Shape; + using nncc::core::ADT::tensor::squeeze; + + auto squeezed = squeeze(Shape{3, 5, 2}); + + ASSERT_EQ(squeezed.rank(), 3); + ASSERT_EQ(squeezed.dim(0), 3); + ASSERT_EQ(squeezed.dim(1), 5); + ASSERT_EQ(squeezed.dim(2), 2); +} + +TEST(ADT_TENSOR_SHAPE, squeeze_pos) +{ + using nncc::core::ADT::tensor::Shape; + using nncc::core::ADT::tensor::squeeze; + + auto squeezed = squeeze(Shape{3, 0, 2}); + + ASSERT_EQ(squeezed.rank(), 2); + ASSERT_EQ(squeezed.dim(0), 3); + ASSERT_EQ(squeezed.dim(1), 2); +} + +TEST(ADT_TENSOR_SHAPE, squeeze_nested) +{ + using nncc::core::ADT::tensor::Shape; + using nncc::core::ADT::tensor::squeeze; + + Shape shape{3, 0, 2}; + + shape.squeeze().squeeze(); + + ASSERT_EQ(shape.rank(), 2); + ASSERT_EQ(shape.dim(0), 3); + ASSERT_EQ(shape.dim(1), 2); +} + TEST(ADT_TENSOR_SHAPE, eq_negative_on_unmatched_rank) { const nncc::core::ADT::tensor::Shape left{1, 1, 1}; -- 2.7.4