[GPU/Enhance] Registering Attention kernels and removind cl_context dependency
authorYash Singh <yash.singh@samsung.com>
Tue, 8 Oct 2024 07:13:17 +0000 (12:43 +0530)
committerJijoong Moon <jijoong.moon@samsung.com>
Sun, 20 Oct 2024 23:33:51 +0000 (08:33 +0900)
Added registerCLKernel function to register custom OpenCL kernels as well as in-house kernels.
Modified attention kernels to remove cl_context related dependencies.
Added initAttentionCLKernels function to register default attention kernels.
Modified unittest to remove layer_context dependency
attention_kernel_strings.h added to handle attention kernels at one place.
Rebased the PR with current log.

Signed-off-by: Yash Singh <yash.singh@samsung.com>
nntrainer/cl_context.cpp
nntrainer/cl_context.h
nntrainer/tensor/cl_operations/attention_kernel_interface.cpp
nntrainer/tensor/cl_operations/attention_kernel_interface.h
nntrainer/tensor/cl_operations/attention_kernel_strings.h [new file with mode: 0644]
nntrainer/tensor/cl_operations/attention_kernels.cpp
nntrainer/tensor/cl_operations/attention_kernels.h
nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp
nntrainer/tensor/cl_operations/meson.build
test/unittest/unittest_attention_kernels_cl.cpp

index 5ecf80f838b2bd7b7d1206da8344a20561537567..10e3ecdbb7891971ec8edd440487bc506cbdaf0b 100644 (file)
@@ -15,6 +15,7 @@
  */
 
 #include <addition_layer_cl.h>
+#include <attention_kernel_strings.h>
 #include <blas_kernel_strings.h>
 #include <cl_context.h>
 #include <concat_cl.h>
@@ -149,6 +150,21 @@ void ClContext::initBlasClKernels() {
   blas_kernels_initialized = true;
 }
 
+void ClContext::initAttentionClKernels() {
+  if (attention_kernels_initialized) {
+    ml_logi("ClContext: Default attention kernels already registered and "
+            "initialized");
+    return;
+  }
+
+  registerClKernel(rotary_emb_cl_kernel_, "rotary_emb_cl");
+
+#ifdef ENABLE_FP16
+  registerClKernel(rotary_emb_cl_kernel_fp16_, "rotary_emb_cl_fp16");
+#endif
+  attention_kernels_initialized = true;
+}
+
 const ClContext::SharedPtrClKernel
 ClContext::registerClKernel(std::string kernel_string,
                             std::string kernel_name) {
index 7683453221bce6f93cd4197864530d70cb208dc6..025365546b370f65d6c100ccce93b9cb1d9cae84 100644 (file)
@@ -211,6 +211,11 @@ public:
    */
   void initBlasClKernels();
 
+  /**
+   * @brief Initialize and register all attention OpenCl kernels
+   */
+  void initAttentionClKernels();
+
   /**
    * @brief destructor to release opencl commandQueue
    */
@@ -229,6 +234,9 @@ private:
   // flag to check default blas kernels registered or not
   bool blas_kernels_initialized = false;
 
+  // flag to check default attention kernels registered or not
+  bool attention_kernels_initialized = false;
+
   FactoryMap<nntrainer::Layer> factory_map;
 
   template <typename Args, typename T> struct isSupportedHelper;
index 85c3331eddda2678fbd45eb7dc1fa111140e3fad..658e2a3d91a5e49c9b57c1318a623e35d428765a 100644 (file)
@@ -59,12 +59,11 @@ void precompute_freqs(unsigned int dim, unsigned int seq_len,
  * @param[in] dim hidden dim size
  * @param[in] from sequence order
  * @param[in] max_timestep maximum timestep
- * @param[in] context layer context to get the resource manager and queue id
  *
  * @todo      Calling precompute_freqs in finalize to reduce code redundancy.
  */
 void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from,
-                         unsigned int max_timestep, RunLayerContext &context) {
+                         unsigned int max_timestep) {
   nntrainer::Tensor out(in.getDim());
   float value = 0.0f;
   float transformed_value = 0.0f;
@@ -111,7 +110,7 @@ void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from,
 
     rotary_emb_cl(data, rdata, freqs_cos, freqs_sin, cos_, sin_,
                   input_batch_size, input_channels, input_height, input_width,
-                  dim, from, max_timestep, in_size, out_size, context);
+                  dim, from, max_timestep, in_size, out_size);
 
   } else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
@@ -123,7 +122,7 @@ void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from,
 
     rotary_emb_cl(data, rdata, freqs_cos, freqs_sin, cos_, sin_,
                   input_batch_size, input_channels, input_height, input_width,
-                  dim, from, max_timestep, in_size, out_size, context);
+                  dim, from, max_timestep, in_size, out_size);
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
index b287cb0a4789161deeb7ecf471bd3ebc907ccb6a..fe9c0f8b0c7bd4ab65b47ec8f713a746b6c8da47 100644 (file)
@@ -14,8 +14,8 @@
 #ifndef __ATTENTION_KERNEL_INTERFACE_H__
 #define __ATTENTION_KERNEL_INTERFACE_H__
 
