[mlir][DenseElementsAttr] Allow for custom floating point types in `getValues`
authorRiver Riddle <riddleriver@gmail.com>
Fri, 13 Nov 2020 06:43:06 +0000 (22:43 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Fri, 13 Nov 2020 06:47:30 +0000 (22:47 -0800)
Some users have native c++ data types that correspond to floating point values stored within a DenseElementsAttr that do not have a corresponding native C++ data type(e.g. bfloat16/half/etc). This revision allows for such users to use those native types directly, and removes the need to go through APFloat when the much faster native value path is available.

Differential Revision: https://reviews.llvm.org/D91402

mlir/include/mlir/IR/Attributes.h

index f024708..0432d05 100644 (file)
@@ -675,6 +675,21 @@ class DenseElementsAttr : public ElementsAttr {
 public:
   using ElementsAttr::ElementsAttr;
 
+  /// Type trait used to check if the given type T is a potentially valid C++
+  /// floating point type that can be used to access the underlying element
+  /// types of a DenseElementsAttr.
+  // TODO: Use std::disjunction when C++17 is supported.
+  template <typename T> struct is_valid_cpp_fp_type {
+    /// The type is a valid floating point type if it is a builtin floating
+    /// point type, or is a potentially user defined floating point type. The
+    /// latter allows for supporting users that have custom types defined for
+    /// bfloat16/half/etc.
+    static inline constexpr bool value =
+        llvm::is_one_of<T, float, double>::value ||
+        (std::numeric_limits<T>::is_specialized &&
+         !std::numeric_limits<T>::is_integer);
+  };
+
   /// Method for support type inquiry through isa, cast and dyn_cast.
   static bool classof(Attribute attr);
 
@@ -690,7 +705,7 @@ public:
   /// static shape.
   template <typename T, typename = typename std::enable_if<
                             std::numeric_limits<T>::is_integer ||
-                            llvm::is_one_of<T, float, double>::value>::type>
+                            is_valid_cpp_fp_type<T>::value>::type>
   static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
     const char *data = reinterpret_cast<const char *>(values.data());
     return getRawIntOrFloat(
@@ -701,7 +716,7 @@ public:
   /// Constructs a dense integer elements attribute from a single element.
   template <typename T, typename = typename std::enable_if<
                             std::numeric_limits<T>::is_integer ||
-                            llvm::is_one_of<T, float, double>::value ||
+                            is_valid_cpp_fp_type<T>::value ||
                             detail::is_complex_t<T>::value>::type>
   static DenseElementsAttr get(const ShapedType &type, T value) {
     return get(type, llvm::makeArrayRef(value));
@@ -714,7 +729,7 @@ public:
             typename = typename std::enable_if<
                 detail::is_complex_t<T>::value &&
                 (std::numeric_limits<ElementT>::is_integer ||
-                 llvm::is_one_of<ElementT, float, double>::value)>::type>
+                 is_valid_cpp_fp_type<ElementT>::value)>::type>
   static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
     const char *data = reinterpret_cast<const char *>(values.data());
     return getRawComplex(type, ArrayRef<char>(data, values.size() * sizeof(T)),
@@ -944,7 +959,7 @@ public:
   template <typename T, typename = typename std::enable_if<
                             (!std::is_same<T, bool>::value &&
                              std::numeric_limits<T>::is_integer) ||
-                            llvm::is_one_of<T, float, double>::value>::type>
+                            is_valid_cpp_fp_type<T>::value>::type>
   llvm::iterator_range<ElementIterator<T>> getValues() const {
     assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
                              std::numeric_limits<T>::is_signed));
@@ -959,7 +974,7 @@ public:
             typename = typename std::enable_if<
                 detail::is_complex_t<T>::value &&
                 (std::numeric_limits<ElementT>::is_integer ||
-                 llvm::is_one_of<ElementT, float, double>::value)>::type>
+                 is_valid_cpp_fp_type<ElementT>::value)>::type>
   llvm::iterator_range<ElementIterator<T>> getValues() const {
     assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
                           std::numeric_limits<ElementT>::is_signed));
@@ -1411,7 +1426,8 @@ private:
   template <typename T>
   typename std::enable_if<
       std::numeric_limits<T>::is_integer ||
-          llvm::is_one_of<T, float, double, StringRef>::value ||
+          DenseElementsAttr::is_valid_cpp_fp_type<T>::value ||
+          std::is_same<T, StringRef>::value ||
           (detail::is_complex_t<T>::value &&
            !llvm::is_one_of<T, std::complex<APInt>,
                             std::complex<APFloat>>::value),