[hgemm] Optimizing dimension checks using bitmask
authorDebadri Samaddar <s.debadri@samsung.com>
Thu, 9 May 2024 08:15:22 +0000 (13:45 +0530)
committerMyungJoo Ham <myungjoo.ham@samsung.com>
Mon, 13 May 2024 05:35:30 +0000 (14:35 +0900)
Used bitmasks for dimension checks.
e.g: N % 8 is same as N & 0x7

Signed-off-by: Debadri Samaddar <s.debadri@samsung.com>
nntrainer/tensor/hgemm/hgemm.cpp

index c817040332997e2f4b860c1baf6f43d197f58d66..faf6d21cbee89d0df4bd94835651af2b9a91099f 100644 (file)
 void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
                    unsigned int N, unsigned int K, float alpha, float beta) {
   if (alpha == 1.F && beta == 0.F) {
-    if (M % 8 == 0 && N % 16 == 0 && K % 8 == 0) {
+    // used bitwise operator instead of modulo for performance
+    // e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M
+    if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) {
       hgemm_noTrans_8x16(M, N, K, A, K, B, N, C32, N, alpha, beta);
-    } else if (M % 8 == 0 && N % 8 == 0 && K % 8 == 0) {
+    } else if ((M & 0x7) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
       hgemm_noTrans_8x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
-    } else if (M % 4 == 0 && N % 8 == 0 && K % 4 == 0) {
+    } else if ((M & 0x3) == 0 && (N & 0x7) == 0 && (K & 0x3) == 0) {
       hgemm_noTrans_4x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
-    } else if (K % 8 == 0 && N % 8 == 0) {
+    } else if ((K & 0x7) == 0 && (N & 0x7) == 0) {
       hgemm_noTrans_1x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
-    } else if (K % 8 == 0 && N % 4 == 0) {
+    } else if ((K & 0x7) == 0 && (N & 0x3) == 0) {
       hgemm_noTrans_1x4(M, N, K, A, K, B, N, C32, N, alpha, beta);
     } else {
       hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
@@ -52,17 +54,19 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
 void hgemm_noTrans(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
                    unsigned int N, unsigned int K, float alpha, float beta) {
   if (alpha == 1.F && beta == 0.F) {
-    if (M % 8 == 0 && N % 16 == 0 && K % 8 == 0) {
+    // used bitwise operator instead of modulo for performance
+    // e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M
+    if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) {
       hgemm_noTrans_8x16(M, N, K, A, K, B, N, C, N, alpha, beta);
-    } else if (M % 8 == 0 && N % 8 == 0 && K % 8 == 0) {
+    } else if ((M & 0x7) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
       hgemm_noTrans_8x8(M, N, K, A, K, B, N, C, N, alpha, beta);
-    } else if (M % 4 == 0 && N % 8 == 0 && K % 4 == 0) {
+    } else if ((M & 0x3) == 0 && (N & 0x7) == 0 && (K & 0x3) == 0) {
       hgemm_noTrans_4x8(M, N, K, A, K, B, N, C, N, alpha, beta);
-    } else if (M % 4 == 0 && N % 4 == 0 && K % 4 == 0) {
+    } else if ((M & 0x3) == 0 && (N & 0x3) == 0 && (K & 0x3) == 0) {
       hgemm_noTrans_4x4(M, N, K, A, K, B, N, C, N, alpha, beta);
-    } else if (K % 8 == 0 && N % 8 == 0) {
+    } else if ((K & 0x7) == 0 && (N & 0x7) == 0) {
       hgemm_noTrans_1x8(M, N, K, A, K, B, N, C, N, alpha, beta);
-    } else if (K % 8 == 0 && N % 4 == 0) {
+    } else if ((K & 0x7) == 0 && (N & 0x3) == 0) {
       hgemm_noTrans_1x4(M, N, K, A, K, B, N, C, N, alpha, beta);
     }
   }