From db86a206ad689890104b4e0d8f4cb773cabc51b5 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 22 Oct 2024 15:51:13 +0200 Subject: [PATCH] (chore): add tests for all axis combinations --- tests/test_views.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/tests/test_views.py b/tests/test_views.py index 73a9cf585..99b1aa74f 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -3,6 +3,7 @@ from contextlib import ExitStack from copy import deepcopy from operator import mul +from typing import TYPE_CHECKING import joblib import numpy as np @@ -35,6 +36,9 @@ ) from anndata.utils import asarray +if TYPE_CHECKING: + from typing import Literal + IGNORE_SPARSE_EFFICIENCY_WARNING = pytest.mark.filterwarnings( "ignore:Changing the sparsity structure:scipy.sparse.SparseEfficiencyWarning" ) @@ -786,11 +790,24 @@ def test_dataframe_view_index_setting(): assert a2.obs.index.values.tolist() == ["a", "b"] -def test_ellipsis_index(adata, subset_func, matrix_type): +@pytest.mark.parametrize("axis", ["obs", "var", None]) +def test_ellipsis_index( + adata: ad.AnnData, subset_func, matrix_type, axis: Literal["obs", "var"] | None +): adata = gen_adata((10, 10), X_type=matrix_type, **GEN_ADATA_DASK_ARGS) - subset_obs_names = subset_func(adata.obs_names) - subset_ellipsis = adata[subset_obs_names, ...] - subset = adata[subset_obs_names, :] + if axis is not None: + axis_subset = subset_func(getattr(adata, f"{axis}_names")) + subset_with_ellipsis = ( + (axis_subset, Ellipsis) if axis == "obs" else (Ellipsis, axis_subset) + ) + subset_with_slice = ( + (axis_subset, slice(None)) if axis == "obs" else (slice(None), axis_subset) + ) + else: + subset_with_ellipsis = Ellipsis + subset_with_slice = (slice(None), slice(None)) + subset_ellipsis = adata[subset_with_ellipsis] + subset = adata[subset_with_slice] assert_equal(subset_ellipsis, subset)