diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index dec024520..c5a2feb0e 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -220,72 +220,78 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche # Update the learning rate as needed lr_scheduler.step() + should_save_model = False if train_config.run_validation: eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run) if train_config.save_metrics: val_step_loss.extend(temp_val_loss) val_step_perplexity.extend(temp_step_perplexity) - + should_save_model = eval_epoch_loss < best_val_loss + else: + should_save_model = True + + if should_save_model: checkpoint_start_time = time.perf_counter() - if train_config.save_model and eval_epoch_loss < best_val_loss: + if train_config.enable_fsdp: + dist.barrier() + if train_config.use_peft: if train_config.enable_fsdp: - dist.barrier() - if train_config.use_peft: - if train_config.enable_fsdp: - if rank==0: - print(f"we are about to save the PEFT modules") - else: + if rank==0: print(f"we are about to save the PEFT modules") - save_peft_checkpoint(model, train_config.output_dir) - if train_config.enable_fsdp: - if rank==0: - print(f"PEFT modules are saved in {train_config.output_dir} directory") - else: + else: + print(f"we are about to save the PEFT modules") + save_peft_checkpoint(model, train_config.output_dir) + if train_config.enable_fsdp: + if rank==0: print(f"PEFT modules are saved in {train_config.output_dir} directory") - else: - if not train_config.enable_fsdp: - save_model_checkpoint(model, train_config.output_dir) - - elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: - print(" Saving the FSDP model checkpoint using FULL_STATE_DICT") + print(f"PEFT modules are saved in {train_config.output_dir} directory") + + else: + if not train_config.enable_fsdp: + save_model_checkpoint(model, train_config.output_dir) + + elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: + print(" Saving the FSDP model checkpoint using FULL_STATE_DICT") + print("=====================================================") + save_fsdp_model_checkpoint_full( + model, optimizer, rank, train_config, epoch=epoch + ) + + if train_config.save_optimizer: + print(" Saving the FSDP optimizer using FULL_STATE_DICT") print("=====================================================") - save_fsdp_model_checkpoint_full( + save_optimizer_checkpoint( model, optimizer, rank, train_config, epoch=epoch ) - - if train_config.save_optimizer: - print(" Saving the FSDP optimizer using FULL_STATE_DICT") - print("=====================================================") - save_optimizer_checkpoint( - model, optimizer, rank, train_config, epoch=epoch - ) - - elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: - - if train_config.save_optimizer: - print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT") - print("=====================================================") - save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer) - else: - print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT") - print("=====================================================") - save_model_and_optimizer_sharded(model, rank, train_config) + + elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: - - if train_config.enable_fsdp: - dist.barrier() + if train_config.save_optimizer: + print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT") + print("=====================================================") + save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer) + else: + print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT") + print("=====================================================") + save_model_and_optimizer_sharded(model, rank, train_config) + + + if train_config.enable_fsdp: + dist.barrier() checkpoint_end_time = time.perf_counter() - checkpoint_start_time checkpoint_times.append(checkpoint_end_time) - if eval_epoch_loss < best_val_loss: - best_val_loss = eval_epoch_loss - if train_config.enable_fsdp: - if rank==0: + + if train_config.run_validation: + if eval_epoch_loss < best_val_loss: + best_val_loss = eval_epoch_loss + if train_config.enable_fsdp: + if rank==0: + print(f"best eval loss on epoch {epoch+1} is {best_val_loss}") + else: print(f"best eval loss on epoch {epoch+1} is {best_val_loss}") - else: - print(f"best eval loss on epoch {epoch+1} is {best_val_loss}") - val_loss.append(float(best_val_loss)) - val_prep.append(float(eval_ppl)) + val_loss.append(float(best_val_loss)) + val_prep.append(float(eval_ppl)) if train_config.enable_fsdp: if rank==0: print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")