Skip to content

Commit

Permalink
Merge pull request #129 from moeyensj/v2.0-observations-with-states
Browse files Browse the repository at this point in the history
Add InputObservations
  • Loading branch information
moeyensj authored Nov 20, 2023
2 parents dc73b7e + a978bec commit 249a81c
Show file tree
Hide file tree
Showing 20 changed files with 639 additions and 453 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ repos:
rev: v1.1.1
hooks:
- id: mypy
exclude: bench/
additional_dependencies:
- 'types-pyyaml'
- 'types-requests'
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ install_requires =
numba
pandas
pyarrow >= 14.0.0
pydantic
pydantic < 2.0.0
pyyaml >= 5.1
quivr @ git+https://github.com/moeyensj/quivr@concatenate-empty-attributes
ray[default]
Expand Down
273 changes: 273 additions & 0 deletions thor/checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
import logging
import pathlib
from typing import Annotated, Dict, Literal, Optional, Type, Union

import pydantic
import quivr as qv
import ray

from thor.clusters import ClusterMembers, Clusters
from thor.observations.observations import Observations
from thor.orbit_determination.fitted_orbits import FittedOrbitMembers, FittedOrbits
from thor.range_and_transform import TransformedDetections

logger = logging.getLogger("thor")


VALID_STAGES = Literal[
"filter_observations",
"range_and_transform",
"cluster_and_link",
"initial_orbit_determination",
"differential_correction",
"recover_orbits",
"complete",
]


class FilterObservations(pydantic.BaseModel):
stage: Literal["filter_observations"]


class RangeAndTransform(pydantic.BaseModel):
class Config:
arbitrary_types_allowed = True

stage: Literal["range_and_transform"]
filtered_observations: Union[Observations, ray.ObjectRef]


class ClusterAndLink(pydantic.BaseModel):
class Config:
arbitrary_types_allowed = True

stage: Literal["cluster_and_link"]
filtered_observations: Union[Observations, ray.ObjectRef]
transformed_detections: TransformedDetections


class InitialOrbitDetermination(pydantic.BaseModel):
class Config:
arbitrary_types_allowed = True

stage: Literal["initial_orbit_determination"]
filtered_observations: Observations
clusters: Clusters
cluster_members: ClusterMembers


class DifferentialCorrection(pydantic.BaseModel):
class Config:
arbitrary_types_allowed = True

stage: Literal["differential_correction"]
filtered_observations: Observations
iod_orbits: FittedOrbits
iod_orbit_members: FittedOrbitMembers


class RecoverOrbits(pydantic.BaseModel):
class Config:
arbitrary_types_allowed = True

stage: Literal["recover_orbits"]
filtered_observations: Observations
od_orbits: FittedOrbits
od_orbit_members: FittedOrbitMembers


class Complete(pydantic.BaseModel):
class Config:
arbitrary_types_allowed = True

stage: Literal["complete"]
recovered_orbits: FittedOrbits
recovered_orbit_members: FittedOrbitMembers


CheckpointData = Annotated[
Union[
FilterObservations,
RangeAndTransform,
ClusterAndLink,
InitialOrbitDetermination,
DifferentialCorrection,
RecoverOrbits,
Complete,
],
pydantic.Field(discriminator="stage"),
]

# A mapping from stage to model class
stage_to_model: Dict[str, Type[pydantic.BaseModel]] = {
"filter_observations": FilterObservations,
"range_and_transform": RangeAndTransform,
"cluster_and_link": ClusterAndLink,
"initial_orbit_determination": InitialOrbitDetermination,
"differential_correction": DifferentialCorrection,
"recover_orbits": RecoverOrbits,
"complete": Complete,
}


def create_checkpoint_data(stage: VALID_STAGES, **data) -> CheckpointData:
"""
Create checkpoint data from the given stage and data.
"""
model = stage_to_model.get(stage)
if model:
return model(stage=stage, **data)
raise ValueError(f"Invalid stage: {stage}")


def load_initial_checkpoint_values(
test_orbit_directory: Optional[pathlib.Path] = None,
) -> CheckpointData:
"""
Check for completed stages and return values from disk if they exist.
We want to avoid loading objects into memory that are not required.
"""
stage: VALID_STAGES = "filter_observations"
# Without a checkpoint directory, we always start at the beginning
if test_orbit_directory is None:
return create_checkpoint_data(stage)

# filtered_observations is always needed when it exists
filtered_observations_path = pathlib.Path(
test_orbit_directory, "filtered_observations.parquet"
)
# If it doesn't exist, start at the beginning.
if not filtered_observations_path.exists():
return create_checkpoint_data(stage)
logger.info("Found filtered observations")
filtered_observations = Observations.from_parquet(filtered_observations_path)

