// Simple conversions.
if (t.isa<IndexType>()) {
int width = dialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
- auto *integerTy = llvm::IntegerType::get(dialect->getLLVMContext(), width);
- return LLVM::LLVMType::get(context, integerTy);
- }
- if (auto intTy = t.dyn_cast<IntegerType>()) {
- int width = intTy.getWidth();
- auto *integerTy = llvm::IntegerType::get(dialect->getLLVMContext(), width);
- return LLVM::LLVMType::get(context, integerTy);
- }
- if (t.isF32()) {
- auto *floatTy = llvm::Type::getFloatTy(dialect->getLLVMContext());
- return LLVM::LLVMType::get(context, floatTy);
- }
- if (t.isF64()) {
- auto *doubleTy = llvm::Type::getDoubleTy(dialect->getLLVMContext());
- return LLVM::LLVMType::get(context, doubleTy);
+ return LLVM::LLVMType::getIntNTy(dialect, width);
}
+ if (auto intTy = t.dyn_cast<IntegerType>())
+ return LLVM::LLVMType::getIntNTy(dialect, intTy.getWidth());
+ if (t.isF32())
+ return LLVM::LLVMType::getFloatTy(dialect);
+ if (t.isF64())
+ return LLVM::LLVMType::getDoubleTy(dialect);
// Range descriptor contains the range bounds and the step as 64-bit integers.
//
// int64_t step;
// };
if (auto rangeTy = t.dyn_cast<linalg::RangeType>()) {
- auto *int64Ty = llvm::Type::getInt64Ty(dialect->getLLVMContext());
- auto *structTy = llvm::StructType::get(int64Ty, int64Ty, int64Ty);
- return LLVM::LLVMType::get(context, structTy);
+ auto int64Ty = LLVM::LLVMType::getInt64Ty(dialect);
+ return LLVM::LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
}
// View descriptor contains the pointer to the data buffer, followed by a
// int64_t strides[Rank];
// };
if (auto viewTy = t.dyn_cast<linalg::ViewType>()) {
- auto *elemTy = linalg::convertLinalgType(viewTy.getElementType())
- .cast<LLVM::LLVMType>()
- .getUnderlyingType()
- ->getPointerTo();
- auto *int64Ty = llvm::Type::getInt64Ty(dialect->getLLVMContext());
- auto *arrayTy = llvm::ArrayType::get(int64Ty, viewTy.getRank());
- auto *structTy = llvm::StructType::get(elemTy, int64Ty, arrayTy, arrayTy);
- return LLVM::LLVMType::get(context, structTy);
+ auto elemTy = linalg::convertLinalgType(viewTy.getElementType())
+ .cast<LLVM::LLVMType>()
+ .getPointerTo();
+ auto int64Ty = LLVM::LLVMType::getInt64Ty(dialect);
+ auto arrayTy = LLVM::LLVMType::getArrayTy(int64Ty, viewTy.getRank());
+ return LLVM::LLVMType::getStructTy(elemTy, int64Ty, arrayTy, arrayTy);
}
// All other types are kept as is.
if (type.hasStaticShape())
return memref;
- auto elementTy = LLVM::LLVMType::get(
- type.getContext(), linalg::convertLinalgType(type.getElementType())
- .cast<LLVM::LLVMType>()
- .getUnderlyingType()
- ->getPointerTo());
+ auto elementTy = linalg::convertLinalgType(type.getElementType())
+ .cast<LLVM::LLVMType>()
+ .getPointerTo();
return intrinsics::extractvalue(elementTy, memref, pos(0));
};
auto sliceOp = cast<linalg::SliceOp>(op);
auto newViewDescriptorType =
linalg::convertLinalgType(sliceOp.getViewType());
- auto elementType = rewriter.getType<LLVM::LLVMType>(
- linalg::convertLinalgType(sliceOp.getElementType())
- .cast<LLVM::LLVMType>()
- .getUnderlyingType()
- ->getPointerTo());
+ auto elementType = linalg::convertLinalgType(sliceOp.getElementType())
+ .cast<LLVM::LLVMType>()
+ .getPointerTo();
auto int64Ty = linalg::convertLinalgType(rewriter.getIntegerType(64));
auto pos = [&rewriter](ArrayRef<int> values) {
auto loadOp = cast<Op>(op);
auto elementType =
loadOp.getViewType().template cast<linalg::ViewType>().getElementType();
- auto *llvmPtrType = linalg::convertLinalgType(elementType)
- .template cast<LLVM::LLVMType>()
- .getUnderlyingType()
- ->getPointerTo();
- elementType = rewriter.getType<LLVM::LLVMType>(llvmPtrType);
+ elementType = linalg::convertLinalgType(elementType)
+ .template cast<LLVM::LLVMType>()
+ .getPointerTo();
auto int64Ty = linalg::convertLinalgType(rewriter.getIntegerType(64));
auto pos = [&rewriter](ArrayRef<int> values) {
// Create a function declaration for printf, signature is `i32 (i8*, ...)`
Builder builder(&module);
- MLIRContext *context = module.getContext();
- auto *llvmDialect =
+ auto *dialect =
module.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
- auto &llvmModule = llvmDialect->getLLVMModule();
- llvm::IRBuilder<> llvmBuilder(llvmModule.getContext());
- auto llvmI32Ty = LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(32));
- auto llvmI8PtrTy =
- LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(8)->getPointerTo());
+ auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect);
+ auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo();
auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty});
printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy);
// It should be variadic, but we don't support it fully just yet.
namespace mlir {
namespace LLVM {
+class LLVMDialect;
namespace detail {
struct LLVMTypeStorage;
-}
+struct LLVMDialectImpl;
+} // namespace detail
class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
detail::LLVMTypeStorage> {
static bool kindof(unsigned kind) { return kind == LLVM_TYPE; }
- static LLVMType get(MLIRContext *context, llvm::Type *llvmType);
-
+ LLVMDialect &getDialect();
llvm::Type *getUnderlyingType() const;
+
+ /// Array type utilities.
+ LLVMType getArrayElementType();
+
+ /// Pointer type utilities.
+ LLVMType getPointerTo(unsigned addrSpace = 0);
+ LLVMType getPointerElementTy();
+
+ /// Struct type utilities.
+ LLVMType getStructElementType(unsigned i);
+
+ /// Utilities used to generate floating point types.
+ static LLVMType getDoubleTy(LLVMDialect *dialect);
+ static LLVMType getFloatTy(LLVMDialect *dialect);
+ static LLVMType getHalfTy(LLVMDialect *dialect);
+
+ /// Utilities used to generate integer types.
+ static LLVMType getIntNTy(LLVMDialect *dialect, unsigned numBits);
+ static LLVMType getInt1Ty(LLVMDialect *dialect) {
+ return getIntNTy(dialect, /*numBits=*/1);
+ }
+ static LLVMType getInt8Ty(LLVMDialect *dialect) {
+ return getIntNTy(dialect, /*numBits=*/8);
+ }
+ static LLVMType getInt8PtrTy(LLVMDialect *dialect) {
+ return getInt8Ty(dialect).getPointerTo();
+ }
+ static LLVMType getInt16Ty(LLVMDialect *dialect) {
+ return getIntNTy(dialect, /*numBits=*/16);
+ }
+ static LLVMType getInt32Ty(LLVMDialect *dialect) {
+ return getIntNTy(dialect, /*numBits=*/32);
+ }
+ static LLVMType getInt64Ty(LLVMDialect *dialect) {
+ return getIntNTy(dialect, /*numBits=*/64);
+ }
+
+ /// Utilities used to generate other miscellaneous types.
+ static LLVMType getArrayTy(LLVMType elementType, uint64_t numElements);
+ static LLVMType getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
+ bool isVarArg);
+ static LLVMType getFunctionTy(LLVMType result, bool isVarArg) {
+ return getFunctionTy(result, llvm::None, isVarArg);
+ }
+ static LLVMType getStructTy(LLVMDialect *dialect, ArrayRef<LLVMType> elements,
+ bool isPacked = false);
+ static LLVMType getStructTy(LLVMDialect *dialect, bool isPacked = false) {
+ return getStructTy(dialect, llvm::None, isPacked);
+ }
+ template <typename... Args>
+ static typename std::enable_if<llvm::are_base_of<LLVMType, Args...>::value,
+ LLVMType>::type
+ getStructTy(LLVMType elt1, Args... elts) {
+ SmallVector<LLVMType, 8> fields({elt1, elts...});
+ return getStructTy(&elt1.getDialect(), fields);
+ }
+ static LLVMType getVectorTy(LLVMType elementType, unsigned numElements);
+ static LLVMType getVoidTy(LLVMDialect *dialect);
+
+private:
+ friend LLVMDialect;
+
+ /// Get an LLVM type with a pre-existing llvm type.
+ static LLVMType get(MLIRContext *context, llvm::Type *llvmType);
};
///// Ops /////
class LLVMDialect : public Dialect {
public:
explicit LLVMDialect(MLIRContext *context);
+ ~LLVMDialect();
static StringRef getDialectNamespace() { return "llvm"; }
- llvm::LLVMContext &getLLVMContext() { return llvmContext; }
- llvm::Module &getLLVMModule() { return module; }
+ llvm::LLVMContext &getLLVMContext();
+ llvm::Module &getLLVMModule();
/// Parse a type registered to this dialect.
Type parseType(StringRef tyData, Location loc) const override;
NamedAttribute argAttr) override;
private:
- llvm::LLVMContext llvmContext;
- llvm::Module module;
+ friend LLVMType;
+
+ std::unique_ptr<detail::LLVMDialectImpl> impl;
};
} // end namespace LLVM
class LLVMContext;
class Module;
class Type;
-}
+} // namespace llvm
namespace mlir {
namespace LLVM {
class LLVMDialect;
+class LLVMType;
}
/// Conversion from the Standard dialect to the LLVM IR dialect. Provides hooks
/// Returns the LLVM context.
llvm::LLVMContext &getLLVMContext();
+ /// Returns the LLVM dialect.
+ LLVM::LLVMDialect *getDialect() { return llvmDialect; }
+
protected:
/// Add a set of converters to the given pattern list. Store the module
/// associated with the dialect for further type conversion.
// Get the LLVM representation of the index type based on the bitwidth of the
// pointer as defined by the data layout of the module.
- llvm::IntegerType *getIndexType();
+ LLVM::LLVMType getIndexType();
// Wrap the given LLVM IR type into an LLVM IR dialect type.
Type wrap(llvm::Type *llvmType);
- // Extract an LLVM IR type from the LLVM IR dialect type.
- llvm::Type *unwrap(Type type);
+ // Extract an LLVM IR dialect type.
+ LLVM::LLVMType unwrap(Type type);
};
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
#include "llvm/IR/Attributes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Type.h"
+#include "llvm/Support/Mutex.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
using namespace mlir::LLVM;
-namespace mlir {
-namespace LLVM {
-namespace detail {
-struct LLVMTypeStorage : public ::mlir::TypeStorage {
- LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {}
-
- // LLVM types are pointer-unique.
- using KeyTy = llvm::Type *;
- bool operator==(const KeyTy &key) const { return key == underlyingType; }
-
- static LLVMTypeStorage *construct(TypeStorageAllocator &allocator,
- llvm::Type *ty) {
- return new (allocator.allocate<LLVMTypeStorage>()) LLVMTypeStorage(ty);
- }
-
- llvm::Type *underlyingType;
-};
-} // end namespace detail
-} // end namespace LLVM
-} // end namespace mlir
-
-LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
- return Base::get(context, FIRST_LLVM_TYPE, llvmType);
-}
-
-llvm::Type *LLVMType::getUnderlyingType() const {
- return getImpl()->underlyingType;
-}
-
static void printLLVMBinaryOp(OpAsmPrinter *p, Operation *op) {
// Fallback to the generic form if the op is not well-formed (may happen
// during incomplete rewrites, and used for debugging).
// The result type is either i1 or a vector type <? x i1> if the inputs are
// vectors.
auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>();
- llvm::Type *llvmResultType = llvm::Type::getInt1Ty(dialect->getLLVMContext());
+ auto resultType = LLVMType::getInt1Ty(dialect);
auto argType = type.dyn_cast<LLVM::LLVMType>();
if (!argType)
return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type");
if (argType.getUnderlyingType()->isVectorTy())
- llvmResultType = llvm::VectorType::get(
- llvmResultType, argType.getUnderlyingType()->getVectorNumElements());
- auto resultType = builder.getType<LLVM::LLVMType>(llvmResultType);
+ resultType = LLVMType::getVectorTy(
+ resultType, argType.getUnderlyingType()->getVectorNumElements());
result->attributes = attrs;
result->addTypes({resultType});
//===----------------------------------------------------------------------===//
static void printAllocaOp(OpAsmPrinter *p, AllocaOp &op) {
- auto *llvmPtrTy = op.getType().cast<LLVM::LLVMType>().getUnderlyingType();
- auto *llvmElemTy = llvm::cast<llvm::PointerType>(llvmPtrTy)->getElementType();
- auto elemTy = LLVM::LLVMType::get(op.getContext(), llvmElemTy);
+ auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
auto funcTy = FunctionType::get({op.arraySize()->getType()}, {op.getType()},
op.getContext());
if (!llvmTy)
return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type"),
nullptr;
- auto *llvmPtrTy = dyn_cast<llvm::PointerType>(llvmTy.getUnderlyingType());
- if (!llvmPtrTy)
+ if (!llvmTy.getUnderlyingType()->isPointerTy())
return parser->emitError(trailingTypeLoc, "expected LLVM pointer type"),
nullptr;
- auto elemTy = LLVM::LLVMType::get(parser->getBuilder().getContext(),
- llvmPtrTy->getElementType());
- return elemTy;
+ return llvmTy.getPointerElementTy();
}
// <operation> ::= `llvm.load` ssa-use attribute-dict? `:` type
Builder &builder = parser->getBuilder();
auto *llvmDialect =
builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
- llvm::Type *llvmResultType;
- Type wrappedResultType;
+ LLVM::LLVMType llvmResultType;
if (funcType.getNumResults() == 0) {
- llvmResultType = llvm::Type::getVoidTy(llvmDialect->getLLVMContext());
- wrappedResultType = builder.getType<LLVM::LLVMType>(llvmResultType);
+ llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect);
} else {
- wrappedResultType = funcType.getResult(0);
- auto wrappedLLVMResultType = wrappedResultType.dyn_cast<LLVM::LLVMType>();
- if (!wrappedLLVMResultType)
+ llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
+ if (!llvmResultType)
return parser->emitError(trailingTypeLoc,
"expected result to have LLVM type");
- llvmResultType = wrappedLLVMResultType.getUnderlyingType();
}
- SmallVector<llvm::Type *, 8> argTypes;
+ SmallVector<LLVM::LLVMType, 8> argTypes;
argTypes.reserve(funcType.getNumInputs());
for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
auto argType = funcType.getInput(i).dyn_cast<LLVM::LLVMType>();
if (!argType)
return parser->emitError(trailingTypeLoc,
"expected LLVM types as inputs");
- argTypes.push_back(argType.getUnderlyingType());
+ argTypes.push_back(argType);
}
- auto *llvmFuncType = llvm::FunctionType::get(llvmResultType, argTypes,
- /*isVarArg=*/false);
- auto wrappedFuncType =
- builder.getType<LLVM::LLVMType>(llvmFuncType->getPointerTo());
+ auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
+ /*isVarArg=*/false);
+ auto wrappedFuncType = llvmFuncType.getPointerTo();
auto funcArguments =
ArrayRef<OpAsmParser::OperandType>(operands).drop_front();
parser->getNameLoc(), result->operands))
return failure();
- result->addTypes(wrappedResultType);
+ result->addTypes(llvmResultType);
}
result->attributes = attrs;
// type by taking the element type, indexed by the position attribute for
// stuctures. Check the position index before accessing, it is supposed to be
// in bounds.
- llvm::Type *llvmContainerType = wrappedContainerType.getUnderlyingType();
for (Attribute subAttr : positionArrayAttr) {
auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
if (!positionElementAttr)
"expected an array of integer literals"),
nullptr;
int position = positionElementAttr.getInt();
+ auto *llvmContainerType = wrappedContainerType.getUnderlyingType();
if (llvmContainerType->isArrayTy()) {
if (position < 0 || static_cast<unsigned>(position) >=
llvmContainerType->getArrayNumElements())
return parser->emitError(attributeLoc, "position out of bounds"),
nullptr;
- llvmContainerType = llvmContainerType->getArrayElementType();
+ wrappedContainerType = wrappedContainerType.getArrayElementType();
} else if (llvmContainerType->isStructTy()) {
if (position < 0 || static_cast<unsigned>(position) >=
llvmContainerType->getStructNumElements())
return parser->emitError(attributeLoc, "position out of bounds"),
nullptr;
- llvmContainerType = llvmContainerType->getStructElementType(position);
+ wrappedContainerType =
+ wrappedContainerType.getStructElementType(position);
} else {
return parser->emitError(typeLoc,
"expected wrapped LLVM IR structure/array type"),
nullptr;
}
}
-
- Builder &builder = parser->getBuilder();
- return builder.getType<LLVM::LLVMType>(llvmContainerType);
+ return wrappedContainerType;
}
// <operation> ::= `llvm.extractvalue` ssa-use
Builder &builder = parser->getBuilder();
auto *llvmDialect =
builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
- auto i1Type = builder.getType<LLVM::LLVMType>(
- llvm::Type::getInt1Ty(llvmDialect->getLLVMContext()));
+ auto i1Type = LLVM::LLVMType::getInt1Ty(llvmDialect);
if (parser->parseOperand(condition) || parser->parseComma() ||
parser->parseSuccessorAndUseList(trueDest, trueOperands) ||
// LLVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
+namespace mlir {
+namespace LLVM {
+namespace detail {
+struct LLVMDialectImpl {
+ LLVMDialectImpl() : module("LLVMDialectModule", llvmContext) {}
+
+ llvm::LLVMContext llvmContext;
+ llvm::Module module;
+
+ /// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not
+ /// multi-threaded and requires locked access to prevent race conditions.
+ llvm::sys::SmartMutex<true> mutex;
+};
+} // end namespace detail
+} // end namespace LLVM
+} // end namespace mlir
+
LLVMDialect::LLVMDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context),
- module("LLVMDialectModule", llvmContext) {
+ impl(new detail::LLVMDialectImpl()) {
addTypes<LLVMType>();
addOperations<
#define GET_OP_LIST
allowUnknownOperations();
}
+LLVMDialect::~LLVMDialect() {}
+
#define GET_OP_CLASSES
#include "mlir/LLVMIR/LLVMOps.cpp.inc"
+llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; }
+llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
+
/// Parse a type registered to this dialect.
Type LLVMDialect::parseType(StringRef tyData, Location loc) const {
+ // LLVM is not thread-safe, so lock access to it.
+ llvm::sys::SmartScopedLock<true> lock(impl->mutex);
+
llvm::SMDiagnostic errorMessage;
- llvm::Type *type = llvm::parseType(tyData, errorMessage, module);
+ llvm::Type *type = llvm::parseType(tyData, errorMessage, impl->module);
if (!type)
return (getContext()->emitError(loc, errorMessage.getMessage()), nullptr);
return LLVMType::get(getContext(), type);
}
static DialectRegistration<LLVMDialect> llvmDialect;
+
+//===----------------------------------------------------------------------===//
+// LLVMType.
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace LLVM {
+namespace detail {
+struct LLVMTypeStorage : public ::mlir::TypeStorage {
+ LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {}
+
+ // LLVM types are pointer-unique.
+ using KeyTy = llvm::Type *;
+ bool operator==(const KeyTy &key) const { return key == underlyingType; }
+
+ static LLVMTypeStorage *construct(TypeStorageAllocator &allocator,
+ llvm::Type *ty) {
+ return new (allocator.allocate<LLVMTypeStorage>()) LLVMTypeStorage(ty);
+ }
+
+ llvm::Type *underlyingType;
+};
+} // end namespace detail
+} // end namespace LLVM
+} // end namespace mlir
+
+LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
+ return Base::get(context, FIRST_LLVM_TYPE, llvmType);
+}
+
+LLVMDialect &LLVMType::getDialect() {
+ return static_cast<LLVMDialect &>(Type::getDialect());
+}
+
+llvm::Type *LLVMType::getUnderlyingType() const {
+ return getImpl()->underlyingType;
+}
+
+/// Array type utilities.
+LLVMType LLVMType::getArrayElementType() {
+ return get(getContext(), getUnderlyingType()->getArrayElementType());
+}
+
+/// Pointer type utilities.
+LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
+ // Lock access to the dialect as this may modify the LLVM context.
+ llvm::sys::SmartScopedLock<true> lock(getDialect().impl->mutex);
+ return get(getContext(), getUnderlyingType()->getPointerTo(addrSpace));
+}
+LLVMType LLVMType::getPointerElementTy() {
+ return get(getContext(), getUnderlyingType()->getPointerElementType());
+}
+
+/// Struct type utilities.
+LLVMType LLVMType::getStructElementType(unsigned i) {
+ return get(getContext(), getUnderlyingType()->getStructElementType(i));
+}
+
+/// Utilities used to generate floating point types.
+LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) {
+ return get(dialect->getContext(),
+ llvm::Type::getDoubleTy(dialect->getLLVMContext()));
+}
+LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) {
+ return get(dialect->getContext(),
+ llvm::Type::getFloatTy(dialect->getLLVMContext()));
+}
+LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) {
+ return get(dialect->getContext(),
+ llvm::Type::getHalfTy(dialect->getLLVMContext()));
+}
+
+/// Utilities used to generate integer types.
+LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) {
+ // Lock access to the dialect as this may modify the LLVM context.
+ llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex);
+ return get(dialect->getContext(),
+ llvm::Type::getIntNTy(dialect->getLLVMContext(), numBits));
+}
+
+/// Utilities used to generate other miscellaneous types.
+LLVMType LLVMType::getArrayTy(LLVMType elementType, uint64_t numElements) {
+ // Lock access to the dialect as this may modify the LLVM context.
+ llvm::sys::SmartScopedLock<true> lock(elementType.getDialect().impl->mutex);
+ return get(
+ elementType.getContext(),
+ llvm::ArrayType::get(elementType.getUnderlyingType(), numElements));
+}
+LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
+ bool isVarArg) {
+ SmallVector<llvm::Type *, 8> llvmParams;
+ for (auto param : params)
+ llvmParams.push_back(param.getUnderlyingType());
+
+ // Lock access to the dialect as this may modify the LLVM context.
+ llvm::sys::SmartScopedLock<true> lock(result.getDialect().impl->mutex);
+ return get(result.getContext(),
+ llvm::FunctionType::get(result.getUnderlyingType(), llvmParams,
+ isVarArg));
+}
+LLVMType LLVMType::getStructTy(LLVMDialect *dialect,
+ ArrayRef<LLVMType> elements, bool isPacked) {
+ SmallVector<llvm::Type *, 8> llvmElements;
+ for (auto elt : elements)
+ llvmElements.push_back(elt.getUnderlyingType());
+
+ // Lock access to the dialect as this may modify the LLVM context.
+ llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex);
+ return get(
+ dialect->getContext(),
+ llvm::StructType::get(dialect->getLLVMContext(), llvmElements, isPacked));
+}
+LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
+ // Lock access to the dialect as this may modify the LLVM context.
+ llvm::sys::SmartScopedLock<true> lock(elementType.getDialect().impl->mutex);
+ return get(
+ elementType.getContext(),
+ llvm::VectorType::get(elementType.getUnderlyingType(), numElements));
+}
+LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) {
+ return get(dialect->getContext(),
+ llvm::Type::getVoidTy(dialect->getLLVMContext()));
+}
return module->getContext();
}
-// Wrap the given LLVM IR type into an LLVM IR dialect type.
-Type LLVMLowering::wrap(llvm::Type *llvmType) {
- return LLVM::LLVMType::get(llvmDialect->getContext(), llvmType);
-}
-
// Extract an LLVM IR type from the LLVM IR dialect type.
-llvm::Type *LLVMLowering::unwrap(Type type) {
+LLVM::LLVMType LLVMLowering::unwrap(Type type) {
if (!type)
return nullptr;
auto *mlirContext = type.getContext();
auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
if (!wrappedLLVMType)
- return mlirContext->emitError(UnknownLoc::get(mlirContext),
- "conversion resulted in a non-LLVM type"),
- nullptr;
- return wrappedLLVMType.getUnderlyingType();
+ mlirContext->emitError(UnknownLoc::get(mlirContext),
+ "conversion resulted in a non-LLVM type");
+ return wrappedLLVMType;
}
-llvm::IntegerType *LLVMLowering::getIndexType() {
- return llvm::IntegerType::get(llvmDialect->getLLVMContext(),
- module->getDataLayout().getPointerSizeInBits());
+LLVM::LLVMType LLVMLowering::getIndexType() {
+ return LLVM::LLVMType::getIntNTy(
+ llvmDialect, module->getDataLayout().getPointerSizeInBits());
}
-Type LLVMLowering::convertIndexType(IndexType type) {
- return wrap(getIndexType());
-}
+Type LLVMLowering::convertIndexType(IndexType type) { return getIndexType(); }
Type LLVMLowering::convertIntegerType(IntegerType type) {
- return wrap(
- llvm::Type::getIntNTy(llvmDialect->getLLVMContext(), type.getWidth()));
+ return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth());
}
Type LLVMLowering::convertFloatType(FloatType type) {
switch (type.getKind()) {
case mlir::StandardTypes::F32:
- return wrap(llvm::Type::getFloatTy(llvmDialect->getLLVMContext()));
+ return LLVM::LLVMType::getFloatTy(llvmDialect);
case mlir::StandardTypes::F64:
- return wrap(llvm::Type::getDoubleTy(llvmDialect->getLLVMContext()));
+ return LLVM::LLVMType::getDoubleTy(llvmDialect);
case mlir::StandardTypes::F16:
- return wrap(llvm::Type::getHalfTy(llvmDialect->getLLVMContext()));
+ return LLVM::LLVMType::getHalfTy(llvmDialect);
case mlir::StandardTypes::BF16: {
auto *mlirContext = llvmDialect->getContext();
return mlirContext->emitError(UnknownLoc::get(mlirContext),
// they are into an LLVM StructType in their order of appearance.
Type LLVMLowering::convertFunctionType(FunctionType type) {
// Convert argument types one by one and check for errors.
- SmallVector<llvm::Type *, 8> argTypes;
+ SmallVector<LLVM::LLVMType, 8> argTypes;
for (auto t : type.getInputs()) {
auto converted = convertType(t);
if (!converted)
// If function does not return anything, create the void result type,
// if it returns on element, convert it, otherwise pack the result types into
// a struct.
- llvm::Type *resultType =
+ LLVM::LLVMType resultType =
type.getNumResults() == 0
- ? llvm::Type::getVoidTy(llvmDialect->getLLVMContext())
+ ? LLVM::LLVMType::getVoidTy(llvmDialect)
: unwrap(packFunctionResults(type.getResults()));
if (!resultType)
return {};
- return wrap(llvm::FunctionType::get(resultType, argTypes, /*isVarArg=*/false)
- ->getPointerTo());
+ return LLVM::LLVMType::getFunctionTy(resultType, argTypes, /*isVarArg=*/false)
+ .getPointerTo();
}
// Convert a MemRef to an LLVM type. If the memref is statically-shaped, then
// pointer to the elemental type of the MemRef and the following N elements are
// values of the Index type, one for each of N dynamic dimensions of the MemRef.
Type LLVMLowering::convertMemRefType(MemRefType type) {
- llvm::Type *elementType = unwrap(convertType(type.getElementType()));
+ LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
- auto ptrType = elementType->getPointerTo();
+ auto ptrType = elementType.getPointerTo();
// Extra value for the memory space.
unsigned numDynamicSizes = type.getNumDynamicDims();
// If memref is statically-shaped we return the underlying pointer type.
- if (numDynamicSizes == 0) {
- return wrap(ptrType);
- }
- SmallVector<llvm::Type *, 8> types(numDynamicSizes + 1, getIndexType());
+ if (numDynamicSizes == 0)
+ return ptrType;
+
+ SmallVector<LLVM::LLVMType, 8> types(numDynamicSizes + 1, getIndexType());
types.front() = ptrType;
- return wrap(llvm::StructType::get(llvmDialect->getLLVMContext(), types));
+ return LLVM::LLVMType::getStructTy(llvmDialect, types);
}
// Convert a 1D vector type to an LLVM vector type.
return {};
}
- llvm::Type *elementType = unwrap(convertType(type.getElementType()));
+ LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
return elementType
- ? wrap(llvm::VectorType::get(elementType, type.getShape().front()))
+ ? LLVM::LLVMType::getVectorTy(elementType, type.getShape().front())
: Type();
}
auto converted = lowering.convertType(elementType);
if (!converted)
return {};
- llvm::Type *llvmType = converted.cast<LLVM::LLVMType>().getUnderlyingType();
- return LLVM::LLVMType::get(t.getContext(), llvmType->getPointerTo());
+ return converted.cast<LLVM::LLVMType>().getPointerTo();
}
LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
// Get the MLIR type wrapping the LLVM integer type whose bit width is defined
// by the pointer size used in the LLVM module.
LLVM::LLVMType getIndexType() const {
- llvm::Type *llvmType = llvm::Type::getIntNTy(
- getContext(), getModule().getDataLayout().getPointerSizeInBits());
- return LLVM::LLVMType::get(dialect.getContext(), llvmType);
+ return LLVM::LLVMType::getIntNTy(
+ &dialect, getModule().getDataLayout().getPointerSizeInBits());
}
// Get the MLIR type wrapping the LLVM i8* type.
LLVM::LLVMType getVoidPtrType() const {
- return LLVM::LLVMType::get(dialect.getContext(),
- llvm::Type::getInt8PtrTy(getContext()));
+ return LLVM::LLVMType::getInt8PtrTy(&dialect);
}
// Create an LLVM IR pseudo-operation defining the given index constant.
cumulativeSize)
.getResult(0);
auto structElementType = lowering.convertType(elementType);
- auto elementPtrType = LLVM::LLVMType::get(
- op->getContext(), structElementType.cast<LLVM::LLVMType>()
- .getUnderlyingType()
- ->getPointerTo());
+ auto elementPtrType =
+ structElementType.cast<LLVM::LLVMType>().getPointerTo();
allocated = rewriter.create<LLVM::BitcastOp>(op->getLoc(), elementPtrType,
ArrayRef<Value *>(allocated));
op->getFunction()->getModule()->getFunctions().push_back(freeFunc);
}
- auto *type =
- operands[0]->getType().cast<LLVM::LLVMType>().getUnderlyingType();
- auto hasStaticShape = type->isPointerTy();
- Type elementPtrType =
- (hasStaticShape)
- ? rewriter.getType<LLVM::LLVMType>(type)
- : rewriter.getType<LLVM::LLVMType>(
- cast<llvm::StructType>(type)->getStructElementType(0));
+ auto type = operands[0]->getType().cast<LLVM::LLVMType>();
+ auto hasStaticShape = type.getUnderlyingType()->isPointerTy();
+ Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0);
Value *bufferPtr = extractMemRefElementPtr(
rewriter, op->getLoc(), operands[0], elementPtrType, hasStaticShape);
Value *casted = rewriter.create<LLVM::BitcastOp>(
return {};
}
-static llvm::Type *unwrapType(Type type) {
- return type.cast<LLVM::LLVMType>().getUnderlyingType();
-}
-
// Create an LLVM IR structure type if there is more than one result.
Type LLVMLowering::packFunctionResults(ArrayRef<Type> types) {
assert(!types.empty() && "expected non-empty list of type");
if (types.size() == 1)
return convertType(types.front());
- SmallVector<llvm::Type *, 8> resultTypes;
+ SmallVector<LLVM::LLVMType, 8> resultTypes;
resultTypes.reserve(types.size());
for (auto t : types) {
- Type converted = convertType(t);
+ auto converted = convertType(t).dyn_cast<LLVM::LLVMType>();
if (!converted)
return {};
- resultTypes.push_back(unwrapType(converted));
+ resultTypes.push_back(converted);
}
- return LLVM::LLVMType::get(
- llvmDialect->getContext(),
- llvm::StructType::get(llvmDialect->getLLVMContext(), resultTypes));
+ return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes);
}
// Convert function signatures using the stored LLVM IR module.
using icmp = ValueBuilder<LLVM::ICmpOp>;
template <typename T>
-static llvm::Type *getPtrToElementType(T containerType,
- LLVMLowering &lowering) {
+static LLVMType getPtrToElementType(T containerType, LLVMLowering &lowering) {
return lowering.convertType(containerType.getElementType())
.template cast<LLVMType>()
- .getUnderlyingType()
- ->getPointerTo();
+ .getPointerTo();
}
// Convert the given type to the LLVM IR Dialect type. The following
// containing the respective dynamic values.
static Type convertLinalgType(Type t, LLVMLowering &lowering) {
auto *context = t.getContext();
- auto *int64Ty = lowering.convertType(IntegerType::get(64, context))
- .cast<LLVM::LLVMType>()
- .getUnderlyingType();
+ auto int64Ty = lowering.convertType(IntegerType::get(64, context))
+ .cast<LLVM::LLVMType>();
// A buffer descriptor contains the pointer to a flat region of storage and
// the size of the region.
// int64_t size;
// };
if (auto bufferType = t.dyn_cast<BufferType>()) {
- auto *ptrTy = getPtrToElementType(bufferType, lowering);
- auto *structTy = llvm::StructType::get(ptrTy, int64Ty);
- return LLVMType::get(context, structTy);
+ auto ptrTy = getPtrToElementType(bufferType, lowering);
+ return LLVMType::getStructTy(ptrTy, int64Ty);
}
// Range descriptor contains the range bounds and the step as 64-bit integers.
// int64_t max;
// int64_t step;
// };
- if (t.isa<RangeType>()) {
- auto *structTy = llvm::StructType::get(int64Ty, int64Ty, int64Ty);
- return LLVMType::get(context, structTy);
- }
+ if (t.isa<RangeType>())
+ return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
// View descriptor contains the pointer to the data buffer, followed by a
// 64-bit integer containing the distance between the beginning of the buffer
// int64_t strides[Rank];
// };
if (auto viewType = t.dyn_cast<ViewType>()) {
- auto *ptrTy = getPtrToElementType(viewType, lowering);
- auto *arrayTy = llvm::ArrayType::get(int64Ty, viewType.getRank());
- auto *structTy = llvm::StructType::get(ptrTy, int64Ty, arrayTy, arrayTy);
- return LLVMType::get(context, structTy);
+ auto ptrTy = getPtrToElementType(viewType, lowering);
+ auto arrayTy = LLVMType::getArrayTy(int64Ty, viewType.getRank());
+ return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy);
}
return Type();
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto indexType = IndexType::get(op->getContext());
- auto voidPtrTy = LLVM::LLVMType::get(
- op->getContext(),
- llvm::IntegerType::get(lowering.getLLVMContext(), 8)->getPointerTo());
+ auto voidPtrTy =
+ LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
auto int64Ty = lowering.convertType(operands[0]->getType());
// Insert the `malloc` declaration if it is not already present.
auto *module = op->getFunction()->getModule();
llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
else
elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
- auto elementPtrType = rewriter.getType<LLVMType>(getPtrToElementType(
- allocOp.getResult()->getType().cast<BufferType>(), lowering));
+ auto elementPtrType = getPtrToElementType(
+ allocOp.getResult()->getType().cast<BufferType>(), lowering);
auto bufferDescriptorType =
convertLinalgType(allocOp.getResult()->getType(), lowering);
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
- auto voidPtrTy = LLVM::LLVMType::get(
- op->getContext(),
- llvm::IntegerType::get(lowering.getLLVMContext(), 8)->getPointerTo());
+ auto voidPtrTy =
+ LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
// Insert the `free` declaration if it is not already present.
auto *module = op->getFunction()->getModule();
Function *freeFunc = module->getNamedFunction("free");
// Get MLIR types for extracting element pointer.
auto deallocOp = cast<BufferDeallocOp>(op);
- auto elementPtrTy = rewriter.getType<LLVMType>(getPtrToElementType(
- deallocOp.getOperand()->getType().cast<BufferType>(), lowering));
+ auto elementPtrTy = getPtrToElementType(
+ deallocOp.getOperand()->getType().cast<BufferType>(), lowering);
// Emit MLIR for buffer_dealloc.
edsc::ScopedContext context(rewriter, op->getLoc());
ArrayRef<Value *> indices,
PatternRewriter &rewriter) const {
auto loadOp = cast<Op>(op);
- auto elementTy = rewriter.getType<LLVMType>(
- getPtrToElementType(loadOp.getViewType(), lowering));
+ auto elementTy = getPtrToElementType(loadOp.getViewType(), lowering);
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
auto pos = [&rewriter](ArrayRef<int> values) {
return positionAttr(rewriter, values);
// Helper function to obtain the ptr of the given `view`.
auto getViewPtr = [pos, &rewriter, this](ViewType type,
Value *view) -> Value * {
- auto elementPtrTy =
- rewriter.getType<LLVMType>(getPtrToElementType(type, lowering));
+ auto elementPtrTy = getPtrToElementType(type, lowering);
return extractvalue(elementPtrTy, view, pos(0));
};
PatternRewriter &rewriter) const override {
auto viewOp = cast<ViewOp>(op);
auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering);
- auto elementTy = rewriter.getType<LLVMType>(
- getPtrToElementType(viewOp.getViewType(), lowering));
+ auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering);
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
auto pos = [&rewriter](ArrayRef<int> values) {