Skip to content

Commit

Permalink
add: additional parameter evaluation_in_async to EvaluationPipeline._…
Browse files Browse the repository at this point in the history
…_call__
  • Loading branch information
soumik12345 committed Oct 7, 2024
1 parent ed76a04 commit 3765b2b
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions hemm/eval_pipelines/eval_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from abc import ABC
from typing import Dict, List, Union

import wandb
import weave

import wandb

from ..metrics.base import BaseMetric
from .model import BaseDiffusionModel, FalDiffusionModel, StabilityAPIModel

Expand Down Expand Up @@ -120,17 +121,23 @@ def log_summary(self, summary: Dict[str, float]) -> None:
}
)

def __call__(self, dataset: Union[List[Dict], str]) -> Dict[str, float]:
def __call__(
self, dataset: Union[List[Dict], str], evaluation_in_async: bool = True
) -> Dict[str, float]:
"""Evaluate the Stable Diffusion model on the given dataset.
Args:
dataset (Union[List[Dict], str]): Dataset to evaluate the model on. If a string is
passed, it is assumed to be a Weave dataset reference.
evaluation_in_async (bool): Whether to evaluate the metrics in async mode.
"""
dataset = weave.ref(dataset).get() if isinstance(dataset, str) else dataset
evaluation = weave.Evaluation(
dataset=dataset,
scorers=[metric_fn.evaluate_async for metric_fn in self.metric_functions],
scorers=[
metric_fn.evaluate_async if evaluation_in_async else metric_fn.evaluate
for metric_fn in self.metric_functions
],
)
self.model.configs.update(self.evaluation_configs)
summary = asyncio.run(evaluation.evaluate(self.infer_async))
Expand Down

0 comments on commit 3765b2b

Please sign in to comment.