[mlir][Linalg] Replace SimplePad with PadTensor in tile-and-pad
authorHanhan Wang <hanchung@google.com>
Thu, 28 Jan 2021 14:49:48 +0000 (06:49 -0800)
committerHanhan Wang <hanchung@google.com>
Thu, 28 Jan 2021 14:50:26 +0000 (06:50 -0800)
This revision creates a build method of PadTensorOp which can be mapped to
SimplePad op. The verifier is updated to accept a static custom result type,
which has the same semantic as SimplePadOp.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D95555

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir

index 9ea1bc5..67c0615 100644 (file)
@@ -199,6 +199,13 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     static RankedTensorType inferResultType(RankedTensorType sourceType,
                                 ArrayRef<int64_t> staticLow,
                                 ArrayRef<int64_t> staticHigh);
+
+    // Return a PadTensorOp that pads `source` to `type` size where the static
+    // sizes are assumed to be greater than the dynamic sizes. The op performs
+    // "high" padding (i.e. it adds trailing padding values until the desired
+    // size is met).
+    static linalg::PadTensorOp createPadHighOp(
+        Type type, Value source, Value pad, Location loc, OpBuilder & builder);
   }];
 
   let builders = [
@@ -208,6 +215,11 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     // Build a PadTensorOp with all dynamic entries.
     OpBuilderDAG<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build a PadTensorOp with with mixed static and dynamic entries and custom
+    // result type. If the type passed is nullptr, it is inferred.
+    OpBuilderDAG<(ins "Type":$resultType, "Value":$source,
+      "ArrayRef<OpFoldResult>":$low, "ArrayRef<OpFoldResult>":$high,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
   ];
 }
index 89a0b44..21c9108 100644 (file)
@@ -970,7 +970,11 @@ static LogicalResult verify(PadTensorOp op) {
   auto expectedType = PadTensorOp::inferResultType(
       sourceType, extractFromI64ArrayAttr(op.static_low()),
       extractFromI64ArrayAttr(op.static_high()));
-  if (resultType != expectedType) {
+  for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
+    if (resultType.getDimSize(i) == expectedType.getDimSize(i))
+      continue;
+    if (expectedType.isDynamicDim(i))
+      continue;
     return op.emitError("specified type ")
            << resultType << " does not match the inferred type "
            << expectedType;
@@ -1077,6 +1081,24 @@ static void print(OpAsmPrinter &p, PadTensorOp op) {
   p << " : " << op.source().getType() << " to " << op.getType();
 }
 
+/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
+/// it is a Value or into `staticVec` if it is an IntegerAttr.
+/// In the case of a Value, a copy of the `sentinel` value is also pushed to
+/// `staticVec`. This is useful to extract mixed static and dynamic entries that
+/// come from an AttrSizedOperandSegments trait.
+static void dispatchIndexOpFoldResult(OpFoldResult ofr,
+                                      SmallVectorImpl<Value> &dynamicVec,
+                                      SmallVectorImpl<int64_t> &staticVec,
+                                      int64_t sentinel) {
+  if (auto v = ofr.dyn_cast<Value>()) {
+    dynamicVec.push_back(v);
+    staticVec.push_back(sentinel);
+    return;
+  }
+  APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
+  staticVec.push_back(apInt.getSExtValue());
+}
+
 void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
                         ArrayRef<int64_t> staticLow,
                         ArrayRef<int64_t> staticHigh, ValueRange low,
@@ -1097,6 +1119,60 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
   build(b, result, source, staticVector, staticVector, low, high, attrs);
 }
 
+void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
+                        Value source, ArrayRef<OpFoldResult> low,
+                        ArrayRef<OpFoldResult> high,
+                        ArrayRef<NamedAttribute> attrs) {
+  assert(resultType.isa<RankedTensorType>());
+  auto sourceType = source.getType().cast<RankedTensorType>();
+  unsigned rank = sourceType.getRank();
+  SmallVector<Value, 4> dynamicLow, dynamicHigh;
+  SmallVector<int64_t, 4> staticLow, staticHigh;
+  for (unsigned i = 0; i < rank; ++i) {
+    // staticLow and staticHigh have full information of the padding config.
+    // This will grow staticLow and staticHigh with 1 value. If the config is
+    // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
+    // value as well.
+    dispatchIndexOpFoldResult(low[i], dynamicLow, staticLow,
+                              ShapedType::kDynamicSize);
+    dispatchIndexOpFoldResult(high[i], dynamicHigh, staticHigh,
+                              ShapedType::kDynamicSize);
+  }
+  if (!resultType) {
+    resultType =
+        PadTensorOp::inferResultType(sourceType, staticLow, staticHigh);
+  }
+  build(b, result, resultType, source, dynamicLow, dynamicHigh,
+        b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh));
+}
+
+PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad,
+                                         Location loc, OpBuilder &builder) {
+  SmallVector<OpFoldResult, 4> low, high;
+  auto rankedTensorType = type.cast<RankedTensorType>();
+  assert(rankedTensorType.hasStaticShape());
+  int rank = rankedTensorType.getRank();
+  for (int i = 0; i < rank; ++i) {
+    auto dimOp = builder.createOrFold<DimOp>(loc, source, i);
+    auto resultDimSize = builder.createOrFold<ConstantIndexOp>(
+        loc, rankedTensorType.getDimSize(i));
+    auto highValue = builder.createOrFold<SubIOp>(loc, resultDimSize, dimOp);
+    high.push_back(highValue);
+    low.push_back(builder.createOrFold<ConstantIndexOp>(loc, 0));
+  }
+  auto padTensorOp =
+      builder.create<linalg::PadTensorOp>(loc, type, source, low, high);
+  SmallVector<Type, 4> blockArgTypes;
+  blockArgTypes.assign(rank, builder.getIndexType());
+  auto &region = padTensorOp.region();
+  // `builder.createBlock` changes the insertion point within the block. Create
+  // a guard to reset the insertion point of the builder after it is destroyed.
+  OpBuilder::InsertionGuard guard(builder);
+  builder.createBlock(&region, region.end(), blockArgTypes);
+  builder.create<linalg::YieldOp>(loc, pad);
+  return padTensorOp;
+}
+
 //===----------------------------------------------------------------------===//
 // ReshapeOp
 //===----------------------------------------------------------------------===//
