[mlir][Tensor] Use helper function for `getDroppedDims`
authorMatthias Springer <me@m-sp.org>
Tue, 28 Mar 2023 14:37:03 +0000 (16:37 +0200)
committerMatthias Springer <me@m-sp.org>
Wed, 29 Mar 2023 07:17:28 +0000 (09:17 +0200)
This helper function is used for both ExtractSliceOp and InsertSliceOp. Also fixes a bug in the implementation of `InsertSliceOp::getDroppedDims`.

Differential Revision: https://reviews.llvm.org/D147048

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

index 93db7da..e7fb287 100644 (file)
@@ -110,6 +110,48 @@ LogicalResult tensor::getOrCreateDestinations(OpBuilder &b, Location loc,
   return success();
 }
 
+/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
+/// rank-extending tensor.insert_slice op.
+static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
+                                           ArrayRef<OpFoldResult> mixedSizes) {
+  llvm::SmallBitVector droppedDims(mixedSizes.size());
+  int64_t shapePos = 0;
+
+  for (const auto &size : enumerate(mixedSizes)) {
+    // Rank-reduced dims must have a static unit dimension.
+    bool isStaticUnitSize =
+        size.value().is<Attribute>() &&
+        size.value().get<Attribute>().cast<IntegerAttr>().getInt() == 1;
+
+    if (shapePos == static_cast<int64_t>(reducedShape.size())) {
+      // There are no more dims in the reduced shape. All remaining sizes must
+      // be rank-reduced dims.
+      assert(isStaticUnitSize && "expected unit dim");
+      droppedDims.set(size.index());
+      continue;
+    }
+
+    // Dim is preserved if the size is not a static 1.
+    if (!isStaticUnitSize) {
+      ++shapePos;
+      continue;
+    }
+
+    // Dim is preserved if the reduced shape dim is also 1.
+    if (reducedShape[shapePos] == 1) {
+      ++shapePos;
+      continue;
+    }
+
+    // Otherwise: Dim is dropped.
+    droppedDims.set(size.index());
+  }
+
+  assert(shapePos == static_cast<int64_t>(reducedShape.size()) &&
+         "dimension mismatch");
+  return droppedDims;
+}
+
 //===----------------------------------------------------------------------===//
 // CastOp
 //===----------------------------------------------------------------------===//
@@ -1740,23 +1782,7 @@ LogicalResult ExtractSliceOp::verify() {
 }
 
 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
-  ArrayRef<int64_t> resultShape = getType().getShape();
-  SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
-  llvm::SmallBitVector droppedDims(mixedSizes.size());
-  unsigned shapePos = 0;
-  for (const auto &size : enumerate(mixedSizes)) {
-    std::optional<int64_t> sizeVal = getConstantIntValue(size.value());
-    // If the size is not 1, or if the current matched dimension of the result
-    // is the same static shape as the size value (which is 1), then the
-    // dimension is preserved.
-    if (!sizeVal || *sizeVal != 1 ||
-        (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
-      shapePos++;
-      continue;
-    }
-    droppedDims.set(size.index());
-  }
-  return droppedDims;
+  return ::getDroppedDims(getType().getShape(), getMixedSizes());
 }
 
 FailureOr<Value>
@@ -2397,23 +2423,7 @@ struct InsertSliceOpSourceCastInserter final
 } // namespace
 
 llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
-  ArrayRef<int64_t> resultShape = getType().getShape();
-  SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
-  llvm::SmallBitVector droppedDims(mixedSizes.size());
-  unsigned shapePos = 0;
-  for (const auto &size : enumerate(mixedSizes)) {
-    std::optional<int64_t> sizeVal = getConstantIntValue(size.value());
-    // If the size is not 1, or if the current matched dimension of the result
-    // is the same static shape as the size value (which is 1), then the
-    // dimension is preserved.
-    if (!sizeVal || *sizeVal != 1 ||
-        (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
-      shapePos++;
-      continue;
-    }
-    droppedDims.set(size.index());
-  }
-  return droppedDims;
+  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
 }
 
 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,