[libc][NFC] Simplify op_generic
authorGuillaume Chatelet <gchatelet@google.com>
Fri, 7 Apr 2023 11:58:20 +0000 (11:58 +0000)
committerGuillaume Chatelet <gchatelet@google.com>
Fri, 7 Apr 2023 11:58:38 +0000 (11:58 +0000)
libc/src/string/memory_utils/op_generic.h

index 0a4a7cb..fd63ac6 100644 (file)
 
 namespace __llvm_libc::generic {
 
-// Implements generic load, store, splat and set for unsigned integral types.
-template <typename T> struct ScalarType {
-  static_assert(cpp::is_integral_v<T> && !cpp::is_signed_v<T>);
-  using Type = T;
-  static constexpr size_t SIZE = sizeof(Type);
-
-  LIBC_INLINE static Type load(CPtr src) {
-    return ::__llvm_libc::load<Type>(src);
-  }
-  LIBC_INLINE static void store(Ptr dst, Type value) {
-    ::__llvm_libc::store<Type>(dst, value);
-  }
-  LIBC_INLINE static Type splat(uint8_t value) {
-    return Type(~0) / Type(0xFF) * Type(value);
-  }
-  LIBC_INLINE static void set(Ptr dst, uint8_t value) {
-    store(dst, splat(value));
-  }
+// Compiler types using the vector attributes.
+using uint8x1_t = uint8_t __attribute__((__vector_size__(1)));
+using uint8x2_t = uint8_t __attribute__((__vector_size__(2)));
+using uint8x4_t = uint8_t __attribute__((__vector_size__(4)));
+using uint8x8_t = uint8_t __attribute__((__vector_size__(8)));
+using uint8x16_t = uint8_t __attribute__((__vector_size__(16)));
+using uint8x32_t = uint8_t __attribute__((__vector_size__(32)));
+using uint8x64_t = uint8_t __attribute__((__vector_size__(64)));
+
+// We accept three types of values as elements for generic operations:
+// - scalar : unsigned integral types
+// - vector : compiler types using the vector attributes
+// - array  : a cpp::array<T, N> where T is itself either a scalar or a vector.
+// The following traits help discriminate between these cases.
+
+template <typename T>
+constexpr bool is_scalar_v = cpp::is_integral_v<T> && cpp::is_unsigned_v<T>;
+
+template <typename T>
+constexpr bool is_vector_v =
+    cpp::details::is_unqualified_any_of<T, uint8x1_t, uint8x2_t, uint8x4_t,
+                                        uint8x8_t, uint8x16_t, uint8x32_t,
+                                        uint8x64_t>();
+
+template <class T> struct is_array : cpp::false_type {};
+template <class T, size_t N> struct is_array<cpp::array<T, N>> {
+  static constexpr bool value = is_scalar_v<T> || is_vector_v<T>;
 };
+template <typename T> constexpr bool is_array_v = is_array<T>::value;
 
-// GCC can only take literals as __vector_size__ argument so we have to use
-// template specialization.
-template <size_t Size> struct VectorValueType {};
-template <> struct VectorValueType<1> {
-  using type = uint8_t __attribute__((__vector_size__(1)));
-};
-template <> struct VectorValueType<2> {
-  using type = uint8_t __attribute__((__vector_size__(2)));
-};
-template <> struct VectorValueType<4> {
-  using type = uint8_t __attribute__((__vector_size__(4)));
-};
-template <> struct VectorValueType<8> {
-  using type = uint8_t __attribute__((__vector_size__(8)));
-};
-template <> struct VectorValueType<16> {
-  using type = uint8_t __attribute__((__vector_size__(16)));
-};
-template <> struct VectorValueType<32> {
-  using type = uint8_t __attribute__((__vector_size__(32)));
-};
-template <> struct VectorValueType<64> {
-  using type = uint8_t __attribute__((__vector_size__(64)));
-};
+template <typename T>
+constexpr bool is_element_type_v =
+    is_scalar_v<T> || is_vector_v<T> || is_array_v<T>;
 
-// Implements generic load, store, splat and set for vector types.
-template <size_t Size> struct VectorType {
-  using Type = typename VectorValueType<Size>::type;
-  static constexpr size_t SIZE = Size;
-  LIBC_INLINE static Type load(CPtr src) {
-    return ::__llvm_libc::load<Type>(src);
+//
+template <class T> struct array_size {};
+template <class T, size_t N>
+struct array_size<cpp::array<T, N>> : cpp::integral_constant<size_t, N> {};
+template <typename T> constexpr size_t array_size_v = array_size<T>::value;
+
+// Generic operations for the above type categories.
+
+template <typename T> T load(CPtr src) {
+  static_assert(is_element_type_v<T>);
+  if constexpr (is_scalar_v<T> || is_vector_v<T>) {
+    return ::__llvm_libc::load<T>(src);
+  } else if constexpr (is_array_v<T>) {
+    using value_type = typename T::value_type;
+    T Value;
+    for (size_t I = 0; I < array_size_v<T>; ++I)
+      Value[I] = load<value_type>(src + (I * sizeof(value_type)));
+    return Value;
   }
-  LIBC_INLINE static void store(Ptr dst, Type value) {
-    ::__llvm_libc::store<Type>(dst, value);
+}
+
+template <typename T> void store(Ptr dst, T value) {
+  static_assert(is_element_type_v<T>);
+  if constexpr (is_scalar_v<T> || is_vector_v<T>) {
+    ::__llvm_libc::store<T>(dst, value);
+  } else if constexpr (is_array_v<T>) {
+    using value_type = typename T::value_type;
+    for (size_t I = 0; I < array_size_v<T>; ++I)
+      store<value_type>(dst + (I * sizeof(value_type)), value[I]);
   }
-  LIBC_INLINE static Type splat(uint8_t value) {
-    Type Out;
+}
+
+template <typename T> T splat(uint8_t value) {
+  static_assert(is_scalar_v<T> || is_vector_v<T>);
+  if constexpr (is_scalar_v<T>)
+    return T(~0) / T(0xFF) * T(value);
+  else if constexpr (is_vector_v<T>) {
+    T Out;
     // This for loop is optimized out for vector types.
-    for (size_t i = 0; i < Size; ++i)
+    for (size_t i = 0; i < sizeof(T); ++i)
       Out[i] = static_cast<uint8_t>(value);
     return Out;
   }
-  LIBC_INLINE static void set(Ptr dst, uint8_t value) {
-    store(dst, splat(value));
-  }
-};
+}
 
-// Implements load, store and set for sizes not natively supported by the
-// platform. SubType is either ScalarType or VectorType.
-template <typename SubType, size_t ArraySize> struct ArrayType {
-  using Type = cpp::array<typename SubType::Type, ArraySize>;
-  static constexpr size_t SizeOfElement = SubType::SIZE;
-  static constexpr size_t SIZE = SizeOfElement * ArraySize;
-  LIBC_INLINE static Type load(CPtr src) {
-    Type Value;
-    for (size_t I = 0; I < ArraySize; ++I)
-      Value[I] = SubType::load(src + (I * SizeOfElement));
-    return Value;
-  }
-  LIBC_INLINE static void store(Ptr dst, Type Value) {
-    for (size_t I = 0; I < ArraySize; ++I)
-      SubType::store(dst + (I * SizeOfElement), Value[I]);
+template <typename T> void set(Ptr dst, uint8_t value) {
+  static_assert(is_element_type_v<T>);
+  if constexpr (is_scalar_v<T> || is_vector_v<T>) {
+    store<T>(dst, splat<T>(value));
+  } else if constexpr (is_array_v<T>) {
+    using value_type = typename T::value_type;
+    const value_type Splat = splat<value_type>(value);
+    for (size_t I = 0; I < array_size_v<T>; ++I)
+      store<value_type>(dst + (I * sizeof(value_type)), Splat);
   }
-  LIBC_INLINE static void set(Ptr dst, uint8_t value) {
-    const auto Splat = SubType::splat(value);
-    for (size_t I = 0; I < ArraySize; ++I)
-      SubType::store(dst + (I * SizeOfElement), Splat);
-  }
-};
+}
 
 static_assert((UINTPTR_MAX == 4294967295U) ||
                   (UINTPTR_MAX == 18446744073709551615UL),
@@ -134,20 +135,28 @@ static_assert((UINTPTR_MAX == 4294967295U) ||
 #endif
 
 namespace details {
-// Checks that each type's SIZE is sorted in strictly decreasing order.
-// i.e. First::SIZE > Second::SIZE > ... > Last::SIZE
+// Checks that each type is sorted in strictly decreasing order of size.
+// i.e. sizeof(First) > sizeof(Second) > ... > sizeof(Last)
+template <typename First> constexpr bool is_decreasing_size() {
+  return sizeof(First) == 1;
+}
 template <typename First, typename Second, typename... Next>
 constexpr bool is_decreasing_size() {
-  if constexpr (sizeof...(Next) > 0) {
-    return (First::SIZE > Second::SIZE) &&
-           is_decreasing_size<Second, Next...>();
-  } else {
-    return First::SIZE > Second::SIZE;
-  }
+  if constexpr (sizeof...(Next) > 0)
+    return sizeof(First) > sizeof(Second) && is_decreasing_size<Next...>();
+  else
+    return sizeof(First) > sizeof(Second) && is_decreasing_size<Second>();
 }
 
-// Helper to test if a type is void.
-template <typename T> inline constexpr bool is_void_v = cpp::is_same_v<T, void>;
+template <size_t Size, typename... Ts> struct Largest;
+template <size_t Size> struct Largest<Size> {
+  using type = uint8_t;
+};
+template <size_t Size, typename T, typename... Ts>
+struct Largest<Size, T, Ts...> {
+  using next = Largest<Size, Ts...>;
+  using type = cpp::conditional_t<(Size >= sizeof(T)), T, typename next::type>;
+};
 
 } // namespace details
 
@@ -160,26 +169,14 @@ template <typename T> inline constexpr bool is_void_v = cpp::is_same_v<T, void>;
 // using ST = SupportedTypes<ScalarType<uint16_t>, ScalarType<uint8_t>>;
 // using Type = ST::TypeFor<10>;
 // static_assert(cpp:is_same_v<Type, ScalarType<uint16_t>>);
-template <typename First, typename Second, typename... Next>
-struct SupportedTypes {
-  static_assert(details::is_decreasing_size<First, Second, Next...>());
-  using MaxType = First;
 
-  template <size_t Size>
-  using TypeFor = cpp::conditional_t<
-      (Size >= First::SIZE), First,
-      typename SupportedTypes<Second, Next...>::template TypeFor<Size>>;
-};
+template <typename First, typename... Ts> struct SupportedTypes {
+  static_assert(details::is_decreasing_size<First, Ts...>());
 
-template <typename First, typename Second>
-struct SupportedTypes<First, Second> {
-  static_assert(details::is_decreasing_size<First, Second>());
   using MaxType = First;
 
   template <size_t Size>
-  using TypeFor = cpp::conditional_t<
-      (Size >= First::SIZE), First,
-      cpp::conditional_t<(Size >= Second::SIZE), Second, void>>;
+  using TypeFor = typename details::Largest<Size, First, Ts...>::type;
 };
 
 // Map from sizes to structures offering static load, store and splat methods.
@@ -189,19 +186,21 @@ struct SupportedTypes<First, Second> {
 // Lists a generic native types to use for Memset and Memmove operations.
 // TODO: Inject the native types within Memset and Memmove depending on the
 // target architectures and derive MaxSize from it.
-using NativeTypeMap =
-    SupportedTypes<VectorType<64>, //
-                   VectorType<32>, //
-                   VectorType<16>,
+using NativeTypeMap = SupportedTypes<uint8x64_t, //
+                                     uint8x32_t, //
+                                     uint8x16_t,
 #if defined(LLVM_LIBC_HAS_UINT64)
-                   ScalarType<uint64_t>, // Not available on 32bit
+                                     uint64_t, // Not available on 32bit
 #endif
-                   ScalarType<uint32_t>, //
-                   ScalarType<uint16_t>, //
-                   ScalarType<uint8_t>>;
+                                     uint32_t, //
+                                     uint16_t, //
+                                     uint8_t>;
 
 namespace details {
 
+// Helper to test if a type is void.
+template <typename T> inline constexpr bool is_void_v = cpp::is_same_v<T, void>;
+
 // In case the 'Size' is not supported we can fall back to a sequence of smaller
 // operations using the largest natively supported type.
 template <size_t Size, size_t MaxSize> static constexpr bool useArrayType() {
@@ -214,12 +213,11 @@ template <size_t Size, size_t MaxSize> static constexpr bool useArrayType() {
 template <size_t Size, size_t MaxSize>
 using getTypeFor = cpp::conditional_t<
     useArrayType<Size, MaxSize>(),
-    ArrayType<NativeTypeMap::TypeFor<MaxSize>, Size / MaxSize>,
+    cpp::array<NativeTypeMap::TypeFor<MaxSize>, Size / MaxSize>,
     NativeTypeMap::TypeFor<Size>>;
 
 } // namespace details
 
-
 ///////////////////////////////////////////////////////////////////////////////
 // Memset
 // The MaxSize template argument gives the maximum size handled natively by the
@@ -240,7 +238,7 @@ template <size_t Size, size_t MaxSize> struct Memset {
       if constexpr (details::is_void_v<T>) {
         deferred_static_assert("Unimplemented Size");
       } else {
-        T::set(dst, value);
+        set<T>(dst, value);
       }
     }
   }
@@ -278,7 +276,7 @@ template <size_t Size, size_t MaxSize> struct Memmove {
     if constexpr (details::is_void_v<T>) {
       deferred_static_assert("Unimplemented Size");
     } else {
-      T::store(dst, T::load(src));
+      store<T>(dst, load<T>(src));
     }
   }
 
@@ -290,10 +288,10 @@ template <size_t Size, size_t MaxSize> struct Memmove {
       // The load and store operations can be performed in any order as long as
       // they are not interleaved. More investigations are needed to determine
       // the best order.
-      const auto head = T::load(src);
-      const auto tail = T::load(src + offset);
-      T::store(dst, head);
-      T::store(dst + offset, tail);
+      const auto head = load<T>(src);
+      const auto tail = load<T>(src + offset);
+      store<T>(dst, head);
+      store<T>(dst + offset, tail);
     }
   }
 
@@ -373,14 +371,14 @@ template <size_t Size, size_t MaxSize> struct Memmove {
                                                 size_t count) {
     static_assert(Size > 1, "a loop of size 1 does not need tail");
     const size_t tail_offset = count - Size;
-    const auto tail_value = T::load(src + tail_offset);
+    const auto tail_value = load<T>(src + tail_offset);
     size_t offset = 0;
     LIBC_LOOP_NOUNROLL
     do {
       block(dst + offset, src + offset);
       offset += Size;
     } while (offset < count - Size);
-    T::store(dst + tail_offset, tail_value);
+    store<T>(dst + tail_offset, tail_value);
   }
 
   // Move backward suitable when dst > src. We load the head bytes before
@@ -400,14 +398,14 @@ template <size_t Size, size_t MaxSize> struct Memmove {
   LIBC_INLINE static void loop_and_tail_backward(Ptr dst, CPtr src,
                                                  size_t count) {
     static_assert(Size > 1, "a loop of size 1 does not need tail");
-    const auto head_value = T::load(src);
+    const auto head_value = load<T>(src);
     ptrdiff_t offset = count - Size;
     LIBC_LOOP_NOUNROLL
     do {
       block(dst + offset, src + offset);
       offset -= Size;
     } while (offset >= 0);
-    T::store(dst, head_value);
+    store<T>(dst, head_value);
   }
 };