-#include <layer_context.h>
 #include <string>
+#include <tensor.h>
 
 namespace nntrainer {
 
@@ -25,10 +25,9 @@ namespace nntrainer {
  * @param[in] dim hidden dim size
  * @param[in] from sequence order
  * @param[in] max_timestep maximum timestep
- * @param[in] context layer context to get the resource manager and queue id
  */
 void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from,
-                         unsigned int max_timestep, RunLayerContext &context);
+                         unsigned int max_timestep);
 
 } // namespace nntrainer
 #endif /* __ATTENTION_KERNEL_INTERFACE_H__ */
diff --git a/nntrainer/tensor/cl_operations/attention_kernel_strings.h b/nntrainer/tensor/cl_operations/attention_kernel_strings.h
new file mode 100644 (file)
index 0000000..d58fd75
--- /dev/null
@@ -0,0 +1,133 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file       attention_kernel_strings.h
+ * @date       8 October 2024
+ * @brief      All attention OpenCL kernel strings
+ * @see                https://github.com/nnstreamer/nntrainer
+ * @author     Yash Singh <yash.singh@samsung.com>
+ * @bug                No known bugs except for NYI items
+ *
+ */
+
+#ifndef __ATTENTION_KERNEL_STRINGS_H__
+#define __ATTENTION_KERNEL_STRINGS_H__
+
+#include <string>
+
+namespace nntrainer {
+static const std::string rotary_emb_cl_kernel_ = R"(
+
+  #pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+__kernel void rotary_emb_cl(__global float *input,
+                                      __global float *output,
+                                      __global float *freqs_cos,
+                                      __global float *freqs_sin,
+                                      __global float *cos_,
+                                      __global float *sin_,
+                                      unsigned int batch,
+                                      unsigned int channel,
+                                      unsigned int height,
+                                      unsigned int width,
+                                      unsigned int dim,
+                                      unsigned int half_,
+                                      unsigned int max_timestep,
+                                      unsigned int from) {
+    __global float *cos_ptr = cos_;
+    __global float *sin_ptr = sin_;
+
+    float value = 0.0f;
+    float transformed_value = 0.0f;
+
+    unsigned int b = get_global_id(0);
+    unsigned int c = get_global_id(1);
+    
+    if(b < batch && c < channel){
+      for (unsigned int h = 0; h < height; h++) {
+        if (from + h < max_timestep) {
+          unsigned idx = (from + h)*dim;
+          for(unsigned int i = idx; i < idx + dim; i++){
+            cos_ptr[i - idx] = freqs_cos[i];
+            sin_ptr[i - idx] = freqs_sin[i];
+          }
+        }
+
+        for (unsigned int w = 0; w < width; w = w + dim) {
+          for (unsigned int k = 0; k < dim; k++) {
+            unsigned int span = w + k;
+            value = input[b * channel * height * width + c * height * width + h * width + span];
+            if (k < half_) {
+              transformed_value = -1.0f * input[b * channel * height * width + c * height * width + h * width + span + half_];
+            } else {
+              transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_];
+            }
+            value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
+            output[b * channel * height * width + c * height * width + h * width + span] = value;
+          }
+        }
+      }
+    }
+}
+)";
+
+#ifdef ENABLE_FP16
+static const std::string rotary_emb_cl_kernel_fp16_ = R"(
+
+  #pragma OPENCL EXTENSION cl_khr_fp16 : enable
+  
+__kernel void rotary_emb_cl_fp16(__global half *input,
+                                      __global half *output,
+                                      __global float *freqs_cos,
+                                      __global float *freqs_sin,
+                                      __global float *cos_,
+                                      __global float *sin_,
+                                      unsigned int batch,
+                                      unsigned int channel,
+                                      unsigned int height,
+                                      unsigned int width,
+                                      unsigned int dim,
+                                      unsigned int half_,
+                                      unsigned int max_timestep,
+                                      unsigned int from) {
+    __global float *cos_ptr = cos_;
+    __global float *sin_ptr = sin_;
+
+    float value = 0.0f;
+    float transformed_value = 0.0f;
+
+    unsigned int b = get_global_id(0);
+    unsigned int c = get_global_id(1);
+    
+    if(b < batch && c < channel){
+      for (unsigned int h = 0; h < height; h++) {
+        if (from + h < max_timestep) {
+          unsigned idx = (from + h)*dim;
+          for(int i = idx; i < idx + dim; i++ ){
+            cos_ptr[i - idx] = freqs_cos[i];
+            sin_ptr[i - idx] = freqs_sin[i];
+          }
+        }
+
+        for (unsigned int w = 0; w < width; w = w + dim) {
+          for (unsigned int k = 0; k < dim; k++) {
+            unsigned int span = w + k;
+            value = (float)input[b * channel * height * width + c * height * width + h * width + span];
+            if (k < half_) {
+              transformed_value = -1.0f * (float)input[b * channel * height * width + c * height * width + h * width + span + half_];
+            } else {
+              transformed_value = (float)input[b * channel * height * width + c * height * width + h * width + span - half_];
+            }
+            value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
+            output[b * channel * height * width + c * height * width + h * width + span] = (half)value;
+          }
+        }
+      }
+    }
+}
+)";
+
+#endif
+} // namespace nntrainer
+#endif /* __ATTENTION_KERNEL_INTERFACE_H__ */
index 5fd646b7c153b699d02a647e5b5bb3c7c6a2a4bd..388cc0805ff33d8fb90c4f5ba5edfdf1e0002c06 100644 (file)
  *
  */
 
