NFC: Refactor block signature conversion to not erase the original arguments.
authorRiver Riddle <riverriddle@google.com>
Wed, 13 Nov 2019 18:27:21 +0000 (10:27 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 13 Nov 2019 18:27:53 +0000 (10:27 -0800)
This refactors the implementation of block signature(type) conversion to not insert fake cast operations to perform the type conversion, but to instead create a new block containing the proper signature. This has the benefit of enabling the use of pre-computed analyses that rely on mapping values. It also leads to a much cleaner implementation overall. The major user facing change is that applySignatureConversion will now replace the entry block of the region, meaning that blocks generally shouldn't be cached over calls to applySignatureConversion.

PiperOrigin-RevId: 280226936

mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/Transforms/test-legalizer-full.mlir
mlir/test/Transforms/test-legalizer.mlir
mlir/test/lib/TestDialect/TestPatterns.cpp

index 06762e5..2deb0c9 100644 (file)
@@ -67,7 +67,7 @@ public:
     ArrayRef<Type> getConvertedTypes() const { return argTypes; }
 
     /// Get the input mapping for the given argument.
-    llvm::Optional<InputMapping> const &getInputMapping(unsigned input) const {
+    Optional<InputMapping> getInputMapping(unsigned input) const {
       return remappedInputs[input];
     }
 
@@ -322,9 +322,12 @@ public:
   ConversionPatternRewriter(MLIRContext *ctx, TypeConverter *converter);
   ~ConversionPatternRewriter() override;
 
-  /// Apply a signature conversion to the entry block of the given region.
-  void applySignatureConversion(Region *region,
-                                TypeConverter::SignatureConversion &conversion);
+  /// Apply a signature conversion to the entry block of the given region. This
+  /// replaces the entry block with a new block containing the updated
+  /// signature. The new entry block to the region is returned for convenience.
+  Block *
+  applySignatureConversion(Region *region,
+                           TypeConverter::SignatureConversion &conversion);
 
   /// Replace all the uses of the block argument `from` with value `to`.
   void replaceUsesOfBlockArgument(BlockArgument *from, Value *to);
index bf35c21..9a4d9bb 100644 (file)
@@ -98,8 +98,8 @@ ForOpConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   TypeConverter::SignatureConversion signatureConverter(
       body->getNumArguments());
   signatureConverter.remapInput(0, newIndVar);
-  rewriter.applySignatureConversion(&forOp.getOperation()->getRegion(0),
-                                    signatureConverter);
+  body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
+                                           signatureConverter);
 
   // Delete the loop terminator.
   rewriter.eraseOp(body->getTerminator());
index 6ddc045..a2065f1 100644 (file)
@@ -116,61 +116,90 @@ Value *ConversionValueMapping::lookupOrDefault(Value *from) const {
 //===----------------------------------------------------------------------===//
 namespace {
 /// This class provides a simple interface for converting the types of block
-/// arguments. This is done by inserting fake cast operations that map from the
-/// illegal type to the original type to allow for undoing pending rewrites in
-/// the case of failure.
+/// arguments. This is done by creating a new block that contains the new legal
+/// types and extracting the block that contains the old illegal types to allow
+/// for undoing pending rewrites in the case of failure.
 struct ArgConverter {
   ArgConverter(TypeConverter *typeConverter, PatternRewriter &rewriter)
-      : castOpName(kCastName, rewriter.getContext()),
-        loc(rewriter.getUnknownLoc()), typeConverter(typeConverter),
+      : loc(rewriter.getUnknownLoc()), typeConverter(typeConverter),
         rewriter(rewriter) {}
 
-  /// Erase any rewrites registered for arguments to blocks within the given
-  /// region. This function is called when the given region is to be destroyed.
-  void cancelPendingRewrites(Block *block);
+  /// This structure contains the information pertaining to an argument that has
+  /// been converted.
+  struct ConvertedArgInfo {
+    ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
+                     Value *castValue = nullptr)
+        : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
 
-  /// Cleanup and undo any generated conversions for the arguments of block.
-  /// This method differs from 'cancelPendingRewrites' in that it returns the
-  /// block signature to its original state.
-  void discardPendingRewrites(Block *block);
+    /// The start index of in the new argument list that contains arguments that
+    /// replace the original.
+    unsigned newArgIdx;
 
-  /// Replace usages of the cast operations with the argument directly.
-  void applyRewrites();
+    /// The number of arguments that replaced the original argument.
+    unsigned newArgSize;
+
+    /// The cast value that was created to cast from the new arguments to the
+    /// old. This only used if 'newArgSize' > 1.
+    Value *castValue;
+  };
+
+  /// This structure contains information pertaining to a block that has had its
+  /// signature converted.
+  struct ConvertedBlockInfo {
+    ConvertedBlockInfo(Block *origBlock) : origBlock(origBlock) {}
+
+    /// The original block that was requested to have its signature converted.
+    Block *origBlock;
+
+    /// The conversion information for each of the arguments. The information is
+    /// None if the argument was dropped during conversion.
+    SmallVector<Optional<ConvertedArgInfo>, 1> argInfo;
+  };
 
   /// Return if the signature of the given block has already been converted.
-  bool hasBeenConverted(Block *block) const { return argMapping.count(block); }
+  bool hasBeenConverted(Block *block) const {
+    return conversionInfo.count(block);
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Rewrite Application
+  //===--------------------------------------------------------------------===//
 
-  /// Attempt to convert the signature of the given block.
-  LogicalResult convertSignature(Block *block, ConversionValueMapping &mapping);
+  /// Erase any rewrites registered for the current block that is about to be
+  /// removed. This merely drops the rewrites without undoing them.
+  void notifyBlockRemoved(Block *block);
 
-  /// Apply the given signature conversion on the given block.
-  void applySignatureConversion(
+  /// Cleanup and undo any generated conversions for the arguments of block.
+  /// This method replaces the new block with the original, reverting the IR to
+  /// its original state.
+  void discardRewrites(Block *block);
+
+  /// Fully replace uses of the old arguments with the new, materializing cast
+  /// operations as necessary.
+  // FIXME(riverriddle) The 'mapping' parameter is only necessary because the
+  // implementation of replaceUsesOfBlockArgument is buggy.
+  void applyRewrites(ConversionValueMapping &mapping);
+
+  //===--------------------------------------------------------------------===//
+  // Conversion
+  //===--------------------------------------------------------------------===//
+
+  /// Attempt to convert the signature of the given block, if successful a new
+  /// block is returned containing the new arguments. On failure, nullptr is
+  /// returned.
+  Block *convertSignature(Block *block, ConversionValueMapping &mapping);
+
+  /// Apply the given signature conversion on the given block. The new block
+  /// containing the updated signature is returned.
+  Block *applySignatureConversion(
       Block *block, TypeConverter::SignatureConversion &signatureConversion,
       ConversionValueMapping &mapping);
 
-  /// Convert the given block argument given the provided set of new argument
-  /// values that are to replace it. This function returns the operation used
-  /// to perform the conversion.
-  Operation *convertArgument(BlockArgument *origArg,
-                             ArrayRef<Value *> newValues,
-                             ConversionValueMapping &mapping);
-
-  /// A utility function used to create a conversion cast operation with the
-  /// given input and result types.
-  Operation *createCast(ArrayRef<Value *> inputs, Type outputType);
-
-  /// This is an operation name for a fake operation that is inserted during the
-  /// conversion process. Operations of this type are guaranteed to never escape
-  /// the converter.
-  static constexpr StringLiteral kCastName = "__mlir_conversion.cast";
-  OperationName castOpName;
-
-  /// This is a collection of cast operations that were generated during the
-  /// conversion process when converting the types of block arguments.
-  llvm::MapVector<Block *, SmallVector<Operation *, 4>> argMapping;
-
-  /// An instance of the unknown location that is used when generating
-  /// producers.
+  /// A collection of blocks that have had their arguments converted.
+  llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
+
+  /// An instance of the unknown location that is used when materializing
+  /// conversions.
   Location loc;
 
   /// The type converter to use when changing types.
@@ -181,66 +210,64 @@ struct ArgConverter {
 };
 } // end anonymous namespace
 
-constexpr StringLiteral ArgConverter::kCastName;
+//===----------------------------------------------------------------------===//
+// Rewrite Application
 
-/// Erase any rewrites registered for arguments to the given block.
-void ArgConverter::cancelPendingRewrites(Block *block) {
-  auto it = argMapping.find(block);
-  if (it == argMapping.end())
+void ArgConverter::notifyBlockRemoved(Block *block) {
+  auto it = conversionInfo.find(block);
+  if (it == conversionInfo.end())
     return;
-  for (auto *op : it->second) {
-    op->dropAllDefinedValueUses();
-    op->erase();
-  }
-  argMapping.erase(it);
+
+  // Drop all uses of the original arguments and delete the original block.
+  Block *origBlock = it->second.origBlock;
+  for (BlockArgument *arg : origBlock->getArguments())
+    arg->dropAllUses();
+  delete origBlock;
+
+  conversionInfo.erase(it);
 }
 
-/// Cleanup and undo any generated conversions for the arguments of block.
-/// This method differs from 'cancelPendingRewrites' in that it returns the
-/// block signature to its original state.
-void ArgConverter::discardPendingRewrites(Block *block) {
-  auto it = argMapping.find(block);
-  if (it == argMapping.end())
+void ArgConverter::discardRewrites(Block *block) {
+  auto it = conversionInfo.find(block);
+  if (it == conversionInfo.end())
     return;
+  Block *origBlock = it->second.origBlock;
 
-  // Erase all of the new arguments.
-  for (int i = block->getNumArguments() - 1; i >= 0; --i) {
+  // Drop all uses of the new block arguments and replace uses of the new block.
+  for (int i = block->getNumArguments() - 1; i >= 0; --i)
     block->getArgument(i)->dropAllUses();
-    block->eraseArgument(i, /*updatePredTerms=*/false);
-  }
+  block->replaceAllUsesWith(origBlock);
 
-  // Re-instate the old arguments.
-  auto &mapping = it->second;
-  for (unsigned i = 0, e = mapping.size(); i != e; ++i) {
-    auto *op = mapping[i];
-    auto *arg = block->addArgument(op->getResult(0)->getType());
-    op->getResult(0)->replaceAllUsesWith(arg);
+  // Move the operations back the original block and the delete the new block.
+  origBlock->getOperations().splice(origBlock->end(), block->getOperations());
+  origBlock->insertBefore(block);
+  block->erase();
 
-    // If this operation is within a block, it will be cleaned up automatically.
-    if (!op->getBlock())
-      op->erase();
-  }
-  argMapping.erase(it);
+  conversionInfo.erase(it);
 }
 
-/// Replace usages of the cast operations with the argument directly.
-void ArgConverter::applyRewrites() {
-  Block *block;
-  ArrayRef<Operation *> argOps;
-  for (auto &mapping : argMapping) {
-    std::tie(block, argOps) = mapping;
+void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
+  for (auto &info : conversionInfo) {
+    Block *newBlock = info.first;
+    ConvertedBlockInfo &blockInfo = info.second;
+    Block *origBlock = blockInfo.origBlock;
 
     // Process the remapping for each of the original arguments.
-    for (unsigned i = 0, e = argOps.size(); i != e; ++i) {
-      auto *op = argOps[i];
-
-      // Handle the case of a 1->N value mapping.
-      if (op->getNumOperands() > 1) {
-        // If all of the uses were removed, we can drop this op. Otherwise,
-        // keep the operation alive and let the user handle any remaining
-        // usages.
-        if (op->use_empty())
-          op->erase();
+    for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
+      Optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
+      BlockArgument *origArg = origBlock->getArgument(i);
+
+      // Handle the case of a 1->0 value mapping.
+      if (!argInfo) {
+        // If there are any dangling uses then replace the argument with one
+        // generated by the type converter. This is necessary as the cast must
+        // persist in the IR after conversion.
+        if (!origArg->use_empty()) {
+          rewriter.setInsertionPointToStart(newBlock);
+          auto *newOp = typeConverter->materializeConversion(
+              rewriter, origArg->getType(), llvm::None, loc);
+          origArg->replaceAllUsesWith(newOp->getResult(0));
+        }
         continue;
       }
 
@@ -248,113 +275,105 @@ void ArgConverter::applyRewrites() {
       // operation.
       // FIXME(riverriddle) This should check that the result type and operand
       // type are the same, otherwise it should force a conversion to be
-      // materialized. This works around a current limitation with regards to
-      // region entry argument type conversion.
-      if (op->getNumOperands() == 1) {
-        op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
-        op->destroy();
+      // materialized.
+      if (argInfo->newArgSize == 1) {
+        origArg->replaceAllUsesWith(
+            mapping.lookupOrDefault(newBlock->getArgument(argInfo->newArgIdx)));
         continue;
       }
 
-      // Otherwise, if there are any dangling uses then replace the fake
-      // conversion operation with one generated by the type converter. This
-      // is necessary as the cast must persist in the IR after conversion.
-      auto *opResult = op->getResult(0);
-      if (!opResult->use_empty()) {
-        rewriter.setInsertionPointToStart(block);
-        SmallVector<Value *, 1> operands(op->getOperands());
-        auto *newOp = typeConverter->materializeConversion(
-            rewriter, opResult->getType(), operands, op->getLoc());
-        opResult->replaceAllUsesWith(newOp->getResult(0));
-      }
-      op->destroy();
+      // Otherwise this is a 1->N value mapping.
+      Value *castValue = argInfo->castValue;
+      assert(argInfo->newArgSize > 1 && castValue && "expected 1->N mapping");
+
+      // If the argument is still used, replace it with the generated cast.
+      if (!origArg->use_empty())
+        origArg->replaceAllUsesWith(mapping.lookupOrDefault(castValue));
+
+      // If all users of the cast were removed, we can drop it. Otherwise, keep
+      // the operation alive and let the user handle any remaining usages.
+      if (castValue->use_empty())
+        castValue->getDefiningOp()->erase();
     }
+
+    // Drop the original block now the rewrites were applied.
+    delete origBlock;
   }
 }
 
-/// Converts the signature of the given entry block.
-LogicalResult ArgConverter::convertSignature(Block *block,
-                                             ConversionValueMapping &mapping) {
+//===----------------------------------------------------------------------===//
+// Conversion
+
+Block *ArgConverter::convertSignature(Block *block,
+                                      ConversionValueMapping &mapping) {
   if (auto conversion = typeConverter->convertBlockSignature(block))
-    return applySignatureConversion(block, *conversion, mapping), success();
-  return failure();
+    return applySignatureConversion(block, *conversion, mapping);
+  return nullptr;
 }
 
-/// Apply the given signature conversion on the given block.
-void ArgConverter::applySignatureConversion(
+Block *ArgConverter::applySignatureConversion(
     Block *block, TypeConverter::SignatureConversion &signatureConversion,
     ConversionValueMapping &mapping) {
+  // If no arguments are being changed or added, there is nothing to do.
   unsigned origArgCount = block->getNumArguments();
   auto convertedTypes = signatureConversion.getConvertedTypes();
   if (origArgCount == 0 && convertedTypes.empty())
-    return;
+    return block;
 
-  SmallVector<Value *, 4> newArgRange(block->addArguments(convertedTypes));
-  ArrayRef<Value *> newArgRef(newArgRange);
+  // Split the block at the beginning to get a new block to use for the updated
+  // signature.
+  Block *newBlock = block->splitBlock(block->begin());
+  block->replaceAllUsesWith(newBlock);
+
+  SmallVector<Value *, 4> newArgRange(newBlock->addArguments(convertedTypes));
+  ArrayRef<Value *> newArgs(newArgRange);
 
   // Remap each of the original arguments as determined by the signature
   // conversion.
-  auto &newArgMapping = argMapping[block];
+  ConvertedBlockInfo info(block);
+  info.argInfo.resize(origArgCount);
+
   OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPointToStart(block);
+  rewriter.setInsertionPointToStart(newBlock);
   for (unsigned i = 0; i != origArgCount; ++i) {
-    ArrayRef<Value *> remappedValues;
-    if (auto &inputMap = signatureConversion.getInputMapping(i)) {
-      // If inputMap->replacementValue is not nullptr, then the argument is
-      // dropped and a replacement value is provided to be the remappedValue.
-      if (inputMap->replacementValue) {
-        assert(inputMap->size == 0 &&
-               "invalid to provide a replacement value when the argument isn't "
-               "dropped");
-        remappedValues = inputMap->replacementValue;
-      } else
-        remappedValues = newArgRef.slice(inputMap->inputNo, inputMap->size);
+    auto inputMap = signatureConversion.getInputMapping(i);
+    if (!inputMap)
+      continue;
+    BlockArgument *origArg = block->getArgument(i);
+
+    // If inputMap->replacementValue is not nullptr, then the argument is
+    // dropped and a replacement value is provided to be the remappedValue.
+    if (inputMap->replacementValue) {
+      assert(inputMap->size == 0 &&
+             "invalid to provide a replacement value when the argument isn't "
+             "dropped");
+      mapping.map(origArg, inputMap->replacementValue);
+      continue;
     }
 
-    BlockArgument *arg = block->getArgument(i);
-    newArgMapping.push_back(convertArgument(arg, remappedValues, mapping));
-  }
+    // If this is a 1->1 mapping, then map the argument directly.
+    if (inputMap->size == 1) {
+      mapping.map(origArg, newArgs[inputMap->inputNo]);
+      info.argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size);
+      continue;
+    }
 
-  // Erase all of the original arguments.
-  for (unsigned i = 0; i != origArgCount; ++i)
-    block->eraseArgument(0, /*updatePredTerms=*/false);
-}
-
-/// Convert the given block argument given the provided set of new argument
-/// values that are to replace it. This function returns the operation used
-/// to perform the conversion.
-Operation *ArgConverter::convertArgument(BlockArgument *origArg,
-                                         ArrayRef<Value *> newValues,
-                                         ConversionValueMapping &mapping) {
-  // Handle the cases of 1->0 or 1->1 mappings.
-  if (newValues.size() < 2) {
-    // Create a temporary producer for the argument during the conversion
-    // process.
-    auto *cast = createCast(newValues, origArg->getType());
-    origArg->replaceAllUsesWith(cast->getResult(0));
-
-    // Insert a mapping between this argument and the one that is replacing
-    // it.
-    if (!newValues.empty())
-      mapping.map(cast->getResult(0), newValues[0]);
-    return cast;
+    // Otherwise, this is a 1->N mapping. Call into the provided type converter
+    // to pack the new values.
+    auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
+    Operation *cast = typeConverter->materializeConversion(
+        rewriter, origArg->getType(), replArgs, loc);
+    assert(cast->getNumResults() == 1 &&
+           cast->getNumOperands() == replArgs.size());
+    mapping.map(origArg, cast->getResult(0));
+    info.argInfo[i] =
+        ConvertedArgInfo(inputMap->inputNo, inputMap->size, cast->getResult(0));
   }
 
-  // Otherwise, this is a 1->N mapping. Call into the provided type converter
-  // to pack the new values.
-  auto *cast = typeConverter->materializeConversion(
-      rewriter, origArg->getType(), newValues, loc);
-  assert(cast->getNumResults() == 1 &&
-         cast->getNumOperands() == newValues.size());
-  origArg->replaceAllUsesWith(cast->getResult(0));
-  return cast;
-}
-
-/// A utility function used to create a conversion cast operation with the
-/// given input and result types.
-Operation *ArgConverter::createCast(ArrayRef<Value *> inputs, Type outputType) {
-  return Operation::create(loc, castOpName, outputType, inputs, llvm::None,
-                           llvm::None, 0, false);
+  // Remove the original block from the region and return the new one.
+  newBlock->getParent()->getBlocks().remove(block);
+  conversionInfo.insert({newBlock, std::move(info)});
+  return newBlock;
 }
 
 //===----------------------------------------------------------------------===//
@@ -470,8 +489,9 @@ struct ConversionPatternRewriterImpl {
   LogicalResult convertBlockSignature(Block *block);
 
   /// Apply a signature conversion on the given region.
-  void applySignatureConversion(Region *region,
-                                TypeConverter::SignatureConversion &conversion);
+  Block *
+  applySignatureConversion(Region *region,
+                           TypeConverter::SignatureConversion &conversion);
 
   /// PatternRewriter hook for replacing the results of an operation.
   void replaceOp(Operation *op, ArrayRef<Value *> newValues,
@@ -589,7 +609,7 @@ void ConversionPatternRewriterImpl::undoBlockActions(
     }
     // Undo the type conversion.
     case BlockActionKind::TypeConversion: {
-      argConverter.discardPendingRewrites(action.block);
+      argConverter.discardRewrites(action.block);
       break;
     }
     }
@@ -619,7 +639,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
     if (argConverter.typeConverter && repl.op->getNumRegions()) {
       for (auto &region : repl.op->getRegions())
         for (auto &block : region)
-          argConverter.cancelPendingRewrites(&block);
+          argConverter.notifyBlockRemoved(&block);
     }
   }
 
@@ -629,7 +649,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
   for (auto &repl : llvm::reverse(replacements))
     repl.op->erase();
 
-  argConverter.applyRewrites();
+  argConverter.applyRewrites(mapping);
 }
 
 LogicalResult
@@ -639,23 +659,25 @@ ConversionPatternRewriterImpl::convertBlockSignature(Block *block) {
   // * The block has already been converted.
   // * This is an entry block, these are converted explicitly via patterns.
   if (!argConverter.typeConverter || argConverter.hasBeenConverted(block) ||
-      block->isEntryBlock())
+      !block->getParent() || block->isEntryBlock())
     return success();
 
   // Otherwise, try to convert the block signature.
-  if (failed(argConverter.convertSignature(block, mapping)))
-    return failure();
-  blockActions.push_back(BlockAction::getTypeConversion(block));
-  return success();
+  Block *newBlock = argConverter.convertSignature(block, mapping);
+  if (newBlock)
+    blockActions.push_back(BlockAction::getTypeConversion(newBlock));
+  return success(newBlock);
 }
 
-void ConversionPatternRewriterImpl::applySignatureConversion(
+Block *ConversionPatternRewriterImpl::applySignatureConversion(
     Region *region, TypeConverter::SignatureConversion &conversion) {
   if (!region->empty()) {
-    argConverter.applySignatureConversion(&region->front(), conversion,
-                                          mapping);
-    blockActions.push_back(BlockAction::getTypeConversion(&region->front()));
+    Block *newEntry = argConverter.applySignatureConversion(
+        &region->front(), conversion, mapping);
+    blockActions.push_back(BlockAction::getTypeConversion(newEntry));
+    return newEntry;
   }
+  return nullptr;
 }
 
 void ConversionPatternRewriterImpl::replaceOp(
@@ -759,9 +781,9 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
 }
 
 /// Apply a signature conversion to the entry block of the given region.
-void ConversionPatternRewriter::applySignatureConversion(
+Block *ConversionPatternRewriter::applySignatureConversion(
     Region *region, TypeConverter::SignatureConversion &conversion) {
-  impl->applySignatureConversion(region, conversion);
+  return impl->applySignatureConversion(region, conversion);
 }
 
 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument *from,
@@ -1266,7 +1288,7 @@ OperationConverter::convertBlockSignatures(ConversionPatternRewriter &rewriter,
     return success();
 
   for (auto &region : op->getRegions()) {
-    for (auto &block : region)
+    for (auto &block : llvm::make_early_inc_range(region))
       if (failed(rewriter.getImpl().convertBlockSignature(&block)))
         return failure();
   }
index 0b65868..d0fc4c9 100644 (file)
@@ -13,7 +13,7 @@ func @multi_level_mapping() {
 // CHECK-LABEL: func @dropped_region_with_illegal_ops
 func @dropped_region_with_illegal_ops() {
   // CHECK-NEXT: test.return
-  "test.drop_op"() ({
+  "test.drop_region_op"() ({
     %ignored = "test.illegal_op_f"() : () -> (i32)
     "test.return"() : () -> ()
   }) : () -> ()
index 1e56325..efb59b0 100644 (file)
@@ -91,7 +91,7 @@ func @remap_cloned_region_args() {
 func @remap_drop_region() {
   // CHECK-NEXT: return
   // CHECK-NEXT: }
-  "test.drop_op"() ({
+  "test.drop_region_op"() ({
     ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
       "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
   }) : () -> ()
index cc935c7..936d763 100644 (file)
@@ -162,15 +162,32 @@ struct TestRegionRewriteUndo : public RewritePattern {
 //===----------------------------------------------------------------------===//
 // Type-Conversion Rewrite Testing
 
-/// This pattern simply erases the given operation.
-struct TestDropOp : public ConversionPattern {
-  TestDropOp(MLIRContext *ctx) : ConversionPattern("test.drop_op", 1, ctx) {}
+/// This patterns erases a region operation that has had a type conversion.
+struct TestDropOpSignatureConversion : public ConversionPattern {
+  TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
+      : ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) {
+  }
   PatternMatchResult
   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                  ConversionPatternRewriter &rewriter) const final {
+                  ConversionPatternRewriter &rewriter) const override {
+    Region &region = op->getRegion(0);
+    Block *entry = &region.front();
+
+    // Convert the original entry arguments.
+    TypeConverter::SignatureConversion result(entry->getNumArguments());
+    for (unsigned i = 0, e = entry->getNumArguments(); i != e; ++i)
+      if (failed(converter.convertSignatureArg(
+              i, entry->getArgument(i)->getType(), result)))
+        return matchFailure();
+
+    // Convert the region signature and just drop the operation.
+    rewriter.applySignatureConversion(&region, result);
     rewriter.eraseOp(op);
     return matchSuccess();
   }
+
+  /// The type converter to use when rewriting the signature.
+  TypeConverter &converter;
 };
 /// This pattern simply updates the operands of the given operation.
 struct TestPassthroughInvalidOp : public ConversionPattern {
@@ -334,10 +351,11 @@ struct TestLegalizePatternDriver
     populateWithGenerated(&getContext(), &patterns);
     patterns
         .insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
-                TestDropOp, TestPassthroughInvalidOp, TestSplitReturnType,
+                TestPassthroughInvalidOp, TestSplitReturnType,
                 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
                 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
                 TestNonRootReplacement>(&getContext());
+    patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
                                               converter);