diff --git a/thor/checkpointing.py b/thor/checkpointing.py index 27ed2e46..93a0b187 100644 --- a/thor/checkpointing.py +++ b/thor/checkpointing.py @@ -43,7 +43,7 @@ class Config: stage: Literal["cluster_and_link"] filtered_observations: Union[Observations, ray.ObjectRef] - transformed_detections: TransformedDetections + transformed_detections: Union[TransformedDetections, ray.ObjectRef] class InitialOrbitDetermination(pydantic.BaseModel): @@ -51,9 +51,9 @@ class Config: arbitrary_types_allowed = True stage: Literal["initial_orbit_determination"] - filtered_observations: Observations - clusters: Clusters - cluster_members: ClusterMembers + filtered_observations: Union[Observations, ray.ObjectRef] + clusters: Union[Clusters, ray.ObjectRef] + cluster_members: Union[ClusterMembers, ray.ObjectRef] class DifferentialCorrection(pydantic.BaseModel): @@ -61,9 +61,9 @@ class Config: arbitrary_types_allowed = True stage: Literal["differential_correction"] - filtered_observations: Observations - iod_orbits: FittedOrbits - iod_orbit_members: FittedOrbitMembers + filtered_observations: Union[Observations, ray.ObjectRef] + iod_orbits: Union[FittedOrbits, ray.ObjectRef] + iod_orbit_members: Union[FittedOrbitMembers, ray.ObjectRef] class RecoverOrbits(pydantic.BaseModel): @@ -71,9 +71,9 @@ class Config: arbitrary_types_allowed = True stage: Literal["recover_orbits"] - filtered_observations: Observations - od_orbits: FittedOrbits - od_orbit_members: FittedOrbitMembers + filtered_observations: Union[Observations, ray.ObjectRef] + od_orbits: Union[FittedOrbits, ray.ObjectRef] + od_orbit_members: Union[FittedOrbitMembers, ray.ObjectRef] class Complete(pydantic.BaseModel): @@ -81,8 +81,8 @@ class Config: arbitrary_types_allowed = True stage: Literal["complete"] - recovered_orbits: FittedOrbits - recovered_orbit_members: FittedOrbitMembers + recovered_orbits: Union[FittedOrbits, ray.ObjectRef] + recovered_orbit_members: Union[FittedOrbitMembers, ray.ObjectRef] CheckpointData = Annotated[ diff --git a/thor/clusters.py b/thor/clusters.py index ef4c3d38..8fbae0c7 100644 --- a/thor/clusters.py +++ b/thor/clusters.py @@ -1,7 +1,7 @@ import logging import time import uuid -from typing import List, Literal, Optional, Tuple +from typing import List, Literal, Optional, Tuple, Union import numba import numpy as np @@ -31,7 +31,7 @@ "ClusterMembers", ] -logger = logging.getLogger("thor") +logger = logging.getLogger(__name__) class Clusters(qv.Table): @@ -585,7 +585,7 @@ def cluster_velocity_worker( def cluster_and_link( - observations: TransformedDetections, + observations: Union[TransformedDetections, ray.ObjectRef], vx_range: List[float] = [-0.1, 0.1], vy_range: List[float] = [-0.1, 0.1], vx_bins: int = 100, @@ -673,6 +673,10 @@ def cluster_and_link( logger.info("Max sample distance: {}".format(radius)) logger.info("Minimum samples: {}".format(min_obs)) + if isinstance(observations, ray.ObjectRef): + observations = ray.get(observations) + logger.info("Retrieved observations from the object store.") + clusters_list = [] cluster_members_list = [] if len(observations) > 0: @@ -690,14 +694,22 @@ def cluster_and_link( if max_processes is None or max_processes > 1: if not ray.is_initialized(): - ray.init(address="auto") + logger.info( + f"Ray is not initialized. Initializing with {max_processes}..." + ) + ray.init(address="auto", num_cpus=max_processes) # Put all arrays (which can be large) in ray's # local object store ahead of time - obs_ids_oid = ray.put(obs_ids) - theta_x_oid = ray.put(theta_x) - theta_y_oid = ray.put(theta_y) - dt_oid = ray.put(dt) + obs_ids_ref = ray.put(obs_ids) + theta_x_ref = ray.put(theta_x) + theta_y_ref = ray.put(theta_y) + dt_ref = ray.put(dt) + refs_to_free = [obs_ids_ref, theta_x_ref, theta_y_ref, dt_ref] + logger.info("Placed gnomonic coordinate arrays in the object store.") + # TODO: transformed detections are already in the object store so we might + # want to instead pass references to those rather than extract arrays + # from them and put them in the object store again. futures = [] for vxi_chunk, vyi_chunk in zip( @@ -707,10 +719,10 @@ def cluster_and_link( cluster_velocity_remote.remote( vxi_chunk, vyi_chunk, - obs_ids_oid, - theta_x_oid, - theta_y_oid, - dt_oid, + obs_ids_ref, + theta_x_ref, + theta_y_ref, + dt_ref, radius=radius, min_obs=min_obs, min_arc_length=min_arc_length, @@ -724,6 +736,11 @@ def cluster_and_link( clusters_list.append(result[0]) cluster_members_list.append(result[1]) + ray.internal.free(refs_to_free) + logger.info( + f"Removed {len(refs_to_free)} references from the object store." + ) + else: for vxi_chunk, vyi_chunk in zip( diff --git a/thor/main.py b/thor/main.py index 4755d805..c2feb06b 100644 --- a/thor/main.py +++ b/thor/main.py @@ -30,8 +30,8 @@ def initialize_use_ray(config: Config) -> bool: if config.max_processes is None or config.max_processes > 1: # Initialize ray if not ray.is_initialized(): - logger.debug( - f"Ray is not initialized. Initializing with {config.max_processes}..." + logger.info( + f"Ray is not initialized. Initializing with {config.max_processes} cpus..." ) ray.init(num_cpus=config.max_processes) @@ -134,12 +134,15 @@ def link_test_orbit( use_ray = initialize_use_ray(config) + refs_to_free = [] if ( use_ray and observations is not None and not isinstance(observations, ray.ObjectRef) ): observations = ray.put(observations) + refs_to_free.append(observations) + logger.info("Placed observations in the object store.") checkpoint = load_initial_checkpoint_values(test_orbit_directory) logger.info(f"Starting at stage: {checkpoint.stage}") @@ -159,6 +162,12 @@ def link_test_orbit( ) if checkpoint.stage == "filter_observations": + if use_ray: + if not isinstance(observations, ray.ObjectRef): + observations = ray.put(observations) + refs_to_free.append(observations) + logger.info("Placed observations in the object store.") + filtered_observations = filter_observations( observations, test_orbit, config, filters ) @@ -176,18 +185,31 @@ def link_test_orbit( path=(filtered_observations_path,), ) + if use_ray: + if not isinstance(filtered_observations, ray.ObjectRef): + filtered_observations = ray.put(filtered_observations) + refs_to_free.append(filtered_observations) + logger.info("Placed filtered observations in the object store.") + checkpoint = create_checkpoint_data( "range_and_transform", filtered_observations=filtered_observations, ) - # Observations are no longer needed, so we can delete them + # Observations are no longer needed. If we are using ray + # lets explicitly free the memory. + if use_ray and isinstance(observations, ray.ObjectRef): + ray.internal.free([observations]) + logger.info("Removed observations from the object store.") del observations if checkpoint.stage == "range_and_transform": filtered_observations = checkpoint.filtered_observations - if use_ray and not isinstance(filtered_observations, ray.ObjectRef): - filtered_observations = ray.put(filtered_observations) + if use_ray: + if not isinstance(filtered_observations, ray.ObjectRef): + filtered_observations = ray.put(filtered_observations) + refs_to_free.append(filtered_observations) + logger.info("Placed filtered observations in the object store.") # Range and transform the observations transformed_detections = range_and_transform( @@ -211,21 +233,26 @@ def link_test_orbit( path=(transformed_detections_path,), ) + if use_ray: + if not isinstance(filtered_observations, ray.ObjectRef): + filtered_observations = ray.put(filtered_observations) + refs_to_free.append(filtered_observations) + logger.info("Placed filtered observations in the object store.") + if not isinstance(transformed_detections, ray.ObjectRef): + transformed_detections = ray.put(transformed_detections) + refs_to_free.append(transformed_detections) + logger.info("Placed transformed detections in the object store.") + checkpoint = create_checkpoint_data( "cluster_and_link", filtered_observations=filtered_observations, transformed_detections=transformed_detections, ) - # TODO: ray support for the rest of the pipeline has not yet been implemented - # so we will convert the ray objects to regular objects for now - if use_ray: - if isinstance(checkpoint.filtered_observations, ray.ObjectRef): - checkpoint.filtered_observations = ray.get(filtered_observations) - if checkpoint.stage == "cluster_and_link": - transformed_detections = checkpoint.transformed_detections filtered_observations = checkpoint.filtered_observations + transformed_detections = checkpoint.transformed_detections + # Run clustering clusters, cluster_members = cluster_and_link( transformed_detections, @@ -258,6 +285,20 @@ def link_test_orbit( path=(clusters_path, cluster_members_path), ) + if use_ray: + if not isinstance(filtered_observations, ray.ObjectRef): + filtered_observations = ray.put(filtered_observations) + refs_to_free.append(filtered_observations) + logger.info("Placed filtered observations in the object store.") + if not isinstance(clusters, ray.ObjectRef): + clusters = ray.put(clusters) + refs_to_free.append(clusters) + logger.info("Placed clusters in the object store.") + if not isinstance(cluster_members, ray.ObjectRef): + cluster_members = ray.put(cluster_members) + refs_to_free.append(cluster_members) + logger.info("Placed cluster members in the object store.") + checkpoint = create_checkpoint_data( "initial_orbit_determination", filtered_observations=filtered_observations, @@ -306,6 +347,20 @@ def link_test_orbit( path=(iod_orbits_path, iod_orbit_members_path), ) + if use_ray: + if not isinstance(filtered_observations, ray.ObjectRef): + filtered_observations = ray.put(filtered_observations) + refs_to_free.append(filtered_observations) + logger.info("Placed filtered observations in the object store.") + if not isinstance(iod_orbits, ray.ObjectRef): + iod_orbits = ray.put(iod_orbits) + refs_to_free.append(iod_orbits) + logger.info("Placed initial orbits in the object store.") + if not isinstance(iod_orbit_members, ray.ObjectRef): + iod_orbit_members = ray.put(iod_orbit_members) + refs_to_free.append(iod_orbit_members) + logger.info("Placed initial orbit members in the object store.") + checkpoint = create_checkpoint_data( "differential_correction", filtered_observations=filtered_observations, @@ -314,9 +369,10 @@ def link_test_orbit( ) if checkpoint.stage == "differential_correction": + filtered_observations = checkpoint.filtered_observations iod_orbits = checkpoint.iod_orbits iod_orbit_members = checkpoint.iod_orbit_members - filtered_observations = checkpoint.filtered_observations + # Run differential correction od_orbits, od_orbit_members = differential_correction( iod_orbits, @@ -354,6 +410,24 @@ def link_test_orbit( path=(od_orbits_path, od_orbit_members_path), ) + if use_ray: + if not isinstance(filtered_observations, ray.ObjectRef): + filtered_observations = ray.put(filtered_observations) + refs_to_free.append(filtered_observations) + logger.info("Placed filtered observations in the object store.") + if not isinstance(od_orbits, ray.ObjectRef): + od_orbits = ray.put(od_orbits) + refs_to_free.append(od_orbits) + logger.info( + "Placed differentially corrected orbits in the object store." + ) + if not isinstance(od_orbit_members, ray.ObjectRef): + od_orbit_members = ray.put(od_orbit_members) + refs_to_free.append(od_orbit_members) + logger.info( + "Placed differentially corrected orbit members in the object store." + ) + checkpoint = create_checkpoint_data( "recover_orbits", filtered_observations=filtered_observations, @@ -362,9 +436,10 @@ def link_test_orbit( ) if checkpoint.stage == "recover_orbits": + filtered_observations = checkpoint.filtered_observations od_orbits = checkpoint.od_orbits od_orbit_members = checkpoint.od_orbit_members - filtered_observations = checkpoint.filtered_observations + # Run arc extension recovered_orbits, recovered_orbit_members = merge_and_extend_orbits( od_orbits, @@ -387,6 +462,12 @@ def link_test_orbit( observations_chunk_size=100000, ) + if use_ray and len(refs_to_free) > 0: + ray.internal.free(refs_to_free) + logger.info( + f"Removed {len(refs_to_free)} references from the object store." + ) + recovered_orbits_path = None recovered_orbit_members_path = None if test_orbit_directory is not None: diff --git a/thor/observations/filters.py b/thor/observations/filters.py index 54b07251..6740ab4d 100644 --- a/thor/observations/filters.py +++ b/thor/observations/filters.py @@ -77,7 +77,7 @@ class ObservationFilter(abc.ABC): @abc.abstractmethod def apply( self, - observations: "Observations", + observations: Union["Observations", ray.ObjectRef], test_orbit: TestOrbits, max_processes: Optional[int] = 1, ) -> "Observations": @@ -149,27 +149,36 @@ def apply( logger.info("Applying TestOrbitRadiusObservationFilter...") logger.info(f"Using radius = {self.radius:.5f} deg") + if isinstance(observations, ray.ObjectRef): + observations_ref = observations + observations = ray.get(observations) + logger.info("Retrieved observations from the object store.") + else: + observations_ref = None + # Generate an ephemeris for every observer time/location in the dataset ephemeris = test_orbit.generate_ephemeris_from_observations(observations) filtered_observations_list = [] if max_processes is None or max_processes > 1: if not ray.is_initialized(): - logger.debug( + logger.info( f"Ray is not initialized. Initializing with {max_processes}..." ) ray.init(num_cpus=max_processes) - if isinstance(observations, ray.ObjectRef): - observations_ref = observations - observations = ray.get(observations_ref) - else: + refs_to_free = [] + if observations_ref is None: observations_ref = ray.put(observations) + refs_to_free.append(observations_ref) + logger.info("Placed observations in the object store.") - if isinstance(ephemeris, ray.ObjectRef): - ephemeris_ref = ephemeris - else: + if not isinstance(ephemeris, ray.ObjectRef): ephemeris_ref = ray.put(ephemeris) + refs_to_free.append(ephemeris_ref) + logger.info("Placed ephemeris in the object store.") + else: + ephemeris_ref = ephemeris state_ids = observations.state_id.unique().sort() futures = [] @@ -187,6 +196,12 @@ def apply( finished, futures = ray.wait(futures, num_returns=1) filtered_observations_list.append(ray.get(finished[0])) + if len(refs_to_free) > 0: + ray.internal.free(refs_to_free) + logger.info( + f"Removed {len(refs_to_free)} references from the object store." + ) + else: state_ids = observations.state_id.unique().sort() for state_id in state_ids: @@ -270,7 +285,7 @@ def _within_radius( def filter_observations( - observations: Observations, + observations: Union[Observations, ray.ObjectRef], test_orbit: TestOrbits, config: Config, filters: Optional[List[ObservationFilter]] = None, @@ -304,12 +319,31 @@ def filter_observations( # By default we always filter by radius from the predicted position of the test orbit filters = [TestOrbitRadiusObservationFilter(radius=config.cell_radius)] + use_ray = config.max_processes is None or config.max_processes > 1 + refs_to_free = [] + if use_ray: + if not isinstance(observations, ray.ObjectRef): + observations = ray.put(observations) + refs_to_free.append(observations) + logger.info("Placed observations in the object store.") + filtered_observations = observations for filter_i in filters: + if use_ray and not isinstance(filtered_observations, ray.ObjectRef): + filtered_observations = ray.put(filtered_observations) + refs_to_free.append(filtered_observations) + logger.info("Placed filtered observations in the object store.") + filtered_observations = filter_i.apply( filtered_observations, test_orbit, config.max_processes ) + # We are done filtering so lets free up the references that were added within + # the scope of this function + if len(refs_to_free) > 0: + ray.internal.free(refs_to_free) + logger.info(f"Removed {len(refs_to_free)} references from the object store.") + # Defragment the observations if len(filtered_observations) > 0: filtered_observations = qv.defragment(filtered_observations) diff --git a/thor/orbits/attribution.py b/thor/orbits/attribution.py index bea4fb01..b6083b03 100644 --- a/thor/orbits/attribution.py +++ b/thor/orbits/attribution.py @@ -1,6 +1,6 @@ import logging import time -from typing import Literal, Optional, Tuple +from typing import Literal, Optional, Tuple, Union import numpy as np import numpy.typing as npt @@ -105,7 +105,7 @@ def drop_coincident_attributions( def attribution_worker( orbit_ids: npt.NDArray[np.str_], observation_indices: npt.NDArray[np.int64], - orbits: FittedOrbits, + orbits: Union[Orbits, FittedOrbits], observations: Observations, radius: float = 1 / 3600, propagator: Literal["PYOORB"] = "PYOORB", @@ -116,6 +116,9 @@ def attribution_worker( else: raise ValueError(f"Invalid propagator '{propagator}'.") + if isinstance(orbits, FittedOrbits): + orbits = orbits.to_orbits() + # Select the orbits and observations for this batch observations = observations.take(observation_indices) orbits = orbits.apply_mask(pc.is_in(orbits.orbit_id, orbit_ids)) @@ -233,29 +236,64 @@ def attribution_worker( def attribute_observations( - orbits: Orbits, - observations: Observations, + orbits: Union[Orbits, FittedOrbits, ray.ObjectRef], + observations: Union[Observations, ray.ObjectRef], radius: float = 5 / 3600, propagator: Literal["PYOORB"] = "PYOORB", propagator_kwargs: dict = {}, orbits_chunk_size: int = 10, observations_chunk_size: int = 100000, max_processes: Optional[int] = 1, + orbit_ids: Optional[npt.NDArray[np.str_]] = None, + obs_ids: Optional[npt.NDArray[np.str_]] = None, ) -> Attributions: logger.info("Running observation attribution...") time_start = time.time() - orbit_ids = orbits.orbit_id + if isinstance(orbits, ray.ObjectRef): + orbits_ref = orbits + orbits = ray.get(orbits) + logger.info("Retrieved orbits from the object store.") + + if orbit_ids is not None: + orbits = orbits.apply_mask(pc.is_in(orbits.orbit_id, orbit_ids)) + logger.info("Applied orbit ID mask to orbits.") + else: + orbits_ref = None + + if isinstance(observations, ray.ObjectRef): + observations_ref = observations + observations = ray.get(observations) + logger.info("Retrieved observations from the object store.") + if obs_ids is not None: + observations = observations.apply_mask(pc.is_in(observations.id, obs_ids)) + logger.info("Applied observation ID mask to observations.") + else: + observations_ref = None + + if isinstance(orbits, FittedOrbits): + orbits = orbits.to_orbits() + + if orbit_ids is None: + orbit_ids = orbits.orbit_id observation_indices = np.arange(0, len(observations)) attributions_list = [] if max_processes is None or max_processes > 1: if not ray.is_initialized(): - ray.init(address="auto") - - observations_ref = ray.put(observations) - orbits_ref = ray.put(orbits) + logger.info(f"Ray is not initialized. Initializing with {max_processes}...") + ray.init(address="auto", max_processes=max_processes) + + refs_to_free = [] + if orbits_ref is None: + orbits_ref = ray.put(orbits) + refs_to_free.append(orbits_ref) + logger.info("Placed orbits in the object store.") + if observations_ref is None: + observations_ref = ray.put(observations) + refs_to_free.append(observations_ref) + logger.info("Placed observations in the object store.") futures = [] for orbit_id_chunk in _iterate_chunks(orbit_ids, orbits_chunk_size): @@ -278,6 +316,12 @@ def attribute_observations( finished, futures = ray.wait(futures, num_returns=1) attributions_list.append(ray.get(finished[0])) + if len(refs_to_free) > 0: + ray.internal.free(refs_to_free) + logger.info( + f"Removed {len(refs_to_free)} references from the object store." + ) + else: for orbit_id_chunk in _iterate_chunks(orbit_ids, orbits_chunk_size): for observations_indices_chunk in _iterate_chunks( @@ -308,9 +352,9 @@ def attribute_observations( def merge_and_extend_orbits( - orbits: FittedOrbits, - orbit_members: FittedOrbitMembers, - observations: Observations, + orbits: Union[FittedOrbits, ray.ObjectRef], + orbit_members: Union[FittedOrbitMembers, ray.ObjectRef], + observations: Union[Observations, ray.ObjectRef], min_obs: int = 6, min_arc_length: float = 1.0, contamination_percentage: float = 20.0, @@ -345,9 +389,29 @@ def merge_and_extend_orbits( Which parallelization backend to use {'ray', 'mp', cf}. Defaults to using Python's concurrent.futures module ('cf'). """ - time_start = time.time() + time_start = time.perf_counter() logger.info("Running orbit extension and merging...") + if isinstance(orbits, ray.ObjectRef): + orbits_ref = orbits + orbits = ray.get(orbits) + logger.info("Retrieved orbits from the object store.") + else: + orbits_ref = None + + if isinstance(orbit_members, ray.ObjectRef): + orbit_members = ray.get(orbit_members) + logger.info("Retrieved orbit members from the object store.") + + if isinstance(observations, ray.ObjectRef): + observations_ref = observations + observations = ray.get(observations) + logger.info("Retrieved observations from the object store.") + else: + observations_ref = None + + use_ray = max_processes is None or max_processes > 1 + # Set the running variables orbits_iter = orbits orbit_members_iter = orbit_members @@ -358,18 +422,48 @@ def merge_and_extend_orbits( odp_orbit_members_list = [] if len(orbits_iter) > 0 and len(observations_iter) > 0: + if use_ray: + if not ray.is_initialized(): + logger.info( + f"Ray is not initialized. Initializing with {max_processes}..." + ) + ray.init(address="auto", max_processes=max_processes) + + refs_to_free = [] + if observations_ref is None: + observations_ref = ray.put(observations) + refs_to_free.append(observations_ref) + logger.info("Placed observations in the object store.") + converged = False while not converged: + + if use_ray: + # Orbits will change with differential correction so we need to add them + # to the object store at the start of each iteration (we cannot simply + # pass references to the same immutable object) + orbits_ref = ray.put(orbits_iter) + logger.info("Placed orbits in the object store.") + + orbits_in = orbits_ref + observations_in = observations_ref + + else: + orbits_in = orbits_iter + observations_in = observations_iter + # Run attribution attributions = attribute_observations( - orbits_iter.to_orbits(), - observations_iter, + orbits_in, + observations_in, radius=radius, propagator=propagator, propagator_kwargs=propagator_kwargs, orbits_chunk_size=orbits_chunk_size, observations_chunk_size=observations_chunk_size, max_processes=max_processes, + orbit_ids=orbits_iter.orbit_id, + obs_ids=observations_iter.id, ) # For orbits with coincident observations: multiple observations attributed at @@ -389,9 +483,9 @@ def merge_and_extend_orbits( # Run differential orbit correction orbits_iter, orbit_members_iter = differential_correction( - orbits_iter, + orbits_in, orbit_members_iter, - observations_iter, + observations_in, rchi2_threshold=rchi2_threshold, min_obs=min_obs, min_arc_length=min_arc_length, @@ -404,6 +498,8 @@ def merge_and_extend_orbits( propagator_kwargs=propagator_kwargs, chunk_size=orbits_chunk_size, max_processes=max_processes, + orbit_ids=orbits_iter.orbit_id, + obs_ids=pc.unique(orbit_members_iter.obs_id), ) orbit_members_iter = orbit_members_iter.drop_outliers() @@ -476,6 +572,13 @@ def merge_and_extend_orbits( pc.is_in(orbits_iter.orbit_id, orbit_members_iter.orbit_id.unique()) ) + # Remove orbits from the object store (the underlying state vectors may + # change with differential correction so we need to add them again at + # the start of the next iteration) + if use_ray: + ray.internal.free([orbits_ref]) + logger.info("Removed orbits from the object store.") + iterations += 1 if len(orbits_iter) == 0: converged = True @@ -506,11 +609,17 @@ def merge_and_extend_orbits( ) odp_orbit_members = odp_orbit_members.drop_outliers() + if use_ray: + if len(refs_to_free) > 0: + ray.internal.free(refs_to_free) + logger.info( + f"Removed {len(refs_to_free)} references from the object store." + ) else: odp_orbits = FittedOrbits.empty() odp_orbit_members = FittedOrbitMembers.empty() - time_end = time.time() + time_end = time.perf_counter() logger.info( f"Number of attribution / differential correction iterations: {iterations}" ) diff --git a/thor/orbits/iod.py b/thor/orbits/iod.py index a3b21c96..d7ac70d0 100644 --- a/thor/orbits/iod.py +++ b/thor/orbits/iod.py @@ -554,9 +554,23 @@ def initial_orbit_determination( "outlier" : Flag to indicate which observations are potential outliers (their chi2 is higher than the chi2 threshold) [float] """ - time_start = time.time() + time_start = time.perf_counter() logger.info("Running initial orbit determination...") + if isinstance(linkage_members, ray.ObjectRef): + linkage_members_ref = linkage_members + linkage_members = ray.get(linkage_members) + logger.info("Retrieved linkage members from the object store.") + else: + linkage_members_ref = None + + if isinstance(observations, ray.ObjectRef): + observations_ref = observations + observations = ray.get(observations) + logger.info("Retrieved observations from the object store.") + else: + observations_ref = None + iod_orbits_list = [] iod_orbit_members_list = [] if len(observations) > 0 and len(linkage_members) > 0: @@ -567,10 +581,21 @@ def initial_orbit_determination( if max_processes is None or max_processes > 1: if not ray.is_initialized(): - ray.init(address="auto") + logger.info( + f"Ray is not initialized. Initializing with {max_processes}..." + ) + ray.init(address="auto", num_cpus=max_processes) + + refs_to_free = [] + if linkage_members_ref is None: + linkage_members_ref = ray.put(linkage_members) + refs_to_free.append(linkage_members_ref) + logger.info("Placed linkage members in the object store.") - observations_ref = ray.put(observations) - linkage_members_ref = ray.put(linkage_members) + if observations_ref is None: + observations_ref = ray.put(observations) + refs_to_free.append(observations_ref) + logger.info("Placed observations in the object store.") futures = [] for linkage_id_chunk in _iterate_chunks(linkage_ids, chunk_size): @@ -598,6 +623,12 @@ def initial_orbit_determination( iod_orbits_list.append(result[0]) iod_orbit_members_list.append(result[1]) + if len(refs_to_free) > 0: + ray.internal.free(refs_to_free) + logger.info( + f"Removed {len(refs_to_free)} references from the object store." + ) + else: for linkage_id_chunk in _iterate_chunks(linkage_ids, chunk_size): iod_orbits_chunk, iod_orbit_members_chunk = iod_worker( @@ -642,7 +673,7 @@ def initial_orbit_determination( iod_orbits = FittedOrbits.empty() iod_orbit_members = FittedOrbitMembers.empty() - time_end = time.time() + time_end = time.perf_counter() logger.info( "Initial orbit determination completed in {:.3f} seconds.".format( time_end - time_start diff --git a/thor/orbits/od.py b/thor/orbits/od.py index cf954c0a..25ab50c8 100644 --- a/thor/orbits/od.py +++ b/thor/orbits/od.py @@ -1,6 +1,6 @@ import logging import time -from typing import Literal, Optional, Tuple +from typing import Literal, Optional, Tuple, Union import numpy as np import numpy.typing as npt @@ -540,9 +540,9 @@ def od( def differential_correction( - orbits: FittedOrbits, - orbit_members: FittedOrbitMembers, - observations: Observations, + orbits: Union[FittedOrbits, ray.ObjectRef], + orbit_members: Union[FittedOrbitMembers, ray.ObjectRef], + observations: Union[Observations, ray.ObjectRef], min_obs: int = 5, min_arc_length: float = 1.0, contamination_percentage: float = 20, @@ -555,6 +555,8 @@ def differential_correction( propagator_kwargs: dict = {}, chunk_size: int = 10, max_processes: Optional[int] = 1, + orbit_ids: Optional[npt.NDArray[np.str_]] = None, + obs_ids: Optional[npt.NDArray[np.str_]] = None, ) -> Tuple[FittedOrbits, FittedOrbitMembers]: """ Differentially correct (via finite/central differencing). @@ -569,9 +571,48 @@ def differential_correction( Which parallelization backend to use {'ray', 'mp', 'cf'}. Defaults to using Python's concurrent.futures module ('cf'). """ + time_start = time.perf_counter() logger.info("Running differential correction...") - time_start = time.time() + if isinstance(orbits, ray.ObjectRef): + orbits_ref = orbits + orbits = ray.get(orbits) + logger.info("Retrieved orbits from the object store.") + + if orbit_ids is not None: + orbits = orbits.apply_mask(pc.is_in(orbits.orbit_id, orbit_ids)) + logger.info("Applied mask to orbit members.") + else: + orbits_ref = None + + if isinstance(orbit_members, ray.ObjectRef): + orbit_members_ref = orbit_members + orbit_members = ray.get(orbit_members) + logger.info("Retrieved orbit members from the object store.") + + if obs_ids is not None: + orbit_members = orbit_members.apply_mask( + pc.is_in(orbit_members.obs_id, obs_ids) + ) + logger.info("Applied mask to orbit members.") + if orbit_ids is not None: + orbit_members = orbit_members.apply_mask( + pc.is_in(orbit_members.orbit_id, orbit_ids) + ) + logger.info("Applied mask to orbit members.") + else: + orbit_members_ref = None + + if isinstance(observations, ray.ObjectRef): + observations_ref = observations + observations = ray.get(observations) + logger.info("Retrieved observations from the object store.") + + if obs_ids is not None: + observations = observations.apply_mask(pc.is_in(observations.id, obs_ids)) + logger.info("Applied mask to observations.") + else: + observations_ref = None if len(orbits) > 0 and len(orbit_members) > 0: @@ -582,11 +623,26 @@ def differential_correction( if max_processes is None or max_processes > 1: if not ray.is_initialized(): - ray.init(address="auto") + logger.info( + f"Ray is not initialized. Initializing with {max_processes}..." + ) + ray.init(address="auto", num_cpus=max_processes) - orbits_ref = ray.put(orbits) - orbit_members_ref = ray.put(orbit_members) - observations_ref = ray.put(observations) + refs_to_free = [] + if orbits_ref is None: + orbits_ref = ray.put(orbits) + refs_to_free.append(orbits_ref) + logger.info("Placed orbits in the object store.") + + if orbit_members_ref is None: + orbit_members_ref = ray.put(orbit_members) + refs_to_free.append(orbit_members_ref) + logger.info("Placed orbit members in the object store.") + + if observations_ref is None: + observations_ref = ray.put(observations) + refs_to_free.append(observations_ref) + logger.info("Placed observations in the object store.") futures = [] for orbit_ids_chunk in _iterate_chunks(orbit_ids, chunk_size): @@ -615,6 +671,12 @@ def differential_correction( od_orbits_list.append(results[0]) od_orbit_members_list.append(results[1]) + if len(refs_to_free) > 0: + ray.internal.free(refs_to_free) + logger.info( + f"Removed {len(refs_to_free)} references from the object store." + ) + else: for orbit_ids_chunk in _iterate_chunks(orbit_ids, chunk_size): @@ -644,7 +706,7 @@ def differential_correction( od_orbits = FittedOrbits.empty() od_orbit_members = FittedOrbitMembers.empty() - time_end = time.time() + time_end = time.perf_counter() logger.info("Differentially corrected {} orbits.".format(len(od_orbits))) logger.info( "Differential correction completed in {:.3f} seconds.".format( diff --git a/thor/range_and_transform.py b/thor/range_and_transform.py index c0fedfff..400c77e4 100644 --- a/thor/range_and_transform.py +++ b/thor/range_and_transform.py @@ -128,7 +128,11 @@ def range_and_transform( logger.info(f"Assuming v = {test_orbit.coordinates.v[0]} au/d") if isinstance(observations, ray.ObjectRef): + observations_ref = observations observations = ray.get(observations) + logger.info("Retrieved observations from the object store.") + else: + observations_ref = None prop = propagator(**propagator_kwargs) @@ -160,21 +164,23 @@ def range_and_transform( if max_processes is None or max_processes > 1: if not ray.is_initialized(): - logger.debug( + logger.info( f"Ray is not initialized. Initializing with {max_processes}..." ) - ray.init(num_cpus=max_processes) + ray.init(address="auto", num_cpus=max_processes) - if isinstance(observations, ray.ObjectRef): - observations_ref = observations - observations = ray.get(observations_ref) - else: + refs_to_free = [] + if observations_ref is None: observations_ref = ray.put(observations) + refs_to_free.append(observations_ref) + logger.info("Placed observations in the object store.") - if isinstance(ephemeris, ray.ObjectRef): - ephemeris_ref = ephemeris - else: + if not isinstance(ephemeris, ray.ObjectRef): ephemeris_ref = ray.put(ephemeris) + refs_to_free.append(ephemeris_ref) + logger.info("Placed ephemeris in the object store.") + else: + ephemeris_ref = ephemeris ranged_detections_cartesian_ref = ray.put(ranged_detections_cartesian) @@ -195,6 +201,12 @@ def range_and_transform( finished, futures = ray.wait(futures, num_returns=1) transformed_detection_list.append(ray.get(finished[0])) + if len(refs_to_free) > 0: + ray.internal.free(refs_to_free) + logger.info( + f"Removed {len(refs_to_free)} references from the object store." + ) + else: # Get state IDs state_ids = observations.state_id.unique().sort()