[ BLAS ] Support non-4-divisible case in matrix transpose
authorskykongkong8 <ss.kong@samsung.com>
Mon, 13 May 2024 07:53:18 +0000 (16:53 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 4 Jun 2024 09:55:20 +0000 (18:55 +0900)
- Previously, there was a code defect when transposing matrix with non-4-divisible col length.
- Bugfix and refactor its using interface: move transpose fallback when NEON is supported.

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/tensor/blas_interface.cpp
nntrainer/tensor/matrix_transpose_neon/matrix_transpose_kernels_neon.h [new file with mode: 0644]
nntrainer/tensor/matrix_transpose_neon/matrix_transpose_neon.cpp
nntrainer/tensor/matrix_transpose_neon/matrix_transpose_neon.h
nntrainer/tensor/matrix_transpose_neon/meson.build
nntrainer/tensor/matrix_transpose_neon/transpose_utils_neon.h [deleted file]

index 8c6adffed04bde8bb28bb667ff440f2ef4359467..ac2085eaaf02e015a3e7cf550bb91031ac55f4ed 100644 (file)
 
 namespace nntrainer {
 
+template <typename T>
+static inline void transpose_fallback(
+    unsigned int M,
+    unsigned int N,
+    const T* src,
+    unsigned int ld_src,
+    T* dst,
+    unsigned int ld_dst) {
+  for (unsigned int j = 0; j < N; j++) {
+    for (unsigned int i = 0; i < M; i++) {
+      dst[i + j * ld_dst] = src[i * ld_src + j];
+    }
+  }
+}
+
 #ifdef ENABLE_FP16
 static void saxpy_FP16(const unsigned int N, const float alpha, const _FP16 *X,
                        const int incX, _FP16 *Y, const int incY) {
@@ -535,21 +550,9 @@ void transpose_matrix(const unsigned int M, const unsigned int N,
                       const _FP16 *src, unsigned int ld_src, _FP16 *dst,
                       unsigned int ld_dst) {
 #ifdef USE_NEON
-/// @note Final form of transpose_neon is NOT having fallback. Debugging WIP.
-  if ((M & 0x3) == 0) {
-    transpose_neon<_FP16>(M, N, src, ld_src, dst, ld_dst);
-  } else {
-    transpose_fallback<_FP16>(M, N, src, ld_src, dst, ld_dst);
-  }
+  transpose_neon<_FP16>(M, N, src, ld_src, dst, ld_dst);
 #else
-  /// @note This code should be replaced with:
-  /// transpose_fallback<_FP16>(M, N, src, ld_src, dst, ld_dst);
-  /// during arch-dep freeing refactorization.
-  for (unsigned int j = 0; j < N; j++) {
-    for (unsigned int i = 0; i < M; i++) {
-      dst[i + j * ld_dst] = src[i * ld_src + j];
-    }
-  }
+  transpose_fallback<_FP16>(M, N, src, ld_src, dst, ld_dst);
 #endif
 }
 #endif
diff --git a/nntrainer/tensor/matrix_transpose_neon/matrix_transpose_kernels_neon.h b/nntrainer/tensor/matrix_transpose_neon/matrix_transpose_kernels_neon.h
new file mode 100644 (file)
index 0000000..b8f7f2d
--- /dev/null
@@ -0,0 +1,212 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   matrix_transpose_kernels_neon.h
+ * @date   09 May 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  These are internal util functions for transposing matrix with NEON
+ *
+ */
+
+#include <arm_neon.h>
+#include <cassert>
+#include <cstdint>
+#include <mask_neon.h>
+
+#define TRANSPOSE_FP16_4x4(row0, row1, row2, row3)                             \
+  float16x4x2_t row01 = vtrn_f16(row0, row1);                                  \
+  float16x4x2_t row23 = vtrn_f16(row2, row3);                                  \
+  row0 = vcvt_f16_f32(vcombine_f32(vget_low_f32(vcvt_f32_f16(row01.val[0])),   \
+                                   vget_low_f32(vcvt_f32_f16(row23.val[0])))); \
+  row1 = vcvt_f16_f32(vcombine_f32(vget_low_f32(vcvt_f32_f16(row01.val[1])),   \
+                                   vget_low_f32(vcvt_f32_f16(row23.val[1])))); \
+  row2 =                                                                       \
+    vcvt_f16_f32(vcombine_f32(vget_high_f32(vcvt_f32_f16(row01.val[0])),       \
+                              vget_high_f32(vcvt_f32_f16(row23.val[0]))));     \
+  row3 =                                                                       \
+    vcvt_f16_f32(vcombine_f32(vget_high_f32(vcvt_f32_f16(row01.val[1])),       \
+                              vget_high_f32(vcvt_f32_f16(row23.val[1]))));
+
+static inline void transpose_kernel_4x4_neon(const __fp16 *src,
+                                             unsigned int ld_src, __fp16 *dst,
+                                             unsigned int ld_dst) {
+  float16x4_t a = vld1_f16(&src[0 * ld_src]);
+  float16x4_t b = vld1_f16(&src[1 * ld_src]);
+  float16x4_t c = vld1_f16(&src[2 * ld_src]);
+  float16x4_t d = vld1_f16(&src[3 * ld_src]);
+
+  TRANSPOSE_FP16_4x4(a, b, c, d);
+
+  vst1_f16(&dst[0 * ld_dst], a);
+  vst1_f16(&dst[1 * ld_dst], b);
+  vst1_f16(&dst[2 * ld_dst], c);
+  vst1_f16(&dst[3 * ld_dst], d);
+}
+
+template <unsigned int M>
+static void transpose_kernel_mxn_neon_128(unsigned int N, const __fp16 *src,
+                                          unsigned int ld_src, __fp16 *dst,
+                                          unsigned int ld_dst) {
+
+  uint16x4_t bitmask_v8 =
+    vld1_u16(reinterpret_cast<const uint16_t *>(masks[N]));
+  float16x4_t input[4];
+  float16x4_t ZEROS = vmov_n_f16(0.F);
+
+  unsigned i;
+  for (i = 0; i < M; ++i) {
+    input[i] = vbsl_f16(bitmask_v8, vld1_f16(&src[i * ld_src]), ZEROS);
+  }
+  for (; i < 4; ++i) {
+    input[i] = vmov_n_f16(0.F);
+  }
+
+  float16x4_t temp[4];
+  for (i = 0; i < (M + 1) / 2; ++i) {
+    temp[2 * i] = vzip1_f16(input[2 * i], input[2 * i + 1]);
+    temp[2 * i + 1] = vzip2_f16(input[2 * i], input[2 * i + 1]);
+  }
+  for (i = i * 2; i < 4; ++i) {
+    temp[i] = vmov_n_f16(0.F);
+  }
+
+  bitmask_v8 = vld1_u16(reinterpret_cast<const uint16_t *>(masks[M]));
+  for (i = 0; i < N; ++i) {
+    if (i % 2 == 0) {
+      input[i] =
+        vcvt_f16_f32(vcombine_f32(vget_low_f32(vcvt_f32_f16(temp[i / 2])),
+                                  vget_low_f32(vcvt_f32_f16(temp[2 + i / 2]))));
+    } else {
+      input[i] = vcvt_f16_f32(
+        vcombine_f32(vget_high_f32(vcvt_f32_f16(temp[i / 2])),
+                     vget_high_f32(vcvt_f32_f16(temp[2 + i / 2]))));
+    }
+    vst1_f16(&dst[i * ld_dst],
+             vbsl_f16(bitmask_v8, input[i], vld1_f16(&dst[i * ld_dst])));
+  }
+}
+
+static inline void transpose_kernel_8x8_neon(const __fp16 *src,
+                                             unsigned int ld_src, __fp16 *dst,
+                                             unsigned int ld_dst) {
+  float16x8_t a = vld1q_f16(&src[0 * ld_src]);
+  float16x8_t b = vld1q_f16(&src[1 * ld_src]);
+  float16x8_t c = vld1q_f16(&src[2 * ld_src]);
+  float16x8_t d = vld1q_f16(&src[3 * ld_src]);
+  float16x8_t e = vld1q_f16(&src[4 * ld_src]);
+  float16x8_t f = vld1q_f16(&src[5 * ld_src]);
+  float16x8_t g = vld1q_f16(&src[6 * ld_src]);
+  float16x8_t h = vld1q_f16(&src[7 * ld_src]);
+
+  float16x8_t ab0145, ab2367, cd0145, cd2367, ef0145, ef2367, gh0145, gh2367;
+  float16x8_t abcd04, abcd15, efgh04, efgh15, abcd26, abcd37, efgh26, efgh37;
+
+  ab0145 = vcombine_f16(vzip1_f16(vget_low_f16(a), vget_low_f16(b)),
+                        vzip1_f16(vget_high_f16(a), vget_high_f16(b)));
+  ab2367 = vcombine_f16(vzip2_f16(vget_low_f16(a), vget_low_f16(b)),
+                        vzip2_f16(vget_high_f16(a), vget_high_f16(b)));
+  cd0145 = vcombine_f16(vzip1_f16(vget_low_f16(c), vget_low_f16(d)),
+                        vzip1_f16(vget_high_f16(c), vget_high_f16(d)));
+  cd2367 = vcombine_f16(vzip2_f16(vget_low_f16(c), vget_low_f16(d)),
+                        vzip2_f16(vget_high_f16(c), vget_high_f16(d)));
+  ef0145 = vcombine_f16(vzip1_f16(vget_low_f16(e), vget_low_f16(f)),
+                        vzip1_f16(vget_high_f16(e), vget_high_f16(f)));
+  ef2367 = vcombine_f16(vzip2_f16(vget_low_f16(e), vget_low_f16(f)),
+                        vzip2_f16(vget_high_f16(e), vget_high_f16(f)));
+  gh0145 = vcombine_f16(vzip1_f16(vget_low_f16(g), vget_low_f16(h)),
+                        vzip1_f16(vget_high_f16(g), vget_high_f16(h)));
+  gh2367 = vcombine_f16(vzip2_f16(vget_low_f16(g), vget_low_f16(h)),
+                        vzip2_f16(vget_high_f16(g), vget_high_f16(h)));
+
+  uint16x8_t shuffle_mask =
+    vld1q_u16(reinterpret_cast<const uint16_t *>(shuffle_masks));
+  abcd04 = vbslq_f16(shuffle_mask, ab0145, vextq_f16(cd0145, cd0145, 6));
+  abcd15 = vbslq_f16(shuffle_mask, vextq_f16(ab0145, ab0145, 2), cd0145);
+
+  efgh04 = vbslq_f16(shuffle_mask, ef0145, vextq_f16(gh0145, gh0145, 6));
+  efgh15 = vbslq_f16(shuffle_mask, vextq_f16(ef0145, ef0145, 2), gh0145);
+
+  abcd26 = vbslq_f16(shuffle_mask, ab2367, vextq_f16(cd2367, cd2367, 6));
+  abcd37 = vbslq_f16(shuffle_mask, vextq_f16(ab2367, ab2367, 2), cd2367);
+
+  efgh26 = vbslq_f16(shuffle_mask, ef2367, vextq_f16(gh2367, gh2367, 6));
+  efgh37 = vbslq_f16(shuffle_mask, vextq_f16(ef2367, ef2367, 2), gh2367);
+
+  a = vcombine_f16(vget_low_f16(abcd04), vget_low_f16(efgh04));
+  b = vcombine_f16(vget_low_f16(abcd15), vget_low_f16(efgh15));
+  c = vcombine_f16(vget_low_f16(abcd26), vget_low_f16(efgh26));
+  d = vcombine_f16(vget_low_f16(abcd37), vget_low_f16(efgh37));
+  e = vcombine_f16(vget_high_f16(abcd04), vget_high_f16(efgh04));
+  f = vcombine_f16(vget_high_f16(abcd15), vget_high_f16(efgh15));
+  g = vcombine_f16(vget_high_f16(abcd26), vget_high_f16(efgh26));
+  h = vcombine_f16(vget_high_f16(abcd37), vget_high_f16(efgh37));
+
+  vst1q_f16(&dst[0 * ld_dst], a);
+  vst1q_f16(&dst[1 * ld_dst], b);
+  vst1q_f16(&dst[2 * ld_dst], c);
+  vst1q_f16(&dst[3 * ld_dst], d);
+  vst1q_f16(&dst[4 * ld_dst], e);
+  vst1q_f16(&dst[5 * ld_dst], f);
+  vst1q_f16(&dst[6 * ld_dst], g);
+  vst1q_f16(&dst[7 * ld_dst], h);
+}
+
+template <unsigned int M>
+static void transpose_kernel_mxn_neon_256(unsigned int N, const __fp16 *src,
+                                          unsigned int ld_src, __fp16 *dst,
+                                          unsigned int ld_dst) {
+  float16x8_t ZEROS = vmovq_n_f16(0.F);
+  uint16x8_t bitmask_v8 =
+    vld1q_u16(reinterpret_cast<const uint16_t *>(neon_16bit_masks[N]));
+  float16x8_t input[8];
+  unsigned i;
+  for (i = 0; i < M; ++i) {
+    input[i] = vbslq_f16(bitmask_v8, vld1q_f16(&src[i * ld_src]), ZEROS);
+  }
+  for (; i < 8; ++i) {
+    input[i] = ZEROS;
+  }
+  float16x8_t temp[8];
+  for (i = 0; i < (M + 1) / 2; ++i) {
+    temp[2 * i] = vcombine_f16(
+      vzip1_f16(vget_low_f16(input[2 * i]), vget_low_f16(input[2 * i + 1])),
+      vzip1_f16(vget_high_f16(input[2 * i]), vget_high_f16(input[2 * i + 1])));
+    temp[2 * i + 1] = vcombine_f16(
+      vzip2_f16(vget_low_f16(input[2 * i]), vget_low_f16(input[2 * i + 1])),
+      vzip2_f16(vget_high_f16(input[2 * i]), vget_high_f16(input[2 * i + 1])));
+  }
+  for (i = i * 2; i < 8; ++i) {
+    temp[i] = ZEROS;
+  }
+
+  uint16x8_t shuffle_mask =
+    vld1q_u16(reinterpret_cast<const uint16_t *>(shuffle_masks));
+  for (i = 0; i < (M + 3) / 4; ++i) {
+    input[4 * i] = vbslq_f16(shuffle_mask, temp[4 * i],
+                             vextq_f16(temp[4 * i + 2], temp[4 * i + 2], 6));
+    input[4 * i + 1] = vbslq_f16(
+      shuffle_mask, vextq_f16(temp[4 * i], temp[4 * i], 2), temp[4 * i + 2]);
+    input[4 * i + 2] =
+      vbslq_f16(shuffle_mask, temp[4 * i + 1],
+                vextq_f16(temp[4 * i + 3], temp[4 * i + 3], 6));
+    input[4 * i + 3] =
+      vbslq_f16(shuffle_mask, vextq_f16(temp[4 * i + 1], temp[4 * i + 1], 2),
+                temp[4 * i + 3]);
+  }
+  bitmask_v8 =
+    vld1q_u16(reinterpret_cast<const uint16_t *>(neon_16bit_masks[M]));
+  for (i = 0; i < N; ++i) {
+    if (i < 4) {
+      temp[i] =
+        vcombine_f16(vget_low_f16(input[i]), vget_low_f16(input[4 + i]));
+    } else {
+      temp[i] =
+        vcombine_f16(vget_high_f16(input[i - 4]), vget_high_f16(input[i]));
+    }
+    vst1q_f16(&dst[i * ld_dst],
+              vbslq_f16(bitmask_v8, temp[i], vld1q_f16(&dst[i * ld_dst])));
+  }
+}
index 7475229809e16a6018248ce29164a67e0baf1196..aee839caa188cb49650b7c293f61c681bdcd58ca 100644 (file)
  */
 
 #include <arm_neon.h>
-#include "./transpose_utils_neon.h"
-#include "./matrix_transpose_neon.h"
+#include <matrix_transpose_kernels_neon.h>
+#include <matrix_transpose_neon.h>
 
 template <>
-void transpose_fallback(
-    unsigned int M,
-    unsigned int N,
-    const __fp16* src,
-    unsigned int ld_src,
-    __fp16* dst,
-    unsigned int ld_dst) {
-  for (unsigned int j = 0; j < N; j++) {
-    for (unsigned int i = 0; i < M; i++) {
-      dst[i + j * ld_dst] = src[i * ld_src + j];
-    }
-  }
-}
-
-template <>
-void transpose_neon(
-    unsigned int M,
-    unsigned int N,
-    const __fp16* src,
-    unsigned int ld_src,
-    __fp16* dst,
-    unsigned int ld_dst) {
+void transpose_neon(unsigned int M, unsigned int N, const __fp16 *src,
+                    unsigned int ld_src, __fp16 *dst, unsigned int ld_dst) {
   unsigned int ib = 0, jb = 0;
   if (N % 8 > 0 && N % 8 < 4) {
     for (ib = 0; ib + 8 <= M; ib += 8) {
       for (jb = 0; jb + 8 <= N; jb += 8) {
-        transpose_kernel_8x8_neon(
-            &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+        transpose_kernel_8x8_neon(&src[ib * ld_src + jb], ld_src,
+                                  &dst[ib + jb * ld_dst], ld_dst);
       }
       for (unsigned int i = ib; i < ib + 8; i += 4) {
-        transpose_kernel_mxn_neon_128<4>(
-            N - jb,
-            &src[i * ld_src + jb],
-            ld_src,
-            &dst[i + jb * ld_dst],
-            ld_dst);
+        transpose_kernel_mxn_neon_128<4>(N - jb, &src[i * ld_src + jb], ld_src,
+                                         &dst[i + jb * ld_dst], ld_dst);
       }
     }
   } else if (N % 8 == 4) {
     for (ib = 0; ib + 8 <= M; ib += 8) {
       for (jb = 0; jb + 8 <= N; jb += 8) {
-        transpose_kernel_8x8_neon(
-            &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+        transpose_kernel_8x8_neon(&src[ib * ld_src + jb], ld_src,
+                                  &dst[ib + jb * ld_dst], ld_dst);
       }
       for (unsigned int i = ib; i < ib + 8; i += 4) {
-        transpose_kernel_4x4_neon(
-            &src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst);
+        transpose_kernel_4x4_neon(&src[i * ld_src + jb], ld_src,
+                                  &dst[i + jb * ld_dst], ld_dst);
       }
     }
   } else {
     for (ib = 0; ib + 8 <= M; ib += 8) {
       for (jb = 0; jb + 8 <= N; jb += 8) {
-        transpose_kernel_8x8_neon(
-            &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+        transpose_kernel_8x8_neon(&src[ib * ld_src + jb], ld_src,
+                                  &dst[ib + jb * ld_dst], ld_dst);
       }
       if (jb < N) {
-        transpose_kernel_mxn_neon_256<8>(
-            N - jb,
-            &src[ib * ld_src + jb],
-            ld_src,
-            &dst[ib + jb * ld_dst],
-            ld_dst);
+        transpose_kernel_mxn_neon_256<8>(N - jb, &src[ib * ld_src + jb], ld_src,
+                                         &dst[ib + jb * ld_dst], ld_dst);
       }
     }
   }
   switch (M - ib) {
-    case 1:
-      for (unsigned int j = 0; j < N; ++j) {
-        dst[ib + j * ld_dst] = src[ib * ld_src + j];
-      }
-      break;
-    case 2:
-      for (jb = 0; jb + 4 <= N; jb += 4) {
-        transpose_kernel_mxn_neon_128<2>(
-            4, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
-      }
-      if (jb < N) {
-        transpose_kernel_mxn_neon_128<2>(
-            N - jb,
-            &src[ib * ld_src + jb],
-            ld_src,
-            &dst[ib + jb * ld_dst],
-            ld_dst);
-      }
-      break;
-    case 3:
-      for (jb = 0; jb + 4 <= N; jb += 4) {
-        transpose_kernel_mxn_neon_128<3>(
-            4, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
-      }
-      if (jb < N) {
-        transpose_kernel_mxn_neon_128<3>(
-            N - jb,
-            &src[ib * ld_src + jb],
-            ld_src,
-            &dst[ib + jb * ld_dst],
-            ld_dst);
-      }
-      break;
-    case 4:
-      for (jb = 0; jb + 4 <= N; jb += 4) {
-        transpose_kernel_4x4_neon(
-            &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
-      }
-      if (jb < N) {
-        transpose_kernel_mxn_neon_128<4>(
-            N - jb,
-            &src[ib * ld_src + jb],
-            ld_src,
-            &dst[ib + jb * ld_dst],
-            ld_dst);
-      }
-      break;
-    case 5:
-      for (jb = 0; jb + 8 <= N; jb += 8) {
-        transpose_kernel_mxn_neon_256<5>(
-            8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
-      }
-      if (jb < N) {
-        transpose_kernel_mxn_neon_256<5>(
-            N - jb,
-            &src[ib * ld_src + jb],
-            ld_src,
-            &dst[ib + jb * ld_dst],
-            ld_dst);
-      }
-      break;
-    case 6:
-      for (jb = 0; jb + 8 <= N; jb += 8) {
-        transpose_kernel_mxn_neon_256<6>(
-            8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
-      }
-      if (jb < N) {
-        transpose_kernel_mxn_neon_256<6>(
-            N - jb,
-            &src[ib * ld_src + jb],
-            ld_src,
-            &dst[ib + jb * ld_dst],
-            ld_dst);
-      }
-      break;
-    case 7:
-      for (jb = 0; jb + 8 <= N; jb += 8) {
-        transpose_kernel_mxn_neon_256<7>(
-            8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
-      }
-      if (jb < N) {
-        transpose_kernel_mxn_neon_256<7>(
-            N - jb,
-            &src[ib * ld_src + jb],
-            ld_src,
-            &dst[ib + jb * ld_dst],
-            ld_dst);
-      }
-      break;
+  case 1:
+    for (unsigned int j = 0; j < N; ++j) {
+      dst[ib + j * ld_dst] = src[ib * ld_src + j];
+    }
+    break;
+  case 2:
+    for (jb = 0; jb + 4 <= N; jb += 4) {
+      transpose_kernel_mxn_neon_128<2>(4, &src[ib * ld_src + jb], ld_src,
+                                       &dst[ib + jb * ld_dst], ld_dst);
+    }
+    if (jb < N) {
+      transpose_kernel_mxn_neon_128<2>(N - jb, &src[ib * ld_src + jb], ld_src,
+                                       &dst[ib + jb * ld_dst], ld_dst);
+    }
+    break;
+  case 3:
+    for (jb = 0; jb + 4 <= N; jb += 4) {
+      transpose_kernel_mxn_neon_128<3>(4, &src[ib * ld_src + jb], ld_src,
+                                       &dst[ib + jb * ld_dst], ld_dst);
+    }
+    if (jb < N) {
+      transpose_kernel_mxn_neon_128<3>(N - jb, &src[ib * ld_src + jb], ld_src,
+                                       &dst[ib + jb * ld_dst], ld_dst);
+    }
+    break;
+  case 4:
+    for (jb = 0; jb + 4 <= N; jb += 4) {
+      transpose_kernel_4x4_neon(&src[ib * ld_src + jb], ld_src,
+                                &dst[ib + jb * ld_dst], ld_dst);
+    }
+    if (jb < N) {
+      transpose_kernel_mxn_neon_128<4>(N - jb, &src[ib * ld_src + jb], ld_src,
+                                       &dst[ib + jb * ld_dst], ld_dst);
+    }
+    break;
+  case 5:
+    for (jb = 0; jb + 8 <= N; jb += 8) {
+      transpose_kernel_mxn_neon_256<5>(8, &src[ib * ld_src + jb], ld_src,
+                                       &dst[ib + jb * ld_dst], ld_dst);
+    }
+    if (jb < N) {
+      transpose_kernel_mxn_neon_256<5>(N - jb, &src[ib * ld_src + jb], ld_src,
+                                       &dst[ib + jb * ld_dst], ld_dst);
+    }
+    break;
+  case 6:
+    for (jb = 0; jb + 8 <= N; jb += 8) {
+      transpose_kernel_mxn_neon_256<6>(8, &src[ib * ld_src + jb], ld_src,
+                                       &dst[ib + jb * ld_dst], ld_dst);
+    }
+    if (jb < N) {
+      transpose_kernel_mxn_neon_256<6>(N - jb, &src[ib * ld_src + jb], ld_src,
+                                       &dst[ib + jb * ld_dst], ld_dst);
+    }
+    break;
+  case 7:
+    for (jb = 0; jb + 8 <= N; jb += 8) {
+      transpose_kernel_mxn_neon_256<7>(8, &src[ib * ld_src + jb], ld_src,
+                                       &dst[ib + jb * ld_dst], ld_dst);
+    }
+    if (jb < N) {
+      transpose_kernel_mxn_neon_256<7>(N - jb, &src[ib * ld_src + jb], ld_src,
+                                       &dst[ib + jb * ld_dst], ld_dst);
+    }
+    break;
   }
 }
index 1f9d3723a2cb9857d1276f8bf0e328f81db8955b..0b70c4974943e1df78cab46015ed308f60644f4c 100644 (file)
  *
  */
 
-#include <cstdint>
-
-/**
- * @brief Matrix Transpose fallback. Note that this transposes a matrix, which
- * is 2D Tensor
- *
- * @tparam T dataType of the incoming matrix. Implement more Kernels and connect
- * to this function in order to support more datatypes.
- * @param M row length of input matrix
- * @param N col length of input matrix
- * @param src source data of input matrix
- * @param ld_src data offset of input matrix
- * @param dst destination data of this function
- * @param ld_dst data offset of output matrix 
- */
-template <typename T>
-void transpose_fallback(unsigned int M, unsigned int N, const T *src,
-                   unsigned int ld_src, T *dst, unsigned int ld_dst);
-
 /**
  * @brief Matrix Transpose using NEON. Note that this transposes a matrix, which
  * is 2D Tensor
@@ -41,7 +22,7 @@ void transpose_fallback(unsigned int M, unsigned int N, const T *src,
  * @param src source data of input matrix
  * @param ld_src data offset of input matrix
  * @param dst destination data of this function
- * @param ld_dst data offset of output matrix 
+ * @param ld_dst data offset of output matrix
  */
 template <typename T>
 void transpose_neon(unsigned int M, unsigned int N, const T *src,
index d0cd3081c1515789763dabdcf204ae28704f581d..7fc75512b3de6da9bf47b8f30c7bcb46ce1e4ab8 100644 (file)
@@ -3,7 +3,9 @@ matrix_transpose_neon_sources = [
 ]
 
 matrix_transpose_neon_headers = [
-
+    'mask_neon.h',
+    'matrix_transpose_kernels_neon.h',
+    'matrix_transpose_neon.h',
 ]
 
 foreach s : matrix_transpose_neon_sources
diff --git a/nntrainer/tensor/matrix_transpose_neon/transpose_utils_neon.h b/nntrainer/tensor/matrix_transpose_neon/transpose_utils_neon.h
deleted file mode 100644 (file)
index b911fb7..0000000
+++ /dev/null
@@ -1,210 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
- *
- * @file   transpose_utils_neon.h
- * @date   09 May 2024
- * @see    https://github.com/nnstreamer/nntrainer
- * @author Sungsik Kong <ss.kong@samsung.com>
- * @bug    No known bugs except for NYI items
- * @brief  These are internal util functions for transposing matrix with NEON
- *
- */
-
-#include <arm_neon.h>
-#include <cassert>
-#include <cstdint>
-#include "./mask_neon.h"
-
-#define TRANSPOSE_FP16_4x4(row0, row1, row2, row3)                             \
-  float16x4x2_t row01 = vtrn_f16(row0, row1);                                  \
-  float16x4x2_t row23 = vtrn_f16(row2, row3);                                  \
-  row0 = vcvt_f16_f32(vcombine_f32(vget_low_f32(vcvt_f32_f16(row01.val[0])),   \
-                                   vget_low_f32(vcvt_f32_f16(row23.val[0])))); \
-  row1 = vcvt_f16_f32(vcombine_f32(vget_low_f32(vcvt_f32_f16(row01.val[1])),   \
-                                   vget_low_f32(vcvt_f32_f16(row23.val[1])))); \
-  row2 =                                                                       \
-    vcvt_f16_f32(vcombine_f32(vget_high_f32(vcvt_f32_f16(row01.val[0])),       \
-                              vget_high_f32(vcvt_f32_f16(row23.val[0]))));     \
-  row3 =                                                                       \
-    vcvt_f16_f32(vcombine_f32(vget_high_f32(vcvt_f32_f16(row01.val[1])),       \
-                              vget_high_f32(vcvt_f32_f16(row23.val[1]))));
-
-
-static inline void transpose_kernel_4x4_neon(const __fp16 *src,
-                                             unsigned int ld_src, __fp16 *dst,
-                                             unsigned int ld_dst) {
-  float16x4_t a = vld1_f16(&src[0 * ld_src]);
-  float16x4_t b = vld1_f16(&src[1 * ld_src]);
-  float16x4_t c = vld1_f16(&src[2 * ld_src]);
-  float16x4_t d = vld1_f16(&src[3 * ld_src]);
-
-  TRANSPOSE_FP16_4x4(a, b, c, d);
-
-  vst1_f16(&dst[0 * ld_dst], a);
-  vst1_f16(&dst[1 * ld_dst], b);
-  vst1_f16(&dst[2 * ld_dst], c);
-  vst1_f16(&dst[3 * ld_dst], d);
-}
-
-template <unsigned int M>
-static void transpose_kernel_mxn_neon_128(unsigned int N, const __fp16 *src,
-                                          unsigned int ld_src, __fp16 *dst,
-                                          unsigned int ld_dst) {
-
-
-  uint16x4_t bitmask_v8 = vld1_u16(reinterpret_cast<const uint16_t *>(masks[N]));
-  float16x4_t input[4];
-  float16x4_t ZEROS = vmov_n_f16(0.F);
-
-  unsigned i;
-  for (i = 0; i < M; ++i) {
-    input[i] = vbsl_f16(bitmask_v8, vld1_f16(&src[i * ld_src]), ZEROS);
-  }
-  for (; i < 4; ++i) {
-    input[i] = vmov_n_f16(0.F);
-  }
-
-  float16x4_t temp[4];
-  for (i = 0; i < (M + 1) / 2; ++i) {
-    temp[2 * i] = vzip1_f16(input[2 * i], input[2 * i + 1]);
-    temp[2 * i + 1] = vzip2_f16(input[2 * i], input[2 * i + 1]);
-  }
-  for (i = i * 2; i < 4; ++i) {
-    temp[i] = vmov_n_f16(0.F);
-  }
-
-  bitmask_v8 = vld1_u16(reinterpret_cast<const uint16_t *>(masks[M]));
-  for (i = 0; i < N; ++i) {
-    if (i % 2 == 0) {
-      input[i] =
-        vcvt_f16_f32(vcombine_f32(vget_low_f32(vcvt_f32_f16(temp[i / 2])),
-                                  vget_low_f32(vcvt_f32_f16(temp[2 + i / 2]))));
-    } else {
-      input[i] =
-        vcvt_f16_f32(vcombine_f32(vget_high_f32(vcvt_f32_f16(temp[i / 2])),
-                                  vget_high_f32(vcvt_f32_f16(temp[2 + i / 2]))));
-    }
-    vst1_f16(&dst[i * ld_dst], vbsl_f16(bitmask_v8, input[i], ZEROS));
-  }
-
-}
-
-static inline void transpose_kernel_8x8_neon(const __fp16 *src,
-                                             unsigned int ld_src, __fp16 *dst,
-                                             unsigned int ld_dst) {
-  float16x8_t a = vld1q_f16(&src[0 * ld_src]);
-  float16x8_t b = vld1q_f16(&src[1 * ld_src]);
-  float16x8_t c = vld1q_f16(&src[2 * ld_src]);
-  float16x8_t d = vld1q_f16(&src[3 * ld_src]);
-  float16x8_t e = vld1q_f16(&src[4 * ld_src]);
-  float16x8_t f = vld1q_f16(&src[5 * ld_src]);
-  float16x8_t g = vld1q_f16(&src[6 * ld_src]);
-  float16x8_t h = vld1q_f16(&src[7 * ld_src]);
-
-  float16x8_t ab0145, ab2367, cd0145, cd2367, ef0145, ef2367, gh0145, gh2367;
-  float16x8_t abcd04, abcd15, efgh04, efgh15, abcd26, abcd37, efgh26, efgh37;
-
-  ab0145 = vcombine_f16(vzip1_f16(vget_low_f16(a), vget_low_f16(b)),
-                        vzip1_f16(vget_high_f16(a), vget_high_f16(b)));
-  ab2367 = vcombine_f16(vzip2_f16(vget_low_f16(a), vget_low_f16(b)),
-                        vzip2_f16(vget_high_f16(a), vget_high_f16(b)));
-  cd0145 = vcombine_f16(vzip1_f16(vget_low_f16(c), vget_low_f16(d)),
-                        vzip1_f16(vget_high_f16(c), vget_high_f16(d)));
-  cd2367 = vcombine_f16(vzip2_f16(vget_low_f16(c), vget_low_f16(d)),
-                        vzip2_f16(vget_high_f16(c), vget_high_f16(d)));
-  ef0145 = vcombine_f16(vzip1_f16(vget_low_f16(e), vget_low_f16(f)),
-                        vzip1_f16(vget_high_f16(e), vget_high_f16(f)));
-  ef2367 = vcombine_f16(vzip2_f16(vget_low_f16(e), vget_low_f16(f)),
-                        vzip2_f16(vget_high_f16(e), vget_high_f16(f)));
-  gh0145 = vcombine_f16(vzip1_f16(vget_low_f16(g), vget_low_f16(h)),
-                        vzip1_f16(vget_high_f16(g), vget_high_f16(h)));
-  gh2367 = vcombine_f16(vzip2_f16(vget_low_f16(g), vget_low_f16(h)),
-                        vzip2_f16(vget_high_f16(g), vget_high_f16(h)));
-
-  uint16x8_t shuffle_mask =
-    vld1q_u16(reinterpret_cast<const uint16_t *>(shuffle_masks));
-  abcd04 = vbslq_f16(shuffle_mask, ab0145, vextq_f16(cd0145, cd0145, 6));
-  abcd15 = vbslq_f16(shuffle_mask, vextq_f16(ab0145, ab0145, 2), cd0145);
-  
-  efgh04 = vbslq_f16(shuffle_mask, ef0145, vextq_f16(gh0145, gh0145, 6));
-  efgh15 = vbslq_f16(shuffle_mask, vextq_f16(ef0145, ef0145, 2), gh0145);
-
-  abcd26 = vbslq_f16(shuffle_mask, ab2367, vextq_f16(cd2367, cd2367, 6));
-  abcd37 = vbslq_f16(shuffle_mask, vextq_f16(ab2367, ab2367, 2), cd2367);
-
-  efgh26 = vbslq_f16(shuffle_mask, ef2367, vextq_f16(gh2367, gh2367, 6));
-  efgh37 = vbslq_f16(shuffle_mask, vextq_f16(ef2367, ef2367, 2), gh2367);
-
-  a = vcombine_f16(vget_low_f16(abcd04), vget_low_f16(efgh04));
-  b = vcombine_f16(vget_low_f16(abcd15), vget_low_f16(efgh15));
-  c = vcombine_f16(vget_low_f16(abcd26), vget_low_f16(efgh26));
-  d = vcombine_f16(vget_low_f16(abcd37), vget_low_f16(efgh37));
-  e = vcombine_f16(vget_high_f16(abcd04), vget_high_f16(efgh04));
-  f = vcombine_f16(vget_high_f16(abcd15), vget_high_f16(efgh15));
-  g = vcombine_f16(vget_high_f16(abcd26), vget_high_f16(efgh26));
-  h = vcombine_f16(vget_high_f16(abcd37), vget_high_f16(efgh37));
-
-  vst1q_f16(&dst[0 * ld_dst], a);
-  vst1q_f16(&dst[1 * ld_dst], b);
-  vst1q_f16(&dst[2 * ld_dst], c);
-  vst1q_f16(&dst[3 * ld_dst], d);
-  vst1q_f16(&dst[4 * ld_dst], e);
-  vst1q_f16(&dst[5 * ld_dst], f);
-  vst1q_f16(&dst[6 * ld_dst], g);
-  vst1q_f16(&dst[7 * ld_dst], h);
-}
-
-template <unsigned int M>
-static void transpose_kernel_mxn_neon_256(unsigned int N, const __fp16 *src,
-                                          unsigned int ld_src, __fp16 *dst,
-                                          unsigned int ld_dst) {
-  float16x8_t ZEROS = vmovq_n_f16(0.F);
-  uint16x8_t bitmask_v8 =
-    vld1q_u16(reinterpret_cast<const uint16_t *>(neon_16bit_masks[N]));
-  float16x8_t input[8];
-  unsigned i;
-  for (i = 0; i < M; ++i) {
-    input[i] = vbslq_f16(bitmask_v8, vld1q_f16(&src[i * ld_src]), ZEROS);
-  }
-  for (; i < 8; ++i) {
-    input[i] = ZEROS;
-  }
-  float16x8_t temp[8];
-  for (i = 0; i < (M + 1) / 2; ++i) {
-    temp[2 * i] = vcombine_f16(
-      vzip1_f16(vget_low_f16(input[2 * i]), vget_low_f16(input[2 * i + 1])),
-      vzip1_f16(vget_high_f16(input[2 * i]), vget_high_f16(input[2 * i + 1])));
-    temp[2 * i + 1] = vcombine_f16(
-      vzip2_f16(vget_low_f16(input[2 * i]), vget_low_f16(input[2 * i + 1])),
-      vzip2_f16(vget_high_f16(input[2 * i]), vget_high_f16(input[2 * i + 1])));
-  }
-  for (i = i * 2; i < 8; ++i) {
-    temp[i] = ZEROS;
-  }
-
-  uint16x8_t shuffle_mask =
-    vld1q_u16(reinterpret_cast<const uint16_t *>(shuffle_masks));
-  for (i = 0; i < (M + 3) / 4; ++i) {
-    input[4 * i] = vbslq_f16(shuffle_mask, temp[4 * i],
-                             vextq_f16(temp[4 * i + 2], temp[4 * i + 2], 6));
-    input[4 * i + 1] = vbslq_f16(
-      shuffle_mask, vextq_f16(temp[4 * i], temp[4 * i], 2), temp[4 * i + 2]);
-    input[4 * i + 2] =
-      vbslq_f16(shuffle_mask, temp[4 * i + 1],
-                vextq_f16(temp[4 * i + 3], temp[4 * i + 3], 6));
-    input[4 * i + 3] =
-      vbslq_f16(shuffle_mask, vextq_f16(temp[4 * i + 1], temp[4 * i + 1], 2),
-                temp[4 * i + 3]);
-  }
-  bitmask_v8 = vld1q_u16(
-    reinterpret_cast<const uint16_t *>(neon_16bit_masks[M]));
-  for (i = 0; i < N; ++i) {
-    if (i < 4) {
-      temp[i] = vcombine_f16(vget_low_f16(input[i]), vget_low_f16(input[4 + i]));
-    } else {
-      temp[i] = vcombine_f16(vget_high_f16(input[i - 4]), vget_high_f16(input[i])); 
-    }
-    vst1q_f16(&dst[i * ld_dst], vbslq_f16(bitmask_v8, temp[i], ZEROS));
-  }
-}