[mlir][linalg] Finer-grained padding control.
authorTobias Gysi <gysit@google.com>
Mon, 27 Sep 2021 19:20:56 +0000 (19:20 +0000)
committerTobias Gysi <gysit@google.com>
Mon, 27 Sep 2021 19:21:37 +0000 (19:21 +0000)
Adapt the signature of the PaddingValueComputationFunction callback to either return the padding value or failure to signal padding is not desired.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110572

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

index 8530b89..03843bd 100644 (file)
@@ -438,10 +438,11 @@ private:
 using TileSizeComputationFunction =
     std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
 
-/// Specify the padding value for an OpOperand. This should be a function of
-/// both the operation and the operand type.
+/// Callback returning the padding value to use for a given OpOperand or failure
+/// for no padding. This should be a function of both the operation and the
+/// operand type.
 using PaddingValueComputationFunction =
-    std::function<Value(OpBuilder &, OpOperand &)>;
+    std::function<FailureOr<Value>(OpBuilder &, OpOperand &)>;
 
 struct LinalgTilingOptions {
   /// Computation function that returns the tile sizes for each operation.
@@ -504,10 +505,11 @@ struct LinalgTilingOptions {
     return *this;
   }
 
-  /// Computation function that returns a padding value to use when padding to
-  /// force static sizes. When `paddingValueComputationFunction` is set, padding
-  /// operations are introduced, that guarantee the underlying op is statically
-  /// shaped and can thus be vectorized.
+  /// Callback returning the padding value to use for a given OpOperand or
+  /// failure for no padding. Padding operations are introduced if
+  /// `paddingValueComputationFunction` is set and does not return failure.
+  /// Padding all operands guarantees the operation is statically shaped and
+  /// thus can be vectorized.
   PaddingValueComputationFunction paddingValueComputationFunction = nullptr;
 
   LinalgTilingOptions &
index cef9e5a..1d28451 100644 (file)
@@ -145,16 +145,21 @@ LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
   return *this;
 }
 
-/// Try to compute a static bounding box for `operand`. The padding happens
-/// even if the operand already has static shape. `result` is the result of a
-/// freshly created PadTensorOp. Return failure if the operand cannot be padded
-/// to a static shape.
+/// Helper function that tries to pad `opOperand`. Exit early and return success
+/// for scalar operands or if `paddingFunc` returns failure. Otherwise, try to
+/// pad the operand even if it already has a static shape. Set `result` to the
+/// result of the created PadTensorOp or return failure if the operand cannot be
+/// padded to a static shape.
 static LogicalResult padOperandToSmallestStaticBoundingBox(
     PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
     const PaddingValueComputationFunction &paddingFunc, Value &result) {
   // Can't pad scalars.
   if (opToPad.getShape(opOperand).empty())
     return success();
+  // Can't pad if no padding value is known.
+  FailureOr<Value> paddingValue = paddingFunc(rewriter, *opOperand);
+  if (failed(paddingValue))
+    return success();
   auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
   // Not a slice op, cannot construct a static bounding box.
   if (!sliceOp)
@@ -173,12 +178,11 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
           opToPad, "No constant bounding box can be found for padding");
     staticSizes.push_back(indexAttr.getInt());
   }
-  Value pad = paddingFunc(rewriter, *opOperand);
   auto staticTensorType = RankedTensorType::get(
       staticSizes, getElementTypeOrSelf(opOperand->get()));
   result = linalg::PadTensorOp::createPadHighOp(
-      staticTensorType, opOperand->get(), pad, /*packing=*/true,
-      opToPad->getLoc(), rewriter);
+      staticTensorType, opOperand->get(), paddingValue.getValue(),
+      /*packing=*/true, opToPad->getLoc(), rewriter);
   return success();
 }
 
index fb28381..7f30762 100644 (file)
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern pad-tiles tile-sizes=2,3,4" -canonicalize | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern pad-tiles tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE
+// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1,2 tile-sizes=2,3,4" -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1 tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE
 
 // CHECK-LABEL: func @matmul_tensors(
 // CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x?xi8>
