struct LLVMDialectImpl;
} // namespace detail
+class LLVMType;
+
+/// Converts an MLIR LLVM dialect type to LLVM IR type. Note that this function
+/// exists exclusively for the purpose of gradual transition to the first-party
+/// modeling of LLVM types. It should not be used outside MLIR-to-LLVM
+/// translation.
+llvm::Type *convertLLVMType(LLVMType type);
+
class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
detail::LLVMTypeStorage> {
public:
static bool kindof(unsigned kind) { return kind == LLVM_TYPE; }
LLVMDialect &getDialect();
- llvm::Type *getUnderlyingType() const;
/// Utilities to identify types.
bool isBFloatTy() { return getUnderlyingType()->isBFloatTy(); }
bool isHalfTy() { return getUnderlyingType()->isHalfTy(); }
bool isFloatTy() { return getUnderlyingType()->isFloatTy(); }
bool isDoubleTy() { return getUnderlyingType()->isDoubleTy(); }
- bool isIntegerTy() { return getUnderlyingType()->isIntegerTy(); }
- bool isIntegerTy(unsigned bitwidth) {
- return getUnderlyingType()->isIntegerTy(bitwidth);
- }
+ bool isFloatingPointTy() { return getUnderlyingType()->isFloatingPointTy(); }
/// Array type utilities.
LLVMType getArrayElementType();
unsigned getArrayNumElements();
bool isArrayTy();
+ /// Integer type utilities.
+ unsigned getIntegerBitWidth() {
+ return getUnderlyingType()->getIntegerBitWidth();
+ }
+ bool isIntegerTy() { return getUnderlyingType()->isIntegerTy(); }
+ bool isIntegerTy(unsigned bitwidth) {
+ return getUnderlyingType()->isIntegerTy(bitwidth);
+ }
+
/// Vector type utilities.
LLVMType getVectorElementType();
unsigned getVectorNumElements();
+ llvm::ElementCount getVectorElementCount();
bool isVectorTy();
/// Function type utilities.
unsigned getFunctionNumParams();
LLVMType getFunctionResultType();
bool isFunctionTy();
+ bool isFunctionVarArg();
/// Pointer type utilities.
LLVMType getPointerTo(unsigned addrSpace = 0);
LLVMType getPointerElementTy();
bool isPointerTy();
+ static bool isValidPointerElementType(LLVMType type);
/// Struct type utilities.
LLVMType getStructElementType(unsigned i);
private:
friend LLVMDialect;
+ friend llvm::Type *convertLLVMType(LLVMType type);
+
+ /// Get the underlying LLVM IR type.
+ llvm::Type *getUnderlyingType() const;
+
+ /// Get the underlying LLVM IR types for the given array of types.
+ static void getUnderlyingTypes(ArrayRef<LLVMType> types,
+ SmallVectorImpl<llvm::Type *> &result);
/// Get an LLVMType with a pre-existing llvm type.
static LLVMType get(MLIRContext *context, llvm::Type *llvmType);
// or result in the operation.
def LLVM_IntrPatterns {
string operand =
- [{opInst.getOperand($0).getType()
- .cast<LLVM::LLVMType>().getUnderlyingType()}];
+ [{convertType(opInst.getOperand($0).getType().cast<LLVM::LLVMType>())}];
string result =
- [{opInst.getResult($0).getType()
- .cast<LLVM::LLVMType>().getUnderlyingType()}];
+ [{convertType(opInst.getResult($0).getType().cast<LLVM::LLVMType>())}];
}
[{
auto llvmType = resultType.dyn_cast<LLVM::LLVMType>(); (void)llvmType;
assert(llvmType && "result must be an LLVM type");
- assert(llvmType.getUnderlyingType() &&
- llvmType.getUnderlyingType()->isVoidTy() &&
- "for zero-result operands, only 'void' is accepted as result type");
+ assert(llvmType.isVoidTy() &&
+ "for zero-result operands, only 'void' is accepted as result type");
build(builder, result, operands, attributes);
}]>;
let verifier = [{
auto wrappedVectorType1 = v1().getType().cast<LLVM::LLVMType>();
auto wrappedVectorType2 = v2().getType().cast<LLVM::LLVMType>();
- if (!wrappedVectorType2.getUnderlyingType()->isVectorTy())
+ if (!wrappedVectorType2.isVectorTy())
return emitOpError("expected LLVM IR Dialect vector type for operand #2");
if (wrappedVectorType1.getVectorElementType() !=
wrappedVectorType2.getVectorElementType())
.getValue().cast<LLVMType>();
}
bool isVarArg() {
- return getType().getUnderlyingType()->isFunctionVarArg();
+ return getType().isFunctionVarArg();
}
- // Hook for OpTrait::FunctionLike, returns the number of function arguments.
+ // Hook for OpTrait::FunctionLike, returns the number of function arguments`.
// Depends on the type attribute being correct as checked by verifyType.
unsigned getNumFuncArguments();
// Vector buffer load/store intrinsics
def ROCDL_MubufLoadOp :
- ROCDL_Op<"buffer.load">,
+ ROCDL_Op<"buffer.load">,
Results<(outs LLVM_Type:$res)>,
Arguments<(ins LLVM_Type:$rsrc,
LLVM_Type:$vindex,
}
def ROCDL_MubufStoreOp :
- ROCDL_Op<"buffer.store">,
+ ROCDL_Op<"buffer.store">,
Arguments<(ins LLVM_Type:$vdata,
LLVM_Type:$rsrc,
LLVM_Type:$vindex,
LLVM_Type:$glc,
LLVM_Type:$slc)>{
string llvmBuilder = [{
- auto vdataType = op.vdata().getType().cast<LLVM::LLVMType>()
- .getUnderlyingType();
+ auto vdataType = convertType(op.vdata().getType().cast<LLVM::LLVMType>());
createIntrinsicCall(builder,
- llvm::Intrinsic::amdgcn_buffer_store, {$vdata, $rsrc, $vindex,
+ llvm::Intrinsic::amdgcn_buffer_store, {$vdata, $rsrc, $vindex,
$offset, $glc, $slc}, {vdataType});
}];
let parser = [{ return parseROCDLMubufStoreOp(parser, result); }];
- let printer = [{
+ let printer = [{
Operation *op = this->getOperation();
p << op->getName() << " " << op->getOperands()
<< " : " << vdata().getType();
llvm::IRBuilder<> &builder);
virtual LogicalResult convertOmpParallel(Operation &op,
llvm::IRBuilder<> &builder);
+
+ /// Converts the type from MLIR LLVM dialect to LLVM.
+ llvm::Type *convertType(LLVMType type);
+
static std::unique_ptr<llvm::Module> prepareLLVMModule(Operation *m);
/// A helper to look up remapped operands in the value remapping table.
/// Returns the bit width of LLVMType integer or vector.
static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) {
- return type.isVectorTy() ? type.getVectorElementType()
- .getUnderlyingType()
- ->getIntegerBitWidth()
- : type.getUnderlyingType()->getIntegerBitWidth();
+ return type.isVectorTy() ? type.getVectorElementType().getIntegerBitWidth()
+ : type.getIntegerBitWidth();
}
/// Creates `IntegerAttribute` with all bits set for given type
op, operands, typeConverter,
[&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
- mlir::VectorType::get(
- {cast<llvm::FixedVectorType>(llvmVectorTy.getUnderlyingType())
- ->getNumElements()},
- floatType),
+ mlir::VectorType::get({llvmVectorTy.getVectorNumElements()},
+ floatType),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
this->typeConverter.convertType(indexCastOp.getResult().getType())
.cast<LLVM::LLVMType>();
auto sourceType = transformed.in().getType().cast<LLVM::LLVMType>();
- unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth();
- unsigned sourceBits = sourceType.getUnderlyingType()->getIntegerBitWidth();
+ unsigned targetBits = targetType.getIntegerBitWidth();
+ unsigned sourceBits = sourceType.getIntegerBitWidth();
if (targetBits == sourceBits)
rewriter.replaceOp(op, transformed.in());
auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout();
align = dataLayout.getPrefTypeAlignment(
- elementTy.cast<LLVM::LLVMType>().getUnderlyingType());
+ LLVM::convertLLVMType(elementTy.cast<LLVM::LLVMType>()));
return success();
}
auto argType = type.dyn_cast<LLVM::LLVMType>();
if (!argType)
return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type");
- if (argType.getUnderlyingType()->isVectorTy())
- resultType = LLVMType::getVectorTy(
- resultType,
- llvm::cast<llvm::FixedVectorType>(argType.getUnderlyingType())
- ->getNumElements());
+ if (argType.isVectorTy())
+ resultType =
+ LLVMType::getVectorTy(resultType, argType.getVectorNumElements());
result.addTypes({resultType});
return success();
if (!llvmTy)
return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"),
nullptr;
- if (!llvmTy.getUnderlyingType()->isPointerTy())
+ if (!llvmTy.isPointerTy())
return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"),
nullptr;
return llvmTy.getPointerElementTy();
parser.resolveOperand(position, positionType, result.operands))
return failure();
auto wrappedVectorType = type.dyn_cast<LLVM::LLVMType>();
- if (!wrappedVectorType ||
- !wrappedVectorType.getUnderlyingType()->isVectorTy())
+ if (!wrappedVectorType || !wrappedVectorType.isVectorTy())
return parser.emitError(
loc, "expected LLVM IR dialect vector type for operand #1");
result.addTypes(wrappedVectorType.getVectorElementType());
"expected an array of integer literals"),
nullptr;
int position = positionElementAttr.getInt();
- auto *llvmContainerType = wrappedContainerType.getUnderlyingType();
- if (llvmContainerType->isArrayTy()) {
+ if (wrappedContainerType.isArrayTy()) {
if (position < 0 || static_cast<unsigned>(position) >=
- llvmContainerType->getArrayNumElements())
+ wrappedContainerType.getArrayNumElements())
return parser.emitError(attributeLoc, "position out of bounds"),
nullptr;
wrappedContainerType = wrappedContainerType.getArrayElementType();
- } else if (llvmContainerType->isStructTy()) {
+ } else if (wrappedContainerType.isStructTy()) {
if (position < 0 || static_cast<unsigned>(position) >=
- llvmContainerType->getStructNumElements())
+ wrappedContainerType.getStructNumElements())
return parser.emitError(attributeLoc, "position out of bounds"),
nullptr;
wrappedContainerType =
return failure();
auto wrappedVectorType = vectorType.dyn_cast<LLVM::LLVMType>();
- if (!wrappedVectorType ||
- !wrappedVectorType.getUnderlyingType()->isVectorTy())
+ if (!wrappedVectorType || !wrappedVectorType.isVectorTy())
return parser.emitError(
loc, "expected LLVM IR dialect vector type for operand #1");
auto valueType = wrappedVectorType.getVectorElementType();
}
static LogicalResult verify(GlobalOp op) {
- if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType()))
+ if (!LLVMType::isValidPointerElementType(op.getType()))
return op.emitOpError(
"expects type to be a valid element type for an LLVM pointer");
if (op.getParentOp() && !satisfiesLLVMModule(op.getParentOp()))
if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
auto type = op.getType();
- if (!type.getUnderlyingType()->isArrayTy() ||
- !type.getArrayElementType().getUnderlyingType()->isIntegerTy(8) ||
+ if (!type.isArrayTy() || !type.getArrayElementType().isIntegerTy(8) ||
type.getArrayNumElements() != strAttr.getValue().size())
return op.emitOpError(
"requires an i8 array type of the length equal to that of the string "
parser.resolveOperand(v2, typeV2, result.operands))
return failure();
auto wrappedContainerType1 = typeV1.dyn_cast<LLVM::LLVMType>();
- if (!wrappedContainerType1 ||
- !wrappedContainerType1.getUnderlyingType()->isVectorTy())
+ if (!wrappedContainerType1 || !wrappedContainerType1.isVectorTy())
return parser.emitError(
loc, "expected LLVM IR dialect vector type for operand #1");
auto vType = LLVMType::getVectorTy(
if (argAttrs.empty())
return;
- unsigned numInputs = type.getUnderlyingType()->getFunctionNumParams();
+ unsigned numInputs = type.getFunctionNumParams();
assert(numInputs == argAttrs.size() &&
"expected as many argument attribute lists as arguments");
SmallString<8> argAttrName;
// getNumArguments hook not failing.
LogicalResult LLVMFuncOp::verifyType() {
auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMType>();
- if (!llvmType || !llvmType.getUnderlyingType()->isFunctionTy())
+ if (!llvmType || !llvmType.isFunctionTy())
return emitOpError("requires '" + getTypeAttrName() +
"' attribute of wrapped LLVM function type");
// Hook for OpTrait::FunctionLike, returns the number of function arguments.
// Depends on the type attribute being correct as checked by verifyType
unsigned LLVMFuncOp::getNumFuncArguments() {
- return getType().getUnderlyingType()->getFunctionNumParams();
+ return getType().getFunctionNumParams();
}
// Hook for OpTrait::FunctionLike, returns the number of function results.
if (op.isVarArg())
return op.emitOpError("only external functions can be variadic");
- auto *funcType = cast<llvm::FunctionType>(op.getType().getUnderlyingType());
- unsigned numArguments = funcType->getNumParams();
+ unsigned numArguments = op.getType().getFunctionNumParams();
Block &entryBlock = op.front();
for (unsigned i = 0; i < numArguments; ++i) {
Type argType = entryBlock.getArgument(i).getType();
if (!argLLVMType)
return op.emitOpError("entry block argument #")
<< i << " is not of LLVM type";
- if (funcType->getParamType(i) != argLLVMType.getUnderlyingType())
+ if (op.getType().getFunctionParamType(i) != argLLVMType)
return op.emitOpError("the type of entry block argument #")
<< i << " does not match the function signature";
}
return op.emitOpError(
"expected LLVM IR result type to match type for operand #1");
if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) {
- if (!valType.getUnderlyingType()->isFloatingPointTy())
+ if (!valType.isFloatingPointTy())
return op.emitOpError("expected LLVM IR floating point type");
} else if (op.bin_op() == AtomicBinOp::xchg) {
if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
return getImpl()->underlyingType;
}
+void LLVMType::getUnderlyingTypes(ArrayRef<LLVMType> types,
+ SmallVectorImpl<llvm::Type *> &result) {
+ result.reserve(result.size() + types.size());
+ for (LLVMType ty : types)
+ result.push_back(ty.getUnderlyingType());
+}
+
/// Array type utilities.
LLVMType LLVMType::getArrayElementType() {
return get(getContext(), getUnderlyingType()->getArrayElementType());
return llvm::cast<llvm::FixedVectorType>(getUnderlyingType())
->getNumElements();
}
+llvm::ElementCount LLVMType::getVectorElementCount() {
+ return llvm::cast<llvm::VectorType>(getUnderlyingType())->getElementCount();
+}
bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); }
/// Function type utilities.
llvm::cast<llvm::FunctionType>(getUnderlyingType())->getReturnType());
}
bool LLVMType::isFunctionTy() { return getUnderlyingType()->isFunctionTy(); }
+bool LLVMType::isFunctionVarArg() {
+ return getUnderlyingType()->isFunctionVarArg();
+}
/// Pointer type utilities.
LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
return get(getContext(), getUnderlyingType()->getPointerElementType());
}
bool LLVMType::isPointerTy() { return getUnderlyingType()->isPointerTy(); }
+bool LLVMType::isValidPointerElementType(LLVMType type) {
+ return llvm::PointerType::isValidElementType(type.getUnderlyingType());
+}
/// Struct type utilities.
LLVMType LLVMType::getStructElementType(unsigned i) {
isPacked);
});
}
-inline static SmallVector<llvm::Type *, 8>
-toUnderlyingTypes(ArrayRef<LLVMType> elements) {
- SmallVector<llvm::Type *, 8> llvmElements;
- for (auto elt : elements)
- llvmElements.push_back(elt.getUnderlyingType());
- return llvmElements;
-}
LLVMType LLVMType::createStructTy(LLVMDialect *dialect,
ArrayRef<LLVMType> elements,
Optional<StringRef> name, bool isPacked) {
StringRef sr = name.hasValue() ? *name : "";
- SmallVector<llvm::Type *, 8> llvmElements(toUnderlyingTypes(elements));
+ SmallVector<llvm::Type *, 8> llvmElements;
+ getUnderlyingTypes(elements, llvmElements);
return getLocked(dialect, [=] {
auto *rv = llvm::StructType::create(dialect->getLLVMContext(), sr);
if (!llvmElements.empty())
ArrayRef<LLVMType> elements, bool isPacked) {
llvm::StructType *st =
llvm::cast<llvm::StructType>(structType.getUnderlyingType());
- SmallVector<llvm::Type *, 8> llvmElements(toUnderlyingTypes(elements));
+ SmallVector<llvm::Type *, 8> llvmElements;
+ getUnderlyingTypes(elements, llvmElements);
return getLocked(&structType.getDialect(), [=] {
st->setBody(llvmElements, isPacked);
return st;
bool LLVMType::isVoidTy() { return getUnderlyingType()->isVoidTy(); }
+llvm::Type *mlir::LLVM::convertLLVMType(LLVMType type) {
+ return type.getUnderlyingType();
+}
+
//===----------------------------------------------------------------------===//
// Utility functions.
//===----------------------------------------------------------------------===//
return nullptr;
if (type.isIntegerTy())
- return b.getIntegerType(type.getUnderlyingType()->getIntegerBitWidth());
+ return b.getIntegerType(type.getIntegerBitWidth());
- if (type.getUnderlyingType()->isFloatTy())
+ if (type.isFloatTy())
return b.getF32Type();
- if (type.getUnderlyingType()->isDoubleTy())
+ if (type.isDoubleTy())
return b.getF64Type();
// LLVM vectors can only contain scalars.
if (type.isVectorTy()) {
- auto numElements = llvm::cast<llvm::VectorType>(type.getUnderlyingType())
- ->getElementCount();
+ auto numElements = type.getVectorElementCount();
if (numElements.Scalable) {
emitError(unknownLoc) << "scalable vectors not supported";
return nullptr;
// attribute type.
if (type.getArrayElementType().isVectorTy()) {
LLVMType vectorType = type.getArrayElementType();
- auto numElements =
- llvm::cast<llvm::VectorType>(vectorType.getUnderlyingType())
- ->getElementCount();
+ auto numElements = vectorType.getVectorElementCount();
if (numElements.Scalable) {
emitError(unknownLoc) << "scalable vectors not supported";
return nullptr;
}
if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
- llvm::Type *ty = lpOp.getType().dyn_cast<LLVMType>().getUnderlyingType();
+ llvm::Type *ty = convertType(lpOp.getType().cast<LLVMType>());
llvm::LandingPadInst *lpi =
builder.CreateLandingPad(ty, lpOp.getNumOperands());
if (!wrappedType)
return emitError(bb.front().getLoc(),
"block argument does not have an LLVM type");
- llvm::Type *type = wrappedType.getUnderlyingType();
+ llvm::Type *type = convertType(wrappedType);
llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
valueMapping[arg] = phi;
}
llvm::sys::SmartScopedLock<true> scopedLock(
llvmDialect->getLLVMContextMutex());
for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
- llvm::Type *type = op.getType().getUnderlyingType();
+ llvm::Type *type = convertType(op.getType());
llvm::Constant *cst = llvm::UndefValue::get(type);
if (op.getValueOrNull()) {
// String attributes are treated separately because they cannot appear as
// NB: Attribute already verified to be boolean, so check if we can indeed
// attach the attribute to this argument, based on its type.
auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMType>();
- if (!argTy.getUnderlyingType()->isPointerTy())
+ if (!argTy.isPointerTy())
return func.emitError(
"llvm.noalias attribute attached to LLVM non-pointer argument");
if (attr.getValue())
// NB: Attribute already verified to be int, so check if we can indeed
// attach the attribute to this argument, based on its type.
auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMType>();
- if (!argTy.getUnderlyingType()->isPointerTy())
+ if (!argTy.isPointerTy())
return func.emitError(
"llvm.align attribute attached to LLVM non-pointer argument");
llvmArg.addAttrs(
for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
function.getName(),
- cast<llvm::FunctionType>(function.getType().getUnderlyingType()));
+ cast<llvm::FunctionType>(convertType(function.getType())));
llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee());
llvmFunc->setLinkage(convertLinkageToLLVM(function.linkage()));
functionMapping[function.getName()] = llvmFunc;
return success();
}
+llvm::Type *ModuleTranslation::convertType(LLVMType type) {
+ return LLVM::convertLLVMType(type);
+}
+
/// A helper to look up remapped operands in the value remapping table.`
SmallVector<llvm::Value *, 8>
ModuleTranslation::lookupValues(ValueRange values) {
} else if (isResultName(op, name)) {
bs << formatv("valueMapping[op.{0}()]", name);
} else if (name == "_resultType") {
- bs << "op.getResult().getType().cast<LLVM::LLVMType>()."
- "getUnderlyingType()";
+ bs << "convertType(op.getResult().getType().cast<LLVM::LLVMType>())";
} else if (name == "_hasResult") {
bs << "opInst.getNumResults() == 1";
} else if (name == "_location") {