Accept additional conversions in the LLVM lowering
authorAlex Zinenko <zinenko@google.com>
Fri, 3 May 2019 12:31:55 +0000 (05:31 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 6 May 2019 15:26:24 +0000 (08:26 -0700)
    Extend the LLVM lowering class following the original idea of the "bag of
    conversions".  LLVMLowering class is now exposed as and can be derived from.
    It provides hooks for derived classes to inject operation conversions and to
    convert custom types.  It is under responsibility of the caller to make sure
    patterns don't overlap.

    Update the lowering from the Linalg dialect to the LLVM IR dialect to use this
    new approach.

--

PiperOrigin-RevId: 246492919

mlir/include/mlir/LLVMIR/LLVMLowering.h [new file with mode: 0644]
mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp

diff --git a/mlir/include/mlir/LLVMIR/LLVMLowering.h b/mlir/include/mlir/LLVMIR/LLVMLowering.h
new file mode 100644 (file)
index 0000000..590d973
--- /dev/null
@@ -0,0 +1,80 @@
+//===- LLVMLowering.h - Lowering to the LLVM IR dialect ---------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Provides a dialect conversion targeting the LLVM IR dialect.  By default, it
+// converts Standard ops and types and provides hooks for dialect-specific
+// extensions to the conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_LLVMIR_LLVMLOWERING_H
+#define MLIR_LLVMIR_LLVMLOWERING_H
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace llvm {
+class Module;
+}
+
+namespace mlir {
+namespace LLVM {
+class LLVMDialect;
+}
+
+/// Conversion from the Standard dialect to the LLVM IR dialect.  Provides hooks
+/// for derived classes to extend the conversion.
+class LLVMLowering : public DialectConversion {
+protected:
+  /// Create a set of converters that live in the pass object by passing them a
+  /// reference to the LLVM IR dialect.  Store the module associated with the
+  /// dialect for further type conversion.
+  llvm::DenseSet<DialectOpConversion *>
+  initConverters(MLIRContext *mlirContext) override final;
+
+  /// Derived classes can override this function to initialize custom converters
+  /// in addition to the existing converters from Standard operations.  It will
+  /// be called after the `module` and `llvmDialect` have been made available.
+  virtual llvm::DenseSet<DialectOpConversion *> initAdditionalConverters() {
+    return {};
+  };
+
+  /// Convert standard and builtin types to LLVM IR.
+  Type convertType(Type t) override final;
+
+  /// Derived classes can override this function to convert custom types.  It
+  /// will be called by convertType if the default conversion from standard and
+  /// builtin types fails.  Derived classes can thus call convertType whenever
+  /// they need type conversion that supports both default and custom types.
+  virtual Type convertAdditionalType(Type t) { return t; }
+
+  /// Convert function signatures to LLVM IR.  In particular, convert functions
+  /// with multiple results into functions returning LLVM IR's structure type.
+  /// Use `convertType` to convert individual argument and result types.
+  FunctionType convertFunctionSignatureType(
+      FunctionType t, ArrayRef<NamedAttributeList> argAttrs,
+      SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) override final;
+
+  /// Storage for the conversion patterns.
+  llvm::BumpPtrAllocator converterStorage;
+  /// LLVM IR module used to parse/create types.
+  llvm::Module *module;
+  LLVM::LLVMDialect *llvmDialect;
+};
+
+} // namespace mlir
+
+#endif // MLIR_LLVMIR_LLVMLOWERING_H
index d103430..587bd6b 100644 (file)
@@ -25,6 +25,7 @@
 #include "mlir/IR/Module.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/LLVMIR/LLVMLowering.h"
 #include "mlir/LLVMIR/Transforms.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/StandardOps/Ops.h"
@@ -63,11 +64,13 @@ public:
 
   // Convert a function signature type to the LLVM IR dialect.  The outer
   // function type remains `mlir::FunctionType`.  Argument types are converted
-  // to LLVM IR as is.  If the function returns a single result, its type is
-  // converted.  Otherwise, the types of results are packed into an LLVM IR
-  // structure type.
-  static FunctionType convertFunctionSignature(FunctionType t,
-                                               llvm::Module &llvmModule);
+  // to LLVM IR using `typeConversionCallback` if provided and using
+  // `TypeConverter::convert` otherwise.  If the function returns a single
+  // result, its type is converted.  Otherwise, the types of results are packed
+  // into an LLVM IR structure type.
+  static FunctionType convertFunctionSignature(
+      FunctionType t, llvm::Module &llvmModule,
+      llvm::function_ref<Type(Type)> typeConversionCallback = {});
 
 private:
   // Construct a type converter.
@@ -83,7 +86,8 @@ private:
 
   // Convert function type arguments and results without converting the
   // function type itself.
-  FunctionType convertFunctionSignatureType(FunctionType type);
+  FunctionType convertFunctionSignatureType(
+      FunctionType type, llvm::function_ref<Type(Type)> typeConversionCallback);
 
   // Convert the index type.  Uses llvmModule data layout to create an integer
   // of the pointer bitwidth.
@@ -224,10 +228,14 @@ Type TypeConverter::convertFunctionType(FunctionType type) {
                   ->getPointerTo());
 }
 
