[GPU/OpenCL] RMSNorm Bug Fix - Index value of alpha corrected in kernel logic.
authorNiket Agarwal <niket.a@samsung.com>
Thu, 10 Oct 2024 11:00:58 +0000 (16:30 +0530)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 14 Oct 2024 11:21:21 +0000 (20:21 +0900)
Updated RMSNorm with the new shared_ptr flow.
Replaced clCreateKernel with registerClKernel.

Self evaluation:

        Build test: [X]Passed [ ]Failed [ ]Skipped
Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Niket Agarwal <niket.a@samsung.com>
nntrainer/cl_context.cpp
nntrainer/layers/cl_layers/meson.build
nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp
nntrainer/layers/cl_layers/rmsnorm_layer_cl.h

index 818a77e0d39efee40ae3f0263540d8c6470ec809..c18a02ce5d8e55e33e6bd336524109359c1b11e4 100644 (file)
@@ -45,9 +45,8 @@ static void add_default_object(ClContext &cc) {
   cc.registerFactory(nntrainer::createLayer<ReshapeLayerCl>,
                      ReshapeLayerCl::type, ml::train::LayerType::LAYER_RESHAPE);
 
-  // cc.registerFactory(nntrainer::createLayer<RMSNormLayerCl>,
-  //                    RMSNormLayerCl::type,
-  //                    ml::train::LayerType::LAYER_RMSNORM);
+  cc.registerFactory(nntrainer::createLayer<RMSNormLayerCl>,
+                     RMSNormLayerCl::type, ml::train::LayerType::LAYER_RMSNORM);
 
   cc.registerFactory(nntrainer::createLayer<ConcatLayerCl>, ConcatLayerCl::type,
                      ml::train::LayerType::LAYER_CONCAT);
index c75328f69aa44e708c9636bd96136602f1c7c589..fbfd46961baf1128d64da05ced94dc23ebcddd67 100644 (file)
@@ -3,7 +3,7 @@ cl_layer_sources = [
   # 'addition_layer_cl.cpp',
    'swiglu_cl.cpp',
    'reshape_cl.cpp',
-  # 'rmsnorm_layer_cl.cpp',
+   'rmsnorm_layer_cl.cpp',
    'concat_cl.cpp',
 ]
 
index 96e9a530693234134d5cdb4a85801c253555134c..179b89fa8a3ad398ae25855212c42e7e1ccddb1a 100644 (file)
@@ -50,8 +50,8 @@ std::string rmsnorm_cl_kernel_fp16_ =
     half rms_norm = sqrt(sum_squares + epsilon);
     // Each work item processes all width elements for its specific n, h, c
     for (int w = 0; w < W; ++w) {
-        output[index+w] = (input[index+w] / rms_norm) * alpha[index+w];
-    }
+        output[index+w] = (input[index+w] / rms_norm) * alpha[w];
+    } 
 }
 )";
 
@@ -80,7 +80,7 @@ std::string rmsnorm_cl_kernel_ =
     float rms_norm = sqrt(sum_squares + epsilon);
     // Each work item processes all width elements for its specific n, h, c
     for (int w = 0; w < W; ++w) {
-        output[index+w] = (input[index+w] / rms_norm) * alpha[index+w];
+        output[index+w] = (input[index+w] / rms_norm) * alpha[w];
     }
 }
 )";
@@ -113,9 +113,13 @@ void RMSNormLayerCl::forwarding(RunLayerContext &context, bool training) {
   Tensor &gamma = context.getWeight(wt_idx[RMSParams::gamma]);
   auto &epsilon = std::get<props::Epsilon>(rmsnorm_props).get();
   if (in.getDataType() == ml::train::TensorDim::DataType::FP32) {
-    rmsnormProcess(in, out, gamma, epsilon, context);
+    rmsnormProcess(in, out, gamma, epsilon);
   } else {
-    rmsnormProcess_fp16(in, out, gamma, epsilon, context);
+#ifdef ENABLE_FP16
+    rmsnormProcess_fp16(in, out, gamma, epsilon);
+#else
+    throw std::invalid_argument("Error: enable-fp16 is not enabled");
+#endif
   }
 }
 
