Refactor DialectConversion to support different conversion modes.
authorRiver Riddle <riverriddle@google.com>
Tue, 16 Jul 2019 18:57:45 +0000 (11:57 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 16 Jul 2019 20:45:41 +0000 (13:45 -0700)
Users generally want several different modes of conversion. This cl refactors DialectConversion to provide two:
* Partial (applyPartialConversion)
  - This mode allows for illegal operations to exist in the IR, and does not fail if an operation fails to be legalized.

* Full (applyFullConversion)
  - This mode fails if any operation is not properly legalized to the conversion target. This allows for ensuring that the IR after a conversion only contains operations legal for the target.

PiperOrigin-RevId: 258412243

mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp
mlir/examples/toy/Ch5/mlir/LateLowering.cpp
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Transforms/DialectConversion.cpp
mlir/lib/Transforms/LowerAffine.cpp
mlir/test/lib/TestDialect/TestPatterns.cpp

index 392d9be..989915e 100644 (file)
@@ -421,8 +421,7 @@ LogicalResult linalg::convertToLLVM(mlir::ModuleOp module) {
 
   ConversionTarget target(*module.getContext());
   target.addLegalDialect<LLVM::LLVMDialect>();
-  return applyConversionPatterns(module, target, converter,
-                                 std::move(patterns));
+  return applyFullConversion(module, target, converter, std::move(patterns));
 }
 
 namespace {
index 54d1f55..e4eaca8 100644 (file)
@@ -160,8 +160,8 @@ LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) {
 
   ConversionTarget target(*module.getContext());
   target.addLegalDialect<LLVM::LLVMDialect>();
-  if (failed(applyConversionPatterns(module, target, converter,
-                                     std::move(patterns))))
+  if (failed(
+          applyFullConversion(module, target, converter, std::move(patterns))))
     return failure();
 
   return success();
index 5677f35..e4df917 100644 (file)
@@ -132,8 +132,8 @@ struct EarlyLoweringPass : public FunctionPass<EarlyLoweringPass> {
 
     OwningRewritePatternList patterns;
     RewriteListBuilder<MulOpConversion>::build(patterns, &getContext());
-    if (failed(applyConversionPatterns(getFunction(), target,
-                                       std::move(patterns)))) {
+    if (failed(applyPartialConversion(getFunction(), target,
+                                      std::move(patterns)))) {
       emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering Toy\n");
       signalPassFailure();
     }
index ca8185c..ebc81ef 100644 (file)
@@ -356,8 +356,8 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
     target.addLegalDialect<AffineOpsDialect, linalg::LinalgDialect,
                            LLVM::LLVMDialect, StandardOpsDialect>();
     target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
-    if (failed(applyConversionPatterns(getModule(), target, typeConverter,
-                                       std::move(toyPatterns)))) {
+    if (failed(applyPartialConversion(getModule(), target, typeConverter,
+                                      std::move(toyPatterns)))) {
       emitError(UnknownLoc::get(getModule().getContext()),
                 "Error lowering Toy\n");
       signalPassFailure();
index 2e8ecfa..79f0d38 100644 (file)
@@ -349,31 +349,61 @@ private:
 };
 
 //===----------------------------------------------------------------------===//
-// Conversion Application
+// Op Conversion Entry Points
 //===----------------------------------------------------------------------===//
 
-/// Convert the given module with the provided conversion patterns and type
-/// conversion object. This function returns failure if a type conversion
-/// failed.
-LLVM_NODISCARD LogicalResult applyConversionPatterns(
+/// Apply a partial conversion on the given operations, and all nested
+/// operations. This method converts as many operations to the target as
+/// possible, ignoring operations that failed to legalize. This method only
+/// returns failure if there are unreachable blocks in any of the regions nested
+/// within 'ops'.
+LLVM_NODISCARD LogicalResult
+applyPartialConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
+                       OwningRewritePatternList &&patterns);
+LLVM_NODISCARD LogicalResult
+applyPartialConversion(Operation *op, ConversionTarget &target,
+                       OwningRewritePatternList &&patterns);
+
+/// Apply a complete conversion on the given operations, and all nested
+/// operations. This method returns failure if the conversion of any operation
+/// fails, or if there are unreachable blocks in any of the regions nested
+/// within 'ops'.
+LLVM_NODISCARD LogicalResult
+applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
+                    OwningRewritePatternList &&patterns);
+LLVM_NODISCARD LogicalResult
+applyFullConversion(Operation *op, ConversionTarget &target,
+                    OwningRewritePatternList &&patterns);
+
+//===----------------------------------------------------------------------===//
+// Op + Type Conversion Entry Points
+//===----------------------------------------------------------------------===//
+
+/// Apply a partial conversion on the function operations within the given
+/// module. This method returns failure if a type conversion was encountered.
+LLVM_NODISCARD LogicalResult applyPartialConversion(
     ModuleOp module, ConversionTarget &target, TypeConverter &converter,
     OwningRewritePatternList &&patterns);
 
-/// Convert the given functions with the provided conversion patterns. This
-/// function returns failure if a type conversion failed.
-LLVM_NODISCARD
-LogicalResult applyConversionPatterns(MutableArrayRef<FuncOp> fns,
-                                      ConversionTarget &target,
-                                      TypeConverter &converter,
-                                      OwningRewritePatternList &&patterns);
-
-/// Convert the given function with the provided conversion patterns. This will
-/// convert as many of the operations within 'fn' as possible given the set of
-/// patterns.
-LLVM_NODISCARD
-LogicalResult applyConversionPatterns(FuncOp fn, ConversionTarget &target,
-                                      OwningRewritePatternList &&patterns);
+/// Apply a partial conversion on the given function operations. This method
+/// returns failure if a type conversion was encountered.
+LLVM_NODISCARD LogicalResult applyPartialConversion(
+    MutableArrayRef<FuncOp> fns, ConversionTarget &target,
+    TypeConverter &converter, OwningRewritePatternList &&patterns);
+
+/// Apply a full conversion on the function operations within the given
+/// module. This method returns failure if a type conversion was encountered, or
+/// if the conversion of any operations failed.
+LLVM_NODISCARD LogicalResult applyFullConversion(
+    ModuleOp module, ConversionTarget &target, TypeConverter &converter,
+    OwningRewritePatternList &&patterns);
 
+/// Apply a partial conversion on the given function operations. This method
+/// returns failure if a type conversion was encountered, or if the conversion
+/// of any operation failed.
+LLVM_NODISCARD LogicalResult applyFullConversion(
+    MutableArrayRef<FuncOp> fns, ConversionTarget &target,
+    TypeConverter &converter, OwningRewritePatternList &&patterns);
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
index 5064bba..9c2053d 100644 (file)
@@ -282,7 +282,7 @@ void ControlFlowToCFGPass::runOnFunction() {
   ConversionTarget target(getContext());
   target.addLegalDialect<StandardOpsDialect>();
   if (failed(
-          applyConversionPatterns(getFunction(), target, std::move(patterns))))
+          applyPartialConversion(getFunction(), target, std::move(patterns))))
     signalPassFailure();
 }
 
index 2e8e313..bc5cbfa 100644 (file)
@@ -1064,13 +1064,13 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
 
     ConversionTarget target(getContext());
     target.addLegalDialect<LLVM::LLVMDialect>();
-    if (failed(applyConversionPatterns(m, target, *typeConverter,
-                                       std::move(patterns))))
+    if (failed(applyPartialConversion(m, target, *typeConverter,
+                                      std::move(patterns))))
       signalPassFailure();
   }
 
   // Callback for creating a list of patterns.  It is called every time in
-  // runOnModule since applyConversionPatterns consumes the list.
+  // runOnModule since applyPartialConversion consumes the list.
   LLVMPatternListFiller patternListFiller;
 
   // Callback for creating an instance of type converter.  The converter
index 298f978..c7ea50f 100644 (file)
@@ -763,8 +763,8 @@ void LowerLinalgToLLVMPass::runOnModule() {
 
   ConversionTarget target(getContext());
   target.addLegalDialect<LLVM::LLVMDialect>();
-  if (failed(applyConversionPatterns(module, target, converter,
-                                     std::move(patterns)))) {
+  if (failed(applyPartialConversion(module, target, converter,
+                                    std::move(patterns)))) {
     signalPassFailure();
   }
 
index a0fba2c..fd01ad9 100644 (file)
@@ -841,30 +841,48 @@ void OperationLegalizer::computeLegalizationGraphBenefit() {
 }
 
 //===----------------------------------------------------------------------===//
-// FunctionConverter
+// OperationConverter
 //===----------------------------------------------------------------------===//
 namespace {
-// This class converts a single function using the given pattern matcher. If a
+enum OpConversionMode {
+  // In this mode, the conversion will ignore failed conversions to allow
+  // illegal operations to co-exist in the IR.
+  Partial,
+
+  // In this mode, all operations must be legal for the given target for the
+  // conversion to succeeed.
+  Full,
+};
+
+// This class converts operations using the given pattern matcher. If a
 // TypeConverter object is provided, then the types of block arguments will be
 // converted using the appropriate 'convertType' calls.
-struct FunctionConverter {
-  explicit FunctionConverter(ConversionTarget &target,
-                             OwningRewritePatternList &patterns,
-                             TypeConverter *conversion = nullptr)
-      : typeConverter(conversion), opLegalizer(target, patterns) {}
+struct OperationConverter {
+  explicit OperationConverter(ConversionTarget &target,
+                              OwningRewritePatternList &patterns,
+                              OpConversionMode mode,
+                              TypeConverter *conversion = nullptr)
+      : typeConverter(conversion), opLegalizer(target, patterns), mode(mode) {}
 
   /// Converts the given function to the conversion target. Returns failure on
-  /// error, success otherwise. If 'signatureConversion' is provided, the
-  /// arguments of the entry block are updated accordingly.
+  /// error, success otherwise.
   LogicalResult
   convertFunction(FuncOp f,
-                  TypeConverter::SignatureConversion *signatureConversion);
+                  TypeConverter::SignatureConversion &signatureConversion);
+
+  /// Converts the given operations to the conversion target.
+  LogicalResult convertOperations(ArrayRef<Operation *> ops);
 
 private:
   /// Converts a block or operation with the given rewriter.
   LogicalResult convert(DialectConversionRewriter &rewriter,
                         llvm::PointerUnion<Operation *, Block *> &ptr);
 
+  /// Converts a set of blocks/operations with the given rewriter.
+  LogicalResult
+  convert(DialectConversionRewriter &rewriter,
+          std::vector<llvm::PointerUnion<Operation *, Block *>> &toConvert);
+
   /// Recursively collect all of the blocks, and operations, to convert from
   /// within 'region'.
   LogicalResult computeConversionSet(
@@ -876,11 +894,14 @@ private:
 
   /// The legalizer to use when converting operations.
   OperationLegalizer opLegalizer;
+
+  /// The conversion mode to use when legalizing operations.
+  OpConversionMode mode;
 };
 } // end anonymous namespace
 
 /// Recursively collect all of the blocks to convert from within 'region'.
-LogicalResult FunctionConverter::computeConversionSet(
+LogicalResult OperationConverter::computeConversionSet(
     Region &region,
     std::vector<llvm::PointerUnion<Operation *, Block *>> &toConvert) {
   if (region.empty())
@@ -919,49 +940,78 @@ LogicalResult FunctionConverter::computeConversionSet(
 
 /// Converts a block or operation with the given rewriter.
 LogicalResult
-FunctionConverter::convert(DialectConversionRewriter &rewriter,
-                           llvm::PointerUnion<Operation *, Block *> &ptr) {
+OperationConverter::convert(DialectConversionRewriter &rewriter,
+                            llvm::PointerUnion<Operation *, Block *> &ptr) {
   // If this is a block, then convert the types of each of the arguments.
   if (auto *block = ptr.dyn_cast<Block *>()) {
     assert(typeConverter && "expected valid type converter");
     return rewriter.argConverter.convertArguments(block, rewriter.mapping);
   }
 
-  // Otherwise, this is an operation to legalize.
-  (void)opLegalizer.legalize(ptr.get<Operation *>(), rewriter);
+  // Otherwise, legalize the given operation.
+  auto *op = ptr.get<Operation *>();
+  auto result = opLegalizer.legalize(op, rewriter);
+
+  // Failed conversions are only important if this is a full conversion.
+  if (mode == OpConversionMode::Full && failed(result))
+    return op->emitError() << "failed to legalize operation '" << op->getName()
+                           << "'";
+
+  // In any other case, illegal operations are allowed to remain in the IR.
   return success();
 }
 
-LogicalResult FunctionConverter::convertFunction(
-    FuncOp f, TypeConverter::SignatureConversion *signatureConversion) {
+LogicalResult OperationConverter::convert(
+    DialectConversionRewriter &rewriter,
+    std::vector<llvm::PointerUnion<Operation *, Block *>> &toConvert) {
+  // Convert each operation/block and discard rewrites on failure.
+  for (auto &it : toConvert) {
+    if (failed(convert(rewriter, it))) {
+      rewriter.discardRewrites();
+      return failure();
+    }
+  }
+
+  // Otherwise the body conversion succeeded, so apply all rewrites.
+  rewriter.applyRewrites();
+  return success();
+}
+
+LogicalResult OperationConverter::convertFunction(
+    FuncOp f, TypeConverter::SignatureConversion &signatureConversion) {
   // If this is an external function, there is nothing else to do.
   if (f.isExternal())
     return success();
 
-  DialectConversionRewriter rewriter(f.getContext(), typeConverter);
-
   // Update the signature of the entry block.
-  if (signatureConversion) {
-    rewriter.argConverter.convertSignature(
-        &f.getBody().front(), *signatureConversion, rewriter.mapping);
-  }
+  DialectConversionRewriter rewriter(f.getContext(), typeConverter);
+  rewriter.argConverter.convertSignature(&f.getBody().front(),
+                                         signatureConversion, rewriter.mapping);
 
-  /// Compute the set of operations and blocks to convert.
+  // Compute the set of operations and blocks to convert.
   std::vector<llvm::PointerUnion<Operation *, Block *>> toConvert;
   if (failed(computeConversionSet(f.getBody(), toConvert)))
     return failure();
+  return convert(rewriter, toConvert);
+}
 
-  // Convert each operation/block and discard rewrites on failure.
-  for (auto &it : toConvert) {
-    if (failed(convert(rewriter, it))) {
-      rewriter.discardRewrites();
-      return failure();
-    }
+/// Converts the given top-level operation to the conversion target.
+LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
+  if (ops.empty())
+    return success();
+
+  /// Compute the set of operations and blocks to convert.
+  std::vector<llvm::PointerUnion<Operation *, Block *>> toConvert;
+  for (auto *op : ops) {
+    toConvert.emplace_back(op);
+    for (auto &region : op->getRegions())
+      if (failed(computeConversionSet(region, toConvert)))
+        return failure();
   }
 
-  // Otherwise the body conversion succeeded, so apply all rewrites.
-  rewriter.applyRewrites();
-  return success();
+  // Rewrite the blocks and operations.
+  DialectConversionRewriter rewriter(ops.front()->getContext(), typeConverter);
+  return convert(rewriter, toConvert);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1102,34 +1152,59 @@ auto ConversionTarget::getOpAction(OperationName op) const
 }
 
 //===----------------------------------------------------------------------===//
-// Conversion Application
+// Op Conversion Entry Points
 //===----------------------------------------------------------------------===//
 
-/// Convert the given module with the provided conversion patterns and type
-/// conversion object. If conversion fails for specific functions, those
-/// functions remains unmodified.
+/// Apply a partial conversion on the given operations, and all nested
+/// operations. This method converts as many operations to the target as
+/// possible, ignoring operations that failed to legalize.
 LogicalResult
-mlir::applyConversionPatterns(ModuleOp module, ConversionTarget &target,
-                              TypeConverter &converter,
-                              OwningRewritePatternList &&patterns) {
-  SmallVector<FuncOp, 32> allFunctions(module.getOps<FuncOp>());
-  return applyConversionPatterns(allFunctions, target, converter,
-                                 std::move(patterns));
+mlir::applyPartialConversion(ArrayRef<Operation *> ops,
+                             ConversionTarget &target,
+                             OwningRewritePatternList &&patterns) {
+  OperationConverter converter(target, patterns, OpConversionMode::Partial);
+  return converter.convertOperations(ops);
+}
+LogicalResult
+mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
+                             OwningRewritePatternList &&patterns) {
+  return applyPartialConversion(llvm::makeArrayRef(op), target,
+                                std::move(patterns));
+}
+
+/// Apply a complete conversion on the given operations, and all nested
+/// operations. This method will return failure if the conversion of any
+/// operation fails.
+LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
+                                        ConversionTarget &target,
+                                        OwningRewritePatternList &&patterns) {
+  OperationConverter converter(target, patterns, OpConversionMode::Full);
+  return converter.convertOperations(ops);
 }
+LogicalResult mlir::applyFullConversion(Operation *op, ConversionTarget &target,
+                                        OwningRewritePatternList &&patterns) {
+  return applyFullConversion(llvm::makeArrayRef(op), target,
+                             std::move(patterns));
+}
+
+//===----------------------------------------------------------------------===//
+// Op + Type Conversion Entry Points
+//===----------------------------------------------------------------------===//
 
-/// Convert the given functions with the provided conversion patterns.
-LogicalResult mlir::applyConversionPatterns(
-    MutableArrayRef<FuncOp> fns, ConversionTarget &target,
-    TypeConverter &converter, OwningRewritePatternList &&patterns) {
+static LogicalResult applyConversion(MutableArrayRef<FuncOp> fns,
+                                     ConversionTarget &target,
+                                     TypeConverter &converter,
+                                     OwningRewritePatternList &&patterns,
+                                     OpConversionMode mode) {
   if (fns.empty())
     return success();
 
   // Build the function converter.
-  auto *ctx = fns.front().getContext();
-  FunctionConverter funcConverter(target, patterns, &converter);
+  OperationConverter funcConverter(target, patterns, mode, &converter);
 
   // Try to convert each of the functions within the module.
   SmallVector<NamedAttributeList, 4> argAttrs;
+  auto *ctx = fns.front().getContext();
   for (auto func : fns) {
     argAttrs.clear();
     func.getAllArgAttrs(argAttrs);
@@ -1144,20 +1219,53 @@ LogicalResult mlir::applyConversionPatterns(
     func.setAllArgAttrs(conversion->getConvertedArgAttrs());
 
     // Convert the body of this function.
-    if (failed(funcConverter.convertFunction(func, &*conversion)))
+    if (failed(funcConverter.convertFunction(func, *conversion)))
       return failure();
   }
 
   return success();
 }
 
-/// Convert the given function with the provided conversion patterns. This will
-/// convert as many of the operations within 'fn' as possible given the set of
-/// patterns.
+/// Apply a partial conversion on the function operations within the given
+/// module. This method returns failure if a type conversion was encountered.
 LogicalResult
-mlir::applyConversionPatterns(FuncOp fn, ConversionTarget &target,
-                              OwningRewritePatternList &&patterns) {
-  // Convert the body of this function.
-  FunctionConverter converter(target, patterns);
-  return converter.convertFunction(fn, /*signatureConversion=*/nullptr);
+mlir::applyPartialConversion(ModuleOp module, ConversionTarget &target,
+                             TypeConverter &converter,
+                             OwningRewritePatternList &&patterns) {
+  SmallVector<FuncOp, 32> allFunctions(module.getOps<FuncOp>());
+  return applyPartialConversion(allFunctions, target, converter,
+                                std::move(patterns));
+}
+
+/// Apply a partial conversion on the given function operations. This method
+/// returns failure if a type conversion was encountered.
+LogicalResult
+mlir::applyPartialConversion(MutableArrayRef<FuncOp> fns,
+                             ConversionTarget &target, TypeConverter &converter,
+                             OwningRewritePatternList &&patterns) {
+  return applyConversion(fns, target, converter, std::move(patterns),
+                         OpConversionMode::Partial);
+}
+
+/// Apply a full conversion on the function operations within the given module.
+/// This method returns failure if a type conversion was encountered, or if the
+/// conversion of any operations failed.
+LogicalResult mlir::applyFullConversion(ModuleOp module,
+                                        ConversionTarget &target,
+                                        TypeConverter &converter,
+                                        OwningRewritePatternList &&patterns) {
+  SmallVector<FuncOp, 32> allFunctions(module.getOps<FuncOp>());
+  return applyFullConversion(allFunctions, target, converter,
+                             std::move(patterns));
+}
+
+/// Apply a full conversion on the given function operations. This method
+/// returns failure if a type conversion was encountered, or if the conversion
+/// of any operation failed.
+LogicalResult mlir::applyFullConversion(MutableArrayRef<FuncOp> fns,
+                                        ConversionTarget &target,
+                                        TypeConverter &converter,
+                                        OwningRewritePatternList &&patterns) {
+  return applyConversion(fns, target, converter, std::move(patterns),
+                         OpConversionMode::Full);
 }
index 82b7074..20a9134 100644 (file)
@@ -521,8 +521,8 @@ class LowerAffinePass : public FunctionPass<LowerAffinePass> {
     populateAffineToStdConversionPatterns(patterns, &getContext());
     ConversionTarget target(getContext());
     target.addLegalDialect<loop::LoopOpsDialect, StandardOpsDialect>();
-    if (failed(applyConversionPatterns(getFunction(), target,
-                                       std::move(patterns))))
+    if (failed(
+            applyPartialConversion(getFunction(), target, std::move(patterns))))
       signalPassFailure();
   }
 };
index d4aef38..ed5aea8 100644 (file)
@@ -177,8 +177,8 @@ struct TestLegalizePatternDriver
 
     TestTypeConverter converter;
     TestConversionTarget target(getContext());
-    if (failed(applyConversionPatterns(getModule(), target, converter,
-                                       std::move(patterns))))
+    if (failed(applyPartialConversion(getModule(), target, converter,
+                                      std::move(patterns))))
       signalPassFailure();
   }
 };