Refactor region type signature conversion to be explicit via patterns.
authorRiver Riddle <riverriddle@google.com>
Sun, 21 Jul 2019 02:05:41 +0000 (19:05 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Sun, 21 Jul 2019 02:06:07 +0000 (19:06 -0700)
This cl enforces that the conversion of the type signatures for regions, and thus their entry blocks, is handled via ConversionPatterns. A new hook 'applySignatureConversion' is added to the ConversionPatternRewriter to perform the desired conversion on a region. This also means that the handling of rewriting the signature of a FuncOp is moved to a pattern. A default implementation is provided via 'mlir::populateFuncOpTypeConversionPattern'. This removes the hacky implicit 'dynamically legal' status of FuncOp that was present previously, and leaves it up to the user to decide when/how to convert the signature of a function.

PiperOrigin-RevId: 259161999

mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
mlir/examples/toy/Ch5/mlir/LateLowering.cpp
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/Transforms/test-legalizer.mlir
mlir/test/lib/TestDialect/TestPatterns.cpp

index 67b0ac0..411a7af 100644 (file)
@@ -425,7 +425,9 @@ LogicalResult linalg::convertToLLVM(mlir::ModuleOp module) {
 
   ConversionTarget target(*module.getContext());
   target.addLegalDialect<LLVM::LLVMDialect>();
-  target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
+  target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
+  target.addDynamicallyLegalOp<FuncOp>(
+      [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
   return applyFullConversion(module, target, std::move(patterns), &converter);
 }
 
index 68a48d6..8c77737 100644 (file)
@@ -162,7 +162,9 @@ LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) {
 
   ConversionTarget target(*module.getContext());
   target.addLegalDialect<LLVM::LLVMDialect>();
-  target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
+  target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
+  target.addDynamicallyLegalOp<FuncOp>(
+      [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
   if (failed(
           applyFullConversion(module, target, std::move(patterns), &converter)))
     return failure();
index 8b80588..5a01122 100644 (file)
@@ -355,12 +355,17 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
     RewriteListBuilder<AddOpConversion, PrintOpConversion, ConstantOpConversion,
                        TransposeOpConversion,
                        ReturnOpConversion>::build(toyPatterns, &getContext());
+    mlir::populateFuncOpTypeConversionPattern(toyPatterns, &getContext(),
+                                              typeConverter);
 
     // Perform Toy specific lowering.
     ConversionTarget target(getContext());
     target.addLegalDialect<AffineOpsDialect, linalg::LinalgDialect,
                            LLVM::LLVMDialect, StandardOpsDialect>();
     target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
+    target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+      return typeConverter.isSignatureLegal(op.getType());
+    });
     if (failed(applyPartialConversion(
             getModule(), target, std::move(toyPatterns), &typeConverter))) {
       emitError(UnknownLoc::get(getModule().getContext()),
index 864d805..d5c4c11 100644 (file)
@@ -63,12 +63,6 @@ public:
   LLVM::LLVMDialect *getDialect() { return llvmDialect; }
 
 protected:
-  /// Convert function signatures to LLVM IR.  In particular, convert functions
-  /// with multiple results into functions returning LLVM IR's structure type.
-  /// Use `convertType` to convert individual argument and result types.
-  LogicalResult convertSignature(FunctionType t,
-                                 SignatureConversion &result) final;
-
   /// LLVM IR module used to parse/create types.
   llvm::Module *module;
   LLVM::LLVMDialect *llvmDialect;
index 5543c21..1ffd5bb 100644 (file)
@@ -61,16 +61,8 @@ public:
       size_t inputNo, size;
     };
 
-    /// Return the converted type signature.
-    FunctionType getConvertedType(MLIRContext *ctx) const {
-      return FunctionType::get(argTypes, resultTypes, ctx);
-    }
-
     /// Return the argument types for the new signature.
-    ArrayRef<Type> getConvertedArgTypes() const { return argTypes; }
-
-    /// Return the result types for the new signature.
-    ArrayRef<Type> getConvertedResultTypes() const { return resultTypes; }
+    ArrayRef<Type> getConvertedTypes() const { return argTypes; }
 
     /// Get the input mapping for the given argument.
     llvm::Optional<InputMapping> getInputMapping(unsigned input) const {
@@ -81,9 +73,6 @@ public:
     // Conversion Hooks
     //===------------------------------------------------------------------===//
 
-    /// Append new result types to the signature conversion.
-    void addResults(ArrayRef<Type> results);
-
     /// Remap an input of the original signature with a new set of types. The
     /// new types are appended to the new signature conversion.
     void addInputs(unsigned origInputNo, ArrayRef<Type> types);
@@ -101,8 +90,8 @@ public:
     /// The remapping information for each of the original arguments.
     SmallVector<llvm::Optional<InputMapping>, 4> remappedInputs;
 
-    /// The set of argument and results types.
-    SmallVector<Type, 4> argTypes, resultTypes;
+    /// The set of new argument types.
+    SmallVector<Type, 4> argTypes;
   };
 
   /// This hooks allows for converting a type. This function should return
@@ -115,18 +104,19 @@ public:
   /// the type convert to on success, and a null type on failure.
   virtual Type convertType(Type t) { return t; }
 
-  /// Convert the given FunctionType signature. This functions returns a valid
-  /// SignatureConversion on success, None otherwise.
-  llvm::Optional<SignatureConversion> convertSignature(FunctionType type);
+  /// Convert the given set of types, filling 'results' as necessary. This
+  /// returns failure if the conversion of any of the types fails, success
+  /// otherwise.
+  LogicalResult convertTypes(ArrayRef<Type> types,
+                             SmallVectorImpl<Type> &results);
+
+  /// Return true if the given type is legal for this type converter, i.e. the
+  /// type converts to itself.
+  bool isLegal(Type type);
 
-  /// This hook allows for changing a FunctionType signature. This function
-  /// should populate 'result' with the new arguments and results on success,
-  /// otherwise return failure.
-  ///
-  /// The default behavior of this function is to call 'convertType' on
-  /// individual function operands and results.
-  virtual LogicalResult convertSignature(FunctionType type,
-                                         SignatureConversion &result);
+  /// Return true if the inputs and outputs of the given function type are
+  /// legal.
+  bool isSignatureLegal(FunctionType funcType);
 
   /// This hook allows for converting a specific argument of a signature. It
   /// takes as inputs the original argument input number, type.
@@ -134,22 +124,6 @@ public:
   virtual LogicalResult convertSignatureArg(unsigned inputNo, Type type,
                                             SignatureConversion &result);
 
-  /// This hook allows for converting the signature of a region 'regionIdx',
-  /// i.e. the signature of the entry to the region, on the given operation
-  /// 'op'. This function should return a valid conversion for the signature on
-  /// success, None otherwise. This hook is allowed to modify the attributes on
-  /// the provided operation if necessary.
-  ///
-  /// The default behavior of this function is to invoke 'convertBlockSignature'
-  /// on the entry block, if one is present. This function also provides special
-  /// handling for FuncOp to update the type signature.
-  ///
-  /// TODO(riverriddle) This should be replaced in favor of using patterns, but
-  /// the pattern rewriter needs to know how to properly replace/remap
-  /// arguments.
-  virtual llvm::Optional<SignatureConversion>
-  convertRegionSignature(Operation *op, unsigned regionIdx);
-
   /// This function converts the type signature of the given block, by invoking
   /// 'convertSignatureArg' for each argument. This function should return a
   /// valid conversion for the signature on success, None otherwise.
@@ -244,6 +218,12 @@ private:
   using RewritePattern::rewrite;
 };
 
+/// Add a pattern to the given pattern list to convert the signature of a FuncOp
+/// with the given type converter.
+void populateFuncOpTypeConversionPattern(OwningRewritePatternList &patterns,
+                                         MLIRContext *ctx,
+                                         TypeConverter &converter);
+
 //===----------------------------------------------------------------------===//
 // Conversion PatternRewriter
 //===----------------------------------------------------------------------===//
@@ -252,12 +232,24 @@ namespace detail {
 struct ConversionPatternRewriterImpl;
 } // end namespace detail
 
-/// This class implements a pattern rewriter for use with ConversionPatterns.
+/// This class implements a pattern rewriter for use with ConversionPatterns. It
+/// extends the base PatternRewriter and provides special conversion specific
+/// hooks.
 class ConversionPatternRewriter final : public PatternRewriter {
 public:
   ConversionPatternRewriter(MLIRContext *ctx, TypeConverter *converter);
   ~ConversionPatternRewriter() override;
 
+  /// Apply a signature conversion to the entry block of the given region.
+  void applySignatureConversion(Region *region,
+                                TypeConverter::SignatureConversion &conversion);
+
+  /// Clone the given operation without cloning its regions.
+  Operation *cloneWithoutRegions(Operation *op);
+  template <typename OpT> OpT cloneWithoutRegions(OpT op) {
+    return cast<OpT>(cloneWithoutRegions(op.getOperation()));
+  }
+
   //===--------------------------------------------------------------------===//
   // PatternRewriter Hooks
   //===--------------------------------------------------------------------===//
index 042e768..c17909b 100644 (file)
@@ -266,6 +266,44 @@ protected:
   LLVM::LLVMDialect &dialect;
 };
 
+struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
+  using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto funcOp = cast<FuncOp>(op);
+    FunctionType type = funcOp.getType();
+
+    // Convert the original function arguments.
+    TypeConverter::SignatureConversion result(type.getNumInputs());
+    for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
+      if (failed(lowering.convertSignatureArg(i, type.getInput(i), result)))
+        return matchFailure();
+
+    // Pack the result types into a struct.
+    Type packedResult;
+    if (type.getNumResults() != 0) {
+      if (!(packedResult = lowering.packFunctionResults(type.getResults())))
+        return matchFailure();
+    }
+
+    // Create a new function with an updated signature.
+    auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+    rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                                newFuncOp.end());
+    newFuncOp.setType(FunctionType::get(
+        result.getConvertedTypes(),
+        packedResult ? ArrayRef<Type>(packedResult) : llvm::None,
+        funcOp.getContext()));
+
+    // Tell the rewriter to convert the region signature.
+    rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
+    rewriter.replaceOp(op, llvm::None);
+    return matchSuccess();
+  }
+};
+
 // Basic lowering implementation for one-to-one rewriting from Standard Ops to
 // LLVM Dialect Ops.
 template <typename SourceOp, typename TargetOp>