@@ -97,6 +97,7 @@ func @matmul_partially_padded_tensors(
 // CHECK: linalg.matmul_i8_i8_i32 ins({{.*}}, {{.*}} : tensor<2x4xi8>, tensor<4x3xi8>) outs({{.*}} : tensor<2x3xi32>) -> tensor<2x3xi32>
 
 
+// Check only the the input operands are padded.
 // CHECK-1DIM-TILE: func @matmul_partially_padded_tensors(
 // CHECK-1DIM-TILE-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x8xi8>
 // CHECK-1DIM-TILE-SAME:    %[[TB:[0-9a-z]+]]: tensor<8x?xi8>
@@ -111,10 +112,8 @@ func @matmul_partially_padded_tensors(
 //      CHECK-1DIM-TILE:                   : tensor<?x8xi8> to tensor<2x8xi8>
 //      CHECK-1DIM-TILE:                %[[pB:.*]] = linalg.pad_tensor %[[sTB]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK-1DIM-TILE:                   : tensor<8x?xi8> to tensor<8x3xi8>
-//      CHECK-1DIM-TILE:                %[[pC:.*]] = linalg.pad_tensor %[[sTC]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
-//      CHECK-1DIM-TILE:                   : tensor<?x?xi32> to tensor<2x3xi32>
-//      CHECK-1DIM-TILE:               %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>)
-//      CHECK-1DIM-TILE:                                           outs(%[[pC]] : tensor<2x3xi32>)  -> tensor<2x3xi32>
+//      CHECK-1DIM-TILE:                %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>)
+//      CHECK-1DIM-TILE:                                           outs(%[[sTC]] : tensor<?x?xi32>)  -> tensor<?x?xi32>
 
 // Check that the tile-and-pad transformation actually introduces the padding
 // as requested, even if original operation already operates on static
index 4b4b6ac..74ff41c 100644 (file)
@@ -108,9 +108,10 @@ struct TestLinalgTransforms
       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
                      "pad_tensor(subtensor)"),
       llvm::cl::init(false)};
-  Option<bool> padTiles{*this, "pad-tiles",
-                        llvm::cl::desc("Pad tiles when test-tile-pattern"),
-                        llvm::cl::init(false)};
+  ListOption<int64_t> paddedOperands{
+      *this, "padded-operands",
+      llvm::cl::desc("Operands to pad when test-tile-pattern"),
+      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
   ListOption<int64_t> peeledLoops{
       *this, "peeled-loops",
       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
@@ -576,7 +577,8 @@ static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
 }
 
 static void applyTilePattern(FuncOp funcOp, std::string loopType,
-                             ArrayRef<int64_t> tileSizes, bool padTiles,
+                             ArrayRef<int64_t> tileSizes,
+                             ArrayRef<int64_t> paddedOperands,
                              ArrayRef<int64_t> peeledLoops,
                              bool scalarizeDynamicDims) {
   MLIRContext *context = funcOp.getContext();
@@ -597,10 +599,15 @@ static void applyTilePattern(FuncOp funcOp, std::string loopType,
   } else {
     linalgTilingOptions.setTileSizes(tileSizes);
   }
-  if (padTiles)
-    linalgTilingOptions.setPaddingValueComputationFunction(
-        getNeutralOfLinalgOp);
-
+  if (!paddedOperands.empty()) {
+    auto paddingFunc = [&](OpBuilder &b,
+                           OpOperand &opOperand) -> FailureOr<Value> {
+      if (llvm::count(paddedOperands, opOperand.getOperandNumber()) == 0)
+        return failure();
+      return getNeutralOfLinalgOp(b, opOperand);
+    };
+    linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc);
+  }
   tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
                     linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>,
                     linalg::LinalgTilingPattern<linalg::GenericOp>>(
@@ -734,10 +741,10 @@ void TestLinalgTransforms::runOnFunction() {
     return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling,
                                         skipPartial);
   if (testTilePattern)
-    return applyTilePattern(getFunction(), loopType, tileSizes, padTiles,
+    return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands,
                             peeledLoops, /*scalarizeDynamicDims=*/false);
   if (testTileScalarizeDynamicDims)
-    return applyTilePattern(getFunction(), loopType, tileSizes, padTiles,
+    return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands,
                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
   if (testHoistPadding) {
     getFunction().walk([&](linalg::PadTensorOp padTensorOp) {