Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve model checkpoint saving logic #691

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 55 additions & 49 deletions src/llama_recipes/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,72 +220,78 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche

# Update the learning rate as needed
lr_scheduler.step()
should_save_model = False
if train_config.run_validation:
eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
if train_config.save_metrics:
val_step_loss.extend(temp_val_loss)
val_step_perplexity.extend(temp_step_perplexity)

should_save_model = eval_epoch_loss < best_val_loss
else:
should_save_model = True

if should_save_model:
checkpoint_start_time = time.perf_counter()
if train_config.save_model and eval_epoch_loss < best_val_loss:
if train_config.enable_fsdp:
dist.barrier()
if train_config.use_peft:
if train_config.enable_fsdp:
dist.barrier()
if train_config.use_peft:
if train_config.enable_fsdp:
if rank==0:
print(f"we are about to save the PEFT modules")
else:
if rank==0:
print(f"we are about to save the PEFT modules")
save_peft_checkpoint(model, train_config.output_dir)
if train_config.enable_fsdp:
if rank==0:
print(f"PEFT modules are saved in {train_config.output_dir} directory")
else:
else:
print(f"we are about to save the PEFT modules")
save_peft_checkpoint(model, train_config.output_dir)
if train_config.enable_fsdp:
if rank==0:
print(f"PEFT modules are saved in {train_config.output_dir} directory")

else:
if not train_config.enable_fsdp:
save_model_checkpoint(model, train_config.output_dir)

elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
print(" Saving the FSDP model checkpoint using FULL_STATE_DICT")
print(f"PEFT modules are saved in {train_config.output_dir} directory")

else:
if not train_config.enable_fsdp:
save_model_checkpoint(model, train_config.output_dir)

elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
print(" Saving the FSDP model checkpoint using FULL_STATE_DICT")
print("=====================================================")
save_fsdp_model_checkpoint_full(
model, optimizer, rank, train_config, epoch=epoch
)

if train_config.save_optimizer:
print(" Saving the FSDP optimizer using FULL_STATE_DICT")
print("=====================================================")
save_fsdp_model_checkpoint_full(
save_optimizer_checkpoint(
model, optimizer, rank, train_config, epoch=epoch
)

if train_config.save_optimizer:
print(" Saving the FSDP optimizer using FULL_STATE_DICT")
print("=====================================================")
save_optimizer_checkpoint(
model, optimizer, rank, train_config, epoch=epoch
)

elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:

if train_config.save_optimizer:
print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
print("=====================================================")
save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
else:
print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
print("=====================================================")
save_model_and_optimizer_sharded(model, rank, train_config)

elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:


if train_config.enable_fsdp:
dist.barrier()
if train_config.save_optimizer:
print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
print("=====================================================")
save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
else:
print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
print("=====================================================")
save_model_and_optimizer_sharded(model, rank, train_config)


if train_config.enable_fsdp:
dist.barrier()
checkpoint_end_time = time.perf_counter() - checkpoint_start_time
checkpoint_times.append(checkpoint_end_time)
if eval_epoch_loss < best_val_loss:
best_val_loss = eval_epoch_loss
if train_config.enable_fsdp:
if rank==0:

if train_config.run_validation:
if eval_epoch_loss < best_val_loss:
best_val_loss = eval_epoch_loss
if train_config.enable_fsdp:
if rank==0:
print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
else:
print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
else:
print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
val_loss.append(float(best_val_loss))
val_prep.append(float(eval_ppl))
val_loss.append(float(best_val_loss))
val_prep.append(float(eval_ppl))
if train_config.enable_fsdp:
if rank==0:
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
Expand Down