@@ -985,10 +1023,10 @@ void mlir::populateStdToLLVMConversionPatterns(
       BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
       CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering,
       DimOpLowering, DivISOpLowering, DivIUOpLowering, DivFOpLowering,
-      IndexCastOpLowering, LoadOpLowering, MemRefCastOpLowering, MulFOpLowering,
-      MulIOpLowering, OrOpLowering, RemISOpLowering, RemIUOpLowering,
-      RemFOpLowering, ReturnOpLowering, SelectOpLowering, StoreOpLowering,
-      SubFOpLowering, SubIOpLowering,
+      FuncOpConversion, IndexCastOpLowering, LoadOpLowering,
+      MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
+      RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
+      SelectOpLowering, StoreOpLowering, SubFOpLowering, SubIOpLowering,
       XOrOpLowering>::build(patterns, *converter.getDialect(), converter);
 }
 
@@ -1014,27 +1052,6 @@ Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
   return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes);
 }
 
-// Convert function signatures using the stored LLVM IR module.
-LogicalResult LLVMTypeConverter::convertSignature(FunctionType type,
-                                                  SignatureConversion &result) {
-  // Convert the original function arguments.
-  for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
-    if (failed(convertSignatureArg(i, type.getInput(i), result)))
-      return failure();
-
-  // If function does not return anything, return immediately.
-  if (type.getNumResults() == 0)
-    return success();
-
-  // Otherwise pack the result types into a struct.
-  if (auto packedRet = packFunctionResults(type.getResults())) {
-    result.addResults(packedRet);
-    return success();
-  }
-
-  return failure();
-}
-
 /// Create an instance of LLVMTypeConverter in the given context.
 static std::unique_ptr<LLVMTypeConverter>
 makeStandardToLLVMTypeConverter(MLIRContext *context) {
@@ -1071,6 +1088,9 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
 
     ConversionTarget target(getContext());
     target.addLegalDialect<LLVM::LLVMDialect>();
+    target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+      return typeConverter->isSignatureLegal(op.getType());
+    });
     if (failed(applyPartialConversion(m, target, std::move(patterns),
                                       typeConverter.get())))
       signalPassFailure();
