diff --git a/hemm/eval_pipelines/eval_pipeline.py b/hemm/eval_pipelines/eval_pipeline.py index 74ccba1..2d7033b 100644 --- a/hemm/eval_pipelines/eval_pipeline.py +++ b/hemm/eval_pipelines/eval_pipeline.py @@ -1,6 +1,7 @@ import asyncio from abc import ABC -from typing import Dict, List, Union +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Union import wandb import weave @@ -10,6 +11,15 @@ from .model import BaseDiffusionModel +def evaluation_wrapper(name: str) -> Callable[[Callable], Callable]: + def wrapper(fn: Callable) -> Callable: + op = weave.op()(fn) + op.name = name # type: ignore + return op + + return wrapper + + class EvaluationPipeline(ABC): """Evaluation pipeline to evaluate the a multi-modal generative model. @@ -18,7 +28,12 @@ class EvaluationPipeline(ABC): seed (int): Seed value for the random number generator. """ - def __init__(self, model: BaseDiffusionModel, seed: int = 42) -> None: + def __init__( + self, + model: BaseDiffusionModel, + seed: int = 42, + configs: Optional[Dict[str, Any]] = None, + ) -> None: super().__init__() self.model = model self.model.initialize() @@ -42,6 +57,7 @@ def __init__(self, model: BaseDiffusionModel, seed: int = 42) -> None: }, "seed": seed, "diffusion_pipeline": dict(self.model._pipeline.config), + **configs, } def add_metric(self, metric_fn: BaseMetric): @@ -105,10 +121,14 @@ def log_summary(self, summary: Dict[str, float]) -> None: { f"evalution/{self.model.diffusion_model_name_or_path}": self.evaluation_table, f"summary/{self.model.diffusion_model_name_or_path}": summary_table, + "summary": summary, } ) - def __call__(self, dataset: Union[List[Dict], str]) -> Dict[str, float]: + @evaluation_wrapper("evaluate") + def __call__( + self, dataset: Union[List[Dict], str], name: Optional[str] = None + ) -> Dict[str, float]: """Evaluate the Stable Diffusion model on the given dataset. Args: @@ -117,10 +137,16 @@ def __call__(self, dataset: Union[List[Dict], str]) -> Dict[str, float]: """ dataset = weave.ref(dataset).get() if isinstance(dataset, str) else dataset evaluation = weave.Evaluation( + name=name, dataset=dataset, - scorers=[metric_fn.evaluate_async for metric_fn in self.metric_functions], + scorers=[ + partial( + metric_fn.evaluate_async, + metadata=weave.Model.model_validate(self.evaluation_configs), + ) + for metric_fn in self.metric_functions + ], ) - with weave.attributes(self.evaluation_configs): - summary = asyncio.run(evaluation.evaluate(self.infer_async)) + summary = asyncio.run(evaluation.evaluate(self.infer_async)) self.log_summary(summary) return summary diff --git a/hemm/metrics/base.py b/hemm/metrics/base.py index fc38fa0..891f051 100644 --- a/hemm/metrics/base.py +++ b/hemm/metrics/base.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict +import weave + class BaseMetric(ABC): @@ -8,9 +10,9 @@ def __init__(self) -> None: super().__init__() @abstractmethod - def evaluate(self) -> Dict[str, Any]: + def evaluate(self, metadata: weave.Model) -> Dict[str, Any]: pass @abstractmethod - def evaluate_async(self) -> Dict[str, Any]: + def evaluate_async(self, metadata: weave.Model) -> Dict[str, Any]: pass diff --git a/hemm/metrics/image_quality/base.py b/hemm/metrics/image_quality/base.py index b9f614e..55c4875 100644 --- a/hemm/metrics/image_quality/base.py +++ b/hemm/metrics/image_quality/base.py @@ -3,6 +3,7 @@ from io import BytesIO from typing import Any, Dict, Union +import weave from PIL import Image from pydantic import BaseModel @@ -50,7 +51,11 @@ def compute_metric( pass def evaluate( - self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any] + self, + prompt: str, + ground_truth_image: str, + model_output: Dict[str, Any], + metadata: weave.Model, ) -> Dict[str, float]: """Compute the metric for the given images. This method is used as the scorer function for `weave.Evaluation` in the evaluation pipelines. diff --git a/hemm/metrics/image_quality/lpips.py b/hemm/metrics/image_quality/lpips.py index daf5bcd..e19b552 100644 --- a/hemm/metrics/image_quality/lpips.py +++ b/hemm/metrics/image_quality/lpips.py @@ -71,14 +71,22 @@ def compute_metric( @weave.op() def evaluate( - self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any] + self, + prompt: str, + ground_truth_image: str, + model_output: Dict[str, Any], + metadata: weave.Model, ) -> Union[float, Dict[str, float]]: _ = "LPIPSMetric" - return super().evaluate(prompt, ground_truth_image, model_output) + return super().evaluate(prompt, ground_truth_image, model_output, metadata) @weave.op() async def evaluate_async( - self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any] + self, + prompt: str, + ground_truth_image: str, + model_output: Dict[str, Any], + metadata: weave.Model, ) -> Union[float, Dict[str, float]]: _ = "LPIPSMetric" - return self.evaluate(prompt, ground_truth_image, model_output) + return self.evaluate(prompt, ground_truth_image, model_output, metadata) diff --git a/hemm/metrics/image_quality/psnr.py b/hemm/metrics/image_quality/psnr.py index c0ac83c..0f17628 100644 --- a/hemm/metrics/image_quality/psnr.py +++ b/hemm/metrics/image_quality/psnr.py @@ -65,14 +65,22 @@ def compute_metric( @weave.op() def evaluate( - self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any] + self, + prompt: str, + ground_truth_image: str, + model_output: Dict[str, Any], + metadata: weave.Model, ) -> Union[float, Dict[str, float]]: _ = "PSNRMetric" - return super().evaluate(prompt, ground_truth_image, model_output) + return super().evaluate(prompt, ground_truth_image, model_output, metadata) @weave.op() async def evaluate_async( - self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any] + self, + prompt: str, + ground_truth_image: str, + model_output: Dict[str, Any], + metadata: weave.Model, ) -> Union[float, Dict[str, float]]: _ = "PSNRMetric" - return self.evaluate(prompt, ground_truth_image, model_output) + return self.evaluate(prompt, ground_truth_image, model_output, metadata) diff --git a/hemm/metrics/image_quality/ssim.py b/hemm/metrics/image_quality/ssim.py index 8468613..ea47cd3 100644 --- a/hemm/metrics/image_quality/ssim.py +++ b/hemm/metrics/image_quality/ssim.py @@ -92,14 +92,22 @@ def compute_metric( @weave.op() def evaluate( - self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any] + self, + prompt: str, + ground_truth_image: str, + model_output: Dict[str, Any], + metadata: weave.Model, ) -> Union[float, Dict[str, float]]: _ = "SSIMMetric" - return super().evaluate(prompt, ground_truth_image, model_output) + return super().evaluate(prompt, ground_truth_image, model_output, metadata) @weave.op() async def evaluate_async( - self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any] + self, + prompt: str, + ground_truth_image: str, + model_output: Dict[str, Any], + metadata: weave.Model, ) -> Union[float, Dict[str, float]]: _ = "SSIMMetric" - return self.evaluate(prompt, ground_truth_image, model_output) + return self.evaluate(prompt, ground_truth_image, model_output, metadata) diff --git a/hemm/metrics/prompt_alignment/base.py b/hemm/metrics/prompt_alignment/base.py index 3ba2852..fa344d7 100644 --- a/hemm/metrics/prompt_alignment/base.py +++ b/hemm/metrics/prompt_alignment/base.py @@ -3,6 +3,7 @@ from io import BytesIO from typing import Any, Dict, Union +import weave from PIL import Image from ..base import BaseMetric @@ -37,7 +38,9 @@ def compute_metric( """ pass - def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]: + def evaluate( + self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model + ) -> Dict[str, float]: """Compute the metric for the given image. This method is used as the scorer function for `weave.Evaluation` in the evaluation pipelines. diff --git a/hemm/metrics/prompt_alignment/blip_score.py b/hemm/metrics/prompt_alignment/blip_score.py index 0caf9c1..6e639d4 100644 --- a/hemm/metrics/prompt_alignment/blip_score.py +++ b/hemm/metrics/prompt_alignment/blip_score.py @@ -48,13 +48,15 @@ def compute_metric( ) @weave.op() - def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]: + def evaluate( + self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model + ) -> Dict[str, float]: _ = "BLIPScoreMertric" - return super().evaluate(prompt, model_output) + return super().evaluate(prompt, model_output, metadata) @weave.op() async def evaluate_async( - self, prompt: str, model_output: Dict[str, Any] + self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model ) -> Dict[str, float]: _ = "BLIPScoreMertric" - return self.evaluate(prompt, model_output) + return self.evaluate(prompt, model_output, metadata) diff --git a/hemm/metrics/prompt_alignment/clip_iqa_score.py b/hemm/metrics/prompt_alignment/clip_iqa_score.py index ba61660..de64c63 100644 --- a/hemm/metrics/prompt_alignment/clip_iqa_score.py +++ b/hemm/metrics/prompt_alignment/clip_iqa_score.py @@ -80,13 +80,15 @@ def compute_metric( return score_dict @weave.op() - def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]: + def evaluate( + self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model + ) -> Dict[str, float]: _ = "CLIPImageQualityScoreMetric" - return super().evaluate(prompt, model_output) + return super().evaluate(prompt, model_output, metadata) @weave.op() async def evaluate_async( - self, prompt: str, model_output: Dict[str, Any] + self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model ) -> Dict[str, float]: _ = "CLIPImageQualityScoreMetric" - return self.evaluate(prompt, model_output) + return self.evaluate(prompt, model_output, metadata) diff --git a/hemm/metrics/prompt_alignment/clip_score.py b/hemm/metrics/prompt_alignment/clip_score.py index bb1f79e..62ccc6f 100644 --- a/hemm/metrics/prompt_alignment/clip_score.py +++ b/hemm/metrics/prompt_alignment/clip_score.py @@ -45,13 +45,15 @@ def compute_metric( ) @weave.op() - def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]: + def evaluate( + self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model + ) -> Dict[str, float]: _ = "CLIPScoreMetric" - return super().evaluate(prompt, model_output) + return super().evaluate(prompt, model_output, metadata) @weave.op() async def evaluate_async( - self, prompt: str, model_output: Dict[str, Any] + self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model ) -> Dict[str, float]: _ = "CLIPScoreMetric" - return self.evaluate(prompt, model_output) + return self.evaluate(prompt, model_output, metadata) diff --git a/hemm/metrics/spatial_relationship/spatial_relationship_2d.py b/hemm/metrics/spatial_relationship/spatial_relationship_2d.py index 73078d0..f516615 100644 --- a/hemm/metrics/spatial_relationship/spatial_relationship_2d.py +++ b/hemm/metrics/spatial_relationship/spatial_relationship_2d.py @@ -221,6 +221,7 @@ def evaluate( entity_2: str, relationship: str, model_output: Dict[str, Any], + metadata: weave.Model, ) -> Dict[str, Union[bool, float, int]]: """Calculate the spatial relationship score for the given prompt and model output. @@ -235,6 +236,7 @@ def evaluate( Dict[str, Union[bool, float, int]]: The comprehensive spatial relationship judgement. """ _ = prompt + _ = metadata image = model_output["image"] boxes: List[BoundingBox] = self.judge.predict(image) @@ -251,5 +253,8 @@ async def evaluate_async( entity_2: str, relationship: str, model_output: Dict[str, Any], + metadata: weave.Model, ) -> Dict[str, Union[bool, float, int]]: - return self.evaluate(prompt, entity_1, entity_2, relationship, model_output) + return self.evaluate( + prompt, entity_1, entity_2, relationship, model_output, metadata + ) diff --git a/hemm/metrics/vqa/disentangled_vqa.py b/hemm/metrics/vqa/disentangled_vqa.py index 2e8bbee..26426ae 100644 --- a/hemm/metrics/vqa/disentangled_vqa.py +++ b/hemm/metrics/vqa/disentangled_vqa.py @@ -63,6 +63,7 @@ def evaluate( adj_2: str, noun_2: str, model_output: Dict[str, Any], + metadata: weave.Model, ) -> Dict[str, Any]: """Evaluate the attribute-binding capability of the model. @@ -78,6 +79,7 @@ def evaluate( Dict[str, Any]: The evaluation result. """ _ = prompt + _ = metadata judgement = self.judge.predict( adj_1, noun_1, adj_2, noun_2, model_output["image"] ) @@ -93,5 +95,8 @@ async def evaluate_async( adj_2: str, noun_2: str, model_output: Dict[str, Any], + metadata: weave.Model, ) -> Dict[str, Any]: - return self.evaluate(prompt, adj_1, noun_1, adj_2, noun_2, model_output) + return self.evaluate( + prompt, adj_1, noun_1, adj_2, noun_2, model_output, metadata + ) diff --git a/hemm/metrics/vqa/multi_modal_llm_eval.py b/hemm/metrics/vqa/multi_modal_llm_eval.py index 573d4ea..47ff1e3 100644 --- a/hemm/metrics/vqa/multi_modal_llm_eval.py +++ b/hemm/metrics/vqa/multi_modal_llm_eval.py @@ -29,7 +29,9 @@ def __init__( self.name = name @weave.op() - def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, Any]: + def evaluate( + self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model + ) -> Dict[str, Any]: """Evaluate the generated image using the judge LLM model. Args: @@ -50,6 +52,6 @@ def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, Any]: @weave.op() async def evaluate_async( - self, prompt: str, model_output: Dict[str, Any] + self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model ) -> Dict[str, Any]: - return self.evaluate(prompt, model_output) + return self.evaluate(prompt, model_output, metadata)