[mlir] Keep track of region signature conversions as argument replacements
authorAlex Zinenko <zinenko@google.com>
Fri, 29 Jan 2021 18:41:10 +0000 (19:41 +0100)
committerAlex Zinenko <zinenko@google.com>
Tue, 2 Feb 2021 09:38:31 +0000 (10:38 +0100)
In dialect conversion, signature conversions essentially perform block argument
replacement and are added to the general value remapping. However, the replaced
values were not tracked, so if a signature conversion was rolled back, the
construction of operand lists for the following patterns could have obtained
block arguments from the mapping and give them to the pattern leading to
use-after-free. Keep track of signature conversions similarly to normal block
argument replacement, and erase such replacements from the general mapping when
the conversion is rolled back.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D95688

mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/Transforms/test-legalize-type-conversion.mlir
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestPatterns.cpp

index ae62a63..ecbd57d 100644 (file)
@@ -256,8 +256,10 @@ struct ArgConverter {
   /// Attempt to convert the signature of the given block, if successful a new
   /// block is returned containing the new arguments. Returns `block` if it did
   /// not require conversion.
-  FailureOr<Block *> convertSignature(Block *block, TypeConverter &converter,
-                                      ConversionValueMapping &mapping);
+  FailureOr<Block *>
+  convertSignature(Block *block, TypeConverter &converter,
+                   ConversionValueMapping &mapping,
+                   SmallVectorImpl<BlockArgument> &argReplacements);
 
   /// Apply the given signature conversion on the given block. The new block
   /// containing the updated signature is returned. If no conversions were
@@ -268,7 +270,8 @@ struct ArgConverter {
   Block *applySignatureConversion(
       Block *block, TypeConverter &converter,
       TypeConverter::SignatureConversion &signatureConversion,
-      ConversionValueMapping &mapping);
+      ConversionValueMapping &mapping,
+      SmallVectorImpl<BlockArgument> &argReplacements);
 
   /// Insert a new conversion into the cache.
   void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
@@ -425,9 +428,9 @@ LogicalResult ArgConverter::materializeLiveConversions(
 //===----------------------------------------------------------------------===//
 // Conversion
 
-FailureOr<Block *>
-ArgConverter::convertSignature(Block *block, TypeConverter &converter,
-                               ConversionValueMapping &mapping) {
+FailureOr<Block *> ArgConverter::convertSignature(
+    Block *block, TypeConverter &converter, ConversionValueMapping &mapping,
+    SmallVectorImpl<BlockArgument> &argReplacements) {
   // Check if the block was already converted. If the block is detached,
   // conservatively assume it is going to be deleted.
   if (hasBeenConverted(block) || !block->getParent())
@@ -435,14 +438,16 @@ ArgConverter::convertSignature(Block *block, TypeConverter &converter,
 
   // Try to convert the signature for the block with the provided converter.
   if (auto conversion = converter.convertBlockSignature(block))
-    return applySignatureConversion(block, converter, *conversion, mapping);
+    return applySignatureConversion(block, converter, *conversion, mapping,
+                                    argReplacements);
   return failure();
 }
 
 Block *ArgConverter::applySignatureConversion(
     Block *block, TypeConverter &converter,
     TypeConverter::SignatureConversion &signatureConversion,
-    ConversionValueMapping &mapping) {
+    ConversionValueMapping &mapping,
+    SmallVectorImpl<BlockArgument> &argReplacements) {
   // If no arguments are being changed or added, there is nothing to do.
   unsigned origArgCount = block->getNumArguments();
   auto convertedTypes = signatureConversion.getConvertedTypes();
@@ -477,6 +482,7 @@ Block *ArgConverter::applySignatureConversion(
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
       mapping.map(origArg, inputMap->replacementValue);
+      argReplacements.push_back(origArg);
       continue;
     }
 
@@ -492,6 +498,7 @@ Block *ArgConverter::applySignatureConversion(
       newArg = replArgs.front();
     }
     mapping.map(origArg, newArg);
+    argReplacements.push_back(origArg);
     info.argInfo[i] =
         ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
   }
@@ -1113,9 +1120,10 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
     Block *block, TypeConverter &converter,
     TypeConverter::SignatureConversion *conversion) {
   FailureOr<Block *> result =
-      conversion ? argConverter.applySignatureConversion(block, converter,
-                                                         *conversion, mapping)
-                 : argConverter.convertSignature(block, converter, mapping);
+      conversion ? argConverter.applySignatureConversion(
+                       block, converter, *conversion, mapping, argReplacements)
+                 : argConverter.convertSignature(block, converter, mapping,
+                                                 argReplacements);
   if (Block *newBlock = result.getValue()) {
     if (newBlock != block)
       blockActions.push_back(BlockAction::getTypeConversion(newBlock));
index c56b3c8..1ea8ddc 100644 (file)
@@ -62,3 +62,18 @@ func @test_valid_result_legalization() {
   %result = "test.type_producer"() : () -> f32
   "foo.return"(%result) : (f32) -> ()
 }
+
+// -----
+
+// Should not segfault here but gracefully fail.
+// CHECK-LABEL: func @test_signature_conversion_undo
+func @test_signature_conversion_undo() {
+  // CHECK: test.signature_conversion_undo
+  "test.signature_conversion_undo"() ({
+  // CHECK: ^{{.*}}(%{{.*}}: f32):
+  ^bb0(%arg0: f32):
+    "test.type_consumer"(%arg0) : (f32) -> ()
+    "test.return"(%arg0) : (f32) -> ()
+  }) : () -> ()
+  return
+}
index 89d2ee8..ec836f2 100644 (file)
@@ -1289,6 +1289,10 @@ def TestMergeBlocksOp : TEST_Op<"merge_blocks"> {
   let results = (outs Variadic<AnyType>:$result);
 }
 
+def TestSignatureConversionUndoOp : TEST_Op<"signature_conversion_undo"> {
+  let regions = (region AnyRegion);
+}
+
 //===----------------------------------------------------------------------===//
 // Test parser.
 //===----------------------------------------------------------------------===//
index a4c16a6..7da02ed 100644 (file)
@@ -774,6 +774,34 @@ struct TestTypeConversionProducer
   }
 };
 
+/// Call signature conversion and then fail the rewrite to trigger the undo
+/// mechanism.
+struct TestSignatureConversionUndo
+    : public OpConversionPattern<TestSignatureConversionUndoOp> {
+  using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(TestSignatureConversionUndoOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
+    return failure();
+  }
+};
+
+/// Just forward the operands to the root op. This is essentially a no-op
+/// pattern that is used to trigger target materialization.
+struct TestTypeConsumerForward
+    : public OpConversionPattern<TestTypeConsumerOp> {
+  using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(TestTypeConsumerOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    rewriter.updateRootInPlace(op, [&] { op->setOperands(operands); });
+    return success();
+  }
+};
+
 struct TestTypeConversionDriver
     : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
   void getDependentDialects(DialectRegistry &registry) const override {
@@ -836,7 +864,8 @@ struct TestTypeConversionDriver
 
     // Initialize the set of rewrite patterns.
     OwningRewritePatternList patterns;
-    patterns.insert<TestTypeConversionProducer>(converter, &getContext());
+    patterns.insert<TestTypeConsumerForward, TestTypeConversionProducer,
+                    TestSignatureConversionUndo>(converter, &getContext());
     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
                                               converter);