Refactor the conversion of block argument types in DialectConversion.
authorRiver Riddle <riverriddle@google.com>
Wed, 17 Jul 2019 21:45:53 +0000 (14:45 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Fri, 19 Jul 2019 18:38:45 +0000 (11:38 -0700)
This cl begins a large refactoring over how signature types are converted in the DialectConversion infrastructure. The signatures of blocks are now converted on-demand when an operation held by that block is being converted. This allows for handling the case where a region is created as part of a pattern, something that wasn't possible previously.

This cl also generalizes the region signature conversion used by FuncOp to work on any region of any operation. This generalization allows for removing the 'apply*Conversion' functions that were specific to FuncOp/ModuleOp. The implementation currently uses a new hook on TypeConverter, 'convertRegionSignature', but this should ideally be removed in favor of using Patterns. That depends on adding support to the PatternRewriter used by ConversionPattern to allow applying signature conversions to regions, which should be coming in a followup.

PiperOrigin-RevId: 258645733

16 files changed:
mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
mlir/examples/toy/Ch5/mlir/LateLowering.cpp
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/include/mlir/IR/Block.h
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/IR/Region.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/IR/Block.cpp
mlir/lib/IR/Region.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/Transforms/test-legalizer.mlir
mlir/test/lib/TestDialect/TestOps.td
mlir/test/lib/TestDialect/TestPatterns.cpp

index 989915e56f887770db6373875b4f59970553cfda..c43a2ae116343e854379cfe04d82175da1bb31e9 100644 (file)
@@ -421,7 +421,8 @@ LogicalResult linalg::convertToLLVM(mlir::ModuleOp module) {
 
   ConversionTarget target(*module.getContext());
   target.addLegalDialect<LLVM::LLVMDialect>();
-  return applyFullConversion(module, target, converter, std::move(patterns));
+  target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
+  return applyFullConversion(module, target, std::move(patterns), &converter);
 }
 
 namespace {
index e4eaca84f4e348c5e1072c162630b5ea4184d4d7..c86f5d79f347722cc71e84e660ea6466b01e4530 100644 (file)
@@ -160,8 +160,9 @@ LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) {
 
   ConversionTarget target(*module.getContext());
   target.addLegalDialect<LLVM::LLVMDialect>();
+  target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
   if (failed(
-          applyFullConversion(module, target, converter, std::move(patterns))))
+          applyFullConversion(module, target, std::move(patterns), &converter)))
     return failure();
 
   return success();
index ebc81ef2be999da6bb38912285ea1cd690f87689..cd826fb25a520416981a1a94e3d8f76b42cecc33 100644 (file)
@@ -356,8 +356,8 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
     target.addLegalDialect<AffineOpsDialect, linalg::LinalgDialect,
                            LLVM::LLVMDialect, StandardOpsDialect>();
     target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
-    if (failed(applyPartialConversion(getModule(), target, typeConverter,
-                                      std::move(toyPatterns)))) {
+    if (failed(applyPartialConversion(
+            getModule(), target, std::move(toyPatterns), &typeConverter))) {
       emitError(UnknownLoc::get(getModule().getContext()),
                 "Error lowering Toy\n");
       signalPassFailure();
index 83b156b3a222049fd01a3cb4f2535a15c6546904..864d8057fff032e5d19b7bdae94a152363deb025 100644 (file)
@@ -67,7 +67,6 @@ protected:
   /// with multiple results into functions returning LLVM IR's structure type.
   /// Use `convertType` to convert individual argument and result types.
   LogicalResult convertSignature(FunctionType t,
-                                 ArrayRef<NamedAttributeList> argAttrs,
                                  SignatureConversion &result) final;
 
   /// LLVM IR module used to parse/create types.
index 9277266963938038fa4603a269e923eebfdbaf62..50cca52b7ab8f7b07c07f125ca31c8a2bd813a1e 100644 (file)
@@ -99,6 +99,9 @@ public:
   /// nullptr if this is a top-level block.
   Operation *getContainingOp();
 
+  /// Return if this block is the entry block in the parent region.
+  bool isEntryBlock();
+
   /// Insert this block (which must not already be in a function) right before
   /// the specified block.
   void insertBefore(Block *block);
index 97efae159797c75b196c4c60fdedceff6d228bbf..de68e4b38dad6cd9683eb8a60303a39e7ec6f36f 100644 (file)
@@ -302,6 +302,10 @@ public:
     return OpTy();
   }
 
+  /// This is implemented to create the specified operations and serves as a
+  /// notification hook for rewriters that want to know about new operations.
+  virtual Operation *createOperation(const OperationState &state) = 0;
+
   /// Move the blocks that belong to "region" before the given position in
   /// another region "parent".  The two regions must be different.  The caller
   /// is responsible for creating or updating the operation transferring flow
@@ -362,10 +366,6 @@ protected:
   // These are the callback methods that subclasses can choose to implement if
   // they would like to be notified about certain types of mutations.
 
-  /// This is implemented to create the specified operations and serves as a
-  /// notification hook for rewriters that want to know about new operations.
-  virtual Operation *createOperation(const OperationState &state) = 0;
-
   /// Notify the pattern rewriter that the specified operation has been mutated
   /// in place.  This is called after the mutation is done.
   virtual void notifyRootUpdated(Operation *op) {}
index d6ca2188621e5a12ec609c98e8689e28348b7e41..5f21226cd29a6c7a53c7ae86c31ffb51ca093db1 100644 (file)
@@ -85,6 +85,9 @@ public:
     return ParentT();
   }
 
+  /// Return the number of this region in the parent operation.
+  unsigned getRegionNumber();
+
   /// Return true if this region is a proper ancestor of the `other` region.
   bool isProperAncestor(Region *other);
 
