auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
auto memRefShape = memRefType.getShape();
auto loc = op->getLoc();
- auto *llvmDialect =
- op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
- assert(llvmDialect && "expected llvm dialect to be registered");
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
// Get a symbol reference to the printf function, inserting it if necessary.
- auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect);
+ auto printfRef = getOrInsertPrintf(rewriter, parentModule);
Value formatSpecifierCst = getOrCreateGlobalString(
- loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule,
- llvmDialect);
+ loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule);
Value newLineCst = getOrCreateGlobalString(
- loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect);
+ loc, rewriter, "nl", StringRef("\n\0", 2), parentModule);
// Create a loop for each of the dimensions within the shape.
SmallVector<Value, 4> loopIvs;
/// Return a symbol reference to the printf function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
- ModuleOp module,
- LLVM::LLVMDialect *llvmDialect) {
+ ModuleOp module) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
return SymbolRefAttr::get("printf", context);
// Create a function declaration for printf, the signature is:
// * `i32 (i8*, ...)`
- auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
- auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
+ auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(context);
+ auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
/*isVarArg=*/true);
/// name, creating the string if necessary.
static Value getOrCreateGlobalString(Location loc, OpBuilder &builder,
StringRef name, StringRef value,
- ModuleOp module,
- LLVM::LLVMDialect *llvmDialect) {
+ ModuleOp module) {
// Create the global at the entry of the module.
LLVM::GlobalOp global;
if (!(global = module.lookupSymbol<LLVM::GlobalOp>(name))) {
OpBuilder::InsertionGuard insertGuard(builder);
builder.setInsertionPointToStart(module.getBody());
auto type = LLVM::LLVMType::getArrayTy(
- LLVM::LLVMType::getInt8Ty(llvmDialect), value.size());
+ LLVM::LLVMType::getInt8Ty(builder.getContext()), value.size());
global = builder.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::Internal, name,
builder.getStringAttr(value));
// Get the pointer to the first character in the global string.
Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
Value cst0 = builder.create<LLVM::ConstantOp>(
- loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
+ loc, LLVM::LLVMType::getInt64Ty(builder.getContext()),
builder.getIntegerAttr(builder.getIndexType(), 0));
return builder.create<LLVM::GEPOp>(
- loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
+ loc, LLVM::LLVMType::getInt8PtrTy(builder.getContext()), globalPtr,
ArrayRef<Value>({cst0, cst0}));
}
};
auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
auto memRefShape = memRefType.getShape();
auto loc = op->getLoc();
- auto *llvmDialect =
- op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
- assert(llvmDialect && "expected llvm dialect to be registered");
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
// Get a symbol reference to the printf function, inserting it if necessary.
- auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect);
+ auto printfRef = getOrInsertPrintf(rewriter, parentModule);
Value formatSpecifierCst = getOrCreateGlobalString(
- loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule,
- llvmDialect);
+ loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule);
Value newLineCst = getOrCreateGlobalString(
- loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect);
+ loc, rewriter, "nl", StringRef("\n\0", 2), parentModule);
// Create a loop for each of the dimensions within the shape.
SmallVector<Value, 4> loopIvs;
/// Return a symbol reference to the printf function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
- ModuleOp module,
- LLVM::LLVMDialect *llvmDialect) {
+ ModuleOp module) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
return SymbolRefAttr::get("printf", context);
// Create a function declaration for printf, the signature is:
// * `i32 (i8*, ...)`
- auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
- auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
+ auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(context);
+ auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
/*isVarArg=*/true);
/// name, creating the string if necessary.
static Value getOrCreateGlobalString(Location loc, OpBuilder &builder,
StringRef name, StringRef value,
- ModuleOp module,
- LLVM::LLVMDialect *llvmDialect) {
+ ModuleOp module) {
// Create the global at the entry of the module.
LLVM::GlobalOp global;
if (!(global = module.lookupSymbol<LLVM::GlobalOp>(name))) {
OpBuilder::InsertionGuard insertGuard(builder);
builder.setInsertionPointToStart(module.getBody());
auto type = LLVM::LLVMType::getArrayTy(
- LLVM::LLVMType::getInt8Ty(llvmDialect), value.size());
+ LLVM::LLVMType::getInt8Ty(builder.getContext()), value.size());
global = builder.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::Internal, name,
builder.getStringAttr(value));
// Get the pointer to the first character in the global string.
Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
Value cst0 = builder.create<LLVM::ConstantOp>(
- loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
+ loc, LLVM::LLVMType::getInt64Ty(builder.getContext()),
builder.getIntegerAttr(builder.getIndexType(), 0));
return builder.create<LLVM::GEPOp>(
- loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
+ loc, LLVM::LLVMType::getInt8PtrTy(builder.getContext()), globalPtr,
ArrayRef<Value>({cst0, cst0}));
}
};
/// global and use it to compute the address of the first character in the
/// string (operations inserted at the builder insertion point).
Value createGlobalString(Location loc, OpBuilder &builder, StringRef name,
- StringRef value, LLVM::Linkage linkage,
- LLVM::LLVMDialect *llvmDialect);
+ StringRef value, LLVM::Linkage linkage);
/// LLVM requires some operations to be inside of a Module operation. This
/// function confirms that the Operation has the desired properties.
"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy(" # width # ")">]>,
"LLVM dialect " # width # "-bit integer">,
BuildableType<
- "::mlir::LLVM::LLVMType::getIntNTy("
- "$_builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(),"
+ "::mlir::LLVM::LLVMType::getIntNTy($_builder.getContext(),"
# width # ")">;
def LLVMI1 : LLVMI<1>;
let builders = [OpBuilder<
"OpBuilder &b, OperationState &result, ICmpPredicate predicate, Value lhs, "
"Value rhs", [{
- LLVMDialect *dialect = &lhs.getType().cast<LLVMType>().getDialect();
- build(b, result, LLVMType::getInt1Ty(dialect),
+ build(b, result, LLVMType::getInt1Ty(lhs.getType().getContext()),
b.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
}]>];
let parser = [{ return parseCmpOp<ICmpPredicate>(parser, result); }];
let builders = [OpBuilder<
"OpBuilder &b, OperationState &result, FCmpPredicate predicate, Value lhs, "
"Value rhs", [{
- LLVMDialect *dialect = &lhs.getType().cast<LLVMType>().getDialect();
- build(b, result, LLVMType::getInt1Ty(dialect),
+ build(b, result, LLVMType::getInt1Ty(lhs.getType().getContext()),
b.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
}]>];
let parser = [{ return parseCmpOp<FCmpPredicate>(parser, result); }];
bool isStructTy();
/// Utilities used to generate floating point types.
- static LLVMType getDoubleTy(LLVMDialect *dialect);
- static LLVMType getFloatTy(LLVMDialect *dialect);
- static LLVMType getBFloatTy(LLVMDialect *dialect);
- static LLVMType getHalfTy(LLVMDialect *dialect);
- static LLVMType getFP128Ty(LLVMDialect *dialect);
- static LLVMType getX86_FP80Ty(LLVMDialect *dialect);
+ static LLVMType getDoubleTy(MLIRContext *context);
+ static LLVMType getFloatTy(MLIRContext *context);
+ static LLVMType getBFloatTy(MLIRContext *context);
+ static LLVMType getHalfTy(MLIRContext *context);
+ static LLVMType getFP128Ty(MLIRContext *context);
+ static LLVMType getX86_FP80Ty(MLIRContext *context);
/// Utilities used to generate integer types.
- static LLVMType getIntNTy(LLVMDialect *dialect, unsigned numBits);
- static LLVMType getInt1Ty(LLVMDialect *dialect) {
- return getIntNTy(dialect, /*numBits=*/1);
+ static LLVMType getIntNTy(MLIRContext *context, unsigned numBits);
+ static LLVMType getInt1Ty(MLIRContext *context) {
+ return getIntNTy(context, /*numBits=*/1);
}
- static LLVMType getInt8Ty(LLVMDialect *dialect) {
- return getIntNTy(dialect, /*numBits=*/8);
+ static LLVMType getInt8Ty(MLIRContext *context) {
+ return getIntNTy(context, /*numBits=*/8);
}
- static LLVMType getInt8PtrTy(LLVMDialect *dialect) {
- return getInt8Ty(dialect).getPointerTo();
+ static LLVMType getInt8PtrTy(MLIRContext *context) {
+ return getInt8Ty(context).getPointerTo();
}
- static LLVMType getInt16Ty(LLVMDialect *dialect) {
- return getIntNTy(dialect, /*numBits=*/16);
+ static LLVMType getInt16Ty(MLIRContext *context) {
+ return getIntNTy(context, /*numBits=*/16);
}
- static LLVMType getInt32Ty(LLVMDialect *dialect) {
- return getIntNTy(dialect, /*numBits=*/32);
+ static LLVMType getInt32Ty(MLIRContext *context) {
+ return getIntNTy(context, /*numBits=*/32);
}
- static LLVMType getInt64Ty(LLVMDialect *dialect) {
- return getIntNTy(dialect, /*numBits=*/64);
+ static LLVMType getInt64Ty(MLIRContext *context) {
+ return getIntNTy(context, /*numBits=*/64);
}
/// Utilities used to generate other miscellaneous types.
static LLVMType getFunctionTy(LLVMType result, bool isVarArg) {
return getFunctionTy(result, llvm::None, isVarArg);
}
- static LLVMType getStructTy(LLVMDialect *dialect, ArrayRef<LLVMType> elements,
+ static LLVMType getStructTy(MLIRContext *context, ArrayRef<LLVMType> elements,
bool isPacked = false);
- static LLVMType getStructTy(LLVMDialect *dialect, bool isPacked = false) {
- return getStructTy(dialect, llvm::None, isPacked);
+ static LLVMType getStructTy(MLIRContext *context, bool isPacked = false) {
+ return getStructTy(context, llvm::None, isPacked);
}
template <typename... Args>
static typename std::enable_if<llvm::are_base_of<LLVMType, Args...>::value,
LLVMType>::type
getStructTy(LLVMType elt1, Args... elts) {
SmallVector<LLVMType, 8> fields({elt1, elts...});
- return getStructTy(&elt1.getDialect(), fields);
+ return getStructTy(elt1.getContext(), fields);
}
static LLVMType getVectorTy(LLVMType elementType, unsigned numElements);
/// Void type utilities.
- static LLVMType getVoidTy(LLVMDialect *dialect);
+ static LLVMType getVoidTy(MLIRContext *context);
bool isVoidTy();
// Creation and setting of LLVM's identified struct types
- static LLVMType createStructTy(LLVMDialect *dialect,
+ static LLVMType createStructTy(MLIRContext *context,
ArrayRef<LLVMType> elements,
Optional<StringRef> name,
bool isPacked = false);
- static LLVMType createStructTy(LLVMDialect *dialect,
+ static LLVMType createStructTy(MLIRContext *context,
Optional<StringRef> name) {
- return createStructTy(dialect, llvm::None, name);
+ return createStructTy(context, llvm::None, name);
}
static LLVMType createStructTy(ArrayRef<LLVMType> elements,
assert(!elements.empty() &&
"This method may not be invoked with an empty list");
LLVMType ele0 = elements.front();
- return createStructTy(&ele0.getDialect(), elements, name, isPacked);
+ return createStructTy(ele0.getContext(), elements, name, isPacked);
}
template <typename... Args>
createStructTy(StringRef name, LLVMType elt1, Args... elts) {
SmallVector<LLVMType, 8> fields({elt1, elts...});
Optional<StringRef> opt_name(name);
- return createStructTy(&elt1.getDialect(), fields, opt_name);
+ return createStructTy(elt1.getContext(), fields, opt_name);
}
static LLVMType setStructTyBody(LLVMType structType,
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
void initializeCachedTypes() {
- llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
- llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
+ llvmVoidType = LLVM::LLVMType::getVoidTy(&getContext());
+ llvmPointerType = LLVM::LLVMType::getInt8PtrTy(&getContext());
llvmPointerPointerType = llvmPointerType.getPointerTo();
- llvmInt8Type = LLVM::LLVMType::getInt8Ty(llvmDialect);
- llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
- llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
+ llvmInt8Type = LLVM::LLVMType::getInt8Ty(&getContext());
+ llvmInt32Type = LLVM::LLVMType::getInt32Ty(&getContext());
+ llvmInt64Type = LLVM::LLVMType::getInt64Ty(&getContext());
llvmIntPtrType = LLVM::LLVMType::getIntNTy(
- llvmDialect, llvmDialect->getDataLayout().getPointerSizeInBits());
+ &getContext(), llvmDialect->getDataLayout().getPointerSizeInBits());
}
LLVM::LLVMType getVoidType() { return llvmVoidType; }
LLVM::LLVMType getIntPtrType() {
return LLVM::LLVMType::getIntNTy(
- getLLVMDialect(),
+ &getContext(),
getLLVMDialect()->getDataLayout().getPointerSizeInBits());
}
std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
return LLVM::createGlobalString(
loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
- LLVM::Linkage::Internal, llvmDialect);
+ LLVM::Linkage::Internal);
}
// Emits LLVM IR to launch a kernel function. Expects the module that contains
SmallString<128> nameBuffer(kernelModule.getName());
nameBuffer.append(kGpuBinaryStorageSuffix);
- Value data = LLVM::createGlobalString(
- loc, builder, nameBuffer.str(), binaryAttr.getValue(),
- LLVM::Linkage::Internal, getLLVMDialect());
+ Value data =
+ LLVM::createGlobalString(loc, builder, nameBuffer.str(),
+ binaryAttr.getValue(), LLVM::Linkage::Internal);
// Emit the load module call to load the module data. Error checking is done
// in the called helper function.
// Rewrite workgroup memory attributions to addresses of global buffers.
rewriter.setInsertionPointToStart(&gpuFuncOp.front());
unsigned numProperArguments = gpuFuncOp.getNumArguments();
- auto i32Type = LLVM::LLVMType::getInt32Ty(typeConverter.getDialect());
+ auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
Value zero = nullptr;
if (!workgroupBuffers.empty())
// Rewrite private memory attributions to alloca'ed buffers.
unsigned numWorkgroupAttributions =
gpuFuncOp.getNumWorkgroupAttributions();
- auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
+ auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
Value attribution = en.value();
auto type = attribution.getType().cast<MemRefType>();
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
- auto dialect = typeConverter.getDialect();
+ MLIRContext *context = rewriter.getContext();
Value newOp;
switch (dimensionToIndex(cast<Op>(op))) {
case X:
- newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+ newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(context));
break;
case Y:
- newOp = rewriter.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+ newOp = rewriter.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(context));
break;
case Z:
- newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+ newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(context));
break;
default:
return failure();
if (indexBitwidth > 32) {
newOp = rewriter.create<LLVM::SExtOp>(
- loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
+ loc, LLVM::LLVMType::getIntNTy(context, indexBitwidth), newOp);
} else if (indexBitwidth < 32) {
newOp = rewriter.create<LLVM::TruncOp>(
- loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
+ loc, LLVM::LLVMType::getIntNTy(context, indexBitwidth), newOp);
}
rewriter.replaceOp(op, {newOp});
return operand;
return rewriter.create<LLVM::FPExtOp>(
- operand.getLoc(), LLVM::LLVMType::getFloatTy(&type.getDialect()),
+ operand.getLoc(), LLVM::LLVMType::getFloatTy(rewriter.getContext()),
operand);
}
Location loc = op->getLoc();
gpu::ShuffleOpAdaptor adaptor(operands);
- auto dialect = typeConverter.getDialect();
auto valueTy = adaptor.value().getType().cast<LLVM::LLVMType>();
- auto int32Type = LLVM::LLVMType::getInt32Ty(dialect);
- auto predTy = LLVM::LLVMType::getInt1Ty(dialect);
- auto resultTy = LLVM::LLVMType::getStructTy(dialect, {valueTy, predTy});
+ auto int32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
+ auto predTy = LLVM::LLVMType::getInt1Ty(rewriter.getContext());
+ auto resultTy =
+ LLVM::LLVMType::getStructTy(rewriter.getContext(), {valueTy, predTy});
Value one = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(1));
: public ConvertVulkanLaunchFuncToVulkanCallsBase<
VulkanLaunchFuncToVulkanCallsPass> {
private:
- LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
-
void initializeCachedTypes() {
- llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
- llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
- llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
- llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
- llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
- llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
+ llvmFloatType = LLVM::LLVMType::getFloatTy(&getContext());
+ llvmVoidType = LLVM::LLVMType::getVoidTy(&getContext());
+ llvmPointerType = LLVM::LLVMType::getInt8PtrTy(&getContext());
+ llvmInt32Type = LLVM::LLVMType::getInt32Ty(&getContext());
+ llvmInt64Type = LLVM::LLVMType::getInt64Ty(&getContext());
}
LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) {
// `!llvm<"{ `element-type`*, `element-type`*, i64,
// [`rank` x i64], [`rank` x i64]}">`.
return LLVM::LLVMType::getStructTy(
- llvmDialect,
+ &getContext(),
{llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
}
void runOnOperation() override;
private:
- LLVM::LLVMDialect *llvmDialect;
LLVM::LLVMType llvmFloatType;
LLVM::LLVMType llvmVoidType;
LLVM::LLVMType llvmPointerType;
// int16_t and bitcast the descriptor.
if (type.isHalfTy()) {
auto memRefTy =
- getMemRefType(rank, LLVM::LLVMType::getInt16Ty(llvmDialect));
+ getMemRefType(rank, LLVM::LLVMType::getInt16Ty(&getContext()));
ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor);
}
}
for (unsigned i = 1; i <= 3; i++) {
- for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(llvmDialect),
- LLVM::LLVMType::getInt32Ty(llvmDialect),
- LLVM::LLVMType::getInt16Ty(llvmDialect),
- LLVM::LLVMType::getInt8Ty(llvmDialect),
- LLVM::LLVMType::getHalfTy(llvmDialect)}) {
+ for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(&getContext()),
+ LLVM::LLVMType::getInt32Ty(&getContext()),
+ LLVM::LLVMType::getInt16Ty(&getContext()),
+ LLVM::LLVMType::getInt8Ty(&getContext()),
+ LLVM::LLVMType::getHalfTy(&getContext())}) {
std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
std::string(stringifyType(type));
if (type.isHalfTy())
- type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(llvmDialect));
+ type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(&getContext()));
if (!module.lookupSymbol(fnName)) {
auto fnType = LLVM::LLVMType::getFunctionTy(
getVoidType(),
std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
- shaderName, LLVM::Linkage::Internal,
- getLLVMDialect());
+ shaderName, LLVM::Linkage::Internal);
}
void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
// that data to runtime call.
Value ptrToSPIRVBinary = LLVM::createGlobalString(
loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
- LLVM::Linkage::Internal, getLLVMDialect());
+ LLVM::Linkage::Internal);
// Create LLVM constant for the size of SPIR-V binary shader.
Value binarySize = builder.create<LLVM::ConstantOp>(
llvm::map_range(type.getElementTypes(), [&](Type elementType) {
return converter.convertType(elementType).cast<LLVM::LLVMType>();
}));
- return LLVM::LLVMType::getStructTy(converter.getDialect(), elementsVector,
+ return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector,
/*isPacked=*/true);
}
/// Creates LLVM dialect constant with the given value.
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
- LLVMTypeConverter &converter, unsigned value) {
+ unsigned value) {
return rewriter.create<LLVM::ConstantOp>(
- loc, LLVM::LLVMType::getInt32Ty(converter.getDialect()),
+ loc, LLVM::LLVMType::getInt32Ty(rewriter.getContext()),
rewriter.getIntegerAttr(rewriter.getI32Type(), value));
}
return failure();
Location loc = varOp.getLoc();
- Value size = createI32ConstantOf(loc, rewriter, typeConverter, 1);
+ Value size = createI32ConstantOf(loc, rewriter, 1);
if (!init) {
rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size);
return success();
}
LLVM::LLVMType LLVMTypeConverter::getIndexType() {
- return LLVM::LLVMType::getIntNTy(llvmDialect, getIndexTypeBitwidth());
+ return LLVM::LLVMType::getIntNTy(&getContext(), getIndexTypeBitwidth());
}
unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
}
Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
- return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth());
+ return LLVM::LLVMType::getIntNTy(&getContext(), type.getWidth());
}
Type LLVMTypeConverter::convertFloatType(FloatType type) {
switch (type.getKind()) {
case mlir::StandardTypes::F32:
- return LLVM::LLVMType::getFloatTy(llvmDialect);
+ return LLVM::LLVMType::getFloatTy(&getContext());
case mlir::StandardTypes::F64:
- return LLVM::LLVMType::getDoubleTy(llvmDialect);
+ return LLVM::LLVMType::getDoubleTy(&getContext());
case mlir::StandardTypes::F16:
- return LLVM::LLVMType::getHalfTy(llvmDialect);
+ return LLVM::LLVMType::getHalfTy(&getContext());
case mlir::StandardTypes::BF16: {
- return LLVM::LLVMType::getBFloatTy(llvmDialect);
+ return LLVM::LLVMType::getBFloatTy(&getContext());
}
default:
llvm_unreachable("non-float type in convertFloatType");
static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
Type LLVMTypeConverter::convertComplexType(ComplexType type) {
auto elementType = convertType(type.getElementType()).cast<LLVM::LLVMType>();
- return LLVM::LLVMType::getStructTy(llvmDialect, {elementType, elementType});
+ return LLVM::LLVMType::getStructTy(&getContext(), {elementType, elementType});
}
// Except for signatures, MLIR function types are converted into LLVM
/// In signatures, unranked MemRef descriptors are expanded into a pair "rank,
/// pointer to descriptor".
SmallVector<Type, 2> LLVMTypeConverter::convertUnrankedMemRefSignature() {
- return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(llvmDialect)};
+ return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(&getContext())};
}
// Function types are converted to LLVM Function types by recursively converting
// a struct.
LLVM::LLVMType resultType =
type.getNumResults() == 0
- ? LLVM::LLVMType::getVoidTy(llvmDialect)
+ ? LLVM::LLVMType::getVoidTy(&getContext())
: unwrap(packFunctionResults(type.getResults()));
if (!resultType)
return {};
LLVM::LLVMType resultType =
type.getNumResults() == 0
- ? LLVM::LLVMType::getVoidTy(llvmDialect)
+ ? LLVM::LLVMType::getVoidTy(&getContext())
: unwrap(packFunctionResults(type.getResults()));
if (!resultType)
return {};
Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
auto rankTy = getIndexType();
- auto ptrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
+ auto ptrTy = LLVM::LLVMType::getInt8PtrTy(&getContext());
return LLVM::LLVMType::getStructTy(rankTy, ptrTy);
}
}
LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const {
- return LLVM::LLVMType::getVoidTy(&getDialect());
+ return LLVM::LLVMType::getVoidTy(&typeConverter.getContext());
}
LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const {
- return LLVM::LLVMType::getInt8PtrTy(&getDialect());
+ return LLVM::LLVMType::getInt8PtrTy(&typeConverter.getContext());
}
Value ConvertToLLVMPattern::createIndexConstant(
unrankedMemrefs, sizes);
// Get frequently used types.
- auto voidType = LLVM::LLVMType::getVoidTy(typeConverter.getDialect());
- auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(typeConverter.getDialect());
- auto i1Type = LLVM::LLVMType::getInt1Ty(typeConverter.getDialect());
+ MLIRContext *context = builder.getContext();
+ auto voidType = LLVM::LLVMType::getVoidTy(context);
+ auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(context);
+ auto i1Type = LLVM::LLVMType::getInt1Ty(context);
LLVM::LLVMType indexType = typeConverter.getIndexType();
// Find the malloc and free, or declare them if necessary.
// Append the cmpxchg op to the end of the loop block.
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
- auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect());
+ auto boolType = LLVM::LLVMType::getInt1Ty(rewriter.getContext());
auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, pairType, dataPtr, loopArgument, result, successOrdering,
resultTypes.push_back(converted);
}
- return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes);
+ return LLVM::LLVMType::getStructTy(&getContext(), resultTypes);
}
Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
OpBuilder &builder) {
auto *context = builder.getContext();
- auto int64Ty = LLVM::LLVMType::getInt64Ty(getDialect());
+ auto int64Ty = LLVM::LLVMType::getInt64Ty(builder.getContext());
auto indexType = IndexType::get(context);
// Alloca with proper alignment. We do not expect optimizations of this
// alloca op and so we omit allocating at the entry block.
// Remaining extraction of element from 1-D LLVM vector
auto position = positionAttrs.back().cast<IntegerAttr>();
- auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
+ auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
}
// Insertion of an element into a 1-D LLVM vector.
- auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
+ auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
Value inserted = rewriter.create<LLVM::InsertElementOp>(
loc, typeConverter.convertType(oneDVectorType), extracted,
if (failed(successStrides) || !isContiguous)
return failure();
- auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
+ auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
// Create descriptor.
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
int64_t rank) const {
Location loc = op->getLoc();
if (rank == 0) {
- if (value.getType() ==
- LLVM::LLVMType::getInt1Ty(typeConverter.getDialect())) {
+ if (value.getType() == LLVM::LLVMType::getInt1Ty(rewriter.getContext())) {
// Convert i1 (bool) to i32 so we can use the print_i32 method.
// This avoids the need for a print_i1 method with an unclear ABI.
- auto i32Type = LLVM::LLVMType::getInt32Ty(typeConverter.getDialect());
+ auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
auto trueVal = rewriter.create<ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(1));
auto falseVal = rewriter.create<ConstantOp>(
}
// Helper for printer method declaration (first hit) and lookup.
- static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect,
- StringRef name, ArrayRef<LLVM::LLVMType> params) {
+ static Operation *getPrint(Operation *op, StringRef name,
+ ArrayRef<LLVM::LLVMType> params) {
auto module = op->getParentOfType<ModuleOp>();
auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
if (func)
OpBuilder moduleBuilder(module.getBodyRegion());
return moduleBuilder.create<LLVM::LLVMFuncOp>(
op->getLoc(), name,
- LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect),
- params, /*isVarArg=*/false));
+ LLVM::LLVMType::getFunctionTy(
+ LLVM::LLVMType::getVoidTy(op->getContext()), params,
+ /*isVarArg=*/false));
}
// Helpers for method names.
Operation *getPrintI32(Operation *op) const {
- LLVM::LLVMDialect *dialect = typeConverter.getDialect();
- return getPrint(op, dialect, "print_i32",
- LLVM::LLVMType::getInt32Ty(dialect));
+ return getPrint(op, "print_i32",
+ LLVM::LLVMType::getInt32Ty(op->getContext()));
}
Operation *getPrintI64(Operation *op) const {
- LLVM::LLVMDialect *dialect = typeConverter.getDialect();
- return getPrint(op, dialect, "print_i64",
- LLVM::LLVMType::getInt64Ty(dialect));
+ return getPrint(op, "print_i64",
+ LLVM::LLVMType::getInt64Ty(op->getContext()));
}
Operation *getPrintFloat(Operation *op) const {
- LLVM::LLVMDialect *dialect = typeConverter.getDialect();
- return getPrint(op, dialect, "print_f32",
- LLVM::LLVMType::getFloatTy(dialect));
+ return getPrint(op, "print_f32",
+ LLVM::LLVMType::getFloatTy(op->getContext()));
}
Operation *getPrintDouble(Operation *op) const {
- LLVM::LLVMDialect *dialect = typeConverter.getDialect();
- return getPrint(op, dialect, "print_f64",
- LLVM::LLVMType::getDoubleTy(dialect));
+ return getPrint(op, "print_f64",
+ LLVM::LLVMType::getDoubleTy(op->getContext()));
}
Operation *getPrintOpen(Operation *op) const {
- return getPrint(op, typeConverter.getDialect(), "print_open", {});
+ return getPrint(op, "print_open", {});
}
Operation *getPrintClose(Operation *op) const {
- return getPrint(op, typeConverter.getDialect(), "print_close", {});
+ return getPrint(op, "print_close", {});
}
Operation *getPrintComma(Operation *op) const {
- return getPrint(op, typeConverter.getDialect(), "print_comma", {});
+ return getPrint(op, "print_comma", {});
}
Operation *getPrintNewline(Operation *op) const {
- return getPrint(op, typeConverter.getDialect(), "print_newline", {});
+ return getPrint(op, "print_newline", {});
}
};
// The result type is either i1 or a vector type <? x i1> if the inputs are
// vectors.
- auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>();
- auto resultType = LLVMType::getInt1Ty(dialect);
+ auto resultType = LLVMType::getInt1Ty(builder.getContext());
auto argType = type.dyn_cast<LLVM::LLVMType>();
if (!argType)
return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type");
return parser.emitError(trailingTypeLoc,
"expected function with 0 or 1 result");
- auto *llvmDialect =
- builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
LLVM::LLVMType llvmResultType;
if (funcType.getNumResults() == 0) {
- llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect);
+ llvmResultType = LLVM::LLVMType::getVoidTy(builder.getContext());
} else {
llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
if (!llvmResultType)
"expected function with 0 or 1 result");
Builder &builder = parser.getBuilder();
- auto *llvmDialect =
- builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
LLVM::LLVMType llvmResultType;
if (funcType.getNumResults() == 0) {
- llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect);
+ llvmResultType = LLVM::LLVMType::getVoidTy(builder.getContext());
} else {
llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
if (!llvmResultType)
if (types.empty()) {
if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) {
MLIRContext *context = parser.getBuilder().getContext();
- auto *dialect = context->getRegisteredDialect<LLVMDialect>();
auto arrayType = LLVM::LLVMType::getArrayTy(
- LLVM::LLVMType::getInt8Ty(dialect), strAttr.getValue().size());
+ LLVM::LLVMType::getInt8Ty(context), strAttr.getValue().size());
types.push_back(arrayType);
} else {
return parser.emitError(parser.getNameLoc(),
llvmInputs.push_back(llvmTy);
}
- // Get the dialect from the input type, if any exist. Look it up in the
- // context otherwise.
- LLVMDialect *dialect =
- llvmInputs.empty() ? b.getContext()->getRegisteredDialect<LLVMDialect>()
- : &llvmInputs.front().getDialect();
-
// No output is denoted as "void" in LLVM type system.
- LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(dialect)
+ LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(b.getContext())
: outputs.front().dyn_cast<LLVMType>();
if (!llvmOutput) {
parser.emitError(loc, "failed to construct function type: expected LLVM "
parser.resolveOperand(val, type, result.operands))
return failure();
- auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>();
- auto boolType = LLVMType::getInt1Ty(dialect);
+ auto boolType = LLVMType::getInt1Ty(builder.getContext());
auto resultType = LLVMType::getStructTy(type, boolType);
result.addTypes(resultType);
Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
StringRef name, StringRef value,
- LLVM::Linkage linkage,
- LLVM::LLVMDialect *llvmDialect) {
+ LLVM::Linkage linkage) {
assert(builder.getInsertionBlock() &&
builder.getInsertionBlock()->getParentOp() &&
"expected builder to point to a block constrained in an op");
// Create the global at the entry of the module.
OpBuilder moduleBuilder(module.getBodyRegion());
- auto type = LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect),
- value.size());
+ MLIRContext *ctx = builder.getContext();
+ auto type =
+ LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(ctx), value.size());
auto global = moduleBuilder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, linkage, name,
builder.getStringAttr(value));
// Get the pointer to the first character in the global string.
Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
Value cst0 = builder.create<LLVM::ConstantOp>(
- loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
+ loc, LLVM::LLVMType::getInt64Ty(ctx),
builder.getIntegerAttr(builder.getIndexType(), 0));
- return builder.create<LLVM::GEPOp>(loc,
- LLVM::LLVMType::getInt8PtrTy(llvmDialect),
+ return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMType::getInt8PtrTy(ctx),
globalPtr, ArrayRef<Value>({cst0, cst0}));
}
//----------------------------------------------------------------------------//
// Utilities used to generate floating point types.
-LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) {
- return LLVMDoubleType::get(dialect->getContext());
+LLVMType LLVMType::getDoubleTy(MLIRContext *context) {
+ return LLVMDoubleType::get(context);
}
-LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) {
- return LLVMFloatType::get(dialect->getContext());
+LLVMType LLVMType::getFloatTy(MLIRContext *context) {
+ return LLVMFloatType::get(context);
}
-LLVMType LLVMType::getBFloatTy(LLVMDialect *dialect) {
- return LLVMBFloatType::get(dialect->getContext());
+LLVMType LLVMType::getBFloatTy(MLIRContext *context) {
+ return LLVMBFloatType::get(context);
}
-LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) {
- return LLVMHalfType::get(dialect->getContext());
+LLVMType LLVMType::getHalfTy(MLIRContext *context) {
+ return LLVMHalfType::get(context);
}
-LLVMType LLVMType::getFP128Ty(LLVMDialect *dialect) {
- return LLVMFP128Type::get(dialect->getContext());
+LLVMType LLVMType::getFP128Ty(MLIRContext *context) {
+ return LLVMFP128Type::get(context);
}
-LLVMType LLVMType::getX86_FP80Ty(LLVMDialect *dialect) {
- return LLVMX86FP80Type::get(dialect->getContext());
+LLVMType LLVMType::getX86_FP80Ty(MLIRContext *context) {
+ return LLVMX86FP80Type::get(context);
}
//----------------------------------------------------------------------------//
// Utilities used to generate integer types.
-LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) {
- return LLVMIntegerType::get(dialect->getContext(), numBits);
+LLVMType LLVMType::getIntNTy(MLIRContext *context, unsigned numBits) {
+ return LLVMIntegerType::get(context, numBits);
}
//----------------------------------------------------------------------------//
return LLVMFunctionType::get(result, params, isVarArg);
}
-LLVMType LLVMType::getStructTy(LLVMDialect *dialect,
+LLVMType LLVMType::getStructTy(MLIRContext *context,
ArrayRef<LLVMType> elements, bool isPacked) {
- return LLVMStructType::getLiteral(dialect->getContext(), elements, isPacked);
+ return LLVMStructType::getLiteral(context, elements, isPacked);
}
LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
//----------------------------------------------------------------------------//
// Void type utilities.
-LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) {
- return LLVMVoidType::get(dialect->getContext());
+LLVMType LLVMType::getVoidTy(MLIRContext *context) {
+ return LLVMVoidType::get(context);
}
bool LLVMType::isVoidTy() { return isa<LLVMVoidType>(); }
//----------------------------------------------------------------------------//
// Creation and setting of LLVM's identified struct types
-LLVMType LLVMType::createStructTy(LLVMDialect *dialect,
+LLVMType LLVMType::createStructTy(MLIRContext *context,
ArrayRef<LLVMType> elements,
Optional<StringRef> name, bool isPacked) {
assert(name.hasValue() &&
std::string stringName = stringNameBase.str();
unsigned counter = 0;
do {
- auto type =
- LLVMStructType::getIdentified(dialect->getContext(), stringName);
+ auto type = LLVMStructType::getIdentified(context, stringName);
if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
counter += 1;
stringName =
p << " : " << op->getResultTypes();
}
-static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
- return parser.getBuilder()
- .getContext()
- ->getRegisteredDialect<LLVM::LLVMDialect>();
-}
-
// <operation> ::=
// `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
// ({return_value_and_is_valid})? : result_type
break;
}
- auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser));
+ auto int32Ty = LLVM::LLVMType::getInt32Ty(parser.getBuilder().getContext());
return parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty},
parser.getNameLoc(), result.operands);
}
// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
OperationState &result) {
- auto llvmDialect = getLlvmDialect(parser);
- auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
- auto int1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
+ MLIRContext *context = parser.getBuilder().getContext();
+ auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
+ auto int1Ty = LLVM::LLVMType::getInt1Ty(context);
SmallVector<OpAsmParser::OperandType, 8> ops;
Type type;
}
static LogicalResult verify(MmaOp op) {
- auto dialect = op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
- auto f16Ty = LLVM::LLVMType::getHalfTy(dialect);
+ MLIRContext *context = op.getContext();
+ auto f16Ty = LLVM::LLVMType::getHalfTy(context);
auto f16x2Ty = LLVM::LLVMType::getVectorTy(f16Ty, 2);
- auto f32Ty = LLVM::LLVMType::getFloatTy(dialect);
+ auto f32Ty = LLVM::LLVMType::getFloatTy(context);
auto f16x2x4StructTy = LLVM::LLVMType::getStructTy(
- dialect, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
+ context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
auto f32x8StructTy = LLVM::LLVMType::getStructTy(
- dialect, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
+ context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
SmallVector<Type, 12> operand_types(op.getOperandTypes().begin(),
op.getOperandTypes().end());
// Parsing for ROCDL ops
//===----------------------------------------------------------------------===//
-static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
- return parser.getBuilder()
- .getContext()
- ->getRegisteredDialect<LLVM::LLVMDialect>();
-}
-
// <operation> ::=
// `llvm.amdgcn.buffer.load.* %rsrc, %vindex, %offset, %glc, %slc :
// result_type`
parser.addTypeToList(type, result.types))
return failure();
- auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser));
- auto int1Ty = LLVM::LLVMType::getInt1Ty(getLlvmDialect(parser));
+ MLIRContext *context = parser.getBuilder().getContext();
+ auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
+ auto int1Ty = LLVM::LLVMType::getInt1Ty(context);
auto i32x4Ty = LLVM::LLVMType::getVectorTy(int32Ty, 4);
return parser.resolveOperands(ops,
{i32x4Ty, int32Ty, int32Ty, int1Ty, int1Ty},
if (parser.parseOperandList(ops, 6) || parser.parseColonType(type))
return failure();
- auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser));
- auto int1Ty = LLVM::LLVMType::getInt1Ty(getLlvmDialect(parser));
+ MLIRContext *context = parser.getBuilder().getContext();
+ auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
+ auto int1Ty = LLVM::LLVMType::getInt1Ty(context);
auto i32x4Ty = LLVM::LLVMType::getVectorTy(int32Ty, 4);
if (parser.resolveOperands(ops,