From f3f25ffc04c0cbcc9a9bfc1b32b61750e8934ea8 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 24 Sep 2021 16:39:37 +0900 Subject: [PATCH] [mlir][linalg] Fix result type in FoldSourceTensorCast * Do not discard static result type information that cannot be inferred from lower/upper padding. * Add optional argument to `PadTensorOp::inferResultType` for specifying known result dimensions. Differential Revision: https://reviews.llvm.org/D110380 --- mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td | 10 +++++++--- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 20 ++++++++++++++------ mlir/test/Dialect/Linalg/canonicalize.mlir | 24 ++++++++++++++++++++++-- 3 files changed, 43 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 4c82eaf..dd568ba 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -226,10 +226,14 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor", } // Infer the shape of the result tensor given the type of the source tensor - // and paddings. - static RankedTensorType inferResultType(RankedTensorType sourceType, + // and paddings. Known result dimensions that cannot necessarily be inferred + // from low/high padding sizes can be optionally specified. Those will be + // considered when computing the result type. + static RankedTensorType inferResultType( + RankedTensorType sourceType, ArrayRef staticLow, - ArrayRef staticHigh); + ArrayRef staticHigh, + ArrayRef resultShape = {}); // Return a PadTensorOp that pads `source` to `type` size where the static // sizes are assumed to be greater than the dynamic sizes. The op performs diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index b3eeaab..75e4a1c 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1055,24 +1055,31 @@ static LogicalResult verify(PadTensorOp op) { RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType, ArrayRef staticLow, - ArrayRef staticHigh) { + ArrayRef staticHigh, + ArrayRef resultShape) { unsigned rank = sourceType.getRank(); assert(staticLow.size() == rank && "unexpected staticLow size mismatch"); assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch"); + assert((resultShape.empty() || resultShape.size() == rank) && + "unexpected resultShape size mismatch"); - SmallVector resultShape; + SmallVector inferredShape; for (auto i : llvm::seq(0, rank)) { if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamicSize || staticHigh[i] == ShapedType::kDynamicSize) { - resultShape.push_back(ShapedType::kDynamicSize); + inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamicSize + : resultShape[i]); } else { int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i]; - resultShape.push_back(size); + assert((resultShape.empty() || size == resultShape[i] || + resultShape[i] == ShapedType::kDynamicSize) && + "mismatch between inferred shape and result shape"); + inferredShape.push_back(size); } } - return RankedTensorType::get(resultShape, sourceType.getElementType()); + return RankedTensorType::get(inferredShape, sourceType.getElementType()); } void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source, @@ -1454,7 +1461,8 @@ struct FoldSourceTensorCast : public OpRewritePattern { auto newResultType = PadTensorOp::inferResultType( castOp.source().getType().cast(), extractFromI64ArrayAttr(padTensorOp.static_low()), - extractFromI64ArrayAttr(padTensorOp.static_high())); + extractFromI64ArrayAttr(padTensorOp.static_high()), + padTensorOp.getResultType().getShape()); if (newResultType == padTensorOp.getResultType()) { rewriter.updateRootInPlace(padTensorOp, [&]() { diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 3d434c2..fce08a1e0 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -629,7 +629,8 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) } // ----- -// CHECK-LABEL: func @pad_tensor_after_cast_differnt_shape( + +// CHECK-LABEL: func @pad_tensor_after_cast_different_shape( // CHECK-SAME: %[[INPUT:.*]]: tensor) -> tensor { // CHECK: %[[CST:.*]] = constant 0.000000e+00 : f32 // CHECK: %[[PADDED:.*]] = linalg.pad_tensor %[[INPUT]] @@ -641,7 +642,7 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) // CHECK-SAME: tensor to tensor // CHECK: return %[[DYNAMIC]] : tensor // CHECK: } -func @pad_tensor_after_cast_differnt_shape(%arg0: tensor) +func @pad_tensor_after_cast_different_shape(%arg0: tensor) -> tensor { %cst = constant 0.000000e+00 : f32 %dynamic = tensor.cast %arg0 : tensor to tensor @@ -653,6 +654,7 @@ func @pad_tensor_after_cast_differnt_shape(%arg0: tensor) } // ----- + // CHECK-LABEL: func @pad_tensor_after_cast_same_shape( // CHECK-SAME: %[[INPUT:.*]]: tensor, // CHECK-SAME: %[[PADDING:.*]]: index) -> tensor { @@ -676,6 +678,24 @@ func @pad_tensor_after_cast_same_shape(%arg0: tensor, %padding : i } // ----- + +// CHECK-LABEL: func @pad_tensor_of_cast( +// CHECK-NOT: tensor.cast +// CHECK: linalg.pad_tensor +// CHECK: tensor<8x?xf32> to tensor<8x32xf32> +func @pad_tensor_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> { + %c0 = constant 0 : index + %cst = constant 0.000000e+00 : f32 + %0 = tensor.cast %t : tensor<8x?xf32> to tensor + %1 = linalg.pad_tensor %0 low[%c0, %c0] high[%c0, %s] { + ^bb0(%arg9: index, %arg10: index): // no predecessors + linalg.yield %cst : f32 + } : tensor to tensor<8x32xf32> + return %1 : tensor<8x32xf32> +} + +// ----- + func @propogate_casts(%arg0 : tensor, %arg1 : f32, %arg2 : index, %arg3 : index) -> tensor { %c0 = constant 0 : index -- 2.7.4