index 994467933e7c4522a258ec5b3d986894458975b1..bfe367417148dd9eae8893fbdea3b3ab49adb591 100644 (file)
@@ -47,8 +47,8 @@ class TypeConverter {
 public:
   virtual ~TypeConverter() = default;
 
-  /// This class provides all of the information necessary to convert a
-  /// FunctionType signature.
+  /// This class provides all of the information necessary to convert a type
+  /// signature.
   class SignatureConversion {
   public:
     SignatureConversion(unsigned numOrigInputs)
@@ -71,11 +71,6 @@ public:
     /// Return the result types for the new signature.
     ArrayRef<Type> getConvertedResultTypes() const { return resultTypes; }
 
-    /// Returns the attributes for the arguments of the new signature.
-    ArrayRef<NamedAttributeList> getConvertedArgAttrs() const {
-      return argAttrs;
-    }
-
     /// Get the input mapping for the given argument.
     llvm::Optional<InputMapping> getInputMapping(unsigned input) const {
       return remappedInputs[input];
@@ -90,13 +85,11 @@ public:
 
     /// Remap an input of the original signature with a new set of types. The
     /// new types are appended to the new signature conversion.
-    void addInputs(unsigned origInputNo, ArrayRef<Type> types,
-                   ArrayRef<NamedAttributeList> attrs = llvm::None);
+    void addInputs(unsigned origInputNo, ArrayRef<Type> types);
 
     /// Append new input types to the signature conversion, this should only be
     /// used if the new types are not intended to remap an existing input.
-    void addInputs(ArrayRef<Type> types,
-                   ArrayRef<NamedAttributeList> attrs = llvm::None);
+    void addInputs(ArrayRef<Type> types);
 
     /// Remap an input of the original signature with a range of types in the
     /// new signature.
@@ -109,9 +102,6 @@ public:
 
     /// The set of argument and results types.
     SmallVector<Type, 4> argTypes, resultTypes;
-
-    /// The set of attributes for each new argument type.
-    SmallVector<NamedAttributeList, 4> argAttrs;
   };
 
   /// This hooks allows for converting a type. This function should return
@@ -126,31 +116,44 @@ public:
 
   /// Convert the given FunctionType signature. This functions returns a valid
   /// SignatureConversion on success, None otherwise.
-  llvm::Optional<SignatureConversion>
-  convertSignature(FunctionType type, ArrayRef<NamedAttributeList> argAttrs);
-  llvm::Optional<SignatureConversion> convertSignature(FunctionType type) {
-    SmallVector<NamedAttributeList, 4> argAttrs(type.getNumInputs());
-    return convertSignature(type, argAttrs);
-  }
+  llvm::Optional<SignatureConversion> convertSignature(FunctionType type);
 
   /// This hook allows for changing a FunctionType signature. This function
-  /// should populate 'result' with the new arguments and result on success,
+  /// should populate 'result' with the new arguments and results on success,
   /// otherwise return failure.
   ///
   /// The default behavior of this function is to call 'convertType' on
-  /// individual function operands and results. Any argument attributes are
-  /// dropped if the resultant conversion is not a 1->1 mapping.
+  /// individual function operands and results.
   virtual LogicalResult convertSignature(FunctionType type,
-                                         ArrayRef<NamedAttributeList> argAttrs,
                                          SignatureConversion &result);
 
   /// This hook allows for converting a specific argument of a signature. It
-  /// takes as inputs the original argument input number, type, and attributes.
+  /// takes as inputs the original argument input number, type.
   /// On success, this function should populate 'result' with any new mappings.
   virtual LogicalResult convertSignatureArg(unsigned inputNo, Type type,
-                                            NamedAttributeList attrs,
                                             SignatureConversion &result);
 
+  /// This hook allows for converting the signature of a region 'regionIdx',
+  /// i.e. the signature of the entry to the region, on the given operation
+  /// 'op'. This function should return a valid conversion for the signature on
+  /// success, None otherwise. This hook is allowed to modify the attributes on
+  /// the provided operation if necessary.
+  ///
+  /// The default behavior of this function is to invoke 'convertBlockSignature'
+  /// on the entry block, if one is present. This function also provides special
+  /// handling for FuncOp to update the type signature.
+  ///
+  /// TODO(riverriddle) This should be replaced in favor of using patterns, but
+  /// the pattern rewriter needs to know how to properly replace/remap
+  /// arguments.
+  virtual llvm::Optional<SignatureConversion>
+  convertRegionSignature(Operation *op, unsigned regionIdx);
+
+  /// This function converts the type signature of the given block, by invoking
+  /// 'convertSignatureArg' for each argument. This function should return a
+  /// valid conversion for the signature on success, None otherwise.
+  llvm::Optional<SignatureConversion> convertBlockSignature(Block *block);
+
   /// This hook allows for materializing a conversion from a set of types into
   /// one result type by generating a cast operation of some kind. The generated
   /// operation should produce one result, of 'resultType', with the provided
@@ -381,54 +384,26 @@ private:
 /// operations. This method converts as many operations to the target as
 /// possible, ignoring operations that failed to legalize. This method only
 /// returns failure if there are unreachable blocks in any of the regions nested
-/// within 'ops'.
-LLVM_NODISCARD LogicalResult
-applyPartialConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
-                       OwningRewritePatternList &&patterns);
-LLVM_NODISCARD LogicalResult
-applyPartialConversion(Operation *op, ConversionTarget &target,
-                       OwningRewritePatternList &&patterns);
+/// within 'ops'. If 'converter' is provided, the signatures of blocks and
+/// regions are also converted.
+LLVM_NODISCARD LogicalResult applyPartialConversion(
+    ArrayRef<Operation *> ops, ConversionTarget &target,
+    OwningRewritePatternList &&patterns, TypeConverter *converter = nullptr);
+LLVM_NODISCARD LogicalResult applyPartialConversion(
+    Operation *op, ConversionTarget &target,
+    OwningRewritePatternList &&patterns, TypeConverter *converter = nullptr);
 
 /// Apply a complete conversion on the given operations, and all nested
 /// operations. This method returns failure if the conversion of any operation
 /// fails, or if there are unreachable blocks in any of the regions nested
-/// within 'ops'.
-LLVM_NODISCARD LogicalResult
-applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
-                    OwningRewritePatternList &&patterns);
-LLVM_NODISCARD LogicalResult
-applyFullConversion(Operation *op, ConversionTarget &target,
-                    OwningRewritePatternList &&patterns);
-
-//===----------------------------------------------------------------------===//
-// Op + Type Conversion Entry Points
-//===----------------------------------------------------------------------===//
-
-/// Apply a partial conversion on the function operations within the given
-/// module. This method returns failure if a type conversion was encountered.
-LLVM_NODISCARD LogicalResult applyPartialConversion(
-    ModuleOp module, ConversionTarget &target, TypeConverter &converter,
-    OwningRewritePatternList &&patterns);
-
-/// Apply a partial conversion on the given function operations. This method
-/// returns failure if a type conversion was encountered.
-LLVM_NODISCARD LogicalResult applyPartialConversion(
-    MutableArrayRef<FuncOp> fns, ConversionTarget &target,
-    TypeConverter &converter, OwningRewritePatternList &&patterns);
-
-/// Apply a full conversion on the function operations within the given
-/// module. This method returns failure if a type conversion was encountered, or
-/// if the conversion of any operations failed.
+/// within 'ops'. If 'converter' is provided, the signatures of blocks and
+/// regions are also converted.
 LLVM_NODISCARD LogicalResult applyFullConversion(
-    ModuleOp module, ConversionTarget &target, TypeConverter &converter,
-    OwningRewritePatternList &&patterns);
-
-/// Apply a partial conversion on the given function operations. This method
-/// returns failure if a type conversion was encountered, or if the conversion
-/// of any operation failed.
+    ArrayRef<Operation *> ops, ConversionTarget &target,
+    OwningRewritePatternList &&patterns, TypeConverter *converter = nullptr);
 LLVM_NODISCARD LogicalResult applyFullConversion(
-    MutableArrayRef<FuncOp> fns, ConversionTarget &target,
-    TypeConverter &converter, OwningRewritePatternList &&patterns);
+    Operation *op, ConversionTarget &target,
+    OwningRewritePatternList &&patterns, TypeConverter *converter = nullptr);
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
index bc5cbfa934d32dfc6df4d908087c21abd298d421..aa72a7b1d48c47233c025afd5212a455f9f03917 100644 (file)
@@ -1006,13 +1006,11 @@ Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
 }
 
 // Convert function signatures using the stored LLVM IR module.
