[ matrix_transpose_neon ] Implement NEON-accelereated half-precision matrix transpose
authorskykongkong8 <ss.kong@samsung.com>
Thu, 9 May 2024 23:51:13 +0000 (08:51 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 4 Jun 2024 09:55:20 +0000 (18:55 +0900)
- Previously, matrix transpose was relying on naive for-loop implementaion.
- Using SIMD instructions, there is a room to be latency-optimized.
- Note that current implementation only supports half-precision matrices.

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/tensor/matrix_transpose_neon/mask_neon.h [new file with mode: 0644]
nntrainer/tensor/matrix_transpose_neon/matrix_transpose_neon.cpp [new file with mode: 0644]
nntrainer/tensor/matrix_transpose_neon/matrix_transpose_neon.h [new file with mode: 0644]
nntrainer/tensor/matrix_transpose_neon/meson.build [new file with mode: 0644]
nntrainer/tensor/matrix_transpose_neon/transpose_utils_neon.h [new file with mode: 0644]

diff --git a/nntrainer/tensor/matrix_transpose_neon/mask_neon.h b/nntrainer/tensor/matrix_transpose_neon/mask_neon.h
new file mode 100644 (file)
index 0000000..35ca5b6
--- /dev/null
@@ -0,0 +1,40 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   mask_neon.h
+ * @date   09 May 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is mask shaped filters to distill certain indices using SIMD
+ *
+ */
+
+#include <cstdint>
+
+// clang-format off
+alignas(16) static const int16_t neon_16bit_masks[9][8] = {
+  {  0,  0,  0,  0,  0,  0,  0,  0,  },
+  { -1,  0,  0,  0,  0,  0,  0,  0,  },
+  { -1, -1,  0,  0,  0,  0,  0,  0,  },
+  { -1, -1, -1,  0,  0,  0,  0,  0,  },
+  { -1, -1, -1, -1,  0,  0,  0,  0,  },
+  { -1, -1, -1, -1, -1,  0,  0,  0,  },
+  { -1, -1, -1, -1, -1, -1,  0,  0,  },
+  { -1, -1, -1, -1, -1, -1, -1,  0,  },
+  { -1, -1, -1, -1, -1, -1, -1, -1,  },
+};
+
+alignas(16) static const int16_t masks[5][4] = {
+  {  0,  0,  0,  0, },
+  { -1,  0,  0,  0, },
+  { -1, -1,  0,  0, },
+  { -1, -1, -1,  0, },
+  { -1, -1, -1, -1, },
+};
+
+alignas(16) static const int16_t shuffle_masks[8] = {
+    -1, -1, 0, 0, -1, -1,  0,  0, 
+};
+// clang-format on
diff --git a/nntrainer/tensor/matrix_transpose_neon/matrix_transpose_neon.cpp b/nntrainer/tensor/matrix_transpose_neon/matrix_transpose_neon.cpp
new file mode 100644 (file)
index 0000000..7475229
--- /dev/null
@@ -0,0 +1,175 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   matrix_transpose_neon.cpp
+ * @date   09 May 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is source file for matrix transpose using NEON
+ *
+ */
+
+#include <arm_neon.h>
+#include "./transpose_utils_neon.h"
+#include "./matrix_transpose_neon.h"
+
+template <>
+void transpose_fallback(
+    unsigned int M,
+    unsigned int N,
+    const __fp16* src,
+    unsigned int ld_src,
+    __fp16* dst,
+    unsigned int ld_dst) {
+  for (unsigned int j = 0; j < N; j++) {
+    for (unsigned int i = 0; i < M; i++) {
+      dst[i + j * ld_dst] = src[i * ld_src + j];
+    }
+  }
+}
+
+template <>
+void transpose_neon(
+    unsigned int M,
+    unsigned int N,
+    const __fp16* src,
+    unsigned int ld_src,
+    __fp16* dst,
+    unsigned int ld_dst) {
+  unsigned int ib = 0, jb = 0;
+  if (N % 8 > 0 && N % 8 < 4) {
+    for (ib = 0; ib + 8 <= M; ib += 8) {
+      for (jb = 0; jb + 8 <= N; jb += 8) {
+        transpose_kernel_8x8_neon(
+            &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+      }
+      for (unsigned int i = ib; i < ib + 8; i += 4) {
+        transpose_kernel_mxn_neon_128<4>(
+            N - jb,
+            &src[i * ld_src + jb],
+            ld_src,
+            &dst[i + jb * ld_dst],
+            ld_dst);
+      }
+    }
+  } else if (N % 8 == 4) {
+    for (ib = 0; ib + 8 <= M; ib += 8) {
+      for (jb = 0; jb + 8 <= N; jb += 8) {
+        transpose_kernel_8x8_neon(
+            &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+      }
+      for (unsigned int i = ib; i < ib + 8; i += 4) {
+        transpose_kernel_4x4_neon(
+            &src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst);
+      }
+    }
+  } else {
+    for (ib = 0; ib + 8 <= M; ib += 8) {
+      for (jb = 0; jb + 8 <= N; jb += 8) {
+        transpose_kernel_8x8_neon(
+            &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+      }
+      if (jb < N) {
+        transpose_kernel_mxn_neon_256<8>(
+            N - jb,
+            &src[ib * ld_src + jb],
+            ld_src,
+            &dst[ib + jb * ld_dst],
+            ld_dst);
+      }
+    }
+  }
+  switch (M - ib) {
+    case 1:
+      for (unsigned int j = 0; j < N; ++j) {
+        dst[ib + j * ld_dst] = src[ib * ld_src + j];
+      }
+      break;
+    case 2:
+      for (jb = 0; jb + 4 <= N; jb += 4) {
+        transpose_kernel_mxn_neon_128<2>(
+            4, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+      }
+      if (jb < N) {
+        transpose_kernel_mxn_neon_128<2>(
+            N - jb,
+            &src[ib * ld_src + jb],
+            ld_src,
+            &dst[ib + jb * ld_dst],
+            ld_dst);
+      }
+      break;
+    case 3:
+      for (jb = 0; jb + 4 <= N; jb += 4) {
+        transpose_kernel_mxn_neon_128<3>(
+            4, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+      }
+      if (jb < N) {
+        transpose_kernel_mxn_neon_128<3>(
+            N - jb,
+            &src[ib * ld_src + jb],
+            ld_src,
+            &dst[ib + jb * ld_dst],
+            ld_dst);
+      }
+      break;
+    case 4:
+      for (jb = 0; jb + 4 <= N; jb += 4) {
+        transpose_kernel_4x4_neon(
+            &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+      }
+      if (jb < N) {
+        transpose_kernel_mxn_neon_128<4>(
+            N - jb,
+            &src[ib * ld_src + jb],
+            ld_src,
+            &dst[ib + jb * ld_dst],
+            ld_dst);
+      }
+      break;
+    case 5:
+      for (jb = 0; jb + 8 <= N; jb += 8) {
+        transpose_kernel_mxn_neon_256<5>(
+            8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+      }
+      if (jb < N) {
+        transpose_kernel_mxn_neon_256<5>(
+            N - jb,
+            &src[ib * ld_src + jb],
+            ld_src,
+            &dst[ib + jb * ld_dst],
+            ld_dst);
+      }
+      break;
+    case 6:
+      for (jb = 0; jb + 8 <= N; jb += 8) {
+        transpose_kernel_mxn_neon_256<6>(
+            8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+      }
+      if (jb < N) {
+        transpose_kernel_mxn_neon_256<6>(
+            N - jb,
+            &src[ib * ld_src + jb],
+            ld_src,
+            &dst[ib + jb * ld_dst],
+            ld_dst);
+      }
+      break;
+    case 7:
+      for (jb = 0; jb + 8 <= N; jb += 8) {
+        transpose_kernel_mxn_neon_256<7>(
+            8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+      }
+      if (jb < N) {
+        transpose_kernel_mxn_neon_256<7>(
+            N - jb,
+            &src[ib * ld_src + jb],
+            ld_src,
+            &dst[ib + jb * ld_dst],
+            ld_dst);
+      }
+      break;
+  }
+}
diff --git a/nntrainer/tensor/matrix_transpose_neon/matrix_transpose_neon.h b/nntrainer/tensor/matrix_transpose_neon/matrix_transpose_neon.h
new file mode 100644 (file)
index 0000000..1f9d372
--- /dev/null
@@ -0,0 +1,48 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   matrix_transpose_neon.h
+ * @date   09 May 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is header file for matrix transpose using NEON
+ *
+ */
+
+#include <cstdint>
+
+/**
+ * @brief Matrix Transpose fallback. Note that this transposes a matrix, which
+ * is 2D Tensor
+ *
+ * @tparam T dataType of the incoming matrix. Implement more Kernels and connect
+ * to this function in order to support more datatypes.
+ * @param M row length of input matrix
+ * @param N col length of input matrix
+ * @param src source data of input matrix
+ * @param ld_src data offset of input matrix
+ * @param dst destination data of this function
+ * @param ld_dst data offset of output matrix 
+ */
+template <typename T>
+void transpose_fallback(unsigned int M, unsigned int N, const T *src,
+                   unsigned int ld_src, T *dst, unsigned int ld_dst);
+
+/**
+ * @brief Matrix Transpose using NEON. Note that this transposes a matrix, which
+ * is 2D Tensor
+ *
+ * @tparam T dataType of the incoming matrix. Implement more Kernels and connect
+ * to this function in order to support more datatypes.
+ * @param M row length of input matrix
+ * @param N col length of input matrix
+ * @param src source data of input matrix
+ * @param ld_src data offset of input matrix
+ * @param dst destination data of this function
+ * @param ld_dst data offset of output matrix 
+ */
+template <typename T>
+void transpose_neon(unsigned int M, unsigned int N, const T *src,
+                    unsigned int ld_src, T *dst, unsigned int ld_dst);
diff --git a/nntrainer/tensor/matrix_transpose_neon/meson.build b/nntrainer/tensor/matrix_transpose_neon/meson.build
new file mode 100644 (file)
index 0000000..d0cd308
--- /dev/null
@@ -0,0 +1,15 @@
+matrix_transpose_neon_sources = [
+    'matrix_transpose_neon.cpp',
+]
+
+matrix_transpose_neon_headers = [
+
+]
+
+foreach s : matrix_transpose_neon_sources
+  nntrainer_sources += meson.current_source_dir() / s
+endforeach
+
+foreach h : matrix_transpose_neon_headers
+  nntrainer_headers += meson.current_source_dir() / h
+endforeach
diff --git a/nntrainer/tensor/matrix_transpose_neon/transpose_utils_neon.h b/nntrainer/tensor/matrix_transpose_neon/transpose_utils_neon.h
new file mode 100644 (file)
index 0000000..b911fb7
--- /dev/null
@@ -0,0 +1,210 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   transpose_utils_neon.h
+ * @date   09 May 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  These are internal util functions for transposing matrix with NEON
+ *
+ */
+
+#include <arm_neon.h>
+#include <cassert>
+#include <cstdint>
+#include "./mask_neon.h"
+
+#define TRANSPOSE_FP16_4x4(row0, row1, row2, row3)                             \
+  float16x4x2_t row01 = vtrn_f16(row0, row1);                                  \
+  float16x4x2_t row23 = vtrn_f16(row2, row3);                                  \
+  row0 = vcvt_f16_f32(vcombine_f32(vget_low_f32(vcvt_f32_f16(row01.val[0])),   \
+                                   vget_low_f32(vcvt_f32_f16(row23.val[0])))); \
+  row1 = vcvt_f16_f32(vcombine_f32(vget_low_f32(vcvt_f32_f16(row01.val[1])),   \
+                                   vget_low_f32(vcvt_f32_f16(row23.val[1])))); \
+  row2 =                                                                       \
+    vcvt_f16_f32(vcombine_f32(vget_high_f32(vcvt_f32_f16(row01.val[0])),       \
+                              vget_high_f32(vcvt_f32_f16(row23.val[0]))));     \
+  row3 =                                                                       \
+    vcvt_f16_f32(vcombine_f32(vget_high_f32(vcvt_f32_f16(row01.val[1])),       \
+                              vget_high_f32(vcvt_f32_f16(row23.val[1]))));
+
+
+static inline void transpose_kernel_4x4_neon(const __fp16 *src,
+                                             unsigned int ld_src, __fp16 *dst,
+                                             unsigned int ld_dst) {
+  float16x4_t a = vld1_f16(&src[0 * ld_src]);
+  float16x4_t b = vld1_f16(&src[1 * ld_src]);
+  float16x4_t c = vld1_f16(&src[2 * ld_src]);
+  float16x4_t d = vld1_f16(&src[3 * ld_src]);
+
+  TRANSPOSE_FP16_4x4(a, b, c, d);
+
+  vst1_f16(&dst[0 * ld_dst], a);
+  vst1_f16(&dst[1 * ld_dst], b);
+  vst1_f16(&dst[2 * ld_dst], c);
+  vst1_f16(&dst[3 * ld_dst], d);
+}
+
+template <unsigned int M>
+static void transpose_kernel_mxn_neon_128(unsigned int N, const __fp16 *src,
+                                          unsigned int ld_src, __fp16 *dst,
+                                          unsigned int ld_dst) {
+
+
+  uint16x4_t bitmask_v8 = vld1_u16(reinterpret_cast<const uint16_t *>(masks[N]));
+  float16x4_t input[4];
+  float16x4_t ZEROS = vmov_n_f16(0.F);
+
+  unsigned i;
+  for (i = 0; i < M; ++i) {
+    input[i] = vbsl_f16(bitmask_v8, vld1_f16(&src[i * ld_src]), ZEROS);
+  }
+  for (; i < 4; ++i) {
+    input[i] = vmov_n_f16(0.F);
+  }
+
+  float16x4_t temp[4];
+  for (i = 0; i < (M + 1) / 2; ++i) {
+    temp[2 * i] = vzip1_f16(input[2 * i], input[2 * i + 1]);
+    temp[2 * i + 1] = vzip2_f16(input[2 * i], input[2 * i + 1]);
+  }
+  for (i = i * 2; i < 4; ++i) {
+    temp[i] = vmov_n_f16(0.F);
+  }
+
+  bitmask_v8 = vld1_u16(reinterpret_cast<const uint16_t *>(masks[M]));
+  for (i = 0; i < N; ++i) {
+    if (i % 2 == 0) {
+      input[i] =
+        vcvt_f16_f32(vcombine_f32(vget_low_f32(vcvt_f32_f16(temp[i / 2])),
+                                  vget_low_f32(vcvt_f32_f16(temp[2 + i / 2]))));
+    } else {
+      input[i] =
+        vcvt_f16_f32(vcombine_f32(vget_high_f32(vcvt_f32_f16(temp[i / 2])),
+                                  vget_high_f32(vcvt_f32_f16(temp[2 + i / 2]))));
+    }
+    vst1_f16(&dst[i * ld_dst], vbsl_f16(bitmask_v8, input[i], ZEROS));
+  }
+
+}
+
+static inline void transpose_kernel_8x8_neon(const __fp16 *src,
+                                             unsigned int ld_src, __fp16 *dst,
+                                             unsigned int ld_dst) {
+  float16x8_t a = vld1q_f16(&src[0 * ld_src]);
+  float16x8_t b = vld1q_f16(&src[1 * ld_src]);
+  float16x8_t c = vld1q_f16(&src[2 * ld_src]);
+  float16x8_t d = vld1q_f16(&src[3 * ld_src]);
+  float16x8_t e = vld1q_f16(&src[4 * ld_src]);
+  float16x8_t f = vld1q_f16(&src[5 * ld_src]);
+  float16x8_t g = vld1q_f16(&src[6 * ld_src]);
+  float16x8_t h = vld1q_f16(&src[7 * ld_src]);
+
+  float16x8_t ab0145, ab2367, cd0145, cd2367, ef0145, ef2367, gh0145, gh2367;
+  float16x8_t abcd04, abcd15, efgh04, efgh15, abcd26, abcd37, efgh26, efgh37;
+
+  ab0145 = vcombine_f16(vzip1_f16(vget_low_f16(a), vget_low_f16(b)),
+                        vzip1_f16(vget_high_f16(a), vget_high_f16(b)));
+  ab2367 = vcombine_f16(vzip2_f16(vget_low_f16(a), vget_low_f16(b)),
+                        vzip2_f16(vget_high_f16(a), vget_high_f16(b)));
+  cd0145 = vcombine_f16(vzip1_f16(vget_low_f16(c), vget_low_f16(d)),
+                        vzip1_f16(vget_high_f16(c), vget_high_f16(d)));
+  cd2367 = vcombine_f16(vzip2_f16(vget_low_f16(c), vget_low_f16(d)),
+                        vzip2_f16(vget_high_f16(c), vget_high_f16(d)));
+  ef0145 = vcombine_f16(vzip1_f16(vget_low_f16(e), vget_low_f16(f)),
+                        vzip1_f16(vget_high_f16(e), vget_high_f16(f)));
+  ef2367 = vcombine_f16(vzip2_f16(vget_low_f16(e), vget_low_f16(f)),
+                        vzip2_f16(vget_high_f16(e), vget_high_f16(f)));
+  gh0145 = vcombine_f16(vzip1_f16(vget_low_f16(g), vget_low_f16(h)),
+                        vzip1_f16(vget_high_f16(g), vget_high_f16(h)));
+  gh2367 = vcombine_f16(vzip2_f16(vget_low_f16(g), vget_low_f16(h)),
+                        vzip2_f16(vget_high_f16(g), vget_high_f16(h)));
+
+  uint16x8_t shuffle_mask =
+    vld1q_u16(reinterpret_cast<const uint16_t *>(shuffle_masks));
+  abcd04 = vbslq_f16(shuffle_mask, ab0145, vextq_f16(cd0145, cd0145, 6));
+  abcd15 = vbslq_f16(shuffle_mask, vextq_f16(ab0145, ab0145, 2), cd0145);
+  
+  efgh04 = vbslq_f16(shuffle_mask, ef0145, vextq_f16(gh0145, gh0145, 6));
+  efgh15 = vbslq_f16(shuffle_mask, vextq_f16(ef0145, ef0145, 2), gh0145);
+
+  abcd26 = vbslq_f16(shuffle_mask, ab2367, vextq_f16(cd2367, cd2367, 6));
+  abcd37 = vbslq_f16(shuffle_mask, vextq_f16(ab2367, ab2367, 2), cd2367);
+
+  efgh26 = vbslq_f16(shuffle_mask, ef2367, vextq_f16(gh2367, gh2367, 6));
+  efgh37 = vbslq_f16(shuffle_mask, vextq_f16(ef2367, ef2367, 2), gh2367);
+
+  a = vcombine_f16(vget_low_f16(abcd04), vget_low_f16(efgh04));
+  b = vcombine_f16(vget_low_f16(abcd15), vget_low_f16(efgh15));
+  c = vcombine_f16(vget_low_f16(abcd26), vget_low_f16(efgh26));
+  d = vcombine_f16(vget_low_f16(abcd37), vget_low_f16(efgh37));
+  e = vcombine_f16(vget_high_f16(abcd04), vget_high_f16(efgh04));
+  f = vcombine_f16(vget_high_f16(abcd15), vget_high_f16(efgh15));
+  g = vcombine_f16(vget_high_f16(abcd26), vget_high_f16(efgh26));
+  h = vcombine_f16(vget_high_f16(abcd37), vget_high_f16(efgh37));
+
+  vst1q_f16(&dst[0 * ld_dst], a);
+  vst1q_f16(&dst[1 * ld_dst], b);
+  vst1q_f16(&dst[2 * ld_dst], c);
+  vst1q_f16(&dst[3 * ld_dst], d);
+  vst1q_f16(&dst[4 * ld_dst], e);
+  vst1q_f16(&dst[5 * ld_dst], f);
+  vst1q_f16(&dst[6 * ld_dst], g);
+  vst1q_f16(&dst[7 * ld_dst], h);
+}
+
+template <unsigned int M>
+static void transpose_kernel_mxn_neon_256(unsigned int N, const __fp16 *src,
+                                          unsigned int ld_src, __fp16 *dst,
+                                          unsigned int ld_dst) {
+  float16x8_t ZEROS = vmovq_n_f16(0.F);
+  uint16x8_t bitmask_v8 =
+    vld1q_u16(reinterpret_cast<const uint16_t *>(neon_16bit_masks[N]));
+  float16x8_t input[8];
+  unsigned i;
+  for (i = 0; i < M; ++i) {
+    input[i] = vbslq_f16(bitmask_v8, vld1q_f16(&src[i * ld_src]), ZEROS);
+  }
+  for (; i < 8; ++i) {
+    input[i] = ZEROS;
+  }
+  float16x8_t temp[8];
+  for (i = 0; i < (M + 1) / 2; ++i) {
+    temp[2 * i] = vcombine_f16(
+      vzip1_f16(vget_low_f16(input[2 * i]), vget_low_f16(input[2 * i + 1])),
+      vzip1_f16(vget_high_f16(input[2 * i]), vget_high_f16(input[2 * i + 1])));
+    temp[2 * i + 1] = vcombine_f16(
+      vzip2_f16(vget_low_f16(input[2 * i]), vget_low_f16(input[2 * i + 1])),
+      vzip2_f16(vget_high_f16(input[2 * i]), vget_high_f16(input[2 * i + 1])));
+  }
+  for (i = i * 2; i < 8; ++i) {
+    temp[i] = ZEROS;
+  }
+
+  uint16x8_t shuffle_mask =
+    vld1q_u16(reinterpret_cast<const uint16_t *>(shuffle_masks));
+  for (i = 0; i < (M + 3) / 4; ++i) {
+    input[4 * i] = vbslq_f16(shuffle_mask, temp[4 * i],
+                             vextq_f16(temp[4 * i + 2], temp[4 * i + 2], 6));
+    input[4 * i + 1] = vbslq_f16(
+      shuffle_mask, vextq_f16(temp[4 * i], temp[4 * i], 2), temp[4 * i + 2]);
+    input[4 * i + 2] =
+      vbslq_f16(shuffle_mask, temp[4 * i + 1],
+                vextq_f16(temp[4 * i + 3], temp[4 * i + 3], 6));
+    input[4 * i + 3] =
+      vbslq_f16(shuffle_mask, vextq_f16(temp[4 * i + 1], temp[4 * i + 1], 2),
+                temp[4 * i + 3]);
+  }
+  bitmask_v8 = vld1q_u16(
+    reinterpret_cast<const uint16_t *>(neon_16bit_masks[M]));
+  for (i = 0; i < N; ++i) {
+    if (i < 4) {
+      temp[i] = vcombine_f16(vget_low_f16(input[i]), vget_low_f16(input[4 + i]));
+    } else {
+      temp[i] = vcombine_f16(vget_high_f16(input[i - 4]), vget_high_f16(input[i])); 
+    }
+    vst1q_f16(&dst[i * ld_dst], vbslq_f16(bitmask_v8, temp[i], ZEROS));
+  }
+}