/// element, then a null attribute is returned.
Attribute getValue(ArrayRef<uint64_t> index) const;
+ /// Generates a new ElementsAttr by mapping each int value to a new
+ /// underlying APInt. The new values can represent either a integer or float.
+ /// This ElementsAttr should contain integers.
+ ElementsAttr
+ mapValues(Type newElementType,
+ llvm::function_ref<APInt(const APInt &)> mapping) const;
+
+ /// Generates a new ElementsAttr by mapping each float value to a new
+ /// underlying APInt. The new values can represent either a integer or float.
+ /// This ElementsAttr should contain floats.
+ ElementsAttr
+ mapValues(Type newElementType,
+ llvm::function_ref<APInt(const APFloat &)> mapping) const;
+
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(Attribute attr) {
return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR &&
static SplatElementsAttr get(ShapedType type, Attribute elt);
Attribute getValue() const;
+ /// Generates a new SplatElementsAttr by mapping each int value to a new
+ /// underlying APInt. The new values can represent either a integer or float.
+ /// This ElementsAttr should contain integers.
+ SplatElementsAttr
+ mapValues(Type newElementType,
+ llvm::function_ref<APInt(const APInt &)> mapping) const;
+
+ /// Generates a new SplatElementsAttr by mapping each float value to a new
+ /// underlying APInt. The new values can represent either a integer or float.
+ /// This ElementsAttr should contain floats.
+ SplatElementsAttr
+ mapValues(Type newElementType,
+ llvm::function_ref<APInt(const APFloat &)> mapping) const;
+
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::SplatElements;
void getValues(SmallVectorImpl<Attribute> &values) const;
+ /// Generates a new DenseElementsAttr by mapping each int value to a new
+ /// underlying APInt. The new values can represent either a integer or float.
+ /// This underlying type must be an DenseIntElementsAttr.
+ DenseElementsAttr
+ mapValues(Type newElementType,
+ llvm::function_ref<APInt(const APInt &)> mapping) const;
+
+ /// Generates a new DenseElementsAttr by mapping each float value to a new
+ /// underlying APInt. the new values can represent either a integer or float.
+ /// This underlying type must be an DenseFPElementsAttr.
+ DenseElementsAttr
+ mapValues(Type newElementType,
+ llvm::function_ref<APInt(const APFloat &)> mapping) const;
+
ArrayRef<char> getRawData() const;
/// Writes value to the bit position `bitPos` in array `rawData`. 'rawData' is
/// type of 'type'.
static DenseIntElementsAttr get(ShapedType type, ArrayRef<int64_t> values);
+ /// Generates a new DenseElementsAttr by mapping each value attribute, and
+ /// constructing the DenseElementsAttr given the new element type.
+ DenseElementsAttr
+ mapValues(Type newElementType,
+ llvm::function_ref<APInt(const APInt &)> mapping) const;
+
/// Gets the integer value of each of the dense elements.
void getValues(SmallVectorImpl<APInt> &values) const;
/// Gets the float value of each of the dense elements.
void getValues(SmallVectorImpl<APFloat> &values) const;
+ /// Generates a new DenseElementsAttr by mapping each value attribute, and
+ /// constructing the DenseElementsAttr given the new element type.
+ DenseElementsAttr
+ mapValues(Type newElementType,
+ llvm::function_ref<APInt(const APFloat &)> mapping) const;
+
/// Iterator access to the float element values.
iterator begin() const;
iterator end() const;
}
}
+ElementsAttr ElementsAttr::mapValues(
+ Type newElementType,
+ llvm::function_ref<APInt(const APInt &)> mapping) const {
+ switch (getKind()) {
+ case StandardAttributes::DenseIntElements:
+ case StandardAttributes::DenseFPElements:
+ return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
+ case StandardAttributes::SplatElements:
+ return cast<SplatElementsAttr>().mapValues(newElementType, mapping);
+ default:
+ llvm_unreachable("unsupported ElementsAttr subtype");
+ }
+}
+
+ElementsAttr ElementsAttr::mapValues(
+ Type newElementType,
+ llvm::function_ref<APInt(const APFloat &)> mapping) const {
+ switch (getKind()) {
+ case StandardAttributes::DenseIntElements:
+ case StandardAttributes::DenseFPElements:
+ return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
+ case StandardAttributes::SplatElements:
+ return cast<SplatElementsAttr>().mapValues(newElementType, mapping);
+ default:
+ llvm_unreachable("unsupported ElementsAttr subtype");
+ }
+}
+
//===----------------------------------------------------------------------===//
// SplatElementsAttr
//===----------------------------------------------------------------------===//
Attribute SplatElementsAttr::getValue() const { return getImpl()->elt; }
+SplatElementsAttr SplatElementsAttr::mapValues(
+ Type newElementType,
+ llvm::function_ref<APInt(const APInt &)> mapping) const {
+ ShapedType inType = getType();
+
+ ShapedType newArrayType;
+ if (inType.isa<RankedTensorType>())
+ newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
+ else if (inType.isa<UnrankedTensorType>())
+ newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
+ else if (inType.isa<VectorType>())
+ newArrayType = VectorType::get(inType.getShape(), newElementType);
+ else
+ assert(false && "Unhandled tensor type");
+
+ assert(getType().getElementType().isa<IntegerType>() &&
+ "Attempting to map non-integer array as integers");
+
+ if (newElementType.isa<IntegerType>()) {
+ APInt newValue = mapping(getValue().cast<IntegerAttr>().getValue());
+ auto newAttr = IntegerAttr::get(newElementType, newValue);
+ return get(newArrayType, newAttr);
+ }
+
+ if (newElementType.isa<FloatType>()) {
+ APFloat newValue(newElementType.cast<FloatType>().getFloatSemantics(),
+ mapping(getValue().cast<IntegerAttr>().getValue()));
+ auto newAttr = FloatAttr::get(newElementType, newValue);
+ return get(newArrayType, newAttr);
+ }
+
+ llvm_unreachable("unknown output splat type");
+}
+
+SplatElementsAttr SplatElementsAttr::mapValues(
+ Type newElementType,
+ llvm::function_ref<APInt(const APFloat &)> mapping) const {
+ Type inType = getType();
+
+ ShapedType newArrayType;
+ if (inType.isa<RankedTensorType>()) {
+ newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
+ } else if (inType.isa<UnrankedTensorType>()) {
+ newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
+ }
+
+ assert(newArrayType && "Unhandled tensor type");
+ assert(getType().getElementType().isa<FloatType>() &&
+ "mapping function expects float tensor");
+
+ Attribute newAttr;
+ if (newElementType.isa<IntegerType>()) {
+ APInt newValue = mapping(getValue().cast<FloatAttr>().getValue());
+ newAttr = IntegerAttr::get(newElementType, newValue);
+ return get(newArrayType, newAttr);
+ }
+
+ if (newElementType.isa<FloatType>()) {
+ APFloat newValue(newElementType.cast<FloatType>().getFloatSemantics(),
+ mapping(getValue().cast<FloatAttr>().getValue()));
+ newAttr = FloatAttr::get(newElementType, newValue);
+ return get(newArrayType, newAttr);
+ }
+
+ llvm_unreachable("unknown output splat type");
+}
+
//===----------------------------------------------------------------------===//
// RawElementIterator
//===----------------------------------------------------------------------===//
}
}
+DenseElementsAttr DenseElementsAttr::mapValues(
+ Type newElementType,
+ llvm::function_ref<APInt(const APInt &)> mapping) const {
+ return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
+}
+
+DenseElementsAttr DenseElementsAttr::mapValues(
+ Type newElementType,
+ llvm::function_ref<APInt(const APFloat &)> mapping) const {
+ return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
+}
+
ArrayRef<char> DenseElementsAttr::getRawData() const {
return static_cast<ImplType *>(impl)->data;
}
values.assign(raw_begin(), raw_end());
}
+template<typename Fn, typename Attr>
+static ShapedType mappingHelper(
+ Fn mapping, Attr& attr, ShapedType inType, Type newElementType,
+ llvm::SmallVectorImpl<char>& data) {
+ size_t bitWidth = getDenseElementBitwidth(newElementType);
+
+ ShapedType newArrayType;
+ if (inType.isa<RankedTensorType>())
+ newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
+ else if (inType.isa<UnrankedTensorType>())
+ newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
+ else if (inType.isa<VectorType>())
+ newArrayType = VectorType::get(inType.getShape(), newElementType);
+ else
+ assert(newArrayType && "Unhandled tensor type");
+
+ data.resize(APInt::getNumWords(bitWidth * inType.getNumElements()) *
+ APInt::APINT_WORD_SIZE);
+
+ uint64_t elementIdx = 0;
+ for (auto value : attr) {
+ auto newInt = mapping(value);
+ assert(newInt.getBitWidth() == bitWidth);
+ attr.writeBits(data.data(), elementIdx * bitWidth, newInt);
+ ++elementIdx;
+ }
+
+ return newArrayType;
+}
+
+DenseElementsAttr DenseIntElementsAttr::mapValues(
+ Type newElementType,
+ llvm::function_ref<APInt(const APInt &)> mapping) const {
+ llvm::SmallVector<char, 8> elementData;
+ auto newArrayType = mappingHelper(
+ mapping, *this, getType(), newElementType, elementData);
+
+ return get(newArrayType, elementData);
+}
+
//===----------------------------------------------------------------------===//
// DenseFPElementsAttr
//===----------------------------------------------------------------------===//
values.assign(begin(), end());
}
+DenseElementsAttr DenseFPElementsAttr::mapValues(
+ Type newElementType,
+ llvm::function_ref<APInt(const APFloat &)> mapping) const {
+ llvm::SmallVector<char, 8> elementData;
+ auto newArrayType = mappingHelper(
+ mapping, *this, getType(), newElementType, elementData);
+
+ return get(newArrayType, elementData);
+}
+
/// Iterator access to the float element values.
DenseFPElementsAttr::iterator DenseFPElementsAttr::begin() const {
auto elementType = getType().getElementType().cast<FloatType>();