[mlir][linalg] Fold linalg.pad_tensor if src type == result type
authorMatthias Springer <springerm@google.com>
Tue, 15 Jun 2021 08:16:32 +0000 (17:16 +0900)
committerMatthias Springer <springerm@google.com>
Tue, 15 Jun 2021 08:25:12 +0000 (17:25 +0900)
Fold PadTensorOp to source if source type and result type have static shape and are equal.

Differential Revision: https://reviews.llvm.org/D103778

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir

index 9b6120e..1b01776 100644 (file)
@@ -296,6 +296,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
   ];
 
   let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 def Linalg_RangeOp :
index 2b3ae89..985a9f7 100644 (file)
@@ -1164,6 +1164,12 @@ Value PadTensorOp::getConstantPaddingValue() {
   return padValue;
 }
 
+OpFoldResult PadTensorOp::fold(ArrayRef<Attribute>) {
+  if (getResultType().hasStaticShape() && getResultType() == getSourceType())
+    return source();
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // ReshapeOp
 //===----------------------------------------------------------------------===//
index 6bd2895..029ac62 100644 (file)
@@ -893,6 +893,22 @@ func @dead_linalg_tensor(%arg0 : tensor<7x7xi32>, %arg1 : tensor<7x7xf32>,
 
 // -----
 
+// CHECK-LABEL: func @pad_tensor_same_static_shape(
+//  CHECK-SAME:   %[[ARG0:.*]]: tensor<5x6xf32>
+//   CHECK-NOT:   linalg.pad_tensor
+//       CHECK:   return %[[ARG0]]
+func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
+    -> tensor<5x6xf32> {
+  %cst = constant 0.000000e+00 : f32
+  %0 = linalg.pad_tensor %arg0 low[%a, 0] high[0, %a] {
+        ^bb0(%arg1: index, %arg2: index):
+          linalg.yield %cst : f32
+  } : tensor<5x6xf32> to tensor<5x6xf32>
+  return %0 : tensor<5x6xf32>
+}
+
+// -----
+
 func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
 {
   %c1 = constant 1 : index