[ hgemm ] Implement NYI functions from matrix A/B hgemm_padding
authorskykongkong8 <ss.kong@samsung.com>
Mon, 15 Jul 2024 09:41:43 +0000 (18:41 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 30 Jul 2024 22:45:30 +0000 (07:45 +0900)
- Missing implementations might trigger unittest fails on Android.
- This patch will now support padding function for all combinations of following conditions : matrix A / B, trans/noTrans, M/K/N direction

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/tensor/hgemm/hgemm.cpp
nntrainer/tensor/hgemm/hgemm_padding/hgemm_padding_a.cpp
nntrainer/tensor/hgemm/hgemm_padding/hgemm_padding_b.cpp

index e2d584a9d4a2c3df996e647f11e7563b3900e124..2fdef140e5f86a49be969aeda4725790db3765e5 100644 (file)
@@ -22,6 +22,7 @@
 #include <hgemm_transAB.h>
 #include <hgemm_transB.h>
 #include <hgemm_util.h>
+#include <iostream>
 #include <limits>
 
 void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
@@ -31,55 +32,62 @@ void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
     return hgemm_K1(A, B, C, M, N, K, alpha, beta, TransA, TransB);
   }
   // dynamic creation to avoid reaching stack limit(causes segmentation fault)
-  float *C32 = (float *)malloc(M * N * sizeof(float));
+  const unsigned int M8_high = ((M - 1) / 8 + 1) * 8;
+  const unsigned int K8_high = ((K - 1) / 8 + 1) * 8;
+  const unsigned int N16_high = ((N - 1) / 16 + 1) * 16;
+  const unsigned int N8_low = (N >> 3) << 3;
+  float32x4_t ZEROS = vmovq_n_f32(0.F);
 
