Refactor DenseElementAttr::getValues methods to return full ranges for splats.
authorRiver Riddle <riverriddle@google.com>
Sun, 11 Aug 2019 00:26:35 +0000 (17:26 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 12 Aug 2019 01:17:28 +0000 (18:17 -0700)
The current implementation only returns one element for the splat case, which often comes as a surprise; leading to subtle/confusing bugs. The new behavior will include an iterate over the full range of elements, as defined by the shaped type, by providing the splat value for each iterator index.

PiperOrigin-RevId: 262756780

mlir/include/mlir/IR/Attributes.h
mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Attributes.cpp

index 323473f..e75f102 100644 (file)
@@ -461,6 +461,9 @@ public:
   /// element, then a null attribute is returned.
   Attribute getValue(ArrayRef<uint64_t> index) const;
 
+  /// Returns the number of elements held by this attribute.
+  int64_t getNumElements() const;
+
   /// Generates a new ElementsAttr by mapping each int value to a new
   /// underlying APInt. The new values can represent either a integer or float.
   /// This ElementsAttr should contain integers.
@@ -482,6 +485,50 @@ public:
   }
 };
 
+namespace detail {
+/// DenseElementsAttr data is aligned to uint64_t, so this traits class is
+/// necessary to interop with PointerIntPair.
+class DenseElementDataPointerTypeTraits {
+public:
+  static inline const void *getAsVoidPointer(const char *ptr) { return ptr; }
+  static inline const char *getFromVoidPointer(const void *ptr) {
+    return static_cast<const char *>(ptr);
+  }
+
+  // Note: We could steal more bits if the need arises.
+  enum { NumLowBitsAvailable = 1 };
+};
+
+/// Pair of raw pointer and a boolean flag of whether the pointer holds a splat,
+using DenseIterPtrAndSplat =
+    llvm::PointerIntPair<const char *, 1, bool,
+                         DenseElementDataPointerTypeTraits>;
+
+/// Impl iterator for indexed DenseElementAttr iterators that records a data
+/// pointer and data index that is adjusted for the case of a splat attribute.
+template <typename ConcreteT, typename T, typename PointerT = T *,
+          typename ReferenceT = T &>
+class DenseElementIndexedIteratorImpl
+    : public indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T,
+                                       PointerT, ReferenceT> {
+protected:
+  DenseElementIndexedIteratorImpl(const char *data, bool isSplat,
+                                  size_t dataIndex)
+      : indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T, PointerT,
+                                  ReferenceT>({data, isSplat}, dataIndex) {}
+
+  /// Return the current index for this iterator, adjusted for the case of a
+  /// splat.
+  ptrdiff_t getDataIndex() const {
+    bool isSplat = this->object.getInt();
+    return isSplat ? 0 : this->index;
+  }
+
+  /// Return the data object pointer.
+  const char *getData() const { return this->object.getPointer(); }
+};
+} // namespace detail
+
 /// An attribute that represents a reference to a dense vector or tensor object.
 ///
 class DenseElementsAttr
@@ -566,10 +613,32 @@ public:
     AttributeElementIterator(DenseElementsAttr attr, size_t index);
   };
 
+  /// Iterator for walking raw element values of the specified type 'T', which
+  /// may be any c++ data type matching the stored representation: int32_t,
+  /// float, etc.
+  template <typename T>
+  class ElementIterator
+      : public detail::DenseElementIndexedIteratorImpl<ElementIterator<T>,
+                                                       const T> {
+  public:
+    /// Accesses the raw value at this iterator position.
+    const T &operator*() const {
+      return reinterpret_cast<const T *>(this->getData())[this->getDataIndex()];
+    }
+
+  private:
+    friend DenseElementsAttr;
+
+    /// Constructs a new iterator.
+    ElementIterator(const char *data, bool isSplat, size_t dataIndex)
+        : detail::DenseElementIndexedIteratorImpl<ElementIterator<T>, const T>(
+              data, isSplat, dataIndex) {}
+  };
+
   /// A utility iterator that allows walking over the internal raw APInt values.
   class IntElementIterator
