Add an AttrBase class to simplify defining derived Attributes. This class serves...
authorRiver Riddle <riverriddle@google.com>
Thu, 9 May 2019 05:25:15 +0000 (22:25 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sat, 11 May 2019 02:24:46 +0000 (19:24 -0700)
--

PiperOrigin-RevId: 247358373

12 files changed:
mlir/include/mlir/IR/AttributeSupport.h
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/StandardTypes.h
mlir/include/mlir/IR/StorageUniquerSupport.h [new file with mode: 0644]
mlir/include/mlir/IR/TypeSupport.h
mlir/include/mlir/IR/Types.h
mlir/lib/IR/AttributeDetail.h
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/StandardTypes.cpp
mlir/lib/IR/Types.cpp
mlir/lib/Quantization/IR/QuantTypes.cpp

index 9bf43b0..88dbf8e 100644 (file)
@@ -23,7 +23,7 @@
 #define MLIR_IR_ATTRIBUTESUPPORT_H
 
 #include "mlir/IR/MLIRContext.h"
-#include "mlir/Support/StorageUniquer.h"
+#include "mlir/IR/StorageUniquerSupport.h"
 #include "llvm/ADT/PointerIntPair.h"
 #include "llvm/ADT/StringRef.h"
 
index 56b8c7b..4f4d2c6 100644 (file)
@@ -90,24 +90,30 @@ public:
     LAST_KIND = SparseElements,
   };
 
+  /// Utility class for implementing attributes.
+  template <typename ConcreteType, typename BaseType = Attribute,
+            typename StorageType = AttributeStorage>
+  using AttrBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
+                                           detail::AttributeUniquer>;
+
   using ImplType = AttributeStorage;
   using ValueType = void;
 
-  Attribute() : attr(nullptr) {}
-  /* implicit */ Attribute(const ImplType *attr)
-      : attr(const_cast<ImplType *>(attr)) {}
+  Attribute() : impl(nullptr) {}
+  /* implicit */ Attribute(const ImplType *impl)
+      : impl(const_cast<ImplType *>(impl)) {}
 
-  Attribute(const Attribute &other) : attr(other.attr) {}
+  Attribute(const Attribute &other) : impl(other.impl) {}
   Attribute &operator=(Attribute other) {
-    attr = other.attr;
+    impl = other.impl;
     return *this;
   }
 
-  bool operator==(Attribute other) const { return attr == other.attr; }
+  bool operator==(Attribute other) const { return impl == other.impl; }
   bool operator!=(Attribute other) const { return !(*this == other); }
-  explicit operator bool() const { return attr; }
+  explicit operator bool() const { return impl; }
 
-  bool operator!() const { return attr == nullptr; }
+  bool operator!() const { return impl == nullptr; }
 
   template <typename U> bool isa() const;
   template <typename U> U dyn_cast() const;
@@ -145,7 +151,7 @@ public:
   void dump() const;
 
   /// Get an opaque pointer to the attribute.
