diff --git a/audb/core/load.py b/audb/core/load.py index d50b32c6..32507974 100644 --- a/audb/core/load.py +++ b/audb/core/load.py @@ -810,7 +810,15 @@ def _load_files( def _misc_tables_used_in_scheme( db: audformat.Database, ) -> typing.List[str]: - r"""List of misc tables that are used inside a scheme.""" + r"""List of misc tables that are used inside a scheme. + + Args: + db: database object + + Returns: + unique list of misc tables used in schemes + + """ misc_tables_used_in_scheme = [] for scheme in db.schemes.values(): if scheme.uses_table: @@ -1635,6 +1643,7 @@ def load_table( table: str, *, version: str = None, + map: typing.Dict[str, typing.Union[str, typing.Sequence[str]]] = None, pickle_tables: bool = True, cache_root: str = None, num_workers: typing.Optional[int] = 1, @@ -1654,6 +1663,15 @@ def load_table( name: name of database table: load table from database version: version of database + map: map scheme or scheme fields to column values. + For example if your table holds a column ``speaker`` with + speaker IDs, which is assigned to a scheme that contains a + dict mapping speaker IDs to age and gender entries, + ``map={'speaker': ['age', 'gender']}`` + will replace the column with two new columns that map ID + values to age and gender, respectively. + To also keep the original column with speaker IDS, you can do + ``map={'speaker': ['speaker', 'age', 'gender']}`` pickle_tables: if ``True``, tables are cached locally in their original format @@ -1675,18 +1693,33 @@ def load_table( that is not part of the database Examples: + >>> df = load_table("emodb", "emotion", version="1.4.1", verbose=False) + >>> df[:3] + emotion emotion.confidence + file + wav/03a01Fa.wav happiness 0.90 + wav/03a01Nc.wav neutral 1.00 + wav/03a01Wa.wav anger 0.95 + >>> df = load_table("emodb", "files", version="1.4.1", verbose=False) + >>> df[:3] + duration speaker transcription + file + wav/03a01Fa.wav 0 days 00:00:01.898250 3 a01 + wav/03a01Nc.wav 0 days 00:00:01.611250 3 a01 + wav/03a01Wa.wav 0 days 00:00:01.877812500 3 a01 >>> df = load_table( ... "emodb", - ... "emotion", + ... "files", ... version="1.4.1", + ... map={"speaker": "age"}, ... verbose=False, ... ) >>> df[:3] - emotion emotion.confidence + duration transcription age file - wav/03a01Fa.wav happiness 0.90 - wav/03a01Nc.wav neutral 1.00 - wav/03a01Wa.wav anger 0.95 + wav/03a01Fa.wav 0 days 00:00:01.898250 a01 31 + wav/03a01Nc.wav 0 days 00:00:01.611250 a01 31 + wav/03a01Wa.wav 0 days 00:00:01.877812500 a01 31 """ if version is None: @@ -1722,14 +1755,23 @@ def load_table( version, ) + # Find only those misc tables used in schemes of the requested table + scheme_misc_tables = [] + for column_id, column in db[table].columns.items(): + if column.scheme_id is not None: + scheme = db.schemes[column.scheme_id] + if scheme.uses_table: + scheme_misc_tables.append(scheme.labels) + scheme_misc_tables = audeer.unique(scheme_misc_tables) + # Load table - tables = _misc_tables_used_in_scheme(db) + [table] + tables = scheme_misc_tables + [table] for _table in tables: table_file = os.path.join(db_root, f"db.{_table}") - if not ( - os.path.exists(f"{table_file}.csv") - or os.path.exists(f"{table_file}.pkl") - ): + # `_load_files()` downloads a table + # from the backend, + # if it cannot find its corresponding csv or parquet file + if not os.path.exists(f"{table_file}.pkl"): _load_files( [_table], "table", @@ -1747,4 +1789,9 @@ def load_table( ) db[_table].load(table_file) - return db[table]._df + if map is None: + df = db[table]._df + else: + df = db[table].get(map=map) + + return df diff --git a/tests/test_load.py b/tests/test_load.py index 1b0e20a2..e944193d 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1,6 +1,8 @@ import os +import random import shutil +import numpy as np import pandas as pd import pytest @@ -44,6 +46,9 @@ def dbs(tmpdir_factory, persistent_repository, storage_format): dictionary containing root folder for each version """ + # Fix seed for audformat.testing + random.seed(1) + # Collect single database paths # and return them in the end paths = {} @@ -728,6 +733,74 @@ def test_load_table(version, table): assert files == expected_files +@pytest.mark.parametrize( + "version, table, map, expected", + [ + ( + "1.0.0", + "files", + None, + pd.concat( + [ + pd.Series( + index=audformat.filewise_index( + [f"audio/00{n + 1}.wav" for n in range(5)] + ), + name="speaker", + data=["adam", "adam", "eve", "eve", None], + dtype=pd.CategoricalDtype(["adam", "eve"]), + ), + pd.Series( + index=audformat.filewise_index( + [f"audio/00{n + 1}.wav" for n in range(5)] + ), + name="misc", + data=[0, 1, 1, 2, np.nan], + dtype=pd.CategoricalDtype([0, 1, 2]), + ), + ], + axis=1, + ), + ), + ( + "1.0.0", + "files", + {"misc": "emotion"}, + pd.concat( + [ + pd.Series( + index=audformat.filewise_index( + [f"audio/00{n + 1}.wav" for n in range(5)] + ), + name="speaker", + data=["adam", "adam", "eve", "eve", None], + dtype=pd.CategoricalDtype(["adam", "eve"]), + ), + pd.Series( + index=audformat.filewise_index( + [f"audio/00{n + 1}.wav" for n in range(5)] + ), + name="emotion", + data=["positive", "positive", "positive", "negative", None], + dtype=pd.CategoricalDtype(["positive", "neutral", "negative"]), + ), + ], + axis=1, + ), + ), + ], +) +def test_load_table_map(version, table, map, expected): + df = audb.load_table( + DB_NAME, + table, + version=version, + map=map, + verbose=False, + ) + pd.testing.assert_frame_equal(df, expected) + + @pytest.mark.parametrize( "version", [