[mlir][tensor] TrackingListener: Find replacement ops through cast-like ExtractSliceOps
authorMatthias Springer <me@m-sp.org>
Thu, 1 Jun 2023 07:00:08 +0000 (09:00 +0200)
committerMatthias Springer <me@m-sp.org>
Thu, 1 Jun 2023 07:00:56 +0000 (09:00 +0200)
Certain ExtractSliceOps, that do extract all elements from the destination, are treated like casts when looking for replacement ops. Such ExtractSliceOps are typically rank expansions.

Differential Revision: https://reviews.llvm.org/D151804

mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
mlir/lib/Dialect/Tensor/Utils/Utils.cpp
mlir/test/Dialect/Tensor/tracking-listener.mlir

index a037d40..c610b5d 100644 (file)
@@ -47,6 +47,10 @@ computeTransposedType(RankedTensorType rankedTensorType,
 /// the same shape.
 bool isCastLikeInsertSliceOp(InsertSliceOp op);
 
+/// A tensor.extract_slice is a cast-like operation if it merely rank-reduces
+/// the source tensor or extracts the entire source tensor.
+bool isCastLikeExtractSliceOp(ExtractSliceOp op);
+
 } // namespace tensor
 } // namespace mlir
 
index 9b609a2..09a6b50 100644 (file)
@@ -38,8 +38,6 @@ tensor::TrackingListener::findReplacementOp(Operation *op,
       return nullptr;
 
     // Skip cast-like operations.
-    // TODO: CastOpInterface could be used if CollapseShapeOp and ExpandShapeOp
-    // implement that interface
     values.clear();
     llvm::TypeSwitch<Operation *>(defOp)
         .Case<CastOp>([&](CastOp op) { values.push_back(op.getSource()); })
@@ -53,6 +51,10 @@ tensor::TrackingListener::findReplacementOp(Operation *op,
           if (isCastLikeInsertSliceOp(op))
             values.push_back(op.getSource());
         })
+        .Case<ExtractSliceOp>([&](ExtractSliceOp op) {
+          if (isCastLikeExtractSliceOp(op))
+            values.push_back(op.getSource());
+        })
         .Default([](Operation *op) {});
   } while (!values.empty());
 
index 165cf9b..4d5404a 100644 (file)
@@ -123,3 +123,22 @@ bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
 
   return true;
 }
+
+bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) {
+  llvm::SmallBitVector droppedDims = op.getDroppedDims();
+  int64_t resultDim = 0;
+  // Source dims and result dims (apart from dropped dims) must have the same
+  // size.
+  for (int64_t dim = 0; dim < op.getSourceType().getRank(); ++dim) {
+    if (droppedDims.test(dim)) {
+      continue;
+    }
+    FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
+        op.getSource(), op.getResult(), dim, resultDim);
+    if (failed(equalDimSize) || !*equalDimSize)
+      return false;
+    ++resultDim;
+  }
+
+  return true;
+}
index 369dcec..6341b7a 100644 (file)
@@ -105,3 +105,38 @@ func.func @cast_like_insert_slice_dynamic(
       {replacement_0 = 0} : tensor<?xf32> into tensor<1x?x1xf32>
   return
 }
+
+// -----
+
+func.func @cast_like_extract_slice() {
+  %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
+  // expected-remark @below {{replacement found}}
+  %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>)
+  %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 5, 1, 1][1, 1, 1, 1]
+      {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<5xf32>
+  return
+}
+
+// -----
+
+func.func @cast_like_extract_slice_dynamic() {
+  %0 = "test.foo"() {replaced} : () -> (tensor<?xf32>)
+  // expected-remark @below {{replacement found}}
+  %1 = "test.foo"() : () -> (tensor<1x?x1x1xf32>)
+  %c1 = arith.constant 1 : index
+  %dim = tensor.dim %1, %c1 : tensor<1x?x1x1xf32>
+  %2 = tensor.extract_slice %1[0, 0, 0, 0][1, %dim, 1, 1][1, 1, 1, 1]
+      {replacement_0 = 0} : tensor<1x?x1x1xf32> to tensor<?xf32>
+  return
+}
+
+// -----
+
+func.func @non_cast_like_extract_slice() {
+  // expected-error @below {{listener could not find replacement op}}
+  %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
+  %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>)
+  %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 3, 1, 1][1, 1, 1, 1]
+      {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<3xf32>
+  return
+}