[GPU/OpenCl] Kernel optimization
authorYash Singh <yash.singh@samsung.com>
Tue, 3 Sep 2024 11:39:35 +0000 (17:09 +0530)
committerJijoong Moon <jijoong.moon@samsung.com>
Sun, 20 Oct 2024 23:33:51 +0000 (08:33 +0900)
Kernel Optimized for GPU. Some trivial changes in code.

Signed-off-by: Yash Singh <yash.singh@samsung.com>
nntrainer/tensor/cl_operations/attention_kernel_interface.cpp
nntrainer/tensor/cl_operations/attention_kernels.cpp
nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp
nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp

index 155127f472547caff2269073950fa7ebe11a525e..85c3331eddda2678fbd45eb7dc1fa111140e3fad 100644 (file)
@@ -24,7 +24,7 @@ namespace nntrainer {
  * @param[out] freqs base frequencies array to be used in the future computation
  * @param[in]  theta rotary angle
  */
-void precompute_freqs(int dim, unsigned int seq_len,
+void precompute_freqs(unsigned int dim, unsigned int seq_len,
                       std::vector<std::vector<float>> &freqs_cos,
                       std::vector<std::vector<float>> &freqs_sin,
                       std::vector<float> &freqs, float theta = 10000.0) {
@@ -33,24 +33,24 @@ void precompute_freqs(int dim, unsigned int seq_len,
     freqs.push_back(1.0 / (std::pow(theta, (2 * i) / static_cast<float>(dim))));
   }
 
-  auto cos = std::vector<std::vector<float>>();
-  cos.assign(seq_len, std::vector<float>(dim, 0));
+  auto cos_vec = std::vector<std::vector<float>>();
+  cos_vec.assign(seq_len, std::vector<float>(dim, 0));
 
-  auto sin = std::vector<std::vector<float>>();
-  sin.assign(seq_len, std::vector<float>(dim, 0));
+  auto sin_vec = std::vector<std::vector<float>>();
+  sin_vec.assign(seq_len, std::vector<float>(dim, 0));
 
   for (unsigned int i = 0; i < seq_len; ++i) {
     for (unsigned int j = 0; j < half_; ++j) {
       float angle = i * freqs[j];
-      cos[i][j] = std::cos(angle);
-      cos[i][j + half_] = std::cos(angle); // repeated 2 times
+      cos_vec[i][j] = std::cos(angle);
+      cos_vec[i][j + half_] = std::cos(angle); // repeated 2 times
 
-      sin[i][j] = std::sin(angle);
-      sin[i][j + half_] = std::sin(angle); // repeated 2 times
+      sin_vec[i][j] = std::sin(angle);
+      sin_vec[i][j + half_] = std::sin(angle); // repeated 2 times
     }
   }
-  freqs_cos = cos;
-  freqs_sin = sin;
+  freqs_cos = cos_vec;
+  freqs_sin = sin_vec;
 }
 
 /**
@@ -59,12 +59,15 @@ void precompute_freqs(int dim, unsigned int seq_len,
  * @param[in] dim hidden dim size
  * @param[in] from sequence order
  * @param[in] max_timestep maximum timestep
+ * @param[in] context layer context to get the resource manager and queue id
+ *
+ * @todo      Calling precompute_freqs in finalize to reduce code redundancy.
  */
 void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from,
                          unsigned int max_timestep, RunLayerContext &context) {
   nntrainer::Tensor out(in.getDim());
-  float value = 0;
-  float transformed_value = 0.0;
+  float value = 0.0f;
+  float transformed_value = 0.0f;
   unsigned int half_ = dim / 2;
 
   std::vector<std::vector<float>> freqs_cos = {};
index 9b5cb7e69998d048583c875187f8cf368482d6fa..5fd646b7c153b699d02a647e5b5bb3c7c6a2a4bd 100644 (file)
@@ -30,37 +30,36 @@ __kernel void rotary_emb_cl(__global float *input,
                                       unsigned int half_,
                                       unsigned int max_timestep,
                                       unsigned int from) {
-    unsigned int gid = get_global_id(0);
-    unsigned int gws = get_global_size(0);
-
     __global float *cos_ptr = cos_;
     __global float *sin_ptr = sin_;
 
     float value = 0.0f;
     float transformed_value = 0.0f;
 
-    for (unsigned int b = 0; b < batch; b++) {
-      for (unsigned int c = 0; c < channel; c++) {
-        for (unsigned int h = 0; h < height; h++) {
-          if (from + h < max_timestep) {
-            unsigned idx = (from + h)*dim;
-            for(unsigned int i = idx; i < idx + dim; i++){
-              cos_ptr[i - idx] = freqs_cos[i];
-              sin_ptr[i - idx] = freqs_sin[i];
-            }
+    unsigned int b = get_global_id(0);
+    unsigned int c = get_global_id(1);
+    
+    if(b < batch && c < channel){
+      for (unsigned int h = 0; h < height; h++) {
+        if (from + h < max_timestep) {
+          unsigned idx = (from + h)*dim;
+          for(unsigned int i = idx; i < idx + dim; i++){
+            cos_ptr[i - idx] = freqs_cos[i];
+            sin_ptr[i - idx] = freqs_sin[i];
           }
-          for (unsigned int w = 0; w < width; w = w + dim) {
-            for (unsigned int k = 0; k < dim; k++) {
-              unsigned int span = w + k;
-              value = input[b * channel * height * width + c * height * width + h * width + span];
-              if (k < half_) {
-                transformed_value = -1.0f * input[b * channel * height * width + c * height * width + h * width + span + half_];
-              } else {
-                transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_];
-              }
-              value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
-              output[b * channel * height * width + c * height * width + h * width + span] = value;
+        }
+
+        for (unsigned int w = 0; w < width; w = w + dim) {
+          for (unsigned int k = 0; k < dim; k++) {
+            unsigned int span = w + k;
+            value = input[b * channel * height * width + c * height * width + h * width + span];
+            if (k < half_) {
+              transformed_value = -1.0f * input[b * channel * height * width + c * height * width + h * width + span + half_];
+            } else {
+              transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_];
             }
+            value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
+            output[b * channel * height * width + c * height * width + h * width + span] = value;
           }
         }
       }
@@ -252,8 +251,8 @@ void rotary_emb_cl(float *in, float *out,
       break;
     }
 
-    const int work_groups_count[3] = {1, 1, 1};
-    const int work_group_size[3] = {32, 1, 1}; // test-value
+    const int work_groups_count[3] = {(int)batch, (int)channel, 1};
+    const int work_group_size[3] = {32, 32, 1}; // test-value
     result = context.command_queue_inst_.DispatchCommand(
       kernel_rotary_emb, work_groups_count, work_group_size);
     if (!result) {
index c6b1fbb263e58672c9e7773dafa6b6094ea95537..7c2c995020faca3979b2fdb2ef5e743bda1d8b19 100644 (file)
@@ -30,38 +30,36 @@ __kernel void rotary_emb_cl_fp16(__global half *input,
                                       unsigned int half_,
                                       unsigned int max_timestep,
                                       unsigned int from) {
-    unsigned int gid = get_global_id(0);
-    unsigned int gws = get_global_size(0);
-
     __global float *cos_ptr = cos_;
     __global float *sin_ptr = sin_;
 
     float value = 0.0f;
     float transformed_value = 0.0f;
 
-    for (unsigned int b = 0; b < batch; b++) {
-      for (unsigned int c = 0; c < channel; c++) {
-        for (unsigned int h = 0; h < height; h++) {
-          if (from + h < max_timestep) {
-            unsigned idx = (from + h)*dim;
-            for(int i = idx; i < idx + dim; i++ ){
-              cos_ptr[i - idx] = freqs_cos[i];
-              sin_ptr[i - idx] = freqs_sin[i];
-            }
+    unsigned int b = get_global_id(0);
+    unsigned int c = get_global_id(1);
+    
+    if(b < batch && c < channel){
+      for (unsigned int h = 0; h < height; h++) {
+        if (from + h < max_timestep) {
+          unsigned idx = (from + h)*dim;
+          for(int i = idx; i < idx + dim; i++ ){
+            cos_ptr[i - idx] = freqs_cos[i];
+            sin_ptr[i - idx] = freqs_sin[i];
           }
+        }
 
-          for (unsigned int w = 0; w < width; w = w + dim) {
-            for (unsigned int k = 0; k < dim; k++) {
-              unsigned int span = w + k;
-              value = (float)input[b * channel * height * width + c * height * width + h * width + span];
-              if (k < half_) {
-                transformed_value = -1.0f * (float)input[b * channel * height * width + c * height * width + h * width + span + half_];
-              } else {
-                transformed_value = (float)input[b * channel * height * width + c * height * width + h * width + span - half_];
-              }
-              value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
-              output[b * channel * height * width + c * height * width + h * width + span] = (half)value;
+        for (unsigned int w = 0; w < width; w = w + dim) {
+          for (unsigned int k = 0; k < dim; k++) {
+            unsigned int span = w + k;
+            value = (float)input[b * channel * height * width + c * height * width + h * width + span];
+            if (k < half_) {
+              transformed_value = -1.0f * (float)input[b * channel * height * width + c * height * width + h * width + span + half_];
+            } else {
+              transformed_value = (float)input[b * channel * height * width + c * height * width + h * width + span - half_];
             }
+            value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
+            output[b * channel * height * width + c * height * width + h * width + span] = (half)value;
           }
         }
       }
@@ -259,8 +257,8 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out,
       break;
     }
 
-    const int work_groups_count[3] = {1, 1, 1};
-    const int work_group_size[3] = {32, 1, 1}; // test-value
+    const int work_groups_count[3] = {(int)batch, (int)channel, 1};
+    const int work_group_size[3] = {32, 32, 1}; // test-value
     result = context.command_queue_inst_.DispatchCommand(
       kernel_rotary_emb_fp16, work_groups_count, work_group_size);
     if (!result) {
index 4ebde87332c16f5cccfc26b8ff8800aaf5032f40..d7bab6cc49cbd74f0997dfe20cd381b2dbf16483 100644 (file)
@@ -15,7 +15,8 @@
 #include <string>
 
 /**
- * @brief     compute frequency for rotary embedding
+ * @brief     Testing code for CPU results and compute frequency for rotary
+ * embedding
  * @param[in] dim hidden dim size
  * @param[in] seq_len sequency length
  * @param[out] freqs_cos cosine of the frequencies
@@ -24,7 +25,7 @@
  * sin values for each position in sequence
  * @param[in] theta rotary angle
  */
-void precompute_freqs(int dim, unsigned int seq_len,
+void precompute_freqs(unsigned int dim, unsigned int seq_len,
                       std::vector<std::vector<float>> &freqs_cos,
                       std::vector<std::vector<float>> &freqs_sin,
                       std::vector<float> &freqs, float theta = 10000.0) {
@@ -35,29 +36,29 @@ void precompute_freqs(int dim, unsigned int seq_len,
                       (std::pow(theta, (2 * i) / static_cast<float>(dim))));
     }
 
-    auto cos = std::vector<std::vector<float>>();
-    cos.assign(seq_len, std::vector<float>(dim, 0));
+    auto cos_vec = std::vector<std::vector<float>>();
+    cos_vec.assign(seq_len, std::vector<float>(dim, 0));
 
-    auto sin = std::vector<std::vector<float>>();
-    sin.assign(seq_len, std::vector<float>(dim, 0));
+    auto sin_vec = std::vector<std::vector<float>>();
+    sin_vec.assign(seq_len, std::vector<float>(dim, 0));
 
     for (unsigned int i = 0; i < seq_len; ++i) {
       for (unsigned int j = 0; j < half_; ++j) {
         float angle = i * freqs[j];
-        cos[i][j] = std::cos(angle);
-        cos[i][j + half_] = std::cos(angle); // repeated 2 times
+        cos_vec[i][j] = std::cos(angle);
+        cos_vec[i][j + half_] = std::cos(angle); // repeated 2 times
 
-        sin[i][j] = std::sin(angle);
-        sin[i][j + half_] = std::sin(angle); // repeated 2 times
+        sin_vec[i][j] = std::sin(angle);
+        sin_vec[i][j + half_] = std::sin(angle); // repeated 2 times
       }
     }
-    freqs_cos = cos;
-    freqs_sin = sin;
+    freqs_cos = cos_vec;
+    freqs_sin = sin_vec;
   }
 }
 
 /**
- * @brief     apply rotary embedding
+ * @brief     Testing code for CPU results and apply rotary embedding
  * @param[in] in input tensor
  * @param[in] dim hidden dim size
  * @param[in] from sequence order
@@ -66,8 +67,8 @@ void precompute_freqs(int dim, unsigned int seq_len,
 void apply_rotary_emb_tensor(nntrainer::Tensor &in, unsigned int dim,
                              unsigned int from, unsigned int max_timestep) {
   nntrainer::Tensor out(in.getDim());
-  float value = 0;
-  float transformed_value = 0.0;
+  float value = 0.0f;
+  float transformed_value = 0.0f;
   unsigned int half_ = dim / 2;
 
   std::vector<std::vector<float>> freqs_cos = {};