diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 584248dc..a373542f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,8 +2,8 @@ fail_fast: false default_language_version: python: python3 default_stages: - - commit - - push + - pre-commit + - pre-push minimum_pre_commit_version: 3.0.0 repos: - repo: https://github.com/pre-commit/mirrors-mypy @@ -13,7 +13,7 @@ repos: additional_dependencies: [numpy>=1.25.0] files: ^src - repo: https://github.com/psf/black - rev: 24.8.0 + rev: 24.10.0 hooks: - id: black additional_dependencies: [toml] @@ -29,7 +29,7 @@ repos: additional_dependencies: [toml] args: [--order-by-type] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-merge-conflict - id: check-ast @@ -42,12 +42,12 @@ repos: - id: check-yaml - id: check-toml - repo: https://github.com/asottile/pyupgrade - rev: v3.17.0 + rev: v3.18.0 hooks: - id: pyupgrade args: [--py3-plus, --py38-plus, --keep-runtime-typing] - repo: https://github.com/asottile/blacken-docs - rev: 1.18.0 + rev: 1.19.0 hooks: - id: blacken-docs additional_dependencies: [black==23.1.0] @@ -63,7 +63,7 @@ repos: - id: doc8 - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.6.4 + rev: v0.6.9 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/pyproject.toml b/pyproject.toml index 20ac87c6..2fd66c9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ dependencies = [ "numpy>=1.20.0", "scipy>=1.7.0", "pandas>=2.0.1", - "networkx>=2.6.3", + "networkx>=3.2", # https://github.com/scverse/scanpy/issues/2411 "matplotlib>=3.5.0", "anndata>=0.9.1", diff --git a/src/moscot/backends/ott/solver.py b/src/moscot/backends/ott/solver.py index 5a66d47a..dba12b5a 100644 --- a/src/moscot/backends/ott/solver.py +++ b/src/moscot/backends/ott/solver.py @@ -533,7 +533,7 @@ def _prepare( # type: ignore[override] ) -> Tuple[MultiLoader, MultiLoader]: train_loaders = [] validate_loaders = [] - seed = kwargs.get("seed", None) + seed = kwargs.get("seed") is_aligned = kwargs.get("is_aligned", False) if train_size == 1.0: for sample_pair in sample_pairs: diff --git a/src/moscot/datasets.py b/src/moscot/datasets.py index 1dc3c381..fb34a573 100644 --- a/src/moscot/datasets.py +++ b/src/moscot/datasets.py @@ -565,7 +565,9 @@ def _get_random_trees( assert len(leaf_names[i]) == n_leaves trees = [] for tree_idx in range(n_trees): - G = nx.random_tree(n_initial_nodes, seed=seed, create_using=nx.DiGraph) + tempG = nx.random_labeled_tree(n_initial_nodes, seed=seed) + G = nx.DiGraph() + G.add_edges_from(tempG.edges) leaves = [x for x in G.nodes() if G.out_degree(x) == 0 and G.in_degree(x) == 1] inner_nodes = list(set(G.nodes()) - set(leaves)) leaves_updated = leaves.copy()