From 9bcda47afcb4af2831654d7c31ad2e956202fed0 Mon Sep 17 00:00:00 2001 From: natashaknk Date: Mon, 3 Oct 2022 10:07:57 -0700 Subject: [PATCH] [mlir][tosa] Swap the reshape at the end of the reduce op for an expand_shape in tosa-to-linalg Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D133877 --- mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 17 +++++++++++++++-- mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir | 6 ++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 11c7d06..b54dab8 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -815,8 +815,21 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, if (!didEncounterError) return failure(); - rewriter.replaceOpWithNewOp(op, resultTy, - linalgOp.getResults()); + SmallVector reassociationMap; + uint64_t expandInputRank = + linalgOp.getResults()[0].getType().cast().getRank(); + reassociationMap.resize(expandInputRank); + + for (uint64_t i = 0; i < expandInputRank; i++) { + int32_t dimToPush = i > axis ? i + 1 : i; + reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush)); + } + int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1; + reassociationMap[expandedDim].push_back( + rewriter.getAffineDimExpr(expandedDim + 1)); + + rewriter.replaceOpWithNewOp( + op, resultTy, linalgOp.getResults()[0], reassociationMap); return success(); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 2e4dfbf..685f782 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -746,8 +746,7 @@ func.func @reduce_float_dyn(%arg0: tensor) -> () { // CHECK: ^bb0(%arg1: f32, %arg2: f32) // CHECK: %[[RES:.+]] = arith.addf %arg1, %arg2 : f32 // CHECK: linalg.yield %[[RES]] : f32 - // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor into tensor - // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1, 2]] : tensor into tensor + // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor into tensor %0 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor) -> tensor return } @@ -768,8 +767,7 @@ func.func @reduce_float_dyn_nonzero_batch(%arg0: tensor<5x?x4xf32>) -> () { // CHECK: ^bb0(%arg1: f32, %arg2: f32) // CHECK: %[[RES:.+]] = arith.mulf %arg1, %arg2 : f32 // CHECK: linalg.yield %[[RES]] : f32 - // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<5x?xf32> into tensor - // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1, 2]] : tensor into tensor<5x?x1xf32> + // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<5x?xf32> into tensor<5x?x1xf32> %0 = "tosa.reduce_prod"(%arg0) {axis = 2 : i64} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32> return } -- 2.7.4