[GPU/OpenCL] Initial version of Rotary Embedding with OpenCL ops
authorYash Singh <yash.singh@samsung.com>
Wed, 28 Aug 2024 12:10:24 +0000 (17:40 +0530)
committerJijoong Moon <jijoong.moon@samsung.com>
Sun, 20 Oct 2024 23:33:51 +0000 (08:33 +0900)
Added initial version of Rotary Embedding kernel for GPU. This includes both FP32 and FP16 implementation got GPU kernel.

Signed-off-by: Yash Singh <yash.singh@samsung.com>
nntrainer/tensor/cl_operations/attention_kernel_interface.cpp [new file with mode: 0644]
nntrainer/tensor/cl_operations/attention_kernel_interface.h [new file with mode: 0644]
nntrainer/tensor/cl_operations/attention_kernels.cpp [new file with mode: 0644]
nntrainer/tensor/cl_operations/attention_kernels.h [new file with mode: 0644]
nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp [new file with mode: 0644]
nntrainer/tensor/cl_operations/meson.build
nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp [new file with mode: 0644]
test/jni/Android.mk
test/unittest/unittest_attention_kernels_cl.cpp [new file with mode: 0644]

diff --git a/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp b/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp
new file mode 100644 (file)
index 0000000..cf28840
--- /dev/null
@@ -0,0 +1,136 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file       attention_kernel_interface.cpp
+ * @date       28 August 2024
+ * @brief      Interface for attention OpenCL kernels
+ * @see                https://github.com/nnstreamer/nntrainer
+ * @author     Yash Singh <yash.singh@samsung.com>
+ * @bug                No known bugs except for NYI items
+ *
+ */
+
+#include <attention_kernel_interface.h>
+#include <attention_kernels.h>
+
+namespace nntrainer {
+/**
+ * @brief      compute frequency for rotary embedding
+ * @param[in]  dim hidden dim size
+ * @param[in]  seq_len sequency length
+ * @param[out] freqs_cos cosine of the frequencies
+ * @param[out] freqs_sin sine of the frequencies
+ * @param[out] freqs base frequencies array to be used in the future computation
+ * @param[in]  theta rotary angle
+ */
+void precompute_freqs(int dim, unsigned int seq_len,
+                      std::vector<std::vector<float>> &freqs_cos,
+                      std::vector<std::vector<float>> &freqs_sin,
+                      std::vector<float> &freqs, float theta = 10000.0) {
+  unsigned int half_ = dim / 2;
+  for (unsigned int i = 0; i < half_; ++i) {
+    freqs.push_back(1.0 / (std::pow(theta, (2 * i) / static_cast<float>(dim))));
+  }
+
+  auto cos = std::vector<std::vector<float>>();
+  cos.assign(seq_len, std::vector<float>(dim, 0));
+
+  auto sin = std::vector<std::vector<float>>();
+  sin.assign(seq_len, std::vector<float>(dim, 0));
+
+  for (unsigned int i = 0; i < seq_len; ++i) {
+    for (unsigned int j = 0; j < half_; ++j) {
+      float angle = i * freqs[j];
+      cos[i][j] = std::cos(angle);
+      cos[i][j + half_] = std::cos(angle); // repeated 2 times
+
+      sin[i][j] = std::sin(angle);
+      sin[i][j + half_] = std::sin(angle); // repeated 2 times
+    }
+  }
+  freqs_cos = cos;
+  freqs_sin = sin;
+}
+
+/**
+ * @brief     apply rotary embedding
+ * @param[in] in input tensor
+ * @param[in] dim hidden dim size
+ * @param[in] from sequence order
+ * @param[in] max_timestep maximum timestep
+ */
+void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from,
+                         unsigned int max_timestep, RunLayerContext &context) {
+  nntrainer::Tensor out(in.getDim());
+  float value = 0;
+  float transformed_value = 0.0;
+  unsigned int half_ = dim / 2;
+
+  std::vector<std::vector<float>> freqs_cos = {};
+  std::vector<std::vector<float>> freqs_sin = {};
+  std::vector<float> freqs;
+
+  precompute_freqs(dim, max_timestep, freqs_cos, freqs_sin, freqs);
+
+  std::vector<float> cos_;
+  std::vector<float> sin_;
+
+  if (from >= max_timestep) {
+    cos_.resize(dim);
+    sin_.resize(dim);
+
+    for (unsigned int i = 0; i < half_; ++i) {
+      float angle = from * freqs[i];
+      cos_[i] = std::cos(angle);
+      cos_[i + half_] = std::cos(angle); // repeated 2 times
+
+      sin_[i] = std::sin(angle);
+      sin_[i + half_] = std::sin(angle); // repeated 2 times
+    }
+  } else {
+    cos_.resize(max_timestep);
+    sin_.resize(max_timestep);
+  }
+
+  unsigned int input_batch_size, input_height, input_width, input_channels;
+  input_batch_size = in.batch();
+  input_height = in.height();
+  input_width = in.width();
+  input_channels = in.channel();
+
+  if (in.getDataType() == ml::train::TensorDim::DataType::FP32) {
+
+    unsigned int in_size = in.size();
+    unsigned int out_size = out.size();
+    float *data = in.getData();
+    float *rdata = out.getData();
+
+    rotary_emb_cl(data, rdata, freqs_cos, freqs_sin, cos_, sin_,
+                  input_batch_size, input_channels, input_height, input_width,
+                  dim, from, max_timestep, in_size, out_size, context);
+
+  } else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+
+    unsigned int in_size = in.size();
+    unsigned int out_size = out.size();
+    _FP16 *data = in.getData<_FP16>();
+    _FP16 *rdata = out.getData<_FP16>();
+
+    rotary_emb_cl(data, rdata, freqs_cos, freqs_sin, cos_, sin_,
+                  input_batch_size, input_channels, input_height, input_width,
+                  dim, from, max_timestep, in_size, out_size, context);
+#else
+    throw std::invalid_argument("Error: enable-fp16 is not enabled");
+#endif
+  }
+
+  if (from >= max_timestep) {
+    cos_.clear();
+    sin_.clear();
+  }
+
+  in.copy(out);
+}
+} // namespace nntrainer
\ No newline at end of file
diff --git a/nntrainer/tensor/cl_operations/attention_kernel_interface.h b/nntrainer/tensor/cl_operations/attention_kernel_interface.h
new file mode 100644 (file)
index 0000000..878561b
--- /dev/null
@@ -0,0 +1,34 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file       blas_kernel_interface.h
+ * @date       28 August 2024
+ * @brief      Interface for attention OpenCL kernels
+ * @see                https://github.com/nnstreamer/nntrainer
+ * @author     Yash Singh <yash.singh@samsung.com>
+ * @bug                No known bugs except for NYI items
+ *
+ */
+
+#ifndef __ATTENTION_KERNEL_INTERFACE_H__
+#define __ATTENTION_KERNEL_INTERFACE_H__
+
+#include <layer_context.h>
+#include <string>
+
+namespace nntrainer {
+
+/**
+ * @brief     Rotary Embedding kernel
+ * @param[in] in input tensor
+ * @param[in] dim hidden dim size
+ * @param[in] from sequence order
+ * @param[in] max_timestep maximum timestep
+ * @param[in] context layer context to get the resource manager and queue id
+ */
+void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from,
+                         unsigned int max_timestep, RunLayerContext &context);
+
+} // namespace nntrainer
+#endif /* __ATTENTION_KERNEL_INTERFACE_H__ */
\ No newline at end of file
diff --git a/nntrainer/tensor/cl_operations/attention_kernels.cpp b/nntrainer/tensor/cl_operations/attention_kernels.cpp
new file mode 100644 (file)
index 0000000..355bb8e
--- /dev/null
@@ -0,0 +1,273 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file       attention_kernels.cpp
+ * @date       28 August 2024
+ * @brief      Common attention OpenCL kernels
+ * @see                https://github.com/nnstreamer/nntrainer
+ * @author     Yash Singh <yash.singh@samsung.com>
+ * @bug                No known bugs except for NYI items
+ *
+ */
+
+#include <attention_kernels.h>
+
+namespace nntrainer {
+std::string rotary_emb_cl_kernel = R"(
+  #pragma OPENCL EXTENSION cl_khr_fp16 : enable
+__kernel void rotary_emb_cl(__global float *input,
+                                      __global float *output,
+                                      __global float *freqs_cos,
+                                      __global float *freqs_sin,
+                                      __global float *cos_,
+                                      __global float *sin_,
+                                      unsigned int batch,
+                                      unsigned int channel,
+                                      unsigned int height,
+                                      unsigned int width,
+                                      unsigned int dim,
+                                      unsigned int half_,
+                                      unsigned int max_timestep,
+                                      unsigned int from) {
+    unsigned int gid = get_global_id(0);
+    unsigned int gws = get_global_size(0);
+
+    __global float *cos_ptr = cos_;
+    __global float *sin_ptr = sin_;
+
+    float value = 0.0f;
+    float transformed_value = 0.0f;
+
+    for (unsigned int b = 0; b < batch; b++) {
+      for (unsigned int c = 0; c < channel; c++) {
+        for (unsigned int h = 0; h < height; h++) {
+          if (from + h < max_timestep) {
+            unsigned idx = (from + h)*dim;
+            for(unsigned int i = idx; i < idx + dim; i++){
+              cos_ptr[i - idx] = freqs_cos[i];
+              sin_ptr[i - idx] = freqs_sin[i];
+            }
+          }
+          for (unsigned int w = 0; w < width; w = w + dim) {
+            for (unsigned int k = 0; k < dim; k++) {
+              unsigned int span = w + k;
+              value = input[b * channel * height * width + c * height * width + h * width + span];
+              if (k < half_) {
+                transformed_value = -1.0f * input[b * channel * height * width + c * height * width + h * width + span + half_];
+              } else {
+                transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_];
+              }
+              value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
+              // printf("GPU Batch: %u, Height: %u, Channel: %u, Width: %u, K: %u, Span: %u, Value: %f, Transformed Value: %f, cos_ptr[k]: %f, sin_ptr[k]: %f\n",  b, h, c, w, k, span, value, transformed_value, cos_ptr[k], sin_ptr[k]);
+              output[b * channel * height * width + c * height * width + h * width + span] = value;
+            }
+          }
+        }
+      }
+    }
+}
+)";
+
+/**
+ * @brief defining global kernel objects
+ */
+opencl::Kernel kernel_rotary_emb;
+
+void rotary_emb_cl(float *in, float *out,
+                   std::vector<std::vector<float>> freqs_cos,
+                   std::vector<std::vector<float>> freqs_sin,
+                   std::vector<float> cos_, std::vector<float> sin_,
+                   unsigned int batch, unsigned int channel,
+                   unsigned int height, unsigned int width, unsigned int dim,
+                   unsigned int from, unsigned int max_timestep,
+                   unsigned int in_size, unsigned int out_size,
+                   RunLayerContext &context) {
+  bool result = false;
+
+  do {
+    result = context.clCreateKernel(
+      rotary_emb_cl_kernel, context.LayerKernel::ROTARY_EMB, kernel_rotary_emb);
+    if (!result) {
+      printf("Failed to create kernel for rotary_emb_cl\n");
+      break;
+    }
+    unsigned int cos_dim = cos_.size();
+    unsigned int sin_dim = sin_.size();
+    unsigned int freqs_cos_dim = freqs_cos.size();
+    unsigned int freqs_sin_dim = freqs_sin.size();
+
+    size_t dim1_size = sizeof(float) * in_size;
+    size_t dim2_size = sizeof(float) * out_size;
+    size_t dim3_size = sizeof(float) * cos_dim;
+    size_t dim4_size = sizeof(float) * sin_dim;
+    size_t dim5_size =
+      sizeof(float) * freqs_cos_dim * dim; // max_timestep * dim
+    size_t dim6_size = sizeof(float) * freqs_sin_dim * dim;
+
+    opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr);
+
+    opencl::Buffer inOutRes(context.context_inst_, dim2_size, true, nullptr);
+
+    opencl::Buffer cosBuf(context.context_inst_, dim3_size, true, nullptr);
+
+    opencl::Buffer sinBuf(context.context_inst_, dim4_size, true, nullptr);
+
+    opencl::Buffer freqs_cosBuf(context.context_inst_, dim5_size, true,
+                                nullptr);
+
+    opencl::Buffer freqs_sinBuf(context.context_inst_, dim6_size, true,
+                                nullptr);
+
+    std::vector<float> freqs_cos_flat;
+    std::vector<float> freqs_sin_flat;
+    for (const auto &row : freqs_cos) {
+      freqs_cos_flat.insert(freqs_cos_flat.end(), row.begin(), row.end());
+    }
+    for (const auto &row : freqs_sin) {
+      freqs_sin_flat.insert(freqs_sin_flat.end(), row.begin(), row.end());
+    }
+
+    result = inputA.WriteData(context.command_queue_inst_, in);
+    if (!result) {
+      printf("Failed to write input data\n");
+      break;
+    }
+
+    result = inOutRes.WriteData(context.command_queue_inst_, out);
+    if (!result) {
+      printf("Failed to write output data\n");
+      break;
+    }
+
+    result = freqs_cosBuf.WriteData(context.command_queue_inst_,
+                                    freqs_cos_flat.data());
+    if (!result) {
+      printf("Failed to write freqs cos data\n");
+      break;
+    }
+
+    result = freqs_sinBuf.WriteData(context.command_queue_inst_,
+                                    freqs_sin_flat.data());
+    if (!result) {
+      printf("Failed to write freqs sin data\n");
+      break;
+    }
+
+    result = cosBuf.WriteData(context.command_queue_inst_, cos_.data());
+    if (!result) {
+      printf("Failed to write cos data\n");
+      break;
+    }
+
+    result = sinBuf.WriteData(context.command_queue_inst_, sin_.data());
+    if (!result) {
+      printf("Failed to write sin data\n");
+      break;
+    }
+
+    result = kernel_rotary_emb.SetKernelArguments(0, &inputA, sizeof(cl_mem));
+    if (!result) {
+      printf("Failed to set inputA argument\n");
+      break;
+    }
+
+    result = kernel_rotary_emb.SetKernelArguments(1, &inOutRes, sizeof(cl_mem));
+    if (!result) {
+      printf("Failed to set inOutRes argument\n");
+      break;
+    }
+
+    result =
+      kernel_rotary_emb.SetKernelArguments(2, &freqs_cosBuf, sizeof(cl_mem));
+    if (!result) {
+      printf("Failed to set freqs_cosBuf argument\n");
+      break;
+    }
+
+    result =
+      kernel_rotary_emb.SetKernelArguments(3, &freqs_sinBuf, sizeof(cl_mem));
+    if (!result) {
+      printf("Failed to set freqs_sinBuf argument\n");
+      break;
+    }
+
+    result = kernel_rotary_emb.SetKernelArguments(4, &cosBuf, sizeof(cl_mem));
+    if (!result) {
+      printf("Failed to set cosBuf argument\n");
+      break;
+    }
+
+    result = kernel_rotary_emb.SetKernelArguments(5, &sinBuf, sizeof(cl_mem));
+    if (!result) {
+      printf("Failed to set sinBuf argument\n");
+      break;
+    }
+
+    result = kernel_rotary_emb.SetKernelArguments(6, &batch, sizeof(int));
+    if (!result) {
+      printf("Failed to set batch argument\n");
+      break;
+    }
+
+    result = kernel_rotary_emb.SetKernelArguments(7, &channel, sizeof(int));
+    if (!result) {
+      printf("Failed to set channel argument\n");
+      break;
+    }
+
+    result = kernel_rotary_emb.SetKernelArguments(8, &height, sizeof(int));
+    if (!result) {
+      printf("Failed to set height argument\n");
+      break;
+    }
+
+    result = kernel_rotary_emb.SetKernelArguments(9, &width, sizeof(int));
+    if (!result) {
+      printf("Failed to set width argument\n");
+      break;
+    }
+
+    result = kernel_rotary_emb.SetKernelArguments(10, &dim, sizeof(int));
+    if (!result) {
+      printf("Failed to set dim argument\n");
+      break;
+    }
+    unsigned int half_ = dim / 2;
+    result = kernel_rotary_emb.SetKernelArguments(11, &half_, sizeof(int));
+    if (!result) {
+      printf("Failed to set half argument\n");
+      break;
+    }
+
+    result =
+      kernel_rotary_emb.SetKernelArguments(12, &max_timestep, sizeof(int));
+    if (!result) {
+      printf("Failed to set timestamp argument\n");
+      break;
+    }
+
+    result = kernel_rotary_emb.SetKernelArguments(13, &from, sizeof(int));
+    if (!result) {
+      printf("Failed to set from argument\n");
+      break;
+    }
+
+    const int work_groups_count[3] = {1, 1, 1};
+    const int work_group_size[3] = {32, 1, 1}; // test-value
+    result = context.command_queue_inst_.DispatchCommand(
+      kernel_rotary_emb, work_groups_count, work_group_size);
+    if (!result) {
+      printf("Failed to dispatch command\n");
+      break;
+    }
+
+    result = inOutRes.ReadData(context.command_queue_inst_, out);
+    if (!result) {
+      printf("Failed to read data\n");
+      break;
+    }
+
+  } while (false);
+}
+} // namespace nntrainer
\ No newline at end of file
diff --git a/nntrainer/tensor/cl_operations/attention_kernels.h b/nntrainer/tensor/cl_operations/attention_kernels.h
new file mode 100644 (file)
index 0000000..432b232
--- /dev/null
@@ -0,0 +1,96 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file       attention_kernels.h
+ * @date       28 August 2024
+ * @brief      Common attention OpenCL kernels
+ * @see                https://github.com/nnstreamer/nntrainer
+ * @author     Yash Singh <yash.singh@samsung.com>
+ * @bug                No known bugs except for NYI items
+ *
+ */
+
+#ifndef __ATTENTION_KERNELS_H__
+#define __ATTENTION_KERNELS_H__
+
+#include <layer_context.h>
+#include <opencl_buffer.h>
+#include <opencl_kernel.h>
+#include <string>
+
+namespace nntrainer {
+
+/**
+ * @brief declaring global kernel objects
+ */
+extern opencl::Kernel kernel_rotary_emb;
+
+/**
+ * @brief     Rotary Embedding process
+ * @param[in] in __fp16 * input
+ * @param[in] out __fp16 * output
+ * @param[out] freqs_cos cosine of the frequencies
+ * @param[out] freqs_sin sine of the frequencies
+ * @param[in] cos_ vector of cos values
+ * @param[in] sin_ vector of sin values
+ * @param[in] batch size of batch
+ * @param[in] channel channel of input
+ * @param[in] height height of input
+ * @param[in] width width of input
+ * @param[in] dim hidden dim size
+ * @param[in] from sequence order
+ * @param[in] max_timestep max timestep
+ * @param[in] in_size size of input
+ * @param[in] out_size size of output
+ * @param[in] context RunLayerContext reference
+ */
+void rotary_emb_cl(float *in, float *out,
+                   std::vector<std::vector<float>> freqs_cos,
+                   std::vector<std::vector<float>> freqs_sin,
+                   std::vector<float> cos_, std::vector<float> sin_,
+                   unsigned int batch, unsigned int channel,
+                   unsigned int height, unsigned int width, unsigned int dim,
+                   unsigned int from, unsigned int max_timestamp,
+                   unsigned int in_size, unsigned int out_size,
+                   RunLayerContext &context);
+
+#ifdef ENABLE_FP16
+/**
+ * @brief declaring global fp16 kernel objects
+ */
+extern opencl::Kernel kernel_rotary_emb_fp16;
+
+/**
+ * @brief     Rotary Embedding process
+ * @param[in] in __fp16 * input
+ * @param[in] out __fp16 * output
+ * @param[out] freqs_cos cosine of the frequencies
+ * @param[out] freqs_sin sine of the frequencies
+ * @param[in] cos_ vector of cos values
+ * @param[in] sin_ vector of sin values
+ * @param[in] batch size of batch
+ * @param[in] channel channel of input
+ * @param[in] height height of input
+ * @param[in] width width of input
+ * @param[in] dim hidden dim size
+ * @param[in] from sequence order
+ * @param[in] max_timestep max timestep
+ * @param[in] in_size size of input
+ * @param[in] out_size size of output
+ * @param[in] context RunLayerContext reference
+ */
+void rotary_emb_cl(__fp16 *in, __fp16 *out,
+                   std::vector<std::vector<float>> freqs_cos,
+                   std::vector<std::vector<float>> freqs_sin,
+                   std::vector<float> cos_, std::vector<float> sin_,
+                   unsigned int batch, unsigned int channel,
+                   unsigned int height, unsigned int width, unsigned int dim,
+                   unsigned int from, unsigned int max_timestamp,
+                   unsigned int in_size, unsigned int out_size,
+                   RunLayerContext &context);
+
+#endif
+
+} // namespace nntrainer
+#endif /* __ATTENTION_KERNELS_H__ */
\ No newline at end of file
diff --git a/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp b/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp
new file mode 100644 (file)
index 0000000..b5d0ca5
--- /dev/null
@@ -0,0 +1,279 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file       attention_kernels_fp16.cpp
+ * @date       28 August 2024
+ * @brief      Common attention OpenCL fp16 kernels
+ * @see                https://github.com/nnstreamer/nntrainer
+ * @author     Yash Singh <yash.singh@samsung.com>
+ * @bug                No known bugs except for NYI items
+ *
+ */
+
+#include <attention_kernels.h>
+
+namespace nntrainer {
+std::string rotary_emb_cl_kernel_fp16 = R"(
+  #pragma OPENCL EXTENSION cl_khr_fp16 : enable
+__kernel void rotary_emb_cl_fp16(__global half *input,
+                                      __global half *output,
+                                      __global float *freqs_cos,
+                                      __global float *freqs_sin,
+                                      __global float *cos_,
+                                      __global float *sin_,
+                                      unsigned int batch,
+                                      unsigned int channel,
+                                      unsigned int height,
+                                      unsigned int width,
+                                      unsigned int dim,
+                                      unsigned int half_,
+                                      unsigned int max_timestep,
+                                      unsigned int from) {
+    unsigned int gid = get_global_id(0);
+    unsigned int gws = get_global_size(0);
+
+    __global float *cos_ptr = cos_;
+    __global float *sin_ptr = sin_;
+
+    float value = 0.0f;
+    float transformed_value = 0.0f;
+
+    for (unsigned int b = 0; b < batch; b++) {
+      for (unsigned int c = 0; c < channel; c++) {
+        for (unsigned int h = 0; h < height; h++) {
+          if (from + h < max_timestep) {
+            unsigned idx = (from + h)*dim;
+            for(int i = idx; i < idx + dim; i++ ){
+              cos_ptr[i - idx] = freqs_cos[i];
+              sin_ptr[i - idx] = freqs_sin[i];
+            }
+          }
+
+          for (unsigned int w = 0; w < width; w = w + dim) {
+            for (unsigned int k = 0; k < dim; k++) {
+              unsigned int span = w + k;
+              value = (float)input[b * channel * height * width + c * height * width + h * width + span];
+              if (k < half_) {
+                transformed_value = -1.0f * (float)input[b * channel * height * width + c * height * width + h * width + span + half_];
+              } else {
+                transformed_value = (float)input[b * channel * height * width + c * height * width + h * width + span - half_];
+              }
+              value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
+              output[b * channel * height * width + c * height * width + h * width + span] = (half)value;
+            }
+          }
+        }
+      }
+    }
+}
+)";
+
+/**
+ * @brief defining global kernel objects
+ */
+opencl::Kernel kernel_rotary_emb_fp16;
+
+void rotary_emb_cl(__fp16 *in, __fp16 *out,
+                   std::vector<std::vector<float>> freqs_cos,
+                   std::vector<std::vector<float>> freqs_sin,
+                   std::vector<float> cos_, std::vector<float> sin_,
+                   unsigned int batch, unsigned int channel,
+                   unsigned int height, unsigned int width, unsigned int dim,
+                   unsigned int from, unsigned int max_timestep,
+                   unsigned int in_size, unsigned int out_size,
+                   RunLayerContext &context) {
+
+  bool result = false;
+  do {
+    result = context.clCreateKernel(rotary_emb_cl_kernel_fp16,
+                                    context.LayerKernel::ROTARY_EMB_FP16,
+                                    kernel_rotary_emb_fp16);
+    if (!result) {
+      printf("Failed to create kernel for rotary_emb_cl\n");
+      break;
+    }
+
+    unsigned int cos_dim = cos_.size();
+    unsigned int sin_dim = sin_.size();
+    unsigned int freqs_cos_dim = freqs_cos.size();
+    unsigned int freqs_sin_dim = freqs_sin.size();
+
+    size_t dim1_size = sizeof(cl_half) * in_size;
+    size_t dim2_size = sizeof(cl_half) * out_size;
+    size_t dim3_size = sizeof(float) * cos_dim;
+    size_t dim4_size = sizeof(float) * sin_dim;
+    size_t dim5_size = sizeof(float) * freqs_cos_dim * dim;
+    size_t dim6_size = sizeof(float) * freqs_sin_dim * dim;
+
+    opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr);
+
+    opencl::Buffer inOutRes(context.context_inst_, dim2_size, true, nullptr);
+
+    opencl::Buffer cosBuf(context.context_inst_, dim3_size, true, nullptr);
+
+    opencl::Buffer sinBuf(context.context_inst_, dim4_size, true, nullptr);
+
+    opencl::Buffer freqs_cosBuf(context.context_inst_, dim5_size, true,
+                                nullptr);
+
+    opencl::Buffer freqs_sinBuf(context.context_inst_, dim6_size, true,
+                                nullptr);
+
+    std::vector<float> freqs_cos_flat;
+    std::vector<float> freqs_sin_flat;
+    for (const auto &row : freqs_cos) {
+      freqs_cos_flat.insert(freqs_cos_flat.end(), row.begin(), row.end());
+    }
+    for (const auto &row : freqs_sin) {
+      freqs_sin_flat.insert(freqs_sin_flat.end(), row.begin(), row.end());
+    }
+
+    result = inputA.WriteData(context.command_queue_inst_, in);
+    if (!result) {
+      printf("Failed to write input data\n");
+      break;
+    }
+
+    result = inOutRes.WriteData(context.command_queue_inst_, out);
+    if (!result) {
+      printf("Failed to write output data\n");
+      break;
+    }
+
+    result = freqs_cosBuf.WriteData(context.command_queue_inst_,
+                                    freqs_cos_flat.data());
+    if (!result) {
+      printf("Failed to write cos data\n");
+      break;
+    }
+
+    result = freqs_sinBuf.WriteData(context.command_queue_inst_,
+                                    freqs_sin_flat.data());
+    if (!result) {
+      printf("Failed to write sin data\n");
+      break;
+    }
+
+    result = cosBuf.WriteData(context.command_queue_inst_, cos_.data());
+    if (!result) {
+      printf("Failed to write cos data\n");
+      break;
+    }
+
+    result = sinBuf.WriteData(context.command_queue_inst_, sin_.data());
+    if (!result) {
+      printf("Failed to write sin data\n");
+      break;
+    }
+
+    result =
+      kernel_rotary_emb_fp16.SetKernelArguments(0, &inputA, sizeof(cl_mem));
+    if (!result) {
+      printf("Failed to set inputA argument\n");
+      break;
+    }
+
+    result =
+      kernel_rotary_emb_fp16.SetKernelArguments(1, &inOutRes, sizeof(cl_mem));
+    if (!result) {
+      printf("Failed to set inOutRes argument\n");
+      break;
+    }
+
+    result = kernel_rotary_emb_fp16.SetKernelArguments(2, &freqs_cosBuf,
+                                                       sizeof(cl_mem));
+    if (!result) {
+      printf("Failed to set freqs_cosBuf argument\n");
+      break;
+    }
+
+    result = kernel_rotary_emb_fp16.SetKernelArguments(3, &freqs_sinBuf,
+                                                       sizeof(cl_mem));
+    if (!result) {
+      printf("Failed to set freqs_sinBuf argument\n");
+      break;
+    }
+
+    result =
+      kernel_rotary_emb_fp16.SetKernelArguments(4, &cosBuf, sizeof(cl_mem));
+    if (!result) {
+      printf("Failed to set cosBuf argument\n");
+      break;
+    }
+
+    result =
+      kernel_rotary_emb_fp16.SetKernelArguments(5, &sinBuf, sizeof(cl_mem));
+    if (!result) {
+      printf("Failed to set sinBuf argument\n");
+      break;
+    }
+
+    result = kernel_rotary_emb_fp16.SetKernelArguments(6, &batch, sizeof(int));
+    if (!result) {
+      printf("Failed to set batch argument\n");
+      break;
+    }
+
+    result =
+      kernel_rotary_emb_fp16.SetKernelArguments(7, &channel, sizeof(int));
+    if (!result) {
+      printf("Failed to set channel argument\n");
+      break;
+    }
+
+    result = kernel_rotary_emb_fp16.SetKernelArguments(8, &height, sizeof(int));
+    if (!result) {
+      printf("Failed to set height argument\n");
+      break;
+    }
+
+    result = kernel_rotary_emb_fp16.SetKernelArguments(9, &width, sizeof(int));
+    if (!result) {
+      printf("Failed to set width argument\n");
+      break;
+    }
+
+    result = kernel_rotary_emb_fp16.SetKernelArguments(10, &dim, sizeof(int));
+    if (!result) {
+      printf("Failed to set dim argument\n");
+      break;
+    }
+    unsigned int half_ = dim / 2;
+    result = kernel_rotary_emb_fp16.SetKernelArguments(11, &half_, sizeof(int));
+    if (!result) {
+      printf("Failed to set half argument\n");
+      break;
+    }
+
+    result =
+      kernel_rotary_emb_fp16.SetKernelArguments(12, &max_timestep, sizeof(int));
+    if (!result) {
+      printf("Failed to set timestamp argument\n");
+      break;
+    }
+
+    result = kernel_rotary_emb_fp16.SetKernelArguments(13, &from, sizeof(int));
+    if (!result) {
+      printf("Failed to set from argument\n");
+      break;
+    }
+
+    const int work_groups_count[3] = {1, 1, 1};
+    const int work_group_size[3] = {32, 1, 1}; // test-value
+    result = context.command_queue_inst_.DispatchCommand(
+      kernel_rotary_emb_fp16, work_groups_count, work_group_size);
+    if (!result) {
+      printf("Failed to dispatch command\n");
+      break;
+    }
+
+    result = inOutRes.ReadData(context.command_queue_inst_, out);
+    if (!result) {
+      printf("Failed to read data\n");
+      break;
+    }
+
+  } while (false);
+}
+} // namespace nntrainer
\ No newline at end of file
index 43e95f7fe94b2ba852313b49e6d975ee30aa1022..3f186ec645a61010b05ffe1b363fa4e7d8a1cb53 100644 (file)
@@ -1,15 +1,19 @@
 cl_op_sources = [
   'blas_kernels.cpp',
   'blas_kernel_interface.cpp',
+  'attention_kernel_interface.cpp',
+  'attention_kernels.cpp',
 ]
 
 cl_op_headers = [
   'blas_kernel_interface.h',
   'blas_kernel_strings.h',
+  'attention_kernel_interface.h',
 ]
 
 if get_option('enable-fp16')
   cl_op_sources += 'blas_kernels_fp16.cpp'
