[mlir] provide same APIs as existing LLVMType in the new LLVM type modeling
authorAlex Zinenko <zinenko@google.com>
Tue, 4 Aug 2020 09:37:31 +0000 (11:37 +0200)
committerAlex Zinenko <zinenko@google.com>
Tue, 4 Aug 2020 11:49:14 +0000 (13:49 +0200)
These are intended to smoothen the transition and may be removed in the future
in favor of more MLIR-compatible APIs. They intentionally have the same
semantics as the existing functions, which must remain stable until the
transition is complete.

Depends On D85019

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D85020

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp

index 6764f98..e409d68 100644 (file)
@@ -26,6 +26,8 @@ class DialectAsmParser;
 class DialectAsmPrinter;
 
 namespace LLVM {
+class LLVMDialect;
+
 namespace detail {
 struct LLVMFunctionTypeStorage;
 struct LLVMIntegerTypeStorage;
@@ -34,6 +36,12 @@ struct LLVMStructTypeStorage;
 struct LLVMTypeAndSizeStorage;
 } // namespace detail
 
+class LLVMBFloatType;
+class LLVMHalfType;
+class LLVMFloatType;
+class LLVMDoubleType;
+class LLVMIntegerType;
+
 //===----------------------------------------------------------------------===//
 // LLVMTypeNew.
 //===----------------------------------------------------------------------===//
@@ -96,6 +104,150 @@ public:
   static bool kindof(unsigned kind) {
     return FIRST_NEW_LLVM_TYPE <= kind && kind <= LAST_NEW_LLVM_TYPE;
   }
+
+  LLVMDialect &getDialect();
+
+  /// Floating-point type utilities.
+  bool isBFloatTy() { return isa<LLVMBFloatType>(); }
+  bool isHalfTy() { return isa<LLVMHalfType>(); }
+  bool isFloatTy() { return isa<LLVMFloatType>(); }
+  bool isDoubleTy() { return isa<LLVMDoubleType>(); }
+  bool isFloatingPointTy() {
+    return isa<LLVMHalfType>() || isa<LLVMBFloatType>() ||
+           isa<LLVMFloatType>() || isa<LLVMDoubleType>();
+  }
+
+  /// Array type utilities.
+  LLVMTypeNew getArrayElementType();
+  unsigned getArrayNumElements();
+  bool isArrayTy();
+
+  /// Integer type utilities.
+  bool isIntegerTy() { return isa<LLVMIntegerType>(); }
+  bool isIntegerTy(unsigned bitwidth);
+  unsigned getIntegerBitWidth();
+
+  /// Vector type utilities.
+  LLVMTypeNew getVectorElementType();
+  unsigned getVectorNumElements();
+  llvm::ElementCount getVectorElementCount();
+  bool isVectorTy();
+
+  /// Function type utilities.
+  LLVMTypeNew getFunctionParamType(unsigned argIdx);
+  unsigned getFunctionNumParams();
+  LLVMTypeNew getFunctionResultType();
+  bool isFunctionTy();
+  bool isFunctionVarArg();
+
+  /// Pointer type utilities.
+  LLVMTypeNew getPointerTo(unsigned addrSpace = 0);
+  LLVMTypeNew getPointerElementTy();
+  bool isPointerTy();
+  static bool isValidPointerElementType(LLVMTypeNew type);
+
+  /// Struct type utilities.
+  LLVMTypeNew getStructElementType(unsigned i);
+  unsigned getStructNumElements();
+  bool isStructTy();
+
+  /// Utilities used to generate floating point types.
+  static LLVMTypeNew getDoubleTy(LLVMDialect *dialect);
+  static LLVMTypeNew getFloatTy(LLVMDialect *dialect);
+  static LLVMTypeNew getBFloatTy(LLVMDialect *dialect);
+  static LLVMTypeNew getHalfTy(LLVMDialect *dialect);
+  static LLVMTypeNew getFP128Ty(LLVMDialect *dialect);
+  static LLVMTypeNew getX86_FP80Ty(LLVMDialect *dialect);
+
+  /// Utilities used to generate integer types.
+  static LLVMTypeNew getIntNTy(LLVMDialect *dialect, unsigned numBits);
+  static LLVMTypeNew getInt1Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/1);
+  }
+  static LLVMTypeNew getInt8Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/8);
+  }
+  static LLVMTypeNew getInt8PtrTy(LLVMDialect *dialect) {
+    return getInt8Ty(dialect).getPointerTo();
+  }
+  static LLVMTypeNew getInt16Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/16);
+  }
+  static LLVMTypeNew getInt32Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/32);
+  }
+  static LLVMTypeNew getInt64Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/64);
+  }
+
+  /// Utilities used to generate other miscellaneous types.
+  static LLVMTypeNew getArrayTy(LLVMTypeNew elementType, uint64_t numElements);
+  static LLVMTypeNew getFunctionTy(LLVMTypeNew result,
+                                   ArrayRef<LLVMTypeNew> params, bool isVarArg);
+  static LLVMTypeNew getFunctionTy(LLVMTypeNew result, bool isVarArg) {
+    return getFunctionTy(result, llvm::None, isVarArg);
+  }
+  static LLVMTypeNew getStructTy(LLVMDialect *dialect,
+                                 ArrayRef<LLVMTypeNew> elements,
+                                 bool isPacked = false);
+  static LLVMTypeNew getStructTy(LLVMDialect *dialect, bool isPacked = false) {
+    return getStructTy(dialect, llvm::None, isPacked);
+  }
+  template <typename... Args>
+  static typename std::enable_if<llvm::are_base_of<LLVMTypeNew, Args...>::value,
+                                 LLVMTypeNew>::type
+  getStructTy(LLVMTypeNew elt1, Args... elts) {
+    SmallVector<LLVMTypeNew, 8> fields({elt1, elts...});
+    return getStructTy(&elt1.getDialect(), fields);
+  }
+  static LLVMTypeNew getVectorTy(LLVMTypeNew elementType, unsigned numElements);
+
+  /// Void type utilities.
+  static LLVMTypeNew getVoidTy(LLVMDialect *dialect);
+  bool isVoidTy();
+
+  // Creation and setting of LLVM's identified struct types
+  static LLVMTypeNew createStructTy(LLVMDialect *dialect,
+                                    ArrayRef<LLVMTypeNew> elements,
+                                    Optional<StringRef> name,
+                                    bool isPacked = false);
+
+  static LLVMTypeNew createStructTy(LLVMDialect *dialect,
+                                    Optional<StringRef> name) {
+    return createStructTy(dialect, llvm::None, name);
+  }
+
+  static LLVMTypeNew createStructTy(ArrayRef<LLVMTypeNew> elements,
+                                    Optional<StringRef> name,
+                                    bool isPacked = false) {
+    assert(!elements.empty() &&
+           "This method may not be invoked with an empty list");
+    LLVMTypeNew ele0 = elements.front();
+    return createStructTy(&ele0.getDialect(), elements, name, isPacked);
+  }
+
+  template <typename... Args>
+  static
+      typename std::enable_if_t<llvm::are_base_of<LLVMTypeNew, Args...>::value,
+                                LLVMTypeNew>
+      createStructTy(StringRef name, LLVMTypeNew elt1, Args... elts) {
+    SmallVector<LLVMTypeNew, 8> fields({elt1, elts...});
+    Optional<StringRef> opt_name(name);
+    return createStructTy(&elt1.getDialect(), fields, opt_name);
+  }
+
+  static LLVMTypeNew setStructTyBody(LLVMTypeNew structType,
+                                     ArrayRef<LLVMTypeNew> elements,
+                                     bool isPacked = false);
+
+  template <typename... Args>
+  static
+      typename std::enable_if_t<llvm::are_base_of<LLVMTypeNew, Args...>::value,
+                                LLVMTypeNew>
+      setStructTyBody(LLVMTypeNew structType, LLVMTypeNew elt1, Args... elts) {
+    SmallVector<LLVMTypeNew, 8> fields({elt1, elts...});
+    return setStructTyBody(structType, fields);
+  }
 };
 
 //===----------------------------------------------------------------------===//