index b4c94ae..7260bb4 100644 (file)
@@ -120,7 +120,7 @@ mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
 /// Return success if either:
 ///   1. The operand is already statically shaped, `result` is left unchanged.
 ///   2. The operand is (partially) dynamic, `result` is the result of a freshly
-///      created SimplePadOp.
+///      created PadTensorOp.
 /// Return failure if the operand cannot be padded to a static shape.
 static LogicalResult padOperandToSmallestStaticBoundingBox(
     PatternRewriter &rewriter, linalg::LinalgOp opToPad, Value operand,
@@ -151,8 +151,8 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
   Value pad = options.paddingValueComputationFunction(rewriter, opToPad);
   auto staticTensorType =
       RankedTensorType::get(staticSizes, tensorType.getElementType());
-  result = rewriter.create<linalg::SimplePadOp>(opToPad->getLoc(),
-                                                staticTensorType, operand, pad);
+  result = linalg::PadTensorOp::createPadHighOp(staticTensorType, operand, pad,
+                                                opToPad->getLoc(), rewriter);
   return success();
 }
 
index 6dc0768..879dfa3 100644 (file)
@@ -57,6 +57,25 @@ func @pad_asymmetrical(%arg0: tensor<2x3xf32>, %ub0: index, %ub1: index,
 
 // -----
 
+func @pad_to_static_size(%arg0: tensor<?x?xf32>, %ub0: index, %ub1: index,
+                         %pad_value: f32) -> tensor<2x3xf32> {
+  %0 = linalg.pad_tensor %arg0 low[0, 0] high[%ub0, %ub1] {
+    ^bb0(%arg1: index, %arg2: index):
+      linalg.yield %pad_value : f32
+    } : tensor<?x?xf32> to tensor<2x3xf32>
+  return %0 : tensor<2x3xf32>
+}
+// CHECK-LABEL: func @pad_to_static_size
+//  CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+//  CHECK-SAME: %[[UB0:[a-zA-Z0-9_]*]]
+//  CHECK-SAME: %[[UB1:[a-zA-Z0-9_]*]]
+//       CHECK:   linalg.pad_tensor %[[ARG0]]
+//  CHECK-SAME:     low[0, 0]
+//  CHECK-SAME:     high[%[[UB0]], %[[UB1]]]
+//       CHECK:    : tensor<?x?xf32> to tensor<2x3xf32>
+
+// -----
+
 func @range(%arg0: index, %arg1: index, %arg2: index) {
   %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
   return
index e412108..4eaf830 100644 (file)
@@ -18,9 +18,12 @@ func @matmul_tensors(
 //  CHECK-NOT:       linalg.matmul {{.*}} tensor<?x?xf32>
 
 // Padding injects static information.
-//      CHECK:       %[[pA:.*]] = linalg.simple_pad %[[sTA]] pad %{{.*}} : tensor<?x?xf32> to tensor<2x4xf32> pad f32
-//      CHECK:       %[[pB:.*]] = linalg.simple_pad %[[sTB]] pad %{{.*}} : tensor<?x?xf32> to tensor<4x3xf32> pad f32
-//      CHECK:       %[[pC:.*]] = linalg.simple_pad %[[sTC]] pad %{{.*}} : tensor<?x?xf32> to tensor<2x3xf32> pad f32
+//      CHECK:       %[[pA:.*]] = linalg.pad_tensor %[[sTA]] low[0, 0] high[%{{.*}}, %{{.*}}]
+//      CHECK:         : tensor<?x?xf32> to tensor<2x4xf32>
+//      CHECK:       %[[pB:.*]] = linalg.pad_tensor %[[sTB]] low[0, 0] high[%{{.*}}, %{{.*}}]
+//      CHECK:         : tensor<?x?xf32> to tensor<2x4xf32>
+//      CHECK:       %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[0, 0] high[%{{.*}}, %{{.*}}]
+//      CHECK:         : tensor<?x?xf32> to tensor<2x4xf32>
 //      CHECK:       %[[pD:.*]] = linalg.matmul ins(%[[pA]], %[[pB]] : tensor<2x4xf32>, tensor<4x3xf32>)
 // CHECK-SAME:                                  outs(%[[pC]] : tensor<2x3xf32>)  -> tensor<2x3xf32>
 //      CHECK:       %[[sTD:.*]] = subtensor %[[pD]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<2x3xf32> to tensor<?x?xf32>