[core.ADT.tensor] Introduce 'squeeze' operation (#207)
author박종현/동작제어Lab(SR)/Senior Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 9 May 2018 01:47:44 +0000 (10:47 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 9 May 2018 01:47:44 +0000 (10:47 +0900)
* [core.ADT.tensor] Introduce 'squeeze' operation

This commit introduces 'squeeze' operation on tensor shape which eliminates
axies with null dimensionality.

This commit also introduces several related unittests which show how to
use this operation and describes the expected operation.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
* Return reference correctly

libs/core/include/nncc/core/ADT/tensor/Shape.h
libs/core/src/nncc/core/ADT/tensor/Shape.cpp
libs/core/src/nncc/core/ADT/tensor/Shape.test.cpp

index 230a8a6..75cf47e 100644 (file)
@@ -30,10 +30,15 @@ public:
   uint32_t &dim(uint32_t axis);
   uint32_t dim(uint32_t axis) const;
 
+public:
+  Shape &squeeze(void);
+
 private:
   std::vector<uint32_t> _dims;
 };
 
+Shape squeeze(const Shape &);
+
 bool operator==(const Shape &, const Shape &);
 
 } // namespace tensor
index d48eca1..c8d1caa 100644 (file)
@@ -1,5 +1,7 @@
 #include "nncc/core/ADT/tensor/Shape.h"
 
+#include <algorithm>
+
 namespace nncc
 {
 namespace core
@@ -20,6 +22,19 @@ Shape &Shape::resize(uint32_t size) { _dims.resize(size); }
 uint32_t &Shape::dim(uint32_t axis) { return _dims.at(axis); }
 uint32_t Shape::dim(uint32_t axis) const { return _dims.at(axis); }
 
+Shape &Shape::squeeze(void)
+{
+  _dims.erase(std::remove(_dims.begin(), _dims.end(), 0), _dims.end());
+  return *this;
+}
+
+Shape squeeze(const Shape &shape)
+{
+  Shape res{shape};
+  res.squeeze();
+  return res;
+}
+
 bool operator==(const Shape &lhs, const Shape &rhs)
 {
   if (lhs.rank() != rhs.rank())
index 3fd5590..7f17441 100644 (file)
@@ -58,6 +58,45 @@ TEST(ADT_TENSOR_SHAPE, copy)
   }
 }
 
+TEST(ADT_TENSOR_SHAPE, squeeze_neg)
+{
+  using nncc::core::ADT::tensor::Shape;
+  using nncc::core::ADT::tensor::squeeze;
+
+  auto squeezed = squeeze(Shape{3, 5, 2});
+
+  ASSERT_EQ(squeezed.rank(), 3);
+  ASSERT_EQ(squeezed.dim(0), 3);
+  ASSERT_EQ(squeezed.dim(1), 5);
+  ASSERT_EQ(squeezed.dim(2), 2);
+}
+
+TEST(ADT_TENSOR_SHAPE, squeeze_pos)
+{
+  using nncc::core::ADT::tensor::Shape;
+  using nncc::core::ADT::tensor::squeeze;
+
+  auto squeezed = squeeze(Shape{3, 0, 2});
+
+  ASSERT_EQ(squeezed.rank(), 2);
+  ASSERT_EQ(squeezed.dim(0), 3);
+  ASSERT_EQ(squeezed.dim(1), 2);
+}
+
+TEST(ADT_TENSOR_SHAPE, squeeze_nested)
+{
+  using nncc::core::ADT::tensor::Shape;
+  using nncc::core::ADT::tensor::squeeze;
+
+  Shape shape{3, 0, 2};
+
+  shape.squeeze().squeeze();
+
+  ASSERT_EQ(shape.rank(), 2);
+  ASSERT_EQ(shape.dim(0), 3);
+  ASSERT_EQ(shape.dim(1), 2);
+}
+
 TEST(ADT_TENSOR_SHAPE, eq_negative_on_unmatched_rank)
 {
   const nncc::core::ADT::tensor::Shape left{1, 1, 1};