Skip to content

Commit

Permalink
Ensure correct boolean dtype in misc table index (#431)
Browse files Browse the repository at this point in the history
* Ensure correct boolean dtype in misc table index

* Remove unneeded code

* Improve comment

* Add dtype tests for table column

* Improve docstrings

* Extend docstring

* Fix return type

* Add another test
  • Loading branch information
hagenw authored Jun 11, 2024
1 parent 1876c84 commit c813739
Show file tree
Hide file tree
Showing 5 changed files with 419 additions and 12 deletions.
17 changes: 15 additions & 2 deletions audformat/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,21 @@ def to_audformat_dtype(dtype: typing.Union[str, typing.Type]) -> str:
return define.DataType.OBJECT


def to_pandas_dtype(dtype: str) -> str:
r"""Convert audformat to pandas dtype."""
def to_pandas_dtype(dtype: str) -> typing.Optional[str]:
r"""Convert audformat to pandas dtype.
We use ``"Int64"`` instead of ``"int64"``,
and ``"boolean"`` instead of ``"bool"``
to allow for nullable entries,
e.g. ``[0, 2, <NA>]``.
Args:
dtype: audformat dtype
Returns:
pandas dtype
"""
if dtype == define.DataType.BOOL:
return "boolean"
elif dtype == define.DataType.DATE:
Expand Down
6 changes: 4 additions & 2 deletions audformat/core/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,8 +1072,10 @@ def __init__(
if isinstance(index, pd.MultiIndex) and index.nlevels == 1:
index = index.get_level_values(0)

# Ensure integers are always stored as Int64
index = utils._maybe_convert_int_dtype(index)
# Ensure integers are stored as Int64,
# and bool values as boolean,
# compare audformat.core.common.to_pandas_dtype()
index = utils._maybe_convert_pandas_dtype(index)

levels = utils._levels(index)
if not all(levels) or len(levels) > len(set(levels)):
Expand Down
38 changes: 31 additions & 7 deletions audformat/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def difference(
if not objs:
return pd.Index([])

objs = [_maybe_convert_int_dtype(obj) for obj in objs]
objs = [_maybe_convert_pandas_dtype(obj) for obj in objs]

if len(objs) == 1:
return objs[0]
Expand Down Expand Up @@ -846,7 +846,7 @@ def intersect(
if not objs:
return pd.Index([])

objs = [_maybe_convert_int_dtype(obj) for obj in objs]
objs = [_maybe_convert_pandas_dtype(obj) for obj in objs]

if len(objs) == 1:
return _alike_index(objs[0])
Expand Down Expand Up @@ -1910,7 +1910,7 @@ def union(
if not objs:
return pd.Index([])

objs = [_maybe_convert_int_dtype(obj) for obj in objs]
objs = [_maybe_convert_pandas_dtype(obj) for obj in objs]

if len(objs) == 1:
return objs[0]
Expand Down Expand Up @@ -2083,19 +2083,43 @@ def _maybe_convert_filewise_index(
return objs


def _maybe_convert_int_dtype(
def _maybe_convert_pandas_dtype(
index: pd.Index,
) -> pd.Index:
r"""Convert integer dtypes to Int64."""
# Ensure integers are always stored as Int64
r"""Ensure desired pandas dtypes.
Applies the following conversions:
* integer -> Int64
* bool -> boolean
Args:
index: index object
Returns:
index object
"""
levels = _levels(index)
dtypes = _dtypes(index)

# Ensure integers are stored as Int64
int_dtypes = {
level: "Int64"
for level, dtype in zip(levels, dtypes)
if pd.api.types.is_integer_dtype(dtype)
}
return set_index_dtypes(index, int_dtypes)
# Ensure bool values are stored as boolean
bool_dtypes = {
level: "boolean"
for level, dtype in zip(levels, dtypes)
if pd.api.types.is_bool_dtype(dtype)
}
# Merge dictionaries
dtypes = {**int_dtypes, **bool_dtypes}

index = set_index_dtypes(index, dtypes)
return index


def _maybe_convert_single_level_multi_index(
Expand Down
184 changes: 184 additions & 0 deletions tests/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,190 @@ def test_access():
assert column.rater is None


@pytest.mark.parametrize(
"column_values, column_dtype, expected_pandas_dtype, expected_audformat_dtype",
[
(
[],
None,
"object",
audformat.define.DataType.OBJECT,
),
(
[],
bool,
"boolean",
audformat.define.DataType.BOOL,
),
(
[],
"boolean",
"boolean",
audformat.define.DataType.BOOL,
),
(
[],
"datetime64[ns]",
"datetime64[ns]",
audformat.define.DataType.DATE,
),
(
[],
float,
"float64",
audformat.define.DataType.FLOAT,
),
(
[],
int,
"Int64",
audformat.define.DataType.INTEGER,
),
(
[],
"int64",
"Int64",
audformat.define.DataType.INTEGER,
),
(
[],
"Int64",
"Int64",
audformat.define.DataType.INTEGER,
),
(
[],
str,
"object",
audformat.define.DataType.OBJECT,
),
(
[],
"string",
"string",
audformat.define.DataType.STRING,
),
(
[],
"timedelta64[ns]",
"timedelta64[ns]",
audformat.define.DataType.TIME,
),
(
[0],
"datetime64[ns]",
"datetime64[ns]",
audformat.define.DataType.DATE,
),
(
[0.0],
None,
"float64",
audformat.define.DataType.FLOAT,
),
(
[0],
None,
"Int64",
audformat.define.DataType.INTEGER,
),
(
[np.NaN],
"Int64",
"Int64",
audformat.define.DataType.INTEGER,
),
(
[0, np.NaN],
"Int64",
"Int64",
audformat.define.DataType.INTEGER,
),
(
[np.NaN],
"Int64",
"Int64",
audformat.define.DataType.INTEGER,
),
(
["0"],
None,
"object",
audformat.define.DataType.OBJECT,
),
(
[0],
"timedelta64[ns]",
"timedelta64[ns]",
audformat.define.DataType.TIME,
),
(
[True],
None,
"boolean",
audformat.define.DataType.BOOL,
),
(
[True, False],
bool,
"boolean",
audformat.define.DataType.BOOL,
),
(
[True, False],
"boolean",
"boolean",
audformat.define.DataType.BOOL,
),
],
)
def test_dtype(
tmpdir,
column_values,
column_dtype,
expected_pandas_dtype,
expected_audformat_dtype,
):
r"""Test table columns have correct dtype.
Ensures that a dataframe column,
associated with a table,
has the dtype,
which corresponds to the scheme of the column.
Args:
tmpdir: pytest tmpdir fixture
column_values: values assigned to the column
column_dtype: pandas dtype of values assigned to column
expected_pandas_dtype: pandas dtype of column after assignment
expected_audformat_dtype: audformat dtype corresponding
to the expected pandas dtype.
This is assigned to the scheme of the column
"""
y = pd.Series(column_values, dtype=column_dtype or "object")

index_values = [f"f{n}" for n in range(len(column_values))]
index = audformat.filewise_index(index_values)

db = audformat.testing.create_db(minimal=True)
db["table"] = audformat.Table(index)
db.schemes["column"] = audformat.Scheme(expected_audformat_dtype)
db["table"]["column"] = audformat.Column(scheme_id="column")
db["table"]["column"].set(y.values)

assert db["table"]["column"].scheme.dtype == expected_audformat_dtype
assert db["table"].df["column"].dtype == expected_pandas_dtype

# Store and load table
db_root = tmpdir.join("db")
db.save(db_root, storage_format="csv")
db_new = audformat.Database.load(db_root)

assert db_new["table"]["column"].scheme.dtype == expected_audformat_dtype
assert db_new["table"].df["column"].dtype == expected_pandas_dtype


def test_exceptions():
column = audformat.Column()
with pytest.raises(RuntimeError):
Expand Down
Loading

0 comments on commit c813739

Please sign in to comment.