Rename DialectConversion to TypeConverter and split out pattern construction...
authorRiver Riddle <riverriddle@google.com>
Sat, 25 May 2019 01:17:50 +0000 (18:17 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:02:03 +0000 (20:02 -0700)
--

PiperOrigin-RevId: 249930583

mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h
mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp
mlir/examples/toy/Ch5/mlir/LateLowering.cpp
mlir/include/mlir/LLVMIR/LLVMLowering.h
mlir/include/mlir/LLVMIR/Transforms.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Transforms/DialectConversion.cpp

index 8bc80d2..e341705 100644 (file)
@@ -45,15 +45,8 @@ mlir::Type convertLinalgType(mlir::Type t);
 /// dialect to the LLVM IR dialect. The LLVM IR dialect must be registered. This
 /// function can be used to apply multiple conversion patterns in the same pass.
 /// It does not have to be called explicitly before the conversion.
-void getDescriptorConverters(mlir::OwningRewritePatternList &patterns,
-                             mlir::MLIRContext *context);
-
-/// Create a DialectConversion from the Linalg dialect to the LLVM IR dialect.
-/// The conversion is set up to convert types and function signatures using
-/// `convertLinalgType` and obtains operation converters by calling `initer`.
-std::unique_ptr<mlir::DialectConversion> makeLinalgToLLVMLowering(
-    std::function<void(mlir::OwningRewritePatternList &, mlir::MLIRContext *)>
-        initer);
+void populateLinalg1ToLLVMConversionPatterns(
+    mlir::OwningRewritePatternList &patterns, mlir::MLIRContext *context);
 
 /// Convert the Linalg dialect types and RangeOp, ViewOp and SliceOp operations
 /// to the LLVM IR dialect types and operations in the given `module`.  This is
index 9a97ee9..26c298a 100644 (file)
@@ -383,46 +383,26 @@ public:
   }
 };
 
-void linalg::getDescriptorConverters(mlir::OwningRewritePatternList &patterns,
-                                     mlir::MLIRContext *context) {
+void linalg::populateLinalg1ToLLVMConversionPatterns(
+    mlir::OwningRewritePatternList &patterns, mlir::MLIRContext *context) {
   RewriteListBuilder<DropConsumer, RangeOpConversion, SliceOpConversion,
                      ViewOpConversion>::build(patterns, context);
 }
 
 namespace {
-// The conversion class from Linalg to LLVMIR.
-class Lowering : public LLVMLowering {
-public:
-  explicit Lowering(std::function<void(mlir::OwningRewritePatternList &patterns,
-                                       mlir::MLIRContext *context)>
-                        conversions)
-      : setup(conversions) {}
-
-protected:
-  // Initialize the list of converters.
-  void initAdditionalConverters(OwningRewritePatternList &patterns) override {
-    setup(patterns, llvmDialect->getContext());
-  }
+/// A type conversion class that converts Linalg and Std types to LLVM.
+struct LinalgTypeConverter : public LLVMTypeConverter {
+  using LLVMTypeConverter::LLVMTypeConverter;
 
   // This gets called for block and region arguments, and attributes.
-  Type convertAdditionalType(Type t) override {
+  Type convertType(Type t) override {
+    if (auto result = LLVMTypeConverter::convertType(t))
+      return result;
     return linalg::convertLinalgType(t);
   }
-
-private:
-  // Conversion setup.
-  std::function<void(mlir::OwningRewritePatternList &patterns,
-                     mlir::MLIRContext *context)>
-      setup;
 };
 } // end anonymous namespace
 
-std::unique_ptr<mlir::DialectConversion> linalg::makeLinalgToLLVMLowering(
-    std::function<void(mlir::OwningRewritePatternList &, mlir::MLIRContext *)>
-        initer) {
-  return llvm::make_unique<Lowering>(initer);
-}
-
 void linalg::convertToLLVM(mlir::Module &module) {
   // Remove affine constructs if any by using an existing pass.
   PassManager pm;
@@ -433,8 +413,12 @@ void linalg::convertToLLVM(mlir::Module &module) {
 
   // Convert Linalg ops to the LLVM IR dialect using the converter defined
   // above.
-  Lowering lowering(getDescriptorConverters);
-  auto r = applyConverter(module, lowering);
+  LinalgTypeConverter converter(module.getContext());
+  OwningRewritePatternList patterns;
+  populateStdToLLVMConversionPatterns(converter, patterns);
+  populateLinalg1ToLLVMConversionPatterns(patterns, module.getContext());
+
+  auto r = applyConversionPatterns(module, converter, std::move(patterns));
   (void)r;
   assert(succeeded(r) && "conversion failed");
 }
index c9d52de..db9f496 100644 (file)
@@ -26,6 +26,7 @@
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/Types.h"
 #include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/LLVMIR/LLVMLowering.h"
 #include "mlir/LLVMIR/Transforms.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -122,13 +123,23 @@ class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
   }
 };
 
+/// A type conversion class that converts Linalg and Std types to LLVM.
+struct LinalgTypeConverter : public LLVMTypeConverter {
+  using LLVMTypeConverter::LLVMTypeConverter;
+
+  // This gets called for block and region arguments, and attributes.
+  Type convertType(Type t) override {
+    if (auto result = LLVMTypeConverter::convertType(t))
+      return result;
+    return linalg::convertLinalgType(t);
+  }
+};
 } // end anonymous namespace
 
 // Helper function that allocates the descriptor converters and adds load/store
 // coverters to the list.
-static void getConversions(mlir::OwningRewritePatternList &patterns,
-                           mlir::MLIRContext *context) {
-  linalg::getDescriptorConverters(patterns, context);
+static void populateLinalg3ToLLVMConversionPatterns(
+    mlir::OwningRewritePatternList &patterns, mlir::MLIRContext *context) {
   RewriteListBuilder<LoadOpConversion, StoreOpConversion>::build(patterns,
                                                                  context);
 }
@@ -141,8 +152,15 @@ void linalg::convertLinalg3ToLLVM(Module &module) {
   (void)rr;
   assert(succeeded(rr) && "affine loop lowering failed");
 
-  auto lowering = makeLinalgToLLVMLowering(getConversions);
-  auto r = applyConverter(module, *lowering);
+  // Convert Linalg ops to the LLVM IR dialect using the converter defined
+  // above.
+  LinalgTypeConverter converter(module.getContext());
+  OwningRewritePatternList patterns;
+  populateStdToLLVMConversionPatterns(converter, patterns);
+  populateLinalg1ToLLVMConversionPatterns(patterns, module.getContext());
+  populateLinalg3ToLLVMConversionPatterns(patterns, module.getContext());
+
+  auto r = applyConversionPatterns(module, converter, std::move(patterns));
   (void)r;
   assert(succeeded(r) && "conversion failed");
 }
index d38ad4f..f4ac522 100644 (file)
@@ -119,26 +119,16 @@ public:
   }
 };
 
