/// If we have two consecutive InsertSliceOp writing to the same slice, we
/// can mutate the second InsertSliceOp's destination to the first one's.
-/// This works similarly when the second op is a ParallelInsertSliceOp.
///
/// Example:
///
/// ```
///
/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
-template <typename InsertOpTy>
-static LogicalResult foldInsertAfterInsertSlice(InsertOpTy insertOp) {
- auto prevInsertOp = insertOp.getDest().template getDefiningOp<InsertOpTy>();
+static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
+ auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
if (!prevInsertOp ||
return success();
}
-/// Same logic for folding InsertSliceOp and ParallelInsertSliceOp, the return
-/// type varies though so we wrap it in a FailureOr.
-///
-/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
-template <typename InsertOpTy>
-FailureOr<OpFoldResult> foldInsertOp(InsertOpTy insertOp, ArrayRef<Attribute>) {
- if (insertOp.getSourceType().hasStaticShape() &&
- insertOp.getDestType().hasStaticShape() &&
- insertOp.getSourceType() == insertOp.getDestType() &&
- succeeded(foldIdentityOffsetSizeAndStrideOpInterface(
- insertOp, insertOp.getDestType())))
- return static_cast<OpFoldResult>(insertOp.getSource());
- if (succeeded(foldInsertAfterInsertSlice(insertOp))) {
- // InsertSliceOp has 1 result but ParallelInsertSliceOp has none and should
- // return OpFoldResult().
- if (std::is_same<InsertOpTy, InsertSliceOp>::value)
- return static_cast<OpFoldResult>(insertOp->getResult(0));
- else
- return OpFoldResult();
- }
- return failure();
-}
-
-OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute> operands) {
- auto maybeOpFoldResult = foldInsertOp(*this, operands);
- return failed(maybeOpFoldResult) ? OpFoldResult() : *maybeOpFoldResult;
+OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
+ if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
+ getSourceType() == getType() &&
+ succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
+ return this->getSource();
+ if (succeeded(foldInsertAfterInsertSlice(*this)))
+ return getResult();
+ return OpFoldResult();
}
LogicalResult InsertSliceOp::reifyResultShapes(
return produceSliceErrorMsg(result, *this, expectedType);
}
-namespace {
-/// Pattern to rewrite a parallel_insert_slice op with constant arguments.
-class ParallelInsertSliceOpConstantArgumentFolder final
- : public OpRewritePattern<ParallelInsertSliceOp> {
-public:
- using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
- PatternRewriter &rewriter) const override {
- // No constant operand, just return.
- if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
- return matchPattern(operand, matchConstantIndex());
- }))
- return failure();
-
- // At least one of offsets/sizes/strides is a new constant.
- // Form the new list of operands and constant attributes from the
- // existing.
- SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
- SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
- SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
- canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
- canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
- canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
-
- // Create the new op in canonical form.
- auto sourceType =
- tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
- insertSliceOp.getSourceType().getRank(),
- insertSliceOp.getDestType(), mixedOffsets, mixedSizes,
- mixedStrides);
- Value toInsert = insertSliceOp.getSource();
- if (sourceType != insertSliceOp.getSourceType()) {
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(insertSliceOp->getParentOp());
- toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
- sourceType, toInsert);
- }
- rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
- insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
- mixedSizes, mixedStrides);
- return success();
- }
-};
-} // namespace
-
-LogicalResult
-ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult> &results) {
- return foldInsertOp(*this, operands);
-}
-
void ParallelInsertSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
}
return %2 : tensor<?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @dont_fold_parallel_insert_slice(
+// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>,
+// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<1x5xf32>)
+func.func @dont_fold_parallel_insert_slice(
+ %arg0 : tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32>
+{
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: scf.foreach_thread () in () -> (tensor<1x5xf32>) {
+ // CHECK-NEXT: scf.foreach_thread.perform_concurrently {
+ // CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][0, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor<1x5xf32>
+ %2 = scf.foreach_thread () in () -> (tensor<1x5xf32>) {
+ scf.foreach_thread.perform_concurrently {
+ tensor.parallel_insert_slice %arg0 into %arg1[%c0, %c0] [1, 5] [%c1, %c1] : tensor<1x5xf32> into tensor<1x5xf32>
+ }
+ }
+ return %2 : tensor<1x5xf32>
+}