/// Builds IR extracting the pointer to the first element of the size array.
static Value sizeBasePtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
- Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType);
+ Value memRefDescPtr,
+ LLVM::LLVMPointerType elemPtrPtrType);
/// Builds IR extracting the size[index] from the descriptor.
static Value size(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter, Value sizeBasePtr,
[{
auto llvmType = resultType.dyn_cast<LLVMType>(); (void)llvmType;
assert(llvmType && "result must be an LLVM type");
- assert(llvmType.isVoidTy() &&
+ assert(llvmType.isa<LLVMVoidType>() &&
"for zero-result operands, only 'void' is accepted as result type");
build($_builder, $_state, operands, attributes);
}]>;
OpBuilderDAG<(ins "Value":$addr, CArg<"unsigned", "0">:$alignment,
CArg<"bool", "false">:$isVolatile, CArg<"bool", "false">:$isNonTemporal),
[{
- auto type = addr.getType().cast<LLVMType>().getPointerElementTy();
+ auto type = addr.getType().cast<LLVMPointerType>().getElementType();
build($_builder, $_state, type, addr, alignment, isVolatile, isNonTemporal);
}]>,
OpBuilderDAG<(ins "Type":$t, "Value":$addr,
OpBuilderDAG<(ins "LLVMFuncOp":$func, "ValueRange":$operands,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
- LLVMType resultType = func.getType().getFunctionResultType();
- if (!resultType.isVoidTy())
+ LLVMType resultType = func.getType().getReturnType();
+ if (!resultType.isa<LLVM::LLVMVoidType>())
$_state.addTypes(resultType);
$_state.addAttribute("callee", $_builder.getSymbolRefAttr(func));
$_state.addAttributes(attributes);
OpBuilderDAG<(ins "Value":$v1, "Value":$v2, "ArrayAttr":$mask,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
let verifier = [{
- auto wrappedVectorType1 = v1().getType().cast<LLVMType>();
- auto wrappedVectorType2 = v2().getType().cast<LLVMType>();
- if (!wrappedVectorType2.isVectorTy())
- return emitOpError("expected LLVM IR Dialect vector type for operand #2");
- if (wrappedVectorType1.getVectorElementType() !=
- wrappedVectorType2.getVectorElementType())
+ auto wrappedVectorType1 = v1().getType().cast<LLVMVectorType>();
+ auto wrappedVectorType2 = v2().getType().cast<LLVMVectorType>();
+ if (wrappedVectorType1.getElementType() !=
+ wrappedVectorType2.getElementType())
return emitOpError("expected matching LLVM IR Dialect element types");
return success();
}];
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
build($_builder, $_state,
- global.getType().getPointerTo(global.addr_space()),
+ LLVM::LLVMPointerType::get(global.getType(), global.addr_space()),
global.sym_name(), attrs);}]>,
OpBuilderDAG<(ins "LLVMFuncOp":$func,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
build($_builder, $_state,
- func.getType().getPointerTo(), func.getName(), attrs);}]>
+ LLVM::LLVMPointerType::get(func.getType()), func.getName(), attrs);}]>
];
let extraClassDeclaration = [{
// to match the signature of the function.
Block *addEntryBlock();
- LLVMType getType() {
+ LLVMFunctionType getType() {
return (*this)->getAttrOfType<TypeAttr>(getTypeAttrName())
- .getValue().cast<LLVMType>();
+ .getValue().cast<LLVMFunctionType>();
}
bool isVarArg() {
- return getType().isFunctionVarArg();
+ return getType().isVarArg();
}
// Hook for OpTrait::FunctionLike, returns the number of function arguments`.
LLVMDialect &getDialect();
- /// Returns the size of a primitive type (including vectors) in bits, for
- /// example, the size of !llvm.i16 is 16 and the size of !llvm.vec<4 x i16>
- /// is 64. Returns 0 for non-primitive (aggregates such as struct) or types
- /// that don't have a size (such as void).
- llvm::TypeSize getPrimitiveSizeInBits();
-
- /// Floating-point type utilities.
- bool isBFloatTy() { return isa<LLVMBFloatType>(); }
- bool isHalfTy() { return isa<LLVMHalfType>(); }
- bool isFloatTy() { return isa<LLVMFloatType>(); }
- bool isDoubleTy() { return isa<LLVMDoubleType>(); }
- bool isFP128Ty() { return isa<LLVMFP128Type>(); }
- bool isX86_FP80Ty() { return isa<LLVMX86FP80Type>(); }
- bool isFloatingPointTy() {
- return isa<LLVMHalfType>() || isa<LLVMBFloatType>() ||
- isa<LLVMFloatType>() || isa<LLVMDoubleType>() ||
- isa<LLVMFP128Type>() || isa<LLVMX86FP80Type>();
- }
-
- /// Array type utilities.
- LLVMType getArrayElementType();
- unsigned getArrayNumElements();
- bool isArrayTy();
-
- /// Integer type utilities.
- bool isIntegerTy() { return isa<LLVMIntegerType>(); }
- bool isIntegerTy(unsigned bitwidth);
- unsigned getIntegerBitWidth();
-
- /// Vector type utilities.
- LLVMType getVectorElementType();
- unsigned getVectorNumElements();
- llvm::ElementCount getVectorElementCount();
- bool isVectorTy();
-
- /// Function type utilities.
- LLVMType getFunctionParamType(unsigned argIdx);
- unsigned getFunctionNumParams();
- LLVMType getFunctionResultType();
- bool isFunctionTy();
- bool isFunctionVarArg();
-
- /// Pointer type utilities.
- LLVMType getPointerTo(unsigned addrSpace = 0);
- LLVMType getPointerElementTy();
- bool isPointerTy();
-
- /// Struct type utilities.
- LLVMType getStructElementType(unsigned i);
- unsigned getStructNumElements();
- bool isStructTy();
-
/// Utilities used to generate floating point types.
static LLVMType getDoubleTy(MLIRContext *context);
static LLVMType getFloatTy(MLIRContext *context);
static LLVMType getInt8Ty(MLIRContext *context) {
return getIntNTy(context, /*numBits=*/8);
}
- static LLVMType getInt8PtrTy(MLIRContext *context) {
- return getInt8Ty(context).getPointerTo();
- }
+ static LLVMType getInt8PtrTy(MLIRContext *context);
static LLVMType getInt16Ty(MLIRContext *context) {
return getIntNTy(context, /*numBits=*/16);
}
/// Void type utilities.
static LLVMType getVoidTy(MLIRContext *context);
- bool isVoidTy();
// Creation and setting of LLVM's identified struct types
static LLVMType createStructTy(MLIRContext *context,
void printType(LLVMType type, DialectAsmPrinter &printer);
} // namespace detail
+//===----------------------------------------------------------------------===//
+// Utility functions.
+//===----------------------------------------------------------------------===//
+
+/// Returns `true` if the given type is compatible with the LLVM dialect.
+inline bool isCompatibleType(Type type) { return type.isa<LLVMType>(); }
+
+inline bool isCompatibleFloatingPointType(Type type) {
+ return type.isa<LLVMHalfType, LLVMBFloatType, LLVMFloatType, LLVMDoubleType,
+ LLVMFP128Type, LLVMX86FP80Type>();
+}
+
+/// Returns the size of the given primitive LLVM dialect-compatible type
+/// (including vectors) in bits, for example, the size of !llvm.i16 is 16 and
+/// the size of !llvm.vec<4 x i16> is 64. Returns 0 for non-primitive
+/// (aggregates such as struct) or types that don't have a size (such as void).
+llvm::TypeSize getPrimitiveTypeSizeInBits(Type type);
+
} // namespace LLVM
} // namespace mlir
let verifier = [{
if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
return success();
- auto type = getType().cast<LLVM::LLVMType>();
- if (!type.isStructTy() || type.getStructNumElements() != 2 ||
- !type.getStructElementType(1).isIntegerTy(
- /*Bitwidth=*/1))
+ auto type = getType().dyn_cast<LLVM::LLVMStructType>();
+ auto elementType = (type && type.getBody().size() == 2)
+ ? type.getBody()[1].dyn_cast<LLVM::LLVMIntegerType>()
+ : nullptr;
+ if (!elementType || elementType.getBitWidth() != 1)
return emitError("expected return type to be a two-element struct with "
"i1 as the second element");
return success();
static FunctionType executeFunctionType(MLIRContext *ctx) {
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
- auto resume = resumeFunctionType(ctx).getPointerTo();
+ auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
return FunctionType::get(ctx, {hdl, resume}, {});
}
static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
- auto resume = resumeFunctionType(ctx).getPointerTo();
+ auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
}
static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
- auto resume = resumeFunctionType(ctx).getPointerTo();
+ auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
}
// A pointer to coroutine resume intrinsic wrapper.
auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
auto resumePtr = builder.create<LLVM::AddressOfOp>(
- loc, resumeFnTy.getPointerTo(), kResume);
+ loc, LLVM::LLVMPointerType::get(resumeFnTy), kResume);
// Save the coroutine state: @llvm.coro.save
auto coroSave = builder.create<LLVM::CallOp>(
// A pointer to coroutine resume intrinsic wrapper.
auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
auto resumePtr = builder.create<LLVM::AddressOfOp>(
- loc, resumeFnTy.getPointerTo(), kResume);
+ loc, LLVM::LLVMPointerType::get(resumeFnTy), kResume);
// Save the coroutine state: @llvm.coro.save
auto coroSave = builder.create<LLVM::CallOp>(
FunctionCallBuilder(StringRef functionName, LLVM::LLVMType returnType,
ArrayRef<LLVM::LLVMType> argumentTypes)
: functionName(functionName),
- functionType(LLVM::LLVMType::getFunctionTy(returnType, argumentTypes,
- /*isVarArg=*/false)) {}
+ functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes,
+ /*isVarArg=*/false)) {}
LLVM::CallOp create(Location loc, OpBuilder &builder,
ArrayRef<Value> arguments) const;
private:
StringRef functionName;
- LLVM::LLVMType functionType;
+ LLVM::LLVMFunctionType functionType;
};
template <typename OpTy>
LLVM::LLVMType llvmVoidType = LLVM::LLVMType::getVoidTy(context);
LLVM::LLVMType llvmPointerType = LLVM::LLVMType::getInt8PtrTy(context);
- LLVM::LLVMType llvmPointerPointerType = llvmPointerType.getPointerTo();
+ LLVM::LLVMType llvmPointerPointerType =
+ LLVM::LLVMPointerType::get(llvmPointerType);
LLVM::LLVMType llvmInt8Type = LLVM::LLVMType::getInt8Ty(context);
LLVM::LLVMType llvmInt32Type = LLVM::LLVMType::getInt32Ty(context);
LLVM::LLVMType llvmInt64Type = LLVM::LLVMType::getInt64Ty(context);
.create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
}();
return builder.create<LLVM::CallOp>(
- loc, const_cast<LLVM::LLVMType &>(functionType).getFunctionResultType(),
+ loc, const_cast<LLVM::LLVMFunctionType &>(functionType).getReturnType(),
builder.getSymbolRefAttr(function), arguments);
}
auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
builder.getI32IntegerAttr(1));
auto structPtr = builder.create<LLVM::AllocaOp>(
- loc, structType.getPointerTo(), one, /*alignment=*/0);
+ loc, LLVM::LLVMPointerType::get(structType), one, /*alignment=*/0);
auto arraySize = builder.create<LLVM::ConstantOp>(
loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments));
auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType,
auto index = builder.create<LLVM::ConstantOp>(
loc, llvmInt32Type, builder.getI32IntegerAttr(en.index()));
auto fieldPtr = builder.create<LLVM::GEPOp>(
- loc, argumentTypes[en.index()].getPointerTo(), structPtr,
+ loc, LLVM::LLVMPointerType::get(argumentTypes[en.index()]), structPtr,
ArrayRef<Value>{zero, index.getResult()});
builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType,
// Rewrite the original GPU function to an LLVM function.
auto funcType = typeConverter->convertType(gpuFuncOp.getType())
- .template cast<LLVM::LLVMType>()
- .getPointerElementTy();
+ .template cast<LLVM::LLVMPointerType>()
+ .getElementType();
// Remap proper input types.
TypeConverter::SignatureConversion signatureConversion(
for (auto en : llvm::enumerate(workgroupBuffers)) {
LLVM::GlobalOp global = en.value();
Value address = rewriter.create<LLVM::AddressOfOp>(loc, global);
- auto elementType = global.getType().getArrayElementType();
+ auto elementType =
+ global.getType().cast<LLVM::LLVMArrayType>().getElementType();
Value memory = rewriter.create<LLVM::GEPOp>(
- loc, elementType.getPointerTo(global.addr_space()), address,
- ArrayRef<Value>{zero, zero});
+ loc, LLVM::LLVMPointerType::get(elementType, global.addr_space()),
+ address, ArrayRef<Value>{zero, zero});
// Build a memref descriptor pointing to the buffer to plug with the
// existing memref infrastructure. This may use more registers than
// Explicitly drop memory space when lowering private memory
// attributions since NVVM models it as `alloca`s in the default
// memory space and does not support `alloca`s with addrspace(5).
- auto ptrType = typeConverter->convertType(type.getElementType())
- .template cast<LLVM::LLVMType>()
- .getPointerTo(AllocaAddrSpace);
+ auto ptrType = LLVM::LLVMPointerType::get(
+ typeConverter->convertType(type.getElementType())
+ .template cast<LLVM::LLVMType>(),
+ AllocaAddrSpace);
Value numElements = rewriter.create<LLVM::ConstantOp>(
gpuFuncOp.getLoc(), int64Ty,
rewriter.getI64IntegerAttr(type.getNumElements()));
LLVMType resultType =
castedOperands.front().getType().cast<LLVM::LLVMType>();
LLVMType funcType = getFunctionType(resultType, castedOperands);
- StringRef funcName = getFunctionName(funcType.getFunctionResultType());
+ StringRef funcName = getFunctionName(
+ funcType.cast<LLVM::LLVMFunctionType>().getReturnType());
if (funcName.empty())
return failure();
private:
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
LLVM::LLVMType type = operand.getType().cast<LLVM::LLVMType>();
- if (!type.isHalfTy())
+ if (!type.isa<LLVM::LLVMHalfType>())
return operand;
return rewriter.create<LLVM::FPExtOp>(
}
StringRef getFunctionName(LLVM::LLVMType type) const {
- if (type.isFloatTy())
+ if (type.isa<LLVM::LLVMFloatType>())
return f32Func;
- if (type.isDoubleTy())
+ if (type.isa<LLVM::LLVMDoubleType>())
return f64Func;
return "";
}
// int64_t sizes[Rank]; // omitted when rank == 0
// int64_t strides[Rank]; // omitted when rank == 0
// };
- auto llvmPtrToElementType = elemenType.getPointerTo();
+ auto llvmPtrToElementType = LLVM::LLVMPointerType::get(elemenType);
auto llvmArrayRankElementSizeType =
LLVM::LLVMType::getArrayTy(getInt64Type(), rank);
/// Returns a string representation from the given `type`.
StringRef stringifyType(LLVM::LLVMType type) {
- if (type.isFloatTy())
+ if (type.isa<LLVM::LLVMFloatType>())
return "Float";
- if (type.isHalfTy())
+ if (type.isa<LLVM::LLVMHalfType>())
return "Half";
- if (type.isIntegerTy(32))
- return "Int32";
- if (type.isIntegerTy(16))
- return "Int16";
- if (type.isIntegerTy(8))
- return "Int8";
+ if (auto intType = type.dyn_cast<LLVM::LLVMIntegerType>()) {
+ if (intType.getBitWidth() == 32)
+ return "Int32";
+ if (intType.getBitWidth() == 16)
+ return "Int16";
+ if (intType.getBitWidth() == 8)
+ return "Int8";
+ }
llvm_unreachable("unsupported type");
}
llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
// Special case for fp16 type. Since it is not a supported type in C we use
// int16_t and bitcast the descriptor.
- if (type.isHalfTy()) {
+ if (type.isa<LLVM::LLVMHalfType>()) {
auto memRefTy =
getMemRefType(rank, LLVM::LLVMType::getInt16Ty(&getContext()));
ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
- loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor);
+ loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor);
}
// Create call to `bindMemRef`.
builder.create<LLVM::CallOp>(
LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) {
auto llvmPtrDescriptorTy =
- ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>();
+ ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMPointerType>();
if (!llvmPtrDescriptorTy)
return failure();
- auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy();
+ auto llvmDescriptorTy =
+ llvmPtrDescriptorTy.getElementType().dyn_cast<LLVM::LLVMStructType>();
// template <typename Elem, size_t Rank>
// struct {
// Elem *allocated;
// int64_t sizes[Rank]; // omitted when rank == 0
// int64_t strides[Rank]; // omitted when rank == 0
// };
- if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy())
+ if (!llvmDescriptorTy)
return failure();
- type = llvmDescriptorTy.getStructElementType(0).getPointerElementTy();
- if (llvmDescriptorTy.getStructNumElements() == 3) {
+ type = llvmDescriptorTy.getBody()[0]
+ .cast<LLVM::LLVMPointerType>()
+ .getElementType();
+ if (llvmDescriptorTy.getBody().size() == 3) {
rank = 0;
return success();
}
- rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements();
+ rank = llvmDescriptorTy.getBody()[3]
+ .cast<LLVM::LLVMArrayType>()
+ .getNumElements();
return success();
}
LLVM::LLVMType::getHalfTy(&getContext())}) {
std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
std::string(stringifyType(type));
- if (type.isHalfTy())
+ if (type.isa<LLVM::LLVMHalfType>())
type = LLVM::LLVMType::getInt16Ty(&getContext());
if (!module.lookupSymbol(fnName)) {
auto fnType = LLVM::LLVMType::getFunctionTy(
getVoidType(),
{getPointerType(), getInt32Type(), getInt32Type(),
- getMemRefType(i, type).getPointerTo()},
+ LLVM::LLVMPointerType::get(getMemRefType(i, type))},
/*isVarArg=*/false);
builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
}
/// Returns the bit width of LLVMType integer or vector.
static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) {
- return type.isVectorTy() ? type.getVectorElementType().getIntegerBitWidth()
- : type.getIntegerBitWidth();
+ auto vectorType = type.dyn_cast<LLVM::LLVMVectorType>();
+ return (vectorType ? vectorType.getElementType() : type)
+ .cast<LLVM::LLVMIntegerType>()
+ .getBitWidth();
}
/// Creates `IntegerAttribute` with all bits set for given type
TypeConverter &converter) {
auto pointeeType =
converter.convertType(type.getPointeeType()).cast<LLVM::LLVMType>();
- return pointeeType.getPointerTo();
+ return LLVM::LLVMPointerType::get(pointeeType);
}
/// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
SignatureConversion conversion(type.getNumInputs());
LLVM::LLVMType converted =
convertFunctionSignature(type, /*isVariadic=*/false, conversion);
- return converted.getPointerTo();
+ return LLVM::LLVMPointerType::get(converted);
}
if (!converted)
return {};
if (t.isa<MemRefType, UnrankedMemRefType>())
- converted = converted.getPointerTo();
+ converted = LLVM::LLVMPointerType::get(converted);
inputs.push_back(converted);
}
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
- auto ptrTy = elementType.getPointerTo(type.getMemorySpace());
+ auto ptrTy = LLVM::LLVMPointerType::get(elementType, type.getMemorySpace());
auto indexTy = getIndexType();
SmallVector<LLVM::LLVMType, 5> results = {ptrTy, ptrTy, indexTy};
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
- return elementType.getPointerTo(type.getMemorySpace());
+ return LLVM::LLVMPointerType::get(elementType, type.getMemorySpace());
}
// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
Value StructBuilder::extractPtr(OpBuilder &builder, Location loc,
unsigned pos) {
- Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
+ Type type = structType.cast<LLVM::LLVMStructType>().getBody()[pos];
return builder.create<LLVM::ExtractValueOp>(loc, type, value,
builder.getI64ArrayAttr(pos));
}
MemRefDescriptor::MemRefDescriptor(Value descriptor)
: StructBuilder(descriptor) {
assert(value != nullptr && "value cannot be null");
- indexType = value.getType().cast<LLVM::LLVMType>().getStructElementType(
- kOffsetPosInMemRefDescriptor);
+ indexType = value.getType()
+ .cast<LLVM::LLVMStructType>()
+ .getBody()[kOffsetPosInMemRefDescriptor];
}
/// Builds IR creating an `undef` value of the descriptor type.
Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
int64_t rank) {
auto indexTy = indexType.cast<LLVM::LLVMType>();
- auto indexPtrTy = indexTy.getPointerTo();
+ auto indexPtrTy = LLVM::LLVMPointerType::get(indexTy);
auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, rank);
- auto arrayPtrTy = arrayTy.getPointerTo();
+ auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy);
// Copy size values to stack-allocated memory.
auto zero = createIndexAttrConstant(builder, loc, indexType, 0);
LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
return value.getType()
- .cast<LLVM::LLVMType>()
- .getStructElementType(kAlignedPtrPosInMemRefDescriptor)
+ .cast<LLVM::LLVMStructType>()
+ .getBody()[kAlignedPtrPosInMemRefDescriptor]
.cast<LLVM::LLVMPointerType>();
}
Value offsetGep = builder.create<LLVM::GEPOp>(
loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
offsetGep = builder.create<LLVM::BitcastOp>(
- loc, typeConverter.getIndexType().getPointerTo(), offsetGep);
+ loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
return builder.create<LLVM::LoadOp>(loc, offsetGep);
}
Value offsetGep = builder.create<LLVM::GEPOp>(
loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
offsetGep = builder.create<LLVM::BitcastOp>(
- loc, typeConverter.getIndexType().getPointerTo(), offsetGep);
+ loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
builder.create<LLVM::StoreOp>(loc, offset, offsetGep);
}
-Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
- Value memRefDescPtr,
- LLVM::LLVMType elemPtrPtrType) {
- LLVM::LLVMType elemPtrTy = elemPtrPtrType.getPointerElementTy();
+Value UnrankedMemRefDescriptor::sizeBasePtr(
+ OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
+ Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType) {
+ LLVM::LLVMType elemPtrTy = elemPtrPtrType.getElementType();
LLVM::LLVMType indexTy = typeConverter.getIndexType();
- LLVM::LLVMType structPtrTy =
- LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy)
- .getPointerTo();
+ LLVM::LLVMType structPtrTy = LLVM::LLVMPointerType::get(
+ LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy));
Value structPtr =
builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
Value three = builder.create<LLVM::ConstantOp>(loc, int32_type,
builder.getI32IntegerAttr(3));
- return builder.create<LLVM::GEPOp>(loc, indexTy.getPointerTo(), structPtr,
- ValueRange({zero, three}));
+ return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMPointerType::get(indexTy),
+ structPtr, ValueRange({zero, three}));
}
Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter,
Value sizeBasePtr, Value index) {
- LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+ LLVM::LLVMType indexPtrTy =
+ LLVM::LLVMPointerType::get(typeConverter.getIndexType());
Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
ValueRange({index}));
return builder.create<LLVM::LoadOp>(loc, sizeStoreGep);
LLVMTypeConverter typeConverter,
Value sizeBasePtr, Value index,
Value size) {
- LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+ LLVM::LLVMType indexPtrTy =
+ LLVM::LLVMPointerType::get(typeConverter.getIndexType());
Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
ValueRange({index}));
builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value sizeBasePtr, Value rank) {
- LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+ LLVM::LLVMType indexPtrTy =
+ LLVM::LLVMPointerType::get(typeConverter.getIndexType());
return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
ValueRange({rank}));
}
LLVMTypeConverter typeConverter,
Value strideBasePtr, Value index,
Value stride) {
- LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+ LLVM::LLVMType indexPtrTy =
+ LLVM::LLVMPointerType::get(typeConverter.getIndexType());
Value strideStoreGep = builder.create<LLVM::GEPOp>(
loc, indexPtrTy, strideBasePtr, ValueRange({index}));
return builder.create<LLVM::LoadOp>(loc, strideStoreGep);
LLVMTypeConverter typeConverter,
Value strideBasePtr, Value index,
Value stride) {
- LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+ LLVM::LLVMType indexPtrTy =
+ LLVM::LLVMPointerType::get(typeConverter.getIndexType());
Value strideStoreGep = builder.create<LLVM::GEPOp>(
loc, indexPtrTy, strideBasePtr, ValueRange({index}));
builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
auto elementType = type.getElementType();
auto structElementType = unwrap(typeConverter->convertType(elementType));
- return structElementType.getPointerTo(type.getMemorySpace());
+ return LLVM::LLVMPointerType::get(structElementType, type.getMemorySpace());
}
void ConvertToLLVMPattern::getMemRefDescriptorSizes(
// %0 = getelementptr %elementType* null, %indexType 1
// %1 = ptrtoint %elementType* %0 to %indexType
// which is a common pattern of getting the size of a type in bytes.
- auto convertedPtrType =
- typeConverter->convertType(type).cast<LLVM::LLVMType>().getPointerTo();
+ auto convertedPtrType = LLVM::LLVMPointerType::get(
+ typeConverter->convertType(type).cast<LLVM::LLVMType>());
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
auto gep = rewriter.create<LLVM::GEPOp>(
loc, convertedPtrType,
builder, loc, typeConverter, unrankedMemRefType,
wrapperArgsRange.take_front(numToDrop));
- auto ptrTy = packed.getType().cast<LLVM::LLVMType>().getPointerTo();
+ auto ptrTy =
+ LLVM::LLVMPointerType::get(packed.getType().cast<LLVM::LLVMType>());
Value one = builder.create<LLVM::ConstantOp>(
loc, typeConverter.convertType(builder.getIndexType()),
builder.getIntegerAttr(builder.getIndexType(), 1));
return info;
info.arraySizes.reserve(vectorType.getRank() - 1);
auto llvmTy = info.llvmArrayTy;
- while (llvmTy.isArrayTy()) {
- info.arraySizes.push_back(llvmTy.getArrayNumElements());
- llvmTy = llvmTy.getArrayElementType();
+ while (llvmTy.isa<LLVM::LLVMArrayType>()) {
+ info.arraySizes.push_back(
+ llvmTy.cast<LLVM::LLVMArrayType>().getNumElements());
+ llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType();
}
- if (!llvmTy.isVectorTy())
+ if (!llvmTy.isa<LLVM::LLVMVectorType>())
return info;
info.llvmVectorTy = llvmTy;
return info;
return failure();
auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
- if (!llvmArrayTy.isArrayTy())
+ if (!llvmArrayTy.isa<LLVM::LLVMArrayType>())
return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy,
LLVM::LLVMType arrayTy =
convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
auto addressOf = rewriter.create<LLVM::AddressOfOp>(
- loc, arrayTy.getPointerTo(memSpace), getGlobalOp.name());
+ loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
// Get the address of the first element in the array by creating a GEP with
// the address of the GV as the base, and (rank + 1) number of 0 indices.
LLVM::LLVMType elementType =
unwrap(typeConverter->convertType(type.getElementType()));
- LLVM::LLVMType elementPtrType = elementType.getPointerTo(memSpace);
+ LLVM::LLVMType elementPtrType =
+ LLVM::LLVMPointerType::get(elementType, memSpace);
SmallVector<Value, 4> operands = {addressOf};
operands.insert(operands.end(), type.getRank() + 1,
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
- if (!operandType.isArrayTy()) {
+ if (!operandType.isa<LLVM::LLVMArrayType>()) {
LLVM::ConstantOp one;
- if (operandType.isVectorTy()) {
+ if (operandType.isa<LLVM::LLVMVectorType>()) {
one = rewriter.create<LLVM::ConstantOp>(
loc, operandType,
SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
op.getOperation(), operands, *getTypeConverter(),
[&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
- mlir::VectorType::get({llvmVectorTy.getVectorNumElements()},
- floatType),
+ mlir::VectorType::get(
+ {llvmVectorTy.cast<LLVM::LLVMFixedVectorType>()
+ .getNumElements()},
+ floatType),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
// ptr = ExtractValueOp src, 1
auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
// castPtr = BitCastOp i8* to structTy*
- auto castPtr =
- rewriter
- .create<LLVM::BitcastOp>(
- loc, targetStructType.cast<LLVM::LLVMType>().getPointerTo(),
- ptr)
- .getResult();
+ auto castPtr = rewriter
+ .create<LLVM::BitcastOp>(
+ loc,
+ LLVM::LLVMPointerType::get(
+ targetStructType.cast<LLVM::LLVMType>()),
+ ptr)
+ .getResult();
// struct = LoadOp castPtr
auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
rewriter.replaceOp(memRefCastOp, loadOp.getResult());
Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
LLVM::LLVMType llvmElementType =
unwrap(typeConverter.convertType(elementType));
- LLVM::LLVMType elementPtrPtrType =
- llvmElementType.getPointerTo(memorySpace).getPointerTo();
+ LLVM::LLVMType elementPtrPtrType = LLVM::LLVMPointerType::get(
+ LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
// Extract pointer to the underlying ranked memref descriptor and cast it to
// ElemType**.
MemRefType targetMemRefType =
castOp.getResult().getType().cast<MemRefType>();
auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
- .dyn_cast_or_null<LLVM::LLVMType>();
- if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
+ .dyn_cast_or_null<LLVM::LLVMStructType>();
+ if (!llvmTargetDescriptorTy)
return failure();
// Create descriptor.
// Set pointers and offset.
LLVM::LLVMType llvmElementType =
unwrap(typeConverter->convertType(elementType));
- LLVM::LLVMType elementPtrPtrType =
- llvmElementType.getPointerTo(addressSpace).getPointerTo();
+ auto elementPtrPtrType = LLVM::LLVMPointerType::get(
+ LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
elementPtrPtrType, allocatedPtr);
UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
rewriter.setInsertionPointToStart(bodyBlock);
// Copy size from shape to descriptor.
- LLVM::LLVMType llvmIndexPtrType = indexType.getPointerTo();
+ LLVM::LLVMType llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
loc,
- typeConverter->convertType(scalarMemRefType)
- .cast<LLVM::LLVMType>()
- .getPointerTo(addressSpace),
+ LLVM::LLVMPointerType::get(
+ typeConverter->convertType(scalarMemRefType).cast<LLVM::LLVMType>(),
+ addressSpace),
underlyingRankedDesc);
// Get pointer to offset field of memref<element_type> descriptor.
- Type indexPtrTy =
- getTypeConverter()->getIndexType().getPointerTo(addressSpace);
+ Type indexPtrTy = LLVM::LLVMPointerType::get(
+ getTypeConverter()->getIndexType(), addressSpace);
Value two = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(rewriter.getI32Type()),
rewriter.getI32IntegerAttr(2));
auto targetType =
typeConverter->convertType(indexCastOp.getResult().getType())
- .cast<LLVM::LLVMType>();
- auto sourceType = transformed.in().getType().cast<LLVM::LLVMType>();
- unsigned targetBits = targetType.getIntegerBitWidth();
- unsigned sourceBits = sourceType.getIntegerBitWidth();
+ .cast<LLVM::LLVMIntegerType>();
+ auto sourceType = transformed.in().getType().cast<LLVM::LLVMIntegerType>();
+ unsigned targetBits = targetType.getBitWidth();
+ unsigned sourceBits = sourceType.getBitWidth();
if (targetBits == sourceBits)
rewriter.replaceOp(indexCastOp, transformed.in());
// Copy the buffer pointer from the old descriptor to the new one.
Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
- loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()),
+ loc,
+ LLVM::LLVMPointerType::get(targetElementTy,
+ viewMemRefType.getMemorySpace()),
extracted);
targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
// Copy the buffer pointer from the old descriptor to the new one.
extracted = sourceMemRef.alignedPtr(rewriter, loc);
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
- loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()),
+ loc,
+ LLVM::LLVMPointerType::get(targetElementTy,
+ viewMemRefType.getMemorySpace()),
extracted);
targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
- loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()),
+ loc,
+ LLVM::LLVMPointerType::get(targetElementTy,
+ srcMemRefType.getMemorySpace()),
allocatedPtr);
targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
alignedPtr, adaptor.byte_shift());
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
- loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()),
+ loc,
+ LLVM::LLVMPointerType::get(targetElementTy,
+ srcMemRefType.getMemorySpace()),
alignedPtr);
targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
auto indexType = IndexType::get(context);
// Alloca with proper alignment. We do not expect optimizations of this
// alloca op and so we omit allocating at the entry block.
- auto ptrType = operand.getType().cast<LLVM::LLVMType>().getPointerTo();
+ auto ptrType =
+ LLVM::LLVMPointerType::get(operand.getType().cast<LLVM::LLVMType>());
Value one = builder.create<LLVM::ConstantOp>(loc, int64Ty,
IntegerAttr::get(indexType, 1));
Value allocated =
Value base;
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
return failure();
- auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
+ auto pType = LLVM::LLVMPointerType::get(type.template cast<LLVM::LLVMType>());
base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
return success();
return failure();
auto llvmSourceDescriptorTy =
- operands[0].getType().dyn_cast<LLVM::LLVMType>();
- if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
+ operands[0].getType().dyn_cast<LLVM::LLVMStructType>();
+ if (!llvmSourceDescriptorTy)
return failure();
MemRefDescriptor sourceMemRef(operands[0]);
auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
- .dyn_cast_or_null<LLVM::LLVMType>();
- if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
+ .dyn_cast_or_null<LLVM::LLVMStructType>();
+ if (!llvmTargetDescriptorTy)
return failure();
// Only contiguous source buffers supported atm.
// TODO: support alignment when possible.
Value dataPtr = this->getStridedElementPtr(
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
- auto vecTy =
- toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
+ auto vecTy = toLLVMTy(xferOp.getVectorType())
+ .template cast<LLVM::LLVMFixedVectorType>();
Value vectorDataPtr;
if (memRefType.getMemorySpace() == 0)
- vectorDataPtr =
- rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
+ vectorDataPtr = rewriter.create<LLVM::BitcastOp>(
+ loc, LLVM::LLVMPointerType::get(vecTy), dataPtr);
else
vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
- loc, vecTy.getPointerTo(), dataPtr);
+ loc, LLVM::LLVMPointerType::get(vecTy), dataPtr);
if (!xferOp.isMaskedDim(0))
return replaceTransferOpWithLoadOrStore(rewriter,
//
// TODO: when the leaf transfer rank is k > 1, we need the last `k`
// dimensions here.
- unsigned vecWidth = vecTy.getVectorNumElements();
+ unsigned vecWidth = vecTy.getNumElements();
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
Value off = xferOp.indices()[lastIndex];
Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
auto toLLVMTy = [&](Type t) {
return this->getTypeConverter()->convertType(t);
};
- LLVM::LLVMType vecTy =
- toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
- unsigned vecWidth = vecTy.getVectorNumElements();
+ auto vecTy = toLLVMTy(xferOp.getVectorType())
+ .template cast<LLVM::LLVMFixedVectorType>();
+ unsigned vecWidth = vecTy.getNumElements();
Location loc = xferOp->getLoc();
// The backend result vector scalarization have trouble scalarize
auto argType = type.dyn_cast<LLVM::LLVMType>();
if (!argType)
return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type");
- if (argType.isVectorTy())
- resultType =
- LLVMType::getVectorTy(resultType, argType.getVectorNumElements());
+ if (auto vecArgType = argType.dyn_cast<LLVM::LLVMFixedVectorType>())
+ resultType = LLVMType::getVectorTy(resultType, vecArgType.getNumElements());
+ assert(!argType.isa<LLVM::LLVMScalableVectorType>() &&
+ "unhandled scalable vector");
result.addTypes({resultType});
return success();
//===----------------------------------------------------------------------===//
static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
- auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
+ auto elemTy = op.getType().cast<LLVM::LLVMPointerType>().getElementType();
auto funcTy = FunctionType::get(op.getContext(), {op.arraySize().getType()},
{op.getType()});
// the resulting type wrapped in MLIR, or nullptr on error.
static Type getLoadStoreElementType(OpAsmParser &parser, Type type,
llvm::SMLoc trailingTypeLoc) {
- auto llvmTy = type.dyn_cast<LLVM::LLVMType>();
+ auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>();
if (!llvmTy)
- return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"),
- nullptr;
- if (!llvmTy.isPointerTy())
return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"),
nullptr;
- return llvmTy.getPointerElementTy();
+ return llvmTy.getElementType();
}
// <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type
auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
/*isVarArg=*/false);
- auto wrappedFuncType = llvmFuncType.getPointerTo();
+ auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
auto funcArguments = llvm::makeArrayRef(operands).drop_front();
for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) {
value = op.getOperand(idx);
- bool isFilter = value.getType().cast<LLVMType>().isArrayTy();
+ bool isFilter = value.getType().isa<LLVMArrayType>();
if (isFilter) {
// FIXME: Verify filter clauses when arrays are appropriately handled
} else {
for (auto value : op.getOperands()) {
// Similar to llvm - if clause is an array type then it is filter
// clause else catch clause
- bool isArrayTy = value.getType().cast<LLVMType>().isArrayTy();
+ bool isArrayTy = value.getType().isa<LLVMArrayType>();
p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
<< value.getType() << ") ";
}
fnType = fn.getType();
}
- if (!fnType.isFunctionTy())
+
+ LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>();
+ if (!funcType)
return op.emitOpError("callee does not have a functional type: ") << fnType;
// Verify that the operand and result types match the callee.
- if (!fnType.isFunctionVarArg() &&
- fnType.getFunctionNumParams() != (op.getNumOperands() - isIndirect))
+ if (!funcType.isVarArg() &&
+ funcType.getNumParams() != (op.getNumOperands() - isIndirect))
return op.emitOpError()
<< "incorrect number of operands ("
<< (op.getNumOperands() - isIndirect)
- << ") for callee (expecting: " << fnType.getFunctionNumParams()
- << ")";
+ << ") for callee (expecting: " << funcType.getNumParams() << ")";
- if (fnType.getFunctionNumParams() > (op.getNumOperands() - isIndirect))
+ if (funcType.getNumParams() > (op.getNumOperands() - isIndirect))
return op.emitOpError() << "incorrect number of operands ("
<< (op.getNumOperands() - isIndirect)
<< ") for varargs callee (expecting at least: "
- << fnType.getFunctionNumParams() << ")";
+ << funcType.getNumParams() << ")";
- for (unsigned i = 0, e = fnType.getFunctionNumParams(); i != e; ++i)
- if (op.getOperand(i + isIndirect).getType() !=
- fnType.getFunctionParamType(i))
+ for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
+ if (op.getOperand(i + isIndirect).getType() != funcType.getParamType(i))
return op.emitOpError() << "operand type mismatch for operand " << i
<< ": " << op.getOperand(i + isIndirect).getType()
- << " != " << fnType.getFunctionParamType(i);
+ << " != " << funcType.getParamType(i);
if (op.getNumResults() &&
- op.getResult(0).getType() != fnType.getFunctionResultType())
+ op.getResult(0).getType() != funcType.getReturnType())
return op.emitOpError()
<< "result type mismatch: " << op.getResult(0).getType()
- << " != " << fnType.getFunctionResultType();
+ << " != " << funcType.getReturnType();
return success();
}
}
auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
/*isVarArg=*/false);
- auto wrappedFuncType = llvmFuncType.getPointerTo();
+ auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
auto funcArguments =
ArrayRef<OpAsmParser::OperandType>(operands).drop_front();
void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result,
Value vector, Value position,
ArrayRef<NamedAttribute> attrs) {
- auto wrappedVectorType = vector.getType().cast<LLVM::LLVMType>();
- auto llvmType = wrappedVectorType.getVectorElementType();
+ auto vectorType = vector.getType().cast<LLVM::LLVMVectorType>();
+ auto llvmType = vectorType.getElementType();
build(b, result, llvmType, vector, position);
result.addAttributes(attrs);
}
parser.resolveOperand(vector, type, result.operands) ||
parser.resolveOperand(position, positionType, result.operands))
return failure();
- auto wrappedVectorType = type.dyn_cast<LLVM::LLVMType>();
- if (!wrappedVectorType || !wrappedVectorType.isVectorTy())
+ auto vectorType = type.dyn_cast<LLVM::LLVMVectorType>();
+ if (!vectorType)
return parser.emitError(
loc, "expected LLVM IR dialect vector type for operand #1");
- result.addTypes(wrappedVectorType.getVectorElementType());
+ result.addTypes(vectorType.getElementType());
return success();
}
ArrayAttr positionAttr,
llvm::SMLoc attributeLoc,
llvm::SMLoc typeLoc) {
- auto wrappedContainerType = containerType.dyn_cast<LLVM::LLVMType>();
- if (!wrappedContainerType)
+ auto llvmType = containerType.dyn_cast<LLVM::LLVMType>();
+ if (!llvmType)
return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
// Infer the element type from the structure type: iteratively step inside the
"expected an array of integer literals"),
nullptr;
int position = positionElementAttr.getInt();
- if (wrappedContainerType.isArrayTy()) {
- if (position < 0 || static_cast<unsigned>(position) >=
- wrappedContainerType.getArrayNumElements())
+ if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
+ if (position < 0 ||
+ static_cast<unsigned>(position) >= arrayType.getNumElements())
return parser.emitError(attributeLoc, "position out of bounds"),
nullptr;
- wrappedContainerType = wrappedContainerType.getArrayElementType();
- } else if (wrappedContainerType.isStructTy()) {
- if (position < 0 || static_cast<unsigned>(position) >=
- wrappedContainerType.getStructNumElements())
+ llvmType = arrayType.getElementType();
+ } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
+ if (position < 0 ||
+ static_cast<unsigned>(position) >= structType.getBody().size())
return parser.emitError(attributeLoc, "position out of bounds"),
nullptr;
- wrappedContainerType =
- wrappedContainerType.getStructElementType(position);
+ llvmType = structType.getBody()[position];
} else {
- return parser.emitError(typeLoc,
- "expected wrapped LLVM IR structure/array type"),
+ return parser.emitError(typeLoc, "expected LLVM IR structure/array type"),
nullptr;
}
}
- return wrappedContainerType;
+ return llvmType;
}
// <operation> ::= `llvm.extractvalue` ssa-use
parser.parseColonType(vectorType))
return failure();
- auto wrappedVectorType = vectorType.dyn_cast<LLVM::LLVMType>();
- if (!wrappedVectorType || !wrappedVectorType.isVectorTy())
+ auto llvmVectorType = vectorType.dyn_cast<LLVM::LLVMVectorType>();
+ if (!llvmVectorType)
return parser.emitError(
loc, "expected LLVM IR dialect vector type for operand #1");
- auto valueType = wrappedVectorType.getVectorElementType();
+ Type valueType = llvmVectorType.getElementType();
if (!valueType)
return failure();
return op.emitOpError(
"must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
- if (global && global.getType().getPointerTo(global.addr_space()) !=
- op.getResult().getType())
+ if (global &&
+ LLVM::LLVMPointerType::get(global.getType(), global.addr_space()) !=
+ op.getResult().getType())
return op.emitOpError(
"the type must be a pointer to the type of the referenced global");
- if (function && function.getType().getPointerTo() != op.getResult().getType())
+ if (function && LLVM::LLVMPointerType::get(function.getType()) !=
+ op.getResult().getType())
return op.emitOpError(
"the type must be a pointer to the type of the referenced function");
if (vectorType.getRank() != 1)
return op->emitOpError("only 1-d vector is allowed");
- auto llvmVector = llvmType.dyn_cast<LLVMVectorType>();
- if (llvmVector.isa<LLVMScalableVectorType>())
+ auto llvmVector = llvmType.dyn_cast<LLVMFixedVectorType>();
+ if (!llvmVector)
return op->emitOpError("only fixed-sized vector is allowed");
- if (vectorType.getDimSize(0) != llvmVector.getVectorNumElements())
+ if (vectorType.getDimSize(0) != llvmVector.getNumElements())
return op->emitOpError(
"invalid cast between vectors with mismatching sizes");
"be an index-compatible integer");
auto ptrType = structType.getBody()[1].dyn_cast<LLVMPointerType>();
- if (!ptrType || !ptrType.getPointerElementTy().isIntegerTy(8))
+ auto ptrElementType =
+ ptrType ? ptrType.getElementType().dyn_cast<LLVMIntegerType>()
+ : nullptr;
+ if (!ptrElementType || ptrElementType.getBitWidth() != 8)
return op->emitOpError("expected second element of a memref descriptor "
"to be an !llvm.ptr<i8>");
return op.emitOpError("must appear at the module level");
if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
- auto type = op.getType();
- if (!type.isArrayTy() || !type.getArrayElementType().isIntegerTy(8) ||
- type.getArrayNumElements() != strAttr.getValue().size())
+ auto type = op.getType().dyn_cast<LLVMArrayType>();
+ LLVMIntegerType elementType =
+ type ? type.getElementType().dyn_cast<LLVMIntegerType>() : nullptr;
+ if (!elementType || elementType.getBitWidth() != 8 ||
+ type.getNumElements() != strAttr.getValue().size())
return op.emitOpError(
"requires an i8 array type of the length equal to that of the string "
"attribute");
void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result,
Value v1, Value v2, ArrayAttr mask,
ArrayRef<NamedAttribute> attrs) {
- auto wrappedContainerType1 = v1.getType().cast<LLVM::LLVMType>();
- auto vType = LLVMType::getVectorTy(
- wrappedContainerType1.getVectorElementType(), mask.size());
+ auto containerType = v1.getType().cast<LLVM::LLVMVectorType>();
+ auto vType =
+ LLVMType::getVectorTy(containerType.getElementType(), mask.size());
build(b, result, vType, v1, v2, mask);
result.addAttributes(attrs);
}
parser.resolveOperand(v1, typeV1, result.operands) ||
parser.resolveOperand(v2, typeV2, result.operands))
return failure();
- auto wrappedContainerType1 = typeV1.dyn_cast<LLVM::LLVMType>();
- if (!wrappedContainerType1 || !wrappedContainerType1.isVectorTy())
+ auto containerType = typeV1.dyn_cast<LLVM::LLVMVectorType>();
+ if (!containerType)
return parser.emitError(
loc, "expected LLVM IR dialect vector type for operand #1");
- auto vType = LLVMType::getVectorTy(
- wrappedContainerType1.getVectorElementType(), maskAttr.size());
+ auto vType =
+ LLVMType::getVectorTy(containerType.getElementType(), maskAttr.size());
result.addTypes(vType);
return success();
}
auto *entry = new Block;
push_back(entry);
- LLVMType type = getType();
- for (unsigned i = 0, e = type.getFunctionNumParams(); i < e; ++i)
- entry->addArgument(type.getFunctionParamType(i));
+ LLVMFunctionType type = getType();
+ for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
+ entry->addArgument(type.getParamType(i));
return entry;
}
if (argAttrs.empty())
return;
- unsigned numInputs = type.getFunctionNumParams();
+ unsigned numInputs = type.cast<LLVMFunctionType>().getNumParams();
assert(numInputs == argAttrs.size() &&
"expected as many argument attribute lists as arguments");
SmallString<8> argAttrName;
p << stringifyLinkage(op.linkage()) << ' ';
p.printSymbolName(op.getName());
- LLVMType fnType = op.getType();
+ LLVMFunctionType fnType = op.getType();
SmallVector<Type, 8> argTypes;
SmallVector<Type, 1> resTypes;
- argTypes.reserve(fnType.getFunctionNumParams());
- for (unsigned i = 0, e = fnType.getFunctionNumParams(); i < e; ++i)
- argTypes.push_back(fnType.getFunctionParamType(i));
+ argTypes.reserve(fnType.getNumParams());
+ for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
+ argTypes.push_back(fnType.getParamType(i));
- LLVMType returnType = fnType.getFunctionResultType();
- if (!returnType.isVoidTy())
+ LLVMType returnType = fnType.getReturnType();
+ if (!returnType.isa<LLVMVoidType>())
resTypes.push_back(returnType);
impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), resTypes);
// attribute is present. This can check for preconditions of the
// getNumArguments hook not failing.
LogicalResult LLVMFuncOp::verifyType() {
- auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMType>();
- if (!llvmType || !llvmType.isFunctionTy())
+ auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMFunctionType>();
+ if (!llvmType)
return emitOpError("requires '" + getTypeAttrName() +
"' attribute of wrapped LLVM function type");
// Hook for OpTrait::FunctionLike, returns the number of function arguments.
// Depends on the type attribute being correct as checked by verifyType
-unsigned LLVMFuncOp::getNumFuncArguments() {
- return getType().getFunctionNumParams();
-}
+unsigned LLVMFuncOp::getNumFuncArguments() { return getType().getNumParams(); }
// Hook for OpTrait::FunctionLike, returns the number of function results.
// Depends on the type attribute being correct as checked by verifyType
// If we modeled a void return as one result, then it would be possible to
// attach an MLIR result attribute to it, and it isn't clear what semantics we
// would assign to that.
- if (getType().getFunctionResultType().isVoidTy())
+ if (getType().getReturnType().isa<LLVMVoidType>())
return 0;
return 1;
}
if (op.isVarArg())
return op.emitOpError("only external functions can be variadic");
- unsigned numArguments = op.getType().getFunctionNumParams();
+ unsigned numArguments = op.getType().getNumParams();
Block &entryBlock = op.front();
for (unsigned i = 0; i < numArguments; ++i) {
Type argType = entryBlock.getArgument(i).getType();
if (!argLLVMType)
return op.emitOpError("entry block argument #")
<< i << " is not of LLVM type";
- if (op.getType().getFunctionParamType(i) != argLLVMType)
+ if (op.getType().getParamType(i) != argLLVMType)
return op.emitOpError("the type of entry block argument #")
<< i << " does not match the function signature";
}
parseAtomicOrdering(parser, result, "ordering") ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
- parser.resolveOperand(ptr, type.getPointerTo(), result.operands) ||
+ parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
+ result.operands) ||
parser.resolveOperand(val, type, result.operands))
return failure();
}
static LogicalResult verify(AtomicRMWOp op) {
- auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>();
+ auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>();
auto valType = op.val().getType().cast<LLVM::LLVMType>();
- if (valType != ptrType.getPointerElementTy())
+ if (valType != ptrType.getElementType())
return op.emitOpError("expected LLVM IR element type for operand #0 to "
"match type for operand #1");
auto resType = op.res().getType().cast<LLVM::LLVMType>();
return op.emitOpError(
"expected LLVM IR result type to match type for operand #1");
if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) {
- if (!valType.isFloatingPointTy())
+ if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
return op.emitOpError("expected LLVM IR floating point type");
} else if (op.bin_op() == AtomicBinOp::xchg) {
- if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
- !valType.isIntegerTy(32) && !valType.isIntegerTy(64) &&
- !valType.isBFloatTy() && !valType.isHalfTy() && !valType.isFloatTy() &&
- !valType.isDoubleTy())
+ auto intType = valType.dyn_cast<LLVMIntegerType>();
+ unsigned intBitWidth = intType ? intType.getBitWidth() : 0;
+ if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
+ intBitWidth != 64 && !valType.isa<LLVMBFloatType>() &&
+ !valType.isa<LLVMHalfType>() && !valType.isa<LLVMFloatType>() &&
+ !valType.isa<LLVMDoubleType>())
return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
} else {
- if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
- !valType.isIntegerTy(32) && !valType.isIntegerTy(64))
+ auto intType = valType.dyn_cast<LLVMIntegerType>();
+ unsigned intBitWidth = intType ? intType.getBitWidth() : 0;
+ if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
+ intBitWidth != 64)
return op.emitOpError("expected LLVM IR integer type");
}
return success();
parseAtomicOrdering(parser, result, "failure_ordering") ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
- parser.resolveOperand(ptr, type.getPointerTo(), result.operands) ||
+ parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
+ result.operands) ||
parser.resolveOperand(cmp, type, result.operands) ||
parser.resolveOperand(val, type, result.operands))
return failure();
}
static LogicalResult verify(AtomicCmpXchgOp op) {
- auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>();
- if (!ptrType.isPointerTy())
+ auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>();
+ if (!ptrType)
return op.emitOpError("expected LLVM IR pointer type for operand #0");
auto cmpType = op.cmp().getType().cast<LLVM::LLVMType>();
auto valType = op.val().getType().cast<LLVM::LLVMType>();
- if (cmpType != ptrType.getPointerElementTy() || cmpType != valType)
+ if (cmpType != ptrType.getElementType() || cmpType != valType)
return op.emitOpError("expected LLVM IR element type for operand #0 to "
"match type for all other operands");
- if (!valType.isPointerTy() && !valType.isIntegerTy(8) &&
- !valType.isIntegerTy(16) && !valType.isIntegerTy(32) &&
- !valType.isIntegerTy(64) && !valType.isBFloatTy() &&
- !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy())
+ auto intType = valType.dyn_cast<LLVMIntegerType>();
+ unsigned intBitWidth = intType ? intType.getBitWidth() : 0;
+ if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 &&
+ intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 &&
+ !valType.isa<LLVMBFloatType>() && !valType.isa<LLVMHalfType>() &&
+ !valType.isa<LLVMFloatType>() && !valType.isa<LLVMDoubleType>())
return op.emitOpError("unexpected LLVM IR type");
if (op.success_ordering() < AtomicOrdering::monotonic ||
op.failure_ordering() < AtomicOrdering::monotonic)
return static_cast<LLVMDialect &>(Type::getDialect());
}
-//----------------------------------------------------------------------------//
-// Misc type utilities.
-
-llvm::TypeSize LLVMType::getPrimitiveSizeInBits() {
- return llvm::TypeSwitch<LLVMType, llvm::TypeSize>(*this)
- .Case<LLVMHalfType, LLVMBFloatType>(
- [](LLVMType) { return llvm::TypeSize::Fixed(16); })
- .Case<LLVMFloatType>([](LLVMType) { return llvm::TypeSize::Fixed(32); })
- .Case<LLVMDoubleType, LLVMX86MMXType>(
- [](LLVMType) { return llvm::TypeSize::Fixed(64); })
- .Case<LLVMIntegerType>([](LLVMIntegerType intTy) {
- return llvm::TypeSize::Fixed(intTy.getBitWidth());
- })
- .Case<LLVMX86FP80Type>([](LLVMType) { return llvm::TypeSize::Fixed(80); })
- .Case<LLVMPPCFP128Type, LLVMFP128Type>(
- [](LLVMType) { return llvm::TypeSize::Fixed(128); })
- .Case<LLVMVectorType>([](LLVMVectorType t) {
- llvm::TypeSize elementSize =
- t.getElementType().getPrimitiveSizeInBits();
- llvm::ElementCount elementCount = t.getElementCount();
- assert(!elementSize.isScalable() &&
- "vector type should have fixed-width elements");
- return llvm::TypeSize(elementSize.getFixedSize() *
- elementCount.getKnownMinValue(),
- elementCount.isScalable());
- })
- .Default([](LLVMType ty) {
- assert((ty.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
- LLVMTokenType, LLVMStructType, LLVMArrayType,
- LLVMPointerType, LLVMFunctionType>()) &&
- "unexpected missing support for primitive type");
- return llvm::TypeSize::Fixed(0);
- });
-}
-
-//----------------------------------------------------------------------------//
-// Integer type utilities.
-
-bool LLVMType::isIntegerTy(unsigned bitwidth) {
- if (auto intType = dyn_cast<LLVMIntegerType>())
- return intType.getBitWidth() == bitwidth;
- return false;
-}
-unsigned LLVMType::getIntegerBitWidth() {
- return cast<LLVMIntegerType>().getBitWidth();
-}
-
-LLVMType LLVMType::getArrayElementType() {
- return cast<LLVMArrayType>().getElementType();
-}
-
-//----------------------------------------------------------------------------//
-// Array type utilities.
-
-unsigned LLVMType::getArrayNumElements() {
- return cast<LLVMArrayType>().getNumElements();
-}
-
-bool LLVMType::isArrayTy() { return isa<LLVMArrayType>(); }
-
-//----------------------------------------------------------------------------//
-// Vector type utilities.
-
-LLVMType LLVMType::getVectorElementType() {
- return cast<LLVMVectorType>().getElementType();
-}
-
-unsigned LLVMType::getVectorNumElements() {
- return cast<LLVMFixedVectorType>().getNumElements();
-}
-llvm::ElementCount LLVMType::getVectorElementCount() {
- return cast<LLVMVectorType>().getElementCount();
-}
-
-bool LLVMType::isVectorTy() { return isa<LLVMVectorType>(); }
-
-//----------------------------------------------------------------------------//
-// Function type utilities.
-
-LLVMType LLVMType::getFunctionParamType(unsigned argIdx) {
- return cast<LLVMFunctionType>().getParamType(argIdx);
-}
-
-unsigned LLVMType::getFunctionNumParams() {
- return cast<LLVMFunctionType>().getNumParams();
-}
-
-LLVMType LLVMType::getFunctionResultType() {
- return cast<LLVMFunctionType>().getReturnType();
-}
-
-bool LLVMType::isFunctionTy() { return isa<LLVMFunctionType>(); }
-
-bool LLVMType::isFunctionVarArg() {
- return cast<LLVMFunctionType>().isVarArg();
-}
-
-//----------------------------------------------------------------------------//
-// Pointer type utilities.
-
-LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
- return LLVMPointerType::get(*this, addrSpace);
-}
-
-LLVMType LLVMType::getPointerElementTy() {
- return cast<LLVMPointerType>().getElementType();
-}
-
-bool LLVMType::isPointerTy() { return isa<LLVMPointerType>(); }
-
-//----------------------------------------------------------------------------//
-// Struct type utilities.
-
-LLVMType LLVMType::getStructElementType(unsigned i) {
- return cast<LLVMStructType>().getBody()[i];
-}
-
-unsigned LLVMType::getStructNumElements() {
- return cast<LLVMStructType>().getBody().size();
-}
-
-bool LLVMType::isStructTy() { return isa<LLVMStructType>(); }
-
//----------------------------------------------------------------------------//
// Utilities used to generate floating point types.
return LLVMIntegerType::get(context, numBits);
}
+LLVMType LLVMType::getInt8PtrTy(MLIRContext *context) {
+ return LLVMPointerType::get(LLVMIntegerType::get(context, 8));
+}
+
//----------------------------------------------------------------------------//
// Utilities used to generate other miscellaneous types.
return LLVMVoidType::get(context);
}
-bool LLVMType::isVoidTy() { return isa<LLVMVoidType>(); }
-
//----------------------------------------------------------------------------//
// Creation and setting of LLVM's identified struct types
bool LLVMVectorType::isValidElementType(LLVMType type) {
return type.isa<LLVMIntegerType, LLVMPointerType>() ||
- type.isFloatingPointTy();
+ mlir::LLVM::isCompatibleFloatingPointType(type);
}
/// Support type casting functionality.
unsigned LLVMScalableVectorType::getMinNumElements() {
return getImpl()->numElements;
}
+
+//===----------------------------------------------------------------------===//
+// Utility functions.
+//===----------------------------------------------------------------------===//
+
+llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
+ assert(isCompatibleType(type) &&
+ "expected a type compatible with the LLVM dialect");
+
+ return llvm::TypeSwitch<Type, llvm::TypeSize>(type)
+ .Case<LLVMHalfType, LLVMBFloatType>(
+ [](LLVMType) { return llvm::TypeSize::Fixed(16); })
+ .Case<LLVMFloatType>([](LLVMType) { return llvm::TypeSize::Fixed(32); })
+ .Case<LLVMDoubleType, LLVMX86MMXType>(
+ [](LLVMType) { return llvm::TypeSize::Fixed(64); })
+ .Case<LLVMIntegerType>([](LLVMIntegerType intTy) {
+ return llvm::TypeSize::Fixed(intTy.getBitWidth());
+ })
+ .Case<LLVMX86FP80Type>([](LLVMType) { return llvm::TypeSize::Fixed(80); })
+ .Case<LLVMPPCFP128Type, LLVMFP128Type>(
+ [](LLVMType) { return llvm::TypeSize::Fixed(128); })
+ .Case<LLVMVectorType>([](LLVMVectorType t) {
+ llvm::TypeSize elementSize =
+ getPrimitiveTypeSizeInBits(t.getElementType());
+ llvm::ElementCount elementCount = t.getElementCount();
+ assert(!elementSize.isScalable() &&
+ "vector type should have fixed-width elements");
+ return llvm::TypeSize(elementSize.getFixedSize() *
+ elementCount.getKnownMinValue(),
+ elementCount.isScalable());
+ })
+ .Default([](Type ty) {
+ assert((ty.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
+ LLVMTokenType, LLVMStructType, LLVMArrayType,
+ LLVMPointerType, LLVMFunctionType>()) &&
+ "unexpected missing support for primitive type");
+ return llvm::TypeSize::Fixed(0);
+ });
+}
for (auto &attr : result.attributes) {
if (attr.first != "return_value_and_is_valid")
continue;
- if (type.isStructTy() && type.getStructNumElements() > 0)
- type = type.getStructElementType(0);
+ auto structType = type.dyn_cast<LLVM::LLVMStructType>();
+ if (structType && !structType.getBody().empty())
+ type = structType.getBody()[0];
break;
}
Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
template <>
Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
- if (!mainFunction.getType().getFunctionResultType().isIntegerTy(32))
+ auto resultType = mainFunction.getType()
+ .cast<LLVM::LLVMFunctionType>()
+ .getReturnType()
+ .dyn_cast<LLVM::LLVMIntegerType>();
+ if (!resultType || resultType.getBitWidth() != 32)
return make_string_error("only single llvm.i32 function result supported");
return Error::success();
}
template <>
Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
- if (!mainFunction.getType().getFunctionResultType().isIntegerTy(64))
+ auto resultType = mainFunction.getType()
+ .cast<LLVM::LLVMFunctionType>()
+ .getReturnType()
+ .dyn_cast<LLVM::LLVMIntegerType>();
+ if (!resultType || resultType.getBitWidth() != 64)
return make_string_error("only single llvm.i64 function result supported");
return Error::success();
}
template <>
Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
- if (!mainFunction.getType().getFunctionResultType().isFloatTy())
+ if (!mainFunction.getType()
+ .cast<LLVM::LLVMFunctionType>()
+ .getReturnType()
+ .isa<LLVM::LLVMFloatType>())
return make_string_error("only single llvm.f32 function result supported");
return Error::success();
}
if (!mainFunction || mainFunction.isExternal())
return make_string_error("entry point not found");
- if (mainFunction.getType().getFunctionNumParams() != 0)
+ if (mainFunction.getType().cast<LLVM::LLVMFunctionType>().getNumParams() != 0)
return make_string_error("function inputs not supported");
if (Error error = checkCompatibleReturnType<Type>(mainFunction))
if (!type)
return nullptr;
- if (type.isIntegerTy())
- return b.getIntegerType(type.getIntegerBitWidth());
+ if (auto intType = type.dyn_cast<LLVMIntegerType>())
+ return b.getIntegerType(intType.getBitWidth());
- if (type.isFloatTy())
+ if (type.isa<LLVMFloatType>())
return b.getF32Type();
- if (type.isDoubleTy())
+ if (type.isa<LLVMDoubleType>())
return b.getF64Type();
// LLVM vectors can only contain scalars.
- if (type.isVectorTy()) {
- auto numElements = type.getVectorElementCount();
+ if (auto vectorType = type.dyn_cast<LLVM::LLVMVectorType>()) {
+ auto numElements = vectorType.getElementCount();
if (numElements.isScalable()) {
emitError(unknownLoc) << "scalable vectors not supported";
return nullptr;
}
- Type elementType = getStdTypeForAttr(type.getVectorElementType());
+ Type elementType = getStdTypeForAttr(vectorType.getElementType());
if (!elementType)
return nullptr;
return VectorType::get(numElements.getKnownMinValue(), elementType);
}
// LLVM arrays can contain other arrays or vectors.
- if (type.isArrayTy()) {
+ if (auto arrayType = type.dyn_cast<LLVMArrayType>()) {
// Recover the nested array shape.
SmallVector<int64_t, 4> shape;
- shape.push_back(type.getArrayNumElements());
- while (type.getArrayElementType().isArrayTy()) {
- type = type.getArrayElementType();
- shape.push_back(type.getArrayNumElements());
+ shape.push_back(arrayType.getNumElements());
+ while (arrayType.getElementType().isa<LLVMArrayType>()) {
+ arrayType = arrayType.getElementType().cast<LLVMArrayType>();
+ shape.push_back(arrayType.getNumElements());
}
// If the innermost type is a vector, use the multi-dimensional vector as
// attribute type.
- if (type.getArrayElementType().isVectorTy()) {
- LLVMType vectorType = type.getArrayElementType();
- auto numElements = vectorType.getVectorElementCount();
+ if (auto vectorType =
+ arrayType.getElementType().dyn_cast<LLVMVectorType>()) {
+ auto numElements = vectorType.getElementCount();
if (numElements.isScalable()) {
emitError(unknownLoc) << "scalable vectors not supported";
return nullptr;
}
shape.push_back(numElements.getKnownMinValue());
- Type elementType = getStdTypeForAttr(vectorType.getVectorElementType());
+ Type elementType = getStdTypeForAttr(vectorType.getElementType());
if (!elementType)
return nullptr;
return VectorType::get(shape, elementType);
}
// Otherwise use a tensor.
- Type elementType = getStdTypeForAttr(type.getArrayElementType());
+ Type elementType = getStdTypeForAttr(arrayType.getElementType());
if (!elementType)
return nullptr;
return RankedTensorType::get(shape, elementType);
if (!attrType)
return nullptr;
- if (type.isIntegerTy()) {
+ if (type.isa<LLVMIntegerType>()) {
SmallVector<APInt, 8> values;
values.reserve(cd->getNumElements());
for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
return DenseElementsAttr::get(attrType, values);
}
- if (type.isFloatTy() || type.isDoubleTy()) {
+ if (type.isa<LLVMFloatType>() || type.isa<LLVMDoubleType>()) {
SmallVector<APFloat, 8> values;
values.reserve(cd->getNumElements());
for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
instMap.clear();
unknownInstMap.clear();
- LLVMType functionType = processType(f->getFunctionType());
+ auto functionType =
+ processType(f->getFunctionType()).dyn_cast<LLVMFunctionType>();
if (!functionType)
return failure();
// Add function arguments to the entry block.
for (auto kv : llvm::enumerate(f->args()))
- instMap[&kv.value()] = blockList[0]->addArgument(
- functionType.getFunctionParamType(kv.index()));
+ instMap[&kv.value()] =
+ blockList[0]->addArgument(functionType.getParamType(kv.index()));
for (auto bbs : llvm::zip(*f, blockList)) {
if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))
// NB: Attribute already verified to be boolean, so check if we can indeed
// attach the attribute to this argument, based on its type.
auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMType>();
- if (!argTy.isPointerTy())
+ if (!argTy.isa<LLVM::LLVMPointerType>())
return func.emitError(
"llvm.noalias attribute attached to LLVM non-pointer argument");
if (attr.getValue())
// NB: Attribute already verified to be int, so check if we can indeed
// attach the attribute to this argument, based on its type.
auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMType>();
- if (!argTy.isPointerTy())
+ if (!argTy.isa<LLVM::LLVMPointerType>())
return func.emitError(
"llvm.align attribute attached to LLVM non-pointer argument");
llvmArg.addAttrs(
// -----
func @load_non_llvm_type(%foo : memref<f32>) {
- // expected-error@+1 {{expected LLVM IR dialect type}}
+ // expected-error@+1 {{expected LLVM pointer type}}
llvm.load %foo : memref<f32>
}
// -----
func @store_non_llvm_type(%foo : memref<f32>, %bar : !llvm.float) {
- // expected-error@+1 {{expected LLVM IR dialect type}}
+ // expected-error@+1 {{expected LLVM pointer type}}
llvm.store %bar, %foo : memref<f32>
}
// -----
func @insertvalue_wrong_nesting() {
- // expected-error@+1 {{expected wrapped LLVM IR structure/array type}}
+ // expected-error@+1 {{expected LLVM IR structure/array type}}
llvm.insertvalue %a, %b[0,0] : !llvm.struct<(i32)>
}
// -----
func @extractvalue_wrong_nesting() {
- // expected-error@+1 {{expected wrapped LLVM IR structure/array type}}
+ // expected-error@+1 {{expected LLVM IR structure/array type}}
llvm.extractvalue %b[0,0] : !llvm.struct<(i32)>
}