Skip to content

Commit

Permalink
Merge pull request #135 from moeyensj/v2.0-ray-references
Browse files Browse the repository at this point in the history
Improve ray reference passing
  • Loading branch information
moeyensj authored Nov 24, 2023
2 parents 8b8c5a8 + f4d389b commit f623dd8
Show file tree
Hide file tree
Showing 8 changed files with 436 additions and 90 deletions.
24 changes: 12 additions & 12 deletions thor/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,46 +43,46 @@ class Config:

stage: Literal["cluster_and_link"]
filtered_observations: Union[Observations, ray.ObjectRef]
transformed_detections: TransformedDetections
transformed_detections: Union[TransformedDetections, ray.ObjectRef]


class InitialOrbitDetermination(pydantic.BaseModel):
class Config:
arbitrary_types_allowed = True

stage: Literal["initial_orbit_determination"]
filtered_observations: Observations
clusters: Clusters
cluster_members: ClusterMembers
filtered_observations: Union[Observations, ray.ObjectRef]
clusters: Union[Clusters, ray.ObjectRef]
cluster_members: Union[ClusterMembers, ray.ObjectRef]


class DifferentialCorrection(pydantic.BaseModel):
class Config:
arbitrary_types_allowed = True

stage: Literal["differential_correction"]
filtered_observations: Observations
iod_orbits: FittedOrbits
iod_orbit_members: FittedOrbitMembers
filtered_observations: Union[Observations, ray.ObjectRef]
iod_orbits: Union[FittedOrbits, ray.ObjectRef]
iod_orbit_members: Union[FittedOrbitMembers, ray.ObjectRef]


class RecoverOrbits(pydantic.BaseModel):
class Config:
arbitrary_types_allowed = True

stage: Literal["recover_orbits"]
filtered_observations: Observations
od_orbits: FittedOrbits
od_orbit_members: FittedOrbitMembers
filtered_observations: Union[Observations, ray.ObjectRef]
od_orbits: Union[FittedOrbits, ray.ObjectRef]
od_orbit_members: Union[FittedOrbitMembers, ray.ObjectRef]


class Complete(pydantic.BaseModel):
class Config:
arbitrary_types_allowed = True

stage: Literal["complete"]
recovered_orbits: FittedOrbits
recovered_orbit_members: FittedOrbitMembers
recovered_orbits: Union[FittedOrbits, ray.ObjectRef]
recovered_orbit_members: Union[FittedOrbitMembers, ray.ObjectRef]


CheckpointData = Annotated[
Expand Down
41 changes: 29 additions & 12 deletions thor/clusters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import time
import uuid
from typing import List, Literal, Optional, Tuple
from typing import List, Literal, Optional, Tuple, Union

import numba
import numpy as np
Expand Down Expand Up @@ -31,7 +31,7 @@
"ClusterMembers",
]

logger = logging.getLogger("thor")
logger = logging.getLogger(__name__)


