Conversion to LLVM Dialect: integrate TypeConverter into LLVMLowering
authorAlex Zinenko <zinenko@google.com>
Thu, 9 May 2019 12:40:54 +0000 (05:40 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sat, 11 May 2019 02:25:36 +0000 (19:25 -0700)
    Historically, the conversion from standard and built-in types to the LLVM IR
    dialect types was performed by a dedicated class, TypeConverter.  This class
    served to contain references to the LLVM IR dialect and to the LLVM IR Module
    to allow querying the data layout.  Recently, the LLVMLowering class was
    introduced to make the conversion to the LLVM IR dialect extensible to other
    source dialects.  This class also includes the references to the LLVM IR
    dialect and module.  TypeConverter was extended with basic support for
    dialect-specific type conversion through callbacks.  This is not sufficient in
    cases where dialect-specific types appear inside other types, such as function
    or container types.

    Integrate TypeConverter into LLVMLowering.  Whenever a subtype needs to be
    converted during standard type conversion (e.g. an argument or a result of a
    FunctionType), the conversion will call to the virtual function
    `LLVMLowering::convertType`, which can be extended to support dialect-specific
    types.

    Provide a new LLVMOpConversion class that serves as a base class for all
    conversions to the LLVM IR dialect and gives them access to LLVMLowering for
    the purpose of type conversion.  Update Linalg to LLVM IR lowering to use this
    class.

--

PiperOrigin-RevId: 247407314

mlir/include/mlir/LLVMIR/LLVMLowering.h
mlir/include/mlir/LLVMIR/Transforms.h
mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp

index 590d973..02bc816 100644 (file)
@@ -27,7 +27,9 @@
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace llvm {
+class IntegerType;
 class Module;
+class Type;
 }
 
 namespace mlir {
@@ -38,6 +40,17 @@ class LLVMDialect;
 /// Conversion from the Standard dialect to the LLVM IR dialect.  Provides hooks
 /// for derived classes to extend the conversion.
 class LLVMLowering : public DialectConversion {
+public:
+  /// Convert types to LLVM IR.  This calls `convertAdditionalType` to convert
+  /// non-standard or non-builtin types.
+  Type convertType(Type t) override final;
+
+  /// Convert a non-empty list of types to be returned from a function into a
+  /// supported LLVM IR type.  In particular, if more than one values is
+  /// returned, create an LLVM IR structure type with elements that correspond
+  /// to each of the MLIR types converted with `convertType`.
+  Type packFunctionResults(ArrayRef<Type> types);
+
 protected:
   /// Create a set of converters that live in the pass object by passing them a
   /// reference to the LLVM IR dialect.  Store the module associated with the
@@ -52,9 +65,6 @@ protected:
     return {};
   };
 
-  /// Convert standard and builtin types to LLVM IR.
-  Type convertType(Type t) override final;
-
   /// Derived classes can override this function to convert custom types.  It
   /// will be called by convertType if the default conversion from standard and
   /// builtin types fails.  Derived classes can thus call convertType whenever
@@ -73,6 +83,62 @@ protected:
   /// LLVM IR module used to parse/create types.
   llvm::Module *module;
   LLVM::LLVMDialect *llvmDialect;
+
+private:
+  Type convertStandardType(Type type);
+
+  // Convert a function type.  The arguments and results are converted one by
+  // one.  Additionally, if the function returns more than one value, pack the
+  // results into an LLVM IR structure type so that the converted function type
+  // returns at most one result.
+  Type convertFunctionType(FunctionType type);
+
+  // Convert the index type.  Uses llvmModule data layout to create an integer
+  // of the pointer bitwidth.
+  Type convertIndexType(IndexType type);
+
+  // Convert an integer type `i*` to `!llvm<"i*">`.
+  Type convertIntegerType(IntegerType type);
+
+  // Convert a floating point type: `f16` to `!llvm.half`, `f32` to
+  // `!llvm.float` and `f64` to `!llvm.double`.  `bf16` is not supported
+  // by LLVM.
+  Type convertFloatType(FloatType type);
+
+  // Convert a memref type into an LLVM type that captures the relevant data.
+  // For statically-shaped memrefs, the resulting type is a pointer to the
+  // (converted) memref element type. For dynamically-shaped memrefs, the
+  // resulting type is an LLVM structure type that contains:
+  //   1. a pointer to the (converted) memref element type
+  //   2. as many index types as memref has dynamic dimensions.
+  Type convertMemRefType(MemRefType type);
+
+  // Convert a 1D vector type into an LLVM vector type.
+  Type convertVectorType(VectorType type);
+
+  // Get the LLVM representation of the index type based on the bitwidth of the
+  // pointer as defined by the data layout of the module.
+  llvm::IntegerType *getIndexType();
+
+  // Wrap the given LLVM IR type into an LLVM IR dialect type.
+  Type wrap(llvm::Type *llvmType);
+
+  // Extract an LLVM IR type from the LLVM IR dialect type.
+  llvm::Type *unwrap(Type type);
+};
+
+/// Base class for operation conversions targeting the LLVM IR dialect. Provides
+/// conversion patterns with an access to the containing LLVMLowering for the
+/// purpose of type conversions.
+class LLVMOpLowering : public DialectOpConversion {
+public:
+  LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
+                 LLVMLowering &lowering);
+
+protected:
+  // Back-reference to the lowering class, used to call type and function
+  // conversions accounting for potential extensions.
+  LLVMLowering &lowering;
 };
 
 } // namespace mlir
