From 2d4b998697fda9a0a213e5fb29f8af45a4828dc7 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Wed, 16 Nov 2022 07:52:34 +0000 Subject: [PATCH] [mlir][Linalg] Avoid unnecessary propagating producer result to fused op result. Elementwise op fusion conserves the result of the producer in the fused op, relying on later clean up patterns to drop unused results of the fused op. Instead, if the producer result has no other use apart from the consumer op, avoid making the producer result available in the fused node. This saves some unnecessary IR manipulations. Differential Revision: https://reviews.llvm.org/D138096 --- .../Dialect/Linalg/IR/LinalgInterfaces.cpp | 5 ++ .../Linalg/Transforms/ElementwiseOpFusion.cpp | 58 +++++++++++++------ .../Linalg/fusion-elementwise-options.mlir | 2 +- .../Dialect/Linalg/fusion-elementwise.mlir | 30 ++++++++++ .../Linalg/TestLinalgElementwiseFusion.cpp | 15 +++++ 5 files changed, 91 insertions(+), 19 deletions(-) create mode 100644 mlir/test/Dialect/Linalg/fusion-elementwise.mlir diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index f8126219ccad..b5088717972a 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -36,6 +36,11 @@ bool linalg::detail::canOpOperandsBeDroppedImpl( continue; indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand)); } + if (indexingMaps.empty()) { + // If there are no indexing maps, the operand can only be dropped + // if the op has no loops. + return linalgOp.getNumLoops() == 0; + } return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index bdc7bac99897..11eed11bbabf 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -143,10 +143,10 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { /// Generate the region of the fused tensor operation. The region of the fused /// op must be empty. -static void -generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, - AffineMap consumerToProducerLoopsMap, - OpOperand *fusedOperand, unsigned nloops) { +static void generateFusedElementwiseOpRegion( + RewriterBase &rewriter, GenericOp fusedOp, + AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, + unsigned nloops, llvm::SmallDenseSet &preservedProducerResults) { auto producer = cast(fusedOperand->get().getDefiningOp()); auto consumer = cast(fusedOperand->getOwner()); // Build the region of the fused op. @@ -202,9 +202,13 @@ generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); // 6. All of the producer's output operands - for (BlockArgument bbArg : - producerBlock.getArguments().take_back(producer.getNumDpsInits())) - mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); + for (auto bbArg : llvm::enumerate( + producerBlock.getArguments().take_back(producer.getNumDpsInits()))) { + if (!preservedProducerResults.count(bbArg.index())) + continue; + mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(), + bbArg.value().getLoc())); + } // 7. All of consumer's output operands. for (BlockArgument bbArg : @@ -247,8 +251,11 @@ generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, SmallVector fusedYieldValues; fusedYieldValues.reserve(producerYieldOp.getNumOperands() + consumerYieldOp.getNumOperands()); - for (auto producerYieldVal : producerYieldOp.getOperands()) - fusedYieldValues.push_back(mapper.lookupOrDefault(producerYieldVal)); + for (auto producerYieldVal : llvm::enumerate(producerYieldOp.getOperands())) { + if (preservedProducerResults.count(producerYieldVal.index())) + fusedYieldValues.push_back( + mapper.lookupOrDefault(producerYieldVal.value())); + } for (auto consumerYieldVal : consumerYieldOp.getOperands()) fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal)); rewriter.create(fusedOp.getLoc(), fusedYieldValues); @@ -269,6 +276,18 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, // TODO: allow fusing the producer of an output operand. assert(consumer.isDpsInput(fusedOperand) && "expected producer of input operand"); + /// Find the results of the producer that have uses outside of the consumer. + llvm::SmallDenseSet preservedProducerResults; + for (auto producerResult : llvm::enumerate(producer->getResults())) { + auto outputOperand = producer.getDpsInitOperand(producerResult.index()); + if (producer.payloadUsesValueFromOperand(outputOperand) || + !producer.canOpOperandsBeDropped(outputOperand) || + llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) { + return user != consumer.getOperation(); + })) { + preservedProducerResults.insert(producerResult.index()); + } + } // Compute the fused operands list and indexing maps. SmallVector fusedInputOperands, fusedOutputOperands; @@ -276,9 +295,9 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, SmallVector fusedIndexMaps; fusedInputOperands.reserve(producer.getNumDpsInputs() + consumer.getNumDpsInputs()); - fusedOutputOperands.reserve(producer.getNumDpsInits() + + fusedOutputOperands.reserve(preservedProducerResults.size() + consumer.getNumDpsInits()); - fusedResultTypes.reserve(producer.getNumDpsInits() + + fusedResultTypes.reserve(preservedProducerResults.size() + consumer.getNumDpsInits()); fusedIndexMaps.reserve(producer->getNumOperands() + consumer->getNumOperands()); @@ -313,13 +332,16 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, } // 6. Collect all of the producer outputs. - for (OpOperand *opOperand : producer.getDpsInitOperands()) { - fusedOutputOperands.push_back(opOperand->get()); + for (auto opOperand : llvm::enumerate(producer.getDpsInitOperands())) { + if (!preservedProducerResults.count(opOperand.index())) + continue; + + fusedOutputOperands.push_back(opOperand.value()->get()); AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( - opOperand, producerResultIndexMap, + opOperand.value(), producerResultIndexMap, consumer.getMatchingIndexingMap(fusedOperand)); fusedIndexMaps.push_back(map); - fusedResultTypes.push_back(opOperand->get().getType()); + fusedResultTypes.push_back(opOperand.value()->get().getType()); } // 7. All of consumer's output operands (skip operands: added by the builder). @@ -358,9 +380,9 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, AffineMap consumerToProducerLoopsMap = invProducerResultIndexMap.compose(consumerResultIndexMap); - generateFusedElementwiseOpRegion(rewriter, fusedOp, - consumerToProducerLoopsMap, fusedOperand, - consumer.getNumLoops()); + generateFusedElementwiseOpRegion( + rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand, + consumer.getNumLoops(), preservedProducerResults); return fusedOp.getOperation(); } diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir index 3cfb1b41c131..f11c58bb8ba8 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops-control -split-input-file | FileCheck %s #map0 = affine_map<(d0, d1) -> (d0, d1)> #binary2Dpointwise = { diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir new file mode 100644 index 000000000000..8131e4054cc6 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops-control -split-input-file | FileCheck %s + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @drop_unused_producer_result(%arg0 : tensor, + %arg1 : tensor) -> tensor { + %0:2 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) outs(%arg0, %arg0 : tensor, tensor) { + ^bb0(%b0: f32, %b1: f32, %b2: f32): + %1 = arith.addf %b0, %b0 : f32 + %2 = arith.mulf %b0, %b0 : f32 + linalg.yield %1, %2 : f32, f32 + } -> (tensor, tensor) + %3 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel"]} + ins(%0#0, %arg1 : tensor, tensor) outs(%arg0 : tensor) { + ^bb0(%b0: f32, %b1: f32, %b2: f32): + %4 = arith.subf %b0, %b1 : f32 + linalg.yield %4 : f32 + } -> tensor + return %3 : tensor +} +// CHECK-LABEL: func @drop_unused_producer_result +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : +// CHECK: return %[[FUSED_OP]] diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp index 4bd6d43f0f6e..e2f61a9611b0 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -75,6 +75,12 @@ struct TestLinalgElementwiseFusion llvm::cl::desc("Test fusion of generic operations."), llvm::cl::init(false)}; + Option fuseGenericOpsControl{ + *this, "fuse-generic-ops-control", + llvm::cl::desc( + "Test fusion of generic operations with a control function."), + llvm::cl::init(false)}; + Option fuseWithReshapeByExpansion{ *this, "fuse-with-reshape-by-expansion", llvm::cl::desc( @@ -108,6 +114,15 @@ struct TestLinalgElementwiseFusion func::FuncOp funcOp = this->getOperation(); if (fuseGenericOps) { + RewritePatternSet fusionPatterns(context); + auto controlFn = [](OpOperand *operand) { return true; }; + linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn); + (void)applyPatternsAndFoldGreedily(funcOp.getBody(), + std::move(fusionPatterns)); + return; + } + + if (fuseGenericOpsControl) { RewritePatternSet fusionPatterns(context); linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, setFusedOpOperandLimit<4>); -- 2.34.1