[mlir][Linalg] Move patterns to remove dead arguments and results out of canonicaliza...
authorMahesh Ravishankar <ravishankarm@google.com>
Wed, 16 Nov 2022 02:51:53 +0000 (02:51 +0000)
committerMahesh Ravishankar <ravishankarm@google.com>
Wed, 16 Nov 2022 16:00:43 +0000 (16:00 +0000)
The patterns to remove dead arguments and results of `linalg.generic`
operations are not necessarily canonicalizations. Instead a new entry
point `populateEraseUnusedOperandsAndResults` is added to allow using
these patterns when needed. The transformations that rely on this
pattern for cleanup now include these patterns explicitly.

Differential Revision: https://reviews.llvm.org/D138085

12 files changed:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp [new file with mode: 0644]
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/decompose-ops.mlir
mlir/test/Dialect/Linalg/erase-unused-operands-and-results.mlir [moved from mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir with 90% similarity]
mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

index dcdc532..40bf31c 100644 (file)
@@ -48,7 +48,12 @@ void populatePadTensorTilingPatterns(RewritePatternSet &patterns,
 
 /// Populate patterns for splitting a `LinalgOp` with multiple statements within
 /// its payload into multiple `GenericOp` that have a single statement.
-void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns);
+/// The option `removeDeadArgsAndResults` adds patterns to remove dead arguments
+/// and results from the generated decomposed ops. This is default `true` since
+/// the core decomposition patterns relies on these clean up patterns. It is set
+/// to false only for testing purposes.
+void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns,
+                                       bool removeDeadArgsAndResults = true);
 
 /// Populate patterns for vectorizing low-D convolution ops. This is a step in
 /// progressive lowering for convolution ops, it assume high-D convolution ops
@@ -76,6 +81,10 @@ void populateElementwiseOpsFusionPatterns(
     RewritePatternSet &patterns,
     const ControlFusionFn &controlElementwiseOpFusion);
 
+/// Pattern to remove dead operands and results of `linalg.generic` operations.
+/// This is effectively DCE for a linalg op.
+void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
+
 /// Function type to control generic op dimension collapsing. It is expected
 /// to return an array of `ReassociationIndices` representing dimensions that
 /// should be merged.
index 63e2ef3..ec1c603 100644 (file)
@@ -871,285 +871,10 @@ void GenericOp::getEffects(
                         getDpsInputOperands(), getDpsInitOperands());
 }
 