-      : public indexed_accessor_iterator<IntElementIterator, const char *,
-                                         APInt, APInt, APInt> {
+      : public detail::DenseElementIndexedIteratorImpl<IntElementIterator,
+                                                       APInt, APInt, APInt> {
   public:
     /// Accesses the raw APInt value at this iterator position.
     APInt operator*() const;
@@ -601,9 +670,6 @@ public:
   // Value Querying
   //===--------------------------------------------------------------------===//
 
-  /// Returns the number of raw elements held by this attribute.
-  size_t rawSize() const;
-
   /// Returns if this attribute corresponds to a splat, i.e. if all element
   /// values are the same.
   bool isSplat() const;
@@ -616,17 +682,18 @@ public:
   /// element, then a null attribute is returned.
   Attribute getValue(ArrayRef<uint64_t> index) const;
 
-  /// Return the held element values as an array of integer or floating-point
+  /// Return the held element values as a range of integer or floating-point
   /// values.
   template <typename T, typename = typename std::enable_if<
                             (!std::is_same<T, bool>::value &&
                              std::numeric_limits<T>::is_integer) ||
                             llvm::is_one_of<T, float, double>::value>::type>
-  ArrayRef<T> getValues() const {
+  llvm::iterator_range<ElementIterator<T>> getValues() const {
     assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer));
-    auto rawData = getRawData();
-    return ArrayRef<T>(reinterpret_cast<const T *>(rawData.data()),
-                       rawData.size() / sizeof(T));
+    auto rawData = getRawData().data();
+    bool splat = isSplat();
+    return {ElementIterator<T>(rawData, splat, 0),
+            ElementIterator<T>(rawData, splat, getNumElements())};
   }
 
   /// Return the held element values as a range of Attributes.
@@ -693,7 +760,7 @@ protected:
     return IntElementIterator(*this, 0);
   }
   IntElementIterator raw_int_end() const {
-    return IntElementIterator(*this, rawSize());
+    return IntElementIterator(*this, getNumElements());
   }
 
   /// Constructs a dense elements attribute from an array of raw APInt values.
