[mlir][tosa] Swap reshape at end of reduce op with expand_shape
authorRamiro Leal-Cavazos <ramiroleal050@gmail.com>
Tue, 14 Mar 2023 18:34:06 +0000 (18:34 +0000)
committerRobert Suderman <suderman@google.com>
Tue, 14 Mar 2023 18:51:39 +0000 (18:51 +0000)
This commit swaps back the `tosa.reshape` op used at the end of the
lowering for reduce ops with the op `tensor.expand_shape`. This is
needed to properly support dynamically-sized tensors. In such cases,
lowering directly to `tensor.expand_shape` allows us to control which
dimension gets expanded at the end using the knowledge of the
reduction. This would not be possible when using `tosa.reshape`, since
the op does not have a way of knowing that we are only unsqueezing a
single dimension.

Note: this change had previously been performed in
https://reviews.llvm.org/D133877.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D145986

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

index f6ca019..271a095 100644 (file)
@@ -807,9 +807,28 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
     return rewriter.notifyMatchFailure(
         op, "unable to create linalg.generic body for reduce op");
 
-  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
-      op, resultTy, linalgOp.getResults()[0],
-      rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
+  SmallVector<ReassociationExprs, 4> reassociationMap;
+  uint64_t expandInputRank =
+      linalgOp.getResults()[0].getType().cast<ShapedType>().getRank();
+  reassociationMap.resize(expandInputRank);
+
+  for (uint64_t i = 0; i < expandInputRank; i++) {
+    int32_t dimToPush = i > axis ? i + 1 : i;
+    reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
+  }
+
+  if (expandInputRank != 0) {
+    int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
+    reassociationMap[expandedDim].push_back(
+        rewriter.getAffineDimExpr(expandedDim + 1));
+  }
+
+  // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
+  // since here we know which dimension to expand, and `tosa::ReshapeOp` would
+  // not have access to such information. This matters when handling dynamically
+  // sized tensors.
+  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
+      op, resultTy, linalgOp.getResults()[0], reassociationMap);
   return success();
 }
 
index 427fe6b..133999e 100644 (file)
@@ -626,7 +626,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
   // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
   // CHECK:   [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield [[RES]] : f32
-  // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 1, 4>}
+  // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32>
   %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
 
   // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32>
@@ -636,7 +636,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
   // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
   // CHECK:   [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield [[RES]] : f32
-  // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 5, 1>}
+  // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32>
   %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5x1xf32>
 
   // CHECK: arith.constant 1.0
@@ -676,7 +676,7 @@ func.func @reduce_float_dyn(%arg0: tensor<?x5x4xf32>) -> () {
   // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
   // CHECK:   %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield %[[RES]] : f32
-  // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: -9223372036854775808, 1, 4>}
+  // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<?x4xf32> into tensor<?x1x4xf32>
   %0 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<?x5x4xf32>) -> tensor<?x1x4xf32>
   return
 }
@@ -696,7 +696,7 @@ func.func @reduce_float_dyn_rank_1(%arg0: tensor<?xf32>) -> () {
   // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
   // CHECK:   %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield %[[RES]] : f32
-  // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: 1>} : (tensor<f32>) -> tensor<1xf32>
+  // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}] : tensor<f32> into tensor<1xf32>
   %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<?xf32>) -> tensor<1xf32>
   return
 }
@@ -718,7 +718,7 @@ func.func @reduce_float_dyn_nonzero_batch(%arg0: tensor<5x?x4xf32>) -> () {
   // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
   // CHECK:   %[[RES:.+]] = arith.mulf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield %[[RES]] : f32
-  // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: 5, -9223372036854775808, 1>}
+  // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<5x?xf32> into tensor<5x?x1xf32>
   %0 = "tosa.reduce_prod"(%arg0) {axis = 2 : i64} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32>
   return
 }
@@ -740,7 +740,7 @@ func.func @reduce_float_dyn_multiple(%arg0: tensor<?x?xf32>) -> () {
   // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
   // CHECK:   %[[MAX:.+]] = arith.maxf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield %[[MAX]] : f32
-  // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: -9223372036854775808, 1>}
+  // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<?xf32> into tensor<?x1xf32>
   %0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor<?x?xf32>) -> tensor<?x1xf32>
   return
 }
@@ -761,7 +761,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
   // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
   // CHECK:   [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
   // CHECK:   linalg.yield [[RES]] : i32
-  // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 1, 4>}
+  // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32>
   %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32>
 
   // CHECK: [[INIT:%.+]] = tensor.empty()
@@ -771,7 +771,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
   // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
   // CHECK:   [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
   // CHECK:   linalg.yield [[RES]] : i32
-  // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 5, 1>}
+  // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32>
   %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x1xi32>
 
   // CHECK: arith.constant 1
@@ -811,7 +811,7 @@ func.func @reduce_bool(%arg0: tensor<5x4xi1>) -> () {
   // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i1, %[[ARG2:[0-9a-zA-Z_]+]]: i1)
   // CHECK:   [[RES:%.+]] = arith.andi %[[ARG1]], %[[ARG2]] : i1
   // CHECK:   linalg.yield [[RES]] : i1
-  // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 1, 4>}
+  // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1>
   %0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<5x4xi1>) -> tensor<1x4xi1>
 
   // CHECK: arith.constant false