Remove the ability to directly construct a DenseElementsAttr with a raw character...
authorRiver Riddle <riverriddle@google.com>
Fri, 7 Jun 2019 16:57:29 +0000 (09:57 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 9 Jun 2019 23:23:34 +0000 (16:23 -0700)
* 'get' methods that allow constructing from an ArrayRef of integer or floating point values.
* A 'reshape' method to allow for changing the shape without changing the underlying data.

PiperOrigin-RevId: 252067898

mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Builders.h
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/Builders.cpp

index 8baa45c..92e80d2 100644 (file)
@@ -81,8 +81,7 @@ struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
       //      auto oldType = constantOp.getType();
       auto newType = rewriter.getTensorType(
           reshapeType.getShape(), valueAttr.getType().getElementType());
-      auto newAttr =
-          mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
+      auto newAttr = valueAttr.reshape(newType);
       rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
                                               newAttr);
     } else if (auto valueAttr =
index 64bd2c9..8e9e8eb 100644 (file)
@@ -83,8 +83,7 @@ struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
       //      auto oldType = constantOp.getType();
       auto newType = rewriter.getTensorType(
           reshapeType.getShape(), valueAttr.getType().getElementType());
-      auto newAttr =
-          mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
+      auto newAttr = valueAttr.reshape(newType);
       rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
                                               newAttr);
     } else if (auto valueAttr =
index 9adcd6c..08a5f33 100644 (file)
@@ -502,16 +502,32 @@ class DenseElementsAttr
 public:
   using Base::Base;
 
-  /// It assumes the elements in the input array have been truncated to the bits
-  /// width specified by the element type. 'type' must be a vector or tensor
-  /// with static shape.
-  static DenseElementsAttr get(ShapedType type, ArrayRef<char> data);
-
   /// Constructs a dense elements attribute from an array of element values.
   /// Each element attribute value is expected to be an element of 'type'.
   /// 'type' must be a vector or tensor with static shape.
   static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values);
 
+  /// Constructs a dense integer elements attribute from an array of integer
+  /// or floating-point values. Each value is expected to be the same bitwidth
+  /// of the element type of 'type'. 'type' must be a vector or tensor with
+  /// static shape.
+  template <typename ShapeT, typename T>
+  static DenseElementsAttr get(ShapeT type, ArrayRef<T> values) {
+    static_assert(std::numeric_limits<T>::is_integer ||
+                      llvm::is_one_of<T, float, double>::value,
+                  "expected integer or floating point element type");
+
+    assert(type.getNumElements() == static_cast<int64_t>(values.size()));
+    assert(type.getElementTypeBitWidth() == (sizeof(T) * CHAR_BIT));
+    const char *data = reinterpret_cast<const char *>(values.data());
+    return getRawIntOrFloat(type,
+                            ArrayRef<char>(data, values.size() * sizeof(T)),
+                            /*isInt=*/std::numeric_limits<T>::is_integer);
+  }
+
+  /// Overload of the above 'get' method that is specialized for boolean values.
+  static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
+
   /// Returns the number of elements held by this attribute.
   size_t size() const;
 
@@ -519,8 +535,14 @@ public:
   /// element, then a null attribute is returned.
   Attribute getValue(ArrayRef<uint64_t> index) const;
 
+  /// Return the held element values as Attributes in 'values'.
   void getValues(SmallVectorImpl<Attribute> &values) const;
 
+  /// Return a new DenseElementsAttr that has the same data as the current
+  /// attribute, but has been reshaped to 'newType'. The new type must have the
+  /// same total number of elements as well as element type.
+  DenseElementsAttr reshape(ShapedType newType);
+
   /// Generates a new DenseElementsAttr by mapping each int value to a new
   /// underlying APInt. The new values can represent either a integer or float.
   /// This underlying type must be an DenseIntElementsAttr.
@@ -600,6 +622,15 @@ protected:
     return RawElementIterator(*this, size());
   }
 
+  /// Get or create a new dense elements attribute instance with the given raw
+  /// data buffer. 'type' must be a vector or tensor with static shape.
+  static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data);
+
+  /// Overload of the raw 'get' method that asserts that the given type is of
+  /// integer or floating-point type.
+  static DenseElementsAttr getRawIntOrFloat(ShapedType type,
+                                            ArrayRef<char> data, bool isInt);
+
   /// Constructs a dense elements attribute from an array of raw APInt values.
   /// Each APInt value is expected to have the same bitwidth as the element type
   /// of 'type'. 'type' must be a vector or tensor with static shape.
@@ -624,11 +655,6 @@ public:
   /// shape.
   static DenseIntElementsAttr get(ShapedType type, ArrayRef<APInt> values);
 
-  /// Constructs a dense integer elements attribute from an array of integer
-  /// values. Each value is expected to be within the bitwidth of the element
-  /// type of 'type'. 'type' must be a vector or tensor with static shape.
-  static DenseIntElementsAttr get(ShapedType type, ArrayRef<int64_t> values);
-
   /// Generates a new DenseElementsAttr by mapping each value attribute, and
   /// constructing the DenseElementsAttr given the new element type.
   DenseElementsAttr
index c192806..c9d9fb2 100644 (file)
@@ -116,7 +116,6 @@ public:
   FunctionAttr getFunctionAttr(Function *value);
   FunctionAttr getFunctionAttr(StringRef value);
   ElementsAttr getSplatElementsAttr(ShapedType type, Attribute elt);