-static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
-  if (!result.use_empty())
-    return false;
-  // If out operand not used in payload, we can drop it.
-  OpOperand *outputOpOperand =
-      genericOp.getDpsInitOperand(result.getResultNumber());
-  if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
-    return true;
-
-  // The out operand that is part of a payload can be dropped if
-  // these conditions are met:
-  // - Result from out operand is dead.
-  // - User of arg is yield.
-  // - outArg data is not being used by other outArgs.
-
-  // Check block arg and cycle from out operand has a single use.
-  BlockArgument outputArg =
-      genericOp.getRegionOutputArgs()[result.getResultNumber()];
-  if (!outputArg.hasOneUse())
-    return false;
-  Operation *argUserOp = *outputArg.user_begin();
-
-  // Check argUser has no other use.
-  if (!argUserOp->use_empty())
-    return false;
-
-  // Check that argUser is a yield.
-  auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
-  if (!yieldOp)
-    return false;
-
-  // Check outArg data is not being used by other outArgs.
-  if (yieldOp.getOperand(result.getResultNumber()) != outputArg)
-    return false;
-
-  return true;
-}
-
 LogicalResult GenericOp::verify() { return success(); }
 
 namespace {
 
-struct DeduplicateAndRemoveDeadOperandsAndResults
-    : public OpRewritePattern<GenericOp> {
-  using OpRewritePattern<GenericOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(GenericOp genericOp,
-                                PatternRewriter &rewriter) const override {
-    // Create a map from argument position in the original op to the argument
-    // position in the new op. If the argument is dropped it wont have an entry.
-    SmallVector<OpOperand *> droppedOpOperands;
-
-    // Information needed to build the new op.
-    SmallVector<Value> newInputOperands, newOutputOperands;
-    SmallVector<AffineMap> newIndexingMaps;
-
-    // Gather information about duplicate input operands.
-    llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
-        deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
-                                 newIndexingMaps);
-
-    // Gather information about the dropped outputs.
-    llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
-        deduplicateOutputOperands(genericOp, droppedOpOperands,
-                                  newOutputOperands, newIndexingMaps);
-
-    // Check if there is any change to operands.
-    if (newInputOperands.size() + newOutputOperands.size() ==
-        genericOp->getNumOperands())
-      return failure();
-
-    // Create the new op with the body being empty.
-    Location loc = genericOp.getLoc();
-    SmallVector<Type> newResultTypes;
-    for (Value v : newOutputOperands)
-      if (v.getType().isa<TensorType>())
-        newResultTypes.push_back(v.getType());
-    auto newOp = rewriter.create<GenericOp>(
-        loc, newResultTypes, newInputOperands, newOutputOperands,
-        rewriter.getAffineMapArrayAttr(newIndexingMaps),
-        genericOp.getIteratorTypes(), genericOp.getDocAttr(),
-        genericOp.getLibraryCallAttr(),
-        [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
-          return;
-        });
-    // Copy over unknown attributes. They might be load bearing for some flow.
-    ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
-    for (NamedAttribute kv : genericOp->getAttrs())
-      if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
-        newOp->setAttr(kv.getName(), kv.getValue());
-
-    // Fix up the payload of the canonicalized operation.
-    populateOpPayload(genericOp, newOp, origInsToNewInsPos,
-                      origOutsToNewOutsPos, rewriter);
-
-    // Replace all live uses of the op.
-    SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
-    for (const auto &result : llvm::enumerate(genericOp.getResults())) {
-      auto it = origOutsToNewOutsPos.find(result.index());
-      if (it == origOutsToNewOutsPos.end())
-        continue;
-      replacementsVals[result.index()] = newOp.getResult(it->second);
-    }
-    rewriter.replaceOp(genericOp, replacementsVals);
-    return success();
-  }
-
-private:
-  // Deduplicate input operands, and return the
-  // - Mapping from operand position in the original op, to operand position in
-  // the canonicalized op.
-  // - The preserved input operands list (by reference).
-  llvm::SmallDenseMap<unsigned, unsigned>
-  deduplicateInputOperands(GenericOp genericOp,
-                           SmallVector<OpOperand *> &droppedOpOperands,
-                           SmallVector<Value> &newInputOperands,
-                           SmallVector<AffineMap> &newIndexingMaps) const {
-    llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
-    llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
-    for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
-      OpOperand *inputOpOperand = en.value();
-      // Check if operand is dead and if dropping the indexing map makes the
-      // loops to shape computation invalid.
-      if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
-        // Add the current operands to the list of potentially droppable
-        // operands. If it cannot be dropped, this needs to be popped back.
-        droppedOpOperands.push_back(inputOpOperand);
-        if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
-          continue;
-        droppedOpOperands.pop_back();
-      }
-
-      // Check if this operand is a duplicate.
-      AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
-      auto it = dedupedInputs.find(
-          std::make_pair(inputOpOperand->get(), indexingMap));
-      if (it != dedupedInputs.end()) {
-        origToNewPos[en.index()] = it->second;
-        droppedOpOperands.push_back(inputOpOperand);
-        continue;
-      }
-
-      // This is a preserved argument.
-      origToNewPos[en.index()] = newInputOperands.size();
-      dedupedInputs[{inputOpOperand->get(), indexingMap}] =
-          newInputOperands.size();
-      newInputOperands.push_back(inputOpOperand->get());
-      newIndexingMaps.push_back(indexingMap);
-    }
-    return origToNewPos;
-  }
-
-  // Deduplicate output operands, and return the
-  // - Mapping from operand position in the original op, to operand position in
-  // the canonicalized op.
-  // - The preserved output operands list (by reference).
-  llvm::SmallDenseMap<unsigned, unsigned>
-  deduplicateOutputOperands(GenericOp genericOp,
-                            SmallVector<OpOperand *> &droppedOpOperands,
-                            SmallVector<Value> &newOutputOperands,
-                            SmallVector<AffineMap> &newIndexingMaps) const {
-    llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
-    llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
-        dedupedOutpts;
-    // If the op doesnt have tensor semantics, keep all the outputs as
-    // preserved.
-    if (!genericOp.hasTensorSemantics()) {
-      for (const auto &en : llvm::enumerate(genericOp.getDpsInitOperands())) {
-        origToNewPos[en.index()] = newOutputOperands.size();
-        newOutputOperands.push_back(en.value()->get());
-        newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(en.value()));
-      }
-      return origToNewPos;
-    }
-    // Output argument can be dropped if the result has
-    // - no users, and
-    // - it is not used in the payload, and
-    // - the corresponding indexing maps are not needed for loop bound
-    //   computation.
-    auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
-    for (const auto &outputOpOperand :
-         llvm::enumerate(genericOp.getDpsInitOperands())) {
-      OpResult result = genericOp.getTiedOpResult(outputOpOperand.value());
-      AffineMap indexingMap =
-          genericOp.getMatchingIndexingMap(outputOpOperand.value());
-      auto key = std::make_tuple(outputOpOperand.value()->get(), indexingMap,
-                                 yieldOp->getOperand(outputOpOperand.index()));
-      if (isResultValueDead(genericOp, result)) {
-        // Check if the opoperand can be dropped without affecting loop
-        // bound computation. Add the operand to the list of dropped op
-        // operand for checking. If it cannot be dropped, need to pop the
-        // value back.
-        droppedOpOperands.push_back(outputOpOperand.value());
-        if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
-          continue;
-        }
-        droppedOpOperands.pop_back();
-      }
-
-      if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) {
-        // The out operand can also be dropped if it is computed redundantly
-        // by another result, the conditions for that are
-        // - The same operand is used as the out operand
-        // - The same indexing map is used
-        // - The same yield value is used.
-        auto it = dedupedOutpts.find(key);
-        if (it != dedupedOutpts.end()) {
-          origToNewPos[outputOpOperand.index()] = it->second;
-          droppedOpOperands.push_back(outputOpOperand.value());
-          continue;
-        }
-      }
-
-      origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
-      dedupedOutpts[key] = newOutputOperands.size();
-      newOutputOperands.push_back(outputOpOperand.value()->get());
-      newIndexingMaps.push_back(
-          genericOp.getMatchingIndexingMap(outputOpOperand.value()));
-    }
-    return origToNewPos;
-  }
-
-  // Populate the body of the canonicalized operation.
-  void populateOpPayload(
-      GenericOp genericOp, GenericOp newOp,
-      const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
-      const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
-      PatternRewriter &rewriter) const {
-    // Merge the body of the original op with the new op.
-    Block *newOpBlock = &newOp.getRegion().front();
-    assert(newOpBlock->empty() && "expected new op to have an empty payload");
-    Block *origOpBlock = &genericOp.getRegion().front();
-    SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
-
-    // Replace all arguments in the original op, with arguments from the
-    // canonicalized op.
-    auto updateReplacements =
-        [&](OpOperandVector &origOperands, OpOperandVector &newOperands,
-            const llvm::SmallDenseMap<unsigned, unsigned> &map) {
-          for (const auto &origOperand : llvm::enumerate(origOperands)) {
-            auto it = map.find(origOperand.index());
-            if (it == map.end())
-              continue;
-            OpOperand *newOperand = newOperands[it->second];
-            replacements[origOperand.value()->getOperandNumber()] =
-                newOpBlock->getArgument(newOperand->getOperandNumber());
-          }
-        };
-
-    OpOperandVector origInputOperands = genericOp.getDpsInputOperands();
-    OpOperandVector newInputOperands = newOp.getDpsInputOperands();
-    updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
-
-    OpOperandVector origOutputOperands = genericOp.getDpsInitOperands();
-    OpOperandVector newOutputOperands = newOp.getDpsInitOperands();
-    updateReplacements(origOutputOperands, newOutputOperands,
-                       origOutsToNewOutsPos);
-
-    // Drop the unused yield args.
-    if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
-      OpBuilder::InsertionGuard g(rewriter);
-      YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
-      rewriter.setInsertionPoint(origYieldOp);
-
-      SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
-      for (const auto &yieldOpOperands :
-           llvm::enumerate(origYieldOp.getValues())) {
-        auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
-        if (it == origOutsToNewOutsPos.end())
-          continue;
-        newYieldVals[it->second] = yieldOpOperands.value();
-      }
-      rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
-    }
-
-    rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
-  }
-};
-
 /// Remove generic operations (on tensors) that are just copying
 /// the values from inputs to the results. Requirements are
 /// 1) All iterator types are parallel
