[mlir][linalg] Extend drop unit dim pattern to all cases of reduction
authorthomasraoux <thomasraoux@google.com>
Fri, 17 Sep 2021 17:09:57 +0000 (10:09 -0700)
committerthomasraoux <thomasraoux@google.com>
Fri, 17 Sep 2021 17:09:57 +0000 (10:09 -0700)
Even with all parallel loops reading the output value is still allowed so we
don't have to handle reduction loops differently.

Differential Revision: https://reviews.llvm.org/D109851

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

index e23a58e..8315de4 100644 (file)
@@ -187,40 +187,13 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
       return failure();
     SmallVector<int64_t> dims = genericOp.getStaticShape();
 
-    // Find all the reduction iterators. Those need some special consideration
-    // (see below).
-    auto getLoopDimsOfType =
-        [&](StringRef iteratorTypeName) -> SmallVector<unsigned, 4> {
-      SmallVector<AffineExpr> dimExprs;
-      getDimsOfType(genericOp, iteratorTypeName, dimExprs);
-      return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) {
-        return expr.cast<AffineDimExpr>().getPosition();
-      }));
-    };
-    auto reductionDims = getLoopDimsOfType(getReductionIteratorTypeName());
-
     DenseSet<unsigned> unitDims;
     SmallVector<unsigned, 4> unitDimsReductionLoops;
     ArrayAttr iteratorTypes = genericOp.iterator_types();
     for (auto expr : enumerate(invertedMap.getResults())) {
       if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
-        if (dims[dimExpr.getPosition()] == 1) {
-          if (isParallelIterator(iteratorTypes[expr.index()]))
-            unitDims.insert(expr.index());
-          else if (isReductionIterator(iteratorTypes[expr.index()]))
-            unitDimsReductionLoops.push_back(expr.index());
-        }
-    }
-
-    // Reduction loops can be dropped if there is at least one other reduction
-    // loop that is not dropped. This accounts for the initial value read in the
-    // reduction loop.
-    if (!unitDimsReductionLoops.empty() && reductionDims.size() > 1) {
-      if (unitDimsReductionLoops.size() == reductionDims.size())
-        unitDims.insert(reductionDims.begin(), std::prev(reductionDims.end()));
-      else
-        unitDims.insert(unitDimsReductionLoops.begin(),
-                        unitDimsReductionLoops.end());
+        if (dims[dimExpr.getPosition()] == 1)
+          unitDims.insert(expr.index());
     }
 
     if (unitDims.empty())
index 0271638..60ad723 100644 (file)
@@ -361,7 +361,7 @@ func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> {
 
 // -----
 
-func @unit_dim_for_reduction_keep_one(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1xf32> {
+func @unit_dim_for_both_reduction(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1xf32> {
   %cst = constant 1.000000e+00 : f32
   %c3 = constant 3 : index
   %1 = linalg.init_tensor [1, 1] : tensor<1x1xf32>
@@ -378,17 +378,16 @@ func @unit_dim_for_reduction_keep_one(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1x
   } -> tensor<1x1xf32>
   return %3 : tensor<1x1xf32>
 }
-//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)>
-//      CHECK: func @unit_dim_for_reduction_keep_one
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0)>
+//      CHECK: func @unit_dim_for_both_reduction
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<1x?x1x1xf32>
-//  CHECK-DAG:   %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
+//  CHECK-DAG:   %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3]
 //      CHECK:   %[[INIT:.+]] = linalg.init_tensor [1] : tensor<1xf32>
 //      CHECK:   %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]])
 //      CHECK:   %[[RESULT:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]]]
-// CHECK-SAME:     iterator_types = ["parallel", "reduction"]
-// CHECK-SAME:     ins(%[[RESHAPE]] : tensor<?x1xf32>)
+// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]]]
+// CHECK-SAME:     iterator_types = ["parallel"]
+// CHECK-SAME:     ins(%[[RESHAPE]] : tensor<?xf32>)
 // CHECK-SAME:     outs(%[[FILL]] : tensor<1xf32>)
 //      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_expand_shape %[[RESULT]] {{\[}}[0, 1]]
 //      CHECK:   return %[[RESULT_RESHAPE]]