Refactor DialectConversion to convert the signatures of blocks when they are moved.
authorRiver Riddle <riverriddle@google.com>
Fri, 16 Aug 2019 17:16:09 +0000 (10:16 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 16 Aug 2019 17:16:38 +0000 (10:16 -0700)
Often we want to ensure that block arguments are converted before operations that use them. This refactors the current implementation to be cleaner/less frequent by triggering conversion when a set of blocks are moved/inlined; or when legalization is successful.

PiperOrigin-RevId: 263795005

mlir/lib/Transforms/DialectConversion.cpp

index 5a4145a..adbed67 100644 (file)
@@ -499,11 +499,11 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 LogicalResult
 ConversionPatternRewriterImpl::convertBlockSignature(Block *block) {
   // Check to see if this block should not be converted:
-  // * The block is invalid, or there is no type converter.
+  // * There is no type converter.
   // * The block has already been converted.
   // * This is an entry block, these are converted explicitly via patterns.
-  if (!block || !argConverter.typeConverter ||
-      argConverter.hasBeenConverted(block) || block->isEntryBlock())
+  if (!argConverter.typeConverter || argConverter.hasBeenConverted(block) ||
+      block->isEntryBlock())
     return success();
 
   // Otherwise, try to convert the block signature.
@@ -738,10 +738,6 @@ bool OperationLegalizer::isIllegal(Operation *op) const {
 LogicalResult
 OperationLegalizer::legalize(Operation *op,
                              ConversionPatternRewriter &rewriter) {
-  // Make sure that the signature of the parent block has been converted.
-  if (failed(rewriter.getImpl().convertBlockSignature(op->getBlock())))
-    return failure();
-
   LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName()
                           << "\n");
 
@@ -802,6 +798,24 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
     return cleanupFailure();
   }
 
+  // If the pattern moved any blocks, try to legalize their types. This ensures
+  // that the types of the block arguments are legal for the region they were
+  // moved into.
+  for (unsigned i = curState.numBlockActions,
+                e = rewriterImpl.blockActions.size();
+       i != e; ++i) {
+    auto &action = rewriterImpl.blockActions[i];
+    if (action.kind != ConversionPatternRewriterImpl::BlockActionKind::Move)
+      continue;
+
+    // Convert the block signature.
+    if (failed(rewriterImpl.convertBlockSignature(action.block))) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "-- FAIL: failed to convert types of moved block.\n");
+      return cleanupFailure();
+    }
+  }
+
   // Recursively legalize each of the new operations.
   for (unsigned i = curState.numCreatedOperations,
                 e = rewriterImpl.createdOps.size();
@@ -958,9 +972,9 @@ enum OpConversionMode {
   Analysis,
 };
 
-// This class converts operations using the given pattern matcher. If a
-// TypeConverter object is provided, then the types of block arguments will be
-// converted using the appropriate 'convertType' calls.
+// This class converts operations to a given conversion target via a set of
+// rewrite patterns. The conversion behaves differently depending on the
+// conversion mode.
 struct OperationConverter {
   explicit OperationConverter(ConversionTarget &target,
                               const OwningRewritePatternList &patterns,
@@ -981,8 +995,7 @@ private:
   LogicalResult computeConversionSet(Region &region,
                                      std::vector<Operation *> &toConvert);
 
-  /// Converts the type signatures of the blocks nested within 'op' that have
-  /// yet to be converted.
+  /// Converts the type signatures of the blocks nested within 'op'.
   LogicalResult convertBlockSignatures(ConversionPatternRewriter &rewriter,
                                        Operation *op);
 
@@ -1001,18 +1014,14 @@ private:
 LogicalResult
 OperationConverter::convertBlockSignatures(ConversionPatternRewriter &rewriter,
                                            Operation *op) {
-  SmallVector<Region *, 8> worklist;
-  for (auto &region : op->getRegions())
-    worklist.push_back(&region);
+  // Check to see if type signatures need to be converted.
+  if (!rewriter.getImpl().argConverter.typeConverter)
+    return success();
 
-  while (!worklist.empty()) {
-    for (auto &block : *worklist.pop_back_val()) {
+  for (auto &region : op->getRegions()) {
+    for (auto &block : region)
       if (failed(rewriter.getImpl().convertBlockSignature(&block)))
         return failure();
-      for (auto &nestedOp : block)
-        for (auto &region : nestedOp.getRegions())
-          worklist.push_back(&region);
-    }
   }
   return success();
 }
@@ -1065,10 +1074,17 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
       return op->emitError()
              << "failed to legalize operation '" << op->getName()
              << "' that was explicitly marked illegal";
-  } else if (mode == OpConversionMode::Analysis) {
-    /// Analysis conversions don't fail if any operations fail to legalize, they
-    /// are only interested in the operations that were successfully legalized.
-    legalizableOps->insert(op);
+  } else {
+    /// Analysis conversions don't fail if any operations fail to legalize,
+    /// they are only interested in the operations that were successfully
+    /// legalized.
+    if (mode == OpConversionMode::Analysis)
+      legalizableOps->insert(op);
+
+    // If legalization succeeded, convert the types any of the blocks within
+    // this operation.
+    if (failed(convertBlockSignatures(rewriter, op)))
+      return failure();
   }
   return success();
 }
@@ -1094,14 +1110,6 @@ OperationConverter::convertOperations(ArrayRef<Operation *> ops,
     if (failed(convert(rewriter, op)))
       return rewriter.getImpl().discardRewrites(), failure();
 
-  // If a type converter was provided, ensure that all blocks have had their
-  // signatures properly converted.
-  if (typeConverter) {
-    for (auto *op : ops)
-      if (failed(convertBlockSignatures(rewriter, op)))
-        return rewriter.getImpl().discardRewrites(), failure();
-  }
-
   // Otherwise, the body conversion succeeded. Apply rewrites if this is not an
   // analysis conversion.
   if (mode == OpConversionMode::Analysis)