Refactor the dialect conversion framework to support multi-level conversions. Multi...
authorRiver Riddle <riverriddle@google.com>
Mon, 3 Jun 2019 19:49:55 +0000 (12:49 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 4 Jun 2019 02:27:02 +0000 (19:27 -0700)
To accomplish this, moving forward users will need to provide a legalization target that defines what operations are legal for the conversion. A target can mark an operation as legal by providing a specific legalization action. The initial actions are:
* Legal
  - This action signals that every instance of the given operation is legal,
    i.e. any combination of attributes, operands, types, etc. is valid.
* Dynamic
  - This action signals that only some instances of a given operation are legal. This
    allows for defining fine-tune constraints, like say std.add is only legal when
    operating on 32-bit integers.

An example target is shown below:
struct MyTarget : public ConversionTarget {
  MyTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
    // All operations in the LLVM dialect are legal.
    addLegalDialect<LLVMDialect>();

    // std.constant op is always legal on this target.
    addLegalOp<ConstantOp>();

    // std.return op has dynamic legality constraints.
    addDynamicallyLegalOp<ReturnOp>();
  }

  /// Implement the custom legalization handler to handle
  /// std.return.
  bool isLegal(Operation *op) override {
    // Process the dynamic handling for a std.return op (and any others that were
    // marked "dynamic").
    ...
  }
};

PiperOrigin-RevId: 251289374

mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp
mlir/examples/toy/Ch5/mlir/LateLowering.cpp
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Transforms/DialectConversion.cpp

index 84234c3..8cd970c 100644 (file)
@@ -418,7 +418,10 @@ void linalg::convertToLLVM(mlir::Module &module) {
   populateStdToLLVMConversionPatterns(converter, patterns);
   populateLinalg1ToLLVMConversionPatterns(patterns, module.getContext());
 
-  auto r = applyConversionPatterns(module, converter, std::move(patterns));
+  ConversionTarget target(*module.getContext());
+  target.addLegalDialects<LLVM::LLVMDialect>();
+  auto r =
+      applyConversionPatterns(module, target, converter, std::move(patterns));
   (void)r;
   assert(succeeded(r) && "conversion failed");
 }
index db9f496..60fdf60 100644 (file)
@@ -29,6 +29,7 @@
 #include "mlir/LLVMIR/LLVMLowering.h"
 #include "mlir/LLVMIR/Transforms.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/LowerAffine.h"
 
 #include "linalg1/ConvertToLLVMDialect.h"
 #include "linalg1/LLVMIntrinsics.h"
