Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lamb fused just for mlperf, not merged #161

Open
wants to merge 1 commit into
base: 0.2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions intel_pytorch_extension_py/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .split_sgd import is_available
from .split_sgd import SplitSGD
from .lamb import Lamb
127 changes: 127 additions & 0 deletions intel_pytorch_extension_py/optim/lamb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Lamb optimizer."""

import collections
import math

import torch
from tensorboardX import SummaryWriter
from torch.optim import Optimizer
from _torch_ipex import lamb_fused_step_

def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
"""Log a histogram of trust ratio scalars in across layers."""
results = collections.defaultdict(list)
for group in optimizer.param_groups:
for p in group['params']:
state = optimizer.state[p]
for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
if i in state:
results[i].append(state[i])

for k, v in results.items():
event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)

class Lamb(Optimizer):
r"""Implements Lamb algorithm.

It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.

Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
adam (bool, optional): always use trust ratio = 1, which turns this into
Adam. Useful for comparison purposes.

.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
"""

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
weight_decay=0, adam=False, bf16=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay)
self.adam = adam
self.bf16 = bf16
super(Lamb, self).__init__(params, defaults)

def set_bf16(self, bf16=False):
self.bf16 = bf16

def step(self, closure=None):
"""Performs a single optimization step.

Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')

state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
if self.bf16:
# additional fp32 version of master weights
state['bot_half'] = torch.zeros_like(p.data, dtype=torch.bfloat16, device=p.data.device)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']


state['step'] += 1
if self.bf16:
lamb_fused_step_(p, p.grad, state['bot_half'], exp_avg, exp_avg_sq, state['step'], group['lr'], beta1, beta2, group['weight_decay'], group['eps'])
else:
step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
# m_t
exp_avg.mul_(beta1).add_(1 - beta1, grad)
# v_t
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
weight_norm = data_fp32.pow(2).sum().sqrt().clamp(0, 10)
adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
adam_norm = adam_step.pow(2).sum().sqrt()

if group['weight_decay'] != 0:
adam_step.add_(group['weight_decay'], p.data)

if weight_norm == 0 or adam_norm == 0:
trust_ratio = 1
else:
trust_ratio = weight_norm / adam_norm
state['weight_norm'] = weight_norm
state['adam_norm'] = adam_norm
state['trust_ratio'] = trust_ratio
if self.adam:
trust_ratio = 1
p.data.add_(-step_size * trust_ratio, adam_step)

return loss
99 changes: 99 additions & 0 deletions torch_ipex/csrc/cpu/ExtendOPs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,105 @@
#include "DevOPs.h"

namespace torch_ipex {
inline float pack_bfloat16_float(at::BFloat16 a, at::BFloat16 b) {
uint16_t* ap = reinterpret_cast<uint16_t*>(&a);
uint16_t* bp = reinterpret_cast<uint16_t*>(&b);
uint32_t hi = static_cast<uint32_t>(*ap);
uint32_t lo = static_cast<uint32_t>(*bp);
uint32_t out = (hi << 16) + lo;
float* outp = reinterpret_cast<float*>(&out);
return *outp;
}

inline std::tuple<at::BFloat16, at::BFloat16> unpack_float_bfloat16(float a) {
uint32_t* ap = reinterpret_cast<uint32_t*>(&a);
uint16_t hi = static_cast<uint16_t>((*ap) >> 16);
uint16_t lo = static_cast<uint16_t>((*ap));
at::BFloat16* hip = reinterpret_cast<at::BFloat16*>(&hi);
at::BFloat16* lop = reinterpret_cast<at::BFloat16*>(&lo);
return std::make_tuple(*hip, *lop);
}

void AtenIpexTypeExt::lamb_fused_step_(at::Tensor & param, at::Tensor & grad, at::Tensor & param2, at::Tensor & exp_avg, at::Tensor & exp_avg_sq, int64_t step, float lr, float beta1, float beta2, float weight_decay, float eps){
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad.scalar_type() ==
at::ScalarType::BFloat16);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(param.scalar_type() ==
at::ScalarType::BFloat16);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(param2.scalar_type() ==
at::ScalarType::BFloat16);
RECORD_FUNCTION("ipex::lamb_fused_step", std::vector<c10::IValue>({param, param2, grad}), torch::autograd::Node::peek_at_next_sequence_nr());
at::BFloat16* param_data = param.data_ptr<at::BFloat16>();
float* exp_avg_data = exp_avg.data_ptr<float>();
float* exp_avg_sq_data = exp_avg_sq.data_ptr<float>();
at::BFloat16* grad_data = grad.data_ptr<at::BFloat16>();
at::BFloat16* param2_data = param2.data_ptr<at::BFloat16>();
int num_threads = at::get_num_threads();
float param_norm_acc[num_threads];
float rtw_norm_acc[num_threads];
std::fill_n(&param_norm_acc[0], num_threads, float(0));
std::fill_n(&rtw_norm_acc[0], num_threads, float(0));
int64_t numel = param.numel();
at::Tensor workspace = at::empty({numel}, exp_avg.options());
float* workspace_data = workspace.data_ptr<float>();
int64_t grain_size = 512;

at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
// local pointers
at::BFloat16* param_ptr = param_data + begin;
float* exp_avg_ptr = exp_avg_data + begin;
float* exp_avg_sq_ptr = exp_avg_sq_data + begin;
at::BFloat16* grad_ptr = grad_data + begin;
at::BFloat16* param2_ptr = param2_data + begin;
float* workspace_ptr = workspace_data + begin;
const int64_t size = end - begin;
float sum1_val = float(0);
float sum2_val = float(0);
int64_t d = 0;
for (; d < size; d++) {
float grad_val = float(grad_ptr[d]);
exp_avg_ptr[d] = exp_avg_ptr[d] * beta1 + grad_val * (1 - beta1);
exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] * beta2 + grad_val * grad_val * (1 - beta2);
float adam_step_val = exp_avg_ptr[d] / (std::sqrt(exp_avg_sq_ptr[d]) + eps);

float param_val = pack_bfloat16_float(param_ptr[d], param2_ptr[d]);
//adam_step_val += param_val * weight_decay;
workspace_ptr[d] = adam_step_val;

sum1_val += param_val * param_val;
sum2_val += adam_step_val * adam_step_val;
}
param_norm_acc[tid] = sum1_val;
rtw_norm_acc[tid] = sum2_val;
});
//std::cout<< "grad: " <<grad<<std::endl;
//std::cout <<"param: "<<param <<std::endl;
// std::cout <<"param2: "<<param2 <<std::endl;
float param_norm_sum = float(0);
float rtw_norm_sum = float(0);
for (int64_t tid = 0; tid < num_threads; tid++) {
param_norm_sum += param_norm_acc[tid];
rtw_norm_sum += rtw_norm_acc[tid];
}