@@ -1227,74 +952,11 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
   }
 };
 
-/// Remove unused cycles.
-/// We can remove unused cycle within a payload of generic region
-/// if these conditions are met:
-/// - Result from out operand is dead.
-/// - Block arg from out operand has a single use in the %cycle
-/// instruction.
-/// - Cycle has a single use and it is in yield.
-struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
-  using OpRewritePattern<GenericOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(GenericOp genericOp,
-                                PatternRewriter &rewriter) const override {
-
-    // If the op doesnt have tensor semantics, preserve the outputs as is.
-    if (!genericOp.hasTensorSemantics())
-      return failure();
-
-    bool hasRemovedCycles = false;
-    // Iterate over output operands and remove any unused cycles.
-    for (const auto &outputOpOperand :
-         llvm::enumerate(genericOp.getDpsInitOperands())) {
-
-      // Check that result from out operand is dead.
-      Value result = genericOp.getResult(outputOpOperand.index());
-      if (!result.use_empty())
-        continue;
-
-      // Check that outputArg has one use in cycle.
-      BlockArgument outputArg =
-          genericOp.getRegionOutputArgs()[outputOpOperand.index()];
-      if (!outputArg.hasOneUse())
-        continue;
-
-      // Check cycle has at most one use.
-      Operation *cycleOp = *outputArg.user_begin();
-      if (!cycleOp->hasOneUse())
-        continue;
-
-      // Check that the cycleUser is a yield.
-      Operation *cycleUserOp = *cycleOp->user_begin();
-      if (!isa<linalg::YieldOp>(cycleUserOp))
-        continue;
-
-      // Check that argIndex matches yieldIndex, else data is being used.
-      if (cycleUserOp->getOperand(outputOpOperand.index()) !=
-          cycleOp->getResult(0))
-        continue;
-
-      // Directly replace the cycle with the blockArg such that
-      // Deduplicate pattern can eliminate it along with unused yield.
-      rewriter.replaceOp(cycleOp, outputArg);
-      rewriter.updateRootInPlace(genericOp, [] {});
-      hasRemovedCycles = true;
-    }
-
-    if (hasRemovedCycles) {
-      return success();
-    }
-
-    return failure();
-  }
-};
 } // namespace
 
 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<DeduplicateAndRemoveDeadOperandsAndResults,