+  cl_op_sources += 'attention_kernels_fp16.cpp'
 endif
 
 foreach s : cl_op_sources
diff --git a/nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp b/nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp
new file mode 100644 (file)
index 0000000..c13cbc0
--- /dev/null
@@ -0,0 +1,195 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file       testing_rotary_emb.cpp
+ * @date       28 August 2024
+ * @brief      Rotary Embedding CPU code
+ * @see                https://github.com/nnstreamer/nntrainer
+ * @author     Yash Singh <yash.singh@samsung.com>
+ * @bug                No known bugs except for NYI items
+ *
+ */
+
+#include "tensor.h"
+#include <string>
+
+/**
+ * @brief     compute frequency for rotary embedding
+ * @param[in] dim hidden dim size
+ * @param[in] seq_len sequency length
+ * @param[out] freqs_cos cosine of the frequencies
+ * @param[out] freqs_sin sine of the frequencies
+ * @param[out] freqs base frequencies array to be used in computation of cos and
+ * sin values for each position in sequence
+ * @param[in] theta rotary angle
+ */
+void precompute_freqs(int dim, unsigned int seq_len,
+                      std::vector<std::vector<float>> &freqs_cos,
+                      std::vector<std::vector<float>> &freqs_sin,
+                      std::vector<float> &freqs, float theta = 10000.0) {
+  if (freqs_cos.empty()) {
+    unsigned int half_ = dim / 2;
+    for (unsigned int i = 0; i < half_; ++i) {
+      freqs.push_back(1.0 /
+                      (std::pow(theta, (2 * i) / static_cast<float>(dim))));
+    }
+
+    auto cos = std::vector<std::vector<float>>();
+    cos.assign(seq_len, std::vector<float>(dim, 0));
+
+    auto sin = std::vector<std::vector<float>>();
+    sin.assign(seq_len, std::vector<float>(dim, 0));
+
+    for (unsigned int i = 0; i < seq_len; ++i) {
+#ifdef USE_NEON
+      nntrainer::calc_trigonometric_vals_dup(half_, freqs.data(), cos[i].data(),
+                                             sin[i].data(), i);
+#else
+      for (unsigned int j = 0; j < half_; ++j) {
+        float angle = i * freqs[j];
+        cos[i][j] = std::cos(angle);
+        cos[i][j + half_] = std::cos(angle); // repeated 2 times
+
+        sin[i][j] = std::sin(angle);
+        sin[i][j + half_] = std::sin(angle); // repeated 2 times
+      }
+#endif
+    }
+    freqs_cos = cos;
+    freqs_sin = sin;
+  }
+}
+
+/**
+ * @brief     apply rotary embedding
+ * @param[in] in input tensor
+ * @param[in] dim hidden dim size
+ * @param[in] from sequence order
+ * @param[in] max_timestep maximum timestep
+ */
+void apply_rotary_emb_tensor(nntrainer::Tensor &in, unsigned int dim,
+                             unsigned int from, unsigned int max_timestep) {
+  nntrainer::Tensor out(in.getDim());
+  float value = 0;
+  float transformed_value = 0.0;
+  unsigned int half_ = dim / 2;
+
+  std::vector<std::vector<float>> freqs_cos = {};
+  std::vector<std::vector<float>> freqs_sin = {};
+  std::vector<float> freqs;
+
+  precompute_freqs(dim, max_timestep, freqs_cos, freqs_sin, freqs);
+
+  std::vector<float> cos_;
+  std::vector<float> sin_;
+
+  if (from >= max_timestep) {
+    cos_ = std::vector<float>(dim);
+    sin_ = std::vector<float>(dim);
+#ifdef USE_NEON
+    nntrainer::calc_trigonometric_vals_dup(half_, freqs.data(), cos_.data(),
+                                           sin_.data(), from);
+#else
+    for (unsigned int i = 0; i < half_; ++i) {
+      float angle = from * freqs[i];
+      cos_[i] = std::cos(angle);
+      cos_[i + half_] = std::cos(angle); // repeated 2 times
+
+      sin_[i] = std::sin(angle);
+      sin_[i + half_] = std::sin(angle); // repeated 2 times
+    }
+#endif
+  } else {
+    cos_.resize(max_timestep);
+    sin_.resize(max_timestep);
+  }
+
+  if (in.getDataType() == ml::train::TensorDim::DataType::FP32) {
+
+    unsigned int input_batch_size, input_height, input_width, input_channels;
+    input_batch_size = in.batch();
+    input_height = in.height();
+    input_width = in.width();
+    input_channels = in.channel();
+
+    for (unsigned int b = 0; b < in.batch(); b++) {
+      for (unsigned int c = 0; c < in.channel(); c++) {
+        for (unsigned int h = 0; h < in.height(); h++) {
+          if (from + h < max_timestep) {
+            cos_ = freqs_cos[from + h];
+            sin_ = freqs_sin[from + h];
+          }
+
+          for (unsigned int w = 0; w < in.width(); w = w + dim) {
+            for (unsigned int k = 0; k < dim; k++) {
+              unsigned int span = w + k;
+              if (span < in.width()) {
+                value = in.getValue<float>(b, c, h, span);
+                if (k < half_) {
+                  transformed_value =
+                    -1.0 * in.getValue<float>(b, c, h, span + half_);
+                } else {
+                  transformed_value = in.getValue<float>(b, c, h, span - half_);
+                }
+                value = value * cos_[k] + transformed_value * sin_[k];
+                // printf("CPU Batch: %u, Channel: %u, Height: %u, Width: %u, K:
+                // %u, Span: %u, Value: %f, Transformed Value: %f, cos_ptr[k]:
+                // %f, sin_ptr[k]: %f\n ",  b, c, h, w, k, span, value,
+                // transformed_value, cos_[k], sin_[k]);
+                out.setValue(b, c, h, span, value);
+              }
+            }
+          }
+        }
+      }
+    }
+  } else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+    for (unsigned int b = 0; b < in.batch(); b++) {
+      for (unsigned int c = 0; c < in.channel(); c++) {
+        for (unsigned int h = 0; h < in.height(); h++) {
+          if (from + h < max_timestep) {
+            cos_ = freqs_cos[from + h];
+            sin_ = freqs_sin[from + h];
+          }
+          for (unsigned int w = 0; w < in.width(); w = w + dim) {
+#ifdef USE_NEON
+            nntrainer::compute_rotary_embedding_value(
+              dim, half_, w, in.getData<_FP16>() + in.getIndex(b, c, h, 0),
+              out.getData<_FP16>() + out.getIndex(b, c, h, 0), cos_.data(),
+              sin_.data());
+#else
+            for (unsigned int k = 0; k < dim; k++) {
+              unsigned int span = w + k;
+              value = static_cast<float>(in.getValue<_FP16>(b, c, h, span));
+
+              if (k < half_) {
+                transformed_value =
+                  -1.0 *
+                  static_cast<float>(in.getValue<_FP16>(b, c, h, half_ + span));
+              } else {
+                transformed_value =
+                  static_cast<float>(in.getValue<_FP16>(b, c, h, span - half_));
+              }
+              out.setValue(b, c, h, span,
+                           static_cast<_FP16>(value * cos_[k] +
+                                              transformed_value * sin_[k]));
+            }
+#endif
+          }
+        }
+      }
+    }
+#else
+    throw std::invalid_argument("Error: enable-fp16 is not enabled");
+#endif
+  }
+
+  if (from >= max_timestep) {
+    cos_.clear();
+    sin_.clear();
+  }
+
+  in.copy(out);
+}
\ No newline at end of file
index 153b4eb84020119f9d376232655c468a139a0a05..faaba46f45eb5852604580a9d518eee3d0b37bf2 100644 (file)
@@ -499,6 +499,22 @@ LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer
 LOCAL_STATIC_LIBRARIES := googletest_main test_util
 include $(BUILD_EXECUTABLE)
 