class Clusters(qv.Table):
Expand Down Expand Up @@ -585,7 +585,7 @@ def cluster_velocity_worker(


def cluster_and_link(
observations: TransformedDetections,
observations: Union[TransformedDetections, ray.ObjectRef],
vx_range: List[float] = [-0.1, 0.1],
vy_range: List[float] = [-0.1, 0.1],
vx_bins: int = 100,
Expand Down Expand Up @@ -673,6 +673,10 @@ def cluster_and_link(
logger.info("Max sample distance: {}".format(radius))
logger.info("Minimum samples: {}".format(min_obs))

if isinstance(observations, ray.ObjectRef):
observations = ray.get(observations)
logger.info("Retrieved observations from the object store.")

clusters_list = []
cluster_members_list = []
if len(observations) > 0:
Expand All @@ -690,14 +694,22 @@ def cluster_and_link(
if max_processes is None or max_processes > 1:

if not ray.is_initialized():
ray.init(address="auto")
logger.info(
f"Ray is not initialized. Initializing with {max_processes}..."
)
ray.init(address="auto", num_cpus=max_processes)

# Put all arrays (which can be large) in ray's
# local object store ahead of time
obs_ids_oid = ray.put(obs_ids)
theta_x_oid = ray.put(theta_x)
theta_y_oid = ray.put(theta_y)
dt_oid = ray.put(dt)
obs_ids_ref = ray.put(obs_ids)
theta_x_ref = ray.put(theta_x)
theta_y_ref = ray.put(theta_y)
dt_ref = ray.put(dt)
refs_to_free = [obs_ids_ref, theta_x_ref, theta_y_ref, dt_ref]
logger.info("Placed gnomonic coordinate arrays in the object store.")
# TODO: transformed detections are already in the object store so we might
# want to instead pass references to those rather than extract arrays
# from them and put them in the object store again.

futures = []
for vxi_chunk, vyi_chunk in zip(
Expand All @@ -707,10 +719,10 @@ def cluster_and_link(
cluster_velocity_remote.remote(
vxi_chunk,
vyi_chunk,
obs_ids_oid,
theta_x_oid,
theta_y_oid,
dt_oid,
obs_ids_ref,
theta_x_ref,
theta_y_ref,
dt_ref,
radius=radius,
min_obs=min_obs,
min_arc_length=min_arc_length,
Expand All @@ -724,6 +736,11 @@ def cluster_and_link(
clusters_list.append(result[0])
cluster_members_list.append(result[1])

ray.internal.free(refs_to_free)
logger.info(
f"Removed {len(refs_to_free)} references from the object store."
)

else:

for vxi_chunk, vyi_chunk in zip(
Expand Down
109 changes: 95 additions & 14 deletions thor/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def initialize_use_ray(config: Config) -> bool:
if config.max_processes is None or config.max_processes > 1:
# Initialize ray
if not ray.is_initialized():
logger.debug(
f"Ray is not initialized. Initializing with {config.max_processes}..."
logger.info(
f"Ray is not initialized. Initializing with {config.max_processes} cpus..."
)
ray.init(num_cpus=config.max_processes)

Expand Down Expand Up @@ -134,12 +134,15 @@ def link_test_orbit(

use_ray = initialize_use_ray(config)

refs_to_free = []
if (
use_ray
and observations is not None
and not isinstance(observations, ray.ObjectRef)
):
observations = ray.put(observations)
refs_to_free.append(observations)
logger.info("Placed observations in the object store.")

checkpoint = load_initial_checkpoint_values(test_orbit_directory)
logger.info(f"Starting at stage: {checkpoint.stage}")
Expand All @@ -159,6 +162,12 @@ def link_test_orbit(
)

if checkpoint.stage == "filter_observations":
if use_ray:
if not isinstance(observations, ray.ObjectRef):
observations = ray.put(observations)
refs_to_free.append(observations)
logger.info("Placed observations in the object store.")

filtered_observations = filter_observations(
observations, test_orbit, config, filters
)
Expand All @@ -176,18 +185,31 @@ def link_test_orbit(
path=(filtered_observations_path,),
)

if use_ray:
if not isinstance(filtered_observations, ray.ObjectRef):
filtered_observations = ray.put(filtered_observations)
refs_to_free.append(filtered_observations)
logger.info("Placed filtered observations in the object store.")

checkpoint = create_checkpoint_data(
"range_and_transform",
filtered_observations=filtered_observations,
)

# Observations are no longer needed, so we can delete them
# Observations are no longer needed. If we are using ray
# lets explicitly free the memory.
if use_ray and isinstance(observations, ray.ObjectRef):
ray.internal.free([observations])
logger.info("Removed observations from the object store.")
del observations

if checkpoint.stage == "range_and_transform":
filtered_observations = checkpoint.filtered_observations
if use_ray and not isinstance(filtered_observations, ray.ObjectRef):
filtered_observations = ray.put(filtered_observations)
if use_ray:
if not isinstance(filtered_observations, ray.ObjectRef):
filtered_observations = ray.put(filtered_observations)
refs_to_free.append(filtered_observations)
logger.info("Placed filtered observations in the object store.")

# Range and transform the observations
transformed_detections = range_and_transform(
Expand All @@ -211,21 +233,26 @@ def link_test_orbit(
path=(transformed_detections_path,),
)

if use_ray:
if not isinstance(filtered_observations, ray.ObjectRef):
filtered_observations = ray.put(filtered_observations)
refs_to_free.append(filtered_observations)
logger.info("Placed filtered observations in the object store.")
if not isinstance(transformed_detections, ray.ObjectRef):
transformed_detections = ray.put(transformed_detections)
refs_to_free.append(transformed_detections)
logger.info("Placed transformed detections in the object store.")

checkpoint = create_checkpoint_data(
"cluster_and_link",
filtered_observations=filtered_observations,
transformed_detections=transformed_detections,
)

# TODO: ray support for the rest of the pipeline has not yet been implemented
# so we will convert the ray objects to regular objects for now
if use_ray:
if isinstance(checkpoint.filtered_observations, ray.ObjectRef):
checkpoint.filtered_observations = ray.get(filtered_observations)

if checkpoint.stage == "cluster_and_link":
transformed_detections = checkpoint.transformed_detections
filtered_observations = checkpoint.filtered_observations
transformed_detections = checkpoint.transformed_detections

# Run clustering
clusters, cluster_members = cluster_and_link(
transformed_detections,
Expand Down Expand Up @@ -258,6 +285,20 @@ def link_test_orbit(
path=(clusters_path, cluster_members_path),
)

if use_ray:
if not isinstance(filtered_observations, ray.ObjectRef):
filtered_observations = ray.put(filtered_observations)
refs_to_free.append(filtered_observations)
logger.info("Placed filtered observations in the object store.")
if not isinstance(clusters, ray.ObjectRef):
clusters = ray.put(clusters)
refs_to_free.append(clusters)
logger.info("Placed clusters in the object store.")
if not isinstance(cluster_members, ray.ObjectRef):
cluster_members = ray.put(cluster_members)
refs_to_free.append(cluster_members)
logger.info("Placed cluster members in the object store.")

checkpoint = create_checkpoint_data(
"initial_orbit_determination",
filtered_observations=filtered_observations,
Expand Down Expand Up @@ -306,6 +347,20 @@ def link_test_orbit(
path=(iod_orbits_path, iod_orbit_members_path),
)

if use_ray:
if not isinstance(filtered_observations, ray.ObjectRef):
filtered_observations = ray.put(filtered_observations)
refs_to_free.append(filtered_observations)
logger.info("Placed filtered observations in the object store.")
if not isinstance(iod_orbits, ray.ObjectRef):
iod_orbits = ray.put(iod_orbits)
refs_to_free.append(iod_orbits)
logger.info("Placed initial orbits in the object store.")
if not isinstance(iod_orbit_members, ray.ObjectRef):
iod_orbit_members = ray.put(iod_orbit_members)
refs_to_free.append(iod_orbit_members)
logger.info("Placed initial orbit members in the object store.")

checkpoint = create_checkpoint_data(
"differential_correction",
filtered_observations=filtered_observations,
Expand All @@ -314,9 +369,10 @@ def link_test_orbit(
)

if checkpoint.stage == "differential_correction":
filtered_observations = checkpoint.filtered_observations
iod_orbits = checkpoint.iod_orbits
iod_orbit_members = checkpoint.iod_orbit_members
filtered_observations = checkpoint.filtered_observations

# Run differential correction
od_orbits, od_orbit_members = differential_correction(
iod_orbits,
Expand Down Expand Up @@ -354,6 +410,24 @@ def link_test_orbit(
path=(od_orbits_path, od_orbit_members_path),
)

if use_ray:
if not isinstance(filtered_observations, ray.ObjectRef):
filtered_observations = ray.put(filtered_observations)
refs_to_free.append(filtered_observations)
logger.info("Placed filtered observations in the object store.")
if not isinstance(od_orbits, ray.ObjectRef):
od_orbits = ray.put(od_orbits)
refs_to_free.append(od_orbits)
logger.info(
"Placed differentially corrected orbits in the object store."
)
if not isinstance(od_orbit_members, ray.ObjectRef):
od_orbit_members = ray.put(od_orbit_members)
refs_to_free.append(od_orbit_members)
logger.info(
"Placed differentially corrected orbit members in the object store."
)

checkpoint = create_checkpoint_data(
"recover_orbits",
filtered_observations=filtered_observations,
Expand All @@ -362,9 +436,10 @@ def link_test_orbit(
)

if checkpoint.stage == "recover_orbits":
filtered_observations = checkpoint.filtered_observations
od_orbits = checkpoint.od_orbits
od_orbit_members = checkpoint.od_orbit_members
filtered_observations = checkpoint.filtered_observations

# Run arc extension
recovered_orbits, recovered_orbit_members = merge_and_extend_orbits(
od_orbits,
Expand All @@ -387,6 +462,12 @@ def link_test_orbit(
observations_chunk_size=100000,
)

if use_ray and len(refs_to_free) > 0:
ray.internal.free(refs_to_free)
logger.info(
f"Removed {len(refs_to_free)} references from the object store."
)

recovered_orbits_path = None
recovered_orbit_members_path = None
if test_orbit_directory is not None:
Expand Down
Loading

0 comments on commit f623dd8

Please sign in to comment.