Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Centralize download utilities #6

Merged
merged 4 commits into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions azula/hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
r"""Utilities for downloading models."""

__all__ = [
"get_dir",
"set_dir",
"download",
]

import gdown
import hashlib
import os
import re
import sys
import torch

from typing import Optional

AZULA_HUB: str = os.path.expanduser("~/.cache/azula/hub")


def get_dir() -> str:
r"""Returns the cache directory used for storing models & weights."""

return AZULA_HUB


def set_dir(cache_dir: str):
r"""Sets the cache directory used for storing models & weights."""

global AZULA_HUB

cache_dir = os.path.expanduser(cache_dir)
cache_dir = os.path.abspath(cache_dir)

AZULA_HUB = cache_dir


def download(
url: str,
filename: Optional[str] = None,
hash_prefix: Optional[str] = None,
quiet: bool = False,
) -> str:
r"""Downloads data at a given URL to a local file.

Arguments:
url: A URL. Google Drive URLs are supported.
filename: A local file name. If :py:`None`, use the sanitized URL instead.
If a file with the same name exists, the download is skipped.
hash_prefix: The expected hash prefix of the file, formatted as `"alg:prefix"`.
quiet: Whether to keep it quiet in the terminal or not.
"""

if filename is None:
filename = re.sub("[ /\\\\|?%*:'\"<>]+", ".", url)
filename = os.path.join(get_dir(), filename)
else:
filename = os.path.expanduser(filename)
filename = os.path.abspath(filename)

if os.path.exists(filename):
if not quiet:
print(f"Skipping download as {filename} already exists.", file=sys.stderr)
else:
if not quiet:
print(f"Downloading {url} to {filename}", file=sys.stderr)

if "drive.google" in url:
gdown.download(url, filename, quiet=quiet)
else:
torch.hub.download_url_to_file(url, filename, progress=not quiet)

if hash_prefix is not None:
alg, prefix = hash_prefix.split(":")
digest = hashlib.new(alg)

with open(filename, "rb") as f: # adapted from hashlib.file_digest
buffer = bytearray(2**20) # reusable 1MB buffer
view = memoryview(buffer)
while True:
size = f.readinto(buffer)
if size == 0: # end of file
break
digest.update(view[:size])

hex_hash = digest.hexdigest()

assert hex_hash.startswith(prefix), (
f"The hash of the downloaded file ({alg}:{hex_hash}) does not match "
f"the expected hash prefix ({alg}:{prefix})."
)

return filename
2 changes: 1 addition & 1 deletion azula/nn/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
r"""Module helpers."""
r"""Utilities for modules and networks."""

import torch
import torch.nn as nn
Expand Down
22 changes: 4 additions & 18 deletions azula/plugins/adm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,14 @@
import torch
import torch.nn as nn

from azula.debug import RaiseMock
from azula.denoise import Gaussian, GaussianDenoiser
from azula.hub import download
from azula.nn.utils import FlattenWrapper
from azula.noise import Schedule
from azula.plugins.utils import RaiseMock
from torch import LongTensor, Tensor
from typing import List, Sequence, Set, Tuple

try:
from gdown import cached_download
except ImportError as e:
cached_download = RaiseMock(name="gdown.cached_download", error=e)

try:
from guided_diffusion import unet # type: ignore
except ImportError as e:
Expand Down Expand Up @@ -172,7 +168,7 @@ def load_model(key: str, **kwargs) -> ImprovedDenoiser:

Arguments:
key: The pre-trained model key.
kwargs: Keyword arguments passed to :func:`torch.hub.load`.
kwargs: Keyword arguments passed to :func:`torch.load`.

Returns:
A pre-trained denoiser.
Expand All @@ -182,17 +178,7 @@ def load_model(key: str, **kwargs) -> ImprovedDenoiser:
kwargs.setdefault("weights_only", True)

url, config = database.get(key)

if "drive.google" in url:
state = torch.load(
f=cached_download(url=url),
**kwargs,
)
else:
state = torch.hub.load_state_dict_from_url(
url=url,
**kwargs,
)
state = torch.load(download(url), **kwargs)

denoiser = make_model(**config)
denoiser.backbone.wrappee.load_state_dict(state)
Expand Down
7 changes: 3 additions & 4 deletions azula/debug.py → azula/plugins/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
r"""Utilities for debugging."""
r"""Utilities for plugins."""

from typing import Any
from unittest.mock import Mock


class RaiseMock(Mock):
r"""Creates an object that raises an error whenever it or its children are called.
r"""Creates an object that raises an error whenever it or its attributes are called.

Arguments:
error: The error to be raised.
Expand All @@ -14,5 +13,5 @@ class RaiseMock(Mock):
def __init__(self, error: Exception, **kwargs):
super().__init__(side_effect=error, **kwargs)

def _get_child_mock(self, **kwargs: Any) -> Mock:
def _get_child_mock(self, **kwargs) -> Mock:
return super()._get_child_mock(error=self.side_effect, **kwargs)
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ API

azula.denoise
azula.guidance
azula.hub
azula.linalg
azula.nn
azula.noise
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
einops>=0.7.0
gdown>=5.1.0
numpy>=1.20.0
torch>=1.12.0
torchvision>=0.13
67 changes: 67 additions & 0 deletions tests/test_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
r"""Tests for the azula.hub module."""

import os
import pytest

from azula.hub import download, get_dir, set_dir


def test_default_dir():
default_dir = get_dir()

assert isinstance(default_dir, str)

os.makedirs(default_dir, exist_ok=True)


def test_set_dir(tmp_path):
set_dir(tmp_path)
cache_dir = get_dir()

assert isinstance(cache_dir, str)
assert os.path.samefile(cache_dir, tmp_path)


def test_download(tmp_path):
# Set cache dir
set_dir(tmp_path)

# With filename
download(
url="https://raw.githubusercontent.com/probabilists/azula/master/LICENSE",
filename=tmp_path / "LICENSE",
)

with open(tmp_path / "LICENSE") as f:
text = f.read()

assert "MIT License" in text
assert "The Probabilists" in text

# Without filename
filename = download(
url="https://raw.githubusercontent.com/probabilists/azula/master/LICENSE",
)

assert os.path.samefile(os.path.dirname(filename), tmp_path)

with open(filename) as f:
text = f.read()

assert "MIT License" in text
assert "The Probabilists" in text

# Hash prefix
download(
url="https://raw.githubusercontent.com/probabilists/azula/master/LICENSE",
hash_prefix="sha256:c8adb00fadb8f4bf",
)

with pytest.raises(AssertionError):
download(
url="https://raw.githubusercontent.com/probabilists/azula/master/LICENSE",
hash_prefix="sha256:abcdefghijklmnop",
)

# TODO (francois-rozet)
# Find a URL to test Google Drive download