index b021981..95244b8 100644 (file)
@@ -44,10 +44,6 @@ namespace LLVM {
 /// another block as a successor more than once with different values, insert
 /// a new dummy block for LLVM PHI nodes to tell the sources apart.
 void ensureDistinctSuccessors(Module *m);
-
-/// Converts a type in either MLIR standard or builtin type into LLVMIR dialect
-/// type.
-Type convertToLLVMDialectType(Type t, llvm::Module &llvmModule);
 } // namespace LLVM
 
 } // namespace mlir
index 5f57ebc..893a063 100644 (file)
 
 using namespace mlir;
 
-namespace {
-// Type converter for the LLVM IR dialect.  Converts MLIR standard and builtin
-// types into equivalent LLVM IR dialect types.
-class TypeConverter {
-public:
-  // Convert one type `t ` and register it in the `llvmModule`.  The latter may
-  // be used to extract information specific to the data layout.
-  // Dispatches to the private functions below based on the actual type.
-  static Type convert(Type t, llvm::Module &llvmModule);
-
-  // Convert the element type of the memref `t` to to an LLVM type, get a
-  // pointer LLVM type pointing to the converted `t`, wrap it into the MLIR LLVM
-  // dialect type and return.
-  static Type getMemRefElementPtrType(MemRefType t, llvm::Module &llvmModule);
-
-  // Convert a non-empty list of types to an LLVM IR dialect type wrapping an
-  // LLVM IR structure type, elements of which are formed by converting
-  // individual types in the given list.  Register the type in the `llvmModule`.
-  // The module may be also used to query the data layout.
-  static Type pack(ArrayRef<Type> types, llvm::Module &llvmModule,
-                   MLIRContext &context);
-
-  // Convert a function signature type to the LLVM IR dialect.  The outer
-  // function type remains `mlir::FunctionType`.  Argument types are converted
-  // to LLVM IR using `typeConversionCallback` if provided and using
-  // `TypeConverter::convert` otherwise.  If the function returns a single
-  // result, its type is converted.  Otherwise, the types of results are packed
-  // into an LLVM IR structure type.
-  static FunctionType convertFunctionSignature(
-      FunctionType t, llvm::Module &llvmModule,
-      llvm::function_ref<Type(Type)> typeConversionCallback = {});
-
-private:
-  // Construct a type converter.
-  explicit TypeConverter(llvm::Module &llvmModule, MLIRContext *context)
-      : module(llvmModule), llvmContext(llvmModule.getContext()),
-        builder(llvmModule.getContext()), mlirContext(context) {}
-
-  // Convert a function type.  The arguments and results are converted one by
-  // one.  Additionally, if the function returns more than one value, pack the
-  // results into an LLVM IR structure type so that the converted function type
-  // returns at most one result.
-  Type convertFunctionType(FunctionType type);
-
-  // Convert function type arguments and results without converting the
-  // function type itself.
-  FunctionType convertFunctionSignatureType(
-      FunctionType type, llvm::function_ref<Type(Type)> typeConversionCallback);
-
-  // Convert the index type.  Uses llvmModule data layout to create an integer
-  // of the pointer bitwidth.
-  Type convertIndexType(IndexType type);
-
-  // Convert an integer type `i*` to `!llvm<"i*">`.
-  Type convertIntegerType(IntegerType type);
-
-  // Convert a floating point type: `f16` to `!llvm.half`, `f32` to
-  // `!llvm.float` and `f64` to `!llvm.double`.  `bf16` is not supported
-  // by LLVM.
-  Type convertFloatType(FloatType type);
-
-  // Convert a memref type into an LLVM type that captures the relevant data.
-  // For statically-shaped memrefs, the resulting type is a pointer to the
-  // (converted) memref element type. For dynamically-shaped memrefs, the
-  // resulting type is an LLVM structure type that contains:
-  //   1. a pointer to the (converted) memref element type
-  //   2. as many index types as memref has dynamic dimensions.
-  Type convertMemRefType(MemRefType type);
-
-  // Convert a 1D vector type into an LLVM vector type.
-  Type convertVectorType(VectorType type);
-
-  // Convert a non-empty list of types into an LLVM structure type containing
-  // those types.  If the list contains a single element, convert the element
-  // directly.
-  Type getPackedResultType(ArrayRef<Type> types);
-
-  // Convert a type to the LLVM IR dialect.  Returns a null type in case of
-  // error.
-  Type convertType(Type type);
-
-  // Get the LLVM representation of the index type based on the bitwidth of the
-  // pointer as defined by the data layout of the module.
-  llvm::IntegerType *getIndexType();
-
-  // Wrap the given LLVM IR type into an LLVM IR dialect type.
-  Type wrap(llvm::Type *llvmType) {
-    return LLVM::LLVMType::get(mlirContext, llvmType);
-  }
-
-  // Extract an LLVM IR type from the LLVM IR dialect type.
-  llvm::Type *unwrap(Type type) {
-    if (!type)
-      return nullptr;
-    auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
-    if (!wrappedLLVMType)
-      return mlirContext->emitError(UnknownLoc::get(mlirContext),
-                                    "conversion resulted in a non-LLVM type"),
-             nullptr;
-    return wrappedLLVMType.getUnderlyingType();
-  }
-
-  llvm::Module &module;
-  llvm::LLVMContext &llvmContext;
-  llvm::IRBuilder<> builder;
+// Wrap the given LLVM IR type into an LLVM IR dialect type.
+Type LLVMLowering::wrap(llvm::Type *llvmType) {
+  return LLVM::LLVMType::get(llvmDialect->getContext(), llvmType);
+}
 
-  MLIRContext *mlirContext;
-};
-} // end anonymous namespace
+// Extract an LLVM IR type from the LLVM IR dialect type.
+llvm::Type *LLVMLowering::unwrap(Type type) {
+  if (!type)
+    return nullptr;
+  auto *mlirContext = type.getContext();
+  auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
+  if (!wrappedLLVMType)
+    return mlirContext->emitError(UnknownLoc::get(mlirContext),
+                                  "conversion resulted in a non-LLVM type"),
+           nullptr;
+  return wrappedLLVMType.getUnderlyingType();
+}
 