-LogicalResult
-LLVMTypeConverter::convertSignature(FunctionType type,
-                                    ArrayRef<NamedAttributeList> argAttrs,
-                                    SignatureConversion &result) {
+LogicalResult LLVMTypeConverter::convertSignature(FunctionType type,
+                                                  SignatureConversion &result) {
   // Convert the original function arguments.
   for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
-    if (failed(convertSignatureArg(i, type.getInput(i), argAttrs[i], result)))
+    if (failed(convertSignatureArg(i, type.getInput(i), result)))
       return failure();
 
   // If function does not return anything, return immediately.
@@ -1064,8 +1062,8 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
 
     ConversionTarget target(getContext());
     target.addLegalDialect<LLVM::LLVMDialect>();
-    if (failed(applyPartialConversion(m, target, *typeConverter,
-                                      std::move(patterns))))
+    if (failed(applyPartialConversion(m, target, std::move(patterns),
+                                      typeConverter.get())))
       signalPassFailure();
   }
 
index 93f5fe6e976769db06450e3e88a9d11c912fd5f8..efa76548a8883574e9cdf4ba090ddda3b8bd921f 100644 (file)
@@ -39,6 +39,10 @@ Block::~Block() {
   assert(!verifyInstOrder() && "Expected valid operation ordering.");
   clear();
 
+  for (auto *arg : arguments)
+    if (!arg->use_empty())
+      arg->user_begin()->dump();
+
   llvm::DeleteContainerPointers(arguments);
 }
 
@@ -50,6 +54,9 @@ Operation *Block::getContainingOp() {
   return getParent() ? getParent()->getContainingOp() : nullptr;
 }
 
+/// Return if this block is the entry block in the parent region.
+bool Block::isEntryBlock() { return this == &getParent()->front(); }
+
 /// Insert this block (which must not already be in a region) right before the
 /// specified block.
 void Block::insertBefore(Block *block) {
index 54deb8ac19737c7ad3060710d58a9abc75c716ea..551d59ca96f3cb88aa2cc5baf9b4f7f54766c643 100644 (file)
@@ -60,6 +60,13 @@ bool Region::isProperAncestor(Region *other) {
   return false;
 }
 
+/// Return the number of this region in the parent operation.
+unsigned Region::getRegionNumber() {
+  // Regions are always stored consecutively, so use pointer subtraction to
+  // figure out what number this is.
+  return this - &getContainingOp()->getRegions()[0];
+}
+
 /// Clone the internal blocks from this region into `dest`. Any
 /// cloned blocks are appended to the back of dest.
 void Region::cloneInto(Region *dest, BlockAndValueMapping &mapper) {
index c7ea50fa59b80c0f501ae646526a416efbe1f0d4..fb26f855c46fe40e5b143cd20fb62e1a77cffd87 100644 (file)
@@ -763,8 +763,8 @@ void LowerLinalgToLLVMPass::runOnModule() {
 
   ConversionTarget target(getContext());
   target.addLegalDialect<LLVM::LLVMDialect>();
-  if (failed(applyPartialConversion(module, target, converter,
-                                    std::move(patterns)))) {
+  if (failed(applyPartialConversion(module, target, std::move(patterns),
+                                    &converter))) {
     signalPassFailure();
   }
 
index 806eb6461e76c3362c6e7ea73813266d7fe10bf1..3ef9766cd8671bf9e65d1db0a6112737d6bea3a9 100644 (file)
@@ -47,21 +47,29 @@ struct ArgConverter {
 
   /// Erase any rewrites registered for arguments to blocks within the given
   /// region. This function is called when the given region is to be destroyed.
-  void cancelPendingRewrites(Region &region);
+  void cancelPendingRewrites(Block *block);
 
-  /// Cleanup and undo any generated conversion values.
-  void discardRewrites();
+  /// Cleanup and undo any generated conversions for the arguments of block.
+  /// This method differs from 'cancelPendingRewrites' in that it returns the
+  /// block signature to its original state.
+  void discardPendingRewrites(Block *block);
 
   /// Replace usages of the cast operations with the argument directly.
   void applyRewrites();
 
-  /// Converts the signature of the given entry block.
-  void convertSignature(Block *block,
-                        TypeConverter::SignatureConversion &signatureConversion,
-                        BlockAndValueMapping &mapping);
+  /// Return if the signature of the given block has already been converted.
+  bool hasBeenConverted(Block *block) const { return argMapping.count(block); }
 
-  /// Converts the arguments of the given block.
-  LogicalResult convertArguments(Block *block, BlockAndValueMapping &mapping);
+  /// Attempt to convert the signature of the given region.
+  LogicalResult convertSignature(Region &region, BlockAndValueMapping &mapping);
+
+  /// Attempt to convert the signature of the given block.
+  LogicalResult convertSignature(Block *block, BlockAndValueMapping &mapping);
+
+  /// Apply the given signature conversion on the given block.
+  void applySignatureConversion(
+      Block *block, TypeConverter::SignatureConversion &signatureConversion,
+      BlockAndValueMapping &mapping);
 
   /// Convert the given block argument given the provided set of new argument
   /// values that are to replace it. This function returns the operation used
@@ -97,56 +105,44 @@ struct ArgConverter {
 
 constexpr StringLiteral ArgConverter::kCastName;
 
-/// Erase any rewrites registered for arguments to blocks within the given
-/// region. This function is called when the given region is to be destroyed.
-void ArgConverter::cancelPendingRewrites(Region &region) {
-  for (auto &block : region) {
-    auto it = argMapping.find(&block);
-    if (it == argMapping.end())
-      continue;
-    for (auto *op : it->second) {
-      // If the operation exists within the parent block, like with 1->N cast
-      // operations, we don't need to drop them. They will be automatically
-      // cleaned up with the region is destroyed.
-      if (op->getBlock())
-        continue;
-
-      op->dropAllDefinedValueUses();
-      op->destroy();
-    }
-    argMapping.erase(it);
+/// Erase any rewrites registered for arguments to the given block.
+void ArgConverter::cancelPendingRewrites(Block *block) {
+  auto it = argMapping.find(block);
+  if (it == argMapping.end())
+    return;
+  for (auto *op : it->second) {
+    op->dropAllDefinedValueUses();
+    op->erase();
   }
+  argMapping.erase(it);
 }
 
-/// Cleanup and undo any generated conversion values.
-void ArgConverter::discardRewrites() {
-  // On failure reinstate all of the original block arguments.
-  Block *block;
-  ArrayRef<Operation *> argOps;
-  for (auto &mapping : argMapping) {
-    std::tie(block, argOps) = mapping;
+/// Cleanup and undo any generated conversions for the arguments of block.
+/// This method differs from 'cancelPendingRewrites' in that it returns the
+/// block signature to its original state.
+void ArgConverter::discardPendingRewrites(Block *block) {
+  auto it = argMapping.find(block);
+  if (it == argMapping.end())
+    return;
 
-    // Erase all of the new arguments.
-    for (int i = block->getNumArguments() - 1; i >= 0; --i) {
-      block->getArgument(i)->dropAllUses();
-      block->eraseArgument(i, /*updatePredTerms=*/false);
-    }
+  // Erase all of the new arguments.
+  for (int i = block->getNumArguments() - 1; i >= 0; --i) {
+    block->getArgument(i)->dropAllUses();
+    block->eraseArgument(i, /*updatePredTerms=*/false);
+  }
 
-    // Re-instate the old arguments.
-    for (unsigned i = 0, e = argOps.size(); i != e; ++i) {
-      auto *op = argOps[i];
-      auto *arg = block->addArgument(op->getResult(0)->getType());
-      op->getResult(0)->replaceAllUsesWith(arg);
+  // Re-instate the old arguments.
+  auto &mapping = it->second;
+  for (unsigned i = 0, e = mapping.size(); i != e; ++i) {
+    auto *op = mapping[i];
+    auto *arg = block->addArgument(op->getResult(0)->getType());
+    op->getResult(0)->replaceAllUsesWith(arg);
 
-      // If this was a 1->N value mapping it exists within the parent block so
-      // erase it instead of destroying.
-      if (op->getBlock())
-        op->erase();
-      else
-        op->destroy();
-    }
+    // If this operation is within a block, it will be cleaned up automatically.
+    if (!op->getBlock())
+      op->erase();
   }
-  argMapping.clear();
+  argMapping.erase(it);
 }
 
 /// Replace usages of the cast operations with the argument directly.
@@ -198,8 +194,29 @@ void ArgConverter::applyRewrites() {
   }
 }
 
+/// Converts the signature of the given region.
+LogicalResult ArgConverter::convertSignature(Region &region,
+                                             BlockAndValueMapping &mapping) {
+  if (auto conversion = typeConverter->convertRegionSignature(
+          region.getContainingOp(), region.getRegionNumber())) {
+    if (!region.empty())
+      applySignatureConversion(&region.front(), *conversion, mapping);
+    return success();
+  }
+  return failure();
+}
+
 /// Converts the signature of the given entry block.
-void ArgConverter::convertSignature(
+LogicalResult ArgConverter::convertSignature(Block *block,
+                                             BlockAndValueMapping &mapping) {
+  auto conversion = typeConverter->convertBlockSignature(block);
+  if (conversion)
+    return applySignatureConversion(block, *conversion, mapping), success();
+  return failure();
+}
+
+/// Apply the given signature conversion on the given block.
+void ArgConverter::applySignatureConversion(
     Block *block, TypeConverter::SignatureConversion &signatureConversion,
     BlockAndValueMapping &mapping) {
   unsigned origArgCount = block->getNumArguments();
@@ -228,37 +245,6 @@ void ArgConverter::convertSignature(
     block->eraseArgument(0, /*updatePredTerms=*/false);
 }
 
-/// Converts the arguments of the given block.
-LogicalResult ArgConverter::convertArguments(Block *block,
-                                             BlockAndValueMapping &mapping) {
-  unsigned origArgCount = block->getNumArguments();
-  if (origArgCount == 0 || argMapping.count(block))
-    return success();
-
-  // Convert the types of each of the block arguments.
-  SmallVector<SmallVector<Type, 1>, 4> newArgTypes(origArgCount);
-  for (unsigned i = 0; i != origArgCount; ++i) {
-    auto *arg = block->getArgument(i);
-    if (failed(typeConverter->convertType(arg->getType(), newArgTypes[i])))
-      return emitError(block->getParent()->getLoc())
-             << "could not convert block argument of type " << arg->getType();
-  }
-
-  // Remap all of the original argument values.
-  auto &newArgMapping = argMapping[block];
-  rewriter.setInsertionPointToStart(block);
-  for (unsigned i = 0; i != origArgCount; ++i) {
-    SmallVector<Value *, 1> newArgs(block->addArguments(newArgTypes[i]));
-    newArgMapping.push_back(
-        convertArgument(block->getArgument(i), newArgs, mapping));
-  }
-
-  // Erase all of the original arguments.
-  for (unsigned i = 0; i != origArgCount; ++i)
-    block->eraseArgument(0, /*updatePredTerms=*/false);
-  return success();
-}
-
 /// Convert the given block argument given the provided set of new argument
 /// values that are to replace it. This function returns the operation used
 /// to perform the conversion.
@@ -304,9 +290,10 @@ Operation *ArgConverter::createCast(ArrayRef<Value *> inputs, Type outputType) {
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
   RewriterState(unsigned numCreatedOperations, unsigned numReplacements,
-                unsigned numBlockActions)
+                unsigned numBlockActions, unsigned numTypeConversions)
       : numCreatedOperations(numCreatedOperations),
-        numReplacements(numReplacements), numBlockActions(numBlockActions) {}
+        numReplacements(numReplacements), numBlockActions(numBlockActions),
+        numTypeConversions(numTypeConversions) {}
 
   /// The current number of created operations.
   unsigned numCreatedOperations;
@@ -316,6 +303,9 @@ struct RewriterState {
 
   /// The current number of block actions performed.
   unsigned numBlockActions;
+
+  /// The current number of type conversion actions performed.
+  unsigned numTypeConversions;
 };
 
 /// This class implements a pattern rewriter for ConversionPattern
@@ -362,6 +352,16 @@ struct DialectConversionRewriter final : public PatternRewriter {
     BlockActionKind kind;
   };
 
+  /// A storage class representing a type conversion of a block or a region.
+  struct TypeConversion {
+    /// The region, or block, that had its types converted.
+    llvm::PointerUnion<Region *, Block *> object;
+
+    /// If the object is a region, this corresponds to the original attributes
+    /// of the parent operation.
+    NamedAttributeList originalParentAttributes;
+  };
+
   DialectConversionRewriter(MLIRContext *ctx, TypeConverter *converter)
       : PatternRewriter(ctx), argConverter(converter, *this) {}
   ~DialectConversionRewriter() = default;
@@ -369,11 +369,15 @@ struct DialectConversionRewriter final : public PatternRewriter {
   /// Return the current state of the rewriter.
   RewriterState getCurrentState() {
     return RewriterState(createdOps.size(), replacements.size(),
-                         blockActions.size());
+                         blockActions.size(), typeConversions.size());
   }
 
   /// Reset the state of the rewriter to a previously saved point.
   void resetState(RewriterState state) {
+    // Undo any type conversions or block actions.
+    undoTypeConversions(state.numTypeConversions);
+    undoBlockActions(state.numBlockActions);
+
     // Reset any replaced operations and undo any saved mappings.
     for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
       for (auto *result : repl.op->getResults())
@@ -383,9 +387,6 @@ struct DialectConversionRewriter final : public PatternRewriter {
     // Pop all of the newly created operations.
     while (createdOps.size() != state.numCreatedOperations)
       createdOps.pop_back_val()->erase();
-
-    // Undo any block operations.
-    undoBlockActions(state.numBlockActions);
   }
 
   /// Undo the block actions (motions, splits) one by one in reverse order until
@@ -412,20 +413,34 @@ struct DialectConversionRewriter final : public PatternRewriter {
       }
       }
     }
+    blockActions.resize(numActionsToKeep);
+  }
+
+  /// Undo the type conversion actions one by one, until "numActionsToKeep"
+  /// actions remain.
+  void undoTypeConversions(unsigned numActionsToKeep = 0) {
+    for (auto &conversion :
+         llvm::drop_begin(typeConversions, numActionsToKeep)) {
+      if (auto *region = conversion.object.dyn_cast<Region *>())
+        region->getContainingOp()->setAttrs(
+            conversion.originalParentAttributes);
+      else
+        argConverter.discardPendingRewrites(conversion.object.get<Block *>());
+    }
+    typeConversions.resize(numActionsToKeep);
   }
 
   /// Cleanup and destroy any generated rewrite operations. This method is
   /// invoked when the conversion process fails.
   void discardRewrites() {
-    argConverter.discardRewrites();
+    undoTypeConversions();
+    undoBlockActions();
 
     // Remove any newly created ops.
     for (auto *op : createdOps) {
       op->dropAllDefinedValueUses();
       op->erase();
     }
-
-    undoBlockActions();
   }
 
   /// Apply all requested operation rewrites. This method is invoked when the
@@ -439,9 +454,10 @@ struct DialectConversionRewriter final : public PatternRewriter {
 
       // if this operation defines any regions, drop any pending argument
       // rewrites.
-      if (repl.op->getNumRegions() && !argConverter.argMapping.empty()) {
+      if (argConverter.typeConverter && repl.op->getNumRegions()) {
         for (auto &region : repl.op->getRegions())
-          argConverter.cancelPendingRewrites(region);
+          for (auto &block : region)
+            argConverter.cancelPendingRewrites(&block);
       }
     }
 
@@ -454,6 +470,32 @@ struct DialectConversionRewriter final : public PatternRewriter {
     argConverter.applyRewrites();
   }
 
+  /// Return if the given block has already been converted.
+  bool hasSignatureBeenConverted(Block *block) {
+    return argConverter.hasBeenConverted(block);
+  }
+
+  /// Convert the signature of the given region.
+  LogicalResult convertRegionSignature(Region &region) {
+    auto parentAttrs = region.getContainingOp()->getAttrList();
+    auto result = argConverter.convertSignature(region, mapping);
+    if (succeeded(result)) {
+      typeConversions.push_back(TypeConversion{&region, parentAttrs});
+      if (!region.empty())
+        typeConversions.push_back(
+            TypeConversion{&region.front(), NamedAttributeList()});
+    }
+    return result;
+  }
+
+  /// Convert the signature of the given block.
+  LogicalResult convertBlockSignature(Block *block) {
+    auto result = argConverter.convertSignature(block, mapping);
+    if (succeeded(result))
+      typeConversions.push_back(TypeConversion{block, NamedAttributeList()});
+    return result;
+  }
+
   /// PatternRewriter hook for replacing the results of an operation.
   void replaceOp(Operation *op, ArrayRef<Value *> newValues,
                  ArrayRef<Value *> valuesToRemoveIfDead) override {
@@ -535,6 +577,9 @@ struct DialectConversionRewriter final : public PatternRewriter {
 
   /// Ordered list of block operations (creations, splits, motions).
   SmallVector<BlockAction, 4> blockActions;
+
+  /// Ordered list of type conversion actions.
+  SmallVector<TypeConversion, 4> typeConversions;
 };
 } // end anonymous namespace
 
@@ -649,6 +694,18 @@ bool OperationLegalizer::isIllegal(Operation *op) const {
 LogicalResult
 OperationLegalizer::legalize(Operation *op,
                              DialectConversionRewriter &rewriter) {
+  // Make sure that the signature of the parent block of this operation has been
+  // converted.
+  if (rewriter.argConverter.typeConverter) {
+    auto *block = op->getBlock();
+    if (block && !rewriter.hasSignatureBeenConverted(block)) {
+      if (failed(block->isEntryBlock()
+                     ? rewriter.convertRegionSignature(*block->getParent())
+                     : rewriter.convertBlockSignature(block)))
+        return failure();
+    }
+  }
+
   LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName()
                           << "\n");
 
@@ -875,37 +932,21 @@ enum OpConversionMode {
 struct OperationConverter {
   explicit OperationConverter(ConversionTarget &target,
                               OwningRewritePatternList &patterns,
-                              OpConversionMode mode,
-                              TypeConverter *conversion = nullptr)
-      : typeConverter(conversion), opLegalizer(target, patterns), mode(mode) {}
-
-  /// Converts the given function to the conversion target. Returns failure on
-  /// error, success otherwise.
-  LogicalResult
-  convertFunction(FuncOp f,
-                  TypeConverter::SignatureConversion &signatureConversion);
+                              OpConversionMode mode)
+      : opLegalizer(target, patterns), mode(mode) {}
 
   /// Converts the given operations to the conversion target.
-  LogicalResult convertOperations(ArrayRef<Operation *> ops);
+  LogicalResult convertOperations(ArrayRef<Operation *> ops,
+                                  TypeConverter *typeConverter);
 
 private:
-  /// Converts a block or operation with the given rewriter.
-  LogicalResult convert(DialectConversionRewriter &rewriter,
-                        llvm::PointerUnion<Operation *, Block *> &ptr);
-
-  /// Converts a set of blocks/operations with the given rewriter.
-  LogicalResult
-  convert(DialectConversionRewriter &rewriter,
-          std::vector<llvm::PointerUnion<Operation *, Block *>> &toConvert);
-
-  /// Recursively collect all of the blocks, and operations, to convert from
-  /// within 'region'.
-  LogicalResult computeConversionSet(
-      Region &region,
-      std::vector<llvm::PointerUnion<Operation *, Block *>> &toConvert);
-
-  /// Pointer to the type converter.
-  TypeConverter *typeConverter;
+  /// Converts an operation with the given rewriter.
+  LogicalResult convert(DialectConversionRewriter &rewriter, Operation *op);
+
+  /// Recursively collect all of the operations, to convert from within
+  /// 'region'.
+  LogicalResult computeConversionSet(Region &region,
+                                     std::vector<Operation *> &toConvert);
 
   /// The legalizer to use when converting operations.
   OperationLegalizer opLegalizer;
@@ -916,9 +957,9 @@ private:
 } // end anonymous namespace
 
 /// Recursively collect all of the blocks to convert from within 'region'.
-LogicalResult OperationConverter::computeConversionSet(
-    Region &region,
-    std::vector<llvm::PointerUnion<Operation *, Block *>> &toConvert) {
+LogicalResult
+OperationConverter::computeConversionSet(Region &region,
+                                         std::vector<Operation *> &toConvert) {
   if (region.empty())
     return success();
 
@@ -929,10 +970,6 @@ LogicalResult OperationConverter::computeConversionSet(
   while (!worklist.empty()) {
     auto *block = worklist.pop_back_val();
 
-    // We only need to process blocks if we are changing argument types.
-    if (typeConverter)
-      toConvert.emplace_back(block);
-
     // Compute the conversion set of each of the nested operations.
     for (auto &op : *block) {
       toConvert.emplace_back(&op);
@@ -953,18 +990,10 @@ LogicalResult OperationConverter::computeConversionSet(
   return success();
 }
 
-/// Converts a block or operation with the given rewriter.
-LogicalResult
-OperationConverter::convert(DialectConversionRewriter &rewriter,
-                            llvm::PointerUnion<Operation *, Block *> &ptr) {
-  // If this is a block, then convert the types of each of the arguments.
-  if (auto *block = ptr.dyn_cast<Block *>()) {
-    assert(typeConverter && "expected valid type converter");
-    return rewriter.argConverter.convertArguments(block, rewriter.mapping);
-  }
-
-  // Otherwise, legalize the given operation.
-  auto *op = ptr.get<Operation *>();
+/// Converts an operation with the given rewriter.
+LogicalResult OperationConverter::convert(DialectConversionRewriter &rewriter,
+                                          Operation *op) {
+  // Legalize the given operation.
   if (failed(opLegalizer.legalize(op, rewriter))) {
     // Handle the case of a failed conversion for each of the different modes.
     /// Full conversions expect all operations to be converted.
@@ -978,50 +1007,30 @@ OperationConverter::convert(DialectConversionRewriter &rewriter,
              << "failed to legalize operation '" << op->getName()
              << "' that was explicitly marked illegal";
   }
-  return success();
-}
 
-LogicalResult OperationConverter::convert(
-    DialectConversionRewriter &rewriter,
-    std::vector<llvm::PointerUnion<Operation *, Block *>> &toConvert) {
-  // Convert each operation/block and discard rewrites on failure.
-  for (auto &it : toConvert) {
-    if (failed(convert(rewriter, it))) {
-      rewriter.discardRewrites();
-      return failure();
-    }
+  // Convert the signature of any empty regions of this operation, non-empty
+  // regions are converted on demand when converting any operations contained
+  // within.
+  // FIXME(riverriddle) This should be replaced by patterns when the pattern
+  // rewriter exposes functionality to remap region signatures.
+  if (rewriter.argConverter.typeConverter) {
+    for (auto &region : op->getRegions())
+      if (region.empty() && failed(rewriter.convertRegionSignature(region)))
+        return failure();
   }
 
-  // Otherwise the body conversion succeeded, so apply all rewrites.
-  rewriter.applyRewrites();
   return success();
 }
 
-LogicalResult OperationConverter::convertFunction(
-    FuncOp f, TypeConverter::SignatureConversion &signatureConversion) {
-  // If this is an external function, there is nothing else to do.
-  if (f.isExternal())
-    return success();
-
-  // Update the signature of the entry block.
-  DialectConversionRewriter rewriter(f.getContext(), typeConverter);
-  rewriter.argConverter.convertSignature(&f.getBody().front(),
-                                         signatureConversion, rewriter.mapping);
-
-  // Compute the set of operations and blocks to convert.
-  std::vector<llvm::PointerUnion<Operation *, Block *>> toConvert;
-  if (failed(computeConversionSet(f.getBody(), toConvert)))
-    return failure();
-  return convert(rewriter, toConvert);
-}
-
 /// Converts the given top-level operation to the conversion target.
-LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
+LogicalResult
+OperationConverter::convertOperations(ArrayRef<Operation *> ops,
+                                      TypeConverter *typeConverter) {
   if (ops.empty())
     return success();
 
   /// Compute the set of operations and blocks to convert.
-  std::vector<llvm::PointerUnion<Operation *, Block *>> toConvert;
+  std::vector<Operation *> toConvert;
   for (auto *op : ops) {
     toConvert.emplace_back(op);
     for (auto &region : op->getRegions())
@@ -1029,9 +1038,18 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
         return failure();
   }
 
-  // Rewrite the blocks and operations.
+  // Convert each operation and discard rewrites on failure.
   DialectConversionRewriter rewriter(ops.front()->getContext(), typeConverter);
-  return convert(rewriter, toConvert);
+  for (auto *op : toConvert) {
+    if (failed(convert(rewriter, op))) {
+      rewriter.discardRewrites();
+      return failure();
+    }
+  }
+
+  // Otherwise the body conversion succeeded, so apply all rewrites.
+  rewriter.applyRewrites();
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1045,27 +1063,19 @@ void TypeConverter::SignatureConversion::addResults(ArrayRef<Type> results) {
 
 /// Remap an input of the original signature with a new set of types. The
 /// new types are appended to the new signature conversion.
-void TypeConverter::SignatureConversion::addInputs(
-    unsigned origInputNo, ArrayRef<Type> types,
-    ArrayRef<NamedAttributeList> attrs) {
+void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
+                                                   ArrayRef<Type> types) {
   assert(!types.empty() && "expected valid types");
   remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
-  addInputs(types, attrs);
+  addInputs(types);
 }
 
 /// Append new input types to the signature conversion, this should only be
 /// used if the new types are not intended to remap an existing input.
-void TypeConverter::SignatureConversion::addInputs(
-    ArrayRef<Type> types, ArrayRef<NamedAttributeList> attrs) {
+void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
   assert(!types.empty() &&
          "1->0 type remappings don't need to be added explicitly");
-  assert(attrs.empty() || types.size() == attrs.size());
-
   argTypes.append(types.begin(), types.end());
-  if (attrs.empty())
-    argAttrs.resize(argTypes.size());
-  else
-    argAttrs.append(attrs.begin(), attrs.end());
 }
 
 /// Remap an input of the original signature with a range of types in the
@@ -1089,23 +1099,20 @@ LogicalResult TypeConverter::convertType(Type t,
 }
 
 /// Convert the given FunctionType signature.
-auto TypeConverter::convertSignature(FunctionType type,
-                                     ArrayRef<NamedAttributeList> argAttrs)
+auto TypeConverter::convertSignature(FunctionType type)
     -> llvm::Optional<SignatureConversion> {
   SignatureConversion result(type.getNumInputs());
-  if (failed(convertSignature(type, argAttrs, result)))
+  if (failed(convertSignature(type, result)))
     return llvm::None;
   return result;
 }
 
 /// This hook allows for changing a FunctionType signature.
-LogicalResult
-TypeConverter::convertSignature(FunctionType type,
-                                ArrayRef<NamedAttributeList> argAttrs,
-                                SignatureConversion &result) {
+LogicalResult TypeConverter::convertSignature(FunctionType type,
+                                              SignatureConversion &result) {
   // Convert the original function arguments.
   for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
-    if (failed(convertSignatureArg(i, type.getInput(i), argAttrs[i], result)))
+    if (failed(convertSignatureArg(i, type.getInput(i), result)))
       return failure();
 
   // Convert the original function results.
@@ -1122,7 +1129,6 @@ TypeConverter::convertSignature(FunctionType type,
 
 /// This hook allows for converting a specific argument of a signature.
 LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
-                                                 NamedAttributeList attrs,
                                                  SignatureConversion &result) {
   // Try to convert the given input type.
   SmallVector<Type, 1> convertedTypes;
@@ -1134,12 +1140,53 @@ LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
     return success();
 
   // Otherwise, add the new inputs.
-  auto convertedAttrs =
-      convertedTypes.size() == 1 ? llvm::makeArrayRef(attrs) : llvm::None;
-  result.addInputs(inputNo, convertedTypes, convertedAttrs);
+  result.addInputs(inputNo, convertedTypes);
   return success();
 }
 
+/// This hook defines how the signature of a region 'regionIdx', i.e. the
+/// signature of the entry to the region, on the given operation 'op' is
+/// converted. This function should return a valid conversion for the signature
+/// on success, None otherwise.
+///
+/// The default behavior of this function is to invoke 'convertBlockSignature'
+/// on the entry block, if one is present. This function also provides special
+/// handling for FuncOp to update the type signature.
+///
+/// TODO(riverriddle) This should be replaced in favor of using patterns, but
+/// the pattern rewriter needs to know how to properly replace/remap
+/// arguments.
+auto TypeConverter::convertRegionSignature(Operation *op, unsigned regionIdx)
+    -> llvm::Optional<SignatureConversion> {
+  // Provide explicit handling for FuncOp.
+  if (auto funcOp = dyn_cast<FuncOp>(op)) {
+    auto conversion = convertSignature(funcOp.getType());
+    if (conversion)
+      funcOp.setType(conversion->getConvertedType(funcOp.getContext()));
+    return conversion;
+  }
+
+  // Otherwise, default to handle the arguments of the entry block for the given
+  // region.
+  auto &region = op->getRegion(regionIdx);
+  if (region.empty())
+    return SignatureConversion(/*numOrigInputs=*/0);
+  return convertBlockSignature(&region.front());
+}
+
+/// This function converts the type signature of the given block, by invoking
+/// 'convertSignatureArg' for each argument. This function should return a valid
+/// conversion for the signature on success, None otherwise.
+auto TypeConverter::convertBlockSignature(Block *block)
+    -> llvm::Optional<SignatureConversion> {
+  SignatureConversion conversion(block->getNumArguments());
+  for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i)
+    if (failed(convertSignatureArg(i, block->getArgument(i)->getType(),
+                                   conversion)))
+      return llvm::None;
+  return conversion;
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionTarget
 //===----------------------------------------------------------------------===//
@@ -1178,18 +1225,19 @@ auto ConversionTarget::getOpAction(OperationName op) const
 /// Apply a partial conversion on the given operations, and all nested
 /// operations. This method converts as many operations to the target as
 /// possible, ignoring operations that failed to legalize.
-LogicalResult
-mlir::applyPartialConversion(ArrayRef<Operation *> ops,
-                             ConversionTarget &target,
-                             OwningRewritePatternList &&patterns) {
-  OperationConverter converter(target, patterns, OpConversionMode::Partial);
-  return converter.convertOperations(ops);
+LogicalResult mlir::applyPartialConversion(ArrayRef<Operation *> ops,
+                                           ConversionTarget &target,
+                                           OwningRewritePatternList &&patterns,
+                                           TypeConverter *converter) {
+  OperationConverter opConverter(target, patterns, OpConversionMode::Partial);
+  return opConverter.convertOperations(ops, converter);
 }
-LogicalResult
-mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
-                             OwningRewritePatternList &&patterns) {
+LogicalResult mlir::applyPartialConversion(Operation *op,
+                                           ConversionTarget &target,
+                                           OwningRewritePatternList &&patterns,
+                                           TypeConverter *converter) {
   return applyPartialConversion(llvm::makeArrayRef(op), target,
-                                std::move(patterns));
+                                std::move(patterns), converter);
 }
 
 /// Apply a complete conversion on the given operations, and all nested
@@ -1197,95 +1245,14 @@ mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
 /// operation fails.
 LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
                                         ConversionTarget &target,
-                                        OwningRewritePatternList &&patterns) {
-  OperationConverter converter(target, patterns, OpConversionMode::Full);
-  return converter.convertOperations(ops);
+                                        OwningRewritePatternList &&patterns,
+                                        TypeConverter *converter) {
+  OperationConverter opConverter(target, patterns, OpConversionMode::Full);
+  return opConverter.convertOperations(ops, converter);
 }
 LogicalResult mlir::applyFullConversion(Operation *op, ConversionTarget &target,
-                                        OwningRewritePatternList &&patterns) {
+                                        OwningRewritePatternList &&patterns,
+                                        TypeConverter *converter) {
   return applyFullConversion(llvm::makeArrayRef(op), target,
-                             std::move(patterns));
-}
-
-//===----------------------------------------------------------------------===//
-// Op + Type Conversion Entry Points
-//===----------------------------------------------------------------------===//
-
-static LogicalResult applyConversion(MutableArrayRef<FuncOp> fns,
-                                     ConversionTarget &target,
-                                     TypeConverter &converter,
-                                     OwningRewritePatternList &&patterns,
-                                     OpConversionMode mode) {
-  if (fns.empty())
-    return success();
-
-  // Build the function converter.
-  OperationConverter funcConverter(target, patterns, mode, &converter);
-
-  // Try to convert each of the functions within the module.
-  SmallVector<NamedAttributeList, 4> argAttrs;
-  auto *ctx = fns.front().getContext();
-  for (auto func : fns) {
-    argAttrs.clear();
-    func.getAllArgAttrs(argAttrs);
-
-    // Convert the function type using the type converter.
-    auto conversion = converter.convertSignature(func.getType(), argAttrs);
-    if (!conversion)
-      return failure();
-
-    // Update the function signature.
-    func.setType(conversion->getConvertedType(ctx));
-    func.setAllArgAttrs(conversion->getConvertedArgAttrs());
-
-    // Convert the body of this function.
-    if (failed(funcConverter.convertFunction(func, *conversion)))
-      return failure();
-  }
-
-  return success();
-}
-
-/// Apply a partial conversion on the function operations within the given
-/// module. This method returns failure if a type conversion was encountered.
-LogicalResult
-mlir::applyPartialConversion(ModuleOp module, ConversionTarget &target,
-                             TypeConverter &converter,
-                             OwningRewritePatternList &&patterns) {
-  SmallVector<FuncOp, 32> allFunctions(module.getOps<FuncOp>());
-  return applyPartialConversion(allFunctions, target, converter,
-                                std::move(patterns));
-}
-
-/// Apply a partial conversion on the given function operations. This method
-/// returns failure if a type conversion was encountered.
-LogicalResult
-mlir::applyPartialConversion(MutableArrayRef<FuncOp> fns,
-                             ConversionTarget &target, TypeConverter &converter,
-                             OwningRewritePatternList &&patterns) {
-  return applyConversion(fns, target, converter, std::move(patterns),
-                         OpConversionMode::Partial);
-}
-
-/// Apply a full conversion on the function operations within the given module.
-/// This method returns failure if a type conversion was encountered, or if the
-/// conversion of any operations failed.
-LogicalResult mlir::applyFullConversion(ModuleOp module,
-                                        ConversionTarget &target,
-                                        TypeConverter &converter,
-                                        OwningRewritePatternList &&patterns) {
-  SmallVector<FuncOp, 32> allFunctions(module.getOps<FuncOp>());
-  return applyFullConversion(allFunctions, target, converter,
-                             std::move(patterns));
-}
-
-/// Apply a full conversion on the given function operations. This method
-/// returns failure if a type conversion was encountered, or if the conversion
-/// of any operation failed.
-LogicalResult mlir::applyFullConversion(MutableArrayRef<FuncOp> fns,
-                                        ConversionTarget &target,
-                                        TypeConverter &converter,
-                                        OwningRewritePatternList &&patterns) {
-  return applyConversion(fns, target, converter, std::move(patterns),
-                         OpConversionMode::Full);
+                             std::move(patterns), converter);
 }
index dbcca99c5f428df4b69d2db9b844b39e65c1a30a..9fb6cc99dd72022f398660905dd2e793ab332227 100644 (file)
@@ -107,3 +107,29 @@ func @fail_to_convert_illegal_op() -> i32 {
   %result = "test.illegal_op_f"() : () -> (i32)
   return %result : i32
 }
+
+// -----
+
+func @fail_to_convert_illegal_op_in_region() {
+  // expected-error@+1 {{failed to legalize operation 'test.region_builder'}}
+  "test.region_builder"() : () -> ()
+  return
+}
+
+// -----
+
+// Check that the entry block arguments of a region are untouched in the case
+// of failure.
+
+// CHECK-LABEL: func @fail_to_convert_region
+func @fail_to_convert_region() {
+  // CHECK-NEXT: "test.drop_op"
+  // CHECK-NEXT: ^bb{{.*}}(%{{.*}}: i64):
+  "test.drop_op"() ({
+    ^bb1(%i0: i64):
+      // expected-error@+1 {{failed to legalize operation 'test.region_builder'}}
+      "test.region_builder"() : () -> ()
+      "test.valid"() : () -> ()
+  }) : () -> ()
+  return
+}
index 63e985546b9a7346322684762010114013321d66..719fca27b252de0d0bc74ea18a7e05287ccaaed6 100644 (file)
@@ -449,6 +449,7 @@ def : Pat<(TestRewriteOp $input), (replaceWithValue $input)>;
 // Test Type Legalization
 //===----------------------------------------------------------------------===//
 
+def TestRegionBuilderOp : TEST_Op<"region_builder">;
 def TestReturnOp : TEST_Op<"return", [Terminator]>,
   Arguments<(ins Variadic<AnyType>:$inputs)>;
 def TestCastOp : TEST_Op<"cast">,
index 8b89a3a05af6248a2bc54b57ed7a140a1bdcf48d..6b1266ec6523b0a68929ef8a4c4c4413fe266fe5 100644 (file)
@@ -74,6 +74,30 @@ struct TestRegionRewriteBlockMovement : public ConversionPattern {
     return matchSuccess();
   }
 };
+/// This pattern is a simple pattern that generates a region containing an
+/// illegal operation.
+struct TestRegionRewriteUndo : public RewritePattern {
+  TestRegionRewriteUndo(MLIRContext *ctx)
+      : RewritePattern("test.region_builder", 1, ctx) {}
+
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const final {
+    // Create the region operation with an entry block containing arguments.
+    OperationState newRegion(op->getLoc(), "test.region");
+    newRegion.addRegion();
+    auto *regionOp = rewriter.createOperation(newRegion);
+    auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
+    entryBlock->addArgument(rewriter.getIntegerType(64));
+
+    // Add an explicitly illegal operation to ensure the conversion fails.
+    rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
+    rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value *>());
+
+    // Drop this operation.
+    rewriter.replaceOp(op, llvm::None);
+    return matchSuccess();
+  }
+};
 /// This pattern simply erases the given operation.
 struct TestDropOp : public ConversionPattern {
   TestDropOp(MLIRContext *ctx) : ConversionPattern("test.drop_op", 1, ctx) {}
@@ -158,7 +182,7 @@ struct TestConversionTarget : public ConversionTarget {
   TestConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
     addLegalOp<LegalOpA, TestValidOp>();
     addDynamicallyLegalOp<TestReturnOp>();
-    addIllegalOp<ILLegalOpF>();
+    addIllegalOp<ILLegalOpF, TestRegionBuilderOp>();
   }
   bool isDynamicallyLegal(Operation *op) const final {
     // Don't allow F32 operands.
@@ -172,15 +196,14 @@ struct TestLegalizePatternDriver
   void runOnModule() override {
     mlir::OwningRewritePatternList patterns;
     populateWithGenerated(&getContext(), &patterns);
-    RewriteListBuilder<TestRegionRewriteBlockMovement, TestDropOp,
-                       TestPassthroughInvalidOp,
+    RewriteListBuilder<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
+                       TestDropOp, TestPassthroughInvalidOp,
                        TestSplitReturnType>::build(patterns, &getContext());
 
     TestTypeConverter converter;
     TestConversionTarget target(getContext());
-    if (failed(applyPartialConversion(getModule(), target, converter,
-                                      std::move(patterns))))
-      signalPassFailure();
+    (void)applyPartialConversion(getModule(), target, std::move(patterns),
+                                 &converter);
   }
 };
 } // end anonymous namespace