From b6ab4f1a8b6546b67dbcc3612f33c26d9b72a5cc Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 15 Jun 2021 17:16:32 +0900 Subject: [PATCH] [mlir][linalg] Fold linalg.pad_tensor if src type == result type Fold PadTensorOp to source if source type and result type have static shape and are equal. Differential Revision: https://reviews.llvm.org/D103778 --- mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td | 1 + mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 6 ++++++ mlir/test/Dialect/Linalg/canonicalize.mlir | 16 ++++++++++++++++ 3 files changed, 23 insertions(+) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 9b6120e..1b01776 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -296,6 +296,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor", ]; let hasCanonicalizer = 1; + let hasFolder = 1; } def Linalg_RangeOp : diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 2b3ae89..985a9f7 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1164,6 +1164,12 @@ Value PadTensorOp::getConstantPaddingValue() { return padValue; } +OpFoldResult PadTensorOp::fold(ArrayRef) { + if (getResultType().hasStaticShape() && getResultType() == getSourceType()) + return source(); + return {}; +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 6bd2895..029ac62 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -893,6 +893,22 @@ func @dead_linalg_tensor(%arg0 : tensor<7x7xi32>, %arg1 : tensor<7x7xf32>, // ----- +// CHECK-LABEL: func @pad_tensor_same_static_shape( +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> +// CHECK-NOT: linalg.pad_tensor +// CHECK: return %[[ARG0]] +func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) + -> tensor<5x6xf32> { + %cst = constant 0.000000e+00 : f32 + %0 = linalg.pad_tensor %arg0 low[%a, 0] high[0, %a] { + ^bb0(%arg1: index, %arg2: index): + linalg.yield %cst : f32 + } : tensor<5x6xf32> to tensor<5x6xf32> + return %0 : tensor<5x6xf32> +} + +// ----- + func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index) { %c1 = constant 1 : index -- 2.7.4