#define MLIR_IR_ATTRIBUTESUPPORT_H
#include "mlir/IR/MLIRContext.h"
-#include "mlir/Support/StorageUniquer.h"
+#include "mlir/IR/StorageUniquerSupport.h"
#include "llvm/ADT/PointerIntPair.h"
#include "llvm/ADT/StringRef.h"
LAST_KIND = SparseElements,
};
+ /// Utility class for implementing attributes.
+ template <typename ConcreteType, typename BaseType = Attribute,
+ typename StorageType = AttributeStorage>
+ using AttrBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
+ detail::AttributeUniquer>;
+
using ImplType = AttributeStorage;
using ValueType = void;
- Attribute() : attr(nullptr) {}
- /* implicit */ Attribute(const ImplType *attr)
- : attr(const_cast<ImplType *>(attr)) {}
+ Attribute() : impl(nullptr) {}
+ /* implicit */ Attribute(const ImplType *impl)
+ : impl(const_cast<ImplType *>(impl)) {}
- Attribute(const Attribute &other) : attr(other.attr) {}
+ Attribute(const Attribute &other) : impl(other.impl) {}
Attribute &operator=(Attribute other) {
- attr = other.attr;
+ impl = other.impl;
return *this;
}
- bool operator==(Attribute other) const { return attr == other.attr; }
+ bool operator==(Attribute other) const { return impl == other.impl; }
bool operator!=(Attribute other) const { return !(*this == other); }
- explicit operator bool() const { return attr; }
+ explicit operator bool() const { return impl; }
- bool operator!() const { return attr == nullptr; }
+ bool operator!() const { return impl == nullptr; }
template <typename U> bool isa() const;
template <typename U> U dyn_cast() const;
void dump() const;
/// Get an opaque pointer to the attribute.
- const void *getAsOpaquePointer() const { return attr; }
+ const void *getAsOpaquePointer() const { return impl; }
/// Construct an attribute from the opaque pointer representation.
static Attribute getFromOpaquePointer(const void *ptr) {
return Attribute(
friend ::llvm::hash_code hash_value(Attribute arg);
protected:
- ImplType *attr;
+ ImplType *impl;
};
inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) {
/// Unit attributes are attributes that hold no specific value and are given
/// meaning by their existence.
-class UnitAttr : public Attribute {
+class UnitAttr : public Attribute::AttrBase<UnitAttr> {
public:
- using Attribute::Attribute;
+ using Base::Base;
static UnitAttr get(MLIRContext *context);
static bool kindof(Kind kind) { return kind == Attribute::Kind::Unit; }
};
-class BoolAttr : public Attribute {
+class BoolAttr : public Attribute::AttrBase<BoolAttr, Attribute,
+ detail::BoolAttributeStorage> {
public:
- using Attribute::Attribute;
- using ImplType = detail::BoolAttributeStorage;
+ using Base::Base;
using ValueType = bool;
static BoolAttr get(bool value, MLIRContext *context);
static bool kindof(Kind kind) { return kind == Kind::Bool; }
};
-class IntegerAttr : public Attribute {
+class IntegerAttr
+ : public Attribute::AttrBase<IntegerAttr, Attribute,
+ detail::IntegerAttributeStorage> {
public:
- using Attribute::Attribute;
- using ImplType = detail::IntegerAttributeStorage;
+ using Base::Base;
using ValueType = APInt;
static IntegerAttr get(Type type, int64_t value);
static bool kindof(Kind kind) { return kind == Kind::Integer; }
};
-class FloatAttr : public Attribute {
+class FloatAttr : public Attribute::AttrBase<FloatAttr, Attribute,
+ detail::FloatAttributeStorage> {
public:
- using Attribute::Attribute;
- using ImplType = detail::FloatAttributeStorage;
+ using Base::Base;
using ValueType = APFloat;
/// Return a float attribute for the specified value in the specified type.
/// Return a float attribute for the specified value in the specified type.
static FloatAttr get(Type type, const APFloat &value);
+ static FloatAttr getChecked(Type type, const APFloat &value, Location loc);
APFloat getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::Float; }
+
+ /// Verify the construction invariants for a double value.
+ static LogicalResult
+ verifyConstructionInvariants(llvm::Optional<Location> loc, MLIRContext *ctx,
+ Type type, double value);
+ static LogicalResult
+ verifyConstructionInvariants(llvm::Optional<Location> loc, MLIRContext *ctx,
+ Type type, const APFloat &value);
};
-class StringAttr : public Attribute {
+class StringAttr : public Attribute::AttrBase<StringAttr, Attribute,
+ detail::StringAttributeStorage> {
public:
- using Attribute::Attribute;
- using ImplType = detail::StringAttributeStorage;
+ using Base::Base;
using ValueType = StringRef;
static StringAttr get(StringRef bytes, MLIRContext *context);
/// Array attributes are lists of other attributes. They are not necessarily
/// type homogenous given that attributes don't, in general, carry types.
-class ArrayAttr : public Attribute {
+class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute,
+ detail::ArrayAttributeStorage> {
public:
- using Attribute::Attribute;
- using ImplType = detail::ArrayAttributeStorage;
+ using Base::Base;
using ValueType = ArrayRef<Attribute>;
static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context);
static bool kindof(Kind kind) { return kind == Kind::Array; }
};
-class AffineMapAttr : public Attribute {
+class AffineMapAttr
+ : public Attribute::AttrBase<AffineMapAttr, Attribute,
+ detail::AffineMapAttributeStorage> {
public:
- using Attribute::Attribute;
- using ImplType = detail::AffineMapAttributeStorage;
+ using Base::Base;
using ValueType = AffineMap;
static AffineMapAttr get(AffineMap value);
static bool kindof(Kind kind) { return kind == Kind::AffineMap; }
};
-class IntegerSetAttr : public Attribute {
+class IntegerSetAttr
+ : public Attribute::AttrBase<IntegerSetAttr, Attribute,
+ detail::IntegerSetAttributeStorage> {
public:
- using Attribute::Attribute;
- using ImplType = detail::IntegerSetAttributeStorage;
+ using Base::Base;
using ValueType = IntegerSet;
static IntegerSetAttr get(IntegerSet value);
static bool kindof(Kind kind) { return kind == Kind::IntegerSet; }
};
-class TypeAttr : public Attribute {
+class TypeAttr : public Attribute::AttrBase<TypeAttr, Attribute,
+ detail::TypeAttributeStorage> {
public:
- using Attribute::Attribute;
- using ImplType = detail::TypeAttributeStorage;
+ using Base::Base;
using ValueType = Type;
static TypeAttr get(Type value);
/// is deleted that had an attribute which referenced it. No references to this
/// attribute should persist across the transformation, but that attribute will
/// remain in MLIRContext.
-class FunctionAttr : public Attribute {
+class FunctionAttr
+ : public Attribute::AttrBase<FunctionAttr, Attribute,
+ detail::FunctionAttributeStorage> {
public:
- using Attribute::Attribute;
- using ImplType = detail::FunctionAttributeStorage;
+ using Base::Base;
using ValueType = Function *;
static FunctionAttr get(Function *value);
/// An attribute that represents a reference to a splat vecctor or tensor
/// constant, meaning all of the elements have the same value.
-class SplatElementsAttr : public ElementsAttr {
+class SplatElementsAttr
+ : public Attribute::AttrBase<SplatElementsAttr, ElementsAttr,
+ detail::SplatElementsAttributeStorage> {
public:
- using ElementsAttr::ElementsAttr;
- using ImplType = detail::SplatElementsAttributeStorage;
+ using Base::Base;
using ValueType = Attribute;
static SplatElementsAttr get(VectorOrTensorType type, Attribute elt);
/// An attribute that represents a reference to a dense integer vector or tensor
/// object.
-class DenseIntElementsAttr : public DenseElementsAttr {
+class DenseIntElementsAttr
+ : public Attribute::AttrBase<DenseIntElementsAttr, DenseElementsAttr,
+ detail::DenseElementsAttributeStorage> {
public:
/// DenseIntElementsAttr iterates on APInt, so we can use the raw element
/// iterator directly.
using iterator = DenseElementsAttr::RawElementIterator;
- using DenseElementsAttr::DenseElementsAttr;
+ using Base::Base;
using DenseElementsAttr::get;
using DenseElementsAttr::getValues;
- using DenseElementsAttr::ImplType;
/// Constructs a dense integer elements attribute from an array of APInt
/// values. Each APInt value is expected to have the same bitwidth as the
/// An attribute that represents a reference to a dense float vector or tensor
/// object. Each element is stored as a double.
-class DenseFPElementsAttr : public DenseElementsAttr {
+class DenseFPElementsAttr
+ : public Attribute::AttrBase<DenseFPElementsAttr, DenseElementsAttr,
+ detail::DenseElementsAttributeStorage> {
public:
/// DenseFPElementsAttr iterates on APFloat, so we need to wrap the raw
/// element iterator.
};
using iterator = ElementIterator;
- using DenseElementsAttr::DenseElementsAttr;
+ using Base::Base;
using DenseElementsAttr::get;
using DenseElementsAttr::getValues;
- using DenseElementsAttr::ImplType;
// Constructs a dense float elements attribute from an array of APFloat
// values. Each APFloat value is expected to have the same bitwidth as the
/// which the compiler may not need to interpret. This attribute is always
/// associated with a particular dialect, which provides a method to convert
/// tensor representation to a non-opaque format.
-class OpaqueElementsAttr : public ElementsAttr {
+class OpaqueElementsAttr
+ : public Attribute::AttrBase<OpaqueElementsAttr, ElementsAttr,
+ detail::OpaqueElementsAttributeStorage> {
public:
- using ElementsAttr::ElementsAttr;
- using ImplType = detail::OpaqueElementsAttributeStorage;
+ using Base::Base;
using ValueType = StringRef;
static OpaqueElementsAttr get(Dialect *dialect, VectorOrTensorType type,
/// [[1, 0, 0, 0],
/// [0, 0, 5, 0],
/// [0, 0, 0, 0]].
-class SparseElementsAttr : public ElementsAttr {
+class SparseElementsAttr
+ : public Attribute::AttrBase<SparseElementsAttr, ElementsAttr,
+ detail::SparseElementsAttributeStorage> {
public:
- using ElementsAttr::ElementsAttr;
- using ImplType = detail::SparseElementsAttributeStorage;
+ using Base::Base;
static SparseElementsAttr get(VectorOrTensorType type,
DenseIntElementsAttr indices,
};
template <typename U> bool Attribute::isa() const {
- assert(attr && "isa<> used on a null attribute.");
+ assert(impl && "isa<> used on a null attribute.");
return U::kindof(getKind());
}
template <typename U> U Attribute::dyn_cast() const {
- return isa<U>() ? U(attr) : U(nullptr);
+ return isa<U>() ? U(impl) : U(nullptr);
}
template <typename U> U Attribute::dyn_cast_or_null() const {
- return (attr && isa<U>()) ? U(attr) : U(nullptr);
+ return (impl && isa<U>()) ? U(impl) : U(nullptr);
}
template <typename U> U Attribute::cast() const {
assert(isa<U>());
- return U(attr);
+ return U(impl);
}
// Make Attribute hashable.
inline ::llvm::hash_code hash_value(Attribute arg) {
- return ::llvm::hash_value(arg.attr);
+ return ::llvm::hash_value(arg.impl);
}
/// NamedAttribute is used for named attribute lists, it holds an identifier for
template <typename First> struct VariadicTypeAdder<First> {
static void addToSet(Dialect &dialect) {
- dialect.addType(First::getTypeID());
+ dialect.addType(First::getClassID());
}
};
static MemRefType getImpl(ArrayRef<int64_t> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace, Optional<Location> location);
+ using Base::getImpl;
};
/// The 'complex' type represents a complex number with a parameterized element
--- /dev/null
+//===- StorageUniquerSupport.h - MLIR Storage Uniquer Utilities -*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines utility classes for interfacing with StorageUniquer.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_STORAGEUNIQUERSUPPORT_H
+#define MLIR_IR_STORAGEUNIQUERSUPPORT_H
+
+#include "mlir/IR/Location.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/STLExtras.h"
+#include "mlir/Support/StorageUniquer.h"
+
+namespace mlir {
+namespace detail {
+/// Utility class for implementing users of storage classes uniqued by a
+/// StorageUniquer. Clients are not expected to interact with this class
+/// directly.
+template <typename ConcreteT, typename BaseT, typename StorageT,
+ typename UniquerT>
+class StorageUserBase : public BaseT {
+public:
+ using BaseT::BaseT;
+
+ /// Utility declarations for the concrete attribute class.
+ using Base = StorageUserBase<ConcreteT, BaseT, StorageT, UniquerT>;
+ using ImplType = StorageT;
+
+ /// Return a unique identifier for the concrete type.
+ static ClassID *getClassID() { return ClassID::getID<ConcreteT>(); }
+
+protected:
+ /// Get or create a new ConcreteT instance within the ctx. This
+ /// function is guaranteed to return a non null object and will assert if
+ /// the arguments provided are invalid.
+ template <typename Kind, typename... Args>
+ static ConcreteT get(MLIRContext *ctx, Kind kind, Args... args) {
+ // Ensure that the invariants are correct for construction.
+ assert(succeeded(
+ ConcreteT::verifyConstructionInvariants(llvm::None, ctx, args...)));
+ return UniquerT::template get<ConcreteT>(ctx, kind, args...);
+ }
+
+ /// Get or create a new ConcreteT instance within the ctx, defined at
+ /// the given, potentially unknown, location. If the arguments provided are
+ /// invalid then emit errors and return a null object.
+ template <typename Kind, typename... Args>
+ static ConcreteT getChecked(Location loc, MLIRContext *ctx, Kind kind,
+ Args... args) {
+ // If the construction invariants fail then we return a null attribute.
+ if (failed(ConcreteT::verifyConstructionInvariants(loc, ctx, args...)))
+ return ConcreteT();
+ return UniquerT::template get<ConcreteT>(ctx, kind, args...);
+ }
+
+ /// Default implementation that just returns success.
+ template <typename... Args>
+ static LogicalResult
+ verifyConstructionInvariants(llvm::Optional<Location> loc, MLIRContext *ctx,
+ Args... args) {
+ return success();
+ }
+
+ /// Utility for easy access to the storage instance.
+ ImplType *getImpl() const { return static_cast<ImplType *>(this->impl); }
+};
+} // namespace detail
+} // namespace mlir
+
+#endif
#define MLIR_IR_TYPE_SUPPORT_H
#include "mlir/IR/MLIRContext.h"
-#include "mlir/Support/StorageUniquer.h"
+#include "mlir/IR/StorageUniquerSupport.h"
#include "llvm/ADT/StringRef.h"
#include <memory>
/// Get the dialect that the type 'T' was registered with.
template <typename T>
static const Dialect &lookupDialectForType(MLIRContext *ctx) {
- return lookupDialectForType(ctx, T::getTypeID());
+ return lookupDialectForType(ctx, T::getClassID());
}
/// Get the dialect that registered the type with the provided typeid.
#ifndef MLIR_IR_TYPES_H
#define MLIR_IR_TYPES_H
-#include "mlir/IR/Location.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "DialectSymbolRegistry.def"
};
- /// Utility class for implementing types. Clients are not expected to interact
- /// with this class directly. The template arguments to this class are defined
- /// as follows:
- /// - ConcreteType
- /// * The top level derived class type.
- ///
- /// - BaseType
- /// * The base type class that this utility should derive from, e.g Type,
- /// TensorType, TensorOrVectorType.
- ///
- /// - StorageType
- /// * The type storage object containing the necessary instance
- /// information for the ConcreteType.
+ /// Utility class for implementing types.
template <typename ConcreteType, typename BaseType,
typename StorageType = DefaultTypeStorage>
- class TypeBase : public BaseType {
- public:
- using BaseType::BaseType;
-
- /// Utility declarations for the concrete type class.
- using Base = TypeBase<ConcreteType, BaseType, StorageType>;
- using ImplType = StorageType;
-
- /// Return a unique identifier for the concrete type.
- static ClassID *getTypeID() { return ClassID::getID<ConcreteType>(); }
-
- protected:
- /// Get or create a new ConcreteType instance within the context. This
- /// function is guaranteed to return a non null type and will assert if the
- /// arguments provided are invalid.
- template <typename... Args>
- static ConcreteType get(MLIRContext *context, unsigned kind, Args... args) {
- // Ensure that the invariants are correct for type construction.
- assert(succeeded(ConcreteType::verifyConstructionInvariants(
- llvm::None, context, args...)));
- return detail::TypeUniquer::get<ConcreteType>(context, kind, args...);
- }
-
- /// Get or create a new ConcreteType instance within the context, defined at
- /// the given, potentially unknown, location. If the arguments provided are
- /// invalid then emit errors and return a null type.
- template <typename... Args>
- static ConcreteType getChecked(Location loc, MLIRContext *context,
- unsigned kind, Args... args) {
- // If the construction invariants fail then we return a null type.
- if (failed(ConcreteType::verifyConstructionInvariants(loc, context,
- args...)))
- return ConcreteType();
- return detail::TypeUniquer::get<ConcreteType>(context, kind, args...);
- }
-
- /// Default implementation that just returns success.
- template <typename... Args>
- static LogicalResult
- verifyConstructionInvariants(llvm::Optional<Location> loc,
- MLIRContext *context, Args... args) {
- return success();
- }
-
- /// Utility for easy access to the storage instance.
- ImplType *getImpl() const { return static_cast<ImplType *>(this->type); }
- };
+ using TypeBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
+ detail::TypeUniquer>;
using ImplType = TypeStorage;
- Type() : type(nullptr) {}
- /* implicit */ Type(const ImplType *type)
- : type(const_cast<ImplType *>(type)) {}
+ Type() : impl(nullptr) {}
+ /* implicit */ Type(const ImplType *impl)
+ : impl(const_cast<ImplType *>(impl)) {}
- Type(const Type &other) : type(other.type) {}
+ Type(const Type &other) : impl(other.impl) {}
Type &operator=(Type other) {
- type = other.type;
+ impl = other.impl;
return *this;
}
- bool operator==(Type other) const { return type == other.type; }
+ bool operator==(Type other) const { return impl == other.impl; }
bool operator!=(Type other) const { return !(*this == other); }
- explicit operator bool() const { return type; }
+ explicit operator bool() const { return impl; }
- bool operator!() const { return type == nullptr; }
+ bool operator!() const { return impl == nullptr; }
template <typename U> bool isa() const;
template <typename U> U dyn_cast() const;
/// Methods for supporting PointerLikeTypeTraits.
const void *getAsOpaquePointer() const {
- return static_cast<const void *>(type);
+ return static_cast<const void *>(impl);
}
static Type getFromOpaquePointer(const void *pointer) {
return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
}
protected:
- ImplType *type;
+ ImplType *impl;
};
inline raw_ostream &operator<<(raw_ostream &os, Type type) {
// Make Type hashable.
inline ::llvm::hash_code hash_value(Type arg) {
- return ::llvm::hash_value(arg.type);
+ return ::llvm::hash_value(arg.impl);
}
template <typename U> bool Type::isa() const {
- assert(type && "isa<> used on a null type.");
+ assert(impl && "isa<> used on a null type.");
return U::kindof(getKind());
}
template <typename U> U Type::dyn_cast() const {
- return isa<U>() ? U(type) : U(nullptr);
+ return isa<U>() ? U(impl) : U(nullptr);
}
template <typename U> U Type::dyn_cast_or_null() const {
- return (type && isa<U>()) ? U(type) : U(nullptr);
+ return (impl && isa<U>()) ? U(impl) : U(nullptr);
}
template <typename U> U Type::cast() const {
assert(isa<U>());
- return U(type);
+ return U(impl);
}
} // end namespace mlir
return llvm::hash_combine(key.first, llvm::hash_value(key.second));
}
+ /// Construct a key with a type and double.
+ static KeyTy getKey(Type type, double value) {
+ // Treat BF16 as double because it is not supported in LLVM's APFloat.
+ // TODO(b/121118307): add BF16 support to APFloat?
+ if (type.isBF16() || type.isF64())
+ return KeyTy(type, APFloat(value));
+
+ // This handles, e.g., F16 because there is no APFloat constructor for it.
+ bool unused;
+ APFloat val(value);
+ val.convert(type.cast<FloatType>().getFloatSemantics(),
+ APFloat::rmNearestTiesToEven, &unused);
+ return KeyTy(type, val);
+ }
+
/// Construct a new storage instance.
static FloatAttributeStorage *construct(AttributeStorageAllocator &allocator,
const KeyTy &key) {
//===----------------------------------------------------------------------===//
Attribute::Kind Attribute::getKind() const {
- return static_cast<Kind>(attr->getKind());
+ return static_cast<Kind>(impl->getKind());
}
/// Return the type of this attribute.
-Type Attribute::getType() const { return attr->getType(); }
+Type Attribute::getType() const { return impl->getType(); }
/// Return the context this attribute belongs to.
MLIRContext *Attribute::getContext() const { return getType().getContext(); }
bool Attribute::isOrContainsFunction() const {
- return attr->isOrContainsFunctionCache();
+ return impl->isOrContainsFunctionCache();
}
// Given an attribute that could refer to a function attribute in the remapping
BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
// Note: The context is also used within the BoolAttrStorage.
- return AttributeUniquer::get<BoolAttr>(context, Attribute::Kind::Bool,
- context, value);
+ return Base::get(context, Attribute::Kind::Bool, context, value);
}
-bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
+bool BoolAttr::getValue() const { return getImpl()->value; }
//===----------------------------------------------------------------------===//
// IntegerAttr
//===----------------------------------------------------------------------===//
IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
- return AttributeUniquer::get<IntegerAttr>(
- type.getContext(), Attribute::Kind::Integer, type, value);
+ return Base::get(type.getContext(), Attribute::Kind::Integer, type, value);
}
IntegerAttr IntegerAttr::get(Type type, int64_t value) {
return get(type, APInt(intType.getWidth(), value));
}
-APInt IntegerAttr::getValue() const {
- return static_cast<ImplType *>(attr)->getValue();
-}
+APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
// FloatAttr
//===----------------------------------------------------------------------===//
-FloatAttr FloatAttr::get(Type type, const APFloat &value) {
- assert(&type.cast<FloatType>().getFloatSemantics() == &value.getSemantics() &&
- "FloatAttr type doesn't match the type implied by its value");
- return AttributeUniquer::get<FloatAttr>(type.getContext(),
- Attribute::Kind::Float, type, value);
-}
-
-static FloatAttr getFloatAttr(Type type, double value,
- llvm::Optional<Location> loc) {
- if (!type.isa<FloatType>()) {
- if (loc)
- type.getContext()->emitError(*loc) << "expected floating point type";
- return nullptr;
- }
-
- // Treat BF16 as double because it is not supported in LLVM's APFloat.
- // TODO(b/121118307): add BF16 support to APFloat?
- if (type.isBF16() || type.isF64())
- return FloatAttr::get(type, APFloat(value));
-
- // This handles, e.g., F16 because there is no APFloat constructor for it.
- bool unused;
- APFloat val(value);
- val.convert(type.cast<FloatType>().getFloatSemantics(),
- APFloat::rmNearestTiesToEven, &unused);
- return FloatAttr::get(type, val);
+FloatAttr FloatAttr::get(Type type, double value) {
+ return Base::get(type.getContext(), Attribute::Kind::Float, type, value);
}
FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
- return getFloatAttr(type, value, loc);
+ return Base::getChecked(loc, type.getContext(), Attribute::Kind::Float, type,
+ value);
}
-FloatAttr FloatAttr::get(Type type, double value) {
- auto res = getFloatAttr(type, value, /*loc=*/llvm::None);
- assert(res && "failed to construct float attribute");
- return res;
+FloatAttr FloatAttr::get(Type type, const APFloat &value) {
+ return Base::get(type.getContext(), Attribute::Kind::Float, type, value);
}
-APFloat FloatAttr::getValue() const {
- return static_cast<ImplType *>(attr)->getValue();
+FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
+ return Base::getChecked(loc, type.getContext(), Attribute::Kind::Float, type,
+ value);
}
+APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
+
double FloatAttr::getValueAsDouble() const {
return getValueAsDouble(getValue());
}
return value.convertToDouble();
}
+/// Verify construction invariants.
+static LogicalResult verifyFloatTypeInvariants(llvm::Optional<Location> loc,
+ Type type) {
+ if (!type.isa<FloatType>()) {
+ if (loc)
+ type.getContext()->emitError(*loc, "expected floating point type");
+ return failure();
+ }
+ return success();
+}
+
+LogicalResult FloatAttr::verifyConstructionInvariants(
+ llvm::Optional<Location> loc, MLIRContext *ctx, Type type, double value) {
+ return verifyFloatTypeInvariants(loc, type);
+}
+
+LogicalResult
+FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
+ MLIRContext *ctx, Type type,
+ const APFloat &value) {
+ // Verify that the type is correct.
+ if (failed(verifyFloatTypeInvariants(loc, type)))
+ return failure();
+
+ // Verify that the type semantics match that of the value.
+ if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
+ if (loc)
+ ctx->emitError(
+ *loc, "FloatAttr type doesn't match the type implied by its value");
+ return failure();
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// StringAttr
//===----------------------------------------------------------------------===//
StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
- return AttributeUniquer::get<StringAttr>(context, Attribute::Kind::String,
- bytes);
+ return Base::get(context, Attribute::Kind::String, bytes);
}
-StringRef StringAttr::getValue() const {
- return static_cast<ImplType *>(attr)->value;
-}
+StringRef StringAttr::getValue() const { return getImpl()->value; }
//===----------------------------------------------------------------------===//
// ArrayAttr
//===----------------------------------------------------------------------===//
ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
- return AttributeUniquer::get<ArrayAttr>(context, Attribute::Kind::Array,
- value);
+ return Base::get(context, Attribute::Kind::Array, value);
}
-ArrayRef<Attribute> ArrayAttr::getValue() const {
- return static_cast<ImplType *>(attr)->value;
-}
+ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
//===----------------------------------------------------------------------===//
// AffineMapAttr
//===----------------------------------------------------------------------===//
AffineMapAttr AffineMapAttr::get(AffineMap value) {
- return AttributeUniquer::get<AffineMapAttr>(
- value.getResult(0).getContext(), Attribute::Kind::AffineMap, value);
+ return Base::get(value.getResult(0).getContext(), Attribute::Kind::AffineMap,
+ value);
}
-AffineMap AffineMapAttr::getValue() const {
- return static_cast<ImplType *>(attr)->value;
-}
+AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
//===----------------------------------------------------------------------===//
// IntegerSetAttr
//===----------------------------------------------------------------------===//
IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
- return AttributeUniquer::get<IntegerSetAttr>(
- value.getConstraint(0).getContext(), Attribute::Kind::IntegerSet, value);
+ return Base::get(value.getConstraint(0).getContext(),
+ Attribute::Kind::IntegerSet, value);
}
-IntegerSet IntegerSetAttr::getValue() const {
- return static_cast<ImplType *>(attr)->value;
-}
+IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
//===----------------------------------------------------------------------===//
// TypeAttr
//===----------------------------------------------------------------------===//
TypeAttr TypeAttr::get(Type value) {
- return AttributeUniquer::get<TypeAttr>(value.getContext(),
- Attribute::Kind::Type, value);
+ return Base::get(value.getContext(), Attribute::Kind::Type, value);
}
-Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
+Type TypeAttr::getValue() const { return getImpl()->value; }
//===----------------------------------------------------------------------===//
// FunctionAttr
FunctionAttr FunctionAttr::get(Function *value) {
assert(value && "Cannot get FunctionAttr for a null function");
- return AttributeUniquer::get<FunctionAttr>(value->getContext(),
- Attribute::Kind::Function, value);
+ return Base::get(value->getContext(), Attribute::Kind::Function, value);
}
/// This function is used by the internals of the Function class to null out
Attribute::Kind::Function, value);
}
-Function *FunctionAttr::getValue() const {
- return static_cast<ImplType *>(attr)->value;
-}
+Function *FunctionAttr::getValue() const { return getImpl()->value; }
FunctionType FunctionAttr::getType() const {
return Attribute::getType().cast<FunctionType>();
Attribute elt) {
assert(elt.getType() == type.getElementType() &&
"value should be of the given element type");
- return AttributeUniquer::get<SplatElementsAttr>(
- type.getContext(), Attribute::Kind::SplatElements, type, elt);
+ return Base::get(type.getContext(), Attribute::Kind::SplatElements, type,
+ elt);
}
-Attribute SplatElementsAttr::getValue() const {
- return static_cast<ImplType *>(attr)->elt;
-}
+Attribute SplatElementsAttr::getValue() const { return getImpl()->elt; }
//===----------------------------------------------------------------------===//
// RawElementIterator
}
ArrayRef<char> DenseElementsAttr::getRawData() const {
- return static_cast<ImplType *>(attr)->data;
+ return static_cast<ImplType *>(impl)->data;
}
// Constructs a dense elements attribute from an array of raw APInt values.
StringRef bytes) {
assert(TensorType::isValidElementType(type.getElementType()) &&
"Input element type should be a valid tensor element type");
- return AttributeUniquer::get<OpaqueElementsAttr>(
- type.getContext(), Attribute::Kind::OpaqueElements, type, dialect, bytes);
+ return Base::get(type.getContext(), Attribute::Kind::OpaqueElements, type,
+ dialect, bytes);
}
-StringRef OpaqueElementsAttr::getValue() const {
- return static_cast<ImplType *>(attr)->bytes;
-}
+StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
/// Return the value at the given index. If index does not refer to a valid
/// element, then a null attribute is returned.
return Attribute();
}
-Dialect *OpaqueElementsAttr::getDialect() const {
- return static_cast<ImplType *>(attr)->dialect;
-}
+Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
bool OpaqueElementsAttr::decode(ElementsAttr &result) {
if (auto *d = getDialect())
DenseElementsAttr values) {
assert(indices.getType().getElementType().isInteger(64) &&
"expected sparse indices to be 64-bit integer values");
- return AttributeUniquer::get<SparseElementsAttr>(
- type.getContext(), Attribute::Kind::SparseElements, type, indices,
- values);
+ return Base::get(type.getContext(), Attribute::Kind::SparseElements, type,
+ indices, values);
}
DenseIntElementsAttr SparseElementsAttr::getIndices() const {
- return static_cast<ImplType *>(attr)->indices;
+ return getImpl()->indices;
}
DenseElementsAttr SparseElementsAttr::getValues() const {
- return static_cast<ImplType *>(attr)->values;
+ return getImpl()->values;
}
/// Return the value of the element at the given index.
//===----------------------------------------------------------------------===//
Type VectorOrTensorType::getElementType() const {
- return static_cast<ImplType *>(type)->elementType;
+ return static_cast<ImplType *>(impl)->elementType;
}
unsigned VectorOrTensorType::getElementTypeBitWidth() const {
cleanedAffineMapComposition, memorySpace);
}
-ArrayRef<int64_t> MemRefType::getShape() const {
- return static_cast<ImplType *>(type)->getShape();
-}
+ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
-Type MemRefType::getElementType() const {
- return static_cast<ImplType *>(type)->elementType;
-}
+Type MemRefType::getElementType() const { return getImpl()->elementType; }
ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
- return static_cast<ImplType *>(type)->getAffineMaps();
+ return getImpl()->getAffineMaps();
}
-unsigned MemRefType::getMemorySpace() const {
- return static_cast<ImplType *>(type)->memorySpace;
-}
+unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; }
unsigned MemRefType::getNumDynamicDims() const {
return llvm::count_if(getShape(), [](int64_t i) { return i < 0; });
using namespace mlir;
using namespace mlir::detail;
-unsigned Type::getKind() const { return type->getKind(); }
+unsigned Type::getKind() const { return impl->getKind(); }
/// Get the dialect this type is registered to.
-const Dialect &Type::getDialect() const { return type->getDialect(); }
+const Dialect &Type::getDialect() const { return impl->getDialect(); }
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
-unsigned Type::getSubclassData() const { return type->getSubclassData(); }
-void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
+unsigned Type::getSubclassData() const { return impl->getSubclassData(); }
+void Type::setSubclassData(unsigned val) { impl->setSubclassData(val); }
/// Function Type.
using namespace mlir::quant::detail;
unsigned QuantizedType::getFlags() const {
- return static_cast<ImplType *>(type)->flags;
+ return static_cast<ImplType *>(impl)->flags;
}
LogicalResult QuantizedType::verifyConstructionInvariants(
}
Type QuantizedType::getStorageType() const {
- return static_cast<ImplType *>(type)->storageType;
+ return static_cast<ImplType *>(impl)->storageType;
}
int64_t QuantizedType::getStorageTypeMin() const {
- return static_cast<ImplType *>(type)->storageTypeMin;
+ return static_cast<ImplType *>(impl)->storageTypeMin;
}
int64_t QuantizedType::getStorageTypeMax() const {
- return static_cast<ImplType *>(type)->storageTypeMax;
+ return static_cast<ImplType *>(impl)->storageTypeMax;
}
unsigned QuantizedType::getStorageTypeIntegralWidth() const {
// NOTE: If ever supporting non-integral storage types, some other scheme
// for determining the width will be needed.
- return static_cast<ImplType *>(type)->storageType.getIntOrFloatBitWidth();
+ return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
}
Type QuantizedType::getExpressedType() const {
- return static_cast<ImplType *>(type)->expressedType;
+ return static_cast<ImplType *>(impl)->expressedType;
}
bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {