[mlir] LLVMType: make getUnderlyingType private
authorAlex Zinenko <zinenko@google.com>
Thu, 23 Jul 2020 08:32:12 +0000 (10:32 +0200)
committerAlex Zinenko <zinenko@google.com>
Wed, 29 Jul 2020 11:43:38 +0000 (13:43 +0200)
The current modeling of LLVM IR types in MLIR is based on the LLVMType class
that wraps a raw `llvm::Type *` and delegates uniquing, printing and parsing to
LLVM itself. This is model makes thread-safe type manipulation hard and is
being progressively replaced with a cleaner MLIR model that replicates the type
system. In the new model, LLVMType will no longer have an underlying LLVM IR
type. Restrict access to this type in the current model in preparation for the
change.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D84389

12 files changed:
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp

index 078cb1cfa4e57c712062463202d3883e236bca1e..52acfbfa8e507b6b612232d1b465e80be0afdb94 100644 (file)
@@ -47,6 +47,14 @@ struct LLVMTypeStorage;
 struct LLVMDialectImpl;
 } // namespace detail
 
+class LLVMType;
+
+/// Converts an MLIR LLVM dialect type to LLVM IR type. Note that this function
+/// exists exclusively for the purpose of gradual transition to the first-party
+/// modeling of LLVM types. It should not be used outside MLIR-to-LLVM
+/// translation.
+llvm::Type *convertLLVMType(LLVMType type);
+
 class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
                                              detail::LLVMTypeStorage> {
 public:
@@ -59,26 +67,32 @@ public:
   static bool kindof(unsigned kind) { return kind == LLVM_TYPE; }
 
   LLVMDialect &getDialect();
-  llvm::Type *getUnderlyingType() const;
 
   /// Utilities to identify types.
   bool isBFloatTy() { return getUnderlyingType()->isBFloatTy(); }
   bool isHalfTy() { return getUnderlyingType()->isHalfTy(); }
   bool isFloatTy() { return getUnderlyingType()->isFloatTy(); }
   bool isDoubleTy() { return getUnderlyingType()->isDoubleTy(); }
-  bool isIntegerTy() { return getUnderlyingType()->isIntegerTy(); }
-  bool isIntegerTy(unsigned bitwidth) {
-    return getUnderlyingType()->isIntegerTy(bitwidth);
-  }
+  bool isFloatingPointTy() { return getUnderlyingType()->isFloatingPointTy(); }
 
   /// Array type utilities.
   LLVMType getArrayElementType();
   unsigned getArrayNumElements();
   bool isArrayTy();
 
+  /// Integer type utilities.
+  unsigned getIntegerBitWidth() {
+    return getUnderlyingType()->getIntegerBitWidth();
+  }
+  bool isIntegerTy() { return getUnderlyingType()->isIntegerTy(); }
+  bool isIntegerTy(unsigned bitwidth) {
+    return getUnderlyingType()->isIntegerTy(bitwidth);
+  }
+
   /// Vector type utilities.
   LLVMType getVectorElementType();
   unsigned getVectorNumElements();
+  llvm::ElementCount getVectorElementCount();
   bool isVectorTy();
 
   /// Function type utilities.
@@ -86,11 +100,13 @@ public:
   unsigned getFunctionNumParams();
   LLVMType getFunctionResultType();
   bool isFunctionTy();
+  bool isFunctionVarArg();
 
   /// Pointer type utilities.
   LLVMType getPointerTo(unsigned addrSpace = 0);
   LLVMType getPointerElementTy();
   bool isPointerTy();
+  static bool isValidPointerElementType(LLVMType type);
 
   /// Struct type utilities.
   LLVMType getStructElementType(unsigned i);
@@ -194,6 +210,14 @@ public:
 
 private:
   friend LLVMDialect;
+  friend llvm::Type *convertLLVMType(LLVMType type);
+
+  /// Get the underlying LLVM IR type.
+  llvm::Type *getUnderlyingType() const;
+
+  /// Get the underlying LLVM IR types for the given array of types.
+  static void getUnderlyingTypes(ArrayRef<LLVMType> types,
+                                 SmallVectorImpl<llvm::Type *> &result);
 
   /// Get an LLVMType with a pre-existing llvm type.
   static LLVMType get(MLIRContext *context, llvm::Type *llvmType);
index d88b372dbf43dfb8c6f9bb1d497047dd72b780f2..4d99bf265c65fdf6e096f3c7509071fa31573080 100644 (file)
@@ -134,11 +134,9 @@ class ListIntSubst<string pattern, list<int> values> {
 // or result in the operation.
 def LLVM_IntrPatterns {
   string operand =
-    [{opInst.getOperand($0).getType()
-      .cast<LLVM::LLVMType>().getUnderlyingType()}];
+    [{convertType(opInst.getOperand($0).getType().cast<LLVM::LLVMType>())}];
   string result =
-    [{opInst.getResult($0).getType()
-      .cast<LLVM::LLVMType>().getUnderlyingType()}];
+    [{convertType(opInst.getResult($0).getType().cast<LLVM::LLVMType>())}];
 }
 
 
index 29d7fd930030aca80c651d1c76ed9496a59d62d0..4da90575524b7c9b0b3c8429f1208dcb78a9f5e6 100644 (file)
@@ -61,9 +61,8 @@ def LLVM_VoidResultTypeOpBuilder : OpBuilder<
   [{
     auto llvmType = resultType.dyn_cast<LLVM::LLVMType>(); (void)llvmType;
     assert(llvmType && "result must be an LLVM type");
-    assert(llvmType.getUnderlyingType() &&
-            llvmType.getUnderlyingType()->isVoidTy() &&
-            "for zero-result operands, only 'void' is accepted as result type");
+    assert(llvmType.isVoidTy() &&
+           "for zero-result operands, only 'void' is accepted as result type");
     build(builder, result, operands, attributes);
   }]>;
 
@@ -477,7 +476,7 @@ def LLVM_ShuffleVectorOp
   let verifier = [{
     auto wrappedVectorType1 = v1().getType().cast<LLVM::LLVMType>();
     auto wrappedVectorType2 = v2().getType().cast<LLVM::LLVMType>();
-    if (!wrappedVectorType2.getUnderlyingType()->isVectorTy())
+    if (!wrappedVectorType2.isVectorTy())
       return emitOpError("expected LLVM IR Dialect vector type for operand #2");
     if (wrappedVectorType1.getVectorElementType() !=
         wrappedVectorType2.getVectorElementType())
@@ -765,10 +764,10 @@ def LLVM_LLVMFuncOp
           .getValue().cast<LLVMType>();
     }
     bool isVarArg() {
-      return getType().getUnderlyingType()->isFunctionVarArg();
+      return getType().isFunctionVarArg();
     }
 
-    // Hook for OpTrait::FunctionLike, returns the number of function arguments.
+    // Hook for OpTrait::FunctionLike, returns the number of function arguments`.
     // Depends on the type attribute being correct as checked by verifyType.
     unsigned getNumFuncArguments();
 
index 786b9ef217bd793df00aac4ce276602c5e5b3a80..0cd11690daa8ba4aaddf80054755f21f4e758471 100644 (file)
@@ -139,7 +139,7 @@ def ROCDL_mfma_i32_16x16x16i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x16i8">;
 // Vector buffer load/store intrinsics
 
 def ROCDL_MubufLoadOp :
-  ROCDL_Op<"buffer.load">, 
+  ROCDL_Op<"buffer.load">,
   Results<(outs LLVM_Type:$res)>,
   Arguments<(ins LLVM_Type:$rsrc,
                  LLVM_Type:$vindex,
@@ -160,7 +160,7 @@ def ROCDL_MubufLoadOp :
 }
 
 def ROCDL_MubufStoreOp :
-  ROCDL_Op<"buffer.store">, 
+  ROCDL_Op<"buffer.store">,
   Arguments<(ins LLVM_Type:$vdata,
                  LLVM_Type:$rsrc,
                  LLVM_Type:$vindex,
@@ -168,14 +168,13 @@ def ROCDL_MubufStoreOp :
                  LLVM_Type:$glc,
                  LLVM_Type:$slc)>{
   string llvmBuilder = [{
-    auto vdataType = op.vdata().getType().cast<LLVM::LLVMType>()
-                       .getUnderlyingType();
+    auto vdataType = convertType(op.vdata().getType().cast<LLVM::LLVMType>());
     createIntrinsicCall(builder,
-          llvm::Intrinsic::amdgcn_buffer_store, {$vdata, $rsrc, $vindex, 
+          llvm::Intrinsic::amdgcn_buffer_store, {$vdata, $rsrc, $vindex,
           $offset, $glc, $slc}, {vdataType});
   }];
   let parser = [{ return parseROCDLMubufStoreOp(parser, result); }];
-  let printer = [{ 
+  let printer = [{
     Operation *op = this->getOperation();
     p << op->getName() << " " << op->getOperands()
       << " : " << vdata().getType();
index e44ae976e0dd0050b85933544ea79eb3b9f806fe..61f8f9fce64c03dd624df4a94cc2a71af6b0f2d9 100644 (file)
@@ -89,6 +89,10 @@ protected:
                                             llvm::IRBuilder<> &builder);
   virtual LogicalResult convertOmpParallel(Operation &op,
                                            llvm::IRBuilder<> &builder);
+
+  /// Converts the type from MLIR LLVM dialect to LLVM.
+  llvm::Type *convertType(LLVMType type);
+
   static std::unique_ptr<llvm::Module> prepareLLVMModule(Operation *m);
 
   /// A helper to look up remapped operands in the value remapping table.
index 25a3ac07d5f4a661d9cdaef9a284aaed715af748..fd0e96b79d2b58c994601739a9d6757581d3c1bc 100644 (file)
@@ -64,10 +64,8 @@ static unsigned getBitWidth(Type type) {
 
 /// Returns the bit width of LLVMType integer or vector.
 static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) {
-  return type.isVectorTy() ? type.getVectorElementType()
-                                 .getUnderlyingType()
-                                 ->getIntegerBitWidth()
-                           : type.getUnderlyingType()->getIntegerBitWidth();
+  return type.isVectorTy() ? type.getVectorElementType().getIntegerBitWidth()
+                           : type.getIntegerBitWidth();
 }
 
 /// Creates `IntegerAttribute` with all bits set for given type
index 1e6fa6a8754ba739be2dc3a5b8688971f38a3242..0d154796f049c68d5366daf8932253c2180bb172 100644 (file)
@@ -2248,10 +2248,8 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
         op, operands, typeConverter,
         [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
           auto splatAttr = SplatElementsAttr::get(
-              mlir::VectorType::get(
-                  {cast<llvm::FixedVectorType>(llvmVectorTy.getUnderlyingType())
-                       ->getNumElements()},
-                  floatType),
+              mlir::VectorType::get({llvmVectorTy.getVectorNumElements()},
+                                    floatType),
               floatOne);
           auto one =
               rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
@@ -2511,8 +2509,8 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
         this->typeConverter.convertType(indexCastOp.getResult().getType())
             .cast<LLVM::LLVMType>();
     auto sourceType = transformed.in().getType().cast<LLVM::LLVMType>();
-    unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth();
-    unsigned sourceBits = sourceType.getUnderlyingType()->getIntegerBitWidth();
+    unsigned targetBits = targetType.getIntegerBitWidth();
+    unsigned sourceBits = sourceType.getIntegerBitWidth();
 
     if (targetBits == sourceBits)
       rewriter.replaceOp(op, transformed.in());
index 5dbc8394b03a562a54cb2aae4db377ea943b001a..4fa7b573f84ef1c7ec47f02d095b4a67895fc6ae 100644 (file)
@@ -127,7 +127,7 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
 
   auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout();
   align = dataLayout.getPrefTypeAlignment(
-      elementTy.cast<LLVM::LLVMType>().getUnderlyingType());
+      LLVM::convertLLVMType(elementTy.cast<LLVM::LLVMType>()));
   return success();
 }
 
index cf7a5d926528f461f080b01e169524c6aaefcb01..17848c6bf3ee914f8786a7324426fb1ba987e780 100644 (file)
@@ -105,11 +105,9 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
   auto argType = type.dyn_cast<LLVM::LLVMType>();
   if (!argType)
     return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type");
-  if (argType.getUnderlyingType()->isVectorTy())
-    resultType = LLVMType::getVectorTy(
-        resultType,
-        llvm::cast<llvm::FixedVectorType>(argType.getUnderlyingType())
-            ->getNumElements());
+  if (argType.isVectorTy())
+    resultType =
+        LLVMType::getVectorTy(resultType, argType.getVectorNumElements());
 
   result.addTypes({resultType});
   return success();
@@ -214,7 +212,7 @@ static Type getLoadStoreElementType(OpAsmParser &parser, Type type,
   if (!llvmTy)
     return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"),
            nullptr;
-  if (!llvmTy.getUnderlyingType()->isPointerTy())
+  if (!llvmTy.isPointerTy())
     return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"),
            nullptr;
   return llvmTy.getPointerElementTy();
@@ -683,8 +681,7 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
       parser.resolveOperand(position, positionType, result.operands))
     return failure();
   auto wrappedVectorType = type.dyn_cast<LLVM::LLVMType>();
-  if (!wrappedVectorType ||
-      !wrappedVectorType.getUnderlyingType()->isVectorTy())
+  if (!wrappedVectorType || !wrappedVectorType.isVectorTy())
     return parser.emitError(
         loc, "expected LLVM IR dialect vector type for operand #1");
   result.addTypes(wrappedVectorType.getVectorElementType());
@@ -725,16 +722,15 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser,
                               "expected an array of integer literals"),
              nullptr;
     int position = positionElementAttr.getInt();
-    auto *llvmContainerType = wrappedContainerType.getUnderlyingType();
-    if (llvmContainerType->isArrayTy()) {
+    if (wrappedContainerType.isArrayTy()) {
       if (position < 0 || static_cast<unsigned>(position) >=
-                              llvmContainerType->getArrayNumElements())
+                              wrappedContainerType.getArrayNumElements())
         return parser.emitError(attributeLoc, "position out of bounds"),
                nullptr;
       wrappedContainerType = wrappedContainerType.getArrayElementType();
-    } else if (llvmContainerType->isStructTy()) {
+    } else if (wrappedContainerType.isStructTy()) {
       if (position < 0 || static_cast<unsigned>(position) >=
-                              llvmContainerType->getStructNumElements())
+                              wrappedContainerType.getStructNumElements())
         return parser.emitError(attributeLoc, "position out of bounds"),
                nullptr;
       wrappedContainerType =
@@ -803,8 +799,7 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser,
     return failure();
 
   auto wrappedVectorType = vectorType.dyn_cast<LLVM::LLVMType>();
-  if (!wrappedVectorType ||
-      !wrappedVectorType.getUnderlyingType()->isVectorTy())
+  if (!wrappedVectorType || !wrappedVectorType.isVectorTy())
     return parser.emitError(
         loc, "expected LLVM IR dialect vector type for operand #1");
   auto valueType = wrappedVectorType.getVectorElementType();
@@ -1125,7 +1120,7 @@ static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
 }
 
 static LogicalResult verify(GlobalOp op) {
-  if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType()))
+  if (!LLVMType::isValidPointerElementType(op.getType()))
     return op.emitOpError(
         "expects type to be a valid element type for an LLVM pointer");
   if (op.getParentOp() && !satisfiesLLVMModule(op.getParentOp()))
@@ -1133,8 +1128,7 @@ static LogicalResult verify(GlobalOp op) {
 
   if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
     auto type = op.getType();
-    if (!type.getUnderlyingType()->isArrayTy() ||
-        !type.getArrayElementType().getUnderlyingType()->isIntegerTy(8) ||
+    if (!type.isArrayTy() || !type.getArrayElementType().isIntegerTy(8) ||
         type.getArrayNumElements() != strAttr.getValue().size())
       return op.emitOpError(
           "requires an i8 array type of the length equal to that of the string "
@@ -1197,8 +1191,7 @@ static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
       parser.resolveOperand(v2, typeV2, result.operands))
     return failure();
   auto wrappedContainerType1 = typeV1.dyn_cast<LLVM::LLVMType>();
-  if (!wrappedContainerType1 ||
-      !wrappedContainerType1.getUnderlyingType()->isVectorTy())
+  if (!wrappedContainerType1 || !wrappedContainerType1.isVectorTy())
     return parser.emitError(
         loc, "expected LLVM IR dialect vector type for operand #1");
   auto vType = LLVMType::getVectorTy(
@@ -1239,7 +1232,7 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
   if (argAttrs.empty())
     return;
 
-  unsigned numInputs = type.getUnderlyingType()->getFunctionNumParams();
+  unsigned numInputs = type.getFunctionNumParams();
   assert(numInputs == argAttrs.size() &&
          "expected as many argument attribute lists as arguments");
   SmallString<8> argAttrName;
@@ -1374,7 +1367,7 @@ static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
 // getNumArguments hook not failing.
 LogicalResult LLVMFuncOp::verifyType() {
   auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMType>();
-  if (!llvmType || !llvmType.getUnderlyingType()->isFunctionTy())
+  if (!llvmType || !llvmType.isFunctionTy())
     return emitOpError("requires '" + getTypeAttrName() +
                        "' attribute of wrapped LLVM function type");
 
@@ -1384,7 +1377,7 @@ LogicalResult LLVMFuncOp::verifyType() {
 // Hook for OpTrait::FunctionLike, returns the number of function arguments.
 // Depends on the type attribute being correct as checked by verifyType
 unsigned LLVMFuncOp::getNumFuncArguments() {
-  return getType().getUnderlyingType()->getFunctionNumParams();
+  return getType().getFunctionNumParams();
 }
 
 // Hook for OpTrait::FunctionLike, returns the number of function results.
@@ -1424,8 +1417,7 @@ static LogicalResult verify(LLVMFuncOp op) {
   if (op.isVarArg())
     return op.emitOpError("only external functions can be variadic");
 
-  auto *funcType = cast<llvm::FunctionType>(op.getType().getUnderlyingType());
-  unsigned numArguments = funcType->getNumParams();
+  unsigned numArguments = op.getType().getFunctionNumParams();
   Block &entryBlock = op.front();
   for (unsigned i = 0; i < numArguments; ++i) {
     Type argType = entryBlock.getArgument(i).getType();
@@ -1433,7 +1425,7 @@ static LogicalResult verify(LLVMFuncOp op) {
     if (!argLLVMType)
       return op.emitOpError("entry block argument #")
              << i << " is not of LLVM type";
-    if (funcType->getParamType(i) != argLLVMType.getUnderlyingType())
+    if (op.getType().getFunctionParamType(i) != argLLVMType)
       return op.emitOpError("the type of entry block argument #")
              << i << " does not match the function signature";
   }
@@ -1566,7 +1558,7 @@ static LogicalResult verify(AtomicRMWOp op) {
     return op.emitOpError(
         "expected LLVM IR result type to match type for operand #1");
   if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) {
-    if (!valType.getUnderlyingType()->isFloatingPointTy())
+    if (!valType.isFloatingPointTy())
       return op.emitOpError("expected LLVM IR floating point type");
   } else if (op.bin_op() == AtomicBinOp::xchg) {
     if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
@@ -1842,6 +1834,13 @@ llvm::Type *LLVMType::getUnderlyingType() const {
   return getImpl()->underlyingType;
 }
 
+void LLVMType::getUnderlyingTypes(ArrayRef<LLVMType> types,
+                                  SmallVectorImpl<llvm::Type *> &result) {
+  result.reserve(result.size() + types.size());
+  for (LLVMType ty : types)
+    result.push_back(ty.getUnderlyingType());
+}
+
 /// Array type utilities.
 LLVMType LLVMType::getArrayElementType() {
   return get(getContext(), getUnderlyingType()->getArrayElementType());
@@ -1861,6 +1860,9 @@ unsigned LLVMType::getVectorNumElements() {
   return llvm::cast<llvm::FixedVectorType>(getUnderlyingType())
       ->getNumElements();
 }
+llvm::ElementCount LLVMType::getVectorElementCount() {
+  return llvm::cast<llvm::VectorType>(getUnderlyingType())->getElementCount();
+}
 bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); }
 
 /// Function type utilities.
@@ -1876,6 +1878,9 @@ LLVMType LLVMType::getFunctionResultType() {
       llvm::cast<llvm::FunctionType>(getUnderlyingType())->getReturnType());
 }
 bool LLVMType::isFunctionTy() { return getUnderlyingType()->isFunctionTy(); }
+bool LLVMType::isFunctionVarArg() {
+  return getUnderlyingType()->isFunctionVarArg();
+}
 
 /// Pointer type utilities.
 LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
@@ -1888,6 +1893,9 @@ LLVMType LLVMType::getPointerElementTy() {
   return get(getContext(), getUnderlyingType()->getPointerElementType());
 }
 bool LLVMType::isPointerTy() { return getUnderlyingType()->isPointerTy(); }
+bool LLVMType::isValidPointerElementType(LLVMType type) {
+  return llvm::PointerType::isValidElementType(type.getUnderlyingType());
+}
 
 /// Struct type utilities.
 LLVMType LLVMType::getStructElementType(unsigned i) {
@@ -1974,18 +1982,12 @@ LLVMType LLVMType::getStructTy(LLVMDialect *dialect,
                                  isPacked);
   });
 }
-inline static SmallVector<llvm::Type *, 8>
-toUnderlyingTypes(ArrayRef<LLVMType> elements) {
-  SmallVector<llvm::Type *, 8> llvmElements;
-  for (auto elt : elements)
-    llvmElements.push_back(elt.getUnderlyingType());
-  return llvmElements;
-}
 LLVMType LLVMType::createStructTy(LLVMDialect *dialect,
                                   ArrayRef<LLVMType> elements,
                                   Optional<StringRef> name, bool isPacked) {
   StringRef sr = name.hasValue() ? *name : "";
-  SmallVector<llvm::Type *, 8> llvmElements(toUnderlyingTypes(elements));
+  SmallVector<llvm::Type *, 8> llvmElements;
+  getUnderlyingTypes(elements, llvmElements);
   return getLocked(dialect, [=] {
     auto *rv = llvm::StructType::create(dialect->getLLVMContext(), sr);
     if (!llvmElements.empty())
@@ -1997,7 +1999,8 @@ LLVMType LLVMType::setStructTyBody(LLVMType structType,
                                    ArrayRef<LLVMType> elements, bool isPacked) {
   llvm::StructType *st =
       llvm::cast<llvm::StructType>(structType.getUnderlyingType());
-  SmallVector<llvm::Type *, 8> llvmElements(toUnderlyingTypes(elements));
+  SmallVector<llvm::Type *, 8> llvmElements;
+  getUnderlyingTypes(elements, llvmElements);
   return getLocked(&structType.getDialect(), [=] {
     st->setBody(llvmElements, isPacked);
     return st;
@@ -2017,6 +2020,10 @@ LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) {
 
 bool LLVMType::isVoidTy() { return getUnderlyingType()->isVoidTy(); }
 
+llvm::Type *mlir::LLVM::convertLLVMType(LLVMType type) {
+  return type.getUnderlyingType();
+}
+
 //===----------------------------------------------------------------------===//
 // Utility functions.
 //===----------------------------------------------------------------------===//
index 9754c614efdfbc2565f523c5af4e593effc704e2..77897d65e1a50bd9a2783cc4348214690c32f37a 100644 (file)
@@ -234,18 +234,17 @@ Type Importer::getStdTypeForAttr(LLVMType type) {
     return nullptr;
 
   if (type.isIntegerTy())
-    return b.getIntegerType(type.getUnderlyingType()->getIntegerBitWidth());
+    return b.getIntegerType(type.getIntegerBitWidth());
 
-  if (type.getUnderlyingType()->isFloatTy())
+  if (type.isFloatTy())
     return b.getF32Type();
 
-  if (type.getUnderlyingType()->isDoubleTy())
+  if (type.isDoubleTy())
     return b.getF64Type();
 
   // LLVM vectors can only contain scalars.
   if (type.isVectorTy()) {
-    auto numElements = llvm::cast<llvm::VectorType>(type.getUnderlyingType())
-                           ->getElementCount();
+    auto numElements = type.getVectorElementCount();
     if (numElements.Scalable) {
       emitError(unknownLoc) << "scalable vectors not supported";
       return nullptr;
@@ -270,9 +269,7 @@ Type Importer::getStdTypeForAttr(LLVMType type) {
     // attribute type.
     if (type.getArrayElementType().isVectorTy()) {
       LLVMType vectorType = type.getArrayElementType();
-      auto numElements =
-          llvm::cast<llvm::VectorType>(vectorType.getUnderlyingType())
-              ->getElementCount();
+      auto numElements = vectorType.getVectorElementCount();
       if (numElements.Scalable) {
         emitError(unknownLoc) << "scalable vectors not supported";
         return nullptr;
index 3a70dd3932e92ea76f09149f34f53f9208b2b824..a0aefc988a5daae376f4ec7e863eefc057d69b1a 100644 (file)
@@ -574,7 +574,7 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
   }
 
   if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
-    llvm::Type *ty = lpOp.getType().dyn_cast<LLVMType>().getUnderlyingType();
+    llvm::Type *ty = convertType(lpOp.getType().cast<LLVMType>());
     llvm::LandingPadInst *lpi =
         builder.CreateLandingPad(ty, lpOp.getNumOperands());
 
@@ -661,7 +661,7 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) {
       if (!wrappedType)
         return emitError(bb.front().getLoc(),
                          "block argument does not have an LLVM type");
-      llvm::Type *type = wrappedType.getUnderlyingType();
+      llvm::Type *type = convertType(wrappedType);
       llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
       valueMapping[arg] = phi;
     }
@@ -687,7 +687,7 @@ LogicalResult ModuleTranslation::convertGlobals() {
   llvm::sys::SmartScopedLock<true> scopedLock(
       llvmDialect->getLLVMContextMutex());
   for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
-    llvm::Type *type = op.getType().getUnderlyingType();
+    llvm::Type *type = convertType(op.getType());
     llvm::Constant *cst = llvm::UndefValue::get(type);
     if (op.getValueOrNull()) {
       // String attributes are treated separately because they cannot appear as
@@ -826,7 +826,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
       // NB: Attribute already verified to be boolean, so check if we can indeed
       // attach the attribute to this argument, based on its type.
       auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMType>();
-      if (!argTy.getUnderlyingType()->isPointerTy())
+      if (!argTy.isPointerTy())
         return func.emitError(
             "llvm.noalias attribute attached to LLVM non-pointer argument");
       if (attr.getValue())
@@ -837,7 +837,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
       // NB: Attribute already verified to be int, so check if we can indeed
       // attach the attribute to this argument, based on its type.
       auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMType>();
-      if (!argTy.getUnderlyingType()->isPointerTy())
+      if (!argTy.isPointerTy())
         return func.emitError(
             "llvm.align attribute attached to LLVM non-pointer argument");
       llvmArg.addAttrs(
@@ -896,7 +896,7 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
   for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
     llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
         function.getName(),
-        cast<llvm::FunctionType>(function.getType().getUnderlyingType()));
+        cast<llvm::FunctionType>(convertType(function.getType())));
     llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee());
     llvmFunc->setLinkage(convertLinkageToLLVM(function.linkage()));
     functionMapping[function.getName()] = llvmFunc;
@@ -928,6 +928,10 @@ LogicalResult ModuleTranslation::convertFunctions() {
   return success();
 }
 
+llvm::Type *ModuleTranslation::convertType(LLVMType type) {
+  return LLVM::convertLLVMType(type);
+}
+
 /// A helper to look up remapped operands in the value remapping table.`
 SmallVector<llvm::Value *, 8>
 ModuleTranslation::lookupValues(ValueRange values) {
index 9edbad8fdd5402aec4a41b1fb825096ff6940ed2..f62e7aebe24a6c9864c0031ddd97b1de69fcc258 100644 (file)
@@ -135,8 +135,7 @@ static bool emitOneBuilder(const Record &record, raw_ostream &os) {
     } else if (isResultName(op, name)) {
       bs << formatv("valueMapping[op.{0}()]", name);
     } else if (name == "_resultType") {
-      bs << "op.getResult().getType().cast<LLVM::LLVMType>()."
-            "getUnderlyingType()";
+      bs << "convertType(op.getResult().getType().cast<LLVM::LLVMType>())";
     } else if (name == "_hasResult") {
       bs << "opInst.getNumResults() == 1";
     } else if (name == "_location") {