Add a pattern to be able to collapse dimensions in a linalg generic op.
Differential Revision: https://reviews.llvm.org/D135503
RewritePatternSet &patterns,
const ControlFusionFn &controlElementwiseOpFusion);
+/// Function type to control generic op dimension collapsing. It is expected
+/// to return an array of `ReassociationIndices` representing dimensions that
+/// should be merged.
+using GetCollapsableDimensionsFn =
+ std::function<SmallVector<ReassociationIndices>(linalg::GenericOp)>;
+
+/// Pattern to collapse dimensions in a linalg.generic op. This will collapse
+/// tensor operands when needed and expand back the result tensors.
+void populateCollapseDimensions(
+ RewritePatternSet &patterns,
+ const GetCollapsableDimensionsFn &controlCollapseDimensions);
+
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.
/// Implementation of fusion with reshape operation by collapsing dimensions.
static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
- OpOperand *fusableOpOperand, PatternRewriter &rewriter) {
+ PatternRewriter &rewriter) {
// Bail on trivial no-op cases.
if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
Optional<SmallVector<Value>> replacements =
collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
- opOperand, rewriter);
+ rewriter);
if (!replacements) {
return rewriter.notifyMatchFailure(
genericOp, "failed to do the fusion by collapsing transformation");
private:
ControlFusionFn controlFoldingReshapes;
};
+
+/// Pattern to collapse dimensions.
+class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> {
+public:
+ CollapseLinalgDimensions(MLIRContext *context,
+ GetCollapsableDimensionsFn collapseDimensions,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<GenericOp>(context, benefit),
+ controlCollapseDimension(std::move(collapseDimensions)) {}
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ SmallVector<ReassociationIndices> collapsableIterationDims =
+ controlCollapseDimension(genericOp);
+ if (collapsableIterationDims.empty())
+ return failure();
+
+ Optional<SmallVector<Value>> replacements = collapseGenericOpIterationDims(
+ genericOp, collapsableIterationDims, rewriter);
+ if (!replacements) {
+ return rewriter.notifyMatchFailure(genericOp,
+ "failed to collpase dimensions");
+ }
+ rewriter.replaceOp(genericOp, *replacements);
+ return success();
+ }
+
+private:
+ GetCollapsableDimensionsFn controlCollapseDimension;
+};
+
} // namespace
//===---------------------------------------------------------------------===//
RemoveOutsDependency>(context);
}
+void mlir::linalg::populateCollapseDimensions(
+ RewritePatternSet &patterns,
+ const GetCollapsableDimensionsFn &controlCollapseDimensions) {
+ patterns.add<CollapseLinalgDimensions>(patterns.getContext(),
+ controlCollapseDimensions);
+}
+
//===---------------------------------------------------------------------===//
// Passes
//===---------------------------------------------------------------------===//
--- /dev/null
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=collapse-dimensions-control=2,3 -split-input-file | FileCheck %s
+
+func.func @collapse_reduction(
+ %arg0: tensor<2x32x10x4096xf32>, %arg1: tensor<2x32xf32>) -> tensor<2x32xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ ins(%arg0 : tensor<2x32x10x4096xf32>) outs(%arg1 : tensor<2x32xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32):
+ %1 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %1 : f32
+ } -> tensor<2x32xf32>
+ return %0 : tensor<2x32xf32>
+}
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @collapse_reduction
+// CHECK: %[[T:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : tensor<2x32x10x4096xf32> into tensor<2x32x40960xf32>
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%[[T]] : tensor<2x32x40960xf32>) outs(%{{.*}} : tensor<2x32xf32>) {
+// CHECK: } -> tensor<2x32xf32>
+
+// -----
+
+func.func @collapse_parallel(
+ %arg0: tensor<32x2x10x4096xf32>, %arg1: tensor<2x32x10x4096xf32>) -> tensor<2x32x10x4096xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg0 : tensor<32x2x10x4096xf32>) outs(%arg1 : tensor<2x32x10x4096xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32):
+ %1 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %1 : f32
+ } -> tensor<2x32x10x4096xf32>
+ return %0 : tensor<2x32x10x4096xf32>
+}
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK-LABEL: func @collapse_parallel
+// CHECK-DAG: %[[S:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : tensor<32x2x10x4096xf32> into tensor<32x2x40960xf32>
+// CHECK-DAG: %[[D:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : tensor<2x32x10x4096xf32> into tensor<2x32x40960xf32>
+// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[S]] : tensor<32x2x40960xf32>) outs(%[[D]] : tensor<2x32x40960xf32>) {
+// CHECK: } -> tensor<2x32x40960xf32>
+// CHECK: tensor.expand_shape %[[R]] {{\[}}[0], [1], [2, 3]] : tensor<2x32x40960xf32> into tensor<2x32x10x4096xf32>
"fusion patterns that "
"collapse the iteration space of the consumer"),
llvm::cl::init(false)};
+ ListOption<int64_t> collapseDimensions{
+ *this, "collapse-dimensions-control",
+ llvm::cl::desc("Test controlling dimension collapse pattern")};
void runOnOperation() override {
MLIRContext *context = &this->getContext();
linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
+
+ if (!collapseDimensions.empty()) {
+ SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
+ collapseDimensions.end());
+ linalg::GetCollapsableDimensionsFn collapseFn =
+ [&dims](linalg::GenericOp op) {
+ SmallVector<ReassociationIndices> reassociations;
+ reassociations.emplace_back(dims);
+ return reassociations;
+ };
+ RewritePatternSet patterns(context);
+ linalg::populateCollapseDimensions(patterns, collapseFn);
+ (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+ }
}
};