Add several utility 'getValues<T>' functions to DenseElementsAttr that return ranges...
authorRiver Riddle <riverriddle@google.com>
Thu, 13 Jun 2019 20:22:32 +0000 (13:22 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Thu, 20 Jun 2019 06:01:03 +0000 (23:01 -0700)
PiperOrigin-RevId: 253092550

mlir/include/mlir/IR/Attributes.h
mlir/lib/IR/Attributes.cpp

index 6e0665e257bad62878b95cfa9e3adde1935015b8..1ffa2762a3a8569e1cca8b11e5b3de0f86790822 100644 (file)
@@ -548,11 +548,62 @@ public:
   static DenseElementsAttr get(ShapedType type, ArrayRef<APFloat> values);
 
   //===--------------------------------------------------------------------===//
-  // Value Querying
+  // Iterators
   //===--------------------------------------------------------------------===//
 
-  /// Return the raw storage data held by this attribute.
-  ArrayRef<char> getRawData() const;
+  /// A utility iterator that allows walking over the internal raw APInt values.
+  class IntElementIterator
+      : public llvm::iterator_facade_base<IntElementIterator,
+                                          std::bidirectional_iterator_tag,
+                                          APInt, std::ptrdiff_t, APInt, APInt> {
+  public:
+    /// Iterator movement.
+    IntElementIterator &operator++() {
+      ++index;
+      return *this;
+    }
+    IntElementIterator &operator--() {
+      --index;
+      return *this;
+    }
+
+    /// Accesses the raw APInt value at this iterator position.
+    APInt operator*() const;
+
+    /// Iterator equality.
+    bool operator==(const IntElementIterator &rhs) const {
+      return rawData == rhs.rawData && index == rhs.index;
+    }
+
+  private:
+    friend DenseElementsAttr;
+
+    /// Constructs a new iterator.
+    IntElementIterator(DenseElementsAttr attr, size_t index);
+
+    /// The base address of the raw data buffer.
+    const char *rawData;
+
+    /// The current element index.
+    size_t index;
+
+    /// The bitwidth of the element type.
+    size_t bitWidth;
+  };
+
+  /// Iterator for walking over APFloat values.
+  class FloatElementIterator final
+      : public llvm::mapped_iterator<IntElementIterator,
+                                     std::function<APFloat(const APInt &)>> {
+    friend DenseElementsAttr;
+
+    /// Initializes the float element iterator to the specified iterator.
+    FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it);
+  };
+
+  //===--------------------------------------------------------------------===//
+  // Value Querying
+  //===--------------------------------------------------------------------===//
 
   /// Returns the number of raw elements held by this attribute.
   size_t rawSize() const;
@@ -572,6 +623,37 @@ public:
   /// Return the held element values as Attributes in 'values'.
   void getValues(SmallVectorImpl<Attribute> &values) const;
 
+  /// Return the held element values as an array of integer or floating-point
+  /// values.
+  template <typename T, typename = typename std::enable_if<
+                            (!std::is_same<T, bool>::value &&
+                             std::numeric_limits<T>::is_integer) ||
+                            llvm::is_one_of<T, float, double>::value>::type>
+  ArrayRef<T> getValues() const {
+    assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer));
+    auto rawData = getRawData();
+    return ArrayRef<T>(reinterpret_cast<const T *>(rawData.data()),
+                       rawData.size() / sizeof(T));
+  }
+
+  /// Return the held element values as a range of APInts. The element type of
+  /// this attribute must be of integer type.
+  llvm::iterator_range<IntElementIterator> getIntValues() const;
+  template <typename T, typename = typename std::enable_if<
+                            std::is_same<T, APInt>::value>::type>
+  llvm::iterator_range<IntElementIterator> getValues() const {
+    return getIntValues();
+  }
+
+  /// Return the held element values as a range of APFloat. The element type of
+  /// this attribute must be of float type.
+  llvm::iterator_range<FloatElementIterator> getFloatValues() const;
+  template <typename T, typename = typename std::enable_if<
+                            std::is_same<T, APFloat>::value>::type>
+  llvm::iterator_range<FloatElementIterator> getValues() const {
+    return getFloatValues();
+  }
+
   //===--------------------------------------------------------------------===//
   // Mutation Utilities
   //===--------------------------------------------------------------------===//
