[mlir][tensor] Remove incorrect parallel_insert_slice folder
authorThomas Raoux <thomasraoux@google.com>
Thu, 25 Aug 2022 19:35:26 +0000 (19:35 +0000)
committerThomas Raoux <thomasraoux@google.com>
Fri, 26 Aug 2022 15:27:54 +0000 (15:27 +0000)
parallel_insert_slice doesn't return a value therefore we shouldn't try
to fold the result. The insert folding don't apply to this op.
The current folding would cause pattern rewrite to not be able to
converge.

Differential Revision: https://reviews.llvm.org/D132668

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir

index 4095d4e..9a0bbf6 100644 (file)
@@ -1207,7 +1207,6 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
   ];
 
   let hasCanonicalizer = 1;
-  let hasFolder = 1;
   let hasVerifier = 1;
 }
 
index 060e2fd..cd4c4b9 100644 (file)
@@ -1552,7 +1552,6 @@ LogicalResult InsertSliceOp::verify() {
 
 /// If we have two consecutive InsertSliceOp writing to the same slice, we
 /// can mutate the second InsertSliceOp's destination to the first one's.
-/// This works similarly when the second op is a ParallelInsertSliceOp.
 ///
 /// Example:
 ///
@@ -1568,9 +1567,8 @@ LogicalResult InsertSliceOp::verify() {
 /// ```
 ///
 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
-template <typename InsertOpTy>
-static LogicalResult foldInsertAfterInsertSlice(InsertOpTy insertOp) {
-  auto prevInsertOp = insertOp.getDest().template getDefiningOp<InsertOpTy>();
+static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
+  auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
 
   auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
   if (!prevInsertOp ||
@@ -1582,32 +1580,14 @@ static LogicalResult foldInsertAfterInsertSlice(InsertOpTy insertOp) {
   return success();
 }
 
-/// Same logic for folding InsertSliceOp and ParallelInsertSliceOp, the return
-/// type varies though so we wrap it in a FailureOr.
-///
-/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
-template <typename InsertOpTy>
-FailureOr<OpFoldResult> foldInsertOp(InsertOpTy insertOp, ArrayRef<Attribute>) {
-  if (insertOp.getSourceType().hasStaticShape() &&
-      insertOp.getDestType().hasStaticShape() &&
-      insertOp.getSourceType() == insertOp.getDestType() &&
-      succeeded(foldIdentityOffsetSizeAndStrideOpInterface(
-          insertOp, insertOp.getDestType())))
-    return static_cast<OpFoldResult>(insertOp.getSource());
-  if (succeeded(foldInsertAfterInsertSlice(insertOp))) {
-    // InsertSliceOp has 1 result but ParallelInsertSliceOp has none and should
-    // return OpFoldResult().
-    if (std::is_same<InsertOpTy, InsertSliceOp>::value)
-      return static_cast<OpFoldResult>(insertOp->getResult(0));
-    else
-      return OpFoldResult();
-  }
-  return failure();
-}
-
-OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute> operands) {
-  auto maybeOpFoldResult = foldInsertOp(*this, operands);
-  return failed(maybeOpFoldResult) ? OpFoldResult() : *maybeOpFoldResult;
+OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
+  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
+      getSourceType() == getType() &&
+      succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
+    return this->getSource();
+  if (succeeded(foldInsertAfterInsertSlice(*this)))
+    return getResult();
+  return OpFoldResult();
 }
 
 LogicalResult InsertSliceOp::reifyResultShapes(
@@ -2319,58 +2299,6 @@ LogicalResult ParallelInsertSliceOp::verify() {
   return produceSliceErrorMsg(result, *this, expectedType);
 }
 
-namespace {
-/// Pattern to rewrite a parallel_insert_slice op with constant arguments.
-class ParallelInsertSliceOpConstantArgumentFolder final
-    : public OpRewritePattern<ParallelInsertSliceOp> {
-public:
-  using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
-                                PatternRewriter &rewriter) const override {
-    // No constant operand, just return.
-    if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
-          return matchPattern(operand, matchConstantIndex());
-        }))
-      return failure();
-
-    // At least one of offsets/sizes/strides is a new constant.
-    // Form the new list of operands and constant attributes from the
-    // existing.
-    SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
-    SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
-    SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
-    canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
-    canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
-    canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
-
-    // Create the new op in canonical form.
-    auto sourceType =
-        tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
-            insertSliceOp.getSourceType().getRank(),
-            insertSliceOp.getDestType(), mixedOffsets, mixedSizes,
-            mixedStrides);
-    Value toInsert = insertSliceOp.getSource();
-    if (sourceType != insertSliceOp.getSourceType()) {
-      OpBuilder::InsertionGuard g(rewriter);
-      rewriter.setInsertionPoint(insertSliceOp->getParentOp());
-      toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
-                                                 sourceType, toInsert);
-    }
-    rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
-        insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
-        mixedSizes, mixedStrides);
-    return success();
-  }
-};
-} // namespace
-
-LogicalResult
-ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
-                            SmallVectorImpl<OpFoldResult> &results) {
-  return foldInsertOp(*this, operands);
-}
-
 void ParallelInsertSliceOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
index 1eb1a5d..ad50ecb 100644 (file)
@@ -1466,3 +1466,24 @@ func.func @canonicalize_parallel_insert_slice_indices(
   }
   return %2 : tensor<?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @dont_fold_parallel_insert_slice(
+//  CHECK-SAME:     %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>, 
+//  CHECK-SAME:     %[[arg1:[0-9a-z]*]]: tensor<1x5xf32>)
+func.func @dont_fold_parallel_insert_slice(
+    %arg0 : tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32>
+{
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  //      CHECK: scf.foreach_thread () in () -> (tensor<1x5xf32>) {
+  // CHECK-NEXT:   scf.foreach_thread.perform_concurrently {
+  // CHECK-NEXT:     tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][0, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor<1x5xf32>
+  %2 = scf.foreach_thread () in ()  -> (tensor<1x5xf32>) {
+    scf.foreach_thread.perform_concurrently {
+      tensor.parallel_insert_slice %arg0 into %arg1[%c0, %c0] [1, 5] [%c1, %c1] : tensor<1x5xf32> into tensor<1x5xf32>
+    }
+  }
+  return %2 : tensor<1x5xf32>
+}