/// element, then a null attribute is returned.
Attribute getValue(ArrayRef<uint64_t> index) const;
+ /// Returns the number of elements held by this attribute.
+ int64_t getNumElements() const;
+
/// Generates a new ElementsAttr by mapping each int value to a new
/// underlying APInt. The new values can represent either a integer or float.
/// This ElementsAttr should contain integers.
}
};
+namespace detail {
+/// DenseElementsAttr data is aligned to uint64_t, so this traits class is
+/// necessary to interop with PointerIntPair.
+class DenseElementDataPointerTypeTraits {
+public:
+ static inline const void *getAsVoidPointer(const char *ptr) { return ptr; }
+ static inline const char *getFromVoidPointer(const void *ptr) {
+ return static_cast<const char *>(ptr);
+ }
+
+ // Note: We could steal more bits if the need arises.
+ enum { NumLowBitsAvailable = 1 };
+};
+
+/// Pair of raw pointer and a boolean flag of whether the pointer holds a splat,
+using DenseIterPtrAndSplat =
+ llvm::PointerIntPair<const char *, 1, bool,
+ DenseElementDataPointerTypeTraits>;
+
+/// Impl iterator for indexed DenseElementAttr iterators that records a data
+/// pointer and data index that is adjusted for the case of a splat attribute.
+template <typename ConcreteT, typename T, typename PointerT = T *,
+ typename ReferenceT = T &>
+class DenseElementIndexedIteratorImpl
+ : public indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T,
+ PointerT, ReferenceT> {
+protected:
+ DenseElementIndexedIteratorImpl(const char *data, bool isSplat,
+ size_t dataIndex)
+ : indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T, PointerT,
+ ReferenceT>({data, isSplat}, dataIndex) {}
+
+ /// Return the current index for this iterator, adjusted for the case of a
+ /// splat.
+ ptrdiff_t getDataIndex() const {
+ bool isSplat = this->object.getInt();
+ return isSplat ? 0 : this->index;
+ }
+
+ /// Return the data object pointer.
+ const char *getData() const { return this->object.getPointer(); }
+};
+} // namespace detail
+
/// An attribute that represents a reference to a dense vector or tensor object.
///
class DenseElementsAttr
AttributeElementIterator(DenseElementsAttr attr, size_t index);
};
+ /// Iterator for walking raw element values of the specified type 'T', which
+ /// may be any c++ data type matching the stored representation: int32_t,
+ /// float, etc.
+ template <typename T>
+ class ElementIterator
+ : public detail::DenseElementIndexedIteratorImpl<ElementIterator<T>,
+ const T> {
+ public:
+ /// Accesses the raw value at this iterator position.
+ const T &operator*() const {
+ return reinterpret_cast<const T *>(this->getData())[this->getDataIndex()];
+ }
+
+ private:
+ friend DenseElementsAttr;
+
+ /// Constructs a new iterator.
+ ElementIterator(const char *data, bool isSplat, size_t dataIndex)
+ : detail::DenseElementIndexedIteratorImpl<ElementIterator<T>, const T>(
+ data, isSplat, dataIndex) {}
+ };
+
/// A utility iterator that allows walking over the internal raw APInt values.
class IntElementIterator
- : public indexed_accessor_iterator<IntElementIterator, const char *,
- APInt, APInt, APInt> {
+ : public detail::DenseElementIndexedIteratorImpl<IntElementIterator,
+ APInt, APInt, APInt> {
public:
/// Accesses the raw APInt value at this iterator position.
APInt operator*() const;
// Value Querying
//===--------------------------------------------------------------------===//
- /// Returns the number of raw elements held by this attribute.
- size_t rawSize() const;
-
/// Returns if this attribute corresponds to a splat, i.e. if all element
/// values are the same.
bool isSplat() const;
/// element, then a null attribute is returned.
Attribute getValue(ArrayRef<uint64_t> index) const;
- /// Return the held element values as an array of integer or floating-point
+ /// Return the held element values as a range of integer or floating-point
/// values.
template <typename T, typename = typename std::enable_if<
(!std::is_same<T, bool>::value &&
std::numeric_limits<T>::is_integer) ||
llvm::is_one_of<T, float, double>::value>::type>
- ArrayRef<T> getValues() const {
+ llvm::iterator_range<ElementIterator<T>> getValues() const {
assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer));
- auto rawData = getRawData();
- return ArrayRef<T>(reinterpret_cast<const T *>(rawData.data()),
- rawData.size() / sizeof(T));
+ auto rawData = getRawData().data();
+ bool splat = isSplat();
+ return {ElementIterator<T>(rawData, splat, 0),
+ ElementIterator<T>(rawData, splat, getNumElements())};
}
/// Return the held element values as a range of Attributes.
return IntElementIterator(*this, 0);
}
IntElementIterator raw_int_end() const {
- return IntElementIterator(*this, rawSize());
+ return IntElementIterator(*this, getNumElements());
}
/// Constructs a dense elements attribute from an array of raw APInt values.
const UniformQuantizedValueConverter &converter) {
// Convert to corresponding quantized value attributes.
SmallVector<APInt, 8> quantValues;
- quantValues.reserve(realFPElementsAttr.rawSize());
- for (APFloat realVal : realFPElementsAttr) {
- quantValues.push_back(converter.quantizeFloatToInt(realVal));
+ if (realFPElementsAttr.isSplat()) {
+ quantValues.push_back(
+ converter.quantizeFloatToInt(*realFPElementsAttr.begin()));
+ } else {
+ quantValues.reserve(realFPElementsAttr.getNumElements());
+ for (APFloat realVal : realFPElementsAttr) {
+ quantValues.push_back(converter.quantizeFloatToInt(realVal));
+ }
}
// Cast from an expressed-type-based type to storage-type-based type,
/// Print the integer element of the given DenseElementsAttr at 'index'.
static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os,
unsigned index) {
- APInt value = *std::next(attr.getIntValues().begin(), index);
+ APInt value = *std::next(attr.int_value_begin(), index);
if (value.getBitWidth() == 1)
os << (value.getBoolValue() ? "true" : "false");
else
/// Print the float element of the given DenseElementsAttr at 'index'.
static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
unsigned index) {
- APFloat value = *std::next(attr.getFloatValues().begin(), index);
+ APFloat value = *std::next(attr.float_value_begin(), index);
printFloatValue(value, os);
}
return Attribute::getType().cast<ShapedType>();
}
+/// Returns the number of elements held by this attribute.
+int64_t ElementsAttr::getNumElements() const {
+ return getType().getNumElements();
+}
+
/// Return the value at the given index. If index does not refer to a valid
/// element, then a null attribute is returned.
Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
/// Constructs a new iterator.
DenseElementsAttr::IntElementIterator::IntElementIterator(
DenseElementsAttr attr, size_t index)
- : indexed_accessor_iterator<IntElementIterator, const char *, APInt, APInt,
- APInt>(attr.getRawData().data(), index),
+ : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
+ attr.getRawData().data(), attr.isSplat(), index),
bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {}
/// Accesses the raw APInt value at this iterator position.
APInt DenseElementsAttr::IntElementIterator::operator*() const {
- return readBits(object, index * getDenseElementStorageWidth(bitWidth),
+ return readBits(getData(),
+ getDataIndex() * getDenseElementStorageWidth(bitWidth),
bitWidth);
}
return static_cast<ImplType *>(impl)->data;
}
-/// Returns the number of raw elements held by this attribute.
-size_t DenseElementsAttr::rawSize() const {
- return isSplat() ? 1 : getType().getNumElements();
-}
-
/// Returns if this attribute corresponds to a splat, i.e. if all element
/// values are the same.
bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; }
return AttributeElementIterator(*this, 0);
}
auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
- return AttributeElementIterator(*this, rawSize());
+ return AttributeElementIterator(*this, getNumElements());
}
/// Return the held element values as a range of APInts. The element type of
else
assert(newArrayType && "Unhandled tensor type");
- data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * attr.rawSize());
+ size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
+ data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
- uint64_t elementIdx = 0;
- for (auto value : attr) {
+ // Functor used to process a single element value of the attribute.
+ auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
auto newInt = mapping(value);
assert(newInt.getBitWidth() == bitWidth);
- writeBits(data.data(), elementIdx * storageBitWidth, newInt);
- ++elementIdx;
+ writeBits(data.data(), index * storageBitWidth, newInt);
+ };
+
+ // Check for the splat case.
+ if (attr.isSplat()) {
+ processElt(*attr.begin(), /*index=*/0);
+ return newArrayType;
}
+ // Otherwise, process all of the element values.
+ uint64_t elementIdx = 0;
+ for (auto value : attr)
+ processElt(value, elementIdx++);
return newArrayType;
}
// The sparse indices are 64-bit integers, so we can reinterpret the raw data
// as a 1-D index array.
auto sparseIndices = getIndices();
- ArrayRef<uint64_t> sparseIndexValues = sparseIndices.getValues<uint64_t>();
+ auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
// Check to see if the indices are a splat.
if (sparseIndices.isSplat()) {
// If the index is also not a splat of the index value, we know that the
// value is zero.
- auto splatIndex = sparseIndexValues.front();
+ auto splatIndex = *sparseIndexValues.begin();
if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
return getZeroAttr();
llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
auto numSparseIndices = sparseIndices.getType().getDimSize(0);
for (size_t i = 0, e = numSparseIndices; i != e; ++i)
- mappedIndices.try_emplace({&sparseIndexValues[i * rank], rank}, i);
+ mappedIndices.try_emplace(
+ {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
// Look for the provided index key within the mapped indices. If the provided
// index is not found, then return a zero attribute.