@@ -596,53 +678,15 @@ public:
             llvm::function_ref<APInt(const APFloat &)> mapping) const;
 
 protected:
-  /// A utility iterator that allows walking over the internal raw APInt values.
-  class RawElementIterator
-      : public llvm::iterator_facade_base<RawElementIterator,
-                                          std::bidirectional_iterator_tag,
-                                          APInt, std::ptrdiff_t, APInt, APInt> {
-  public:
-    /// Iterator movement.
-    RawElementIterator &operator++() {
-      ++index;
-      return *this;
-    }
-    RawElementIterator &operator--() {
-      --index;
-      return *this;
-    }
-
-    /// Accesses the raw APInt value at this iterator position.
-    APInt operator*() const;
-
-    /// Iterator equality.
-    bool operator==(const RawElementIterator &rhs) const {
-      return rawData == rhs.rawData && index == rhs.index;
-    }
-    bool operator!=(const RawElementIterator &rhs) const {
-      return !(*this == rhs);
-    }
-
-  private:
-    friend DenseElementsAttr;
-
-    /// Constructs a new iterator.
-    RawElementIterator(DenseElementsAttr attr, size_t index);
-
-    /// The base address of the raw data buffer.
-    const char *rawData;
-
-    /// The current element index.
-    size_t index;
-
-    /// The bitwidth of the element type.
-    size_t bitWidth;
-  };
+  /// Return the raw storage data held by this attribute.
+  ArrayRef<char> getRawData() const;
 
-  /// Raw element iterators for this attribute.
-  RawElementIterator raw_begin() const { return RawElementIterator(*this, 0); }
-  RawElementIterator raw_end() const {
-    return RawElementIterator(*this, rawSize());
+  /// Get iterators to the raw APInt values for each element in this attribute.
+  IntElementIterator raw_int_begin() const {
+    return IntElementIterator(*this, 0);
+  }
+  IntElementIterator raw_int_end() const {
+    return IntElementIterator(*this, rawSize());
   }
 
   /// Constructs a dense elements attribute from an array of raw APInt values.
@@ -661,6 +705,11 @@ protected:
   static DenseElementsAttr getRawIntOrFloat(ShapedType type,
                                             ArrayRef<char> data,
                                             int64_t dataEltSize, bool isInt);
+
+  /// Check the information for a c++ data type, check if this type is valid for
+  /// the current attribute. This method is used to verify specific type
+  /// invariants that the templatized 'getValues' method cannot.
+  bool isValidIntOrFloat(int64_t dataEltSize, bool isInt) const;
 };
 
 /// An attribute that represents a reference to a dense integer vector or tensor
