[mlir] Remove instance methods from LLVMType
authorAlex Zinenko <zinenko@google.com>
Tue, 22 Dec 2020 10:22:21 +0000 (11:22 +0100)
committerAlex Zinenko <zinenko@google.com>
Tue, 22 Dec 2020 22:34:54 +0000 (23:34 +0100)
LLVMType contains multiple instance methods that were introduced initially for
compatibility with LLVM API. These methods boil down to `cast` followed by
type-specific call. Arguably, they are mostly used in an LLVM cast-follows-isa
anti-pattern. This doesn't connect nicely to the rest of the MLIR
infrastructure and actively prevents it from making the LLVM dialect type
system more open, e.g., reusing built-in types when appropriate. Remove such
instance methods and replaces their uses with apporpriate casts and methods on
derived classes. In some cases, the result may look slightly more verbose, but
most cases should actually use a stricter subtype of LLVMType anyway and avoid
the isa/cast.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D93680

20 files changed:
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/ExecutionEngine/JitRunner.cpp
mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Dialect/LLVMIR/invalid.mlir

index 7c069c9cd5566fed87bd7866db48dad2995dd64c..63ff16a84ab876e51e5cfafbb56d6138f7a6f2a6 100644 (file)
@@ -446,7 +446,8 @@ public:
   /// Builds IR extracting the pointer to the first element of the size array.
   static Value sizeBasePtr(OpBuilder &builder, Location loc,
                            LLVMTypeConverter &typeConverter,
-                           Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType);
+                           Value memRefDescPtr,
+                           LLVM::LLVMPointerType elemPtrPtrType);
   /// Builds IR extracting the size[index] from the descriptor.
   static Value size(OpBuilder &builder, Location loc,
                     LLVMTypeConverter typeConverter, Value sizeBasePtr,
index df022ef47b3317d5b1376641e8ee3d9e4d1bd11d..552fe15e68997fbc0a813bccb7424051d163d8dd 100644 (file)
@@ -51,7 +51,7 @@ def LLVM_VoidResultTypeOpBuilder :
   [{
     auto llvmType = resultType.dyn_cast<LLVMType>(); (void)llvmType;
     assert(llvmType && "result must be an LLVM type");
-    assert(llvmType.isVoidTy() &&
+    assert(llvmType.isa<LLVMVoidType>() &&
            "for zero-result operands, only 'void' is accepted as result type");
     build($_builder, $_state, operands, attributes);
   }]>;
@@ -288,7 +288,7 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
     OpBuilderDAG<(ins "Value":$addr, CArg<"unsigned", "0">:$alignment,
       CArg<"bool", "false">:$isVolatile, CArg<"bool", "false">:$isNonTemporal),
     [{
-      auto type = addr.getType().cast<LLVMType>().getPointerElementTy();
+      auto type = addr.getType().cast<LLVMPointerType>().getElementType();
       build($_builder, $_state, type, addr, alignment, isVolatile, isNonTemporal);
     }]>,
     OpBuilderDAG<(ins "Type":$t, "Value":$addr,
@@ -443,8 +443,8 @@ def LLVM_CallOp : LLVM_Op<"call"> {
     OpBuilderDAG<(ins "LLVMFuncOp":$func, "ValueRange":$operands,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
     [{
-      LLVMType resultType = func.getType().getFunctionResultType();
-      if (!resultType.isVoidTy())
+      LLVMType resultType = func.getType().getReturnType();
+      if (!resultType.isa<LLVM::LLVMVoidType>())
         $_state.addTypes(resultType);
       $_state.addAttribute("callee", $_builder.getSymbolRefAttr(func));
       $_state.addAttributes(attributes);
@@ -515,12 +515,10 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", [NoSideEffect]> {
     OpBuilderDAG<(ins "Value":$v1, "Value":$v2, "ArrayAttr":$mask,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
   let verifier = [{
-    auto wrappedVectorType1 = v1().getType().cast<LLVMType>();
-    auto wrappedVectorType2 = v2().getType().cast<LLVMType>();
-    if (!wrappedVectorType2.isVectorTy())
-      return emitOpError("expected LLVM IR Dialect vector type for operand #2");
-    if (wrappedVectorType1.getVectorElementType() !=
-        wrappedVectorType2.getVectorElementType())
+    auto wrappedVectorType1 = v1().getType().cast<LLVMVectorType>();
+    auto wrappedVectorType2 = v2().getType().cast<LLVMVectorType>();
+    if (wrappedVectorType1.getElementType() !=
+        wrappedVectorType2.getElementType())
       return emitOpError("expected matching LLVM IR Dialect element types");
     return success();
   }];
@@ -768,13 +766,13 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof"> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       build($_builder, $_state,
-            global.getType().getPointerTo(global.addr_space()),
+            LLVM::LLVMPointerType::get(global.getType(), global.addr_space()),
             global.sym_name(), attrs);}]>,
     OpBuilderDAG<(ins "LLVMFuncOp":$func,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       build($_builder, $_state,
-            func.getType().getPointerTo(), func.getName(), attrs);}]>
+            LLVM::LLVMPointerType::get(func.getType()), func.getName(), attrs);}]>
   ];
 
   let extraClassDeclaration = [{
@@ -970,12 +968,12 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func",
     // to match the signature of the function.
     Block *addEntryBlock();
 
-    LLVMType getType() {
+    LLVMFunctionType getType() {
       return (*this)->getAttrOfType<TypeAttr>(getTypeAttrName())
-          .getValue().cast<LLVMType>();
+          .getValue().cast<LLVMFunctionType>();
     }
     bool isVarArg() {
-      return getType().isFunctionVarArg();
+      return getType().isVarArg();
     }
 
     // Hook for OpTrait::FunctionLike, returns the number of function arguments`.
index f92bdf9e3041ac609387465d4d05732788160af7..e1938c12c809e912d8e28775d6c626fc17ee7b1f 100644 (file)
@@ -80,58 +80,6 @@ public:
 
   LLVMDialect &getDialect();
 
-  /// Returns the size of a primitive type (including vectors) in bits, for
-  /// example, the size of !llvm.i16 is 16 and the size of !llvm.vec<4 x i16>
-  /// is 64. Returns 0 for non-primitive (aggregates such as struct) or types
-  /// that don't have a size (such as void).
-  llvm::TypeSize getPrimitiveSizeInBits();
-
-  /// Floating-point type utilities.
-  bool isBFloatTy() { return isa<LLVMBFloatType>(); }
-  bool isHalfTy() { return isa<LLVMHalfType>(); }
-  bool isFloatTy() { return isa<LLVMFloatType>(); }
-  bool isDoubleTy() { return isa<LLVMDoubleType>(); }
-  bool isFP128Ty() { return isa<LLVMFP128Type>(); }
-  bool isX86_FP80Ty() { return isa<LLVMX86FP80Type>(); }
-  bool isFloatingPointTy() {
-    return isa<LLVMHalfType>() || isa<LLVMBFloatType>() ||
-           isa<LLVMFloatType>() || isa<LLVMDoubleType>() ||
-           isa<LLVMFP128Type>() || isa<LLVMX86FP80Type>();
-  }
-
-  /// Array type utilities.
-  LLVMType getArrayElementType();
-  unsigned getArrayNumElements();
-  bool isArrayTy();
-
-  /// Integer type utilities.
-  bool isIntegerTy() { return isa<LLVMIntegerType>(); }
-  bool isIntegerTy(unsigned bitwidth);
-  unsigned getIntegerBitWidth();
-
-  /// Vector type utilities.
-  LLVMType getVectorElementType();
-  unsigned getVectorNumElements();
-  llvm::ElementCount getVectorElementCount();
-  bool isVectorTy();
-
-  /// Function type utilities.
-  LLVMType getFunctionParamType(unsigned argIdx);
-  unsigned getFunctionNumParams();
-  LLVMType getFunctionResultType();
-  bool isFunctionTy();
-  bool isFunctionVarArg();
-
-  /// Pointer type utilities.
-  LLVMType getPointerTo(unsigned addrSpace = 0);
-  LLVMType getPointerElementTy();
-  bool isPointerTy();
-
-  /// Struct type utilities.
-  LLVMType getStructElementType(unsigned i);
-  unsigned getStructNumElements();
-  bool isStructTy();
-
   /// Utilities used to generate floating point types.
   static LLVMType getDoubleTy(MLIRContext *context);
   static LLVMType getFloatTy(MLIRContext *context);
@@ -148,9 +96,7 @@ public:
   static LLVMType getInt8Ty(MLIRContext *context) {
     return getIntNTy(context, /*numBits=*/8);
   }
-  static LLVMType getInt8PtrTy(MLIRContext *context) {
-    return getInt8Ty(context).getPointerTo();
-  }
+  static LLVMType getInt8PtrTy(MLIRContext *context);
   static LLVMType getInt16Ty(MLIRContext *context) {
     return getIntNTy(context, /*numBits=*/16);
   }
@@ -184,7 +130,6 @@ public:
 
   /// Void type utilities.
   static LLVMType getVoidTy(MLIRContext *context);
-  bool isVoidTy();
 
   // Creation and setting of LLVM's identified struct types
   static LLVMType createStructTy(MLIRContext *context,
@@ -585,6 +530,24 @@ LLVMType parseType(DialectAsmParser &parser);
 void printType(LLVMType type, DialectAsmPrinter &printer);
 } // namespace detail
 
+//===----------------------------------------------------------------------===//
+// Utility functions.
+//===----------------------------------------------------------------------===//
+
+/// Returns `true` if the given type is compatible with the LLVM dialect.
+inline bool isCompatibleType(Type type) { return type.isa<LLVMType>(); }
+
+inline bool isCompatibleFloatingPointType(Type type) {
+  return type.isa<LLVMHalfType, LLVMBFloatType, LLVMFloatType, LLVMDoubleType,
+                  LLVMFP128Type, LLVMX86FP80Type>();
+}
+
+/// Returns the size of the given primitive LLVM dialect-compatible type
+/// (including vectors) in bits, for example, the size of !llvm.i16 is 16 and
+/// the size of !llvm.vec<4 x i16> is 64. Returns 0 for non-primitive
+/// (aggregates such as struct) or types that don't have a size (such as void).
+llvm::TypeSize getPrimitiveTypeSizeInBits(Type type);
+
 } // namespace LLVM
 } // namespace mlir
 
index 1f9b860eb52eb55c51ee039d1c9b4f5330a675a4..3c73cdf64eb707bcf489305f7c6b9323ad4cf5bf 100644 (file)
@@ -109,10 +109,11 @@ def NVVM_ShflBflyOp :
   let verifier = [{
     if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
       return success();
-    auto type = getType().cast<LLVM::LLVMType>();
-    if (!type.isStructTy() || type.getStructNumElements() != 2 ||
-        !type.getStructElementType(1).isIntegerTy(
-            /*Bitwidth=*/1))
+    auto type = getType().dyn_cast<LLVM::LLVMStructType>();
+    auto elementType = (type && type.getBody().size() == 2)
+                     ? type.getBody()[1].dyn_cast<LLVM::LLVMIntegerType>()
+                     : nullptr;
+    if (!elementType || elementType.getBitWidth() != 1)
       return emitError("expected return type to be a two-element struct with "
                        "i1 as the second element");
     return success();
index 273754fe2480c49ccf068296d4bc84847b15e4a0..65545d8ab2de1de5085f6adb1af50c748e318d97 100644 (file)
@@ -79,7 +79,7 @@ struct AsyncAPI {
 
   static FunctionType executeFunctionType(MLIRContext *ctx) {
     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
-    auto resume = resumeFunctionType(ctx).getPointerTo();
+    auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
     return FunctionType::get(ctx, {hdl, resume}, {});
   }
 
@@ -91,13 +91,13 @@ struct AsyncAPI {
 
   static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
-    auto resume = resumeFunctionType(ctx).getPointerTo();
+    auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
   }
 
   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
-    auto resume = resumeFunctionType(ctx).getPointerTo();
+    auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
   }
 
@@ -507,7 +507,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
   // A pointer to coroutine resume intrinsic wrapper.
   auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
   auto resumePtr = builder.create<LLVM::AddressOfOp>(
-      loc, resumeFnTy.getPointerTo(), kResume);
+      loc, LLVM::LLVMPointerType::get(resumeFnTy), kResume);
 
   // Save the coroutine state: @llvm.coro.save
   auto coroSave = builder.create<LLVM::CallOp>(
@@ -750,7 +750,7 @@ public:
       // A pointer to coroutine resume intrinsic wrapper.
       auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
       auto resumePtr = builder.create<LLVM::AddressOfOp>(
-          loc, resumeFnTy.getPointerTo(), kResume);
+          loc, LLVM::LLVMPointerType::get(resumeFnTy), kResume);
 
       // Save the coroutine state: @llvm.coro.save
       auto coroSave = builder.create<LLVM::CallOp>(
index 41a079c44eea58af00a4614a47a6f098ee5e6977..bbb2bf1e04ff2212bb8a0e5dc6b8bce64c4e4be6 100644 (file)
@@ -55,14 +55,14 @@ public:
   FunctionCallBuilder(StringRef functionName, LLVM::LLVMType returnType,
                       ArrayRef<LLVM::LLVMType> argumentTypes)
       : functionName(functionName),
-        functionType(LLVM::LLVMType::getFunctionTy(returnType, argumentTypes,
-                                                   /*isVarArg=*/false)) {}
+        functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes,
+                                                 /*isVarArg=*/false)) {}
   LLVM::CallOp create(Location loc, OpBuilder &builder,
                       ArrayRef<Value> arguments) const;
 
 private:
   StringRef functionName;
-  LLVM::LLVMType functionType;
+  LLVM::LLVMFunctionType functionType;
 };
 
 template <typename OpTy>
