From 54f5e29540174b05f3939573409a151fb6041b6e Mon Sep 17 00:00:00 2001 From: Yizhen Date: Sun, 15 Sep 2024 02:05:52 +0800 Subject: [PATCH 1/2] lisa multigpu test --- src/lmflow/args.py | 14 +++++ src/lmflow/pipeline/finetuner.py | 95 ++++++++++++++++++++++++++++++-- 2 files changed, 104 insertions(+), 5 deletions(-) diff --git a/src/lmflow/args.py b/src/lmflow/args.py index 48cf913d4..00a015cca 100644 --- a/src/lmflow/args.py +++ b/src/lmflow/args.py @@ -755,6 +755,20 @@ class FinetunerArguments(TrainingArguments): } ) + def __post_init__(self): + super().__post_init__() + if self.use_lisa: + if not self.use_customized_optim: + logger.warning( + "You are using lisa while the `use_customized_optim` is `False`. " + "Setting `use_customized_optim` to `True`.") + self.use_customized_optim = True + if self.customized_optim != "adam": + logger.warning( + "Currently only support adam optimizer when using lisa. " + "Setting `customized_optim` to `adam`.") + self.customized_optim = "adam" + @dataclass class RewardModelTunerArguments(FinetunerArguments): """ diff --git a/src/lmflow/pipeline/finetuner.py b/src/lmflow/pipeline/finetuner.py index de5eb6279..5b0307269 100644 --- a/src/lmflow/pipeline/finetuner.py +++ b/src/lmflow/pipeline/finetuner.py @@ -7,7 +7,7 @@ import logging import os import sys -from typing import Any, Iterable, Optional, Tuple +from typing import Any, Iterable, Optional, Tuple, Union import datasets import transformers @@ -211,7 +211,12 @@ def group_texts(examples): return lm_datasets - def create_customized_optimizer(self, base_trainer_class, model_args): + def create_customized_optimizer( + self, + base_trainer_class: Union[Trainer, PeftTrainer], + model_args, + use_lisa: bool = False, + ): class CustomizedOptimTrainer(base_trainer_class): @staticmethod @@ -407,8 +412,85 @@ def create_optimizer(self): ) if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer(self.optimizer) - - return CustomizedOptimTrainer + + if not use_lisa: + return CustomizedOptimTrainer + + class LISATrainer(CustomizedOptimTrainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + setattr(self.args, '_trainer', self) + print(self.args, flush=True) + print(hasattr(self.args, '_trainer'), flush=True) + self.lisa_optim_initialize_finished = False + + def create_optimizer(self): + print('customized trainer create optimizer') + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + + if self.optimizer is None: + decay_parameters = self.get_decay_parameter_names(opt_model) + if not self.lisa_optim_initialize_finished: + optimizer_grouped_parameters = [ + { + "params": [ + p for idx, (n, p) in enumerate(opt_model.named_parameters()) + if idx==0 + ], + "weight_decay": self.args.weight_decay, + }, + ] + else: + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in opt_model.named_parameters() + if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() + if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + + optimizer_cls, optimizer_kwargs = CustomizedOptimTrainer.get_optimizer_cls_and_kwargs(self.args, opt_model) + + # Overwrite `params` in case it's created by + # `get_optimizer_cls_and_kwargs` e.g. for GaLore optimizer. + if "params" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop( + "params" + ) + + # For layer-wise dummy optimizers we overwrite + # optimizer_grouped_parameters with `optimizer_dict` to + # avoid arguments conflicts. + if "optimizer_dict" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop( + "optimizer_dict" + ) + + self.optimizer = optimizer_cls( + optimizer_grouped_parameters, + **optimizer_kwargs + ) + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer(self.optimizer) + + def switch_optimizer(self): + print('switch_optim') + self.lisa_optim_initialize_finished = True + self.optimizer = None + # self.accelerator.free_memory() + self.create_optimizer() + self.optimizer = self.accelerator.prepare(self.optimizer) + + return LISATrainer def tune(self, model, @@ -505,7 +587,7 @@ def compute_metrics(eval_preds): if training_args.use_customized_optim: BaseTrainer = FinetuningTrainer FinetuningTrainer = self.create_customized_optimizer( - BaseTrainer, model_args + BaseTrainer, model_args, training_args.use_lisa ) if training_args.use_lisa: @@ -544,6 +626,9 @@ def on_step_begin(self, args, state, control, **kwargs): # Check if it's time to switch active layers, including at step 0 if state.global_step % self.interval_steps == 0: self.switch_active_layers() + if hasattr(args, '_trainer'): + trainer = getattr(args, '_trainer') + trainer.switch_optimizer() def switch_active_layers(self): # First, disable gradients for all layers From 72379a9f734110dc1caed610908706df1beecc7e Mon Sep 17 00:00:00 2001 From: Yizhen Date: Wed, 25 Sep 2024 12:33:38 +0800 Subject: [PATCH 2/2] [bug fix] temporarily disable saving training args when lisa --- src/lmflow/pipeline/finetuner.py | 48 ++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/src/lmflow/pipeline/finetuner.py b/src/lmflow/pipeline/finetuner.py index 5b0307269..05d1b650b 100644 --- a/src/lmflow/pipeline/finetuner.py +++ b/src/lmflow/pipeline/finetuner.py @@ -490,6 +490,54 @@ def switch_optimizer(self): self.create_optimizer() self.optimizer = self.accelerator.prepare(self.optimizer) + def _save(self, output_dir: Optional[str] = None, state_dict=None): + import torch + from transformers.utils import ( + is_safetensors_available, + is_peft_available, + SAFE_WEIGHTS_NAME, + WEIGHTS_NAME + ) + if is_safetensors_available(): + import safetensors.torch + if is_peft_available(): + from peft import PeftModel + + # If we are executing this function, we are the process zero, so we don't check for that. + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving model checkpoint to {output_dir}") + + supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + if not isinstance(self.model, supported_classes): + if state_dict is None: + state_dict = self.model.state_dict() + + if isinstance(self.accelerator.unwrap_model(self.model), supported_classes): + self.accelerator.unwrap_model(self.model).save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + if self.args.save_safetensors: + safetensors.torch.save_file( + state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"} + ) + else: + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + self.model.save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) + + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + + # Good practice: save your training arguments together with the trained model + # torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + return LISATrainer def tune(self,