namespace detail {
struct AttributeStorage;
-struct UnitAttributeStorage;
struct BoolAttributeStorage;
struct IntegerAttributeStorage;
struct FloatAttributeStorage;
class UnitAttr : public Attribute {
public:
using Attribute::Attribute;
- using ImplType = detail::UnitAttributeStorage;
static UnitAttr get(MLIRContext *context);
/// This should not be used directly.
StorageUniquer &getTypeUniquer();
+ /// Returns the storage uniquer used for constructing attribute storage
+ /// instances. This should not be used directly.
+ StorageUniquer &getAttributeUniquer();
+
private:
const std::unique_ptr<MLIRContextImpl> impl;
/// that builds a unique instance of the derived storage. The arguments to
/// this function are an allocator to store any uniqued data and the key
/// type for this storage.
+///
+/// - Provide a cleanup method:
+/// 'void cleanup()'
+/// that is called when erasing a storage instance. This should cleanup any
+/// fields of the storage as necessary and not attempt to free the memory
+/// of the storage itself.
class StorageUniquer {
public:
StorageUniquer();
/// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
/// that can be used to initialize a newly inserted storage instance. This
- /// overload is used for derived types that have complex storage or uniquing
+ /// function is used for derived types that have complex storage or uniquing
/// constraints.
template <typename Storage, typename... Args>
Storage *getComplex(std::function<void(Storage *)> initFn, unsigned kind,
/// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
/// that can be used to initialize a newly inserted storage instance. This
- /// overload is used for derived types that use no additional storage or
+ /// function is used for derived types that use no additional storage or
/// uniquing outside of the kind.
template <typename Storage>
Storage *getSimple(std::function<void(Storage *)> initFn, unsigned kind) {
return static_cast<Storage *>(getImpl(kind, ctorFn));
}
+ /// Erases a uniqued instance of 'Storage'. This function is used for derived
+ /// types that have complex storage or uniquing constraints.
+ template <typename Storage, typename... Args>
+ void eraseComplex(unsigned kind, Args &&... args) {
+ // Construct a value of the derived key type.
+ auto derivedKey = getKey<Storage>(args...);
+
+ // Create a hash of the kind and the derived key.
+ unsigned hashValue = getHash<Storage>(kind, derivedKey);
+
+ // Generate an equality function for the derived storage.
+ std::function<bool(const BaseStorage *)> isEqual =
+ [&derivedKey](const BaseStorage *existing) {
+ return static_cast<const Storage &>(*existing) == derivedKey;
+ };
+
+ // Attempt to erase the storage instance.
+ eraseImpl(kind, hashValue, isEqual, [](BaseStorage *storage) {
+ static_cast<Storage *>(storage)->cleanup();
+ });
+ }
+
private:
/// Implementation for getting/creating an instance of a derived type with
/// complex storage.
BaseStorage *getImpl(unsigned kind,
std::function<BaseStorage *(StorageAllocator &)> ctorFn);
+ /// Implementation for erasing an instance of a derived type with complex
+ /// storage.
+ void eraseImpl(unsigned kind, unsigned hashValue,
+ llvm::function_ref<bool(const BaseStorage *)> isEqual,
+ std::function<void(BaseStorage *)> cleanupFn);
+
/// The internal implementation class.
std::unique_ptr<detail::StorageUniquerImpl> impl;
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
+#include "mlir/Support/StorageUniquer.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
namespace detail {
-
/// Base storage class appearing in an attribute.
-struct AttributeStorage {
- AttributeStorage(Attribute::Kind kind, bool isOrContainsFunctionCache = false)
- : kind(kind), isOrContainsFunctionCache(isOrContainsFunctionCache) {}
-
- Attribute::Kind kind : 8;
+struct AttributeStorage : public StorageUniquer::BaseStorage {
+ AttributeStorage(bool isOrContainsFunctionCache = false)
+ : isOrContainsFunctionCache(isOrContainsFunctionCache) {}
/// This field is true if this is, or contains, a function attribute.
bool isOrContainsFunctionCache : 1;
};
-/// An attribute representing a unit value.
-struct UnitAttributeStorage : public AttributeStorage {
- UnitAttributeStorage() : AttributeStorage(Attribute::Kind::Unit) {}
+// A utility class to get, or create, unique instances of attributes within an
+// MLIRContext. This class manages all creation and uniquing of attributes.
+class AttributeUniquer {
+public:
+ /// Get an uniqued instance of attribute T. This overload is used for
+ /// derived attributes that have complex storage or uniquing constraints.
+ template <typename T, typename... Args>
+ static typename std::enable_if<
+ !std::is_same<typename T::ImplType, AttributeStorage>::value, T>::type
+ get(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) {
+ return ctx->getAttributeUniquer().getComplex<typename T::ImplType>(
+ /*initFn=*/{}, static_cast<unsigned>(kind),
+ std::forward<Args>(args)...);
+ }
+
+ /// Get an uniqued instance of attribute T. This overload is used for
+ /// derived attributes that use the AttributeStorage directly and thus need no
+ /// additional storage or uniquing.
+ template <typename T, typename... Args>
+ static typename std::enable_if<
+ std::is_same<typename T::ImplType, AttributeStorage>::value, T>::type
+ get(MLIRContext *ctx, Attribute::Kind kind) {
+ return ctx->getAttributeUniquer().getSimple<AttributeStorage>(
+ /*initFn=*/{}, static_cast<unsigned>(kind));
+ }
+
+ /// Erase a uniqued instance of attribute T. This overload is used for
+ /// derived attributes that have complex storage or uniquing constraints.
+ template <typename T, typename... Args>
+ static typename std::enable_if<
+ !std::is_same<typename T::ImplType, AttributeStorage>::value>::type
+ erase(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) {
+ return ctx->getAttributeUniquer().eraseComplex<typename T::ImplType>(
+ static_cast<unsigned>(kind), std::forward<Args>(args)...);
+ }
};
+using AttributeStorageAllocator = StorageUniquer::StorageAllocator;
+
/// An attribute representing a boolean value.
struct BoolAttributeStorage : public AttributeStorage {
- BoolAttributeStorage(Type type, bool value)
- : AttributeStorage(Attribute::Kind::Bool), type(type), value(value) {}
- const Type type;
+ using KeyTy = std::pair<MLIRContext *, bool>;
+
+ BoolAttributeStorage(Type type, bool value) : type(type), value(value) {}
+
+ /// We only check equality for and hash with the boolean key parameter.
+ bool operator==(const KeyTy &key) const { return key.second == value; }
+ static unsigned hashKey(const KeyTy &key) {
+ return llvm::hash_value(key.second);
+ }
+
+ static BoolAttributeStorage *construct(AttributeStorageAllocator &allocator,
+ const KeyTy &key) {
+ return new (allocator.allocate<BoolAttributeStorage>())
+ BoolAttributeStorage(IntegerType::get(1, key.first), key.second);
+ }
+
+ Type type;
bool value;
};
struct IntegerAttributeStorage final
: public AttributeStorage,
public llvm::TrailingObjects<IntegerAttributeStorage, uint64_t> {
+ using KeyTy = std::pair<Type, APInt>;
+
IntegerAttributeStorage(Type type, size_t numObjects)
- : AttributeStorage(Attribute::Kind::Integer), type(type),
- numObjects(numObjects) {
+ : type(type), numObjects(numObjects) {
assert((type.isIndex() || type.isa<IntegerType>()) && "invalid type");
}
- const Type type;
- size_t numObjects;
+ /// Key equality and hash functions.
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(type, getValue());
+ }
+ static unsigned hashKey(const KeyTy &key) {
+ return llvm::hash_combine(key.first, llvm::hash_value(key.second));
+ }
+
+ /// Construct a new storage instance.
+ static IntegerAttributeStorage *
+ construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+ Type type;
+ APInt value;
+ std::tie(type, value) = key;
+
+ auto elements = ArrayRef<uint64_t>(value.getRawData(), value.getNumWords());
+ auto size =
+ IntegerAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
+ auto rawMem = allocator.allocate(size, alignof(IntegerAttributeStorage));
+ auto result = ::new (rawMem) IntegerAttributeStorage(type, elements.size());
+ std::uninitialized_copy(elements.begin(), elements.end(),
+ result->getTrailingObjects<uint64_t>());
+ return result;
+ }
/// Returns an APInt representing the stored value.
APInt getValue() const {
return APInt(type.getIntOrFloatBitWidth(),
{getTrailingObjects<uint64_t>(), numObjects});
}
+
+ Type type;
+ size_t numObjects;
};
/// An attribute representing a floating point value.
struct FloatAttributeStorage final
: public AttributeStorage,
public llvm::TrailingObjects<FloatAttributeStorage, uint64_t> {
+ using KeyTy = std::pair<Type, APFloat>;
+
FloatAttributeStorage(const llvm::fltSemantics &semantics, Type type,
size_t numObjects)
- : AttributeStorage(Attribute::Kind::Float), semantics(semantics),
- type(type.cast<FloatType>()), numObjects(numObjects) {}
- const llvm::fltSemantics &semantics;
- const FloatType type;
- size_t numObjects;
+ : semantics(semantics), type(type.cast<FloatType>()),
+ numObjects(numObjects) {}
+
+ /// Key equality and hash functions.
+ bool operator==(const KeyTy &key) const {
+ return key.first == type && key.second.bitwiseIsEqual(getValue());
+ }
+ static unsigned hashKey(const KeyTy &key) {
+ return llvm::hash_combine(key.first, llvm::hash_value(key.second));
+ }
+
+ /// Construct a new storage instance.
+ static FloatAttributeStorage *construct(AttributeStorageAllocator &allocator,
+ const KeyTy &key) {
+ const auto &apint = key.second.bitcastToAPInt();
+
+ // Here one word's bitwidth equals to that of uint64_t.
+ auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());
+
+ auto byteSize =
+ FloatAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
+ auto rawMem = allocator.allocate(byteSize, alignof(FloatAttributeStorage));
+ auto result = ::new (rawMem) FloatAttributeStorage(
+ key.second.getSemantics(), key.first, elements.size());
+ std::uninitialized_copy(elements.begin(), elements.end(),
+ result->getTrailingObjects<uint64_t>());
+ return result;
+ }
/// Returns an APFloat representing the stored value.
APFloat getValue() const {
{getTrailingObjects<uint64_t>(), numObjects});
return APFloat(semantics, val);
}
+
+ const llvm::fltSemantics &semantics;
+ FloatType type;
+ size_t numObjects;
};
/// An attribute representing a string value.
struct StringAttributeStorage : public AttributeStorage {
- StringAttributeStorage(StringRef value)
- : AttributeStorage(Attribute::Kind::String), value(value) {}
+ using KeyTy = StringRef;
+
+ StringAttributeStorage(StringRef value) : value(value) {}
+
+ /// Key equality function.
+ bool operator==(const KeyTy &key) const { return key == value; }
+
+ /// Construct a new storage instance.
+ static StringAttributeStorage *construct(AttributeStorageAllocator &allocator,
+ const KeyTy &key) {
+ return new (allocator.allocate<StringAttributeStorage>())
+ StringAttributeStorage(allocator.copyInto(key));
+ }
+
StringRef value;
};
/// An attribute representing an array of other attributes.
struct ArrayAttributeStorage : public AttributeStorage {
+ using KeyTy = ArrayRef<Attribute>;
+
ArrayAttributeStorage(bool hasFunctionAttr, ArrayRef<Attribute> value)
- : AttributeStorage(Attribute::Kind::Array, hasFunctionAttr),
- value(value) {}
+ : AttributeStorage(hasFunctionAttr), value(value) {}
+
+ /// Key equality function.
+ bool operator==(const KeyTy &key) const { return key == value; }
+
+ /// Construct a new storage instance.
+ static ArrayAttributeStorage *construct(AttributeStorageAllocator &allocator,
+ const KeyTy &key) {
+ // Check to see if any of the elements have a function attr.
+ bool hasFunctionAttr = llvm::any_of(
+ key, [](Attribute elt) { return elt.isOrContainsFunction(); });
+
+ // Initialize the memory using placement new.
+ return new (allocator.allocate<ArrayAttributeStorage>())
+ ArrayAttributeStorage(hasFunctionAttr, allocator.copyInto(key));
+ }
+
ArrayRef<Attribute> value;
};
// An attribute representing a reference to an affine map.
struct AffineMapAttributeStorage : public AttributeStorage {
- AffineMapAttributeStorage(AffineMap value)
- : AttributeStorage(Attribute::Kind::AffineMap), value(value) {}
+ using KeyTy = AffineMap;
+
+ AffineMapAttributeStorage(AffineMap value) : value(value) {}
+
+ /// Key equality function.
+ bool operator==(const KeyTy &key) const { return key == value; }
+
+ /// Construct a new storage instance.
+ static AffineMapAttributeStorage *
+ construct(AttributeStorageAllocator &allocator, KeyTy key) {
+ return new (allocator.allocate<AffineMapAttributeStorage>())
+ AffineMapAttributeStorage(key);
+ }
+
AffineMap value;
};
// An attribute representing a reference to an integer set.
struct IntegerSetAttributeStorage : public AttributeStorage {
- IntegerSetAttributeStorage(IntegerSet value)
- : AttributeStorage(Attribute::Kind::IntegerSet), value(value) {}
+ using KeyTy = IntegerSet;
+
+ IntegerSetAttributeStorage(IntegerSet value) : value(value) {}
+
+ /// Key equality function.
+ bool operator==(const KeyTy &key) const { return key == value; }
+
+ /// Construct a new storage instance.
+ static IntegerSetAttributeStorage *
+ construct(AttributeStorageAllocator &allocator, KeyTy key) {
+ return new (allocator.allocate<IntegerSetAttributeStorage>())
+ IntegerSetAttributeStorage(key);
+ }
+
IntegerSet value;
};
/// An attribute representing a reference to a type.
struct TypeAttributeStorage : public AttributeStorage {
- TypeAttributeStorage(Type value)
- : AttributeStorage(Attribute::Kind::Type), value(value) {}
+ using KeyTy = Type;
+
+ TypeAttributeStorage(Type value) : value(value) {}
+
+ /// Key equality function.
+ bool operator==(const KeyTy &key) const { return key == value; }
+
+ /// Construct a new storage instance.
+ static TypeAttributeStorage *construct(AttributeStorageAllocator &allocator,
+ KeyTy key) {
+ return new (allocator.allocate<TypeAttributeStorage>())
+ TypeAttributeStorage(key);
+ }
+
Type value;
};
/// An attribute representing a reference to a function.
struct FunctionAttributeStorage : public AttributeStorage {
+ using KeyTy = Function *;
+
FunctionAttributeStorage(Function *value)
- : AttributeStorage(Attribute::Kind::Function,
- /*isOrContainsFunctionCache=*/true),
- value(value) {}
+ : AttributeStorage(/*isOrContainsFunctionCache=*/true), value(value) {}
+
+ /// Key equality function.
+ bool operator==(const KeyTy &key) const { return key == value; }
+
+ /// Construct a new storage instance.
+ static FunctionAttributeStorage *
+ construct(AttributeStorageAllocator &allocator, KeyTy key) {
+ return new (allocator.allocate<FunctionAttributeStorage>())
+ FunctionAttributeStorage(key);
+ }
+
+ /// Storage cleanup function.
+ void cleanup() {
+ // Null out the function reference in the attribute to avoid dangling
+ // pointers.
+ value = nullptr;
+ }
+
Function *value;
};
/// A base attribute representing a reference to a vector or tensor constant.
struct ElementsAttributeStorage : public AttributeStorage {
- ElementsAttributeStorage(Attribute::Kind kind, VectorOrTensorType type)
- : AttributeStorage(kind), type(type) {}
+ ElementsAttributeStorage(VectorOrTensorType type) : type(type) {}
VectorOrTensorType type;
};
/// An attribute representing a reference to a vector or tensor constant,
/// inwhich all elements have the same value.
struct SplatElementsAttributeStorage : public ElementsAttributeStorage {
+ using KeyTy = std::pair<VectorOrTensorType, Attribute>;
+
SplatElementsAttributeStorage(VectorOrTensorType type, Attribute elt)
- : ElementsAttributeStorage(Attribute::Kind::SplatElements, type),
- elt(elt) {}
+ : ElementsAttributeStorage(type), elt(elt) {}
+
+ /// Key equality and hash functions.
+ bool operator==(const KeyTy &key) const {
+ return key == std::make_pair(type, elt);
+ }
+ static unsigned hashKey(const KeyTy &key) {
+ return llvm::hash_combine(key.first, key.second);
+ }
+
+ /// Construct a new storage instance.
+ static SplatElementsAttributeStorage *
+ construct(AttributeStorageAllocator &allocator, KeyTy key) {
+ return new (allocator.allocate<SplatElementsAttributeStorage>())
+ SplatElementsAttributeStorage(key.first, key.second);
+ }
+
Attribute elt;
};
/// An attribute representing a reference to a dense vector or tensor object.
struct DenseElementsAttributeStorage : public ElementsAttributeStorage {
- DenseElementsAttributeStorage(Attribute::Kind kind, VectorOrTensorType type,
- ArrayRef<char> data)
- : ElementsAttributeStorage(kind, type), data(data) {}
+ using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>;
+
+ DenseElementsAttributeStorage(VectorOrTensorType ty, ArrayRef<char> data)
+ : ElementsAttributeStorage(ty), data(data) {}
+
+ /// Key equality and hash functions.
+ bool operator==(const KeyTy &key) const { return key == KeyTy(type, data); }
+ static unsigned hashKey(const KeyTy &key) {
+ return llvm::hash_combine(key.first, key.second);
+ }
+
+ /// Construct a new storage instance.
+ static DenseElementsAttributeStorage *
+ construct(AttributeStorageAllocator &allocator, KeyTy key) {
+ // If the data buffer is non-empty, we copy it into the allocator.
+ ArrayRef<char> data = key.second;
+ if (!data.empty()) {
+ // Rounding up the allocate size to multiples of APINT_WORD_SIZE, so
+ // the `readBits` will not fail when it accesses multiples of
+ // APINT_WORD_SIZE each time.
+ size_t sizeToAllocate =
+ llvm::alignTo(data.size(), APInt::APINT_WORD_SIZE);
+ auto *rawCopy = (char *)allocator.allocate(sizeToAllocate, 64);
+ std::uninitialized_copy(data.begin(), data.end(), rawCopy);
+ data = {rawCopy, data.size()};
+ }
+ return new (allocator.allocate<DenseElementsAttributeStorage>())
+ DenseElementsAttributeStorage(key.first, data);
+ }
+
ArrayRef<char> data;
};
/// An attribute representing a reference to a tensor constant with opaque
/// content.
struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage {
+ using KeyTy = std::tuple<VectorOrTensorType, Dialect *, StringRef>;
+
OpaqueElementsAttributeStorage(VectorOrTensorType type, Dialect *dialect,
StringRef bytes)
- : ElementsAttributeStorage(Attribute::Kind::OpaqueElements, type),
- dialect(dialect), bytes(bytes) {}
+ : ElementsAttributeStorage(type), dialect(dialect), bytes(bytes) {}
+
+ /// Key equality and hash functions.
+ bool operator==(const KeyTy &key) const {
+ return key == std::make_tuple(type, dialect, bytes);
+ }
+ static unsigned hashKey(const KeyTy &key) {
+ return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
+ std::get<2>(key));
+ }
+
+ /// Construct a new storage instance.
+ static OpaqueElementsAttributeStorage *
+ construct(AttributeStorageAllocator &allocator, KeyTy key) {
+ // TODO(b/131468830): Provide a way to avoid copying content of large opaque
+ // tensors This will likely require a new reference attribute kind.
+ return new (allocator.allocate<OpaqueElementsAttributeStorage>())
+ OpaqueElementsAttributeStorage(std::get<0>(key), std::get<1>(key),
+ allocator.copyInto(std::get<2>(key)));
+ }
+
Dialect *dialect;
StringRef bytes;
};
/// An attribute representing a reference to a sparse vector or tensor object.
struct SparseElementsAttributeStorage : public ElementsAttributeStorage {
+ using KeyTy =
+ std::tuple<VectorOrTensorType, DenseIntElementsAttr, DenseElementsAttr>;
+
SparseElementsAttributeStorage(VectorOrTensorType type,
DenseIntElementsAttr indices,
DenseElementsAttr values)
- : ElementsAttributeStorage(Attribute::Kind::SparseElements, type),
- indices(indices), values(values) {}
+ : ElementsAttributeStorage(type), indices(indices), values(values) {}
+
+ /// Key equality and hash functions.
+ bool operator==(const KeyTy &key) const {
+ return key == std::make_tuple(type, indices, values);
+ }
+ static unsigned hashKey(const KeyTy &key) {
+ return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
+ std::get<2>(key));
+ }
+
+ /// Construct a new storage instance.
+ static SparseElementsAttributeStorage *
+ construct(AttributeStorageAllocator &allocator, KeyTy key) {
+ return new (allocator.allocate<SparseElementsAttributeStorage>())
+ SparseElementsAttributeStorage(std::get<0>(key), std::get<1>(key),
+ std::get<2>(key));
+ }
+
DenseIntElementsAttr indices;
DenseElementsAttr values;
};
using namespace mlir;
using namespace mlir::detail;
-Attribute::Kind Attribute::getKind() const { return attr->kind; }
+Attribute::Kind Attribute::getKind() const {
+ return static_cast<Kind>(attr->getKind());
+}
bool Attribute::isOrContainsFunction() const {
return attr->isOrContainsFunctionCache;
}
//===----------------------------------------------------------------------===//
+// UnitAttr
+//===----------------------------------------------------------------------===//
+
+UnitAttr UnitAttr::get(MLIRContext *context) {
+ return AttributeUniquer::get<UnitAttr>(context, Attribute::Kind::Unit);
+}
+
+//===----------------------------------------------------------------------===//
// NumericAttr
//===----------------------------------------------------------------------===//
// BoolAttr
//===----------------------------------------------------------------------===//
+BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
+ // Note: The context is also used within the BoolAttrStorage.
+ return AttributeUniquer::get<BoolAttr>(context, Attribute::Kind::Bool,
+ context, value);
+}
+
bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
Type BoolAttr::getType() const { return static_cast<ImplType *>(attr)->type; }
// IntegerAttr
//===----------------------------------------------------------------------===//
+IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
+ return AttributeUniquer::get<IntegerAttr>(
+ type.getContext(), Attribute::Kind::Integer, type, value);
+}
+
+IntegerAttr IntegerAttr::get(Type type, int64_t value) {
+ // This uses 64 bit APInts by default for index type.
+ if (type.isIndex())
+ return get(type, APInt(64, value));
+
+ auto intType = type.cast<IntegerType>();
+ return get(type, APInt(intType.getWidth(), value));
+}
+
APInt IntegerAttr::getValue() const {
return static_cast<ImplType *>(attr)->getValue();
}
// FloatAttr
//===----------------------------------------------------------------------===//
+FloatAttr FloatAttr::get(Type type, const APFloat &value) {
+ assert(&type.cast<FloatType>().getFloatSemantics() == &value.getSemantics() &&
+ "FloatAttr type doesn't match the type implied by its value");
+ return AttributeUniquer::get<FloatAttr>(type.getContext(),
+ Attribute::Kind::Float, type, value);
+}
+
+static FloatAttr getFloatAttr(Type type, double value,
+ llvm::Optional<Location> loc) {
+ if (!type.isa<FloatType>()) {
+ if (loc)
+ type.getContext()->emitError(*loc, "expected floating point type");
+ return nullptr;
+ }
+
+ // Treat BF16 as double because it is not supported in LLVM's APFloat.
+ // TODO(b/121118307): add BF16 support to APFloat?
+ if (type.isBF16() || type.isF64())
+ return FloatAttr::get(type, APFloat(value));
+
+ // This handles, e.g., F16 because there is no APFloat constructor for it.
+ bool unused;
+ APFloat val(value);
+ val.convert(type.cast<FloatType>().getFloatSemantics(),
+ APFloat::rmNearestTiesToEven, &unused);
+ return FloatAttr::get(type, val);
+}
+
+FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
+ return getFloatAttr(type, value, loc);
+}
+
+FloatAttr FloatAttr::get(Type type, double value) {
+ auto res = getFloatAttr(type, value, /*loc=*/llvm::None);
+ assert(res && "failed to construct float attribute");
+ return res;
+}
+
APFloat FloatAttr::getValue() const {
return static_cast<ImplType *>(attr)->getValue();
}
// StringAttr
//===----------------------------------------------------------------------===//
+StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
+ return AttributeUniquer::get<StringAttr>(context, Attribute::Kind::String,
+ bytes);
+}
+
StringRef StringAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
// ArrayAttr
//===----------------------------------------------------------------------===//
+ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
+ return AttributeUniquer::get<ArrayAttr>(context, Attribute::Kind::Array,
+ value);
+}
+
ArrayRef<Attribute> ArrayAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
// AffineMapAttr
//===----------------------------------------------------------------------===//
+AffineMapAttr AffineMapAttr::get(AffineMap value) {
+ return AttributeUniquer::get<AffineMapAttr>(
+ value.getResult(0).getContext(), Attribute::Kind::AffineMap, value);
+}
+
AffineMap AffineMapAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
// IntegerSetAttr
//===----------------------------------------------------------------------===//
+IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
+ return AttributeUniquer::get<IntegerSetAttr>(
+ value.getConstraint(0).getContext(), Attribute::Kind::IntegerSet, value);
+}
+
IntegerSet IntegerSetAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
// TypeAttr
//===----------------------------------------------------------------------===//
+TypeAttr TypeAttr::get(Type value, MLIRContext *context) {
+ return AttributeUniquer::get<TypeAttr>(context, Attribute::Kind::Type, value);
+}
+
Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
//===----------------------------------------------------------------------===//
// FunctionAttr
//===----------------------------------------------------------------------===//
+FunctionAttr FunctionAttr::get(Function *value, MLIRContext *context) {
+ assert(value && "Cannot get FunctionAttr for a null function");
+ return AttributeUniquer::get<FunctionAttr>(context, Attribute::Kind::Function,
+ value);
+}
+
+/// This function is used by the internals of the Function class to null out
+/// attributes referring to functions that are about to be deleted.
+void FunctionAttr::dropFunctionReference(Function *value) {
+ AttributeUniquer::erase<FunctionAttr>(value->getContext(),
+ Attribute::Kind::Function, value);
+}
+
Function *FunctionAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
// SplatElementsAttr
//===----------------------------------------------------------------------===//
+SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
+ Attribute elt) {
+ assert(elt.cast<NumericAttr>().getType() == type.getElementType() &&
+ "value should be of the given type");
+ return AttributeUniquer::get<SplatElementsAttr>(
+ type.getContext(), Attribute::Kind::SplatElements, type, elt);
+}
+
Attribute SplatElementsAttr::getValue() const {
return static_cast<ImplType *>(attr)->elt;
}
// DenseElementsAttr
//===----------------------------------------------------------------------===//
+DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
+ ArrayRef<char> data) {
+ assert((type.getSizeInBits() <= data.size() * APInt::APINT_WORD_SIZE) &&
+ "Input data bit size should be larger than that type requires");
+
+ Attribute::Kind kind;
+ switch (type.getElementType().getKind()) {
+ case StandardTypes::BF16:
+ case StandardTypes::F16:
+ case StandardTypes::F32:
+ case StandardTypes::F64:
+ kind = Attribute::Kind::DenseFPElements;
+ break;
+ case StandardTypes::Integer:
+ kind = Attribute::Kind::DenseIntElements;
+ break;
+ default:
+ llvm_unreachable("unexpected element type");
+ }
+ return AttributeUniquer::get<DenseElementsAttr>(type.getContext(), kind, type,
+ data);
+}
+
+DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
+ ArrayRef<Attribute> values) {
+ assert(type.getElementType().isIntOrFloat() &&
+ "expected int or float element type");
+ assert(values.size() == type.getNumElements() &&
+ "expected 'values' to contain the same number of elements as 'type'");
+
+ // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
+ // with double semantics.
+ auto eltType = type.getElementType();
+ size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
+
+ // Compress the attribute values into a character buffer.
+ SmallVector<char, 8> data(APInt::getNumWords(bitWidth * values.size()) *
+ APInt::APINT_WORD_SIZE);
+ APInt intVal;
+ for (unsigned i = 0, e = values.size(); i < e; ++i) {
+ switch (eltType.getKind()) {
+ case StandardTypes::BF16:
+ case StandardTypes::F16:
+ case StandardTypes::F32:
+ case StandardTypes::F64:
+ assert(eltType == values[i].cast<FloatAttr>().getType() &&
+ "expected attribute value to have element type");
+ intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
+ break;
+ case StandardTypes::Integer:
+ assert(eltType == values[i].cast<IntegerAttr>().getType() &&
+ "expected attribute value to have element type");
+ intVal = values[i].cast<IntegerAttr>().getValue();
+ break;
+ default:
+ llvm_unreachable("unexpected element type");
+ }
+ assert(intVal.getBitWidth() == bitWidth &&
+ "expected value to have same bitwidth as element type");
+ writeBits(data.data(), i * bitWidth, intVal);
+ }
+ return get(type, data);
+}
+
/// Returns the number of elements held by this attribute.
size_t DenseElementsAttr::size() const { return getType().getNumElements(); }
// OpaqueElementsAttr
//===----------------------------------------------------------------------===//
+OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect,
+ VectorOrTensorType type,
+ StringRef bytes) {
+ assert(TensorType::isValidElementType(type.getElementType()) &&
+ "Input element type should be a valid tensor element type");
+ return AttributeUniquer::get<OpaqueElementsAttr>(
+ type.getContext(), Attribute::Kind::OpaqueElements, type, dialect, bytes);
+}
+
StringRef OpaqueElementsAttr::getValue() const {
return static_cast<ImplType *>(attr)->bytes;
}
// SparseElementsAttr
//===----------------------------------------------------------------------===//
+SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type,
+ DenseIntElementsAttr indices,
+ DenseElementsAttr values) {
+ assert(indices.getType().getElementType().isInteger(64) &&
+ "expected sparse indices to be 64-bit integer values");
+ return AttributeUniquer::get<SparseElementsAttr>(
+ type.getContext(), Attribute::Kind::SparseElements, type, indices,
+ values);
+}
+
DenseIntElementsAttr SparseElementsAttr::getIndices() const {
return static_cast<ImplType *>(attr)->indices;
}
}
};
-struct FloatAttrKeyInfo : DenseMapInfo<FloatAttributeStorage *> {
- // Float attributes are uniqued based on wrapped APFloat.
- using KeyTy = std::pair<Type, APFloat>;
- using DenseMapInfo<FloatAttributeStorage *>::isEqual;
-
- static unsigned getHashValue(FloatAttributeStorage *key) {
- return getHashValue(KeyTy(key->type, key->getValue()));
- }
-
- static unsigned getHashValue(KeyTy key) {
- return hash_combine(key.first, llvm::hash_value(key.second));
- }
-
- static bool isEqual(const KeyTy &lhs, const FloatAttributeStorage *rhs) {
- if (rhs == getEmptyKey() || rhs == getTombstoneKey())
- return false;
- return lhs.first == rhs->type && lhs.second.bitwiseIsEqual(rhs->getValue());
- }
-};
-
-struct IntegerAttrKeyInfo : DenseMapInfo<IntegerAttributeStorage *> {
- // Integer attributes are uniqued based on wrapped APInt.
- using KeyTy = std::pair<Type, APInt>;
- using DenseMapInfo<IntegerAttributeStorage *>::isEqual;
-
- static unsigned getHashValue(IntegerAttributeStorage *key) {
- return getHashValue(KeyTy(key->type, key->getValue()));
- }
-
- static unsigned getHashValue(KeyTy key) {
- return hash_combine(key.first, llvm::hash_value(key.second));
- }
-
- static bool isEqual(const KeyTy &lhs, const IntegerAttributeStorage *rhs) {
- if (rhs == getEmptyKey() || rhs == getTombstoneKey())
- return false;
- assert((lhs.first.isIndex() || (lhs.first.isa<IntegerType>() &&
- lhs.first.cast<IntegerType>().getWidth() ==
- lhs.second.getBitWidth())) &&
- "mismatching integer type and value bitwidth");
- return lhs.first == rhs->type && lhs.second == rhs->getValue();
- }
-};
-
-struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttributeStorage *> {
- // Array attributes are uniqued based on their elements.
- using KeyTy = ArrayRef<Attribute>;
- using DenseMapInfo<ArrayAttributeStorage *>::isEqual;
-
- static unsigned getHashValue(ArrayAttributeStorage *key) {
- return getHashValue(KeyTy(key->value));
- }
-
- static unsigned getHashValue(KeyTy key) {
- return hash_combine_range(key.begin(), key.end());
- }
-
- static bool isEqual(const KeyTy &lhs, const ArrayAttributeStorage *rhs) {
- if (rhs == getEmptyKey() || rhs == getTombstoneKey())
- return false;
- return lhs == rhs->value;
- }
-};
-
struct AttributeListKeyInfo : DenseMapInfo<AttributeListStorage *> {
// Array attributes are uniqued based on their elements.
using KeyTy = ArrayRef<NamedAttribute>;
}
};
-struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
- using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>;
- using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual;
-
- static unsigned getHashValue(DenseElementsAttributeStorage *key) {
- return getHashValue(KeyTy(key->type, key->data));
- }
-
- static unsigned getHashValue(KeyTy key) {
- return hash_combine(
- key.first, hash_combine_range(key.second.begin(), key.second.end()));
- }
-
- static bool isEqual(const KeyTy &lhs,
- const DenseElementsAttributeStorage *rhs) {
- if (rhs == getEmptyKey() || rhs == getTombstoneKey())
- return false;
- return lhs == std::make_pair(rhs->type, rhs->data);
- }
-};
-
-struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> {
- // Opaque element attributes are uniqued based on their dialect, type and
- // value.
- using KeyTy = std::tuple<Dialect *, VectorOrTensorType, StringRef>;
- using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual;
-
- static unsigned getHashValue(OpaqueElementsAttributeStorage *key) {
- return getHashValue(KeyTy(key->dialect, key->type, key->bytes));
- }
-
- static unsigned getHashValue(KeyTy key) {
- auto bytes = std::get<2>(key);
- return hash_combine(std::get<0>(key), std::get<1>(key),
- hash_combine_range(bytes.begin(), bytes.end()));
- }
-
- static bool isEqual(const KeyTy &lhs,
- const OpaqueElementsAttributeStorage *rhs) {
- if (rhs == getEmptyKey() || rhs == getTombstoneKey())
- return false;
- return lhs == std::make_tuple(rhs->dialect, rhs->type, rhs->bytes);
- }
-};
-
struct CallSiteLocationKeyInfo : DenseMapInfo<CallSiteLocationStorage *> {
// Call locations are uniqued based on their held concret location
// and the caller location.
//===--------------------------------------------------------------------===//
// Attribute uniquing
//===--------------------------------------------------------------------===//
+ StorageUniquer attributeUniquer;
- // Attribute allocator and mutex for thread safety.
+ // Attribute list allocator and mutex for thread safety.
llvm::BumpPtrAllocator attributeAllocator;
llvm::sys::SmartRWMutex<true> attributeMutex;
- UnitAttributeStorage unitAttr;
- BoolAttributeStorage *boolAttrs[2] = {nullptr};
- DenseSet<IntegerAttributeStorage *, IntegerAttrKeyInfo> integerAttrs;
- DenseSet<FloatAttributeStorage *, FloatAttrKeyInfo> floatAttrs;
- llvm::StringMap<StringAttributeStorage *> stringAttrs;
- using ArrayAttrSet = DenseSet<ArrayAttributeStorage *, ArrayAttrKeyInfo>;
- ArrayAttrSet arrayAttrs;
- DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs;
- DenseMap<IntegerSet, IntegerSetAttributeStorage *> integerSetAttrs;
- DenseMap<Type, TypeAttributeStorage *> typeAttrs;
using AttributeListSet =
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
AttributeListSet attributeLists;
- DenseMap<Function *, FunctionAttributeStorage *> functionAttrs;
- DenseMap<std::pair<Type, Attribute>, SplatElementsAttributeStorage *>
- splatElementsAttrs;
- using DenseElementsAttrSet =
- DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>;
- DenseElementsAttrSet denseElementsAttrs;
- using OpaqueElementsAttrSet =
- DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>;
- OpaqueElementsAttrSet opaqueElementsAttrs;
- DenseMap<std::tuple<Type, Attribute, Attribute>,
- SparseElementsAttributeStorage *>
- sparseElementsAttrs;
public:
MLIRContextImpl()
// Attribute uniquing
//===----------------------------------------------------------------------===//
-UnitAttr UnitAttr::get(MLIRContext *context) {
- return &context->getImpl().unitAttr;
-}
-
-BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
- auto &impl = context->getImpl();
-
- { // Check for an existing instance in read-only mode.
- llvm::sys::SmartScopedReader<true> attributeLock(impl.attributeMutex);
- if (auto *result = impl.boolAttrs[value])
- return result;
- }
-
- // Aquire the mutex in write mode so that we can safely construct the new
- // instance.
- llvm::sys::SmartScopedWriter<true> attributeLock(impl.attributeMutex);
-
- // Check for an existing instance again here, because another writer thread
- // may have already created one.
- auto *&result = impl.boolAttrs[value];
- if (result)
- return result;
-
- result = impl.attributeAllocator.Allocate<BoolAttributeStorage>();
- new (result) BoolAttributeStorage(IntegerType::get(1, context), value);
- return result;
-}
-
-IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
- auto &impl = type.getContext()->getImpl();
- IntegerAttrKeyInfo::KeyTy key({type, value});
-
- // Safely get or create an attribute instance.
- return safeGetOrCreate(impl.integerAttrs, key, impl.attributeMutex, [&] {
- auto elements = ArrayRef<uint64_t>(value.getRawData(), value.getNumWords());
-
- auto byteSize =
- IntegerAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
- auto rawMem = impl.attributeAllocator.Allocate(
- byteSize, alignof(IntegerAttributeStorage));
- auto result = ::new (rawMem) IntegerAttributeStorage(type, elements.size());
- std::uninitialized_copy(elements.begin(), elements.end(),
- result->getTrailingObjects<uint64_t>());
- return result;
- });
-}
-
-IntegerAttr IntegerAttr::get(Type type, int64_t value) {
- // This uses 64 bit APInts by default for index type.
- if (type.isIndex())
- return get(type, APInt(64, value));
-
- auto intType = type.dyn_cast<IntegerType>();
- assert(intType && "expected an integer type for an integer attribute");
- return get(type, APInt(intType.getWidth(), value));
-}
-
-static FloatAttr getFloatAttr(Type type, double value,
- llvm::Optional<Location> loc) {
- if (!type.isa<FloatType>()) {
- if (loc)
- type.getContext()->emitError(*loc, "expected floating point type");
- return nullptr;
- }
-
- // Treat BF16 as double because it is not supported in LLVM's APFloat.
- // TODO(jpienaar): add BF16 support to APFloat?
- if (type.isBF16() || type.isF64())
- return FloatAttr::get(type, APFloat(value));
-
- // This handles, e.g., F16 because there is no APFloat constructor for it.
- bool unused;
- APFloat val(value);
- val.convert(type.cast<FloatType>().getFloatSemantics(),
- APFloat::rmNearestTiesToEven, &unused);
- return FloatAttr::get(type, val);
-}
-
-FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
- return getFloatAttr(type, value, loc);
-}
-
-FloatAttr FloatAttr::get(Type type, double value) {
- auto res = getFloatAttr(type, value, /*loc=*/llvm::None);
- assert(res && "failed to construct float attribute");
- return res;
-}
-
-FloatAttr FloatAttr::get(Type type, const APFloat &value) {
- auto fltType = type.cast<FloatType>();
- assert(&fltType.getFloatSemantics() == &value.getSemantics() &&
- "FloatAttr type doesn't match the type implied by its value");
- (void)fltType;
- auto &impl = type.getContext()->getImpl();
- FloatAttrKeyInfo::KeyTy key({type, value});
-
- // Safely get or create an attribute instance.
- return safeGetOrCreate(impl.floatAttrs, key, impl.attributeMutex, [&] {
- const auto &apint = value.bitcastToAPInt();
- // Here one word's bitwidth equals to that of uint64_t.
- auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());
-
- auto byteSize =
- FloatAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
- auto rawMem = impl.attributeAllocator.Allocate(
- byteSize, alignof(FloatAttributeStorage));
- auto result = ::new (rawMem)
- FloatAttributeStorage(value.getSemantics(), type, elements.size());
- std::uninitialized_copy(elements.begin(), elements.end(),
- result->getTrailingObjects<uint64_t>());
- return result;
- });
-}
-
-StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
- auto &impl = context->getImpl();
-
- { // Check for an existing instance in read-only mode.
- llvm::sys::SmartScopedReader<true> attributeLock(impl.attributeMutex);
- auto it = impl.stringAttrs.find(bytes);
- if (it != impl.stringAttrs.end())
- return it->second;
- }
-
- // Aquire the mutex in write mode so that we can safely construct the new
- // instance.
- llvm::sys::SmartScopedWriter<true> attributeLock(impl.attributeMutex);
-
- // Check for an existing instance again here, because another writer thread
- // may have already created one.
- auto it = impl.stringAttrs.insert({bytes, nullptr}).first;
- if (it->second)
- return it->second;
-
- auto result = new (impl.attributeAllocator.Allocate<StringAttributeStorage>())
- StringAttributeStorage(it->first());
- return it->second = result;
-}
-
-ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
- auto &impl = context->getImpl();
-
- // Safely get or create an attribute instance.
- return safeGetOrCreate(impl.arrayAttrs, value, impl.attributeMutex, [&] {
- auto *result = impl.attributeAllocator.Allocate<ArrayAttributeStorage>();
-
- // Copy the elements into the bump pointer.
- value = copyArrayRefInto(impl.attributeAllocator, value);
-
- // Check to see if any of the elements have a function attr.
- bool hasFunctionAttr = false;
- for (auto elt : value)
- if (elt.isOrContainsFunction()) {
- hasFunctionAttr = true;
- break;
- }
-
- // Initialize the memory using placement new.
- return new (result) ArrayAttributeStorage(hasFunctionAttr, value);
- });
-}
-
-AffineMapAttr AffineMapAttr::get(AffineMap value) {
- auto *context = value.getResult(0).getContext();
- auto &impl = context->getImpl();
-
- // Safely get or create an attribute instance.
- return safeGetOrCreate(impl.affineMapAttrs, value, impl.attributeMutex, [&] {
- auto result = impl.attributeAllocator.Allocate<AffineMapAttributeStorage>();
- return new (result) AffineMapAttributeStorage(value);
- });
-}
-
-IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
- auto *context = value.getConstraint(0).getContext();
- auto &impl = context->getImpl();
-
- // Safely get or create an attribute instance.
- return safeGetOrCreate(impl.integerSetAttrs, value, impl.attributeMutex, [&] {
- auto result =
- impl.attributeAllocator.Allocate<IntegerSetAttributeStorage>();
- return new (result) IntegerSetAttributeStorage(value);
- });
-}
-
-TypeAttr TypeAttr::get(Type type, MLIRContext *context) {
- auto &impl = context->getImpl();
-
- // Safely get or create an attribute instance.
- return safeGetOrCreate(impl.typeAttrs, type, impl.attributeMutex, [&] {
- auto result = impl.attributeAllocator.Allocate<TypeAttributeStorage>();
- return new (result) TypeAttributeStorage(type);
- });
-}
-
-FunctionAttr FunctionAttr::get(Function *value, MLIRContext *context) {
- assert(value && "Cannot get FunctionAttr for a null function");
- auto &impl = context->getImpl();
-
- // Safely get or create an attribute instance.
- return safeGetOrCreate(impl.functionAttrs, value, impl.attributeMutex, [&] {
- auto result = impl.attributeAllocator.Allocate<FunctionAttributeStorage>();
- return new (result) FunctionAttributeStorage(value);
- });
-}
-
-/// This function is used by the internals of the Function class to null out
-/// attributes referring to functions that are about to be deleted.
-void FunctionAttr::dropFunctionReference(Function *value) {
- auto &impl = value->getContext()->getImpl();
-
- // Aquire the mutex in write mode so that we can safely remove the attribute
- // if it exists.
- llvm::sys::SmartScopedWriter<true> attributeLock(impl.attributeMutex);
-
- // Check to see if there was an attribute referring to this function.
- auto &functionAttrs = impl.functionAttrs;
-
- // If not, then we're done.
- auto it = functionAttrs.find(value);
- if (it == functionAttrs.end())
- return;
-
- // If so, null out the function reference in the attribute (to avoid dangling
- // pointers) and remove the entry from the map so the map doesn't contain
- // dangling keys.
- it->second->value = nullptr;
- functionAttrs.erase(it);
+/// Returns the storage uniquer used for constructing attribute storage
+/// instances. This should not be used directly.
+StorageUniquer &MLIRContext::getAttributeUniquer() {
+ return getImpl().attributeUniquer;
}
/// Perform a three-way comparison between the names of the specified
});
}
-// Returns false if the given `attr` is not of the given `type`.
-// Note: This function is only intended to be used for assertion. So it's
-// possibly allowing invalid cases that are unimplemented.
-static bool attrIsOfType(Attribute attr, Type type) {
- if (auto floatAttr = attr.dyn_cast<FloatAttr>())
- return floatAttr.getType() == type;
- if (auto intAttr = attr.dyn_cast<IntegerAttr>())
- return intAttr.getType() == type;
- if (auto elementsAttr = attr.dyn_cast<ElementsAttr>())
- return elementsAttr.getType() == type;
- // TODO: check the other cases
- return true;
-}
-
-SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
- Attribute elt) {
- auto attr = elt.dyn_cast<NumericAttr>();
- assert(attr && "expected numeric value");
- assert(attr.getType() == type.getElementType() &&
- "value should be of the given type");
- (void)attr;
-
- auto &impl = type.getContext()->getImpl();
-
- // Safely get or create an attribute instance.
- std::pair<Type, Attribute> key(type, elt);
- return safeGetOrCreate(
- impl.splatElementsAttrs, key, impl.attributeMutex, [&] {
- auto result =
- impl.attributeAllocator.Allocate<SplatElementsAttributeStorage>();
- return new (result) SplatElementsAttributeStorage(type, elt);
- });
-}
-
-DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
- ArrayRef<char> data) {
- auto bitsRequired = type.getSizeInBits();
- (void)bitsRequired;
- assert((bitsRequired <= data.size() * APInt::APINT_WORD_SIZE) &&
- "Input data bit size should be larger than that type requires");
-
- auto &impl = type.getContext()->getImpl();
- DenseElementsAttrInfo::KeyTy key({type, data});
-
- // Safely get or create an attribute instance.
- return safeGetOrCreate(
- impl.denseElementsAttrs, key, impl.attributeMutex, [&] {
- Attribute::Kind kind;
- switch (type.getElementType().getKind()) {
- case StandardTypes::BF16:
- case StandardTypes::F16:
- case StandardTypes::F32:
- case StandardTypes::F64:
- kind = Attribute::Kind::DenseFPElements;
- break;
- case StandardTypes::Integer:
- kind = Attribute::Kind::DenseIntElements;
- break;
- default:
- llvm_unreachable("unexpected element type");
- }
-
- // If the data buffer is non-empty, we copy it into the context.
- ArrayRef<char> copy;
- if (!data.empty()) {
- // Rounding up the allocate size to multiples of APINT_WORD_SIZE, so
- // the `readBits` will not fail when it accesses multiples of
- // APINT_WORD_SIZE each time.
- size_t sizeToAllocate =
- llvm::alignTo(data.size(), APInt::APINT_WORD_SIZE);
- auto *rawCopy =
- (char *)impl.attributeAllocator.Allocate(sizeToAllocate, 64);
- std::uninitialized_copy(data.begin(), data.end(), rawCopy);
- copy = {rawCopy, data.size()};
- }
- auto *result =
- impl.attributeAllocator.Allocate<DenseElementsAttributeStorage>();
- return new (result) DenseElementsAttributeStorage(kind, type, copy);
- });
-}
-
-DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
- ArrayRef<Attribute> values) {
- assert(type.getElementType().isIntOrFloat() &&
- "expected int or float element type");
- assert(values.size() == type.getNumElements() &&
- "expected 'values' to contain the same number of elements as 'type'");
-
- // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
- // with double semantics.
- auto eltType = type.getElementType();
- size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
-
- // Compress the attribute values into a character buffer.
- SmallVector<char, 8> data(APInt::getNumWords(bitWidth * values.size()) *
- APInt::APINT_WORD_SIZE);
- APInt intVal;
- for (unsigned i = 0, e = values.size(); i < e; ++i) {
- switch (eltType.getKind()) {
- case StandardTypes::BF16:
- case StandardTypes::F16:
- case StandardTypes::F32:
- case StandardTypes::F64:
- assert(eltType == values[i].cast<FloatAttr>().getType() &&
- "expected attribute value to have element type");
- intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
- break;
- case StandardTypes::Integer:
- assert(eltType == values[i].cast<IntegerAttr>().getType() &&
- "expected attribute value to have element type");
- intVal = values[i].cast<IntegerAttr>().getValue();
- break;
- default:
- llvm_unreachable("unexpected element type");
- }
- assert(intVal.getBitWidth() == bitWidth &&
- "expected value to have same bitwidth as element type");
- writeBits(data.data(), i * bitWidth, intVal);
- }
- return get(type, data);
-}
-
-OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect,
- VectorOrTensorType type,
- StringRef bytes) {
- assert(TensorType::isValidElementType(type.getElementType()) &&
- "Input element type should be a valid tensor element type");
-
- auto &impl = type.getContext()->getImpl();
- OpaqueElementsAttrInfo::KeyTy key(dialect, type, bytes);
-
- return safeGetOrCreate(
- impl.opaqueElementsAttrs, key, impl.attributeMutex, [&] {
- auto *result =
- impl.attributeAllocator.Allocate<OpaqueElementsAttributeStorage>();
-
- // TODO: Provide a way to avoid copying content of large opaque tensors
- // This will likely require a new reference attribute kind.
- bytes = bytes.copy(impl.attributeAllocator);
- return new (result)
- OpaqueElementsAttributeStorage(type, dialect, bytes);
- });
-}
-
-SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type,
- DenseIntElementsAttr indices,
- DenseElementsAttr values) {
- assert(indices.getType().getElementType().isInteger(64) &&
- "expected sparse indices to be 64-bit integer values");
-
- auto &impl = type.getContext()->getImpl();
- auto key = std::make_tuple(type, indices, values);
-
- // Safely get or create an attribute instance.
- return safeGetOrCreate(
- impl.sparseElementsAttrs, key, impl.attributeMutex, [&] {
- return new (
- impl.attributeAllocator.Allocate<SparseElementsAttributeStorage>())
- SparseElementsAttributeStorage(type, indices, values);
- });
-}
-
//===----------------------------------------------------------------------===//
// AffineMap and AffineExpr uniquing
//===----------------------------------------------------------------------===//
return result = initializeStorage(kind, ctorFn);
}
+ /// Erase an instance of a complex derived type.
+ void erase(unsigned kind, unsigned hashValue,
+ llvm::function_ref<bool(const BaseStorage *)> isEqual,
+ llvm::function_ref<void(BaseStorage *)> cleanupFn) {
+ LookupKey lookupKey{kind, hashValue, isEqual};
+
+ // Acquire a writer-lock so that we can safely erase the type instance.
+ llvm::sys::SmartScopedWriter<true> typeLock(mutex);
+ auto existing = storageTypes.find_as(lookupKey);
+ if (existing == storageTypes.end())
+ return;
+
+ // Cleanup the storage and remove it from the map.
+ cleanupFn(existing->storage);
+ storageTypes.erase(existing);
+ }
+
//===--------------------------------------------------------------------===//
// Instance Storage
//===--------------------------------------------------------------------===//
-> BaseStorage * {
return impl->getOrCreate(kind, ctorFn);
}
+
+/// Implementation for erasing an instance of a derived type with complex
+/// storage.
+void StorageUniquer::eraseImpl(
+ unsigned kind, unsigned hashValue,
+ llvm::function_ref<bool(const BaseStorage *)> isEqual,
+ std::function<void(BaseStorage *)> cleanupFn) {
+ impl->erase(kind, hashValue, isEqual, cleanupFn);
+}