@@ -76,7 +76,8 @@ protected:
 
   LLVM::LLVMType llvmVoidType = LLVM::LLVMType::getVoidTy(context);
   LLVM::LLVMType llvmPointerType = LLVM::LLVMType::getInt8PtrTy(context);
-  LLVM::LLVMType llvmPointerPointerType = llvmPointerType.getPointerTo();
+  LLVM::LLVMType llvmPointerPointerType =
+      LLVM::LLVMPointerType::get(llvmPointerType);
   LLVM::LLVMType llvmInt8Type = LLVM::LLVMType::getInt8Ty(context);
   LLVM::LLVMType llvmInt32Type = LLVM::LLVMType::getInt32Ty(context);
   LLVM::LLVMType llvmInt64Type = LLVM::LLVMType::getInt64Ty(context);
@@ -312,7 +313,7 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
         .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
   }();
   return builder.create<LLVM::CallOp>(
-      loc, const_cast<LLVM::LLVMType &>(functionType).getFunctionResultType(),
+      loc, const_cast<LLVM::LLVMFunctionType &>(functionType).getReturnType(),
       builder.getSymbolRefAttr(function), arguments);
 }
 
@@ -518,7 +519,7 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
   auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
                                               builder.getI32IntegerAttr(1));
   auto structPtr = builder.create<LLVM::AllocaOp>(
-      loc, structType.getPointerTo(), one, /*alignment=*/0);
+      loc, LLVM::LLVMPointerType::get(structType), one, /*alignment=*/0);
   auto arraySize = builder.create<LLVM::ConstantOp>(
       loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments));
   auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType,
@@ -529,7 +530,7 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
     auto index = builder.create<LLVM::ConstantOp>(
         loc, llvmInt32Type, builder.getI32IntegerAttr(en.index()));
     auto fieldPtr = builder.create<LLVM::GEPOp>(
-        loc, argumentTypes[en.index()].getPointerTo(), structPtr,
+        loc, LLVM::LLVMPointerType::get(argumentTypes[en.index()]), structPtr,
         ArrayRef<Value>{zero, index.getResult()});
     builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
     auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType,
index bf17200e594f1f21ba720dd2a88cf3d3cfd9a50a..914b7ee50cf93403d5249afab4b395f0ca2abda1 100644 (file)
@@ -51,8 +51,8 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
 
     // Rewrite the original GPU function to an LLVM function.
     auto funcType = typeConverter->convertType(gpuFuncOp.getType())
-                        .template cast<LLVM::LLVMType>()
-                        .getPointerElementTy();
+                        .template cast<LLVM::LLVMPointerType>()
+                        .getElementType();
 
     // Remap proper input types.
     TypeConverter::SignatureConversion signatureConversion(
@@ -94,10 +94,11 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
       for (auto en : llvm::enumerate(workgroupBuffers)) {
         LLVM::GlobalOp global = en.value();
         Value address = rewriter.create<LLVM::AddressOfOp>(loc, global);
-        auto elementType = global.getType().getArrayElementType();
+        auto elementType =
+            global.getType().cast<LLVM::LLVMArrayType>().getElementType();
         Value memory = rewriter.create<LLVM::GEPOp>(
-            loc, elementType.getPointerTo(global.addr_space()), address,
-            ArrayRef<Value>{zero, zero});
+            loc, LLVM::LLVMPointerType::get(elementType, global.addr_space()),
+            address, ArrayRef<Value>{zero, zero});
 
         // Build a memref descriptor pointing to the buffer to plug with the
         // existing memref infrastructure. This may use more registers than
@@ -123,9 +124,10 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
         // Explicitly drop memory space when lowering private memory
         // attributions since NVVM models it as `alloca`s in the default
         // memory space and does not support `alloca`s with addrspace(5).
-        auto ptrType = typeConverter->convertType(type.getElementType())
-                           .template cast<LLVM::LLVMType>()
-                           .getPointerTo(AllocaAddrSpace);
+        auto ptrType = LLVM::LLVMPointerType::get(
+            typeConverter->convertType(type.getElementType())
+                .template cast<LLVM::LLVMType>(),
+            AllocaAddrSpace);
         Value numElements = rewriter.create<LLVM::ConstantOp>(
             gpuFuncOp.getLoc(), int64Ty,
             rewriter.getI64IntegerAttr(type.getNumElements()));
index 9d08aeee190610c8a99ebe2c6e183a6148554770..b2887aa1d78291526ebe8b8377b89ac7fd2029db 100644 (file)
@@ -57,7 +57,8 @@ public:
     LLVMType resultType =
         castedOperands.front().getType().cast<LLVM::LLVMType>();
     LLVMType funcType = getFunctionType(resultType, castedOperands);
-    StringRef funcName = getFunctionName(funcType.getFunctionResultType());
+    StringRef funcName = getFunctionName(
+        funcType.cast<LLVM::LLVMFunctionType>().getReturnType());
     if (funcName.empty())
       return failure();
 
@@ -80,7 +81,7 @@ public:
 private:
   Value maybeCast(Value operand, PatternRewriter &rewriter) const {
     LLVM::LLVMType type = operand.getType().cast<LLVM::LLVMType>();
-    if (!type.isHalfTy())
+    if (!type.isa<LLVM::LLVMHalfType>())
       return operand;
 
     return rewriter.create<LLVM::FPExtOp>(
@@ -100,9 +101,9 @@ private:
   }
 
   StringRef getFunctionName(LLVM::LLVMType type) const {
-    if (type.isFloatTy())
+    if (type.isa<LLVM::LLVMFloatType>())
       return f32Func;
-    if (type.isDoubleTy())
+    if (type.isa<LLVM::LLVMDoubleType>())
       return f64Func;
     return "";
   }
index 355bced96ae750db3a0c3c99a131bee12ffcd36f..c676cd256d66a8660a64d14c7cc095edf11b3929 100644 (file)
@@ -75,7 +75,7 @@ private:
     //   int64_t sizes[Rank]; // omitted when rank == 0
     //   int64_t strides[Rank]; // omitted when rank == 0
     // };
-    auto llvmPtrToElementType = elemenType.getPointerTo();
+    auto llvmPtrToElementType = LLVM::LLVMPointerType::get(elemenType);
     auto llvmArrayRankElementSizeType =
         LLVM::LLVMType::getArrayTy(getInt64Type(), rank);
 
@@ -131,16 +131,18 @@ private:
 
   /// Returns a string representation from the given `type`.
   StringRef stringifyType(LLVM::LLVMType type) {
-    if (type.isFloatTy())
+    if (type.isa<LLVM::LLVMFloatType>())
       return "Float";
-    if (type.isHalfTy())
+    if (type.isa<LLVM::LLVMHalfType>())
       return "Half";
-    if (type.isIntegerTy(32))
-      return "Int32";
-    if (type.isIntegerTy(16))
-      return "Int16";
-    if (type.isIntegerTy(8))
-      return "Int8";
+    if (auto intType = type.dyn_cast<LLVM::LLVMIntegerType>()) {
+      if (intType.getBitWidth() == 32)
+        return "Int32";
+      if (intType.getBitWidth() == 16)
+        return "Int16";
+      if (intType.getBitWidth() == 8)
+        return "Int8";
+    }
 
     llvm_unreachable("unsupported type");
   }
@@ -238,11 +240,11 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
         llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
     // Special case for fp16 type. Since it is not a supported type in C we use
     // int16_t and bitcast the descriptor.
-    if (type.isHalfTy()) {
+    if (type.isa<LLVM::LLVMHalfType>()) {
       auto memRefTy =
           getMemRefType(rank, LLVM::LLVMType::getInt16Ty(&getContext()));
       ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
-          loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor);
+          loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor);
     }
     // Create call to `bindMemRef`.
     builder.create<LLVM::CallOp>(
@@ -257,11 +259,12 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
 LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
     Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) {
   auto llvmPtrDescriptorTy =
-      ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>();
+      ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMPointerType>();
   if (!llvmPtrDescriptorTy)
     return failure();
 
-  auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy();
+  auto llvmDescriptorTy =
+      llvmPtrDescriptorTy.getElementType().dyn_cast<LLVM::LLVMStructType>();
   // template <typename Elem, size_t Rank>
   // struct {
   //   Elem *allocated;
@@ -270,15 +273,19 @@ LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
   //   int64_t sizes[Rank]; // omitted when rank == 0
   //   int64_t strides[Rank]; // omitted when rank == 0
   // };
-  if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy())
+  if (!llvmDescriptorTy)
     return failure();
 
-  type = llvmDescriptorTy.getStructElementType(0).getPointerElementTy();
-  if (llvmDescriptorTy.getStructNumElements() == 3) {
+  type = llvmDescriptorTy.getBody()[0]
+             .cast<LLVM::LLVMPointerType>()
+             .getElementType();
+  if (llvmDescriptorTy.getBody().size() == 3) {
     rank = 0;
     return success();
   }
-  rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements();
+  rank = llvmDescriptorTy.getBody()[3]
+             .cast<LLVM::LLVMArrayType>()
+             .getNumElements();
   return success();
 }
 
@@ -326,13 +333,13 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
                                 LLVM::LLVMType::getHalfTy(&getContext())}) {
       std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
                            std::string(stringifyType(type));
-      if (type.isHalfTy())
+      if (type.isa<LLVM::LLVMHalfType>())
         type = LLVM::LLVMType::getInt16Ty(&getContext());
       if (!module.lookupSymbol(fnName)) {
         auto fnType = LLVM::LLVMType::getFunctionTy(
             getVoidType(),
             {getPointerType(), getInt32Type(), getInt32Type(),
-             getMemRefType(i, type).getPointerTo()},
+             LLVM::LLVMPointerType::get(getMemRefType(i, type))},
             /*isVarArg=*/false);
         builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
       }