-              EraseIdentityGenericOp, RemoveUnusedCycleInGenericOp>(context);
+  results.add<EraseIdentityGenericOp>(context);
 }
 
 LogicalResult GenericOp::fold(ArrayRef<Attribute>,
index c69f55c..8809b25 100644 (file)
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   DropUnitDims.cpp
   ElementwiseOpFusion.cpp
   ElementwiseToLinalg.cpp
+  EraseUnusedOperandsAndResults.cpp
   FusePadOpWithLinalgProducer.cpp
   Fusion.cpp
   FusionOnTensors.cpp
index bbba218..15b3304 100644 (file)
@@ -376,6 +376,9 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
 }
 
 void mlir::linalg::populateDecomposeLinalgOpsPattern(
-    RewritePatternSet &patterns) {
+    RewritePatternSet &patterns, bool removeDeadArgsAndResults) {
   patterns.insert<DecomposeLinalgOp>(patterns.getContext());
+  // Add the patterns to clean up the dead operands and results.
+  if (removeDeadArgsAndResults)
+    populateEraseUnusedOperandsAndResultsPatterns(patterns);
 }
index 5b0d9bf..b486708 100644 (file)
@@ -1780,6 +1780,8 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
   patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
   patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
                RemoveOutsDependency>(context);