@@ -669,11 +718,9 @@ class DenseIntElementsAttr : public DenseElementsAttr {
 public:
   /// DenseIntElementsAttr iterates on APInt, so we can use the raw element
   /// iterator directly.
-  using iterator = DenseElementsAttr::RawElementIterator;
+  using iterator = DenseElementsAttr::IntElementIterator;
 
   using DenseElementsAttr::DenseElementsAttr;
-  using DenseElementsAttr::get;
-  using DenseElementsAttr::getValues;
 
   /// Generates a new DenseElementsAttr by mapping each value attribute, and
   /// constructing the DenseElementsAttr given the new element type.
@@ -681,12 +728,9 @@ public:
   mapValues(Type newElementType,
             llvm::function_ref<APInt(const APInt &)> mapping) const;
 
-  /// Gets the integer value of each of the dense elements.
-  void getValues(SmallVectorImpl<APInt> &values) const;
-
   /// Iterator access to the integer element values.
-  iterator begin() const { return raw_begin(); }
-  iterator end() const { return raw_end(); }
+  iterator begin() const { return raw_int_begin(); }
+  iterator end() const { return raw_int_end(); }
 
   /// Method for supporting type inquiry through isa, cast and dyn_cast.
   static bool classof(Attribute attr);
@@ -696,24 +740,9 @@ public:
 /// object. Each element is stored as a double.
 class DenseFPElementsAttr : public DenseElementsAttr {
 public:
-  /// DenseFPElementsAttr iterates on APFloat, so we need to wrap the raw
-  /// element iterator.
-  class ElementIterator final
-      : public llvm::mapped_iterator<RawElementIterator,
-                                     std::function<APFloat(const APInt &)>> {
-    friend DenseFPElementsAttr;
-
-    /// Initializes the float element iterator to the specified iterator.
-    ElementIterator(const llvm::fltSemantics &smt, RawElementIterator it);
-  };
-  using iterator = ElementIterator;
+  using iterator = DenseElementsAttr::FloatElementIterator;
 
   using DenseElementsAttr::DenseElementsAttr;
-  using DenseElementsAttr::get;
-  using DenseElementsAttr::getValues;
-
-  /// Gets the float value of each of the dense elements.
-  void getValues(SmallVectorImpl<APFloat> &values) const;
 
   /// Generates a new DenseElementsAttr by mapping each value attribute, and
   /// constructing the DenseElementsAttr given the new element type.
@@ -722,8 +751,8 @@ public:
             llvm::function_ref<APInt(const APFloat &)> mapping) const;
 
   /// Iterator access to the float element values.
-  iterator begin() const;
-  iterator end() const;
+  iterator begin() const { return getFloatValues().begin(); }
+  iterator end() const { return getFloatValues().end(); }
 
   /// Method for supporting type inquiry through isa, cast and dyn_cast.
   static bool classof(Attribute attr);
index 7b48d52fbbfd377ff301313780ba887790417e3d..b8bdbd3d232c92247bdcd983d98d13b232ce3cd6 100644 (file)
@@ -556,17 +556,23 @@ static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
 }
 
 /// Constructs a new iterator.
-DenseElementsAttr::RawElementIterator::RawElementIterator(
+DenseElementsAttr::IntElementIterator::IntElementIterator(
     DenseElementsAttr attr, size_t index)
     : rawData(attr.getRawData().data()), index(index),
       bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {}
 
 /// Accesses the raw APInt value at this iterator position.
-APInt DenseElementsAttr::RawElementIterator::operator*() const {
+APInt DenseElementsAttr::IntElementIterator::operator*() const {
   return readBits(rawData, index * getDenseElementStorageWidth(bitWidth),
                   bitWidth);
 }
 
+DenseElementsAttr::FloatElementIterator::FloatElementIterator(
+    const llvm::fltSemantics &smt, IntElementIterator it)
+    : llvm::mapped_iterator<IntElementIterator,
+                            std::function<APFloat(const APInt &)>>(
+          it, [&](const APInt &val) { return APFloat(smt, val); }) {}
+
 //===----------------------------------------------------------------------===//
 // DenseElementsAttr
 //===----------------------------------------------------------------------===//
@@ -582,7 +588,8 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
 
   // Compress the attribute values into a character buffer.
-  SmallVector<char, 8> data((storageBitWidth / CHAR_BIT) * values.size());
+  SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
+                            values.size());
   APInt intVal;
   for (unsigned i = 0, e = values.size(); i < e; ++i) {
     assert(eltType == values[i].getType() &&
@@ -653,7 +660,8 @@ DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
 
   size_t bitWidth = getDenseElementBitwidth(type.getElementType());
   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
-  std::vector<char> elementData((storageBitWidth / CHAR_BIT) * values.size());
+  std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
+                                values.size());
   for (unsigned i = 0, e = values.size(); i != e; ++i) {
     assert(values[i].getBitWidth() == bitWidth);
     writeBits(elementData.data(), i * storageBitWidth, values[i]);
@@ -670,6 +678,20 @@ DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
                    data, isSplat);
 }
 
+/// Check the information for a c++ data type, check if this type is valid for
+/// the current attribute. This method is used to verify specific type
+/// invariants that the templatized 'getValues' method cannot.
+static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize,
+                              bool isInt) {
+  // Make sure that the data element size is the same as the type element width.
+  if ((dataEltSize * CHAR_BIT) != type.getElementTypeBitWidth())
+    return false;
+
+  // Check that the element type is valid.
+  return isInt ? type.getElementType().isa<IntegerType>()
+               : type.getElementType().isa<FloatType>();
+}
+
 /// Overload of the 'getRaw' method that asserts that the given type is of
 /// integer type. This method is used to verify type invariants that the
 /// templatized 'get' method cannot.