index 7cfedf9..4733e56 100644 (file)
@@ -49,9 +49,14 @@ convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
                            const UniformQuantizedValueConverter &converter) {
   // Convert to corresponding quantized value attributes.
   SmallVector<APInt, 8> quantValues;
-  quantValues.reserve(realFPElementsAttr.rawSize());
-  for (APFloat realVal : realFPElementsAttr) {
-    quantValues.push_back(converter.quantizeFloatToInt(realVal));
+  if (realFPElementsAttr.isSplat()) {
+    quantValues.push_back(
+        converter.quantizeFloatToInt(*realFPElementsAttr.begin()));
+  } else {
+    quantValues.reserve(realFPElementsAttr.getNumElements());
+    for (APFloat realVal : realFPElementsAttr) {
+      quantValues.push_back(converter.quantizeFloatToInt(realVal));
+    }
   }
 
   // Cast from an expressed-type-based type to storage-type-based type,
index a137f26..6825cb8 100644 (file)
@@ -726,7 +726,7 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
 /// Print the integer element of the given DenseElementsAttr at 'index'.
 static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os,
                                  unsigned index) {
-  APInt value = *std::next(attr.getIntValues().begin(), index);
+  APInt value = *std::next(attr.int_value_begin(), index);
   if (value.getBitWidth() == 1)
     os << (value.getBoolValue() ? "true" : "false");
   else
@@ -736,7 +736,7 @@ static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os,
 /// Print the float element of the given DenseElementsAttr at 'index'.
 static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
                                    unsigned index) {
-  APFloat value = *std::next(attr.getFloatValues().begin(), index);
+  APFloat value = *std::next(attr.float_value_begin(), index);
   printFloatValue(value, os);
 }
 
index e2a401c..df3ae71 100644 (file)
@@ -357,6 +357,11 @@ ShapedType ElementsAttr::getType() const {
   return Attribute::getType().cast<ShapedType>();
 }
 
+/// Returns the number of elements held by this attribute.
+int64_t ElementsAttr::getNumElements() const {
+  return getType().getNumElements();
+}
+
 /// Return the value at the given index. If index does not refer to a valid
 /// element, then a null attribute is returned.
 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
@@ -494,13 +499,14 @@ Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
 /// Constructs a new iterator.
 DenseElementsAttr::IntElementIterator::IntElementIterator(
     DenseElementsAttr attr, size_t index)
-    : indexed_accessor_iterator<IntElementIterator, const char *, APInt, APInt,
-                                APInt>(attr.getRawData().data(), index),
+    : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
+          attr.getRawData().data(), attr.isSplat(), index),
       bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {}
 
 /// Accesses the raw APInt value at this iterator position.
 APInt DenseElementsAttr::IntElementIterator::operator*() const {
-  return readBits(object, index * getDenseElementStorageWidth(bitWidth),
+  return readBits(getData(),
+                  getDataIndex() * getDenseElementStorageWidth(bitWidth),
                   bitWidth);
 }
 
@@ -655,11 +661,6 @@ ArrayRef<char> DenseElementsAttr::getRawData() const {
   return static_cast<ImplType *>(impl)->data;
 }
 
-/// Returns the number of raw elements held by this attribute.
-size_t DenseElementsAttr::rawSize() const {
-  return isSplat() ? 1 : getType().getNumElements();
-}
-
 /// Returns if this attribute corresponds to a splat, i.e. if all element
 /// values are the same.
 bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; }
@@ -723,7 +724,7 @@ auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
   return AttributeElementIterator(*this, 0);
 }
 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
-  return AttributeElementIterator(*this, rawSize());
+  return AttributeElementIterator(*this, getNumElements());
 }
 
 /// Return the held element values as a range of APInts. The element type of
@@ -811,16 +812,26 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
   else
     assert(newArrayType && "Unhandled tensor type");
 
-  data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * attr.rawSize());
+  size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
+  data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
 
-  uint64_t elementIdx = 0;
-  for (auto value : attr) {
+  // Functor used to process a single element value of the attribute.
+  auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
     auto newInt = mapping(value);
     assert(newInt.getBitWidth() == bitWidth);
-    writeBits(data.data(), elementIdx * storageBitWidth, newInt);
-    ++elementIdx;
+    writeBits(data.data(), index * storageBitWidth, newInt);
+  };
+
+  // Check for the splat case.
+  if (attr.isSplat()) {
+    processElt(*attr.begin(), /*index=*/0);
+    return newArrayType;
   }
 
+  // Otherwise, process all of the element values.
+  uint64_t elementIdx = 0;
+  for (auto value : attr)
+    processElt(value, elementIdx++);
   return newArrayType;
 }
 
@@ -935,13 +946,13 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
   // as a 1-D index array.
   auto sparseIndices = getIndices();
-  ArrayRef<uint64_t> sparseIndexValues = sparseIndices.getValues<uint64_t>();
+  auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
 
   // Check to see if the indices are a splat.
   if (sparseIndices.isSplat()) {
     // If the index is also not a splat of the index value, we know that the
     // value is zero.
-    auto splatIndex = sparseIndexValues.front();
+    auto splatIndex = *sparseIndexValues.begin();
     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
       return getZeroAttr();
 
@@ -954,7 +965,8 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
-    mappedIndices.try_emplace({&sparseIndexValues[i * rank], rank}, i);
+    mappedIndices.try_emplace(
+        {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
 
   // Look for the provided index key within the mapped indices. If the provided
   // index is not found, then return a zero attribute.