index cacb4787edd4eff47ccabd5c446b7cc2ef752630..7da9c47f9219926abc42f1acd4e7798fcbe5d792 100644 (file)
@@ -66,8 +66,10 @@ static unsigned getBitWidth(Type type) {
 
 /// Returns the bit width of LLVMType integer or vector.
 static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) {
-  return type.isVectorTy() ? type.getVectorElementType().getIntegerBitWidth()
-                           : type.getIntegerBitWidth();
+  auto vectorType = type.dyn_cast<LLVM::LLVMVectorType>();
+  return (vectorType ? vectorType.getElementType() : type)
+      .cast<LLVM::LLVMIntegerType>()
+      .getBitWidth();
 }
 
 /// Creates `IntegerAttribute` with all bits set for given type
@@ -265,7 +267,7 @@ static Type convertPointerType(spirv::PointerType type,
                                TypeConverter &converter) {
   auto pointeeType =
       converter.convertType(type.getPointeeType()).cast<LLVM::LLVMType>();
-  return pointeeType.getPointerTo();
+  return LLVM::LLVMPointerType::get(pointeeType);
 }
 
 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
index 6fbcc220a86b08e061b37d54355b258c567eece6..e37e7e2dc0c11421c5071fdadc461f10e2b87c8a 100644 (file)
@@ -215,7 +215,7 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
   SignatureConversion conversion(type.getNumInputs());
   LLVM::LLVMType converted =
       convertFunctionSignature(type, /*isVariadic=*/false, conversion);
-  return converted.getPointerTo();
+  return LLVM::LLVMPointerType::get(converted);
 }
 
 
@@ -267,7 +267,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
     if (!converted)
       return {};
     if (t.isa<MemRefType, UnrankedMemRefType>())
-      converted = converted.getPointerTo();
+      converted = LLVM::LLVMPointerType::get(converted);
     inputs.push_back(converted);
   }
 
@@ -324,7 +324,7 @@ LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
   LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
   if (!elementType)
     return {};
-  auto ptrTy = elementType.getPointerTo(type.getMemorySpace());
+  auto ptrTy = LLVM::LLVMPointerType::get(elementType, type.getMemorySpace());
   auto indexTy = getIndexType();
 
   SmallVector<LLVM::LLVMType, 5> results = {ptrTy, ptrTy, indexTy};
@@ -396,7 +396,7 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
   LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
   if (!elementType)
     return {};
-  return elementType.getPointerTo(type.getMemorySpace());
+  return LLVM::LLVMPointerType::get(elementType, type.getMemorySpace());
 }
 
 // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
@@ -460,7 +460,7 @@ StructBuilder::StructBuilder(Value v) : value(v) {
 
 Value StructBuilder::extractPtr(OpBuilder &builder, Location loc,
                                 unsigned pos) {
-  Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
+  Type type = structType.cast<LLVM::LLVMStructType>().getBody()[pos];
   return builder.create<LLVM::ExtractValueOp>(loc, type, value,
                                               builder.getI64ArrayAttr(pos));
 }
@@ -507,8 +507,9 @@ Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
 MemRefDescriptor::MemRefDescriptor(Value descriptor)
     : StructBuilder(descriptor) {
   assert(value != nullptr && "value cannot be null");
-  indexType = value.getType().cast<LLVM::LLVMType>().getStructElementType(
-      kOffsetPosInMemRefDescriptor);
+  indexType = value.getType()
+                  .cast<LLVM::LLVMStructType>()
+                  .getBody()[kOffsetPosInMemRefDescriptor];
 }
 
 /// Builds IR creating an `undef` value of the descriptor type.
@@ -618,9 +619,9 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
 Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
                              int64_t rank) {
   auto indexTy = indexType.cast<LLVM::LLVMType>();
-  auto indexPtrTy = indexTy.getPointerTo();
+  auto indexPtrTy = LLVM::LLVMPointerType::get(indexTy);
   auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, rank);
-  auto arrayPtrTy = arrayTy.getPointerTo();
+  auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy);
 
   // Copy size values to stack-allocated memory.
   auto zero = createIndexAttrConstant(builder, loc, indexType, 0);
@@ -675,8 +676,8 @@ void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
 
 LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
   return value.getType()
-      .cast<LLVM::LLVMType>()
-      .getStructElementType(kAlignedPtrPosInMemRefDescriptor)
+      .cast<LLVM::LLVMStructType>()
+      .getBody()[kAlignedPtrPosInMemRefDescriptor]
       .cast<LLVM::LLVMPointerType>();
 }
 
@@ -922,7 +923,7 @@ Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
   Value offsetGep = builder.create<LLVM::GEPOp>(
       loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
   offsetGep = builder.create<LLVM::BitcastOp>(
-      loc, typeConverter.getIndexType().getPointerTo(), offsetGep);
+      loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
   return builder.create<LLVM::LoadOp>(loc, offsetGep);
 }
 
@@ -939,19 +940,17 @@ void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
   Value offsetGep = builder.create<LLVM::GEPOp>(
       loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
   offsetGep = builder.create<LLVM::BitcastOp>(
-      loc, typeConverter.getIndexType().getPointerTo(), offsetGep);
+      loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
   builder.create<LLVM::StoreOp>(loc, offset, offsetGep);
 }
 
-Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc,
-                                            LLVMTypeConverter &typeConverter,
-                                            Value memRefDescPtr,
-                                            LLVM::LLVMType elemPtrPtrType) {
-  LLVM::LLVMType elemPtrTy = elemPtrPtrType.getPointerElementTy();
+Value UnrankedMemRefDescriptor::sizeBasePtr(
+    OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
+    Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType) {
+  LLVM::LLVMType elemPtrTy = elemPtrPtrType.getElementType();
   LLVM::LLVMType indexTy = typeConverter.getIndexType();
-  LLVM::LLVMType structPtrTy =
-      LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy)
-          .getPointerTo();
+  LLVM::LLVMType structPtrTy = LLVM::LLVMPointerType::get(
+      LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy));
   Value structPtr =
       builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
 
@@ -961,14 +960,15 @@ Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc,
       createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
   Value three = builder.create<LLVM::ConstantOp>(loc, int32_type,
                                                  builder.getI32IntegerAttr(3));
-  return builder.create<LLVM::GEPOp>(loc, indexTy.getPointerTo(), structPtr,
-                                     ValueRange({zero, three}));
+  return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMPointerType::get(indexTy),
+                                     structPtr, ValueRange({zero, three}));
 }
 
 Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
                                      LLVMTypeConverter typeConverter,
                                      Value sizeBasePtr, Value index) {
-  LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+  LLVM::LLVMType indexPtrTy =
+      LLVM::LLVMPointerType::get(typeConverter.getIndexType());
   Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
                                                    ValueRange({index}));
   return builder.create<LLVM::LoadOp>(loc, sizeStoreGep);
@@ -978,7 +978,8 @@ void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
                                        LLVMTypeConverter typeConverter,
                                        Value sizeBasePtr, Value index,
                                        Value size) {
-  LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+  LLVM::LLVMType indexPtrTy =
+      LLVM::LLVMPointerType::get(typeConverter.getIndexType());
   Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
                                                    ValueRange({index}));
   builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
@@ -987,7 +988,8 @@ void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
 Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc,
                                               LLVMTypeConverter &typeConverter,
                                               Value sizeBasePtr, Value rank) {
-  LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+  LLVM::LLVMType indexPtrTy =
+      LLVM::LLVMPointerType::get(typeConverter.getIndexType());
   return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
                                      ValueRange({rank}));
 }
