half rms_norm = sqrt(sum_squares + epsilon);
// Each work item processes all width elements for its specific n, h, c
for (int w = 0; w < W; ++w) {
- output[index+w] = (input[index+w] / rms_norm) * alpha[index+w];
- }
+ output[index+w] = (input[index+w] / rms_norm) * alpha[w];
+ }
}
)";
float rms_norm = sqrt(sum_squares + epsilon);
// Each work item processes all width elements for its specific n, h, c
for (int w = 0; w < W; ++w) {
- output[index+w] = (input[index+w] / rms_norm) * alpha[index+w];
+ output[index+w] = (input[index+w] / rms_norm) * alpha[w];
}
}
)";
Tensor &gamma = context.getWeight(wt_idx[RMSParams::gamma]);
auto &epsilon = std::get<props::Epsilon>(rmsnorm_props).get();
if (in.getDataType() == ml::train::TensorDim::DataType::FP32) {
- rmsnormProcess(in, out, gamma, epsilon, context);
+ rmsnormProcess(in, out, gamma, epsilon);
} else {
- rmsnormProcess_fp16(in, out, gamma, epsilon, context);
+#ifdef ENABLE_FP16
+ rmsnormProcess_fp16(in, out, gamma, epsilon);
+#else
+ throw std::invalid_argument("Error: enable-fp16 is not enabled");
+#endif
}
}
opencl::Kernel RMSNormLayerCl::kernel_rmsnorm_fp16;
void RMSNormLayerCl::rmsnormProcess(Tensor const &input, Tensor &result,
- Tensor const &gamma, const float epsilon,
- RunLayerContext &context) {
+ Tensor const &gamma, const float epsilon) {
bool ret = false;
int dim1 = input.batch() * input.height() * input.width() * input.channel();
CREATE_IF_EMPTY_DIMS(result, input.batch(), input.channel(), input.height(),
int c = input.channel();
int h = input.height();
int w = input.width();
+
do {
- ret =
- context.clCreateKernel(rmsnorm_cl_kernel_, context.LayerKernel::RMSNORM,
- RMSNormLayerCl::kernel_rmsnorm);
- if (!ret) {
+ ClContext::SharedPtrClKernel kernel_rmsnorm_ptr =
+ cl_context_ref.registerClKernel(rmsnorm_cl_kernel_, "rmsnorm_cl");
+ if (!kernel_rmsnorm_ptr) {
break;
}
- opencl::Buffer inputbuf(context.context_inst_, dim1 * sizeof(float), true,
- nullptr);
+ opencl::Buffer inputbuf(cl_context_ref.context_inst_, dim1 * sizeof(float),
+ true, nullptr);
- opencl::Buffer gammabuf(context.context_inst_,
+ opencl::Buffer gammabuf(cl_context_ref.context_inst_,
input.width() * sizeof(float), true, nullptr);
- opencl::Buffer resultbuf(context.context_inst_, dim1 * sizeof(float), true,
- nullptr);
+ opencl::Buffer resultbuf(cl_context_ref.context_inst_, dim1 * sizeof(float),
+ true, nullptr);
const float *data = input.getData();
float *rdata = result.getData();
const float *gdata = gamma.getData();
- ret = inputbuf.WriteData(context.command_queue_inst_, data);
+ ret = inputbuf.WriteData(cl_context_ref.command_queue_inst_, data);
if (!ret) {
break;
}
- ret = gammabuf.WriteData(context.command_queue_inst_, gdata);
+ ret = gammabuf.WriteData(cl_context_ref.command_queue_inst_, gdata);
if (!ret) {
break;
}
- ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(0, &inputbuf,
- sizeof(cl_mem));
+ ret = kernel_rmsnorm_ptr->SetKernelArguments(0, &inputbuf, sizeof(cl_mem));
if (!ret) {
break;
}
- ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(1, &resultbuf,
- sizeof(cl_mem));
+ ret = kernel_rmsnorm_ptr->SetKernelArguments(1, &resultbuf, sizeof(cl_mem));
if (!ret) {
break;
}
- ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(2, &gammabuf,
- sizeof(cl_mem));
+ ret = kernel_rmsnorm_ptr->SetKernelArguments(2, &gammabuf, sizeof(cl_mem));
if (!ret) {
break;
}
- ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(4, &b, sizeof(int));
+ ret = kernel_rmsnorm_ptr->SetKernelArguments(4, &b, sizeof(int));
if (!ret) {
break;
}
- ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(3, &epsilon,
- sizeof(float));
+ ret = kernel_rmsnorm_ptr->SetKernelArguments(3, &epsilon, sizeof(float));
if (!ret) {
break;
}
- ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(5, &c, sizeof(int));
+ ret = kernel_rmsnorm_ptr->SetKernelArguments(5, &c, sizeof(int));
if (!ret) {
break;
}
- ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(6, &h, sizeof(int));
+ ret = kernel_rmsnorm_ptr->SetKernelArguments(6, &h, sizeof(int));
if (!ret) {
break;
}
- ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(7, &w, sizeof(int));
+ ret = kernel_rmsnorm_ptr->SetKernelArguments(7, &w, sizeof(int));
if (!ret) {
break;
}
const int work_groups_count[3] = {b * c, h, 1};
const int work_group_size[3] = {32, 32, 1}; // test-value
- ret = context.command_queue_inst_.DispatchCommand(
- RMSNormLayerCl::kernel_rmsnorm, work_groups_count, work_group_size);
+ ret = cl_context_ref.command_queue_inst_.DispatchCommand(
+ kernel_rmsnorm_ptr, work_groups_count, work_group_size);
if (!ret) {
break;
}
- ret = resultbuf.ReadData(context.command_queue_inst_, rdata);
+ ret = resultbuf.ReadData(cl_context_ref.command_queue_inst_, rdata);
if (!ret) {
break;
}
void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result,
Tensor const &gamma,
- const float epsilon,
- RunLayerContext &context) {
+ const float epsilon) {
bool ret = false;
int dim1 = input.batch() * input.height() * input.width() * input.channel();
int h = input.height();
int w = input.width();
do {
- ret = context.clCreateKernel(rmsnorm_cl_kernel_fp16_,
- context.LayerKernel::RMSNORM_FP16,
- RMSNormLayerCl::kernel_rmsnorm_fp16);
- if (!ret) {
+ ClContext::SharedPtrClKernel kernel_rmsnorm_ptr =
+ cl_context_ref.registerClKernel(rmsnorm_cl_kernel_fp16_,
+ "rmsnorm_cl_fp16");
+ if (!kernel_rmsnorm_ptr) {
break;
}
- opencl::Buffer inputbuf(context.context_inst_, dim1 * sizeof(cl_half), true,
- nullptr);
+ opencl::Buffer inputbuf(cl_context_ref.context_inst_,
+ dim1 * sizeof(cl_half), true, nullptr);
- opencl::Buffer gammabuf(context.context_inst_,
+ opencl::Buffer gammabuf(cl_context_ref.context_inst_,
input.width() * sizeof(cl_half), true, nullptr);
- opencl::Buffer resultbuf(context.context_inst_, dim1 * sizeof(cl_half),
- true, nullptr);
+ opencl::Buffer resultbuf(cl_context_ref.context_inst_,
+ dim1 * sizeof(cl_half), true, nullptr);
const __fp16 *data = input.getData<__fp16>();
__fp16 *rdata = result.getData<__fp16>();
const __fp16 *gdata = gamma.getData<__fp16>();
- ret = inputbuf.WriteData(context.command_queue_inst_, data);
+ ret = inputbuf.WriteData(cl_context_ref.command_queue_inst_, data);
if (!ret) {
break;
}
- ret = gammabuf.WriteData(context.command_queue_inst_, gdata);
+ ret = gammabuf.WriteData(cl_context_ref.command_queue_inst_, gdata);
if (!ret) {
break;
}
- ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
- 0, &inputbuf, sizeof(cl_mem));
+ ret = kernel_rmsnorm_ptr->SetKernelArguments(0, &inputbuf, sizeof(cl_mem));
if (!ret) {
break;
}
- ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
- 1, &resultbuf, sizeof(cl_mem));
+ ret = kernel_rmsnorm_ptr->SetKernelArguments(1, &resultbuf, sizeof(cl_mem));
if (!ret) {
break;
}
- ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
- 2, &gammabuf, sizeof(cl_mem));
+ ret = kernel_rmsnorm_ptr->SetKernelArguments(2, &gammabuf, sizeof(cl_mem));
if (!ret) {
break;
}
- ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(4, &b,
- sizeof(int));
+ ret = kernel_rmsnorm_ptr->SetKernelArguments(4, &b, sizeof(int));
if (!ret) {
break;
}
- ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
- 3, &epsilon, sizeof(cl_half));
+ ret = kernel_rmsnorm_ptr->SetKernelArguments(3, &epsilon, sizeof(cl_half));
if (!ret) {
break;
}
- ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(5, &c,
- sizeof(int));
+ ret = kernel_rmsnorm_ptr->SetKernelArguments(5, &c, sizeof(int));
if (!ret) {
break;
}
- ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(6, &h,
- sizeof(int));
+ ret = kernel_rmsnorm_ptr->SetKernelArguments(6, &h, sizeof(int));
if (!ret) {
break;
}
- ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(7, &w,
- sizeof(int));
+ ret = kernel_rmsnorm_ptr->SetKernelArguments(7, &w, sizeof(int));
if (!ret) {
break;
}
const int work_groups_count[3] = {b * c, h, 1};
const int work_group_size[3] = {32, 32, 1}; // test-value
- ret = context.command_queue_inst_.DispatchCommand(
- RMSNormLayerCl::kernel_rmsnorm_fp16, work_groups_count, work_group_size);
+ ret = cl_context_ref.command_queue_inst_.DispatchCommand(
+ kernel_rmsnorm_ptr, work_groups_count, work_group_size);
if (!ret) {
break;
}
- ret = resultbuf.ReadData(context.command_queue_inst_, rdata);
+ ret = resultbuf.ReadData(cl_context_ref.command_queue_inst_, rdata);
if (!ret) {
break;
}
auto &epsilon = std::get<props::Epsilon>(rmsnorm_props).get();
if (in_step.getDataType() == ml::train::TensorDim::DataType::FP32) {
- rmsnormProcess(in, out, gamma, epsilon, context);
+ rmsnormProcess(in, out, gamma, epsilon);
} else {
- rmsnormProcess_fp16(in, out, gamma, epsilon, context);
+ rmsnormProcess_fp16(in, out, gamma, epsilon);
}
}