Skip to content

Commit

Permalink
Improve peak memory by pushing sorting of large tables upstream of fu…
Browse files Browse the repository at this point in the history
…nctions.
  • Loading branch information
akoumjian committed Jan 5, 2024
1 parent 15b3517 commit bc2674e
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 55 deletions.
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ services:
- ".docker_bash_history.txt:/root/.bash_history"
- ".volumes:/opt/volumes/"
tmpfs:
- /dev/shm:size=8g
- /dev/shm:size=12g
111 changes: 78 additions & 33 deletions thor/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
else:
from sklearn.cluster import DBSCAN

import hashlib
from typing import List, Literal, Tuple

import pyarrow as pa
import pyarrow.compute as pc

__all__ = [
"cluster_and_link",
Expand All @@ -36,61 +41,91 @@
logger = logging.getLogger(__name__)


def hash_obs_ids(obs_ids: List[str]) -> str:
"""
Create unique strings for each set unique set of observation IDs
We use hashes rather than original string in order to save memory.
"""
return hashlib.md5("".join(sorted(set(obs_ids))).encode()).hexdigest()


def drop_duplicate_clusters(
clusters: "Clusters",
cluster_members: "ClusterMembers",
num_cpus: int = 1,
) -> Tuple["Clusters", "ClusterMembers"]:
"""
Drop clusters that have identical sets of observation IDs.
Parameters
----------
clusters: `~thor.clusters.Clusters`
A table of clusters. Must be sorted by cluster_id.
cluster_members: `~thor.clusters.ClusterMembers`
A table of cluster members.
A table of cluster members. Must be sorted by cluster_id.
Returns
-------
`~thor.clusters.Clusters`
`~thor.clusters.Clusters`, `~thor.clusters.ClusterMembers`
A table of clusters with duplicate clusters removed.
The cluster members belonging to those clusters.
"""
# Sort by cluster_id and obs_id
clusters = clusters.sort_by([("cluster_id", "ascending")])
cluster_members = cluster_members.sort_by(
[("cluster_id", "ascending"), ("obs_id", "ascending")]
)

# Group by cluster_id and aggregate a list of distinct obs_ids
grouped_by_cluster_id = cluster_members.table.group_by(
["cluster_id"], use_threads=False
).aggregate([("obs_id", "distinct")])
obs_ids_per_cluster = grouped_by_cluster_id["obs_id_distinct"].to_pylist()

# Group by with a distinct aggregation is not guaranteed to preserve the order of the elements within each list
# but does preserve the order of the lists themselves. So we sort each list of obs_ids and while we are
# sorting we also convert the lists to a single string on which we can group later.
# Pyarrow currently does not support groupby on lists of strings, this is likely a missing feature.
# As an example, the following code doesn't work:
# grouped_by_obs_lists = grouped_by_cluster_id.group_by(
# ["obs_id_distinct"],
# use_threads=False
# ).aggregate([("index", "first")])
for i, obs_ids_i in enumerate(obs_ids_per_cluster):
obs_ids_i.sort()
obs_ids_per_cluster[i] = "".join(obs_ids_i)

squashed_obs_ids = pa.table(
# Ensure clusters and cluster members are sorted by cluster id
# by spot checking the first few and last few rows are
# in sorted order
assert clusters.cluster_id[:3].to_pylist() == sorted(
clusters.cluster_id[:3].to_pylist()
), "clusters must be sorted by cluster_id" # noqa: E501
assert clusters.cluster_id[-3:].to_pylist() == sorted(
clusters.cluster_id[-3:].to_pylist()
), "clusters must be sorted by cluster_id" # noqa: E501
assert cluster_members.cluster_id[:3].to_pylist() == sorted(
cluster_members.cluster_id[:3].to_pylist()
), "cluster_members must be sorted by cluster_id" # noqa: E501
assert cluster_members.cluster_id[-3:].to_pylist() == sorted(
cluster_members.cluster_id[-3:].to_pylist()
), "cluster_members must be sorted by cluster_id" # noqa: E501

# We used to use a group by in pyarrow here,
# but found the memory accumulationw as too high.
# A simple loop that accumulates the distinct obs ids
# for each cluster is more memory efficient.
logger.info(f"Accumulating cluster observation IDs into single strings.")
obs_ids_per_cluster: Union[List[str], pa.Array] = []
current_obs_ids: List[str] = []
current_cluster_id = None
for member in cluster_members:
cluster_id = member.cluster_id.to_pylist()[0]
obs_id = member.obs_id.to_pylist()[0]
if cluster_id != current_cluster_id:
if current_cluster_id is not None:
obs_ids_per_cluster.append(hash_obs_ids(current_obs_ids))
current_cluster_id = cluster_id
current_obs_ids = []
current_obs_ids.append(obs_id)
obs_ids_per_cluster.append(hash_obs_ids(current_obs_ids))

logger.info(f"Grouping by unique observation sets.")
obs_ids_per_cluster = pa.table(
{
"index": pa.array(np.arange(0, len(obs_ids_per_cluster))),
"obs_ids": obs_ids_per_cluster,
}
)
indices = (
squashed_obs_ids.group_by(["obs_ids"], use_threads=False)
.aggregate([("index", "first")])["index_first"]
.combine_chunks()
)

obs_ids_per_cluster = obs_ids_per_cluster.combine_chunks()
obs_ids_per_cluster = obs_ids_per_cluster.group_by(["obs_ids"], use_threads=False)

logger.info(f"Taking first index of each unique observation set.")
indices = obs_ids_per_cluster.aggregate([("index", "first")])["index_first"]
del obs_ids_per_cluster
indices = indices.combine_chunks()

logger.info(f"Taking clusters that belong to unique observation sets.")
clusters = clusters.take(indices)

logger.info(f"Taking cluster members that belong to unique clusters.")
cluster_members = cluster_members.apply_mask(
pc.is_in(cluster_members.cluster_id, clusters.cluster_id)
)
Expand Down Expand Up @@ -821,6 +856,16 @@ def cluster_and_link(
time_start_drop = time.perf_counter()
logger.info("Removing duplicate clusters...")
num_clusters = len(clusters)

# Ensure clusters, cluster_members are defragmented and sorted
# prior to dropping duplicates. We do this here so that
# we don't sort inside the function and make a whole new copy
# while the old one stays referenced in memory
clusters = qv.defragment(clusters)
cluster_members = qv.defragment(cluster_members)
clusters = clusters.sort_by([("cluster_id", "ascending")])
cluster_members = cluster_members.sort_by([("cluster_id", "ascending")])

clusters, cluster_members = drop_duplicate_clusters(clusters, cluster_members)
logger.info(f"Removed {num_clusters - len(clusters)} duplicate clusters.")
time_end_drop = time.perf_counter()
Expand Down
57 changes: 36 additions & 21 deletions thor/observations/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,39 @@ def from_input_observations(cls, observations: InputObservations) -> "Observatio
observations : `~Observations`
A table of THOR observations.
"""
# Sort the observations by time and observatory code
observations_sorted = observations.sort_by(
["time.days", "time.nanos", "observatory_code"]
)

# If the times are not in UTC, convert them to UTC
if observations_sorted.time.scale != "utc":
observations_sorted = observations_sorted.set_column(
"time", observations_sorted.time.rescale("utc")
# Do a spot check that observations are pre-sorted by time
# and observatory code. We pre-sort to avoid large duplicate
# memory usage
assert (
observations[:3]
.to_dataframe()
.equals(
observations[:3]
.to_dataframe()
.sort_values(["time.days", "time.nanos", "observatory_code"])
)
), "Input observations must be sorted by day, time, observatory code"

assert (
observations[-3:]
.to_dataframe()
.equals(
observations[-3:]
.to_dataframe()
.sort_values(["time.days", "time.nanos", "observatory_code"])
)
), "Input observations must be sorted by day, time, observatory code"

assert observations.time.scale == "utc", "Input observations must be in UTC"

# Extract the sigma and covariance values for RA and Dec
ra_sigma = observations_sorted.ra_sigma.to_numpy(zero_copy_only=False)
dec_sigma = observations_sorted.dec_sigma.to_numpy(zero_copy_only=False)
ra_dec_cov = observations_sorted.ra_dec_cov.to_numpy(zero_copy_only=False)
ra_sigma = observations.ra_sigma.to_numpy(zero_copy_only=False)
dec_sigma = observations.dec_sigma.to_numpy(zero_copy_only=False)
ra_dec_cov = observations.ra_dec_cov.to_numpy(zero_copy_only=False)

# Create the covariance matrices
covariance_matrices = np.full((len(observations_sorted), 6, 6), np.nan)
covariance_matrices = np.full((len(observations), 6, 6), np.nan)
covariance_matrices[:, 1, 1] = ra_sigma**2
covariance_matrices[:, 2, 2] = dec_sigma**2
covariance_matrices[:, 1, 2] = ra_dec_cov
Expand All @@ -87,24 +102,24 @@ def from_input_observations(cls, observations: InputObservations) -> "Observatio

# Create the coordinates table
coords = SphericalCoordinates.from_kwargs(
lon=observations_sorted.ra,
lat=observations_sorted.dec,
time=observations_sorted.time,
lon=observations.ra,
lat=observations.dec,
time=observations.time,
covariance=covariances,
origin=Origin.from_kwargs(code=observations_sorted.observatory_code),
origin=Origin.from_kwargs(code=observations.observatory_code),
frame="equatorial",
)

# Create the photometry table
photometry = Photometry.from_kwargs(
filter=observations_sorted.filter,
mag=observations_sorted.mag,
mag_sigma=observations_sorted.mag_sigma,
filter=observations.filter,
mag=observations.mag,
mag_sigma=observations.mag_sigma,
)

return cls.from_kwargs(
id=observations_sorted.id,
exposure_id=observations_sorted.exposure_id,
id=observations.id,
exposure_id=observations.exposure_id,
coordinates=coords,
photometry=photometry,
state_id=calculate_state_ids(coords),
Expand Down
Binary file not shown.
24 changes: 24 additions & 0 deletions thor/tests/memory/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ def snapshot():
)


@pytest.fixture
def memory_input_observations():
from thor.observations import InputObservations

return InputObservations.from_parquet(FIXTURES_DIR / "input_observations.parquet")


# We are going to test all the major stages used in link_test_orbit
@pytest.fixture
def memory_observations():
Expand Down Expand Up @@ -201,6 +208,23 @@ def ray_cluster(memory_config):
ray.shutdown()


@pytest.mark.memory
@pytest.mark.parametrize("memory_config", CONFIG_PROCESSES, indirect=True)
def test_load_input_observations(
memory_snapshot, memory_config, ray_cluster, memory_input_observations
):
from thor.observations import Observations

# It's always necessary to sort ahead of time, so we include it in our test
memory_input_observations = memory_input_observations.sort_by(
["time.days", "time.nanos", "observatory_code"]
)
memory_input_observations = memory_input_observations.set_column(
"time", memory_input_observations.time.rescale("utc")
)
observations = Observations.from_input_observations(memory_input_observations)


@pytest.mark.memory
@pytest.mark.parametrize("memory_config", CONFIG_PROCESSES, indirect=True)
def test_filter_observations(
Expand Down
36 changes: 36 additions & 0 deletions thor/tests/test_clusters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pytest
import quivr as qv

from ..clusters import (
ClusterMembers,
Expand Down Expand Up @@ -405,3 +408,36 @@ def test_Clusters_drop_duplicates():
cluster_members_filtered.obs_id.to_numpy(zero_copy_only=False),
np.hstack(np.array(obs_ids)),
)


def test_drop_duplicate_clusters_sorted():
"""
Test that drop duplicate clusters throws an assertion error if not sorted
"""
clusters = Clusters.from_kwargs(
cluster_id=["c00005", "c00000", "c00001", "c00002", "c00003", "c00004"],
vtheta_x=np.full(6, 0.0),
vtheta_y=np.full(6, 0.0),
arc_length=np.full(6, 0.0),
num_obs=np.full(6, 5),
)

cluster_members = ClusterMembers.from_kwargs(
cluster_id=["c00005", "c00000", "c00001", "c00002", "c00003", "c00004"],
obs_id=[
"obs_01",
"obs_02",
"obs_03",
"obs_04",
"obs_05",
"obs_06",
],
)

with pytest.raises(AssertionError):
drop_duplicate_clusters(clusters, cluster_members)

clusters = clusters.sort_by([("cluster_id", "ascending")])
cluster_members = cluster_members.sort_by([("cluster_id", "ascending")])

drop_duplicate_clusters(clusters, cluster_members)

0 comments on commit bc2674e

Please sign in to comment.