[mlir][linalg] Remove duplicate tensor.pad lowering pattern
authorMatthias Springer <me@m-sp.org>
Thu, 22 Jun 2023 09:30:57 +0000 (11:30 +0200)
committerMatthias Springer <me@m-sp.org>
Thu, 22 Jun 2023 09:35:27 +0000 (11:35 +0200)
There is another transform that lowers tensor.pad to tensor.empty + linalg.fill + tensor.insert_slice: `transform.structured.rewrite_in_destination_passing_style`. Delete the other transform.

Differential Revision: https://reviews.llvm.org/D153429

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/lower-pad-tensor.mlir [deleted file]
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

index 22137ed..9591f0b 100644 (file)
@@ -1151,15 +1151,6 @@ struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
                                 PatternRewriter &rewriter) const override;
 };
 
-/// tensor::PadOp is not canonicalized away yet, so we provide a
-/// transformation to `linalg.generic`.
-struct PadOpTransformationPattern : public OpRewritePattern<tensor::PadOp> {
-  using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tensor::PadOp padOp,
-                                PatternRewriter &rewriter) const override;
-};
-
 using OptimizeCopyFn =
     std::function<LogicalResult(RewriterBase &, tensor::PadOp, Value)>;
 
index 5fd9228..9044fea 100644 (file)
@@ -139,12 +139,6 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
                                opOperand->get(), paddingValue, nofold);
 }
 
-static SmallVector<utils::IteratorType>
-getNParallelLoopsAttrs(unsigned nParallelLoops) {
-  return SmallVector<utils::IteratorType>(nParallelLoops,
-                                          utils::IteratorType::parallel);
-}
-
 //===----------------------------------------------------------------------===//
 // Transformations exposed as functional-style API calls.
 //===----------------------------------------------------------------------===//
@@ -1028,71 +1022,6 @@ LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
   return vectorizeCopy(rewriter, copyOp);
 }
 
-///
-/// Pattern to rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp (to
-/// initialize with pad_val) and GenericOp (to copy contents).
-///
-LogicalResult
-PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
-                                            PatternRewriter &rewriter) const {
-
-  auto inputShapedType = cast<ShapedType>(padOp.getSource().getType());
-  auto resultShapedType = cast<ShapedType>(padOp.getResult().getType());
-
-  // Bail on non-static shapes.
-  if (!inputShapedType.hasStaticShape())
-    return failure();
-  if (!resultShapedType.hasStaticShape())
-    return failure();
-
-  // Only support padding with a constant for now, i.e. either:
-  //   1. A BBarg from a different block.
-  //   2. A value defined outside of the current block.
-  Block &block = padOp.getRegion().front();
-  auto yieldOp = cast<tensor::YieldOp>(block.getTerminator());
-  Value padValue = yieldOp.getValue();
-  Operation *definingOp = padValue.getDefiningOp();
-  if (definingOp && definingOp->getBlock() == &block)
-    return failure();
-  if (!definingOp && cast<BlockArgument>(padValue).getOwner() == &block)
-    return failure();
-
-  // Create tensor with the padded shape
-  Location loc = padOp.getLoc();
-  SmallVector<Value> indices(resultShapedType.getRank(),
-                             rewriter.create<arith::ConstantIndexOp>(loc, 0));
-  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
-      loc, resultShapedType.getShape(), resultShapedType.getElementType());
-
-  // Initialize tensor with the pad value
-  Value tmpTensor = rewriter
-                        .create<linalg::FillOp>(loc, ValueRange{padValue},
-                                                ValueRange{emptyTensor})
-                        .result();
-
-  // Copy original contents into new tensor
-  // Uses linalg.generic, but could be done with tensor.insert_slice
-  SmallVector<AffineExpr, 4> outputExprs;
-  for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
-    outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
-                          padOp.getStaticLow()[i]);
-  }
-
-  SmallVector<AffineMap, 2> transferMaps = {
-      rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
-      AffineMap::get(resultShapedType.getRank(),
-                     /*symbolCount=*/0, outputExprs, rewriter.getContext())};
-
-  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
-      padOp, resultShapedType, padOp.getSource(), tmpTensor, transferMaps,
-      getNParallelLoopsAttrs(resultShapedType.getRank()),
-      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
-        nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
-      });
-
-  return success();
-}
-
 /// Filling `dest` using FillOp constant padding value if possible.
 /// Otherwise, generate a tensor::GenerateOp.
 Value GeneralizePadOpPattern::createFillOrGenerateOp(
diff --git a/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir b/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir
deleted file mode 100644 (file)
index 6df26b9..0000000
+++ /dev/null
@@ -1,63 +0,0 @@
-// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-transform-pad-tensor"  %s | FileCheck --check-prefix=CHECK %s
-
-// CHECK-DAG:   #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG:   #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 + 1, d1 + 1, d2 + 1, d3 + 2)>
-// CHECK-LABEL: func @pad_tensor_with_memrefs
-func.func @pad_tensor_with_memrefs(%arg0: memref<1x28x28x1xf32>) -> memref<2x31x31x3xf32> {
-  %cst = arith.constant 0.000000e+00 : f32
-  %0 = bufferization.to_tensor %arg0 : memref<1x28x28x1xf32>
-  %1 = tensor.pad %0 low[1, 1, 1, 2] high[0, 2, 2, 0]  {
-  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
-    tensor.yield %cst : f32
-  } : tensor<1x28x28x1xf32> to tensor<2x31x31x3xf32>
-  %2 = bufferization.to_memref %1 : memref<2x31x31x3xf32>
-  return %2 : memref<2x31x31x3xf32>
-}
-
-// CHECK:       linalg.fill
-// CHECK:       linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-
-// -----
-
-// CHECK-DAG:   #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG:   #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0 + 1, d1 + 2, d2 + 2)>
-// CHECK-LABEL: func @pad_tensor_no_memrefs
-func.func @pad_tensor_no_memrefs(%arg0: tensor<1x28x28xf32>) -> tensor<2x32x32xf32> {
-  %cst = arith.constant 0.000000e+00 : f32
-  %0 = tensor.pad %arg0 low[1, 2, 2] high[0, 2, 2]  {
-  ^bb0(%arg1: index, %arg2: index, %arg3: index):
-    tensor.yield %cst : f32
-  } : tensor<1x28x28xf32> to tensor<2x32x32xf32>
-  return %0 : tensor<2x32x32xf32>
-}
-
-// CHECK:       linalg.fill
-// CHECK:       linalg.generic
-// CHECK-SAME:      indexing_maps = [#[[$MAP2]], #[[$MAP3]]]
-
-// -----
-
-// CHECK-DAG:   #[[$MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG:   #[[$MAP5:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 + 2, d2 + 2, d3)>
-// CHECK-LABEL: func @pad_tensor_detailed
-func.func @pad_tensor_detailed(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
-  %cst = arith.constant 0.000000e+00 : f32
-  %0 = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0]  {
-  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
-    tensor.yield %cst : f32
-  } : tensor<1x28x28x1xf32> to tensor<1x32x32x1xf32>
-  return %0 : tensor<1x32x32x1xf32>
-}
-
-// CHECK:      %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32>
-// CHECK:      %[[CTE:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK:      %[[TMP:.+]] = tensor.empty() : tensor<1x32x32x1xf32>
-// CHECK:      %[[R1c:.+]] = linalg.fill
-// CHECK:      %[[R2c:.+]] = linalg.generic
-// CHECK-SAME:   indexing_maps = [#[[$MAP4]], #[[$MAP5]]]
-// CHECK-SAME:   iterator_types = ["parallel", "parallel", "parallel", "parallel"]
-// CHECK:        ins(%{{.*}} : tensor<1x28x28x1xf32>) outs(%{{.*}} : tensor<1x32x32x1xf32>)
-// CHECK:      ^bb0(%[[VAL:.+]]: f32, %{{.*}}: f32)
-// CHECK:        linalg.yield %[[VAL]] : f32
-// CHECK:      return %[[R2c:.+]]
index c1d01ac..4892fa2 100644 (file)
@@ -70,10 +70,6 @@ struct TestLinalgTransforms
       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
                      "in vector.contract form"),
       llvm::cl::init(false)};
-  Option<bool> testTransformPadTensor{
-      *this, "test-transform-pad-tensor",
-      llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
-      llvm::cl::init(false)};
   Option<bool> testGeneralizePadTensor{
       *this, "test-generalize-pad-tensor",
       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
@@ -163,12 +159,6 @@ static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
-static void applyPadTensorToGenericPatterns(func::FuncOp funcOp) {
-  RewritePatternSet patterns(funcOp.getContext());
-  patterns.add<PadOpTransformationPattern>(funcOp.getContext());
-  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-}
-
 static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) {
   RewritePatternSet patterns(funcOp.getContext());
   patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
@@ -225,8 +215,6 @@ void TestLinalgTransforms::runOnOperation() {
     return applyVectorTransferForwardingPatterns(getOperation());
   if (testGenericToVectorPattern)
     return applyLinalgToVectorPatterns(getOperation());
-  if (testTransformPadTensor)
-    return applyPadTensorToGenericPatterns(getOperation());
   if (testGeneralizePadTensor)
     return applyGeneralizePadTensorPatterns(getOperation());
   if (testGeneralizeTensorPackOp)