From 1ee0d60a9be5dcbe3234b81a1c93e6a206a88154 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Thu, 25 Aug 2022 19:35:26 +0000 Subject: [PATCH] [mlir][tensor] Remove incorrect parallel_insert_slice folder parallel_insert_slice doesn't return a value therefore we shouldn't try to fold the result. The insert folding don't apply to this op. The current folding would cause pattern rewrite to not be able to converge. Differential Revision: https://reviews.llvm.org/D132668 --- mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 1 - mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 92 +++--------------------- mlir/test/Dialect/Tensor/canonicalize.mlir | 21 ++++++ 3 files changed, 31 insertions(+), 83 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 4095d4e..9a0bbf6 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1207,7 +1207,6 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [ ]; let hasCanonicalizer = 1; - let hasFolder = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 060e2fd..cd4c4b9 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1552,7 +1552,6 @@ LogicalResult InsertSliceOp::verify() { /// If we have two consecutive InsertSliceOp writing to the same slice, we /// can mutate the second InsertSliceOp's destination to the first one's. -/// This works similarly when the second op is a ParallelInsertSliceOp. /// /// Example: /// @@ -1568,9 +1567,8 @@ LogicalResult InsertSliceOp::verify() { /// ``` /// /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp. -template -static LogicalResult foldInsertAfterInsertSlice(InsertOpTy insertOp) { - auto prevInsertOp = insertOp.getDest().template getDefiningOp(); +static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) { + auto prevInsertOp = insertOp.getDest().getDefiningOp(); auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; if (!prevInsertOp || @@ -1582,32 +1580,14 @@ static LogicalResult foldInsertAfterInsertSlice(InsertOpTy insertOp) { return success(); } -/// Same logic for folding InsertSliceOp and ParallelInsertSliceOp, the return -/// type varies though so we wrap it in a FailureOr. -/// -/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp. -template -FailureOr foldInsertOp(InsertOpTy insertOp, ArrayRef) { - if (insertOp.getSourceType().hasStaticShape() && - insertOp.getDestType().hasStaticShape() && - insertOp.getSourceType() == insertOp.getDestType() && - succeeded(foldIdentityOffsetSizeAndStrideOpInterface( - insertOp, insertOp.getDestType()))) - return static_cast(insertOp.getSource()); - if (succeeded(foldInsertAfterInsertSlice(insertOp))) { - // InsertSliceOp has 1 result but ParallelInsertSliceOp has none and should - // return OpFoldResult(). - if (std::is_same::value) - return static_cast(insertOp->getResult(0)); - else - return OpFoldResult(); - } - return failure(); -} - -OpFoldResult InsertSliceOp::fold(ArrayRef operands) { - auto maybeOpFoldResult = foldInsertOp(*this, operands); - return failed(maybeOpFoldResult) ? OpFoldResult() : *maybeOpFoldResult; +OpFoldResult InsertSliceOp::fold(ArrayRef) { + if (getSourceType().hasStaticShape() && getType().hasStaticShape() && + getSourceType() == getType() && + succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) + return this->getSource(); + if (succeeded(foldInsertAfterInsertSlice(*this))) + return getResult(); + return OpFoldResult(); } LogicalResult InsertSliceOp::reifyResultShapes( @@ -2319,58 +2299,6 @@ LogicalResult ParallelInsertSliceOp::verify() { return produceSliceErrorMsg(result, *this, expectedType); } -namespace { -/// Pattern to rewrite a parallel_insert_slice op with constant arguments. -class ParallelInsertSliceOpConstantArgumentFolder final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp, - PatternRewriter &rewriter) const override { - // No constant operand, just return. - if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) { - return matchPattern(operand, matchConstantIndex()); - })) - return failure(); - - // At least one of offsets/sizes/strides is a new constant. - // Form the new list of operands and constant attributes from the - // existing. - SmallVector mixedOffsets(insertSliceOp.getMixedOffsets()); - SmallVector mixedSizes(insertSliceOp.getMixedSizes()); - SmallVector mixedStrides(insertSliceOp.getMixedStrides()); - canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset); - canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic); - canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); - - // Create the new op in canonical form. - auto sourceType = - tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( - insertSliceOp.getSourceType().getRank(), - insertSliceOp.getDestType(), mixedOffsets, mixedSizes, - mixedStrides); - Value toInsert = insertSliceOp.getSource(); - if (sourceType != insertSliceOp.getSourceType()) { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(insertSliceOp->getParentOp()); - toInsert = rewriter.create(insertSliceOp.getLoc(), - sourceType, toInsert); - } - rewriter.replaceOpWithNewOp( - insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets, - mixedSizes, mixedStrides); - return success(); - } -}; -} // namespace - -LogicalResult -ParallelInsertSliceOp::fold(ArrayRef operands, - SmallVectorImpl &results) { - return foldInsertOp(*this, operands); -} - void ParallelInsertSliceOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results.add, diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 1eb1a5d..ad50ecb 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1466,3 +1466,24 @@ func.func @canonicalize_parallel_insert_slice_indices( } return %2 : tensor } + +// ----- + +// CHECK-LABEL: func.func @dont_fold_parallel_insert_slice( +// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>, +// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<1x5xf32>) +func.func @dont_fold_parallel_insert_slice( + %arg0 : tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> +{ + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: scf.foreach_thread () in () -> (tensor<1x5xf32>) { + // CHECK-NEXT: scf.foreach_thread.perform_concurrently { + // CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][0, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor<1x5xf32> + %2 = scf.foreach_thread () in () -> (tensor<1x5xf32>) { + scf.foreach_thread.perform_concurrently { + tensor.parallel_insert_slice %arg0 into %arg1[%c0, %c0] [1, 5] [%c1, %c1] : tensor<1x5xf32> into tensor<1x5xf32> + } + } + return %2 : tensor<1x5xf32> +} -- 2.7.4