Skip to content

Commit

Permalink
Simplify roundtrip io tests (#1702)
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep authored Oct 1, 2024
1 parent e939e95 commit 8e9eb88
Showing 1 changed file with 45 additions and 46 deletions.
91 changes: 45 additions & 46 deletions tests/test_readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
import warnings
from contextlib import contextmanager
from functools import partial
from importlib.util import find_spec
from pathlib import Path
from string import ascii_letters
Expand All @@ -25,6 +26,7 @@

if TYPE_CHECKING:
from os import PathLike
from typing import Literal

HERE = Path(__file__).parent

Expand Down Expand Up @@ -658,30 +660,13 @@ def random_cats(n):
assert_equal(orig, curr)


def test_write_string_types(tmp_path, diskfmt):
# https://github.com/scverse/anndata/issues/456
adata_pth = tmp_path / f"adata.{diskfmt}"

adata = ad.AnnData(
obs=pd.DataFrame(
np.ones((3, 2)),
columns=["a", np.str_("b")],
index=["a", "b", "c"],
),
)

write = getattr(adata, f"write_{diskfmt}")
read = getattr(ad, f"read_{diskfmt}")

write(adata_pth)
from_disk = read(adata_pth)

assert_equal(adata, from_disk)

def test_write_string_type_error(tmp_path, diskfmt):
adata = ad.AnnData(obs=dict(obs_names=list("abc")))
adata.obs[b"c"] = np.zeros(3)

# This should error, and tell you which key is at fault
with pytest.raises(TypeError, match=r"writing key 'obs'") as exc_info:
write(adata_pth)
getattr(adata, f"write_{diskfmt}")(tmp_path / f"adata.{diskfmt}")

assert "b'c'" in str(exc_info.value)

Expand Down Expand Up @@ -722,15 +707,39 @@ def test_zarr_chunk_X(tmp_path):
# Round-tripping scanpy datasets
################################

diskfmt2 = diskfmt

def _do_roundtrip(
adata: ad.AnnData, pth: Path, diskfmt: Literal["h5ad", "zarr"]
) -> ad.AnnData:
getattr(adata, f"write_{diskfmt}")(pth)
return getattr(ad, f"read_{diskfmt}")(pth)


@pytest.fixture
def roundtrip(diskfmt):
return partial(_do_roundtrip, diskfmt=diskfmt)


def test_write_string_types(tmp_path, diskfmt, roundtrip):
# https://github.com/scverse/anndata/issues/456
adata_pth = tmp_path / f"adata.{diskfmt}"

adata = ad.AnnData(
obs=pd.DataFrame(
np.ones((3, 2)),
columns=["a", np.str_("b")],
index=["a", "b", "c"],
),
)

from_disk = roundtrip(adata, adata_pth)

assert_equal(adata, from_disk)


@pytest.mark.skipif(not find_spec("scanpy"), reason="Scanpy is not installed")
def test_scanpy_pbmc68k(tmp_path, diskfmt, diskfmt2):
read1 = lambda pth: getattr(ad, f"read_{diskfmt}")(pth)
write1 = lambda adata, pth: getattr(adata, f"write_{diskfmt}")(pth)
read2 = lambda pth: getattr(ad, f"read_{diskfmt2}")(pth)
write2 = lambda adata, pth: getattr(adata, f"write_{diskfmt2}")(pth)
def test_scanpy_pbmc68k(tmp_path, diskfmt, roundtrip, diskfmt2):
roundtrip2 = partial(_do_roundtrip, diskfmt=diskfmt2)

filepth1 = tmp_path / f"test1.{diskfmt}"
filepth2 = tmp_path / f"test2.{diskfmt2}"
Expand All @@ -745,17 +754,15 @@ def test_scanpy_pbmc68k(tmp_path, diskfmt, diskfmt2):
warnings.simplefilter("ignore", ad.OldFormatWarning)
pbmc = sc.datasets.pbmc68k_reduced()

write1(pbmc, filepth1)
from_disk1 = read1(filepth1) # Do we read okay
write2(from_disk1, filepth2) # Can we round trip
from_disk2 = read2(filepth2)
from_disk1 = roundtrip(pbmc, filepth1) # Do we read okay
from_disk2 = roundtrip2(from_disk1, filepth2) # Can we round trip

assert_equal(pbmc, from_disk1) # Not expected to be exact due to `nan`s
assert_equal(pbmc, from_disk2)


@pytest.mark.skipif(not find_spec("scanpy"), reason="Scanpy is not installed")
def test_scanpy_krumsiek11(tmp_path, diskfmt):
def test_scanpy_krumsiek11(tmp_path, diskfmt, roundtrip):
filepth = tmp_path / f"test.{diskfmt}"
with warnings.catch_warnings():
warnings.filterwarnings(
Expand All @@ -769,11 +776,10 @@ def test_scanpy_krumsiek11(tmp_path, diskfmt):
del orig.uns["highlights"] # Can’t write int keys
# Can’t write "string" dtype: https://github.com/scverse/anndata/issues/679
orig.obs["cell_type"] = orig.obs["cell_type"].astype(str)
getattr(orig, f"write_{diskfmt}")(filepth)
with pytest.warns(UserWarning, match=r"Observation names are not unique"):
read = getattr(ad, f"read_{diskfmt}")(filepth)
curr = roundtrip(orig, filepth)

assert_equal(orig, read, exact=True)
assert_equal(orig, curr, exact=True)


# Checking if we can read legacy zarr files
Expand Down Expand Up @@ -808,11 +814,8 @@ def test_backwards_compat_zarr():
assert_equal(pbmc_zarr, pbmc_orig)


# TODO: use diskfmt fixture once zarr backend implemented
def test_adata_in_uns(tmp_path, diskfmt):
def test_adata_in_uns(tmp_path, diskfmt, roundtrip):
pth = tmp_path / f"adatas_in_uns.{diskfmt}"
read = lambda pth: getattr(ad, f"read_{diskfmt}")(pth)
write = lambda adata, pth: getattr(adata, f"write_{diskfmt}")(pth)

orig = gen_adata((4, 5))
orig.uns["adatas"] = {
Expand All @@ -823,20 +826,16 @@ def test_adata_in_uns(tmp_path, diskfmt):
another_one.raw = gen_adata((2, 7))
orig.uns["adatas"]["b"].uns["another_one"] = another_one

write(orig, pth)
curr = read(pth)
curr = roundtrip(orig, pth)

assert_equal(orig, curr)


def test_io_dtype(tmp_path, diskfmt, dtype):
def test_io_dtype(tmp_path, diskfmt, dtype, roundtrip):
pth = tmp_path / f"adata_dtype.{diskfmt}"
read = lambda pth: getattr(ad, f"read_{diskfmt}")(pth)
write = lambda adata, pth: getattr(adata, f"write_{diskfmt}")(pth)

orig = ad.AnnData(np.ones((5, 8), dtype=dtype))
write(orig, pth)
curr = read(pth)
curr = roundtrip(orig, pth)

assert curr.X.dtype == dtype

Expand Down

0 comments on commit 8e9eb88

Please sign in to comment.