Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ layer ] Mixed precision forwarding / backwarding for bn layer @open sesame 03/07 10:42 #2462

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 166 additions & 49 deletions nntrainer/layers/bn_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,32 +174,86 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context,
/** use hidden_ as temporary tensor before setting the result in hidden */
Tensor t_full = hidden_;
Tensor &cvar = context.getTensor(wt_idx[BNParams::cvar]);
if (input_.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
Tensor mu32 = mu.getSingleTensor();
Tensor var32 = var.getSingleTensor();
Tensor gamma32 = gamma.getSingleTensor();
Tensor beta32 = beta.getSingleTensor();
Tensor input_32 = input_.getSingleTensor();
Tensor hidden_32 = hidden_.getSingleTensor();
Tensor t_full32 = hidden_32;
Tensor deviation32 = deviation.getSingleTensor();
Tensor invstd32 = invstd.getSingleTensor();
Tensor t_reduced32 = t_reduced.getSingleTensor();
Tensor cvar32 = cvar.getSingleTensor();

if (training) {
input_32.average(axes_to_reduce, t_reduced32);
input_32.subtract(t_reduced32, deviation32);

mu32.multiply_i(momentum);
mu32.add_i(t_reduced32, 1 - momentum);

deviation32.pow(2.0f, t_full32);
t_full32.average(axes_to_reduce, cvar32);

var32.multiply_i(momentum);
var32.add_i(cvar32, 1 - momentum);

cvar32.add_i(epsilon);
cvar32.pow(-0.5f, invstd32);
} else {
input_32.subtract(mu32, deviation32);
/** @todo do below 2 lines only for first iteration */
var32.add(epsilon, invstd32);
invstd32.pow_i(-0.5f);
}

if (training) {
input_.average(axes_to_reduce, t_reduced);
input_.subtract(t_reduced, deviation);

mu.multiply_i(momentum);
mu.add_i(t_reduced, 1 - momentum);

deviation.pow(2.0f, t_full);
t_full.average(axes_to_reduce, cvar);

var.multiply_i(momentum);
var.add_i(cvar, 1 - momentum);

cvar.add_i(epsilon);
cvar.pow(-0.5f, invstd);
deviation32.multiply(invstd32, hidden_32);
hidden_32.multiply_i(gamma32);
hidden_32.add_i(beta32);

mu.copyData(mu32);
var.copyData(var32);
gamma.copyData(gamma32);
beta.copyData(beta32);
input_.copyData(input_32);
hidden_.copyData(hidden_32);
deviation.copyData(deviation32);
invstd.copyData(invstd32);
t_reduced.copyData(t_reduced32);
cvar.copyData(cvar32);
#else
throw std::runtime_error("enable-fp16 is not enabled");
#endif
} else {
input_.subtract(mu, deviation);
/** @todo do below 2 lines only for first iteration */
var.add(epsilon, invstd);
invstd.pow_i(-0.5f);
}
if (training) {
input_.average(axes_to_reduce, t_reduced);
input_.subtract(t_reduced, deviation);

mu.multiply_i(momentum);
mu.add_i(t_reduced, 1 - momentum);

deviation.pow(2.0f, t_full);
t_full.average(axes_to_reduce, cvar);

var.multiply_i(momentum);
var.add_i(cvar, 1 - momentum);

cvar.add_i(epsilon);
cvar.pow(-0.5f, invstd);
} else {
input_.subtract(mu, deviation);
/** @todo do below 2 lines only for first iteration */
var.add(epsilon, invstd);
invstd.pow_i(-0.5f);
}

deviation.multiply(invstd, hidden_);
hidden_.multiply_i(gamma);
hidden_.add_i(beta);
deviation.multiply(invstd, hidden_);
hidden_.multiply_i(gamma);
hidden_.add_i(beta);
}
}

