[GPU/OPENCL] RMSNorm Accuracy Fix
authorThummala Pallavi <t.pallavi@samsung.com>
Fri, 2 Aug 2024 06:41:34 +0000 (12:11 +0530)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 19 Aug 2024 07:28:18 +0000 (16:28 +0900)
The alpha values were not picked correctly.

Signed-off-by: Thummala Pallavi <t.pallavi@samsung.com>
nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp

index 0dd1f15c852aa58fd993dd144bededea6a045007..96e9a530693234134d5cdb4a85801c253555134c 100644 (file)
@@ -26,7 +26,7 @@ std::string rmsnorm_cl_kernel_fp16_ =
     __kernel void rmsnorm_cl_fp16(
     __global const half *input,  // Input tensor
     __global half *output,    // Output tensor
-    __global const half *alpha,  // Alpha values (one for each channel)
+    __global const half *alpha,  // Alpha values (one for each width)
     half epsilon,
     int B,                  // Number of batches
     int C,                  // Number of channels
@@ -50,7 +50,7 @@ std::string rmsnorm_cl_kernel_fp16_ =
     half rms_norm = sqrt(sum_squares + epsilon);
     // Each work item processes all width elements for its specific n, h, c
     for (int w = 0; w < W; ++w) {
-        output[index+w] = (input[index+w] / rms_norm) * alpha[c];
+        output[index+w] = (input[index+w] / rms_norm) * alpha[index+w];
     }
 }
 )";
@@ -59,7 +59,7 @@ std::string rmsnorm_cl_kernel_ =
   R"(__kernel void rmsnorm_cl(
     __global const float *input,  // Input tensor
     __global float *output,    // Output tensor
-    __global const float *alpha,  // Alpha values (one for each channel)
+    __global const float *alpha,  // Alpha values (one for each width)
     float epsilon,
     int B,                  // Number of batches
     int C,                  // Number of channels
@@ -80,7 +80,7 @@ std::string rmsnorm_cl_kernel_ =
     float rms_norm = sqrt(sum_squares + epsilon);
     // Each work item processes all width elements for its specific n, h, c
     for (int w = 0; w < W; ++w) {
-        output[index+w] = (input[index+w] / rms_norm) * alpha[c];
+        output[index+w] = (input[index+w] / rms_norm) * alpha[index+w];
     }
 }
 )";
@@ -114,7 +114,7 @@ void RMSNormLayerCl::forwarding(RunLayerContext &context, bool training) {
   auto &epsilon = std::get<props::Epsilon>(rmsnorm_props).get();
   if (in.getDataType() == ml::train::TensorDim::DataType::FP32) {
     rmsnormProcess(in, out, gamma, epsilon, context);
-  } else{
+  } else {
     rmsnormProcess_fp16(in, out, gamma, epsilon, context);
   }
 }
@@ -276,14 +276,14 @@ void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result,
     if (!ret) {
       break;
     }
-    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
-      4, &b, sizeof(int));
+    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(4, &b,
+                                                                 sizeof(int));
     if (!ret) {
       break;
     }
 
-    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(3, &epsilon,
-                                                                 sizeof(cl_half));
+    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
+      3, &epsilon, sizeof(cl_half));
     if (!ret) {
       break;
     }
@@ -317,12 +317,11 @@ void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result,
       break;
     }
   } while (false);
-
 }
 
 void RMSNormLayerCl::incremental_forwarding(nntrainer::RunLayerContext &context,
-                                          unsigned int from, unsigned int to,
-                                          bool training) {
+                                            unsigned int from, unsigned int to,
+                                            bool training) {
   Tensor &in = context.getInput(SINGLE_INOUT_IDX);
   Tensor &out = context.getOutput(SINGLE_INOUT_IDX);
   Tensor &gamma = context.getWeight(wt_idx[RMSParams::gamma]);
@@ -374,4 +373,3 @@ void RMSNormLayerCl::setProperty(const std::vector<std::string> &values) {
 }
 
 } // namespace nntrainer
-