Shape &Shape::squeeze(void)
{
- _dims.erase(std::remove(_dims.begin(), _dims.end(), 0), _dims.end());
+ _dims.erase(std::remove(_dims.begin(), _dims.end(), 1), _dims.end());
return *this;
}
ASSERT_EQ(squeezed.dim(2), 2);
}
-TEST(ADT_TENSOR_SHAPE, squeeze_pos)
+TEST(ADT_TENSOR_SHAPE, squeeze_neg_0)
{
using nncc::core::ADT::tensor::Shape;
using nncc::core::ADT::tensor::squeeze;
auto squeezed = squeeze(Shape{3, 0, 2});
+ ASSERT_EQ(squeezed.rank(), 3);
+ ASSERT_EQ(squeezed.dim(0), 3);
+ ASSERT_EQ(squeezed.dim(1), 0);
+ ASSERT_EQ(squeezed.dim(2), 2);
+}
+
+TEST(ADT_TENSOR_SHAPE, squeeze_pos)
+{
+ using nncc::core::ADT::tensor::Shape;
+ using nncc::core::ADT::tensor::squeeze;
+
+ auto squeezed = squeeze(Shape{3, 1, 2});
+
ASSERT_EQ(squeezed.rank(), 2);
ASSERT_EQ(squeezed.dim(0), 3);
ASSERT_EQ(squeezed.dim(1), 2);
using nncc::core::ADT::tensor::Shape;
using nncc::core::ADT::tensor::squeeze;
- Shape shape{3, 0, 2};
+ Shape shape{3, 1, 2};
shape.squeeze().squeeze();