__kernel void rmsnorm_cl_fp16(
__global const half *input, // Input tensor
__global half *output, // Output tensor
- __global const half *alpha, // Alpha values (one for each channel)
+ __global const half *alpha, // Alpha values (one for each width)
half epsilon,
int B, // Number of batches
int C, // Number of channels
half rms_norm = sqrt(sum_squares + epsilon);
// Each work item processes all width elements for its specific n, h, c
for (int w = 0; w < W; ++w) {
- output[index+w] = (input[index+w] / rms_norm) * alpha[c];
+ output[index+w] = (input[index+w] / rms_norm) * alpha[index+w];
}
}
)";
R"(__kernel void rmsnorm_cl(
__global const float *input, // Input tensor
__global float *output, // Output tensor
- __global const float *alpha, // Alpha values (one for each channel)
+ __global const float *alpha, // Alpha values (one for each width)
float epsilon,
int B, // Number of batches
int C, // Number of channels
float rms_norm = sqrt(sum_squares + epsilon);
// Each work item processes all width elements for its specific n, h, c
for (int w = 0; w < W; ++w) {
- output[index+w] = (input[index+w] / rms_norm) * alpha[c];
+ output[index+w] = (input[index+w] / rms_norm) * alpha[index+w];
}
}
)";
auto &epsilon = std::get<props::Epsilon>(rmsnorm_props).get();
if (in.getDataType() == ml::train::TensorDim::DataType::FP32) {
rmsnormProcess(in, out, gamma, epsilon, context);
- } else{
+ } else {
rmsnormProcess_fp16(in, out, gamma, epsilon, context);
}
}
if (!ret) {
break;
}
- ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
- 4, &b, sizeof(int));
+ ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(4, &b,
+ sizeof(int));
if (!ret) {
break;
}
- ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(3, &epsilon,
- sizeof(cl_half));
+ ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
+ 3, &epsilon, sizeof(cl_half));
if (!ret) {
break;
}
break;
}
} while (false);
-
}
void RMSNormLayerCl::incremental_forwarding(nntrainer::RunLayerContext &context,
- unsigned int from, unsigned int to,
- bool training) {
+ unsigned int from, unsigned int to,
+ bool training) {
Tensor &in = context.getInput(SINGLE_INOUT_IDX);
Tensor &out = context.getOutput(SINGLE_INOUT_IDX);
Tensor &gamma = context.getWeight(wt_idx[RMSParams::gamma]);
}
} // namespace nntrainer
-