[mlir][linalg] Expose pattern to collapse generic op dimensions
authorThomas Raoux <thomasraoux@google.com>
Fri, 7 Oct 2022 23:59:06 +0000 (23:59 +0000)
committerThomas Raoux <thomasraoux@google.com>
Mon, 10 Oct 2022 16:44:01 +0000 (16:44 +0000)
Add a pattern to be able to collapse dimensions in a linalg generic op.

Differential Revision: https://reviews.llvm.org/D135503

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/Dialect/Linalg/collapse-dim.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp

index 62dcc8e..fb37c6f 100644 (file)
@@ -76,6 +76,18 @@ void populateElementwiseOpsFusionPatterns(
     RewritePatternSet &patterns,
     const ControlFusionFn &controlElementwiseOpFusion);
 
+/// Function type to control generic op dimension collapsing. It is expected
+/// to return an array of `ReassociationIndices` representing dimensions that
+/// should be merged.
+using GetCollapsableDimensionsFn =
+    std::function<SmallVector<ReassociationIndices>(linalg::GenericOp)>;
+
+/// Pattern to collapse dimensions in a linalg.generic op. This will collapse
+/// tensor operands when needed and expand back the result tensors.
+void populateCollapseDimensions(
+    RewritePatternSet &patterns,
+    const GetCollapsableDimensionsFn &controlCollapseDimensions);
+
 /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
 /// producer (consumer) generic operation by expanding the dimensionality of the
 /// loop in the generic op.
index 45bc4a8..05dce4c 100644 (file)
@@ -1367,7 +1367,7 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
 /// Implementation of fusion with reshape operation by collapsing dimensions.
 static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
     GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
-    OpOperand *fusableOpOperand, PatternRewriter &rewriter) {
+    PatternRewriter &rewriter) {
   // Bail on trivial no-op cases.
   if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
       llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
@@ -1510,7 +1510,7 @@ public:
 
       Optional<SmallVector<Value>> replacements =
           collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
-                                         opOperand, rewriter);
+                                         rewriter);
       if (!replacements) {
         return rewriter.notifyMatchFailure(
             genericOp, "failed to do the fusion by collapsing transformation");
@@ -1525,6 +1525,37 @@ public:
 private:
   ControlFusionFn controlFoldingReshapes;
 };
+
+/// Pattern to collapse dimensions.
+class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> {
+public:
+  CollapseLinalgDimensions(MLIRContext *context,
+                           GetCollapsableDimensionsFn collapseDimensions,
+                           PatternBenefit benefit = 1)
+      : OpRewritePattern<GenericOp>(context, benefit),
+        controlCollapseDimension(std::move(collapseDimensions)) {}
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<ReassociationIndices> collapsableIterationDims =
+        controlCollapseDimension(genericOp);
+    if (collapsableIterationDims.empty())
+      return failure();
+
+    Optional<SmallVector<Value>> replacements = collapseGenericOpIterationDims(
+        genericOp, collapsableIterationDims, rewriter);
+    if (!replacements) {
+      return rewriter.notifyMatchFailure(genericOp,
+                                         "failed to collpase dimensions");
+    }
+    rewriter.replaceOp(genericOp, *replacements);
+    return success();
+  }
+
+private:
+  GetCollapsableDimensionsFn controlCollapseDimension;
+};
+
 } // namespace
 
 //===---------------------------------------------------------------------===//
@@ -1743,6 +1774,13 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
                RemoveOutsDependency>(context);
 }
 
+void mlir::linalg::populateCollapseDimensions(
+    RewritePatternSet &patterns,
+    const GetCollapsableDimensionsFn &controlCollapseDimensions) {
+  patterns.add<CollapseLinalgDimensions>(patterns.getContext(),
+                                         controlCollapseDimensions);
+}
+
 //===---------------------------------------------------------------------===//
 // Passes
 //===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
new file mode 100644 (file)
index 0000000..3587557
--- /dev/null
@@ -0,0 +1,55 @@
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=collapse-dimensions-control=2,3 -split-input-file | FileCheck %s
+
+func.func @collapse_reduction(
+    %arg0: tensor<2x32x10x4096xf32>, %arg1: tensor<2x32xf32>) -> tensor<2x32xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [
+        affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+        affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+  iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+  ins(%arg0 : tensor<2x32x10x4096xf32>) outs(%arg1 : tensor<2x32xf32>) {
+  ^bb0(%arg3: f32, %arg4: f32):
+    %1 = arith.addf %arg3, %arg4 : f32
+    linalg.yield %1 : f32
+  } -> tensor<2x32xf32>
+  return %0 : tensor<2x32xf32>
+}
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @collapse_reduction
+//       CHECK:   %[[T:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : tensor<2x32x10x4096xf32> into tensor<2x32x40960xf32>
+//       CHECK:   linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], 
+//  CHECK-SAME:     iterator_types = ["parallel", "parallel", "reduction"]} 
+//  CHECK-SAME:     ins(%[[T]] : tensor<2x32x40960xf32>) outs(%{{.*}} : tensor<2x32xf32>) {
+//       CHECK:   } -> tensor<2x32xf32>
+
+// -----
+
+func.func @collapse_parallel(
+    %arg0: tensor<32x2x10x4096xf32>, %arg1: tensor<2x32x10x4096xf32>) -> tensor<2x32x10x4096xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [
+        affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>,
+        affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+  iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  ins(%arg0 : tensor<32x2x10x4096xf32>) outs(%arg1 : tensor<2x32x10x4096xf32>) {
+  ^bb0(%arg3: f32, %arg4: f32):
+    %1 = arith.addf %arg3, %arg4 : f32
+    linalg.yield %1 : f32
+  } -> tensor<2x32x10x4096xf32>
+  return %0 : tensor<2x32x10x4096xf32>
+}
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK-LABEL: func @collapse_parallel
+//   CHECK-DAG:  %[[S:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : tensor<32x2x10x4096xf32> into tensor<32x2x40960xf32>
+//   CHECK-DAG:  %[[D:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : tensor<2x32x10x4096xf32> into tensor<2x32x40960xf32>
+//       CHECK:  %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], 
+//  CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel"]} 
+//  CHECK-SAME:     ins(%[[S]] : tensor<32x2x40960xf32>) outs(%[[D]] : tensor<2x32x40960xf32>) {
+//       CHECK:   } -> tensor<2x32x40960xf32>
+//       CHECK:  tensor.expand_shape %[[R]] {{\[}}[0], [1], [2, 3]] : tensor<2x32x40960xf32> into tensor<2x32x10x4096xf32>
index 41e46d0..0119516 100644 (file)
@@ -99,6 +99,9 @@ struct TestLinalgElementwiseFusion
                      "fusion patterns that "
                      "collapse the iteration space of the consumer"),
       llvm::cl::init(false)};
+  ListOption<int64_t> collapseDimensions{
+      *this, "collapse-dimensions-control",
+      llvm::cl::desc("Test controlling dimension collapse pattern")};
 
   void runOnOperation() override {
     MLIRContext *context = &this->getContext();
@@ -179,6 +182,20 @@ struct TestLinalgElementwiseFusion
       linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
       (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
     }
+
+    if (!collapseDimensions.empty()) {
+      SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
+                                   collapseDimensions.end());
+      linalg::GetCollapsableDimensionsFn collapseFn =
+          [&dims](linalg::GenericOp op) {
+            SmallVector<ReassociationIndices> reassociations;
+            reassociations.emplace_back(dims);
+            return reassociations;
+          };
+      RewritePatternSet patterns(context);
+      linalg::populateCollapseDimensions(patterns, collapseFn);
+      (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+    }
   }
 };