return rewriter.notifyMatchFailure(
op, "unable to create linalg.generic body for reduce op");
- rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
- op, resultTy, linalgOp.getResults()[0],
- rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
+ SmallVector<ReassociationExprs, 4> reassociationMap;
+ uint64_t expandInputRank =
+ linalgOp.getResults()[0].getType().cast<ShapedType>().getRank();
+ reassociationMap.resize(expandInputRank);
+
+ for (uint64_t i = 0; i < expandInputRank; i++) {
+ int32_t dimToPush = i > axis ? i + 1 : i;
+ reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
+ }
+
+ if (expandInputRank != 0) {
+ int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
+ reassociationMap[expandedDim].push_back(
+ rewriter.getAffineDimExpr(expandedDim + 1));
+ }
+
+ // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
+ // since here we know which dimension to expand, and `tosa::ReshapeOp` would
+ // not have access to such information. This matters when handling dynamically
+ // sized tensors.
+ rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
+ op, resultTy, linalgOp.getResults()[0], reassociationMap);
return success();
}
// CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
// CHECK: [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield [[RES]] : f32
- // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 1, 4>}
+ // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32>
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32>
// CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
// CHECK: [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield [[RES]] : f32
- // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 5, 1>}
+ // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32>
%1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5x1xf32>
// CHECK: arith.constant 1.0
// CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
// CHECK: %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[RES]] : f32
- // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: -9223372036854775808, 1, 4>}
+ // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<?x4xf32> into tensor<?x1x4xf32>
%0 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<?x5x4xf32>) -> tensor<?x1x4xf32>
return
}
// CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
// CHECK: %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[RES]] : f32
- // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: 1>} : (tensor<f32>) -> tensor<1xf32>
+ // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}] : tensor<f32> into tensor<1xf32>
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<?xf32>) -> tensor<1xf32>
return
}
// CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
// CHECK: %[[RES:.+]] = arith.mulf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[RES]] : f32
- // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: 5, -9223372036854775808, 1>}
+ // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<5x?xf32> into tensor<5x?x1xf32>
%0 = "tosa.reduce_prod"(%arg0) {axis = 2 : i64} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32>
return
}
// CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
// CHECK: %[[MAX:.+]] = arith.maxf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[MAX]] : f32
- // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: -9223372036854775808, 1>}
+ // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<?xf32> into tensor<?x1xf32>
%0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor<?x?xf32>) -> tensor<?x1xf32>
return
}
// CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
// CHECK: [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
// CHECK: linalg.yield [[RES]] : i32
- // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 1, 4>}
+ // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32>
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32>
// CHECK: [[INIT:%.+]] = tensor.empty()
// CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
// CHECK: [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
// CHECK: linalg.yield [[RES]] : i32
- // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 5, 1>}
+ // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32>
%1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x1xi32>
// CHECK: arith.constant 1
// CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i1, %[[ARG2:[0-9a-zA-Z_]+]]: i1)
// CHECK: [[RES:%.+]] = arith.andi %[[ARG1]], %[[ARG2]] : i1
// CHECK: linalg.yield [[RES]] : i1
- // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 1, 4>}
+ // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1>
%0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<5x4xi1>) -> tensor<1x4xi1>
// CHECK: arith.constant false