Skip to content

Commit

Permalink
Add map argument to audb.load_table()
Browse files Browse the repository at this point in the history
  • Loading branch information
hagenw committed Aug 14, 2024
1 parent 593776c commit 68dfd4a
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 1 deletion.
31 changes: 30 additions & 1 deletion audb/core/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1635,6 +1635,7 @@ def load_table(
table: str,
*,
version: str = None,
map: typing.Dict[str, typing.Union[str, typing.Sequence[str]]] = None,
pickle_tables: bool = True,
cache_root: str = None,
num_workers: typing.Optional[int] = 1,
Expand All @@ -1654,6 +1655,15 @@ def load_table(
name: name of database
table: load table from database
version: version of database
map: map scheme or scheme fields to column values.
For example if your table holds a column ``speaker`` with
speaker IDs, which is assigned to a scheme that contains a
dict mapping speaker IDs to age and gender entries,
``map={'speaker': ['age', 'gender']}``
will replace the column with two new columns that map ID
values to age and gender, respectively.
To also keep the original column with speaker IDS, you can do
``map={'speaker': ['speaker', 'age', 'gender']}``
pickle_tables: if ``True``,
tables are cached locally
in their original format
Expand Down Expand Up @@ -1688,6 +1698,20 @@ def load_table(
wav/03a01Nc.wav neutral 1.00
wav/03a01Wa.wav anger 0.95
>>> df = load_table(
... "emodb",
... "files",
... version="1.4.1",
... map={"speaker": "age"},
... verbose=False,
... )
>>> df[:3]
duration transcription age
file
wav/03a01Fa.wav 0 days 00:00:01.898250 a01 31
wav/03a01Nc.wav 0 days 00:00:01.611250 a01 31
wav/03a01Wa.wav 0 days 00:00:01.877812500 a01 31
"""
if version is None:
version = latest_version(name)
Expand Down Expand Up @@ -1747,4 +1771,9 @@ def load_table(
)
db[_table].load(table_file)

return db[table]._df
if map is None:
df = db[table]._df
else:
df = db[table].get(map=map)

return df
72 changes: 72 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import random
import shutil

import numpy as np
import pandas as pd
import pytest

Expand Down Expand Up @@ -44,6 +46,8 @@ def dbs(tmpdir_factory, persistent_repository, storage_format):
dictionary containing root folder for each version
"""
random.seed(1)

# Collect single database paths
# and return them in the end
paths = {}
Expand Down Expand Up @@ -728,6 +732,74 @@ def test_load_table(version, table):
assert files == expected_files


@pytest.mark.parametrize(
"version, table, map, expected",
[
(
"1.0.0",
"files",
None,
pd.concat(
[
pd.Series(
index=audformat.filewise_index(
[f"audio/00{n + 1}.wav" for n in range(5)]
),
name="speaker",
data=["adam", "adam", "eve", "eve", None],
dtype=pd.CategoricalDtype(["adam", "eve"]),
),
pd.Series(
index=audformat.filewise_index(
[f"audio/00{n + 1}.wav" for n in range(5)]
),
name="misc",
data=[0, 1, 1, 2, np.nan],
dtype=pd.CategoricalDtype([0, 1, 2]),
),
],
axis=1,
),
),
(
"1.0.0",
"files",
{"misc": "emotion"},
pd.concat(
[
pd.Series(
index=audformat.filewise_index(
[f"audio/00{n + 1}.wav" for n in range(5)]
),
name="speaker",
data=["adam", "adam", "eve", "eve", None],
dtype=pd.CategoricalDtype(["adam", "eve"]),
),
pd.Series(
index=audformat.filewise_index(
[f"audio/00{n + 1}.wav" for n in range(5)]
),
name="emotion",
data=["positive", "positive", "positive", "negative", None],
dtype=pd.CategoricalDtype(["positive", "neutral", "negative"]),
),
],
axis=1,
),
),
],
)
def test_load_table_map(version, table, map, expected):
df = audb.load_table(
DB_NAME,
table,
version=version,
map=map,
verbose=False,
)
pd.testing.assert_frame_equal(df, expected)


@pytest.mark.parametrize(
"version",
[
Expand Down

0 comments on commit 68dfd4a

Please sign in to comment.