void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) {
Expand All @@ -213,42 +267,105 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) {

Tensor &t_reduced = context.getTensor(wt_idx[BNParams::t_reduced]);
Tensor &t_full = context.getTensor(wt_idx[BNParams::t_full]);
if (deriv.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
Tensor gamma32 = gamma.getSingleTensor();
Tensor deriv32 = deriv.getSingleTensor();
Tensor dx32 = dx.getSingleTensor();
Tensor deviation32 = deviation.getSingleTensor();
Tensor invstd32 = invstd.getSingleTensor();
Tensor cvar32 = cvar.getSingleTensor();
Tensor t_reduced32 = t_reduced.getSingleTensor();
Tensor t_full32 = t_full.getSingleTensor();

deviation32.multiply(deriv32, t_full32);
t_full32.average(axes_to_reduce, t_reduced32);
t_reduced32.divide_i(cvar32);
deviation32.multiply_i(t_reduced32);

if (context.getTrainable()) {
/**
* This calculates dgamma tensor.
*/
Tensor dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]); ?

Tensor dgamma32 = dgamma.getSingleTensor();
t_full32.multiply_i(invstd32);
t_full32.sum(axes_to_reduce, dgamma32);
dgamma.copyData(dgamma32);

/**
* This implementation depends on the pre-calculated dbeta calculated.
*/
Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
Tensor dbeta32 = dbeta.getSingleTensor();
dbeta32.divide(divider, t_reduced32);
} else {
deriv32.average(axes_to_reduce, t_reduced32);
}

deviation.multiply(deriv, t_full);
t_full.average(axes_to_reduce, t_reduced);
t_reduced.divide_i(cvar);
deviation.multiply_i(t_reduced);

if (context.getTrainable()) {
/**
* This calculates dgamma tensor.
*/
Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]);
t_full.multiply_i(invstd);
t_full.sum(axes_to_reduce, dgamma);

/**
* This implementation depends on the pre-calculated dbeta calculated.
*/
Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
dbeta.divide(divider, t_reduced);
deriv32.subtract(t_reduced32, dx32);
dx32.subtract_i(deviation32);

invstd32.multiply_i(gamma32);
dx32.multiply_i(invstd32);

gamma.copyData(gamma32);
dx.copyData(dx32);
deviation.copyData(deviation32);
invstd.copyData(invstd32);
cvar.copyData(cvar32);
t_reduced.copyData(t_reduced32);
t_full.copyData(t_full32);
#else
throw std::runtime_error("enable-fp16 is not enabled");
#endif
} else {
deriv.average(axes_to_reduce, t_reduced);
}
deviation.multiply(deriv, t_full);
t_full.average(axes_to_reduce, t_reduced);
t_reduced.divide_i(cvar);
deviation.multiply_i(t_reduced);

if (context.getTrainable()) {
/**
* This calculates dgamma tensor.
*/
Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]);
t_full.multiply_i(invstd);
t_full.sum(axes_to_reduce, dgamma);

/**
* This implementation depends on the pre-calculated dbeta calculated.
*/
Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
dbeta.divide(divider, t_reduced);
} else {
deriv.average(axes_to_reduce, t_reduced);
}

deriv.subtract(t_reduced, dx);
dx.subtract_i(deviation);
deriv.subtract(t_reduced, dx);
dx.subtract_i(deviation);

invstd.multiply_i(gamma);
dx.multiply_i(invstd);
invstd.multiply_i(gamma);
dx.multiply_i(invstd);
}
}

void BatchNormalizationLayer::calcGradient(RunLayerContext &context) {
/** dgamma is calculated in calcDerivative. dbeta is calculated here */
Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
const Tensor &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX);

deriv.sum(axes_to_reduce, dbeta);
if (deriv.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
Tensor dbeta32 = dbeta.getSingleTensor();
Tensor deriv32 = deriv.getSingleTensor();
deriv32.sum(axes_to_reduce, dbeta32);
dbeta.copyData(dbeta32);
#else
throw std::runtime_error("enable-fp16 is not enabled");
#endif
} else {
deriv.sum(axes_to_reduce, dbeta);
}
}

void BatchNormalizationLayer::exportTo(
Expand Down
14 changes: 14 additions & 0 deletions nntrainer/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2014,6 +2014,20 @@ class Tensor {

scale_factors_fp16 = scales;
}

/**
* @brief Get the Single Tensor object
*
* @param input
* @return Tensor
*/
Tensor getSingleTensor() const {
TensorDim output_dim = getDim();
output_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor output(output_dim, true);
output.copyData(*this);
return output;
}
#endif

/**
Expand Down
Loading