/// Populate patterns for splitting a `LinalgOp` with multiple statements within
/// its payload into multiple `GenericOp` that have a single statement.
-void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns);
+/// The option `removeDeadArgsAndResults` adds patterns to remove dead arguments
+/// and results from the generated decomposed ops. This is default `true` since
+/// the core decomposition patterns relies on these clean up patterns. It is set
+/// to false only for testing purposes.
+void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns,
+ bool removeDeadArgsAndResults = true);
/// Populate patterns for vectorizing low-D convolution ops. This is a step in
/// progressive lowering for convolution ops, it assume high-D convolution ops
RewritePatternSet &patterns,
const ControlFusionFn &controlElementwiseOpFusion);
+/// Pattern to remove dead operands and results of `linalg.generic` operations.
+/// This is effectively DCE for a linalg op.
+void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
+
/// Function type to control generic op dimension collapsing. It is expected
/// to return an array of `ReassociationIndices` representing dimensions that
/// should be merged.
getDpsInputOperands(), getDpsInitOperands());
}
-static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
- if (!result.use_empty())
- return false;
- // If out operand not used in payload, we can drop it.
- OpOperand *outputOpOperand =
- genericOp.getDpsInitOperand(result.getResultNumber());
- if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
- return true;
-
- // The out operand that is part of a payload can be dropped if
- // these conditions are met:
- // - Result from out operand is dead.
- // - User of arg is yield.
- // - outArg data is not being used by other outArgs.
-
- // Check block arg and cycle from out operand has a single use.
- BlockArgument outputArg =
- genericOp.getRegionOutputArgs()[result.getResultNumber()];
- if (!outputArg.hasOneUse())
- return false;
- Operation *argUserOp = *outputArg.user_begin();
-
- // Check argUser has no other use.
- if (!argUserOp->use_empty())
- return false;
-
- // Check that argUser is a yield.
- auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
- if (!yieldOp)
- return false;
-
- // Check outArg data is not being used by other outArgs.
- if (yieldOp.getOperand(result.getResultNumber()) != outputArg)
- return false;
-
- return true;
-}
-
LogicalResult GenericOp::verify() { return success(); }
namespace {
-struct DeduplicateAndRemoveDeadOperandsAndResults
- : public OpRewritePattern<GenericOp> {
- using OpRewritePattern<GenericOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(GenericOp genericOp,
- PatternRewriter &rewriter) const override {
- // Create a map from argument position in the original op to the argument
- // position in the new op. If the argument is dropped it wont have an entry.
- SmallVector<OpOperand *> droppedOpOperands;
-
- // Information needed to build the new op.
- SmallVector<Value> newInputOperands, newOutputOperands;
- SmallVector<AffineMap> newIndexingMaps;
-
- // Gather information about duplicate input operands.
- llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
- deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
- newIndexingMaps);
-
- // Gather information about the dropped outputs.
- llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
- deduplicateOutputOperands(genericOp, droppedOpOperands,
- newOutputOperands, newIndexingMaps);
-
- // Check if there is any change to operands.
- if (newInputOperands.size() + newOutputOperands.size() ==
- genericOp->getNumOperands())
- return failure();
-
- // Create the new op with the body being empty.
- Location loc = genericOp.getLoc();
- SmallVector<Type> newResultTypes;
- for (Value v : newOutputOperands)
- if (v.getType().isa<TensorType>())
- newResultTypes.push_back(v.getType());
- auto newOp = rewriter.create<GenericOp>(
- loc, newResultTypes, newInputOperands, newOutputOperands,
- rewriter.getAffineMapArrayAttr(newIndexingMaps),
- genericOp.getIteratorTypes(), genericOp.getDocAttr(),
- genericOp.getLibraryCallAttr(),
- [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
- return;
- });
- // Copy over unknown attributes. They might be load bearing for some flow.
- ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
- for (NamedAttribute kv : genericOp->getAttrs())
- if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
- newOp->setAttr(kv.getName(), kv.getValue());
-
- // Fix up the payload of the canonicalized operation.
- populateOpPayload(genericOp, newOp, origInsToNewInsPos,
- origOutsToNewOutsPos, rewriter);
-
- // Replace all live uses of the op.
- SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
- for (const auto &result : llvm::enumerate(genericOp.getResults())) {
- auto it = origOutsToNewOutsPos.find(result.index());
- if (it == origOutsToNewOutsPos.end())
- continue;
- replacementsVals[result.index()] = newOp.getResult(it->second);
- }
- rewriter.replaceOp(genericOp, replacementsVals);
- return success();
- }
-
-private:
- // Deduplicate input operands, and return the
- // - Mapping from operand position in the original op, to operand position in
- // the canonicalized op.
- // - The preserved input operands list (by reference).
- llvm::SmallDenseMap<unsigned, unsigned>
- deduplicateInputOperands(GenericOp genericOp,
- SmallVector<OpOperand *> &droppedOpOperands,
- SmallVector<Value> &newInputOperands,
- SmallVector<AffineMap> &newIndexingMaps) const {
- llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
- llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
- for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
- OpOperand *inputOpOperand = en.value();
- // Check if operand is dead and if dropping the indexing map makes the
- // loops to shape computation invalid.
- if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
- // Add the current operands to the list of potentially droppable
- // operands. If it cannot be dropped, this needs to be popped back.
- droppedOpOperands.push_back(inputOpOperand);
- if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
- continue;
- droppedOpOperands.pop_back();
- }
-
- // Check if this operand is a duplicate.
- AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
- auto it = dedupedInputs.find(
- std::make_pair(inputOpOperand->get(), indexingMap));
- if (it != dedupedInputs.end()) {
- origToNewPos[en.index()] = it->second;
- droppedOpOperands.push_back(inputOpOperand);
- continue;
- }
-
- // This is a preserved argument.
- origToNewPos[en.index()] = newInputOperands.size();
- dedupedInputs[{inputOpOperand->get(), indexingMap}] =
- newInputOperands.size();
- newInputOperands.push_back(inputOpOperand->get());
- newIndexingMaps.push_back(indexingMap);
- }
- return origToNewPos;
- }
-
- // Deduplicate output operands, and return the
- // - Mapping from operand position in the original op, to operand position in
- // the canonicalized op.
- // - The preserved output operands list (by reference).
- llvm::SmallDenseMap<unsigned, unsigned>
- deduplicateOutputOperands(GenericOp genericOp,
- SmallVector<OpOperand *> &droppedOpOperands,
- SmallVector<Value> &newOutputOperands,
- SmallVector<AffineMap> &newIndexingMaps) const {
- llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
- llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
- dedupedOutpts;
- // If the op doesnt have tensor semantics, keep all the outputs as
- // preserved.
- if (!genericOp.hasTensorSemantics()) {
- for (const auto &en : llvm::enumerate(genericOp.getDpsInitOperands())) {
- origToNewPos[en.index()] = newOutputOperands.size();
- newOutputOperands.push_back(en.value()->get());
- newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(en.value()));
- }
- return origToNewPos;
- }
- // Output argument can be dropped if the result has
- // - no users, and
- // - it is not used in the payload, and
- // - the corresponding indexing maps are not needed for loop bound
- // computation.
- auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
- for (const auto &outputOpOperand :
- llvm::enumerate(genericOp.getDpsInitOperands())) {
- OpResult result = genericOp.getTiedOpResult(outputOpOperand.value());
- AffineMap indexingMap =
- genericOp.getMatchingIndexingMap(outputOpOperand.value());
- auto key = std::make_tuple(outputOpOperand.value()->get(), indexingMap,
- yieldOp->getOperand(outputOpOperand.index()));
- if (isResultValueDead(genericOp, result)) {
- // Check if the opoperand can be dropped without affecting loop
- // bound computation. Add the operand to the list of dropped op
- // operand for checking. If it cannot be dropped, need to pop the
- // value back.
- droppedOpOperands.push_back(outputOpOperand.value());
- if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
- continue;
- }
- droppedOpOperands.pop_back();
- }
-
- if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) {
- // The out operand can also be dropped if it is computed redundantly
- // by another result, the conditions for that are
- // - The same operand is used as the out operand
- // - The same indexing map is used
- // - The same yield value is used.
- auto it = dedupedOutpts.find(key);
- if (it != dedupedOutpts.end()) {
- origToNewPos[outputOpOperand.index()] = it->second;
- droppedOpOperands.push_back(outputOpOperand.value());
- continue;
- }
- }
-
- origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
- dedupedOutpts[key] = newOutputOperands.size();
- newOutputOperands.push_back(outputOpOperand.value()->get());
- newIndexingMaps.push_back(
- genericOp.getMatchingIndexingMap(outputOpOperand.value()));
- }
- return origToNewPos;
- }
-
- // Populate the body of the canonicalized operation.
- void populateOpPayload(
- GenericOp genericOp, GenericOp newOp,
- const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
- const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
- PatternRewriter &rewriter) const {
- // Merge the body of the original op with the new op.
- Block *newOpBlock = &newOp.getRegion().front();
- assert(newOpBlock->empty() && "expected new op to have an empty payload");
- Block *origOpBlock = &genericOp.getRegion().front();
- SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
-
- // Replace all arguments in the original op, with arguments from the
- // canonicalized op.
- auto updateReplacements =
- [&](OpOperandVector &origOperands, OpOperandVector &newOperands,
- const llvm::SmallDenseMap<unsigned, unsigned> &map) {
- for (const auto &origOperand : llvm::enumerate(origOperands)) {
- auto it = map.find(origOperand.index());
- if (it == map.end())
- continue;
- OpOperand *newOperand = newOperands[it->second];
- replacements[origOperand.value()->getOperandNumber()] =
- newOpBlock->getArgument(newOperand->getOperandNumber());
- }
- };
-
- OpOperandVector origInputOperands = genericOp.getDpsInputOperands();
- OpOperandVector newInputOperands = newOp.getDpsInputOperands();
- updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
-
- OpOperandVector origOutputOperands = genericOp.getDpsInitOperands();
- OpOperandVector newOutputOperands = newOp.getDpsInitOperands();
- updateReplacements(origOutputOperands, newOutputOperands,
- origOutsToNewOutsPos);
-
- // Drop the unused yield args.
- if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
- OpBuilder::InsertionGuard g(rewriter);
- YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
- rewriter.setInsertionPoint(origYieldOp);
-
- SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
- for (const auto &yieldOpOperands :
- llvm::enumerate(origYieldOp.getValues())) {
- auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
- if (it == origOutsToNewOutsPos.end())
- continue;
- newYieldVals[it->second] = yieldOpOperands.value();
- }
- rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
- }
-
- rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
- }
-};
-
/// Remove generic operations (on tensors) that are just copying
/// the values from inputs to the results. Requirements are
/// 1) All iterator types are parallel
}
};
-/// Remove unused cycles.
-/// We can remove unused cycle within a payload of generic region
-/// if these conditions are met:
-/// - Result from out operand is dead.
-/// - Block arg from out operand has a single use in the %cycle
-/// instruction.
-/// - Cycle has a single use and it is in yield.
-struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
- using OpRewritePattern<GenericOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(GenericOp genericOp,
- PatternRewriter &rewriter) const override {
-
- // If the op doesnt have tensor semantics, preserve the outputs as is.
- if (!genericOp.hasTensorSemantics())
- return failure();
-
- bool hasRemovedCycles = false;
- // Iterate over output operands and remove any unused cycles.
- for (const auto &outputOpOperand :
- llvm::enumerate(genericOp.getDpsInitOperands())) {
-
- // Check that result from out operand is dead.
- Value result = genericOp.getResult(outputOpOperand.index());
- if (!result.use_empty())
- continue;
-
- // Check that outputArg has one use in cycle.
- BlockArgument outputArg =
- genericOp.getRegionOutputArgs()[outputOpOperand.index()];
- if (!outputArg.hasOneUse())
- continue;
-
- // Check cycle has at most one use.
- Operation *cycleOp = *outputArg.user_begin();
- if (!cycleOp->hasOneUse())
- continue;
-
- // Check that the cycleUser is a yield.
- Operation *cycleUserOp = *cycleOp->user_begin();
- if (!isa<linalg::YieldOp>(cycleUserOp))
- continue;
-
- // Check that argIndex matches yieldIndex, else data is being used.
- if (cycleUserOp->getOperand(outputOpOperand.index()) !=
- cycleOp->getResult(0))
- continue;
-
- // Directly replace the cycle with the blockArg such that
- // Deduplicate pattern can eliminate it along with unused yield.
- rewriter.replaceOp(cycleOp, outputArg);
- rewriter.updateRootInPlace(genericOp, [] {});
- hasRemovedCycles = true;
- }
-
- if (hasRemovedCycles) {
- return success();
- }
-
- return failure();
- }
-};
} // namespace
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DeduplicateAndRemoveDeadOperandsAndResults,
- EraseIdentityGenericOp, RemoveUnusedCycleInGenericOp>(context);
+ results.add<EraseIdentityGenericOp>(context);
}
LogicalResult GenericOp::fold(ArrayRef<Attribute>,
DropUnitDims.cpp
ElementwiseOpFusion.cpp
ElementwiseToLinalg.cpp
+ EraseUnusedOperandsAndResults.cpp
FusePadOpWithLinalgProducer.cpp
Fusion.cpp
FusionOnTensors.cpp
}
void mlir::linalg::populateDecomposeLinalgOpsPattern(
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns, bool removeDeadArgsAndResults) {
patterns.insert<DecomposeLinalgOp>(patterns.getContext());
+ // Add the patterns to clean up the dead operands and results.
+ if (removeDeadArgsAndResults)
+ populateEraseUnusedOperandsAndResultsPatterns(patterns);
}
patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
RemoveOutsDependency>(context);
+ // Add the patterns that clean up dead operands and results.
+ populateEraseUnusedOperandsAndResultsPatterns(patterns);
}
void mlir::linalg::populateCollapseDimensions(
--- /dev/null
+//===- EraseUnusedOperandsAndResults.cpp ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+/// Return `true` if the `result` of an operation `genericOp` is dead.
+static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
+ if (!result.use_empty())
+ return false;
+ // If out operand not used in payload, we can drop it.
+ OpOperand *outputOpOperand =
+ genericOp.getDpsInitOperand(result.getResultNumber());
+ if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
+ return true;
+
+ // The out operand that is part of a payload can be dropped if
+ // these conditions are met:
+ // - Result from out operand is dead.
+ // - User of arg is yield.
+ // - outArg data is not being used by other outArgs.
+
+ // Check block arg and cycle from out operand has a single use.
+ BlockArgument outputArg =
+ genericOp.getRegionOutputArgs()[result.getResultNumber()];
+ if (!outputArg.hasOneUse())
+ return false;
+ Operation *argUserOp = *outputArg.user_begin();
+
+ // Check argUser has no other use.
+ if (!argUserOp->use_empty())
+ return false;
+
+ // Check that argUser is a yield.
+ auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
+ if (!yieldOp)
+ return false;
+
+ // Check outArg data is not being used by other outArgs.
+ if (yieldOp.getOperand(result.getResultNumber()) != outputArg)
+ return false;
+
+ return true;
+}
+
+namespace {
+
+struct DeduplicateAndRemoveDeadOperandsAndResults
+ : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ // Create a map from argument position in the original op to the argument
+ // position in the new op. If the argument is dropped it wont have an entry.
+ SmallVector<OpOperand *> droppedOpOperands;
+
+ // Information needed to build the new op.
+ SmallVector<Value> newInputOperands, newOutputOperands;
+ SmallVector<AffineMap> newIndexingMaps;
+
+ // Gather information about duplicate input operands.
+ llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
+ deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
+ newIndexingMaps);
+
+ // Gather information about the dropped outputs.
+ llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
+ deduplicateOutputOperands(genericOp, droppedOpOperands,
+ newOutputOperands, newIndexingMaps);
+
+ // Check if there is any change to operands.
+ if (newInputOperands.size() + newOutputOperands.size() ==
+ genericOp->getNumOperands())
+ return failure();
+
+ // Create the new op with the body being empty.
+ Location loc = genericOp.getLoc();
+ SmallVector<Type> newResultTypes;
+ for (Value v : newOutputOperands)
+ if (v.getType().isa<TensorType>())
+ newResultTypes.push_back(v.getType());
+ auto newOp = rewriter.create<GenericOp>(
+ loc, newResultTypes, newInputOperands, newOutputOperands,
+ rewriter.getAffineMapArrayAttr(newIndexingMaps),
+ genericOp.getIteratorTypes(), genericOp.getDocAttr(),
+ genericOp.getLibraryCallAttr(),
+ [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
+ return;
+ });
+ // Copy over unknown attributes. They might be load bearing for some flow.
+ ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
+ for (NamedAttribute kv : genericOp->getAttrs())
+ if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
+ newOp->setAttr(kv.getName(), kv.getValue());
+
+ // Fix up the payload of the canonicalized operation.
+ populateOpPayload(genericOp, newOp, origInsToNewInsPos,
+ origOutsToNewOutsPos, rewriter);
+
+ // Replace all live uses of the op.
+ SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
+ for (const auto &result : llvm::enumerate(genericOp.getResults())) {
+ auto it = origOutsToNewOutsPos.find(result.index());
+ if (it == origOutsToNewOutsPos.end())
+ continue;
+ replacementsVals[result.index()] = newOp.getResult(it->second);
+ }
+ rewriter.replaceOp(genericOp, replacementsVals);
+ return success();
+ }
+
+private:
+ // Deduplicate input operands, and return the
+ // - Mapping from operand position in the original op, to operand position in
+ // the canonicalized op.
+ // - The preserved input operands list (by reference).
+ llvm::SmallDenseMap<unsigned, unsigned>
+ deduplicateInputOperands(GenericOp genericOp,
+ SmallVector<OpOperand *> &droppedOpOperands,
+ SmallVector<Value> &newInputOperands,
+ SmallVector<AffineMap> &newIndexingMaps) const {
+ llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
+ llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
+ for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
+ OpOperand *inputOpOperand = en.value();
+ // Check if operand is dead and if dropping the indexing map makes the
+ // loops to shape computation invalid.
+ if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
+ // Add the current operands to the list of potentially droppable
+ // operands. If it cannot be dropped, this needs to be popped back.
+ droppedOpOperands.push_back(inputOpOperand);
+ if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
+ continue;
+ droppedOpOperands.pop_back();
+ }
+
+ // Check if this operand is a duplicate.
+ AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
+ auto it = dedupedInputs.find(
+ std::make_pair(inputOpOperand->get(), indexingMap));
+ if (it != dedupedInputs.end()) {
+ origToNewPos[en.index()] = it->second;
+ droppedOpOperands.push_back(inputOpOperand);
+ continue;
+ }
+
+ // This is a preserved argument.
+ origToNewPos[en.index()] = newInputOperands.size();
+ dedupedInputs[{inputOpOperand->get(), indexingMap}] =
+ newInputOperands.size();
+ newInputOperands.push_back(inputOpOperand->get());
+ newIndexingMaps.push_back(indexingMap);
+ }
+ return origToNewPos;
+ }
+
+ // Deduplicate output operands, and return the
+ // - Mapping from operand position in the original op, to operand position in
+ // the canonicalized op.
+ // - The preserved output operands list (by reference).
+ llvm::SmallDenseMap<unsigned, unsigned>
+ deduplicateOutputOperands(GenericOp genericOp,
+ SmallVector<OpOperand *> &droppedOpOperands,
+ SmallVector<Value> &newOutputOperands,
+ SmallVector<AffineMap> &newIndexingMaps) const {
+ llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
+ llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
+ dedupedOutpts;
+ // If the op doesnt have tensor semantics, keep all the outputs as
+ // preserved.
+ if (!genericOp.hasTensorSemantics()) {
+ for (const auto &en : llvm::enumerate(genericOp.getDpsInitOperands())) {
+ origToNewPos[en.index()] = newOutputOperands.size();
+ newOutputOperands.push_back(en.value()->get());
+ newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(en.value()));
+ }
+ return origToNewPos;
+ }
+ // Output argument can be dropped if the result has
+ // - no users, and
+ // - it is not used in the payload, and
+ // - the corresponding indexing maps are not needed for loop bound
+ // computation.
+ auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
+ for (const auto &outputOpOperand :
+ llvm::enumerate(genericOp.getDpsInitOperands())) {
+ OpResult result = genericOp.getTiedOpResult(outputOpOperand.value());
+ AffineMap indexingMap =
+ genericOp.getMatchingIndexingMap(outputOpOperand.value());
+ auto key = std::make_tuple(outputOpOperand.value()->get(), indexingMap,
+ yieldOp->getOperand(outputOpOperand.index()));
+ if (isResultValueDead(genericOp, result)) {
+ // Check if the opoperand can be dropped without affecting loop
+ // bound computation. Add the operand to the list of dropped op
+ // operand for checking. If it cannot be dropped, need to pop the
+ // value back.
+ droppedOpOperands.push_back(outputOpOperand.value());
+ if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
+ continue;
+ }
+ droppedOpOperands.pop_back();
+ }
+
+ if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) {
+ // The out operand can also be dropped if it is computed redundantly
+ // by another result, the conditions for that are
+ // - The same operand is used as the out operand
+ // - The same indexing map is used
+ // - The same yield value is used.
+ auto it = dedupedOutpts.find(key);
+ if (it != dedupedOutpts.end()) {
+ origToNewPos[outputOpOperand.index()] = it->second;
+ droppedOpOperands.push_back(outputOpOperand.value());
+ continue;
+ }
+ }
+
+ origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
+ dedupedOutpts[key] = newOutputOperands.size();
+ newOutputOperands.push_back(outputOpOperand.value()->get());
+ newIndexingMaps.push_back(
+ genericOp.getMatchingIndexingMap(outputOpOperand.value()));
+ }
+ return origToNewPos;
+ }
+
+ // Populate the body of the canonicalized operation.
+ void populateOpPayload(
+ GenericOp genericOp, GenericOp newOp,
+ const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
+ const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
+ PatternRewriter &rewriter) const {
+ // Merge the body of the original op with the new op.
+ Block *newOpBlock = &newOp.getRegion().front();
+ assert(newOpBlock->empty() && "expected new op to have an empty payload");
+ Block *origOpBlock = &genericOp.getRegion().front();
+ SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
+
+ // Replace all arguments in the original op, with arguments from the
+ // canonicalized op.
+ auto updateReplacements =
+ [&](OpOperandVector &origOperands, OpOperandVector &newOperands,
+ const llvm::SmallDenseMap<unsigned, unsigned> &map) {
+ for (const auto &origOperand : llvm::enumerate(origOperands)) {
+ auto it = map.find(origOperand.index());
+ if (it == map.end())
+ continue;
+ OpOperand *newOperand = newOperands[it->second];
+ replacements[origOperand.value()->getOperandNumber()] =
+ newOpBlock->getArgument(newOperand->getOperandNumber());
+ }
+ };
+
+ OpOperandVector origInputOperands = genericOp.getDpsInputOperands();
+ OpOperandVector newInputOperands = newOp.getDpsInputOperands();
+ updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
+
+ OpOperandVector origOutputOperands = genericOp.getDpsInitOperands();
+ OpOperandVector newOutputOperands = newOp.getDpsInitOperands();
+ updateReplacements(origOutputOperands, newOutputOperands,
+ origOutsToNewOutsPos);
+
+ // Drop the unused yield args.
+ if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
+ OpBuilder::InsertionGuard g(rewriter);
+ YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
+ rewriter.setInsertionPoint(origYieldOp);
+
+ SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
+ for (const auto &yieldOpOperands :
+ llvm::enumerate(origYieldOp.getValues())) {
+ auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
+ if (it == origOutsToNewOutsPos.end())
+ continue;
+ newYieldVals[it->second] = yieldOpOperands.value();
+ }
+ rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
+ }
+
+ rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
+ }
+};
+
+/// Remove unused cycles.
+/// We can remove unused cycle within a payload of generic region
+/// if these conditions are met:
+/// - Result from out operand is dead.
+/// - Block arg from out operand has a single use in the %cycle
+/// instruction.
+/// - Cycle has a single use and it is in yield.
+struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+
+ // If the op doesnt have tensor semantics, preserve the outputs as is.
+ if (!genericOp.hasTensorSemantics())
+ return failure();
+
+ bool hasRemovedCycles = false;
+ // Iterate over output operands and remove any unused cycles.
+ for (const auto &outputOpOperand :
+ llvm::enumerate(genericOp.getDpsInitOperands())) {
+
+ // Check that result from out operand is dead.
+ Value result = genericOp.getResult(outputOpOperand.index());
+ if (!result.use_empty())
+ continue;
+
+ // Check that outputArg has one use in cycle.
+ BlockArgument outputArg =
+ genericOp.getRegionOutputArgs()[outputOpOperand.index()];
+ if (!outputArg.hasOneUse())
+ continue;
+
+ // Check cycle has at most one use.
+ Operation *cycleOp = *outputArg.user_begin();
+ if (!cycleOp->hasOneUse())
+ continue;
+
+ // Check that the cycleUser is a yield.
+ Operation *cycleUserOp = *cycleOp->user_begin();
+ if (!isa<linalg::YieldOp>(cycleUserOp))
+ continue;
+
+ // Check that argIndex matches yieldIndex, else data is being used.
+ if (cycleUserOp->getOperand(outputOpOperand.index()) !=
+ cycleOp->getResult(0))
+ continue;
+
+ // Directly replace the cycle with the blockArg such that
+ // Deduplicate pattern can eliminate it along with unused yield.
+ rewriter.replaceOp(cycleOp, outputArg);
+ rewriter.updateRootInPlace(genericOp, [] {});
+ hasRemovedCycles = true;
+ }
+
+ if (hasRemovedCycles) {
+ return success();
+ }
+
+ return failure();
+ }
+};
+} // namespace
+
+void mlir::linalg::populateEraseUnusedOperandsAndResultsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults,
+ RemoveUnusedCycleInGenericOp>(patterns.getContext());
+}
}
// -----
-
-// CHECK-LABEL: func @remove_deadargs_generic_basic
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
-// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
-// CHECK-SAME: outs({{.*}} : tensor<?xf32>) {
-#map0 = affine_map<(d0) -> (d0)>
-func.func @remove_deadargs_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 7.0 : f32
- %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
- %1 = tensor.empty(%0) : tensor<?xf32>
- %2 = tensor.empty(%0) : tensor<?xf32>
- %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %1 : tensor<?xf32>, tensor<?xf32>) outs (%2:tensor<?xf32>) {
- ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
- %4 = arith.addf %arg1, %cst : f32
- linalg.yield %4 : f32
- } -> tensor<?xf32>
- return %3 : tensor<?xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @remove_deadargs_generic_mixedaccess
-// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
-// CHECK-NOT: ins
-// CHECK-SAME: outs({{.*}} : tensor<?x?xf32>) {
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-#map1 = affine_map<(d0, d1) -> (d1, d0)>
-func.func @remove_deadargs_generic_mixedaccess(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 0 : index
- %cst1 = arith.constant 7.0 : f32
- %cst2 = arith.constant 6.0 : f32
- %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
- %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
- %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
- %3 = tensor.empty(%1, %0) : tensor<?x?xf32>
- %4 = tensor.empty(%0, %1) : tensor<?x?xf32>
- %5 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%2, %3 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%4:tensor<?x?xf32>) {
- ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
- %6 = arith.divf %cst1, %cst2 : f32
- linalg.yield %6 : f32
- } -> tensor<?x?xf32>
- return %5 : tensor<?x?xf32>
-}
-
-// -----
// CHECK-LABEL: func @fold_fill_reshape()
func.func @fold_fill_reshape() -> tensor<6x4xf32> {
%zero = arith.constant 0.0 : f32
// RUN: mlir-opt -test-linalg-decompose-ops -cse -split-input-file %s | FileCheck %s
-// RUN: mlir-opt -test-linalg-decompose-ops -cse -canonicalize -split-input-file %s | FileCheck %s --check-prefix=CANONICALIZECHECK
+// RUN: mlir-opt -test-linalg-decompose-ops=remove-dead-args-and-results -cse -split-input-file %s | FileCheck %s --check-prefix=CANONICALIZECHECK
func.func @simple_op(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
-// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-erase-unused-operands-and-results | FileCheck %s
+
+// CHECK-LABEL: func @remove_deadargs_generic_basic
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
+// CHECK-SAME: outs({{.*}} : tensor<?xf32>) {
+#map0 = affine_map<(d0) -> (d0)>
+func.func @remove_deadargs_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 7.0 : f32
+ %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
+ %1 = tensor.empty(%0) : tensor<?xf32>
+ %2 = tensor.empty(%0) : tensor<?xf32>
+ %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %1 : tensor<?xf32>, tensor<?xf32>) outs (%2:tensor<?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+ %4 = arith.addf %arg1, %cst : f32
+ linalg.yield %4 : f32
+ } -> tensor<?xf32>
+ return %3 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @remove_deadargs_generic_mixedaccess
+// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+// CHECK-NOT: ins
+// CHECK-SAME: outs({{.*}} : tensor<?x?xf32>) {
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1, d0)>
+func.func @remove_deadargs_generic_mixedaccess(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 0 : index
+ %cst1 = arith.constant 7.0 : f32
+ %cst2 = arith.constant 6.0 : f32
+ %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
+ %3 = tensor.empty(%1, %0) : tensor<?x?xf32>
+ %4 = tensor.empty(%0, %1) : tensor<?x?xf32>
+ %5 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%2, %3 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%4:tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+ %6 = arith.divf %cst1, %cst2 : f32
+ linalg.yield %6 : f32
+ } -> tensor<?x?xf32>
+ return %5 : tensor<?x?xf32>
+}
+
+// -----
// Test case: Most basic case. Adding a vector to itself.
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file -canonicalize | FileCheck %s --check-prefix=CANONICALIZE
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#binary2Dpointwise = {
} -> tensor<?x?xf32>
return %6 : tensor<?x?xf32>
}
+
// CHECK-LABEL: func @test_fusion_limit
// CHECK-SAME: %[[ARG0:[a-zA-z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG5:[a-zA-z0-9_]+]]: tensor<?x?xf32>
// CHECK: %[[OP1:.+]] = linalg.generic {{.+}} ins(%[[ARG2]], %[[ARG3]]
// CHECK: %[[OP2:.+]] = linalg.generic {{.+}} ins(%[[ARG4]], %[[ARG5]]
-// CHECK: %[[OP3:.+]]:2 = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]]
-// CHECK: return %[[OP3]]#1
-
-// CANONICALIZE-LABEL: func @test_fusion_limit
-// CANONICALIZE-SAME: %[[ARG0:[a-zA-z0-9_]+]]: tensor<?x?xf32>
-// CANONICALIZE-SAME: %[[ARG1:[a-zA-z0-9_]+]]: tensor<?x?xf32>
-// CANONICALIZE-SAME: %[[ARG2:[a-zA-z0-9_]+]]: tensor<?x?xf32>
-// CANONICALIZE-SAME: %[[ARG3:[a-zA-z0-9_]+]]: tensor<?x?xf32>
-// CANONICALIZE-SAME: %[[ARG4:[a-zA-z0-9_]+]]: tensor<?x?xf32>
-// CANONICALIZE-SAME: %[[ARG5:[a-zA-z0-9_]+]]: tensor<?x?xf32>
-// CANONICALIZE: %[[OP1:.+]] = linalg.generic {{.+}} ins(%[[ARG2]], %[[ARG3]]
-// CANONICALIZE: %[[OP2:.+]] = linalg.generic {{.+}} ins(%[[ARG4]], %[[ARG5]]
-// CANONICALIZE: %[[OP3:.+]] = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]]
-// CANONICALIZE: return %[[OP3]]
+// CHECK: %[[OP3:.+]] = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]]
+// CHECK: return %[[OP3]]
: public PassWrapper<TestLinalgDecomposeOps, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDecomposeOps)
- TestLinalgDecomposeOps() = default;
- TestLinalgDecomposeOps(const TestLinalgDecomposeOps &pass) = default;
+ TestLinalgDecomposeOps(){};
+ TestLinalgDecomposeOps(const TestLinalgDecomposeOps &pass){};
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect>();
}
return "Test Linalg decomposition patterns";
}
+ Option<bool> removeDeadArgsAndResults{
+ *this, "remove-dead-args-and-results",
+ llvm::cl::desc("Test patterns to erase unused operands and results"),
+ llvm::cl::init(false)};
+
void runOnOperation() override {
MLIRContext *context = &this->getContext();
RewritePatternSet decompositionPatterns(context);
- linalg::populateDecomposeLinalgOpsPattern(decompositionPatterns);
+ linalg::populateDecomposeLinalgOpsPattern(decompositionPatterns,
+ removeDeadArgsAndResults);
if (failed(applyPatternsAndFoldGreedily(
getOperation(), std::move(decompositionPatterns)))) {
return signalPassFailure();
llvm::cl::desc(
"Test patterns to swap tensor.extract_slice(linalg.fill())"),
llvm::cl::init(false)};
+ Option<bool> testEraseUnusedOperandsAndResults{
+ *this, "test-erase-unused-operands-and-results",
+ llvm::cl::desc("Test patterns to erase unused operands and results"),
+ llvm::cl::init(false)};
};
} // namespace
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
+static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) {
+ RewritePatternSet patterns(funcOp.getContext());
+ populateEraseUnusedOperandsAndResultsPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnOperation() {
if (testPatterns)
return applyBubbleUpExtractSliceOpPattern(getOperation());
if (testSwapExtractSliceWithFill)
return applySwapExtractSliceWithFillPattern(getOperation());
+ if (testEraseUnusedOperandsAndResults)
+ return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
}
namespace mlir {