From f12639d0d674327876388cfde6b4d226359284ac Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Tue, 9 May 2023 06:20:53 +0000 Subject: [PATCH] [mlir][Linalg] Avoid collapsing dimensions of linalg op that arent foldable. The collapsing dimensions transformation is limited to only those cases where the sequence of dimensions are contiguous in all the ranges of the indexing maps of the operation. Add this check before applying the transformation. Differential Revision: https://reviews.llvm.org/D150176 --- .../mlir/Dialect/Linalg/Transforms/Transforms.h | 17 +++++++++++++++-- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 20 ++++++++++++++++++-- mlir/test/Dialect/Linalg/collapse-dim.mlir | 17 +++++++++++++++++ 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 4a3a86f..b538d0f 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -901,8 +901,21 @@ splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc = false); -/// Collapses dimensions of linalg.generic operation. It also collapses inputs -/// before the op and expands outputs after the op. +/// Return `true` if a given sequence of dimensions are contiguous in the +/// range of the specified indexing map. +bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence); +/// Return `true` if all sequences of dimensions specified in `dimSequences` are +/// contiguous in all the ranges of the `maps`. +bool areDimSequencesPreserved(ArrayRef maps, + ArrayRef dimSequences); + +/// Collapses dimensions of linalg.generic operation. A precondition to +/// calling this method is that for each list in `foldedIterationDim`, the +/// sequence of dimensions is contiguous in domains of all `indexing_maps` of +/// the `genericOp`. This can be checked using `areDimSequencePreserved` method. +/// When valid, the method also collapses the operands of the op. Returns +/// replacement values of the results of the original `genericOp` by inserting +/// reshapes to get back values of compatible types. FailureOr> collapseGenericOpIterationDims( GenericOp genericOp, ArrayRef foldedIterationDims, RewriterBase &rewriter); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 57e6e2a..bf728a6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1004,8 +1004,8 @@ getDomainReassociation(AffineMap indexingMap, /// For a given `dimSequence`, check if the sequence is conserved in the /// `indexingMap`. `indexingMap` is expected to be a projected permutation. /// Non-existence of the sequence returns true as well. -static bool isDimSequencePreserved(AffineMap indexingMap, - ReassociationIndicesRef dimSequence) { +bool mlir::linalg::isDimSequencePreserved(AffineMap indexingMap, + ReassociationIndicesRef dimSequence) { assert(!dimSequence.empty() && "expected non-empty list for dimension sequence"); assert(indexingMap.isProjectedPermutation() && @@ -1045,6 +1045,15 @@ static bool isDimSequencePreserved(AffineMap indexingMap, return true; } +bool mlir::linalg::areDimSequencesPreserved( + ArrayRef maps, ArrayRef dimSequences) { + return llvm::all_of(maps, [&](AffineMap map) { + return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) { + return isDimSequencePreserved(map, dimSequence); + }); + }); +} + // Return the list of dimensions of the iteration domain that can be // collapsed to allow for fusion with the a producer that is an expand_shape // operation. If all dimensions created by expansion can be collapsed in the @@ -1592,6 +1601,13 @@ public: if (collapsableIterationDims.empty()) return failure(); + // Check if the specified list of dimensions to collapse is a valid list. + if (!areDimSequencesPreserved(genericOp.getIndexingMapsArray(), + collapsableIterationDims)) { + return rewriter.notifyMatchFailure( + genericOp, "specified dimensions cannot be collapsed"); + } + std::optional> replacements = collapseGenericOpIterationDims(genericOp, collapsableIterationDims, rewriter); diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir index 0bd2bc1..6737a6e 100644 --- a/mlir/test/Dialect/Linalg/collapse-dim.mlir +++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir @@ -53,3 +53,20 @@ func.func @collapse_parallel( // CHECK-SAME: ins(%[[S]] : tensor<32x2x40960xf32>) outs(%[[D]] : tensor<2x32x40960xf32>) { // CHECK: } -> tensor<2x32x40960xf32> // CHECK: tensor.expand_shape %[[R]] {{\[}}[0], [1], [2, 3]] : tensor<2x32x40960xf32> into tensor<2x32x10x4096xf32> + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @uncollapsable(%arg0 : tensor<41x3x1x57xf32>, %arg1 : tensor<3x1x57x41xf32>) -> tensor<3x1x57x41xf32> { + %0 = linalg.generic { + indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<41x3x1x57xf32>) outs(%arg1 : tensor<3x1x57x41xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<3x1x57x41xf32> + return %0 : tensor<3x1x57x41xf32> +} +// CHECK-LABEL: func @uncollapsable( +// CHECK: linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] -- 2.7.4