@@ -996,7 +998,8 @@ Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
                                        LLVMTypeConverter typeConverter,
                                        Value strideBasePtr, Value index,
                                        Value stride) {
-  LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+  LLVM::LLVMType indexPtrTy =
+      LLVM::LLVMPointerType::get(typeConverter.getIndexType());
   Value strideStoreGep = builder.create<LLVM::GEPOp>(
       loc, indexPtrTy, strideBasePtr, ValueRange({index}));
   return builder.create<LLVM::LoadOp>(loc, strideStoreGep);
@@ -1006,7 +1009,8 @@ void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
                                          LLVMTypeConverter typeConverter,
                                          Value strideBasePtr, Value index,
                                          Value stride) {
-  LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+  LLVM::LLVMType indexPtrTy =
+      LLVM::LLVMPointerType::get(typeConverter.getIndexType());
   Value strideStoreGep = builder.create<LLVM::GEPOp>(
       loc, indexPtrTy, strideBasePtr, ValueRange({index}));
   builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
@@ -1100,7 +1104,7 @@ bool ConvertToLLVMPattern::isSupportedMemRefType(MemRefType type) const {
 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
   auto elementType = type.getElementType();
   auto structElementType = unwrap(typeConverter->convertType(elementType));
-  return structElementType.getPointerTo(type.getMemorySpace());
+  return LLVM::LLVMPointerType::get(structElementType, type.getMemorySpace());
 }
 
 void ConvertToLLVMPattern::getMemRefDescriptorSizes(
@@ -1158,8 +1162,8 @@ Value ConvertToLLVMPattern::getSizeInBytes(
   //   %0 = getelementptr %elementType* null, %indexType 1
   //   %1 = ptrtoint %elementType* %0 to %indexType
   // which is a common pattern of getting the size of a type in bytes.
-  auto convertedPtrType =
-      typeConverter->convertType(type).cast<LLVM::LLVMType>().getPointerTo();
+  auto convertedPtrType = LLVM::LLVMPointerType::get(
+      typeConverter->convertType(type).cast<LLVM::LLVMType>());
   auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
   auto gep = rewriter.create<LLVM::GEPOp>(
       loc, convertedPtrType,
@@ -1315,7 +1319,8 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
                     builder, loc, typeConverter, unrankedMemRefType,
                     wrapperArgsRange.take_front(numToDrop));
 
-      auto ptrTy = packed.getType().cast<LLVM::LLVMType>().getPointerTo();
+      auto ptrTy =
+          LLVM::LLVMPointerType::get(packed.getType().cast<LLVM::LLVMType>());
       Value one = builder.create<LLVM::ConstantOp>(
           loc, typeConverter.convertType(builder.getIndexType()),
           builder.getIntegerAttr(builder.getIndexType(), 1));
@@ -1512,11 +1517,12 @@ static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
     return info;
   info.arraySizes.reserve(vectorType.getRank() - 1);
   auto llvmTy = info.llvmArrayTy;
-  while (llvmTy.isArrayTy()) {
-    info.arraySizes.push_back(llvmTy.getArrayNumElements());
-    llvmTy = llvmTy.getArrayElementType();
+  while (llvmTy.isa<LLVM::LLVMArrayType>()) {
+    info.arraySizes.push_back(
+        llvmTy.cast<LLVM::LLVMArrayType>().getNumElements());
+    llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType();
   }
-  if (!llvmTy.isVectorTy())
+  if (!llvmTy.isa<LLVM::LLVMVectorType>())
     return info;
   info.llvmVectorTy = llvmTy;
   return info;
@@ -1644,7 +1650,7 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
     return failure();
 
   auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
-  if (!llvmArrayTy.isArrayTy())
+  if (!llvmArrayTy.isa<LLVM::LLVMArrayType>())
     return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
 
   auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy,
@@ -2457,13 +2463,14 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
     LLVM::LLVMType arrayTy =
         convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
-        loc, arrayTy.getPointerTo(memSpace), getGlobalOp.name());
+        loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
 
     // Get the address of the first element in the array by creating a GEP with
     // the address of the GV as the base, and (rank + 1) number of 0 indices.
     LLVM::LLVMType elementType =
         unwrap(typeConverter->convertType(type.getElementType()));
-    LLVM::LLVMType elementPtrType = elementType.getPointerTo(memSpace);
+    LLVM::LLVMType elementPtrType =
+        LLVM::LLVMPointerType::get(elementType, memSpace);
 
     SmallVector<Value, 4> operands = {addressOf};
     operands.insert(operands.end(), type.getRank() + 1,
@@ -2504,9 +2511,9 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
 
-    if (!operandType.isArrayTy()) {
+    if (!operandType.isa<LLVM::LLVMArrayType>()) {
       LLVM::ConstantOp one;
-      if (operandType.isVectorTy()) {
+      if (operandType.isa<LLVM::LLVMVectorType>()) {
         one = rewriter.create<LLVM::ConstantOp>(
             loc, operandType,
             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
@@ -2526,8 +2533,10 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
         op.getOperation(), operands, *getTypeConverter(),
         [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
           auto splatAttr = SplatElementsAttr::get(
-              mlir::VectorType::get({llvmVectorTy.getVectorNumElements()},
-                                    floatType),
+              mlir::VectorType::get(
+                  {llvmVectorTy.cast<LLVM::LLVMFixedVectorType>()
+                       .getNumElements()},
+                  floatType),
               floatOne);
           auto one =
               rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
@@ -2614,12 +2623,13 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
       // ptr = ExtractValueOp src, 1
       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
       // castPtr = BitCastOp i8* to structTy*
-      auto castPtr =
-          rewriter
-              .create<LLVM::BitcastOp>(
-                  loc, targetStructType.cast<LLVM::LLVMType>().getPointerTo(),
-                  ptr)
-              .getResult();
+      auto castPtr = rewriter
+                         .create<LLVM::BitcastOp>(
+                             loc,
+                             LLVM::LLVMPointerType::get(
+                                 targetStructType.cast<LLVM::LLVMType>()),
+                             ptr)
+                         .getResult();
       // struct = LoadOp castPtr
       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
@@ -2654,8 +2664,8 @@ static void extractPointersAndOffset(Location loc,
   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
   LLVM::LLVMType llvmElementType =
       unwrap(typeConverter.convertType(elementType));
-  LLVM::LLVMType elementPtrPtrType =
-      llvmElementType.getPointerTo(memorySpace).getPointerTo();
+  LLVM::LLVMType elementPtrPtrType = LLVM::LLVMPointerType::get(
+      LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
 
   // Extract pointer to the underlying ranked memref descriptor and cast it to
   // ElemType**.
@@ -2700,8 +2710,8 @@ private:
     MemRefType targetMemRefType =
         castOp.getResult().getType().cast<MemRefType>();
     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
-                                      .dyn_cast_or_null<LLVM::LLVMType>();
-    if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
+                                      .dyn_cast_or_null<LLVM::LLVMStructType>();
+    if (!llvmTargetDescriptorTy)
       return failure();
 
     // Create descriptor.
@@ -2804,8 +2814,8 @@ private:
     // Set pointers and offset.
     LLVM::LLVMType llvmElementType =
         unwrap(typeConverter->convertType(elementType));
-    LLVM::LLVMType elementPtrPtrType =
-        llvmElementType.getPointerTo(addressSpace).getPointerTo();
+    auto elementPtrPtrType = LLVM::LLVMPointerType::get(
+        LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
                                               elementPtrPtrType, allocatedPtr);
     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
@@ -2858,7 +2868,7 @@ private:
     rewriter.setInsertionPointToStart(bodyBlock);
 
     // Copy size from shape to descriptor.
-    LLVM::LLVMType llvmIndexPtrType = indexType.getPointerTo();
+    LLVM::LLVMType llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
@@ -2950,14 +2960,14 @@ private:
     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
         loc,
-        typeConverter->convertType(scalarMemRefType)
-            .cast<LLVM::LLVMType>()
-            .getPointerTo(addressSpace),
+        LLVM::LLVMPointerType::get(
+            typeConverter->convertType(scalarMemRefType).cast<LLVM::LLVMType>(),
+            addressSpace),
         underlyingRankedDesc);
 
     // Get pointer to offset field of memref<element_type> descriptor.
-    Type indexPtrTy =
-        getTypeConverter()->getIndexType().getPointerTo(addressSpace);
+    Type indexPtrTy = LLVM::LLVMPointerType::get(
+        getTypeConverter()->getIndexType()addressSpace);
     Value two = rewriter.create<LLVM::ConstantOp>(
         loc, typeConverter->convertType(rewriter.getI32Type()),
         rewriter.getI32IntegerAttr(2));
@@ -3120,10 +3130,10 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
 
     auto targetType =
         typeConverter->convertType(indexCastOp.getResult().getType())
-            .cast<LLVM::LLVMType>();
-    auto sourceType = transformed.in().getType().cast<LLVM::LLVMType>();
-    unsigned targetBits = targetType.getIntegerBitWidth();
-    unsigned sourceBits = sourceType.getIntegerBitWidth();
+            .cast<LLVM::LLVMIntegerType>();
+    auto sourceType = transformed.in().getType().cast<LLVM::LLVMIntegerType>();
+    unsigned targetBits = targetType.getBitWidth();
+    unsigned sourceBits = sourceType.getBitWidth();
 
     if (targetBits == sourceBits)
       rewriter.replaceOp(indexCastOp, transformed.in());
@@ -3462,14 +3472,18 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
     // Copy the buffer pointer from the old descriptor to the new one.
     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
-        loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()),
+        loc,
+        LLVM::LLVMPointerType::get(targetElementTy,
+                                   viewMemRefType.getMemorySpace()),
         extracted);
     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
 
     // Copy the buffer pointer from the old descriptor to the new one.
     extracted = sourceMemRef.alignedPtr(rewriter, loc);
     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
-        loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()),
+        loc,
+        LLVM::LLVMPointerType::get(targetElementTy,
+                                   viewMemRefType.getMemorySpace()),
         extracted);
     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
 
@@ -3662,7 +3676,9 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
     auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
-        loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()),
+        loc,
+        LLVM::LLVMPointerType::get(targetElementTy,
+                                   srcMemRefType.getMemorySpace()),
         allocatedPtr);
     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
 
@@ -3671,7 +3687,9 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
     alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
                                               alignedPtr, adaptor.byte_shift());
     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
-        loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()),
+        loc,
+        LLVM::LLVMPointerType::get(targetElementTy,
+                                   srcMemRefType.getMemorySpace()),
         alignedPtr);
     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
 
@@ -4064,7 +4082,8 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
   auto indexType = IndexType::get(context);
   // Alloca with proper alignment. We do not expect optimizations of this
   // alloca op and so we omit allocating at the entry block.
-  auto ptrType = operand.getType().cast<LLVM::LLVMType>().getPointerTo();
+  auto ptrType =
+      LLVM::LLVMPointerType::get(operand.getType().cast<LLVM::LLVMType>());
   Value one = builder.create<LLVM::ConstantOp>(loc, int64Ty,
                                                IntegerAttr::get(indexType, 1));
   Value allocated =
index a982b90e0e93bc3714bcc2c06fad73357bd84b08..bcc91e304e72dece69e7cd5b7f91ab78e5323f07 100644 (file)
@@ -193,7 +193,7 @@ static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
   Value base;
   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
     return failure();
-  auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
+  auto pType = LLVM::LLVMPointerType::get(type.template cast<LLVM::LLVMType>());
   base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
   return success();
@@ -1100,14 +1100,14 @@ public:
       return failure();
 
     auto llvmSourceDescriptorTy =
-        operands[0].getType().dyn_cast<LLVM::LLVMType>();
-    if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
+        operands[0].getType().dyn_cast<LLVM::LLVMStructType>();
+    if (!llvmSourceDescriptorTy)
       return failure();
     MemRefDescriptor sourceMemRef(operands[0]);
 
     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
-                                      .dyn_cast_or_null<LLVM::LLVMType>();
-    if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
+                                      .dyn_cast_or_null<LLVM::LLVMStructType>();
+    if (!llvmTargetDescriptorTy)
       return failure();
 
     // Only contiguous source buffers supported atm.
@@ -1231,15 +1231,15 @@ public:
     // TODO: support alignment when possible.
     Value dataPtr = this->getStridedElementPtr(
         loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
-    auto vecTy =
-        toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
+    auto vecTy = toLLVMTy(xferOp.getVectorType())
+                     .template cast<LLVM::LLVMFixedVectorType>();
     Value vectorDataPtr;
     if (memRefType.getMemorySpace() == 0)
-      vectorDataPtr =
-          rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
+      vectorDataPtr = rewriter.create<LLVM::BitcastOp>(
+          loc, LLVM::LLVMPointerType::get(vecTy), dataPtr);
     else
       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
-          loc, vecTy.getPointerTo(), dataPtr);
+          loc, LLVM::LLVMPointerType::get(vecTy), dataPtr);
 
     if (!xferOp.isMaskedDim(0))
       return replaceTransferOpWithLoadOrStore(rewriter,
@@ -1253,7 +1253,7 @@ public:
     //
     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
     //       dimensions here.
-    unsigned vecWidth = vecTy.getVectorNumElements();
+    unsigned vecWidth = vecTy.getNumElements();
     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
     Value off = xferOp.indices()[lastIndex];
     Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
index 973b116ef498ca62127bc37db380aec81a3fc0c4..1335f33e10aa73e593fc8f6266f2fc42893b6d7a 100644 (file)
@@ -78,9 +78,9 @@ public:
     auto toLLVMTy = [&](Type t) {
       return this->getTypeConverter()->convertType(t);
     };
-    LLVM::LLVMType vecTy =
-        toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
-    unsigned vecWidth = vecTy.getVectorNumElements();
+    auto vecTy = toLLVMTy(xferOp.getVectorType())
+                     .template cast<LLVM::LLVMFixedVectorType>();
+    unsigned vecWidth = vecTy.getNumElements();
     Location loc = xferOp->getLoc();
 
     // The backend result vector scalarization have trouble scalarize
index 7b1300da1783f8a1792581db631508b7469f8c17..2bdbb877ec84c72d15b5ed75a1af66f64cb9491e 100644 (file)
@@ -105,9 +105,10 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
   auto argType = type.dyn_cast<LLVM::LLVMType>();
   if (!argType)
     return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type");
-  if (argType.isVectorTy())
-    resultType =
-        LLVMType::getVectorTy(resultType, argType.getVectorNumElements());
+  if (auto vecArgType = argType.dyn_cast<LLVM::LLVMFixedVectorType>())
+    resultType = LLVMType::getVectorTy(resultType, vecArgType.getNumElements());
+  assert(!argType.isa<LLVM::LLVMScalableVectorType>() &&
+         "unhandled scalable vector");
 
   result.addTypes({resultType});
   return success();
@@ -118,7 +119,7 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
 //===----------------------------------------------------------------------===//
 
 static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
-  auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
+  auto elemTy = op.getType().cast<LLVM::LLVMPointerType>().getElementType();
 
   auto funcTy = FunctionType::get(op.getContext(), {op.arraySize().getType()},
                                   {op.getType()});
@@ -363,14 +364,11 @@ static void printLoadOp(OpAsmPrinter &p, LoadOp &op) {
 // the resulting type wrapped in MLIR, or nullptr on error.
 static Type getLoadStoreElementType(OpAsmParser &parser, Type type,
                                     llvm::SMLoc trailingTypeLoc) {
-  auto llvmTy = type.dyn_cast<LLVM::LLVMType>();
+  auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>();
   if (!llvmTy)
-    return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"),
-           nullptr;
-  if (!llvmTy.isPointerTy())
     return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"),
            nullptr;
-  return llvmTy.getPointerElementTy();
+  return llvmTy.getElementType();
 }
 
 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type
@@ -569,7 +567,7 @@ static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
 
     auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
                                                       /*isVarArg=*/false);
-    auto wrappedFuncType = llvmFuncType.getPointerTo();
+    auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
 
     auto funcArguments = llvm::makeArrayRef(operands).drop_front();
 
@@ -613,7 +611,7 @@ static LogicalResult verify(LandingpadOp op) {
 
   for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) {
     value = op.getOperand(idx);
-    bool isFilter = value.getType().cast<LLVMType>().isArrayTy();
+    bool isFilter = value.getType().isa<LLVMArrayType>();
     if (isFilter) {
       // FIXME: Verify filter clauses when arrays are appropriately handled
     } else {
@@ -646,7 +644,7 @@ static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) {
   for (auto value : op.getOperands()) {
     // Similar to llvm - if clause is an array type then it is filter
     // clause else catch clause
-    bool isArrayTy = value.getType().cast<LLVMType>().isArrayTy();
+    bool isArrayTy = value.getType().isa<LLVMArrayType>();
     p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
       << value.getType() << ") ";
   }
@@ -728,37 +726,37 @@ static LogicalResult verify(CallOp &op) {
 
     fnType = fn.getType();
   }
-  if (!fnType.isFunctionTy())
+
+  LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>();
+  if (!funcType)
     return op.emitOpError("callee does not have a functional type: ") << fnType;
 
   // Verify that the operand and result types match the callee.
 
-  if (!fnType.isFunctionVarArg() &&
-      fnType.getFunctionNumParams() != (op.getNumOperands() - isIndirect))
+  if (!funcType.isVarArg() &&
+      funcType.getNumParams() != (op.getNumOperands() - isIndirect))
     return op.emitOpError()
            << "incorrect number of operands ("
            << (op.getNumOperands() - isIndirect)
-           << ") for callee (expecting: " << fnType.getFunctionNumParams()
-           << ")";
+           << ") for callee (expecting: " << funcType.getNumParams() << ")";
 
-  if (fnType.getFunctionNumParams() > (op.getNumOperands() - isIndirect))
+  if (funcType.getNumParams() > (op.getNumOperands() - isIndirect))
     return op.emitOpError() << "incorrect number of operands ("
                             << (op.getNumOperands() - isIndirect)
                             << ") for varargs callee (expecting at least: "
-                            << fnType.getFunctionNumParams() << ")";
+                            << funcType.getNumParams() << ")";
 
-  for (unsigned i = 0, e = fnType.getFunctionNumParams(); i != e; ++i)
-    if (op.getOperand(i + isIndirect).getType() !=
-        fnType.getFunctionParamType(i))
+  for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
+    if (op.getOperand(i + isIndirect).getType() != funcType.getParamType(i))
       return op.emitOpError() << "operand type mismatch for operand " << i
                               << ": " << op.getOperand(i + isIndirect).getType()
-                              << " != " << fnType.getFunctionParamType(i);
+                              << " != " << funcType.getParamType(i);
 
   if (op.getNumResults() &&
-      op.getResult(0).getType() != fnType.getFunctionResultType())
+      op.getResult(0).getType() != funcType.getReturnType())
     return op.emitOpError()
            << "result type mismatch: " << op.getResult(0).getType()
-           << " != " << fnType.getFunctionResultType();
+           << " != " << funcType.getReturnType();
 
   return success();
 }
@@ -848,7 +846,7 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
     }
     auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
                                                       /*isVarArg=*/false);
