Refactor DialectConversion to use 'materializeConversion' when a type conversion...
authorRiver Riddle <riverriddle@google.com>
Fri, 28 Jun 2019 18:28:30 +0000 (11:28 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 28 Jun 2019 18:29:04 +0000 (11:29 -0700)
During conversion, if a type conversion has dangling uses a type conversion must persist after conversion has finished to maintain valid IR. In these cases, we now query the TypeConverter to materialize a conversion for us. This allows for the default case of a full conversion to continue working as expected, but also handle the degenerate cases more robustly.

PiperOrigin-RevId: 255637171

mlir/examples/toy/Ch5/mlir/LateLowering.cpp
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/Transforms/test-legalizer.mlir
mlir/test/lib/TestDialect/TestOps.td
mlir/test/lib/TestDialect/TestPatterns.cpp

index 34a2dc3..8b2a392 100644 (file)
@@ -331,6 +331,14 @@ protected:
       return array.toMemref();
     return t;
   }
+
+  /// Materialize a conversion to allow for partial lowering of types.
+  Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
+                                   ArrayRef<Value *> inputs,
+                                   Location loc) override {
+    assert(inputs.size() == 1 && "expected only one input value");
+    return rewriter.create<toy::TypeCastOp>(loc, inputs[0], resultType);
+  }
 };
 
 /// This is lowering to Linalg the parts that can be (matmul and add on arrays)
index 599c7f1..00da0d5 100644 (file)
@@ -221,13 +221,13 @@ public:
   /// one result type by generating a cast operation of some kind. The generated
   /// operation should produce one result, of 'resultType', with the provided
   /// 'inputs' as operands. This hook must be overridden when a type conversion
-  /// results in more than one type.
+  /// results in more than one type, or if a type conversion may persist after
+  /// the conversion has finished.
   virtual Operation *materializeConversion(PatternRewriter &rewriter,
                                            Type resultType,
                                            ArrayRef<Value *> inputs,
                                            Location loc) {
-    llvm_unreachable("expected 'materializeConversion' to be overridden when "
-                     "generating 1->N type conversions");
+    llvm_unreachable("expected 'materializeConversion' to be overridden");
   }
 };
 
@@ -337,14 +337,13 @@ private:
 
 /// Convert the given module with the provided conversion patterns and type
 /// conversion object. This function returns failure if a type conversion
-/// failed, potentially leaving the IR in an invalid state.
+/// failed.
 LLVM_NODISCARD LogicalResult applyConversionPatterns(
     Module &module, ConversionTarget &target, TypeConverter &converter,
     OwningRewritePatternList &&patterns);
 
 /// Convert the given functions with the provided conversion patterns. This