+include $(CLEAR_VARS)
+
+LOCAL_MODULE := unittest_attention_kernels_cl
+LOCAL_CFLAGS := -Igoogletest/include -I../include -I../unittest/layers -I../../nntrainer/layers/loss -pthread -fexceptions -fopenmp -static-openmp -DMIN_CPP_VERSION=201703L -DNNTR_NUM_THREADS=1 -D__LOGGING__=1 -DENABLE_TEST=1 -DREDUCE_TOLERANCE=1 -march=armv8.2-a+fp16 -mfpu=neon-fp16 -mfloat-abi=softfp -O3 -frtti -DNDK_BUILD=1 -DENABLE_FP16=1 -DENABLE_OPENCL=1 
+LOCAL_CXXFLAGS      += -std=c++17 -frtti -fexceptions
+LOCAL_LDLIBS        := -llog -landroid -fopenmp -static-openmp
+
+LOCAL_SRC_FILES := \
+        ../unittest/unittest_attention_kernels_cl.cpp
+
+LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES)
+
+LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer
+LOCAL_STATIC_LIBRARIES := googletest_main test_util
+include $(BUILD_EXECUTABLE)
+
 # unittest_ccapi
 include $(CLEAR_VARS)
 
diff --git a/test/unittest/unittest_attention_kernels_cl.cpp b/test/unittest/unittest_attention_kernels_cl.cpp
new file mode 100644 (file)
index 0000000..7a09e5c
--- /dev/null
@@ -0,0 +1,233 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file       unittest_attention_kernels_cl.cpp
+ * @date       28 August 2024
+ * @brief      Test setup for blas OpenCL kernels
+ * @see                https://github.com/nnstreamer/nntrainer
+ * @author     Yash Singh <yash.singh@samsung.com>
+ * @bug                No known bugs except for NYI items
+ */
+
+#include <fstream>
+#include <gtest/gtest.h>
+#include <type_traits>
+
+#include "nntrainer_test_util.h"
+#include "util_func.h"
+#include <attention_kernel_interface.h>
+#include <cl_context.h>
+#include <layer_context.h>
+#include <tensor.h>
+
+#include "testing_rotarty_emb.cpp"
+
+#define EXPECT_IN_RANGE(VAL, MIN, MAX) \
+  EXPECT_GE((VAL), (MIN));             \
+  EXPECT_LE((VAL), (MAX))
+
+using namespace nntrainer;
+
+static RunLayerContext setUpGpuContext() {
+
+  auto &ac = nntrainer::ClContext::Global();
+  auto rc = RunLayerContext();
+
+  return rc;
+}
+
+TEST(attention_kernels, rotary_emb_kernel_FP32) {
+  RunLayerContext rc = setUpGpuContext();
+
+  int batch = 1;
+  int channel = 1;
+  int height = 4;
+  int width = 4;
+
+  unsigned int dim = 2;
+  unsigned int from = 4;
+  unsigned int max_timestep = 4;
+
+  const float alpha = 1e-1;
+  const int MOD = 10;
+
+  nntrainer::TensorDim::TensorType t_type_nchw_fp32 = {
+    nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32};
+
+  nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32);
+  nntrainer::Tensor B_fp32(batch, channel, height, width, t_type_nchw_fp32);
+
+  GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) +
+                           j * (batch * height) + k * (width) + l + 1) %
+                          MOD) *
+                           alpha);
+
+  B_fp32.copy(A_fp32);
+
+  // std::cout << "\nA_fp32 and B_fp32 before rotary embedding:" << std::endl;
+  // for (unsigned int i = 0; i < A_fp32.size(); ++i) {
+  //   std::cout << "Element " << i << " -> " << *(A_fp32.getData<float>() + i)
+  //   <<"\t"<<*(B_fp32.getData<float>() + i)<< std::endl;
+  // }
+
+  apply_rotary_emb_cl(A_fp32, dim, from, max_timestep, rc);
+  apply_rotary_emb_tensor(B_fp32, dim, from, max_timestep);
+
+  float mseErrorNeon_fp32 =
+    mse<float>(A_fp32.getData<float>(), B_fp32.getData<float>(), A_fp32.size());
+
+  double cosSimNeon_fp32 = cosine_similarity<float>(
+    A_fp32.getData<float>(), B_fp32.getData<float>(), A_fp32.size());
+
+  const float epsilon = 1e-3 * width;
+
+  EXPECT_IN_RANGE(mseErrorNeon_fp32, 0, epsilon);
+  EXPECT_IN_RANGE((float)cosSimNeon_fp32, 0.99, 1);
+}
+
+TEST(attention_kernels, rotary_emb_kernel_FP32_case2) {
+  RunLayerContext rc = setUpGpuContext();
+
+  int batch = 4;
+  int channel = 4;
+  int height = 8;
+  int width = 8;
+
+  unsigned int dim = 2;
+  unsigned int from = 2;
+  unsigned int max_timestep = 4;
+
+  const float alpha = 1e-1;
+  const int MOD = 10;
+
+  nntrainer::TensorDim::TensorType t_type_nchw_fp32 = {
+    nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32};
+
+  nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32);
+  nntrainer::Tensor B_fp32(batch, channel, height, width, t_type_nchw_fp32);
+
+  GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) +
+                           j * (batch * height) + k * (width) + l + 1) %
+                          MOD) *
+                           alpha);
+
+  B_fp32.copy(A_fp32);
+
+  apply_rotary_emb_cl(A_fp32, dim, from, max_timestep, rc);
+  apply_rotary_emb_tensor(B_fp32, dim, from, max_timestep);
+
+  float mseErrorNeon_fp32 =
+    mse<float>(A_fp32.getData<float>(), B_fp32.getData<float>(), A_fp32.size());
+
+  double cosSimNeon_fp32 = cosine_similarity<float>(
+    A_fp32.getData<float>(), B_fp32.getData<float>(), A_fp32.size());
+
+  const float epsilon = 1e-3 * width;
+
+  EXPECT_IN_RANGE(mseErrorNeon_fp32, 0, epsilon);
+  EXPECT_IN_RANGE((float)cosSimNeon_fp32, 0.99, 1);
+}
+
+TEST(attention_kernels, rotary_emb_kernel_FP16) {
+  RunLayerContext rc = setUpGpuContext();
+
+  int batch = 1;
+  int channel = 1;
+  int height = 4;
+  int width = 4;
+
+  unsigned int dim = 2;
+  unsigned int from = 4;
+  unsigned int max_timestep = 4;
+
+  const float alpha = 1e-1;
+  const int MOD = 10;
+
+  nntrainer::TensorDim::TensorType t_type_nchw_fp16 = {
+    nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16};
+
+  nntrainer::Tensor A_fp16(batch, channel, height, width, t_type_nchw_fp16);
+  nntrainer::Tensor B_fp16(batch, channel, height, width, t_type_nchw_fp16);
+
+  GEN_TEST_INPUT(A_fp16, i * (batch * height * channel) * alpha +
+                           j * (batch * height) * alpha + k * (width)*alpha +
+                           l + 1);
+
+  B_fp16.copy(A_fp16);
+
+  apply_rotary_emb_cl(A_fp16, dim, from, max_timestep, rc);
+  apply_rotary_emb_tensor(B_fp16, dim, from, max_timestep);
+
+  float mseErrorNeon_fp16 = mse<__fp16>(
+    A_fp16.getData<__fp16>(), B_fp16.getData<__fp16>(), A_fp16.size());
+
+  double cosSimNeon_fp16 = cosine_similarity<__fp16>(
+    A_fp16.getData<__fp16>(), B_fp16.getData<__fp16>(), A_fp16.size());
+
+  const float epsilon = 1e-3 * width;
+
+  EXPECT_IN_RANGE(mseErrorNeon_fp16, 0, epsilon);
+  EXPECT_IN_RANGE((float)cosSimNeon_fp16, 0.99, 1);
+}
+
+TEST(attention_kernels, rotary_emb_kernel_FP16_case2) {
+  RunLayerContext rc = setUpGpuContext();
+
+  int batch = 4;
+  int channel = 4;
+  int height = 8;
+  int width = 8;
+
+  unsigned int dim = 4;
+  unsigned int from = 4;
+  unsigned int max_timestep = 8;
+
+  const float alpha = 1e-1;
+  const int MOD = 10;
+
+  nntrainer::TensorDim::TensorType t_type_nchw_fp16 = {
+    nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16};
+
+  nntrainer::Tensor A_fp16(batch, channel, height, width, t_type_nchw_fp16);
+  nntrainer::Tensor B_fp16(batch, channel, height, width, t_type_nchw_fp16);
+
+  GEN_TEST_INPUT(A_fp16, i * (batch * height * channel) * alpha +
+                           j * (batch * height) * alpha + k * (width)*alpha +
+                           l + 1);
+
+  B_fp16.copy(A_fp16);
+
+  apply_rotary_emb_cl(A_fp16, dim, from, max_timestep, rc);
+  apply_rotary_emb_tensor(B_fp16, dim, from, max_timestep);
+
+  float mseErrorNeon_fp16 = mse<__fp16>(
+    A_fp16.getData<__fp16>(), B_fp16.getData<__fp16>(), A_fp16.size());
+
+  double cosSimNeon_fp16 = cosine_similarity<__fp16>(
+    A_fp16.getData<__fp16>(), B_fp16.getData<__fp16>(), A_fp16.size());
+
+  const float epsilon = 1e-3 * width;
+
+  EXPECT_IN_RANGE(mseErrorNeon_fp16, 0, epsilon);
+  EXPECT_IN_RANGE((float)cosSimNeon_fp16, 0.99, 1);
+}
+
+GTEST_API_ int main(int argc, char **argv) {
+  int result = -1;
+
+  try {
+    testing::InitGoogleTest(&argc, argv);
+  } catch (...) {
+    std::cerr << "Error during InitGoogleTest" << std::endl;
+    return 0;
+  }
+
+  try {
+    result = RUN_ALL_TESTS();
+  } catch (...) {
+    std::cerr << "Error during RUN_ALL_TESTS()" << std::endl;
+  }
+
+  return result;
+}