--- /dev/null
+//===- LLVMLowering.h - Lowering to the LLVM IR dialect ---------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Provides a dialect conversion targeting the LLVM IR dialect. By default, it
+// converts Standard ops and types and provides hooks for dialect-specific
+// extensions to the conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_LLVMIR_LLVMLOWERING_H
+#define MLIR_LLVMIR_LLVMLOWERING_H
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace llvm {
+class Module;
+}
+
+namespace mlir {
+namespace LLVM {
+class LLVMDialect;
+}
+
+/// Conversion from the Standard dialect to the LLVM IR dialect. Provides hooks
+/// for derived classes to extend the conversion.
+class LLVMLowering : public DialectConversion {
+protected:
+ /// Create a set of converters that live in the pass object by passing them a
+ /// reference to the LLVM IR dialect. Store the module associated with the
+ /// dialect for further type conversion.
+ llvm::DenseSet<DialectOpConversion *>
+ initConverters(MLIRContext *mlirContext) override final;
+
+ /// Derived classes can override this function to initialize custom converters
+ /// in addition to the existing converters from Standard operations. It will
+ /// be called after the `module` and `llvmDialect` have been made available.
+ virtual llvm::DenseSet<DialectOpConversion *> initAdditionalConverters() {
+ return {};
+ };
+
+ /// Convert standard and builtin types to LLVM IR.
+ Type convertType(Type t) override final;
+
+ /// Derived classes can override this function to convert custom types. It
+ /// will be called by convertType if the default conversion from standard and
+ /// builtin types fails. Derived classes can thus call convertType whenever
+ /// they need type conversion that supports both default and custom types.
+ virtual Type convertAdditionalType(Type t) { return t; }
+
+ /// Convert function signatures to LLVM IR. In particular, convert functions
+ /// with multiple results into functions returning LLVM IR's structure type.
+ /// Use `convertType` to convert individual argument and result types.
+ FunctionType convertFunctionSignatureType(
+ FunctionType t, ArrayRef<NamedAttributeList> argAttrs,
+ SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) override final;
+
+ /// Storage for the conversion patterns.
+ llvm::BumpPtrAllocator converterStorage;
+ /// LLVM IR module used to parse/create types.
+ llvm::Module *module;
+ LLVM::LLVMDialect *llvmDialect;
+};
+
+} // namespace mlir
+
+#endif // MLIR_LLVMIR_LLVMLOWERING_H
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/LLVMIR/LLVMLowering.h"
#include "mlir/LLVMIR/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/StandardOps/Ops.h"
// Convert a function signature type to the LLVM IR dialect. The outer
// function type remains `mlir::FunctionType`. Argument types are converted
- // to LLVM IR as is. If the function returns a single result, its type is
- // converted. Otherwise, the types of results are packed into an LLVM IR
- // structure type.
- static FunctionType convertFunctionSignature(FunctionType t,
- llvm::Module &llvmModule);
+ // to LLVM IR using `typeConversionCallback` if provided and using
+ // `TypeConverter::convert` otherwise. If the function returns a single
+ // result, its type is converted. Otherwise, the types of results are packed
+ // into an LLVM IR structure type.
+ static FunctionType convertFunctionSignature(
+ FunctionType t, llvm::Module &llvmModule,
+ llvm::function_ref<Type(Type)> typeConversionCallback = {});
private:
// Construct a type converter.
// Convert function type arguments and results without converting the
// function type itself.
- FunctionType convertFunctionSignatureType(FunctionType type);
+ FunctionType convertFunctionSignatureType(
+ FunctionType type, llvm::function_ref<Type(Type)> typeConversionCallback);
// Convert the index type. Uses llvmModule data layout to create an integer
// of the pointer bitwidth.
->getPointerTo());
}
-FunctionType TypeConverter::convertFunctionSignatureType(FunctionType type) {
+FunctionType TypeConverter::convertFunctionSignatureType(
+ FunctionType type, llvm::function_ref<Type(Type)> typeConversionCallback) {
+ if (!typeConversionCallback)
+ typeConversionCallback = [this](Type t) { return convertType(t); };
+
SmallVector<Type, 8> argTypes;
for (auto t : type.getInputs()) {
- auto converted = convertType(t);
+ auto converted = typeConversionCallback(t);
if (!converted)
return {};
argTypes.push_back(converted);
if (auto llvmType = type.dyn_cast<LLVM::LLVMType>())
return llvmType;
- std::string message;
- llvm::raw_string_ostream os(message);
- os << "unsupported type: ";
- type.print(os);
- mlirContext->emitError(UnknownLoc::get(mlirContext), os.str());
return {};
}
return TypeConverter(module, t.getContext()).convertType(t);
}
-FunctionType TypeConverter::convertFunctionSignature(FunctionType t,
- llvm::Module &module) {
- return TypeConverter(module, t.getContext()).convertFunctionSignatureType(t);
+FunctionType TypeConverter::convertFunctionSignature(
+ FunctionType t, llvm::Module &module,
+ llvm::function_ref<Type(Type)> typeConversionCallback) {
+ return TypeConverter(module, t.getContext())
+ .convertFunctionSignatureType(t, typeConversionCallback);
}
Type TypeConverter::getMemRefElementPtrType(MemRefType t,
}
};
-/// A dialect converter from the Standard dialect to the LLVM IR dialect.
-class LLVMLowering : public DialectConversion {
-protected:
- // Create a set of converters that live in the pass object by passing them a
- // reference to the LLVM IR dialect. Store the module associated with the
- // dialect for further type conversion.
- llvm::DenseSet<DialectOpConversion *>
- initConverters(MLIRContext *mlirContext) override {
- converterStorage.Reset();
- auto *llvmDialect = static_cast<LLVM::LLVMDialect *>(
- mlirContext->getRegisteredDialect("llvm"));
- if (!llvmDialect) {
- mlirContext->emitError(UnknownLoc::get(mlirContext),
- "LLVM IR dialect is not registered");
- return {};
- }
-
- module = &llvmDialect->getLLVMModule();
-
- // FIXME: this should be tablegen'ed
- return ConversionListBuilder<
- AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
- BranchOpLowering, CallIndirectOpLowering, CallOpLowering,
- CmpIOpLowering, CondBranchOpLowering, ConstLLVMOpLowering,
- DeallocOpLowering, DimOpLowering, DivISOpLowering, DivIUOpLowering,
- DivFOpLowering, LoadOpLowering, MemRefCastOpLowering, MulFOpLowering,
- MulIOpLowering, OrOpLowering, RemISOpLowering, RemIUOpLowering,
- RemFOpLowering, ReturnOpLowering, SelectOpLowering, StoreOpLowering,
- SubFOpLowering, SubIOpLowering, XOrOpLowering>::build(&converterStorage,
- *llvmDialect);
+// Create a set of converters that live in the pass object by passing them a
+// reference to the LLVM IR dialect. Store the module associated with the
+// dialect for further type conversion.
+llvm::DenseSet<DialectOpConversion *>
+LLVMLowering::initConverters(MLIRContext *mlirContext) {
+ converterStorage.Reset();
+ llvmDialect = static_cast<LLVM::LLVMDialect *>(
+ mlirContext->getRegisteredDialect("llvm"));
+ if (!llvmDialect) {
+ mlirContext->emitError(UnknownLoc::get(mlirContext),
+ "LLVM IR dialect is not registered");
+ return {};
}
- // Convert types using the stored LLVM IR module.
- Type convertType(Type t) override {
- return TypeConverter::convert(t, *module);
- }
+ module = &llvmDialect->getLLVMModule();
+
+ // FIXME: this should be tablegen'ed
+ auto converters = ConversionListBuilder<
+ AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
+ BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
+ CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering,
+ DimOpLowering, DivISOpLowering, DivIUOpLowering, DivFOpLowering,
+ LoadOpLowering, MemRefCastOpLowering, MulFOpLowering, MulIOpLowering,
+ OrOpLowering, RemISOpLowering, RemIUOpLowering, RemFOpLowering,
+ ReturnOpLowering, SelectOpLowering, StoreOpLowering, SubFOpLowering,
+ SubIOpLowering, XOrOpLowering>::build(&converterStorage, *llvmDialect);
+ auto extraConverters = initAdditionalConverters();
+ converters.insert(extraConverters.begin(), extraConverters.end());
+ return converters;
+}
- // Convert function signatures using the stored LLVM IR module.
- FunctionType convertFunctionSignatureType(
- FunctionType t, ArrayRef<NamedAttributeList> argAttrs,
- SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) override {
+// Convert types using the stored LLVM IR module.
+Type LLVMLowering::convertType(Type t) {
+ if (auto result = TypeConverter::convert(t, *module))
+ return result;
+ if (auto result = convertAdditionalType(t))
+ return result;
- convertedArgAttrs.reserve(argAttrs.size());
- for (auto attr : argAttrs)
- convertedArgAttrs.push_back(attr);
- return TypeConverter::convertFunctionSignature(t, *module);
- }
+ auto *mlirContext = llvmDialect->getContext();
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "unsupported type: ";
+ t.print(os);
+ mlirContext->emitError(UnknownLoc::get(mlirContext), os.str());
+ return {};
+}
-private:
- // Storage for the conversion patterns.
- llvm::BumpPtrAllocator converterStorage;
- // LLVM IR module used to parse/create types.
- llvm::Module *module;
+// Convert function signatures using the stored LLVM IR module.
+FunctionType LLVMLowering::convertFunctionSignatureType(
+ FunctionType t, ArrayRef<NamedAttributeList> argAttrs,
+ SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) {
+
+ convertedArgAttrs.reserve(argAttrs.size());
+ for (auto attr : argAttrs)
+ convertedArgAttrs.push_back(attr);
+ return TypeConverter::convertFunctionSignature(
+ t, *module, [this](Type t) { return convertType(t); });
+}
+
+namespace {
+// Make sure LLVM conversion pass errors out on the unsupported types instead
+// of keeping them as is and resulting in a more cryptic verifier error.
+class LLVMStandardLowering : public LLVMLowering {
+protected:
+ Type convertAdditionalType(Type) override { return {}; }
};
+} // namespace
/// A pass converting MLIR Standard operations into the LLVM IR dialect.
class LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
}
private:
- LLVMLowering impl;
+ LLVMStandardLowering impl;
};
ModulePassBase *mlir::createConvertToLLVMIRPass() {
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/LLVMIR/LLVMLowering.h"
#include "mlir/LLVMIR/Transforms.h"
#include "mlir/Linalg/IR/LinalgOps.h"
#include "mlir/Linalg/IR/LinalgTypes.h"
}
};
-llvm::DenseSet<mlir::DialectOpConversion *>
-allocateDescriptorConverters(llvm::BumpPtrAllocator *allocator,
- mlir::MLIRContext *context) {
- return ConversionListBuilder<BufferSizeOpConversion, DotOpConversion,
- RangeOpConversion, SliceOpConversion,
- ViewOpConversion>::build(allocator, context);
-}
-
namespace {
// The conversion class from Linalg to LLVMIR.
-class Lowering : public DialectConversion {
-public:
- explicit Lowering(std::function<llvm::DenseSet<mlir::DialectOpConversion *>(
- llvm::BumpPtrAllocator *, mlir::MLIRContext *context)>
- conversions)
- : setup(conversions) {}
-
- Lowering &setLLVMModule(MLIRContext *context) {
- llvmModule = getLLVMModule(context);
- return *this;
- }
-
+class Lowering : public LLVMLowering {
protected:
- // Initialize the list of converters.
- llvm::DenseSet<DialectOpConversion *>
- initConverters(MLIRContext *context) override {
- converterStorage.Reset();
- return setup(&converterStorage, context);
+ llvm::DenseSet<DialectOpConversion *> initAdditionalConverters() override {
+ return ConversionListBuilder<
+ BufferSizeOpConversion, DotOpConversion, RangeOpConversion,
+ SliceOpConversion, ViewOpConversion>::build(&converterStorage,
+ llvmDialect->getContext());
}
- // This gets called for block and region arguments, and attributes.
- Type convertType(Type t) override {
- if (auto res = convertLinalgType(t, *llvmModule))
- return res;
- return convertToLLVMDialectType(t, *llvmModule);
+ Type convertAdditionalType(Type t) override {
+ return convertLinalgType(t, *module);
}
-
-private:
- // Storage for individual converters.
- llvm::BumpPtrAllocator converterStorage;
-
- // Conversion setup.
- std::function<llvm::DenseSet<mlir::DialectOpConversion *>(
- llvm::BumpPtrAllocator *, mlir::MLIRContext *context)>
- setup;
-
- llvm::Module *llvmModule;
};
} // end anonymous namespace
-std::unique_ptr<mlir::DialectConversion> makeLinalgToLLVMLowering(
- std::function<llvm::DenseSet<mlir::DialectOpConversion *>(
- llvm::BumpPtrAllocator *, mlir::MLIRContext *context)>
- initer) {
- return llvm::make_unique<Lowering>(initer);
-}
-
namespace {
struct LowerLinalgToLLVMPass : public ModulePass<LowerLinalgToLLVMPass> {
void runOnModule();
void LowerLinalgToLLVMPass::runOnModule() {
auto &module = getModule();
- // Convert Linalg ops to the LLVM IR dialect using the converter defined
- // above.
- auto r = Lowering(allocateDescriptorConverters)
- .setLLVMModule(module.getContext())
- .convert(&module);
- if (failed(r))
- signalPassFailure();
-
- // Convert the remaining standard MLIR operations to the LLVM IR dialect using
- // the default converter.
- auto converter = createStdToLLVMConverter();
- r = converter->convert(&module);
+ // Convert to the LLVM IR dialect using the converter defined above.
+ auto r = Lowering().convert(&module);
if (failed(r))
signalPassFailure();
}