Skip to content

Commit

Permalink
update checkpoint io for model and ema
Browse files Browse the repository at this point in the history
  • Loading branch information
oahzxl committed Feb 18, 2024
1 parent a6c4787 commit 24b3754
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 22 deletions.
65 changes: 65 additions & 0 deletions dit/models/utils/ckpt_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,71 @@
import json
import logging
import os
from typing import Tuple

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator


def load_json(file_path: str):
with open(file_path, "r") as f:
return json.load(f)


def save_json(data, file_path: str):
with open(file_path, "w") as f:
json.dump(data, f, indent=4)


def save(
booster: Booster,
model: nn.Module,
ema: nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
epoch: int,
step: int,
batch_size: int,
coordinator: DistCoordinator,
save_dir: str,
):
save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}")
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)

booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
# ema is not boosted, so we don't need to use booster.save_model
torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt"))
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
if lr_scheduler is not None:
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
running_states = {
"epoch": epoch,
"step": step,
"sample_start_index": step * batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
dist.barrier()


def load(
booster: Booster, model: nn.Module, ema: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str
) -> Tuple[int, int, int]:
booster.load_model(model, os.path.join(load_dir, "model"))
# ema is not boosted, so we don't use booster.load_model
ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt")))
booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
if lr_scheduler is not None:
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
running_states = load_json(os.path.join(load_dir, "running_states.json"))
dist.barrier()
return running_states["epoch"], running_states["step"], running_states["sample_start_index"]


def create_logger(logging_dir):
Expand Down
62 changes: 41 additions & 21 deletions dit/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from diffusers.models import AutoencoderKL
from models.diffusion import create_diffusion
from models.dit import DiT_models
from models.utils.ckpt_utils import create_logger
from models.utils.ckpt_utils import create_logger, load, save
from models.utils.data_utils import center_crop_arr, prepare_dataloader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
Expand Down Expand Up @@ -109,17 +109,15 @@ def main(args):
# ==============================
# Setup an experiment folder
# ==============================
# Make outputs folder (holds all experiment subfolders)
os.makedirs(args.outputs, exist_ok=True)
experiment_index = len(glob(f"{args.outputs}/*"))
# e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders)
model_string_name = args.model.replace("/", "-")
# Create an experiment folder
experiment_dir = f"{args.outputs}/{experiment_index:03d}-{model_string_name}"
if coordinator.is_master():
# Make results folder (holds all experiment subfolders)
os.makedirs(args.results_dir, exist_ok=True)
experiment_index = len(glob(f"{args.results_dir}/*"))
# e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders)
model_string_name = args.model.replace("/", "-")
# Create an experiment folder
experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}"
# Stores saved model checkpoints
checkpoint_dir = f"{experiment_dir}/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(experiment_dir, exist_ok=True)
with open(f"{experiment_dir}/config.txt", "w") as f:
json.dump(args.__dict__, f, indent=4)
logger = create_logger(experiment_dir)
Expand Down Expand Up @@ -203,6 +201,8 @@ def main(args):
# Setup optimizer
# We used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=0, adamw_mode=True)
# You can use a lr scheduler if you want
lr_scheduler = None

# Prepare models for training
# Ensure EMA is initialized with synced weights
Expand All @@ -212,11 +212,6 @@ def main(args):
# EMA model should always be in eval mode
ema.eval()

# Boost model for distributed training
torch.set_default_dtype(dtype)
model, optimizer, _, _, _ = booster.boost(model=model, optimizer=optimizer)
torch.set_default_dtype(torch.float)

# Setup data:
transform = transforms.Compose(
[
Expand All @@ -237,10 +232,22 @@ def main(args):
)
logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})")

# Boost model for distributed training
torch.set_default_dtype(dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, dataloader=dataloader
)
torch.set_default_dtype(torch.float)

# Variables for monitoring/logging purposes:
start_epoch = 0
start_step = 0
sampler_start_idx = 0
if args.load is not None:
logger.info("Loading checkpoint")
start_epoch, start_step, sampler_start_idx = load(booster, model, ema, optimizer, lr_scheduler, args.load)
logger.info(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}")

num_steps_per_epoch = len(dataloader)

logger.info(f"Training for {args.epochs} epochs...")
Expand Down Expand Up @@ -287,8 +294,20 @@ def main(args):
writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step)

if args.ckpt_every > 0 and (step + 1) % args.ckpt_every == 0:
coordinator.print_on_master(f"Saving checkpoint")
coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}")
logger.info(f"Saving checkpoint")
save(
booster,
model,
ema,
optimizer,
lr_scheduler,
epoch,
step + 1,
args.batch_size,
coordinator,
experiment_dir,
)
logger.info(f"Saved checkpoint at epoch {epoch} step {step + 1} to {experiment_dir}")

# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(0)
Expand All @@ -307,7 +326,8 @@ def main(args):
parser.add_argument(
"--plugin", type=str, default="zero2", choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"]
)
parser.add_argument("--results-dir", type=str, default="results")
parser.add_argument("--outputs", type=str, default="outputs")
parser.add_argument("--load", type=str, default=None)
parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2")
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
parser.add_argument("--num-classes", type=int, default=1000)
Expand All @@ -316,8 +336,8 @@ def main(args):
parser.add_argument("--global-seed", type=int, default=0)
parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--log-every", type=int, default=100)
parser.add_argument("--ckpt-every", type=int, default=50_000)
parser.add_argument("--log-every", type=int, default=2)
parser.add_argument("--ckpt-every", type=int, default=10)
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["bf16", "fp16"])
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--lr", type=float, default=1e-4, help="Gradient clipping value")
Expand Down
2 changes: 1 addition & 1 deletion dit/train.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
torchrun --nnodes=1 --nproc_per_node=1 train.py --model DiT-XL/2 --grad_checkpoint
torchrun --nnodes=1 --nproc_per_node=2 train.py --model DiT-XL/2 --grad_checkpoint

0 comments on commit 24b3754

Please sign in to comment.