+  // Add the patterns that clean up dead operands and results.
+  populateEraseUnusedOperandsAndResultsPatterns(patterns);
 }
 
 void mlir::linalg::populateCollapseDimensions(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
new file mode 100644 (file)
index 0000000..87df83f
--- /dev/null
@@ -0,0 +1,362 @@
+//===- EraseUnusedOperandsAndResults.cpp ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+/// Return `true` if the `result` of an operation `genericOp` is dead.
+static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
+  if (!result.use_empty())
+    return false;
+  // If out operand not used in payload, we can drop it.
+  OpOperand *outputOpOperand =
+      genericOp.getDpsInitOperand(result.getResultNumber());
+  if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
+    return true;
+
+  // The out operand that is part of a payload can be dropped if
+  // these conditions are met:
+  // - Result from out operand is dead.
+  // - User of arg is yield.
+  // - outArg data is not being used by other outArgs.
+
+  // Check block arg and cycle from out operand has a single use.
+  BlockArgument outputArg =
+      genericOp.getRegionOutputArgs()[result.getResultNumber()];
+  if (!outputArg.hasOneUse())
+    return false;
+  Operation *argUserOp = *outputArg.user_begin();
+
+  // Check argUser has no other use.
+  if (!argUserOp->use_empty())
+    return false;
+
+  // Check that argUser is a yield.
+  auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
+  if (!yieldOp)
+    return false;
+
+  // Check outArg data is not being used by other outArgs.
+  if (yieldOp.getOperand(result.getResultNumber()) != outputArg)
+    return false;
+
+  return true;
+}
+
+namespace {
+
+struct DeduplicateAndRemoveDeadOperandsAndResults
+    : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    // Create a map from argument position in the original op to the argument
+    // position in the new op. If the argument is dropped it wont have an entry.
+    SmallVector<OpOperand *> droppedOpOperands;
+
+    // Information needed to build the new op.
+    SmallVector<Value> newInputOperands, newOutputOperands;
+    SmallVector<AffineMap> newIndexingMaps;
+
+    // Gather information about duplicate input operands.
+    llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
+        deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
+                                 newIndexingMaps);
+
+    // Gather information about the dropped outputs.
+    llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
+        deduplicateOutputOperands(genericOp, droppedOpOperands,
+                                  newOutputOperands, newIndexingMaps);
+
+    // Check if there is any change to operands.
+    if (newInputOperands.size() + newOutputOperands.size() ==
+        genericOp->getNumOperands())
+      return failure();
+
+    // Create the new op with the body being empty.
+    Location loc = genericOp.getLoc();
+    SmallVector<Type> newResultTypes;
+    for (Value v : newOutputOperands)
+      if (v.getType().isa<TensorType>())
+        newResultTypes.push_back(v.getType());
+    auto newOp = rewriter.create<GenericOp>(
+        loc, newResultTypes, newInputOperands, newOutputOperands,
+        rewriter.getAffineMapArrayAttr(newIndexingMaps),
+        genericOp.getIteratorTypes(), genericOp.getDocAttr(),
+        genericOp.getLibraryCallAttr(),
+        [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
+          return;
+        });
+    // Copy over unknown attributes. They might be load bearing for some flow.
+    ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
+    for (NamedAttribute kv : genericOp->getAttrs())
+      if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
+        newOp->setAttr(kv.getName(), kv.getValue());
+
+    // Fix up the payload of the canonicalized operation.
+    populateOpPayload(genericOp, newOp, origInsToNewInsPos,
+                      origOutsToNewOutsPos, rewriter);
+
+    // Replace all live uses of the op.
+    SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
+    for (const auto &result : llvm::enumerate(genericOp.getResults())) {
+      auto it = origOutsToNewOutsPos.find(result.index());
+      if (it == origOutsToNewOutsPos.end())
+        continue;
+      replacementsVals[result.index()] = newOp.getResult(it->second);
+    }
+    rewriter.replaceOp(genericOp, replacementsVals);
+    return success();
+  }
+
+private:
+  // Deduplicate input operands, and return the
+  // - Mapping from operand position in the original op, to operand position in
+  // the canonicalized op.
+  // - The preserved input operands list (by reference).
+  llvm::SmallDenseMap<unsigned, unsigned>
+  deduplicateInputOperands(GenericOp genericOp,
+                           SmallVector<OpOperand *> &droppedOpOperands,
+                           SmallVector<Value> &newInputOperands,
+                           SmallVector<AffineMap> &newIndexingMaps) const {
+    llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
+    llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
+    for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
+      OpOperand *inputOpOperand = en.value();
+      // Check if operand is dead and if dropping the indexing map makes the
+      // loops to shape computation invalid.
+      if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
+        // Add the current operands to the list of potentially droppable
+        // operands. If it cannot be dropped, this needs to be popped back.
+        droppedOpOperands.push_back(inputOpOperand);
+        if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
+          continue;
+        droppedOpOperands.pop_back();
+      }
+
+      // Check if this operand is a duplicate.
+      AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
+      auto it = dedupedInputs.find(
+          std::make_pair(inputOpOperand->get(), indexingMap));
+      if (it != dedupedInputs.end()) {
+        origToNewPos[en.index()] = it->second;
+        droppedOpOperands.push_back(inputOpOperand);
+        continue;
+      }
+
+      // This is a preserved argument.
+      origToNewPos[en.index()] = newInputOperands.size();
+      dedupedInputs[{inputOpOperand->get(), indexingMap}] =
+          newInputOperands.size();
+      newInputOperands.push_back(inputOpOperand->get());
+      newIndexingMaps.push_back(indexingMap);
+    }
+    return origToNewPos;
+  }
+
+  // Deduplicate output operands, and return the
+  // - Mapping from operand position in the original op, to operand position in
+  // the canonicalized op.
+  // - The preserved output operands list (by reference).
+  llvm::SmallDenseMap<unsigned, unsigned>
+  deduplicateOutputOperands(GenericOp genericOp,
+                            SmallVector<OpOperand *> &droppedOpOperands,
+                            SmallVector<Value> &newOutputOperands,
+                            SmallVector<AffineMap> &newIndexingMaps) const {
+    llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
+    llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
+        dedupedOutpts;
+    // If the op doesnt have tensor semantics, keep all the outputs as
+    // preserved.
+    if (!genericOp.hasTensorSemantics()) {
+      for (const auto &en : llvm::enumerate(genericOp.getDpsInitOperands())) {
+        origToNewPos[en.index()] = newOutputOperands.size();
+        newOutputOperands.push_back(en.value()->get());
+        newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(en.value()));
+      }
+      return origToNewPos;
+    }
+    // Output argument can be dropped if the result has
+    // - no users, and
+    // - it is not used in the payload, and
+    // - the corresponding indexing maps are not needed for loop bound
+    //   computation.
+    auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
+    for (const auto &outputOpOperand :
+         llvm::enumerate(genericOp.getDpsInitOperands())) {
+      OpResult result = genericOp.getTiedOpResult(outputOpOperand.value());
+      AffineMap indexingMap =
+          genericOp.getMatchingIndexingMap(outputOpOperand.value());
+      auto key = std::make_tuple(outputOpOperand.value()->get(), indexingMap,
+                                 yieldOp->getOperand(outputOpOperand.index()));
+      if (isResultValueDead(genericOp, result)) {
+        // Check if the opoperand can be dropped without affecting loop
+        // bound computation. Add the operand to the list of dropped op
+        // operand for checking. If it cannot be dropped, need to pop the
+        // value back.
+        droppedOpOperands.push_back(outputOpOperand.value());
+        if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
+          continue;
+        }
+        droppedOpOperands.pop_back();
+      }
+
+      if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) {
+        // The out operand can also be dropped if it is computed redundantly
+        // by another result, the conditions for that are
+        // - The same operand is used as the out operand
+        // - The same indexing map is used
+        // - The same yield value is used.
+        auto it = dedupedOutpts.find(key);
+        if (it != dedupedOutpts.end()) {
+          origToNewPos[outputOpOperand.index()] = it->second;
+          droppedOpOperands.push_back(outputOpOperand.value());
+          continue;
+        }
+      }
+
+      origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
+      dedupedOutpts[key] = newOutputOperands.size();
+      newOutputOperands.push_back(outputOpOperand.value()->get());
+      newIndexingMaps.push_back(
+          genericOp.getMatchingIndexingMap(outputOpOperand.value()));
+    }
+    return origToNewPos;
+  }
+
+  // Populate the body of the canonicalized operation.
+  void populateOpPayload(
+      GenericOp genericOp, GenericOp newOp,
+      const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
+      const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
+      PatternRewriter &rewriter) const {
+    // Merge the body of the original op with the new op.
+    Block *newOpBlock = &newOp.getRegion().front();
+    assert(newOpBlock->empty() && "expected new op to have an empty payload");
+    Block *origOpBlock = &genericOp.getRegion().front();
+    SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
+
+    // Replace all arguments in the original op, with arguments from the
+    // canonicalized op.
+    auto updateReplacements =
+        [&](OpOperandVector &origOperands, OpOperandVector &newOperands,
+            const llvm::SmallDenseMap<unsigned, unsigned> &map) {
+          for (const auto &origOperand : llvm::enumerate(origOperands)) {
+            auto it = map.find(origOperand.index());
+            if (it == map.end())
+              continue;
+            OpOperand *newOperand = newOperands[it->second];
+            replacements[origOperand.value()->getOperandNumber()] =
+                newOpBlock->getArgument(newOperand->getOperandNumber());
+          }
+        };
+
+    OpOperandVector origInputOperands = genericOp.getDpsInputOperands();
+    OpOperandVector newInputOperands = newOp.getDpsInputOperands();
+    updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
+
+    OpOperandVector origOutputOperands = genericOp.getDpsInitOperands();
+    OpOperandVector newOutputOperands = newOp.getDpsInitOperands();
+    updateReplacements(origOutputOperands, newOutputOperands,
+                       origOutsToNewOutsPos);
+
+    // Drop the unused yield args.
+    if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
+      OpBuilder::InsertionGuard g(rewriter);
+      YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
+      rewriter.setInsertionPoint(origYieldOp);
+
+      SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
+      for (const auto &yieldOpOperands :
+           llvm::enumerate(origYieldOp.getValues())) {
+        auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
+        if (it == origOutsToNewOutsPos.end())
+          continue;
+        newYieldVals[it->second] = yieldOpOperands.value();
+      }
+      rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
+    }
+
+    rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
+  }
+};
+
+/// Remove unused cycles.
+/// We can remove unused cycle within a payload of generic region
+/// if these conditions are met:
+/// - Result from out operand is dead.
+/// - Block arg from out operand has a single use in the %cycle
+/// instruction.
+/// - Cycle has a single use and it is in yield.
+struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+
+    // If the op doesnt have tensor semantics, preserve the outputs as is.
+    if (!genericOp.hasTensorSemantics())
+      return failure();
+
+    bool hasRemovedCycles = false;
+    // Iterate over output operands and remove any unused cycles.
+    for (const auto &outputOpOperand :
+         llvm::enumerate(genericOp.getDpsInitOperands())) {
+
+      // Check that result from out operand is dead.
+      Value result = genericOp.getResult(outputOpOperand.index());
+      if (!result.use_empty())
+        continue;
+
+      // Check that outputArg has one use in cycle.
+      BlockArgument outputArg =
+          genericOp.getRegionOutputArgs()[outputOpOperand.index()];
+      if (!outputArg.hasOneUse())
+        continue;
+
+      // Check cycle has at most one use.
+      Operation *cycleOp = *outputArg.user_begin();
+      if (!cycleOp->hasOneUse())
+        continue;
+
+      // Check that the cycleUser is a yield.
+      Operation *cycleUserOp = *cycleOp->user_begin();
+      if (!isa<linalg::YieldOp>(cycleUserOp))
+        continue;
+
+      // Check that argIndex matches yieldIndex, else data is being used.
+      if (cycleUserOp->getOperand(outputOpOperand.index()) !=
+          cycleOp->getResult(0))
+        continue;
+
+      // Directly replace the cycle with the blockArg such that
+      // Deduplicate pattern can eliminate it along with unused yield.
+      rewriter.replaceOp(cycleOp, outputArg);
+      rewriter.updateRootInPlace(genericOp, [] {});
+      hasRemovedCycles = true;
+    }
+
+    if (hasRemovedCycles) {
+      return success();
+    }
+
+    return failure();
+  }
+};
+} // namespace
+
+void mlir::linalg::populateEraseUnusedOperandsAndResultsPatterns(
+    RewritePatternSet &patterns) {
+  patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults,
+                  RemoveUnusedCycleInGenericOp>(patterns.getContext());
+}
index 1fe5fe5..55013c4 100644 (file)
@@ -296,54 +296,6 @@ func.func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
 }
 
 // -----
