From: Alex Zinenko Date: Thu, 23 Jul 2020 08:32:12 +0000 (+0200) Subject: [mlir] LLVMType: make getUnderlyingType private X-Git-Tag: llvmorg-13-init~16461 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=aec38c619dfa1c41b6b0f6c52a31d221ac108b6b;p=platform%2Fupstream%2Fllvm.git [mlir] LLVMType: make getUnderlyingType private The current modeling of LLVM IR types in MLIR is based on the LLVMType class that wraps a raw `llvm::Type *` and delegates uniquing, printing and parsing to LLVM itself. This is model makes thread-safe type manipulation hard and is being progressively replaced with a cleaner MLIR model that replicates the type system. In the new model, LLVMType will no longer have an underlying LLVM IR type. Restrict access to this type in the current model in preparation for the change. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D84389 --- diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index 078cb1cfa4e5..52acfbfa8e50 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -47,6 +47,14 @@ struct LLVMTypeStorage; struct LLVMDialectImpl; } // namespace detail +class LLVMType; + +/// Converts an MLIR LLVM dialect type to LLVM IR type. Note that this function +/// exists exclusively for the purpose of gradual transition to the first-party +/// modeling of LLVM types. It should not be used outside MLIR-to-LLVM +/// translation. +llvm::Type *convertLLVMType(LLVMType type); + class LLVMType : public mlir::Type::TypeBase { public: @@ -59,26 +67,32 @@ public: static bool kindof(unsigned kind) { return kind == LLVM_TYPE; } LLVMDialect &getDialect(); - llvm::Type *getUnderlyingType() const; /// Utilities to identify types. bool isBFloatTy() { return getUnderlyingType()->isBFloatTy(); } bool isHalfTy() { return getUnderlyingType()->isHalfTy(); } bool isFloatTy() { return getUnderlyingType()->isFloatTy(); } bool isDoubleTy() { return getUnderlyingType()->isDoubleTy(); } - bool isIntegerTy() { return getUnderlyingType()->isIntegerTy(); } - bool isIntegerTy(unsigned bitwidth) { - return getUnderlyingType()->isIntegerTy(bitwidth); - } + bool isFloatingPointTy() { return getUnderlyingType()->isFloatingPointTy(); } /// Array type utilities. LLVMType getArrayElementType(); unsigned getArrayNumElements(); bool isArrayTy(); + /// Integer type utilities. + unsigned getIntegerBitWidth() { + return getUnderlyingType()->getIntegerBitWidth(); + } + bool isIntegerTy() { return getUnderlyingType()->isIntegerTy(); } + bool isIntegerTy(unsigned bitwidth) { + return getUnderlyingType()->isIntegerTy(bitwidth); + } + /// Vector type utilities. LLVMType getVectorElementType(); unsigned getVectorNumElements(); + llvm::ElementCount getVectorElementCount(); bool isVectorTy(); /// Function type utilities. @@ -86,11 +100,13 @@ public: unsigned getFunctionNumParams(); LLVMType getFunctionResultType(); bool isFunctionTy(); + bool isFunctionVarArg(); /// Pointer type utilities. LLVMType getPointerTo(unsigned addrSpace = 0); LLVMType getPointerElementTy(); bool isPointerTy(); + static bool isValidPointerElementType(LLVMType type); /// Struct type utilities. LLVMType getStructElementType(unsigned i); @@ -194,6 +210,14 @@ public: private: friend LLVMDialect; + friend llvm::Type *convertLLVMType(LLVMType type); + + /// Get the underlying LLVM IR type. + llvm::Type *getUnderlyingType() const; + + /// Get the underlying LLVM IR types for the given array of types. + static void getUnderlyingTypes(ArrayRef types, + SmallVectorImpl &result); /// Get an LLVMType with a pre-existing llvm type. static LLVMType get(MLIRContext *context, llvm::Type *llvmType); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index d88b372dbf43..4d99bf265c65 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -134,11 +134,9 @@ class ListIntSubst values> { // or result in the operation. def LLVM_IntrPatterns { string operand = - [{opInst.getOperand($0).getType() - .cast().getUnderlyingType()}]; + [{convertType(opInst.getOperand($0).getType().cast())}]; string result = - [{opInst.getResult($0).getType() - .cast().getUnderlyingType()}]; + [{convertType(opInst.getResult($0).getType().cast())}]; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 29d7fd930030..4da90575524b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -61,9 +61,8 @@ def LLVM_VoidResultTypeOpBuilder : OpBuilder< [{ auto llvmType = resultType.dyn_cast(); (void)llvmType; assert(llvmType && "result must be an LLVM type"); - assert(llvmType.getUnderlyingType() && - llvmType.getUnderlyingType()->isVoidTy() && - "for zero-result operands, only 'void' is accepted as result type"); + assert(llvmType.isVoidTy() && + "for zero-result operands, only 'void' is accepted as result type"); build(builder, result, operands, attributes); }]>; @@ -477,7 +476,7 @@ def LLVM_ShuffleVectorOp let verifier = [{ auto wrappedVectorType1 = v1().getType().cast(); auto wrappedVectorType2 = v2().getType().cast(); - if (!wrappedVectorType2.getUnderlyingType()->isVectorTy()) + if (!wrappedVectorType2.isVectorTy()) return emitOpError("expected LLVM IR Dialect vector type for operand #2"); if (wrappedVectorType1.getVectorElementType() != wrappedVectorType2.getVectorElementType()) @@ -765,10 +764,10 @@ def LLVM_LLVMFuncOp .getValue().cast(); } bool isVarArg() { - return getType().getUnderlyingType()->isFunctionVarArg(); + return getType().isFunctionVarArg(); } - // Hook for OpTrait::FunctionLike, returns the number of function arguments. + // Hook for OpTrait::FunctionLike, returns the number of function arguments`. // Depends on the type attribute being correct as checked by verifyType. unsigned getNumFuncArguments(); diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 786b9ef217bd..0cd11690daa8 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -139,7 +139,7 @@ def ROCDL_mfma_i32_16x16x16i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x16i8">; // Vector buffer load/store intrinsics def ROCDL_MubufLoadOp : - ROCDL_Op<"buffer.load">, + ROCDL_Op<"buffer.load">, Results<(outs LLVM_Type:$res)>, Arguments<(ins LLVM_Type:$rsrc, LLVM_Type:$vindex, @@ -160,7 +160,7 @@ def ROCDL_MubufLoadOp : } def ROCDL_MubufStoreOp : - ROCDL_Op<"buffer.store">, + ROCDL_Op<"buffer.store">, Arguments<(ins LLVM_Type:$vdata, LLVM_Type:$rsrc, LLVM_Type:$vindex, @@ -168,14 +168,13 @@ def ROCDL_MubufStoreOp : LLVM_Type:$glc, LLVM_Type:$slc)>{ string llvmBuilder = [{ - auto vdataType = op.vdata().getType().cast() - .getUnderlyingType(); + auto vdataType = convertType(op.vdata().getType().cast()); createIntrinsicCall(builder, - llvm::Intrinsic::amdgcn_buffer_store, {$vdata, $rsrc, $vindex, + llvm::Intrinsic::amdgcn_buffer_store, {$vdata, $rsrc, $vindex, $offset, $glc, $slc}, {vdataType}); }]; let parser = [{ return parseROCDLMubufStoreOp(parser, result); }]; - let printer = [{ + let printer = [{ Operation *op = this->getOperation(); p << op->getName() << " " << op->getOperands() << " : " << vdata().getType(); diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index e44ae976e0dd..61f8f9fce64c 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -89,6 +89,10 @@ protected: llvm::IRBuilder<> &builder); virtual LogicalResult convertOmpParallel(Operation &op, llvm::IRBuilder<> &builder); + + /// Converts the type from MLIR LLVM dialect to LLVM. + llvm::Type *convertType(LLVMType type); + static std::unique_ptr prepareLLVMModule(Operation *m); /// A helper to look up remapped operands in the value remapping table. diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp index 25a3ac07d5f4..fd0e96b79d2b 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -64,10 +64,8 @@ static unsigned getBitWidth(Type type) { /// Returns the bit width of LLVMType integer or vector. static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) { - return type.isVectorTy() ? type.getVectorElementType() - .getUnderlyingType() - ->getIntegerBitWidth() - : type.getUnderlyingType()->getIntegerBitWidth(); + return type.isVectorTy() ? type.getVectorElementType().getIntegerBitWidth() + : type.getIntegerBitWidth(); } /// Creates `IntegerAttribute` with all bits set for given type diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 1e6fa6a8754b..0d154796f049 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2248,10 +2248,8 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern { op, operands, typeConverter, [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( - mlir::VectorType::get( - {cast(llvmVectorTy.getUnderlyingType()) - ->getNumElements()}, - floatType), + mlir::VectorType::get({llvmVectorTy.getVectorNumElements()}, + floatType), floatOne); auto one = rewriter.create(loc, llvmVectorTy, splatAttr); @@ -2511,8 +2509,8 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern { this->typeConverter.convertType(indexCastOp.getResult().getType()) .cast(); auto sourceType = transformed.in().getType().cast(); - unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth(); - unsigned sourceBits = sourceType.getUnderlyingType()->getIntegerBitWidth(); + unsigned targetBits = targetType.getIntegerBitWidth(); + unsigned sourceBits = sourceType.getIntegerBitWidth(); if (targetBits == sourceBits) rewriter.replaceOp(op, transformed.in()); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 5dbc8394b03a..4fa7b573f84e 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -127,7 +127,7 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op, auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout(); align = dataLayout.getPrefTypeAlignment( - elementTy.cast().getUnderlyingType()); + LLVM::convertLLVMType(elementTy.cast())); return success(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index cf7a5d926528..17848c6bf3ee 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -105,11 +105,9 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { auto argType = type.dyn_cast(); if (!argType) return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"); - if (argType.getUnderlyingType()->isVectorTy()) - resultType = LLVMType::getVectorTy( - resultType, - llvm::cast(argType.getUnderlyingType()) - ->getNumElements()); + if (argType.isVectorTy()) + resultType = + LLVMType::getVectorTy(resultType, argType.getVectorNumElements()); result.addTypes({resultType}); return success(); @@ -214,7 +212,7 @@ static Type getLoadStoreElementType(OpAsmParser &parser, Type type, if (!llvmTy) return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"), nullptr; - if (!llvmTy.getUnderlyingType()->isPointerTy()) + if (!llvmTy.isPointerTy()) return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"), nullptr; return llvmTy.getPointerElementTy(); @@ -683,8 +681,7 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser, parser.resolveOperand(position, positionType, result.operands)) return failure(); auto wrappedVectorType = type.dyn_cast(); - if (!wrappedVectorType || - !wrappedVectorType.getUnderlyingType()->isVectorTy()) + if (!wrappedVectorType || !wrappedVectorType.isVectorTy()) return parser.emitError( loc, "expected LLVM IR dialect vector type for operand #1"); result.addTypes(wrappedVectorType.getVectorElementType()); @@ -725,16 +722,15 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser, "expected an array of integer literals"), nullptr; int position = positionElementAttr.getInt(); - auto *llvmContainerType = wrappedContainerType.getUnderlyingType(); - if (llvmContainerType->isArrayTy()) { + if (wrappedContainerType.isArrayTy()) { if (position < 0 || static_cast(position) >= - llvmContainerType->getArrayNumElements()) + wrappedContainerType.getArrayNumElements()) return parser.emitError(attributeLoc, "position out of bounds"), nullptr; wrappedContainerType = wrappedContainerType.getArrayElementType(); - } else if (llvmContainerType->isStructTy()) { + } else if (wrappedContainerType.isStructTy()) { if (position < 0 || static_cast(position) >= - llvmContainerType->getStructNumElements()) + wrappedContainerType.getStructNumElements()) return parser.emitError(attributeLoc, "position out of bounds"), nullptr; wrappedContainerType = @@ -803,8 +799,7 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser, return failure(); auto wrappedVectorType = vectorType.dyn_cast(); - if (!wrappedVectorType || - !wrappedVectorType.getUnderlyingType()->isVectorTy()) + if (!wrappedVectorType || !wrappedVectorType.isVectorTy()) return parser.emitError( loc, "expected LLVM IR dialect vector type for operand #1"); auto valueType = wrappedVectorType.getVectorElementType(); @@ -1125,7 +1120,7 @@ static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { } static LogicalResult verify(GlobalOp op) { - if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType())) + if (!LLVMType::isValidPointerElementType(op.getType())) return op.emitOpError( "expects type to be a valid element type for an LLVM pointer"); if (op.getParentOp() && !satisfiesLLVMModule(op.getParentOp())) @@ -1133,8 +1128,7 @@ static LogicalResult verify(GlobalOp op) { if (auto strAttr = op.getValueOrNull().dyn_cast_or_null()) { auto type = op.getType(); - if (!type.getUnderlyingType()->isArrayTy() || - !type.getArrayElementType().getUnderlyingType()->isIntegerTy(8) || + if (!type.isArrayTy() || !type.getArrayElementType().isIntegerTy(8) || type.getArrayNumElements() != strAttr.getValue().size()) return op.emitOpError( "requires an i8 array type of the length equal to that of the string " @@ -1197,8 +1191,7 @@ static ParseResult parseShuffleVectorOp(OpAsmParser &parser, parser.resolveOperand(v2, typeV2, result.operands)) return failure(); auto wrappedContainerType1 = typeV1.dyn_cast(); - if (!wrappedContainerType1 || - !wrappedContainerType1.getUnderlyingType()->isVectorTy()) + if (!wrappedContainerType1 || !wrappedContainerType1.isVectorTy()) return parser.emitError( loc, "expected LLVM IR dialect vector type for operand #1"); auto vType = LLVMType::getVectorTy( @@ -1239,7 +1232,7 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, if (argAttrs.empty()) return; - unsigned numInputs = type.getUnderlyingType()->getFunctionNumParams(); + unsigned numInputs = type.getFunctionNumParams(); assert(numInputs == argAttrs.size() && "expected as many argument attribute lists as arguments"); SmallString<8> argAttrName; @@ -1374,7 +1367,7 @@ static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) { // getNumArguments hook not failing. LogicalResult LLVMFuncOp::verifyType() { auto llvmType = getTypeAttr().getValue().dyn_cast_or_null(); - if (!llvmType || !llvmType.getUnderlyingType()->isFunctionTy()) + if (!llvmType || !llvmType.isFunctionTy()) return emitOpError("requires '" + getTypeAttrName() + "' attribute of wrapped LLVM function type"); @@ -1384,7 +1377,7 @@ LogicalResult LLVMFuncOp::verifyType() { // Hook for OpTrait::FunctionLike, returns the number of function arguments. // Depends on the type attribute being correct as checked by verifyType unsigned LLVMFuncOp::getNumFuncArguments() { - return getType().getUnderlyingType()->getFunctionNumParams(); + return getType().getFunctionNumParams(); } // Hook for OpTrait::FunctionLike, returns the number of function results. @@ -1424,8 +1417,7 @@ static LogicalResult verify(LLVMFuncOp op) { if (op.isVarArg()) return op.emitOpError("only external functions can be variadic"); - auto *funcType = cast(op.getType().getUnderlyingType()); - unsigned numArguments = funcType->getNumParams(); + unsigned numArguments = op.getType().getFunctionNumParams(); Block &entryBlock = op.front(); for (unsigned i = 0; i < numArguments; ++i) { Type argType = entryBlock.getArgument(i).getType(); @@ -1433,7 +1425,7 @@ static LogicalResult verify(LLVMFuncOp op) { if (!argLLVMType) return op.emitOpError("entry block argument #") << i << " is not of LLVM type"; - if (funcType->getParamType(i) != argLLVMType.getUnderlyingType()) + if (op.getType().getFunctionParamType(i) != argLLVMType) return op.emitOpError("the type of entry block argument #") << i << " does not match the function signature"; } @@ -1566,7 +1558,7 @@ static LogicalResult verify(AtomicRMWOp op) { return op.emitOpError( "expected LLVM IR result type to match type for operand #1"); if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) { - if (!valType.getUnderlyingType()->isFloatingPointTy()) + if (!valType.isFloatingPointTy()) return op.emitOpError("expected LLVM IR floating point type"); } else if (op.bin_op() == AtomicBinOp::xchg) { if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) && @@ -1842,6 +1834,13 @@ llvm::Type *LLVMType::getUnderlyingType() const { return getImpl()->underlyingType; } +void LLVMType::getUnderlyingTypes(ArrayRef types, + SmallVectorImpl &result) { + result.reserve(result.size() + types.size()); + for (LLVMType ty : types) + result.push_back(ty.getUnderlyingType()); +} + /// Array type utilities. LLVMType LLVMType::getArrayElementType() { return get(getContext(), getUnderlyingType()->getArrayElementType()); @@ -1861,6 +1860,9 @@ unsigned LLVMType::getVectorNumElements() { return llvm::cast(getUnderlyingType()) ->getNumElements(); } +llvm::ElementCount LLVMType::getVectorElementCount() { + return llvm::cast(getUnderlyingType())->getElementCount(); +} bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); } /// Function type utilities. @@ -1876,6 +1878,9 @@ LLVMType LLVMType::getFunctionResultType() { llvm::cast(getUnderlyingType())->getReturnType()); } bool LLVMType::isFunctionTy() { return getUnderlyingType()->isFunctionTy(); } +bool LLVMType::isFunctionVarArg() { + return getUnderlyingType()->isFunctionVarArg(); +} /// Pointer type utilities. LLVMType LLVMType::getPointerTo(unsigned addrSpace) { @@ -1888,6 +1893,9 @@ LLVMType LLVMType::getPointerElementTy() { return get(getContext(), getUnderlyingType()->getPointerElementType()); } bool LLVMType::isPointerTy() { return getUnderlyingType()->isPointerTy(); } +bool LLVMType::isValidPointerElementType(LLVMType type) { + return llvm::PointerType::isValidElementType(type.getUnderlyingType()); +} /// Struct type utilities. LLVMType LLVMType::getStructElementType(unsigned i) { @@ -1974,18 +1982,12 @@ LLVMType LLVMType::getStructTy(LLVMDialect *dialect, isPacked); }); } -inline static SmallVector -toUnderlyingTypes(ArrayRef elements) { - SmallVector llvmElements; - for (auto elt : elements) - llvmElements.push_back(elt.getUnderlyingType()); - return llvmElements; -} LLVMType LLVMType::createStructTy(LLVMDialect *dialect, ArrayRef elements, Optional name, bool isPacked) { StringRef sr = name.hasValue() ? *name : ""; - SmallVector llvmElements(toUnderlyingTypes(elements)); + SmallVector llvmElements; + getUnderlyingTypes(elements, llvmElements); return getLocked(dialect, [=] { auto *rv = llvm::StructType::create(dialect->getLLVMContext(), sr); if (!llvmElements.empty()) @@ -1997,7 +1999,8 @@ LLVMType LLVMType::setStructTyBody(LLVMType structType, ArrayRef elements, bool isPacked) { llvm::StructType *st = llvm::cast(structType.getUnderlyingType()); - SmallVector llvmElements(toUnderlyingTypes(elements)); + SmallVector llvmElements; + getUnderlyingTypes(elements, llvmElements); return getLocked(&structType.getDialect(), [=] { st->setBody(llvmElements, isPacked); return st; @@ -2017,6 +2020,10 @@ LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) { bool LLVMType::isVoidTy() { return getUnderlyingType()->isVoidTy(); } +llvm::Type *mlir::LLVM::convertLLVMType(LLVMType type) { + return type.getUnderlyingType(); +} + //===----------------------------------------------------------------------===// // Utility functions. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp index 9754c614efdf..77897d65e1a5 100644 --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -234,18 +234,17 @@ Type Importer::getStdTypeForAttr(LLVMType type) { return nullptr; if (type.isIntegerTy()) - return b.getIntegerType(type.getUnderlyingType()->getIntegerBitWidth()); + return b.getIntegerType(type.getIntegerBitWidth()); - if (type.getUnderlyingType()->isFloatTy()) + if (type.isFloatTy()) return b.getF32Type(); - if (type.getUnderlyingType()->isDoubleTy()) + if (type.isDoubleTy()) return b.getF64Type(); // LLVM vectors can only contain scalars. if (type.isVectorTy()) { - auto numElements = llvm::cast(type.getUnderlyingType()) - ->getElementCount(); + auto numElements = type.getVectorElementCount(); if (numElements.Scalable) { emitError(unknownLoc) << "scalable vectors not supported"; return nullptr; @@ -270,9 +269,7 @@ Type Importer::getStdTypeForAttr(LLVMType type) { // attribute type. if (type.getArrayElementType().isVectorTy()) { LLVMType vectorType = type.getArrayElementType(); - auto numElements = - llvm::cast(vectorType.getUnderlyingType()) - ->getElementCount(); + auto numElements = vectorType.getVectorElementCount(); if (numElements.Scalable) { emitError(unknownLoc) << "scalable vectors not supported"; return nullptr; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 3a70dd3932e9..a0aefc988a5d 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -574,7 +574,7 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst, } if (auto lpOp = dyn_cast(opInst)) { - llvm::Type *ty = lpOp.getType().dyn_cast().getUnderlyingType(); + llvm::Type *ty = convertType(lpOp.getType().cast()); llvm::LandingPadInst *lpi = builder.CreateLandingPad(ty, lpOp.getNumOperands()); @@ -661,7 +661,7 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { if (!wrappedType) return emitError(bb.front().getLoc(), "block argument does not have an LLVM type"); - llvm::Type *type = wrappedType.getUnderlyingType(); + llvm::Type *type = convertType(wrappedType); llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors); valueMapping[arg] = phi; } @@ -687,7 +687,7 @@ LogicalResult ModuleTranslation::convertGlobals() { llvm::sys::SmartScopedLock scopedLock( llvmDialect->getLLVMContextMutex()); for (auto op : getModuleBody(mlirModule).getOps()) { - llvm::Type *type = op.getType().getUnderlyingType(); + llvm::Type *type = convertType(op.getType()); llvm::Constant *cst = llvm::UndefValue::get(type); if (op.getValueOrNull()) { // String attributes are treated separately because they cannot appear as @@ -826,7 +826,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { // NB: Attribute already verified to be boolean, so check if we can indeed // attach the attribute to this argument, based on its type. auto argTy = mlirArg.getType().dyn_cast(); - if (!argTy.getUnderlyingType()->isPointerTy()) + if (!argTy.isPointerTy()) return func.emitError( "llvm.noalias attribute attached to LLVM non-pointer argument"); if (attr.getValue()) @@ -837,7 +837,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { // NB: Attribute already verified to be int, so check if we can indeed // attach the attribute to this argument, based on its type. auto argTy = mlirArg.getType().dyn_cast(); - if (!argTy.getUnderlyingType()->isPointerTy()) + if (!argTy.isPointerTy()) return func.emitError( "llvm.align attribute attached to LLVM non-pointer argument"); llvmArg.addAttrs( @@ -896,7 +896,7 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() { for (auto function : getModuleBody(mlirModule).getOps()) { llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction( function.getName(), - cast(function.getType().getUnderlyingType())); + cast(convertType(function.getType()))); llvm::Function *llvmFunc = cast(llvmFuncCst.getCallee()); llvmFunc->setLinkage(convertLinkageToLLVM(function.linkage())); functionMapping[function.getName()] = llvmFunc; @@ -928,6 +928,10 @@ LogicalResult ModuleTranslation::convertFunctions() { return success(); } +llvm::Type *ModuleTranslation::convertType(LLVMType type) { + return LLVM::convertLLVMType(type); +} + /// A helper to look up remapped operands in the value remapping table.` SmallVector ModuleTranslation::lookupValues(ValueRange values) { diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp index 9edbad8fdd54..f62e7aebe24a 100644 --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -135,8 +135,7 @@ static bool emitOneBuilder(const Record &record, raw_ostream &os) { } else if (isResultName(op, name)) { bs << formatv("valueMapping[op.{0}()]", name); } else if (name == "_resultType") { - bs << "op.getResult().getType().cast()." - "getUnderlyingType()"; + bs << "convertType(op.getResult().getType().cast())"; } else if (name == "_hasResult") { bs << "opInst.getNumResults() == 1"; } else if (name == "_location") {