using namespace mlir;
-namespace {
-// Type converter for the LLVM IR dialect. Converts MLIR standard and builtin
-// types into equivalent LLVM IR dialect types.
-class TypeConverter {
-public:
- // Convert one type `t ` and register it in the `llvmModule`. The latter may
- // be used to extract information specific to the data layout.
- // Dispatches to the private functions below based on the actual type.
- static Type convert(Type t, llvm::Module &llvmModule);
-
- // Convert the element type of the memref `t` to to an LLVM type, get a
- // pointer LLVM type pointing to the converted `t`, wrap it into the MLIR LLVM
- // dialect type and return.
- static Type getMemRefElementPtrType(MemRefType t, llvm::Module &llvmModule);
-
- // Convert a non-empty list of types to an LLVM IR dialect type wrapping an
- // LLVM IR structure type, elements of which are formed by converting
- // individual types in the given list. Register the type in the `llvmModule`.
- // The module may be also used to query the data layout.
- static Type pack(ArrayRef<Type> types, llvm::Module &llvmModule,
- MLIRContext &context);
-
- // Convert a function signature type to the LLVM IR dialect. The outer
- // function type remains `mlir::FunctionType`. Argument types are converted
- // to LLVM IR using `typeConversionCallback` if provided and using
- // `TypeConverter::convert` otherwise. If the function returns a single
- // result, its type is converted. Otherwise, the types of results are packed
- // into an LLVM IR structure type.
- static FunctionType convertFunctionSignature(
- FunctionType t, llvm::Module &llvmModule,
- llvm::function_ref<Type(Type)> typeConversionCallback = {});
-
-private:
- // Construct a type converter.
- explicit TypeConverter(llvm::Module &llvmModule, MLIRContext *context)
- : module(llvmModule), llvmContext(llvmModule.getContext()),
- builder(llvmModule.getContext()), mlirContext(context) {}
-
- // Convert a function type. The arguments and results are converted one by
- // one. Additionally, if the function returns more than one value, pack the
- // results into an LLVM IR structure type so that the converted function type
- // returns at most one result.
- Type convertFunctionType(FunctionType type);
-
- // Convert function type arguments and results without converting the
- // function type itself.
- FunctionType convertFunctionSignatureType(
- FunctionType type, llvm::function_ref<Type(Type)> typeConversionCallback);
-
- // Convert the index type. Uses llvmModule data layout to create an integer
- // of the pointer bitwidth.
- Type convertIndexType(IndexType type);
-
- // Convert an integer type `i*` to `!llvm<"i*">`.
- Type convertIntegerType(IntegerType type);
-
- // Convert a floating point type: `f16` to `!llvm.half`, `f32` to
- // `!llvm.float` and `f64` to `!llvm.double`. `bf16` is not supported
- // by LLVM.
- Type convertFloatType(FloatType type);
-
- // Convert a memref type into an LLVM type that captures the relevant data.
- // For statically-shaped memrefs, the resulting type is a pointer to the
- // (converted) memref element type. For dynamically-shaped memrefs, the
- // resulting type is an LLVM structure type that contains:
- // 1. a pointer to the (converted) memref element type
- // 2. as many index types as memref has dynamic dimensions.
- Type convertMemRefType(MemRefType type);
-
- // Convert a 1D vector type into an LLVM vector type.
- Type convertVectorType(VectorType type);
-
- // Convert a non-empty list of types into an LLVM structure type containing
- // those types. If the list contains a single element, convert the element
- // directly.
- Type getPackedResultType(ArrayRef<Type> types);
-
- // Convert a type to the LLVM IR dialect. Returns a null type in case of
- // error.
- Type convertType(Type type);
-
- // Get the LLVM representation of the index type based on the bitwidth of the
- // pointer as defined by the data layout of the module.
- llvm::IntegerType *getIndexType();
-
- // Wrap the given LLVM IR type into an LLVM IR dialect type.
- Type wrap(llvm::Type *llvmType) {
- return LLVM::LLVMType::get(mlirContext, llvmType);
- }
-
- // Extract an LLVM IR type from the LLVM IR dialect type.
- llvm::Type *unwrap(Type type) {
- if (!type)
- return nullptr;
- auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
- if (!wrappedLLVMType)
- return mlirContext->emitError(UnknownLoc::get(mlirContext),
- "conversion resulted in a non-LLVM type"),
- nullptr;
- return wrappedLLVMType.getUnderlyingType();
- }
-
- llvm::Module &module;
- llvm::LLVMContext &llvmContext;
- llvm::IRBuilder<> builder;
+// Wrap the given LLVM IR type into an LLVM IR dialect type.
+Type LLVMLowering::wrap(llvm::Type *llvmType) {
+ return LLVM::LLVMType::get(llvmDialect->getContext(), llvmType);
+}
- MLIRContext *mlirContext;
-};
-} // end anonymous namespace
+// Extract an LLVM IR type from the LLVM IR dialect type.
+llvm::Type *LLVMLowering::unwrap(Type type) {
+ if (!type)
+ return nullptr;
+ auto *mlirContext = type.getContext();
+ auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
+ if (!wrappedLLVMType)
+ return mlirContext->emitError(UnknownLoc::get(mlirContext),
+ "conversion resulted in a non-LLVM type"),
+ nullptr;
+ return wrappedLLVMType.getUnderlyingType();
+}
-llvm::IntegerType *TypeConverter::getIndexType() {
- return builder.getIntNTy(module.getDataLayout().getPointerSizeInBits());
+llvm::IntegerType *LLVMLowering::getIndexType() {
+ return llvm::IntegerType::get(llvmDialect->getLLVMContext(),
+ module->getDataLayout().getPointerSizeInBits());
}
-Type TypeConverter::convertIndexType(IndexType type) {
+Type LLVMLowering::convertIndexType(IndexType type) {
return wrap(getIndexType());
}
-Type TypeConverter::convertIntegerType(IntegerType type) {
- return wrap(builder.getIntNTy(type.getWidth()));
+Type LLVMLowering::convertIntegerType(IntegerType type) {
+ return wrap(
+ llvm::Type::getIntNTy(llvmDialect->getLLVMContext(), type.getWidth()));
}
-Type TypeConverter::convertFloatType(FloatType type) {
+Type LLVMLowering::convertFloatType(FloatType type) {
switch (type.getKind()) {
case mlir::StandardTypes::F32:
- return wrap(builder.getFloatTy());
+ return wrap(llvm::Type::getFloatTy(llvmDialect->getLLVMContext()));
case mlir::StandardTypes::F64:
- return wrap(builder.getDoubleTy());
+ return wrap(llvm::Type::getDoubleTy(llvmDialect->getLLVMContext()));
case mlir::StandardTypes::F16:
- return wrap(builder.getHalfTy());
- case mlir::StandardTypes::BF16:
+ return wrap(llvm::Type::getHalfTy(llvmDialect->getLLVMContext()));
+ case mlir::StandardTypes::BF16: {
+ auto *mlirContext = llvmDialect->getContext();
return mlirContext->emitError(UnknownLoc::get(mlirContext),
"unsupported type: BF16"),
Type();
+ }
default:
llvm_unreachable("non-float type in convertFloatType");
}
}
-// If `types` has more than one type, pack them into an LLVM StructType,
-// otherwise just convert the type.
-Type TypeConverter::getPackedResultType(ArrayRef<Type> types) {
- // We don't convert zero-valued functions to one-valued functions returning
- // void yet.
- assert(!types.empty() && "empty type list");
-
- // Convert result types one by one and check for errors.
- SmallVector<llvm::Type *, 8> resultTypes;
- for (auto t : types) {
- llvm::Type *converted = unwrap(convertType(t));
- if (!converted)
- return {};
- resultTypes.push_back(converted);
- }
-
- // LLVM does not support tuple returns. If there are more than 2 results,
- // pack them into an LLVM struct type.
- if (resultTypes.size() == 1)
- return wrap(resultTypes.front());
- return wrap(llvm::StructType::get(llvmContext, resultTypes));
-}
-
// Function types are converted to LLVM Function types by recursively converting
// argument and result types. If MLIR Function has zero results, the LLVM
// Function has one VoidType result. If MLIR Function has more than one result,
// they are into an LLVM StructType in their order of appearance.
-Type TypeConverter::convertFunctionType(FunctionType type) {
+Type LLVMLowering::convertFunctionType(FunctionType type) {
// Convert argument types one by one and check for errors.
SmallVector<llvm::Type *, 8> argTypes;
for (auto t : type.getInputs()) {
// If function does not return anything, create the void result type,
// if it returns on element, convert it, otherwise pack the result types into
// a struct.
- llvm::Type *resultType = type.getNumResults() == 0
- ? llvm::Type::getVoidTy(llvmContext)
- : unwrap(getPackedResultType(type.getResults()));
+ llvm::Type *resultType =
+ type.getNumResults() == 0
+ ? llvm::Type::getVoidTy(llvmDialect->getLLVMContext())
+ : unwrap(packFunctionResults(type.getResults()));
if (!resultType)
return {};
return wrap(llvm::FunctionType::get(resultType, argTypes, /*isVarArg=*/false)
->getPointerTo());
}
-FunctionType TypeConverter::convertFunctionSignatureType(
- FunctionType type, llvm::function_ref<Type(Type)> typeConversionCallback) {
- if (!typeConversionCallback)
- typeConversionCallback = [this](Type t) { return convertType(t); };
-
- SmallVector<Type, 8> argTypes;
- for (auto t : type.getInputs()) {
- auto converted = typeConversionCallback(t);
- if (!converted)
- return {};
- argTypes.push_back(converted);
- }
-
- // If function does not return anything, return immediately.
- if (type.getNumResults() == 0)
- return FunctionType::get(argTypes, {}, mlirContext);
-
- // Otherwise pack the result types into a struct.
- if (auto result = getPackedResultType(type.getResults()))
- return FunctionType::get(argTypes, {result}, mlirContext);
-
- return {};
-}
-
// Convert a MemRef to an LLVM type. If the memref is statically-shaped, then
// we return a pointer to the converted element type. Otherwise we return an
// LLVM stucture type, where the first element of the structure type is a
// pointer to the elemental type of the MemRef and the following N elements are
// values of the Index type, one for each of N dynamic dimensions of the MemRef.
-Type TypeConverter::convertMemRefType(MemRefType type) {
+Type LLVMLowering::convertMemRefType(MemRefType type) {
llvm::Type *elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
SmallVector<llvm::Type *, 8> types(numDynamicSizes + 1, getIndexType());
types.front() = ptrType;
- return wrap(llvm::StructType::get(llvmContext, types));
+ return wrap(llvm::StructType::get(llvmDialect->getLLVMContext(), types));
}
// Convert a 1D vector type to an LLVM vector type.
-Type TypeConverter::convertVectorType(VectorType type) {
+Type LLVMLowering::convertVectorType(VectorType type) {
if (type.getRank() != 1) {
+ auto *mlirContext = llvmDialect->getContext();
mlirContext->emitError(UnknownLoc::get(mlirContext),
"only 1D vectors are supported");
return {};
}
// Dispatch based on the actual type. Return null type on error.
-Type TypeConverter::convertType(Type type) {
+Type LLVMLowering::convertStandardType(Type type) {
if (auto funcType = type.dyn_cast<FunctionType>())
return convertFunctionType(funcType);
if (auto intType = type.dyn_cast<IntegerType>())
return {};
}
-Type TypeConverter::convert(Type t, llvm::Module &module) {
- return TypeConverter(module, t.getContext()).convertType(t);
-}
-
-FunctionType TypeConverter::convertFunctionSignature(
- FunctionType t, llvm::Module &module,
- llvm::function_ref<Type(Type)> typeConversionCallback) {
- return TypeConverter(module, t.getContext())
- .convertFunctionSignatureType(t, typeConversionCallback);
-}
-
-Type TypeConverter::getMemRefElementPtrType(MemRefType t,
- llvm::Module &module) {
+// Convert the element type of the memref `t` to to an LLVM type using
+// `lowering`, get a pointer LLVM type pointing to the converted `t`, wrap it
+// into the MLIR LLVM dialect type and return.
+static Type getMemRefElementPtrType(MemRefType t, LLVMLowering &lowering) {
auto elementType = t.getElementType();
- auto converted = convert(elementType, module);
+ auto converted = lowering.convertType(elementType);
if (!converted)
return {};
llvm::Type *llvmType = converted.cast<LLVM::LLVMType>().getUnderlyingType();
return LLVM::LLVMType::get(t.getContext(), llvmType->getPointerTo());
}
-Type TypeConverter::pack(ArrayRef<Type> types, llvm::Module &module,
- MLIRContext &mlirContext) {
- return TypeConverter(module, &mlirContext).getPackedResultType(types);
-}
+LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
+ LLVMLowering &lowering_)
+ : DialectOpConversion(rootOpName, /*benefit=*/1, context),
+ lowering(lowering_) {}
namespace {
// Base class for Standard to LLVM IR op conversions. Matches the Op type
// provided as template argument. Carries a reference to the LLVM dialect in
// case it is necessary for rewriters.
template <typename SourceOp>
-class LLVMLegalizationPattern : public DialectOpConversion {
+class LLVMLegalizationPattern : public LLVMOpLowering {
public:
// Construct a conversion pattern.
- explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect)
- : DialectOpConversion(SourceOp::getOperationName(), 1,
- dialect.getContext()),
- dialect(dialect) {}
+ explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_,
+ LLVMLowering &lowering_)
+ : LLVMOpLowering(SourceOp::getOperationName(), dialect_.getContext(),
+ lowering_),
+ dialect(dialect_) {}
// Match by type.
PatternMatchResult match(Operation *op) const override {
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
unsigned numResults = op->getNumResults();
- auto *mlirContext = op->getContext();
Type packedType;
if (numResults != 0) {
packedType =
- TypeConverter::pack(getTypes(op->getResults()),
- this->dialect.getLLVMModule(), *mlirContext);
+ this->lowering.packFunctionResults(getTypes(op->getResults()));
assert(packedType && "type conversion failed, such operation should not "
"have been matched");
}
SmallVector<Value *, 4> results;
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
- auto type = TypeConverter::convert(op->getResult(i)->getType(),
- this->dialect.getLLVMModule());
+ auto type = this->lowering.convertType(op->getResult(i)->getType());
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), type, newOp.getOperation()->getResult(0),
this->getIntegerArrayAttr(rewriter, i)));
rewriter.getFunctionAttr(mallocFunc),
cumulativeSize)
.getResult(0);
- auto structElementType = TypeConverter::convert(elementType, getModule());
+ auto structElementType = lowering.convertType(elementType);
auto elementPtrType = LLVM::LLVMType::get(
op->getContext(), structElementType.cast<LLVM::LLVMType>()
.getUnderlyingType()
}
// Create the MemRef descriptor.
- auto structType = TypeConverter::convert(type, getModule());
+ auto structType = lowering.convertType(type);
Value *memRefDescriptor = rewriter.create<LLVM::UndefOp>(
op->getLoc(), structType, ArrayRef<Value *>{});
auto sourceType = memRefCastOp.getOperand()->getType().cast<MemRefType>();
// Copy the data buffer pointer.
- auto elementTypePtr =
- TypeConverter::getMemRefElementPtrType(targetType, getModule());
+ auto elementTypePtr = getMemRefElementPtrType(targetType, lowering);
Value *buffer =
extractMemRefElementPtr(rewriter, op->getLoc(), operands[0],
elementTypePtr, sourceType.hasStaticShape());
}
// Create the new MemRef descriptor.
- auto structType = TypeConverter::convert(targetType, getModule());
+ auto structType = lowering.convertType(targetType);
Value *newDescriptor = rewriter.create<LLVM::UndefOp>(
op->getLoc(), structType, ArrayRef<Value *>{});
// Otherwise target type is dynamic memref, so create a proper descriptor.
Value *getDataPtr(Location loc, MemRefType type, Value *dataPtr,
ArrayRef<Value *> indices, FuncBuilder &rewriter,
llvm::Module &module) const {
- auto ptrType = TypeConverter::getMemRefElementPtrType(type, module);
+ auto ptrType = getMemRefElementPtrType(type, this->lowering);
auto shape = type.getShape();
if (type.hasStaticShape()) {
// NB: If memref was statically-shaped, dataPtr is pointer to raw data.
Value *dataPtr = getDataPtr(op->getLoc(), type, operands.front(),
operands.drop_front(), rewriter, getModule());
- auto elementType =
- TypeConverter::convert(type.getElementType(), getModule());
+ auto elementType = lowering.convertType(type.getElementType());
SmallVector<Value *, 4> results;
results.push_back(rewriter.create<LLVM::LoadOp>(
// Otherwise, we need to pack the arguments into an LLVM struct type before
// returning.
- auto *mlirContext = op->getContext();
- auto packedType = TypeConverter::pack(
- getTypes(op->getOperands()), dialect.getLLVMModule(), *mlirContext);
+ auto packedType = lowering.packFunctionResults(getTypes(op->getOperands()));
Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
for (unsigned i = 0; i < numArguments; ++i) {
LoadOpLowering, MemRefCastOpLowering, MulFOpLowering, MulIOpLowering,
OrOpLowering, RemISOpLowering, RemIUOpLowering, RemFOpLowering,
ReturnOpLowering, SelectOpLowering, StoreOpLowering, SubFOpLowering,
- SubIOpLowering, XOrOpLowering>::build(&converterStorage, *llvmDialect);
+ SubIOpLowering, XOrOpLowering>::build(&converterStorage, *llvmDialect,
+ *this);
auto extraConverters = initAdditionalConverters();
converters.insert(extraConverters.begin(), extraConverters.end());
return converters;
// Convert types using the stored LLVM IR module.
Type LLVMLowering::convertType(Type t) {
- if (auto result = TypeConverter::convert(t, *module))
+ if (auto result = convertStandardType(t))
return result;
if (auto result = convertAdditionalType(t))
return result;
return {};
}
+static llvm::Type *unwrapType(Type type) {
+ return type.cast<LLVM::LLVMType>().getUnderlyingType();
+}
+
+// Create an LLVM IR structure type if there is more than one result.
+Type LLVMLowering::packFunctionResults(ArrayRef<Type> types) {
+ assert(!types.empty() && "expected non-empty list of type");
+
+ if (types.size() == 1)
+ return convertType(types.front());
+
+ SmallVector<llvm::Type *, 8> resultTypes;
+ resultTypes.reserve(types.size());
+ for (auto t : types) {
+ Type converted = convertType(t);
+ if (!converted)
+ return {};
+ resultTypes.push_back(unwrapType(converted));
+ }
+
+ return LLVM::LLVMType::get(
+ llvmDialect->getContext(),
+ llvm::StructType::get(llvmDialect->getLLVMContext(), resultTypes));
+}
+
// Convert function signatures using the stored LLVM IR module.
FunctionType LLVMLowering::convertFunctionSignatureType(
- FunctionType t, ArrayRef<NamedAttributeList> argAttrs,
+ FunctionType type, ArrayRef<NamedAttributeList> argAttrs,
SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) {
convertedArgAttrs.reserve(argAttrs.size());
for (auto attr : argAttrs)
convertedArgAttrs.push_back(attr);
- return TypeConverter::convertFunctionSignature(
- t, *module, [this](Type t) { return convertType(t); });
+
+ SmallVector<Type, 8> argTypes;
+ for (auto t : type.getInputs()) {
+ auto converted = convertType(t);
+ if (!converted)
+ return {};
+ argTypes.push_back(converted);
+ }
+
+ // If function does not return anything, return immediately.
+ if (type.getNumResults() == 0)
+ return FunctionType::get(argTypes, {}, type.getContext());
+
+ // Otherwise pack the result types into a struct.
+ if (auto result = packFunctionResults(type.getResults()))
+ return FunctionType::get(argTypes, result, result.getContext());
+
+ return {};
}
namespace {
static PassRegistration<LLVMLoweringPass>
pass("lower-to-llvm", "Convert all functions to the LLVM IR dialect");
-
-Type mlir::LLVM::convertToLLVMDialectType(Type t, llvm::Module &llvmModule) {
- return TypeConverter::convert(t, llvmModule);
-}
using sub = ValueBuilder<mlir::LLVM::SubOp>;
using mul = ValueBuilder<mlir::LLVM::MulOp>;
-static llvm::Module *getLLVMModule(MLIRContext *context) {
- auto *llvmDialect =
- static_cast<LLVM::LLVMDialect *>(context->getRegisteredDialect("llvm"));
- if (!llvmDialect) {
- context->emitError(UnknownLoc::get(context),
- "LLVM IR dialect is not registered");
- return nullptr;
- }
- return &llvmDialect->getLLVMModule();
-}
-
template <typename T>
static llvm::Type *getPtrToElementType(T containerType,
- llvm::Module &llvmModule) {
- return convertToLLVMDialectType(containerType.getElementType(), llvmModule)
+ LLVMLowering &lowering) {
+ return lowering.convertType(containerType.getElementType())
.template cast<LLVMType>()
.getUnderlyingType()
->getPointerTo();
// - an F32 type is converted into an LLVM float type
// - a Buffer, Range or View is converted into an LLVM structure type
// containing the respective dynamic values.
-static Type convertLinalgType(Type t, llvm::Module &llvmModule) {
+static Type convertLinalgType(Type t, LLVMLowering &lowering) {
auto *context = t.getContext();
- auto *int64Ty = llvm::Type::getInt64Ty(llvmModule.getContext());
+ auto *int64Ty = lowering.convertType(IntegerType::get(64, context))
+ .cast<LLVM::LLVMType>()
+ .getUnderlyingType();
// A buffer descriptor contains the pointer to a flat region of storage and
// the size of the region.
// int64_t size;
// };
if (auto bufferTy = t.dyn_cast<BufferType>()) {
- auto *ptrTy = getPtrToElementType(bufferTy, llvmModule);
+ auto *ptrTy = getPtrToElementType(bufferTy, lowering);
auto *structTy = llvm::StructType::get(ptrTy, int64Ty);
return LLVMType::get(context, structTy);
}
// int64_t strides[Rank];
// };
if (auto viewTy = t.dyn_cast<ViewType>()) {
- auto *ptrTy = getPtrToElementType(viewTy, llvmModule);
+ auto *ptrTy = getPtrToElementType(viewTy, lowering);
auto *arrayTy = llvm::ArrayType::get(int64Ty, viewTy.getRank());
auto *structTy = llvm::StructType::get(ptrTy, int64Ty, arrayTy, arrayTy);
return LLVMType::get(context, structTy);
}
// BufferSizeOp creates a new `index` value.
-class BufferSizeOpConversion : public DialectOpConversion {
+class BufferSizeOpConversion : public LLVMOpLowering {
public:
- explicit BufferSizeOpConversion(MLIRContext *context)
- : DialectOpConversion(BufferSizeOp::getOperationName(), 1, context),
- llvmModule(*getLLVMModule(context)) {}
+ BufferSizeOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+ : LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
- auto bufferSizeType =
- convertToLLVMDialectType(operands[0]->getType(), llvmModule);
+ auto bufferSizeType = lowering.convertType(operands[0]->getType());
edsc::ScopedContext context(rewriter, op->getLoc());
return {extractvalue(bufferSizeType, operands[0],
makePositionAttr(rewriter, 1))};
}
-
- llvm::Module &llvmModule;
};
// RangeOp creates a new range descriptor.
-class RangeOpConversion : public DialectOpConversion {
+class RangeOpConversion : public LLVMOpLowering {
public:
- explicit RangeOpConversion(MLIRContext *context)
- : DialectOpConversion(RangeOp::getOperationName(), 1, context),
- llvmModule(*getLLVMModule(context)) {}
+ explicit RangeOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+ : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
auto rangeOp = op->cast<RangeOp>();
auto rangeDescriptorType =
- convertLinalgType(rangeOp.getResult()->getType(), llvmModule);
+ convertLinalgType(rangeOp.getResult()->getType(), lowering);
edsc::ScopedContext context(rewriter, op->getLoc());
return {desc};
}
-
- llvm::Module &llvmModule;
};
-class SliceOpConversion : public DialectOpConversion {
+class SliceOpConversion : public LLVMOpLowering {
public:
- explicit SliceOpConversion(MLIRContext *context)
- : DialectOpConversion(SliceOp::getOperationName(), 1, context),
- llvmModule(*getLLVMModule(context)) {}
+ explicit SliceOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+ : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
auto sliceOp = op->cast<SliceOp>();
auto viewDescriptorType =
- convertLinalgType(sliceOp.getViewType(), llvmModule);
+ convertLinalgType(sliceOp.getViewType(), lowering);
auto viewType = sliceOp.getBaseViewType();
- auto int64Ty =
- convertToLLVMDialectType(rewriter.getIntegerType(64), llvmModule);
+ auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
// Helper function to create an integer array attribute out of a list of
// values.
auto getViewPtr = [pos, &rewriter, this](ViewType type,
Value *view) -> Value * {
auto elementPtrTy =
- rewriter.getType<LLVMType>(getPtrToElementType(type, llvmModule));
+ rewriter.getType<LLVMType>(getPtrToElementType(type, lowering));
return extractvalue(elementPtrTy, view, pos(0));
};
return {desc};
}
-
- llvm::Module &llvmModule;
};
-class ViewOpConversion : public DialectOpConversion {
+class ViewOpConversion : public LLVMOpLowering {
public:
- explicit ViewOpConversion(MLIRContext *context)
- : DialectOpConversion(ViewOp::getOperationName(), 1, context),
- llvmModule(*getLLVMModule(context)) {}
+ explicit ViewOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+ : LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
auto viewOp = op->cast<ViewOp>();
- auto viewDescriptorType =
- convertLinalgType(viewOp.getViewType(), llvmModule);
+ auto viewDescriptorType = convertLinalgType(viewOp.getViewType(), lowering);
auto elementType = rewriter.getType<LLVMType>(
- getPtrToElementType(viewOp.getViewType(), llvmModule));
- auto int64Ty =
- convertToLLVMDialectType(rewriter.getIntegerType(64), llvmModule);
+ getPtrToElementType(viewOp.getViewType(), lowering));
+ auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
auto pos = [&rewriter](ArrayRef<int> values) {
return makePositionAttr(rewriter, values);
return {desc};
}
-
- llvm::Module &llvmModule;
};
// DotOp creates a new range descriptor.
-class DotOpConversion : public DialectOpConversion {
+class DotOpConversion : public LLVMOpLowering {
public:
- explicit DotOpConversion(MLIRContext *context)
- : DialectOpConversion(DotOp::getOperationName(), 1, context) {}
+ explicit DotOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+ : LLVMOpLowering(DotOp::getOperationName(), context, lowering_) {}
static StringRef libraryFunctionName() { return "linalg_dot"; }
return ConversionListBuilder<
BufferSizeOpConversion, DotOpConversion, RangeOpConversion,
SliceOpConversion, ViewOpConversion>::build(&converterStorage,
- llvmDialect->getContext());
+ llvmDialect->getContext(),
+ *this);
}
Type convertAdditionalType(Type t) override {
- return convertLinalgType(t, *module);
+ return convertLinalgType(t, *this);
}
};
} // end anonymous namespace