+#include <attention_kernel_strings.h>
 #include <attention_kernels.h>
 
 namespace nntrainer {
-std::string rotary_emb_cl_kernel = R"(
-  #pragma OPENCL EXTENSION cl_khr_fp16 : enable
-__kernel void rotary_emb_cl(__global float *input,
-                                      __global float *output,
-                                      __global float *freqs_cos,
-                                      __global float *freqs_sin,
-                                      __global float *cos_,
-                                      __global float *sin_,
-                                      unsigned int batch,
-                                      unsigned int channel,
-                                      unsigned int height,
-                                      unsigned int width,
-                                      unsigned int dim,
-                                      unsigned int half_,
-                                      unsigned int max_timestep,
-                                      unsigned int from) {
-    __global float *cos_ptr = cos_;
-    __global float *sin_ptr = sin_;
-
-    float value = 0.0f;
-    float transformed_value = 0.0f;
-
-    unsigned int b = get_global_id(0);
-    unsigned int c = get_global_id(1);
-    
-    if(b < batch && c < channel){
-      for (unsigned int h = 0; h < height; h++) {
-        if (from + h < max_timestep) {
-          unsigned idx = (from + h)*dim;
-          for(unsigned int i = idx; i < idx + dim; i++){
-            cos_ptr[i - idx] = freqs_cos[i];
-            sin_ptr[i - idx] = freqs_sin[i];
-          }
-        }
-
-        for (unsigned int w = 0; w < width; w = w + dim) {
-          for (unsigned int k = 0; k < dim; k++) {
-            unsigned int span = w + k;
-            value = input[b * channel * height * width + c * height * width + h * width + span];
-            if (k < half_) {
-              transformed_value = -1.0f * input[b * channel * height * width + c * height * width + h * width + span + half_];
-            } else {
-              transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_];
-            }
-            value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
-            output[b * channel * height * width + c * height * width + h * width + span] = value;
-          }
-        }
-      }
-    }
-}
-)";
-
-/**
- * @brief defining global kernel objects
- */
-opencl::Kernel kernel_rotary_emb;
 
 void rotary_emb_cl(float *in, float *out,
                    std::vector<std::vector<float>> freqs_cos,
@@ -79,17 +23,16 @@ void rotary_emb_cl(float *in, float *out,
                    unsigned int batch, unsigned int channel,
                    unsigned int height, unsigned int width, unsigned int dim,
                    unsigned int from, unsigned int max_timestep,
-                   unsigned int in_size, unsigned int out_size,
-                   RunLayerContext &context) {
+                   unsigned int in_size, unsigned int out_size) {
   bool result = false;
 
   do {
-    result = context.clCreateKernel(
-      rotary_emb_cl_kernel, context.LayerKernel::ROTARY_EMB, kernel_rotary_emb);
-    if (!result) {
-      printf("Failed to create kernel for rotary_emb_cl\n");
+    ClContext::SharedPtrClKernel kernel_rotaryEmb_ptr =
+      cl_context_ref.registerClKernel(rotary_emb_cl_kernel_, "rotary_emb_cl");
+    if (!kernel_rotaryEmb_ptr) {
       break;
     }
+
     unsigned int cos_dim = cos_.size();
     unsigned int sin_dim = sin_.size();
     unsigned int freqs_cos_dim = freqs_cos.size();
@@ -103,18 +46,22 @@ void rotary_emb_cl(float *in, float *out,
       sizeof(float) * freqs_cos_dim * dim; // max_timestep * dim
     size_t dim6_size = sizeof(float) * freqs_sin_dim * dim;
 
-    opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr);
+    opencl::Buffer inputA(cl_context_ref.context_inst_, dim1_size, true,
+                          nullptr);
 
-    opencl::Buffer inOutRes(context.context_inst_, dim2_size, true, nullptr);
+    opencl::Buffer inOutRes(cl_context_ref.context_inst_, dim2_size, true,
+                            nullptr);
 
-    opencl::Buffer cosBuf(context.context_inst_, dim3_size, true, nullptr);
+    opencl::Buffer cosBuf(cl_context_ref.context_inst_, dim3_size, true,
+                          nullptr);
 
-    opencl::Buffer sinBuf(context.context_inst_, dim4_size, true, nullptr);
+    opencl::Buffer sinBuf(cl_context_ref.context_inst_, dim4_size, true,
+                          nullptr);
 
-    opencl::Buffer freqs_cosBuf(context.context_inst_, dim5_size, true,
+    opencl::Buffer freqs_cosBuf(cl_context_ref.context_inst_, dim5_size, true,
                                 nullptr);
 
-    opencl::Buffer freqs_sinBuf(context.context_inst_, dim6_size, true,
+    opencl::Buffer freqs_sinBuf(cl_context_ref.context_inst_, dim6_size, true,
                                 nullptr);
 
     std::vector<float> freqs_cos_flat;
@@ -126,126 +73,130 @@ void rotary_emb_cl(float *in, float *out,
       freqs_sin_flat.insert(freqs_sin_flat.end(), row.begin(), row.end());
     }
 
-    result = inputA.WriteData(context.command_queue_inst_, in);
+    result = inputA.WriteData(cl_context_ref.command_queue_inst_, in);
     if (!result) {
       printf("Failed to write input data\n");
       break;
     }
 
-    result = inOutRes.WriteData(context.command_queue_inst_, out);
+    result = inOutRes.WriteData(cl_context_ref.command_queue_inst_, out);
     if (!result) {
       printf("Failed to write output data\n");
       break;
     }
 
-    result = freqs_cosBuf.WriteData(context.command_queue_inst_,
+    result = freqs_cosBuf.WriteData(cl_context_ref.command_queue_inst_,
                                     freqs_cos_flat.data());
     if (!result) {
       printf("Failed to write freqs cos data\n");
       break;
     }
 
-    result = freqs_sinBuf.WriteData(context.command_queue_inst_,
+    result = freqs_sinBuf.WriteData(cl_context_ref.command_queue_inst_,
                                     freqs_sin_flat.data());
     if (!result) {
       printf("Failed to write freqs sin data\n");
       break;
     }
 
-    result = cosBuf.WriteData(context.command_queue_inst_, cos_.data());
+    result = cosBuf.WriteData(cl_context_ref.command_queue_inst_, cos_.data());
     if (!result) {
       printf("Failed to write cos data\n");
       break;
     }
 
-    result = sinBuf.WriteData(context.command_queue_inst_, sin_.data());
+    result = sinBuf.WriteData(cl_context_ref.command_queue_inst_, sin_.data());
     if (!result) {
       printf("Failed to write sin data\n");
       break;
     }
 
-    result = kernel_rotary_emb.SetKernelArguments(0, &inputA, sizeof(cl_mem));
+    result =
+      kernel_rotaryEmb_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem));
     if (!result) {
       printf("Failed to set inputA argument\n");
       break;
     }
 
-    result = kernel_rotary_emb.SetKernelArguments(1, &inOutRes, sizeof(cl_mem));
+    result =
+      kernel_rotaryEmb_ptr->SetKernelArguments(1, &inOutRes, sizeof(cl_mem));
     if (!result) {
       printf("Failed to set inOutRes argument\n");
       break;
     }
 
-    result =
-      kernel_rotary_emb.SetKernelArguments(2, &freqs_cosBuf, sizeof(cl_mem));
+    result = kernel_rotaryEmb_ptr->SetKernelArguments(2, &freqs_cosBuf,
+                                                      sizeof(cl_mem));
     if (!result) {
       printf("Failed to set freqs_cosBuf argument\n");
       break;
     }
 
-    result =
-      kernel_rotary_emb.SetKernelArguments(3, &freqs_sinBuf, sizeof(cl_mem));
+    result = kernel_rotaryEmb_ptr->SetKernelArguments(3, &freqs_sinBuf,
+                                                      sizeof(cl_mem));
     if (!result) {
       printf("Failed to set freqs_sinBuf argument\n");
       break;
     }
 
-    result = kernel_rotary_emb.SetKernelArguments(4, &cosBuf, sizeof(cl_mem));
+    result =
+      kernel_rotaryEmb_ptr->SetKernelArguments(4, &cosBuf, sizeof(cl_mem));
     if (!result) {
       printf("Failed to set cosBuf argument\n");
       break;
     }
 
-    result = kernel_rotary_emb.SetKernelArguments(5, &sinBuf, sizeof(cl_mem));
+    result =
+      kernel_rotaryEmb_ptr->SetKernelArguments(5, &sinBuf, sizeof(cl_mem));
     if (!result) {
       printf("Failed to set sinBuf argument\n");
       break;
     }
 
-    result = kernel_rotary_emb.SetKernelArguments(6, &batch, sizeof(int));
+    result = kernel_rotaryEmb_ptr->SetKernelArguments(6, &batch, sizeof(int));
     if (!result) {
       printf("Failed to set batch argument\n");
       break;
     }
 
-    result = kernel_rotary_emb.SetKernelArguments(7, &channel, sizeof(int));
+    result = kernel_rotaryEmb_ptr->SetKernelArguments(7, &channel, sizeof(int));
     if (!result) {
       printf("Failed to set channel argument\n");
       break;
     }
 
-    result = kernel_rotary_emb.SetKernelArguments(8, &height, sizeof(int));
+    result = kernel_rotaryEmb_ptr->SetKernelArguments(8, &height, sizeof(int));
     if (!result) {
       printf("Failed to set height argument\n");
       break;
     }
 
-    result = kernel_rotary_emb.SetKernelArguments(9, &width, sizeof(int));
+    result = kernel_rotaryEmb_ptr->SetKernelArguments(9, &width, sizeof(int));
     if (!result) {
       printf("Failed to set width argument\n");
       break;
     }
 
-    result = kernel_rotary_emb.SetKernelArguments(10, &dim, sizeof(int));
+    result = kernel_rotaryEmb_ptr->SetKernelArguments(10, &dim, sizeof(int));
     if (!result) {
       printf("Failed to set dim argument\n");
       break;
     }
     unsigned int half_ = dim / 2;
-    result = kernel_rotary_emb.SetKernelArguments(11, &half_, sizeof(int));
+    result = kernel_rotaryEmb_ptr->SetKernelArguments(11, &half_, sizeof(int));
     if (!result) {
       printf("Failed to set half argument\n");
       break;
     }
 
     result =
-      kernel_rotary_emb.SetKernelArguments(12, &max_timestep, sizeof(int));
+      kernel_rotaryEmb_ptr->SetKernelArguments(12, &max_timestep, sizeof(int));
     if (!result) {
       printf("Failed to set timestamp argument\n");
       break;
     }
 
-    result = kernel_rotary_emb.SetKernelArguments(13, &from, sizeof(int));
+    result = kernel_rotaryEmb_ptr->SetKernelArguments(13, &from, sizeof(int));
     if (!result) {
       printf("Failed to set from argument\n");
       break;
@@ -253,14 +204,14 @@ void rotary_emb_cl(float *in, float *out,
 
     const int work_groups_count[3] = {(int)batch, (int)channel, 1};
     const int work_group_size[3] = {32, 32, 1}; // test-value
-    result = context.command_queue_inst_.DispatchCommand(
-      kernel_rotary_emb, work_groups_count, work_group_size);
+    result = cl_context_ref.command_queue_inst_.DispatchCommand(
+      kernel_rotaryEmb_ptr, work_groups_count, work_group_size);
     if (!result) {
       printf("Failed to dispatch command\n");
       break;
     }
 
-    result = inOutRes.ReadData(context.command_queue_inst_, out);
+    result = inOutRes.ReadData(cl_context_ref.command_queue_inst_, out);
     if (!result) {
       printf("Failed to read data\n");
       break;
index 97e2a98ceac0c2b46b789279d4726a865d80e34d..37a3a4428aa07e754c3fdd493ff3a68ac5f21ac7 100644 (file)
 #ifndef __ATTENTION_KERNELS_H__
 #define __ATTENTION_KERNELS_H__
 
-#include <layer_context.h>
+#include <cl_context.h>
 #include <opencl_buffer.h>
 #include <opencl_kernel.h>
 #include <string>
 
 namespace nntrainer {
 
-/**
- * @brief declaring global kernel objects
- */
-extern opencl::Kernel kernel_rotary_emb;
+// get global cl_context to use in kernels
+static ClContext cl_context_ref;
 
 /**
  * @brief     Rotary Embedding process
@@ -43,7 +41,6 @@ extern opencl::Kernel kernel_rotary_emb;
  * @param[in] max_timestep max timestep
  * @param[in] in_size size of input
  * @param[in] out_size size of output
- * @param[in] context RunLayerContext reference
  */
 void rotary_emb_cl(float *in, float *out,
                    std::vector<std::vector<float>> freqs_cos,
@@ -52,14 +49,9 @@ void rotary_emb_cl(float *in, float *out,
                    unsigned int batch, unsigned int channel,
                    unsigned int height, unsigned int width, unsigned int dim,
                    unsigned int from, unsigned int max_timestamp,
-                   unsigned int in_size, unsigned int out_size,
-                   RunLayerContext &context);
+                   unsigned int in_size, unsigned int out_size);
 
 #ifdef ENABLE_FP16
-/**
- * @brief declaring global fp16 kernel objects
- */
-extern opencl::Kernel kernel_rotary_emb_fp16;
 
 /**
  * @brief     Rotary Embedding process
@@ -78,7 +70,6 @@ extern opencl::Kernel kernel_rotary_emb_fp16;
  * @param[in] max_timestep max timestep
  * @param[in] in_size size of input
  * @param[in] out_size size of output
- * @param[in] context RunLayerContext reference
  */
 void rotary_emb_cl(__fp16 *in, __fp16 *out,
                    std::vector<std::vector<float>> freqs_cos,
@@ -87,8 +78,7 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out,
                    unsigned int batch, unsigned int channel,
                    unsigned int height, unsigned int width, unsigned int dim,
                    unsigned int from, unsigned int max_timestamp,
-                   unsigned int in_size, unsigned int out_size,
-                   RunLayerContext &context);
+                   unsigned int in_size, unsigned int out_size);
 
 #endif
 
index 7c2c995020faca3979b2fdb2ef5e743bda1d8b19..c1284b0a9c1482a2710c9fd52b3c6d16702bde7b 100644 (file)
  *
  */
 
+#include <attention_kernel_strings.h>
 #include <attention_kernels.h>
 
 namespace nntrainer {
-std::string rotary_emb_cl_kernel_fp16 = R"(
-  #pragma OPENCL EXTENSION cl_khr_fp16 : enable
-__kernel void rotary_emb_cl_fp16(__global half *input,
-                                      __global half *output,
-                                      __global float *freqs_cos,
-                                      __global float *freqs_sin,
-                                      __global float *cos_,
-                                      __global float *sin_,
-                                      unsigned int batch,
-                                      unsigned int channel,
-                                      unsigned int height,
-                                      unsigned int width,
-                                      unsigned int dim,
-                                      unsigned int half_,
-                                      unsigned int max_timestep,
-                                      unsigned int from) {
-    __global float *cos_ptr = cos_;
-    __global float *sin_ptr = sin_;
-
-    float value = 0.0f;
-    float transformed_value = 0.0f;
-
-    unsigned int b = get_global_id(0);
-    unsigned int c = get_global_id(1);
-    
-    if(b < batch && c < channel){
-      for (unsigned int h = 0; h < height; h++) {
-        if (from + h < max_timestep) {
-          unsigned idx = (from + h)*dim;
-          for(int i = idx; i < idx + dim; i++ ){
-            cos_ptr[i - idx] = freqs_cos[i];
-            sin_ptr[i - idx] = freqs_sin[i];
-          }
-        }
-
-        for (unsigned int w = 0; w < width; w = w + dim) {
-          for (unsigned int k = 0; k < dim; k++) {
-            unsigned int span = w + k;
-            value = (float)input[b * channel * height * width + c * height * width + h * width + span];
-            if (k < half_) {
-              transformed_value = -1.0f * (float)input[b * channel * height * width + c * height * width + h * width + span + half_];
-            } else {
-              transformed_value = (float)input[b * channel * height * width + c * height * width + h * width + span - half_];
-            }
-            value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
-            output[b * channel * height * width + c * height * width + h * width + span] = (half)value;
-          }
-        }
-      }
-    }
-}
-)";
-
-/**
- * @brief defining global kernel objects
- */
-opencl::Kernel kernel_rotary_emb_fp16;
 
 void rotary_emb_cl(__fp16 *in, __fp16 *out,
                    std::vector<std::vector<float>> freqs_cos,
@@ -79,16 +23,14 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out,
                    unsigned int batch, unsigned int channel,
                    unsigned int height, unsigned int width, unsigned int dim,
                    unsigned int from, unsigned int max_timestep,
-                   unsigned int in_size, unsigned int out_size,
-                   RunLayerContext &context) {
+                   unsigned int in_size, unsigned int out_size) {
 
   bool result = false;
   do {
-    result = context.clCreateKernel(rotary_emb_cl_kernel_fp16,
-                                    context.LayerKernel::ROTARY_EMB_FP16,
-                                    kernel_rotary_emb_fp16);
-    if (!result) {
-      printf("Failed to create kernel for rotary_emb_cl\n");
+    ClContext::SharedPtrClKernel kernel_rotaryEmb_fp16_ptr =
+      cl_context_ref.registerClKernel(rotary_emb_cl_kernel_fp16_,
+                                      "rotary_emb_cl_fp16");
+    if (!kernel_rotaryEmb_fp16_ptr) {
       break;
     }
 
@@ -104,18 +46,22 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out,
     size_t dim5_size = sizeof(float) * freqs_cos_dim * dim;
     size_t dim6_size = sizeof(float) * freqs_sin_dim * dim;
 
-    opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr);
+    opencl::Buffer inputA(cl_context_ref.context_inst_, dim1_size, true,
+                          nullptr);
 
-    opencl::Buffer inOutRes(context.context_inst_, dim2_size, true, nullptr);
+    opencl::Buffer inOutRes(cl_context_ref.context_inst_, dim2_size, true,
+                            nullptr);
 
-    opencl::Buffer cosBuf(context.context_inst_, dim3_size, true, nullptr);
+    opencl::Buffer cosBuf(cl_context_ref.context_inst_, dim3_size, true,
+                          nullptr);
 
-    opencl::Buffer sinBuf(context.context_inst_, dim4_size, true, nullptr);
+    opencl::Buffer sinBuf(cl_context_ref.context_inst_, dim4_size, true,
+                          nullptr);
 
-    opencl::Buffer freqs_cosBuf(context.context_inst_, dim5_size, true,
+    opencl::Buffer freqs_cosBuf(cl_context_ref.context_inst_, dim5_size, true,
                                 nullptr);
 
-    opencl::Buffer freqs_sinBuf(context.context_inst_, dim6_size, true,
+    opencl::Buffer freqs_sinBuf(cl_context_ref.context_inst_, dim6_size, true,
                                 nullptr);
 
     std::vector<float> freqs_cos_flat;
@@ -127,131 +73,137 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out,
       freqs_sin_flat.insert(freqs_sin_flat.end(), row.begin(), row.end());
     }
 
-    result = inputA.WriteData(context.command_queue_inst_, in);
+    result = inputA.WriteData(cl_context_ref.command_queue_inst_, in);
     if (!result) {
       printf("Failed to write input data\n");
       break;
     }
 
-    result = inOutRes.WriteData(context.command_queue_inst_, out);
+    result = inOutRes.WriteData(cl_context_ref.command_queue_inst_, out);
     if (!result) {
       printf("Failed to write output data\n");
       break;
     }
 
-    result = freqs_cosBuf.WriteData(context.command_queue_inst_,
+    result = freqs_cosBuf.WriteData(cl_context_ref.command_queue_inst_,
                                     freqs_cos_flat.data());
     if (!result) {
       printf("Failed to write freqs cos data\n");
       break;
     }
 
-    result = freqs_sinBuf.WriteData(context.command_queue_inst_,
+    result = freqs_sinBuf.WriteData(cl_context_ref.command_queue_inst_,
                                     freqs_sin_flat.data());
     if (!result) {
       printf("Failed to write freqs sin data\n");
       break;
     }
 
-    result = cosBuf.WriteData(context.command_queue_inst_, cos_.data());
+    result = cosBuf.WriteData(cl_context_ref.command_queue_inst_, cos_.data());
     if (!result) {
       printf("Failed to write cos data\n");
       break;
     }
 
-    result = sinBuf.WriteData(context.command_queue_inst_, sin_.data());
+    result = sinBuf.WriteData(cl_context_ref.command_queue_inst_, sin_.data());
     if (!result) {
       printf("Failed to write sin data\n");
       break;
     }
 
     result =
-      kernel_rotary_emb_fp16.SetKernelArguments(0, &inputA, sizeof(cl_mem));
+      kernel_rotaryEmb_fp16_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem));
     if (!result) {
       printf("Failed to set inputA argument\n");
       break;
     }
 
-    result =
-      kernel_rotary_emb_fp16.SetKernelArguments(1, &inOutRes, sizeof(cl_mem));
+    result = kernel_rotaryEmb_fp16_ptr->SetKernelArguments(1, &inOutRes,
+                                                           sizeof(cl_mem));
     if (!result) {
       printf("Failed to set inOutRes argument\n");
       break;
     }
 
-    result = kernel_rotary_emb_fp16.SetKernelArguments(2, &freqs_cosBuf,
-                                                       sizeof(cl_mem));
+    result = kernel_rotaryEmb_fp16_ptr->SetKernelArguments(2, &freqs_cosBuf,
+                                                           sizeof(cl_mem));
     if (!result) {
       printf("Failed to set freqs_cosBuf argument\n");
       break;
     }
 
-    result = kernel_rotary_emb_fp16.SetKernelArguments(3, &freqs_sinBuf,
-                                                       sizeof(cl_mem));
+    result = kernel_rotaryEmb_fp16_ptr->SetKernelArguments(3, &freqs_sinBuf,
+                                                           sizeof(cl_mem));
     if (!result) {
       printf("Failed to set freqs_sinBuf argument\n");
       break;
     }
 
     result =
-      kernel_rotary_emb_fp16.SetKernelArguments(4, &cosBuf, sizeof(cl_mem));
+      kernel_rotaryEmb_fp16_ptr->SetKernelArguments(4, &cosBuf, sizeof(cl_mem));
     if (!result) {
       printf("Failed to set cosBuf argument\n");
       break;
     }
 
     result =
-      kernel_rotary_emb_fp16.SetKernelArguments(5, &sinBuf, sizeof(cl_mem));
+      kernel_rotaryEmb_fp16_ptr->SetKernelArguments(5, &sinBuf, sizeof(cl_mem));
     if (!result) {
       printf("Failed to set sinBuf argument\n");
       break;
     }
 
-    result = kernel_rotary_emb_fp16.SetKernelArguments(6, &batch, sizeof(int));
+    result =
+      kernel_rotaryEmb_fp16_ptr->SetKernelArguments(6, &batch, sizeof(int));
     if (!result) {
       printf("Failed to set batch argument\n");
       break;
     }
 
     result =
-      kernel_rotary_emb_fp16.SetKernelArguments(7, &channel, sizeof(int));
+      kernel_rotaryEmb_fp16_ptr->SetKernelArguments(7, &channel, sizeof(int));
     if (!result) {
       printf("Failed to set channel argument\n");
       break;
     }
 
-    result = kernel_rotary_emb_fp16.SetKernelArguments(8, &height, sizeof(int));
+    result =
+      kernel_rotaryEmb_fp16_ptr->SetKernelArguments(8, &height, sizeof(int));
     if (!result) {
       printf("Failed to set height argument\n");
       break;
     }
 
-    result = kernel_rotary_emb_fp16.SetKernelArguments(9, &width, sizeof(int));
+    result =
+      kernel_rotaryEmb_fp16_ptr->SetKernelArguments(9, &width, sizeof(int));
     if (!result) {
       printf("Failed to set width argument\n");
       break;
     }
 
-    result = kernel_rotary_emb_fp16.SetKernelArguments(10, &dim, sizeof(int));
+    result =
+      kernel_rotaryEmb_fp16_ptr->SetKernelArguments(10, &dim, sizeof(int));
     if (!result) {
       printf("Failed to set dim argument\n");
       break;
     }
     unsigned int half_ = dim / 2;
-    result = kernel_rotary_emb_fp16.SetKernelArguments(11, &half_, sizeof(int));
+    result =
+      kernel_rotaryEmb_fp16_ptr->SetKernelArguments(11, &half_, sizeof(int));
     if (!result) {
       printf("Failed to set half argument\n");
       break;
     }
 
-    result =
-      kernel_rotary_emb_fp16.SetKernelArguments(12, &max_timestep, sizeof(int));
+    result = kernel_rotaryEmb_fp16_ptr->SetKernelArguments(12, &max_timestep,
+                                                           sizeof(int));
     if (!result) {
       printf("Failed to set timestamp argument\n");
       break;
     }
 
-    result = kernel_rotary_emb_fp16.SetKernelArguments(13, &from, sizeof(int));
+    result =
+      kernel_rotaryEmb_fp16_ptr->SetKernelArguments(13, &from, sizeof(int));
     if (!result) {
       printf("Failed to set from argument\n");
       break;
@@ -259,14 +211,14 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out,
 
     const int work_groups_count[3] = {(int)batch, (int)channel, 1};
     const int work_group_size[3] = {32, 32, 1}; // test-value
-    result = context.command_queue_inst_.DispatchCommand(
-      kernel_rotary_emb_fp16, work_groups_count, work_group_size);
+    result = cl_context_ref.command_queue_inst_.DispatchCommand(
+      kernel_rotaryEmb_fp16_ptr, work_groups_count, work_group_size);
     if (!result) {
       printf("Failed to dispatch command\n");
       break;
     }
 
-    result = inOutRes.ReadData(context.command_queue_inst_, out);
+    result = inOutRes.ReadData(cl_context_ref.command_queue_inst_, out);
     if (!result) {
       printf("Failed to read data\n");
       break;
index 3f186ec645a61010b05ffe1b363fa4e7d8a1cb53..a1b9b795bbf644ba8f50023f54912ba576b04e61 100644 (file)
@@ -9,6 +9,7 @@ cl_op_headers = [
   'blas_kernel_interface.h',
   'blas_kernel_strings.h',
   'attention_kernel_interface.h',
+  'attention_kernel_strings.h',
 ]
 
 if get_option('enable-fp16')
index d2a26cc9d3e489eefb5aabc972b32e2f08e3a2ad..a95937446da7379f2c2ad80ae51b51f6567ac330 100644 (file)
 
 using namespace nntrainer;
 
-static RunLayerContext setUpGpuContext() {
-
+static void setUpGpuContext() {
   auto &ac = nntrainer::ClContext::Global();
-  auto rc = RunLayerContext();
-
-  return rc;
+  ac.initAttentionClKernels();
 }
 
 TEST(attention_kernels, rotary_emb_kernel_FP32) {
-  RunLayerContext rc = setUpGpuContext();
+  setUpGpuContext();
 
   int batch = 1;
   int channel = 1;
@@ -65,7 +62,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP32) {
 
   B_fp32.copy(A_fp32);
 
-  apply_rotary_emb_cl(A_fp32, dim, from, max_timestep, rc);
+  apply_rotary_emb_cl(A_fp32, dim, from, max_timestep);
   apply_rotary_emb_tensor(B_fp32, dim, from, max_timestep);
 
   float mseErrorNeon_fp32 =
@@ -81,7 +78,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP32) {
 }
 
 TEST(attention_kernels, rotary_emb_kernel_FP32_case2) {
-  RunLayerContext rc = setUpGpuContext();
+  setUpGpuContext();
 
   int batch = 4;
   int channel = 4;
@@ -108,7 +105,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP32_case2) {
 
   B_fp32.copy(A_fp32);
 
-  apply_rotary_emb_cl(A_fp32, dim, from, max_timestep, rc);
+  apply_rotary_emb_cl(A_fp32, dim, from, max_timestep);
   apply_rotary_emb_tensor(B_fp32, dim, from, max_timestep);
 
   float mseErrorNeon_fp32 =
@@ -124,7 +121,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP32_case2) {
 }
 
 TEST(attention_kernels, rotary_emb_kernel_FP16) {
-  RunLayerContext rc = setUpGpuContext();
+  setUpGpuContext();
 
   int batch = 1;
   int channel = 1;
@@ -150,7 +147,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP16) {
 
   B_fp16.copy(A_fp16);
 
-  apply_rotary_emb_cl(A_fp16, dim, from, max_timestep, rc);
+  apply_rotary_emb_cl(A_fp16, dim, from, max_timestep);
   apply_rotary_emb_tensor(B_fp16, dim, from, max_timestep);
 
   float mseErrorNeon_fp16 = mse<__fp16>(
@@ -166,7 +163,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP16) {
 }
 
 TEST(attention_kernels, rotary_emb_kernel_FP16_case2) {
-  RunLayerContext rc = setUpGpuContext();
+  setUpGpuContext();
 
   int batch = 4;
   int channel = 4;
@@ -192,7 +189,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP16_case2) {
 
   B_fp16.copy(A_fp16);
 
-  apply_rotary_emb_cl(A_fp16, dim, from, max_timestep, rc);
+  apply_rotary_emb_cl(A_fp16, dim, from, max_timestep);
   apply_rotary_emb_tensor(B_fp16, dim, from, max_timestep);
 
   float mseErrorNeon_fp16 = mse<__fp16>(