-
-// CHECK-LABEL: func @remove_deadargs_generic_basic
-//  CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
-//       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
-//  CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
-//  CHECK-SAME: outs({{.*}} : tensor<?xf32>) {
-#map0 = affine_map<(d0) -> (d0)>
-func.func @remove_deadargs_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 7.0 : f32
-  %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
-  %1 = tensor.empty(%0) : tensor<?xf32>
-  %2 = tensor.empty(%0) : tensor<?xf32>
-  %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %1 : tensor<?xf32>, tensor<?xf32>) outs (%2:tensor<?xf32>) {
-  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
-    %4 = arith.addf  %arg1, %cst : f32
-       linalg.yield %4 : f32
-  } -> tensor<?xf32>
-  return %3 : tensor<?xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @remove_deadargs_generic_mixedaccess
-//       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
-//   CHECK-NOT: ins
-//  CHECK-SAME: outs({{.*}} : tensor<?x?xf32>) {
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-#map1 = affine_map<(d0, d1) -> (d1, d0)>
-func.func @remove_deadargs_generic_mixedaccess(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 0 : index
-  %cst1 = arith.constant 7.0 : f32
-  %cst2 = arith.constant 6.0 : f32
-  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
-  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
-  %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
-  %3 = tensor.empty(%1, %0) : tensor<?x?xf32>
-  %4 = tensor.empty(%0, %1) : tensor<?x?xf32>
-  %5 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%2, %3 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%4:tensor<?x?xf32>) {
-  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
-    %6 = arith.divf  %cst1, %cst2 : f32
-       linalg.yield %6 : f32
-  } -> tensor<?x?xf32>
-  return %5 : tensor<?x?xf32>
-}
-
-// -----
 // CHECK-LABEL: func @fold_fill_reshape()
 func.func @fold_fill_reshape() -> tensor<6x4xf32> {
   %zero = arith.constant 0.0 : f32
index b562715..df29e63 100644 (file)
@@ -1,5 +1,5 @@
 // RUN: mlir-opt -test-linalg-decompose-ops -cse -split-input-file %s | FileCheck %s
-// RUN: mlir-opt -test-linalg-decompose-ops -cse -canonicalize -split-input-file %s | FileCheck %s --check-prefix=CANONICALIZECHECK
+// RUN: mlir-opt -test-linalg-decompose-ops=remove-dead-args-and-results -cse -split-input-file %s | FileCheck %s --check-prefix=CANONICALIZECHECK
 
 func.func @simple_op(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>)
     -> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -1,4 +1,52 @@
-// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-erase-unused-operands-and-results | FileCheck %s
+
+// CHECK-LABEL: func @remove_deadargs_generic_basic
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+//       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+//  CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
+//  CHECK-SAME: outs({{.*}} : tensor<?xf32>) {
+#map0 = affine_map<(d0) -> (d0)>
+func.func @remove_deadargs_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 7.0 : f32
+  %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
+  %1 = tensor.empty(%0) : tensor<?xf32>
+  %2 = tensor.empty(%0) : tensor<?xf32>
+  %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %1 : tensor<?xf32>, tensor<?xf32>) outs (%2:tensor<?xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+    %4 = arith.addf  %arg1, %cst : f32
+        linalg.yield %4 : f32
+  } -> tensor<?xf32>
+  return %3 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @remove_deadargs_generic_mixedaccess
+//       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+//   CHECK-NOT: ins
+//  CHECK-SAME: outs({{.*}} : tensor<?x?xf32>) {
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1, d0)>
+func.func @remove_deadargs_generic_mixedaccess(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 0 : index
+  %cst1 = arith.constant 7.0 : f32
+  %cst2 = arith.constant 6.0 : f32
+  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
+  %3 = tensor.empty(%1, %0) : tensor<?x?xf32>
+  %4 = tensor.empty(%0, %1) : tensor<?x?xf32>
+  %5 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%2, %3 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%4:tensor<?x?xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+    %6 = arith.divf  %cst1, %cst2 : f32
+        linalg.yield %6 : f32
+  } -> tensor<?x?xf32>
+  return %5 : tensor<?x?xf32>
+}
+
+// -----
 
 // Test case: Most basic case. Adding a vector to itself.
 
