From 9e3ca7987a4dc33cdf847b79a6304b117651d21f Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 22 Mar 2023 00:54:15 +0000 Subject: [PATCH] [mlir][tosa] Canonicalize concatenate->slice sequence Adds a canonicalizer for the concatenate->slice sequence where an output of slice can be replaced with an input of concatenate. This is useful in the context of operations with complex inputs and outputs that are legalized from a framework such as TFL. For example, a TFL graph (FFT->FFT) will be legalized to the following TOSA graph: / \ slice slice \ / FFT / \ -+ concatenate | / \ | Redundant slice slice | \ / -+ FFT / \ concatenate | Concatenate and slice operations at the boundaries of the graph are useful as they maintain the correct correspondance of input/output tensors to the original TFL graph. However, consecutive complex operations will result in redundant concatenate->slice sequences which should be removed from the final TOSA graph. The canonicalization does not currently handle dynamic types. Signed-off-by: Luke Hutton Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D144545 --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 1 + mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 59 ++++++++++++++++++++++ mlir/test/Dialect/Tosa/canonicalize.mlir | 53 +++++++++++++++++++ 3 files changed, 113 insertions(+) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 7c8018a..b6127f1 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1556,6 +1556,7 @@ def Tosa_SliceOp: Tosa_Op<"slice", [ Tosa_Tensor1Dto6D:$output ); + let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 1a8a578..16f23e4 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -519,6 +519,65 @@ void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +struct ConcatSliceOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, + PatternRewriter &rewriter) const override { + Value sliceInput = sliceOp.getInput(); + auto concatOp = sliceInput.getDefiningOp(); + if (!concatOp) + return rewriter.notifyMatchFailure( + sliceOp, "slice input must be concat operation"); + + OperandRange inputs = concatOp.getInput1(); + auto concatType = dyn_cast(concatOp.getType()); + if (!concatType || !concatType.hasStaticShape()) + return rewriter.notifyMatchFailure( + sliceOp, "slice input must be a static ranked tensor"); + int32_t axis = concatOp.getAxis(); + + llvm::SmallVector sliceStart(sliceOp.getStart()); + llvm::ArrayRef sliceSize = sliceOp.getSize(); + + // Validate slice on the concatenated axis. Slicing along this + // axis should span only one of the inputs to the concatenate + // operation. + std::optional replaceWithSlice; + for (auto input : inputs) { + auto inputType = dyn_cast(input.getType()); + if (!inputType || !inputType.hasStaticShape()) + return rewriter.notifyMatchFailure( + sliceOp, "concat input must be a static ranked tensor"); + + if (sliceStart[axis] >= 0 && + (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) { + replaceWithSlice = + rewriter + .create( + sliceOp.getLoc(), sliceOp.getType(), input, + rewriter.getDenseI64ArrayAttr(sliceOp.getStart()), + rewriter.getDenseI64ArrayAttr(sliceSize)) + .getResult(); + break; + } + sliceStart[axis] -= inputType.getDimSize(axis); + } + + if (!replaceWithSlice) + return rewriter.notifyMatchFailure( + sliceOp, "corresponding concat input not found for slice"); + + rewriter.replaceOp(sliceOp, replaceWithSlice.value()); + return success(); + } +}; + +void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // Operator Folders. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index e16a614..77627d8 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -434,3 +434,56 @@ func.func @fold_resize_bilinear(%arg0 : tensor<1x15x13x1xi8>) -> tensor<1x15x13x %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = array, offset = array, border = array} : (tensor<1x15x13x1xi8>) -> tensor<1x15x13x1xi8> return %resize : tensor<1x15x13x1xi8> } + +// ----- + +// CHECK-LABEL: @canonicalize_concat_slice_final_axis +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12x1xf32>, %[[VAL_1:.*]]: tensor<1x12x12x1xf32> +// CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32> +func.func @canonicalize_concat_slice_final_axis(%arg0 : tensor<1x12x12x1xf32>, %arg1 : tensor<1x12x12x1xf32>) -> (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) { + %0 = "tosa.concat"(%arg0, %arg1) {axis = 3 : i64} : (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) -> tensor<1x12x12x2xf32> + %1 = "tosa.slice"(%0) {size = array, start = array} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32> + %2 = "tosa.slice"(%0) {size = array, start = array} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32> + return %1, %2 : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32> +} + +// ----- + +// CHECK-LABEL: @canonicalize_concat_slice_middle_axis +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32> +// CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12xf32>, tensor<1x12x12xf32> +func.func @canonicalize_concat_slice_middle_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x12xf32>, tensor<1x12x12xf32>) { + %0 = "tosa.concat"(%arg0, %arg1) {axis = 1 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x24x12xf32> + %1 = "tosa.slice"(%0) {size = array, start = array} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32> + %2 = "tosa.slice"(%0) {size = array, start = array} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32> + return %1, %2 : tensor<1x12x12xf32>, tensor<1x12x12xf32> +} + +// ----- + +// CHECK-LABEL: @canonicalize_cross_concat_inputs +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_1]]) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_2]]) {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32> +// CHECK: %[[VAL_4:.*]] = "tosa.slice"(%[[VAL_2]]) {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32> +// CHECK: return %[[VAL_3]], %[[VAL_4]] : tensor<1x12x15xf32>, tensor<1x12x20xf32> +func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x15xf32>, tensor<1x12x20xf32>) { + %0 = "tosa.concat"(%arg0, %arg1) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32> + %1 = "tosa.slice"(%0) {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32> + %2 = "tosa.slice"(%0) {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32> + return %1, %2 : tensor<1x12x15xf32>, tensor<1x12x20xf32> +} + +// ----- + +// CHECK-LABEL: @canonicalize_concat_slice_on_non_concat_axis +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array, start = array} : (tensor<1x12x12xf32>) -> tensor<1x6x12xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_1]]) {size = array, start = array} : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32> +// CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x6x12xf32>, tensor<1x3x12xf32> +func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x6x12xf32>, tensor<1x3x12xf32>) { + %0 = "tosa.concat"(%arg0, %arg1) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32> + %1 = "tosa.slice"(%0) {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x6x12xf32> + %2 = "tosa.slice"(%0) {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x3x12xf32> + return %1, %2 : tensor<1x6x12xf32>, tensor<1x3x12xf32> +} -- 2.7.4