Skip to content

Commit

Permalink
changing visualise function to host results of zeno client instead of…
Browse files Browse the repository at this point in the history
… localserver
  • Loading branch information
amangupt01 authored Oct 30, 2023
1 parent 67f5316 commit 2b61914
Showing 1 changed file with 42 additions and 3 deletions.
45 changes: 42 additions & 3 deletions zeno_build/reporting/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
from typing import Any

import pandas as pd
from zeno import ModelReturn, ZenoParameters, model, zeno

from zeno import ModelReturn, ZenoParameters, model
from zeno.classes.base import ZenoColumnType
from zeno_build.experiments.experiment_run import ExperimentRun
from zeno.backend import ZenoBackend
from zeno_client import ZenoClient, ZenoMetric

import os
import time

def visualize(
df: pd.DataFrame,
Expand All @@ -17,6 +21,7 @@ def visualize(
data_column: str,
functions: list[Callable],
zeno_config: dict = {},
project_name:str = ""
) -> None:
"""Run Zeno to visualize the results of a parameter search run.
Expand All @@ -33,6 +38,16 @@ def visualize(
raise ValueError("Length of data and labels must be equal.")
if data_column not in df.columns:
raise ValueError(f"Data column {data_column} not in DataFrame.")
if "ZENO_CLIENT_KEY" not in os.environ:
raise ValueError(
"ZENO_CLIENT_KEY environment variable must be set to visualize results."
)

zeno_client = ZenoClient(os.environ["ZENO_CLIENT_KEY"])
if project_name == "":
project_name = time.strftime("%Y-%m-%d %H-%M-%S")
print(f"Creating project {project_name}")

model_results: dict[str, ExperimentRun] = {x.name: x for x in results}

@model
Expand Down Expand Up @@ -61,4 +76,28 @@ def pred(df, ops):
)
config = config.copy(update=zeno_config)

zeno(config)
zeno_instance = ZenoBackend(config)

# wait for the backend to complete processing
zeno_thread = zeno_instance.start_processing()
zeno_thread.join()

project = zeno_client.create_project(
name=project_name,
view=view,
metrics=[ZenoMetric(id=idx,name=metric, type='mean', columns=[metric[4:]]) for idx,metric in enumerate(zeno_instance.metric_functions.keys())])

# upload the dataset
dataset_columns = ["index", data_column, "label"]
dataset_df = zeno_instance.df[dataset_columns]
for predistill_function in zeno_instance.predistill_functions.keys():
dataset_df[predistill_function] = zeno_instance.df[ZenoColumnType.PREDISTILL+predistill_function]
project.upload_dataset(dataset_df, id_column="index", data_column=data_column, label_column="label")

# upload the model results
for model_name in model_results:
df_model = pd.DataFrame({"index": dataset_df["index"]})
df_model["output"] = zeno_instance.df[ZenoColumnType.OUTPUT+"output" + model_name]
for postdistill_function in zeno_instance.postdistill_functions.keys():
df_model[postdistill_function] = zeno_instance.df[ZenoColumnType.POSTDISTILL + postdistill_function + model_name]
project.upload_system(df_model, name=model_name, id_column="index", output_column="output")

0 comments on commit 2b61914

Please sign in to comment.