@@ -123,8 +127,7 @@ opencl::Kernel RMSNormLayerCl::kernel_rmsnorm;
 opencl::Kernel RMSNormLayerCl::kernel_rmsnorm_fp16;
 
 void RMSNormLayerCl::rmsnormProcess(Tensor const &input, Tensor &result,
-                                    Tensor const &gamma, const float epsilon,
-                                    RunLayerContext &context) {
+                                    Tensor const &gamma, const float epsilon) {
   bool ret = false;
   int dim1 = input.batch() * input.height() * input.width() * input.channel();
   CREATE_IF_EMPTY_DIMS(result, input.batch(), input.channel(), input.height(),
@@ -133,86 +136,82 @@ void RMSNormLayerCl::rmsnormProcess(Tensor const &input, Tensor &result,
   int c = input.channel();
   int h = input.height();
   int w = input.width();
+
   do {
-    ret =
-      context.clCreateKernel(rmsnorm_cl_kernel_, context.LayerKernel::RMSNORM,
-                             RMSNormLayerCl::kernel_rmsnorm);
-    if (!ret) {
+    ClContext::SharedPtrClKernel kernel_rmsnorm_ptr =
+      cl_context_ref.registerClKernel(rmsnorm_cl_kernel_, "rmsnorm_cl");
+    if (!kernel_rmsnorm_ptr) {
       break;
     }
 
-    opencl::Buffer inputbuf(context.context_inst_, dim1 * sizeof(float), true,
-                            nullptr);
+    opencl::Buffer inputbuf(cl_context_ref.context_inst_, dim1 * sizeof(float),
+                            true, nullptr);
 
-    opencl::Buffer gammabuf(context.context_inst_,
+    opencl::Buffer gammabuf(cl_context_ref.context_inst_,
                             input.width() * sizeof(float), true, nullptr);
-    opencl::Buffer resultbuf(context.context_inst_, dim1 * sizeof(float), true,
-                             nullptr);
+    opencl::Buffer resultbuf(cl_context_ref.context_inst_, dim1 * sizeof(float),
+                             true, nullptr);
 
     const float *data = input.getData();
     float *rdata = result.getData();
     const float *gdata = gamma.getData();
-    ret = inputbuf.WriteData(context.command_queue_inst_, data);
+    ret = inputbuf.WriteData(cl_context_ref.command_queue_inst_, data);
     if (!ret) {
       break;
     }
 
-    ret = gammabuf.WriteData(context.command_queue_inst_, gdata);
+    ret = gammabuf.WriteData(cl_context_ref.command_queue_inst_, gdata);
     if (!ret) {
       break;
     }
-    ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(0, &inputbuf,
-                                                            sizeof(cl_mem));
+    ret = kernel_rmsnorm_ptr->SetKernelArguments(0, &inputbuf, sizeof(cl_mem));
     if (!ret) {
       break;
     }
 
-    ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(1, &resultbuf,
-                                                            sizeof(cl_mem));
+    ret = kernel_rmsnorm_ptr->SetKernelArguments(1, &resultbuf, sizeof(cl_mem));
     if (!ret) {
       break;
     }
 
-    ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(2, &gammabuf,
-                                                            sizeof(cl_mem));
+    ret = kernel_rmsnorm_ptr->SetKernelArguments(2, &gammabuf, sizeof(cl_mem));
     if (!ret) {
       break;
     }
-    ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(4, &b, sizeof(int));
+    ret = kernel_rmsnorm_ptr->SetKernelArguments(4, &b, sizeof(int));
 
     if (!ret) {
       break;
     }
 
-    ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(3, &epsilon,
-                                                            sizeof(float));
+    ret = kernel_rmsnorm_ptr->SetKernelArguments(3, &epsilon, sizeof(float));
     if (!ret) {
       break;
     }
 
-    ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(5, &c, sizeof(int));
+    ret = kernel_rmsnorm_ptr->SetKernelArguments(5, &c, sizeof(int));
     if (!ret) {
       break;
     }
 
-    ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(6, &h, sizeof(int));
+    ret = kernel_rmsnorm_ptr->SetKernelArguments(6, &h, sizeof(int));
     if (!ret) {
       break;
     }
-    ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(7, &w, sizeof(int));
+    ret = kernel_rmsnorm_ptr->SetKernelArguments(7, &w, sizeof(int));
     if (!ret) {
       break;
     }
     const int work_groups_count[3] = {b * c, h, 1};
     const int work_group_size[3] = {32, 32, 1}; // test-value
 
-    ret = context.command_queue_inst_.DispatchCommand(
-      RMSNormLayerCl::kernel_rmsnorm, work_groups_count, work_group_size);
+    ret = cl_context_ref.command_queue_inst_.DispatchCommand(
+      kernel_rmsnorm_ptr, work_groups_count, work_group_size);
     if (!ret) {
       break;
     }
 
-    ret = resultbuf.ReadData(context.command_queue_inst_, rdata);
+    ret = resultbuf.ReadData(cl_context_ref.command_queue_inst_, rdata);
     if (!ret) {
       break;
     }
@@ -222,8 +221,7 @@ void RMSNormLayerCl::rmsnormProcess(Tensor const &input, Tensor &result,
 
 void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result,
                                          Tensor const &gamma,
-                                         const float epsilon,
-                                         RunLayerContext &context) {
+                                         const float epsilon) {
 
   bool ret = false;
   int dim1 = input.batch() * input.height() * input.width() * input.channel();
@@ -234,85 +232,77 @@ void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result,
   int h = input.height();
   int w = input.width();
   do {
-    ret = context.clCreateKernel(rmsnorm_cl_kernel_fp16_,
-                                 context.LayerKernel::RMSNORM_FP16,
-                                 RMSNormLayerCl::kernel_rmsnorm_fp16);
-    if (!ret) {
+    ClContext::SharedPtrClKernel kernel_rmsnorm_ptr =
+      cl_context_ref.registerClKernel(rmsnorm_cl_kernel_fp16_,
+                                      "rmsnorm_cl_fp16");
+    if (!kernel_rmsnorm_ptr) {
       break;
     }
-    opencl::Buffer inputbuf(context.context_inst_, dim1 * sizeof(cl_half), true,
-                            nullptr);
+    opencl::Buffer inputbuf(cl_context_ref.context_inst_,
+                            dim1 * sizeof(cl_half), true, nullptr);
 
-    opencl::Buffer gammabuf(context.context_inst_,
+    opencl::Buffer gammabuf(cl_context_ref.context_inst_,
                             input.width() * sizeof(cl_half), true, nullptr);
-    opencl::Buffer resultbuf(context.context_inst_, dim1 * sizeof(cl_half),
-                             true, nullptr);
+    opencl::Buffer resultbuf(cl_context_ref.context_inst_,
+                             dim1 * sizeof(cl_half), true, nullptr);
 
     const __fp16 *data = input.getData<__fp16>();
     __fp16 *rdata = result.getData<__fp16>();
     const __fp16 *gdata = gamma.getData<__fp16>();
-    ret = inputbuf.WriteData(context.command_queue_inst_, data);
+    ret = inputbuf.WriteData(cl_context_ref.command_queue_inst_, data);
     if (!ret) {
       break;
     }
 
-    ret = gammabuf.WriteData(context.command_queue_inst_, gdata);
+    ret = gammabuf.WriteData(cl_context_ref.command_queue_inst_, gdata);
     if (!ret) {
       break;
     }
-    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
-      0, &inputbuf, sizeof(cl_mem));
+    ret = kernel_rmsnorm_ptr->SetKernelArguments(0, &inputbuf, sizeof(cl_mem));
     if (!ret) {
       break;
     }
-    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
-      1, &resultbuf, sizeof(cl_mem));
+    ret = kernel_rmsnorm_ptr->SetKernelArguments(1, &resultbuf, sizeof(cl_mem));
     if (!ret) {
       break;
     }
 
-    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
-      2, &gammabuf, sizeof(cl_mem));
+    ret = kernel_rmsnorm_ptr->SetKernelArguments(2, &gammabuf, sizeof(cl_mem));
     if (!ret) {
       break;
     }
-    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(4, &b,
-                                                                 sizeof(int));
+    ret = kernel_rmsnorm_ptr->SetKernelArguments(4, &b, sizeof(int));
     if (!ret) {
       break;
     }
 
-    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
-      3, &epsilon, sizeof(cl_half));
+    ret = kernel_rmsnorm_ptr->SetKernelArguments(3, &epsilon, sizeof(cl_half));
     if (!ret) {
       break;
     }
 
-    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(5, &c,
-                                                                 sizeof(int));
+    ret = kernel_rmsnorm_ptr->SetKernelArguments(5, &c, sizeof(int));
     if (!ret) {
       break;
     }
-    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(6, &h,
-                                                                 sizeof(int));
+    ret = kernel_rmsnorm_ptr->SetKernelArguments(6, &h, sizeof(int));
     if (!ret) {
       break;
     }
-    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(7, &w,
-                                                                 sizeof(int));
+    ret = kernel_rmsnorm_ptr->SetKernelArguments(7, &w, sizeof(int));
     if (!ret) {
       break;
     }
     const int work_groups_count[3] = {b * c, h, 1};
     const int work_group_size[3] = {32, 32, 1}; // test-value
 
-    ret = context.command_queue_inst_.DispatchCommand(
-      RMSNormLayerCl::kernel_rmsnorm_fp16, work_groups_count, work_group_size);
+    ret = cl_context_ref.command_queue_inst_.DispatchCommand(
+      kernel_rmsnorm_ptr, work_groups_count, work_group_size);
     if (!ret) {
       break;
     }
 
-    ret = resultbuf.ReadData(context.command_queue_inst_, rdata);
+    ret = resultbuf.ReadData(cl_context_ref.command_queue_inst_, rdata);
     if (!ret) {
       break;
     }
@@ -347,9 +337,9 @@ void RMSNormLayerCl::incremental_forwarding(nntrainer::RunLayerContext &context,
   auto &epsilon = std::get<props::Epsilon>(rmsnorm_props).get();
 
   if (in_step.getDataType() == ml::train::TensorDim::DataType::FP32) {
-    rmsnormProcess(in, out, gamma, epsilon, context);
+    rmsnormProcess(in, out, gamma, epsilon);
   } else {
-    rmsnormProcess_fp16(in, out, gamma, epsilon, context);
+    rmsnormProcess_fp16(in, out, gamma, epsilon);
   }
 }
 
index 4b34729409fb2c9d722042946e130df26cfaf82d..43f942ea1e47e721408c7eae2b65d67aae644605 100644 (file)
@@ -19,6 +19,7 @@
 #include <layer_impl.h>
 #include <nntrainer_log.h>
 
+#include <cl_context.h>
 #include <opencl_buffer.h>
 #include <opencl_kernel.h>
 
@@ -49,7 +50,12 @@ public:
  * @class   RMSNormLayer
  * @brief   RMS Norm layer
  */
+
 class RMSNormLayerCl : public LayerImpl {
+
+private:
+  inline static ClContext cl_context_ref;
+
 public:
   /**
    * @brief     Constructor of RMS Norm Layer
@@ -84,9 +90,9 @@ public:
   void forwarding(RunLayerContext &context, bool training) override;
 
   /**
-   * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
-   * int from, unsigned int to, bool training)
-   */
+   * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
+   * int from, unsigned int to, bool training)
+   */
   void incremental_forwarding(RunLayerContext &context, unsigned int from,
                               unsigned int to, bool training) override;
 
@@ -121,24 +127,22 @@ public:
    * @param[in] result Tensor
    * @param[in] gamma Tensor
    * @param[in] epsilon float
-   * @param[in] RunLayerContext reference
    */
 
   void rmsnormProcess(Tensor const &input, Tensor &result, Tensor const &gamma,
-                      const float epsilon, RunLayerContext &context);
-
+                      const float epsilon);
+#ifdef ENABLE_FP16
   /**
    * @brief Process data and dimensions for FP16 rms norm operation
    * @param[in] input Tensor
    * @param[in] result Tensor
    * @param[in] gamma Tensor
    * @param[in] epsilon float
-   * @param[in] RunLayerContext reference
    */
 
   void rmsnormProcess_fp16(Tensor const &input, Tensor &result,
-                           Tensor const &gamma, const float epsilon,
-                           RunLayerContext &context);
+                           Tensor const &gamma, const float epsilon);
+#endif
   /**
    * @copydoc Layer::supportBackwarding()
    */