Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup caching of audbcards.Dataset #83

Merged
merged 22 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion audbcards/core/datacard.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def _render_template(self) -> str:
template = environment.get_template("datacard.j2")

# Convert dataset object to dictionary
dataset = self.dataset.properties()
dataset = self.dataset._cached_properties()

# Add additional datacard only properties
dataset = self._expand_dataset(dataset)
Expand Down
174 changes: 117 additions & 57 deletions audbcards/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pickle
import typing

import dohq_artifactory
import jinja2
import pandas as pd

Expand All @@ -18,19 +17,6 @@
from audbcards.core.utils import limit_presented_samples


def _getstate(self):
return self.name


def _setstate(self, state):
self.name = state


# Ensure we can pickle the repository
dohq_artifactory.GenericRepository.__getstate__ = _getstate
dohq_artifactory.GenericRepository.__setstate__ = _setstate


class _Dataset:
@classmethod
def create(
Expand All @@ -51,7 +37,7 @@ def create(
return obj

obj = cls(name, version, cache_root)
_ = obj.properties()
_ = obj._cached_properties()

cls._save_pickled(obj, dataset_cache_filename)
return obj
Expand All @@ -62,43 +48,41 @@ def __init__(
version: str,
cache_root: str = None,
):
self.cache_root = audeer.mkdir(audeer.path(cache_root))
self.header = audb.info.header(
name,
version=version,
load_tables=True, # ensure misc tables are loaded
)
self.deps = audb.dependencies(
name,
version=version,
verbose=False,
)
self.cache_root = audeer.mkdir(cache_root)
r"""Cache root folder."""

# Store name and version in private attributes here,
# ``self.name`` and ``self.version``
# are implemented as cached properties below
self._name = name
self._version = version
self._repository = audb.repository(name, version)
self._backend = audbackend.access(
name=self._repository.backend,
host=self._repository.host,
repository=self._repository.name,
)
if isinstance(self._backend, audbackend.Artifactory):
self._backend._use_legacy_file_structure() # pragma: nocover

# Private attributes,
# used inside corresponding properties.
self._header = self._load_header()
self._deps = self._load_dependencies()
self._repository_object = self._load_repository_object() # load before backend
self._backend = self._load_backend()

# Clean up cache
# by removing all other versions of the same dataset
# to reduce its storage size in CI runners
versions = audeer.list_dir_names(
audeer.path(self.cache_root, name),
audeer.path(cache_root, name),
basenames=True,
)
other_versions = [v for v in versions if v != version]
for other_version in other_versions:
audeer.rmdir(audeer.path(self.cache_root, name, other_version))
audeer.rmdir(cache_root, name, other_version)

def __getstate__(self):
r"""Returns attributes to be pickled."""
return self._cached_properties()

@staticmethod
def _dataset_cache_path(name: str, version: str, cache_root: str) -> str:
r"""Generate the name of the cache file."""
cache_dir = audeer.mkdir(audeer.path(cache_root, name, version))
cache_dir = audeer.mkdir(cache_root, name, version)

cache_filename = audeer.path(
cache_dir,
Expand All @@ -123,6 +107,34 @@ def _save_pickled(obj, path: str):
with open(path, "wb") as f:
pickle.dump(obj, f, protocol=4)

@property
def backend(self) -> audbackend.Backend:
r"""Dataset backend object."""
if not hasattr(self, "_backend"): # when loaded from cache
self._backend = self._load_backend()
return self._backend

@property
def deps(self) -> audb.Dependencies:
r"""Dataset dependency table."""
if not hasattr(self, "_deps"): # when loaded from cache
self._deps = self._load_dependencies()
return self._deps

@property
def header(self) -> audformat.Database:
r"""Dataset header."""
if not hasattr(self, "_header"): # when loaded from cache
self._header = self._load_header()
return self._header

@property
def repository_object(self) -> audb.Repository:
r"""Repository object containing dataset."""
if not hasattr(self, "_repository_object"): # when loaded from cache
self._repository_object = self._load_repository_object()
return self._repository_object

@functools.cached_property
def archives(self) -> int:
r"""Number of archives of media files in dataset."""
Expand Down Expand Up @@ -228,44 +240,34 @@ def license_link(self) -> typing.Optional[str]:
@functools.cached_property
def name(self) -> str:
r"""Name of dataset."""
return self.header.name
return self._name

@functools.cached_property
def publication_date(self) -> str:
r"""Date dataset was uploaded to repository."""
path = self._backend.join("/", self.name, "db.yaml")
return self._backend.date(path, self._version)
path = self.backend.join("/", self.name, "db.yaml")
return self.backend.date(path, self.version)

@functools.cached_property
def publication_owner(self) -> str:
r"""User who uploaded dataset to repository."""
path = self._backend.join("/", self.name, "db.yaml")
return self._backend.owner(path, self._version)

def properties(self):
"""Get list of properties of the object."""
class_items = self.__class__.__dict__.items()
props = dict(
(k, getattr(self, k))
for k, v in class_items
if isinstance(v, functools.cached_property)
)
return props
path = self.backend.join("/", self.name, "db.yaml")
return self.backend.owner(path, self.version)

@functools.cached_property
def repository(self) -> str:
r"""Repository containing the dataset."""
return f"{self._repository.name}"
return f"{self.repository_object.name}"

@functools.cached_property
def repository_link(self) -> str:
r"""Link to repository in Artifactory web UI."""
# NOTE: this needs to be changed
# as we want to support different backends
return (
f"{self._repository.host}/"
f"{self.repository_object.host}/"
f"webapp/#/artifacts/browse/tree/General/"
f"{self._repository.name}/"
f"{self.repository}/"
f"{self.name}"
)

Expand All @@ -289,6 +291,18 @@ def schemes(self) -> typing.List[str]:
r"""Schemes of dataset."""
return list(self.header.schemes)

@functools.cached_property
def schemes_summary(self) -> str:
r"""Summary of dataset schemes.

It lists all schemes in a string,
showing additional information
on schemes named ``'emotion'`` and ``'speaker'``,
e.g. ``'speaker: [age, gender, language]'``.

"""
return format_schemes(self.header.schemes)

@functools.cached_property
def schemes_table(self) -> typing.List[typing.List[str]]:
"""Schemes table with name, type, min, max, labels, mappings.
Expand Down Expand Up @@ -361,6 +375,40 @@ def version(self) -> str:
r"""Version of dataset."""
return self._version

def _cached_properties(self):
"""Get list of cached properties of the object."""
class_items = self.__class__.__dict__.items()
props = dict(
(k, getattr(self, k))
for k, v in class_items
if isinstance(v, functools.cached_property)
)
return props

def _load_backend(self) -> audbackend.Backend:
r"""Load backend containing dataset."""
backend = audbackend.access(
name=self.repository_object.backend,
host=self.repository_object.host,
repository=self.repository,
)
if isinstance(backend, audbackend.Artifactory):
backend._use_legacy_file_structure() # pragma: nocover
ChristianGeng marked this conversation as resolved.
Show resolved Hide resolved
return backend

def _load_dependencies(self) -> audb.Dependencies:
r"""Load dataset dependencies."""
return audb.dependencies(self.name, version=self.version, verbose=False)

def _load_header(self) -> audformat.Database:
r"""Load dataset header."""
# Ensure misc tables are loaded
return audb.info.header(self.name, version=self.version, load_tables=True)

def _load_repository_object(self) -> audb.Repository:
r"""Load repository object containing dataset."""
return audb.repository(self.name, self.version)

@functools.cached_property
def _scheme_table_columns(self) -> typing.List[str]:
"""Column names for the scheme table.
Expand Down Expand Up @@ -521,12 +569,24 @@ def __new__(
instance = _Dataset.create(name, version, cache_root=cache_root)
return instance

# Add an __init__() function,
# to allow documenting instance variables
def __init__(
self,
name: str,
version: str,
*,
cache_root: str = None,
):
self.cache_root = audeer.mkdir(cache_root)
r"""Cache root folder."""

# Copy attributes and methods
# to include in documentation
for prop in [
name
for name, value in inspect.getmembers(_Dataset)
if isinstance(value, functools.cached_property) and not name.startswith("_")
if not name.startswith("_") and name not in ["create"]
]:
vars()[prop] = getattr(_Dataset, prop)

Expand Down Expand Up @@ -589,7 +649,7 @@ def create_datasets_page(
dataset.short_description,
f"`{dataset.license} <{dataset.license_link}>`__",
dataset.version,
format_schemes(dataset.header.schemes),
dataset.schemes_summary,
)
for dataset in datasets
]
Expand Down
1 change: 1 addition & 0 deletions audbcards/sphinx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import sphinx
import sphinx.application

import audb
import audeer
Expand Down
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
audb >=1.6.5 # for audb.Dependencies.__eq__()
audeer >=1.21.0
pytest
55 changes: 51 additions & 4 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_dataset_property_scope(tmpdir, db, request):
cache_root=dataset_cache,
)

props = [x for x in dataset.properties().keys()]
props = [x for x in dataset._cached_properties().keys()]

# should not exist in local scope
for prop in props:
Expand Down Expand Up @@ -73,7 +73,8 @@ def test_dataset(audb_cache, tmpdir, repository, db, request):
# __init__
assert dataset.name == db.name
assert dataset.version == pytest.VERSION
assert dataset._repository == repository
assert dataset.repository_object == repository
assert dataset.backend == backend
expected_header = audb.info.header(
db.name,
version=pytest.VERSION,
Expand Down Expand Up @@ -304,8 +305,8 @@ def test_cache_file_existence(self, constructor):
def test_props_equal(self, constructor):
"""Cached and uncached datasets have equal props."""
ds_uncached, ds_cached, _ = constructor
props_uncached = ds_uncached.properties()
props_cached = ds_cached.properties()
props_uncached = ds_uncached._cached_properties()
props_cached = ds_cached._cached_properties()
list_props_uncached = list(props_uncached.keys())
list_props_cached = list(props_cached.keys())
assert list_props_uncached == list_props_cached
Expand Down Expand Up @@ -364,3 +365,49 @@ def test_dataset_cache_path():
"emodb-1.2.1.pkl",
)
assert cache_path_calculated == cache_path_expected


@pytest.mark.parametrize(
"db",
[
"medium_db",
],
)
def test_dataset_cache_loading(audb_cache, tmpdir, repository, db, request):
"""Test cached properties after loading from cache.

We no longer store all attributes/properties
in cache as pickle files,
but limit ourselves to the cached properties.
This test ensures,
that other attributes will be re-calculated.

"""
db = request.getfixturevalue(db)
cache_root = audeer.mkdir(tmpdir, "cache")
dataset = audbcards.Dataset(db.name, pytest.VERSION, cache_root=cache_root)
del dataset
dataset = audbcards.Dataset(db.name, pytest.VERSION, cache_root=cache_root)
deps = audb.dependencies(
db.name,
version=pytest.VERSION,
cache_root=audb_cache,
)
backend = audbackend.access(
name=repository.backend,
host=repository.host,
repository=repository.name,
)
header = audb.info.header(
db.name,
version=pytest.VERSION,
load_tables=True,
cache_root=audb_cache,
)
assert dataset.backend == backend
assert dataset.deps == deps
# The dataset header is a not fully loaded `audformat.Database` object,
# so we cannot directly use `audformat.Database.__eq__()`
# to compare it.
assert str(dataset.header) == str(header)
assert dataset.repository_object == repository
Loading