From: Mahesh Ravishankar Date: Wed, 24 Aug 2022 05:56:13 +0000 (+0000) Subject: [mlir][Linalg] Handle multi-result operations in Elementwise op fusion. X-Git-Tag: upstream/17.0.6~35601 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a7bfdc23ab3ade54da99f0f59dababe4d71ae75b;p=platform%2Fupstream%2Fllvm.git [mlir][Linalg] Handle multi-result operations in Elementwise op fusion. This drops the artificial requirement of producers having a single result value to be able to fuse with consumers. The current default also only fuses producer with consumer when the producer has a single use. This is a simplifying assumption. There are legitimate use cases where a producer can be fused with consumer and the fused o pcould be used to replace the uses of the producer as well. This needs to be done with care to avoid use-def violations. To allow for downstream users to explore more fusion opportunities, the core transformation method is exposed as a utility function. This patch also modifies the control function to take just the fused operand as the argument. This is enough information for the callers to get the producer and the consumer operations being considered to fuse. It also provides information of which producer result is used. Differential Revision: https://reviews.llvm.org/D132301 --- diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 8f53d5a..abd0243 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -66,8 +66,7 @@ void populateSparseTensorRewriting(RewritePatternSet &patterns); /// Function type which is used to control when to stop fusion. It is expected /// that OpOperand is not modified in the callback. The OpOperand is not marked /// as const to allow callers to use non-const methods. -using ControlFusionFn = - std::function; +using ControlFusionFn = std::function; /// Patterns for fusing linalg operation on tensors. @@ -111,6 +110,17 @@ void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns); /// Patterns that are used to bubble up extract slice op above linalg op. void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns); +/// Return true if two `linalg.generic` operations with producer/consumer +/// relationship through `fusedOperand` can be fused using elementwise op +/// fusion. +bool areElementwiseOpsFusable(OpOperand *fusedOperand); + +/// Fuse two `linalg.generic` operations that have a producer-consumer +/// relationship captured through `fusedOperand`. The method expects +/// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`. +FailureOr fuseElementwiseOps(RewriterBase &rewriter, + OpOperand *fusedOperand); + /// Split the given `op` into two parts along the given iteration space /// `dimension` at the specified `splitPoint`, and return the two parts. /// diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp index 05421bf..e46f7ae 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp @@ -122,10 +122,8 @@ public: // Identified this as a potential candidate for folding. Now check the // policy to see whether we are allowed to proceed. - for (int i = 0; i < numInputs; ++i) { - OpOperand *consumer = genericOp.getInputOperand(i); - OpResult producer = consumer->get().cast(); - if (!controlFn(producer, *consumer)) + for (auto operand : genericOp.getInputOperands()) { + if (!controlFn(operand)) return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index d19f926..e3d121c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -65,8 +65,14 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( } /// Conditions for elementwise fusion of generic operations. -static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, - OpOperand *consumerOpOperand) { +bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { + auto producer = fusedOperand->get().getDefiningOp(); + auto consumer = dyn_cast(fusedOperand->getOwner()); + + // Check producer and consumer are generic ops. + if (!producer || !consumer) + return false; + // Producer and consumer must have tensor semantics. if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) return false; @@ -78,19 +84,15 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, // Only allow fusing the producer of an input operand for now. // TODO: allow fusing the producer of an output operand. - if (!consumer.isInputTensor(consumerOpOperand)) + if (!consumer.isInputTensor(fusedOperand)) return false; // Get the consumer index map. The number of results of the consumer index // map must match the number of loops of the producer. - AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand); + AffineMap consumerIndexMap = consumer.getTiedIndexingMap(fusedOperand); if (consumerIndexMap.getNumResults() != producer.getNumLoops()) return false; - // Currently support only operations with single result. - if (producer.getNumOutputs() != 1) - return false; - // Finally the index_map for the result must be invertible. For now just // verify it is a permutation. AffineMap producerResultIndexMap = @@ -114,7 +116,7 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, for (auto pair : llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) { Value operand = std::get<0>(pair); - if (operand == consumerOpOperand->get()) + if (operand == fusedOperand->get()) continue; AffineMap operandMap = std::get<1>(pair); addToCoveredDims(operandMap); @@ -136,12 +138,11 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, /// Generate the region of the fused tensor operation. The region of the fused /// op must be empty. static void -generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, +generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, - OpOperand *consumerOpOperand, - unsigned nloops) { - auto producer = cast(consumerOpOperand->get().getDefiningOp()); - auto consumer = cast(consumerOpOperand->getOwner()); + OpOperand *fusedOperand, unsigned nloops) { + auto producer = cast(fusedOperand->get().getDefiningOp()); + auto consumer = cast(fusedOperand->getOwner()); // Build the region of the fused op. Block &producerBlock = producer->getRegion(0).front(); Block &consumerBlock = consumer->getRegion(0).front(); @@ -172,11 +173,11 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, } } // TODO: allow fusing the producer of an output operand. - assert(consumer.isInputTensor(consumerOpOperand) && + assert(consumer.isInputTensor(fusedOperand) && "expected producer of input operand"); // 3. Consumer input operands up to consumerIdx (exclusive). for (BlockArgument bbArg : consumerBlock.getArguments().take_front( - consumerOpOperand->getOperandNumber())) // input assumption. + fusedOperand->getOperandNumber())) // input assumption. mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); // Replacing consumerIdx requires getting the cloned, yielded, value from @@ -187,29 +188,22 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, producerBlock.getArguments().take_front(producer.getNumInputs())) mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); - // 4.b. Producer output operand/map that is fused needs to be mapped to the - // producer bbArg if it is an "initTensor" (i.e. its value is actually read). - assert(producer->getNumResults() == 1 && "expected single result producer"); - if (producer.isInitTensor(producer.getOutputOperand(0))) { - BlockArgument bbArg = producerBlock.getArguments() - .drop_front(producer.getNumInputs()) - // TODO: bbArg index of - .front(); - mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); - } // 5. Remaining consumer's input operands (drop past index `consumerIdx`). for (BlockArgument bbArg : consumerBlock.getArguments() .take_front(consumer.getNumInputs()) - .drop_front(consumerOpOperand->getOperandNumber() + 1)) + .drop_front(fusedOperand->getOperandNumber() + 1)) mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); - // 6. All of consumer's output operands. + + // 6. All of the producer's output operands + for (BlockArgument bbArg : + producerBlock.getArguments().take_back(producer.getNumOutputs())) + mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); + + // 7. All of consumer's output operands. for (BlockArgument bbArg : consumerBlock.getArguments().take_back(consumer.getNumOutputs())) mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); - // 7. All of producer's output operands except the one fused. - // TODO: allow fusion of multi-result producers. - assert(producer->getNumResults() == 1 && "expected single result producer"); // 8. Clone all producer operations except for the yield and index operations // to the fused operation. @@ -219,15 +213,15 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, } // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just // forward the yield operand. - auto yieldOp = cast(producerBlock.getTerminator()); - // TODO: allow fusion of multi-result producers. - assert(producer->getNumResults() == 1 && "expected single result producer"); - unsigned producerResultNumber = 0; + auto producerYieldOp = cast(producerBlock.getTerminator()); + unsigned producerResultNumber = + fusedOperand->get().cast().getResultNumber(); Value replacement = - mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber)); + mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber)); + // Sanity checks, if replacement is not already in the mapper then it must be // produced outside. - if (replacement == yieldOp.getOperand(producerResultNumber)) { + if (replacement == producerYieldOp.getOperand(producerResultNumber)) { if (auto bb = replacement.dyn_cast()) assert(bb.getOwner() != &producerBlock && "yielded block argument must have been mapped"); @@ -235,91 +229,101 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, assert(!producer->isAncestor(replacement.getDefiningOp()) && "yielded value must have been mapped"); } - mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()), + mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()), replacement); // 10. Clone operations from the consumer to the fused op. - for (auto &op : consumerBlock.getOperations()) + for (auto &op : consumerBlock.without_terminator()) rewriter.clone(op, mapper); + // 11. Include the final yield (which is the remapped values for all the + // yield) + auto consumerYieldOp = cast(consumerBlock.getTerminator()); + SmallVector fusedYieldValues; + fusedYieldValues.reserve(producerYieldOp.getNumOperands() + + consumerYieldOp.getNumOperands()); + for (auto producerYieldVal : producerYieldOp.getOperands()) + fusedYieldValues.push_back(mapper.lookupOrDefault(producerYieldVal)); + for (auto consumerYieldVal : consumerYieldOp.getOperands()) + fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal)); + rewriter.create(fusedOp.getLoc(), fusedYieldValues); + // Sanity checks. assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && "Ill-formed GenericOp region"); } -static Optional> -fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, - const ControlFusionFn &controlFn, - PatternRewriter &rewriter) { - auto consumer = cast(consumerOpOperand->getOwner()); - if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) || - !controlFn(producer->getResult(0), *consumerOpOperand)) - return llvm::None; - +FailureOr +mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, + OpOperand *fusedOperand) { + assert(areElementwiseOpsFusable(fusedOperand) && + "expected elementwise operation pre-conditions to pass"); + auto producerResult = fusedOperand->get().cast(); + auto producer = cast(producerResult.getOwner()); + auto consumer = cast(fusedOperand->getOwner()); // TODO: allow fusing the producer of an output operand. - assert(consumer.isInputTensor(consumerOpOperand) && + assert(consumer.isInputTensor(fusedOperand) && "expected producer of input operand"); // Compute the fused operands list and indexing maps. - SmallVector fusedOperands; + SmallVector fusedInputOperands, fusedOutputOperands; + SmallVector fusedResultTypes; SmallVector fusedIndexMaps; - fusedOperands.reserve(producer->getNumOperands() + - consumer->getNumOperands()); - fusedIndexMaps.reserve(producer->getNumOperands() + - consumer->getNumOperands()); + fusedInputOperands.reserve(producer.getNumInputs() + consumer.getNumInputs()); + fusedOutputOperands.reserve(producer.getNumOutputs() + + consumer.getNumOutputs()); + fusedResultTypes.reserve(producer.getNumOutputs() + consumer.getNumOutputs()); + fusedIndexMaps.reserve(producer.getNumInputsAndOutputs() + + consumer.getNumInputsAndOutputs()); // In the following, numbering matches that of `generateFusedTensorOpRegion`. // 3. Consumer input operands/maps up to consumerIdx (exclusive). SmallVector consumerInputs = consumer.getInputOperands(); SmallVector::iterator it = - llvm::find(consumerInputs, consumerOpOperand); + llvm::find(consumerInputs, fusedOperand); assert(it != consumerInputs.end() && "expected to find the consumer operand"); for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { - fusedOperands.push_back(opOperand->get()); + fusedInputOperands.push_back(opOperand->get()); fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); } // 4. Splice in producer's input operands/maps. - assert(producer->getNumResults() == 1 && "expected single result producer"); AffineMap producerResultIndexMap = - producer.getTiedIndexingMap(producer.getOutputOperand(0)); + producer.getTiedIndexingMapForResult(producerResult); for (OpOperand *opOperand : producer.getInputOperands()) { - fusedOperands.push_back(opOperand->get()); + fusedInputOperands.push_back(opOperand->get()); // Compute indexing maps for the producer args in the fused operation. AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( opOperand, producerResultIndexMap, - consumer.getTiedIndexingMap(consumerOpOperand)); - fusedIndexMaps.push_back(map); - } - // 4.b. Producer output operand/map that is fused needs to be passed if it is - // an "initTensor" (i.e. its value is actually read). - assert(producer->getNumResults() == 1 && "expected single result producer"); - if (producer.isInitTensor(producer.getOutputOperand(0))) { - fusedOperands.push_back(producer.getOutputOperand(0)->get()); - // Compute indexing maps for the producer args in the fused operation. - AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( - producer.getOutputOperand(0), producerResultIndexMap, - consumer.getTiedIndexingMap(consumerOpOperand)); + consumer.getTiedIndexingMap(fusedOperand)); fusedIndexMaps.push_back(map); } // 5. Remaining consumer's input operands/maps (drop past index // `consumerIdx`). for (OpOperand *opOperand : llvm::make_range(std::next(it), consumerInputs.end())) { - fusedOperands.push_back(opOperand->get()); + fusedInputOperands.push_back(opOperand->get()); fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); } - // 6. All of consumer's output operands (skip operands: added by the builder). - for (OpOperand *opOperand : consumer.getOutputOperands()) + + // 6. Collect all of the producer outputs. + for (OpOperand *opOperand : producer.getOutputOperands()) { + fusedOutputOperands.push_back(opOperand->get()); + AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( + opOperand, producerResultIndexMap, + consumer.getTiedIndexingMap(fusedOperand)); + fusedIndexMaps.push_back(map); + fusedResultTypes.push_back(opOperand->get().getType()); + } + + // 7. All of consumer's output operands (skip operands: added by the builder). + for (OpOperand *opOperand : consumer.getOutputOperands()) { + fusedOutputOperands.push_back(opOperand->get()); fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); - // 7. All of producer's output operands/maps except the one fused. - // TODO: allow fusion of multi-result producers. - assert(producer->getNumResults() == 1 && "expected single result producer"); + fusedResultTypes.push_back(opOperand->get().getType()); + } // Generate the fused op. - SmallVector consumerOutputs = consumer.getOutputOperands(); auto fusedOp = rewriter.create( - consumer.getLoc(), consumer->getResultTypes(), - /*inputs=*/fusedOperands, - // TODO: handle outputs. - consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps), + consumer.getLoc(), fusedResultTypes, fusedInputOperands, + fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps), consumer.getIteratorTypes(), /*doc=*/nullptr, /*library_call=*/nullptr); @@ -328,13 +332,13 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, // in the input, but going ahead here would result in verification errors. // So cleanup and abort. rewriter.eraseOp(fusedOp); - return llvm::None; + return rewriter.notifyMatchFailure( + fusedOp, "fused op failed loop bound computation check"); } // Construct an AffineMap from consumer loops to producer loops. // consumer loop -> tensor index - AffineMap consumerResultIndexMap = - consumer.getTiedIndexingMap(consumerOpOperand); + AffineMap consumerResultIndexMap = consumer.getTiedIndexingMap(fusedOperand); // tensor index -> producer loop AffineMap invProducerResultIndexMap = inversePermutation(producerResultIndexMap); @@ -345,19 +349,9 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, invProducerResultIndexMap.compose(consumerResultIndexMap); generateFusedElementwiseOpRegion(rewriter, fusedOp, - consumerToProducerLoopsMap, - consumerOpOperand, consumer.getNumLoops()); - return SmallVector(fusedOp->getResults()); -} - -static Optional> -fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, - GenericOp producer, const ControlFusionFn &controlFn) { - if (producer->getNumResults() != 1) - return llvm::None; - - return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn, - rewriter); + consumerToProducerLoopsMap, fusedOperand, + consumer.getNumLoops()); + return fusedOp.getOperation(); } namespace { @@ -373,14 +367,16 @@ public: PatternRewriter &rewriter) const override { // Find the first operand that is defined by another generic op on tensors. for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { - auto producer = - dyn_cast_or_null(opOperand->get().getDefiningOp()); - if (!producer || !producer.hasTensorSemantics()) + if (!areElementwiseOpsFusable(opOperand)) + continue; + if (!controlFn(opOperand)) continue; - Optional> fusedOpResults = - fuseElementwiseOps(rewriter, opOperand, producer, controlFn); - if (fusedOpResults) { - rewriter.replaceOp(genericOp, *fusedOpResults); + + FailureOr fusedOp = fuseElementwiseOps(rewriter, opOperand); + if (succeeded(fusedOp)) { + auto replacements = fusedOp.getValue()->getResults().take_back( + genericOp.getNumResults()); + rewriter.replaceOp(genericOp, replacements); return success(); } } @@ -713,6 +709,10 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); })); + // Set insertion point to the generic op. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(genericOp); + SmallVector expandedOpOperands; expandedOpOperands.reserve(genericOp.getNumInputs()); for (OpOperand *opOperand : genericOp.getInputOperands()) { @@ -792,7 +792,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, SmallVector resultVals; for (OpResult opResult : genericOp->getOpResults()) { int64_t resultNumber = opResult.getResultNumber(); - if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) { + if (resultTypes[resultNumber] != opResult.getType()) { SmallVector reassociation = getReassociationForExpansion( genericOp.getTiedIndexingMap( @@ -834,7 +834,7 @@ public: // - The tensor reshape op is folding. // - All constraints of fusing with reshape by expansion are met. if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) || - (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand))) + (!controlFoldingReshapes(opOperand))) continue; Optional> replacementValues = @@ -865,18 +865,50 @@ struct FoldReshapeWithGenericOpByExpansion LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, PatternRewriter &rewriter) const override { // Fold only if all constraints of fusing with reshape by expansion are met. - GenericOp producer = reshapeOp.getSrc().getDefiningOp(); - if (!producer || producer.getNumOutputs() != 1 || - !isFusableWithReshapeByDimExpansion(producer, - producer.getOutputOperand(0)) || - !controlFoldingReshapes(producer->getResult(0), - reshapeOp->getOpOperand(0))) - return failure(); + auto producerResult = reshapeOp.getSrc().dyn_cast(); + if (!producerResult) { + return rewriter.notifyMatchFailure(reshapeOp, + "source not produced by an operation"); + } + + auto producer = dyn_cast(producerResult.getOwner()); + if (!producer) { + return rewriter.notifyMatchFailure(reshapeOp, + "producer not a generic op"); + } + + if (!isFusableWithReshapeByDimExpansion( + producer, + producer.getOutputOperand(producerResult.getResultNumber()))) { + return rewriter.notifyMatchFailure( + reshapeOp, "failed preconditions of fusion with producer generic op"); + } + + if (!controlFoldingReshapes(&reshapeOp->getOpOperand(0))) { + return rewriter.notifyMatchFailure(reshapeOp, + "fusion blocked by control function"); + } + Optional> replacementValues = fuseWithReshapeByExpansion( - producer, reshapeOp, producer.getOutputOperand(0), rewriter); - if (!replacementValues) - return failure(); - rewriter.replaceOp(reshapeOp, *replacementValues); + producer, reshapeOp, + producer.getOutputOperand(producerResult.getResultNumber()), rewriter); + if (!replacementValues) { + return rewriter.notifyMatchFailure(reshapeOp, + "fusion by expansion failed"); + } + + // Find the replacement for the reshape op. Since the replacements have the + // same type as the returns of the original generic op, the consumer reshape + // op can be replaced by the source of the collapse_shape op that defines + // the replacement. + Value reshapeReplacement = (*replacementValues) + [reshapeOp.getSrc().cast().getResultNumber()]; + if (auto collapseOp = + reshapeReplacement.getDefiningOp()) { + reshapeReplacement = collapseOp.getSrc(); + } + rewriter.replaceOp(reshapeOp, reshapeReplacement); + rewriter.replaceOp(producer, *replacementValues); return success(); } @@ -1469,7 +1501,7 @@ public: getCollapsableIterationSpaceDims(genericOp, opOperand, reshapeOp.getReassociationIndices()); if (collapsableIterationDims.empty() || - !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) { + !controlFoldingReshapes(opOperand)) { continue; } @@ -1726,9 +1758,9 @@ struct LinalgElementwiseOpFusionPass RewritePatternSet patterns(context); // Add folding with reshape by expansion patterns. - ControlFusionFn defaultControlFn = [](const OpResult &producer, - const OpOperand &consumer) { - return producer.hasOneUse(); + ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) { + Operation *producer = fusedOperand->get().getDefiningOp(); + return producer && producer->hasOneUse(); }; // Add elementwise op fusion patterns. diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index b7c91c0..3002d44 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -662,12 +662,12 @@ func.func @generic_index_op2(%arg0: tensor<1x8xf64>, %arg1: tensor<1x8xi32>) -> linalg.yield %r : f64 } -> tensor<1x8xf64> - // CHECK-NEXT: %[[R:.*]] = linalg.generic + // CHECK-NEXT: %[[R:.*]]:2 = linalg.generic // CHECK: bb0(%[[BBA:[0-9a-z]*]]: f64, %[[BBB:[0-9a-z]*]]: i32): // CHECK-NEXT: %[[A:.*]] = func.call @compute1(%[[BBA]]) : (f64) -> f64 // CHECK-NEXT: %[[B:.*]] = func.call @compute2(%[[A]], %[[BBB]]) : (f64, i32) -> i32 - // CHECK-NEXT: linalg.yield %[[B]] : i32 - // CHECK-NEXT: } -> tensor<1x8xi32> + // CHECK-NEXT: linalg.yield %[[A]], %[[B]] : f64, i32 + // CHECK-NEXT: } -> (tensor<1x8xf64>, tensor<1x8xi32>) %1 = linalg.generic { indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>], iterator_types = ["parallel", "parallel"]} @@ -678,7 +678,7 @@ func.func @generic_index_op2(%arg0: tensor<1x8xf64>, %arg1: tensor<1x8xi32>) -> linalg.yield %r : i32 } -> tensor<1x8xi32> - // CHECK-NEXT: return %[[R]] : tensor<1x8xi32> + // CHECK-NEXT: return %[[R]]#1 : tensor<1x8xi32> return %1 : tensor<1x8xi32> } @@ -948,7 +948,7 @@ func.func @no_fusion_missing_reduction_shape(%arg0: tensor, %arg1: index) - // ----- -func.func @illegal_fusion(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) -> tensor<5000xi32> { +func.func @fusion_different_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) -> tensor<5000xi32> { %c1_i32 = arith.constant 1 : i32 %0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>], @@ -971,10 +971,25 @@ func.func @illegal_fusion(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) -> } -> tensor<5000xi32> return %2 : tensor<5000xi32> } -// CHECK-LABEL: func @illegal_fusion( -// CHECK: %[[PRODUCER:.+]] = linalg.generic -// CHECK: linalg.generic -// CHECK-SAME: ins(%[[PRODUCER]] +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)> +// CHECK: func @fusion_different_axes( +// CHECK-SAME: %[[ARG0:.+]]: tensor<5000xi64> +// CHECK-SAME: %[[ARG1:.+]]: tensor<5000xi32> +// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [5000] : tensor<5000xi64> +// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [5000] : tensor<5000xi32> +// CHECK: %[[RESULT:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: outs(%[[INIT0]], %[[INIT1]] : +// CHECK-NEXT: ^bb0( +// CHECK-SAME: %[[B0:.+]]: i64 +// CHECK-SAME: %[[B1:.+]]: i32 +// CHECK-DAG: %[[T0:.+]] = linalg.index 0 +// CHECK-DAG: %[[CAST1:.+]] = arith.index_cast %[[T0]] : index to i64 +// CHECK-DAG: %[[CAST2:.+]] = arith.index_cast %[[CAST1]] : i64 to index +// CHECK: %[[EXTRACT:.+]] = tensor.extract %[[ARG1]][%[[CAST2]]] +// CHECK: linalg.yield %[[CAST1]], %[[EXTRACT]] +// CHECK: return %[[RESULT]]#1 // ----- @@ -995,7 +1010,7 @@ func.func @fold_fill_generic_basic(%arg0: tensor) -> (tensor) { %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor, tensor) outs (%3:tensor) { ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): %5 = arith.addf %arg1, %arg2 : f32 - linalg.yield %5 : f32 + linalg.yield %5 : f32 } -> tensor return %4 : tensor } @@ -1024,7 +1039,50 @@ func.func @fold_fill_generic_mixedaccess(%arg0: tensor) -> (tensor, tensor) outs (%6:tensor) { ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): %8 = arith.divf %arg1, %arg2 : f32 - linalg.yield %8 : f32 + linalg.yield %8 : f32 } -> tensor return %7 : tensor } + +// ----- + +#map = affine_map<() -> ()> +module { + func.func @fuse_multi_result_producer(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor { + %0 = linalg.init_tensor [] : tensor + %1 = linalg.init_tensor [] : tensor + %2:2 = linalg.generic { + indexing_maps = [#map, #map, #map, #map, #map], iterator_types = []} + ins(%arg0, %arg1, %arg1 : tensor, tensor, tensor) outs(%0, %1 : tensor, tensor) { + ^bb0(%arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32, %arg9: f32): + %4 = arith.addf %arg5, %arg6 : f32 + %5 = arith.addf %4, %arg7 : f32 + linalg.yield %4, %5 : f32, f32 + } -> (tensor, tensor) + %3 = linalg.generic { + indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%2#1, %arg1 : tensor, tensor) outs(%arg4 : tensor) { + ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): + %4 = arith.addf %arg5, %arg6 : f32 + %5 = arith.addf %4, %arg6 : f32 + linalg.yield %5 : f32 + } -> tensor + return %3 : tensor + } +} +// CHECK-LABEL: func.func @fuse_multi_result_producer +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : +// CHECK-SAME: outs(%[[INIT]] : +// CHECK-NEXT: ^bb0 +// CHECK-SAME: %[[B0:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B1:[a-zA-Z0-9]+]]: f32 +// CHECK-DAG: %[[T0:.+]] = arith.addf %[[B0]], %[[B1]] +// CHECK-DAG: %[[T1:.+]] = arith.addf %[[T0]], %[[B1]] +// CHECK-DAG: %[[T2:.+]] = arith.addf %[[T1]], %[[B1]] +// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]] +// CHECK: linalg.yield %[[T3]] : f32 +// CHECK: return %[[GENERIC]] diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir index 1efaaa3..b051728 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file -canonicalize | FileCheck %s --check-prefix=CANONICALIZE #map0 = affine_map<(d0, d1) -> (d0, d1)> #binary2Dpointwise = { @@ -58,5 +59,17 @@ func.func @test_fusion_limit( // CHECK-SAME: %[[ARG5:[a-zA-z0-9_]+]]: tensor // CHECK: %[[OP1:.+]] = linalg.generic {{.+}} ins(%[[ARG2]], %[[ARG3]] // CHECK: %[[OP2:.+]] = linalg.generic {{.+}} ins(%[[ARG4]], %[[ARG5]] -// CHECK: %[[OP3:.+]] = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]] -// CHECK: return %[[OP3]] +// CHECK: %[[OP3:.+]]:2 = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]] +// CHECK: return %[[OP3]]#1 + +// CANONICALIZE-LABEL: func @test_fusion_limit +// CANONICALIZE-SAME: %[[ARG0:[a-zA-z0-9_]+]]: tensor +// CANONICALIZE-SAME: %[[ARG1:[a-zA-z0-9_]+]]: tensor +// CANONICALIZE-SAME: %[[ARG2:[a-zA-z0-9_]+]]: tensor +// CANONICALIZE-SAME: %[[ARG3:[a-zA-z0-9_]+]]: tensor +// CANONICALIZE-SAME: %[[ARG4:[a-zA-z0-9_]+]]: tensor +// CANONICALIZE-SAME: %[[ARG5:[a-zA-z0-9_]+]]: tensor +// CANONICALIZE: %[[OP1:.+]] = linalg.generic {{.+}} ins(%[[ARG2]], %[[ARG3]] +// CANONICALIZE: %[[OP2:.+]] = linalg.generic {{.+}} ins(%[[ARG4]], %[[ARG5]] +// CANONICALIZE: %[[OP3:.+]] = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]] +// CANONICALIZE: return %[[OP3]] diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index 45e87212..e151b99 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -499,3 +499,43 @@ func.func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor, tensor) // CHECK: return %[[GENERIC]] + +// ----- + +func.func @reshape_as_consumer_permutation_with_multiple_results + (%a : tensor, %b : tensor) + -> (tensor, tensor) { + %c:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d2, d1)>, + affine_map<(d0, d1, d2) -> (d2, d0, d1)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%a, %b : tensor, tensor) + outs(%a, %a : tensor, tensor) { + ^bb0(%arg0 : f32, %arg1: f32, %s: f32, %t : f32): + %1 = arith.addf %arg0, %arg1 : f32 + linalg.yield %1, %1 : f32, f32 + } -> (tensor, tensor) + %d = tensor.expand_shape %c#0 [[0, 1], [2], [3, 4, 5]] + : tensor into tensor + %e = tensor.expand_shape %c#1 [[0], [1, 2], [3, 4, 5]] + : tensor into tensor + return %d, %e : tensor, tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5, d0, d1, d2, d3, d4)> +// CHECK: func @reshape_as_consumer_permutation_with_multiple_results +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[RESHAPE0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3, 4], [5]{{\]}} +// CHECK-DAG: %[[RESHAPE1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1, 2], [3]{{\]}} +// CHECK-DAG: %[[RESHAPE2:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3, 4, 5]{{\]}} +// CHECK-DAG: %[[RESHAPE3:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3, 4, 5]{{\]}} +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] +// CHECK-SAME: ins(%[[RESHAPE0]], %[[RESHAPE1]] : +// CHECK-SAME: outs(%[[RESHAPE2]], %[[RESHAPE3]] : +// CHECK: return %[[GENERIC]]#0, %[[GENERIC]]#1 diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp index 1b046b9..41e46d0 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -35,14 +35,18 @@ static void addOperands(Operation *op, SetVector &operandSet) { } template -static bool setFusedOpOperandLimit(const OpResult &producer, - const OpOperand &consumer) { +static bool setFusedOpOperandLimit(OpOperand *fusedOperand) { + Operation *producer = fusedOperand->get().getDefiningOp(); + if (!producer) + return false; + + Operation *consumer = fusedOperand->getOwner(); SetVector fusedOpOperands; - if (producer.getOwner()->getNumResults() != 1) + if (producer->getNumResults() != 1) return false; - addOperands(consumer.getOwner(), fusedOpOperands); - fusedOpOperands.remove(producer); - addOperands(producer.getOwner(), fusedOpOperands); + addOperands(consumer, fusedOpOperands); + fusedOpOperands.remove(producer->getResult(0)); + addOperands(producer, fusedOpOperands); return fusedOpOperands.size() <= limit; } @@ -113,8 +117,7 @@ struct TestLinalgElementwiseFusion if (fuseWithReshapeByExpansion) { RewritePatternSet fusionPatterns(context); linalg::populateFoldReshapeOpsByExpansionPatterns( - fusionPatterns, [](const OpResult & /*producer*/, - OpOperand & /*consumer*/) { return true; }); + fusionPatterns, [](OpOperand * /*fusedOperand*/) { return true; }); if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(fusionPatterns)))) return signalPassFailure(); @@ -125,15 +128,19 @@ struct TestLinalgElementwiseFusion RewritePatternSet fusionPatterns(context); linalg::ControlFusionFn controlReshapeFusionFn = - [](const OpResult &producer, OpOperand &consumer) { - if (auto collapseOp = - producer.getDefiningOp()) { + [](OpOperand *fusedOperand) { + Operation *producer = fusedOperand->get().getDefiningOp(); + if (!producer) + return false; + + if (auto collapseOp = dyn_cast(producer)) { if (!collapseOp.getSrc().getDefiningOp()) { return false; } } - if (auto expandOp = - dyn_cast(consumer.getOwner())) { + + Operation *consumer = fusedOperand->getOwner(); + if (auto expandOp = dyn_cast(consumer)) { if (expandOp->hasOneUse()) { OpOperand &use = *expandOp->getUses().begin(); auto linalgOp = dyn_cast(use.getOwner()); @@ -155,18 +162,17 @@ struct TestLinalgElementwiseFusion if (fuseWithReshapeByCollapsing) { RewritePatternSet patterns(context); linalg::populateFoldReshapeOpsByCollapsingPatterns( - patterns, [](const OpResult & /*producer*/, - OpOperand & /*consumer*/) { return true; }); + patterns, [](OpOperand * /*fusedOperand */) { return true; }); (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); } if (fuseWithReshapeByCollapsingWithControlFn) { RewritePatternSet patterns(context); - linalg::ControlFusionFn controlFn = [](const OpResult &producer, - OpOperand &consumer) -> bool { - if (isa(producer.getDefiningOp())) { + linalg::ControlFusionFn controlFn = [](OpOperand *fusedOperand) -> bool { + Operation *producer = fusedOperand->get().getDefiningOp(); + if (isa(producer)) { // Skip fusing the first operand. - return consumer.getOperandNumber(); + return fusedOperand->getOperandNumber(); } return true; };