Skip to content

Commit

Permalink
If using ray do not store unfiltered observations in the ray object s…
Browse files Browse the repository at this point in the history
…tore in link_test_orbit and instead do so in filter_observations
  • Loading branch information
moeyensj committed Jan 24, 2024
1 parent 3ac98c9 commit 2a4645b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 19 deletions.
20 changes: 1 addition & 19 deletions thor/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,6 @@ def link_test_orbit(
)

refs_to_free = []
if (
use_ray
and observations is not None
and not isinstance(observations, (ray.ObjectRef, str))
):
observations = ray.put(observations)
refs_to_free.append(observations)
logger.info("Placed observations in the object store.")

checkpoint = load_initial_checkpoint_values(test_orbit_directory)
logger.info(f"Starting at stage: {checkpoint.stage}")
Expand All @@ -152,12 +144,6 @@ def link_test_orbit(
)

if checkpoint.stage == "filter_observations":
if use_ray:
if not isinstance(observations, (ray.ObjectRef, str)):
observations = ray.put(observations)
refs_to_free.append(observations)
logger.info("Placed observations in the object store.")

filtered_observations = filter_observations(
observations, test_orbit, config, filters
)
Expand Down Expand Up @@ -186,11 +172,7 @@ def link_test_orbit(
filtered_observations=filtered_observations,
)

# Observations are no longer needed. If we are using ray
# lets explicitly free the memory.
if use_ray and isinstance(observations, ray.ObjectRef):
ray.internal.free([observations])
logger.info("Removed observations from the object store.")
# Observations are no longer needed
del observations

if checkpoint.stage == "range_and_transform":
Expand Down
9 changes: 9 additions & 0 deletions thor/observations/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,11 @@ def filter_observations(

use_ray = initialize_use_ray(num_cpus=config.max_processes)
if use_ray:

if isinstance(observations, Observations):
observations = ray.put(observations)
logger.info("Placed observations in the object store.")

futures = []
for state_id_chunk in _iterate_chunks(state_ids, chunk_size):

Expand All @@ -344,6 +349,10 @@ def filter_observations(
if filtered_observations.fragmented():
filtered_observations = qv.defragment(filtered_observations)

if isinstance(observations, ray.ObjectRef):
ray.internal.free([observations])
logger.info("Removed observations from the object store.")

else:

for state_id_chunk in _iterate_chunks(state_ids, chunk_size):
Expand Down

0 comments on commit 2a4645b

Please sign in to comment.