diff --git a/thor/main.py b/thor/main.py index 00b5d013..12d4aa03 100644 --- a/thor/main.py +++ b/thor/main.py @@ -125,14 +125,6 @@ def link_test_orbit( ) refs_to_free = [] - if ( - use_ray - and observations is not None - and not isinstance(observations, (ray.ObjectRef, str)) - ): - observations = ray.put(observations) - refs_to_free.append(observations) - logger.info("Placed observations in the object store.") checkpoint = load_initial_checkpoint_values(test_orbit_directory) logger.info(f"Starting at stage: {checkpoint.stage}") @@ -152,12 +144,6 @@ def link_test_orbit( ) if checkpoint.stage == "filter_observations": - if use_ray: - if not isinstance(observations, (ray.ObjectRef, str)): - observations = ray.put(observations) - refs_to_free.append(observations) - logger.info("Placed observations in the object store.") - filtered_observations = filter_observations( observations, test_orbit, config, filters ) @@ -186,11 +172,7 @@ def link_test_orbit( filtered_observations=filtered_observations, ) - # Observations are no longer needed. If we are using ray - # lets explicitly free the memory. - if use_ray and isinstance(observations, ray.ObjectRef): - ray.internal.free([observations]) - logger.info("Removed observations from the object store.") + # Observations are no longer needed del observations if checkpoint.stage == "range_and_transform": diff --git a/thor/observations/filters.py b/thor/observations/filters.py index 55d46c86..d4998968 100644 --- a/thor/observations/filters.py +++ b/thor/observations/filters.py @@ -324,6 +324,11 @@ def filter_observations( use_ray = initialize_use_ray(num_cpus=config.max_processes) if use_ray: + + if isinstance(observations, Observations): + observations = ray.put(observations) + logger.info("Placed observations in the object store.") + futures = [] for state_id_chunk in _iterate_chunks(state_ids, chunk_size): @@ -344,6 +349,10 @@ def filter_observations( if filtered_observations.fragmented(): filtered_observations = qv.defragment(filtered_observations) + if isinstance(observations, ray.ObjectRef): + ray.internal.free([observations]) + logger.info("Removed observations from the object store.") + else: for state_id_chunk in _iterate_chunks(state_ids, chunk_size):