@@ -677,15 +699,20 @@ DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
                                                       ArrayRef<char> data,
                                                       int64_t dataEltSize,
                                                       bool isInt) {
-  assert(isInt ? type.getElementType().isa<IntegerType>()
-               : type.getElementType().isa<FloatType>());
-  assert((dataEltSize * CHAR_BIT) == type.getElementTypeBitWidth());
+  assert(::isValidIntOrFloat(type, dataEltSize, isInt));
 
   int64_t numElements = data.size() / dataEltSize;
   assert(numElements == 1 || numElements == type.getNumElements());
   return getRaw(type, data, /*isSplat=*/numElements == 1);
 }
 
+/// A method used to verify specific type invariants that the templatized 'get'
+/// method cannot.
+bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize,
+                                          bool isInt) const {
+  return ::isValidIntOrFloat(getType(), dataEltSize, isInt);
+}
+
 /// Return the raw storage data held by this attribute.
 ArrayRef<char> DenseElementsAttr::getRawData() const {
   return static_cast<ImplType *>(impl)->data;
@@ -708,10 +735,10 @@ Attribute DenseElementsAttr::getSplatValue() const {
 
   auto elementType = getType().getElementType();
   if (elementType.isa<IntegerType>())
-    return IntegerAttr::get(elementType, *raw_begin());
+    return IntegerAttr::get(elementType, *raw_int_begin());
   if (auto fType = elementType.dyn_cast<FloatType>())
     return FloatAttr::get(elementType,
-                          APFloat(fType.getFloatSemantics(), *raw_begin()));
+                          APFloat(fType.getFloatSemantics(), *raw_int_begin()));
   llvm_unreachable("unexpected element type");
 }
 
@@ -760,30 +787,44 @@ Attribute DenseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
 }
 
 void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
+  values.reserve(rawSize());
+
   auto elementType = getType().getElementType();
   if (elementType.isa<IntegerType>()) {
-    // Get the raw APInt values.
-    SmallVector<APInt, 8> intValues;
-    cast<DenseIntElementsAttr>().getValues(intValues);
-
-    // Convert each to an IntegerAttr.
-    for (auto &intVal : intValues)
+    // Convert each value to an IntegerAttr.
+    for (auto intVal : getIntValues())
       values.push_back(IntegerAttr::get(elementType, intVal));
     return;
   }
   if (elementType.isa<FloatType>()) {
-    // Get the raw APFloat values.
-    SmallVector<APFloat, 8> floatValues;
-    cast<DenseFPElementsAttr>().getValues(floatValues);
-
-    // Convert each to an FloatAttr.
-    for (auto &floatVal : floatValues)
+    // Convert each value to a FloatAttr.
+    for (auto floatVal : getFloatValues())
       values.push_back(FloatAttr::get(elementType, floatVal));
     return;
   }
   llvm_unreachable("unexpected element type");
 }
 
+/// Return the held element values as a range of APInts. The element type of
+/// this attribute must be of integer type.
+auto DenseElementsAttr::getIntValues() const
+    -> llvm::iterator_range<IntElementIterator> {
+  assert(getType().getElementType().isa<IntegerType>() &&
+         "expected integer type");
+  return {raw_int_begin(), raw_int_end()};
+}
+
+/// Return the held element values as a range of APFloat. The element type of
+/// this attribute must be of float type.
+auto DenseElementsAttr::getFloatValues() const
+    -> llvm::iterator_range<FloatElementIterator> {
+  auto elementType = getType().getElementType().cast<FloatType>();
+  assert(elementType.isa<FloatType>() && "expected float type");
+  const auto &elementSemantics = elementType.getFloatSemantics();
+  return {FloatElementIterator(elementSemantics, raw_int_begin()),
+          FloatElementIterator(elementSemantics, raw_int_end())};
+}
+
 /// Return a new DenseElementsAttr that has the same data as the current
 /// attribute, but has been reshaped to 'newType'. The new type must have the
 /// same total number of elements as well as element type.
