[GPU/OpenCL] Initial version of RMSNorm Layer
authorThummalaPallavi <t.pallavi@samsung.com>
Tue, 11 Jun 2024 09:24:49 +0000 (14:54 +0530)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 10 Jul 2024 04:36:32 +0000 (13:36 +0900)
Added naive version of OpenCL implementation for RMSNorm Layer.
Incorporated kernel for ops used.
Added unit test for rmsnorm_layer_cl.

Signed-off-by: ThummalaPallavi <t.pallavi@samsung.com>
api/ccapi/include/layer.h
api/nntrainer-api-common.h
nntrainer/cl_context.cpp
nntrainer/layers/cl_layers/meson.build
nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp [new file with mode: 0644]
nntrainer/layers/cl_layers/rmsnorm_layer_cl.h [new file with mode: 0644]
nntrainer/layers/layer_context.cpp
nntrainer/layers/layer_context.h
test/input_gen/gen_layer_tests.py
test/jni/Android.mk
test/unittest/layers/unittest_layers_rmsnorm_cl.cpp [new file with mode: 0644]

index d2ae90c872dffe95084d2e5c246bdbaa9ae46d01..5b09216a5fa32410510716b0109783ff2886153c 100644 (file)
@@ -101,6 +101,7 @@ enum LayerType {
   LAYER_LOSS_CONSTANT_DERIVATIVE, /**< Synthetic loss layer to feed constant
                                      derivative */
   LAYER_UPSAMPLE2D,               /**< Upsample 2D Layer type */
+  LAYER_RMSNORM = ML_TRAIN_LAYER_TYPE_RMSNORM, /**<RMS NORM Layer */
   LAYER_UNKNOWN = ML_TRAIN_LAYER_TYPE_UNKNOWN /**< Unknown */
 };
 
@@ -307,6 +308,15 @@ Swiglu(const std::vector<std::string> &properties = {},
   return createLayer(LayerType::LAYER_SWIGLU, properties, compute_engine);
 }
 
+/**
+ * @brief Helper function to create RMS normalization layer for GPU
+ */
+inline std::unique_ptr<Layer> RMSNormCl(
+  const std::vector<std::string> &properties = {},
+  const LayerComputeEngine &compute_engine = LayerComputeEngine::GPU) {
+  return createLayer(LayerType::LAYER_RMSNORM, properties, compute_engine);
+}
+
 /**
  * @brief Helper function to create batch normalization layer
  */
index 4c762150cc10a99243a4c627c7fd1a98a323563d..76d9976f3b50597f1a89edf788e5bb0d6f4e18ac 100644 (file)
@@ -76,6 +76,7 @@ typedef enum {
                                        Sigmoid Loss Layer type (Since 6.5) */
   ML_TRAIN_LAYER_TYPE_LOSS_CROSS_ENTROPY_SOFTMAX = 502, /**< Cross Entropy with
                                        Softmax Loss Layer type (Since 6.5) */
+  ML_TRAIN_LAYER_TYPE_RMSNORM = 503, /**< Cross Entropy with */
   ML_TRAIN_LAYER_TYPE_UNKNOWN = 999                     /**< Unknown Layer */
 } ml_train_layer_type_e;
 
index 1c9a32779a3bb2fe680d6d37ecf338e73e92df57..2ba0a390d39bd98c89ad9ca40bdef48f144606a6 100644 (file)
@@ -7,6 +7,7 @@
  * @see     https://github.com/nnstreamer/nntrainer
  * @author  Debadri Samaddar <s.debadri@samsung.com>
  * @author  Niket Agarwal <niket.a@samsung.com>
+ * @author  Thummala Pallavi <t.pallavi@samsung.com>
  * @bug     No known bugs except for NYI items
  * @brief   This file contains app context related functions and classes that
  * manages the global configuration of the current OpenCL environment. It also
@@ -18,6 +19,7 @@
 #include <fc_layer_cl.h>
 #include <reshape_cl.h>
 #include <swiglu_cl.h>
