[ trivial ] Add missing docs and error message
authorskykongkong8 <ss.kong@samsung.com>
Wed, 7 Aug 2024 01:21:08 +0000 (10:21 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 9 Aug 2024 00:24:54 +0000 (09:24 +0900)
- Add missing doxtgen tags : transpose boolean params
- error message : emit error when try to use full-fp16 kernel with experimental kernel build

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/tensor/hgemm/hgemm.h
nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_8x16_experimental.cpp

index e67edec840c5186e16df2c28dd70b589e4bfa91d..2904302d2410b4888f60031da5d7a1e2135a288a 100644 (file)
@@ -23,6 +23,8 @@
  * @param[in] K number of op(A)'s and columns and op(B)'s rows
  * @param[in] alpha float number
  * @param[in] beta float number
+ * @param[in] TransA bool transpose info of lhs matrix
+ * @param[in] TransB bool transpose info of rhs matrix
  */
 void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
            unsigned int N, unsigned int K, float alpha, float beta, bool TransA,
@@ -39,6 +41,8 @@ void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
  * @param[in] K number of op(A)'s and columns and op(B)'s rows
  * @param[in] alpha float number
  * @param[in] beta float number
+ * @param[in] TransA bool transpose info of lhs matrix
+ * @param[in] TransB bool transpose info of rhs matrix
  */
 void hgemm_small(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
                  unsigned int N, unsigned int K, float alpha, float beta,
@@ -55,6 +59,8 @@ void hgemm_small(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
  * @param[in] K number of op(A)'s and columns and op(B)'s rows
  * @param[in] alpha float number
  * @param[in] beta float number
+ * @param[in] TransA bool transpose info of lhs matrix
+ * @param[in] TransB bool transpose info of rhs matrix
  */
 void hgemm_ensure_divisibility(const __fp16 *A, const __fp16 *B, float *C32,
                                unsigned int M, unsigned int N, unsigned int K,
@@ -72,6 +78,8 @@ void hgemm_ensure_divisibility(const __fp16 *A, const __fp16 *B, float *C32,
  * @param[in] K number of op(A)'s and columns and op(B)'s rows
  * @param[in] alpha float number
  * @param[in] beta float number
+ * @param[in] TransA bool transpose info of lhs matrix
+ * @param[in] TransB bool transpose info of rhs matrix
  */
 void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32,
                     unsigned int M, unsigned int N, unsigned int K,
@@ -88,6 +96,8 @@ void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32,
  * @param[in] K number of op(A)'s and columns and op(B)'s rows
  * @param[in] alpha float number
  * @param[in] beta float number
+ * @param[in] TransA bool transpose info of lhs matrix
+ * @param[in] TransB bool transpose info of rhs matrix
  */
 void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
               unsigned int N, unsigned int K, float alpha, float beta,
index 81033e16177fc676bc19655180c745a1eb96c1c4..b1c5ffe06af245307147aa34beda7e14afbc9884 100644 (file)
@@ -14,6 +14,7 @@
 #include <arm_neon.h>
 #include <assert.h>
 #include <hgemm_kernel.h>
+#include <stdexcept>
 #include <stdlib.h>
 
 #define INIT_KERNEL_8X16()       \
                         vcvt_f32_f16(vget_high_f16(v120_127))));               \
   } while (0)
 
-template<>
+template <>
 void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
                        __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
-//  std::invalid_argument("Error : should not reach experimental kernel + full fp16 usage in hgemm");
+  throw std::runtime_error(
+    "Error : should not reach for full-fp16 usage in experimental kernel");
 }
 
-template<>
+template <>
 void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
                        __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
   assert(M > 0 && N > 0 && K > 0);
@@ -803,4 +805,3 @@ void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
     b = sb;
   }
 }
-