# Unfortunately we have to reinitialize the times to set the attribute
# correctly.
filtered_observations = qv.defragment(filtered_observations)
filtered_observations = filtered_observations.sort_by(
[
"coordinates.time.days",
"coordinates.time.nanos",
"coordinates.origin.code",
]
)

# If the pipeline was started but we have recovered_orbits already, we
# are done and should exit early.
recovered_orbits_path = pathlib.Path(
test_orbit_directory, "recovered_orbits.parquet"
)
recovered_orbit_members_path = pathlib.Path(
test_orbit_directory, "recovered_orbit_members.parquet"
)
if recovered_orbits_path.exists() and recovered_orbit_members_path.exists():
logger.info("Found recovered orbits in checkpoint")
recovered_orbits = FittedOrbits.from_parquet(recovered_orbits_path)
recovered_orbit_members = FittedOrbitMembers.from_parquet(
recovered_orbit_members_path
)

# Unfortunately we have to reinitialize the times to set the attribute
# correctly.
recovered_orbits = qv.defragment(recovered_orbits)
recovered_orbits = recovered_orbits.sort_by(
[
"coordinates.time.days",
"coordinates.time.nanos",
]
)

return create_checkpoint_data(
"complete",
recovered_orbits=recovered_orbits,
recovered_orbit_members=recovered_orbit_members,
)

# Now with filtered_observations available, we can check for the later
# stages in reverse order.
od_orbits_path = pathlib.Path(test_orbit_directory, "od_orbits.parquet")
od_orbit_members_path = pathlib.Path(
test_orbit_directory, "od_orbit_members.parquet"
)
if od_orbits_path.exists() and od_orbit_members_path.exists():
logger.info("Found OD orbits in checkpoint")
od_orbits = FittedOrbits.from_parquet(od_orbits_path)
od_orbit_members = FittedOrbitMembers.from_parquet(od_orbit_members_path)

# Unfortunately we have to reinitialize the times to set the attribute
# correctly.
od_orbits = qv.defragment(od_orbits)
od_orbits = od_orbits.sort_by(
[
"coordinates.time.days",
"coordinates.time.nanos",
]
)

return create_checkpoint_data(
"recover_orbits",
filtered_observations=filtered_observations,
od_orbits=od_orbits,
od_orbit_members=od_orbit_members,
)

iod_orbits_path = pathlib.Path(test_orbit_directory, "iod_orbits.parquet")
iod_orbit_members_path = pathlib.Path(
test_orbit_directory, "iod_orbit_members.parquet"
)
if iod_orbits_path.exists() and iod_orbit_members_path.exists():
logger.info("Found IOD orbits")
iod_orbits = FittedOrbits.from_parquet(iod_orbits_path)
iod_orbit_members = FittedOrbitMembers.from_parquet(iod_orbit_members_path)

# Unfortunately we have to reinitialize the times to set the attribute
# correctly.
iod_orbits = qv.defragment(iod_orbits)
iod_orbits = iod_orbits.sort_by(
[
"coordinates.time.days",
"coordinates.time.nanos",
]
)

return create_checkpoint_data(
"differential_correction",
filtered_observations=filtered_observations,
iod_orbits=iod_orbits,
iod_orbit_members=iod_orbit_members,
)

clusters_path = pathlib.Path(test_orbit_directory, "clusters.parquet")
cluster_members_path = pathlib.Path(test_orbit_directory, "cluster_members.parquet")
if clusters_path.exists() and cluster_members_path.exists():
logger.info("Found clusters")
clusters = Clusters.from_parquet(clusters_path)
cluster_members = ClusterMembers.from_parquet(cluster_members_path)

return create_checkpoint_data(
"initial_orbit_determination",
filtered_observations=filtered_observations,
clusters=clusters,
cluster_members=cluster_members,
)

transformed_detections_path = pathlib.Path(
test_orbit_directory, "transformed_detections.parquet"
)
if transformed_detections_path.exists():
logger.info("Found transformed detections")
transformed_detections = TransformedDetections.from_parquet(
transformed_detections_path
)

return create_checkpoint_data(
"cluster_and_link",
filtered_observations=filtered_observations,
transformed_detections=transformed_detections,
)

return create_checkpoint_data(
"range_and_transform", filtered_observations=filtered_observations
)
3 changes: 1 addition & 2 deletions thor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

class Config(BaseModel):
max_processes: Optional[int] = None
propagator: str = "PYOORB"
parallel_backend: Literal["cf"] = "cf"
propagator: Literal["PYOORB"] = "PYOORB"
cell_radius: float = 10
vx_min: float = -0.1
vx_max: float = 0.1
Expand Down
Loading

0 comments on commit 249a81c

Please sign in to comment.