// auto oldType = constantOp.getType();
auto newType = rewriter.getTensorType(
reshapeType.getShape(), valueAttr.getType().getElementType());
- auto newAttr =
- mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
+ auto newAttr = valueAttr.reshape(newType);
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
newAttr);
} else if (auto valueAttr =
// auto oldType = constantOp.getType();
auto newType = rewriter.getTensorType(
reshapeType.getShape(), valueAttr.getType().getElementType());
- auto newAttr =
- mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
+ auto newAttr = valueAttr.reshape(newType);
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
newAttr);
} else if (auto valueAttr =
public:
using Base::Base;
- /// It assumes the elements in the input array have been truncated to the bits
- /// width specified by the element type. 'type' must be a vector or tensor
- /// with static shape.
- static DenseElementsAttr get(ShapedType type, ArrayRef<char> data);
-
/// Constructs a dense elements attribute from an array of element values.
/// Each element attribute value is expected to be an element of 'type'.
/// 'type' must be a vector or tensor with static shape.
static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values);
+ /// Constructs a dense integer elements attribute from an array of integer
+ /// or floating-point values. Each value is expected to be the same bitwidth
+ /// of the element type of 'type'. 'type' must be a vector or tensor with
+ /// static shape.
+ template <typename ShapeT, typename T>
+ static DenseElementsAttr get(ShapeT type, ArrayRef<T> values) {
+ static_assert(std::numeric_limits<T>::is_integer ||
+ llvm::is_one_of<T, float, double>::value,
+ "expected integer or floating point element type");
+
+ assert(type.getNumElements() == static_cast<int64_t>(values.size()));
+ assert(type.getElementTypeBitWidth() == (sizeof(T) * CHAR_BIT));
+ const char *data = reinterpret_cast<const char *>(values.data());
+ return getRawIntOrFloat(type,
+ ArrayRef<char>(data, values.size() * sizeof(T)),
+ /*isInt=*/std::numeric_limits<T>::is_integer);
+ }
+
+ /// Overload of the above 'get' method that is specialized for boolean values.
+ static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
+
/// Returns the number of elements held by this attribute.
size_t size() const;
/// element, then a null attribute is returned.
Attribute getValue(ArrayRef<uint64_t> index) const;
+ /// Return the held element values as Attributes in 'values'.
void getValues(SmallVectorImpl<Attribute> &values) const;
+ /// Return a new DenseElementsAttr that has the same data as the current
+ /// attribute, but has been reshaped to 'newType'. The new type must have the
+ /// same total number of elements as well as element type.
+ DenseElementsAttr reshape(ShapedType newType);
+
/// Generates a new DenseElementsAttr by mapping each int value to a new
/// underlying APInt. The new values can represent either a integer or float.
/// This underlying type must be an DenseIntElementsAttr.
return RawElementIterator(*this, size());
}
+ /// Get or create a new dense elements attribute instance with the given raw
+ /// data buffer. 'type' must be a vector or tensor with static shape.
+ static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data);
+
+ /// Overload of the raw 'get' method that asserts that the given type is of
+ /// integer or floating-point type.
+ static DenseElementsAttr getRawIntOrFloat(ShapedType type,
+ ArrayRef<char> data, bool isInt);
+
/// Constructs a dense elements attribute from an array of raw APInt values.
/// Each APInt value is expected to have the same bitwidth as the element type
/// of 'type'. 'type' must be a vector or tensor with static shape.
/// shape.
static DenseIntElementsAttr get(ShapedType type, ArrayRef<APInt> values);
- /// Constructs a dense integer elements attribute from an array of integer
- /// values. Each value is expected to be within the bitwidth of the element
- /// type of 'type'. 'type' must be a vector or tensor with static shape.
- static DenseIntElementsAttr get(ShapedType type, ArrayRef<int64_t> values);
-
/// Generates a new DenseElementsAttr by mapping each value attribute, and
/// constructing the DenseElementsAttr given the new element type.
DenseElementsAttr
FunctionAttr getFunctionAttr(Function *value);
FunctionAttr getFunctionAttr(StringRef value);
ElementsAttr getSplatElementsAttr(ShapedType type, Attribute elt);
- ElementsAttr getDenseElementsAttr(ShapedType type, ArrayRef<char> data);
ElementsAttr getDenseElementsAttr(ShapedType type,
ArrayRef<Attribute> values);
ElementsAttr getDenseIntElementsAttr(ShapedType type,
// DenseElementsAttr
//===----------------------------------------------------------------------===//
-DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<char> data) {
+DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
+ ArrayRef<char> data) {
assert((static_cast<uint64_t>(type.getSizeInBits()) <=
data.size() * APInt::APINT_WORD_SIZE) &&
"Input data bit size should be larger than that type requires");
data);
}
+/// Overload of the raw 'get' method that asserts that the given type is of
+/// integer type.
+DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
+ ArrayRef<char> data,
+ bool isInt) {
+ assert(isInt ? type.getElementType().isa<IntegerType>()
+ : type.getElementType().isa<FloatType>());
+ return getRaw(type, data);
+}
+
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+ ArrayRef<bool> values) {
+ assert(type.getNumElements() == static_cast<int64_t>(values.size()));
+ assert(type.getElementType().isInteger(1));
+
+ std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
+ for (int i = 0, e = values.size(); i != e; ++i)
+ writeBits(buff.data(), i, llvm::APInt(1, values[i]));
+ return getRaw(type, buff);
+}
+
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<Attribute> values) {
assert(type.getElementType().isIntOrFloat() &&
"expected value to have same bitwidth as element type");
writeBits(data.data(), i * storageBitWidth, intVal);
}
- return get(type, data);
+ return getRaw(type, data);
}
/// Returns the number of elements held by this attribute.
llvm_unreachable("unexpected element type");
}
+/// Return a new DenseElementsAttr that has the same data as the current
+/// attribute, but has been reshaped to 'newType'. The new type must have the
+/// same total number of elements as well as element type.
+DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
+ ShapedType curType = getType();
+ if (curType == newType)
+ return *this;
+
+ (void)curType;
+ assert(newType.getElementType() == curType.getElementType() &&
+ "expected the same element type");
+ assert(newType.getNumElements() == curType.getNumElements() &&
+ "expected the same number of elements");
+ return getRaw(newType, getRawData());
+}
+
DenseElementsAttr DenseElementsAttr::mapValues(
Type newElementType,
llvm::function_ref<APInt(const APInt &)> mapping) const {
assert(values[i].getBitWidth() == bitWidth);
writeBits(elementData.data(), i * storageBitWidth, values[i]);
}
- return get(type, elementData);
+ return getRaw(type, elementData);
}
/// Writes value to the bit position `bitPos` in array `rawData`.
return DenseElementsAttr::get(type, values).cast<DenseIntElementsAttr>();
}
-/// Constructs a dense integer elements attribute from an array of integer
-/// values. Each value is expected to be within the bitwidth of the element
-/// type of 'type'.
-DenseIntElementsAttr DenseIntElementsAttr::get(ShapedType type,
- ArrayRef<int64_t> values) {
- auto eltType = type.getElementType();
- size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
-
- // Convert the raw integer values to APInt.
- SmallVector<APInt, 8> apIntValues;
- apIntValues.reserve(values.size());
- for (auto value : values)
- apIntValues.emplace_back(APInt(bitWidth, value));
- return get(type, apIntValues);
-}
-
void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
values.reserve(size());
values.assign(raw_begin(), raw_end());
auto newArrayType =
mappingHelper(mapping, *this, getType(), newElementType, elementData);
- return get(newArrayType, elementData);
+ return getRaw(newArrayType, elementData);
}
/// Method for supporting type inquiry through isa, cast and dyn_cast.
auto newArrayType =
mappingHelper(mapping, *this, getType(), newElementType, elementData);
- return get(newArrayType, elementData);
+ return getRaw(newArrayType, elementData);
}
/// Iterator access to the float element values.
}
ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
- ArrayRef<char> data) {
- return DenseElementsAttr::get(type, data);
-}
-
-ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
ArrayRef<Attribute> values) {
return DenseElementsAttr::get(type, values);
}