-FunctionType TypeConverter::convertFunctionSignatureType(FunctionType type) {
+FunctionType TypeConverter::convertFunctionSignatureType(
+    FunctionType type, llvm::function_ref<Type(Type)> typeConversionCallback) {
+  if (!typeConversionCallback)
+    typeConversionCallback = [this](Type t) { return convertType(t); };
+
   SmallVector<Type, 8> argTypes;
   for (auto t : type.getInputs()) {
-    auto converted = convertType(t);
+    auto converted = typeConversionCallback(t);
     if (!converted)
       return {};
     argTypes.push_back(converted);
@@ -298,11 +306,6 @@ Type TypeConverter::convertType(Type type) {
   if (auto llvmType = type.dyn_cast<LLVM::LLVMType>())
     return llvmType;
 
-  std::string message;
-  llvm::raw_string_ostream os(message);
-  os << "unsupported type: ";
-  type.print(os);
-  mlirContext->emitError(UnknownLoc::get(mlirContext), os.str());
   return {};
 }
 
@@ -310,9 +313,11 @@ Type TypeConverter::convert(Type t, llvm::Module &module) {
   return TypeConverter(module, t.getContext()).convertType(t);
 }
 
-FunctionType TypeConverter::convertFunctionSignature(FunctionType t,
-                                                     llvm::Module &module) {
-  return TypeConverter(module, t.getContext()).convertFunctionSignatureType(t);
+FunctionType TypeConverter::convertFunctionSignature(
+    FunctionType t, llvm::Module &module,
+    llvm::function_ref<Type(Type)> typeConversionCallback) {
+  return TypeConverter(module, t.getContext())
+      .convertFunctionSignatureType(t, typeConversionCallback);
 }
 
 Type TypeConverter::getMemRefElementPtrType(MemRefType t,
@@ -1089,60 +1094,73 @@ void mlir::LLVM::ensureDistinctSuccessors(Module *m) {
   }
 };
 
-/// A dialect converter from the Standard dialect to the LLVM IR dialect.
-class LLVMLowering : public DialectConversion {
-protected:
-  // Create a set of converters that live in the pass object by passing them a
-  // reference to the LLVM IR dialect.  Store the module associated with the
-  // dialect for further type conversion.
-  llvm::DenseSet<DialectOpConversion *>
-  initConverters(MLIRContext *mlirContext) override {
-    converterStorage.Reset();
-    auto *llvmDialect = static_cast<LLVM::LLVMDialect *>(
-        mlirContext->getRegisteredDialect("llvm"));
-    if (!llvmDialect) {
-      mlirContext->emitError(UnknownLoc::get(mlirContext),
-                             "LLVM IR dialect is not registered");
-      return {};
-    }
-
-    module = &llvmDialect->getLLVMModule();
-
-    // FIXME: this should be tablegen'ed
-    return ConversionListBuilder<
-        AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
-        BranchOpLowering, CallIndirectOpLowering, CallOpLowering,
-        CmpIOpLowering, CondBranchOpLowering, ConstLLVMOpLowering,
-        DeallocOpLowering, DimOpLowering, DivISOpLowering, DivIUOpLowering,
-        DivFOpLowering, LoadOpLowering, MemRefCastOpLowering, MulFOpLowering,
-        MulIOpLowering, OrOpLowering, RemISOpLowering, RemIUOpLowering,
-        RemFOpLowering, ReturnOpLowering, SelectOpLowering, StoreOpLowering,
-        SubFOpLowering, SubIOpLowering, XOrOpLowering>::build(&converterStorage,
-                                                              *llvmDialect);
+// Create a set of converters that live in the pass object by passing them a
+// reference to the LLVM IR dialect.  Store the module associated with the
+// dialect for further type conversion.
+llvm::DenseSet<DialectOpConversion *>
+LLVMLowering::initConverters(MLIRContext *mlirContext) {
+  converterStorage.Reset();
+  llvmDialect = static_cast<LLVM::LLVMDialect *>(
+      mlirContext->getRegisteredDialect("llvm"));
+  if (!llvmDialect) {
+    mlirContext->emitError(UnknownLoc::get(mlirContext),
+                           "LLVM IR dialect is not registered");
+    return {};
   }
 
-  // Convert types using the stored LLVM IR module.
-  Type convertType(Type t) override {
-    return TypeConverter::convert(t, *module);
-  }
+  module = &llvmDialect->getLLVMModule();
+
+  // FIXME: this should be tablegen'ed
+  auto converters = ConversionListBuilder<
+      AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
+      BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
+      CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering,
+      DimOpLowering, DivISOpLowering, DivIUOpLowering, DivFOpLowering,
+      LoadOpLowering, MemRefCastOpLowering, MulFOpLowering, MulIOpLowering,
+      OrOpLowering, RemISOpLowering, RemIUOpLowering, RemFOpLowering,
+      ReturnOpLowering, SelectOpLowering, StoreOpLowering, SubFOpLowering,
+      SubIOpLowering, XOrOpLowering>::build(&converterStorage, *llvmDialect);
+  auto extraConverters = initAdditionalConverters();
+  converters.insert(extraConverters.begin(), extraConverters.end());
+  return converters;
+}
 
-  // Convert function signatures using the stored LLVM IR module.
-  FunctionType convertFunctionSignatureType(
-      FunctionType t, ArrayRef<NamedAttributeList> argAttrs,
-      SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) override {
+// Convert types using the stored LLVM IR module.
+Type LLVMLowering::convertType(Type t) {
+  if (auto result = TypeConverter::convert(t, *module))
+    return result;
+  if (auto result = convertAdditionalType(t))
+    return result;
 
-    convertedArgAttrs.reserve(argAttrs.size());
-    for (auto attr : argAttrs)
-      convertedArgAttrs.push_back(attr);
-    return TypeConverter::convertFunctionSignature(t, *module);
-  }
+  auto *mlirContext = llvmDialect->getContext();
+  std::string message;
+  llvm::raw_string_ostream os(message);
+  os << "unsupported type: ";
+  t.print(os);
+  mlirContext->emitError(UnknownLoc::get(mlirContext), os.str());
+  return {};
+}
 
-private:
-  // Storage for the conversion patterns.
-  llvm::BumpPtrAllocator converterStorage;
-  // LLVM IR module used to parse/create types.
-  llvm::Module *module;
+// Convert function signatures using the stored LLVM IR module.
+FunctionType LLVMLowering::convertFunctionSignatureType(
+    FunctionType t, ArrayRef<NamedAttributeList> argAttrs,
+    SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) {
+
+  convertedArgAttrs.reserve(argAttrs.size());
+  for (auto attr : argAttrs)
+    convertedArgAttrs.push_back(attr);
+  return TypeConverter::convertFunctionSignature(
+      t, *module, [this](Type t) { return convertType(t); });
+}
+
+namespace {
+// Make sure LLVM conversion pass errors out on the unsupported types instead
+// of keeping them as is and resulting in a more cryptic verifier error.
+class LLVMStandardLowering : public LLVMLowering {
+protected:
+  Type convertAdditionalType(Type) override { return {}; }
 };
+} // namespace
 
 /// A pass converting MLIR Standard operations into the LLVM IR dialect.
 class LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
@@ -1156,7 +1174,7 @@ public:
   }
 
 private:
-  LLVMLowering impl;
+  LLVMStandardLowering impl;
 };
 
 ModulePassBase *mlir::createConvertToLLVMIRPass() {
index d399f6f..7463c71 100644 (file)
@@ -26,6 +26,7 @@
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/Types.h"
 #include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/LLVMIR/LLVMLowering.h"
 #include "mlir/LLVMIR/Transforms.h"
 #include "mlir/Linalg/IR/LinalgOps.h"
 #include "mlir/Linalg/IR/LinalgTypes.h"
@@ -375,63 +376,23 @@ public:
   }
 };
 
-llvm::DenseSet<mlir::DialectOpConversion *>
-allocateDescriptorConverters(llvm::BumpPtrAllocator *allocator,
-                             mlir::MLIRContext *context) {
-  return ConversionListBuilder<BufferSizeOpConversion, DotOpConversion,
-                               RangeOpConversion, SliceOpConversion,
-                               ViewOpConversion>::build(allocator, context);
-}
-
 namespace {
 // The conversion class from Linalg to LLVMIR.
-class Lowering : public DialectConversion {
-public:
-  explicit Lowering(std::function<llvm::DenseSet<mlir::DialectOpConversion *>(
-                        llvm::BumpPtrAllocator *, mlir::MLIRContext *context)>
-                        conversions)
-      : setup(conversions) {}
-
-  Lowering &setLLVMModule(MLIRContext *context) {
-    llvmModule = getLLVMModule(context);
-    return *this;
-  }
-
+class Lowering : public LLVMLowering {
 protected:
-  // Initialize the list of converters.
-  llvm::DenseSet<DialectOpConversion *>
-  initConverters(MLIRContext *context) override {
-    converterStorage.Reset();
-    return setup(&converterStorage, context);
+  llvm::DenseSet<DialectOpConversion *> initAdditionalConverters() override {
+    return ConversionListBuilder<
+        BufferSizeOpConversion, DotOpConversion, RangeOpConversion,
+        SliceOpConversion, ViewOpConversion>::build(&converterStorage,
+                                                    llvmDialect->getContext());
   }
 
-  // This gets called for block and region arguments, and attributes.
-  Type convertType(Type t) override {
-    if (auto res = convertLinalgType(t, *llvmModule))
-      return res;
-    return convertToLLVMDialectType(t, *llvmModule);
+  Type convertAdditionalType(Type t) override {
+    return convertLinalgType(t, *module);
   }
-
-private:
-  // Storage for individual converters.
-  llvm::BumpPtrAllocator converterStorage;
-
-  // Conversion setup.
-  std::function<llvm::DenseSet<mlir::DialectOpConversion *>(
-      llvm::BumpPtrAllocator *, mlir::MLIRContext *context)>
-      setup;
-
-  llvm::Module *llvmModule;
 };
 } // end anonymous namespace
 
-std::unique_ptr<mlir::DialectConversion> makeLinalgToLLVMLowering(
-    std::function<llvm::DenseSet<mlir::DialectOpConversion *>(
-        llvm::BumpPtrAllocator *, mlir::MLIRContext *context)>
-        initer) {
-  return llvm::make_unique<Lowering>(initer);
-}
-
 namespace {
 struct LowerLinalgToLLVMPass : public ModulePass<LowerLinalgToLLVMPass> {
   void runOnModule();
@@ -441,18 +402,8 @@ struct LowerLinalgToLLVMPass : public ModulePass<LowerLinalgToLLVMPass> {
 void LowerLinalgToLLVMPass::runOnModule() {
   auto &module = getModule();
 
-  // Convert Linalg ops to the LLVM IR dialect using the converter defined
-  // above.
-  auto r = Lowering(allocateDescriptorConverters)
-               .setLLVMModule(module.getContext())
-               .convert(&module);
-  if (failed(r))
-    signalPassFailure();
-
-  // Convert the remaining standard MLIR operations to the LLVM IR dialect using
-  // the default converter.
-  auto converter = createStdToLLVMConverter();
-  r = converter->convert(&module);
+  // Convert to the LLVM IR dialect using the converter defined above.
+  auto r = Lowering().convert(&module);
   if (failed(r))
     signalPassFailure();
 }