From 68dfd4a36515175e56ece6d79eba361fdc166bb6 Mon Sep 17 00:00:00 2001 From: Hagen Wierstorf Date: Wed, 14 Aug 2024 15:47:50 +0200 Subject: [PATCH] Add map argument to audb.load_table() --- audb/core/load.py | 31 +++++++++++++++++++- tests/test_load.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 1 deletion(-) diff --git a/audb/core/load.py b/audb/core/load.py index d50b32c6..faabfe89 100644 --- a/audb/core/load.py +++ b/audb/core/load.py @@ -1635,6 +1635,7 @@ def load_table( table: str, *, version: str = None, + map: typing.Dict[str, typing.Union[str, typing.Sequence[str]]] = None, pickle_tables: bool = True, cache_root: str = None, num_workers: typing.Optional[int] = 1, @@ -1654,6 +1655,15 @@ def load_table( name: name of database table: load table from database version: version of database + map: map scheme or scheme fields to column values. + For example if your table holds a column ``speaker`` with + speaker IDs, which is assigned to a scheme that contains a + dict mapping speaker IDs to age and gender entries, + ``map={'speaker': ['age', 'gender']}`` + will replace the column with two new columns that map ID + values to age and gender, respectively. + To also keep the original column with speaker IDS, you can do + ``map={'speaker': ['speaker', 'age', 'gender']}`` pickle_tables: if ``True``, tables are cached locally in their original format @@ -1688,6 +1698,20 @@ def load_table( wav/03a01Nc.wav neutral 1.00 wav/03a01Wa.wav anger 0.95 + >>> df = load_table( + ... "emodb", + ... "files", + ... version="1.4.1", + ... map={"speaker": "age"}, + ... verbose=False, + ... ) + >>> df[:3] + duration transcription age + file + wav/03a01Fa.wav 0 days 00:00:01.898250 a01 31 + wav/03a01Nc.wav 0 days 00:00:01.611250 a01 31 + wav/03a01Wa.wav 0 days 00:00:01.877812500 a01 31 + """ if version is None: version = latest_version(name) @@ -1747,4 +1771,9 @@ def load_table( ) db[_table].load(table_file) - return db[table]._df + if map is None: + df = db[table]._df + else: + df = db[table].get(map=map) + + return df diff --git a/tests/test_load.py b/tests/test_load.py index 1b0e20a2..50c4ed43 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1,6 +1,8 @@ import os +import random import shutil +import numpy as np import pandas as pd import pytest @@ -44,6 +46,8 @@ def dbs(tmpdir_factory, persistent_repository, storage_format): dictionary containing root folder for each version """ + random.seed(1) + # Collect single database paths # and return them in the end paths = {} @@ -728,6 +732,74 @@ def test_load_table(version, table): assert files == expected_files +@pytest.mark.parametrize( + "version, table, map, expected", + [ + ( + "1.0.0", + "files", + None, + pd.concat( + [ + pd.Series( + index=audformat.filewise_index( + [f"audio/00{n + 1}.wav" for n in range(5)] + ), + name="speaker", + data=["adam", "adam", "eve", "eve", None], + dtype=pd.CategoricalDtype(["adam", "eve"]), + ), + pd.Series( + index=audformat.filewise_index( + [f"audio/00{n + 1}.wav" for n in range(5)] + ), + name="misc", + data=[0, 1, 1, 2, np.nan], + dtype=pd.CategoricalDtype([0, 1, 2]), + ), + ], + axis=1, + ), + ), + ( + "1.0.0", + "files", + {"misc": "emotion"}, + pd.concat( + [ + pd.Series( + index=audformat.filewise_index( + [f"audio/00{n + 1}.wav" for n in range(5)] + ), + name="speaker", + data=["adam", "adam", "eve", "eve", None], + dtype=pd.CategoricalDtype(["adam", "eve"]), + ), + pd.Series( + index=audformat.filewise_index( + [f"audio/00{n + 1}.wav" for n in range(5)] + ), + name="emotion", + data=["positive", "positive", "positive", "negative", None], + dtype=pd.CategoricalDtype(["positive", "neutral", "negative"]), + ), + ], + axis=1, + ), + ), + ], +) +def test_load_table_map(version, table, map, expected): + df = audb.load_table( + DB_NAME, + table, + version=version, + map=map, + verbose=False, + ) + pd.testing.assert_frame_equal(df, expected) + + @pytest.mark.parametrize( "version", [