-    auto wrappedFuncType = llvmFuncType.getPointerTo();
+    auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
 
     auto funcArguments =
         ArrayRef<OpAsmParser::OperandType>(operands).drop_front();
@@ -875,8 +873,8 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result,
                                    Value vector, Value position,
                                    ArrayRef<NamedAttribute> attrs) {
-  auto wrappedVectorType = vector.getType().cast<LLVM::LLVMType>();
-  auto llvmType = wrappedVectorType.getVectorElementType();
+  auto vectorType = vector.getType().cast<LLVM::LLVMVectorType>();
+  auto llvmType = vectorType.getElementType();
   build(b, result, llvmType, vector, position);
   result.addAttributes(attrs);
 }
@@ -903,11 +901,11 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
       parser.resolveOperand(vector, type, result.operands) ||
       parser.resolveOperand(position, positionType, result.operands))
     return failure();
-  auto wrappedVectorType = type.dyn_cast<LLVM::LLVMType>();
-  if (!wrappedVectorType || !wrappedVectorType.isVectorTy())
+  auto vectorType = type.dyn_cast<LLVM::LLVMVectorType>();
+  if (!vectorType)
     return parser.emitError(
         loc, "expected LLVM IR dialect vector type for operand #1");
-  result.addTypes(wrappedVectorType.getVectorElementType());
+  result.addTypes(vectorType.getElementType());
   return success();
 }
 
@@ -930,8 +928,8 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser,
                                                        ArrayAttr positionAttr,
                                                        llvm::SMLoc attributeLoc,
                                                        llvm::SMLoc typeLoc) {
-  auto wrappedContainerType = containerType.dyn_cast<LLVM::LLVMType>();
-  if (!wrappedContainerType)
+  auto llvmType = containerType.dyn_cast<LLVM::LLVMType>();
+  if (!llvmType)
     return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
 
   // Infer the element type from the structure type: iteratively step inside the
@@ -945,26 +943,24 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser,
                               "expected an array of integer literals"),
              nullptr;
     int position = positionElementAttr.getInt();
-    if (wrappedContainerType.isArrayTy()) {
-      if (position < 0 || static_cast<unsigned>(position) >=
-                              wrappedContainerType.getArrayNumElements())
+    if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
+      if (position < 0 ||
+          static_cast<unsigned>(position) >= arrayType.getNumElements())
         return parser.emitError(attributeLoc, "position out of bounds"),
                nullptr;
-      wrappedContainerType = wrappedContainerType.getArrayElementType();
-    } else if (wrappedContainerType.isStructTy()) {
-      if (position < 0 || static_cast<unsigned>(position) >=
-                              wrappedContainerType.getStructNumElements())
+      llvmType = arrayType.getElementType();
+    } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
+      if (position < 0 ||
+          static_cast<unsigned>(position) >= structType.getBody().size())
         return parser.emitError(attributeLoc, "position out of bounds"),
                nullptr;
-      wrappedContainerType =
-          wrappedContainerType.getStructElementType(position);
+      llvmType = structType.getBody()[position];
     } else {
-      return parser.emitError(typeLoc,
-                              "expected wrapped LLVM IR structure/array type"),
+      return parser.emitError(typeLoc, "expected LLVM IR structure/array type"),
              nullptr;
     }
   }
-  return wrappedContainerType;
+  return llvmType;
 }
 
 // <operation> ::= `llvm.extractvalue` ssa-use
@@ -1021,11 +1017,11 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser,
       parser.parseColonType(vectorType))
     return failure();
 