-/// function returns failure if a type conversion failed, potentially leaving
-/// the IR in an invalid state.
+/// function returns failure if a type conversion failed.
 LLVM_NODISCARD
 LogicalResult applyConversionPatterns(ArrayRef<Function *> fns,
                                       ConversionTarget &target,
index 88db7b6..f6d6329 100644 (file)
@@ -40,8 +40,10 @@ namespace {
 /// illegal type to the original type to allow for undoing pending rewrites in
 /// the case of failure.
 struct ArgConverter {
-  ArgConverter(MLIRContext *ctx)
-      : castOpName(kCastName, ctx), loc(UnknownLoc::get(ctx)) {}
+  ArgConverter(TypeConverter *typeConverter, PatternRewriter &rewriter)
+      : castOpName(kCastName, rewriter.getContext()),
+        loc(rewriter.getUnknownLoc()), typeConverter(typeConverter),
+        rewriter(rewriter) {}
 
   /// Erase any rewrites registered for arguments to blocks within the given
   /// region. This function is called when the given region is to be destroyed.
@@ -51,26 +53,21 @@ struct ArgConverter {
   void discardRewrites();
 
   /// Replace usages of the cast operations with the argument directly.
-  LogicalResult applyRewrites();
+  void applyRewrites();
 
   /// Converts the signature of the given entry block.
-  void convertSignature(Block *block, PatternRewriter &rewriter,
-                        TypeConverter &converter,
+  void convertSignature(Block *block,
                         TypeConverter::SignatureConversion &signatureConversion,
                         BlockAndValueMapping &mapping);
 
   /// Converts the arguments of the given block.
-  LogicalResult convertArguments(Block *block, PatternRewriter &rewriter,
-                                 TypeConverter &converter,
-                                 BlockAndValueMapping &mapping);
+  LogicalResult convertArguments(Block *block, BlockAndValueMapping &mapping);
 
   /// Convert the given block argument given the provided set of new argument
   /// values that are to replace it. This function returns the operation used
   /// to perform the conversion.
   Operation *convertArgument(BlockArgument *origArg,
                              ArrayRef<Value *> newValues,
-                             PatternRewriter &rewriter,
-                             TypeConverter &converter,
                              BlockAndValueMapping &mapping);
 
   /// A utility function used to create a conversion cast operation with the
@@ -90,6 +87,12 @@ struct ArgConverter {
   /// An instance of the unknown location that is used when generating
   /// producers.
   Location loc;
+
+  /// The type converter to use when changing types.
+  TypeConverter *typeConverter;
+
+  /// The pattern rewriter to use when materializing conversions.
+  PatternRewriter &rewriter;
 };
 
 constexpr StringLiteral ArgConverter::kCastName;
@@ -147,11 +150,9 @@ void ArgConverter::discardRewrites() {
 }
 
 /// Replace usages of the cast operations with the argument directly.
-LogicalResult ArgConverter::applyRewrites() {
+void ArgConverter::applyRewrites() {
   Block *block;
   ArrayRef<Operation *> argOps;
-
-  LogicalResult result = success();
   for (auto &mapping : argMapping) {
     std::tie(block, argOps) = mapping;
 
@@ -169,41 +170,25 @@ LogicalResult ArgConverter::applyRewrites() {
         continue;
       }
 
-      // Handle the case where this argument had a direct mapping.
-      if (op->getNumOperands() == 1) {
-        op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
-        // Otherwise, this argument was expected to be dropped.
-      } else if (!op->getResult(0)->use_empty()) {
-        // Don't emit another error if we already have one.
-        if (!failed(result)) {
-          auto *parent = block->getParent();
-          auto diag = emitError(parent->getLoc())
-                      << "block argument #" << i << " with type "
-                      << op->getResult(0)->getType()
-                      << " has unexpected remaining uses";
-          auto *user = *op->getResult(0)->user_begin();
-          diag.attachNote(user->getLoc())
-              << "unexpected user defined here : " << *user;
-          result = failure();
-        }
-        // Move this fake producer to the beginning of the parent block, we
-        // can't recover from this failure and we want to make sure the
-        // operations get cleaned up. Recovering from this would require
-        // detecting that an argument would be unused before applying all of
-        // the operation rewrites, which can get quite expensive.
-        block->push_front(op);
-        continue;
+      // Otherwise, if there are any dangling uses then replace the fake
+      // conversion operation with one generated by the type converter. This
+      // is necessary as the cast must persist in the IR after conversion.
+      auto *opResult = op->getResult(0);
+      if (!opResult->use_empty()) {
+        rewriter.setInsertionPointToStart(block);
+        SmallVector<Value *, 1> operands(op->getOperands());
+        auto *newOp = typeConverter->materializeConversion(
+            rewriter, opResult->getType(), operands, op->getLoc());
+        opResult->replaceAllUsesWith(newOp->getResult(0));
       }
       op->destroy();
     }
   }
-  return result;
 }
 
 /// Converts the signature of the given entry block.
 void ArgConverter::convertSignature(
-    Block *block, PatternRewriter &rewriter, TypeConverter &converter,
-    TypeConverter::SignatureConversion &signatureConversion,
+    Block *block, TypeConverter::SignatureConversion &signatureConversion,
     BlockAndValueMapping &mapping) {
   unsigned origArgCount = block->getNumArguments();
   auto convertedTypes = signatureConversion.getConvertedArgTypes();
@@ -223,8 +208,7 @@ void ArgConverter::convertSignature(
       remappedValues = newArgRef.slice(inputMap->inputNo, inputMap->size);
 
     BlockArgument *arg = block->getArgument(i);
-    newArgMapping.push_back(
-        convertArgument(arg, remappedValues, rewriter, converter, mapping));
+    newArgMapping.push_back(convertArgument(arg, remappedValues, mapping));
   }
 
   // Erase all of the original arguments.
@@ -234,8 +218,6 @@ void ArgConverter::convertSignature(
 
 /// Converts the arguments of the given block.
 LogicalResult ArgConverter::convertArguments(Block *block,
-                                             PatternRewriter &rewriter,
-                                             TypeConverter &converter,
                                              BlockAndValueMapping &mapping) {
   unsigned origArgCount = block->getNumArguments();
   if (origArgCount == 0)
@@ -245,7 +227,7 @@ LogicalResult ArgConverter::convertArguments(Block *block,
   SmallVector<SmallVector<Type, 1>, 4> newArgTypes(origArgCount);
   for (unsigned i = 0; i != origArgCount; ++i) {
     auto *arg = block->getArgument(i);
-    if (failed(converter.convertType(arg->getType(), newArgTypes[i])))
+    if (failed(typeConverter->convertType(arg->getType(), newArgTypes[i])))
       return emitError(block->getParent()->getLoc())
              << "could not convert block argument of type " << arg->getType();
   }
@@ -255,8 +237,8 @@ LogicalResult ArgConverter::convertArguments(Block *block,
   rewriter.setInsertionPointToStart(block);
   for (unsigned i = 0; i != origArgCount; ++i) {
     SmallVector<Value *, 1> newArgs(block->addArguments(newArgTypes[i]));
-    newArgMapping.push_back(convertArgument(block->getArgument(i), newArgs,
-                                            rewriter, converter, mapping));
+    newArgMapping.push_back(
+        convertArgument(block->getArgument(i), newArgs, mapping));
   }
 
   // Erase all of the original arguments.
@@ -270,8 +252,6 @@ LogicalResult ArgConverter::convertArguments(Block *block,
 /// to perform the conversion.
 Operation *ArgConverter::convertArgument(BlockArgument *origArg,
                                          ArrayRef<Value *> newValues,
-                                         PatternRewriter &rewriter,
-                                         TypeConverter &converter,
                                          BlockAndValueMapping &mapping) {
   // Handle the cases of 1->0 or 1->1 mappings.
   if (newValues.size() < 2) {
@@ -289,8 +269,8 @@ Operation *ArgConverter::convertArgument(BlockArgument *origArg,
 
   // Otherwise, this is a 1->N mapping. Call into the provided type converter
   // to pack the new values.
-  auto *cast = converter.materializeConversion(rewriter, origArg->getType(),
-                                               newValues, loc);
+  auto *cast = typeConverter->materializeConversion(
+      rewriter, origArg->getType(), newValues, loc);
   assert(cast->getNumResults() == 1 &&
          cast->getNumOperands() == newValues.size());
   origArg->replaceAllUsesWith(cast->getResult(0));
@@ -370,8 +350,8 @@ struct DialectConversionRewriter final : public PatternRewriter {
     BlockActionKind kind;
   };
 
-  DialectConversionRewriter(Region &region)
-      : PatternRewriter(region), argConverter(region.getContext()) {}
+  DialectConversionRewriter(Region &region, TypeConverter *converter)
+      : PatternRewriter(region), argConverter(converter, *this) {}
   ~DialectConversionRewriter() = default;
 
   /// Return the current state of the rewriter.
@@ -438,7 +418,7 @@ struct DialectConversionRewriter final : public PatternRewriter {
 
   /// Apply all requested operation rewrites. This method is invoked when the
   /// conversion process succeeds.
-  LogicalResult applyRewrites() {
+  void applyRewrites() {
     // Apply all of the rewrites replacements requested during conversion.
     for (auto &repl : replacements) {
       for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i)
@@ -454,7 +434,7 @@ struct DialectConversionRewriter final : public PatternRewriter {
       repl.op->erase();
     }
 
-    return argConverter.applyRewrites();
+    argConverter.applyRewrites();
   }
 
   /// PatternRewriter hook for replacing the results of an operation.
@@ -943,8 +923,8 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter,
   if (typeConverter) {
     for (Block &block :
          llvm::drop_begin(region.getBlocks(), convertEntryTypes ? 0 : 1)) {
-      if (failed(rewriter.argConverter.convertArguments(
-              &block, rewriter, *typeConverter, rewriter.mapping)))
+      if (failed(
+              rewriter.argConverter.convertArguments(&block, rewriter.mapping)))
         return failure();
     }
   }
@@ -973,13 +953,12 @@ LogicalResult FunctionConverter::convertFunction(
   if (f->isExternal())
     return success();
 
-  DialectConversionRewriter rewriter(f->getBody());
+  DialectConversionRewriter rewriter(f->getBody(), typeConverter);
 
   // Update the signature of the entry block.
   if (signatureConversion) {
-    rewriter.argConverter.convertSignature(&f->getBody().front(), rewriter,
-                                           *typeConverter, *signatureConversion,
-                                           rewriter.mapping);
+    rewriter.argConverter.convertSignature(
+        &f->getBody().front(), *signatureConversion, rewriter.mapping);
   }
 
   // Rewrite the function body.
@@ -990,8 +969,9 @@ LogicalResult FunctionConverter::convertFunction(
     return failure();
   }
 
-  // Otherwise the body conversion succeeded, so try to apply all rewrites.
-  return rewriter.applyRewrites();
+  // Otherwise the body conversion succeeded, so apply all rewrites.
+  rewriter.applyRewrites();
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
index 7dc076b..66777d7 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-legalize-patterns -split-input-file -verify-diagnostics %s | FileCheck %s --dump-input=fail
+// RUN: mlir-opt -test-legalize-patterns %s
 
 // CHECK-LABEL: verifyDirectPattern
 func @verifyDirectPattern() -> i32 {
@@ -17,16 +17,16 @@ func @verifyLargerBenefit() -> i32 {
 // CHECK-LABEL: func @remap_input_1_to_0()
 func @remap_input_1_to_0(i16)
 
-// CHECK-LABEL: func @remap_input_1_to_1(%arg0: f64) -> f64
-func @remap_input_1_to_1(%arg0: i64) -> i64 {
- // CHECK-NEXT: return %arg0 : f64
- return %arg0 : i64
+// CHECK-LABEL: func @remap_input_1_to_1(%arg0: f64)
+func @remap_input_1_to_1(%arg0: i64) {
+  // CHECK-NEXT: "test.valid"{{.*}} : (f64)
+  "test.invalid"(%arg0) : (i64) -> ()
 }
 
-// CHECK-LABEL: func @remap_input_1_to_N(%arg0: f16, %arg1: f16) -> (f16, f16)
+// CHECK-LABEL: func @remap_input_1_to_N(%arg0: f16, %arg1: f16)
 func @remap_input_1_to_N(%arg0: f32) -> f32 {
- // CHECK-NEXT: "test.return"(%arg0, %arg1) : (f16, f16) -> ()
- "test.return"(%arg0) : (f32) -> ()
+ // CHECK-NEXT: "test.valid"(%arg0, %arg1) : (f16, f16) -> ()
+ "test.invalid"(%arg0) : (f32) -> ()
 }
 
 // CHECK-LABEL: func @remap_input_1_to_N_remaining_use(%arg0: f16, %arg1: f16)
@@ -38,8 +38,8 @@ func @remap_input_1_to_N_remaining_use(%arg0: f32) {
 
 // CHECK-LABEL: func @remap_multi(%arg0: f64, %arg1: f64) -> (f64, f64)
 func @remap_multi(%arg0: i64, %unused: i16, %arg1: i64) -> (i64, i64) {
- // CHECK-NEXT: return %arg0, %arg1 : f64, f64
- return %arg0, %arg1 : i64, i64
+ // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
+ "test.invalid"(%arg0, %arg1) : (i64, i64) -> ()
 }
 
 // CHECK-LABEL: func @remap_nested
@@ -48,8 +48,8 @@ func @remap_nested() {
   "foo.region"() ({
     // CHECK-NEXT: ^bb1(%i0: f64, %i1: f64):
     ^bb1(%i0: i64, %unused: i16, %i1: i64):
-      // CHECK-NEXT: "work"{{.*}} : (f64, f64)
-      "work"(%i0, %i1) : (i64, i64) -> ()
+      // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
+      "test.invalid"(%i0, %i1) : (i64, i64) -> ()
   }) : () -> ()
   return
 }
@@ -59,10 +59,10 @@ func @remap_moved_region_args() {
   // CHECK-NEXT: return
   // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
   // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
-  // CHECK-NEXT: "work"{{.*}} : (f64, f64, f32)
+  // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
   "test.region"() ({
     ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
-      "work"(%i0, %i1, %2) : (i64, i64, f32) -> ()
+      "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
   }) : () -> ()
   return
 }
@@ -73,15 +73,14 @@ func @remap_drop_region() {
   // CHECK-NEXT: }
   "test.drop_op"() ({
     ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
-      "work"(%i0, %i1, %2) : (i64, i64, f32) -> ()
+      "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
   }) : () -> ()
   return
 }
 
-// -----
-
+// CHECK-LABEL: func @dropped_input_in_use
 func @dropped_input_in_use(%arg: i16, %arg2: i64) {
-  // expected-error@-1 {{block argument #0 with type 'i16' has unexpected remaining uses}}
-  // expected-note@+1 {{unexpected user defined here}}
+  // CHECK-NEXT: "test.cast"{{.*}} : () -> i16
+  // CHECK-NEXT: "work"{{.*}} : (i16)
   "work"(%arg) : (i16) -> ()
 }
index e12aaac..ba5362b 100644 (file)
@@ -245,5 +245,9 @@ def TestReturnOp : TEST_Op<"return", [Terminator]>,
   Arguments<(ins Variadic<AnyType>:$inputs)>;
 def TestCastOp : TEST_Op<"cast">,
   Arguments<(ins Variadic<AnyType>:$inputs)>, Results<(outs AnyType:$res)>;
+def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
+  Arguments<(ins Variadic<AnyType>:$inputs)>;
+def TestValidOp : TEST_Op<"valid", [Terminator]>,
+  Arguments<(ins Variadic<AnyType>:$inputs)>;
 
 #endif // TEST_OPS
index f323d7f..bde01f7 100644 (file)
@@ -78,6 +78,17 @@ struct TestDropOp : public ConversionPattern {
     return matchSuccess();
   }
 };
+/// This pattern simply updates the operands of the given operation.
+struct TestPassthroughInvalidOp : public ConversionPattern {
+  TestPassthroughInvalidOp(MLIRContext *ctx)
+      : ConversionPattern("test.invalid", 1, ctx) {}
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const final {
+    rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
+                                             llvm::None);
+    return matchSuccess();
+  }
+};
 /// This pattern handles the case of a split return value.
 struct TestSplitReturnType : public ConversionPattern {
   TestSplitReturnType(MLIRContext *ctx)
@@ -139,7 +150,7 @@ struct TestTypeConverter : public TypeConverter {
 
 struct TestConversionTarget : public ConversionTarget {
   TestConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
-    addLegalOp<LegalOpA>();
+    addLegalOp<LegalOpA, TestValidOp>();
     addDynamicallyLegalOp<TestReturnOp>();
   }
   bool isDynamicallyLegal(Operation *op) const final {
@@ -155,6 +166,7 @@ struct TestLegalizePatternDriver
     mlir::OwningRewritePatternList patterns;
     populateWithGenerated(&getContext(), &patterns);
     RewriteListBuilder<TestRegionRewriteBlockMovement, TestDropOp,
+                       TestPassthroughInvalidOp,
                        TestSplitReturnType>::build(patterns, &getContext());
 
     TestTypeConverter converter;