index 9373bde..3cfb1b4 100644 (file)
@@ -1,5 +1,4 @@
 // RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file -canonicalize | FileCheck %s --check-prefix=CANONICALIZE
 
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 #binary2Dpointwise = {
@@ -50,6 +49,7 @@ func.func @test_fusion_limit(
     } -> tensor<?x?xf32>
   return %6 : tensor<?x?xf32>
 }
+
 // CHECK-LABEL: func @test_fusion_limit
 //  CHECK-SAME:   %[[ARG0:[a-zA-z0-9_]+]]: tensor<?x?xf32>
 //  CHECK-SAME:   %[[ARG1:[a-zA-z0-9_]+]]: tensor<?x?xf32>
@@ -59,17 +59,5 @@ func.func @test_fusion_limit(
 //  CHECK-SAME:   %[[ARG5:[a-zA-z0-9_]+]]: tensor<?x?xf32>
 //       CHECK:   %[[OP1:.+]] = linalg.generic {{.+}} ins(%[[ARG2]], %[[ARG3]]
 //       CHECK:   %[[OP2:.+]] = linalg.generic {{.+}} ins(%[[ARG4]], %[[ARG5]]
-//       CHECK:   %[[OP3:.+]]:2 = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]]
-//       CHECK:   return %[[OP3]]#1
-
-// CANONICALIZE-LABEL: func @test_fusion_limit
-//  CANONICALIZE-SAME:   %[[ARG0:[a-zA-z0-9_]+]]: tensor<?x?xf32>
-//  CANONICALIZE-SAME:   %[[ARG1:[a-zA-z0-9_]+]]: tensor<?x?xf32>
-//  CANONICALIZE-SAME:   %[[ARG2:[a-zA-z0-9_]+]]: tensor<?x?xf32>
-//  CANONICALIZE-SAME:   %[[ARG3:[a-zA-z0-9_]+]]: tensor<?x?xf32>
-//  CANONICALIZE-SAME:   %[[ARG4:[a-zA-z0-9_]+]]: tensor<?x?xf32>
-//  CANONICALIZE-SAME:   %[[ARG5:[a-zA-z0-9_]+]]: tensor<?x?xf32>
-//       CANONICALIZE:   %[[OP1:.+]] = linalg.generic {{.+}} ins(%[[ARG2]], %[[ARG3]]
-//       CANONICALIZE:   %[[OP2:.+]] = linalg.generic {{.+}} ins(%[[ARG4]], %[[ARG5]]
-//       CANONICALIZE:   %[[OP3:.+]] = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]]
-//       CANONICALIZE:   return %[[OP3]]
+//       CHECK:   %[[OP3:.+]] = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]]
+//       CHECK:   return %[[OP3]]
index 6b66cbe..e0743db 100644 (file)
@@ -22,8 +22,8 @@ struct TestLinalgDecomposeOps
     : public PassWrapper<TestLinalgDecomposeOps, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDecomposeOps)
 
