Added initial version of Rotary Embedding kernel for GPU. This includes both FP32 and FP16 implementation got GPU kernel.
Signed-off-by: Yash Singh <yash.singh@samsung.com>
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file attention_kernel_interface.cpp
+ * @date 28 August 2024
+ * @brief Interface for attention OpenCL kernels
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Yash Singh <yash.singh@samsung.com>
+ * @bug No known bugs except for NYI items
+ *
+ */
+
+#include <attention_kernel_interface.h>
+#include <attention_kernels.h>
+
+namespace nntrainer {
+/**
+ * @brief compute frequency for rotary embedding
+ * @param[in] dim hidden dim size
+ * @param[in] seq_len sequency length
+ * @param[out] freqs_cos cosine of the frequencies
+ * @param[out] freqs_sin sine of the frequencies
+ * @param[out] freqs base frequencies array to be used in the future computation
+ * @param[in] theta rotary angle
+ */
+void precompute_freqs(int dim, unsigned int seq_len,
+ std::vector<std::vector<float>> &freqs_cos,
+ std::vector<std::vector<float>> &freqs_sin,
+ std::vector<float> &freqs, float theta = 10000.0) {
+ unsigned int half_ = dim / 2;
+ for (unsigned int i = 0; i < half_; ++i) {
+ freqs.push_back(1.0 / (std::pow(theta, (2 * i) / static_cast<float>(dim))));
+ }
+
+ auto cos = std::vector<std::vector<float>>();
+ cos.assign(seq_len, std::vector<float>(dim, 0));
+
+ auto sin = std::vector<std::vector<float>>();
+ sin.assign(seq_len, std::vector<float>(dim, 0));
+
+ for (unsigned int i = 0; i < seq_len; ++i) {
+ for (unsigned int j = 0; j < half_; ++j) {
+ float angle = i * freqs[j];
+ cos[i][j] = std::cos(angle);
+ cos[i][j + half_] = std::cos(angle); // repeated 2 times
+
+ sin[i][j] = std::sin(angle);
+ sin[i][j + half_] = std::sin(angle); // repeated 2 times
+ }
+ }
+ freqs_cos = cos;
+ freqs_sin = sin;
+}
+
+/**
+ * @brief apply rotary embedding
+ * @param[in] in input tensor
+ * @param[in] dim hidden dim size
+ * @param[in] from sequence order
+ * @param[in] max_timestep maximum timestep
+ */
+void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from,
+ unsigned int max_timestep, RunLayerContext &context) {
+ nntrainer::Tensor out(in.getDim());
+ float value = 0;
+ float transformed_value = 0.0;
+ unsigned int half_ = dim / 2;
+
+ std::vector<std::vector<float>> freqs_cos = {};
+ std::vector<std::vector<float>> freqs_sin = {};
+ std::vector<float> freqs;
+
+ precompute_freqs(dim, max_timestep, freqs_cos, freqs_sin, freqs);
+
+ std::vector<float> cos_;
+ std::vector<float> sin_;
+
+ if (from >= max_timestep) {
+ cos_.resize(dim);
+ sin_.resize(dim);
+
+ for (unsigned int i = 0; i < half_; ++i) {
+ float angle = from * freqs[i];
+ cos_[i] = std::cos(angle);
+ cos_[i + half_] = std::cos(angle); // repeated 2 times
+
+ sin_[i] = std::sin(angle);
+ sin_[i + half_] = std::sin(angle); // repeated 2 times
+ }
+ } else {
+ cos_.resize(max_timestep);
+ sin_.resize(max_timestep);
+ }
+
+ unsigned int input_batch_size, input_height, input_width, input_channels;
+ input_batch_size = in.batch();
+ input_height = in.height();
+ input_width = in.width();
+ input_channels = in.channel();
+
+ if (in.getDataType() == ml::train::TensorDim::DataType::FP32) {
+
+ unsigned int in_size = in.size();
+ unsigned int out_size = out.size();
+ float *data = in.getData();
+ float *rdata = out.getData();
+
+ rotary_emb_cl(data, rdata, freqs_cos, freqs_sin, cos_, sin_,
+ input_batch_size, input_channels, input_height, input_width,
+ dim, from, max_timestep, in_size, out_size, context);
+
+ } else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+
+ unsigned int in_size = in.size();
+ unsigned int out_size = out.size();
+ _FP16 *data = in.getData<_FP16>();
+ _FP16 *rdata = out.getData<_FP16>();
+
+ rotary_emb_cl(data, rdata, freqs_cos, freqs_sin, cos_, sin_,
+ input_batch_size, input_channels, input_height, input_width,
+ dim, from, max_timestep, in_size, out_size, context);
+#else
+ throw std::invalid_argument("Error: enable-fp16 is not enabled");
+#endif
+ }
+
+ if (from >= max_timestep) {
+ cos_.clear();
+ sin_.clear();
+ }
+
+ in.copy(out);
+}
+} // namespace nntrainer
\ No newline at end of file
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file blas_kernel_interface.h
+ * @date 28 August 2024
+ * @brief Interface for attention OpenCL kernels
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Yash Singh <yash.singh@samsung.com>
+ * @bug No known bugs except for NYI items
+ *
+ */
+
+#ifndef __ATTENTION_KERNEL_INTERFACE_H__
+#define __ATTENTION_KERNEL_INTERFACE_H__
+
+#include <layer_context.h>
+#include <string>
+
+namespace nntrainer {
+
+/**
+ * @brief Rotary Embedding kernel
+ * @param[in] in input tensor
+ * @param[in] dim hidden dim size
+ * @param[in] from sequence order
+ * @param[in] max_timestep maximum timestep
+ * @param[in] context layer context to get the resource manager and queue id
+ */
+void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from,
+ unsigned int max_timestep, RunLayerContext &context);
+
+} // namespace nntrainer
+#endif /* __ATTENTION_KERNEL_INTERFACE_H__ */
\ No newline at end of file
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file attention_kernels.cpp
+ * @date 28 August 2024
+ * @brief Common attention OpenCL kernels
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Yash Singh <yash.singh@samsung.com>
+ * @bug No known bugs except for NYI items
+ *
+ */
+
+#include <attention_kernels.h>
+
+namespace nntrainer {
+std::string rotary_emb_cl_kernel = R"(
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
+__kernel void rotary_emb_cl(__global float *input,
+ __global float *output,
+ __global float *freqs_cos,
+ __global float *freqs_sin,
+ __global float *cos_,
+ __global float *sin_,
+ unsigned int batch,
+ unsigned int channel,
+ unsigned int height,
+ unsigned int width,
+ unsigned int dim,
+ unsigned int half_,
+ unsigned int max_timestep,
+ unsigned int from) {
+ unsigned int gid = get_global_id(0);
+ unsigned int gws = get_global_size(0);
+
+ __global float *cos_ptr = cos_;
+ __global float *sin_ptr = sin_;
+
+ float value = 0.0f;
+ float transformed_value = 0.0f;
+
+ for (unsigned int b = 0; b < batch; b++) {
+ for (unsigned int c = 0; c < channel; c++) {
+ for (unsigned int h = 0; h < height; h++) {
+ if (from + h < max_timestep) {
+ unsigned idx = (from + h)*dim;
+ for(unsigned int i = idx; i < idx + dim; i++){
+ cos_ptr[i - idx] = freqs_cos[i];
+ sin_ptr[i - idx] = freqs_sin[i];
+ }
+ }
+ for (unsigned int w = 0; w < width; w = w + dim) {
+ for (unsigned int k = 0; k < dim; k++) {
+ unsigned int span = w + k;
+ value = input[b * channel * height * width + c * height * width + h * width + span];
+ if (k < half_) {
+ transformed_value = -1.0f * input[b * channel * height * width + c * height * width + h * width + span + half_];
+ } else {
+ transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_];
+ }
+ value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
+ // printf("GPU Batch: %u, Height: %u, Channel: %u, Width: %u, K: %u, Span: %u, Value: %f, Transformed Value: %f, cos_ptr[k]: %f, sin_ptr[k]: %f\n", b, h, c, w, k, span, value, transformed_value, cos_ptr[k], sin_ptr[k]);
+ output[b * channel * height * width + c * height * width + h * width + span] = value;
+ }
+ }
+ }
+ }
+ }
+}
+)";
+
+/**
+ * @brief defining global kernel objects
+ */
+opencl::Kernel kernel_rotary_emb;
+
+void rotary_emb_cl(float *in, float *out,
+ std::vector<std::vector<float>> freqs_cos,
+ std::vector<std::vector<float>> freqs_sin,
+ std::vector<float> cos_, std::vector<float> sin_,
+ unsigned int batch, unsigned int channel,
+ unsigned int height, unsigned int width, unsigned int dim,
+ unsigned int from, unsigned int max_timestep,
+ unsigned int in_size, unsigned int out_size,
+ RunLayerContext &context) {
+ bool result = false;
+
+ do {
+ result = context.clCreateKernel(
+ rotary_emb_cl_kernel, context.LayerKernel::ROTARY_EMB, kernel_rotary_emb);
+ if (!result) {
+ printf("Failed to create kernel for rotary_emb_cl\n");
+ break;
+ }
+ unsigned int cos_dim = cos_.size();
+ unsigned int sin_dim = sin_.size();
+ unsigned int freqs_cos_dim = freqs_cos.size();
+ unsigned int freqs_sin_dim = freqs_sin.size();
+
+ size_t dim1_size = sizeof(float) * in_size;
+ size_t dim2_size = sizeof(float) * out_size;
+ size_t dim3_size = sizeof(float) * cos_dim;
+ size_t dim4_size = sizeof(float) * sin_dim;
+ size_t dim5_size =
+ sizeof(float) * freqs_cos_dim * dim; // max_timestep * dim
+ size_t dim6_size = sizeof(float) * freqs_sin_dim * dim;
+
+ opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr);
+
+ opencl::Buffer inOutRes(context.context_inst_, dim2_size, true, nullptr);
+
+ opencl::Buffer cosBuf(context.context_inst_, dim3_size, true, nullptr);
+
+ opencl::Buffer sinBuf(context.context_inst_, dim4_size, true, nullptr);
+
+ opencl::Buffer freqs_cosBuf(context.context_inst_, dim5_size, true,
+ nullptr);
+
+ opencl::Buffer freqs_sinBuf(context.context_inst_, dim6_size, true,
+ nullptr);
+
+ std::vector<float> freqs_cos_flat;
+ std::vector<float> freqs_sin_flat;
+ for (const auto &row : freqs_cos) {
+ freqs_cos_flat.insert(freqs_cos_flat.end(), row.begin(), row.end());
+ }
+ for (const auto &row : freqs_sin) {
+ freqs_sin_flat.insert(freqs_sin_flat.end(), row.begin(), row.end());
+ }
+
+ result = inputA.WriteData(context.command_queue_inst_, in);
+ if (!result) {
+ printf("Failed to write input data\n");
+ break;
+ }
+
+ result = inOutRes.WriteData(context.command_queue_inst_, out);
+ if (!result) {
+ printf("Failed to write output data\n");
+ break;
+ }
+
+ result = freqs_cosBuf.WriteData(context.command_queue_inst_,
+ freqs_cos_flat.data());
+ if (!result) {
+ printf("Failed to write freqs cos data\n");
+ break;
+ }
+
+ result = freqs_sinBuf.WriteData(context.command_queue_inst_,
+ freqs_sin_flat.data());
+ if (!result) {
+ printf("Failed to write freqs sin data\n");
+ break;
+ }
+
+ result = cosBuf.WriteData(context.command_queue_inst_, cos_.data());
+ if (!result) {
+ printf("Failed to write cos data\n");
+ break;
+ }
+
+ result = sinBuf.WriteData(context.command_queue_inst_, sin_.data());
+ if (!result) {
+ printf("Failed to write sin data\n");
+ break;
+ }
+
+ result = kernel_rotary_emb.SetKernelArguments(0, &inputA, sizeof(cl_mem));
+ if (!result) {
+ printf("Failed to set inputA argument\n");
+ break;
+ }
+
+ result = kernel_rotary_emb.SetKernelArguments(1, &inOutRes, sizeof(cl_mem));
+ if (!result) {
+ printf("Failed to set inOutRes argument\n");
+ break;
+ }
+
+ result =
+ kernel_rotary_emb.SetKernelArguments(2, &freqs_cosBuf, sizeof(cl_mem));
+ if (!result) {
+ printf("Failed to set freqs_cosBuf argument\n");
+ break;
+ }
+
+ result =
+ kernel_rotary_emb.SetKernelArguments(3, &freqs_sinBuf, sizeof(cl_mem));
+ if (!result) {
+ printf("Failed to set freqs_sinBuf argument\n");
+ break;
+ }
+
+ result = kernel_rotary_emb.SetKernelArguments(4, &cosBuf, sizeof(cl_mem));
+ if (!result) {
+ printf("Failed to set cosBuf argument\n");
+ break;
+ }
+
+ result = kernel_rotary_emb.SetKernelArguments(5, &sinBuf, sizeof(cl_mem));
+ if (!result) {
+ printf("Failed to set sinBuf argument\n");
+ break;
+ }
+
+ result = kernel_rotary_emb.SetKernelArguments(6, &batch, sizeof(int));
+ if (!result) {
+ printf("Failed to set batch argument\n");
+ break;
+ }
+
+ result = kernel_rotary_emb.SetKernelArguments(7, &channel, sizeof(int));
+ if (!result) {
+ printf("Failed to set channel argument\n");
+ break;
+ }
+
+ result = kernel_rotary_emb.SetKernelArguments(8, &height, sizeof(int));
+ if (!result) {
+ printf("Failed to set height argument\n");
+ break;
+ }
+
+ result = kernel_rotary_emb.SetKernelArguments(9, &width, sizeof(int));
+ if (!result) {
+ printf("Failed to set width argument\n");
+ break;
+ }
+
+ result = kernel_rotary_emb.SetKernelArguments(10, &dim, sizeof(int));
+ if (!result) {
+ printf("Failed to set dim argument\n");
+ break;
+ }
+ unsigned int half_ = dim / 2;
+ result = kernel_rotary_emb.SetKernelArguments(11, &half_, sizeof(int));
+ if (!result) {
+ printf("Failed to set half argument\n");
+ break;
+ }
+
+ result =
+ kernel_rotary_emb.SetKernelArguments(12, &max_timestep, sizeof(int));
+ if (!result) {
+ printf("Failed to set timestamp argument\n");
+ break;
+ }
+
+ result = kernel_rotary_emb.SetKernelArguments(13, &from, sizeof(int));
+ if (!result) {
+ printf("Failed to set from argument\n");
+ break;
+ }
+
+ const int work_groups_count[3] = {1, 1, 1};
+ const int work_group_size[3] = {32, 1, 1}; // test-value
+ result = context.command_queue_inst_.DispatchCommand(
+ kernel_rotary_emb, work_groups_count, work_group_size);
+ if (!result) {
+ printf("Failed to dispatch command\n");
+ break;
+ }
+
+ result = inOutRes.ReadData(context.command_queue_inst_, out);
+ if (!result) {
+ printf("Failed to read data\n");
+ break;
+ }
+
+ } while (false);
+}
+} // namespace nntrainer
\ No newline at end of file
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file attention_kernels.h
+ * @date 28 August 2024
+ * @brief Common attention OpenCL kernels
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Yash Singh <yash.singh@samsung.com>
+ * @bug No known bugs except for NYI items
+ *
+ */
+
+#ifndef __ATTENTION_KERNELS_H__
+#define __ATTENTION_KERNELS_H__
+
+#include <layer_context.h>
+#include <opencl_buffer.h>
+#include <opencl_kernel.h>
+#include <string>
+
+namespace nntrainer {
+
+/**
+ * @brief declaring global kernel objects
+ */
+extern opencl::Kernel kernel_rotary_emb;
+
+/**
+ * @brief Rotary Embedding process
+ * @param[in] in __fp16 * input
+ * @param[in] out __fp16 * output
+ * @param[out] freqs_cos cosine of the frequencies
+ * @param[out] freqs_sin sine of the frequencies
+ * @param[in] cos_ vector of cos values
+ * @param[in] sin_ vector of sin values
+ * @param[in] batch size of batch
+ * @param[in] channel channel of input
+ * @param[in] height height of input
+ * @param[in] width width of input
+ * @param[in] dim hidden dim size
+ * @param[in] from sequence order
+ * @param[in] max_timestep max timestep
+ * @param[in] in_size size of input
+ * @param[in] out_size size of output
+ * @param[in] context RunLayerContext reference
+ */
+void rotary_emb_cl(float *in, float *out,
+ std::vector<std::vector<float>> freqs_cos,
+ std::vector<std::vector<float>> freqs_sin,
+ std::vector<float> cos_, std::vector<float> sin_,
+ unsigned int batch, unsigned int channel,
+ unsigned int height, unsigned int width, unsigned int dim,
+ unsigned int from, unsigned int max_timestamp,
+ unsigned int in_size, unsigned int out_size,
+ RunLayerContext &context);
+
+#ifdef ENABLE_FP16
+/**
+ * @brief declaring global fp16 kernel objects
+ */
+extern opencl::Kernel kernel_rotary_emb_fp16;
+
+/**
+ * @brief Rotary Embedding process
+ * @param[in] in __fp16 * input
+ * @param[in] out __fp16 * output
+ * @param[out] freqs_cos cosine of the frequencies
+ * @param[out] freqs_sin sine of the frequencies
+ * @param[in] cos_ vector of cos values
+ * @param[in] sin_ vector of sin values
+ * @param[in] batch size of batch
+ * @param[in] channel channel of input
+ * @param[in] height height of input
+ * @param[in] width width of input
+ * @param[in] dim hidden dim size
+ * @param[in] from sequence order
+ * @param[in] max_timestep max timestep
+ * @param[in] in_size size of input
+ * @param[in] out_size size of output
+ * @param[in] context RunLayerContext reference
+ */
+void rotary_emb_cl(__fp16 *in, __fp16 *out,
+ std::vector<std::vector<float>> freqs_cos,
+ std::vector<std::vector<float>> freqs_sin,
+ std::vector<float> cos_, std::vector<float> sin_,
+ unsigned int batch, unsigned int channel,
+ unsigned int height, unsigned int width, unsigned int dim,
+ unsigned int from, unsigned int max_timestamp,
+ unsigned int in_size, unsigned int out_size,
+ RunLayerContext &context);
+
+#endif
+
+} // namespace nntrainer
+#endif /* __ATTENTION_KERNELS_H__ */
\ No newline at end of file
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file attention_kernels_fp16.cpp
+ * @date 28 August 2024
+ * @brief Common attention OpenCL fp16 kernels
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Yash Singh <yash.singh@samsung.com>
+ * @bug No known bugs except for NYI items
+ *
+ */
+
+#include <attention_kernels.h>
+
+namespace nntrainer {
+std::string rotary_emb_cl_kernel_fp16 = R"(
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
+__kernel void rotary_emb_cl_fp16(__global half *input,
+ __global half *output,
+ __global float *freqs_cos,
+ __global float *freqs_sin,
+ __global float *cos_,
+ __global float *sin_,
+ unsigned int batch,
+ unsigned int channel,
+ unsigned int height,
+ unsigned int width,
+ unsigned int dim,
+ unsigned int half_,
+ unsigned int max_timestep,
+ unsigned int from) {
+ unsigned int gid = get_global_id(0);
+ unsigned int gws = get_global_size(0);
+
+ __global float *cos_ptr = cos_;
+ __global float *sin_ptr = sin_;
+
+ float value = 0.0f;
+ float transformed_value = 0.0f;
+
+ for (unsigned int b = 0; b < batch; b++) {
+ for (unsigned int c = 0; c < channel; c++) {
+ for (unsigned int h = 0; h < height; h++) {
+ if (from + h < max_timestep) {
+ unsigned idx = (from + h)*dim;
+ for(int i = idx; i < idx + dim; i++ ){
+ cos_ptr[i - idx] = freqs_cos[i];
+ sin_ptr[i - idx] = freqs_sin[i];
+ }
+ }
+
+ for (unsigned int w = 0; w < width; w = w + dim) {
+ for (unsigned int k = 0; k < dim; k++) {
+ unsigned int span = w + k;
+ value = (float)input[b * channel * height * width + c * height * width + h * width + span];
+ if (k < half_) {
+ transformed_value = -1.0f * (float)input[b * channel * height * width + c * height * width + h * width + span + half_];
+ } else {
+ transformed_value = (float)input[b * channel * height * width + c * height * width + h * width + span - half_];
+ }
+ value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
+ output[b * channel * height * width + c * height * width + h * width + span] = (half)value;
+ }
+ }
+ }
+ }
+ }
+}
+)";
+
+/**
+ * @brief defining global kernel objects
+ */
+opencl::Kernel kernel_rotary_emb_fp16;
+
+void rotary_emb_cl(__fp16 *in, __fp16 *out,
+ std::vector<std::vector<float>> freqs_cos,
+ std::vector<std::vector<float>> freqs_sin,
+ std::vector<float> cos_, std::vector<float> sin_,
+ unsigned int batch, unsigned int channel,
+ unsigned int height, unsigned int width, unsigned int dim,
+ unsigned int from, unsigned int max_timestep,
+ unsigned int in_size, unsigned int out_size,
+ RunLayerContext &context) {
+
+ bool result = false;
+ do {
+ result = context.clCreateKernel(rotary_emb_cl_kernel_fp16,
+ context.LayerKernel::ROTARY_EMB_FP16,
+ kernel_rotary_emb_fp16);
+ if (!result) {
+ printf("Failed to create kernel for rotary_emb_cl\n");
+ break;
+ }
+
+ unsigned int cos_dim = cos_.size();
+ unsigned int sin_dim = sin_.size();
+ unsigned int freqs_cos_dim = freqs_cos.size();
+ unsigned int freqs_sin_dim = freqs_sin.size();
+
+ size_t dim1_size = sizeof(cl_half) * in_size;
+ size_t dim2_size = sizeof(cl_half) * out_size;
+ size_t dim3_size = sizeof(float) * cos_dim;
+ size_t dim4_size = sizeof(float) * sin_dim;
+ size_t dim5_size = sizeof(float) * freqs_cos_dim * dim;
+ size_t dim6_size = sizeof(float) * freqs_sin_dim * dim;
+
+ opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr);
+
+ opencl::Buffer inOutRes(context.context_inst_, dim2_size, true, nullptr);
+
+ opencl::Buffer cosBuf(context.context_inst_, dim3_size, true, nullptr);
+
+ opencl::Buffer sinBuf(context.context_inst_, dim4_size, true, nullptr);
+
+ opencl::Buffer freqs_cosBuf(context.context_inst_, dim5_size, true,
+ nullptr);
+
+ opencl::Buffer freqs_sinBuf(context.context_inst_, dim6_size, true,
+ nullptr);
+
+ std::vector<float> freqs_cos_flat;
+ std::vector<float> freqs_sin_flat;
+ for (const auto &row : freqs_cos) {
+ freqs_cos_flat.insert(freqs_cos_flat.end(), row.begin(), row.end());
+ }
+ for (const auto &row : freqs_sin) {
+ freqs_sin_flat.insert(freqs_sin_flat.end(), row.begin(), row.end());
+ }
+
+ result = inputA.WriteData(context.command_queue_inst_, in);
+ if (!result) {
+ printf("Failed to write input data\n");
+ break;
+ }
+
+ result = inOutRes.WriteData(context.command_queue_inst_, out);
+ if (!result) {
+ printf("Failed to write output data\n");
+ break;
+ }
+
+ result = freqs_cosBuf.WriteData(context.command_queue_inst_,
+ freqs_cos_flat.data());
+ if (!result) {
+ printf("Failed to write cos data\n");
+ break;
+ }
+
+ result = freqs_sinBuf.WriteData(context.command_queue_inst_,
+ freqs_sin_flat.data());
+ if (!result) {
+ printf("Failed to write sin data\n");
+ break;
+ }
+
+ result = cosBuf.WriteData(context.command_queue_inst_, cos_.data());
+ if (!result) {
+ printf("Failed to write cos data\n");
+ break;
+ }
+
+ result = sinBuf.WriteData(context.command_queue_inst_, sin_.data());
+ if (!result) {
+ printf("Failed to write sin data\n");
+ break;
+ }
+
+ result =
+ kernel_rotary_emb_fp16.SetKernelArguments(0, &inputA, sizeof(cl_mem));
+ if (!result) {
+ printf("Failed to set inputA argument\n");
+ break;
+ }
+
+ result =
+ kernel_rotary_emb_fp16.SetKernelArguments(1, &inOutRes, sizeof(cl_mem));
+ if (!result) {
+ printf("Failed to set inOutRes argument\n");
+ break;
+ }
+
+ result = kernel_rotary_emb_fp16.SetKernelArguments(2, &freqs_cosBuf,
+ sizeof(cl_mem));
+ if (!result) {
+ printf("Failed to set freqs_cosBuf argument\n");
+ break;
+ }
+
+ result = kernel_rotary_emb_fp16.SetKernelArguments(3, &freqs_sinBuf,
+ sizeof(cl_mem));
+ if (!result) {
+ printf("Failed to set freqs_sinBuf argument\n");
+ break;
+ }
+
+ result =
+ kernel_rotary_emb_fp16.SetKernelArguments(4, &cosBuf, sizeof(cl_mem));
+ if (!result) {
+ printf("Failed to set cosBuf argument\n");
+ break;
+ }
+
+ result =
+ kernel_rotary_emb_fp16.SetKernelArguments(5, &sinBuf, sizeof(cl_mem));
+ if (!result) {
+ printf("Failed to set sinBuf argument\n");
+ break;
+ }
+
+ result = kernel_rotary_emb_fp16.SetKernelArguments(6, &batch, sizeof(int));
+ if (!result) {
+ printf("Failed to set batch argument\n");
+ break;
+ }
+
+ result =
+ kernel_rotary_emb_fp16.SetKernelArguments(7, &channel, sizeof(int));
+ if (!result) {
+ printf("Failed to set channel argument\n");
+ break;
+ }
+
+ result = kernel_rotary_emb_fp16.SetKernelArguments(8, &height, sizeof(int));
+ if (!result) {
+ printf("Failed to set height argument\n");
+ break;
+ }
+
+ result = kernel_rotary_emb_fp16.SetKernelArguments(9, &width, sizeof(int));
+ if (!result) {
+ printf("Failed to set width argument\n");
+ break;
+ }
+
+ result = kernel_rotary_emb_fp16.SetKernelArguments(10, &dim, sizeof(int));
+ if (!result) {
+ printf("Failed to set dim argument\n");
+ break;
+ }
+ unsigned int half_ = dim / 2;
+ result = kernel_rotary_emb_fp16.SetKernelArguments(11, &half_, sizeof(int));
+ if (!result) {
+ printf("Failed to set half argument\n");
+ break;
+ }
+
+ result =
+ kernel_rotary_emb_fp16.SetKernelArguments(12, &max_timestep, sizeof(int));
+ if (!result) {
+ printf("Failed to set timestamp argument\n");
+ break;
+ }
+
+ result = kernel_rotary_emb_fp16.SetKernelArguments(13, &from, sizeof(int));
+ if (!result) {
+ printf("Failed to set from argument\n");
+ break;
+ }
+
+ const int work_groups_count[3] = {1, 1, 1};
+ const int work_group_size[3] = {32, 1, 1}; // test-value
+ result = context.command_queue_inst_.DispatchCommand(
+ kernel_rotary_emb_fp16, work_groups_count, work_group_size);
+ if (!result) {
+ printf("Failed to dispatch command\n");
+ break;
+ }
+
+ result = inOutRes.ReadData(context.command_queue_inst_, out);
+ if (!result) {
+ printf("Failed to read data\n");
+ break;
+ }
+
+ } while (false);
+}
+} // namespace nntrainer
\ No newline at end of file
cl_op_sources = [
'blas_kernels.cpp',
'blas_kernel_interface.cpp',
+ 'attention_kernel_interface.cpp',
+ 'attention_kernels.cpp',
]
cl_op_headers = [
'blas_kernel_interface.h',
'blas_kernel_strings.h',
+ 'attention_kernel_interface.h',
]
if get_option('enable-fp16')
cl_op_sources += 'blas_kernels_fp16.cpp'
+ cl_op_sources += 'attention_kernels_fp16.cpp'
endif
foreach s : cl_op_sources
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file testing_rotary_emb.cpp
+ * @date 28 August 2024
+ * @brief Rotary Embedding CPU code
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Yash Singh <yash.singh@samsung.com>
+ * @bug No known bugs except for NYI items
+ *
+ */
+
+#include "tensor.h"
+#include <string>
+
+/**
+ * @brief compute frequency for rotary embedding
+ * @param[in] dim hidden dim size
+ * @param[in] seq_len sequency length
+ * @param[out] freqs_cos cosine of the frequencies
+ * @param[out] freqs_sin sine of the frequencies
+ * @param[out] freqs base frequencies array to be used in computation of cos and
+ * sin values for each position in sequence
+ * @param[in] theta rotary angle
+ */
+void precompute_freqs(int dim, unsigned int seq_len,
+ std::vector<std::vector<float>> &freqs_cos,
+ std::vector<std::vector<float>> &freqs_sin,
+ std::vector<float> &freqs, float theta = 10000.0) {
+ if (freqs_cos.empty()) {
+ unsigned int half_ = dim / 2;
+ for (unsigned int i = 0; i < half_; ++i) {
+ freqs.push_back(1.0 /
+ (std::pow(theta, (2 * i) / static_cast<float>(dim))));
+ }
+
+ auto cos = std::vector<std::vector<float>>();
+ cos.assign(seq_len, std::vector<float>(dim, 0));
+
+ auto sin = std::vector<std::vector<float>>();
+ sin.assign(seq_len, std::vector<float>(dim, 0));
+
+ for (unsigned int i = 0; i < seq_len; ++i) {
+#ifdef USE_NEON
+ nntrainer::calc_trigonometric_vals_dup(half_, freqs.data(), cos[i].data(),
+ sin[i].data(), i);
+#else
+ for (unsigned int j = 0; j < half_; ++j) {
+ float angle = i * freqs[j];
+ cos[i][j] = std::cos(angle);
+ cos[i][j + half_] = std::cos(angle); // repeated 2 times
+
+ sin[i][j] = std::sin(angle);
+ sin[i][j + half_] = std::sin(angle); // repeated 2 times
+ }
+#endif
+ }
+ freqs_cos = cos;
+ freqs_sin = sin;
+ }
+}
+
+/**
+ * @brief apply rotary embedding
+ * @param[in] in input tensor
+ * @param[in] dim hidden dim size
+ * @param[in] from sequence order
+ * @param[in] max_timestep maximum timestep
+ */
+void apply_rotary_emb_tensor(nntrainer::Tensor &in, unsigned int dim,
+ unsigned int from, unsigned int max_timestep) {
+ nntrainer::Tensor out(in.getDim());
+ float value = 0;
+ float transformed_value = 0.0;
+ unsigned int half_ = dim / 2;
+
+ std::vector<std::vector<float>> freqs_cos = {};
+ std::vector<std::vector<float>> freqs_sin = {};
+ std::vector<float> freqs;
+
+ precompute_freqs(dim, max_timestep, freqs_cos, freqs_sin, freqs);
+
+ std::vector<float> cos_;
+ std::vector<float> sin_;
+
+ if (from >= max_timestep) {
+ cos_ = std::vector<float>(dim);
+ sin_ = std::vector<float>(dim);
+#ifdef USE_NEON
+ nntrainer::calc_trigonometric_vals_dup(half_, freqs.data(), cos_.data(),
+ sin_.data(), from);
+#else
+ for (unsigned int i = 0; i < half_; ++i) {
+ float angle = from * freqs[i];
+ cos_[i] = std::cos(angle);
+ cos_[i + half_] = std::cos(angle); // repeated 2 times
+
+ sin_[i] = std::sin(angle);
+ sin_[i + half_] = std::sin(angle); // repeated 2 times
+ }
+#endif
+ } else {
+ cos_.resize(max_timestep);
+ sin_.resize(max_timestep);
+ }
+
+ if (in.getDataType() == ml::train::TensorDim::DataType::FP32) {
+
+ unsigned int input_batch_size, input_height, input_width, input_channels;
+ input_batch_size = in.batch();
+ input_height = in.height();
+ input_width = in.width();
+ input_channels = in.channel();
+
+ for (unsigned int b = 0; b < in.batch(); b++) {
+ for (unsigned int c = 0; c < in.channel(); c++) {
+ for (unsigned int h = 0; h < in.height(); h++) {
+ if (from + h < max_timestep) {
+ cos_ = freqs_cos[from + h];
+ sin_ = freqs_sin[from + h];
+ }
+
+ for (unsigned int w = 0; w < in.width(); w = w + dim) {
+ for (unsigned int k = 0; k < dim; k++) {
+ unsigned int span = w + k;
+ if (span < in.width()) {
+ value = in.getValue<float>(b, c, h, span);
+ if (k < half_) {
+ transformed_value =
+ -1.0 * in.getValue<float>(b, c, h, span + half_);
+ } else {
+ transformed_value = in.getValue<float>(b, c, h, span - half_);
+ }
+ value = value * cos_[k] + transformed_value * sin_[k];
+ // printf("CPU Batch: %u, Channel: %u, Height: %u, Width: %u, K:
+ // %u, Span: %u, Value: %f, Transformed Value: %f, cos_ptr[k]:
+ // %f, sin_ptr[k]: %f\n ", b, c, h, w, k, span, value,
+ // transformed_value, cos_[k], sin_[k]);
+ out.setValue(b, c, h, span, value);
+ }
+ }
+ }
+ }
+ }
+ }
+ } else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+ for (unsigned int b = 0; b < in.batch(); b++) {
+ for (unsigned int c = 0; c < in.channel(); c++) {
+ for (unsigned int h = 0; h < in.height(); h++) {
+ if (from + h < max_timestep) {
+ cos_ = freqs_cos[from + h];
+ sin_ = freqs_sin[from + h];
+ }
+ for (unsigned int w = 0; w < in.width(); w = w + dim) {
+#ifdef USE_NEON
+ nntrainer::compute_rotary_embedding_value(
+ dim, half_, w, in.getData<_FP16>() + in.getIndex(b, c, h, 0),
+ out.getData<_FP16>() + out.getIndex(b, c, h, 0), cos_.data(),
+ sin_.data());
+#else
+ for (unsigned int k = 0; k < dim; k++) {
+ unsigned int span = w + k;
+ value = static_cast<float>(in.getValue<_FP16>(b, c, h, span));
+
+ if (k < half_) {
+ transformed_value =
+ -1.0 *
+ static_cast<float>(in.getValue<_FP16>(b, c, h, half_ + span));
+ } else {
+ transformed_value =
+ static_cast<float>(in.getValue<_FP16>(b, c, h, span - half_));
+ }
+ out.setValue(b, c, h, span,
+ static_cast<_FP16>(value * cos_[k] +
+ transformed_value * sin_[k]));
+ }
+#endif
+ }
+ }
+ }
+ }
+#else
+ throw std::invalid_argument("Error: enable-fp16 is not enabled");
+#endif
+ }
+
+ if (from >= max_timestep) {
+ cos_.clear();
+ sin_.clear();
+ }
+
+ in.copy(out);
+}
\ No newline at end of file
LOCAL_STATIC_LIBRARIES := googletest_main test_util
include $(BUILD_EXECUTABLE)
+include $(CLEAR_VARS)
+
+LOCAL_MODULE := unittest_attention_kernels_cl
+LOCAL_CFLAGS := -Igoogletest/include -I../include -I../unittest/layers -I../../nntrainer/layers/loss -pthread -fexceptions -fopenmp -static-openmp -DMIN_CPP_VERSION=201703L -DNNTR_NUM_THREADS=1 -D__LOGGING__=1 -DENABLE_TEST=1 -DREDUCE_TOLERANCE=1 -march=armv8.2-a+fp16 -mfpu=neon-fp16 -mfloat-abi=softfp -O3 -frtti -DNDK_BUILD=1 -DENABLE_FP16=1 -DENABLE_OPENCL=1
+LOCAL_CXXFLAGS += -std=c++17 -frtti -fexceptions
+LOCAL_LDLIBS := -llog -landroid -fopenmp -static-openmp
+
+LOCAL_SRC_FILES := \
+ ../unittest/unittest_attention_kernels_cl.cpp
+
+LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES)
+
+LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer
+LOCAL_STATIC_LIBRARIES := googletest_main test_util
+include $(BUILD_EXECUTABLE)
+
# unittest_ccapi
include $(CLEAR_VARS)
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file unittest_attention_kernels_cl.cpp
+ * @date 28 August 2024
+ * @brief Test setup for blas OpenCL kernels
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Yash Singh <yash.singh@samsung.com>
+ * @bug No known bugs except for NYI items
+ */
+
+#include <fstream>
+#include <gtest/gtest.h>
+#include <type_traits>
+
+#include "nntrainer_test_util.h"
+#include "util_func.h"
+#include <attention_kernel_interface.h>
+#include <cl_context.h>
+#include <layer_context.h>
+#include <tensor.h>
+
+#include "testing_rotarty_emb.cpp"
+
+#define EXPECT_IN_RANGE(VAL, MIN, MAX) \
+ EXPECT_GE((VAL), (MIN)); \
+ EXPECT_LE((VAL), (MAX))
+
+using namespace nntrainer;
+
+static RunLayerContext setUpGpuContext() {
+
+ auto &ac = nntrainer::ClContext::Global();
+ auto rc = RunLayerContext();
+
+ return rc;
+}
+
+TEST(attention_kernels, rotary_emb_kernel_FP32) {
+ RunLayerContext rc = setUpGpuContext();
+
+ int batch = 1;
+ int channel = 1;
+ int height = 4;
+ int width = 4;
+
+ unsigned int dim = 2;
+ unsigned int from = 4;
+ unsigned int max_timestep = 4;
+
+ const float alpha = 1e-1;
+ const int MOD = 10;
+
+ nntrainer::TensorDim::TensorType t_type_nchw_fp32 = {
+ nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32};
+
+ nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32);
+ nntrainer::Tensor B_fp32(batch, channel, height, width, t_type_nchw_fp32);
+
+ GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) +
+ j * (batch * height) + k * (width) + l + 1) %
+ MOD) *
+ alpha);
+
+ B_fp32.copy(A_fp32);
+
+ // std::cout << "\nA_fp32 and B_fp32 before rotary embedding:" << std::endl;
+ // for (unsigned int i = 0; i < A_fp32.size(); ++i) {
+ // std::cout << "Element " << i << " -> " << *(A_fp32.getData<float>() + i)
+ // <<"\t"<<*(B_fp32.getData<float>() + i)<< std::endl;
+ // }
+
+ apply_rotary_emb_cl(A_fp32, dim, from, max_timestep, rc);
+ apply_rotary_emb_tensor(B_fp32, dim, from, max_timestep);
+
+ float mseErrorNeon_fp32 =
+ mse<float>(A_fp32.getData<float>(), B_fp32.getData<float>(), A_fp32.size());
+
+ double cosSimNeon_fp32 = cosine_similarity<float>(
+ A_fp32.getData<float>(), B_fp32.getData<float>(), A_fp32.size());
+
+ const float epsilon = 1e-3 * width;
+
+ EXPECT_IN_RANGE(mseErrorNeon_fp32, 0, epsilon);
+ EXPECT_IN_RANGE((float)cosSimNeon_fp32, 0.99, 1);
+}
+
+TEST(attention_kernels, rotary_emb_kernel_FP32_case2) {
+ RunLayerContext rc = setUpGpuContext();
+
+ int batch = 4;
+ int channel = 4;
+ int height = 8;
+ int width = 8;
+
+ unsigned int dim = 2;
+ unsigned int from = 2;
+ unsigned int max_timestep = 4;
+
+ const float alpha = 1e-1;
+ const int MOD = 10;
+
+ nntrainer::TensorDim::TensorType t_type_nchw_fp32 = {
+ nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32};
+
+ nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32);
+ nntrainer::Tensor B_fp32(batch, channel, height, width, t_type_nchw_fp32);
+
+ GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) +
+ j * (batch * height) + k * (width) + l + 1) %
+ MOD) *
+ alpha);
+
+ B_fp32.copy(A_fp32);
+
+ apply_rotary_emb_cl(A_fp32, dim, from, max_timestep, rc);
+ apply_rotary_emb_tensor(B_fp32, dim, from, max_timestep);
+
+ float mseErrorNeon_fp32 =
+ mse<float>(A_fp32.getData<float>(), B_fp32.getData<float>(), A_fp32.size());
+
+ double cosSimNeon_fp32 = cosine_similarity<float>(
+ A_fp32.getData<float>(), B_fp32.getData<float>(), A_fp32.size());
+
+ const float epsilon = 1e-3 * width;
+
+ EXPECT_IN_RANGE(mseErrorNeon_fp32, 0, epsilon);
+ EXPECT_IN_RANGE((float)cosSimNeon_fp32, 0.99, 1);
+}
+
+TEST(attention_kernels, rotary_emb_kernel_FP16) {
+ RunLayerContext rc = setUpGpuContext();
+
+ int batch = 1;
+ int channel = 1;
+ int height = 4;
+ int width = 4;
+
+ unsigned int dim = 2;
+ unsigned int from = 4;
+ unsigned int max_timestep = 4;
+
+ const float alpha = 1e-1;
+ const int MOD = 10;
+
+ nntrainer::TensorDim::TensorType t_type_nchw_fp16 = {
+ nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16};
+
+ nntrainer::Tensor A_fp16(batch, channel, height, width, t_type_nchw_fp16);
+ nntrainer::Tensor B_fp16(batch, channel, height, width, t_type_nchw_fp16);
+
+ GEN_TEST_INPUT(A_fp16, i * (batch * height * channel) * alpha +
+ j * (batch * height) * alpha + k * (width)*alpha +
+ l + 1);
+
+ B_fp16.copy(A_fp16);
+
+ apply_rotary_emb_cl(A_fp16, dim, from, max_timestep, rc);
+ apply_rotary_emb_tensor(B_fp16, dim, from, max_timestep);
+
+ float mseErrorNeon_fp16 = mse<__fp16>(
+ A_fp16.getData<__fp16>(), B_fp16.getData<__fp16>(), A_fp16.size());
+
+ double cosSimNeon_fp16 = cosine_similarity<__fp16>(
+ A_fp16.getData<__fp16>(), B_fp16.getData<__fp16>(), A_fp16.size());
+
+ const float epsilon = 1e-3 * width;
+
+ EXPECT_IN_RANGE(mseErrorNeon_fp16, 0, epsilon);
+ EXPECT_IN_RANGE((float)cosSimNeon_fp16, 0.99, 1);
+}
+
+TEST(attention_kernels, rotary_emb_kernel_FP16_case2) {
+ RunLayerContext rc = setUpGpuContext();
+
+ int batch = 4;
+ int channel = 4;
+ int height = 8;
+ int width = 8;
+
+ unsigned int dim = 4;
+ unsigned int from = 4;
+ unsigned int max_timestep = 8;
+
+ const float alpha = 1e-1;
+ const int MOD = 10;
+
+ nntrainer::TensorDim::TensorType t_type_nchw_fp16 = {
+ nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16};
+
+ nntrainer::Tensor A_fp16(batch, channel, height, width, t_type_nchw_fp16);
+ nntrainer::Tensor B_fp16(batch, channel, height, width, t_type_nchw_fp16);
+
+ GEN_TEST_INPUT(A_fp16, i * (batch * height * channel) * alpha +
+ j * (batch * height) * alpha + k * (width)*alpha +
+ l + 1);
+
+ B_fp16.copy(A_fp16);
+
+ apply_rotary_emb_cl(A_fp16, dim, from, max_timestep, rc);
+ apply_rotary_emb_tensor(B_fp16, dim, from, max_timestep);
+
+ float mseErrorNeon_fp16 = mse<__fp16>(
+ A_fp16.getData<__fp16>(), B_fp16.getData<__fp16>(), A_fp16.size());
+
+ double cosSimNeon_fp16 = cosine_similarity<__fp16>(
+ A_fp16.getData<__fp16>(), B_fp16.getData<__fp16>(), A_fp16.size());
+
+ const float epsilon = 1e-3 * width;
+
+ EXPECT_IN_RANGE(mseErrorNeon_fp16, 0, epsilon);
+ EXPECT_IN_RANGE((float)cosSimNeon_fp16, 0.99, 1);
+}
+
+GTEST_API_ int main(int argc, char **argv) {
+ int result = -1;
+
+ try {
+ testing::InitGoogleTest(&argc, argv);
+ } catch (...) {
+ std::cerr << "Error during InitGoogleTest" << std::endl;
+ return 0;
+ }
+
+ try {
+ result = RUN_ALL_TESTS();
+ } catch (...) {
+ std::cerr << "Error during RUN_ALL_TESTS()" << std::endl;
+ }
+
+ return result;
+}