Skip to content

Commit

Permalink
cast as anypath on load and nowhere else
Browse files Browse the repository at this point in the history
  • Loading branch information
klwetstone committed Aug 15, 2023
1 parent 93a7611 commit ca4f28a
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 9 deletions.
6 changes: 6 additions & 0 deletions cyano/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import yaml

from cloudpathlib import AnyPath
from pathlib import Path
import typer

Expand Down Expand Up @@ -35,6 +36,8 @@ def predict(
"""Load an existing cyanobacteria prediction model and generate
severity level predictions for a set of samples.
"""
samples_path = AnyPath(samples_path)

pipeline = CyanoModelPipeline.from_disk(model_zip)
pipeline.run_prediction(samples_path, output_path)

Expand All @@ -53,6 +56,9 @@ def evaluate(
default=Path.cwd() / "metrics", help="Folder in which to save out metrics and plots."
),
):
y_pred_csv = AnyPath(y_pred_csv)
y_true_csv = AnyPath(y_true_csv)

EvaluatePreds(
y_pred_csv=y_pred_csv, y_true_csv=y_true_csv, save_dir=save_dir
).calculate_all_and_save()
Expand Down
5 changes: 2 additions & 3 deletions cyano/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
from pathlib import Path

from cloudpathlib import AnyPath
import pandas as pd
import lightgbm as lgb
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -88,7 +87,7 @@ def __init__(self, y_true_csv: Path, y_pred_csv: Path, save_dir: Path, model: lg
"""
self.model = model

y_true_df = pd.read_csv(AnyPath(y_true_csv))
y_true_df = pd.read_csv(y_true_csv)

if "severity" not in y_true_df.columns:
raise ValueError("Evaluation data must include a `severity` column to evaluate.")
Expand All @@ -98,7 +97,7 @@ def __init__(self, y_true_csv: Path, y_pred_csv: Path, save_dir: Path, model: lg
self.y_true = y_true_df["severity"].rename("y_true")
self.metadata = y_true_df.drop(columns=["severity"])

y_pred_df = pd.read_csv(AnyPath(y_pred_csv)).set_index("sample_id")
y_pred_df = pd.read_csv(y_pred_csv).set_index("sample_id")

try:
self.y_pred = y_pred_df.loc[self.y_true.index]["severity"].rename("y_pred")
Expand Down
7 changes: 6 additions & 1 deletion cyano/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from typing import Union
import yaml

from cloudpathlib import AnyPath
from loguru import logger
from pydantic import BaseModel, ConfigDict, field_serializer
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator

from cyano.config import FeaturesConfig, ModelTrainingConfig
from cyano.pipeline import CyanoModelPipeline
Expand All @@ -20,6 +21,10 @@ class ExperimentConfig(BaseModel):
debug: bool = False
num_processes: int = 4

@field_validator("train_csv", "predict_csv")
def convert_filepaths(cls, path_field):
return AnyPath(path_field)

# Avoid conflict with pydantic protected namespace
model_config = ConfigDict(protected_namespaces=())

Expand Down
8 changes: 4 additions & 4 deletions cyano/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import yaml
from zipfile import ZipFile

from cloudpathlib import AnyPath
import lightgbm as lgb
from loguru import logger
import pandas as pd
Expand Down Expand Up @@ -41,7 +40,7 @@ def __init__(

def _prep_train_data(self, data, debug: bool = False):
"""Load labels and save out samples with UIDs"""
labels = pd.read_csv(AnyPath(data))
labels = pd.read_csv(data)
labels = labels[["date", "latitude", "longitude", "severity"]]
labels = add_unique_identifier(labels)
if debug:
Expand Down Expand Up @@ -79,8 +78,9 @@ def _prepare_features(self, samples):
features = generate_features(samples, satellite_meta, self.features_config, self.cache_dir)
save_features_to = self.cache_dir / "features_train.csv"
features.to_csv(save_features_to, index=True)
pct_with_features = features.index.nunique() / samples.shape[0]
logger.success(
f"{features.shape[1]:,} features for {features.index.nunique():,} samples (of {samples.shape[0]:,}) saved to {save_features_to}"
f"{features.shape[1]:,} features for {features.index.nunique():,} samples ({pct_with_features:.0%}) saved to {save_features_to}"
)

return features
Expand Down Expand Up @@ -123,7 +123,7 @@ def from_disk(cls, filepath, cache_dir=None):
return cls(features_config=features_config, model=model, cache_dir=cache_dir)

def _prep_predict_data(self, data, debug: bool = False):
df = pd.read_csv(AnyPath(data))
df = pd.read_csv(data)
df = add_unique_identifier(df)

samples = df[["date", "latitude", "longitude"]]
Expand Down
2 changes: 1 addition & 1 deletion tests/assets/evaluate_data.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
latitude,longitude,date,split,region,severity,density
latitude,longitude,date,split,region,severity,density_cells_per_ml
40.090275,-76.873132,2018-05-21,train,northeast,1,0.0
35.7200811863161,-79.1374207771809,2013-05-22,train,south,2,29046.0
35.6940254103693,-79.1858165585188,2016-10-18,train,south,1,94.0
Expand Down

0 comments on commit ca4f28a

Please sign in to comment.