+#include <rmsnorm_layer_cl.h>
 
 namespace nntrainer {
 
@@ -40,6 +42,9 @@ static void add_default_object(ClContext &cc) {
 
   cc.registerFactory(nntrainer::createLayer<ReshapeLayerCl>,
                      ReshapeLayerCl::type, ml::train::LayerType::LAYER_RESHAPE);
+
+  cc.registerFactory(nntrainer::createLayer<RMSNormLayerCl>,
+                     RMSNormLayerCl::type, ml::train::LayerType::LAYER_RMSNORM);
 }
 
 static void registerer(ClContext &cc) noexcept {
index aa30060a5002ef950330af193d25586388a42629..634a2d8fb56b1d04c3444ce5b19fac1770d362d1 100644 (file)
@@ -3,6 +3,7 @@ cl_layer_sources = [
   'addition_layer_cl.cpp',
   'swiglu_cl.cpp',
   'reshape_cl.cpp',
+  'rmsnorm_layer_cl.cpp'
 ]
 
 foreach s : cl_layer_sources
diff --git a/nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp b/nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp
new file mode 100644 (file)
index 0000000..0dd1f15
--- /dev/null
@@ -0,0 +1,377 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Thummala Pallavi <t.pallavi@samsung.com>
+ *
+ * @file        rmsnorm_layer_cl.cpp
+ * @date        8 June 2024
+ * @brief       This is RMSNorm Layer Class for Neural Network with
+ * OpenCl implementation
+ * @see         https://github.com/nnstreamer/nntrainer
+ * @author      Thummala Pallavi <t.pallavi@samsung.com>
+ * @bug         No known bugs except for NYI items
+ *
+ */
+
+#include <common_properties.h>
+#include <layer_context.h>
+#include <lazy_tensor.h>
+#include <nntrainer_error.h>
+#include <node_exporter.h>
+#include <rmsnorm_layer_cl.h>
+#include <util_func.h>
+
+std::string rmsnorm_cl_kernel_fp16_ =
+  R"(
+    #pragma OPENCL EXTENSION cl_khr_fp16 : enable
+    __kernel void rmsnorm_cl_fp16(
+    __global const half *input,  // Input tensor
+    __global half *output,    // Output tensor
+    __global const half *alpha,  // Alpha values (one for each channel)
+    half epsilon,
+    int B,                  // Number of batches
+    int C,                  // Number of channels
+    int H,                  // Height of feature map
+    int W                   // Width of feature map
+) {
+    int global_id = get_global_id(0);  // Get the global work item index
+
+    // Compute the corresponding batch, height, and channel indices
+    int n = global_id / C;       // Batch index
+    int c = global_id % C;                    // Height index
+    int h = get_global_id(1);                    // Channel index
+    int index = ((n * C + c) * H + h) * W;
+
+    // Calculate RMS norm for the current channel, height, and batch
+    half sum_squares = 0.0f;
+    for (int j = 0; j < W; ++j) {
+        sum_squares += input[index+j] * input[index+j];
+    }
+    sum_squares /= W;
+    half rms_norm = sqrt(sum_squares + epsilon);
+    // Each work item processes all width elements for its specific n, h, c
+    for (int w = 0; w < W; ++w) {
+        output[index+w] = (input[index+w] / rms_norm) * alpha[c];
+    }
+}
+)";
+
+std::string rmsnorm_cl_kernel_ =
+  R"(__kernel void rmsnorm_cl(
+    __global const float *input,  // Input tensor
+    __global float *output,    // Output tensor
+    __global const float *alpha,  // Alpha values (one for each channel)
+    float epsilon,
+    int B,                  // Number of batches
+    int C,                  // Number of channels
+    int H,                  // Height of feature map
+    int W                   // Width of feature map
+) {
+    // Compute the corresponding batch, height, and channel indices
+    int n = get_global_id(0) / C;
+    int c = get_global_id(0) % C;
+    int h = get_global_id(1);
+    int index = ((n * C + c) * H + h) * W;
+    // Calculate RMS norm for the current channel, height, and batch
+    float sum_squares = 0.0f;
+    for (int j = 0; j < W; ++j) {
+        sum_squares += input[index+j] * input[index+j];
+    }
+    sum_squares /= W;
+    float rms_norm = sqrt(sum_squares + epsilon);
+    // Each work item processes all width elements for its specific n, h, c
+    for (int w = 0; w < W; ++w) {
+        output[index+w] = (input[index+w] / rms_norm) * alpha[c];
+    }
+}
+)";
+
+namespace nntrainer {
+
+static constexpr size_t SINGLE_INOUT_IDX = 0;
+
+enum RMSParams { gamma };
+
+RMSNormLayerCl::RMSNormLayerCl() : LayerImpl() { wt_idx.fill(0); }
+
+void RMSNormLayerCl::finalize(InitLayerContext &context) {
+  std::vector<TensorDim> dim = context.getInputDimensions();
+  context.setOutputDimensions(dim);
+  auto &rmsparams_gamma =
+    std::get<props::RMS_NORM_GAMMA_INIT_GPU>(rmsnorm_props);
+
+  TensorDim gamma_dim(
+    1, 1, 1, dim[0].width(),
+    TensorDim::TensorType(context.getFormat(), context.getWeightDataType()));
+  wt_idx[RMSParams::gamma] =
+    context.requestWeight(gamma_dim, rmsparams_gamma, WeightRegularizer::NONE,
+                          1.0f, 0.0f, "gamma", false);
+}
+
+void RMSNormLayerCl::forwarding(RunLayerContext &context, bool training) {
+  Tensor &in = context.getInput(SINGLE_INOUT_IDX);
+  Tensor &out = context.getOutput(SINGLE_INOUT_IDX);
+  Tensor &gamma = context.getWeight(wt_idx[RMSParams::gamma]);
+  auto &epsilon = std::get<props::Epsilon>(rmsnorm_props).get();
+  if (in.getDataType() == ml::train::TensorDim::DataType::FP32) {
+    rmsnormProcess(in, out, gamma, epsilon, context);
+  } else{
+    rmsnormProcess_fp16(in, out, gamma, epsilon, context);
+  }
+}
+
+opencl::Kernel RMSNormLayerCl::kernel_rmsnorm;
+opencl::Kernel RMSNormLayerCl::kernel_rmsnorm_fp16;
+
+void RMSNormLayerCl::rmsnormProcess(Tensor const &input, Tensor &result,
+                                    Tensor const &gamma, const float epsilon,
+                                    RunLayerContext &context) {
+  bool ret = false;
+  int dim1 = input.batch() * input.height() * input.width() * input.channel();
+  CREATE_IF_EMPTY_DIMS(result, input.batch(), input.channel(), input.height(),
+                       input.width(), input.getTensorType());
+  int b = input.batch();
+  int c = input.channel();
+  int h = input.height();
+  int w = input.width();
+  do {
+    ret =
+      context.clCreateKernel(rmsnorm_cl_kernel_, context.LayerKernel::RMSNORM,
+                             RMSNormLayerCl::kernel_rmsnorm);
+    if (!ret) {
+      break;
+    }
+
+    opencl::Buffer inputbuf(context.context_inst_, dim1 * sizeof(float), true,
+                            nullptr);
+
+    opencl::Buffer gammabuf(context.context_inst_,
+                            input.width() * sizeof(float), true, nullptr);
+    opencl::Buffer resultbuf(context.context_inst_, dim1 * sizeof(float), true,
+                             nullptr);
+
+    const float *data = input.getData();
+    float *rdata = result.getData();
+    const float *gdata = gamma.getData();
+    ret = inputbuf.WriteData(context.command_queue_inst_, data);
+    if (!ret) {
+      break;
+    }
+
+    ret = gammabuf.WriteData(context.command_queue_inst_, gdata);
+    if (!ret) {
+      break;
+    }
+    ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(0, &inputbuf,
+                                                            sizeof(cl_mem));
+    if (!ret) {
+      break;
+    }
+
+    ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(1, &resultbuf,
+                                                            sizeof(cl_mem));
+    if (!ret) {
+      break;
+    }
+
+    ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(2, &gammabuf,
+                                                            sizeof(cl_mem));
+    if (!ret) {
+      break;
+    }
+    ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(4, &b, sizeof(int));
+
+    if (!ret) {
+      break;
+    }
+
+    ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(3, &epsilon,
+                                                            sizeof(float));
+    if (!ret) {
+      break;
+    }
+
+    ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(5, &c, sizeof(int));
+    if (!ret) {
+      break;
+    }
+
+    ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(6, &h, sizeof(int));
+    if (!ret) {
+      break;
+    }
+    ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(7, &w, sizeof(int));
+    if (!ret) {
+      break;
+    }
+    const int work_groups_count[3] = {b * c, h, 1};
+    const int work_group_size[3] = {32, 32, 1}; // test-value
+
+    ret = context.command_queue_inst_.DispatchCommand(
+      RMSNormLayerCl::kernel_rmsnorm, work_groups_count, work_group_size);
+    if (!ret) {
+      break;
+    }
+
+    ret = resultbuf.ReadData(context.command_queue_inst_, rdata);
+    if (!ret) {
+      break;
+    }
+
+  } while (false);
+}
+
+void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result,
+                                         Tensor const &gamma,
+                                         const float epsilon,
+                                         RunLayerContext &context) {
+
+  bool ret = false;
+  int dim1 = input.batch() * input.height() * input.width() * input.channel();
+  CREATE_IF_EMPTY_DIMS(result, input.batch(), input.channel(), input.height(),
+                       input.width(), input.getTensorType());
+  int b = input.batch();
+  int c = input.channel();
+  int h = input.height();
+  int w = input.width();
+  do {
+    ret = context.clCreateKernel(rmsnorm_cl_kernel_fp16_,
+                                 context.LayerKernel::RMSNORM_FP16,
+                                 RMSNormLayerCl::kernel_rmsnorm_fp16);
+    if (!ret) {
+      break;
+    }
+    opencl::Buffer inputbuf(context.context_inst_, dim1 * sizeof(cl_half), true,
+                            nullptr);
+
+    opencl::Buffer gammabuf(context.context_inst_,
+                            input.width() * sizeof(cl_half), true, nullptr);
+    opencl::Buffer resultbuf(context.context_inst_, dim1 * sizeof(cl_half),
+                             true, nullptr);
+
+    const __fp16 *data = input.getData<__fp16>();
+    __fp16 *rdata = result.getData<__fp16>();
+    const __fp16 *gdata = gamma.getData<__fp16>();
+    ret = inputbuf.WriteData(context.command_queue_inst_, data);
+    if (!ret) {
+      break;
+    }
+
+    ret = gammabuf.WriteData(context.command_queue_inst_, gdata);
+    if (!ret) {
+      break;
+    }
+    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
+      0, &inputbuf, sizeof(cl_mem));
+    if (!ret) {
+      break;
+    }
+    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
+      1, &resultbuf, sizeof(cl_mem));
+    if (!ret) {
+      break;
+    }
+
+    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
+      2, &gammabuf, sizeof(cl_mem));
+    if (!ret) {
+      break;
+    }
+    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
+      4, &b, sizeof(int));
+    if (!ret) {
+      break;
+    }
+
+    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(3, &epsilon,
+                                                                 sizeof(cl_half));
+    if (!ret) {
+      break;
+    }
+
+    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(5, &c,
+                                                                 sizeof(int));
+    if (!ret) {
+      break;
+    }
+    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(6, &h,
+                                                                 sizeof(int));
+    if (!ret) {
+      break;
+    }
+    ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(7, &w,
+                                                                 sizeof(int));
+    if (!ret) {
+      break;
+    }
+    const int work_groups_count[3] = {b * c, h, 1};
+    const int work_group_size[3] = {32, 32, 1}; // test-value
+
+    ret = context.command_queue_inst_.DispatchCommand(
+      RMSNormLayerCl::kernel_rmsnorm_fp16, work_groups_count, work_group_size);
+    if (!ret) {
+      break;
+    }
+
+    ret = resultbuf.ReadData(context.command_queue_inst_, rdata);
+    if (!ret) {
+      break;
+    }
+  } while (false);
+
+}
+
+void RMSNormLayerCl::incremental_forwarding(nntrainer::RunLayerContext &context,
+                                          unsigned int from, unsigned int to,
+                                          bool training) {
+  Tensor &in = context.getInput(SINGLE_INOUT_IDX);
+  Tensor &out = context.getOutput(SINGLE_INOUT_IDX);
+  Tensor &gamma = context.getWeight(wt_idx[RMSParams::gamma]);
+  ml::train::TensorDim in_dim = in.getDim();
+  ml::train::TensorDim out_dim = out.getDim();
+
+  ml::train::TensorDim in_step_dim = in_dim;
+  ml::train::TensorDim out_step_dim = out_dim;
+
+  if (from) {
+    NNTR_THROW_IF(to - from != 1, std::invalid_argument)
+      << "incremental step size is not 1";
+    from = 0;
+    to = 1;
+  }
+
+  in_step_dim.height(to - from);
+  out_step_dim.height(to - from);
+
+  Tensor in_step = in.getSharedDataTensor(in_step_dim, 0, true);
+  Tensor out_step = out.getSharedDataTensor(out_step_dim, 0, true);
+
+  auto &epsilon = std::get<props::Epsilon>(rmsnorm_props).get();
+
+  if (in_step.getDataType() == ml::train::TensorDim::DataType::FP32) {
+    rmsnormProcess(in, out, gamma, epsilon, context);
+  } else {
+    rmsnormProcess_fp16(in, out, gamma, epsilon, context);
+  }
+}
+
+void RMSNormLayerCl::calcDerivative(RunLayerContext &context) {
+  ml_logi("Training not supported");
+}
+
+void RMSNormLayerCl::calcGradient(RunLayerContext &context) {
+  ml_logi("Training not supported");
+}
+
+void RMSNormLayerCl::exportTo(Exporter &exporter,
+                              const ml::train::ExportMethods &method) const {
+  LayerImpl::exportTo(exporter, method);
+  exporter.saveResult(rmsnorm_props, method, this);
+}
+
+void RMSNormLayerCl::setProperty(const std::vector<std::string> &values) {
+  auto remain_props = loadProperties(values, rmsnorm_props);
+  LayerImpl::setProperty(remain_props);
+}
+
+} // namespace nntrainer
+
diff --git a/nntrainer/layers/cl_layers/rmsnorm_layer_cl.h b/nntrainer/layers/cl_layers/rmsnorm_layer_cl.h
new file mode 100644 (file)
index 0000000..cd2fc9d
--- /dev/null
@@ -0,0 +1,170 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2020
+ *
+ * @file   rmsnorm_layer.h
+ * @date   8 June 2024
+ * @brief  This is RMS Norm Layer Class of Neural Network
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Thummala Pallavi <t.pallavi@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ */
+
+#ifndef __RMSNORM_LAYER_CL_H__
+#define __RMSNORM_LAYER_CL_H__
+#ifdef __cplusplus
+
+#include <common_properties.h>
+#include <layer_impl.h>
+#include <nntrainer_log.h>
+
+#include <opencl_buffer.h>
+#include <opencl_kernel.h>
+
+namespace nntrainer {
+
+namespace props{
+
+/**
+ * @brief RMS_NORM_GAMMA_INIT_GPU Initialization Enumeration Information
+ *
+ */
+class RMS_NORM_GAMMA_INIT_GPU final
+  : public ::nntrainer::EnumProperty<::nntrainer::props::InitializerInfo> {
+public:
+  /**
+   * @brief Construct a RMS_NORM_GAMMA_INIT object
+   */
+  RMS_NORM_GAMMA_INIT_GPU(::nntrainer::Tensor::Initializer value =
+                        ::nntrainer::Tensor::Initializer::ONES) {
+    set(value);
+  };
+  using prop_tag = enum_class_prop_tag;
+  static constexpr const char *key = "gamma_initializer";
+};
+};
+
+
+/**
+ * @class   RMSNormLayer
+ * @brief   RMS Norm layer
+ */
+class RMSNormLayerCl : public LayerImpl {
+public:
+  /**
+   * @brief     Constructor of RMS Norm Layer
+   */
+  RMSNormLayerCl();
+
+  /**
+   * @brief     Destructor of RMS Norm Layer
+   */
+  ~RMSNormLayerCl() = default;
+
+  /**
+   *  @brief  Move constructor.
+   *  @param[in] RMSNorm &&
+   */
+  RMSNormLayerCl(RMSNormLayerCl &&rhs) noexcept = default;
+
+  /**
+   * @brief  Move assignment operator.
+   * @parma[in] rhs RMS Norm to be moved.
+   */
+  RMSNormLayerCl &operator=(RMSNormLayerCl &&rhs) = default;
+
+  /**
+   * @copydoc Layer::finalize(InitLayerContext &context)
+   */
+  void finalize(InitLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
+   */
+  void forwarding(RunLayerContext &context, bool training) override;
+
+  /**
+   * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
+   * int from, unsigned int to, bool training)
+   */
+  void incremental_forwarding(RunLayerContext &context, unsigned int from,
+                              unsigned int to, bool training) override;
+
+  /**
+   * @copydoc Layer::calcDerivative(RunLayerContext &context)
+   */
+  void calcDerivative(RunLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::calcGradient(RunLayerContext &context)
+   */
+  void calcGradient(RunLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
+   * method)
+   */
+  void exportTo(Exporter &exporter,
+                const ml::train::ExportMethods &method) const override;
+
+  /**
+   * @copydoc Layer::getType()
+   */
+  const std::string getType() const override {
+    return RMSNormLayerCl::type;
+  };
+
+  static opencl::Kernel kernel_rmsnorm;
+  static opencl::Kernel kernel_rmsnorm_fp16;
+
+  /**
+   * @brief Process data and dimensions for rms norm operation
+   * @param[in] input Tensor
+   * @param[in] result Tensor
+   * @param[in] gamma Tensor
+   * @param[in] epsilon float
+   * @param[in] RunLayerContext reference
+   */
+
+
+  void rmsnormProcess(Tensor const &input, Tensor &result, Tensor const &gamma, const float epsilon,
+                    RunLayerContext &context);
+
+
+  /**
+   * @brief Process data and dimensions for FP16 rms norm operation
+   * @param[in] input Tensor
+   * @param[in] result Tensor
+   * @param[in] gamma Tensor
+   * @param[in] epsilon float
+   * @param[in] RunLayerContext reference
+   */
+
+
+  void rmsnormProcess_fp16(Tensor const &input, Tensor &result, Tensor const &gamma, const float epsilon,
+                    RunLayerContext &context);
+  /**
+   * @copydoc Layer::supportBackwarding()
+   */
+  bool supportBackwarding() const override {
+    return false;
+  }
+
+  /**
+   * @copydoc Layer::setProperty(const std::vector<std::string> &values)
+   */
+  void setProperty(const std::vector<std::string> &values) override;
+
+  inline static const std::string type = "rmsnorm";
+
+private:
+  std::array<unsigned int, 1> wt_idx;
+  std::tuple<props::RMS_NORM_GAMMA_INIT_GPU, props::Epsilon>
+    rmsnorm_props; /**< rmsnorm layer properties */
+};
+} // namespace nntrainer
+
+#endif /* __cplusplus */
+#endif /* __RMSNORM_LAYER_CL__ */
+
index 532634ffc906b9aa8abe21b48ce0fb6aa83bd269..d71221c3525226733b3e57b596b864fad41def34 100644 (file)
@@ -719,6 +719,10 @@ std::string RunLayerContext::getKernelName(LayerKernel layerKernel) {
     return "copy_cl";
   case LayerKernel::COPY_FP16:
     return "copy_cl_fp16";
+  case LayerKernel::RMSNORM:
+    return "rmsnorm_cl";
+  case LayerKernel::RMSNORM_FP16:
+    return "rmsnorm_cl_fp16";
   default:
     return "";
   }
index a37f16aca52d8040efb22c3c71ef1ca1ffa8e51a..b8b8ffccd8d8a748ea358b39126a918645e9edd7 100644 (file)
@@ -850,6 +850,8 @@ public:
     SSCAL_FP16 = 1 << 17,         /**< placeholder for kernel name */
     COPY = 1 << 18,               /**< placeholder for kernel name */
     COPY_FP16 = 1 << 19,          /**< placeholder for kernel name */
+    RMSNORM = 1 << 20,
+    RMSNORM_FP16 = 1 << 21
   };
 
   /**
index 55ed99e6ef66747c373302d630d6aaee9e82df73..1300dcd8d79e05da81510fbef25879437ab9d72c 100644 (file)
@@ -19,6 +19,7 @@ Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
 @author Sungsik Kong <ss.kong@samsung.com>
 @author        Debadri Samaddar <s.debadri@samsung.com>
 @author        Niket Agarwal <niket.a@samsung.com>
+@author        Thummala Pallavi <t.pallavi@samsung.com>
 """
 
 import warnings
@@ -922,3 +923,34 @@ if __name__ == "__main__":
 
     reshape_layer = tf.keras.layers.Lambda(lambda x: reshape_tensor(x, 2, 3, 3, 3))
     record_single(reshape_layer, (2, 3, 3, 3), "reshape", input_type="float")
+
+    class RMSNorm(tf.keras.layers.Layer):
+        def __init__(self, epsilon=1e-3, **kwargs):
+            super(RMSNorm, self).__init__(**kwargs)
+            self.epsilon = epsilon
+
+        def build(self, input_shape):
+            # Initialize gamma as trainable parameters
+            self.gamma = self.add_weight(
+                shape=input_shape[-1:],
+                initializer=tf.keras.initializers.Ones(),
+                trainable=False,
+                name='gamma'
+            )
+            super(RMSNorm, self).build(input_shape)
+
+        def call(self, inputs):
+            # Compute the mean of the squares of the inputs along the last dimension
+            mean_square = tf.reduce_mean(tf.square(inputs), axis=[-1], keepdims=True)
+            print(mean_square)
+            # Compute the RMS value with epsilon for numerical stability
+            rms_value = tf.sqrt(mean_square + self.epsilon)
+            print(rms_value)
+            # Normalize inputs and scale by gamma
+            normalized_inputs = inputs / rms_value * self.gamma
+            return normalized_inputs
+
+    rms_normtest = RMSNorm()
+    rms_normtest_fp16 = RMSNorm()
+    record_single(rms_normtest,(2,3,3,3),"rms_normtest")
+    record_single_fp16(rms_normtest_fp16,(2,3,3,3),"rms_normtest_fp16_new")
index 1418a42137f3c1a8a7b2aa50f09e646164d66bcf..e0853d90b8932a6c59d46ae7452225fa7224247e 100644 (file)
@@ -448,6 +448,7 @@ LOCAL_SRC_FILES := \
         ../unittest/layers/unittest_layers_loss.cpp \
         ../unittest/layers/unittest_layers_reshape_cl.cpp \
         ../unittest/layers/unittest_layers_fully_connected.cpp \
+     ../unittest/layers/unittest_layers_rmsnorm_cl.cpp \
         ../unittest/layers/unittest_layers_batch_normalization.cpp \
         ../unittest/layers/unittest_layers_layer_normalization.cpp \
         ../unittest/layers/unittest_layers_convolution2d.cpp \
diff --git a/test/unittest/layers/unittest_layers_rmsnorm_cl.cpp b/test/unittest/layers/unittest_layers_rmsnorm_cl.cpp
new file mode 100644 (file)
index 0000000..ddd356e
--- /dev/null
@@ -0,0 +1,50 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Thummala Pallavi <t.pallavi@samsung.com>
+ *
+ * @file unittest_layers_rmsnorm_cl.cpp
+ * @date 7 June 2024
+ * @brief RMS Norm Layer Test
+ * @see        https://github.com/nnstreamer/nntrainer
+ * @author Thummala Pallavi <t.pallavi@samsung.com>
+ * @bug No known bugs except for NYI items
+ */
+#include <tuple>
+
+#include <gtest/gtest.h>
+
+#include <layers_common_tests.h>
+#include <rmsnorm_layer_cl.h>
+
+auto semantic_rms = LayerSemanticsParamType(
+  nntrainer::createLayer<nntrainer::RMSNormLayerCl>,
+  nntrainer::RMSNormLayerCl::type, {"epsilon=0.001"},
+  LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1);
+
+GTEST_PARAMETER_TEST(RMSNormGPU, LayerSemanticsGpu,
+                     ::testing::Values(semantic_rms));
+
+auto rms_plain_skip_CG = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::RMSNormLayerCl>, {"epsilon=0.001"},
+  "2:3:3:3", "rms_normtest.nnlayergolden",
+  LayerGoldenTestParamOptions::SKIP_CALC_DERIV |
+    LayerGoldenTestParamOptions::SKIP_CALC_GRAD |
+    LayerGoldenTestParamOptions::USE_INC_FORWARD,
+  "nchw", "fp32", "fp32");
+
+GTEST_PARAMETER_TEST(RMSNormGPU, LayerGoldenTest,
+                     ::testing::Values(rms_plain_skip_CG));
+
+#ifdef ENABLE_FP16
+auto rms_plain_skip_CG_fp16 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::RMSNormLayerCl>, {"epsilon=0.001"},
+  "2:3:3:3", "rms_normtest_fp16_new.nnlayergolden",
+  LayerGoldenTestParamOptions::SKIP_CALC_DERIV |
+    LayerGoldenTestParamOptions::SKIP_CALC_GRAD |
+    LayerGoldenTestParamOptions::USE_INC_FORWARD,
+  "nchw", "fp16", "fp16");
+
+GTEST_PARAMETER_TEST(RMSNormGPU16, LayerGoldenTest,
+                     ::testing::Values(rms_plain_skip_CG_fp16));
+
+#endif