float true_ratio = std::min(float(10), std::max(float(0), std::sqrt(param_norm_sum))) / std::sqrt(rtw_norm_sum);
//printf("param_norm_sum= %f, rtw_norm_sum= %f, true_ratio in fused kernel = %f\n", std::min(float(10), std::max(float(0), std::sqrt(param_norm_sum))),rtw_norm_sum,true_ratio);
at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
at::BFloat16* param_ptr = param_data + begin;
at::BFloat16* param2_ptr = param2_data + begin;
float* workspace_ptr = workspace_data + begin;

const int64_t size = end - begin;

int64_t d = 0;
for (; d < size; d++) {
float param_val = pack_bfloat16_float(param_ptr[d], param2_ptr[d]);
param_val -= workspace_ptr[d] * lr * true_ratio;
std::tie(param_ptr[d], param2_ptr[d]) = unpack_float_bfloat16(param_val);
}
});

}

void AtenIpexTypeExt::packed_add_(at::Tensor & top_half, at::Tensor & bot_half, const at::Tensor & grad, float alpha) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad.scalar_type() == at::ScalarType::BFloat16);
Expand Down
1 change: 1 addition & 0 deletions torch_ipex/csrc/cpu/ExtendOPs.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ namespace torch_ipex {

class AtenIpexTypeExt {
public:
static void lamb_fused_step_(at::Tensor & weight, at::Tensor & grad, at::Tensor & bot_half, at::Tensor & exp_avg, at::Tensor & exp_avg_sq, int64_t step, float lr, float beta1, float beta2, float weight_decay, float eps);
static void packed_add_(at::Tensor & top_half, at::Tensor & bot_half, const at::Tensor & grad, float alpha);
static at::Tensor interaction_forward(const std::vector<at::Tensor> & input);
static std::vector<at::Tensor> interaction_backward(const at::Tensor & grad_out, const std::vector<at::Tensor> & input);
Expand Down
5 changes: 5 additions & 0 deletions torch_ipex/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ void InitIpexModuleBindings(py::module m) {
m.def("enable_pure_bf16", []() { AutoOptConfig::singleton().set_pure_bf16(true); });
m.def("disable_pure_bf16", []() { AutoOptConfig::singleton().set_pure_bf16(false); });
m.def("get_pure_bf16", []() { return AutoOptConfig::singleton().get_pure_bf16(); });
m.def("lamb_fused_step_",
[](at::Tensor &param, at::Tensor &grad, at::Tensor & param2, at::Tensor & exp_avg, at::Tensor & exp_avg_sq, int64_t step, float lr, float beta1, float beta2, float weight_decay, float eps) {
AtenIpexTypeExt::lamb_fused_step_(param, grad, param2, exp_avg, exp_avg_sq, step, lr, beta1, beta2, weight_decay, eps);
});

m.def("packed_add_",
[](at::Tensor &top_half, at::Tensor &bot_half,
const at::Tensor &grad, float alpha) {
Expand Down