results.add<ClampClampOptimization>(context);
}
+struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ Value sliceInput = sliceOp.getInput();
+ auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
+ if (!concatOp)
+ return rewriter.notifyMatchFailure(
+ sliceOp, "slice input must be concat operation");
+
+ OperandRange inputs = concatOp.getInput1();
+ auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
+ if (!concatType || !concatType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ sliceOp, "slice input must be a static ranked tensor");
+ int32_t axis = concatOp.getAxis();
+
+ llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
+ llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();
+
+ // Validate slice on the concatenated axis. Slicing along this
+ // axis should span only one of the inputs to the concatenate
+ // operation.
+ std::optional<Value> replaceWithSlice;
+ for (auto input : inputs) {
+ auto inputType = dyn_cast<RankedTensorType>(input.getType());
+ if (!inputType || !inputType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ sliceOp, "concat input must be a static ranked tensor");
+
+ if (sliceStart[axis] >= 0 &&
+ (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
+ replaceWithSlice =
+ rewriter
+ .create<tosa::SliceOp>(
+ sliceOp.getLoc(), sliceOp.getType(), input,
+ rewriter.getDenseI64ArrayAttr(sliceOp.getStart()),
+ rewriter.getDenseI64ArrayAttr(sliceSize))
+ .getResult();
+ break;
+ }
+ sliceStart[axis] -= inputType.getDimSize(axis);
+ }
+
+ if (!replaceWithSlice)
+ return rewriter.notifyMatchFailure(
+ sliceOp, "corresponding concat input not found for slice");
+
+ rewriter.replaceOp(sliceOp, replaceWithSlice.value());
+ return success();
+ }
+};
+
+void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ConcatSliceOptimization>(context);
+}
+
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
%resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<1x15x13x1xi8>) -> tensor<1x15x13x1xi8>
return %resize : tensor<1x15x13x1xi8>
}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_concat_slice_final_axis
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12x1xf32>, %[[VAL_1:.*]]: tensor<1x12x12x1xf32>
+// CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
+func.func @canonicalize_concat_slice_final_axis(%arg0 : tensor<1x12x12x1xf32>, %arg1 : tensor<1x12x12x1xf32>) -> (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) {
+ %0 = "tosa.concat"(%arg0, %arg1) {axis = 3 : i64} : (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) -> tensor<1x12x12x2xf32>
+ %1 = "tosa.slice"(%0) {size = array<i64: 1, 12, 12, 1>, start = array<i64: 0, 0, 0, 0>} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32>
+ %2 = "tosa.slice"(%0) {size = array<i64: 1, 12, 12, 1>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32>
+ return %1, %2 : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_concat_slice_middle_axis
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
+// CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12xf32>, tensor<1x12x12xf32>
+func.func @canonicalize_concat_slice_middle_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x12xf32>, tensor<1x12x12xf32>) {
+ %0 = "tosa.concat"(%arg0, %arg1) {axis = 1 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x24x12xf32>
+ %1 = "tosa.slice"(%0) {size = array<i64: 1, 12, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32>
+ %2 = "tosa.slice"(%0) {size = array<i64: 1, 12, 12>, start = array<i64: 0, 12, 0>} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32>
+ return %1, %2 : tensor<1x12x12xf32>, tensor<1x12x12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_cross_concat_inputs
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
+// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_1]]) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
+// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_2]]) {size = array<i64: 1, 12, 15>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32>
+// CHECK: %[[VAL_4:.*]] = "tosa.slice"(%[[VAL_2]]) {size = array<i64: 1, 12, 20>, start = array<i64: 0, 0, 4>} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32>
+// CHECK: return %[[VAL_3]], %[[VAL_4]] : tensor<1x12x15xf32>, tensor<1x12x20xf32>
+func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x15xf32>, tensor<1x12x20xf32>) {
+ %0 = "tosa.concat"(%arg0, %arg1) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
+ %1 = "tosa.slice"(%0) {size = array<i64: 1, 12, 15>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32>
+ %2 = "tosa.slice"(%0) {size = array<i64: 1, 12, 20>, start = array<i64: 0, 0, 4>} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32>
+ return %1, %2 : tensor<1x12x15xf32>, tensor<1x12x20xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_concat_slice_on_non_concat_axis
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
+// CHECK: %[[VAL_2:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x12x12xf32>) -> tensor<1x6x12xf32>
+// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_1]]) {size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 12>} : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32>
+// CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x6x12xf32>, tensor<1x3x12xf32>
+func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x6x12xf32>, tensor<1x3x12xf32>) {
+ %0 = "tosa.concat"(%arg0, %arg1) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
+ %1 = "tosa.slice"(%0) {size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x6x12xf32>
+ %2 = "tosa.slice"(%0) {size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 12>} : (tensor<1x12x24xf32>) -> tensor<1x3x12xf32>
+ return %1, %2 : tensor<1x6x12xf32>, tensor<1x3x12xf32>
+}