std::is_permutation(minor_to_major.begin(), minor_to_major.end(),
dims.begin()));
}
- Shape stripped_shape =
- shape.has_layout() ? MakeShapeWithLayout(shape.element_type(),
- dimension_sizes, minor_to_major)
- : MakeShape(shape.element_type(), dimension_sizes);
+ Shape stripped_shape;
+ if (LayoutUtil::IsDenseArray(shape)) {
+ stripped_shape = MakeShapeWithLayout(shape.element_type(), dimension_sizes,
+ minor_to_major);
+ } else if (LayoutUtil::IsSparseArray(shape)) {
+ stripped_shape =
+ MakeShapeWithSparseLayout(shape.element_type(), dimension_sizes,
+ shape.layout().max_sparse_elements());
+ } else {
+ stripped_shape = MakeShape(shape.element_type(), dimension_sizes);
+ }
VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape);
VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape);
ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1})));
}
+TEST(ShapeUtilTest, StripDegenerateDimensions) {
+ EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::StripDegenerateDimensions(
+ ShapeUtil::MakeShape(F32, {3, 1, 2})),
+ ShapeUtil::MakeShape(F32, {3, 2})));
+ EXPECT_TRUE(ShapeUtil::Equal(
+ ShapeUtil::StripDegenerateDimensions(
+ ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 1, 2}, 10)),
+ ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 2}, 10)));
+}
+
TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) {
EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast(
ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}),