Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] LISA multi GPU support #899

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,20 @@ class FinetunerArguments(TrainingArguments):
}
)

def __post_init__(self):
super().__post_init__()
if self.use_lisa:
if not self.use_customized_optim:
logger.warning(
"You are using lisa while the `use_customized_optim` is `False`. "
"Setting `use_customized_optim` to `True`.")
self.use_customized_optim = True
if self.customized_optim != "adam":
logger.warning(
"Currently only support adam optimizer when using lisa. "
"Setting `customized_optim` to `adam`.")
self.customized_optim = "adam"

@dataclass
class RewardModelTunerArguments(FinetunerArguments):
"""
Expand Down
143 changes: 138 additions & 5 deletions src/lmflow/pipeline/finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
import os
import sys
from typing import Any, Iterable, Optional, Tuple
from typing import Any, Iterable, Optional, Tuple, Union

import datasets
import transformers
Expand Down Expand Up @@ -211,7 +211,12 @@ def group_texts(examples):

return lm_datasets

def create_customized_optimizer(self, base_trainer_class, model_args):
def create_customized_optimizer(
self,
base_trainer_class: Union[Trainer, PeftTrainer],
model_args,
use_lisa: bool = False,
):
class CustomizedOptimTrainer(base_trainer_class):

@staticmethod
Expand Down Expand Up @@ -407,8 +412,133 @@ def create_optimizer(self):
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)

return CustomizedOptimTrainer

if not use_lisa:
return CustomizedOptimTrainer

class LISATrainer(CustomizedOptimTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
setattr(self.args, '_trainer', self)
print(self.args, flush=True)
print(hasattr(self.args, '_trainer'), flush=True)
self.lisa_optim_initialize_finished = False

def create_optimizer(self):
print('customized trainer create optimizer')
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

if self.optimizer is None:
decay_parameters = self.get_decay_parameter_names(opt_model)
if not self.lisa_optim_initialize_finished:
optimizer_grouped_parameters = [
{
"params": [
p for idx, (n, p) in enumerate(opt_model.named_parameters())
if idx==0
],
"weight_decay": self.args.weight_decay,
},
]
else:
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]

optimizer_cls, optimizer_kwargs = CustomizedOptimTrainer.get_optimizer_cls_and_kwargs(self.args, opt_model)

# Overwrite `params` in case it's created by
# `get_optimizer_cls_and_kwargs` e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop(
"params"
)

# For layer-wise dummy optimizers we overwrite
# optimizer_grouped_parameters with `optimizer_dict` to
# avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop(
"optimizer_dict"
)

self.optimizer = optimizer_cls(
optimizer_grouped_parameters,
**optimizer_kwargs
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)

def switch_optimizer(self):
print('switch_optim')
self.lisa_optim_initialize_finished = True
self.optimizer = None
# self.accelerator.free_memory()
self.create_optimizer()
self.optimizer = self.accelerator.prepare(self.optimizer)

def _save(self, output_dir: Optional[str] = None, state_dict=None):
import torch
from transformers.utils import (
is_safetensors_available,
is_peft_available,
SAFE_WEIGHTS_NAME,
WEIGHTS_NAME
)
if is_safetensors_available():
import safetensors.torch
if is_peft_available():
from peft import PeftModel

# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")

supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()

if isinstance(self.accelerator.unwrap_model(self.model), supported_classes):
self.accelerator.unwrap_model(self.model).save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if self.args.save_safetensors:
safetensors.torch.save_file(
state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)

if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)

# Good practice: save your training arguments together with the trained model
# torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

return LISATrainer

def tune(self,
model,
Expand Down Expand Up @@ -505,7 +635,7 @@ def compute_metrics(eval_preds):
if training_args.use_customized_optim:
BaseTrainer = FinetuningTrainer
FinetuningTrainer = self.create_customized_optimizer(
BaseTrainer, model_args
BaseTrainer, model_args, training_args.use_lisa
)

if training_args.use_lisa:
Expand Down Expand Up @@ -544,6 +674,9 @@ def on_step_begin(self, args, state, control, **kwargs):
# Check if it's time to switch active layers, including at step 0
if state.global_step % self.interval_steps == 0:
self.switch_active_layers()
if hasattr(args, '_trainer'):
trainer = getattr(args, '_trainer')
trainer.switch_optimizer()

def switch_active_layers(self):
# First, disable gradients for all layers
Expand Down
Loading