Skip to content

Commit

Permalink
Fix internal API use
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Oct 24, 2024
1 parent be96642 commit 73bcecd
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
8 changes: 4 additions & 4 deletions src/anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,13 +1429,13 @@ def to_memory(self, *, copy: bool = False) -> AnnData:
]:
attr = getattr(self, attr_name, None)
if attr is not None:
new[attr_name] = to_memory(attr, copy)
new[attr_name] = to_memory(attr, copy=copy)

if self.raw is not None:
new["raw"] = {
"X": to_memory(self.raw.X, copy),
"var": to_memory(self.raw.var, copy),
"varm": to_memory(self.raw.varm, copy),
"X": to_memory(self.raw.X, copy=copy),
"var": to_memory(self.raw.var, copy=copy),
"varm": to_memory(self.raw.varm, copy=copy),
}

if self.isbacked:
Expand Down
27 changes: 15 additions & 12 deletions src/anndata/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def assert_equal(
def assert_equal_cupy(
a: CupyArray, b: object, *, exact: bool = False, elem_name: str | None = None
):
assert_equal(b, a.get(), exact, elem_name)
assert_equal(b, a.get(), exact=exact, elem_name=elem_name)


@assert_equal.register(np.ndarray)
Expand All @@ -594,7 +594,10 @@ def assert_equal_ndarray(
# Reshaping to allow >2d arrays
assert a.shape == b.shape, format_msg(elem_name)
assert_equal(
pd.DataFrame(a.reshape(-1)), pd.DataFrame(b.reshape(-1)), exact, elem_name
pd.DataFrame(a.reshape(-1)),
pd.DataFrame(b.reshape(-1)),
exact=exact,
elem_name=elem_name,
)
else:
assert np.all(a == b), format_msg(elem_name)
Expand All @@ -617,22 +620,22 @@ def assert_equal_sparse(
elem_name: str | None = None,
):
a = asarray(a)
assert_equal(b, a, exact, elem_name=elem_name)
assert_equal(b, a, exact=exact, elem_name=elem_name)


@assert_equal.register(SpArray)
def assert_equal_sparse_array(
a: SpArray, b: object, *, exact: bool = False, elem_name: str | None = None
):
return assert_equal_sparse(a, b, exact, elem_name)
return assert_equal_sparse(a, b, exact=exact, elem_name=elem_name)


@assert_equal.register(CupySparseMatrix)
def assert_equal_cupy_sparse(
a: CupySparseMatrix, b: object, *, exact: bool = False, elem_name: str | None = None
):
a = a.toarray()
assert_equal(b, a, exact, elem_name=elem_name)
assert_equal(b, a, exact=exact, elem_name=elem_name)


@assert_equal.register(h5py.Dataset)
Expand All @@ -641,22 +644,22 @@ def assert_equal_h5py_dataset(
a: ArrayStorageType, b: object, *, exact: bool = False, elem_name: str | None = None
):
a = asarray(a)
assert_equal(b, a, exact, elem_name=elem_name)
assert_equal(b, a, exact=exact, elem_name=elem_name)


@assert_equal.register(DaskArray)
def assert_equal_dask_array(
a: DaskArray, b: object, *, exact: bool = False, elem_name: str | None = None
):
assert_equal(b, a.compute(), exact, elem_name)
assert_equal(b, a.compute(), exact=exact, elem_name=elem_name)


@assert_equal.register(pd.DataFrame)
def are_equal_dataframe(
a: pd.DataFrame, b: object, *, exact: bool = False, elem_name: str | None = None
):
if not isinstance(b, pd.DataFrame):
assert_equal(b, a, exact, elem_name) # , a.values maybe?
assert_equal(b, a, exact=exact, elem_name=elem_name) # , a.values maybe?

report_name(pd.testing.assert_frame_equal)(
a,
Expand Down Expand Up @@ -690,7 +693,7 @@ def assert_equal_mapping(
for k in a.keys():
if elem_name is None:
elem_name = ""
assert_equal(a[k], b[k], exact, f"{elem_name}/{k}")
assert_equal(a[k], b[k], exact=exact, elem_name=f"{elem_name}/{k}")


@assert_equal.register(AlignedMappingBase)
Expand Down Expand Up @@ -783,8 +786,8 @@ def fmt_name(x):

# There may be issues comparing views, since np.allclose
# can modify ArrayViews if they contain `nan`s
assert_equal(a.obs_names, b.obs_names, exact, elem_name=fmt_name("obs_names"))
assert_equal(a.var_names, b.var_names, exact, elem_name=fmt_name("var_names"))
assert_equal(a.obs_names, b.obs_names, exact=exact, elem_name=fmt_name("obs_names"))
assert_equal(a.var_names, b.var_names, exact=exact, elem_name=fmt_name("var_names"))
if not exact:
# Reorder all elements if necessary
idx = [slice(None), slice(None)]
Expand Down Expand Up @@ -813,7 +816,7 @@ def fmt_name(x):
assert_equal(
getattr(a, attr),
getattr(b, attr),
exact,
exact=exact,
elem_name=fmt_name(attr),
)

Expand Down

0 comments on commit 73bcecd

Please sign in to comment.