Decouple running a conversion from the DialectConversion class. The DialectConver...
authorRiver Riddle <riverriddle@google.com>
Thu, 23 May 2019 16:23:33 +0000 (09:23 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 02:58:04 +0000 (19:58 -0700)
--

PiperOrigin-RevId: 249657549

mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp
mlir/examples/toy/Ch5/mlir/LateLowering.cpp
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Transforms/DialectConversion.cpp

index 7a0e1a5..b81fc25 100644 (file)
@@ -436,7 +436,8 @@ void linalg::convertToLLVM(mlir::Module &module) {
 
   // Convert Linalg ops to the LLVM IR dialect using the converter defined
   // above.
-  auto r = Lowering(getDescriptorConverters).convert(&module);
+  Lowering lowering(getDescriptorConverters);
+  auto r = applyConverter(module, lowering);
   (void)r;
   assert(succeeded(r) && "conversion failed");
 }
index 9f2d46c..4a258f8 100644 (file)
@@ -143,7 +143,7 @@ void linalg::convertLinalg3ToLLVM(Module &module) {
   assert(succeeded(rr) && "affine loop lowering failed");
 
   auto lowering = makeLinalgToLLVMLowering(getConversions);
-  auto r = lowering->convert(&module);
+  auto r = applyConverter(module, *lowering);
   (void)r;
   assert(succeeded(r) && "conversion failed");
 }
index b6e0703..72ef800 100644 (file)
@@ -134,7 +134,8 @@ protected:
 /// dialect.
 struct EarlyLoweringPass : public ModulePass<EarlyLoweringPass> {
   void runOnModule() override {
-    if (failed(EarlyLowering().convert(&getModule()))) {
+    EarlyLowering lowering;
+    if (failed(applyConverter(getModule(), lowering))) {
       getModule().getContext()->emitError(
           mlir::UnknownLoc::get(getModule().getContext()),
           "Error lowering Toy\n");
index 5a0a901..2837807 100644 (file)
@@ -343,8 +343,9 @@ protected:
 /// and is targeting LLVM otherwise.
 struct LateLoweringPass : public ModulePass<LateLoweringPass> {
   void runOnModule() override {
-    // Perform Toy specific lowering
-    if (failed(LateLowering().convert(&getModule()))) {
+    // Perform Toy specific lowering.
+    LateLowering lowering;
+    if (failed(applyConverter(getModule(), lowering))) {
       getModule().getContext()->emitError(
           UnknownLoc::get(getModule().getContext()), "Error lowering Toy\n");
       signalPassFailure();
index 3facb7a..d1b3318 100644 (file)
@@ -39,8 +39,8 @@ class Value;
 /// Base class for the dialect conversion patterns that require type changes.
 /// Specific conversions must derive this class and implement least one
 /// `rewrite` method.
-/// NOTE: These conversion patterns can only be used with the DialectConversion
-/// class.
+/// NOTE: These conversion patterns can only be used with the 'apply*' methods
+/// below.
 class DialectConversionPattern : public RewritePattern {
 public:
   /// Construct an DialectConversionPattern.  `rootName` must correspond to the
@@ -112,22 +112,10 @@ private:
 //       match against the list of conversions.  On the first match, call
 //       `rewrite` for the operations, and advance to the next iteration.  If no
 //       match is found, replicate the operation as is.
-/// 3. Update all attributes of function type to point to the new functions.
-/// 4. Replace old functions with new functions in the module.
-/// If any error happened during the conversion, the pass fails as soon as
-/// possible.
-///
-/// If conversion fails for a specific function, that functions remains
-/// unmodified. Otherwise, successfully converted functions will remain
-/// converted.
 class DialectConversion {
 public:
   virtual ~DialectConversion() = default;
 
-  /// Run the converter on the provided module.
-  LLVM_NODISCARD
-  LogicalResult convert(Module *m);
-
   /// Derived classes must implement this hook to produce a set of conversion
   /// patterns to apply.  They may use `mlirContext` to obtain registered
   /// dialects or operations.  This will be called in the beginning of the
@@ -170,6 +158,19 @@ public:
       SmallVectorImpl<NamedAttributeList> &convertedArgAttrs);
 };
 
+/// Convert the given module with the provided dialect conversion object.
+/// If conversion fails for a specific function, those functions remains
+/// unmodified.
+LLVM_NODISCARD
+LogicalResult applyConverter(Module &module, DialectConversion &converter);
+
+/// Convert the given function with the provided conversion patterns. This will
+/// convert as many of the operations within 'fn' as possible given the set of
+/// patterns.
+LLVM_NODISCARD
+LogicalResult applyConversionPatterns(Function &fn,
+                                      OwningRewritePatternList &&patterns);
+
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
index a2476dc..347280d 100644 (file)
@@ -1006,9 +1006,9 @@ class LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
 public:
   // Run the dialect converter on the module.
   void runOnModule() override {
-    Module *m = &getModule();
-    LLVM::ensureDistinctSuccessors(m);
-    if (failed(impl.convert(m)))
+    Module &m = getModule();
+    LLVM::ensureDistinctSuccessors(&m);
+    if (failed(applyConverter(m, impl)))
       signalPassFailure();
   }
 
index 8c2bdb7..cf8a2cc 100644 (file)
@@ -608,7 +608,8 @@ void LowerLinalgToLLVMPass::runOnModule() {
     signalPassFailure();
 
   // Convert to the LLVM IR dialect using the converter defined above.
-  if (failed(Lowering().convert(&module)))
+  Lowering lowering;
+  if (failed(applyConverter(module, lowering)))
     signalPassFailure();
 }
 
index de66d01..389b5ad 100644 (file)
@@ -226,13 +226,13 @@ void DialectConversionPattern::rewrite(Operation *op,
 // FunctionConverter
 //===----------------------------------------------------------------------===//
 namespace {
-// This class converts a single function using a given DialectConversion
-// structure.
+// This class converts a single function using the given pattern matcher. If a
+// DialectConversion object is also provided, then the types of block arguments
+// will be converted using the appropriate 'convertType' calls.
 class FunctionConverter {
 public:
-  // Constructs a FunctionConverter.
-  explicit FunctionConverter(MLIRContext *ctx, DialectConversion *conversion,
-                             RewritePatternMatcher &matcher)
+  explicit FunctionConverter(MLIRContext *ctx, RewritePatternMatcher &matcher,
+                             DialectConversion *conversion = nullptr)
       : dialectConversion(conversion), matcher(matcher) {}
 
   /// Converts the given function to the dialect using hooks defined in
@@ -319,11 +319,15 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter,
                                  Region &region, RegionParent *parent) {
   assert(!region.empty() && "expected non-empty region");
 
-  // Create the arguments of each of the blocks in the region.
-  for (Block &block : region)
-    for (auto *arg : block.getArguments())
-      if (failed(convertArgument(rewriter, arg, parent->getLoc())))
-        return failure();
+  // Create the arguments of each of the blocks in the region. If a type
+  // converter was not provided, then we don't need to change any of the block
+  // types.
+  if (dialectConversion) {
+    for (Block &block : region)
+      for (auto *arg : block.getArguments())
+        if (failed(convertArgument(rewriter, arg, parent->getLoc())))
+          return failure();
+  }
 
   // Start a DFS-order traversal of the CFG to make sure defs are converted
   // before uses in dominated blocks.
@@ -346,8 +350,8 @@ LogicalResult FunctionConverter::convertFunction(Function *f) {
   // Rewrite the function body.
   DialectConversionRewriter rewriter(f);
   if (failed(convertRegion(rewriter, f->getBody(), f))) {
-    // Reset any of the converted arguments.
-    rewriter.argConverter.discardRewrites();
+    // Reset any of the generated rewrites.
+    rewriter.discardRewrites();
     return failure();
   }
 
@@ -360,24 +364,6 @@ LogicalResult FunctionConverter::convertFunction(Function *f) {
 // DialectConversion
 //===----------------------------------------------------------------------===//
 
-namespace {
-/// This class represents a function to be converted. It allows for converting
-/// the body of functions and the signature in two phases.
-struct ConvertedFunction {
-  ConvertedFunction(Function *fn, FunctionType newType,
-                    ArrayRef<NamedAttributeList> newFunctionArgAttrs)
-      : fn(fn), newType(newType),
-        newFunctionArgAttrs(newFunctionArgAttrs.begin(),
-                            newFunctionArgAttrs.end()) {}
-
-  /// The function to convert.
-  Function *fn;
-  /// The new type and argument attributes for the function.
-  FunctionType newType;
-  SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
-};
-} // end anonymous namespace
-
 // Create a function type with arguments and results converted, and argument
 // attributes passed through.
 FunctionType DialectConversion::convertFunctionSignatureType(
@@ -403,21 +389,38 @@ FunctionType DialectConversion::convertFunctionSignatureType(
   return FunctionType::get(arguments, results, type.getContext());
 }
 
-// Converts the module as follows.
-// 1. Call `convertFunction` on each function of the module and collect the
-// mapping between old and new functions.
-// 2. Remap all function attributes in the new functions to point to the new
-// functions instead of the old ones.
-// 3. Replace old functions with the new in the module.
-LogicalResult DialectConversion::convert(Module *module) {
-  if (!module)
-    return failure();
+//===----------------------------------------------------------------------===//
+// applyConversionPatterns
+//===----------------------------------------------------------------------===//
 
+namespace {
+/// This class represents a function to be converted. It allows for converting
+/// the body of functions and the signature in two phases.
+struct ConvertedFunction {
+  ConvertedFunction(Function *fn, FunctionType newType,
+                    ArrayRef<NamedAttributeList> newFunctionArgAttrs)
+      : fn(fn), newType(newType),
+        newFunctionArgAttrs(newFunctionArgAttrs.begin(),
+                            newFunctionArgAttrs.end()) {}
+
+  /// The function to convert.
+  Function *fn;
+  /// The new type and argument attributes for the function.
+  FunctionType newType;
+  SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
+};
+} // end anonymous namespace
+
+/// Convert the given module with the provided dialect conversion object.
+/// If conversion fails for a specific function, those functions remains
+/// unmodified.
+LogicalResult mlir::applyConverter(Module &module,
+                                   DialectConversion &converter) {
   // Grab the conversion patterns from the converter and create the pattern
   // matcher.
-  MLIRContext *context = module->getContext();
+  MLIRContext *context = module.getContext();
   OwningRewritePatternList patterns;
-  initConverters(patterns, context);
+  converter.initConverters(patterns, context);
   RewritePatternMatcher matcher(std::move(patterns));
 
   // Try to convert each of the functions within the module. Defer updating the
@@ -426,18 +429,18 @@ LogicalResult DialectConversion::convert(Module *module) {
   // public signatures of the functions within the module before they are
   // updated.
   std::vector<ConvertedFunction> toConvert;
-  toConvert.reserve(module->getFunctions().size());
-  for (auto &func : *module) {
+  toConvert.reserve(module.getFunctions().size());
+  for (auto &func : module) {
     // Convert the function type using the dialect converter.
     SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
-    FunctionType newType = convertFunctionSignatureType(
+    FunctionType newType = converter.convertFunctionSignatureType(
         func.getType(), func.getAllArgAttrs(), newFunctionArgAttrs);
     if (!newType || !newType.isa<FunctionType>())
       return func.emitError("could not convert function type");
 
     // Convert the body of this function.
-    FunctionConverter converter(context, this, matcher);
-    if (failed(converter.convertFunction(&func)))
+    FunctionConverter funcConverter(context, matcher, &converter);
+    if (failed(funcConverter.convertFunction(&func)))
       return failure();
 
     // Add function signature to be updated.
@@ -453,3 +456,15 @@ LogicalResult DialectConversion::convert(Module *module) {
 
   return success();
 }
+
+/// Convert the given function with the provided conversion patterns. This will
+/// convert as many of the operations within 'fn' as possible given the set of
+/// patterns.
+LogicalResult
+mlir::applyConversionPatterns(Function &fn,
+                              OwningRewritePatternList &&patterns) {
+  // Convert the body of this function.
+  RewritePatternMatcher matcher(std::move(patterns));
+  FunctionConverter converter(fn.getContext(), matcher);
+  return converter.convertFunction(&fn);
+}