From 2d61628c1f49963921b9ac1995218191dc5e3091 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Tue, 14 Mar 2023 18:34:06 +0000 Subject: [PATCH] [mlir][tosa] Swap reshape at end of reduce op with expand_shape This commit swaps back the `tosa.reshape` op used at the end of the lowering for reduce ops with the op `tensor.expand_shape`. This is needed to properly support dynamically-sized tensors. In such cases, lowering directly to `tensor.expand_shape` allows us to control which dimension gets expanded at the end using the knowledge of the reduction. This would not be possible when using `tosa.reshape`, since the op does not have a way of knowing that we are only unsqueezing a single dimension. Note: this change had previously been performed in https://reviews.llvm.org/D133877. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D145986 --- mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 25 +++++++++++++++++++--- .../Conversion/TosaToLinalg/tosa-to-linalg.mlir | 18 ++++++++-------- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index f6ca019..271a095 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -807,9 +807,28 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, return rewriter.notifyMatchFailure( op, "unable to create linalg.generic body for reduce op"); - rewriter.replaceOpWithNewOp( - op, resultTy, linalgOp.getResults()[0], - rewriter.getDenseI64ArrayAttr(resultTy.getShape())); + SmallVector reassociationMap; + uint64_t expandInputRank = + linalgOp.getResults()[0].getType().cast().getRank(); + reassociationMap.resize(expandInputRank); + + for (uint64_t i = 0; i < expandInputRank; i++) { + int32_t dimToPush = i > axis ? i + 1 : i; + reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush)); + } + + if (expandInputRank != 0) { + int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1; + reassociationMap[expandedDim].push_back( + rewriter.getAffineDimExpr(expandedDim + 1)); + } + + // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`, + // since here we know which dimension to expand, and `tosa::ReshapeOp` would + // not have access to such information. This matters when handling dynamically + // sized tensors. + rewriter.replaceOpWithNewOp( + op, resultTy, linalgOp.getResults()[0], reassociationMap); return success(); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 427fe6b..133999e 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -626,7 +626,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () { // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) // CHECK: [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32 // CHECK: linalg.yield [[RES]] : f32 - // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array} + // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32> %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32> // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32> @@ -636,7 +636,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () { // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) // CHECK: [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32 // CHECK: linalg.yield [[RES]] : f32 - // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array} + // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32> %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5x1xf32> // CHECK: arith.constant 1.0 @@ -676,7 +676,7 @@ func.func @reduce_float_dyn(%arg0: tensor) -> () { // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) // CHECK: %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32 // CHECK: linalg.yield %[[RES]] : f32 - // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array} + // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor into tensor %0 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor) -> tensor return } @@ -696,7 +696,7 @@ func.func @reduce_float_dyn_rank_1(%arg0: tensor) -> () { // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) // CHECK: %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32 // CHECK: linalg.yield %[[RES]] : f32 - // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array} : (tensor) -> tensor<1xf32> + // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}] : tensor into tensor<1xf32> %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor) -> tensor<1xf32> return } @@ -718,7 +718,7 @@ func.func @reduce_float_dyn_nonzero_batch(%arg0: tensor<5x?x4xf32>) -> () { // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) // CHECK: %[[RES:.+]] = arith.mulf %[[ARG1]], %[[ARG2]] : f32 // CHECK: linalg.yield %[[RES]] : f32 - // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array} + // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<5x?xf32> into tensor<5x?x1xf32> %0 = "tosa.reduce_prod"(%arg0) {axis = 2 : i64} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32> return } @@ -740,7 +740,7 @@ func.func @reduce_float_dyn_multiple(%arg0: tensor) -> () { // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) // CHECK: %[[MAX:.+]] = arith.maxf %[[ARG1]], %[[ARG2]] : f32 // CHECK: linalg.yield %[[MAX]] : f32 - // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array} + // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor into tensor %0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor) -> tensor return } @@ -761,7 +761,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () { // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) // CHECK: [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32 // CHECK: linalg.yield [[RES]] : i32 - // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array} + // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32> %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32> // CHECK: [[INIT:%.+]] = tensor.empty() @@ -771,7 +771,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () { // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) // CHECK: [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32 // CHECK: linalg.yield [[RES]] : i32 - // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array} + // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32> %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x1xi32> // CHECK: arith.constant 1 @@ -811,7 +811,7 @@ func.func @reduce_bool(%arg0: tensor<5x4xi1>) -> () { // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i1, %[[ARG2:[0-9a-zA-Z_]+]]: i1) // CHECK: [[RES:%.+]] = arith.andi %[[ARG1]], %[[ARG2]] : i1 // CHECK: linalg.yield [[RES]] : i1 - // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array} + // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1> %0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<5x4xi1>) -> tensor<1x4xi1> // CHECK: arith.constant false -- 2.7.4