Skip to content

Commit

Permalink
[Feature] Add a switch to control whether the model checkpoint needs …
Browse files Browse the repository at this point in the history
…to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
zhurunhua and pre-commit-ci[bot] authored Jul 26, 2024
1 parent 2069472 commit ad35a98
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions applications/Colossal-LLaMA/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def main() -> None:
parser.add_argument("--zero", type=int, default=1)
parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
parser.add_argument(
"--skip_save_each_epoch",
action="store_true",
default=False,
help="skip saving the model checkpoint after each epoch is completed.",
)
args = parser.parse_args()

with open(args.config_file, "w") as f:
Expand Down Expand Up @@ -370,11 +376,17 @@ def main() -> None:
)
total_loss.fill_(0.0)
pbar.update()

# Save modeling.

if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
step + 1
) == len(dataloader):
save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)

if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)

if save_model_condition:
coordinator.print_on_master("\nStart saving model checkpoint with running states")

if args.use_neft:
Expand Down

0 comments on commit ad35a98

Please sign in to comment.