-// The conversion class from Toy IR Dialect to a mix of Linalg and LLVM.
-class EarlyLowering : public DialectConversion {
-protected:
-  // Initialize the list of converters.
-  void initConverters(OwningRewritePatternList &patterns,
-                      MLIRContext *context) override {
-    RewriteListBuilder<MulOpConversion>::build(patterns, context);
-  }
-};
-
 /// This is lowering to Linalg the parts that are computationally intensive
 /// (like matmul for example...) while keeping the rest of the code in the Toy
 /// dialect.
-struct EarlyLoweringPass : public ModulePass<EarlyLoweringPass> {
-  void runOnModule() override {
-    EarlyLowering lowering;
-    if (failed(applyConverter(getModule(), lowering))) {
-      getModule().getContext()->emitError(
-          mlir::UnknownLoc::get(getModule().getContext()),
-          "Error lowering Toy\n");
+struct EarlyLoweringPass : public FunctionPass<EarlyLoweringPass> {
+  void runOnFunction() override {
+    OwningRewritePatternList patterns;
+    RewriteListBuilder<MulOpConversion>::build(patterns, &getContext());
+    if (failed(applyConversionPatterns(getFunction(), std::move(patterns)))) {
+      getContext().emitError(mlir::UnknownLoc::get(&getContext()),
+                             "Error lowering Toy\n");
       signalPassFailure();
     }
   }
@@ -147,9 +137,4 @@ struct EarlyLoweringPass : public ModulePass<EarlyLoweringPass> {
 
 namespace toy {
 Pass *createEarlyLoweringPass() { return new EarlyLoweringPass(); }
-
-std::unique_ptr<mlir::DialectConversion> makeToyEarlyLowering() {
-  return llvm::make_unique<EarlyLowering>();
-}
-
 } // namespace toy
index 13b20b9..611d716 100644 (file)
@@ -316,16 +316,8 @@ public:
 
 /// This is the main class registering our individual converter classes with
 /// the DialectConversion framework in MLIR.
-class LateLowering : public DialectConversion {
+class ToyTypeConverter : public TypeConverter {
 protected:
-  /// Initialize the list of converters.
-  void initConverters(OwningRewritePatternList &patterns,
-                      MLIRContext *context) override {
-    RewriteListBuilder<AddOpConversion, PrintOpConversion, ConstantOpConversion,
-                       TransposeOpConversion,
-                       ReturnOpConversion>::build(patterns, context);
-  }
-
   /// Convert a Toy type, this gets called for block and region arguments, and
   /// attributes.
   Type convertType(Type t) override {
@@ -339,13 +331,20 @@ protected:
 /// and is targeting LLVM otherwise.
 struct LateLoweringPass : public ModulePass<LateLoweringPass> {
   void runOnModule() override {
+    ToyTypeConverter typeConverter;
+    OwningRewritePatternList toyPatterns;
+    RewriteListBuilder<AddOpConversion, PrintOpConversion, ConstantOpConversion,
+                       TransposeOpConversion,
+                       ReturnOpConversion>::build(toyPatterns, &getContext());
+
     // Perform Toy specific lowering.
-    LateLowering lowering;
-    if (failed(applyConverter(getModule(), lowering))) {
+    if (failed(applyConversionPatterns(getModule(), typeConverter,
+                                       std::move(toyPatterns)))) {
       getModule().getContext()->emitError(
           UnknownLoc::get(getModule().getContext()), "Error lowering Toy\n");
       signalPassFailure();
     }
+
     // At this point the IR is almost using only standard and affine dialects.
     // A few things remain before we emit LLVM IR. First to reuse as much of
     // MLIR as possible we will try to lower everything to the standard and/or
@@ -432,9 +431,4 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
 
 namespace toy {
 Pass *createLateLoweringPass() { return new LateLoweringPass(); }
-
-std::unique_ptr<DialectConversion> makeToyLateLowering() {
-  return llvm::make_unique<LateLowering>();
-}
-
 } // namespace toy
index 9db282f..64b6400 100644 (file)
@@ -37,15 +37,16 @@ namespace mlir {
 namespace LLVM {
 class LLVMDialect;
 class LLVMType;
-}
+} // namespace LLVM
 
-/// Conversion from the Standard dialect to the LLVM IR dialect.  Provides hooks
-/// for derived classes to extend the conversion.
-class LLVMLowering : public DialectConversion {
+/// Conversion from types in the Standard dialect to the LLVM IR dialect.
+class LLVMTypeConverter : public TypeConverter {
 public:
+  LLVMTypeConverter(MLIRContext *ctx);
+
   /// Convert types to LLVM IR.  This calls `convertAdditionalType` to convert
   /// non-standard or non-builtin types.
-  Type convertType(Type t) override final;
+  Type convertType(Type t) override;
 
   /// Convert a non-empty list of types to be returned from a function into a
   /// supported LLVM IR type.  In particular, if more than one values is
@@ -60,22 +61,6 @@ public:
   LLVM::LLVMDialect *getDialect() { return llvmDialect; }
 
 protected:
-  /// Add a set of converters to the given pattern list. Store the module
-  /// associated with the dialect for further type conversion.
-  void initConverters(OwningRewritePatternList &patterns,
-                      MLIRContext *mlirContext) override final;
-
-  /// Derived classes can override this function to initialize custom converters
-  /// in addition to the existing converters from Standard operations.  It will
-  /// be called after the `module` and `llvmDialect` have been made available.
-  virtual void initAdditionalConverters(OwningRewritePatternList &patterns) {}
-
-  /// Derived classes can override this function to convert custom types.  It
-  /// will be called by convertType if the default conversion from standard and
-  /// builtin types fails.  Derived classes can thus call convertType whenever
-  /// they need type conversion that supports both default and custom types.
-  virtual Type convertAdditionalType(Type t) { return t; }
-
   /// Convert function signatures to LLVM IR.  In particular, convert functions
   /// with multiple results into functions returning LLVM IR's structure type.
   /// Use `convertType` to convert individual argument and result types.
@@ -83,8 +68,6 @@ protected:
       FunctionType t, ArrayRef<NamedAttributeList> argAttrs,
       SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) override final;
 
-  /// Storage for the conversion patterns.
-  llvm::BumpPtrAllocator converterStorage;
   /// LLVM IR module used to parse/create types.
   llvm::Module *module;
   LLVM::LLVMDialect *llvmDialect;
@@ -138,12 +121,12 @@ private:
 class LLVMOpLowering : public ConversionPattern {
 public:
   LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
-                 LLVMLowering &lowering);
+                 LLVMTypeConverter &lowering);
 
 protected:
   // Back-reference to the lowering class, used to call type and function
   // conversions accounting for potential extensions.
-  LLVMLowering &lowering;
+  LLVMTypeConverter &lowering;
 };
 
 } // namespace mlir
index 95244b8..051fce8 100644 (file)
@@ -19,6 +19,7 @@
 #define MLIR_LLVMIR_TRANSFORMS_H_
 
 #include <memory>
+#include <vector>
 
 namespace llvm {
 class Module;
@@ -26,16 +27,20 @@ class Module;
 
 namespace mlir {
 class DialectConversion;
+class LLVMTypeConverter;
 class Module;
 class ModulePassBase;
+class RewritePattern;
 class Type;
 
+using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
+
 /// Creates a pass to convert Standard dialects into the LLVMIR dialect.
 ModulePassBase *createConvertToLLVMIRPass();
 
-/// Creates a dialect converter from the standard dialect to the LLVM IR
-/// dialect and transfers ownership to the caller.
-std::unique_ptr<DialectConversion> createStdToLLVMConverter();
+/// Collect a set of patterns to convert from the Standard dialect to LLVM.
+void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                         OwningRewritePatternList &patterns);
 
 namespace LLVM {
 /// Make argument-taking successors of each block distinct.  PHI nodes in LLVM
index bbbcd7d..165e065 100644 (file)
@@ -95,32 +95,11 @@ private:
   using RewritePattern::rewrite;
 };
 
-/// Base class for dialect conversion interface.  Specific converters must
+/// Base class for type conversion interface. Specific converters must
 /// derive this class and implement the pure virtual functions.
-///
-/// The module conversion proceeds as follows.
-/// 1. Call `initConverters` to obtain a set of conversions to apply, given the
-///    current MLIR context.
-/// 2. For each function in the module do the following.
-//    a. Create a new function with the same name and convert its signature
-//       using `convertType`.
-//    b. For each block in the function, create a block in the function with
-//       its arguments converted using `convertType`.
-//    c. Traverse blocks in DFS-preorder of successors starting from the entry
-//       block (if any), and convert individual operations as follows.  Pattern
-//       match against the list of conversions.  On the first match, call
-//       `rewrite` for the operations, and advance to the next iteration.  If no
-//       match is found, replicate the operation as is.
-class DialectConversion {
+class TypeConverter {
 public:
-  virtual ~DialectConversion() = default;
-
-  /// Derived classes must implement this hook to produce a set of conversion
-  /// patterns to apply.  They may use `mlirContext` to obtain registered
-  /// dialects or operations.  This will be called in the beginning of the
-  /// conversion.
-  virtual void initConverters(OwningRewritePatternList &patterns,
-                              MLIRContext *mlirContext) = 0;
+  virtual ~TypeConverter() = default;
 
   /// Derived classes must reimplement this hook if they need to convert
   /// block or function argument types or function result types.  If the target
@@ -157,11 +136,12 @@ public:
       SmallVectorImpl<NamedAttributeList> &convertedArgAttrs);
 };
 
-/// Convert the given module with the provided dialect conversion object.
-/// If conversion fails for a specific function, those functions remains
-/// unmodified.
+/// Convert the given module with the provided conversion patterns and type
+/// conversion object. If conversion fails for specific functions, those
+/// functions remains unmodified.
 LLVM_NODISCARD
-LogicalResult applyConverter(Module &module, DialectConversion &converter);
+LogicalResult applyConversionPatterns(Module &module, TypeConverter &converter,
+                                      OwningRewritePatternList &&patterns);
 
 /// Convert the given function with the provided conversion patterns. This will
 /// convert as many of the operations within 'fn' as possible given the set of
index 4354f82..3a756a9 100644 (file)
 
 using namespace mlir;
 
+LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
+    : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) {
+  assert(llvmDialect && "LLVM IR dialect is not registered");
+  module = &llvmDialect->getLLVMModule();
+}
+
 // Get the LLVM context.
-llvm::LLVMContext &LLVMLowering::getLLVMContext() {
+llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
   return module->getContext();
 }
 
 // Extract an LLVM IR type from the LLVM IR dialect type.
-LLVM::LLVMType LLVMLowering::unwrap(Type type) {
+LLVM::LLVMType LLVMTypeConverter::unwrap(Type type) {
   if (!type)
     return nullptr;
   auto *mlirContext = type.getContext();
@@ -57,18 +63,20 @@ LLVM::LLVMType LLVMLowering::unwrap(Type type) {
   return wrappedLLVMType;
 }
 
-LLVM::LLVMType LLVMLowering::getIndexType() {
+LLVM::LLVMType LLVMTypeConverter::getIndexType() {
   return LLVM::LLVMType::getIntNTy(
       llvmDialect, module->getDataLayout().getPointerSizeInBits());
 }
 
-Type LLVMLowering::convertIndexType(IndexType type) { return getIndexType(); }
+Type LLVMTypeConverter::convertIndexType(IndexType type) {
+  return getIndexType();
+}
 
-Type LLVMLowering::convertIntegerType(IntegerType type) {
+Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
   return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth());
 }
 
-Type LLVMLowering::convertFloatType(FloatType type) {
+Type LLVMTypeConverter::convertFloatType(FloatType type) {
   switch (type.getKind()) {
   case mlir::StandardTypes::F32:
     return LLVM::LLVMType::getFloatTy(llvmDialect);
@@ -91,7 +99,7 @@ Type LLVMLowering::convertFloatType(FloatType type) {
 // argument and result types.  If MLIR Function has zero results, the LLVM
 // Function has one VoidType result.  If MLIR Function has more than one result,
 // they are into an LLVM StructType in their order of appearance.
-Type LLVMLowering::convertFunctionType(FunctionType type) {
+Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
   // Convert argument types one by one and check for errors.
   SmallVector<LLVM::LLVMType, 8> argTypes;
   for (auto t : type.getInputs()) {
@@ -119,7 +127,7 @@ Type LLVMLowering::convertFunctionType(FunctionType type) {
 // LLVM stucture type, where the first element of the structure type is a
 // pointer to the elemental type of the MemRef and the following N elements are
 // values of the Index type, one for each of N dynamic dimensions of the MemRef.
-Type LLVMLowering::convertMemRefType(MemRefType type) {
+Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
   LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
   if (!elementType)
     return {};
@@ -138,7 +146,7 @@ Type LLVMLowering::convertMemRefType(MemRefType type) {
 }
 
 // Convert a 1D vector type to an LLVM vector type.
-Type LLVMLowering::convertVectorType(VectorType type) {
+Type LLVMTypeConverter::convertVectorType(VectorType type) {
   if (type.getRank() != 1) {
     auto *mlirContext = llvmDialect->getContext();
     mlirContext->emitError(UnknownLoc::get(mlirContext),
@@ -153,7 +161,7 @@ Type LLVMLowering::convertVectorType(VectorType type) {
 }
 
 // Dispatch based on the actual type.  Return null type on error.
-Type LLVMLowering::convertStandardType(Type type) {
+Type LLVMTypeConverter::convertStandardType(Type type) {
   if (auto funcType = type.dyn_cast<FunctionType>())
     return convertFunctionType(funcType);
   if (auto intType = type.dyn_cast<IntegerType>())
@@ -175,7 +183,7 @@ Type LLVMLowering::convertStandardType(Type type) {
 // Convert the element type of the memref `t` to to an LLVM type using
 // `lowering`, get a pointer LLVM type pointing to the converted `t`, wrap it
 // into the MLIR LLVM dialect type and return.
-static Type getMemRefElementPtrType(MemRefType t, LLVMLowering &lowering) {
+static Type getMemRefElementPtrType(MemRefType t, LLVMTypeConverter &lowering) {
   auto elementType = t.getElementType();
   auto converted = lowering.convertType(elementType);
   if (!converted)
@@ -184,7 +192,7 @@ static Type getMemRefElementPtrType(MemRefType t, LLVMLowering &lowering) {
 }
 
 LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
-                               LLVMLowering &lowering_)
+                               LLVMTypeConverter &lowering_)
     : ConversionPattern(rootOpName, /*benefit=*/1, context),
       lowering(lowering_) {}
 
@@ -197,7 +205,7 @@ class LLVMLegalizationPattern : public LLVMOpLowering {
 public:
   // Construct a conversion pattern.
   explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_,
-                                   LLVMLowering &lowering_)
+                                   LLVMTypeConverter &lowering_)
       : LLVMOpLowering(SourceOp::getOperationName(), dialect_.getContext(),
                        lowering_),
         dialect(dialect_) {}
@@ -895,20 +903,9 @@ void mlir::LLVM::ensureDistinctSuccessors(Module *m) {
   }
 }
 
-// Create a set of converters that live in the pass object by passing them a
-// reference to the LLVM IR dialect.  Store the module associated with the
-// dialect for further type conversion.
-void LLVMLowering::initConverters(OwningRewritePatternList &patterns,
-                                  MLIRContext *mlirContext) {
-  llvmDialect = mlirContext->getRegisteredDialect<LLVM::LLVMDialect>();
-  if (!llvmDialect) {
-    mlirContext->emitError(UnknownLoc::get(mlirContext),
-                           "LLVM IR dialect is not registered");
-    return;
-  }
-
-  module = &llvmDialect->getLLVMModule();
-
+/// Collect a set of patterns to convert from the Standard dialect to LLVM.
+void mlir::populateStdToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
   // FIXME: this should be tablegen'ed
   RewriteListBuilder<
       AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
@@ -918,25 +915,15 @@ void LLVMLowering::initConverters(OwningRewritePatternList &patterns,
       LoadOpLowering, MemRefCastOpLowering, MulFOpLowering, MulIOpLowering,
       OrOpLowering, RemISOpLowering, RemIUOpLowering, RemFOpLowering,
       ReturnOpLowering, SelectOpLowering, StoreOpLowering, SubFOpLowering,
-      SubIOpLowering, XOrOpLowering>::build(patterns, *llvmDialect, *this);
-  initAdditionalConverters(patterns);
+      SubIOpLowering, XOrOpLowering>::build(patterns, *converter.getDialect(),
+                                            converter);
 }
 
 // Convert types using the stored LLVM IR module.
-Type LLVMLowering::convertType(Type t) {
-  if (auto result = convertStandardType(t))
-    return result;
-  if (auto result = convertAdditionalType(t))
-    return result;
-
-  auto *mlirContext = llvmDialect->getContext();
-  mlirContext->emitError(UnknownLoc::get(mlirContext))
-      << "unsupported type: " << t;
-  return {};
-}
+Type LLVMTypeConverter::convertType(Type t) { return convertStandardType(t); }
 
 // Create an LLVM IR structure type if there is more than one result.
-Type LLVMLowering::packFunctionResults(ArrayRef<Type> types) {
+Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
   assert(!types.empty() && "expected non-empty list of type");
 
   if (types.size() == 1)
@@ -955,7 +942,7 @@ Type LLVMLowering::packFunctionResults(ArrayRef<Type> types) {
 }
 
 // Convert function signatures using the stored LLVM IR module.
-FunctionType LLVMLowering::convertFunctionSignatureType(
+FunctionType LLVMTypeConverter::convertFunctionSignatureType(
     FunctionType type, ArrayRef<NamedAttributeList> argAttrs,
     SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) {
 
@@ -983,36 +970,25 @@ FunctionType LLVMLowering::convertFunctionSignatureType(
 }
 
 namespace {
-// Make sure LLVM conversion pass errors out on the unsupported types instead
-// of keeping them as is and resulting in a more cryptic verifier error.
-class LLVMStandardLowering : public LLVMLowering {
-protected:
-  Type convertAdditionalType(Type) override { return {}; }
-};
-} // namespace
-
 /// A pass converting MLIR Standard operations into the LLVM IR dialect.
-class LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
-public:
+struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
   // Run the dialect converter on the module.
   void runOnModule() override {
     Module &m = getModule();
     LLVM::ensureDistinctSuccessors(&m);
-    if (failed(applyConverter(m, impl)))
+
+    LLVMTypeConverter converter(&getContext());
+    OwningRewritePatternList patterns;
+    populateStdToLLVMConversionPatterns(converter, patterns);
+    if (failed(applyConversionPatterns(m, converter, std::move(patterns))))
       signalPassFailure();
   }
-
-private:
-  LLVMStandardLowering impl;
 };
+} // end anonymous namespace
 
 ModulePassBase *mlir::createConvertToLLVMIRPass() {
   return new LLVMLoweringPass();
 }
 
-std::unique_ptr<DialectConversion> mlir::createStdToLLVMConverter() {
-  return llvm::make_unique<LLVMLowering>();
-}
-
 static PassRegistration<LLVMLoweringPass>
     pass("lower-to-llvm", "Convert all functions to the LLVM IR dialect");
index f94868d..f51c25a 100644 (file)
@@ -64,7 +64,8 @@ using llvm_select = ValueBuilder<LLVM::SelectOp>;
 using icmp = ValueBuilder<LLVM::ICmpOp>;
 
 template <typename T>
-static LLVMType getPtrToElementType(T containerType, LLVMLowering &lowering) {
+static LLVMType getPtrToElementType(T containerType,
+                                    LLVMTypeConverter &lowering) {
   return lowering.convertType(containerType.getElementType())
       .template cast<LLVMType>()
       .getPointerTo();
@@ -78,7 +79,7 @@ static LLVMType getPtrToElementType(T containerType, LLVMLowering &lowering) {
 //   - an F32 type is converted into an LLVM float type
 //   - a Buffer, Range or View is converted into an LLVM structure type
 //     containing the respective dynamic values.
-static Type convertLinalgType(Type t, LLVMLowering &lowering) {
+static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) {
   auto *context = t.getContext();
   auto int64Ty = lowering.convertType(IntegerType::get(64, context))
                      .cast<LLVM::LLVMType>();
@@ -152,7 +153,7 @@ static ArrayAttr positionAttr(Builder &builder, ArrayRef<int> position) {
 class BufferAllocOpConversion : public LLVMOpLowering {
 public:
   explicit BufferAllocOpConversion(MLIRContext *context,
-                                   LLVMLowering &lowering_)
+                                   LLVMTypeConverter &lowering_)
       : LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {}
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
@@ -207,7 +208,7 @@ public:
 class BufferDeallocOpConversion : public LLVMOpLowering {
 public:
   explicit BufferDeallocOpConversion(MLIRContext *context,
-                                     LLVMLowering &lowering_)
+                                     LLVMTypeConverter &lowering_)
       : LLVMOpLowering(BufferDeallocOp::getOperationName(), context,
                        lowering_) {}
 
@@ -241,7 +242,7 @@ public:
 // BufferSizeOp creates a new `index` value.
 class BufferSizeOpConversion : public LLVMOpLowering {
 public:
-  BufferSizeOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+  BufferSizeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
       : LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {}
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
@@ -256,7 +257,7 @@ public:
 // DimOp creates a new `index` value.
 class DimOpConversion : public LLVMOpLowering {
 public:
-  explicit DimOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+  explicit DimOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
       : LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {}
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
@@ -277,7 +278,8 @@ namespace {
 // LLVM IR Dialect.
 template <typename Op> class LoadStoreOpConversion : public LLVMOpLowering {
 public:
-  explicit LoadStoreOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+  explicit LoadStoreOpConversion(MLIRContext *context,
+                                 LLVMTypeConverter &lowering_)
       : LLVMOpLowering(Op::getOperationName(), context, lowering_) {}
   using Base = LoadStoreOpConversion<Op>;
 
@@ -327,7 +329,7 @@ class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
 // RangeOp creates a new range descriptor.
 class RangeOpConversion : public LLVMOpLowering {
 public:
-  explicit RangeOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+  explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
       : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {}
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
@@ -354,7 +356,7 @@ public:
 class RangeIntersectOpConversion : public LLVMOpLowering {
 public:
   explicit RangeIntersectOpConversion(MLIRContext *context,
-                                      LLVMLowering &lowering_)
+                                      LLVMTypeConverter &lowering_)
       : LLVMOpLowering(RangeIntersectOp::getOperationName(), context,
                        lowering_) {}
 
@@ -397,7 +399,7 @@ public:
 
 class SliceOpConversion : public LLVMOpLowering {
 public:
-  explicit SliceOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+  explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
       : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {}
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
@@ -494,7 +496,7 @@ class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
 
 class ViewOpConversion : public LLVMOpLowering {
 public:
-  explicit ViewOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+  explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
       : LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {}
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
@@ -550,7 +552,7 @@ public:
 // DotOp creates a new range descriptor.
 class DotOpConversion : public LLVMOpLowering {
 public:
-  explicit DotOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+  explicit DotOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
       : LLVMOpLowering(DotOp::getOperationName(), context, lowering_) {}
 
   static StringRef libraryFunctionName() { return "linalg_dot"; }
@@ -574,25 +576,30 @@ public:
 
 namespace {
 // The conversion class from Linalg to LLVMIR.
-class Lowering : public LLVMLowering {
-protected:
-  void initAdditionalConverters(OwningRewritePatternList &patterns) override {
-    RewriteListBuilder<BufferAllocOpConversion, BufferDeallocOpConversion,
-                       BufferSizeOpConversion, DimOpConversion, DotOpConversion,
-                       LoadOpConversion, RangeOpConversion,
-                       RangeIntersectOpConversion, SliceOpConversion,
-                       StoreOpConversion,
-                       ViewOpConversion>::build(patterns,
-                                                llvmDialect->getContext(),
-                                                *this);
-  }
+struct LinalgTypeConverter : LLVMTypeConverter {
+  using LLVMTypeConverter::LLVMTypeConverter;
 
-  Type convertAdditionalType(Type t) override {
+  Type convertType(Type t) override {
+    if (auto result = LLVMTypeConverter::convertType(t))
+      return result;
     return convertLinalgType(t, *this);
   }
 };
 } // end anonymous namespace
 
+/// Populate the given list with patterns that convert from Linalg to LLVM.
+static void
+populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
+                                       OwningRewritePatternList &patterns,
+                                       MLIRContext *ctx) {
+  RewriteListBuilder<BufferAllocOpConversion, BufferDeallocOpConversion,
+                     BufferSizeOpConversion, DimOpConversion, DotOpConversion,
+                     LoadOpConversion, RangeOpConversion,
+                     RangeIntersectOpConversion, SliceOpConversion,
+                     StoreOpConversion, ViewOpConversion>::build(patterns, ctx,
+                                                                 converter);
+}
+
 namespace {
 struct LowerLinalgToLLVMPass : public ModulePass<LowerLinalgToLLVMPass> {
   void runOnModule();
@@ -608,8 +615,12 @@ void LowerLinalgToLLVMPass::runOnModule() {
     signalPassFailure();
 
   // Convert to the LLVM IR dialect using the converter defined above.
-  Lowering lowering;
-  if (failed(applyConverter(module, lowering)))
+  OwningRewritePatternList patterns;
+  LinalgTypeConverter converter(&getContext());
+  populateStdToLLVMConversionPatterns(converter, patterns);
+  populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
+
+  if (failed(applyConversionPatterns(module, converter, std::move(patterns))))
     signalPassFailure();
 }
 
index 4428302..b6a15f6 100644 (file)
@@ -227,16 +227,16 @@ void ConversionPattern::rewrite(Operation *op,
 //===----------------------------------------------------------------------===//
 namespace {
 // This class converts a single function using the given pattern matcher. If a
-// DialectConversion object is also provided, then the types of block arguments
-// will be converted using the appropriate 'convertType' calls.
+// TypeConverter object is provided, then the types of block arguments will be
+// converted using the appropriate 'convertType' calls.
 class FunctionConverter {
 public:
   explicit FunctionConverter(MLIRContext *ctx, RewritePatternMatcher &matcher,
-                             DialectConversion *conversion = nullptr)
-      : dialectConversion(conversion), matcher(matcher) {}
+                             TypeConverter *conversion = nullptr)
+      : typeConverter(conversion), matcher(matcher) {}
 
   /// Converts the given function to the dialect using hooks defined in
-  /// `dialectConversion`. Returns failure on error, success otherwise.
+  /// `typeConverter`. Returns failure on error, success otherwise.
   LogicalResult convertFunction(Function *f);
 
   /// Converts the given region starting from the entry block and following the
@@ -260,7 +260,7 @@ public:
                                 BlockArgument *arg, Location loc);
 
   /// Pointer to a specific dialect conversion info.
-  DialectConversion *dialectConversion;
+  TypeConverter *typeConverter;
 
   /// The matcher to use when converting operations.
   RewritePatternMatcher &matcher;
@@ -270,7 +270,7 @@ public:
 LogicalResult
 FunctionConverter::convertArgument(DialectConversionRewriter &rewriter,
                                    BlockArgument *arg, Location loc) {
-  auto convertedType = dialectConversion->convertType(arg->getType());
+  auto convertedType = typeConverter->convertType(arg->getType());
   if (!convertedType)
     return arg->getContext()->emitError(loc)
            << "could not convert block argument of type : " << arg->getType();
@@ -322,7 +322,7 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter,
   // Create the arguments of each of the blocks in the region. If a type
   // converter was not provided, then we don't need to change any of the block
   // types.
-  if (dialectConversion) {
+  if (typeConverter) {
     for (Block &block : region)
       for (auto *arg : block.getArguments())
         if (failed(convertArgument(rewriter, arg, loc)))
@@ -362,12 +362,12 @@ LogicalResult FunctionConverter::convertFunction(Function *f) {
 }
 
 //===----------------------------------------------------------------------===//
-// DialectConversion
+// TypeConverter
 //===----------------------------------------------------------------------===//
 
 // Create a function type with arguments and results converted, and argument
 // attributes passed through.
-FunctionType DialectConversion::convertFunctionSignatureType(
+FunctionType TypeConverter::convertFunctionSignatureType(
     FunctionType type, ArrayRef<NamedAttributeList> argAttrs,
     SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) {
   SmallVector<Type, 8> arguments;
@@ -412,16 +412,15 @@ struct ConvertedFunction {
 };
 } // end anonymous namespace
 
-/// Convert the given module with the provided dialect conversion object.
-/// If conversion fails for a specific function, those functions remains
-/// unmodified.
-LogicalResult mlir::applyConverter(Module &module,
-                                   DialectConversion &converter) {
+/// Convert the given module with the provided conversion patterns and type
+/// conversion object. If conversion fails for specific functions, those
+/// functions remains unmodified.
+LogicalResult
+mlir::applyConversionPatterns(Module &module, TypeConverter &converter,
+                              OwningRewritePatternList &&patterns) {
   // Grab the conversion patterns from the converter and create the pattern
   // matcher.
   MLIRContext *context = module.getContext();
-  OwningRewritePatternList patterns;
-  converter.initConverters(patterns, context);
   RewritePatternMatcher matcher(std::move(patterns));
 
   // Try to convert each of the functions within the module. Defer updating the