-  auto wrappedVectorType = vectorType.dyn_cast<LLVM::LLVMType>();
-  if (!wrappedVectorType || !wrappedVectorType.isVectorTy())
+  auto llvmVectorType = vectorType.dyn_cast<LLVM::LLVMVectorType>();
+  if (!llvmVectorType)
     return parser.emitError(
         loc, "expected LLVM IR dialect vector type for operand #1");
-  auto valueType = wrappedVectorType.getVectorElementType();
+  Type valueType = llvmVectorType.getElementType();
   if (!valueType)
     return failure();
 
@@ -1145,12 +1141,14 @@ static LogicalResult verify(AddressOfOp op) {
     return op.emitOpError(
         "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
 
-  if (global && global.getType().getPointerTo(global.addr_space()) !=
-                    op.getResult().getType())
+  if (global &&
+      LLVM::LLVMPointerType::get(global.getType(), global.addr_space()) !=
+          op.getResult().getType())
     return op.emitOpError(
         "the type must be a pointer to the type of the referenced global");
 
-  if (function && function.getType().getPointerTo() != op.getResult().getType())
+  if (function && LLVM::LLVMPointerType::get(function.getType()) !=
+                      op.getResult().getType())
     return op.emitOpError(
         "the type must be a pointer to the type of the referenced function");
 
@@ -1276,11 +1274,11 @@ static LogicalResult verifyCast(DialectCastOp op, LLVMType llvmType,
     if (vectorType.getRank() != 1)
       return op->emitOpError("only 1-d vector is allowed");
 
-    auto llvmVector = llvmType.dyn_cast<LLVMVectorType>();
-    if (llvmVector.isa<LLVMScalableVectorType>())
+    auto llvmVector = llvmType.dyn_cast<LLVMFixedVectorType>();
+    if (!llvmVector)
       return op->emitOpError("only fixed-sized vector is allowed");
 
-    if (vectorType.getDimSize(0) != llvmVector.getVectorNumElements())
+    if (vectorType.getDimSize(0) != llvmVector.getNumElements())
       return op->emitOpError(
           "invalid cast between vectors with mismatching sizes");
 
@@ -1375,7 +1373,10 @@ static LogicalResult verifyCast(DialectCastOp op, LLVMType llvmType,
                              "be an index-compatible integer");
 
     auto ptrType = structType.getBody()[1].dyn_cast<LLVMPointerType>();
-    if (!ptrType || !ptrType.getPointerElementTy().isIntegerTy(8))
+    auto ptrElementType =
+        ptrType ? ptrType.getElementType().dyn_cast<LLVMIntegerType>()
+                : nullptr;
+    if (!ptrElementType || ptrElementType.getBitWidth() != 8)
       return op->emitOpError("expected second element of a memref descriptor "
                              "to be an !llvm.ptr<i8>");
 
@@ -1503,9 +1504,11 @@ static LogicalResult verify(GlobalOp op) {
     return op.emitOpError("must appear at the module level");
 
   if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
-    auto type = op.getType();
-    if (!type.isArrayTy() || !type.getArrayElementType().isIntegerTy(8) ||
-        type.getArrayNumElements() != strAttr.getValue().size())
+    auto type = op.getType().dyn_cast<LLVMArrayType>();
+    LLVMIntegerType elementType =
+        type ? type.getElementType().dyn_cast<LLVMIntegerType>() : nullptr;
+    if (!elementType || elementType.getBitWidth() != 8 ||
+        type.getNumElements() != strAttr.getValue().size())
       return op.emitOpError(
           "requires an i8 array type of the length equal to that of the string "
           "attribute");
@@ -1534,9 +1537,9 @@ static LogicalResult verify(GlobalOp op) {
 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result,
                                   Value v1, Value v2, ArrayAttr mask,
                                   ArrayRef<NamedAttribute> attrs) {
-  auto wrappedContainerType1 = v1.getType().cast<LLVM::LLVMType>();
-  auto vType = LLVMType::getVectorTy(
-      wrappedContainerType1.getVectorElementType(), mask.size());
+  auto containerType = v1.getType().cast<LLVM::LLVMVectorType>();
+  auto vType =
+      LLVMType::getVectorTy(containerType.getElementType(), mask.size());
   build(b, result, vType, v1, v2, mask);
   result.addAttributes(attrs);
 }
@@ -1566,12 +1569,12 @@ static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
       parser.resolveOperand(v1, typeV1, result.operands) ||
       parser.resolveOperand(v2, typeV2, result.operands))
     return failure();
-  auto wrappedContainerType1 = typeV1.dyn_cast<LLVM::LLVMType>();
-  if (!wrappedContainerType1 || !wrappedContainerType1.isVectorTy())
+  auto containerType = typeV1.dyn_cast<LLVM::LLVMVectorType>();
+  if (!containerType)
     return parser.emitError(
         loc, "expected LLVM IR dialect vector type for operand #1");
-  auto vType = LLVMType::getVectorTy(
-      wrappedContainerType1.getVectorElementType(), maskAttr.size());
+  auto vType =
+      LLVMType::getVectorTy(containerType.getElementType(), maskAttr.size());
   result.addTypes(vType);
   return success();
 }
@@ -1588,9 +1591,9 @@ Block *LLVMFuncOp::addEntryBlock() {
   auto *entry = new Block;
   push_back(entry);
 
-  LLVMType type = getType();
-  for (unsigned i = 0, e = type.getFunctionNumParams(); i < e; ++i)
-    entry->addArgument(type.getFunctionParamType(i));
+  LLVMFunctionType type = getType();
+  for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
+    entry->addArgument(type.getParamType(i));
   return entry;
 }
 
@@ -1608,7 +1611,7 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
   if (argAttrs.empty())
     return;
 
-  unsigned numInputs = type.getFunctionNumParams();
+  unsigned numInputs = type.cast<LLVMFunctionType>().getNumParams();
   assert(numInputs == argAttrs.size() &&
          "expected as many argument attribute lists as arguments");
   SmallString<8> argAttrName;
@@ -1711,15 +1714,15 @@ static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
     p << stringifyLinkage(op.linkage()) << ' ';
   p.printSymbolName(op.getName());
 
-  LLVMType fnType = op.getType();
+  LLVMFunctionType fnType = op.getType();
   SmallVector<Type, 8> argTypes;
   SmallVector<Type, 1> resTypes;
-  argTypes.reserve(fnType.getFunctionNumParams());
-  for (unsigned i = 0, e = fnType.getFunctionNumParams(); i < e; ++i)
-    argTypes.push_back(fnType.getFunctionParamType(i));
+  argTypes.reserve(fnType.getNumParams());
+  for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
+    argTypes.push_back(fnType.getParamType(i));
 
-  LLVMType returnType = fnType.getFunctionResultType();
-  if (!returnType.isVoidTy())
+  LLVMType returnType = fnType.getReturnType();
+  if (!returnType.isa<LLVMVoidType>())
     resTypes.push_back(returnType);
 
   impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), resTypes);
@@ -1737,8 +1740,8 @@ static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
 // attribute is present.  This can check for preconditions of the
 // getNumArguments hook not failing.
 LogicalResult LLVMFuncOp::verifyType() {
-  auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMType>();
-  if (!llvmType || !llvmType.isFunctionTy())
+  auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMFunctionType>();
+  if (!llvmType)
     return emitOpError("requires '" + getTypeAttrName() +
                        "' attribute of wrapped LLVM function type");
 
@@ -1747,9 +1750,7 @@ LogicalResult LLVMFuncOp::verifyType() {
 
 // Hook for OpTrait::FunctionLike, returns the number of function arguments.
 // Depends on the type attribute being correct as checked by verifyType
-unsigned LLVMFuncOp::getNumFuncArguments() {
-  return getType().getFunctionNumParams();
-}
+unsigned LLVMFuncOp::getNumFuncArguments() { return getType().getNumParams(); }
 
 // Hook for OpTrait::FunctionLike, returns the number of function results.
 // Depends on the type attribute being correct as checked by verifyType
@@ -1759,7 +1760,7 @@ unsigned LLVMFuncOp::getNumFuncResults() {
   // If we modeled a void return as one result, then it would be possible to
   // attach an MLIR result attribute to it, and it isn't clear what semantics we
   // would assign to that.
-  if (getType().getFunctionResultType().isVoidTy())
+  if (getType().getReturnType().isa<LLVMVoidType>())
     return 0;
   return 1;
 }
@@ -1788,7 +1789,7 @@ static LogicalResult verify(LLVMFuncOp op) {
   if (op.isVarArg())
     return op.emitOpError("only external functions can be variadic");
 
-  unsigned numArguments = op.getType().getFunctionNumParams();
+  unsigned numArguments = op.getType().getNumParams();
   Block &entryBlock = op.front();
   for (unsigned i = 0; i < numArguments; ++i) {
     Type argType = entryBlock.getArgument(i).getType();
@@ -1796,7 +1797,7 @@ static LogicalResult verify(LLVMFuncOp op) {
     if (!argLLVMType)
       return op.emitOpError("entry block argument #")
              << i << " is not of LLVM type";
-    if (op.getType().getFunctionParamType(i) != argLLVMType)
+    if (op.getType().getParamType(i) != argLLVMType)
       return op.emitOpError("the type of entry block argument #")
              << i << " does not match the function signature";
   }
@@ -1896,7 +1897,8 @@ static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
       parseAtomicOrdering(parser, result, "ordering") ||
       parser.parseOptionalAttrDict(result.attributes) ||
       parser.parseColonType(type) ||
-      parser.resolveOperand(ptr, type.getPointerTo(), result.operands) ||
+      parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
+                            result.operands) ||
       parser.resolveOperand(val, type, result.operands))
     return failure();
 
@@ -1905,9 +1907,9 @@ static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
 }
 
 static LogicalResult verify(AtomicRMWOp op) {
-  auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>();
+  auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>();
   auto valType = op.val().getType().cast<LLVM::LLVMType>();
-  if (valType != ptrType.getPointerElementTy())
+  if (valType != ptrType.getElementType())
     return op.emitOpError("expected LLVM IR element type for operand #0 to "
                           "match type for operand #1");
   auto resType = op.res().getType().cast<LLVM::LLVMType>();
@@ -1915,17 +1917,21 @@ static LogicalResult verify(AtomicRMWOp op) {
     return op.emitOpError(
         "expected LLVM IR result type to match type for operand #1");
   if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) {
-    if (!valType.isFloatingPointTy())
+    if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
       return op.emitOpError("expected LLVM IR floating point type");
   } else if (op.bin_op() == AtomicBinOp::xchg) {
-    if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
-        !valType.isIntegerTy(32) && !valType.isIntegerTy(64) &&
-        !valType.isBFloatTy() && !valType.isHalfTy() && !valType.isFloatTy() &&
-        !valType.isDoubleTy())
+    auto intType = valType.dyn_cast<LLVMIntegerType>();
+    unsigned intBitWidth = intType ? intType.getBitWidth() : 0;
+    if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
+        intBitWidth != 64 && !valType.isa<LLVMBFloatType>() &&
+        !valType.isa<LLVMHalfType>() && !valType.isa<LLVMFloatType>() &&
+        !valType.isa<LLVMDoubleType>())
       return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
   } else {
-    if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
-        !valType.isIntegerTy(32) && !valType.isIntegerTy(64))
+    auto intType = valType.dyn_cast<LLVMIntegerType>();
+    unsigned intBitWidth = intType ? intType.getBitWidth() : 0;
+    if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
+        intBitWidth != 64)
       return op.emitOpError("expected LLVM IR integer type");
   }
   return success();
