Skip to content

Commit

Permalink
[GPU/OPENCL] RMSNorm Accuracy Fix
Browse files Browse the repository at this point in the history
The alpha values were not picked correctly.

Signed-off-by: Thummala Pallavi <[email protected]>
  • Loading branch information
pallaviNNT committed Aug 12, 2024
1 parent 8877d6a commit d31a28c
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ std::string rmsnorm_cl_kernel_fp16_ =
__kernel void rmsnorm_cl_fp16(
__global const half *input, // Input tensor
__global half *output, // Output tensor
__global const half *alpha, // Alpha values (one for each channel)
__global const half *alpha, // Alpha values (one for each width)
half epsilon,
int B, // Number of batches
int C, // Number of channels
Expand All @@ -50,7 +50,7 @@ std::string rmsnorm_cl_kernel_fp16_ =
half rms_norm = sqrt(sum_squares + epsilon);
// Each work item processes all width elements for its specific n, h, c
for (int w = 0; w < W; ++w) {
output[index+w] = (input[index+w] / rms_norm) * alpha[c];
output[index+w] = (input[index+w] / rms_norm) * alpha[index+w];
}
}
)";
Expand All @@ -59,7 +59,7 @@ std::string rmsnorm_cl_kernel_ =
R"(__kernel void rmsnorm_cl(
__global const float *input, // Input tensor
__global float *output, // Output tensor
__global const float *alpha, // Alpha values (one for each channel)
__global const float *alpha, // Alpha values (one for each width)
float epsilon,
int B, // Number of batches
int C, // Number of channels
Expand All @@ -80,7 +80,7 @@ std::string rmsnorm_cl_kernel_ =
float rms_norm = sqrt(sum_squares + epsilon);
// Each work item processes all width elements for its specific n, h, c
for (int w = 0; w < W; ++w) {
output[index+w] = (input[index+w] / rms_norm) * alpha[c];
output[index+w] = (input[index+w] / rms_norm) * alpha[index+w];
}
}
)";
Expand Down Expand Up @@ -114,7 +114,7 @@ void RMSNormLayerCl::forwarding(RunLayerContext &context, bool training) {
auto &epsilon = std::get<props::Epsilon>(rmsnorm_props).get();
if (in.getDataType() == ml::train::TensorDim::DataType::FP32) {
rmsnormProcess(in, out, gamma, epsilon, context);
} else{
} else {
rmsnormProcess_fp16(in, out, gamma, epsilon, context);
}
}
Expand Down Expand Up @@ -276,14 +276,14 @@ void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result,
if (!ret) {
break;
}
ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
4, &b, sizeof(int));
ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(4, &b,
sizeof(int));
if (!ret) {
break;
}

ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(3, &epsilon,
sizeof(cl_half));
ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
3, &epsilon, sizeof(cl_half));
if (!ret) {
break;
}
Expand Down Expand Up @@ -317,12 +317,11 @@ void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result,
break;
}
} while (false);

}

void RMSNormLayerCl::incremental_forwarding(nntrainer::RunLayerContext &context,
unsigned int from, unsigned int to,
bool training) {
unsigned int from, unsigned int to,
bool training) {
Tensor &in = context.getInput(SINGLE_INOUT_IDX);
Tensor &out = context.getOutput(SINGLE_INOUT_IDX);
Tensor &gamma = context.getWeight(wt_idx[RMSParams::gamma]);
Expand Down Expand Up @@ -374,4 +373,3 @@ void RMSNormLayerCl::setProperty(const std::vector<std::string> &values) {
}

} // namespace nntrainer

0 comments on commit d31a28c

Please sign in to comment.