#include <cstdint>
#include <mask_neon.h>
-#define TRANSPOSE_FP16_4x4(row0, row1, row2, row3) \
- float16x4x2_t row01 = vtrn_f16(row0, row1); \
- float16x4x2_t row23 = vtrn_f16(row2, row3); \
- row0 = vcvt_f16_f32(vcombine_f32(vget_low_f32(vcvt_f32_f16(row01.val[0])), \
- vget_low_f32(vcvt_f32_f16(row23.val[0])))); \
- row1 = vcvt_f16_f32(vcombine_f32(vget_low_f32(vcvt_f32_f16(row01.val[1])), \
- vget_low_f32(vcvt_f32_f16(row23.val[1])))); \
- row2 = \
- vcvt_f16_f32(vcombine_f32(vget_high_f32(vcvt_f32_f16(row01.val[0])), \
- vget_high_f32(vcvt_f32_f16(row23.val[0])))); \
- row3 = \
- vcvt_f16_f32(vcombine_f32(vget_high_f32(vcvt_f32_f16(row01.val[1])), \
- vget_high_f32(vcvt_f32_f16(row23.val[1]))));
-
+#define TRANSPOSE_FP16_4x4(row0, row1, row2, row3) \
+ do { \
+ float16x4x2_t row01 = vtrn_f16(row0, row1); \
+ float16x4x2_t row23 = vtrn_f16(row2, row3); \
+ row0 = \
+ vcvt_f16_f32(vcombine_f32(vget_low_f32(vcvt_f32_f16(row01.val[0])), \
+ vget_low_f32(vcvt_f32_f16(row23.val[0])))); \
+ row1 = \
+ vcvt_f16_f32(vcombine_f32(vget_low_f32(vcvt_f32_f16(row01.val[1])), \
+ vget_low_f32(vcvt_f32_f16(row23.val[1])))); \
+ row2 = \
+ vcvt_f16_f32(vcombine_f32(vget_high_f32(vcvt_f32_f16(row01.val[0])), \
+ vget_high_f32(vcvt_f32_f16(row23.val[0])))); \
+ row3 = \
+ vcvt_f16_f32(vcombine_f32(vget_high_f32(vcvt_f32_f16(row01.val[1])), \
+ vget_high_f32(vcvt_f32_f16(row23.val[1])))); \
+ } while (0)
+/**
+ * @brief 4x4 sized kernel for matrix transpose in NEON
+ *
+ * @param src __fp16* source data
+ * @param ld_src col length of src
+ * @param dst __fp16* destination data
+ * @param ld_dst col length of dst
+ */
static inline void transpose_kernel_4x4_neon(const __fp16 *src,
unsigned int ld_src, __fp16 *dst,
unsigned int ld_dst) {
vst1_f16(&dst[3 * ld_dst], d);
}
+/**
+ * @brief general case mxn sized matrix transpose kernel with 128 bit SIMD
+ * register
+ *
+ * @tparam M leftover size for row direction
+ * @param N leftover size for col direction
+ * @param src __fp16* source data
+ * @param ld_src col length of src
+ * @param dst __fp16* destination data
+ * @param ld_dst col length of dst
+ */
template <unsigned int M>
static void transpose_kernel_mxn_neon_128(unsigned int N, const __fp16 *src,
unsigned int ld_src, __fp16 *dst,
vbsl_f16(bitmask_v8, input[i], vld1_f16(&dst[i * ld_dst])));
}
}
-
+/**
+ * @brief 8x8 sized kernel for matrix transpose in NEON
+ *
+ * @param src __fp16* source data
+ * @param ld_src col length of src
+ * @param dst __fp16* destination data
+ * @param ld_dst col length of dst
+ */
static inline void transpose_kernel_8x8_neon(const __fp16 *src,
unsigned int ld_src, __fp16 *dst,
unsigned int ld_dst) {
vst1q_f16(&dst[6 * ld_dst], g);
vst1q_f16(&dst[7 * ld_dst], h);
}
-
+/**
+ * @brief general case mxn sized matrix transpose kernel with 256 bit SIMD
+ * register
+ *
+ * @tparam M leftover size for row direction
+ * @param N leftover size for col direction
+ * @param src __fp16* source data
+ * @param ld_src col length of src
+ * @param dst __fp16* destination data
+ * @param ld_dst col length of dst
+ */
template <unsigned int M>
static void transpose_kernel_mxn_neon_256(unsigned int N, const __fp16 *src,
unsigned int ld_src, __fp16 *dst,
if (is_format_nchw) {
for (unsigned int b = 0; b < batch(); ++b) {
for (unsigned int c = 0; c < channel(); ++c) {
- transpose_matrix(
- height(), width(), getData<_FP16>() + getIndex(b, c, 0, 0),
- width(), out.getData<_FP16>() + out.getIndex(b, c, 0, 0),
- out.width());
+ transpose_matrix(height(), width(),
+ getData<_FP16>() + getIndex(b, c, 0, 0), width(),
+ out.getData<_FP16>() + out.getIndex(b, c, 0, 0),
+ out.width());
}
}
} else {