@@ -145,12 +146,12 @@ static void populateLinalg3ToLLVMConversionPatterns(
 }
 
 void linalg::convertLinalg3ToLLVM(Module &module) {
-  // Remove affine constructs if any by using an existing pass.
-  PassManager pm;
-  pm.addPass(createLowerAffinePass());
-  auto rr = pm.run(&module);
-  (void)rr;
-  assert(succeeded(rr) && "affine loop lowering failed");
+  // Remove affine constructs.
+  for (auto &func : module) {
+    auto rr = lowerAffineConstructs(func);
+    (void)rr;
+    assert(succeeded(rr) && "affine loop lowering failed");
+  }
 
   // Convert Linalg ops to the LLVM IR dialect using the converter defined
   // above.
@@ -160,7 +161,10 @@ void linalg::convertLinalg3ToLLVM(Module &module) {
   populateLinalg1ToLLVMConversionPatterns(patterns, module.getContext());
   populateLinalg3ToLLVMConversionPatterns(patterns, module.getContext());
 
-  auto r = applyConversionPatterns(module, converter, std::move(patterns));
+  ConversionTarget target(*module.getContext());
+  target.addLegalDialects<LLVM::LLVMDialect>();
+  auto r =
+      applyConversionPatterns(module, target, converter, std::move(patterns));
   (void)r;
   assert(succeeded(r) && "conversion failed");
 }
index f4ac522..45d608d 100644 (file)
@@ -27,6 +27,7 @@
 
 #include "toy/Dialect.h"
 
+#include "linalg1/Dialect.h"
 #include "linalg1/Intrinsics.h"
 #include "linalg1/ViewOp.h"
 #include "linalg3/TensorOps.h"
@@ -124,9 +125,14 @@ public:
 /// dialect.
 struct EarlyLoweringPass : public FunctionPass<EarlyLoweringPass> {
   void runOnFunction() override {
+    ConversionTarget target(getContext());
+    target.addLegalDialects<linalg::LinalgDialect, StandardOpsDialect>();
+    target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
+
     OwningRewritePatternList patterns;
     RewriteListBuilder<MulOpConversion>::build(patterns, &getContext());
-    if (failed(applyConversionPatterns(getFunction(), std::move(patterns)))) {
+    if (failed(applyConversionPatterns(getFunction(), target,
+                                       std::move(patterns)))) {
       getContext().emitError(mlir::UnknownLoc::get(&getContext()),
                              "Error lowering Toy\n");
       signalPassFailure();
index 611d716..d682d12 100644 (file)
@@ -24,6 +24,7 @@
 
 #include "toy/Dialect.h"
 
+#include "linalg1/Dialect.h"
 #include "linalg1/Intrinsics.h"
 #include "linalg1/ViewOp.h"
 #include "linalg3/ConvertToLLVMDialect.h"
@@ -338,7 +339,11 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
                        ReturnOpConversion>::build(toyPatterns, &getContext());
 
     // Perform Toy specific lowering.
-    if (failed(applyConversionPatterns(getModule(), typeConverter,
+    ConversionTarget target(getContext());
+    target.addLegalDialects<AffineOpsDialect, linalg::LinalgDialect,
+                            LLVM::LLVMDialect, StandardOpsDialect>();
+    target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
+    if (failed(applyConversionPatterns(getModule(), target, typeConverter,
                                        std::move(toyPatterns)))) {
       getModule().getContext()->emitError(
           UnknownLoc::get(getModule().getContext()), "Error lowering Toy\n");
index 165e065..ac24252 100644 (file)
@@ -25,6 +25,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/MapVector.h"
 
 namespace mlir {
 
@@ -136,18 +137,108 @@ public:
       SmallVectorImpl<NamedAttributeList> &convertedArgAttrs);
 };
 
+/// This class describes a specific conversion target.
+class ConversionTarget {
+public:
+  /// This enumeration corresponds to the specific action to take when
+  /// considering an operation legal for this conversion target.
+  enum class LegalizationAction {
+    /// The target supports this operation.
+    Legal,
+
+    /// This operation has dynamic legalization constraints that must be checked
+    /// by the target.
+    Dynamic
+  };
+
+  /// The type used to store operation legality information.
+  using LegalityMapTy = llvm::MapVector<OperationName, LegalizationAction>;
+
+  ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
+  virtual ~ConversionTarget() = default;
+
+  /// Runs a custom legalization query for the given operation. This should
+  /// return true if the given operation is legal, otherwise false.
+  virtual bool isLegal(Operation *op) const {
+    llvm_unreachable(
+        "targets with custom legalization must override 'isLegal'");
+  }
+
+  /// Register a legality action for the given operation.
+  void setOpAction(OperationName op, LegalizationAction action) {
+    legalOperations[op] = action;
+  }
+  template <typename OpT> void setOpAction(LegalizationAction action) {
+    setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
+  }
+
+  /// Register the given operations as legal.
+  template <typename OpT> void addLegalOp() {
+    setOpAction<OpT>(LegalizationAction::Legal);
+  }
+  template <typename OpT, typename OpT2, typename... OpTs> void addLegalOp() {
+    addLegalOp<OpT>();
+    addLegalOp<OpT2, OpTs...>();
+  }
+
+  /// Register the operations of the given dialects as legal.
+  void addLegalDialects(ArrayRef<StringRef> dialectNames);
+  template <typename... Names>
+  void addLegalDialects(StringRef name, Names... names) {
+    SmallVector<StringRef, 2> dialectNames({name, names...});
+    addLegalDialects(dialectNames);
+  }
+  template <typename... Args> void addLegalDialects() {
+    SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
+    addLegalDialects(dialectNames);
+  }
+
+  /// Register the given operation as dynamically legal, i.e. requiring custom
+  /// handling by the target via 'isLegal'.
+  template <typename OpT> void addDynamicallyLegalOp() {
+    setOpAction<OpT>(LegalizationAction::Dynamic);
+  }
+  template <typename OpT, typename OpT2, typename... OpTs>
+  void addDynamicallyLegalOp() {
+    addDynamicallyLegalOp<OpT>();
+    addDynamicallyLegalOp<OpT2, OpTs...>();
+  }
+
+  /// Get the legality action for the given operation.
+  llvm::Optional<LegalizationAction> getOpAction(OperationName op) const {
+    auto it = legalOperations.find(op);
+    if (it != legalOperations.end())
+      return it->second;
+    return llvm::None;
+  }
+
+  /// Returns a range of operations that this target has defined to be legal in
+  /// some capacity.
+  llvm::iterator_range<LegalityMapTy::const_iterator> getLegalOps() const {
+    return llvm::make_range(legalOperations.begin(), legalOperations.end());
+  }
+
+private:
+  /// A deterministic mapping of operation name to the specific legality action
+  /// to take.
+  LegalityMapTy legalOperations;
+
+  /// The current context this target applies to.
+  MLIRContext &ctx;
+};
+
 /// Convert the given module with the provided conversion patterns and type
 /// conversion object. If conversion fails for specific functions, those
 /// functions remains unmodified.
-LLVM_NODISCARD
-LogicalResult applyConversionPatterns(Module &module, TypeConverter &converter,
-                                      OwningRewritePatternList &&patterns);
+LLVM_NODISCARD LogicalResult applyConversionPatterns(
+    Module &module, ConversionTarget &target, TypeConverter &converter,
+    OwningRewritePatternList &&patterns);
 
 /// Convert the given function with the provided conversion patterns. This will
 /// convert as many of the operations within 'fn' as possible given the set of
 /// patterns.
 LLVM_NODISCARD
-LogicalResult applyConversionPatterns(Function &fn,
+LogicalResult applyConversionPatterns(Function &fn, ConversionTarget &target,
                                       OwningRewritePatternList &&patterns);
 
 } // end namespace mlir
index a9717e2..0e30a8e 100644 (file)
@@ -986,7 +986,11 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
     LLVMTypeConverter converter(&getContext());
     OwningRewritePatternList patterns;
     populateStdToLLVMConversionPatterns(converter, patterns);
-    if (failed(applyConversionPatterns(m, converter, std::move(patterns))))
+
+    ConversionTarget target(getContext());
+    target.addLegalDialects<LLVM::LLVMDialect>();
+    if (failed(
+            applyConversionPatterns(m, target, converter, std::move(patterns))))
       signalPassFailure();
   }
 };
index c686af8..60c0daf 100644 (file)
@@ -677,7 +677,10 @@ void LowerLinalgToLLVMPass::runOnModule() {
   populateStdToLLVMConversionPatterns(converter, patterns);
   populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
 
-  if (failed(applyConversionPatterns(module, converter, std::move(patterns))))
+  ConversionTarget target(getContext());
+  target.addLegalDialects<LLVM::LLVMDialect>();
+  if (failed(applyConversionPatterns(module, target, converter,
+                                     std::move(patterns))))
     signalPassFailure();
 }
 
index b6a15f6..4c110d1 100644 (file)
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Module.h"
 #include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
-using namespace mlir::impl;
+
+#define DEBUG_TYPE "dialect-conversion"
 
 //===----------------------------------------------------------------------===//
 // ArgConverter
@@ -136,7 +141,8 @@ struct DialectConversionRewriter final : public PatternRewriter {
     assert(newValues.size() == op->getNumResults());
     // Create mappings for any type changes.
     for (unsigned i = 0, e = newValues.size(); i < e; ++i)
-      if (op->getResult(i)->getType() != newValues[i]->getType())
+      if (newValues[i] &&
+          op->getResult(i)->getType() != newValues[i]->getType())
         mapping.map(op->getResult(i), newValues[i]);
 
     // Record the requested operation replacement.
@@ -223,17 +229,235 @@ void ConversionPattern::rewrite(Operation *op,
 }
 
 //===----------------------------------------------------------------------===//
+// ConversionTarget
+//===----------------------------------------------------------------------===//
+
+/// Register the operations of the given dialects as legal.
+void ConversionTarget::addLegalDialects(ArrayRef<StringRef> dialectNames) {
+  SmallPtrSet<Dialect *, 2> dialects;
+  for (auto dialectName : dialectNames)
+    if (auto *dialect = ctx.getRegisteredDialect(dialectName))
+      dialects.insert(dialect);
+
+  // Set all dialect operations as legal.
+  for (auto op : ctx.getRegisteredOperations())
+    if (dialects.count(&op->dialect))
+      setOpAction(op, LegalizationAction::Legal);
+}
+
+//===----------------------------------------------------------------------===//
+// OperationLegalizer
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class represents the information necessary for legalizing an operation
+/// kind.
+struct OpLegality {
+  /// This is the legalization action specified by the target, if it provided
+  /// one.
+  llvm::Optional<ConversionTarget::LegalizationAction> targetAction;
+
+  /// The set of patterns to apply to an instance of this operation to legalize
+  /// it.
+  SmallVector<RewritePattern *, 1> patterns;
+};
+
+/// This class defines a recursive operation legalizer.
+class OperationLegalizer {
+public:
+  OperationLegalizer(ConversionTarget &targetInfo,
+                     OwningRewritePatternList &patterns)
+      : target(targetInfo) {
+    buildLegalizationGraph(patterns);
+  }
+
+  /// Attempt to legalize the given operation. Returns success if the operation
+  /// was legalized, failure otherwise.
+  LogicalResult legalize(Operation *op, DialectConversionRewriter &rewriter);
+
+private:
+  /// Attempt to legalize the given operation by applying the provided pattern.
+  /// Returns success if the operation was legalized, failure otherwise.
+  LogicalResult legalizePattern(Operation *op, RewritePattern *pattern,
+                                DialectConversionRewriter &rewriter);
+
+  /// Build an optimistic legalization graph given the provided patterns. This
+  /// function populates 'legalOps' with the operations that are either legal,
+  /// or transitively legal for the current target given the provided patterns.
+  void buildLegalizationGraph(OwningRewritePatternList &patterns);
+
+  /// The current set of patterns that have been applied.
+  llvm::SmallPtrSet<RewritePattern *, 8> appliedPatterns;
+
+  /// The set of legality information for operations transitively supported by
+  /// the target.
+  DenseMap<OperationName, OpLegality> legalOps;
+
+  /// The legalization information provided by the target.
+  ConversionTarget &target;
+};
+} // namespace
+
+LogicalResult
+OperationLegalizer::legalize(Operation *op,
+                             DialectConversionRewriter &rewriter) {
+  LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName()
+                          << "\n");
+
+  auto it = legalOps.find(op->getName());
+  if (it == legalOps.end()) {
+    LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no known legalization path.\n");
+    return failure();
+  }
+
+  // Check if this was marked legal by the target.
+  auto &opInfo = it->second;
+  if (auto action = opInfo.targetAction) {
+    // Check if this operation is always legal.
+    if (*action == ConversionTarget::LegalizationAction::Legal)
+      return success();
+
+    // Otherwise, handle custom legalization.
+    LLVM_DEBUG(llvm::dbgs() << "- Trying dynamic legalization.\n");
+    if (target.isLegal(op))
+      return success();
+
+    // Fallthough to see if a pattern can convert this into a legal operation.
+  }
+
+  // Otherwise, we need to apply a legalization pattern to this operation.
+  // TODO(riverriddle) This currently has no cost model and doesn't prioritize
+  // specific patterns in any way.
+  for (auto *pattern : opInfo.patterns)
+    if (succeeded(legalizePattern(op, pattern, rewriter)))
+      return success();
+
+  LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no matched legalization pattern.\n");
+  return failure();
+}
+
+LogicalResult
+OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
+                                    DialectConversionRewriter &rewriter) {
+  LLVM_DEBUG({
+    llvm::dbgs() << "-* Applying rewrite pattern '" << op->getName() << " -> (";
+    interleaveComma(pattern->getGeneratedOps(), llvm::dbgs());
+    llvm::dbgs() << ")'.\n";
+  });
+
+  // Ensure that we don't cycle by not allowing the same pattern to be
+  // applied twice in the same recursion stack.
+  // TODO(riverriddle) We could eventually converge, but that requires more
+  // complicated analysis.
+  if (!appliedPatterns.insert(pattern).second) {
+    LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern was already applied.\n");
+    return failure();
+  }
+
+  auto curOpCount = rewriter.createdOps.size();
+  auto curReplCount = rewriter.replacements.size();
+  auto cleanupFailure = [&] {
+    // Pop all of the newly created operations and replacements.
+    while (rewriter.createdOps.size() != curOpCount)
+      rewriter.createdOps.pop_back_val()->erase();
+    rewriter.replacements.resize(curReplCount);
+    appliedPatterns.erase(pattern);
+    return failure();
+  };
+
+  // Try to rewrite with the given pattern.
+  rewriter.setInsertionPoint(op);
+  if (!pattern->matchAndRewrite(op, rewriter)) {
+    LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern failed to match.\n");
+    return cleanupFailure();
+  }
+
+  // Recursively legalize each of the new operations.
+  for (unsigned i = curOpCount, e = rewriter.createdOps.size(); i != e; ++i) {
+    if (succeeded(legalize(rewriter.createdOps[i], rewriter)))
+      continue;
+    LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n");
+    return cleanupFailure();
+  }
+
+  appliedPatterns.erase(pattern);
+  return success();
+}
+
+void OperationLegalizer::buildLegalizationGraph(
+    OwningRewritePatternList &patterns) {
+  // A mapping between an operation and a set of operations that can be used to
+  // generate it.
+  DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
+  // A mapping between an operation and any currently invalid patterns it has.
+  DenseMap<OperationName, SmallPtrSet<RewritePattern *, 2>> invalidPatterns;
+  // A worklist of patterns to consider for legality.
+  llvm::SetVector<RewritePattern *> patternWorklist;
+
+  // Collect the initial set of valid target ops.
+  for (auto &opInfoPair : target.getLegalOps())
+    legalOps[opInfoPair.first].targetAction = opInfoPair.second;
+
+  // Build the mapping from operations to the parent ops that may generate them.
+  for (auto &pattern : patterns) {
+    auto root = pattern->getRootKind();
+
+    // Skip operations that are known to always be legal.
+    auto it = legalOps.find(root);
+    if (it != legalOps.end() &&
+        it->second.targetAction == ConversionTarget::LegalizationAction::Legal)
+      continue;
+
+    // Add this pattern to the invalid set for the root op and record this root
+    // as a parent for any generated operations.
+    invalidPatterns[root].insert(pattern.get());
+    for (auto op : pattern->getGeneratedOps())
+      parentOps[op].insert(root);
+
+    // If this pattern doesn't generate any operations, optimistically add it to
+    // the worklist.
+    if (pattern->getGeneratedOps().empty())
+      patternWorklist.insert(pattern.get());
+  }
+
+  // Build the initial worklist with the patterns that generate operations that
+  // are known to be legal.
+  for (auto &opInfoPair : target.getLegalOps())
+    for (auto &parentOp : parentOps[opInfoPair.first])
+      patternWorklist.set_union(invalidPatterns[parentOp]);
+
+  while (!patternWorklist.empty()) {
+    auto *pattern = patternWorklist.pop_back_val();
+
+    // Check to see if any of the generated operations are invalid.
+    if (llvm::any_of(pattern->getGeneratedOps(),
+                     [&](OperationName op) { return !legalOps.count(op); }))
+      continue;
+
+    // Otherwise, if all of the generated operation are valid, this op is now
+    // legal so add all of the child patterns to the worklist.
+    legalOps[pattern->getRootKind()].patterns.push_back(pattern);
+    invalidPatterns[pattern->getRootKind()].erase(pattern);
+
+    // Add any invalid patterns of the parent operations to see if they have now
+    // become legal.
+    for (auto op : parentOps[pattern->getRootKind()])
+      patternWorklist.set_union(invalidPatterns[op]);
+  }
+}
+
+//===----------------------------------------------------------------------===//
 // FunctionConverter
 //===----------------------------------------------------------------------===//
 namespace {
 // This class converts a single function using the given pattern matcher. If a
 // TypeConverter object is provided, then the types of block arguments will be
 // converted using the appropriate 'convertType' calls.
-class FunctionConverter {
-public:
-  explicit FunctionConverter(MLIRContext *ctx, RewritePatternMatcher &matcher,
+struct FunctionConverter {
+  explicit FunctionConverter(MLIRContext *ctx, ConversionTarget &target,
+                             OwningRewritePatternList &patterns,
                              TypeConverter *conversion = nullptr)
-      : typeConverter(conversion), matcher(matcher) {}
+      : typeConverter(conversion), opLegalizer(target, patterns) {}
 
   /// Converts the given function to the dialect using hooks defined in
   /// `typeConverter`. Returns failure on error, success otherwise.
@@ -262,8 +486,8 @@ public:
   /// Pointer to a specific dialect conversion info.
   TypeConverter *typeConverter;
 
-  /// The matcher to use when converting operations.
-  RewritePatternMatcher &matcher;
+  /// The legalizer to use when converting operations.
+  OperationLegalizer opLegalizer;
 };
 } // end anonymous namespace
 
@@ -293,15 +517,14 @@ FunctionConverter::convertBlock(DialectConversionRewriter &rewriter,
 
   // Iterate over ops and convert them.
   for (Operation &op : llvm::make_early_inc_range(*block)) {
-    rewriter.setInsertionPoint(&op);
-    if (matcher.matchAndRewrite(&op, rewriter))
-      continue;
-
     // Traverse any held regions.
     for (auto &region : op.getRegions())
       if (!region.empty() &&
           failed(convertRegion(rewriter, region, op.getLoc())))
         return failure();
+
+    // Legalize the current operation.
+    (void)opLegalizer.legalize(&op, rewriter);
   }
 
   // Recurse to children that haven't been visited.
@@ -416,12 +639,12 @@ struct ConvertedFunction {
 /// conversion object. If conversion fails for specific functions, those
 /// functions remains unmodified.
 LogicalResult
-mlir::applyConversionPatterns(Module &module, TypeConverter &converter,
+mlir::applyConversionPatterns(Module &module, ConversionTarget &target,
+                              TypeConverter &converter,
                               OwningRewritePatternList &&patterns) {
-  // Grab the conversion patterns from the converter and create the pattern
-  // matcher.
-  MLIRContext *context = module.getContext();
-  RewritePatternMatcher matcher(std::move(patterns));
+  // Build the function converter.
+  FunctionConverter funcConverter(module.getContext(), target, patterns,
+                                  &converter);
 
   // Try to convert each of the functions within the module. Defer updating the
   // signatures of the functions until after all of the bodies have been
@@ -439,7 +662,6 @@ mlir::applyConversionPatterns(Module &module, TypeConverter &converter,
       return func.emitError("could not convert function type");
 
     // Convert the body of this function.
-    FunctionConverter funcConverter(context, matcher, &converter);
     if (failed(funcConverter.convertFunction(&func)))
       return failure();
 
@@ -461,10 +683,9 @@ mlir::applyConversionPatterns(Module &module, TypeConverter &converter,
 /// convert as many of the operations within 'fn' as possible given the set of
 /// patterns.
 LogicalResult
-mlir::applyConversionPatterns(Function &fn,
+mlir::applyConversionPatterns(Function &fn, ConversionTarget &target,
                               OwningRewritePatternList &&patterns) {
   // Convert the body of this function.
-  RewritePatternMatcher matcher(std::move(patterns));
-  FunctionConverter converter(fn.getContext(), matcher);
+  FunctionConverter converter(fn.getContext(), target, patterns);
   return converter.convertFunction(&fn);
 }