@@ -323,6 +475,9 @@ public:
   /// Checks if a struct is opaque.
   bool isOpaque();
 
+  /// Checks if a struct is initialized.
+  bool isInitialized();
+
   /// Returns the name of an identified struct.
   StringRef getName();
 
index 3540091..abecbcc 100644 (file)
@@ -13,6 +13,7 @@
 
 #include "TypeDetail.h"
 
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/TypeSupport.h"
@@ -23,6 +24,213 @@ using namespace mlir;
 using namespace mlir::LLVM;
 
 //===----------------------------------------------------------------------===//
+// LLVMTypeNew.
+//===----------------------------------------------------------------------===//
+
+// TODO: when these types are registered with the LLVMDialect, this method
+// should be removed and the regular Type::getDialect should just work.
+LLVMDialect &LLVMTypeNew::getDialect() {
+  return *getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
+}
+
+//----------------------------------------------------------------------------//
+// Integer type utilities.
+
+bool LLVMTypeNew::isIntegerTy(unsigned bitwidth) {
+  if (auto intType = dyn_cast<LLVMIntegerType>())
+    return intType.getBitWidth() == bitwidth;
+  return false;
+}
+
+unsigned LLVMTypeNew::getIntegerBitWidth() {
+  return cast<LLVMIntegerType>().getBitWidth();
+}
+
+LLVMTypeNew LLVMTypeNew::getArrayElementType() {
+  return cast<LLVMArrayType>().getElementType();
+}
+
+//----------------------------------------------------------------------------//
+// Array type utilities.
+
+unsigned LLVMTypeNew::getArrayNumElements() {
+  return cast<LLVMArrayType>().getNumElements();
+}
+
+bool LLVMTypeNew::isArrayTy() { return isa<LLVMArrayType>(); }
+
+//----------------------------------------------------------------------------//
+// Vector type utilities.
+
+LLVMTypeNew LLVMTypeNew::getVectorElementType() {
+  return cast<LLVMVectorType>().getElementType();
+}
+
+unsigned LLVMTypeNew::getVectorNumElements() {
+  return cast<LLVMFixedVectorType>().getNumElements();
+}
+llvm::ElementCount LLVMTypeNew::getVectorElementCount() {
+  return cast<LLVMVectorType>().getElementCount();
+}
+
+bool LLVMTypeNew::isVectorTy() { return isa<LLVMVectorType>(); }
+
+//----------------------------------------------------------------------------//
+// Function type utilities.
+
+LLVMTypeNew LLVMTypeNew::getFunctionParamType(unsigned argIdx) {
+  return cast<LLVMFunctionType>().getParamType(argIdx);
+}
+
+unsigned LLVMTypeNew::getFunctionNumParams() {
+  return cast<LLVMFunctionType>().getNumParams();
+}
+
+LLVMTypeNew LLVMTypeNew::getFunctionResultType() {
+  return cast<LLVMFunctionType>().getReturnType();
+}
+
+bool LLVMTypeNew::isFunctionTy() { return isa<LLVMFunctionType>(); }
+
+bool LLVMTypeNew::isFunctionVarArg() {
+  return cast<LLVMFunctionType>().isVarArg();
+}
+
+//----------------------------------------------------------------------------//
+// Pointer type utilities.
+
+LLVMTypeNew LLVMTypeNew::getPointerTo(unsigned addrSpace) {
+  return LLVMPointerType::get(*this, addrSpace);
+}
+
+LLVMTypeNew LLVMTypeNew::getPointerElementTy() {
+  return cast<LLVMPointerType>().getElementType();
+}
+
+bool LLVMTypeNew::isPointerTy() { return isa<LLVMPointerType>(); }
+
+bool LLVMTypeNew::isValidPointerElementType(LLVMTypeNew type) {
+  return !type.isa<LLVMVoidType>() && !type.isa<LLVMTokenType>() &&
+         !type.isa<LLVMMetadataType>() && !type.isa<LLVMLabelType>();
+}
+
+//----------------------------------------------------------------------------//
+// Struct type utilities.
+
+LLVMTypeNew LLVMTypeNew::getStructElementType(unsigned i) {
+  return cast<LLVMStructType>().getBody()[i];
+}
+
+unsigned LLVMTypeNew::getStructNumElements() {
+  return cast<LLVMStructType>().getBody().size();
+}
+
+bool LLVMTypeNew::isStructTy() { return isa<LLVMStructType>(); }
+
+//----------------------------------------------------------------------------//
+// Utilities used to generate floating point types.
+
+LLVMTypeNew LLVMTypeNew::getDoubleTy(LLVMDialect *dialect) {
+  return LLVMDoubleType::get(dialect->getContext());
+}
+
+LLVMTypeNew LLVMTypeNew::getFloatTy(LLVMDialect *dialect) {
+  return LLVMFloatType::get(dialect->getContext());
+}
+
+LLVMTypeNew LLVMTypeNew::getBFloatTy(LLVMDialect *dialect) {
+  return LLVMBFloatType::get(dialect->getContext());
+}
+
+LLVMTypeNew LLVMTypeNew::getHalfTy(LLVMDialect *dialect) {
+  return LLVMHalfType::get(dialect->getContext());
+}
+
+LLVMTypeNew LLVMTypeNew::getFP128Ty(LLVMDialect *dialect) {
+  return LLVMFP128Type::get(dialect->getContext());
+}
+
+LLVMTypeNew LLVMTypeNew::getX86_FP80Ty(LLVMDialect *dialect) {
+  return LLVMX86FP80Type::get(dialect->getContext());
+}
+
+//----------------------------------------------------------------------------//
+// Utilities used to generate integer types.
+
+LLVMTypeNew LLVMTypeNew::getIntNTy(LLVMDialect *dialect, unsigned numBits) {
+  return LLVMIntegerType::get(dialect->getContext(), numBits);
+}
+
+//----------------------------------------------------------------------------//
+// Utilities used to generate other miscellaneous types.
+
+LLVMTypeNew LLVMTypeNew::getArrayTy(LLVMTypeNew elementType,
+                                    uint64_t numElements) {
+  return LLVMArrayType::get(elementType, numElements);
+}
+
+LLVMTypeNew LLVMTypeNew::getFunctionTy(LLVMTypeNew result,
+                                       ArrayRef<LLVMTypeNew> params,
+                                       bool isVarArg) {
+  return LLVMFunctionType::get(result, params, isVarArg);
+}
+
+LLVMTypeNew LLVMTypeNew::getStructTy(LLVMDialect *dialect,
+                                     ArrayRef<LLVMTypeNew> elements,
+                                     bool isPacked) {
+  return LLVMStructType::getLiteral(dialect->getContext(), elements, isPacked);
+}
+
+LLVMTypeNew LLVMTypeNew::getVectorTy(LLVMTypeNew elementType,
+                                     unsigned numElements) {
+  return LLVMFixedVectorType::get(elementType, numElements);
+}
+
+//----------------------------------------------------------------------------//
+// Void type utilities.
+
+LLVMTypeNew LLVMTypeNew::getVoidTy(LLVMDialect *dialect) {
+  return LLVMVoidType::get(dialect->getContext());
+}
+
+bool LLVMTypeNew::isVoidTy() { return isa<LLVMVoidType>(); }
+
+//----------------------------------------------------------------------------//
+// Creation and setting of LLVM's identified struct types
+
+LLVMTypeNew LLVMTypeNew::createStructTy(LLVMDialect *dialect,
+                                        ArrayRef<LLVMTypeNew> elements,
+                                        Optional<StringRef> name,
+                                        bool isPacked) {
+  assert(name.hasValue() &&
+         "identified structs with no identifier not supported");
+  StringRef stringNameBase = name.getValueOr("");
+  std::string stringName = stringNameBase.str();
+  unsigned counter = 0;
+  do {
+    auto type =
+        LLVMStructType::getIdentified(dialect->getContext(), stringName);
+    if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
+      counter += 1;
+      stringName =
+          (Twine(stringNameBase) + "." + std::to_string(counter)).str();
+      continue;
+    }
+    return type;
+  } while (true);
+}
+
+LLVMTypeNew LLVMTypeNew::setStructTyBody(LLVMTypeNew structType,
+                                         ArrayRef<LLVMTypeNew> elements,
+                                         bool isPacked) {
+  LogicalResult couldSet =
+      structType.cast<LLVMStructType>().setBody(elements, isPacked);
+  assert(succeeded(couldSet) && "failed to set the body");
+  (void)couldSet;
+  return structType;
+}
+
+//===----------------------------------------------------------------------===//
 // Array type.
 
 LLVMArrayType LLVMArrayType::get(LLVMTypeNew elementType,
@@ -117,6 +325,7 @@ bool LLVMStructType::isIdentified() { return getImpl()->isIdentified(); }
 bool LLVMStructType::isOpaque() {
   return getImpl()->isOpaque() || !getImpl()->isInitialized();
 }
+bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); }
 StringRef LLVMStructType::getName() { return getImpl()->getIdentifier(); }
 ArrayRef<LLVMTypeNew> LLVMStructType::getBody() {
   return isIdentified() ? getImpl()->getIdentifiedStructBody()