Add ConstantDataVector::getRaw() to create a constant data vector from raw data.
authorNick Lewycky <nick@wasmer.io>
Tue, 9 Mar 2021 23:37:04 +0000 (15:37 -0800)
committerNick Lewycky <nicholas@mxc.ca>
Tue, 16 Mar 2021 18:57:53 +0000 (11:57 -0700)
This parallels ConstantDataArray::getRaw() and can be used with ConstantDataSequential::getRawDataValues() in the base class for both types.

Update BuildConstantData{Array,Vector} tests to test the getRaw API. Also removes its unused Module.

In passing, update some comments to include the support for half and bfloat. Update tests to include testing for bfloat.

Differential Revision: https://reviews.llvm.org/D98302

llvm/include/llvm/IR/Constants.h
llvm/unittests/IR/ConstantsTest.cpp

index 510163a..223e47a 100644 (file)
@@ -558,10 +558,10 @@ public:
 
 //===----------------------------------------------------------------------===//
 /// ConstantDataSequential - A vector or array constant whose element type is a
-/// simple 1/2/4/8-byte integer or float/double, and whose elements are just
-/// simple data values (i.e. ConstantInt/ConstantFP).  This Constant node has no
-/// operands because it stores all of the elements of the constant as densely
-/// packed data, instead of as Value*'s.
+/// simple 1/2/4/8-byte integer or half/bfloat/float/double, and whose elements
+/// are just simple data values (i.e. ConstantInt/ConstantFP).  This Constant
+/// node has no operands because it stores all of the elements of the constant
+/// as densely packed data, instead of as Value*'s.
 ///
 /// This is the common base class of ConstantDataArray and ConstantDataVector.
 ///
@@ -700,11 +700,11 @@ public:
     return ConstantDataArray::get(Context, makeArrayRef(Elts));
   }
 
-  /// get() constructor - Return a constant with array type with an element
+  /// getRaw() constructor - Return a constant with array type with an element
   /// count and element type matching the NumElements and ElementTy parameters
   /// passed in. Note that this can return a ConstantAggregateZero object.
