From 7c09b2e124b2a201bfe9a913e145613776a10706 Mon Sep 17 00:00:00 2001 From: soumik12345 <19soumik.rakshit96@gmail.com> Date: Sat, 3 Aug 2024 11:36:56 +0000 Subject: [PATCH] rename: EvaluationPipeline.__call__ to EvaluationPipeline.evaluate --- hemm/eval_pipelines/eval_pipeline.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/hemm/eval_pipelines/eval_pipeline.py b/hemm/eval_pipelines/eval_pipeline.py index 2d7033b..d60db82 100644 --- a/hemm/eval_pipelines/eval_pipeline.py +++ b/hemm/eval_pipelines/eval_pipeline.py @@ -1,7 +1,7 @@ import asyncio from abc import ABC from functools import partial -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import wandb import weave @@ -11,15 +11,6 @@ from .model import BaseDiffusionModel -def evaluation_wrapper(name: str) -> Callable[[Callable], Callable]: - def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore - return op - - return wrapper - - class EvaluationPipeline(ABC): """Evaluation pipeline to evaluate the a multi-modal generative model. @@ -125,8 +116,8 @@ def log_summary(self, summary: Dict[str, float]) -> None: } ) - @evaluation_wrapper("evaluate") - def __call__( + @weave.op() + def evaluate( self, dataset: Union[List[Dict], str], name: Optional[str] = None ) -> Dict[str, float]: """Evaluate the Stable Diffusion model on the given dataset.