From d31a28c3bb35bd37d64719f7d4ce5aa601ad5194 Mon Sep 17 00:00:00 2001 From: Thummala Pallavi Date: Fri, 2 Aug 2024 12:11:34 +0530 Subject: [PATCH] [GPU/OPENCL] RMSNorm Accuracy Fix The alpha values were not picked correctly. Signed-off-by: Thummala Pallavi --- .../layers/cl_layers/rmsnorm_layer_cl.cpp | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp b/nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp index 0dd1f15c85..a0baa1e988 100644 --- a/nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp +++ b/nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp @@ -26,7 +26,7 @@ std::string rmsnorm_cl_kernel_fp16_ = __kernel void rmsnorm_cl_fp16( __global const half *input, // Input tensor __global half *output, // Output tensor - __global const half *alpha, // Alpha values (one for each channel) + __global const half *alpha, // Alpha values (one for each width) half epsilon, int B, // Number of batches int C, // Number of channels @@ -50,7 +50,7 @@ std::string rmsnorm_cl_kernel_fp16_ = half rms_norm = sqrt(sum_squares + epsilon); // Each work item processes all width elements for its specific n, h, c for (int w = 0; w < W; ++w) { - output[index+w] = (input[index+w] / rms_norm) * alpha[c]; + output[index+w] = (input[index+w] / rms_norm) * alpha[index+w]; } } )"; @@ -59,7 +59,7 @@ std::string rmsnorm_cl_kernel_ = R"(__kernel void rmsnorm_cl( __global const float *input, // Input tensor __global float *output, // Output tensor - __global const float *alpha, // Alpha values (one for each channel) + __global const float *alpha, // Alpha values (one for each width) float epsilon, int B, // Number of batches int C, // Number of channels @@ -80,7 +80,7 @@ std::string rmsnorm_cl_kernel_ = float rms_norm = sqrt(sum_squares + epsilon); // Each work item processes all width elements for its specific n, h, c for (int w = 0; w < W; ++w) { - output[index+w] = (input[index+w] / rms_norm) * alpha[c]; + output[index+w] = (input[index+w] / rms_norm) * alpha[index+w]; } } )"; @@ -114,7 +114,7 @@ void RMSNormLayerCl::forwarding(RunLayerContext &context, bool training) { auto &epsilon = std::get(rmsnorm_props).get(); if (in.getDataType() == ml::train::TensorDim::DataType::FP32) { rmsnormProcess(in, out, gamma, epsilon, context); - } else{ + } else { rmsnormProcess_fp16(in, out, gamma, epsilon, context); } } @@ -276,14 +276,14 @@ void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result, if (!ret) { break; } - ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments( - 4, &b, sizeof(int)); + ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(4, &b, + sizeof(int)); if (!ret) { break; } - ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(3, &epsilon, - sizeof(cl_half)); + ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments( + 3, &epsilon, sizeof(cl_half)); if (!ret) { break; } @@ -317,12 +317,11 @@ void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result, break; } } while (false); - } void RMSNormLayerCl::incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) { + unsigned int from, unsigned int to, + bool training) { Tensor &in = context.getInput(SINGLE_INOUT_IDX); Tensor &out = context.getOutput(SINGLE_INOUT_IDX); Tensor &gamma = context.getWeight(wt_idx[RMSParams::gamma]); @@ -374,4 +373,3 @@ void RMSNormLayerCl::setProperty(const std::vector &values) { } } // namespace nntrainer -