#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/StorageUniquer.h"
#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/PointerIntPair.h"
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
namespace detail {
-/// Base storage class appearing in an attribute.
+/// Base storage class appearing in an attribute. Derived storage classes should
+/// only be constructed within the context of the AttributeUniquer.
struct AttributeStorage : public StorageUniquer::BaseStorage {
- AttributeStorage(bool isOrContainsFunctionCache = false)
- : isOrContainsFunctionCache(isOrContainsFunctionCache) {}
-
- /// This field is true if this is, or contains, a function attribute.
- bool isOrContainsFunctionCache : 1;
+ /// Construct a new attribute storage instance with the given type and a
+ /// boolean that signals if the derived attribute is or contains a function
+ /// pointer.
+ /// Note: All attributes require a valid type. If a null type is provided
+ /// here, the type of the attribute will automatically default to
+ /// NoneType upon initialization in the uniquer.
+ AttributeStorage(Type type = {}, bool isOrContainsFunctionCache = false)
+ : typeAndContainsFunctionAttrPair(type, isOrContainsFunctionCache) {}
+ AttributeStorage(bool isOrContainsFunctionCache)
+ : AttributeStorage(/*type=*/{}, isOrContainsFunctionCache) {}
+
+ bool isOrContainsFunctionCache() const {
+ return typeAndContainsFunctionAttrPair.getInt();
+ }
+
+ Type getType() const { return typeAndContainsFunctionAttrPair.getPointer(); }
+ void setType(Type type) { typeAndContainsFunctionAttrPair.setPointer(type); }
+
+ /// This field is a pair of:
+ /// - The type of the attribute value.
+ /// - A boolean that is true if this is, or contains, a function attribute.
+ llvm::PointerIntPair<Type, 1, bool> typeAndContainsFunctionAttrPair;
};
// A utility class to get, or create, unique instances of attributes within an
!std::is_same<typename T::ImplType, AttributeStorage>::value, T>::type
get(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) {
return ctx->getAttributeUniquer().getComplex<typename T::ImplType>(
- /*initFn=*/{}, static_cast<unsigned>(kind),
+ getInitFn(ctx), static_cast<unsigned>(kind),
std::forward<Args>(args)...);
}
std::is_same<typename T::ImplType, AttributeStorage>::value, T>::type
get(MLIRContext *ctx, Attribute::Kind kind) {
return ctx->getAttributeUniquer().getSimple<AttributeStorage>(
- /*initFn=*/{}, static_cast<unsigned>(kind));
+ getInitFn(ctx), static_cast<unsigned>(kind));
}
/// Erase a uniqued instance of attribute T. This overload is used for
return ctx->getAttributeUniquer().eraseComplex<typename T::ImplType>(
static_cast<unsigned>(kind), std::forward<Args>(args)...);
}
+
+ /// Generate a functor to initialize a new attribute storage instance.
+ static std::function<void(AttributeStorage *)> getInitFn(MLIRContext *ctx) {
+ return [ctx](AttributeStorage *storage) {
+ // If the attribute did not provide a type, then default to NoneType.
+ if (!storage->getType())
+ storage->setType(NoneType::get(ctx));
+ };
+ }
};
using AttributeStorageAllocator = StorageUniquer::StorageAllocator;
struct BoolAttributeStorage : public AttributeStorage {
using KeyTy = std::pair<MLIRContext *, bool>;
- BoolAttributeStorage(Type type, bool value) : type(type), value(value) {}
+ BoolAttributeStorage(Type type, bool value)
+ : AttributeStorage(type), value(value) {}
/// We only check equality for and hash with the boolean key parameter.
bool operator==(const KeyTy &key) const { return key.second == value; }
BoolAttributeStorage(IntegerType::get(1, key.first), key.second);
}
- Type type;
bool value;
};
using KeyTy = std::pair<Type, APInt>;
IntegerAttributeStorage(Type type, size_t numObjects)
- : type(type), numObjects(numObjects) {
+ : AttributeStorage(type), numObjects(numObjects) {
assert((type.isIndex() || type.isa<IntegerType>()) && "invalid type");
}
/// Key equality and hash functions.
bool operator==(const KeyTy &key) const {
- return key == KeyTy(type, getValue());
+ return key == KeyTy(getType(), getValue());
}
static unsigned hashKey(const KeyTy &key) {
return llvm::hash_combine(key.first, llvm::hash_value(key.second));
/// Returns an APInt representing the stored value.
APInt getValue() const {
- if (type.isIndex())
+ if (getType().isIndex())
return APInt(64, {getTrailingObjects<uint64_t>(), numObjects});
- return APInt(type.getIntOrFloatBitWidth(),
+ return APInt(getType().getIntOrFloatBitWidth(),
{getTrailingObjects<uint64_t>(), numObjects});
}
- Type type;
size_t numObjects;
};
FloatAttributeStorage(const llvm::fltSemantics &semantics, Type type,
size_t numObjects)
- : semantics(semantics), type(type.cast<FloatType>()),
- numObjects(numObjects) {}
+ : AttributeStorage(type), semantics(semantics), numObjects(numObjects) {}
/// Key equality and hash functions.
bool operator==(const KeyTy &key) const {
- return key.first == type && key.second.bitwiseIsEqual(getValue());
+ return key.first == getType() && key.second.bitwiseIsEqual(getValue());
}
static unsigned hashKey(const KeyTy &key) {
return llvm::hash_combine(key.first, llvm::hash_value(key.second));
}
const llvm::fltSemantics &semantics;
- FloatType type;
size_t numObjects;
};
using KeyTy = Function *;
FunctionAttributeStorage(Function *value)
- : AttributeStorage(/*isOrContainsFunctionCache=*/true), value(value) {}
+ : AttributeStorage(value->getType(), /*isOrContainsFunctionCache=*/true),
+ value(value) {}
/// Key equality function.
bool operator==(const KeyTy &key) const { return key == value; }
Function *value;
};
-/// A base attribute representing a reference to a vector or tensor constant.
-struct ElementsAttributeStorage : public AttributeStorage {
- ElementsAttributeStorage(VectorOrTensorType type) : type(type) {}
- VectorOrTensorType type;
-};
-
/// An attribute representing a reference to a vector or tensor constant,
/// inwhich all elements have the same value.
-struct SplatElementsAttributeStorage : public ElementsAttributeStorage {
- using KeyTy = std::pair<VectorOrTensorType, Attribute>;
+struct SplatElementsAttributeStorage : public AttributeStorage {
+ using KeyTy = std::pair<Type, Attribute>;
- SplatElementsAttributeStorage(VectorOrTensorType type, Attribute elt)
- : ElementsAttributeStorage(type), elt(elt) {}
+ SplatElementsAttributeStorage(Type type, Attribute elt)
+ : AttributeStorage(type), elt(elt) {}
/// Key equality and hash functions.
bool operator==(const KeyTy &key) const {
- return key == std::make_pair(type, elt);
- }
- static unsigned hashKey(const KeyTy &key) {
- return llvm::hash_combine(key.first, key.second);
+ return key == std::make_pair(getType(), elt);
}
/// Construct a new storage instance.
};
/// An attribute representing a reference to a dense vector or tensor object.
-struct DenseElementsAttributeStorage : public ElementsAttributeStorage {
- using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>;
+struct DenseElementsAttributeStorage : public AttributeStorage {
+ using KeyTy = std::pair<Type, ArrayRef<char>>;
- DenseElementsAttributeStorage(VectorOrTensorType ty, ArrayRef<char> data)
- : ElementsAttributeStorage(ty), data(data) {}
+ DenseElementsAttributeStorage(Type ty, ArrayRef<char> data)
+ : AttributeStorage(ty), data(data) {}
/// Key equality and hash functions.
- bool operator==(const KeyTy &key) const { return key == KeyTy(type, data); }
- static unsigned hashKey(const KeyTy &key) {
- return llvm::hash_combine(key.first, key.second);
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(getType(), data);
}
/// Construct a new storage instance.
/// An attribute representing a reference to a tensor constant with opaque
/// content.
-struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage {
- using KeyTy = std::tuple<VectorOrTensorType, Dialect *, StringRef>;
+struct OpaqueElementsAttributeStorage : public AttributeStorage {
+ using KeyTy = std::tuple<Type, Dialect *, StringRef>;
- OpaqueElementsAttributeStorage(VectorOrTensorType type, Dialect *dialect,
- StringRef bytes)
- : ElementsAttributeStorage(type), dialect(dialect), bytes(bytes) {}
+ OpaqueElementsAttributeStorage(Type type, Dialect *dialect, StringRef bytes)
+ : AttributeStorage(type), dialect(dialect), bytes(bytes) {}
/// Key equality and hash functions.
bool operator==(const KeyTy &key) const {
- return key == std::make_tuple(type, dialect, bytes);
+ return key == std::make_tuple(getType(), dialect, bytes);
}
static unsigned hashKey(const KeyTy &key) {
return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
};
/// An attribute representing a reference to a sparse vector or tensor object.
-struct SparseElementsAttributeStorage : public ElementsAttributeStorage {
- using KeyTy =
- std::tuple<VectorOrTensorType, DenseIntElementsAttr, DenseElementsAttr>;
+struct SparseElementsAttributeStorage : public AttributeStorage {
+ using KeyTy = std::tuple<Type, DenseIntElementsAttr, DenseElementsAttr>;
- SparseElementsAttributeStorage(VectorOrTensorType type,
- DenseIntElementsAttr indices,
+ SparseElementsAttributeStorage(Type type, DenseIntElementsAttr indices,
DenseElementsAttr values)
- : ElementsAttributeStorage(type), indices(indices), values(values) {}
+ : AttributeStorage(type), indices(indices), values(values) {}
/// Key equality and hash functions.
bool operator==(const KeyTy &key) const {
- return key == std::make_tuple(type, indices, values);
+ return key == std::make_tuple(getType(), indices, values);
}
static unsigned hashKey(const KeyTy &key) {
return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
return static_cast<Kind>(attr->getKind());
}
+/// Return the type of this attribute.
+Type Attribute::getType() const { return attr->getType(); }
+
bool Attribute::isOrContainsFunction() const {
- return attr->isOrContainsFunctionCache;
+ return attr->isOrContainsFunctionCache();
}
// Given an attribute that could refer to a function attribute in the remapping
// NumericAttr
//===----------------------------------------------------------------------===//
-Type NumericAttr::getType() const {
- if (auto boolAttr = dyn_cast<BoolAttr>())
- return boolAttr.getType();
- if (auto intAttr = dyn_cast<IntegerAttr>())
- return intAttr.getType();
- if (auto floatAttr = dyn_cast<FloatAttr>())
- return floatAttr.getType();
- if (auto elemAttr = dyn_cast<ElementsAttr>())
- return elemAttr.getType();
-
- llvm_unreachable("unhandled NumericAttr subclass");
-}
-
bool NumericAttr::kindof(Kind kind) {
return BoolAttr::kindof(kind) || IntegerAttr::kindof(kind) ||
FloatAttr::kindof(kind) || ElementsAttr::kindof(kind);
bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
-Type BoolAttr::getType() const { return static_cast<ImplType *>(attr)->type; }
-
//===----------------------------------------------------------------------===//
// IntegerAttr
//===----------------------------------------------------------------------===//
int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
-Type IntegerAttr::getType() const {
- return static_cast<ImplType *>(attr)->type;
-}
-
//===----------------------------------------------------------------------===//
// FloatAttr
//===----------------------------------------------------------------------===//
return static_cast<ImplType *>(attr)->getValue();
}
-Type FloatAttr::getType() const { return static_cast<ImplType *>(attr)->type; }
-
double FloatAttr::getValueAsDouble() const {
const auto &semantics = getType().cast<FloatType>().getFloatSemantics();
auto value = getValue();
return static_cast<ImplType *>(attr)->value;
}
-FunctionType FunctionAttr::getType() const { return getValue()->getType(); }
+FunctionType FunctionAttr::getType() const {
+ return Attribute::getType().cast<FunctionType>();
+}
//===----------------------------------------------------------------------===//
// ElementsAttr
//===----------------------------------------------------------------------===//
VectorOrTensorType ElementsAttr::getType() const {
- return static_cast<ImplType *>(attr)->type;
+ return Attribute::getType().cast<VectorOrTensorType>();
}
/// Return the value at the given index. If index does not refer to a valid
SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
Attribute elt) {
- assert(elt.cast<NumericAttr>().getType() == type.getElementType() &&
- "value should be of the given type");
+ assert(elt.getType() == type.getElementType() &&
+ "value should be of the given element type");
return AttributeUniquer::get<SplatElementsAttr>(
type.getContext(), Attribute::Kind::SplatElements, type, elt);
}