Skip to content

Commit

Permalink
rename: EvaluationPipeline.__call__ to EvaluationPipeline.evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Aug 3, 2024
1 parent fed51c3 commit 7c09b2e
Showing 1 changed file with 3 additions and 12 deletions.
15 changes: 3 additions & 12 deletions hemm/eval_pipelines/eval_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from abc import ABC
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import wandb
import weave
Expand All @@ -11,15 +11,6 @@
from .model import BaseDiffusionModel


def evaluation_wrapper(name: str) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op.name = name # type: ignore
return op

return wrapper


class EvaluationPipeline(ABC):
"""Evaluation pipeline to evaluate the a multi-modal generative model.
Expand Down Expand Up @@ -125,8 +116,8 @@ def log_summary(self, summary: Dict[str, float]) -> None:
}
)

@evaluation_wrapper("evaluate")
def __call__(
@weave.op()
def evaluate(
self, dataset: Union[List[Dict], str], name: Optional[str] = None
) -> Dict[str, float]:
"""Evaluate the Stable Diffusion model on the given dataset.
Expand Down

0 comments on commit 7c09b2e

Please sign in to comment.