-  ElementsAttr getDenseElementsAttr(ShapedType type, ArrayRef<char> data);
   ElementsAttr getDenseElementsAttr(ShapedType type,
                                     ArrayRef<Attribute> values);
   ElementsAttr getDenseIntElementsAttr(ShapedType type,
index b5c965d..2e91ac6 100644 (file)
@@ -532,7 +532,8 @@ APInt DenseElementsAttr::RawElementIterator::operator*() const {
 // DenseElementsAttr
 //===----------------------------------------------------------------------===//
 
-DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<char> data) {
+DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
+                                            ArrayRef<char> data) {
   assert((static_cast<uint64_t>(type.getSizeInBits()) <=
           data.size() * APInt::APINT_WORD_SIZE) &&
          "Input data bit size should be larger than that type requires");
@@ -543,6 +544,27 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<char> data) {
                    data);
 }
 
+/// Overload of the raw 'get' method that asserts that the given type is of
+/// integer type.
+DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
+                                                      ArrayRef<char> data,
+                                                      bool isInt) {
+  assert(isInt ? type.getElementType().isa<IntegerType>()
+               : type.getElementType().isa<FloatType>());
+  return getRaw(type, data);
+}
+
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+                                         ArrayRef<bool> values) {
+  assert(type.getNumElements() == static_cast<int64_t>(values.size()));
+  assert(type.getElementType().isInteger(1));
+
+  std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
+  for (int i = 0, e = values.size(); i != e; ++i)
+    writeBits(buff.data(), i, llvm::APInt(1, values[i]));
+  return getRaw(type, buff);
+}
+
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<Attribute> values) {
   assert(type.getElementType().isIntOrFloat() &&
@@ -579,7 +601,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
            "expected value to have same bitwidth as element type");
     writeBits(data.data(), i * storageBitWidth, intVal);
   }
-  return get(type, data);
+  return getRaw(type, data);
 }
 
 /// Returns the number of elements held by this attribute.
@@ -650,6 +672,22 @@ void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
   llvm_unreachable("unexpected element type");
 }
 
+/// Return a new DenseElementsAttr that has the same data as the current
+/// attribute, but has been reshaped to 'newType'. The new type must have the
+/// same total number of elements as well as element type.
+DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
+  ShapedType curType = getType();
+  if (curType == newType)
+    return *this;
+
+  (void)curType;
+  assert(newType.getElementType() == curType.getElementType() &&
+         "expected the same element type");
+  assert(newType.getNumElements() == curType.getNumElements() &&
+         "expected the same number of elements");
+  return getRaw(newType, getRawData());
+}
+
 DenseElementsAttr DenseElementsAttr::mapValues(
     Type newElementType,
     llvm::function_ref<APInt(const APInt &)> mapping) const {
@@ -681,7 +719,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
     assert(values[i].getBitWidth() == bitWidth);
     writeBits(elementData.data(), i * storageBitWidth, values[i]);
   }
-  return get(type, elementData);
+  return getRaw(type, elementData);
 }
 
 /// Writes value to the bit position `bitPos` in array `rawData`.
@@ -728,22 +766,6 @@ DenseIntElementsAttr DenseIntElementsAttr::get(ShapedType type,
   return DenseElementsAttr::get(type, values).cast<DenseIntElementsAttr>();
 }
 
-/// Constructs a dense integer elements attribute from an array of integer
-/// values. Each value is expected to be within the bitwidth of the element
-/// type of 'type'.
-DenseIntElementsAttr DenseIntElementsAttr::get(ShapedType type,
-                                               ArrayRef<int64_t> values) {
-  auto eltType = type.getElementType();
-  size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
-
-  // Convert the raw integer values to APInt.
-  SmallVector<APInt, 8> apIntValues;
-  apIntValues.reserve(values.size());
-  for (auto value : values)
-    apIntValues.emplace_back(APInt(bitWidth, value));
-  return get(type, apIntValues);
-}
-
 void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
   values.reserve(size());
   values.assign(raw_begin(), raw_end());
@@ -786,7 +808,7 @@ DenseElementsAttr DenseIntElementsAttr::mapValues(
   auto newArrayType =
       mappingHelper(mapping, *this, getType(), newElementType, elementData);
 
-  return get(newArrayType, elementData);
+  return getRaw(newArrayType, elementData);
 }
 
 /// Method for supporting type inquiry through isa, cast and dyn_cast.
@@ -829,7 +851,7 @@ DenseElementsAttr DenseFPElementsAttr::mapValues(
   auto newArrayType =
       mappingHelper(mapping, *this, getType(), newElementType, elementData);
 
-  return get(newArrayType, elementData);
+  return getRaw(newArrayType, elementData);
 }
 
 /// Iterator access to the float element values.
index 63efe47..e5945d4 100644 (file)
@@ -189,11 +189,6 @@ ElementsAttr Builder::getSplatElementsAttr(ShapedType type, Attribute elt) {
 }
 
 ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
-                                           ArrayRef<char> data) {
-  return DenseElementsAttr::get(type, data);
-}
-
-ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
                                            ArrayRef<Attribute> values) {
   return DenseElementsAttr::get(type, values);
 }