From 76ea62a2735a760545bfa98524e7a658a15268ac Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Thu, 30 Mar 2023 11:26:09 +0200 Subject: [PATCH] [mlir] Fix folding into tensor.pad op. When low/high padding is folded in padOp, there should be inserted a tensor.cast back to the original result type. Right now, there is a no-op tensor.cast from new type to new type... Differential Revision: https://reviews.llvm.org/D147210 --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 10 ++++++---- mlir/test/Dialect/Tensor/canonicalize.mlir | 10 +++++----- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index e7fb287..7ee9325 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2879,10 +2879,12 @@ struct FoldStaticPadding : public OpRewritePattern { auto inputDims = input.getType().cast().getShape(); auto inputRank = inputDims.size(); - if (!padTensorOp.getResult().getType().isa()) + auto oldResultType = + dyn_cast(padTensorOp.getResult().getType()); + if (!oldResultType) return failure(); - auto outputDims = - padTensorOp.getResult().getType().cast().getShape(); + + auto outputDims = oldResultType.getShape(); // Extract the static info from the high and low operands. SmallVector constOperandsLow; @@ -2955,7 +2957,7 @@ struct FoldStaticPadding : public OpRewritePattern { IRMapping mapper; padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper); - rewriter.replaceOpWithNewOp(padTensorOp, newResultType, + rewriter.replaceOpWithNewOp(padTensorOp, oldResultType, newOp); return success(); diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 8a5e047..0a42e2b 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1140,7 +1140,7 @@ func.func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) // ----- // CHECK-LABEL: func @pad_fold_static( -// CHECK-SAME: %[[INPUT:.*]]: tensor) -> tensor { +// CHECK-SAME: %[[INPUT:.*]]: tensor) -> tensor { // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[PADDING:.*]] = arith.constant 4 : index // CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]] @@ -1148,16 +1148,16 @@ func.func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) // CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index): // CHECK: tensor.yield %[[CST]] : f32 // CHECK: } : tensor to tensor -func.func @pad_fold_static(%arg0: tensor) - -> tensor { +// CHECK: tensor.cast +func.func @pad_fold_static(%arg0: tensor) -> tensor { + %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f32 %padding = arith.constant 4 : index %padded = tensor.pad %arg0 low[0, %padding, 1, 1] high[0, %padding, 1, 1] { ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): tensor.yield %cst: f32 } : tensor to tensor - %result = tensor.collapse_shape %padded [[0, 1, 2, 3]] : tensor into tensor - return %result : tensor + return %padded : tensor } // ----- -- 2.7.4