-llvm::IntegerType *TypeConverter::getIndexType() {
-  return builder.getIntNTy(module.getDataLayout().getPointerSizeInBits());
+llvm::IntegerType *LLVMLowering::getIndexType() {
+  return llvm::IntegerType::get(llvmDialect->getLLVMContext(),
+                                module->getDataLayout().getPointerSizeInBits());
 }
 
-Type TypeConverter::convertIndexType(IndexType type) {
+Type LLVMLowering::convertIndexType(IndexType type) {
   return wrap(getIndexType());
 }
 
-Type TypeConverter::convertIntegerType(IntegerType type) {
-  return wrap(builder.getIntNTy(type.getWidth()));
+Type LLVMLowering::convertIntegerType(IntegerType type) {
+  return wrap(
+      llvm::Type::getIntNTy(llvmDialect->getLLVMContext(), type.getWidth()));
 }
 
-Type TypeConverter::convertFloatType(FloatType type) {
+Type LLVMLowering::convertFloatType(FloatType type) {
   switch (type.getKind()) {
   case mlir::StandardTypes::F32:
-    return wrap(builder.getFloatTy());
+    return wrap(llvm::Type::getFloatTy(llvmDialect->getLLVMContext()));
   case mlir::StandardTypes::F64:
-    return wrap(builder.getDoubleTy());
+    return wrap(llvm::Type::getDoubleTy(llvmDialect->getLLVMContext()));
   case mlir::StandardTypes::F16:
-    return wrap(builder.getHalfTy());
-  case mlir::StandardTypes::BF16:
+    return wrap(llvm::Type::getHalfTy(llvmDialect->getLLVMContext()));
+  case mlir::StandardTypes::BF16: {
+    auto *mlirContext = llvmDialect->getContext();
     return mlirContext->emitError(UnknownLoc::get(mlirContext),
                                   "unsupported type: BF16"),
            Type();
+  }
   default:
     llvm_unreachable("non-float type in convertFloatType");
   }
 }
 
-// If `types` has more than one type, pack them into an LLVM StructType,
-// otherwise just convert the type.
-Type TypeConverter::getPackedResultType(ArrayRef<Type> types) {
-  // We don't convert zero-valued functions to one-valued functions returning
-  // void yet.
-  assert(!types.empty() && "empty type list");
-
-  // Convert result types one by one and check for errors.
-  SmallVector<llvm::Type *, 8> resultTypes;
-  for (auto t : types) {
-    llvm::Type *converted = unwrap(convertType(t));
-    if (!converted)
-      return {};
-    resultTypes.push_back(converted);
-  }
-
-  // LLVM does not support tuple returns.  If there are more than 2 results,
-  // pack them into an LLVM struct type.
-  if (resultTypes.size() == 1)
-    return wrap(resultTypes.front());
-  return wrap(llvm::StructType::get(llvmContext, resultTypes));
-}
-
 // Function types are converted to LLVM Function types by recursively converting
 // argument and result types.  If MLIR Function has zero results, the LLVM
 // Function has one VoidType result.  If MLIR Function has more than one result,
 // they are into an LLVM StructType in their order of appearance.
-Type TypeConverter::convertFunctionType(FunctionType type) {
+Type LLVMLowering::convertFunctionType(FunctionType type) {
   // Convert argument types one by one and check for errors.
   SmallVector<llvm::Type *, 8> argTypes;
   for (auto t : type.getInputs()) {
@@ -219,45 +108,22 @@ Type TypeConverter::convertFunctionType(FunctionType type) {
   // If function does not return anything, create the void result type,
   // if it returns on element, convert it, otherwise pack the result types into
   // a struct.
-  llvm::Type *resultType = type.getNumResults() == 0
-                               ? llvm::Type::getVoidTy(llvmContext)
-                               : unwrap(getPackedResultType(type.getResults()));
+  llvm::Type *resultType =
+      type.getNumResults() == 0
+          ? llvm::Type::getVoidTy(llvmDialect->getLLVMContext())
+          : unwrap(packFunctionResults(type.getResults()));
   if (!resultType)
     return {};
   return wrap(llvm::FunctionType::get(resultType, argTypes, /*isVarArg=*/false)
                   ->getPointerTo());
 }
 
-FunctionType TypeConverter::convertFunctionSignatureType(
-    FunctionType type, llvm::function_ref<Type(Type)> typeConversionCallback) {
-  if (!typeConversionCallback)
-    typeConversionCallback = [this](Type t) { return convertType(t); };
-
-  SmallVector<Type, 8> argTypes;
-  for (auto t : type.getInputs()) {
-    auto converted = typeConversionCallback(t);
-    if (!converted)
-      return {};
-    argTypes.push_back(converted);
-  }
-
-  // If function does not return anything, return immediately.
-  if (type.getNumResults() == 0)
-    return FunctionType::get(argTypes, {}, mlirContext);
-
-  // Otherwise pack the result types into a struct.
-  if (auto result = getPackedResultType(type.getResults()))
-    return FunctionType::get(argTypes, {result}, mlirContext);
-
-  return {};
-}
-
 // Convert a MemRef to an LLVM type. If the memref is statically-shaped, then
 // we return a pointer to the converted element type. Otherwise we return an
 // LLVM stucture type, where the first element of the structure type is a
 // pointer to the elemental type of the MemRef and the following N elements are
 // values of the Index type, one for each of N dynamic dimensions of the MemRef.
-Type TypeConverter::convertMemRefType(MemRefType type) {
+Type LLVMLowering::convertMemRefType(MemRefType type) {
   llvm::Type *elementType = unwrap(convertType(type.getElementType()));
   if (!elementType)
     return {};
@@ -272,12 +138,13 @@ Type TypeConverter::convertMemRefType(MemRefType type) {
   SmallVector<llvm::Type *, 8> types(numDynamicSizes + 1, getIndexType());
   types.front() = ptrType;
 
-  return wrap(llvm::StructType::get(llvmContext, types));
+  return wrap(llvm::StructType::get(llvmDialect->getLLVMContext(), types));
 }
 
 // Convert a 1D vector type to an LLVM vector type.
-Type TypeConverter::convertVectorType(VectorType type) {
+Type LLVMLowering::convertVectorType(VectorType type) {
   if (type.getRank() != 1) {
+    auto *mlirContext = llvmDialect->getContext();
     mlirContext->emitError(UnknownLoc::get(mlirContext),
                            "only 1D vectors are supported");
     return {};
@@ -290,7 +157,7 @@ Type TypeConverter::convertVectorType(VectorType type) {
 }
 
 // Dispatch based on the actual type.  Return null type on error.
-Type TypeConverter::convertType(Type type) {
+Type LLVMLowering::convertStandardType(Type type) {
   if (auto funcType = type.dyn_cast<FunctionType>())
     return convertFunctionType(funcType);
   if (auto intType = type.dyn_cast<IntegerType>())
@@ -309,44 +176,36 @@ Type TypeConverter::convertType(Type type) {
   return {};
 }
 
-Type TypeConverter::convert(Type t, llvm::Module &module) {
-  return TypeConverter(module, t.getContext()).convertType(t);
-}
-
-FunctionType TypeConverter::convertFunctionSignature(
-    FunctionType t, llvm::Module &module,
-    llvm::function_ref<Type(Type)> typeConversionCallback) {
-  return TypeConverter(module, t.getContext())
-      .convertFunctionSignatureType(t, typeConversionCallback);
-}
-
-Type TypeConverter::getMemRefElementPtrType(MemRefType t,
-                                            llvm::Module &module) {
+// Convert the element type of the memref `t` to to an LLVM type using
+// `lowering`, get a pointer LLVM type pointing to the converted `t`, wrap it
+// into the MLIR LLVM dialect type and return.
+static Type getMemRefElementPtrType(MemRefType t, LLVMLowering &lowering) {
   auto elementType = t.getElementType();
-  auto converted = convert(elementType, module);
+  auto converted = lowering.convertType(elementType);
   if (!converted)
     return {};
   llvm::Type *llvmType = converted.cast<LLVM::LLVMType>().getUnderlyingType();
   return LLVM::LLVMType::get(t.getContext(), llvmType->getPointerTo());
 }
 
-Type TypeConverter::pack(ArrayRef<Type> types, llvm::Module &module,
-                         MLIRContext &mlirContext) {
-  return TypeConverter(module, &mlirContext).getPackedResultType(types);
-}
+LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
+                               LLVMLowering &lowering_)
+    : DialectOpConversion(rootOpName, /*benefit=*/1, context),
+      lowering(lowering_) {}
 
 namespace {
 // Base class for Standard to LLVM IR op conversions.  Matches the Op type
 // provided as template argument.  Carries a reference to the LLVM dialect in
 // case it is necessary for rewriters.
 template <typename SourceOp>
-class LLVMLegalizationPattern : public DialectOpConversion {
+class LLVMLegalizationPattern : public LLVMOpLowering {
 public:
   // Construct a conversion pattern.
-  explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect)
-      : DialectOpConversion(SourceOp::getOperationName(), 1,
-                            dialect.getContext()),
-        dialect(dialect) {}
+  explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_,
+                                   LLVMLowering &lowering_)
+      : LLVMOpLowering(SourceOp::getOperationName(), dialect_.getContext(),
+                       lowering_),
+        dialect(dialect_) {}
 
   // Match by type.
   PatternMatchResult match(Operation *op) const override {
@@ -436,13 +295,11 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
     unsigned numResults = op->getNumResults();
-    auto *mlirContext = op->getContext();
 
     Type packedType;
     if (numResults != 0) {
       packedType =
-          TypeConverter::pack(getTypes(op->getResults()),
-                              this->dialect.getLLVMModule(), *mlirContext);
+          this->lowering.packFunctionResults(getTypes(op->getResults()));
       assert(packedType && "type conversion failed, such operation should not "
                            "have been matched");
     }
@@ -461,8 +318,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
     SmallVector<Value *, 4> results;
     results.reserve(numResults);
     for (unsigned i = 0; i < numResults; ++i) {
-      auto type = TypeConverter::convert(op->getResult(i)->getType(),
-                                         this->dialect.getLLVMModule());
+      auto type = this->lowering.convertType(op->getResult(i)->getType());
       results.push_back(rewriter.create<LLVM::ExtractValueOp>(
           op->getLoc(), type, newOp.getOperation()->getResult(0),
           this->getIntegerArrayAttr(rewriter, i)));
@@ -623,7 +479,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
                                   rewriter.getFunctionAttr(mallocFunc),
                                   cumulativeSize)
             .getResult(0);
-    auto structElementType = TypeConverter::convert(elementType, getModule());
+    auto structElementType = lowering.convertType(elementType);
     auto elementPtrType = LLVM::LLVMType::get(
         op->getContext(), structElementType.cast<LLVM::LLVMType>()
                               .getUnderlyingType()
@@ -637,7 +493,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
     }
 
     // Create the MemRef descriptor.
-    auto structType = TypeConverter::convert(type, getModule());
+    auto structType = lowering.convertType(type);
     Value *memRefDescriptor = rewriter.create<LLVM::UndefOp>(
         op->getLoc(), structType, ArrayRef<Value *>{});
 
@@ -718,8 +574,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
     auto sourceType = memRefCastOp.getOperand()->getType().cast<MemRefType>();
 
     // Copy the data buffer pointer.
-    auto elementTypePtr =
-        TypeConverter::getMemRefElementPtrType(targetType, getModule());
+    auto elementTypePtr = getMemRefElementPtrType(targetType, lowering);
     Value *buffer =
         extractMemRefElementPtr(rewriter, op->getLoc(), operands[0],
                                 elementTypePtr, sourceType.hasStaticShape());
@@ -729,7 +584,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
     }
 
     // Create the new MemRef descriptor.
-    auto structType = TypeConverter::convert(targetType, getModule());
+    auto structType = lowering.convertType(targetType);
     Value *newDescriptor = rewriter.create<LLVM::UndefOp>(
         op->getLoc(), structType, ArrayRef<Value *>{});
     // Otherwise target type is dynamic memref, so create a proper descriptor.
@@ -921,7 +776,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
   Value *getDataPtr(Location loc, MemRefType type, Value *dataPtr,
                     ArrayRef<Value *> indices, FuncBuilder &rewriter,
                     llvm::Module &module) const {
-    auto ptrType = TypeConverter::getMemRefElementPtrType(type, module);
+    auto ptrType = getMemRefElementPtrType(type, this->lowering);
     auto shape = type.getShape();
     if (type.hasStaticShape()) {
       // NB: If memref was statically-shaped, dataPtr is pointer to raw data.
@@ -944,8 +799,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
 
     Value *dataPtr = getDataPtr(op->getLoc(), type, operands.front(),
                                 operands.drop_front(), rewriter, getModule());
-    auto elementType =
-        TypeConverter::convert(type.getElementType(), getModule());
+    auto elementType = lowering.convertType(type.getElementType());
 
     SmallVector<Value *, 4> results;
     results.push_back(rewriter.create<LLVM::LoadOp>(
@@ -1018,9 +872,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
 
     // Otherwise, we need to pack the arguments into an LLVM struct type before
     // returning.
-    auto *mlirContext = op->getContext();
-    auto packedType = TypeConverter::pack(
-        getTypes(op->getOperands()), dialect.getLLVMModule(), *mlirContext);
+    auto packedType = lowering.packFunctionResults(getTypes(op->getOperands()));
 
     Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
     for (unsigned i = 0; i < numArguments; ++i) {
@@ -1119,7 +971,8 @@ LLVMLowering::initConverters(MLIRContext *mlirContext) {
       LoadOpLowering, MemRefCastOpLowering, MulFOpLowering, MulIOpLowering,
       OrOpLowering, RemISOpLowering, RemIUOpLowering, RemFOpLowering,
       ReturnOpLowering, SelectOpLowering, StoreOpLowering, SubFOpLowering,
-      SubIOpLowering, XOrOpLowering>::build(&converterStorage, *llvmDialect);
+      SubIOpLowering, XOrOpLowering>::build(&converterStorage, *llvmDialect,
+                                            *this);
   auto extraConverters = initAdditionalConverters();
   converters.insert(extraConverters.begin(), extraConverters.end());
   return converters;
@@ -1127,7 +980,7 @@ LLVMLowering::initConverters(MLIRContext *mlirContext) {
 
 // Convert types using the stored LLVM IR module.
 Type LLVMLowering::convertType(Type t) {
-  if (auto result = TypeConverter::convert(t, *module))
+  if (auto result = convertStandardType(t))
     return result;
   if (auto result = convertAdditionalType(t))
     return result;
@@ -1138,16 +991,57 @@ Type LLVMLowering::convertType(Type t) {
   return {};
 }
 
+static llvm::Type *unwrapType(Type type) {
+  return type.cast<LLVM::LLVMType>().getUnderlyingType();
+}
+
+// Create an LLVM IR structure type if there is more than one result.
+Type LLVMLowering::packFunctionResults(ArrayRef<Type> types) {
+  assert(!types.empty() && "expected non-empty list of type");
+
+  if (types.size() == 1)
+    return convertType(types.front());
+
+  SmallVector<llvm::Type *, 8> resultTypes;
+  resultTypes.reserve(types.size());
+  for (auto t : types) {
+    Type converted = convertType(t);
+    if (!converted)
+      return {};
+    resultTypes.push_back(unwrapType(converted));
+  }
+
+  return LLVM::LLVMType::get(
+      llvmDialect->getContext(),
+      llvm::StructType::get(llvmDialect->getLLVMContext(), resultTypes));
+}
+
 // Convert function signatures using the stored LLVM IR module.
 FunctionType LLVMLowering::convertFunctionSignatureType(
-    FunctionType t, ArrayRef<NamedAttributeList> argAttrs,
+    FunctionType type, ArrayRef<NamedAttributeList> argAttrs,
     SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) {
 
   convertedArgAttrs.reserve(argAttrs.size());
   for (auto attr : argAttrs)
     convertedArgAttrs.push_back(attr);
-  return TypeConverter::convertFunctionSignature(
-      t, *module, [this](Type t) { return convertType(t); });
+
+  SmallVector<Type, 8> argTypes;
+  for (auto t : type.getInputs()) {
+    auto converted = convertType(t);
+    if (!converted)
+      return {};
+    argTypes.push_back(converted);
+  }
+
+  // If function does not return anything, return immediately.
+  if (type.getNumResults() == 0)
+    return FunctionType::get(argTypes, {}, type.getContext());
+
+  // Otherwise pack the result types into a struct.
+  if (auto result = packFunctionResults(type.getResults()))
+    return FunctionType::get(argTypes, result, result.getContext());
+
+  return {};
 }
 
 namespace {
@@ -1184,7 +1078,3 @@ std::unique_ptr<DialectConversion> mlir::createStdToLLVMConverter() {
 
 static PassRegistration<LLVMLoweringPass>
     pass("lower-to-llvm", "Convert all functions to the LLVM IR dialect");
-
-Type mlir::LLVM::convertToLLVMDialectType(Type t, llvm::Module &llvmModule) {
-  return TypeConverter::convert(t, llvmModule);
-}
index 7463c71..108da6c 100644 (file)
@@ -54,21 +54,10 @@ using add = ValueBuilder<mlir::LLVM::AddOp>;
 using sub = ValueBuilder<mlir::LLVM::SubOp>;
 using mul = ValueBuilder<mlir::LLVM::MulOp>;
 
-static llvm::Module *getLLVMModule(MLIRContext *context) {
-  auto *llvmDialect =
-      static_cast<LLVM::LLVMDialect *>(context->getRegisteredDialect("llvm"));
-  if (!llvmDialect) {
-    context->emitError(UnknownLoc::get(context),
-                       "LLVM IR dialect is not registered");
-    return nullptr;
-  }
-  return &llvmDialect->getLLVMModule();
-}
-
 template <typename T>
 static llvm::Type *getPtrToElementType(T containerType,
-                                       llvm::Module &llvmModule) {
-  return convertToLLVMDialectType(containerType.getElementType(), llvmModule)
+                                       LLVMLowering &lowering) {
+  return lowering.convertType(containerType.getElementType())
       .template cast<LLVMType>()
       .getUnderlyingType()
       ->getPointerTo();
@@ -82,9 +71,11 @@ static llvm::Type *getPtrToElementType(T containerType,
 //   - an F32 type is converted into an LLVM float type
 //   - a Buffer, Range or View is converted into an LLVM structure type
 //     containing the respective dynamic values.
-static Type convertLinalgType(Type t, llvm::Module &llvmModule) {
+static Type convertLinalgType(Type t, LLVMLowering &lowering) {
   auto *context = t.getContext();
-  auto *int64Ty = llvm::Type::getInt64Ty(llvmModule.getContext());
+  auto *int64Ty = lowering.convertType(IntegerType::get(64, context))
+                      .cast<LLVM::LLVMType>()
+                      .getUnderlyingType();
 
   // A buffer descriptor contains the pointer to a flat region of storage and
   // the size of the region.
@@ -95,7 +86,7 @@ static Type convertLinalgType(Type t, llvm::Module &llvmModule) {
   //   int64_t size;
   // };
   if (auto bufferTy = t.dyn_cast<BufferType>()) {
-    auto *ptrTy = getPtrToElementType(bufferTy, llvmModule);
+    auto *ptrTy = getPtrToElementType(bufferTy, lowering);
     auto *structTy = llvm::StructType::get(ptrTy, int64Ty);
     return LLVMType::get(context, structTy);
   }
@@ -136,7 +127,7 @@ static Type convertLinalgType(Type t, llvm::Module &llvmModule) {
   //   int64_t strides[Rank];
   // };
   if (auto viewTy = t.dyn_cast<ViewType>()) {
-    auto *ptrTy = getPtrToElementType(viewTy, llvmModule);
+    auto *ptrTy = getPtrToElementType(viewTy, lowering);
     auto *arrayTy = llvm::ArrayType::get(int64Ty, viewTy.getRank());
     auto *structTy = llvm::StructType::get(ptrTy, int64Ty, arrayTy, arrayTy);
     return LLVMType::get(context, structTy);
@@ -157,36 +148,31 @@ static ArrayAttr makePositionAttr(FuncBuilder &builder,
 }
 
 // BufferSizeOp creates a new `index` value.
-class BufferSizeOpConversion : public DialectOpConversion {
+class BufferSizeOpConversion : public LLVMOpLowering {
 public:
-  explicit BufferSizeOpConversion(MLIRContext *context)
-      : DialectOpConversion(BufferSizeOp::getOperationName(), 1, context),
-        llvmModule(*getLLVMModule(context)) {}
+  BufferSizeOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+      : LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {}
 
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
-    auto bufferSizeType =
-        convertToLLVMDialectType(operands[0]->getType(), llvmModule);
+    auto bufferSizeType = lowering.convertType(operands[0]->getType());
     edsc::ScopedContext context(rewriter, op->getLoc());
     return {extractvalue(bufferSizeType, operands[0],
                          makePositionAttr(rewriter, 1))};
   }
-
-  llvm::Module &llvmModule;
 };
 
 // RangeOp creates a new range descriptor.
-class RangeOpConversion : public DialectOpConversion {
+class RangeOpConversion : public LLVMOpLowering {
 public:
-  explicit RangeOpConversion(MLIRContext *context)
-      : DialectOpConversion(RangeOp::getOperationName(), 1, context),
-        llvmModule(*getLLVMModule(context)) {}
+  explicit RangeOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+      : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {}
 
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
     auto rangeOp = op->cast<RangeOp>();
     auto rangeDescriptorType =
-        convertLinalgType(rangeOp.getResult()->getType(), llvmModule);
+        convertLinalgType(rangeOp.getResult()->getType(), lowering);
 
     edsc::ScopedContext context(rewriter, op->getLoc());
 
@@ -201,24 +187,20 @@ public:
 
     return {desc};
   }
-
-  llvm::Module &llvmModule;
 };
 
-class SliceOpConversion : public DialectOpConversion {
+class SliceOpConversion : public LLVMOpLowering {
 public:
-  explicit SliceOpConversion(MLIRContext *context)
-      : DialectOpConversion(SliceOp::getOperationName(), 1, context),
-        llvmModule(*getLLVMModule(context)) {}
+  explicit SliceOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+      : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {}
 
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
     auto sliceOp = op->cast<SliceOp>();
     auto viewDescriptorType =
-        convertLinalgType(sliceOp.getViewType(), llvmModule);
+        convertLinalgType(sliceOp.getViewType(), lowering);
     auto viewType = sliceOp.getBaseViewType();
-    auto int64Ty =
-        convertToLLVMDialectType(rewriter.getIntegerType(64), llvmModule);
+    auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
 
     // Helper function to create an integer array attribute out of a list of
     // values.
@@ -229,7 +211,7 @@ public:
     auto getViewPtr = [pos, &rewriter, this](ViewType type,
                                              Value *view) -> Value * {
       auto elementPtrTy =
-          rewriter.getType<LLVMType>(getPtrToElementType(type, llvmModule));
+          rewriter.getType<LLVMType>(getPtrToElementType(type, lowering));
       return extractvalue(elementPtrTy, view, pos(0));
     };
 
@@ -288,25 +270,20 @@ public:
 
     return {desc};
   }
-
-  llvm::Module &llvmModule;
 };
 
-class ViewOpConversion : public DialectOpConversion {
+class ViewOpConversion : public LLVMOpLowering {
 public:
-  explicit ViewOpConversion(MLIRContext *context)
-      : DialectOpConversion(ViewOp::getOperationName(), 1, context),
-        llvmModule(*getLLVMModule(context)) {}
+  explicit ViewOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+      : LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {}
 
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
     auto viewOp = op->cast<ViewOp>();
-    auto viewDescriptorType =
-        convertLinalgType(viewOp.getViewType(), llvmModule);
+    auto viewDescriptorType = convertLinalgType(viewOp.getViewType(), lowering);
     auto elementType = rewriter.getType<LLVMType>(
-        getPtrToElementType(viewOp.getViewType(), llvmModule));
-    auto int64Ty =
-        convertToLLVMDialectType(rewriter.getIntegerType(64), llvmModule);
+        getPtrToElementType(viewOp.getViewType(), lowering));
+    auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
 
     auto pos = [&rewriter](ArrayRef<int> values) {
       return makePositionAttr(rewriter, values);
@@ -350,15 +327,13 @@ public:
 
     return {desc};
   }
-
-  llvm::Module &llvmModule;
 };
 
 // DotOp creates a new range descriptor.
-class DotOpConversion : public DialectOpConversion {
+class DotOpConversion : public LLVMOpLowering {
 public:
-  explicit DotOpConversion(MLIRContext *context)
-      : DialectOpConversion(DotOp::getOperationName(), 1, context) {}
+  explicit DotOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+      : LLVMOpLowering(DotOp::getOperationName(), context, lowering_) {}
 
   static StringRef libraryFunctionName() { return "linalg_dot"; }
 
@@ -384,11 +359,12 @@ protected:
     return ConversionListBuilder<
         BufferSizeOpConversion, DotOpConversion, RangeOpConversion,
         SliceOpConversion, ViewOpConversion>::build(&converterStorage,
-                                                    llvmDialect->getContext());
+                                                    llvmDialect->getContext(),
+                                                    *this);
   }
 
   Type convertAdditionalType(Type t) override {
-    return convertLinalgType(t, *module);
+    return convertLinalgType(t, *this);
   }
 };
 } // end anonymous namespace