-  // performing beta*C
-  unsigned int idx = 0;
-  unsigned int size = M * N;
+  // void* C_ptr = 0;
+  // int iRet = posix_memalign(&C_ptr, 64, M8_high * N16_high * sizeof(float));
+  // float* C32 = (float*) C_ptr;
+  float *C32 = (float *)malloc(M8_high * N16_high * sizeof(float));
+
+  unsigned int size = M8_high * N16_high;
   unsigned int size8 = (size >> 3) << 3;
   unsigned int size4 = (size >> 2) << 2;
 
   if (std::fpclassify(beta) != FP_ZERO) {
-    for (; idx < size8; idx += 8) {
-      float16x8_t c =
-        vmulq_n_f16(vld1q_f16(&C[idx]), static_cast<__fp16>(beta));
-
-      vst1q_f32(&C32[idx], vcvt_f32_f16(vget_low_f16(c)));
-      vst1q_f32(&C32[idx + 4], vcvt_f32_f16(vget_high_f16(c)));
-    }
-    // remaining 4
-    for (; idx < size4; idx += 4) {
-      float16x4_t c = vmul_n_f16(vld1_f16(&C[idx]), static_cast<__fp16>(beta));
-
-      vst1q_f32(&C32[idx], vcvt_f32_f16(c));
+    for (unsigned int m = 0; m < M; ++m) {
+      for (unsigned int n = 0; n < N8_low; n += 8) {
+        float16x8_t c =
+          vmulq_n_f16(vld1q_f16(&C[m * N + n]), static_cast<__fp16>(beta));
+        vst1q_f32(&C32[m * N16_high + n], vcvt_f32_f16(vget_low_f16(c)));
+        vst1q_f32(&C32[m * N16_high + n + 4], vcvt_f32_f16(vget_high_f16(c)));
+      }
+      for (unsigned int n = N8_low; n < N; ++n) {
+        C32[m * N16_high + n] = beta * C[m * N + n];
+      }
+      for (unsigned int n = N; n < N16_high; ++n) {
+        C32[m * N16_high + n] = 0.F;
+      }
     }
-
-    // remaining values if dimensions not a multiple of 8
-    for (; idx < size; idx++) {
-      C32[idx] = C[idx] * beta;
+    for (unsigned m = M; m < M8_high; ++m) {
+      for (unsigned int n = 0; n < N16_high; n += 4) {
+        vst1q_f32(&C32[m * N16_high + n], ZEROS);
+      }
     }
   } else {
-    float32x4_t zeros = vmovq_n_f32(0.F);
-    for (; idx < size4; idx += 4) {
-      vst1q_f32(&C32[idx], zeros);
+    for (unsigned int idx = 0; idx < size4; idx += 4) {
+      vst1q_f32(&C32[idx], ZEROS);
     }
-    for (; idx < size; idx++) {
+    for (unsigned int idx = size4; idx < size; idx++) {
       C32[idx] = 0.F;
     }
   }
 
   hgemm_ensure_divisibility(A, B, C32, M, N, K, alpha, beta, TransA, TransB);
 
-  for (unsigned int idx = 0; idx < size8; idx += 8) {
-    float32x4_t x1 = vld1q_f32(&C32[idx]);
-    float32x4_t x2 = vld1q_f32(&C32[idx + 4]);
-
-    float16x8_t y1 = vcombine_f16(vcvt_f16_f32(x1), vcvt_f16_f32(x2));
-
-    vst1q_f16(&C[idx], y1);
-  }
-  for (unsigned int idx = size8; idx < size; ++idx) {
-    C[idx] = static_cast<__fp16>(C32[idx]);
+  for (unsigned int m = 0; m < M; ++m) {
+    for (unsigned int n = 0; n < N8_low; n += 8) {
+      float32x4_t x1 = vld1q_f32(&C32[m * N16_high + n]);
+      float32x4_t x2 = vld1q_f32(&C32[m * N16_high + n + 4]);
+      vst1q_f16(&C[m * N + n],
+                vcombine_f16(vcvt_f16_f32(x1), vcvt_f16_f32(x2)));
+    }
+    for (unsigned int n = N8_low; n < N; ++n) {
+      C[m * N + n] = C32[m * N16_high + n];
+    }
   }
 
   free(C32);
@@ -91,16 +99,16 @@ void hgemm_ensure_divisibility(const __fp16 *A, const __fp16 *B, float *C32,
                                bool TransB) {
   /// @note Padding standard : 8x16 is the only KERNEL that outperforms single
   /// precision GEMM 'so far'. Padding will forcibly make every GEMM cases to
-  /// use it. Note that padding is not the optimal way here, but just an option
+  /// use it. Note that padding is not an optimal way here, but just an option
   /// that is easier to implement. Fine-grained packing, blocking, and
-  /// corresponding kernels should be supported on the future for optimal
-  /// performance.
+  /// corresponding kernels should be supported in the future for optimal
+  /// performance in terms of both latency and memory.
 
   __fp16 *A_ = (__fp16 *)A, *B_ = (__fp16 *)B;
   unsigned int M_ = M, N_ = N, K_ = K;
   bool pad_A = false, pad_B = false;
 
-  // Case 2 : smaller than 8, 16 | padding would be redundant
+  // Smaller than 8, 16 -> padding would be redundant
   if (M < 8 && K < 16 && N < 16)
     return hgemm_classify(A_, B_, C32, M_, N_, K_, alpha, beta, TransA, TransB);
 
@@ -108,7 +116,8 @@ void hgemm_ensure_divisibility(const __fp16 *A, const __fp16 *B, float *C32,
   __fp16 *Bp;
 
   const unsigned int M8_high = ((M - 1) / 8 + 1) * 8;
-  const unsigned int K8_high = ((K - 1) / 8 + 1) * 8;
+  // const unsigned int K8_high = ((K - 1) / 8 + 1) * 8;
+  const unsigned int K8_high = ((K - 1) / 16 + 1) * 16;
   const unsigned int N16_high = ((N - 1) / 16 + 1) * 16;
 
   if ((M8_high != M) || (K8_high != K)) {
@@ -128,6 +137,48 @@ void hgemm_ensure_divisibility(const __fp16 *A, const __fp16 *B, float *C32,
     N_ = N16_high;
   }
 
+  // std::cout << "A matrix\n";
+  // for (unsigned int m = 0; m < M; m += 1) {
+  //   for (unsigned int k = 0; k < K; ++k) {
+  //     std::cout << A[m * K + k] << "\t";
+  //   }
+  //   std::cout << std::endl;
+  // }
+  // std::cout << std::endl;
+  // if (pad_A) {
+  //   std::cout << "B padding\n";
+  //   for (unsigned int m = 0; m < M; m += 1) {
+  //     for (unsigned int k = 0; k < K8_high; ++k) {
+  //       std::cout << A_[m * K8_high + k] << "\t";
+  //     }
+  //     std::cout << std::endl;
+  //   }
+  //   std::cout << std::endl;
+  // }
+  // std::cout << "B matrix\n";
+  // for (unsigned int k = 0; k < K; ++k) {
+  //   for (unsigned int n = 0; n < N; n += 1) {
+  //     std::cout << B[k * N + n] << "\t";
+  //   }
+  //   std::cout << std::endl;
+  // }
+  // std::cout << std::endl;
+  // if (pad_B) {
+  //   std::cout << "B padding\n";
+  //   for (unsigned int k = 0; k < K; ++k) {
+  //     for (unsigned int n = 0; n < N16_high; n += 1) {
+  //       std::cout << B_[k * N16_high + n] << "\t";
+  //     }
+  //     std::cout << std::endl;
+  //   }
+  //   std::cout << std::endl;
+  // }
+
+  // std::cout << "A matrix\n";
+  // matrix_printer<__fp16>(A_, M_, K_);
+  // std::cout << "B matrix\n";
+  // matrix_printer<__fp16>(B_, K_, N_);
+
   hgemm_classify(A_, B_, C32, M_, N_, K_, alpha, beta, TransA, TransB);
 
   if (pad_A)
index 297852719b0c0c0707f596fcef9ac95f40a9482b..898ca2447f282bf4adf0a4789208b1b774a67628 100644 (file)
@@ -94,7 +94,7 @@ void hgemm_padding_A_noTrans_wrt_MK(const __fp16 *A, __fp16 *Ap, unsigned int M,
   float16x8_t ZEROS = vmovq_n_f16(0.F);
 
   for (unsigned int m = 0; m < M; ++m) {
-    for (unsigned int k = 0; k < K8_low; ++k) {
+    for (unsigned int k = 0; k < K8_low; k += 8) {
       vst1q_f16(&Ap[m * K8 + k], vld1q_f16(&A[m * K + k]));
     }
     for (unsigned int k = K8_low; k < K; ++k) {
@@ -105,8 +105,8 @@ void hgemm_padding_A_noTrans_wrt_MK(const __fp16 *A, __fp16 *Ap, unsigned int M,
     }
   }
   for (unsigned int m = M; m < M8; ++m) {
-    for (unsigned int k = K; k < K8; ++k) {
-      Ap[m * K8 + k] = ZEROS;
+    for (unsigned int k = 0; k < K8; k += 8) {
+      vst1q_f16(&Ap[m * K8 + k], ZEROS);
     }
   }
 }
@@ -115,16 +115,15 @@ void hgemm_padding_A_Trans_wrt_M(const __fp16 *A, __fp16 *Ap, unsigned int M,
                                  unsigned int K, unsigned int M8,
                                  unsigned int K8) {
   const unsigned int M8_low = (M >> 3) << 3;
-
   for (unsigned int k = 0; k < K; ++k) {
     for (unsigned int m = 0; m < M8_low; m += 8) {
-      vst1q_f16(&Ap[k * M + m], vld1q_f16(&A[k * M + m]));
+      vst1q_f16(&Ap[k * M8 + m], vld1q_f16(&A[k * M + m]));
     }
     for (unsigned int m = M8_low; m < M; ++m) {
-      Ap[k * M + m] = A[k * M + m];
+      Ap[k * M8 + m] = A[k * M + m];
     }
     for (unsigned int m = M; m < M8; ++m) {
-      Ap[k * M + m] = 0.F;
+      Ap[k * M8 + m] = 0.F;
     }
   }
 }
@@ -132,11 +131,38 @@ void hgemm_padding_A_Trans_wrt_M(const __fp16 *A, __fp16 *Ap, unsigned int M,
 void hgemm_padding_A_Trans_wrt_K(const __fp16 *A, __fp16 *Ap, unsigned int M,
                                  unsigned int K, unsigned int M8,
                                  unsigned int K8) {
-  std::cerr << "Error : hgemm_padding_A_Trans_wrt_K NYI!\n";
+  float16x8_t ZEROS = vmovq_n_f16(0.F);
+  for (unsigned int k = 0; k < K; ++k) {
+    for (unsigned int m = 0; m < M; m += 8) {
+      vst1q_f16(&Ap[k * M8 + m], vld1q_f16(&A[k * M + m]));
+    }
+  }
+  for (unsigned int k = K; k < K8; ++k) {
+    for (unsigned int m = 0; m < M; m += 8) {
+      vst1q_f16(&Ap[k * M8 + m], ZEROS);
+    }
+  }
 }
 
 void hgemm_padding_A_Trans_wrt_MK(const __fp16 *A, __fp16 *Ap, unsigned int M,
                                   unsigned int K, unsigned int M8,
                                   unsigned int K8) {
-  std::cerr << "Error : hgemm_padding_A_Trans_wrt_MK NYI!\n";
+  float16x8_t ZEROS = vmovq_n_f16(0.F);
+  const unsigned int M8_low = (M >> 3) << 3;
+  for (unsigned int k = 0; k < K; ++k) {
+    for (unsigned int m = 0; m < M8_low; m += 8) {
+      vst1q_f16(&Ap[k * M8 + m], vld1q_f16(&A[k * M + m]));
+    }
+    for (unsigned int m = M8_low; m < M; ++m) {
+      Ap[k * M8 + m] = A[k * M + m];
+    }
+    for (unsigned int m = M; m < M8; ++m) {
+      Ap[k * M8 + m] = 0.F;
+    }
+  }
+  for (unsigned int k = K; k < K8; ++k) {
+    for (unsigned int m = 0; m < M8; m += 8) {
+      vst1q_f16(&Ap[k * M8 + m], ZEROS);
+    }
+  }
 }
index f0ef1052b2ef6128ebf796b308dd6ee72e44b6f6..3eb437b3cf565c7e8d10e0863c0ad133a43474a7 100644 (file)
@@ -20,9 +20,9 @@ void hgemm_padding_B(const __fp16 *B, __fp16 *Bp, unsigned int K,
                      unsigned int N, unsigned int K8, unsigned int N16,
                      bool transB) {
   if (transB) {
-    hgemm_padding_B_Trans(B, Bp, K, N, K8, N16);
+    return hgemm_padding_B_Trans(B, Bp, K, N, K8, N16);
   } else {
-    hgemm_padding_B_noTrans(B, Bp, K, N, K8, N16);
+    return hgemm_padding_B_noTrans(B, Bp, K, N, K8, N16);
   }
 }
 
@@ -30,11 +30,11 @@ void hgemm_padding_B_noTrans(const __fp16 *B, __fp16 *Bp, unsigned int K,
                              unsigned int N, unsigned int K8,
                              unsigned int N16) {
   if (K != K8 && N != N16) {
-    hgemm_padding_B_noTrans_wrt_KN(B, Bp, K, N, K8, N16);
+    return hgemm_padding_B_noTrans_wrt_KN(B, Bp, K, N, K8, N16);
   } else if (K != K8) {
-    hgemm_padding_B_noTrans_wrt_K(B, Bp, K, N, K8, N16);
+    return hgemm_padding_B_noTrans_wrt_K(B, Bp, K, N, K8, N16);
   } else if (N != N16) {
-    hgemm_padding_B_noTrans_wrt_N(B, Bp, K, N, K8, N16);
+    return hgemm_padding_B_noTrans_wrt_N(B, Bp, K, N, K8, N16);
   } else {
     std::cerr << "Error : No room for matrix B padding\n";
   }
@@ -43,11 +43,11 @@ void hgemm_padding_B_noTrans(const __fp16 *B, __fp16 *Bp, unsigned int K,
 void hgemm_padding_B_Trans(const __fp16 *B, __fp16 *Bp, unsigned int K,
                            unsigned int N, unsigned int K8, unsigned int N16) {
   if (K != K8 && N != N16) {
-    hgemm_padding_B_Trans_wrt_KN(B, Bp, K, N, K8, N16);
+    return hgemm_padding_B_Trans_wrt_KN(B, Bp, K, N, K8, N16);
   } else if (K != K8) {
-    hgemm_padding_B_Trans_wrt_K(B, Bp, K, N, K8, N16);
+    return hgemm_padding_B_Trans_wrt_K(B, Bp, K, N, K8, N16);
   } else if (N != N16) {
-    hgemm_padding_B_Trans_wrt_N(B, Bp, K, N, K8, N16);
+    return hgemm_padding_B_Trans_wrt_N(B, Bp, K, N, K8, N16);
   } else {
     std::cerr << "Error : No room for matrix B padding\n";
   }
@@ -56,7 +56,18 @@ void hgemm_padding_B_Trans(const __fp16 *B, __fp16 *Bp, unsigned int K,
 void hgemm_padding_B_noTrans_wrt_N(const __fp16 *B, __fp16 *Bp, unsigned int K,
                                    unsigned int N, unsigned int K8,
                                    unsigned int N16) {
-  std::cerr << "Error : hgemm_padding_B_noTrans_wrt_N NYI!\n";
+  const unsigned int N8_low = (N >> 3) << 3;
+  for (unsigned int k = 0; k < K; ++k) {
+    for (unsigned int n = 0; n < N8_low; n += 8) {
+      vst1q_f16(&Bp[k * N16 + n], vld1q_f16(&B[k * N + n]));
+    }
+    for (unsigned int n = N8_low; n < N; ++n) {
+      Bp[k * N16 + n] = B[k * N + n];
+    }
+    for (unsigned int n = N; n < N16; ++n) {
+      Bp[k * N16 + n] = 0.F;
+    }
+  }
 }
 
 void hgemm_padding_B_noTrans_wrt_K(const __fp16 *B, __fp16 *Bp, unsigned int K,
@@ -79,13 +90,41 @@ void hgemm_padding_B_noTrans_wrt_K(const __fp16 *B, __fp16 *Bp, unsigned int K,
 void hgemm_padding_B_noTrans_wrt_KN(const __fp16 *B, __fp16 *Bp, unsigned int K,
                                     unsigned int N, unsigned int K8,
                                     unsigned int N16) {
-  std::cerr << "Error : hgemm_padding_B_noTrans_wrt_KN NYI!\n";
+  unsigned int N8_low = (N >> 3) << 3;
+  float16x8_t ZEROS = vmovq_n_f16(0.F);
+  for (unsigned int k = 0; k < K; ++k) {
+    for (unsigned int n = 0; n < N8_low; n += 8) {
+      vst1q_f16(&Bp[k * N16 + n], vld1q_f16(&B[k * N + n]));
+    }
+    for (unsigned int n = N8_low; n < N; ++n) {
+      Bp[k * N16 + n] = B[k * N + n];
+    }
+    for (unsigned int n = N; n < N16; ++n) {
+      Bp[k * N16 + n] = 0.F;
+    }
+  }
+  for (unsigned int k = K; k < K8; ++k) {
+    for (unsigned int n = 0; n < N16; n += 8) {
+      vst1q_f16(&Bp[k * N16 + n], ZEROS);
+    }
+  }
 }
 
 void hgemm_padding_B_Trans_wrt_N(const __fp16 *B, __fp16 *Bp, unsigned int K,
                                  unsigned int N, unsigned int K8,
                                  unsigned int N16) {
-  std::cerr << "Error : hgemm_padding_B_Trans_wrt_N NYI!\n";
+  float16x8_t ZEROS = vmovq_n_f16(0.F);
+
+  for (unsigned int n = 0; n < N; ++n) {
+    for (unsigned int k = 0; k < K; k += 8) {
+      vst1q_f16(&Bp[n * K8 + k], vld1q_f16(&B[n * K + k]));
+    }
+  }
+  for (unsigned int n = N; n < N16; ++n) {
+    for (unsigned int k = 0; k < K; k += 8) {
+      vst1q_f16(&Bp[n * K8 + k], ZEROS);
+    }
+  }
 }
 
 void hgemm_padding_B_Trans_wrt_K(const __fp16 *B, __fp16 *Bp, unsigned int K,
@@ -110,5 +149,23 @@ void hgemm_padding_B_Trans_wrt_K(const __fp16 *B, __fp16 *Bp, unsigned int K,
 void hgemm_padding_B_Trans_wrt_KN(const __fp16 *B, __fp16 *Bp, unsigned int K,
                                   unsigned int N, unsigned int K8,
                                   unsigned int N16) {
-  std::cerr << "Error : hgemm_padding_B_Trans_wrt_KN NYI!\n";
+  unsigned int K8_low = (K >> 3) << 3;
+  float16x8_t ZEROS = vmovq_n_f16(0.F);
+
+  for (unsigned int n = 0; n < N; ++n) {
+    for (unsigned int k = 0; k < K8_low; k += 8) {
+      vst1q_f16(&Bp[n * K8 + k], vld1q_f16(&B[n * K + k]));
+    }
+    for (unsigned int k = K8_low; k < K; ++k) {
+      Bp[n * K8 + k] = B[n * K + k];
+    }
+    for (unsigned int k = K; k < K8; ++k) {
+      Bp[n * K8 + k] = 0.F;
+    }
+  }
+  for (unsigned int n = N; n < N16; ++n) {
+    for (unsigned int k = 0; k < K8; k += 8) {
+      vst1q_f16(&Bp[n * K8 + k], ZEROS);
+    }
+  }
 }