@@ -1958,7 +1964,8 @@ static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
       parseAtomicOrdering(parser, result, "failure_ordering") ||
       parser.parseOptionalAttrDict(result.attributes) ||
       parser.parseColonType(type) ||
-      parser.resolveOperand(ptr, type.getPointerTo(), result.operands) ||
+      parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
+                            result.operands) ||
       parser.resolveOperand(cmp, type, result.operands) ||
       parser.resolveOperand(val, type, result.operands))
     return failure();
@@ -1971,18 +1978,20 @@ static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
 }
 
 static LogicalResult verify(AtomicCmpXchgOp op) {
-  auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>();
-  if (!ptrType.isPointerTy())
+  auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>();
+  if (!ptrType)
     return op.emitOpError("expected LLVM IR pointer type for operand #0");
   auto cmpType = op.cmp().getType().cast<LLVM::LLVMType>();
   auto valType = op.val().getType().cast<LLVM::LLVMType>();
-  if (cmpType != ptrType.getPointerElementTy() || cmpType != valType)
+  if (cmpType != ptrType.getElementType() || cmpType != valType)
     return op.emitOpError("expected LLVM IR element type for operand #0 to "
                           "match type for all other operands");
-  if (!valType.isPointerTy() && !valType.isIntegerTy(8) &&
-      !valType.isIntegerTy(16) && !valType.isIntegerTy(32) &&
-      !valType.isIntegerTy(64) && !valType.isBFloatTy() &&
-      !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy())
+  auto intType = valType.dyn_cast<LLVMIntegerType>();
+  unsigned intBitWidth = intType ? intType.getBitWidth() : 0;
+  if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 &&
+      intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 &&
+      !valType.isa<LLVMBFloatType>() && !valType.isa<LLVMHalfType>() &&
+      !valType.isa<LLVMFloatType>() && !valType.isa<LLVMDoubleType>())
     return op.emitOpError("unexpected LLVM IR type");
   if (op.success_ordering() < AtomicOrdering::monotonic ||
       op.failure_ordering() < AtomicOrdering::monotonic)
index a89287b764e5d714ca475d4badb6fd24bc0b016e..0616efb7ef3f998fd8fd653b158915c611b7121a 100644 (file)
@@ -36,129 +36,6 @@ LLVMDialect &LLVMType::getDialect() {
   return static_cast<LLVMDialect &>(Type::getDialect());
 }
 
-//----------------------------------------------------------------------------//
-// Misc type utilities.
-
-llvm::TypeSize LLVMType::getPrimitiveSizeInBits() {
-  return llvm::TypeSwitch<LLVMType, llvm::TypeSize>(*this)
-      .Case<LLVMHalfType, LLVMBFloatType>(
-          [](LLVMType) { return llvm::TypeSize::Fixed(16); })
-      .Case<LLVMFloatType>([](LLVMType) { return llvm::TypeSize::Fixed(32); })
-      .Case<LLVMDoubleType, LLVMX86MMXType>(
-          [](LLVMType) { return llvm::TypeSize::Fixed(64); })
-      .Case<LLVMIntegerType>([](LLVMIntegerType intTy) {
-        return llvm::TypeSize::Fixed(intTy.getBitWidth());
-      })
-      .Case<LLVMX86FP80Type>([](LLVMType) { return llvm::TypeSize::Fixed(80); })
-      .Case<LLVMPPCFP128Type, LLVMFP128Type>(
-          [](LLVMType) { return llvm::TypeSize::Fixed(128); })
-      .Case<LLVMVectorType>([](LLVMVectorType t) {
-        llvm::TypeSize elementSize =
-            t.getElementType().getPrimitiveSizeInBits();
-        llvm::ElementCount elementCount = t.getElementCount();
-        assert(!elementSize.isScalable() &&
-               "vector type should have fixed-width elements");
-        return llvm::TypeSize(elementSize.getFixedSize() *
-                                  elementCount.getKnownMinValue(),
-                              elementCount.isScalable());
-      })
-      .Default([](LLVMType ty) {
-        assert((ty.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
-                       LLVMTokenType, LLVMStructType, LLVMArrayType,
-                       LLVMPointerType, LLVMFunctionType>()) &&
-               "unexpected missing support for primitive type");
-        return llvm::TypeSize::Fixed(0);
-      });
-}
-
-//----------------------------------------------------------------------------//
-// Integer type utilities.
-
-bool LLVMType::isIntegerTy(unsigned bitwidth) {
-  if (auto intType = dyn_cast<LLVMIntegerType>())
-    return intType.getBitWidth() == bitwidth;
-  return false;
-}
-unsigned LLVMType::getIntegerBitWidth() {
-  return cast<LLVMIntegerType>().getBitWidth();
-}
-
-LLVMType LLVMType::getArrayElementType() {
-  return cast<LLVMArrayType>().getElementType();
-}
-
-//----------------------------------------------------------------------------//
-// Array type utilities.
-
-unsigned LLVMType::getArrayNumElements() {
-  return cast<LLVMArrayType>().getNumElements();
-}
-
-bool LLVMType::isArrayTy() { return isa<LLVMArrayType>(); }
-
-//----------------------------------------------------------------------------//
-// Vector type utilities.
-
-LLVMType LLVMType::getVectorElementType() {
-  return cast<LLVMVectorType>().getElementType();
-}
-
-unsigned LLVMType::getVectorNumElements() {
-  return cast<LLVMFixedVectorType>().getNumElements();
-}
-llvm::ElementCount LLVMType::getVectorElementCount() {
-  return cast<LLVMVectorType>().getElementCount();
-}
-
-bool LLVMType::isVectorTy() { return isa<LLVMVectorType>(); }
-
-//----------------------------------------------------------------------------//
-// Function type utilities.
-
-LLVMType LLVMType::getFunctionParamType(unsigned argIdx) {
-  return cast<LLVMFunctionType>().getParamType(argIdx);
-}
-
-unsigned LLVMType::getFunctionNumParams() {
-  return cast<LLVMFunctionType>().getNumParams();
-}
-
-LLVMType LLVMType::getFunctionResultType() {
-  return cast<LLVMFunctionType>().getReturnType();
-}
-
-bool LLVMType::isFunctionTy() { return isa<LLVMFunctionType>(); }
-
-bool LLVMType::isFunctionVarArg() {
-  return cast<LLVMFunctionType>().isVarArg();
-}
-
-//----------------------------------------------------------------------------//
-// Pointer type utilities.
-
-LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
-  return LLVMPointerType::get(*this, addrSpace);
-}
-
-LLVMType LLVMType::getPointerElementTy() {
-  return cast<LLVMPointerType>().getElementType();
-}
-
-bool LLVMType::isPointerTy() { return isa<LLVMPointerType>(); }
-
-//----------------------------------------------------------------------------//
-// Struct type utilities.
-
-LLVMType LLVMType::getStructElementType(unsigned i) {
-  return cast<LLVMStructType>().getBody()[i];
-}
-
-unsigned LLVMType::getStructNumElements() {
-  return cast<LLVMStructType>().getBody().size();
-}
-
-bool LLVMType::isStructTy() { return isa<LLVMStructType>(); }
-
 //----------------------------------------------------------------------------//
 // Utilities used to generate floating point types.
 
@@ -193,6 +70,10 @@ LLVMType LLVMType::getIntNTy(MLIRContext *context, unsigned numBits) {
   return LLVMIntegerType::get(context, numBits);
 }
 
+LLVMType LLVMType::getInt8PtrTy(MLIRContext *context) {
+  return LLVMPointerType::get(LLVMIntegerType::get(context, 8));
+}
+
 //----------------------------------------------------------------------------//
 // Utilities used to generate other miscellaneous types.
 
@@ -221,8 +102,6 @@ LLVMType LLVMType::getVoidTy(MLIRContext *context) {
   return LLVMVoidType::get(context);
 }
 
-bool LLVMType::isVoidTy() { return isa<LLVMVoidType>(); }
-
 //----------------------------------------------------------------------------//
 // Creation and setting of LLVM's identified struct types
 