-  TestLinalgDecomposeOps() = default;
-  TestLinalgDecomposeOps(const TestLinalgDecomposeOps &pass) = default;
+  TestLinalgDecomposeOps(){};
+  TestLinalgDecomposeOps(const TestLinalgDecomposeOps &pass){};
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<AffineDialect, linalg::LinalgDialect>();
   }
@@ -32,10 +32,16 @@ struct TestLinalgDecomposeOps
     return "Test Linalg decomposition patterns";
   }
 
+  Option<bool> removeDeadArgsAndResults{
+      *this, "remove-dead-args-and-results",
+      llvm::cl::desc("Test patterns to erase unused operands and results"),
+      llvm::cl::init(false)};
+
   void runOnOperation() override {
     MLIRContext *context = &this->getContext();
     RewritePatternSet decompositionPatterns(context);
-    linalg::populateDecomposeLinalgOpsPattern(decompositionPatterns);
+    linalg::populateDecomposeLinalgOpsPattern(decompositionPatterns,
+                                              removeDeadArgsAndResults);
     if (failed(applyPatternsAndFoldGreedily(
             getOperation(), std::move(decompositionPatterns)))) {
       return signalPassFailure();
index 7313c30..92ee447 100644 (file)
@@ -109,6 +109,10 @@ struct TestLinalgTransforms
       llvm::cl::desc(
           "Test patterns to swap tensor.extract_slice(linalg.fill())"),
       llvm::cl::init(false)};
+  Option<bool> testEraseUnusedOperandsAndResults{
+      *this, "test-erase-unused-operands-and-results",
+      llvm::cl::desc("Test patterns to erase unused operands and results"),
+      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -175,6 +179,12 @@ static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  populateEraseUnusedOperandsAndResultsPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   if (testPatterns)
@@ -193,6 +203,8 @@ void TestLinalgTransforms::runOnOperation() {
     return applyBubbleUpExtractSliceOpPattern(getOperation());
   if (testSwapExtractSliceWithFill)
     return applySwapExtractSliceWithFillPattern(getOperation());
+  if (testEraseUnusedOperandsAndResults)
+    return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
 }
 
 namespace mlir {