From 2e71dad3328e03337eca53352e7d45b6efc7e0a2 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 12 Nov 2020 22:43:06 -0800 Subject: [PATCH] [mlir][DenseElementsAttr] Allow for custom floating point types in `getValues` Some users have native c++ data types that correspond to floating point values stored within a DenseElementsAttr that do not have a corresponding native C++ data type(e.g. bfloat16/half/etc). This revision allows for such users to use those native types directly, and removes the need to go through APFloat when the much faster native value path is available. Differential Revision: https://reviews.llvm.org/D91402 --- mlir/include/mlir/IR/Attributes.h | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index f024708..0432d05 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -675,6 +675,21 @@ class DenseElementsAttr : public ElementsAttr { public: using ElementsAttr::ElementsAttr; + /// Type trait used to check if the given type T is a potentially valid C++ + /// floating point type that can be used to access the underlying element + /// types of a DenseElementsAttr. + // TODO: Use std::disjunction when C++17 is supported. + template struct is_valid_cpp_fp_type { + /// The type is a valid floating point type if it is a builtin floating + /// point type, or is a potentially user defined floating point type. The + /// latter allows for supporting users that have custom types defined for + /// bfloat16/half/etc. + static inline constexpr bool value = + llvm::is_one_of::value || + (std::numeric_limits::is_specialized && + !std::numeric_limits::is_integer); + }; + /// Method for support type inquiry through isa, cast and dyn_cast. static bool classof(Attribute attr); @@ -690,7 +705,7 @@ public: /// static shape. template ::is_integer || - llvm::is_one_of::value>::type> + is_valid_cpp_fp_type::value>::type> static DenseElementsAttr get(const ShapedType &type, ArrayRef values) { const char *data = reinterpret_cast(values.data()); return getRawIntOrFloat( @@ -701,7 +716,7 @@ public: /// Constructs a dense integer elements attribute from a single element. template ::is_integer || - llvm::is_one_of::value || + is_valid_cpp_fp_type::value || detail::is_complex_t::value>::type> static DenseElementsAttr get(const ShapedType &type, T value) { return get(type, llvm::makeArrayRef(value)); @@ -714,7 +729,7 @@ public: typename = typename std::enable_if< detail::is_complex_t::value && (std::numeric_limits::is_integer || - llvm::is_one_of::value)>::type> + is_valid_cpp_fp_type::value)>::type> static DenseElementsAttr get(const ShapedType &type, ArrayRef values) { const char *data = reinterpret_cast(values.data()); return getRawComplex(type, ArrayRef(data, values.size() * sizeof(T)), @@ -944,7 +959,7 @@ public: template ::value && std::numeric_limits::is_integer) || - llvm::is_one_of::value>::type> + is_valid_cpp_fp_type::value>::type> llvm::iterator_range> getValues() const { assert(isValidIntOrFloat(sizeof(T), std::numeric_limits::is_integer, std::numeric_limits::is_signed)); @@ -959,7 +974,7 @@ public: typename = typename std::enable_if< detail::is_complex_t::value && (std::numeric_limits::is_integer || - llvm::is_one_of::value)>::type> + is_valid_cpp_fp_type::value)>::type> llvm::iterator_range> getValues() const { assert(isValidComplex(sizeof(T), std::numeric_limits::is_integer, std::numeric_limits::is_signed)); @@ -1411,7 +1426,8 @@ private: template typename std::enable_if< std::numeric_limits::is_integer || - llvm::is_one_of::value || + DenseElementsAttr::is_valid_cpp_fp_type::value || + std::is_same::value || (detail::is_complex_t::value && !llvm::is_one_of, std::complex>::value), -- 2.7.4