From 7e361d0079ec0a82f4b9f11799583bcbb16dd946 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Mon, 26 Aug 2024 22:40:54 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20EDM=20plugin=20(#9)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- azula/plugins/adm/__init__.py | 35 +++++------- azula/plugins/adm/database.py | 24 +++++--- azula/plugins/edm/__init__.py | 103 ++++++++++++++++++++++++++++++++++ azula/plugins/edm/database.py | 28 +++++++++ docs/index.rst | 2 +- docs/tutorials/guidance.ipynb | 2 +- 7 files changed, 164 insertions(+), 32 deletions(-) create mode 100644 azula/plugins/edm/__init__.py create mode 100644 azula/plugins/edm/database.py diff --git a/README.md b/README.md index cb02eb5..ab601ae 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,7 @@ from azula.plugins import adm from azula.sample import DDIMSampler # Download weights from openai/guided-diffusion -denoiser = adm.load_model("imagenet_256x256_uncond") +denoiser = adm.load_model("imagenet_256x256") # Generate a batch of 4 images sampler = DDIMSampler(denoiser, steps=64).cuda() diff --git a/azula/plugins/adm/__init__.py b/azula/plugins/adm/__init__.py index 0513ddb..97658f6 100644 --- a/azula/plugins/adm/__init__.py +++ b/azula/plugins/adm/__init__.py @@ -8,11 +8,13 @@ git clone https://github.com/openai/guided-diffusion -and add it to your Python path. +and add it to your Python path before importing the plugin. .. code-block:: python import sys; sys.path.append("path/to/guided-diffusion") + ... + from azula.plugins import adm References: | Diffusion Models Beat GANs on Image Synthesis (Dhariwal et al., 2021) @@ -24,7 +26,6 @@ "ImprovedDenoiser", "list_models", "load_model", - "make_model", ] import numpy as np @@ -163,7 +164,7 @@ def list_models() -> List[str]: return database.keys() -def load_model(key: str, **kwargs) -> ImprovedDenoiser: +def load_model(key: str, **kwargs) -> GaussianDenoiser: r"""Loads a pre-trained ADM model. Arguments: @@ -193,31 +194,21 @@ def make_model( # Schedule schedule_name: str = "linear", timesteps: int = 1000, + # Data + image_channels: int = 3, + image_size: int = 64, # Backbone attention_resolutions: Set[int] = {32, 16, 8}, # noqa: B006 channel_mult: Sequence[int] = (1, 2, 3, 4), dropout: float = 0.0, - image_size: int = 64, num_channels: int = 128, num_classes: int = None, num_heads: int = 1, num_head_channels: int = 64, num_res_blocks: int = 3, **kwargs, -) -> ImprovedDenoiser: - r"""Builds an ADM model. - - Arguments: - learned_var: Whether the variance term is learned or not. - schedule_name: The beta schedule name. - timesteps: The number of schedule time steps. - - The remaining arguments are for the :class:`guided_diffusion.unet.UNetModel` - backbone. - - Returns: - A denoiser. - """ +) -> GaussianDenoiser: + r"""Instantiates an ADM denoiser.""" kwargs.setdefault("resblock_updown", True) kwargs.setdefault("use_fp16", False) @@ -229,8 +220,8 @@ def make_model( backbone = FlattenWrapper( wrappee=unet.UNetModel( image_size=image_size, - in_channels=3, - out_channels=6 if learned_var else 3, + in_channels=image_channels, + out_channels=2 * image_channels if learned_var else image_channels, model_channels=num_channels, channel_mult=channel_mult, num_classes=num_classes, @@ -241,12 +232,12 @@ def make_model( dropout=dropout, **kwargs, ), - shape=(3, image_size, image_size), + shape=(image_channels, image_size, image_size), ) schedule = BetaSchedule(name=schedule_name, steps=timesteps) - return ImprovedDenoiser(backbone=backbone, schedule=schedule) + return ImprovedDenoiser(backbone, schedule) # fmt: off diff --git a/azula/plugins/adm/database.py b/azula/plugins/adm/database.py index b4993d8..b007f99 100644 --- a/azula/plugins/adm/database.py +++ b/azula/plugins/adm/database.py @@ -20,14 +20,15 @@ def keys() -> List[str]: URLS = { - "imagenet_64x64": "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/64x64_diffusion.pt", - "imagenet_256x256": "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion.pt", - "imagenet_256x256_uncond": "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt", + "imagenet_64x64_cond": "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/64x64_diffusion.pt", + "imagenet_256x256": "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt", + "imagenet_256x256_cond": "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion.pt", + "imagenet_512x512_cond": "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/512x512_diffusion.pt", "ffhq_256x256": "https://drive.google.com/uc?id=1BGwhRWUoguF-D8wlZ65tf227gp3cDUDh", } CONFIGS = { - "imagenet_64x64": { + "imagenet_64x64_cond": { "schedule_name": "cosine", "attention_resolutions": {32, 16, 8}, "channel_mult": (1, 2, 3, 4), @@ -44,16 +45,25 @@ def keys() -> List[str]: "channel_mult": (1, 1, 2, 2, 4, 4), "image_size": 256, "num_channels": 256, - "num_classes": 1000, + "num_classes": None, "num_head_channels": 64, "num_res_blocks": 2, }, - "imagenet_256x256_uncond": { + "imagenet_256x256_cond": { "attention_resolutions": {32, 16, 8}, "channel_mult": (1, 1, 2, 2, 4, 4), "image_size": 256, "num_channels": 256, - "num_classes": None, + "num_classes": 1000, + "num_head_channels": 64, + "num_res_blocks": 2, + }, + "imagenet_512x512_cond": { + "attention_resolutions": {32, 16, 8}, + "channel_mult": (0.5, 1, 1, 2, 2, 4, 4), + "image_size": 256, + "num_channels": 256, + "num_classes": 1000, "num_head_channels": 64, "num_res_blocks": 2, }, diff --git a/azula/plugins/edm/__init__.py b/azula/plugins/edm/__init__.py new file mode 100644 index 0000000..e3c8156 --- /dev/null +++ b/azula/plugins/edm/__init__.py @@ -0,0 +1,103 @@ +r"""Elucidated diffusion model (EDM) plugin. + +This plugin depends on the `dnnlib`, `torch_utils` and `training` modules in the +`NVlabs/edm `_ repository. To use it, clone the +repository to you machine + +.. code-block:: console + + git clone https://github.com/NVlabs/edm + +and add it to your Python path before importing the plugin. + +.. code-block:: python + + import sys; sys.path.append("path/to/edm") + ... + from azula.plugins import edm + +References: + | Elucidating the Design Space of Diffusion-Based Generative Models (Karras et al., 2022) + | https://arxiv.org/abs/2206.00364 +""" + +__all__ = [ + "ElucidatedDenoiser", + "list_models", + "load_model", +] + +import pickle +import re +import torch.nn as nn + +from azula.denoise import Gaussian, GaussianDenoiser +from azula.hub import download +from azula.nn.utils import FlattenWrapper +from azula.noise import VESchedule +from torch import Tensor +from typing import List, Optional + +# isort: split +from . import database + + +class ElucidatedDenoiser(GaussianDenoiser): + r"""Creates an elucidated denoiser. + + Arguments: + backbone: A noise conditional network. + schedule: A variance exploding (VE) schedule. + """ + + def __init__(self, backbone: nn.Module, schedule: Optional[VESchedule] = None): + super().__init__() + + self.backbone = backbone + + if schedule is None: + self.schedule = VESchedule() + else: + self.schedule = schedule + + def forward(self, x_t: Tensor, t: Tensor, **kwargs) -> Gaussian: + _, sigma_t = self.schedule(t) # alpha_t = 1 + + mean = self.backbone(x_t, sigma_t.squeeze(-1), **kwargs) + var = sigma_t**2 / (1 + sigma_t**2) + + return Gaussian(mean=mean, var=var) + + +def list_models() -> List[str]: + r"""Returns the list of available pre-trained models.""" + + return database.keys() + + +def load_model(key: str) -> GaussianDenoiser: + r"""Loads a pre-trained EDM model. + + Arguments: + key: The pre-trained model key. + + Returns: + A pre-trained denoiser. + """ + + url = database.get(key) + filename = download(url) + + with open(filename, "rb") as f: + model = pickle.load(f)["ema"] + model.eval() + + image_size = re.search(r"(\d+)x(\d+)", key).groups() + image_size = map(int, image_size) + + return ElucidatedDenoiser( + backbone=FlattenWrapper( + wrappee=model, + shape=(3, *image_size), + ), + ) diff --git a/azula/plugins/edm/database.py b/azula/plugins/edm/database.py new file mode 100644 index 0000000..1014034 --- /dev/null +++ b/azula/plugins/edm/database.py @@ -0,0 +1,28 @@ +r"""Pre-trained models database.""" + +from typing import List + + +def get(key: str) -> str: + r"""Returns the URL of a pre-trained model. + + Arguments: + key: The pre-trained model key. + """ + + return URLS[key] + + +def keys() -> List[str]: + r"""Returns the list of available pre-trained models.""" + + return list(URLS.keys()) + + +URLS = { + "cifar10_32x32": "https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-uncond-ve.pkl", + "cifar10_32x32_cond": "https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-ve.pkl", + "afhq_64x64": "https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-afhqv2-64x64-uncond-ve.pkl", + "ffhq_64x64": "https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-ffhq-64x64-uncond-ve.pkl", + "imagenet_64x64_cond": "https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-imagenet-64x64-cond-adm.pkl", +} diff --git a/docs/index.rst b/docs/index.rst index c9d903e..29ff9b1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -82,7 +82,7 @@ Alternatively, Azula's plugin interface allows to load pre-trained models and us from azula.sample import DDIMSampler # Download weights from openai/guided-diffusion - denoiser = adm.load_model("imagenet_256x256_uncond") + denoiser = adm.load_model("imagenet_256x256") # Generate a batch of 4 images sampler = DDIMSampler(denoiser, steps=64).cuda() diff --git a/docs/tutorials/guidance.ipynb b/docs/tutorials/guidance.ipynb index 3ae6d70..482f4cd 100644 --- a/docs/tutorials/guidance.ipynb +++ b/docs/tutorials/guidance.ipynb @@ -77,7 +77,7 @@ } ], "source": [ - "denoiser = adm.load_model(\"imagenet_256x256_uncond\").to(device)" + "denoiser = adm.load_model(\"imagenet_256x256\").to(device)" ] }, {