Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add map argument to audb.load_table() #447

Merged
merged 6 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 59 additions & 12 deletions audb/core/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,15 @@ def _load_files(
def _misc_tables_used_in_scheme(
db: audformat.Database,
) -> typing.List[str]:
r"""List of misc tables that are used inside a scheme."""
r"""List of misc tables that are used inside a scheme.

Args:
db: database object

Returns:
unique list of misc tables used in schemes
ChristianGeng marked this conversation as resolved.
Show resolved Hide resolved

"""
misc_tables_used_in_scheme = []
for scheme in db.schemes.values():
if scheme.uses_table:
Expand Down Expand Up @@ -1635,6 +1643,7 @@ def load_table(
table: str,
*,
version: str = None,
map: typing.Dict[str, typing.Union[str, typing.Sequence[str]]] = None,
pickle_tables: bool = True,
cache_root: str = None,
num_workers: typing.Optional[int] = 1,
Expand All @@ -1654,6 +1663,15 @@ def load_table(
name: name of database
table: load table from database
version: version of database
map: map scheme or scheme fields to column values.
For example if your table holds a column ``speaker`` with
speaker IDs, which is assigned to a scheme that contains a
dict mapping speaker IDs to age and gender entries,
``map={'speaker': ['age', 'gender']}``
will replace the column with two new columns that map ID
values to age and gender, respectively.
To also keep the original column with speaker IDS, you can do
``map={'speaker': ['speaker', 'age', 'gender']}``
pickle_tables: if ``True``,
tables are cached locally
in their original format
Expand All @@ -1675,18 +1693,33 @@ def load_table(
that is not part of the database

Examples:
>>> df = load_table("emodb", "emotion", version="1.4.1", verbose=False)
>>> df[:3]
emotion emotion.confidence
file
wav/03a01Fa.wav happiness 0.90
wav/03a01Nc.wav neutral 1.00
wav/03a01Wa.wav anger 0.95
>>> df = load_table("emodb", "files", version="1.4.1", verbose=False)
>>> df[:3]
duration speaker transcription
file
wav/03a01Fa.wav 0 days 00:00:01.898250 3 a01
wav/03a01Nc.wav 0 days 00:00:01.611250 3 a01
wav/03a01Wa.wav 0 days 00:00:01.877812500 3 a01
>>> df = load_table(
... "emodb",
... "emotion",
... "files",
... version="1.4.1",
... map={"speaker": "age"},
... verbose=False,
... )
>>> df[:3]
emotion emotion.confidence
duration transcription age
file
wav/03a01Fa.wav happiness 0.90
wav/03a01Nc.wav neutral 1.00
wav/03a01Wa.wav anger 0.95
wav/03a01Fa.wav 0 days 00:00:01.898250 a01 31
wav/03a01Nc.wav 0 days 00:00:01.611250 a01 31
wav/03a01Wa.wav 0 days 00:00:01.877812500 a01 31

"""
if version is None:
Expand Down Expand Up @@ -1722,14 +1755,23 @@ def load_table(
version,
)

# Find only those misc tables used in schemes of the requested table
scheme_misc_tables = []
for column_id, column in db[table].columns.items():
if column.scheme_id is not None:
scheme = db.schemes[column.scheme_id]
if scheme.uses_table:
scheme_misc_tables.append(scheme.labels)
scheme_misc_tables = audeer.unique(scheme_misc_tables)

# Load table
tables = _misc_tables_used_in_scheme(db) + [table]
tables = scheme_misc_tables + [table]
for _table in tables:
table_file = os.path.join(db_root, f"db.{_table}")
if not (
os.path.exists(f"{table_file}.csv")
or os.path.exists(f"{table_file}.pkl")
):
# `_load_files()` downloads a table
# from the backend,
# if it cannot find its corresponding csv or parquet file
if not os.path.exists(f"{table_file}.pkl"):
_load_files(
[_table],
"table",
Expand All @@ -1747,4 +1789,9 @@ def load_table(
)
db[_table].load(table_file)

return db[table]._df
if map is None:
df = db[table]._df
else:
df = db[table].get(map=map)

return df
73 changes: 73 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import random
import shutil

import numpy as np
import pandas as pd
import pytest

Expand Down Expand Up @@ -44,6 +46,9 @@ def dbs(tmpdir_factory, persistent_repository, storage_format):
dictionary containing root folder for each version

"""
# Fix seed for audformat.testing
random.seed(1)
ChristianGeng marked this conversation as resolved.
Show resolved Hide resolved

# Collect single database paths
# and return them in the end
paths = {}
Expand Down Expand Up @@ -728,6 +733,74 @@ def test_load_table(version, table):
assert files == expected_files


@pytest.mark.parametrize(
"version, table, map, expected",
[
(
"1.0.0",
"files",
None,
pd.concat(
[
pd.Series(
index=audformat.filewise_index(
[f"audio/00{n + 1}.wav" for n in range(5)]
),
name="speaker",
data=["adam", "adam", "eve", "eve", None],
ChristianGeng marked this conversation as resolved.
Show resolved Hide resolved
dtype=pd.CategoricalDtype(["adam", "eve"]),
),
pd.Series(
index=audformat.filewise_index(
[f"audio/00{n + 1}.wav" for n in range(5)]
),
name="misc",
data=[0, 1, 1, 2, np.nan],
dtype=pd.CategoricalDtype([0, 1, 2]),
),
],
axis=1,
),
),
(
"1.0.0",
"files",
{"misc": "emotion"},
pd.concat(
[
pd.Series(
index=audformat.filewise_index(
[f"audio/00{n + 1}.wav" for n in range(5)]
),
name="speaker",
data=["adam", "adam", "eve", "eve", None],
dtype=pd.CategoricalDtype(["adam", "eve"]),
),
pd.Series(
index=audformat.filewise_index(
[f"audio/00{n + 1}.wav" for n in range(5)]
),
name="emotion",
data=["positive", "positive", "positive", "negative", None],
dtype=pd.CategoricalDtype(["positive", "neutral", "negative"]),
),
],
axis=1,
),
),
],
)
def test_load_table_map(version, table, map, expected):
df = audb.load_table(
DB_NAME,
table,
version=version,
map=map,
verbose=False,
)
pd.testing.assert_frame_equal(df, expected)


@pytest.mark.parametrize(
"version",
[
Expand Down
Loading