/// Generate the region of the fused tensor operation. The region of the fused
/// op must be empty.
-static void
-generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp,
- AffineMap consumerToProducerLoopsMap,
- OpOperand *fusedOperand, unsigned nloops) {
+static void generateFusedElementwiseOpRegion(
+ RewriterBase &rewriter, GenericOp fusedOp,
+ AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
+ unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
// Build the region of the fused op.
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
// 6. All of the producer's output operands
- for (BlockArgument bbArg :
- producerBlock.getArguments().take_back(producer.getNumDpsInits()))
- mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
+ for (auto bbArg : llvm::enumerate(
+ producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
+ if (!preservedProducerResults.count(bbArg.index()))
+ continue;
+ mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
+ bbArg.value().getLoc()));
+ }
// 7. All of consumer's output operands.
for (BlockArgument bbArg :
SmallVector<Value> fusedYieldValues;
fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
consumerYieldOp.getNumOperands());
- for (auto producerYieldVal : producerYieldOp.getOperands())
- fusedYieldValues.push_back(mapper.lookupOrDefault(producerYieldVal));
+ for (auto producerYieldVal : llvm::enumerate(producerYieldOp.getOperands())) {
+ if (preservedProducerResults.count(producerYieldVal.index()))
+ fusedYieldValues.push_back(
+ mapper.lookupOrDefault(producerYieldVal.value()));
+ }
for (auto consumerYieldVal : consumerYieldOp.getOperands())
fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
// TODO: allow fusing the producer of an output operand.
assert(consumer.isDpsInput(fusedOperand) &&
"expected producer of input operand");
+ /// Find the results of the producer that have uses outside of the consumer.
+ llvm::SmallDenseSet<int> preservedProducerResults;
+ for (auto producerResult : llvm::enumerate(producer->getResults())) {
+ auto outputOperand = producer.getDpsInitOperand(producerResult.index());
+ if (producer.payloadUsesValueFromOperand(outputOperand) ||
+ !producer.canOpOperandsBeDropped(outputOperand) ||
+ llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
+ return user != consumer.getOperation();
+ })) {
+ preservedProducerResults.insert(producerResult.index());
+ }
+ }
// Compute the fused operands list and indexing maps.
SmallVector<Value> fusedInputOperands, fusedOutputOperands;
SmallVector<AffineMap> fusedIndexMaps;
fusedInputOperands.reserve(producer.getNumDpsInputs() +
consumer.getNumDpsInputs());
- fusedOutputOperands.reserve(producer.getNumDpsInits() +
+ fusedOutputOperands.reserve(preservedProducerResults.size() +
consumer.getNumDpsInits());
- fusedResultTypes.reserve(producer.getNumDpsInits() +
+ fusedResultTypes.reserve(preservedProducerResults.size() +
consumer.getNumDpsInits());
fusedIndexMaps.reserve(producer->getNumOperands() +
consumer->getNumOperands());
}
// 6. Collect all of the producer outputs.
- for (OpOperand *opOperand : producer.getDpsInitOperands()) {
- fusedOutputOperands.push_back(opOperand->get());
+ for (auto opOperand : llvm::enumerate(producer.getDpsInitOperands())) {
+ if (!preservedProducerResults.count(opOperand.index()))
+ continue;
+
+ fusedOutputOperands.push_back(opOperand.value()->get());
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
- opOperand, producerResultIndexMap,
+ opOperand.value(), producerResultIndexMap,
consumer.getMatchingIndexingMap(fusedOperand));
fusedIndexMaps.push_back(map);
- fusedResultTypes.push_back(opOperand->get().getType());
+ fusedResultTypes.push_back(opOperand.value()->get().getType());
}
// 7. All of consumer's output operands (skip operands: added by the builder).
AffineMap consumerToProducerLoopsMap =
invProducerResultIndexMap.compose(consumerResultIndexMap);
- generateFusedElementwiseOpRegion(rewriter, fusedOp,
- consumerToProducerLoopsMap, fusedOperand,
- consumer.getNumLoops());
+ generateFusedElementwiseOpRegion(
+ rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
+ consumer.getNumLoops(), preservedProducerResults);
return fusedOp.getOperation();
}
--- /dev/null
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops-control -split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @drop_unused_producer_result(%arg0 : tensor<?x?xf32>,
+ %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0:2 = linalg.generic {
+ indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<?x?xf32>) outs(%arg0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>) {
+ ^bb0(%b0: f32, %b1: f32, %b2: f32):
+ %1 = arith.addf %b0, %b0 : f32
+ %2 = arith.mulf %b0, %b0 : f32
+ linalg.yield %1, %2 : f32, f32
+ } -> (tensor<?x?xf32>, tensor<?x?xf32>)
+ %3 = linalg.generic {
+ indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0#0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg0 : tensor<?x?xf32>) {
+ ^bb0(%b0: f32, %b1: f32, %b2: f32):
+ %4 = arith.subf %b0, %b1 : f32
+ linalg.yield %4 : f32
+ } -> tensor<?x?xf32>
+ return %3 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @drop_unused_producer_result
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK: %[[FUSED_OP:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
+// CHECK: return %[[FUSED_OP]]