@@ -816,11 +857,6 @@ DenseElementsAttr DenseElementsAttr::mapValues(
 // DenseIntElementsAttr
 //===----------------------------------------------------------------------===//
 
-void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
-  values.reserve(rawSize());
-  values.assign(raw_begin(), raw_end());
-}
-
 template <typename Fn, typename Attr>
 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
                                 Type newElementType,
@@ -838,7 +874,7 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
   else
     assert(newArrayType && "Unhandled tensor type");
 
-  data.resize(storageBitWidth * inType.getNumElements());
+  data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * attr.rawSize());
 
   uint64_t elementIdx = 0;
   for (auto value : attr) {
@@ -871,17 +907,6 @@ bool DenseIntElementsAttr::classof(Attribute attr) {
 // DenseFPElementsAttr
 //===----------------------------------------------------------------------===//
 
-DenseFPElementsAttr::ElementIterator::ElementIterator(
-    const llvm::fltSemantics &smt, RawElementIterator it)
-    : llvm::mapped_iterator<RawElementIterator,
-                            std::function<APFloat(const APInt &)>>(
-          it, [&](const APInt &val) { return APFloat(smt, val); }) {}
-
-void DenseFPElementsAttr::getValues(SmallVectorImpl<APFloat> &values) const {
-  values.reserve(rawSize());
-  values.assign(begin(), end());
-}
-
 DenseElementsAttr DenseFPElementsAttr::mapValues(
     Type newElementType,
     llvm::function_ref<APInt(const APFloat &)> mapping) const {
@@ -892,18 +917,6 @@ DenseElementsAttr DenseFPElementsAttr::mapValues(
   return getRaw(newArrayType, elementData, isSplat());
 }
 
-/// Iterator access to the float element values.
-DenseFPElementsAttr::iterator DenseFPElementsAttr::begin() const {
-  auto elementType = getType().getElementType().cast<FloatType>();
-  const auto &elementSemantics = elementType.getFloatSemantics();
-  return {elementSemantics, raw_begin()};
-}
-DenseFPElementsAttr::iterator DenseFPElementsAttr::end() const {
-  auto elementType = getType().getElementType().cast<FloatType>();
-  const auto &elementSemantics = elementType.getFloatSemantics();
-  return {elementSemantics, raw_end()};
-}
-
 /// Method for supporting type inquiry through isa, cast and dyn_cast.
 bool DenseFPElementsAttr::classof(Attribute attr) {
   return attr.isa<DenseElementsAttr>() &&
@@ -985,14 +998,13 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
   // as a 1-D index array.
   auto sparseIndices = getIndices();
-  const uint64_t *sparseIndexValues =
-      reinterpret_cast<const uint64_t *>(sparseIndices.getRawData().data());
+  ArrayRef<uint64_t> sparseIndexValues = sparseIndices.getValues<uint64_t>();
 
   // Check to see if the indices are a splat.
   if (sparseIndices.isSplat()) {
     // If the index is also not a splat of the index value, we know that the
     // value is zero.
-    auto splatIndex = *sparseIndexValues;
+    auto splatIndex = sparseIndexValues.front();
     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
       return getZeroAttr();
 
@@ -1005,7 +1017,7 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
-    mappedIndices.try_emplace({sparseIndexValues + (i * rank), rank}, i);
+    mappedIndices.try_emplace({&sparseIndexValues[i * rank], rank}, i);
 
   // Look for the provided index key within the mapped indices. If the provided
   // index is not found, then return a zero attribute.