-  /// ElementTy needs to be one of i8/i16/i32/i64/float/double. Data is the
-  /// buffer containing the elements. Be careful to make sure Data uses the
+  /// ElementTy must be one of i8/i16/i32/i64/half/bfloat/float/double. Data is
+  /// the buffer containing the elements. Be careful to make sure Data uses the
   /// right endianness, the buffer will be used as-is.
   static Constant *getRaw(StringRef Data, uint64_t NumElements,
                           Type *ElementTy) {
@@ -772,6 +772,18 @@ public:
   static Constant *get(LLVMContext &Context, ArrayRef<float> Elts);
   static Constant *get(LLVMContext &Context, ArrayRef<double> Elts);
 
+  /// getRaw() constructor - Return a constant with vector type with an element
+  /// count and element type matching the NumElements and ElementTy parameters
+  /// passed in. Note that this can return a ConstantAggregateZero object.
+  /// ElementTy must be one of i8/i16/i32/i64/half/bfloat/float/double. Data is
+  /// the buffer containing the elements. Be careful to make sure Data uses the
+  /// right endianness, the buffer will be used as-is.
+  static Constant *getRaw(StringRef Data, uint64_t NumElements,
+                          Type *ElementTy) {
+    Type *Ty = VectorType::get(ElementTy, ElementCount::getFixed(NumElements));
+    return getImpl(Data, Ty);
+  }
+
   /// getFP() constructors - Return a constant of vector type with a float
   /// element type taken from argument `ElementType', and count taken from
   /// argument `Elts'.  The amount of bits of the contained type must match the
@@ -784,7 +796,7 @@ public:
 
   /// Return a ConstantVector with the specified constant in each element.
   /// The specified constant has to be a of a compatible type (i8/i16/
-  /// i32/i64/float/double) and must be a ConstantFP or ConstantInt.
+  /// i32/i64/half/bfloat/float/double) and must be a ConstantFP or ConstantInt.
   static Constant *getSplat(unsigned NumElts, Constant *Elt);
 
   /// Returns true if this is a splat constant, meaning that all elements have
index 44dbb90..50eb3e0 100644 (file)
@@ -418,45 +418,55 @@ static std::string getNameOfType(Type *T) {
 
 TEST(ConstantsTest, BuildConstantDataArrays) {
   LLVMContext Context;
-  std::unique_ptr<Module> M(new Module("MyModule", Context));
 
   for (Type *T : {Type::getInt8Ty(Context), Type::getInt16Ty(Context),
                   Type::getInt32Ty(Context), Type::getInt64Ty(Context)}) {
     ArrayType *ArrayTy = ArrayType::get(T, 2);
     Constant *Vals[] = {ConstantInt::get(T, 0), ConstantInt::get(T, 1)};
-    Constant *CDV = ConstantArray::get(ArrayTy, Vals);
-    ASSERT_TRUE(dyn_cast<ConstantDataArray>(CDV) != nullptr)
-        << " T = " << getNameOfType(T);
+    Constant *CA = ConstantArray::get(ArrayTy, Vals);
+    ASSERT_TRUE(isa<ConstantDataArray>(CA)) << " T = " << getNameOfType(T);
+    auto *CDA = cast<ConstantDataArray>(CA);
+    Constant *CA2 = ConstantDataArray::getRaw(
+        CDA->getRawDataValues(), CDA->getNumElements(), CDA->getElementType());
+    ASSERT_TRUE(CA == CA2) << " T = " << getNameOfType(T);
   }
 
-  for (Type *T : {Type::getHalfTy(Context), Type::getFloatTy(Context),
-                  Type::getDoubleTy(Context)}) {
+  for (Type *T : {Type::getHalfTy(Context), Type::getBFloatTy(Context),
+                  Type::getFloatTy(Context), Type::getDoubleTy(Context)}) {
     ArrayType *ArrayTy = ArrayType::get(T, 2);
     Constant *Vals[] = {ConstantFP::get(T, 0), ConstantFP::get(T, 1)};
-    Constant *CDV = ConstantArray::get(ArrayTy, Vals);
-    ASSERT_TRUE(dyn_cast<ConstantDataArray>(CDV) != nullptr)
-        << " T = " << getNameOfType(T);
+    Constant *CA = ConstantArray::get(ArrayTy, Vals);
+    ASSERT_TRUE(isa<ConstantDataArray>(CA)) << " T = " << getNameOfType(T);
+    auto *CDA = cast<ConstantDataArray>(CA);
+    Constant *CA2 = ConstantDataArray::getRaw(
+        CDA->getRawDataValues(), CDA->getNumElements(), CDA->getElementType());
+    ASSERT_TRUE(CA == CA2) << " T = " << getNameOfType(T);
   }
 }
 
 TEST(ConstantsTest, BuildConstantDataVectors) {
   LLVMContext Context;
-  std::unique_ptr<Module> M(new Module("MyModule", Context));
 
   for (Type *T : {Type::getInt8Ty(Context), Type::getInt16Ty(Context),
                   Type::getInt32Ty(Context), Type::getInt64Ty(Context)}) {
     Constant *Vals[] = {ConstantInt::get(T, 0), ConstantInt::get(T, 1)};
-    Constant *CDV = ConstantVector::get(Vals);
-    ASSERT_TRUE(dyn_cast<ConstantDataVector>(CDV) != nullptr)
-        << " T = " << getNameOfType(T);
+    Constant *CV = ConstantVector::get(Vals);
+    ASSERT_TRUE(isa<ConstantDataVector>(CV)) << " T = " << getNameOfType(T);
+    auto *CDV = cast<ConstantDataVector>(CV);
+    Constant *CV2 = ConstantDataVector::getRaw(
+        CDV->getRawDataValues(), CDV->getNumElements(), CDV->getElementType());
+    ASSERT_TRUE(CV == CV2) << " T = " << getNameOfType(T);
   }
 
-  for (Type *T : {Type::getHalfTy(Context), Type::getFloatTy(Context),
-                  Type::getDoubleTy(Context)}) {
+  for (Type *T : {Type::getHalfTy(Context), Type::getBFloatTy(Context),
+                  Type::getFloatTy(Context), Type::getDoubleTy(Context)}) {
     Constant *Vals[] = {ConstantFP::get(T, 0), ConstantFP::get(T, 1)};
-    Constant *CDV = ConstantVector::get(Vals);
-    ASSERT_TRUE(dyn_cast<ConstantDataVector>(CDV) != nullptr)
-        << " T = " << getNameOfType(T);
+    Constant *CV = ConstantVector::get(Vals);
+    ASSERT_TRUE(isa<ConstantDataVector>(CV)) << " T = " << getNameOfType(T);
+    auto *CDV = cast<ConstantDataVector>(CV);
+    Constant *CV2 = ConstantDataVector::getRaw(
+        CDV->getRawDataValues(), CDV->getNumElements(), CDV->getElementType());
+    ASSERT_TRUE(CV == CV2) << " T = " << getNameOfType(T);
   }
 }