* @param[in] K number of op(A)'s and columns and op(B)'s rows
* @param[in] alpha float number
* @param[in] beta float number
+ * @param[in] TransA bool transpose info of lhs matrix
+ * @param[in] TransB bool transpose info of rhs matrix
*/
void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
unsigned int N, unsigned int K, float alpha, float beta, bool TransA,
* @param[in] K number of op(A)'s and columns and op(B)'s rows
* @param[in] alpha float number
* @param[in] beta float number
+ * @param[in] TransA bool transpose info of lhs matrix
+ * @param[in] TransB bool transpose info of rhs matrix
*/
void hgemm_small(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
unsigned int N, unsigned int K, float alpha, float beta,
* @param[in] K number of op(A)'s and columns and op(B)'s rows
* @param[in] alpha float number
* @param[in] beta float number
+ * @param[in] TransA bool transpose info of lhs matrix
+ * @param[in] TransB bool transpose info of rhs matrix
*/
void hgemm_ensure_divisibility(const __fp16 *A, const __fp16 *B, float *C32,
unsigned int M, unsigned int N, unsigned int K,
* @param[in] K number of op(A)'s and columns and op(B)'s rows
* @param[in] alpha float number
* @param[in] beta float number
+ * @param[in] TransA bool transpose info of lhs matrix
+ * @param[in] TransB bool transpose info of rhs matrix
*/
void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32,
unsigned int M, unsigned int N, unsigned int K,
* @param[in] K number of op(A)'s and columns and op(B)'s rows
* @param[in] alpha float number
* @param[in] beta float number
+ * @param[in] TransA bool transpose info of lhs matrix
+ * @param[in] TransB bool transpose info of rhs matrix
*/
void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
unsigned int N, unsigned int K, float alpha, float beta,
#include <arm_neon.h>
#include <assert.h>
#include <hgemm_kernel.h>
+#include <stdexcept>
#include <stdlib.h>
#define INIT_KERNEL_8X16() \
vcvt_f32_f16(vget_high_f16(v120_127)))); \
} while (0)
-template<>
+template <>
void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
__fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
-// std::invalid_argument("Error : should not reach experimental kernel + full fp16 usage in hgemm");
+ throw std::runtime_error(
+ "Error : should not reach for full-fp16 usage in experimental kernel");
}
-template<>
+template <>
void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
__fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
assert(M > 0 && N > 0 && K > 0);
b = sb;
}
}
-