#include "nncc/core/ADT/tensor/Shape.h"
+#include <algorithm>
+
namespace nncc
{
namespace core
uint32_t &Shape::dim(uint32_t axis) { return _dims.at(axis); }
uint32_t Shape::dim(uint32_t axis) const { return _dims.at(axis); }
+Shape &Shape::squeeze(void)
+{
+ _dims.erase(std::remove(_dims.begin(), _dims.end(), 0), _dims.end());
+ return *this;
+}
+
+Shape squeeze(const Shape &shape)
+{
+ Shape res{shape};
+ res.squeeze();
+ return res;
+}
+
bool operator==(const Shape &lhs, const Shape &rhs)
{
if (lhs.rank() != rhs.rank())
}
}
+TEST(ADT_TENSOR_SHAPE, squeeze_neg)
+{
+ using nncc::core::ADT::tensor::Shape;
+ using nncc::core::ADT::tensor::squeeze;
+
+ auto squeezed = squeeze(Shape{3, 5, 2});
+
+ ASSERT_EQ(squeezed.rank(), 3);
+ ASSERT_EQ(squeezed.dim(0), 3);
+ ASSERT_EQ(squeezed.dim(1), 5);
+ ASSERT_EQ(squeezed.dim(2), 2);
+}
+
+TEST(ADT_TENSOR_SHAPE, squeeze_pos)
+{
+ using nncc::core::ADT::tensor::Shape;
+ using nncc::core::ADT::tensor::squeeze;
+
+ auto squeezed = squeeze(Shape{3, 0, 2});
+
+ ASSERT_EQ(squeezed.rank(), 2);
+ ASSERT_EQ(squeezed.dim(0), 3);
+ ASSERT_EQ(squeezed.dim(1), 2);
+}
+
+TEST(ADT_TENSOR_SHAPE, squeeze_nested)
+{
+ using nncc::core::ADT::tensor::Shape;
+ using nncc::core::ADT::tensor::squeeze;
+
+ Shape shape{3, 0, 2};
+
+ shape.squeeze().squeeze();
+
+ ASSERT_EQ(shape.rank(), 2);
+ ASSERT_EQ(shape.dim(0), 3);
+ ASSERT_EQ(shape.dim(1), 2);
+}
+
TEST(ADT_TENSOR_SHAPE, eq_negative_on_unmatched_rank)
{
const nncc::core::ADT::tensor::Shape left{1, 1, 1};