diff --git a/zeno_build/reporting/visualize.py b/zeno_build/reporting/visualize.py index 8af66f3..d8c5a00 100644 --- a/zeno_build/reporting/visualize.py +++ b/zeno_build/reporting/visualize.py @@ -4,10 +4,14 @@ from typing import Any import pandas as pd -from zeno import ModelReturn, ZenoParameters, model, zeno - +from zeno import ModelReturn, ZenoParameters, model +from zeno.classes.base import ZenoColumnType from zeno_build.experiments.experiment_run import ExperimentRun +from zeno.backend import ZenoBackend +from zeno_client import ZenoClient, ZenoMetric +import os +import time def visualize( df: pd.DataFrame, @@ -17,6 +21,7 @@ def visualize( data_column: str, functions: list[Callable], zeno_config: dict = {}, + project_name:str = "" ) -> None: """Run Zeno to visualize the results of a parameter search run. @@ -33,6 +38,16 @@ def visualize( raise ValueError("Length of data and labels must be equal.") if data_column not in df.columns: raise ValueError(f"Data column {data_column} not in DataFrame.") + if "ZENO_CLIENT_KEY" not in os.environ: + raise ValueError( + "ZENO_CLIENT_KEY environment variable must be set to visualize results." + ) + + zeno_client = ZenoClient(os.environ["ZENO_CLIENT_KEY"]) + if project_name == "": + project_name = time.strftime("%Y-%m-%d %H-%M-%S") + print(f"Creating project {project_name}") + model_results: dict[str, ExperimentRun] = {x.name: x for x in results} @model @@ -61,4 +76,28 @@ def pred(df, ops): ) config = config.copy(update=zeno_config) - zeno(config) + zeno_instance = ZenoBackend(config) + + # wait for the backend to complete processing + zeno_thread = zeno_instance.start_processing() + zeno_thread.join() + + project = zeno_client.create_project( + name=project_name, + view=view, + metrics=[ZenoMetric(id=idx,name=metric, type='mean', columns=[metric[4:]]) for idx,metric in enumerate(zeno_instance.metric_functions.keys())]) + + # upload the dataset + dataset_columns = ["index", data_column, "label"] + dataset_df = zeno_instance.df[dataset_columns] + for predistill_function in zeno_instance.predistill_functions.keys(): + dataset_df[predistill_function] = zeno_instance.df[ZenoColumnType.PREDISTILL+predistill_function] + project.upload_dataset(dataset_df, id_column="index", data_column=data_column, label_column="label") + + # upload the model results + for model_name in model_results: + df_model = pd.DataFrame({"index": dataset_df["index"]}) + df_model["output"] = zeno_instance.df[ZenoColumnType.OUTPUT+"output" + model_name] + for postdistill_function in zeno_instance.postdistill_functions.keys(): + df_model[postdistill_function] = zeno_instance.df[ZenoColumnType.POSTDISTILL + postdistill_function + model_name] + project.upload_system(df_model, name=model_name, id_column="index", output_column="output")