[core.ADT.tensor] Squeeze Tensor Shape correctly. (#226)
author박종현/동작제어Lab(SR)/Senior Engineer/삼성전자 <jh1302.park@samsung.com>
Fri, 18 May 2018 00:09:02 +0000 (09:09 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Fri, 18 May 2018 00:09:02 +0000 (09:09 +0900)
Squeeze operation (in TensorFlow) removess axises whose dimensionality
is 1, but the current squeeze operation eliminates if their dimensionality
is null.

This commit fixes this mismatch on the semantics of squeeze operation.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
libs/core/src/nncc/core/ADT/tensor/Shape.cpp
libs/core/src/nncc/core/ADT/tensor/Shape.test.cpp

index b3dd4ab..97df5d2 100644 (file)
@@ -24,7 +24,7 @@ uint32_t Shape::dim(uint32_t axis) const { return _dims.at(axis); }
 
 Shape &Shape::squeeze(void)
 {
-  _dims.erase(std::remove(_dims.begin(), _dims.end(), 0), _dims.end());
+  _dims.erase(std::remove(_dims.begin(), _dims.end(), 1), _dims.end());
   return *this;
 }
 
index ab2dc87..31398e4 100644 (file)
@@ -95,13 +95,26 @@ TEST(ADT_TENSOR_SHAPE, squeeze_neg)
   ASSERT_EQ(squeezed.dim(2), 2);
 }
 
-TEST(ADT_TENSOR_SHAPE, squeeze_pos)
+TEST(ADT_TENSOR_SHAPE, squeeze_neg_0)
 {
   using nncc::core::ADT::tensor::Shape;
   using nncc::core::ADT::tensor::squeeze;
 
   auto squeezed = squeeze(Shape{3, 0, 2});
 
+  ASSERT_EQ(squeezed.rank(), 3);
+  ASSERT_EQ(squeezed.dim(0), 3);
+  ASSERT_EQ(squeezed.dim(1), 0);
+  ASSERT_EQ(squeezed.dim(2), 2);
+}
+
+TEST(ADT_TENSOR_SHAPE, squeeze_pos)
+{
+  using nncc::core::ADT::tensor::Shape;
+  using nncc::core::ADT::tensor::squeeze;
+
+  auto squeezed = squeeze(Shape{3, 1, 2});
+
   ASSERT_EQ(squeezed.rank(), 2);
   ASSERT_EQ(squeezed.dim(0), 3);
   ASSERT_EQ(squeezed.dim(1), 2);
@@ -112,7 +125,7 @@ TEST(ADT_TENSOR_SHAPE, squeeze_nested)
   using nncc::core::ADT::tensor::Shape;
   using nncc::core::ADT::tensor::squeeze;
 
-  Shape shape{3, 0, 2};
+  Shape shape{3, 1, 2};
 
   shape.squeeze().squeeze();