@@ -470,7 +349,7 @@ LLVMStructType::verifyConstructionInvariants(Location loc,
 
 bool LLVMVectorType::isValidElementType(LLVMType type) {
   return type.isa<LLVMIntegerType, LLVMPointerType>() ||
-         type.isFloatingPointTy();
+         mlir::LLVM::isCompatibleFloatingPointType(type);
 }
 
 /// Support type casting functionality.
@@ -536,3 +415,42 @@ LLVMScalableVectorType::getChecked(Location loc, LLVMType elementType,
 unsigned LLVMScalableVectorType::getMinNumElements() {
   return getImpl()->numElements;
 }
+
+//===----------------------------------------------------------------------===//
+// Utility functions.
+//===----------------------------------------------------------------------===//
+
+llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
+  assert(isCompatibleType(type) &&
+         "expected a type compatible with the LLVM dialect");
+
+  return llvm::TypeSwitch<Type, llvm::TypeSize>(type)
+      .Case<LLVMHalfType, LLVMBFloatType>(
+          [](LLVMType) { return llvm::TypeSize::Fixed(16); })
+      .Case<LLVMFloatType>([](LLVMType) { return llvm::TypeSize::Fixed(32); })
+      .Case<LLVMDoubleType, LLVMX86MMXType>(
+          [](LLVMType) { return llvm::TypeSize::Fixed(64); })
+      .Case<LLVMIntegerType>([](LLVMIntegerType intTy) {
+        return llvm::TypeSize::Fixed(intTy.getBitWidth());
+      })
+      .Case<LLVMX86FP80Type>([](LLVMType) { return llvm::TypeSize::Fixed(80); })
+      .Case<LLVMPPCFP128Type, LLVMFP128Type>(
+          [](LLVMType) { return llvm::TypeSize::Fixed(128); })
+      .Case<LLVMVectorType>([](LLVMVectorType t) {
+        llvm::TypeSize elementSize =
+            getPrimitiveTypeSizeInBits(t.getElementType());
+        llvm::ElementCount elementCount = t.getElementCount();
+        assert(!elementSize.isScalable() &&
+               "vector type should have fixed-width elements");
+        return llvm::TypeSize(elementSize.getFixedSize() *
+                                  elementCount.getKnownMinValue(),
+                              elementCount.isScalable());
+      })
+      .Default([](Type ty) {
+        assert((ty.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
+                       LLVMTokenType, LLVMStructType, LLVMArrayType,
+                       LLVMPointerType, LLVMFunctionType>()) &&
+               "unexpected missing support for primitive type");
+        return llvm::TypeSize::Fixed(0);
+      });
+}
index 707ff7c1b089badd458799b14adc079c8e097b50..c202075fa2066b2ef289265db0e3d12182af501e 100644 (file)
@@ -57,8 +57,9 @@ static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
   for (auto &attr : result.attributes) {
     if (attr.first != "return_value_and_is_valid")
       continue;
-    if (type.isStructTy() && type.getStructNumElements() > 0)
-      type = type.getStructElementType(0);
+    auto structType = type.dyn_cast<LLVM::LLVMStructType>();
+    if (structType && !structType.getBody().empty())
+      type = structType.getBody()[0];
     break;
   }
 
index a323e68170c1b739bd03d68a3a61e80917998f1a..bfdae2b4588d597d1daa9d840452b53005c64d63 100644 (file)
@@ -196,19 +196,30 @@ template <typename Type>
 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
 template <>
 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
-  if (!mainFunction.getType().getFunctionResultType().isIntegerTy(32))
+  auto resultType = mainFunction.getType()
+                        .cast<LLVM::LLVMFunctionType>()
+                        .getReturnType()
+                        .dyn_cast<LLVM::LLVMIntegerType>();
+  if (!resultType || resultType.getBitWidth() != 32)
     return make_string_error("only single llvm.i32 function result supported");
   return Error::success();
 }
 template <>
 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
-  if (!mainFunction.getType().getFunctionResultType().isIntegerTy(64))
+  auto resultType = mainFunction.getType()
+                        .cast<LLVM::LLVMFunctionType>()
+                        .getReturnType()
+                        .dyn_cast<LLVM::LLVMIntegerType>();
+  if (!resultType || resultType.getBitWidth() != 64)
     return make_string_error("only single llvm.i64 function result supported");
   return Error::success();
 }
 template <>
 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
-  if (!mainFunction.getType().getFunctionResultType().isFloatTy())
+  if (!mainFunction.getType()
+           .cast<LLVM::LLVMFunctionType>()
+           .getReturnType()
+           .isa<LLVM::LLVMFloatType>())
     return make_string_error("only single llvm.f32 function result supported");
   return Error::success();
 }
@@ -220,7 +231,7 @@ Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module,
   if (!mainFunction || mainFunction.isExternal())
     return make_string_error("entry point not found");
 
-  if (mainFunction.getType().getFunctionNumParams() != 0)
+  if (mainFunction.getType().cast<LLVM::LLVMFunctionType>().getNumParams() != 0)
     return make_string_error("function inputs not supported");
 
   if (Error error = checkCompatibleReturnType<Type>(mainFunction))
index 7f89a41de5db723b6e77b9ecf4b669ef2ea0fc7b..9786751ef4b0db45300c8c0539543394d8aa3c61 100644 (file)
@@ -172,57 +172,57 @@ Type Importer::getStdTypeForAttr(LLVMType type) {
   if (!type)
     return nullptr;
 
-  if (type.isIntegerTy())
-    return b.getIntegerType(type.getIntegerBitWidth());
+  if (auto intType = type.dyn_cast<LLVMIntegerType>())
+    return b.getIntegerType(intType.getBitWidth());
 
-  if (type.isFloatTy())
+  if (type.isa<LLVMFloatType>())
     return b.getF32Type();
 
-  if (type.isDoubleTy())
+  if (type.isa<LLVMDoubleType>())
     return b.getF64Type();
 
   // LLVM vectors can only contain scalars.
-  if (type.isVectorTy()) {
-    auto numElements = type.getVectorElementCount();
+  if (auto vectorType = type.dyn_cast<LLVM::LLVMVectorType>()) {
+    auto numElements = vectorType.getElementCount();
     if (numElements.isScalable()) {
       emitError(unknownLoc) << "scalable vectors not supported";
       return nullptr;
     }
-    Type elementType = getStdTypeForAttr(type.getVectorElementType());
+    Type elementType = getStdTypeForAttr(vectorType.getElementType());
     if (!elementType)
       return nullptr;
     return VectorType::get(numElements.getKnownMinValue(), elementType);
   }
 
   // LLVM arrays can contain other arrays or vectors.
-  if (type.isArrayTy()) {
+  if (auto arrayType = type.dyn_cast<LLVMArrayType>()) {
     // Recover the nested array shape.
     SmallVector<int64_t, 4> shape;
-    shape.push_back(type.getArrayNumElements());
-    while (type.getArrayElementType().isArrayTy()) {
-      type = type.getArrayElementType();
-      shape.push_back(type.getArrayNumElements());
+    shape.push_back(arrayType.getNumElements());
+    while (arrayType.getElementType().isa<LLVMArrayType>()) {
+      arrayType = arrayType.getElementType().cast<LLVMArrayType>();
+      shape.push_back(arrayType.getNumElements());
     }
 
     // If the innermost type is a vector, use the multi-dimensional vector as
     // attribute type.
-    if (type.getArrayElementType().isVectorTy()) {
-      LLVMType vectorType = type.getArrayElementType();
-      auto numElements = vectorType.getVectorElementCount();
+    if (auto vectorType =
+            arrayType.getElementType().dyn_cast<LLVMVectorType>()) {
+      auto numElements = vectorType.getElementCount();
       if (numElements.isScalable()) {
         emitError(unknownLoc) << "scalable vectors not supported";
         return nullptr;
       }
       shape.push_back(numElements.getKnownMinValue());
 
-      Type elementType = getStdTypeForAttr(vectorType.getVectorElementType());
+      Type elementType = getStdTypeForAttr(vectorType.getElementType());
       if (!elementType)
         return nullptr;
       return VectorType::get(shape, elementType);
     }
 
     // Otherwise use a tensor.
-    Type elementType = getStdTypeForAttr(type.getArrayElementType());
+    Type elementType = getStdTypeForAttr(arrayType.getElementType());
     if (!elementType)
       return nullptr;
     return RankedTensorType::get(shape, elementType);
@@ -261,7 +261,7 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
     if (!attrType)
       return nullptr;
 
-    if (type.isIntegerTy()) {
+    if (type.isa<LLVMIntegerType>()) {
       SmallVector<APInt, 8> values;
       values.reserve(cd->getNumElements());
       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
@@ -269,7 +269,7 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
       return DenseElementsAttr::get(attrType, values);
     }
 
-    if (type.isFloatTy() || type.isDoubleTy()) {
+    if (type.isa<LLVMFloatType>() || type.isa<LLVMDoubleType>()) {
       SmallVector<APFloat, 8> values;
       values.reserve(cd->getNumElements());
       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
@@ -777,7 +777,8 @@ LogicalResult Importer::processFunction(llvm::Function *f) {
   instMap.clear();
   unknownInstMap.clear();
 
-  LLVMType functionType = processType(f->getFunctionType());
+  auto functionType =
+      processType(f->getFunctionType()).dyn_cast<LLVMFunctionType>();
   if (!functionType)
     return failure();
 
@@ -805,8 +806,8 @@ LogicalResult Importer::processFunction(llvm::Function *f) {
 
   // Add function arguments to the entry block.
   for (auto kv : llvm::enumerate(f->args()))
-    instMap[&kv.value()] = blockList[0]->addArgument(
-        functionType.getFunctionParamType(kv.index()));
+    instMap[&kv.value()] =
+        blockList[0]->addArgument(functionType.getParamType(kv.index()));
 
   for (auto bbs : llvm::zip(*f, blockList)) {
     if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))
index 8c650506e2d74d631b11300bc982437d21e97322..ae0745b0be28f4daffb1b00103dd0b64aade15b7 100644 (file)
@@ -969,7 +969,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
       // NB: Attribute already verified to be boolean, so check if we can indeed
       // attach the attribute to this argument, based on its type.
       auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMType>();
-      if (!argTy.isPointerTy())
+      if (!argTy.isa<LLVM::LLVMPointerType>())
         return func.emitError(
             "llvm.noalias attribute attached to LLVM non-pointer argument");
       if (attr.getValue())
@@ -981,7 +981,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
       // NB: Attribute already verified to be int, so check if we can indeed
       // attach the attribute to this argument, based on its type.
       auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMType>();
-      if (!argTy.isPointerTy())
+      if (!argTy.isa<LLVM::LLVMPointerType>())
         return func.emitError(
             "llvm.align attribute attached to LLVM non-pointer argument");
       llvmArg.addAttrs(
index 9461ebbd9ede92f6c74a3026433860c5c6efe5e0..d02c252c0bf3601f8a53efafc3f8c0822d8e938c 100644 (file)
@@ -98,7 +98,7 @@ func @gep_non_function_type(%pos : !llvm.i64, %base : !llvm.ptr<float>) {
 // -----
 
 func @load_non_llvm_type(%foo : memref<f32>) {
-  // expected-error@+1 {{expected LLVM IR dialect type}}
+  // expected-error@+1 {{expected LLVM pointer type}}
   llvm.load %foo : memref<f32>
 }
 
@@ -112,7 +112,7 @@ func @load_non_ptr_type(%foo : !llvm.float) {
 // -----
 
 func @store_non_llvm_type(%foo : memref<f32>, %bar : !llvm.float) {
-  // expected-error@+1 {{expected LLVM IR dialect type}}
+  // expected-error@+1 {{expected LLVM pointer type}}
   llvm.store %bar, %foo : memref<f32>
 }
 
@@ -267,7 +267,7 @@ func @insertvalue_array_out_of_bounds() {
 // -----
 
 func @insertvalue_wrong_nesting() {
-  // expected-error@+1 {{expected wrapped LLVM IR structure/array type}}
+  // expected-error@+1 {{expected LLVM IR structure/array type}}
   llvm.insertvalue %a, %b[0,0] : !llvm.struct<(i32)>
 }
 
@@ -311,7 +311,7 @@ func @extractvalue_array_out_of_bounds() {
 // -----
 
 func @extractvalue_wrong_nesting() {
-  // expected-error@+1 {{expected wrapped LLVM IR structure/array type}}
+  // expected-error@+1 {{expected LLVM IR structure/array type}}
   llvm.extractvalue %b[0,0] : !llvm.struct<(i32)>
 }