[mlir] Canonicalize dynamic tensor.pad ops with constant inputs
authorGeorge Petterson <gpetters@protonmail.com>
Fri, 3 Feb 2023 21:43:45 +0000 (16:43 -0500)
committerGeorge Petterson <gpetters@protonmail.com>
Fri, 3 Feb 2023 21:43:45 +0000 (16:43 -0500)
This commit adds a canonicalization pattern for tensor.pad which changes the output type to static at each dimension where the input shape is static and the high and low operands are constants. This corrects an issue arising in Torch-MLIR where pad ops would sometimes introduce dynamic shapes unnecessarily.

Reviewed By: raikonenfnu

Differential Revision: https://reviews.llvm.org/D143135

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir

index 7960b64..d35895a 100644 (file)
@@ -2846,12 +2846,105 @@ struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
   }
 };
 
+struct FoldStaticPadding : public OpRewritePattern<PadOp> {
+  using OpRewritePattern<PadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PadOp padTensorOp,
+                                PatternRewriter &rewriter) const override {
+    Value input = padTensorOp.getSource();
+    if (!input.getType().isa<RankedTensorType>())
+      return failure();
+    auto inputDims = input.getType().cast<RankedTensorType>().getShape();
+    auto inputRank = inputDims.size();
+
+    if (!padTensorOp.getResult().getType().isa<RankedTensorType>())
+      return failure();
+    auto outputDims =
+        padTensorOp.getResult().getType().cast<RankedTensorType>().getShape();
+
+    // Extract the static info from the high and low operands.
+    SmallVector<int64_t> constOperandsLow;
+    for (auto operand : padTensorOp.getLow()) {
+      APSInt intOp;
+      if (!matchPattern(operand, m_ConstantInt(&intOp))) {
+        constOperandsLow.push_back(ShapedType::kDynamic);
+        continue;
+      }
+      constOperandsLow.push_back(intOp.getExtValue());
+    }
+    SmallVector<int64_t> constOperandsHigh;
+    for (auto operand : padTensorOp.getHigh()) {
+      APSInt intOp;
+      if (!matchPattern(operand, m_ConstantInt(&intOp))) {
+        constOperandsHigh.push_back(ShapedType::kDynamic);
+        continue;
+      }
+      constOperandsHigh.push_back(intOp.getExtValue());
+    }
+
+    SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
+    SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
+
+    // Verify the op is well-formed.
+    if (inputDims.size() != outputDims.size() ||
+        inputDims.size() != constLow.size() ||
+        inputDims.size() != constHigh.size())
+      return failure();
+
+    auto lowCount = 0;
+    auto highCount = 0;
+    for (size_t i = 0; i < inputRank; i++) {
+      if (constLow[i] == ShapedType::kDynamic)
+        constLow[i] = constOperandsLow[lowCount++];
+      if (constHigh[i] == ShapedType::kDynamic)
+        constHigh[i] = constOperandsHigh[highCount++];
+    }
+
+    auto staticLow = ArrayRef<int64_t>(constLow);
+    auto staticHigh = ArrayRef<int64_t>(constHigh);
+
+    // Calculate the output sizes with the static information.
+    SmallVector<int64_t> newOutDims;
+    for (size_t i = 0; i < inputRank; i++) {
+      if (outputDims[i] == ShapedType::kDynamic) {
+        newOutDims.push_back(
+            (staticLow[i] == ShapedType::kDynamic ||
+                     staticHigh[i] == ShapedType::kDynamic ||
+                     inputDims[i] == ShapedType::kDynamic
+                 ? ShapedType::kDynamic
+                 : inputDims[i] + staticLow[i] + staticHigh[i]));
+      } else {
+        newOutDims.push_back(outputDims[i]);
+      }
+    }
+
+    if (SmallVector<int64_t>(outputDims) == newOutDims ||
+        llvm::all_of(newOutDims,
+                     [&](int64_t x) { return x == ShapedType::kDynamic; }))
+      return failure();
+
+    // Rewrite the op using the new static type.
+    auto newResultType = RankedTensorType::get(
+        newOutDims, padTensorOp.getType().getElementType());
+    auto newOp = rewriter.create<PadOp>(
+        padTensorOp->getLoc(), newResultType, input, padTensorOp.getLow(),
+        padTensorOp.getHigh(), staticLow, staticHigh, padTensorOp.getNofold());
+
+    IRMapping mapper;
+    padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, newResultType,
+                                                newOp);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
   results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
-              FoldOrthogonalPaddings>(context);
+              FoldOrthogonalPaddings, FoldStaticPadding>(context);
 }
 
 /// Return the padding value of the PadOp if it constant. In this context,
index f4706fc..6b72804 100644 (file)
@@ -1111,6 +1111,29 @@ func.func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
 
 // -----
 
+// CHECK-LABEL:   func @pad_fold_static(
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?xf32> {
+// CHECK:           %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[PADDING:.*]] = arith.constant 4 : index
+// CHECK:           %[[PADDED:.*]] = tensor.pad %[[INPUT]]
+// CHECK-SAME:        low[0, 4, 1, 1] high[0, 4, 1, 1]  {
+// CHECK:           ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
+// CHECK:             tensor.yield %[[CST]] : f32
+// CHECK:           } : tensor<?x64x?x?xf32> to tensor<?x72x?x?xf32>
+func.func @pad_fold_static(%arg0: tensor<?x64x?x?xf32>)
+    -> tensor<?xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %padding = arith.constant 4 : index
+  %padded = tensor.pad %arg0 low[0, %padding, 1, 1] high[0, %padding, 1, 1]  {
+    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
+    tensor.yield %cst: f32
+  } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
+  %result = tensor.collapse_shape %padded [[0, 1, 2, 3]] : tensor<?x?x?x?xf32> into tensor<?xf32>
+  return %result : tensor<?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @pad_nofold_same_static_shape(
 //  CHECK-SAME:   %[[ARG0:.*]]: tensor<5x6xf32>
 //       CHECK:   %[[PAD:.*]] = tensor.pad