unsigned getPointerBitwidth(unsigned addressSpace = 0);
protected:
- /// LLVM IR module used to parse/create types.
- llvm::Module *module;
+ /// Pointer to the LLVM dialect.
LLVM::LLVMDialect *llvmDialect;
private:
/// Returns the LLVM IR context.
llvm::LLVMContext &getContext() const;
- /// Returns the LLVM IR module associated with the LLVM dialect.
- llvm::Module &getModule() const;
-
/// Gets the MLIR type wrapping the LLVM integer type whose bit width is
/// defined by the used type converter.
LLVM::LLVMType getIndexType() const;
ConversionPatternRewriter &rewriter) const;
Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
- ValueRange indices, ConversionPatternRewriter &rewriter,
- llvm::Module &module) const;
+ ValueRange indices,
+ ConversionPatternRewriter &rewriter) const;
/// Returns the type of a pointer to an element of the memref.
Type getElementPtrType(MemRefType type) const;
private:
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
- llvm::LLVMContext &getLLVMContext() {
- return getLLVMDialect()->getLLVMContext();
- }
-
void initializeCachedTypes() {
- const llvm::Module &module = llvmDialect->getLLVMModule();
llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
llvmPointerPointerType = llvmPointerType.getPointerTo();
llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
llvmIntPtrType = LLVM::LLVMType::getIntNTy(
- llvmDialect, module.getDataLayout().getPointerSizeInBits());
+ llvmDialect, llvmDialect->getDataLayout().getPointerSizeInBits());
}
LLVM::LLVMType getVoidType() { return llvmVoidType; }
LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
LLVM::LLVMType getIntPtrType() {
- const llvm::Module &module = getLLVMDialect()->getLLVMModule();
return LLVM::LLVMType::getIntNTy(
- getLLVMDialect(), module.getDataLayout().getPointerSizeInBits());
+ getLLVMDialect(),
+ getLLVMDialect()->getDataLayout().getPointerSizeInBits());
}
// Allocate a void pointer on the stack.
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()),
options(options) {
assert(llvmDialect && "LLVM IR dialect is not registered");
- module = &llvmDialect->getLLVMModule();
if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
this->options.indexBitwidth =
- module->getDataLayout().getPointerSizeInBits();
+ llvmDialect->getDataLayout().getPointerSizeInBits();
// Register conversions for the standard types.
addConversion([&](ComplexType type) { return convertComplexType(type); });
/// Get the LLVM context.
llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
- return module->getContext();
+ return llvmDialect->getLLVMContext();
}
LLVM::LLVMType LLVMTypeConverter::getIndexType() {
}
unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
- return module->getDataLayout().getPointerSizeInBits(addressSpace);
+ return llvmDialect->getDataLayout().getPointerSizeInBits(addressSpace);
}
Type LLVMTypeConverter::convertIndexType(IndexType type) {
return typeConverter.getLLVMContext();
}
-llvm::Module &ConvertToLLVMPattern::getModule() const {
- return getDialect().getLLVMModule();
-}
-
LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
return typeConverter.getIndexType();
}
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
}
-Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type,
- Value memRefDesc, ValueRange indices,
- ConversionPatternRewriter &rewriter,
- llvm::Module &module) const {
+Value ConvertToLLVMPattern::getDataPtr(
+ Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
+ ConversionPatternRewriter &rewriter) const {
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto type = loadOp.getMemRefType();
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
- transformed.indices(), rewriter, getModule());
+ transformed.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
return success();
}
StoreOp::Adaptor transformed(operands);
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
- transformed.indices(), rewriter, getModule());
+ transformed.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
dataPtr);
return success();
auto type = prefetchOp.getMemRefType();
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
- transformed.indices(), rewriter, getModule());
+ transformed.indices(), rewriter);
// Replace with llvm.prefetch.
auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
auto resultType = adaptor.value().getType();
auto memRefType = atomicOp.getMemRefType();
auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(),
- adaptor.indices(), rewriter, getModule());
+ adaptor.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
op, resultType, *maybeKind, dataPtr, adaptor.value(),
LLVM::AtomicOrdering::acq_rel);
rewriter.setInsertionPointToEnd(initBlock);
auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
- adaptor.indices(), rewriter, getModule());
+ adaptor.indices(), rewriter);
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);