From 6ee522d5fc9cf5eef5e6d526c822491ad737ec85 Mon Sep 17 00:00:00 2001 From: Alec Koumjian Date: Mon, 8 Jan 2024 12:01:16 -0500 Subject: [PATCH] Passes linkage ids as ray references for remote iod and od workers to avoid massive data duplication --- setup.cfg | 2 +- thor/orbits/iod.py | 54 +++++++++++++++++++++++++++++++++++++++++----- thor/orbits/od.py | 48 ++++++++++++++++++++++++++++++++++++++--- 3 files changed, 95 insertions(+), 9 deletions(-) diff --git a/setup.cfg b/setup.cfg index 4d2d09cd..038a3106 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,7 +31,7 @@ setup_requires = wheel setuptools_scm >= 6.0 install_requires = - adam-core @ git+https://github.com/B612-Asteroid-Institute/adam_core@ef8ee48976dbf9c70580c166de4cc4fd6195fa36#egg=adam_core + adam-core @ git+https://github.com/B612-Asteroid-Institute/adam_core@11292b1edb9d0183bf23f7c29ea9a4744162aa95#egg=adam_core astropy >= 5.3.1 astroquery difi diff --git a/thor/orbits/iod.py b/thor/orbits/iod.py index 9fd3c5ee..4a516c5e 100644 --- a/thor/orbits/iod.py +++ b/thor/orbits/iod.py @@ -1,7 +1,7 @@ import logging import time from itertools import combinations -from typing import Literal, Optional, Tuple, Type, Union +from typing import Iterable, Literal, Optional, Sequence, Tuple, Type, Union import numpy as np import numpy.typing as npt @@ -11,7 +11,7 @@ import ray from adam_core.coordinates.residuals import Residuals from adam_core.propagator import PYOORB, Propagator -from adam_core.propagator.utils import _iterate_chunks +from adam_core.propagator.utils import _iterate_chunk_indices, _iterate_chunks from adam_core.ray_cluster import initialize_use_ray from ..clusters import ClusterMembers @@ -180,7 +180,45 @@ def iod_worker( return iod_orbits, iod_orbit_members -iod_worker_remote = ray.remote(iod_worker) +@ray.remote +def iod_worker_remote( + linkage_ids: Union[npt.NDArray[np.str_], ray.ObjectRef], + linkage_members_indices: Tuple[int, int], + observations: Union[Observations, ray.ObjectRef], + linkage_members: Union[ClusterMembers, FittedOrbitMembers, ray.ObjectRef], + min_obs: int = 6, + min_arc_length: float = 1.0, + contamination_percentage: float = 0.0, + rchi2_threshold: float = 200, + observation_selection_method: Literal[ + "combinations", "first+middle+last", "thirds" + ] = "combinations", + linkage_id_col: str = "cluster_id", + iterate: bool = False, + light_time: bool = True, + propagator: Type[Propagator] = PYOORB, + propagator_kwargs: dict = {}, +) -> Tuple[FittedOrbits, FittedOrbitMembers]: + + + # Select linkage ids from linkage_members_indices + linkage_id_chunk = linkage_ids[linkage_members_indices[0] : linkage_members_indices[1]] + return iod_worker( + linkage_id_chunk, + observations, + linkage_members, + min_obs=min_obs, + min_arc_length=min_arc_length, + contamination_percentage=contamination_percentage, + rchi2_threshold=rchi2_threshold, + observation_selection_method=observation_selection_method, + linkage_id_col=linkage_id_col, + iterate=iterate, + light_time=light_time, + propagator=propagator, + propagator_kwargs=propagator_kwargs, + ) + iod_worker_remote.options(num_returns=1, num_cpus=1) @@ -587,6 +625,11 @@ def initial_orbit_determination( use_ray = initialize_use_ray(num_cpus=max_processes) if use_ray: refs_to_free = [] + + linkage_ids_ref = ray.put(linkage_ids) + refs_to_free.append(linkage_ids_ref) + logger.info("Placed linkage IDs in the object store.") + if linkage_members_ref is None: linkage_members_ref = ray.put(linkage_members) refs_to_free.append(linkage_members_ref) @@ -598,10 +641,11 @@ def initial_orbit_determination( logger.info("Placed observations in the object store.") futures = [] - for linkage_id_chunk in _iterate_chunks(linkage_ids, chunk_size): + for linkage_id_chunk_indices in _iterate_chunk_indices(linkage_ids, chunk_size): futures.append( iod_worker_remote.remote( - linkage_id_chunk, + linkage_ids_ref, + linkage_id_chunk_indices, observations_ref, linkage_members_ref, min_obs=min_obs, diff --git a/thor/orbits/od.py b/thor/orbits/od.py index 9d06f6db..2f2cfce6 100644 --- a/thor/orbits/od.py +++ b/thor/orbits/od.py @@ -11,6 +11,7 @@ from adam_core.coordinates.residuals import Residuals from adam_core.orbits import Orbits from adam_core.propagator import PYOORB, _iterate_chunks +from adam_core.propagator.utils import _iterate_chunk_indices from adam_core.ray_cluster import initialize_use_ray from scipy.linalg import solve @@ -80,7 +81,42 @@ def od_worker( return od_orbits, od_orbit_members -od_worker_remote = ray.remote(od_worker) +@ray.remote +def od_worker_remote( + orbit_ids: npt.NDArray[np.str_], + orbit_ids_indices: Tuple[int, int], + orbits: FittedOrbits, + orbit_members: FittedOrbitMembers, + observations: Observations, + rchi2_threshold: float = 100, + min_obs: int = 5, + min_arc_length: float = 1.0, + contamination_percentage: float = 0.0, + delta: float = 1e-6, + max_iter: int = 20, + method: Literal["central", "finite"] = "central", + fit_epoch: bool = False, + propagator: Literal["PYOORB"] = "PYOORB", + propagator_kwargs: dict = {}, +) -> Tuple[FittedOrbits, FittedOrbitMembers]: + orbit_ids_chunk = orbit_ids[orbit_ids_indices[0] : orbit_ids_indices[1]] + return od_worker( + orbit_ids_chunk, + orbits, + orbit_members, + observations, + rchi2_threshold=rchi2_threshold, + min_obs=min_obs, + min_arc_length=min_arc_length, + contamination_percentage=contamination_percentage, + delta=delta, + max_iter=max_iter, + method=method, + fit_epoch=fit_epoch, + propagator=propagator, + propagator_kwargs=propagator_kwargs, + ) + od_worker_remote.options(num_returns=1, num_cpus=1) @@ -619,6 +655,11 @@ def differential_correction( use_ray = initialize_use_ray(num_cpus=max_processes) if use_ray: refs_to_free = [] + + orbit_ids_ref = ray.put(orbit_ids) + refs_to_free.append(orbit_ids_ref) + logger.info("Placed orbit IDs in the object store.") + if orbits_ref is None: orbits_ref = ray.put(orbits) refs_to_free.append(orbits_ref) @@ -635,10 +676,11 @@ def differential_correction( logger.info("Placed observations in the object store.") futures = [] - for orbit_ids_chunk in _iterate_chunks(orbit_ids, chunk_size): + for orbit_ids_indices in _iterate_chunk_indices(orbit_ids, chunk_size): futures.append( od_worker_remote.remote( - orbit_ids_chunk, + orbit_ids_ref, + orbit_ids_indices, orbits_ref, orbit_members_ref, observations_ref,