-  const void *getAsOpaquePointer() const { return attr; }
+  const void *getAsOpaquePointer() const { return impl; }
   /// Construct an attribute from the opaque pointer representation.
   static Attribute getFromOpaquePointer(const void *ptr) {
     return Attribute(
@@ -155,7 +161,7 @@ public:
   friend ::llvm::hash_code hash_value(Attribute arg);
 
 protected:
-  ImplType *attr;
+  ImplType *impl;
 };
 
 inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) {
@@ -165,19 +171,19 @@ inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) {
 
 /// Unit attributes are attributes that hold no specific value and are given
 /// meaning by their existence.
-class UnitAttr : public Attribute {
+class UnitAttr : public Attribute::AttrBase<UnitAttr> {
 public:
-  using Attribute::Attribute;
+  using Base::Base;
 
   static UnitAttr get(MLIRContext *context);
 
   static bool kindof(Kind kind) { return kind == Attribute::Kind::Unit; }
 };
 
-class BoolAttr : public Attribute {
+class BoolAttr : public Attribute::AttrBase<BoolAttr, Attribute,
+                                            detail::BoolAttributeStorage> {
 public:
-  using Attribute::Attribute;
-  using ImplType = detail::BoolAttributeStorage;
+  using Base::Base;
   using ValueType = bool;
 
   static BoolAttr get(bool value, MLIRContext *context);
@@ -188,10 +194,11 @@ public:
   static bool kindof(Kind kind) { return kind == Kind::Bool; }
 };
 
-class IntegerAttr : public Attribute {
+class IntegerAttr
+    : public Attribute::AttrBase<IntegerAttr, Attribute,
+                                 detail::IntegerAttributeStorage> {
 public:
-  using Attribute::Attribute;
-  using ImplType = detail::IntegerAttributeStorage;
+  using Base::Base;
   using ValueType = APInt;
 
   static IntegerAttr get(Type type, int64_t value);
@@ -205,10 +212,10 @@ public:
   static bool kindof(Kind kind) { return kind == Kind::Integer; }
 };
 
-class FloatAttr : public Attribute {
+class FloatAttr : public Attribute::AttrBase<FloatAttr, Attribute,
+                                             detail::FloatAttributeStorage> {
 public:
-  using Attribute::Attribute;
-  using ImplType = detail::FloatAttributeStorage;
+  using Base::Base;
   using ValueType = APFloat;
 
   /// Return a float attribute for the specified value in the specified type.
@@ -219,6 +226,7 @@ public:
 
   /// Return a float attribute for the specified value in the specified type.
   static FloatAttr get(Type type, const APFloat &value);
+  static FloatAttr getChecked(Type type, const APFloat &value, Location loc);
 
   APFloat getValue() const;
 
@@ -229,12 +237,20 @@ public:
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool kindof(Kind kind) { return kind == Kind::Float; }
+
+  /// Verify the construction invariants for a double value.
+  static LogicalResult
+  verifyConstructionInvariants(llvm::Optional<Location> loc, MLIRContext *ctx,
+                               Type type, double value);
+  static LogicalResult
+  verifyConstructionInvariants(llvm::Optional<Location> loc, MLIRContext *ctx,
+                               Type type, const APFloat &value);
 };
 
-class StringAttr : public Attribute {
+class StringAttr : public Attribute::AttrBase<StringAttr, Attribute,
+                                              detail::StringAttributeStorage> {
 public:
-  using Attribute::Attribute;
-  using ImplType = detail::StringAttributeStorage;
+  using Base::Base;
   using ValueType = StringRef;
 
   static StringAttr get(StringRef bytes, MLIRContext *context);
@@ -247,10 +263,10 @@ public:
 
 /// Array attributes are lists of other attributes.  They are not necessarily
 /// type homogenous given that attributes don't, in general, carry types.
-class ArrayAttr : public Attribute {
+class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute,
+                                             detail::ArrayAttributeStorage> {
 public:
-  using Attribute::Attribute;
-  using ImplType = detail::ArrayAttributeStorage;
+  using Base::Base;
   using ValueType = ArrayRef<Attribute>;
 
   static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context);
@@ -267,10 +283,11 @@ public:
   static bool kindof(Kind kind) { return kind == Kind::Array; }
 };
 
-class AffineMapAttr : public Attribute {
+class AffineMapAttr
+    : public Attribute::AttrBase<AffineMapAttr, Attribute,
+                                 detail::AffineMapAttributeStorage> {
 public:
-  using Attribute::Attribute;
-  using ImplType = detail::AffineMapAttributeStorage;
+  using Base::Base;
   using ValueType = AffineMap;
 
   static AffineMapAttr get(AffineMap value);
@@ -281,10 +298,11 @@ public:
   static bool kindof(Kind kind) { return kind == Kind::AffineMap; }
 };
 
-class IntegerSetAttr : public Attribute {
+class IntegerSetAttr
+    : public Attribute::AttrBase<IntegerSetAttr, Attribute,
+                                 detail::IntegerSetAttributeStorage> {
 public:
-  using Attribute::Attribute;
-  using ImplType = detail::IntegerSetAttributeStorage;
+  using Base::Base;
   using ValueType = IntegerSet;
 
   static IntegerSetAttr get(IntegerSet value);
@@ -295,10 +313,10 @@ public:
   static bool kindof(Kind kind) { return kind == Kind::IntegerSet; }
 };
 
-class TypeAttr : public Attribute {
+class TypeAttr : public Attribute::AttrBase<TypeAttr, Attribute,
+                                            detail::TypeAttributeStorage> {
 public:
-  using Attribute::Attribute;
-  using ImplType = detail::TypeAttributeStorage;
+  using Base::Base;
   using ValueType = Type;
 
   static TypeAttr get(Type value);
@@ -316,10 +334,11 @@ public:
 /// is deleted that had an attribute which referenced it.  No references to this
 /// attribute should persist across the transformation, but that attribute will
 /// remain in MLIRContext.
-class FunctionAttr : public Attribute {
+class FunctionAttr
+    : public Attribute::AttrBase<FunctionAttr, Attribute,
+                                 detail::FunctionAttributeStorage> {
 public:
-  using Attribute::Attribute;
-  using ImplType = detail::FunctionAttributeStorage;
+  using Base::Base;
   using ValueType = Function *;
 
   static FunctionAttr get(Function *value);
@@ -356,10 +375,11 @@ public:
 
 /// An attribute that represents a reference to a splat vecctor or tensor
 /// constant, meaning all of the elements have the same value.
-class SplatElementsAttr : public ElementsAttr {
+class SplatElementsAttr
+    : public Attribute::AttrBase<SplatElementsAttr, ElementsAttr,
+                                 detail::SplatElementsAttributeStorage> {
 public:
-  using ElementsAttr::ElementsAttr;
-  using ImplType = detail::SplatElementsAttributeStorage;
+  using Base::Base;
   using ValueType = Attribute;
 
   static SplatElementsAttr get(VectorOrTensorType type, Attribute elt);
@@ -467,16 +487,17 @@ protected:
 
 /// An attribute that represents a reference to a dense integer vector or tensor
 /// object.
-class DenseIntElementsAttr : public DenseElementsAttr {
+class DenseIntElementsAttr
+    : public Attribute::AttrBase<DenseIntElementsAttr, DenseElementsAttr,
+                                 detail::DenseElementsAttributeStorage> {
 public:
   /// DenseIntElementsAttr iterates on APInt, so we can use the raw element
   /// iterator directly.
   using iterator = DenseElementsAttr::RawElementIterator;
 
-  using DenseElementsAttr::DenseElementsAttr;
+  using Base::Base;
   using DenseElementsAttr::get;
   using DenseElementsAttr::getValues;
-  using DenseElementsAttr::ImplType;
 
   /// Constructs a dense integer elements attribute from an array of APInt
   /// values. Each APInt value is expected to have the same bitwidth as the
@@ -503,7 +524,9 @@ public:
 
 /// An attribute that represents a reference to a dense float vector or tensor
 /// object. Each element is stored as a double.
-class DenseFPElementsAttr : public DenseElementsAttr {
+class DenseFPElementsAttr
+    : public Attribute::AttrBase<DenseFPElementsAttr, DenseElementsAttr,
+                                 detail::DenseElementsAttributeStorage> {
 public:
   /// DenseFPElementsAttr iterates on APFloat, so we need to wrap the raw
   /// element iterator.
@@ -517,10 +540,9 @@ public:
   };
   using iterator = ElementIterator;
 
-  using DenseElementsAttr::DenseElementsAttr;
+  using Base::Base;
   using DenseElementsAttr::get;
   using DenseElementsAttr::getValues;
-  using DenseElementsAttr::ImplType;
 
   // Constructs a dense float elements attribute from an array of APFloat
   // values. Each APFloat value is expected to have the same bitwidth as the
@@ -544,10 +566,11 @@ public:
 /// which the compiler may not need to interpret. This attribute is always
 /// associated with a particular dialect, which provides a method to convert
 /// tensor representation to a non-opaque format.
-class OpaqueElementsAttr : public ElementsAttr {
+class OpaqueElementsAttr
+    : public Attribute::AttrBase<OpaqueElementsAttr, ElementsAttr,
+                                 detail::OpaqueElementsAttributeStorage> {
 public:
-  using ElementsAttr::ElementsAttr;
-  using ImplType = detail::OpaqueElementsAttributeStorage;
+  using Base::Base;
   using ValueType = StringRef;
 
   static OpaqueElementsAttr get(Dialect *dialect, VectorOrTensorType type,
@@ -589,10 +612,11 @@ public:
 /// [[1, 0, 0, 0],
 ///  [0, 0, 5, 0],
 ///  [0, 0, 0, 0]].
-class SparseElementsAttr : public ElementsAttr {
+class SparseElementsAttr
+    : public Attribute::AttrBase<SparseElementsAttr, ElementsAttr,
+                                 detail::SparseElementsAttributeStorage> {
 public:
-  using ElementsAttr::ElementsAttr;
-  using ImplType = detail::SparseElementsAttributeStorage;
+  using Base::Base;
 
   static SparseElementsAttr get(VectorOrTensorType type,
                                 DenseIntElementsAttr indices,
@@ -610,23 +634,23 @@ public:
 };
 
 template <typename U> bool Attribute::isa() const {
-  assert(attr && "isa<> used on a null attribute.");
+  assert(impl && "isa<> used on a null attribute.");
   return U::kindof(getKind());
 }
 template <typename U> U Attribute::dyn_cast() const {
-  return isa<U>() ? U(attr) : U(nullptr);
+  return isa<U>() ? U(impl) : U(nullptr);
 }
 template <typename U> U Attribute::dyn_cast_or_null() const {
-  return (attr && isa<U>()) ? U(attr) : U(nullptr);
+  return (impl && isa<U>()) ? U(impl) : U(nullptr);
 }
 template <typename U> U Attribute::cast() const {
   assert(isa<U>());
-  return U(attr);
+  return U(impl);
 }
 
 // Make Attribute hashable.
 inline ::llvm::hash_code hash_value(Attribute arg) {
-  return ::llvm::hash_value(arg.attr);
+  return ::llvm::hash_value(arg.impl);
 }
 
 /// NamedAttribute is used for named attribute lists, it holds an identifier for
index 928c6f7..c279ffd 100644 (file)
@@ -189,7 +189,7 @@ protected:
 
   template <typename First> struct VariadicTypeAdder<First> {
     static void addToSet(Dialect &dialect) {
-      dialect.addType(First::getTypeID());
+      dialect.addType(First::getClassID());
     }
   };
 
index 27c4b49..55be6ed 100644 (file)
@@ -422,6 +422,7 @@ private:
   static MemRefType getImpl(ArrayRef<int64_t> shape, Type elementType,
                             ArrayRef<AffineMap> affineMapComposition,
                             unsigned memorySpace, Optional<Location> location);
+  using Base::getImpl;
 };
 
 /// The 'complex' type represents a complex number with a parameterized element
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
new file mode 100644 (file)
index 0000000..c510d7e
--- /dev/null
@@ -0,0 +1,86 @@
+//===- StorageUniquerSupport.h - MLIR Storage Uniquer Utilities -*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines utility classes for interfacing with StorageUniquer.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_STORAGEUNIQUERSUPPORT_H
+#define MLIR_IR_STORAGEUNIQUERSUPPORT_H
+
+#include "mlir/IR/Location.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/STLExtras.h"
+#include "mlir/Support/StorageUniquer.h"
+
+namespace mlir {
+namespace detail {
+/// Utility class for implementing users of storage classes uniqued by a
+/// StorageUniquer. Clients are not expected to interact with this class
+/// directly.
+template <typename ConcreteT, typename BaseT, typename StorageT,
+          typename UniquerT>
+class StorageUserBase : public BaseT {
+public:
+  using BaseT::BaseT;
+
+  /// Utility declarations for the concrete attribute class.
+  using Base = StorageUserBase<ConcreteT, BaseT, StorageT, UniquerT>;
+  using ImplType = StorageT;
+
+  /// Return a unique identifier for the concrete type.
+  static ClassID *getClassID() { return ClassID::getID<ConcreteT>(); }
+
+protected:
+  /// Get or create a new ConcreteT instance within the ctx. This
+  /// function is guaranteed to return a non null object and will assert if
+  /// the arguments provided are invalid.
+  template <typename Kind, typename... Args>
+  static ConcreteT get(MLIRContext *ctx, Kind kind, Args... args) {
+    // Ensure that the invariants are correct for construction.
+    assert(succeeded(
+        ConcreteT::verifyConstructionInvariants(llvm::None, ctx, args...)));
+    return UniquerT::template get<ConcreteT>(ctx, kind, args...);
+  }
+
+  /// Get or create a new ConcreteT instance within the ctx, defined at
+  /// the given, potentially unknown, location. If the arguments provided are
+  /// invalid then emit errors and return a null object.
+  template <typename Kind, typename... Args>
+  static ConcreteT getChecked(Location loc, MLIRContext *ctx, Kind kind,
+                              Args... args) {
+    // If the construction invariants fail then we return a null attribute.
+    if (failed(ConcreteT::verifyConstructionInvariants(loc, ctx, args...)))
+      return ConcreteT();
+    return UniquerT::template get<ConcreteT>(ctx, kind, args...);
+  }
+
+  /// Default implementation that just returns success.
+  template <typename... Args>
+  static LogicalResult
+  verifyConstructionInvariants(llvm::Optional<Location> loc, MLIRContext *ctx,
+                               Args... args) {
+    return success();
+  }
+
+  /// Utility for easy access to the storage instance.
+  ImplType *getImpl() const { return static_cast<ImplType *>(this->impl); }
+};
+} // namespace detail
+} // namespace mlir
+
+#endif
index d8504de..f5b8b18 100644 (file)
@@ -23,7 +23,7 @@
 #define MLIR_IR_TYPE_SUPPORT_H
 
 #include "mlir/IR/MLIRContext.h"
-#include "mlir/Support/StorageUniquer.h"
+#include "mlir/IR/StorageUniquerSupport.h"
 #include "llvm/ADT/StringRef.h"
 #include <memory>
 
@@ -110,7 +110,7 @@ private:
   /// Get the dialect that the type 'T' was registered with.
   template <typename T>
   static const Dialect &lookupDialectForType(MLIRContext *ctx) {
-    return lookupDialectForType(ctx, T::getTypeID());
+    return lookupDialectForType(ctx, T::getClassID());
   }
 
   /// Get the dialect that registered the type with the provided typeid.
index 40e2df4..c963f9e 100644 (file)
 #ifndef MLIR_IR_TYPES_H
 #define MLIR_IR_TYPES_H
 
-#include "mlir/IR/Location.h"
 #include "mlir/IR/TypeSupport.h"
 #include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Support/STLExtras.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMapInfo.h"
 
@@ -112,86 +109,29 @@ public:
 #include "DialectSymbolRegistry.def"
   };
 
-  /// Utility class for implementing types. Clients are not expected to interact
-  /// with this class directly. The template arguments to this class are defined
-  /// as follows:
-  ///   - ConcreteType
-  ///     * The top level derived class type.
-  ///
-  ///   - BaseType
-  ///     * The base type class that this utility should derive from, e.g Type,
-  ///       TensorType, TensorOrVectorType.
-  ///
-  ///   - StorageType
-  ///     * The type storage object containing the necessary instance
-  ///       information for the ConcreteType.
+  /// Utility class for implementing types.
   template <typename ConcreteType, typename BaseType,
             typename StorageType = DefaultTypeStorage>
-  class TypeBase : public BaseType {
-  public:
-    using BaseType::BaseType;
-
-    /// Utility declarations for the concrete type class.
-    using Base = TypeBase<ConcreteType, BaseType, StorageType>;
-    using ImplType = StorageType;
-
-    /// Return a unique identifier for the concrete type.
-    static ClassID *getTypeID() { return ClassID::getID<ConcreteType>(); }
-
-  protected:
-    /// Get or create a new ConcreteType instance within the context. This
-    /// function is guaranteed to return a non null type and will assert if the
-    /// arguments provided are invalid.
-    template <typename... Args>
-    static ConcreteType get(MLIRContext *context, unsigned kind, Args... args) {
-      // Ensure that the invariants are correct for type construction.
-      assert(succeeded(ConcreteType::verifyConstructionInvariants(
-          llvm::None, context, args...)));
-      return detail::TypeUniquer::get<ConcreteType>(context, kind, args...);
-    }
-
-    /// Get or create a new ConcreteType instance within the context, defined at
-    /// the given, potentially unknown, location. If the arguments provided are
-    /// invalid then emit errors and return a null type.
-    template <typename... Args>
-    static ConcreteType getChecked(Location loc, MLIRContext *context,
-                                   unsigned kind, Args... args) {
-      // If the construction invariants fail then we return a null type.
-      if (failed(ConcreteType::verifyConstructionInvariants(loc, context,
-                                                            args...)))
-        return ConcreteType();
-      return detail::TypeUniquer::get<ConcreteType>(context, kind, args...);
-    }
-
-    /// Default implementation that just returns success.
-    template <typename... Args>
-    static LogicalResult
-    verifyConstructionInvariants(llvm::Optional<Location> loc,
-                                 MLIRContext *context, Args... args) {
-      return success();
-    }
-
-    /// Utility for easy access to the storage instance.
-    ImplType *getImpl() const { return static_cast<ImplType *>(this->type); }
-  };
+  using TypeBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
+                                           detail::TypeUniquer>;
 
   using ImplType = TypeStorage;
 
-  Type() : type(nullptr) {}
-  /* implicit */ Type(const ImplType *type)
-      : type(const_cast<ImplType *>(type)) {}
+  Type() : impl(nullptr) {}
+  /* implicit */ Type(const ImplType *impl)
+      : impl(const_cast<ImplType *>(impl)) {}
 
-  Type(const Type &other) : type(other.type) {}
+  Type(const Type &other) : impl(other.impl) {}
   Type &operator=(Type other) {
-    type = other.type;
+    impl = other.impl;
     return *this;
   }
 
-  bool operator==(Type other) const { return type == other.type; }
+  bool operator==(Type other) const { return impl == other.impl; }
   bool operator!=(Type other) const { return !(*this == other); }
-  explicit operator bool() const { return type; }
+  explicit operator bool() const { return impl; }
 
-  bool operator!() const { return type == nullptr; }
+  bool operator!() const { return impl == nullptr; }
 
   template <typename U> bool isa() const;
   template <typename U> U dyn_cast() const;
@@ -240,14 +180,14 @@ public:
 
   /// Methods for supporting PointerLikeTypeTraits.
   const void *getAsOpaquePointer() const {
-    return static_cast<const void *>(type);
+    return static_cast<const void *>(impl);
   }
   static Type getFromOpaquePointer(const void *pointer) {
     return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
   }
 
 protected:
-  ImplType *type;
+  ImplType *impl;
 };
 
 inline raw_ostream &operator<<(raw_ostream &os, Type type) {
@@ -317,22 +257,22 @@ public:
 
 // Make Type hashable.
 inline ::llvm::hash_code hash_value(Type arg) {
-  return ::llvm::hash_value(arg.type);
+  return ::llvm::hash_value(arg.impl);
 }
 
 template <typename U> bool Type::isa() const {
-  assert(type && "isa<> used on a null type.");
+  assert(impl && "isa<> used on a null type.");
   return U::kindof(getKind());
 }
 template <typename U> U Type::dyn_cast() const {
-  return isa<U>() ? U(type) : U(nullptr);
+  return isa<U>() ? U(impl) : U(nullptr);
 }
 template <typename U> U Type::dyn_cast_or_null() const {
-  return (type && isa<U>()) ? U(type) : U(nullptr);
+  return (impl && isa<U>()) ? U(impl) : U(nullptr);
 }
 template <typename U> U Type::cast() const {
   assert(isa<U>());
-  return U(type);
+  return U(impl);
 }
 
 } // end namespace mlir
index aab4445..cd4c88a 100644 (file)
@@ -123,6 +123,21 @@ struct FloatAttributeStorage final
     return llvm::hash_combine(key.first, llvm::hash_value(key.second));
   }
 
+  /// Construct a key with a type and double.
+  static KeyTy getKey(Type type, double value) {
+    // Treat BF16 as double because it is not supported in LLVM's APFloat.
+    // TODO(b/121118307): add BF16 support to APFloat?
+    if (type.isBF16() || type.isF64())
+      return KeyTy(type, APFloat(value));
+
+    // This handles, e.g., F16 because there is no APFloat constructor for it.
+    bool unused;
+    APFloat val(value);
+    val.convert(type.cast<FloatType>().getFloatSemantics(),
+                APFloat::rmNearestTiesToEven, &unused);
+    return KeyTy(type, val);
+  }
+
   /// Construct a new storage instance.
   static FloatAttributeStorage *construct(AttributeStorageAllocator &allocator,
                                           const KeyTy &key) {
index 62c3c93..0504c49 100644 (file)
@@ -61,17 +61,17 @@ AttributeUniquer::getInitFn(MLIRContext *ctx) {
 //===----------------------------------------------------------------------===//
 
 Attribute::Kind Attribute::getKind() const {
-  return static_cast<Kind>(attr->getKind());
+  return static_cast<Kind>(impl->getKind());
 }
 
 /// Return the type of this attribute.
-Type Attribute::getType() const { return attr->getType(); }
+Type Attribute::getType() const { return impl->getType(); }
 
 /// Return the context this attribute belongs to.
 MLIRContext *Attribute::getContext() const { return getType().getContext(); }
 
 bool Attribute::isOrContainsFunction() const {
-  return attr->isOrContainsFunctionCache();
+  return impl->isOrContainsFunctionCache();
 }
 
 // Given an attribute that could refer to a function attribute in the remapping
@@ -120,19 +120,17 @@ UnitAttr UnitAttr::get(MLIRContext *context) {
 
 BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
   // Note: The context is also used within the BoolAttrStorage.
-  return AttributeUniquer::get<BoolAttr>(context, Attribute::Kind::Bool,
-                                         context, value);
+  return Base::get(context, Attribute::Kind::Bool, context, value);
 }
 
-bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
+bool BoolAttr::getValue() const { return getImpl()->value; }
 
 //===----------------------------------------------------------------------===//
 // IntegerAttr
 //===----------------------------------------------------------------------===//
 
 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
-  return AttributeUniquer::get<IntegerAttr>(
-      type.getContext(), Attribute::Kind::Integer, type, value);
+  return Base::get(type.getContext(), Attribute::Kind::Integer, type, value);
 }
 
 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
@@ -144,9 +142,7 @@ IntegerAttr IntegerAttr::get(Type type, int64_t value) {
   return get(type, APInt(intType.getWidth(), value));
 }
 
-APInt IntegerAttr::getValue() const {
-  return static_cast<ImplType *>(attr)->getValue();
-}
+APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
 
 int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
 
@@ -154,48 +150,26 @@ int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
 // FloatAttr
 //===----------------------------------------------------------------------===//
 
-FloatAttr FloatAttr::get(Type type, const APFloat &value) {
-  assert(&type.cast<FloatType>().getFloatSemantics() == &value.getSemantics() &&
-         "FloatAttr type doesn't match the type implied by its value");
-  return AttributeUniquer::get<FloatAttr>(type.getContext(),
-                                          Attribute::Kind::Float, type, value);
-}
-
-static FloatAttr getFloatAttr(Type type, double value,
-                              llvm::Optional<Location> loc) {
-  if (!type.isa<FloatType>()) {
-    if (loc)
-      type.getContext()->emitError(*loc) << "expected floating point type";
-    return nullptr;
-  }
-
-  // Treat BF16 as double because it is not supported in LLVM's APFloat.
-  // TODO(b/121118307): add BF16 support to APFloat?
-  if (type.isBF16() || type.isF64())
-    return FloatAttr::get(type, APFloat(value));
-
-  // This handles, e.g., F16 because there is no APFloat constructor for it.
-  bool unused;
-  APFloat val(value);
-  val.convert(type.cast<FloatType>().getFloatSemantics(),
-              APFloat::rmNearestTiesToEven, &unused);
-  return FloatAttr::get(type, val);
+FloatAttr FloatAttr::get(Type type, double value) {
+  return Base::get(type.getContext(), Attribute::Kind::Float, type, value);
 }
 
 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
-  return getFloatAttr(type, value, loc);
+  return Base::getChecked(loc, type.getContext(), Attribute::Kind::Float, type,
+                          value);
 }
 
-FloatAttr FloatAttr::get(Type type, double value) {
-  auto res = getFloatAttr(type, value, /*loc=*/llvm::None);
-  assert(res && "failed to construct float attribute");
-  return res;
+FloatAttr FloatAttr::get(Type type, const APFloat &value) {
+  return Base::get(type.getContext(), Attribute::Kind::Float, type, value);
 }
 
-APFloat FloatAttr::getValue() const {
-  return static_cast<ImplType *>(attr)->getValue();
+FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
+  return Base::getChecked(loc, type.getContext(), Attribute::Kind::Float, type,
+                          value);
 }
 
+APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
+
 double FloatAttr::getValueAsDouble() const {
   return getValueAsDouble(getValue());
 }
@@ -208,68 +182,91 @@ double FloatAttr::getValueAsDouble(APFloat value) {
   return value.convertToDouble();
 }
 
+/// Verify construction invariants.
+static LogicalResult verifyFloatTypeInvariants(llvm::Optional<Location> loc,
+                                               Type type) {
+  if (!type.isa<FloatType>()) {
+    if (loc)
+      type.getContext()->emitError(*loc, "expected floating point type");
+    return failure();
+  }
+  return success();
+}
+
+LogicalResult FloatAttr::verifyConstructionInvariants(
+    llvm::Optional<Location> loc, MLIRContext *ctx, Type type, double value) {
+  return verifyFloatTypeInvariants(loc, type);
+}
+
+LogicalResult
+FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
+                                        MLIRContext *ctx, Type type,
+                                        const APFloat &value) {
+  // Verify that the type is correct.
+  if (failed(verifyFloatTypeInvariants(loc, type)))
+    return failure();
+
+  // Verify that the type semantics match that of the value.
+  if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
+    if (loc)
+      ctx->emitError(
+          *loc, "FloatAttr type doesn't match the type implied by its value");
+    return failure();
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // StringAttr
 //===----------------------------------------------------------------------===//
 
 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
-  return AttributeUniquer::get<StringAttr>(context, Attribute::Kind::String,
-                                           bytes);
+  return Base::get(context, Attribute::Kind::String, bytes);
 }
 
-StringRef StringAttr::getValue() const {
-  return static_cast<ImplType *>(attr)->value;
-}
+StringRef StringAttr::getValue() const { return getImpl()->value; }
 
 //===----------------------------------------------------------------------===//
 // ArrayAttr
 //===----------------------------------------------------------------------===//
 
 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
-  return AttributeUniquer::get<ArrayAttr>(context, Attribute::Kind::Array,
-                                          value);
+  return Base::get(context, Attribute::Kind::Array, value);
 }
 
-ArrayRef<Attribute> ArrayAttr::getValue() const {
-  return static_cast<ImplType *>(attr)->value;
-}
+ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
 
 //===----------------------------------------------------------------------===//
 // AffineMapAttr
 //===----------------------------------------------------------------------===//
 
 AffineMapAttr AffineMapAttr::get(AffineMap value) {
-  return AttributeUniquer::get<AffineMapAttr>(
-      value.getResult(0).getContext(), Attribute::Kind::AffineMap, value);
+  return Base::get(value.getResult(0).getContext(), Attribute::Kind::AffineMap,
+                   value);
 }
 
-AffineMap AffineMapAttr::getValue() const {
-  return static_cast<ImplType *>(attr)->value;
-}
+AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
 
 //===----------------------------------------------------------------------===//
 // IntegerSetAttr
 //===----------------------------------------------------------------------===//
 
 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
-  return AttributeUniquer::get<IntegerSetAttr>(
-      value.getConstraint(0).getContext(), Attribute::Kind::IntegerSet, value);
+  return Base::get(value.getConstraint(0).getContext(),
+                   Attribute::Kind::IntegerSet, value);
 }
 
-IntegerSet IntegerSetAttr::getValue() const {
-  return static_cast<ImplType *>(attr)->value;
-}
+IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
 
 //===----------------------------------------------------------------------===//
 // TypeAttr
 //===----------------------------------------------------------------------===//
 
 TypeAttr TypeAttr::get(Type value) {
-  return AttributeUniquer::get<TypeAttr>(value.getContext(),
-                                         Attribute::Kind::Type, value);
+  return Base::get(value.getContext(), Attribute::Kind::Type, value);
 }
 
-Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
+Type TypeAttr::getValue() const { return getImpl()->value; }
 
 //===----------------------------------------------------------------------===//
 // FunctionAttr
@@ -277,8 +274,7 @@ Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
 
 FunctionAttr FunctionAttr::get(Function *value) {
   assert(value && "Cannot get FunctionAttr for a null function");
-  return AttributeUniquer::get<FunctionAttr>(value->getContext(),
-                                             Attribute::Kind::Function, value);
+  return Base::get(value->getContext(), Attribute::Kind::Function, value);
 }
 
 /// This function is used by the internals of the Function class to null out
@@ -288,9 +284,7 @@ void FunctionAttr::dropFunctionReference(Function *value) {
                                         Attribute::Kind::Function, value);
 }
 
-Function *FunctionAttr::getValue() const {
-  return static_cast<ImplType *>(attr)->value;
-}
+Function *FunctionAttr::getValue() const { return getImpl()->value; }
 
 FunctionType FunctionAttr::getType() const {
   return Attribute::getType().cast<FunctionType>();
@@ -330,13 +324,11 @@ SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
                                          Attribute elt) {
   assert(elt.getType() == type.getElementType() &&
          "value should be of the given element type");
-  return AttributeUniquer::get<SplatElementsAttr>(
-      type.getContext(), Attribute::Kind::SplatElements, type, elt);
+  return Base::get(type.getContext(), Attribute::Kind::SplatElements, type,
+                   elt);
 }
 
-Attribute SplatElementsAttr::getValue() const {
-  return static_cast<ImplType *>(attr)->elt;
-}
+Attribute SplatElementsAttr::getValue() const { return getImpl()->elt; }
 
 //===----------------------------------------------------------------------===//
 // RawElementIterator
@@ -502,7 +494,7 @@ void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
 }
 
 ArrayRef<char> DenseElementsAttr::getRawData() const {
-  return static_cast<ImplType *>(attr)->data;
+  return static_cast<ImplType *>(impl)->data;
 }
 
 // Constructs a dense elements attribute from an array of raw APInt values.
@@ -652,13 +644,11 @@ OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect,
                                            StringRef bytes) {
   assert(TensorType::isValidElementType(type.getElementType()) &&
          "Input element type should be a valid tensor element type");
-  return AttributeUniquer::get<OpaqueElementsAttr>(
-      type.getContext(), Attribute::Kind::OpaqueElements, type, dialect, bytes);
+  return Base::get(type.getContext(), Attribute::Kind::OpaqueElements, type,
+                   dialect, bytes);
 }
 
-StringRef OpaqueElementsAttr::getValue() const {
-  return static_cast<ImplType *>(attr)->bytes;
-}
+StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
 
 /// Return the value at the given index. If index does not refer to a valid
 /// element, then a null attribute is returned.
@@ -668,9 +658,7 @@ Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
   return Attribute();
 }
 
-Dialect *OpaqueElementsAttr::getDialect() const {
-  return static_cast<ImplType *>(attr)->dialect;
-}
+Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
 
 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
   if (auto *d = getDialect())
@@ -687,17 +675,16 @@ SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type,
                                            DenseElementsAttr values) {
   assert(indices.getType().getElementType().isInteger(64) &&
          "expected sparse indices to be 64-bit integer values");
-  return AttributeUniquer::get<SparseElementsAttr>(
-      type.getContext(), Attribute::Kind::SparseElements, type, indices,
-      values);
+  return Base::get(type.getContext(), Attribute::Kind::SparseElements, type,
+                   indices, values);
 }
 
 DenseIntElementsAttr SparseElementsAttr::getIndices() const {
-  return static_cast<ImplType *>(attr)->indices;
+  return getImpl()->indices;
 }
 
 DenseElementsAttr SparseElementsAttr::getValues() const {
-  return static_cast<ImplType *>(attr)->values;
+  return getImpl()->values;
 }
 
 /// Return the value of the element at the given index.
index 036fd09..d727fb1 100644 (file)
@@ -108,7 +108,7 @@ unsigned Type::getIntOrFloatBitWidth() const {
 //===----------------------------------------------------------------------===//
 
 Type VectorOrTensorType::getElementType() const {
-  return static_cast<ImplType *>(type)->elementType;
+  return static_cast<ImplType *>(impl)->elementType;
 }
 
 unsigned VectorOrTensorType::getElementTypeBitWidth() const {
@@ -360,21 +360,15 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
                    cleanedAffineMapComposition, memorySpace);
 }
 
-ArrayRef<int64_t> MemRefType::getShape() const {
-  return static_cast<ImplType *>(type)->getShape();
-}
+ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
 
-Type MemRefType::getElementType() const {
-  return static_cast<ImplType *>(type)->elementType;
-}
+Type MemRefType::getElementType() const { return getImpl()->elementType; }
 
 ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
-  return static_cast<ImplType *>(type)->getAffineMaps();
+  return getImpl()->getAffineMaps();
 }
 
-unsigned MemRefType::getMemorySpace() const {
-  return static_cast<ImplType *>(type)->memorySpace;
-}
+unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; }
 
 unsigned MemRefType::getNumDynamicDims() const {
   return llvm::count_if(getShape(), [](int64_t i) { return i < 0; });
index 197e669..a8f8c6c 100644 (file)
 using namespace mlir;
 using namespace mlir::detail;
 
-unsigned Type::getKind() const { return type->getKind(); }
+unsigned Type::getKind() const { return impl->getKind(); }
 
 /// Get the dialect this type is registered to.
-const Dialect &Type::getDialect() const { return type->getDialect(); }
+const Dialect &Type::getDialect() const { return impl->getDialect(); }
 
 MLIRContext *Type::getContext() const { return getDialect().getContext(); }
 
-unsigned Type::getSubclassData() const { return type->getSubclassData(); }
-void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
+unsigned Type::getSubclassData() const { return impl->getSubclassData(); }
+void Type::setSubclassData(unsigned val) { impl->setSubclassData(val); }
 
 /// Function Type.
 
index 88ba4ae..d3d4c60 100644 (file)
@@ -28,7 +28,7 @@ using namespace mlir::quant;
 using namespace mlir::quant::detail;
 
 unsigned QuantizedType::getFlags() const {
-  return static_cast<ImplType *>(type)->flags;
+  return static_cast<ImplType *>(impl)->flags;
 }
 
 LogicalResult QuantizedType::verifyConstructionInvariants(
@@ -75,25 +75,25 @@ LogicalResult QuantizedType::verifyConstructionInvariants(
 }
 
 Type QuantizedType::getStorageType() const {
-  return static_cast<ImplType *>(type)->storageType;
+  return static_cast<ImplType *>(impl)->storageType;
 }
 
 int64_t QuantizedType::getStorageTypeMin() const {
-  return static_cast<ImplType *>(type)->storageTypeMin;
+  return static_cast<ImplType *>(impl)->storageTypeMin;
 }
 
 int64_t QuantizedType::getStorageTypeMax() const {
-  return static_cast<ImplType *>(type)->storageTypeMax;
+  return static_cast<ImplType *>(impl)->storageTypeMax;
 }
 
 unsigned QuantizedType::getStorageTypeIntegralWidth() const {
   // NOTE: If ever supporting non-integral storage types, some other scheme
   // for determining the width will be needed.
-  return static_cast<ImplType *>(type)->storageType.getIntOrFloatBitWidth();
+  return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
 }
 
 Type QuantizedType::getExpressedType() const {
-  return static_cast<ImplType *>(type)->expressedType;
+  return static_cast<ImplType *>(impl)->expressedType;
 }
 
 bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {