Added registerCLKernel function to register custom OpenCL kernels as well as in-house kernels.
Modified attention kernels to remove cl_context related dependencies.
Added initAttentionCLKernels function to register default attention kernels.
Modified unittest to remove layer_context dependency
attention_kernel_strings.h added to handle attention kernels at one place.
Rebased the PR with current log.
Signed-off-by: Yash Singh <yash.singh@samsung.com>
*/
#include <addition_layer_cl.h>
+#include <attention_kernel_strings.h>
#include <blas_kernel_strings.h>
#include <cl_context.h>
#include <concat_cl.h>
blas_kernels_initialized = true;
}
+void ClContext::initAttentionClKernels() {
+ if (attention_kernels_initialized) {
+ ml_logi("ClContext: Default attention kernels already registered and "
+ "initialized");
+ return;
+ }
+
+ registerClKernel(rotary_emb_cl_kernel_, "rotary_emb_cl");
+
+#ifdef ENABLE_FP16
+ registerClKernel(rotary_emb_cl_kernel_fp16_, "rotary_emb_cl_fp16");
+#endif
+ attention_kernels_initialized = true;
+}
+
const ClContext::SharedPtrClKernel
ClContext::registerClKernel(std::string kernel_string,
std::string kernel_name) {
*/
void initBlasClKernels();
+ /**
+ * @brief Initialize and register all attention OpenCl kernels
+ */
+ void initAttentionClKernels();
+
/**
* @brief destructor to release opencl commandQueue
*/
// flag to check default blas kernels registered or not
bool blas_kernels_initialized = false;
+ // flag to check default attention kernels registered or not
+ bool attention_kernels_initialized = false;
+
FactoryMap<nntrainer::Layer> factory_map;
template <typename Args, typename T> struct isSupportedHelper;
* @param[in] dim hidden dim size
* @param[in] from sequence order
* @param[in] max_timestep maximum timestep
- * @param[in] context layer context to get the resource manager and queue id
*
* @todo Calling precompute_freqs in finalize to reduce code redundancy.
*/
void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from,
- unsigned int max_timestep, RunLayerContext &context) {
+ unsigned int max_timestep) {
nntrainer::Tensor out(in.getDim());
float value = 0.0f;
float transformed_value = 0.0f;
rotary_emb_cl(data, rdata, freqs_cos, freqs_sin, cos_, sin_,
input_batch_size, input_channels, input_height, input_width,
- dim, from, max_timestep, in_size, out_size, context);
+ dim, from, max_timestep, in_size, out_size);
} else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
rotary_emb_cl(data, rdata, freqs_cos, freqs_sin, cos_, sin_,
input_batch_size, input_channels, input_height, input_width,
- dim, from, max_timestep, in_size, out_size, context);
+ dim, from, max_timestep, in_size, out_size);
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
#ifndef __ATTENTION_KERNEL_INTERFACE_H__
#define __ATTENTION_KERNEL_INTERFACE_H__
-#include <layer_context.h>
#include <string>
+#include <tensor.h>
namespace nntrainer {
* @param[in] dim hidden dim size
* @param[in] from sequence order
* @param[in] max_timestep maximum timestep
- * @param[in] context layer context to get the resource manager and queue id
*/
void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from,
- unsigned int max_timestep, RunLayerContext &context);
+ unsigned int max_timestep);
} // namespace nntrainer
#endif /* __ATTENTION_KERNEL_INTERFACE_H__ */
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file attention_kernel_strings.h
+ * @date 8 October 2024
+ * @brief All attention OpenCL kernel strings
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Yash Singh <yash.singh@samsung.com>
+ * @bug No known bugs except for NYI items
+ *
+ */
+
+#ifndef __ATTENTION_KERNEL_STRINGS_H__
+#define __ATTENTION_KERNEL_STRINGS_H__
+
+#include <string>
+
+namespace nntrainer {
+static const std::string rotary_emb_cl_kernel_ = R"(
+
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+__kernel void rotary_emb_cl(__global float *input,
+ __global float *output,
+ __global float *freqs_cos,
+ __global float *freqs_sin,
+ __global float *cos_,
+ __global float *sin_,
+ unsigned int batch,
+ unsigned int channel,
+ unsigned int height,
+ unsigned int width,
+ unsigned int dim,
+ unsigned int half_,
+ unsigned int max_timestep,
+ unsigned int from) {
+ __global float *cos_ptr = cos_;
+ __global float *sin_ptr = sin_;
+
+ float value = 0.0f;
+ float transformed_value = 0.0f;
+
+ unsigned int b = get_global_id(0);
+ unsigned int c = get_global_id(1);
+
+ if(b < batch && c < channel){
+ for (unsigned int h = 0; h < height; h++) {
+ if (from + h < max_timestep) {
+ unsigned idx = (from + h)*dim;
+ for(unsigned int i = idx; i < idx + dim; i++){
+ cos_ptr[i - idx] = freqs_cos[i];
+ sin_ptr[i - idx] = freqs_sin[i];
+ }
+ }
+
+ for (unsigned int w = 0; w < width; w = w + dim) {
+ for (unsigned int k = 0; k < dim; k++) {
+ unsigned int span = w + k;
+ value = input[b * channel * height * width + c * height * width + h * width + span];
+ if (k < half_) {
+ transformed_value = -1.0f * input[b * channel * height * width + c * height * width + h * width + span + half_];
+ } else {
+ transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_];
+ }
+ value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
+ output[b * channel * height * width + c * height * width + h * width + span] = value;
+ }
+ }
+ }
+ }
+}
+)";
+
+#ifdef ENABLE_FP16
+static const std::string rotary_emb_cl_kernel_fp16_ = R"(
+
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+__kernel void rotary_emb_cl_fp16(__global half *input,
+ __global half *output,
+ __global float *freqs_cos,
+ __global float *freqs_sin,
+ __global float *cos_,
+ __global float *sin_,
+ unsigned int batch,
+ unsigned int channel,
+ unsigned int height,
+ unsigned int width,
+ unsigned int dim,
+ unsigned int half_,
+ unsigned int max_timestep,
+ unsigned int from) {
+ __global float *cos_ptr = cos_;
+ __global float *sin_ptr = sin_;
+
+ float value = 0.0f;
+ float transformed_value = 0.0f;
+
+ unsigned int b = get_global_id(0);
+ unsigned int c = get_global_id(1);
+
+ if(b < batch && c < channel){
+ for (unsigned int h = 0; h < height; h++) {
+ if (from + h < max_timestep) {
+ unsigned idx = (from + h)*dim;
+ for(int i = idx; i < idx + dim; i++ ){
+ cos_ptr[i - idx] = freqs_cos[i];
+ sin_ptr[i - idx] = freqs_sin[i];
+ }
+ }
+
+ for (unsigned int w = 0; w < width; w = w + dim) {
+ for (unsigned int k = 0; k < dim; k++) {
+ unsigned int span = w + k;
+ value = (float)input[b * channel * height * width + c * height * width + h * width + span];
+ if (k < half_) {
+ transformed_value = -1.0f * (float)input[b * channel * height * width + c * height * width + h * width + span + half_];
+ } else {
+ transformed_value = (float)input[b * channel * height * width + c * height * width + h * width + span - half_];
+ }
+ value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
+ output[b * channel * height * width + c * height * width + h * width + span] = (half)value;
+ }
+ }
+ }
+ }
+}
+)";
+
+#endif
+} // namespace nntrainer
+#endif /* __ATTENTION_KERNEL_INTERFACE_H__ */
*
*/
+#include <attention_kernel_strings.h>
#include <attention_kernels.h>
namespace nntrainer {
-std::string rotary_emb_cl_kernel = R"(
- #pragma OPENCL EXTENSION cl_khr_fp16 : enable
-__kernel void rotary_emb_cl(__global float *input,
- __global float *output,
- __global float *freqs_cos,
- __global float *freqs_sin,
- __global float *cos_,
- __global float *sin_,
- unsigned int batch,
- unsigned int channel,
- unsigned int height,
- unsigned int width,
- unsigned int dim,
- unsigned int half_,
- unsigned int max_timestep,
- unsigned int from) {
- __global float *cos_ptr = cos_;
- __global float *sin_ptr = sin_;
-
- float value = 0.0f;
- float transformed_value = 0.0f;
-
- unsigned int b = get_global_id(0);
- unsigned int c = get_global_id(1);
-
- if(b < batch && c < channel){
- for (unsigned int h = 0; h < height; h++) {
- if (from + h < max_timestep) {
- unsigned idx = (from + h)*dim;
- for(unsigned int i = idx; i < idx + dim; i++){
- cos_ptr[i - idx] = freqs_cos[i];
- sin_ptr[i - idx] = freqs_sin[i];
- }
- }
-
- for (unsigned int w = 0; w < width; w = w + dim) {
- for (unsigned int k = 0; k < dim; k++) {
- unsigned int span = w + k;
- value = input[b * channel * height * width + c * height * width + h * width + span];
- if (k < half_) {
- transformed_value = -1.0f * input[b * channel * height * width + c * height * width + h * width + span + half_];
- } else {
- transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_];
- }
- value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
- output[b * channel * height * width + c * height * width + h * width + span] = value;
- }
- }
- }
- }
-}
-)";
-
-/**
- * @brief defining global kernel objects
- */
-opencl::Kernel kernel_rotary_emb;
void rotary_emb_cl(float *in, float *out,
std::vector<std::vector<float>> freqs_cos,
unsigned int batch, unsigned int channel,
unsigned int height, unsigned int width, unsigned int dim,
unsigned int from, unsigned int max_timestep,
- unsigned int in_size, unsigned int out_size,
- RunLayerContext &context) {
+ unsigned int in_size, unsigned int out_size) {
bool result = false;
do {
- result = context.clCreateKernel(
- rotary_emb_cl_kernel, context.LayerKernel::ROTARY_EMB, kernel_rotary_emb);
- if (!result) {
- printf("Failed to create kernel for rotary_emb_cl\n");
+ ClContext::SharedPtrClKernel kernel_rotaryEmb_ptr =
+ cl_context_ref.registerClKernel(rotary_emb_cl_kernel_, "rotary_emb_cl");
+ if (!kernel_rotaryEmb_ptr) {
break;
}
+
unsigned int cos_dim = cos_.size();
unsigned int sin_dim = sin_.size();
unsigned int freqs_cos_dim = freqs_cos.size();
sizeof(float) * freqs_cos_dim * dim; // max_timestep * dim
size_t dim6_size = sizeof(float) * freqs_sin_dim * dim;
- opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr);
+ opencl::Buffer inputA(cl_context_ref.context_inst_, dim1_size, true,
+ nullptr);
- opencl::Buffer inOutRes(context.context_inst_, dim2_size, true, nullptr);
+ opencl::Buffer inOutRes(cl_context_ref.context_inst_, dim2_size, true,
+ nullptr);
- opencl::Buffer cosBuf(context.context_inst_, dim3_size, true, nullptr);
+ opencl::Buffer cosBuf(cl_context_ref.context_inst_, dim3_size, true,
+ nullptr);
- opencl::Buffer sinBuf(context.context_inst_, dim4_size, true, nullptr);
+ opencl::Buffer sinBuf(cl_context_ref.context_inst_, dim4_size, true,
+ nullptr);
- opencl::Buffer freqs_cosBuf(context.context_inst_, dim5_size, true,
+ opencl::Buffer freqs_cosBuf(cl_context_ref.context_inst_, dim5_size, true,
nullptr);
- opencl::Buffer freqs_sinBuf(context.context_inst_, dim6_size, true,
+ opencl::Buffer freqs_sinBuf(cl_context_ref.context_inst_, dim6_size, true,
nullptr);
std::vector<float> freqs_cos_flat;
freqs_sin_flat.insert(freqs_sin_flat.end(), row.begin(), row.end());
}
- result = inputA.WriteData(context.command_queue_inst_, in);
+ result = inputA.WriteData(cl_context_ref.command_queue_inst_, in);
if (!result) {
printf("Failed to write input data\n");
break;
}
- result = inOutRes.WriteData(context.command_queue_inst_, out);
+ result = inOutRes.WriteData(cl_context_ref.command_queue_inst_, out);
if (!result) {
printf("Failed to write output data\n");
break;
}
- result = freqs_cosBuf.WriteData(context.command_queue_inst_,
+ result = freqs_cosBuf.WriteData(cl_context_ref.command_queue_inst_,
freqs_cos_flat.data());
if (!result) {
printf("Failed to write freqs cos data\n");
break;
}
- result = freqs_sinBuf.WriteData(context.command_queue_inst_,
+ result = freqs_sinBuf.WriteData(cl_context_ref.command_queue_inst_,
freqs_sin_flat.data());
if (!result) {
printf("Failed to write freqs sin data\n");
break;
}
- result = cosBuf.WriteData(context.command_queue_inst_, cos_.data());
+ result = cosBuf.WriteData(cl_context_ref.command_queue_inst_, cos_.data());
if (!result) {
printf("Failed to write cos data\n");
break;
}
- result = sinBuf.WriteData(context.command_queue_inst_, sin_.data());
+ result = sinBuf.WriteData(cl_context_ref.command_queue_inst_, sin_.data());
if (!result) {
printf("Failed to write sin data\n");
break;
}
- result = kernel_rotary_emb.SetKernelArguments(0, &inputA, sizeof(cl_mem));
+ result =
+ kernel_rotaryEmb_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem));
if (!result) {
printf("Failed to set inputA argument\n");
break;
}
- result = kernel_rotary_emb.SetKernelArguments(1, &inOutRes, sizeof(cl_mem));
+ result =
+ kernel_rotaryEmb_ptr->SetKernelArguments(1, &inOutRes, sizeof(cl_mem));
if (!result) {
printf("Failed to set inOutRes argument\n");
break;
}
- result =
- kernel_rotary_emb.SetKernelArguments(2, &freqs_cosBuf, sizeof(cl_mem));
+ result = kernel_rotaryEmb_ptr->SetKernelArguments(2, &freqs_cosBuf,
+ sizeof(cl_mem));
if (!result) {
printf("Failed to set freqs_cosBuf argument\n");
break;
}
- result =
- kernel_rotary_emb.SetKernelArguments(3, &freqs_sinBuf, sizeof(cl_mem));
+ result = kernel_rotaryEmb_ptr->SetKernelArguments(3, &freqs_sinBuf,
+ sizeof(cl_mem));
if (!result) {
printf("Failed to set freqs_sinBuf argument\n");
break;
}
- result = kernel_rotary_emb.SetKernelArguments(4, &cosBuf, sizeof(cl_mem));
+ result =
+ kernel_rotaryEmb_ptr->SetKernelArguments(4, &cosBuf, sizeof(cl_mem));
if (!result) {
printf("Failed to set cosBuf argument\n");
break;
}
- result = kernel_rotary_emb.SetKernelArguments(5, &sinBuf, sizeof(cl_mem));
+ result =
+ kernel_rotaryEmb_ptr->SetKernelArguments(5, &sinBuf, sizeof(cl_mem));
if (!result) {
printf("Failed to set sinBuf argument\n");
break;
}
- result = kernel_rotary_emb.SetKernelArguments(6, &batch, sizeof(int));
+ result = kernel_rotaryEmb_ptr->SetKernelArguments(6, &batch, sizeof(int));
if (!result) {
printf("Failed to set batch argument\n");
break;
}
- result = kernel_rotary_emb.SetKernelArguments(7, &channel, sizeof(int));
+ result = kernel_rotaryEmb_ptr->SetKernelArguments(7, &channel, sizeof(int));
if (!result) {
printf("Failed to set channel argument\n");
break;
}
- result = kernel_rotary_emb.SetKernelArguments(8, &height, sizeof(int));
+ result = kernel_rotaryEmb_ptr->SetKernelArguments(8, &height, sizeof(int));
if (!result) {
printf("Failed to set height argument\n");
break;
}
- result = kernel_rotary_emb.SetKernelArguments(9, &width, sizeof(int));
+ result = kernel_rotaryEmb_ptr->SetKernelArguments(9, &width, sizeof(int));
if (!result) {
printf("Failed to set width argument\n");
break;
}
- result = kernel_rotary_emb.SetKernelArguments(10, &dim, sizeof(int));
+ result = kernel_rotaryEmb_ptr->SetKernelArguments(10, &dim, sizeof(int));
if (!result) {
printf("Failed to set dim argument\n");
break;
}
unsigned int half_ = dim / 2;
- result = kernel_rotary_emb.SetKernelArguments(11, &half_, sizeof(int));
+ result = kernel_rotaryEmb_ptr->SetKernelArguments(11, &half_, sizeof(int));
if (!result) {
printf("Failed to set half argument\n");
break;
}
result =
- kernel_rotary_emb.SetKernelArguments(12, &max_timestep, sizeof(int));
+ kernel_rotaryEmb_ptr->SetKernelArguments(12, &max_timestep, sizeof(int));
if (!result) {
printf("Failed to set timestamp argument\n");
break;
}
- result = kernel_rotary_emb.SetKernelArguments(13, &from, sizeof(int));
+ result = kernel_rotaryEmb_ptr->SetKernelArguments(13, &from, sizeof(int));
if (!result) {
printf("Failed to set from argument\n");
break;
const int work_groups_count[3] = {(int)batch, (int)channel, 1};
const int work_group_size[3] = {32, 32, 1}; // test-value
- result = context.command_queue_inst_.DispatchCommand(
- kernel_rotary_emb, work_groups_count, work_group_size);
+ result = cl_context_ref.command_queue_inst_.DispatchCommand(
+ kernel_rotaryEmb_ptr, work_groups_count, work_group_size);
if (!result) {
printf("Failed to dispatch command\n");
break;
}
- result = inOutRes.ReadData(context.command_queue_inst_, out);
+ result = inOutRes.ReadData(cl_context_ref.command_queue_inst_, out);
if (!result) {
printf("Failed to read data\n");
break;
#ifndef __ATTENTION_KERNELS_H__
#define __ATTENTION_KERNELS_H__
-#include <layer_context.h>
+#include <cl_context.h>
#include <opencl_buffer.h>
#include <opencl_kernel.h>
#include <string>
namespace nntrainer {
-/**
- * @brief declaring global kernel objects
- */
-extern opencl::Kernel kernel_rotary_emb;
+// get global cl_context to use in kernels
+static ClContext cl_context_ref;
/**
* @brief Rotary Embedding process
* @param[in] max_timestep max timestep
* @param[in] in_size size of input
* @param[in] out_size size of output
- * @param[in] context RunLayerContext reference
*/
void rotary_emb_cl(float *in, float *out,
std::vector<std::vector<float>> freqs_cos,
unsigned int batch, unsigned int channel,
unsigned int height, unsigned int width, unsigned int dim,
unsigned int from, unsigned int max_timestamp,
- unsigned int in_size, unsigned int out_size,
- RunLayerContext &context);
+ unsigned int in_size, unsigned int out_size);
#ifdef ENABLE_FP16
-/**
- * @brief declaring global fp16 kernel objects
- */
-extern opencl::Kernel kernel_rotary_emb_fp16;
/**
* @brief Rotary Embedding process
* @param[in] max_timestep max timestep
* @param[in] in_size size of input
* @param[in] out_size size of output
- * @param[in] context RunLayerContext reference
*/
void rotary_emb_cl(__fp16 *in, __fp16 *out,
std::vector<std::vector<float>> freqs_cos,
unsigned int batch, unsigned int channel,
unsigned int height, unsigned int width, unsigned int dim,
unsigned int from, unsigned int max_timestamp,
- unsigned int in_size, unsigned int out_size,
- RunLayerContext &context);
+ unsigned int in_size, unsigned int out_size);
#endif
*
*/
+#include <attention_kernel_strings.h>
#include <attention_kernels.h>
namespace nntrainer {
-std::string rotary_emb_cl_kernel_fp16 = R"(
- #pragma OPENCL EXTENSION cl_khr_fp16 : enable
-__kernel void rotary_emb_cl_fp16(__global half *input,
- __global half *output,
- __global float *freqs_cos,
- __global float *freqs_sin,
- __global float *cos_,
- __global float *sin_,
- unsigned int batch,
- unsigned int channel,
- unsigned int height,
- unsigned int width,
- unsigned int dim,
- unsigned int half_,
- unsigned int max_timestep,
- unsigned int from) {
- __global float *cos_ptr = cos_;
- __global float *sin_ptr = sin_;
-
- float value = 0.0f;
- float transformed_value = 0.0f;
-
- unsigned int b = get_global_id(0);
- unsigned int c = get_global_id(1);
-
- if(b < batch && c < channel){
- for (unsigned int h = 0; h < height; h++) {
- if (from + h < max_timestep) {
- unsigned idx = (from + h)*dim;
- for(int i = idx; i < idx + dim; i++ ){
- cos_ptr[i - idx] = freqs_cos[i];
- sin_ptr[i - idx] = freqs_sin[i];
- }
- }
-
- for (unsigned int w = 0; w < width; w = w + dim) {
- for (unsigned int k = 0; k < dim; k++) {
- unsigned int span = w + k;
- value = (float)input[b * channel * height * width + c * height * width + h * width + span];
- if (k < half_) {
- transformed_value = -1.0f * (float)input[b * channel * height * width + c * height * width + h * width + span + half_];
- } else {
- transformed_value = (float)input[b * channel * height * width + c * height * width + h * width + span - half_];
- }
- value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
- output[b * channel * height * width + c * height * width + h * width + span] = (half)value;
- }
- }
- }
- }
-}
-)";
-
-/**
- * @brief defining global kernel objects
- */
-opencl::Kernel kernel_rotary_emb_fp16;
void rotary_emb_cl(__fp16 *in, __fp16 *out,
std::vector<std::vector<float>> freqs_cos,
unsigned int batch, unsigned int channel,
unsigned int height, unsigned int width, unsigned int dim,
unsigned int from, unsigned int max_timestep,
- unsigned int in_size, unsigned int out_size,
- RunLayerContext &context) {
+ unsigned int in_size, unsigned int out_size) {
bool result = false;
do {
- result = context.clCreateKernel(rotary_emb_cl_kernel_fp16,
- context.LayerKernel::ROTARY_EMB_FP16,
- kernel_rotary_emb_fp16);
- if (!result) {
- printf("Failed to create kernel for rotary_emb_cl\n");
+ ClContext::SharedPtrClKernel kernel_rotaryEmb_fp16_ptr =
+ cl_context_ref.registerClKernel(rotary_emb_cl_kernel_fp16_,
+ "rotary_emb_cl_fp16");
+ if (!kernel_rotaryEmb_fp16_ptr) {
break;
}
size_t dim5_size = sizeof(float) * freqs_cos_dim * dim;
size_t dim6_size = sizeof(float) * freqs_sin_dim * dim;
- opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr);
+ opencl::Buffer inputA(cl_context_ref.context_inst_, dim1_size, true,
+ nullptr);
- opencl::Buffer inOutRes(context.context_inst_, dim2_size, true, nullptr);
+ opencl::Buffer inOutRes(cl_context_ref.context_inst_, dim2_size, true,
+ nullptr);
- opencl::Buffer cosBuf(context.context_inst_, dim3_size, true, nullptr);
+ opencl::Buffer cosBuf(cl_context_ref.context_inst_, dim3_size, true,
+ nullptr);
- opencl::Buffer sinBuf(context.context_inst_, dim4_size, true, nullptr);
+ opencl::Buffer sinBuf(cl_context_ref.context_inst_, dim4_size, true,
+ nullptr);
- opencl::Buffer freqs_cosBuf(context.context_inst_, dim5_size, true,
+ opencl::Buffer freqs_cosBuf(cl_context_ref.context_inst_, dim5_size, true,
nullptr);
- opencl::Buffer freqs_sinBuf(context.context_inst_, dim6_size, true,
+ opencl::Buffer freqs_sinBuf(cl_context_ref.context_inst_, dim6_size, true,
nullptr);
std::vector<float> freqs_cos_flat;
freqs_sin_flat.insert(freqs_sin_flat.end(), row.begin(), row.end());
}
- result = inputA.WriteData(context.command_queue_inst_, in);
+ result = inputA.WriteData(cl_context_ref.command_queue_inst_, in);
if (!result) {
printf("Failed to write input data\n");
break;
}
- result = inOutRes.WriteData(context.command_queue_inst_, out);
+ result = inOutRes.WriteData(cl_context_ref.command_queue_inst_, out);
if (!result) {
printf("Failed to write output data\n");
break;
}
- result = freqs_cosBuf.WriteData(context.command_queue_inst_,
+ result = freqs_cosBuf.WriteData(cl_context_ref.command_queue_inst_,
freqs_cos_flat.data());
if (!result) {
printf("Failed to write freqs cos data\n");
break;
}
- result = freqs_sinBuf.WriteData(context.command_queue_inst_,
+ result = freqs_sinBuf.WriteData(cl_context_ref.command_queue_inst_,
freqs_sin_flat.data());
if (!result) {
printf("Failed to write freqs sin data\n");
break;
}
- result = cosBuf.WriteData(context.command_queue_inst_, cos_.data());
+ result = cosBuf.WriteData(cl_context_ref.command_queue_inst_, cos_.data());
if (!result) {
printf("Failed to write cos data\n");
break;
}
- result = sinBuf.WriteData(context.command_queue_inst_, sin_.data());
+ result = sinBuf.WriteData(cl_context_ref.command_queue_inst_, sin_.data());
if (!result) {
printf("Failed to write sin data\n");
break;
}
result =
- kernel_rotary_emb_fp16.SetKernelArguments(0, &inputA, sizeof(cl_mem));
+ kernel_rotaryEmb_fp16_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem));
if (!result) {
printf("Failed to set inputA argument\n");
break;
}
- result =
- kernel_rotary_emb_fp16.SetKernelArguments(1, &inOutRes, sizeof(cl_mem));
+ result = kernel_rotaryEmb_fp16_ptr->SetKernelArguments(1, &inOutRes,
+ sizeof(cl_mem));
if (!result) {
printf("Failed to set inOutRes argument\n");
break;
}
- result = kernel_rotary_emb_fp16.SetKernelArguments(2, &freqs_cosBuf,
- sizeof(cl_mem));
+ result = kernel_rotaryEmb_fp16_ptr->SetKernelArguments(2, &freqs_cosBuf,
+ sizeof(cl_mem));
if (!result) {
printf("Failed to set freqs_cosBuf argument\n");
break;
}
- result = kernel_rotary_emb_fp16.SetKernelArguments(3, &freqs_sinBuf,
- sizeof(cl_mem));
+ result = kernel_rotaryEmb_fp16_ptr->SetKernelArguments(3, &freqs_sinBuf,
+ sizeof(cl_mem));
if (!result) {
printf("Failed to set freqs_sinBuf argument\n");
break;
}
result =
- kernel_rotary_emb_fp16.SetKernelArguments(4, &cosBuf, sizeof(cl_mem));
+ kernel_rotaryEmb_fp16_ptr->SetKernelArguments(4, &cosBuf, sizeof(cl_mem));
if (!result) {
printf("Failed to set cosBuf argument\n");
break;
}
result =
- kernel_rotary_emb_fp16.SetKernelArguments(5, &sinBuf, sizeof(cl_mem));
+ kernel_rotaryEmb_fp16_ptr->SetKernelArguments(5, &sinBuf, sizeof(cl_mem));
if (!result) {
printf("Failed to set sinBuf argument\n");
break;
}
- result = kernel_rotary_emb_fp16.SetKernelArguments(6, &batch, sizeof(int));
+ result =
+ kernel_rotaryEmb_fp16_ptr->SetKernelArguments(6, &batch, sizeof(int));
if (!result) {
printf("Failed to set batch argument\n");
break;
}
result =
- kernel_rotary_emb_fp16.SetKernelArguments(7, &channel, sizeof(int));
+ kernel_rotaryEmb_fp16_ptr->SetKernelArguments(7, &channel, sizeof(int));
if (!result) {
printf("Failed to set channel argument\n");
break;
}
- result = kernel_rotary_emb_fp16.SetKernelArguments(8, &height, sizeof(int));
+ result =
+ kernel_rotaryEmb_fp16_ptr->SetKernelArguments(8, &height, sizeof(int));
if (!result) {
printf("Failed to set height argument\n");
break;
}
- result = kernel_rotary_emb_fp16.SetKernelArguments(9, &width, sizeof(int));
+ result =
+ kernel_rotaryEmb_fp16_ptr->SetKernelArguments(9, &width, sizeof(int));
if (!result) {
printf("Failed to set width argument\n");
break;
}
- result = kernel_rotary_emb_fp16.SetKernelArguments(10, &dim, sizeof(int));
+ result =
+ kernel_rotaryEmb_fp16_ptr->SetKernelArguments(10, &dim, sizeof(int));
if (!result) {
printf("Failed to set dim argument\n");
break;
}
unsigned int half_ = dim / 2;
- result = kernel_rotary_emb_fp16.SetKernelArguments(11, &half_, sizeof(int));
+ result =
+ kernel_rotaryEmb_fp16_ptr->SetKernelArguments(11, &half_, sizeof(int));
if (!result) {
printf("Failed to set half argument\n");
break;
}
- result =
- kernel_rotary_emb_fp16.SetKernelArguments(12, &max_timestep, sizeof(int));
+ result = kernel_rotaryEmb_fp16_ptr->SetKernelArguments(12, &max_timestep,
+ sizeof(int));
if (!result) {
printf("Failed to set timestamp argument\n");
break;
}
- result = kernel_rotary_emb_fp16.SetKernelArguments(13, &from, sizeof(int));
+ result =
+ kernel_rotaryEmb_fp16_ptr->SetKernelArguments(13, &from, sizeof(int));
if (!result) {
printf("Failed to set from argument\n");
break;
const int work_groups_count[3] = {(int)batch, (int)channel, 1};
const int work_group_size[3] = {32, 32, 1}; // test-value
- result = context.command_queue_inst_.DispatchCommand(
- kernel_rotary_emb_fp16, work_groups_count, work_group_size);
+ result = cl_context_ref.command_queue_inst_.DispatchCommand(
+ kernel_rotaryEmb_fp16_ptr, work_groups_count, work_group_size);
if (!result) {
printf("Failed to dispatch command\n");
break;
}
- result = inOutRes.ReadData(context.command_queue_inst_, out);
+ result = inOutRes.ReadData(cl_context_ref.command_queue_inst_, out);
if (!result) {
printf("Failed to read data\n");
break;
'blas_kernel_interface.h',
'blas_kernel_strings.h',
'attention_kernel_interface.h',
+ 'attention_kernel_strings.h',
]
if get_option('enable-fp16')
using namespace nntrainer;
-static RunLayerContext setUpGpuContext() {
-
+static void setUpGpuContext() {
auto &ac = nntrainer::ClContext::Global();
- auto rc = RunLayerContext();
-
- return rc;
+ ac.initAttentionClKernels();
}
TEST(attention_kernels, rotary_emb_kernel_FP32) {
- RunLayerContext rc = setUpGpuContext();
+ setUpGpuContext();
int batch = 1;
int channel = 1;
B_fp32.copy(A_fp32);
- apply_rotary_emb_cl(A_fp32, dim, from, max_timestep, rc);
+ apply_rotary_emb_cl(A_fp32, dim, from, max_timestep);
apply_rotary_emb_tensor(B_fp32, dim, from, max_timestep);
float mseErrorNeon_fp32 =
}
TEST(attention_kernels, rotary_emb_kernel_FP32_case2) {
- RunLayerContext rc = setUpGpuContext();
+ setUpGpuContext();
int batch = 4;
int channel = 4;
B_fp32.copy(A_fp32);
- apply_rotary_emb_cl(A_fp32, dim, from, max_timestep, rc);
+ apply_rotary_emb_cl(A_fp32, dim, from, max_timestep);
apply_rotary_emb_tensor(B_fp32, dim, from, max_timestep);
float mseErrorNeon_fp32 =
}
TEST(attention_kernels, rotary_emb_kernel_FP16) {
- RunLayerContext rc = setUpGpuContext();
+ setUpGpuContext();
int batch = 1;
int channel = 1;
B_fp16.copy(A_fp16);
- apply_rotary_emb_cl(A_fp16, dim, from, max_timestep, rc);
+ apply_rotary_emb_cl(A_fp16, dim, from, max_timestep);
apply_rotary_emb_tensor(B_fp16, dim, from, max_timestep);
float mseErrorNeon_fp16 = mse<__fp16>(
}
TEST(attention_kernels, rotary_emb_kernel_FP16_case2) {
- RunLayerContext rc = setUpGpuContext();
+ setUpGpuContext();
int batch = 4;
int channel = 4;
B_fp16.copy(A_fp16);
- apply_rotary_emb_cl(A_fp16, dim, from, max_timestep, rc);
+ apply_rotary_emb_cl(A_fp16, dim, from, max_timestep);
apply_rotary_emb_tensor(B_fp16, dim, from, max_timestep);
float mseErrorNeon_fp16 = mse<__fp16>(