From: thomasraoux Date: Fri, 17 Sep 2021 17:09:57 +0000 (-0700) Subject: [mlir][linalg] Extend drop unit dim pattern to all cases of reduction X-Git-Tag: upstream/15.0.7~31218 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=36aac53b36983c5ae8b7dcb0519c34e8c41dc4e5;p=platform%2Fupstream%2Fllvm.git [mlir][linalg] Extend drop unit dim pattern to all cases of reduction Even with all parallel loops reading the output value is still allowed so we don't have to handle reduction loops differently. Differential Revision: https://reviews.llvm.org/D109851 --- diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index e23a58e..8315de4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -187,40 +187,13 @@ struct FoldUnitDimLoops : public OpRewritePattern { return failure(); SmallVector dims = genericOp.getStaticShape(); - // Find all the reduction iterators. Those need some special consideration - // (see below). - auto getLoopDimsOfType = - [&](StringRef iteratorTypeName) -> SmallVector { - SmallVector dimExprs; - getDimsOfType(genericOp, iteratorTypeName, dimExprs); - return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) { - return expr.cast().getPosition(); - })); - }; - auto reductionDims = getLoopDimsOfType(getReductionIteratorTypeName()); - DenseSet unitDims; SmallVector unitDimsReductionLoops; ArrayAttr iteratorTypes = genericOp.iterator_types(); for (auto expr : enumerate(invertedMap.getResults())) { if (AffineDimExpr dimExpr = expr.value().dyn_cast()) - if (dims[dimExpr.getPosition()] == 1) { - if (isParallelIterator(iteratorTypes[expr.index()])) - unitDims.insert(expr.index()); - else if (isReductionIterator(iteratorTypes[expr.index()])) - unitDimsReductionLoops.push_back(expr.index()); - } - } - - // Reduction loops can be dropped if there is at least one other reduction - // loop that is not dropped. This accounts for the initial value read in the - // reduction loop. - if (!unitDimsReductionLoops.empty() && reductionDims.size() > 1) { - if (unitDimsReductionLoops.size() == reductionDims.size()) - unitDims.insert(reductionDims.begin(), std::prev(reductionDims.end())); - else - unitDims.insert(unitDimsReductionLoops.begin(), - unitDimsReductionLoops.end()); + if (dims[dimExpr.getPosition()] == 1) + unitDims.insert(expr.index()); } if (unitDims.empty()) diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir index 0271638..60ad723 100644 --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -361,7 +361,7 @@ func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> { // ----- -func @unit_dim_for_reduction_keep_one(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1xf32> { +func @unit_dim_for_both_reduction(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1xf32> { %cst = constant 1.000000e+00 : f32 %c3 = constant 3 : index %1 = linalg.init_tensor [1, 1] : tensor<1x1xf32> @@ -378,17 +378,16 @@ func @unit_dim_for_reduction_keep_one(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1x } -> tensor<1x1xf32> return %3 : tensor<1x1xf32> } -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)> -// CHECK: func @unit_dim_for_reduction_keep_one +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0)> +// CHECK: func @unit_dim_for_both_reduction // CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x1xf32> -// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] +// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3] // CHECK: %[[INIT:.+]] = linalg.init_tensor [1] : tensor<1xf32> // CHECK: %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]]) // CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction"] -// CHECK-SAME: ins(%[[RESHAPE]] : tensor) +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel"] +// CHECK-SAME: ins(%[[RESHAPE]] : tensor) // CHECK-SAME: outs(%[[FILL]] : tensor<1xf32>) // CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_expand_shape %[[RESULT]] {{\[}}[0, 1]] // CHECK: return %[[RESULT_RESHAPE]]