index 98be230..b6bfa58 100644 (file)
@@ -774,6 +774,8 @@ void LowerLinalgToLLVMPass::runOnModule() {
 
   ConversionTarget target(getContext());
   target.addLegalDialect<LLVM::LLVMDialect>();
+  target.addDynamicallyLegalOp<FuncOp>(
+      [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
   if (failed(applyPartialConversion(module, target, std::move(patterns),
                                     &converter))) {
     signalPassFailure();
index 02ca31f..aac2e11 100644 (file)
@@ -61,9 +61,6 @@ struct ArgConverter {
   /// Return if the signature of the given block has already been converted.
   bool hasBeenConverted(Block *block) const { return argMapping.count(block); }
 
-  /// Attempt to convert the signature of the given region.
-  LogicalResult convertSignature(Region &region, BlockAndValueMapping &mapping);
-
   /// Attempt to convert the signature of the given block.
   LogicalResult convertSignature(Block *block, BlockAndValueMapping &mapping);
 
@@ -196,23 +193,10 @@ void ArgConverter::applyRewrites() {
   }
 }
 
-/// Converts the signature of the given region.
-LogicalResult ArgConverter::convertSignature(Region &region,
-                                             BlockAndValueMapping &mapping) {
-  if (auto conversion = typeConverter->convertRegionSignature(
-          region.getContainingOp(), region.getRegionNumber())) {
-    if (!region.empty())
-      applySignatureConversion(&region.front(), *conversion, mapping);
-    return success();
-  }
-  return failure();
-}
-
 /// Converts the signature of the given entry block.
 LogicalResult ArgConverter::convertSignature(Block *block,
                                              BlockAndValueMapping &mapping) {
-  auto conversion = typeConverter->convertBlockSignature(block);
-  if (conversion)
+  if (auto conversion = typeConverter->convertBlockSignature(block))
     return applySignatureConversion(block, *conversion, mapping), success();
   return failure();
 }
@@ -222,7 +206,7 @@ void ArgConverter::applySignatureConversion(
     Block *block, TypeConverter::SignatureConversion &signatureConversion,
     BlockAndValueMapping &mapping) {
   unsigned origArgCount = block->getNumArguments();
-  auto convertedTypes = signatureConversion.getConvertedArgTypes();
+  auto convertedTypes = signatureConversion.getConvertedTypes();
   if (origArgCount == 0 && convertedTypes.empty())
     return;
 
@@ -292,10 +276,9 @@ namespace {
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
   RewriterState(unsigned numCreatedOperations, unsigned numReplacements,
-                unsigned numBlockActions, unsigned numTypeConversions)
+                unsigned numBlockActions)
       : numCreatedOperations(numCreatedOperations),
-        numReplacements(numReplacements), numBlockActions(numBlockActions),
-        numTypeConversions(numTypeConversions) {}
+        numReplacements(numReplacements), numBlockActions(numBlockActions) {}
 
   /// The current number of created operations.
   unsigned numCreatedOperations;
@@ -305,9 +288,6 @@ struct RewriterState {
 
   /// The current number of block actions performed.
   unsigned numBlockActions;
-
-  /// The current number of type conversion actions performed.
-  unsigned numTypeConversions;
 };
 } // end anonymous namespace
 
@@ -326,7 +306,7 @@ struct ConversionPatternRewriterImpl {
 
   /// The kind of the block action performed during the rewrite.  Actions can be
   /// undone if the conversion fails.
-  enum class BlockActionKind { Split, Move };
+  enum class BlockActionKind { Split, Move, TypeConversion };
 
   /// Original position of the given block in its parent region.  We cannot use
   /// a region iterator because it could have been invalidated by other region
@@ -339,6 +319,21 @@ struct ConversionPatternRewriterImpl {
   /// The storage class for an undoable block action (one of BlockActionKind),
   /// contains the information necessary to undo this action.
   struct BlockAction {
+    static BlockAction getSplit(Block *block, Block *originalBlock) {
+      BlockAction action{BlockActionKind::Split, block};
+      action.originalBlock = originalBlock;
+      return action;
+    }
+    static BlockAction getMove(Block *block, BlockPosition originalPos) {
+      return {BlockActionKind::Move, block, {originalPos}};
+    }
+    static BlockAction getTypeConversion(Block *block) {
+      return BlockAction{BlockActionKind::TypeConversion, block};
+    }
+
+    // The action kind.
+    BlockActionKind kind;
+
     // A pointer to the block that was created by the action.
     Block *block;
 
@@ -351,18 +346,6 @@ struct ConversionPatternRewriterImpl {
       // block that was split into two parts.
       Block *originalBlock;
     };
-
-    BlockActionKind kind;
-  };
-
-  /// A storage class representing a type conversion of a block or a region.
-  struct TypeConversion {
-    /// The region, or block, that had its types converted.
-    llvm::PointerUnion<Region *, Block *> object;
-
-    /// If the object is a region, this corresponds to the original attributes
-    /// of the parent operation.
-    NamedAttributeList originalParentAttributes;
   };
 
   ConversionPatternRewriterImpl(PatternRewriter &rewriter,
@@ -379,10 +362,6 @@ struct ConversionPatternRewriterImpl {
   /// "numActionsToKeep" actions remains.
   void undoBlockActions(unsigned numActionsToKeep = 0);
 
-  /// Undo the type conversion actions one by one, until "numActionsToKeep"
-  /// actions remain.
-  void undoTypeConversions(unsigned numActionsToKeep = 0);
-
   /// Cleanup and destroy any generated rewrite operations. This method is
   /// invoked when the conversion process fails.
   void discardRewrites();
@@ -391,15 +370,13 @@ struct ConversionPatternRewriterImpl {
   /// conversion process succeeds.
   void applyRewrites();
 
-  /// Return if the given block has already been converted.
-  bool hasSignatureBeenConverted(Block *block);
-
-  /// Convert the signature of the given region.
-  LogicalResult convertRegionSignature(Region &region);
-
   /// Convert the signature of the given block.
   LogicalResult convertBlockSignature(Block *block);
 
+  /// Apply a signature conversion on the given region.
+  void applySignatureConversion(Region *region,
+                                TypeConverter::SignatureConversion &conversion);
+
   /// PatternRewriter hook for replacing the results of an operation.
   void replaceOp(Operation *op, ArrayRef<Value *> newValues,
                  ArrayRef<Value *> valuesToRemoveIfDead);
@@ -430,21 +407,17 @@ struct ConversionPatternRewriterImpl {
 
   /// Ordered list of block operations (creations, splits, motions).
   SmallVector<BlockAction, 4> blockActions;
-
-  /// Ordered list of type conversion actions.
-  SmallVector<TypeConversion, 4> typeConversions;
 };
 } // end namespace detail
 } // end namespace mlir
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
   return RewriterState(createdOps.size(), replacements.size(),
-                       blockActions.size(), typeConversions.size());
+                       blockActions.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
-  // Undo any type conversions or block actions.
-  undoTypeConversions(state.numTypeConversions);
+  // Undo any block actions.
   undoBlockActions(state.numBlockActions);
 
   // Reset any replaced operations and undo any saved mappings.
@@ -478,24 +451,17 @@ void ConversionPatternRewriterImpl::undoBlockActions(
           action.block->getParent()->getBlocks(), action.block);
       break;
     }
+    // Undo the type conversion.
+    case BlockActionKind::TypeConversion: {
+      argConverter.discardPendingRewrites(action.block);
+      break;
+    }
     }
   }
   blockActions.resize(numActionsToKeep);
 }
 
-void ConversionPatternRewriterImpl::undoTypeConversions(
-    unsigned numActionsToKeep) {
-  for (auto &conversion : llvm::drop_begin(typeConversions, numActionsToKeep)) {
-    if (auto *region = conversion.object.dyn_cast<Region *>())
-      region->getContainingOp()->setAttrs(conversion.originalParentAttributes);
-    else
-      argConverter.discardPendingRewrites(conversion.object.get<Block *>());
-  }
-  typeConversions.resize(numActionsToKeep);
-}
-
 void ConversionPatternRewriterImpl::discardRewrites() {
-  undoTypeConversions();
   undoBlockActions();
 
   // Remove any newly created ops.
@@ -530,29 +496,30 @@ void ConversionPatternRewriterImpl::applyRewrites() {
   argConverter.applyRewrites();
 }
 
-bool ConversionPatternRewriterImpl::hasSignatureBeenConverted(Block *block) {
-  return argConverter.hasBeenConverted(block);
-}
-
 LogicalResult
-ConversionPatternRewriterImpl::convertRegionSignature(Region &region) {
-  auto parentAttrs = region.getContainingOp()->getAttrList();
-  auto result = argConverter.convertSignature(region, mapping);
-  if (succeeded(result)) {
-    typeConversions.push_back(TypeConversion{&region, parentAttrs});
-    if (!region.empty())
-      typeConversions.push_back(
-          TypeConversion{&region.front(), NamedAttributeList()});
-  }
-  return result;
+ConversionPatternRewriterImpl::convertBlockSignature(Block *block) {
+  // Check to see if this block should not be converted:
+  // * The block is invalid, or there is no type converter.
+  // * The block has already been converted.
+  // * This is an entry block, these are converted explicitly via patterns.
+  if (!block || !argConverter.typeConverter ||
+      argConverter.hasBeenConverted(block) || block->isEntryBlock())
+    return success();
+
+  // Otherwise, try to convert the block signature.
+  if (failed(argConverter.convertSignature(block, mapping)))
+    return failure();
+  blockActions.push_back(BlockAction::getTypeConversion(block));
+  return success();
 }
 
-LogicalResult
-ConversionPatternRewriterImpl::convertBlockSignature(Block *block) {
-  auto result = argConverter.convertSignature(block, mapping);
-  if (succeeded(result))
-    typeConversions.push_back(TypeConversion{block, NamedAttributeList()});
-  return result;
+void ConversionPatternRewriterImpl::applySignatureConversion(
+    Region *region, TypeConverter::SignatureConversion &conversion) {
+  if (!region->empty()) {
+    argConverter.applySignatureConversion(&region->front(), conversion,
+                                          mapping);
+    blockActions.push_back(BlockAction::getTypeConversion(&region->front()));
+  }
 }
 
 void ConversionPatternRewriterImpl::replaceOp(
@@ -574,11 +541,7 @@ void ConversionPatternRewriterImpl::replaceOp(
 
 void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
                                                      Block *continuation) {
-  BlockAction action;
-  action.kind = BlockActionKind::Split;
-  action.block = continuation;
-  action.originalBlock = block;
-  blockActions.push_back(action);
+  blockActions.push_back(BlockAction::getSplit(continuation, block));
 }
 
 void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
@@ -586,11 +549,7 @@ void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
   for (auto &pair : llvm::enumerate(region)) {
     Block &block = pair.value();
     unsigned position = pair.index();
-    BlockAction action;
-    action.kind = BlockActionKind::Move;
-    action.block = &block;
-    action.originalPosition = {&region, position};
-    blockActions.push_back(action);
+    blockActions.push_back(BlockAction::getMove(&block, {&region, position}));
   }
 }
 
@@ -618,6 +577,19 @@ void ConversionPatternRewriter::replaceOp(
   impl->replaceOp(op, newValues, valuesToRemoveIfDead);
 }
 
+/// Apply a signature conversion to the entry block of the given region.
+void ConversionPatternRewriter::applySignatureConversion(
+    Region *region, TypeConverter::SignatureConversion &conversion) {
+  impl->applySignatureConversion(region, conversion);
+}
+
+/// Clone the given operation without cloning its regions.
+Operation *ConversionPatternRewriter::cloneWithoutRegions(Operation *op) {
+  Operation *newOp = OpBuilder::cloneWithoutRegions(*op);
+  impl->createdOps.push_back(newOp);
+  return newOp;
+}
+
 /// PatternRewriter hook for splitting a block into two parts.
 Block *ConversionPatternRewriter::splitBlock(Block *block,
                                              Block::iterator before) {
@@ -766,18 +738,9 @@ bool OperationLegalizer::isIllegal(Operation *op) const {
 LogicalResult
 OperationLegalizer::legalize(Operation *op,
                              ConversionPatternRewriter &rewriter) {
-  // Make sure that the signature of the parent block of this operation has been
-  // converted.
-  auto &rewriterImpl = rewriter.getImpl();
-  if (rewriterImpl.argConverter.typeConverter) {
-    auto *block = op->getBlock();
-    if (block && !rewriterImpl.hasSignatureBeenConverted(block)) {
-      if (failed(block->isEntryBlock()
-                     ? rewriterImpl.convertRegionSignature(*block->getParent())
-                     : rewriterImpl.convertBlockSignature(block)))
-        return failure();
-    }
-  }
+  // Make sure that the signature of the parent block has been converted.
+  if (failed(rewriter.getImpl().convertBlockSignature(op->getBlock())))
+    return failure();
 
   LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName()
                           << "\n");
@@ -1008,11 +971,15 @@ private:
   /// Converts an operation with the given rewriter.
   LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
 
-  /// Recursively collect all of the operations, to convert from within
-  /// 'region'.
+  /// Recursively collect all of the operations to convert from within 'region'.
   LogicalResult computeConversionSet(Region &region,
                                      std::vector<Operation *> &toConvert);
 
+  /// Converts the type signatures of the blocks nested within 'op' that have
+  /// yet to be converted.
+  LogicalResult convertBlockSignatures(ConversionPatternRewriter &rewriter,
+                                       Operation *op);
+
   /// The legalizer to use when converting operations.
   OperationLegalizer opLegalizer;
 
@@ -1021,7 +988,25 @@ private:
 };
 } // end anonymous namespace
 
-/// Recursively collect all of the blocks to convert from within 'region'.
+LogicalResult
+OperationConverter::convertBlockSignatures(ConversionPatternRewriter &rewriter,
+                                           Operation *op) {
+  SmallVector<Region *, 8> worklist;
+  for (auto &region : op->getRegions())
+    worklist.push_back(&region);
+
+  while (!worklist.empty()) {
+    for (auto &block : *worklist.pop_back_val()) {
+      if (failed(rewriter.getImpl().convertBlockSignature(&block)))
+        return failure();
+      for (auto &nestedOp : block)
+        for (auto &region : nestedOp.getRegions())
+          worklist.push_back(&region);
+    }
+  }
+  return success();
+}
+
 LogicalResult
 OperationConverter::computeConversionSet(Region &region,
                                          std::vector<Operation *> &toConvert) {
@@ -1055,7 +1040,6 @@ OperationConverter::computeConversionSet(Region &region,
   return success();
 }
 
-/// Converts an operation with the given rewriter.
 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
                                           Operation *op) {
   // Legalize the given operation.
@@ -1072,23 +1056,9 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
              << "failed to legalize operation '" << op->getName()
              << "' that was explicitly marked illegal";
   }
-
-  // Convert the signature of any empty regions of this operation, non-empty
-  // regions are converted on demand when converting any operations contained
-  // within.
-  // FIXME(riverriddle) This should be replaced by patterns when the pattern
-  // rewriter exposes functionality to remap region signatures.
-  auto &rewriterImpl = rewriter.getImpl();
-  if (rewriterImpl.argConverter.typeConverter) {
-    for (auto &region : op->getRegions())
-      if (region.empty() && failed(rewriterImpl.convertRegionSignature(region)))
-        return failure();
-  }
-
   return success();
 }
 
-/// Converts the given operations to the conversion target.
 LogicalResult
 OperationConverter::convertOperations(ArrayRef<Operation *> ops,
                                       TypeConverter *typeConverter) {
@@ -1106,11 +1076,16 @@ OperationConverter::convertOperations(ArrayRef<Operation *> ops,
 
   // Convert each operation and discard rewrites on failure.
   ConversionPatternRewriter rewriter(ops.front()->getContext(), typeConverter);
-  for (auto *op : toConvert) {
-    if (failed(convert(rewriter, op))) {
-      rewriter.getImpl().discardRewrites();
-      return failure();
-    }
+  for (auto *op : toConvert)
+    if (failed(convert(rewriter, op)))
+      return rewriter.getImpl().discardRewrites(), failure();
+
+  // If a type converter was provided, ensure that all blocks have had their
+  // signatures properly converted.
+  if (typeConverter) {
+    for (auto *op : ops)
+      if (failed(convertBlockSignatures(rewriter, op)))
+        return rewriter.getImpl().discardRewrites(), failure();
   }
 
   // Otherwise the body conversion succeeded, so apply all rewrites.
@@ -1122,11 +1097,6 @@ OperationConverter::convertOperations(ArrayRef<Operation *> ops,
 // Type Conversion
 //===----------------------------------------------------------------------===//
 
-/// Append new result types to the signature conversion.
-void TypeConverter::SignatureConversion::addResults(ArrayRef<Type> results) {
-  resultTypes.append(results.begin(), results.end());
-}
-
 /// Remap an input of the original signature with a new set of types. The
 /// new types are appended to the new signature conversion.
 void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
@@ -1164,33 +1134,31 @@ LogicalResult TypeConverter::convertType(Type t,
   return failure();
 }
 
-/// Convert the given FunctionType signature.
-auto TypeConverter::convertSignature(FunctionType type)
-    -> llvm::Optional<SignatureConversion> {
-  SignatureConversion result(type.getNumInputs());
-  if (failed(convertSignature(type, result)))
-    return llvm::None;
-  return result;
-}
-
-/// This hook allows for changing a FunctionType signature.
-LogicalResult TypeConverter::convertSignature(FunctionType type,
-                                              SignatureConversion &result) {
-  // Convert the original function arguments.
-  for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
-    if (failed(convertSignatureArg(i, type.getInput(i), result)))
+/// Convert the given set of types, filling 'results' as necessary. This
+/// returns failure if the conversion of any of the types fails, success
+/// otherwise.
+LogicalResult TypeConverter::convertTypes(ArrayRef<Type> types,
+                                          SmallVectorImpl<Type> &results) {
+  for (auto type : types)
+    if (failed(convertType(type, results)))
       return failure();
+  return success();
+}
 
-  // Convert the original function results.
-  SmallVector<Type, 1> convertedTypes;
-  for (auto t : type.getResults()) {
-    convertedTypes.clear();
-    if (failed(convertType(t, convertedTypes)))
-      return failure();
-    result.addResults(convertedTypes);
-  }
+/// Return true if the given type is legal for this type converter, i.e. the
+/// type converts to itself.
+bool TypeConverter::isLegal(Type type) {
+  SmallVector<Type, 1> results;
+  return succeeded(convertType(type, results)) && results.size() == 1 &&
+         results.front() == type;
+}
 
-  return success();
+/// Return true if the inputs and outputs of the given function type are
+/// legal.
+bool TypeConverter::isSignatureLegal(FunctionType funcType) {
+  return llvm::all_of(
+      llvm::concat<const Type>(funcType.getInputs(), funcType.getResults()),
+      [this](Type type) { return isLegal(type); });
 }
 
 /// This hook allows for converting a specific argument of a signature.
@@ -1210,34 +1178,55 @@ LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
   return success();
 }
 
-/// This hook defines how the signature of a region 'regionIdx', i.e. the
-/// signature of the entry to the region, on the given operation 'op' is
-/// converted. This function should return a valid conversion for the signature
-/// on success, None otherwise.
-///
-/// The default behavior of this function is to invoke 'convertBlockSignature'
-/// on the entry block, if one is present. This function also provides special
-/// handling for FuncOp to update the type signature.
-///
-/// TODO(riverriddle) This should be replaced in favor of using patterns, but
-/// the pattern rewriter needs to know how to properly replace/remap
-/// arguments.
-auto TypeConverter::convertRegionSignature(Operation *op, unsigned regionIdx)
-    -> llvm::Optional<SignatureConversion> {
-  // Provide explicit handling for FuncOp.
-  if (auto funcOp = dyn_cast<FuncOp>(op)) {
-    auto conversion = convertSignature(funcOp.getType());
-    if (conversion)
-      funcOp.setType(conversion->getConvertedType(funcOp.getContext()));
-    return conversion;
+/// Create a default conversion pattern that rewrites the type signature of a
+/// FuncOp.
+namespace {
+struct FuncOpSignatureConversion : public ConversionPattern {
+  FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
+      : ConversionPattern(FuncOp::getOperationName(), 1, ctx),
+        converter(converter) {}
+
+  /// Hook for derived classes to implement combined matching and rewriting.
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto funcOp = cast<FuncOp>(op);
+    FunctionType type = funcOp.getType();
+
+    // Convert the original function arguments.
+    TypeConverter::SignatureConversion result(type.getNumInputs());
+    for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
+      if (failed(converter.convertSignatureArg(i, type.getInput(i), result)))
+        return matchFailure();
+
+    // Convert the original function results.
+    SmallVector<Type, 1> convertedResults;
+    if (failed(converter.convertTypes(type.getResults(), convertedResults)))
+      return matchFailure();
+
+    // Create a new function with an updated signature.
+    auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+    rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                                newFuncOp.end());
+    newFuncOp.setType(FunctionType::get(result.getConvertedTypes(),
+                                        convertedResults, funcOp.getContext()));
+
+    // Tell the rewriter to convert the region signature.
+    rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
+    rewriter.replaceOp(op, llvm::None);
+    return matchSuccess();
   }
 
-  // Otherwise, default to handle the arguments of the entry block for the given
-  // region.
-  auto &region = op->getRegion(regionIdx);
-  if (region.empty())
-    return SignatureConversion(/*numOrigInputs=*/0);
-  return convertBlockSignature(&region.front());
+  /// The type converter to use when rewriting the signature.
+  TypeConverter &converter;
+};
+} // end anonymous namespace
+
+void mlir::populateFuncOpTypeConversionPattern(
+    OwningRewritePatternList &patterns, MLIRContext *ctx,
+    TypeConverter &converter) {
+  RewriteListBuilder<FuncOpSignatureConversion>::build(patterns, ctx,
+                                                       converter);
 }
 
 /// This function converts the type signature of the given block, by invoking
index 9fb6cc9..c44e489 100644 (file)
@@ -49,13 +49,13 @@ func @remap_multi(%arg0: i64, %unused: i16, %arg1: i64) -> (i64, i64) {
  "test.invalid"(%arg0, %arg1) : (i64, i64) -> ()
 }
 
-// CHECK-LABEL: func @remap_nested
-func @remap_nested() {
+// CHECK-LABEL: func @no_remap_nested
+func @no_remap_nested() {
   // CHECK-NEXT: "foo.region"
   "foo.region"() ({
-    // CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: f64):
+    // CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
     ^bb0(%i0: i64, %unused: i16, %i1: i64):
-      // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
+      // CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
       "test.invalid"(%i0, %i1) : (i64, i64) -> ()
   }) : () -> ()
   return
index d452edb..1cbd253 100644 (file)
@@ -185,23 +185,26 @@ struct TestTypeConverter : public TypeConverter {
 struct TestLegalizePatternDriver
     : public ModulePass<TestLegalizePatternDriver> {
   void runOnModule() override {
+    TestTypeConverter converter;
     mlir::OwningRewritePatternList patterns;
     populateWithGenerated(&getContext(), &patterns);
     RewriteListBuilder<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
                        TestDropOp, TestPassthroughInvalidOp,
                        TestSplitReturnType>::build(patterns, &getContext());
+    mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
+                                              converter);
 
     // Define the conversion target used for the test.
     ConversionTarget target(getContext());
-    target.addLegalOp<LegalOpA, TestValidOp>();
+    target.addLegalOp<LegalOpA, TestCastOp, TestValidOp>();
     target.addIllegalOp<ILLegalOpF, TestRegionBuilderOp>();
     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
       // Don't allow F32 operands.
       return llvm::none_of(op.getOperandTypes(),
                            [](Type type) { return type.isF32(); });
     });
-
-    TestTypeConverter converter;
+    target.addDynamicallyLegalOp<FuncOp>(
+        [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
     (void)applyPartialConversion(getModule(), target, std::move(patterns),
                                  &converter);
   }