[blas/opencl] SGEMM OpenCL kernels added
authorDebadri Samaddar <s.debadri@samsung.com>
Thu, 20 Jun 2024 10:28:02 +0000 (15:58 +0530)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 4 Jul 2024 06:53:44 +0000 (15:53 +0900)
Added all possible OpenCL kernels for SGEMM
Added unit tests

Signed-off-by: Debadri Samaddar <s.debadri@samsung.com>
nntrainer/layers/layer_context.cpp
nntrainer/layers/layer_context.h
nntrainer/tensor/cl_operations/blas_kernel_interface.cpp
nntrainer/tensor/cl_operations/blas_kernels.cpp
nntrainer/tensor/cl_operations/blas_kernels.h
nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp
test/unittest/unittest_blas_kernels_cl.cpp

index c534165b48a7efc8d558a8c79d5e604f32fd7b08..532634ffc906b9aa8abe21b48ce0fb6aa83bd269 100644 (file)
@@ -681,16 +681,28 @@ std::string RunLayerContext::getKernelName(LayerKernel layerKernel) {
   switch (layerKernel) {
   case LayerKernel::SGEMV:
     return "sgemv_cl";
-  case LayerKernel::DOT:
-    return "dot_cl";
-  case LayerKernel::SGEMM:
-    return "sgemm_cl";
   case LayerKernel::SGEMV_FP16:
     return "sgemv_cl_fp16";
+  case LayerKernel::DOT:
+    return "dot_cl";
   case LayerKernel::DOT_FP16:
     return "dot_cl_fp16";
-  case LayerKernel::SGEMM_FP16:
-    return "sgemm_cl_fp16";
+  case LayerKernel::SGEMM_NOTRANS:
+    return "sgemm_cl_noTrans";
+  case LayerKernel::SGEMM_NOTRANS_FP16:
+    return "sgemm_cl_noTrans_fp16";
+  case LayerKernel::SGEMM_TRANSA:
+    return "sgemm_cl_transA";
+  case LayerKernel::SGEMM_TRANSA_FP16:
+    return "sgemm_cl_transA_fp16";
+  case LayerKernel::SGEMM_TRANSB:
+    return "sgemm_cl_transB";
+  case LayerKernel::SGEMM_TRANSB_FP16:
+    return "sgemm_cl_transB_fp16";
+  case LayerKernel::SGEMM_TRANSAB:
+    return "sgemm_cl_transAB";
+  case LayerKernel::SGEMM_TRANSAB_FP16:
+    return "sgemm_cl_transAB_fp16";
   case LayerKernel::ADD:
     return "addition_cl";
   case LayerKernel::ADD_FP16:
index 18fec0e8deb1be4c9feb7de715f1b5625458a42e..a37f16aca52d8040efb22c3c71ef1ca1ffa8e51a 100644 (file)
@@ -830,20 +830,26 @@ public:
    * getKernelName function.
    */
   enum LayerKernel {
-    SGEMV = 1 << 0,       /**< placeholder for kernel name */
-    DOT = 1 << 1,         /**< placeholder for kernel name */
-    SGEMM = 1 << 2,       /**< placeholder for kernel name */
-    SGEMV_FP16 = 1 << 3,  /**< placeholder for kernel name */
-    DOT_FP16 = 1 << 4,    /**< placeholder for kernel name */
-    SGEMM_FP16 = 1 << 5,  /**< placeholder for kernel name */
-    ADD = 1 << 6,         /**< placeholder for kernel name */
-    ADD_FP16 = 1 << 7,    /**< placeholder for kernel name */
-    SWIGLU = 1 << 8,      /**< placeholder for kernel name */
-    SWIGLU_FP16 = 1 << 9, /**< placeholder for kernel name */
-    SSCAL = 1 << 10,      /**< placeholder for kernel name */
-    SSCAL_FP16 = 1 << 11, /**< placeholder for kernel name */
-    COPY = 1 << 12,       /**< placeholder for kernel name */
-    COPY_FP16 = 1 << 13   /**< placeholder for kernel name */
+    SGEMV = 1 << 0,               /**< placeholder for kernel name */
+    SGEMV_FP16 = 1 << 1,          /**< placeholder for kernel name */
+    DOT = 1 << 2,                 /**< placeholder for kernel name */
+    DOT_FP16 = 1 << 3,            /**< placeholder for kernel name */
+    SGEMM_NOTRANS = 1 << 4,       /**< placeholder for kernel name */
+    SGEMM_NOTRANS_FP16 = 1 << 5,  /**< placeholder for kernel name */
+    SGEMM_TRANSA = 1 << 6,        /**< placeholder for kernel name */
+    SGEMM_TRANSA_FP16 = 1 << 7,   /**< placeholder for kernel name */
+    SGEMM_TRANSB = 1 << 8,        /**< placeholder for kernel name */
+    SGEMM_TRANSB_FP16 = 1 << 9,   /**< placeholder for kernel name */
+    SGEMM_TRANSAB = 1 << 10,      /**< placeholder for kernel name */
+    SGEMM_TRANSAB_FP16 = 1 << 11, /**< placeholder for kernel name */
+    ADD = 1 << 12,                /**< placeholder for kernel name */
+    ADD_FP16 = 1 << 13,           /**< placeholder for kernel name */
+    SWIGLU = 1 << 14,             /**< placeholder for kernel name */
+    SWIGLU_FP16 = 1 << 15,        /**< placeholder for kernel name */
+    SSCAL = 1 << 16,              /**< placeholder for kernel name */
+    SSCAL_FP16 = 1 << 17,         /**< placeholder for kernel name */
+    COPY = 1 << 18,               /**< placeholder for kernel name */
+    COPY_FP16 = 1 << 19,          /**< placeholder for kernel name */
   };
 
   /**
index 9abeca9b781d6d15e167fa2a7c03f65a63462508..9e8422d404f5b67d9799dfc694052cd4847deaed 100644 (file)
@@ -147,9 +147,8 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result,
     }
     /// case others: use gemm
     else {
-      // transA == false, transB == false
-      sgemm_cl(data, mdata, rdata, M, N, K, lda, ldb, ldc, context);
-      // todo: other condition implementations
+      sgemm_cl(transA, transB, data, mdata, rdata, M, N, K, lda, ldb, ldc,
+               context);
     }
   } else if (input.getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
@@ -184,9 +183,8 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result,
     }
     /// case others: use sgemm
     else {
-      // transA == false, transB == false
-      sgemm_cl(data, mdata, rdata, M, N, K, lda, ldb, ldc, context);
-      // todo: other condition implementations
+      sgemm_cl(transA, transB, data, mdata, rdata, M, N, K, lda, ldb, ldc,
+               context);
     }
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
index 3d459232dcdcb3d41cb8119272ddbfe6c0f20e84..791cdc5e6b8c6e450314081b145e934d73ceebd1 100644 (file)
@@ -35,8 +35,8 @@ std::string dot_cl_kernel_ =
         }
     })";
 
-std::string sgemm_cl_kernel_ =
-  R"(__kernel void sgemm_cl(const __global float* A, const __global float* B,
+std::string sgemm_cl_noTrans_kernel_ =
+  R"(__kernel void sgemm_cl_noTrans(const __global float* A, const __global float* B,
                       __global float* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) {
         
         unsigned int m = get_global_id(0);
@@ -51,6 +51,58 @@ std::string sgemm_cl_kernel_ =
         C[m * ldc + n] = c;
     })";
 
+std::string sgemm_cl_transA_kernel_ =
+  R"(__kernel void sgemm_cl_transA(const __global float* A, const __global float* B,
+                      __global float* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) {
+        
+        unsigned int m = get_global_id(0);
+        unsigned int n = get_global_id(1);
+        float c = 0.0f;
+        for (unsigned int k = 0; k < K; ++k) {
+          float a, b;
+          a = A[k * lda + m];
+          b = B[k * ldb + n];
+          c += a * b;
+        }
+        C[m * ldc + n] = c;
+    })";
+
+std::string sgemm_cl_transB_kernel_ =
+  R"(__kernel void sgemm_cl_transB(const __global float *A, const __global float *B,
+                              __global float *C, unsigned int K,
+                              unsigned int lda, unsigned int ldb,
+                              unsigned int ldc) {
+
+        unsigned int m = get_global_id(0);
+        unsigned int n = get_global_id(1);
+        float c = 0.0f;
+        for (unsigned int k = 0; k < K; ++k) {
+          float a, b;
+          a = A[m * lda + k];
+          b = B[n * ldb + k];
+          c += a * b;
+        }
+        C[m * ldc + n] = c;
+    })";
+
+std::string sgemm_cl_transAB_kernel_ =
+  R"(__kernel void sgemm_cl_transAB(const __global float *A, const __global float *B,
+                               __global float *C, unsigned int K,
+                               unsigned int lda, unsigned int ldb,
+                               unsigned int ldc) {
+
+        unsigned int m = get_global_id(0);
+        unsigned int n = get_global_id(1);
+        float c = 0.0f;
+        for (unsigned int k = 0; k < K; ++k) {
+          float a, b;
+          a = A[k * lda + m];
+          b = B[n * ldb + k];
+          c += a * b;
+        }
+        C[m * ldc + n] = c;
+    })";
+
 std::string addition_cl_kernel_ =
   R"(__kernel void addition_cl(__global const float* input, __global float* output, const unsigned int size) {
     #pragma printf_support
@@ -71,7 +123,10 @@ std::string sscal_cl_kernel_ =
  * @brief defining global kernel objects
  */
 opencl::Kernel kernel_sgemv;
-opencl::Kernel kernel_sgemm;
+opencl::Kernel kernel_sgemm_transAB;
+opencl::Kernel kernel_sgemm_transA;
+opencl::Kernel kernel_sgemm_transB;
+opencl::Kernel kernel_sgemm_noTrans;
 opencl::Kernel kernel_dot;
 opencl::Kernel kernel_addition;
 opencl::Kernel kernel_sscal;
@@ -227,19 +282,43 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1,
   return cl_ret;
 }
 
-void sgemm_cl(const float *A, const float *B, float *C, unsigned int M,
-              unsigned int N, unsigned int K, unsigned int lda,
-              unsigned int ldb, unsigned int ldc, RunLayerContext &context) {
+void sgemm_cl(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const float *A,
+              const float *B, float *C, unsigned int M, unsigned int N,
+              unsigned int K, unsigned int lda, unsigned int ldb,
+              unsigned int ldc, RunLayerContext &context) {
+
+  opencl::Kernel *kernel_sgemm = nullptr;
+  RunLayerContext::LayerKernel layerKernel;
+  std::string sgemm_cl_kernel_;
+
+  if (TransA != CblasTrans && TransB != CblasTrans) {
+    kernel_sgemm = &kernel_sgemm_noTrans;
+    layerKernel = context.LayerKernel::SGEMM_NOTRANS;
+    sgemm_cl_kernel_ = sgemm_cl_noTrans_kernel_;
+  } else if (TransA == CblasTrans && TransB != CblasTrans) {
+    kernel_sgemm = &kernel_sgemm_transA;
+    layerKernel = context.LayerKernel::SGEMM_TRANSA;
+    sgemm_cl_kernel_ = sgemm_cl_transA_kernel_;
+  } else if (TransA != CblasTrans && TransB == CblasTrans) {
+    kernel_sgemm = &kernel_sgemm_transB;
+    layerKernel = context.LayerKernel::SGEMM_TRANSB;
+    sgemm_cl_kernel_ = sgemm_cl_transB_kernel_;
+  } else {
+    kernel_sgemm = &kernel_sgemm_transAB;
+    layerKernel = context.LayerKernel::SGEMM_TRANSAB;
+    sgemm_cl_kernel_ = sgemm_cl_transAB_kernel_;
+  }
 
   bool result = false;
 
   do {
-    result = context.clCreateKernel(sgemm_cl_kernel_,
-                                    context.LayerKernel::SGEMM, kernel_sgemm);
+    result =
+      context.clCreateKernel(sgemm_cl_kernel_, layerKernel, *kernel_sgemm);
     if (!result) {
       break;
     }
 
+    // sizes will be same for transpose
     size_t m_k_size = M * K * sizeof(float);
     size_t k_n_size = K * N * sizeof(float);
     size_t m_n_size = M * N * sizeof(float);
@@ -265,37 +344,37 @@ void sgemm_cl(const float *A, const float *B, float *C, unsigned int M,
       break;
     }
 
-    result = kernel_sgemm.SetKernelArguments(0, &inputA, sizeof(cl_mem));
+    result = kernel_sgemm->SetKernelArguments(0, &inputA, sizeof(cl_mem));
     if (!result) {
       break;
     }
 
-    result = kernel_sgemm.SetKernelArguments(1, &inputB, sizeof(cl_mem));
+    result = kernel_sgemm->SetKernelArguments(1, &inputB, sizeof(cl_mem));
     if (!result) {
       break;
     }
 
-    result = kernel_sgemm.SetKernelArguments(2, &inOutC, sizeof(cl_mem));
+    result = kernel_sgemm->SetKernelArguments(2, &inOutC, sizeof(cl_mem));
     if (!result) {
       break;
     }
 
-    result = kernel_sgemm.SetKernelArguments(3, &K, sizeof(int));
+    result = kernel_sgemm->SetKernelArguments(3, &K, sizeof(int));
     if (!result) {
       break;
     }
 
-    result = kernel_sgemm.SetKernelArguments(4, &lda, sizeof(int));
+    result = kernel_sgemm->SetKernelArguments(4, &lda, sizeof(int));
     if (!result) {
       break;
     }
 
-    result = kernel_sgemm.SetKernelArguments(5, &ldb, sizeof(int));
+    result = kernel_sgemm->SetKernelArguments(5, &ldb, sizeof(int));
     if (!result) {
       break;
     }
 
-    result = kernel_sgemm.SetKernelArguments(6, &ldc, sizeof(int));
+    result = kernel_sgemm->SetKernelArguments(6, &ldc, sizeof(int));
     if (!result) {
       break;
     }
@@ -304,7 +383,7 @@ void sgemm_cl(const float *A, const float *B, float *C, unsigned int M,
     const int work_group_size[3] = {32, 32, 1}; // test-value
 
     result = context.command_queue_inst_.DispatchCommand(
-      kernel_sgemm, work_groups_count, work_group_size);
+      *kernel_sgemm, work_groups_count, work_group_size);
     if (!result) {
       break;
     }
index 3ae4ae97b387c24d77f32a4a767640cafc5b8db9..6b118c68dd0b92715a11ed4c46adfaa2a92c3f3d 100644 (file)
@@ -25,7 +25,10 @@ namespace nntrainer {
  * @brief declaring global kernel objects
  */
 extern opencl::Kernel kernel_sgemv;
-extern opencl::Kernel kernel_sgemm;
+extern opencl::Kernel kernel_sgemm_noTrans;
+extern opencl::Kernel kernel_sgemm_transAB;
+extern opencl::Kernel kernel_sgemm_transA;
+extern opencl::Kernel kernel_sgemm_transB;
 extern opencl::Kernel kernel_dot;
 extern opencl::Kernel kernel_addition;
 extern opencl::Kernel kernel_sscal;
@@ -58,6 +61,8 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1,
 /**
  * @brief     sgemm computation : Y = op(A)*op(B) + C,
  * where op(X) is one of X or X**T
+ * @param[in] transA CBLAS_TRANSPOSE
+ * @param[in] transB CBLAS_TRANSPOSE
  * @param[in] A float * for Matrix A
  * @param[in] B float * for Matrix B
  * @param[in] C float * for Matrix C
@@ -69,9 +74,10 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1,
  * @param[in] ldc number of C's columns
  * @param[in] context RunLayerContext reference
  */
-void sgemm_cl(const float *A, const float *B, float *C, unsigned int M,
-              unsigned int N, unsigned int K, unsigned int lda,
-              unsigned int ldb, unsigned int ldc, RunLayerContext &context);
+void sgemm_cl(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const float *A,
+              const float *B, float *C, unsigned int M, unsigned int N,
+              unsigned int K, unsigned int lda, unsigned int ldb,
+              unsigned int ldc, RunLayerContext &context);
 
 /**
  * @brief     addition : sum of all input vectors
@@ -98,7 +104,10 @@ void sscal_cl(float *X, const unsigned int N, const float alpha,
  * @brief declaring global fp16 kernel objects
  */
 extern opencl::Kernel kernel_sgemv_fp16;
-extern opencl::Kernel kernel_sgemm_fp16;
+extern opencl::Kernel kernel_sgemm_noTrans_fp16;
+extern opencl::Kernel kernel_sgemm_transAB_fp16;
+extern opencl::Kernel kernel_sgemm_transA_fp16;
+extern opencl::Kernel kernel_sgemm_transB_fp16;
 extern opencl::Kernel kernel_dot_fp16;
 extern opencl::Kernel kernel_addition_fp16;
 extern opencl::Kernel kernel_sscal_fp16;
@@ -131,6 +140,8 @@ __fp16 dot_cl(const __fp16 *vecAdata, const __fp16 *vecXdata, unsigned int dim1,
 /**
  * @brief     fp16 sgemm computation : Y = op(A)*op(B) + C,
  * where op(X) is one of X or X**T
+ * @param[in] transA CBLAS_TRANSPOSE
+ * @param[in] transB CBLAS_TRANSPOSE
  * @param[in] A fp16 * for Matrix A
  * @param[in] B fp16 * for Matrix B
  * @param[in] C fp16 * for Matrix C
@@ -142,9 +153,10 @@ __fp16 dot_cl(const __fp16 *vecAdata, const __fp16 *vecXdata, unsigned int dim1,
  * @param[in] ldc number of C's columns
  * @param[in] context RunLayerContext reference
  */
-void sgemm_cl(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
-              unsigned int N, unsigned int K, unsigned int lda,
-              unsigned int ldb, unsigned int ldc, RunLayerContext &context);
+void sgemm_cl(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const __fp16 *A,
+              const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N,
+              unsigned int K, unsigned int lda, unsigned int ldb,
+              unsigned int ldc, RunLayerContext &context);
 
 /**
  * @brief     fp16 addition : sum of all input vectors
index 83f0d2136b934a458440f17145b6ea1b39a08559..96c7ce9c90ac9a8bf6c6f72366f70984c59ca9ed 100644 (file)
@@ -41,11 +41,11 @@ std::string dot_cl_kernel_fp16_ =
         }
     })";
 
-std::string sgemm_cl_kernel_fp16_ =
+std::string sgemm_cl_noTrans_kernel_fp16_ =
   R"(
     #pragma OPENCL EXTENSION cl_khr_fp16 : enable
 
-    __kernel void sgemm_cl_fp16(const __global half* A, const __global half* B,
+    __kernel void sgemm_cl_noTrans_fp16(const __global half* A, const __global half* B,
                       __global half* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) {
         
         unsigned int m = get_global_id(0);
@@ -60,6 +60,63 @@ std::string sgemm_cl_kernel_fp16_ =
         C[m * ldc + n] = c;
     })";
 
+std::string sgemm_cl_transA_kernel_fp16_ =
+  R"(
+    #pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+    __kernel void sgemm_cl_transA_fp16(const __global half* A, const __global half* B,
+                      __global half* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) {
+        
+        unsigned int m = get_global_id(0);
+        unsigned int n = get_global_id(1);
+        half c = 0.0f;
+        for (unsigned int k = 0; k < K; ++k) {
+          half a, b;
+          a = A[k * lda + m];
+          b = B[k * ldb + n];
+          c += a * b;
+        }
+        C[m * ldc + n] = c;
+    })";
+
+std::string sgemm_cl_transB_kernel_fp16_ =
+  R"(
+    #pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+    __kernel void sgemm_cl_transB_fp16(const __global half* A, const __global half* B,
+                      __global half* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) {
+        
+        unsigned int m = get_global_id(0);
+        unsigned int n = get_global_id(1);
+        half c = 0.0f;
+        for (unsigned int k = 0; k < K; ++k) {
+          half a, b;
+          a = A[m * lda + k];
+          b = B[n * ldb + k];
+          c += a * b;
+        }
+        C[m * ldc + n] = c;
+    })";
+
+std::string sgemm_cl_transAB_kernel_fp16_ =
+  R"(
+    #pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+    __kernel void sgemm_cl_transAB_fp16(const __global half* A, const __global half* B,
+                      __global half* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) {
+        
+        unsigned int m = get_global_id(0);
+        unsigned int n = get_global_id(1);
+        half c = 0.0f;
+        for (unsigned int k = 0; k < K; ++k) {
+          half a, b;
+          a = A[k * lda + m];
+          b = B[n * ldb + k];
+          c += a * b;
+        }
+        C[m * ldc + n] = c;
+    })";
+
 std::string addition_cl_kernel_fp16_ =
   R"(
     #pragma OPENCL EXTENSION cl_khr_fp16 : enable
@@ -85,7 +142,10 @@ std::string sscal_cl_kernel_fp16_ =
  * @brief defining global kernel objects
  */
 opencl::Kernel kernel_sgemv_fp16;
-opencl::Kernel kernel_sgemm_fp16;
+opencl::Kernel kernel_sgemm_transAB_fp16;
+opencl::Kernel kernel_sgemm_transA_fp16;
+opencl::Kernel kernel_sgemm_transB_fp16;
+opencl::Kernel kernel_sgemm_noTrans_fp16;
 opencl::Kernel kernel_dot_fp16;
 opencl::Kernel kernel_addition_fp16;
 opencl::Kernel kernel_sscal_fp16;
@@ -242,20 +302,43 @@ __fp16 dot_cl(const __fp16 *vecAdata, const __fp16 *vecXdata, unsigned int dim1,
   return cl_ret;
 }
 
-void sgemm_cl(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
-              unsigned int N, unsigned int K, unsigned int lda,
-              unsigned int ldb, unsigned int ldc, RunLayerContext &context) {
+void sgemm_cl(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const __fp16 *A,
+              const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N,
+              unsigned int K, unsigned int lda, unsigned int ldb,
+              unsigned int ldc, RunLayerContext &context) {
+
+  opencl::Kernel *kernel_sgemm_fp16 = nullptr;
+  RunLayerContext::LayerKernel layerKernel;
+  std::string sgemm_cl_kernel_fp16_;
+
+  if (TransA != CblasTrans && TransB != CblasTrans) {
+    kernel_sgemm_fp16 = &kernel_sgemm_noTrans_fp16;
+    layerKernel = context.LayerKernel::SGEMM_NOTRANS_FP16;
+    sgemm_cl_kernel_fp16_ = sgemm_cl_noTrans_kernel_fp16_;
+  } else if (TransA == CblasTrans && TransB != CblasTrans) {
+    kernel_sgemm_fp16 = &kernel_sgemm_transA_fp16;
+    layerKernel = context.LayerKernel::SGEMM_TRANSA_FP16;
+    sgemm_cl_kernel_fp16_ = sgemm_cl_transA_kernel_fp16_;
+  } else if (TransA != CblasTrans && TransB == CblasTrans) {
+    kernel_sgemm_fp16 = &kernel_sgemm_transB_fp16;
+    layerKernel = context.LayerKernel::SGEMM_TRANSB_FP16;
+    sgemm_cl_kernel_fp16_ = sgemm_cl_transB_kernel_fp16_;
+  } else {
+    kernel_sgemm_fp16 = &kernel_sgemm_transAB_fp16;
+    layerKernel = context.LayerKernel::SGEMM_TRANSAB_FP16;
+    sgemm_cl_kernel_fp16_ = sgemm_cl_transAB_kernel_fp16_;
+  }
 
   bool result = false;
 
   do {
-    result = context.clCreateKernel(sgemm_cl_kernel_fp16_,
-                                    context.LayerKernel::SGEMM_FP16,
-                                    kernel_sgemm_fp16);
+    result = context.clCreateKernel(sgemm_cl_kernel_fp16_, layerKernel,
+                                    *kernel_sgemm_fp16);
     if (!result) {
       break;
     }
 
+    // sizes will be same for transpose
     size_t m_k_size = M * K * sizeof(cl_half);
     size_t k_n_size = K * N * sizeof(cl_half);
     size_t m_n_size = M * N * sizeof(cl_half);
@@ -281,37 +364,37 @@ void sgemm_cl(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
       break;
     }
 
-    result = kernel_sgemm_fp16.SetKernelArguments(0, &inputA, sizeof(cl_mem));
+    result = kernel_sgemm_fp16->SetKernelArguments(0, &inputA, sizeof(cl_mem));
     if (!result) {
       break;
     }
 
-    result = kernel_sgemm_fp16.SetKernelArguments(1, &inputB, sizeof(cl_mem));
+    result = kernel_sgemm_fp16->SetKernelArguments(1, &inputB, sizeof(cl_mem));
     if (!result) {
       break;
     }
 
-    result = kernel_sgemm_fp16.SetKernelArguments(2, &inOutC, sizeof(cl_mem));
+    result = kernel_sgemm_fp16->SetKernelArguments(2, &inOutC, sizeof(cl_mem));
     if (!result) {
       break;
     }
 
-    result = kernel_sgemm_fp16.SetKernelArguments(3, &K, sizeof(int));
+    result = kernel_sgemm_fp16->SetKernelArguments(3, &K, sizeof(int));
     if (!result) {
       break;
     }
 
-    result = kernel_sgemm_fp16.SetKernelArguments(4, &lda, sizeof(int));
+    result = kernel_sgemm_fp16->SetKernelArguments(4, &lda, sizeof(int));
     if (!result) {
       break;
     }
 
-    result = kernel_sgemm_fp16.SetKernelArguments(5, &ldb, sizeof(int));
+    result = kernel_sgemm_fp16->SetKernelArguments(5, &ldb, sizeof(int));
     if (!result) {
       break;
     }
 
-    result = kernel_sgemm_fp16.SetKernelArguments(6, &ldc, sizeof(int));
+    result = kernel_sgemm_fp16->SetKernelArguments(6, &ldc, sizeof(int));
     if (!result) {
       break;
     }
@@ -320,7 +403,7 @@ void sgemm_cl(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
     const int work_group_size[3] = {32, 32, 1}; // test-value
 
     result = context.command_queue_inst_.DispatchCommand(
-      kernel_sgemm_fp16, work_groups_count, work_group_size);
+      *kernel_sgemm_fp16, work_groups_count, work_group_size);
     if (!result) {
       break;
     }
index cac5b9e964eabe2b353801c1c4ae4ce299921f44..d897d69e8db746208268918f4a9aa0885be3fe01 100644 (file)
@@ -44,7 +44,7 @@ TEST(blas_kernels, dotCL_sgemv) {
   int width = 768;
 
   int height_b = 768;
-  int width_b = 96000;
+  int width_b = 2048;
 
   bool transA = false;
   bool transB = false;
@@ -94,7 +94,7 @@ TEST(blas_kernels, dotCL_sgemv_n) {
   int width = 768;
 
   int height_b = 768;
-  int width_b = 96000;
+  int width_b = 2048;
 
   bool transA = true;
   bool transB = false;
@@ -166,6 +166,254 @@ TEST(nntrainer_Tensor, multiply_i) {
   EXPECT_IN_RANGE(cosSimNeon, 0.99, 1);
 }
 
+TEST(nntrainer_Tensor, dot_gemm_50_768_1024_noTrans) {
+  /// @note GEMM : A X B = C
+  RunLayerContext rc = setUpGpuContext();
+
+  int batch = 1;
+  int channel = 1;
+  int height = 50;
+  int width = 768;
+
+  int height_b = 768;
+  int width_b = 1024;
+
+  bool transA = false;
+  bool transB = false;
+
+  const float alpha = 1e-1;
+  const int MOD = 10;
+
+  nntrainer::TensorDim::TensorType t_type_nchw_fp16 = {
+    nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16};
+
+  nntrainer::TensorDim::TensorType t_type_nchw_fp32 = {
+    nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32};
+
+  nntrainer::Tensor A(batch, channel, height, width, t_type_nchw_fp16);
+  nntrainer::Tensor B(batch, channel, height_b, width_b, t_type_nchw_fp16);
+
+  nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32);
+  nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32);
+
+  GEN_TEST_INPUT(A, ((i * (batch * height * channel) + j * (batch * height) +
+                      k * (width) + l + 1) %
+                     MOD) *
+                      alpha);
+  GEN_TEST_INPUT_B(B, ((i * (batch * height_b * channel) +
+                        j * (batch * height_b) + k * (width_b) + l + 1) %
+                       MOD) *
+                        alpha);
+  GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) +
+                           j * (batch * height) + k * (width) + l + 1) %
+                          MOD) *
+                           alpha);
+  GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) +
+                             j * (batch * height_b) + k * (width_b) + l + 1) %
+                            MOD) *
+                             alpha);
+
+  nntrainer::Tensor C = dotCl(A_fp32, B_fp32, rc, transA, transB);
+  nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB);
+
+  float mseErrorNeon =
+    mse<float>(C.getData<float>(), C_fp32.getData<float>(), C.size());
+
+  double cosSimNeon = cosine_similarity<float>(
+    C.getData<float>(), C_fp32.getData<float>(), C.size());
+
+  const float epsilon = 1e-3 * width;
+
+  EXPECT_IN_RANGE(mseErrorNeon, 0, epsilon);
+  EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1);
+}
+
+TEST(nntrainer_Tensor, dot_gemm_50_768_2048_transB) {
+  /// @note GEMM : A X B = C
+  RunLayerContext rc = setUpGpuContext();
+
+  int batch = 1;
+  int channel = 1;
+  int height = 50;
+  int width = 768;
+
+  int height_b = 2048;
+  int width_b = 768;
+
+  bool transA = false;
+  bool transB = true;
+
+  const float alpha = 1e-1;
+  const int MOD = 10;
+
+  nntrainer::TensorDim::TensorType t_type_nchw_fp16 = {
+    nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16};
+
+  nntrainer::TensorDim::TensorType t_type_nchw_fp32 = {
+    nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32};
+
+  nntrainer::Tensor A(batch, channel, height, width, t_type_nchw_fp16);
+  nntrainer::Tensor B(batch, channel, height_b, width_b, t_type_nchw_fp16);
+
+  nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32);
+  nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32);
+
+  GEN_TEST_INPUT(A, ((i * (batch * height * channel) + j * (batch * height) +
+                      k * (width) + l + 1) %
+                     MOD) *
+                      alpha);
+  GEN_TEST_INPUT_B(B, ((i * (batch * height_b * channel) +
+                        j * (batch * height_b) + k * (width_b) + l + 1) %
+                       MOD) *
+                        alpha);
+  GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) +
+                           j * (batch * height) + k * (width) + l + 1) %
+                          MOD) *
+                           alpha);
+  GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) +
+                             j * (batch * height_b) + k * (width_b) + l + 1) %
+                            MOD) *
+                             alpha);
+
+  nntrainer::Tensor C = dotCl(A_fp32, B_fp32, rc, transA, transB);
+  nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB);
+
+  float mseErrorNeon =
+    mse<float>(C.getData<float>(), C_fp32.getData<float>(), C.size());
+
+  double cosSimNeon = cosine_similarity<float>(
+    C.getData<float>(), C_fp32.getData<float>(), C.size());
+
+  const float epsilon = 1e-3 * width;
+
+  EXPECT_IN_RANGE(mseErrorNeon, 0, epsilon);
+  EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1);
+}
+
+TEST(nntrainer_Tensor, dot_gemm_50_768_1024_transA) {
+  /// @note GEMM : A X B = C
+  RunLayerContext rc = setUpGpuContext();
+
+  int batch = 1;
+  int channel = 1;
+  int height = 768;
+  int width = 50;
+
+  int height_b = 768;
+  int width_b = 1024;
+
+  bool transA = true;
+  bool transB = false;
+
+  const float alpha = 1e-1;
+  const int MOD = 10;
+
+  nntrainer::TensorDim::TensorType t_type_nchw_fp16 = {
+    nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16};
+
+  nntrainer::TensorDim::TensorType t_type_nchw_fp32 = {
+    nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32};
+
+  nntrainer::Tensor A(batch, channel, height, width, t_type_nchw_fp16);
+  nntrainer::Tensor B(batch, channel, height_b, width_b, t_type_nchw_fp16);
+
+  nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32);
+  nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32);
+
+  GEN_TEST_INPUT(A, ((i * (batch * height * channel) + j * (batch * height) +
+                      k * (width) + l + 1) %
+                     MOD) *
+                      alpha);
+  GEN_TEST_INPUT_B(B, ((i * (batch * height_b * channel) +
+                        j * (batch * height_b) + k * (width_b) + l + 1) %
+                       MOD) *
+                        alpha);
+  GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) +
+                           j * (batch * height) + k * (width) + l + 1) %
+                          MOD) *
+                           alpha);
+  GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) +
+                             j * (batch * height_b) + k * (width_b) + l + 1) %
+                            MOD) *
+                             alpha);
+
+  nntrainer::Tensor C = dotCl(A_fp32, B_fp32, rc, transA, transB);
+  nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB);
+
+  float mseErrorNeon =
+    mse<float>(C.getData<float>(), C_fp32.getData<float>(), C.size());
+
+  double cosSimNeon = cosine_similarity<float>(
+    C.getData<float>(), C_fp32.getData<float>(), C.size());
+
+  const float epsilon = 1e-3 * width;
+
+  EXPECT_IN_RANGE(mseErrorNeon, 0, epsilon);
+  EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1);
+}
+
+TEST(nntrainer_Tensor, dot_gemm_50_768_2048_transAB) {
+  /// @note GEMM : A X B = C
+  RunLayerContext rc = setUpGpuContext();
+
+  int batch = 1;
+  int channel = 1;
+  int height = 768;
+  int width = 50;
+
+  int height_b = 2048;
+  int width_b = 768;
+
+  bool transA = true;
+  bool transB = true;
+
+  const float alpha = 1e-1;
+  const int MOD = 10;
+
+  nntrainer::TensorDim::TensorType t_type_nchw_fp16 = {
+    nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16};
+
+  nntrainer::TensorDim::TensorType t_type_nchw_fp32 = {
+    nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32};
+
+  nntrainer::Tensor A(batch, channel, height, width, t_type_nchw_fp16);
+  nntrainer::Tensor B(batch, channel, height_b, width_b, t_type_nchw_fp16);
+
+  nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32);
+  nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32);
+
+  GEN_TEST_INPUT(A, ((i * (batch * height * channel) + j * (batch * height) +
+                      k * (width) + l + 1) %
+                     MOD) *
+                      alpha);
+  GEN_TEST_INPUT_B(B, ((i * (batch * height_b * channel) +
+                        j * (batch * height_b) + k * (width_b) + l + 1) %
+                       MOD) *
+                        alpha);
+  GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) +
+                           j * (batch * height) + k * (width) + l + 1) %
+                          MOD) *
+                           alpha);
+  GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) +
+                             j * (batch * height_b) + k * (width_b) + l + 1) %
+                            MOD) *
+                             alpha);
+
+  nntrainer::Tensor C = dotCl(A_fp32, B_fp32, rc, transA, transB);
+  nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB);
+
+  float mseErrorNeon =
+    mse<float>(C.getData<float>(), C_fp32.getData<float>(), C.size());
+
+  double cosSimNeon = cosine_similarity<float>(
+    C.getData<float>(), C_fp32.getData<float>(), C.size());
+
+  const float epsilon = 1e-3 * width;
+
+  EXPECT_IN_RANGE(mseErrorNeon, 0, epsilon);
+  EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1);
+}
+
 GTEST_API_ int main(int argc, char **argv) {
   int result = -1;