diff --git a/.gitignore b/.gitignore index 81113fa9..78d9c302 100644 --- a/.gitignore +++ b/.gitignore @@ -159,3 +159,6 @@ coverage.xml # ignore testmon and coverage files .coverage .testmondata* + +# ignore data files +datasets diff --git a/dit/models/utils/ckpt_utils.py b/dit/models/utils/ckpt_utils.py new file mode 100644 index 00000000..c13c459e --- /dev/null +++ b/dit/models/utils/ckpt_utils.py @@ -0,0 +1,21 @@ +import logging + +import torch.distributed as dist + + +def create_logger(logging_dir): + """ + Create a logger that writes to a log file and stdout. + """ + if dist.get_rank() == 0: # real logger + logging.basicConfig( + level=logging.INFO, + format="[\033[34m%(asctime)s\033[0m] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")], + ) + logger = logging.getLogger(__name__) + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger diff --git a/dit/models/utils/data_utils.py b/dit/models/utils/data_utils.py new file mode 100644 index 00000000..93a3f109 --- /dev/null +++ b/dit/models/utils/data_utils.py @@ -0,0 +1,110 @@ +import random +from typing import Iterator, Optional + +import numpy as np +import torch +from PIL import Image +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import _get_default_group +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from torch.utils.data.distributed import DistributedSampler + + +class StatefulDistributedSampler(DistributedSampler): + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) + self.start_index: int = 0 + + def __iter__(self) -> Iterator: + iterator = super().__iter__() + indices = list(iterator) + indices = indices[self.start_index :] + return iter(indices) + + def __len__(self) -> int: + return self.num_samples - self.start_index + + def set_start_index(self, start_index: int) -> None: + self.start_index = start_index + + +def prepare_dataloader( + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + process_group: Optional[ProcessGroup] = None, + **kwargs, +): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `StatefulDistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + process_group = process_group or _get_default_group() + sampler = StatefulDistributedSampler( + dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle + ) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) + + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) diff --git a/dit/train.py b/dit/train.py index 0b93698b..8d5eb4f7 100644 --- a/dit/train.py +++ b/dit/train.py @@ -7,28 +7,24 @@ """ A minimal training script for DiT using PyTorch DDP. """ -import torch - -# the first flag below was False when we tested this script but True makes A100 training a lot faster: -torch.backends.cuda.matmul.allow_tf32 = True -torch.backends.cudnn.allow_tf32 = True import argparse import json -import logging import os from collections import OrderedDict from copy import deepcopy from glob import glob -from time import time -import numpy as np +import torch import torch.distributed as dist -import tqdm from diffusers.models import AutoencoderKL from models.diffusion import create_diffusion from models.dit import DiT_models -from PIL import Image +from models.utils.ckpt_utils import create_logger +from models.utils.data_utils import center_crop_arr, prepare_dataloader from torch.utils.tensorboard import SummaryWriter +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from tqdm import tqdm import colossalai from colossalai.booster import Booster @@ -37,6 +33,10 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device +# the first flag below was False when we tested this script but True makes A100 training a lot faster: +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + ################################################################################# # Training Helper Functions # ################################################################################# @@ -75,7 +75,8 @@ def update_ema(ema_model, model, decay=0.9999): model_params = OrderedDict(model.named_parameters()) for name, param in model_params.items(): - # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed + # TODO: Consider applying only to params that require_grad + # to avoid small numerical changes of pos_embed ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) @@ -87,48 +88,6 @@ def requires_grad(model, flag=True): p.requires_grad = flag -def cleanup(): - """ - End DDP training. - """ - dist.destroy_process_group() - - -def create_logger(logging_dir): - """ - Create a logger that writes to a log file and stdout. - """ - if dist.get_rank() == 0: # real logger - logging.basicConfig( - level=logging.INFO, - format="[\033[34m%(asctime)s\033[0m] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")], - ) - logger = logging.getLogger(__name__) - else: # dummy logger (does nothing) - logger = logging.getLogger(__name__) - logger.addHandler(logging.NullHandler()) - return logger - - -def center_crop_arr(pil_image, image_size): - """ - Center cropping implementation from ADM. - https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 - """ - while min(*pil_image.size) >= 2 * image_size: - pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) - - scale = image_size / min(*pil_image.size) - pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) - - arr = np.array(pil_image) - crop_y = (arr.shape[0] - image_size) // 2 - crop_x = (arr.shape[1] - image_size) // 2 - return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) - - ################################################################################# # Training Loop # ################################################################################# @@ -145,18 +104,21 @@ def main(args): # ============================== colossalai.launch_from_torch({}) coordinator = DistCoordinator() - rank = dist.get_rank() device = get_current_device() # ============================== # Setup an experiment folder # ============================== if coordinator.is_master(): - os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders) + # Make results folder (holds all experiment subfolders) + os.makedirs(args.results_dir, exist_ok=True) experiment_index = len(glob(f"{args.results_dir}/*")) - model_string_name = args.model.replace("/", "-") # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders) - experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder - checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints + # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders) + model_string_name = args.model.replace("/", "-") + # Create an experiment folder + experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" + # Stores saved model checkpoints + checkpoint_dir = f"{experiment_dir}/checkpoints" os.makedirs(checkpoint_dir, exist_ok=True) with open(f"{experiment_dir}/config.txt", "w") as f: json.dump(args.__dict__, f, indent=4) @@ -171,7 +133,7 @@ def main(args): if coordinator.is_master(): tensorboard_dir = f"{experiment_dir}/tensorboard" os.makedirs(tensorboard_dir, exist_ok=True) - SummaryWriter(tensorboard_dir) + writer = SummaryWriter(tensorboard_dir) # ============================== # Initialize Booster @@ -231,18 +193,24 @@ def main(args): # Create ema and vae model # Note that parameter initialization is done within the DiT constructor - ema = deepcopy(model).to(device) # Create an EMA of the model for use after training + # Create an EMA of the model for use after training + ema = deepcopy(model).to(device) requires_grad(ema, False) - diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule + # default: 1000 steps, linear noise schedule + diffusion = create_diffusion(timestep_respacing="") vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) - # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper) + # Setup optimizer + # We used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=0, adamw_mode=True) # Prepare models for training - update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights - model.train() # important! This enables embedding dropout for classifier-free guidance - ema.eval() # EMA model should always be in eval mode + # Ensure EMA is initialized with synced weights + update_ema(ema, model, decay=0) + # important! This enables embedding dropout for classifier-free guidance + model.train() + # EMA model should always be in eval mode + ema.eval() # Boost model for distributed training torch.set_default_dtype(dtype) @@ -250,116 +218,92 @@ def main(args): torch.set_default_dtype(torch.float) # Setup data: - # transform = transforms.Compose([ - # transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), - # transforms.RandomHorizontalFlip(), - # transforms.ToTensor(), - # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) - # ]) - # dataset = ImageFolder(args.data_path, transform=transform) - # sampler = DistributedSampler( - # dataset, - # num_replicas=dist.get_world_size(), - # rank=rank, - # shuffle=True, - # seed=args.global_seed - # ) - # loader = DataLoader( - # dataset, - # batch_size=int(args.global_batch_size // dist.get_world_size()), - # shuffle=False, - # sampler=sampler, - # num_workers=args.num_workers, - # pin_memory=True, - # drop_last=True - # ) - # logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})") + transform = transforms.Compose( + [ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + dataset = CIFAR10(args.data_path, transform=transform, download=True) + dataloader = prepare_dataloader( + dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True, + num_workers=args.num_workers, + ) + logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})") # Variables for monitoring/logging purposes: - train_steps = 0 - log_steps = 0 - running_loss = 0 - start_time = time() - - batch = int(args.global_batch_size // dist.get_world_size()) - - # TODO: load ckpt + start_epoch = 0 + start_step = 0 + sampler_start_idx = 0 + num_steps_per_epoch = len(dataloader) logger.info(f"Training for {args.epochs} epochs...") - for epoch in range(args.epochs): - # sampler.set_epoch(epoch) + # if resume training, set the sampler start index to the correct value + dataloader.sampler.set_start_index(sampler_start_idx) + for epoch in range(start_epoch, args.epochs): + dataloader.sampler.set_epoch(epoch) + dataloader_iter = iter(dataloader) logger.info(f"Beginning epoch {epoch}...") - # for x, y in loader: - for _ in tqdm.tqdm(range(100)): - x = torch.randn(batch, 3, args.image_size, args.image_size).to(device) - y = torch.randint(0, 1000, (batch,)).to(device) - # x = x.to(device) - # y = y.to(device) - - # VAE encode - with torch.no_grad(): - # Map input images to latent space + normalize latents: - x = vae.encode(x).latent_dist.sample().mul_(0.18215) - x = x.to(dtype) - - # Diffusion - t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device) - model_kwargs = dict(y=y) - loss_dict = diffusion.training_losses(model, x, t, model_kwargs) - loss = loss_dict["loss"].mean() - booster.backward(loss=loss, optimizer=optimizer) - optimizer.step() - optimizer.zero_grad() - - # Update EMA - update_ema(ema, model.module) - - # Log loss values: - running_loss += loss.item() - log_steps += 1 - train_steps += 1 - if train_steps % args.log_every == 0: - # Measure training speed: - torch.cuda.synchronize() - end_time = time() - steps_per_sec = log_steps / (end_time - start_time) - # Reduce loss history over all processes: - avg_loss = torch.tensor(running_loss / log_steps, device=device) - dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) - avg_loss = avg_loss.item() / dist.get_world_size() - logger.info( - f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}" - ) - # Reset monitoring variables: - running_loss = 0 - log_steps = 0 - start_time = time() - - # Save DiT checkpoint: - if train_steps % args.ckpt_every == 0 and train_steps > 0: - if rank == 0: - checkpoint = { - "model": model.module.state_dict(), - "ema": ema.state_dict(), - "opt": optimizer.state_dict(), - "args": args, - } - checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" - torch.save(checkpoint, checkpoint_path) - logger.info(f"Saved checkpoint to {checkpoint_path}") - dist.barrier() + with tqdm( + range(start_step, num_steps_per_epoch), + desc=f"Epoch {epoch}", + disable=not coordinator.is_master(), + total=num_steps_per_epoch, + initial=start_step, + ) as pbar: + for step in pbar: + x, y = next(dataloader_iter) + x = x.to(device) + y = y.to(device) + + # VAE encode + with torch.no_grad(): + # Map input images to latent space + normalize latents: + x = vae.encode(x).latent_dist.sample().mul_(0.18215) + x = x.to(dtype) + + # Diffusion + t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device) + model_kwargs = dict(y=y) + loss_dict = diffusion.training_losses(model, x, t, model_kwargs) + loss = loss_dict["loss"].mean() + booster.backward(loss=loss, optimizer=optimizer) + optimizer.step() + optimizer.zero_grad() + + # Update EMA + update_ema(ema, model.module) + + # Log loss values: + all_reduce_mean(loss) + if coordinator.is_master() and (step + 1) % args.log_every == 0: + pbar.set_postfix({"loss": loss.item()}) + writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step) + + if args.ckpt_every > 0 and (step + 1) % args.ckpt_every == 0: + coordinator.print_on_master(f"Saving checkpoint") + coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}") + + # the continue epochs are not resumed, so we need to reset the sampler start index and start step + dataloader.sampler.set_start_index(0) + start_step = 0 model.eval() # important! This disables randomized embedding dropout # do any sampling/FID calculation/etc. with ema (or model) in eval mode ... logger.info("Done!") - cleanup() if __name__ == "__main__": # Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters). parser = argparse.ArgumentParser() - parser.add_argument("--data-path", type=str) + parser.add_argument("--data-path", type=str, default="./datasets") parser.add_argument( "--plugin", type=str, default="zero2", choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"] ) @@ -368,7 +312,7 @@ def main(args): parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) parser.add_argument("--num-classes", type=int, default=1000) parser.add_argument("--epochs", type=int, default=1400) - parser.add_argument("--global-batch-size", type=int, default=16) + parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--global-seed", type=int, default=0) parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training parser.add_argument("--num-workers", type=int, default=4)