[ Trivial ] Remove redundant comments and format
authorskykongkong8 <ss.kong@samsung.com>
Mon, 15 Apr 2024 04:11:24 +0000 (13:11 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 22 May 2024 23:13:42 +0000 (08:13 +0900)
- Due to adaptive macro kernel usage, previous comment is no longer needed.

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/tensor/hgemm/hgemm.cpp
nntrainer/tensor/hgemm/hgemm_kernel_4x4.h
nntrainer/tensor/hgemm/hgemm_kernel_4x8.h
nntrainer/tensor/hgemm/hgemm_kernel_8x16.h
nntrainer/tensor/hgemm/hgemm_kernel_8x8.h

index 61353e595b8feb6a6225e23f603d41f0bfc32a39..4aaadf331c46b28314197dce965d0d9e5ba9a943 100644 (file)
@@ -64,9 +64,9 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
       hgemm_noTrans_4x8(M, N, K, A, K, B, N, C, N, alpha, beta);
     } else if ((M & 0x3) == 0 && (N & 0x3) == 0 && (K & 0x3) == 0) {
       hgemm_noTrans_4x4(M, N, K, A, K, B, N, C, N, alpha, beta);
-    } else if ((K & 0x7) == 0 && (N & 0x7) == 0) {
+    } else if ((N & 0x7) == 0 && (K & 0x7) == 0) {
       hgemm_noTrans_1x8(M, N, K, A, K, B, N, C, N, alpha, beta);
-    } else if ((K & 0x7) == 0 && (N & 0x3) == 0) {
+    } else if ((N & 0x3) == 0 && (K & 0x7) == 0) {
       hgemm_noTrans_1x4(M, N, K, A, K, B, N, C, N, alpha, beta);
     }
   }
index ab49faaca7d6b16157eca62b54263a870ee7443b..7bf75b13b7a7ba9115692afeb19437f2f40ecaa0 100644 (file)
@@ -20,7 +20,7 @@
   v26 = vdup_n_f16(0.F);  \
   v27 = vdup_n_f16(0.F);
 
-// 1. Partial sum 256 digits : Medium accuracy, medium latency
+// 1. Partial sum 256 digits
 #define KERNEL_4x4_ACC16()               \
   dv0 = vld1_f16(a);                     \
   vb0 = vld1_f16(b);                     \
   b += 4 * 16;                           \
   a += 4 * 16;
 
-// 2. Partial sum 128 digits : Medium accuracy, medium latency
+// 2. Partial sum 128 digits
 #define KERNEL_4x4_ACC8()                \
   dv0 = vld1_f16(a);                     \
   vb0 = vld1_f16(b);                     \
   b += 4 * 8;                            \
   a += 4 * 8;
 
-// 2. Partial sum 16 digits : Best accuracy, worst latency
+// 2. Partial sum 16 digits
 #define KERNEL_4x4_ACC1()                \
   dv0 = vld1_f16(a);                     \
   vb0 = vld1_f16(b);                     \
index 064f0a7b73a66bacd0e6a43b4726cc60022ed0bb..01204457e9c8901cbc05337efa65a3bdc32be6ec 100644 (file)
@@ -20,7 +20,7 @@
   v6 = vdupq_n_f16(0.F);  \
   v9 = vdupq_n_f16(0.F);
 
-// 1. Partial sum 256 digits : worst accuracy, best latency
+// 1. Partial sum 256 digits
 #define KERNEL_4x8_ACC16()              \
   dv0 = vld1_f16(a);                    \
   v24 = vld1q_f16(b);                   \
   b += 8 * 16;                          \
   a += 4 * 16;
 
-// 1. Partial sum 256 digits : worst accuracy, best latency
+// 1. Partial sum 256 digits
 #define KERNEL_4x8_ACC8()               \
   dv0 = vld1_f16(a);                    \
   v24 = vld1q_f16(b);                   \
   b += 8 * 8;                           \
   a += 4 * 8;
 
-// 2. Partial sum 128 digits : medium accuracy, medium latency
+// 2. Partial sum 128 digits
 #define KERNEL_4x8_ACC4()               \
   dv0 = vld1_f16(a);                    \
   v24 = vld1q_f16(b);                   \
   b += 8 * 4;                           \
   a += 4 * 4;
 
-// 3. Partial sum 32 digits : Best accuracy, worst latency
+// 3. Partial sum 32 digits
 #define KERNEL_4x8_ACC1()               \
   dv0 = vld1_f16(a);                    \
   v24 = vld1q_f16(b);                   \
index 38778ea8f39324ce4c5e80b15c35d2b51d7eb19d..a89a6b542181a80c2183114cf8d1f3f10c29ac70 100644 (file)
@@ -32,7 +32,7 @@
   v112_119 = vdupq_n_f16(0.F); \
   v120_127 = vdupq_n_f16(0.F);
 
-// 0. Partial sum 2048 digits : Best latency, worst accuracy.
+// 1. Partial sum 2048 digits
 #define KERNEL_8x16_ACC16()                          \
   va0 = vld1q_f16(a);                                \
   v24 = vld1q_f16(b);                                \
   b += 16 * 16;                                      \
   a += 8 * 16;
 
