[XLA] Fix bug in ShapeUtil::StripDegenerateDimensions
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 28 Apr 2018 01:24:57 +0000 (18:24 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 28 Apr 2018 01:27:28 +0000 (18:27 -0700)
PiperOrigin-RevId: 194621163

tensorflow/compiler/xla/shape_util.cc
tensorflow/compiler/xla/shape_util_test.cc

index ac7e201..d58baa3 100644 (file)
@@ -905,10 +905,17 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
            std::is_permutation(minor_to_major.begin(), minor_to_major.end(),
                                dims.begin()));
   }
-  Shape stripped_shape =
-      shape.has_layout() ? MakeShapeWithLayout(shape.element_type(),
-                                               dimension_sizes, minor_to_major)
-                         : MakeShape(shape.element_type(), dimension_sizes);
+  Shape stripped_shape;
+  if (LayoutUtil::IsDenseArray(shape)) {
+    stripped_shape = MakeShapeWithLayout(shape.element_type(), dimension_sizes,
+                                         minor_to_major);
+  } else if (LayoutUtil::IsSparseArray(shape)) {
+    stripped_shape =
+        MakeShapeWithSparseLayout(shape.element_type(), dimension_sizes,
+                                  shape.layout().max_sparse_elements());
+  } else {
+    stripped_shape = MakeShape(shape.element_type(), dimension_sizes);
+  }
 
   VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape);
   VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape);
index 13582a2..f7675e9 100644 (file)
@@ -713,6 +713,16 @@ TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) {
       ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1})));
 }
 
+TEST(ShapeUtilTest, StripDegenerateDimensions) {
+  EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::StripDegenerateDimensions(
+                                   ShapeUtil::MakeShape(F32, {3, 1, 2})),
+                               ShapeUtil::MakeShape(F32, {3, 2})));
+  EXPECT_TRUE(ShapeUtil::Equal(
+      ShapeUtil::StripDegenerateDimensions(
+          ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 1, 2}, 10)),
+      ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 2}, 10)));
+}
+
 TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) {
   EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast(
       ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}),