From 68dfd4a36515175e56ece6d79eba361fdc166bb6 Mon Sep 17 00:00:00 2001 From: Hagen Wierstorf Date: Wed, 14 Aug 2024 15:47:50 +0200 Subject: [PATCH 1/6] Add map argument to audb.load_table() --- audb/core/load.py | 31 +++++++++++++++++++- tests/test_load.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 1 deletion(-) diff --git a/audb/core/load.py b/audb/core/load.py index d50b32c6..faabfe89 100644 --- a/audb/core/load.py +++ b/audb/core/load.py @@ -1635,6 +1635,7 @@ def load_table( table: str, *, version: str = None, + map: typing.Dict[str, typing.Union[str, typing.Sequence[str]]] = None, pickle_tables: bool = True, cache_root: str = None, num_workers: typing.Optional[int] = 1, @@ -1654,6 +1655,15 @@ def load_table( name: name of database table: load table from database version: version of database + map: map scheme or scheme fields to column values. + For example if your table holds a column ``speaker`` with + speaker IDs, which is assigned to a scheme that contains a + dict mapping speaker IDs to age and gender entries, + ``map={'speaker': ['age', 'gender']}`` + will replace the column with two new columns that map ID + values to age and gender, respectively. + To also keep the original column with speaker IDS, you can do + ``map={'speaker': ['speaker', 'age', 'gender']}`` pickle_tables: if ``True``, tables are cached locally in their original format @@ -1688,6 +1698,20 @@ def load_table( wav/03a01Nc.wav neutral 1.00 wav/03a01Wa.wav anger 0.95 + >>> df = load_table( + ... "emodb", + ... "files", + ... version="1.4.1", + ... map={"speaker": "age"}, + ... verbose=False, + ... ) + >>> df[:3] + duration transcription age + file + wav/03a01Fa.wav 0 days 00:00:01.898250 a01 31 + wav/03a01Nc.wav 0 days 00:00:01.611250 a01 31 + wav/03a01Wa.wav 0 days 00:00:01.877812500 a01 31 + """ if version is None: version = latest_version(name) @@ -1747,4 +1771,9 @@ def load_table( ) db[_table].load(table_file) - return db[table]._df + if map is None: + df = db[table]._df + else: + df = db[table].get(map=map) + + return df diff --git a/tests/test_load.py b/tests/test_load.py index 1b0e20a2..50c4ed43 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1,6 +1,8 @@ import os +import random import shutil +import numpy as np import pandas as pd import pytest @@ -44,6 +46,8 @@ def dbs(tmpdir_factory, persistent_repository, storage_format): dictionary containing root folder for each version """ + random.seed(1) + # Collect single database paths # and return them in the end paths = {} @@ -728,6 +732,74 @@ def test_load_table(version, table): assert files == expected_files +@pytest.mark.parametrize( + "version, table, map, expected", + [ + ( + "1.0.0", + "files", + None, + pd.concat( + [ + pd.Series( + index=audformat.filewise_index( + [f"audio/00{n + 1}.wav" for n in range(5)] + ), + name="speaker", + data=["adam", "adam", "eve", "eve", None], + dtype=pd.CategoricalDtype(["adam", "eve"]), + ), + pd.Series( + index=audformat.filewise_index( + [f"audio/00{n + 1}.wav" for n in range(5)] + ), + name="misc", + data=[0, 1, 1, 2, np.nan], + dtype=pd.CategoricalDtype([0, 1, 2]), + ), + ], + axis=1, + ), + ), + ( + "1.0.0", + "files", + {"misc": "emotion"}, + pd.concat( + [ + pd.Series( + index=audformat.filewise_index( + [f"audio/00{n + 1}.wav" for n in range(5)] + ), + name="speaker", + data=["adam", "adam", "eve", "eve", None], + dtype=pd.CategoricalDtype(["adam", "eve"]), + ), + pd.Series( + index=audformat.filewise_index( + [f"audio/00{n + 1}.wav" for n in range(5)] + ), + name="emotion", + data=["positive", "positive", "positive", "negative", None], + dtype=pd.CategoricalDtype(["positive", "neutral", "negative"]), + ), + ], + axis=1, + ), + ), + ], +) +def test_load_table_map(version, table, map, expected): + df = audb.load_table( + DB_NAME, + table, + version=version, + map=map, + verbose=False, + ) + pd.testing.assert_frame_equal(df, expected) + + @pytest.mark.parametrize( "version", [ From 61f59fa6b8989c7846597959cefedefda9ae44d2 Mon Sep 17 00:00:00 2001 From: Hagen Wierstorf Date: Wed, 14 Aug 2024 15:51:09 +0200 Subject: [PATCH 2/6] Improve docstring example --- audb/core/load.py | 1 - 1 file changed, 1 deletion(-) diff --git a/audb/core/load.py b/audb/core/load.py index faabfe89..ce461c1f 100644 --- a/audb/core/load.py +++ b/audb/core/load.py @@ -1697,7 +1697,6 @@ def load_table( wav/03a01Fa.wav happiness 0.90 wav/03a01Nc.wav neutral 1.00 wav/03a01Wa.wav anger 0.95 - >>> df = load_table( ... "emodb", ... "files", From b760c8d524da4a8fa89462ef2dc2bdd805c9692f Mon Sep 17 00:00:00 2001 From: Hagen Wierstorf Date: Wed, 14 Aug 2024 15:58:59 +0200 Subject: [PATCH 3/6] Add another example --- audb/core/load.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/audb/core/load.py b/audb/core/load.py index ce461c1f..4803b205 100644 --- a/audb/core/load.py +++ b/audb/core/load.py @@ -1685,18 +1685,20 @@ def load_table( that is not part of the database Examples: - >>> df = load_table( - ... "emodb", - ... "emotion", - ... version="1.4.1", - ... verbose=False, - ... ) + >>> df = load_table("emodb", "emotion", version="1.4.1", verbose=False) >>> df[:3] emotion emotion.confidence file wav/03a01Fa.wav happiness 0.90 wav/03a01Nc.wav neutral 1.00 wav/03a01Wa.wav anger 0.95 + >>> df = load_table("emodb", "files", version="1.4.1", verbose=False) + >>> df[:3] + duration speaker transcription + file + wav/03a01Fa.wav 0 days 00:00:01.898250 3 a01 + wav/03a01Nc.wav 0 days 00:00:01.611250 3 a01 + wav/03a01Wa.wav 0 days 00:00:01.877812500 3 a01 >>> df = load_table( ... "emodb", ... "files", From 264884d1f060464470f55d9de9956cd723321575 Mon Sep 17 00:00:00 2001 From: Hagen Wierstorf Date: Thu, 15 Aug 2024 08:33:11 +0200 Subject: [PATCH 4/6] Download only required misc tables --- audb/core/load.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/audb/core/load.py b/audb/core/load.py index 4803b205..fef85ef7 100644 --- a/audb/core/load.py +++ b/audb/core/load.py @@ -810,7 +810,15 @@ def _load_files( def _misc_tables_used_in_scheme( db: audformat.Database, ) -> typing.List[str]: - r"""List of misc tables that are used inside a scheme.""" + r"""List of misc tables that are used inside a scheme. + + Args: + db: database object + + Returns: + unique list of misc tables used in schemes + + """ misc_tables_used_in_scheme = [] for scheme in db.schemes.values(): if scheme.uses_table: @@ -1747,14 +1755,23 @@ def load_table( version, ) + # Find misc tables used in schemes of the requested table + scheme_misc_tables = [] + for column_id, column in db[table].columns.items(): + if column.scheme_id is not None: + scheme = db.schemes[column.scheme_id] + if scheme.uses_table: + scheme_misc_tables.append(scheme.labels) + scheme_misc_tables = audeer.unique(scheme_misc_tables) + # Load table - tables = _misc_tables_used_in_scheme(db) + [table] + tables = scheme_misc_tables + [table] for _table in tables: table_file = os.path.join(db_root, f"db.{_table}") - if not ( - os.path.exists(f"{table_file}.csv") - or os.path.exists(f"{table_file}.pkl") - ): + # `_load_files()` downloads a table + # from the backend, + # if it cannot find its corresponding csv or parquet file + if not os.path.exists(f"{table_file}.pkl"): _load_files( [_table], "table", From 3a7c9653e731ab4d87c2eb9fc283b40938833d39 Mon Sep 17 00:00:00 2001 From: Hagen Wierstorf Date: Mon, 19 Aug 2024 16:23:35 +0200 Subject: [PATCH 5/6] Update docstring --- audb/core/load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/audb/core/load.py b/audb/core/load.py index fef85ef7..32507974 100644 --- a/audb/core/load.py +++ b/audb/core/load.py @@ -1755,7 +1755,7 @@ def load_table( version, ) - # Find misc tables used in schemes of the requested table + # Find only those misc tables used in schemes of the requested table scheme_misc_tables = [] for column_id, column in db[table].columns.items(): if column.scheme_id is not None: From 1e47e960538fb7597e33abb3f4b35dfd00f239a0 Mon Sep 17 00:00:00 2001 From: Hagen Wierstorf Date: Mon, 19 Aug 2024 16:25:44 +0200 Subject: [PATCH 6/6] Add comment for seed --- tests/test_load.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_load.py b/tests/test_load.py index 50c4ed43..e944193d 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -46,6 +46,7 @@ def dbs(tmpdir_factory, persistent_repository, storage_format): dictionary containing root folder for each version """ + # Fix seed for audformat.testing random.seed(1) # Collect single database paths