-// 1. Partial sum 1024 digits : Medium-high accuracy, medium latency
+// 2. Partial sum 1024 digits
 #define KERNEL_8x16_ACC8()                           \
   va0 = vld1q_f16(a);                                \
   v24 = vld1q_f16(b);                                \
   b += 16 * 8;                                       \
   a += 8 * 8;
 
-// 2. Partial sum 512 digits : Medium accuracy, medium latency
+// 3. Partial sum 512 digits
 #define KERNEL_8x16_ACC4()                           \
   va0 = vld1q_f16(a);                                \
   v24 = vld1q_f16(b);                                \
   b += 16 * 4;                                       \
   a += 8 * 4;
 
-// 3. Partial sum 128 digits : Best accuracy, worst latency
+// 3. Partial sum 128 digits
 #define KERNEL_8x16_ACC1()                           \
   va0 = vld1q_f16(a);                                \
   v24 = vld1q_f16(b);                                \
@@ -740,28 +740,26 @@ void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
       for (; l < K;) {
         KERNEL_8x16_ACC1();
       }
-       vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0_7));
-        vst1q_f16(c + 8, vaddq_f16(vld1q_f16(c + 8), v64_71));
-        vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v8_15));
-        vst1q_f16(c + ldc + 8, vaddq_f16(vld1q_f16(c + ldc + 8), v72_79));
-        vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v16_23));
-        vst1q_f16(c + 2 * ldc + 8,
-                  vaddq_f16(vld1q_f16(c + 2 * ldc + 8), v80_87));
-        vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v24_31));
-        vst1q_f16(c + 3 * ldc + 8,
-                  vaddq_f16(vld1q_f16(c + 3 * ldc + 8), v88_95));
-        vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v32_39));
-        vst1q_f16(c + 4 * ldc + 8,
-                  vaddq_f16(vld1q_f16(c + 4 * ldc + 8), v96_103));
-        vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v40_47));
-        vst1q_f16(c + 5 * ldc + 8,
-                  vaddq_f16(vld1q_f16(c + 5 * ldc + 8), v104_111));
-        vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v48_55));
-        vst1q_f16(c + 6 * ldc + 8,
-                  vaddq_f16(vld1q_f16(c + 6 * ldc + 8), v112_119));
-        vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v56_63));
-        vst1q_f16(c + 7 * ldc + 8,
-                  vaddq_f16(vld1q_f16(c + 7 * ldc + 8), v120_127));
+      vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0_7));
+      vst1q_f16(c + 8, vaddq_f16(vld1q_f16(c + 8), v64_71));
+      vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v8_15));
+      vst1q_f16(c + ldc + 8, vaddq_f16(vld1q_f16(c + ldc + 8), v72_79));
+      vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v16_23));
+      vst1q_f16(c + 2 * ldc + 8, vaddq_f16(vld1q_f16(c + 2 * ldc + 8), v80_87));
+      vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v24_31));
+      vst1q_f16(c + 3 * ldc + 8, vaddq_f16(vld1q_f16(c + 3 * ldc + 8), v88_95));
+      vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v32_39));
+      vst1q_f16(c + 4 * ldc + 8,
+                vaddq_f16(vld1q_f16(c + 4 * ldc + 8), v96_103));
+      vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v40_47));
+      vst1q_f16(c + 5 * ldc + 8,
+                vaddq_f16(vld1q_f16(c + 5 * ldc + 8), v104_111));
+      vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v48_55));
+      vst1q_f16(c + 6 * ldc + 8,
+                vaddq_f16(vld1q_f16(c + 6 * ldc + 8), v112_119));
+      vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v56_63));
+      vst1q_f16(c + 7 * ldc + 8,
+                vaddq_f16(vld1q_f16(c + 7 * ldc + 8), v120_127));
       c += 16;
       a -= 8 * K;
     }
index c913bdd040c93668431b8e385dcd275e21365ff6..4901c3f5182e3ce7e448154fe0d36873ac988f38 100644 (file)
@@ -24,7 +24,7 @@
   v30 = vdupq_n_f16(0.F); \
   v31 = vdupq_n_f16(0.F);
 
-// 1. Partial sum 1024 digits : Worst accuracy, best latency
+// 1. Partial sum 1024 digits
 #define KERNEL_8x8_ACC16()                 \
   va0 = vld1q_f16(a);                      \
   v16 = vld1q_f16(b);                      \
   b += 8 * 16;                             \
   a += 8 * 16;
 
-// 2. Partial sum 512 digits : Medium accuracy, medium latency
+// 2. Partial sum 512 digits
 #define KERNEL_8x8_ACC8()                  \
   va0 = vld1q_f16(a);                      \
   v16 = vld1q_f16(b);                      \
   b += 8 * 8;                              \
   a += 8 * 8;
 
-// 3. Partial sum 256 digits : Medium accuracy, medium latency
+// 3. Partial sum 256 digits
 #define KERNEL_8x8_ACC4()                  \
   va0 = vld1q_f16(a);                      \
   v16 = vld1q_f16(b);                      \
   b += 8 * 4;                              \
   a += 8 * 4;
 
-// 4. Partial sum 64 digits : Best accuracy, worst latency
+// 4. Partial sum 64 digits
 #define KERNEL_8x8_ACC1()                  \
   va0 = vld1q_f16(a);                      \
   v16 = vld1q_f16(b);                      \