Skip to content

Commit

Permalink
update: metadata management for metrics using weave.Model
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Aug 3, 2024
1 parent 482c029 commit fed51c3
Show file tree
Hide file tree
Showing 13 changed files with 117 additions and 39 deletions.
38 changes: 32 additions & 6 deletions hemm/eval_pipelines/eval_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from abc import ABC
from typing import Dict, List, Union
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union

import wandb
import weave
Expand All @@ -10,6 +11,15 @@
from .model import BaseDiffusionModel


def evaluation_wrapper(name: str) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op.name = name # type: ignore
return op

return wrapper


class EvaluationPipeline(ABC):
"""Evaluation pipeline to evaluate the a multi-modal generative model.
Expand All @@ -18,7 +28,12 @@ class EvaluationPipeline(ABC):
seed (int): Seed value for the random number generator.
"""

def __init__(self, model: BaseDiffusionModel, seed: int = 42) -> None:
def __init__(
self,
model: BaseDiffusionModel,
seed: int = 42,
configs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()
self.model = model
self.model.initialize()
Expand All @@ -42,6 +57,7 @@ def __init__(self, model: BaseDiffusionModel, seed: int = 42) -> None:
},
"seed": seed,
"diffusion_pipeline": dict(self.model._pipeline.config),
**configs,
}

def add_metric(self, metric_fn: BaseMetric):
Expand Down Expand Up @@ -105,10 +121,14 @@ def log_summary(self, summary: Dict[str, float]) -> None:
{
f"evalution/{self.model.diffusion_model_name_or_path}": self.evaluation_table,
f"summary/{self.model.diffusion_model_name_or_path}": summary_table,
"summary": summary,
}
)

def __call__(self, dataset: Union[List[Dict], str]) -> Dict[str, float]:
@evaluation_wrapper("evaluate")
def __call__(
self, dataset: Union[List[Dict], str], name: Optional[str] = None
) -> Dict[str, float]:
"""Evaluate the Stable Diffusion model on the given dataset.
Args:
Expand All @@ -117,10 +137,16 @@ def __call__(self, dataset: Union[List[Dict], str]) -> Dict[str, float]:
"""
dataset = weave.ref(dataset).get() if isinstance(dataset, str) else dataset
evaluation = weave.Evaluation(
name=name,
dataset=dataset,
scorers=[metric_fn.evaluate_async for metric_fn in self.metric_functions],
scorers=[
partial(
metric_fn.evaluate_async,
metadata=weave.Model.model_validate(self.evaluation_configs),
)
for metric_fn in self.metric_functions
],
)
with weave.attributes(self.evaluation_configs):
summary = asyncio.run(evaluation.evaluate(self.infer_async))
summary = asyncio.run(evaluation.evaluate(self.infer_async))
self.log_summary(summary)
return summary
6 changes: 4 additions & 2 deletions hemm/metrics/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from abc import ABC, abstractmethod
from typing import Any, Dict

import weave


class BaseMetric(ABC):

def __init__(self) -> None:
super().__init__()

@abstractmethod
def evaluate(self) -> Dict[str, Any]:
def evaluate(self, metadata: weave.Model) -> Dict[str, Any]:
pass

@abstractmethod
def evaluate_async(self) -> Dict[str, Any]:
def evaluate_async(self, metadata: weave.Model) -> Dict[str, Any]:
pass
7 changes: 6 additions & 1 deletion hemm/metrics/image_quality/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from io import BytesIO
from typing import Any, Dict, Union

import weave
from PIL import Image
from pydantic import BaseModel

Expand Down Expand Up @@ -50,7 +51,11 @@ def compute_metric(
pass

def evaluate(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
self,
prompt: str,
ground_truth_image: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Dict[str, float]:
"""Compute the metric for the given images. This method is used as the scorer
function for `weave.Evaluation` in the evaluation pipelines.
Expand Down
16 changes: 12 additions & 4 deletions hemm/metrics/image_quality/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,22 @@ def compute_metric(

@weave.op()
def evaluate(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
self,
prompt: str,
ground_truth_image: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Union[float, Dict[str, float]]:
_ = "LPIPSMetric"
return super().evaluate(prompt, ground_truth_image, model_output)
return super().evaluate(prompt, ground_truth_image, model_output, metadata)

@weave.op()
async def evaluate_async(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
self,
prompt: str,
ground_truth_image: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Union[float, Dict[str, float]]:
_ = "LPIPSMetric"
return self.evaluate(prompt, ground_truth_image, model_output)
return self.evaluate(prompt, ground_truth_image, model_output, metadata)
16 changes: 12 additions & 4 deletions hemm/metrics/image_quality/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,22 @@ def compute_metric(

@weave.op()
def evaluate(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
self,
prompt: str,
ground_truth_image: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Union[float, Dict[str, float]]:
_ = "PSNRMetric"
return super().evaluate(prompt, ground_truth_image, model_output)
return super().evaluate(prompt, ground_truth_image, model_output, metadata)

@weave.op()
async def evaluate_async(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
self,
prompt: str,
ground_truth_image: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Union[float, Dict[str, float]]:
_ = "PSNRMetric"
return self.evaluate(prompt, ground_truth_image, model_output)
return self.evaluate(prompt, ground_truth_image, model_output, metadata)
16 changes: 12 additions & 4 deletions hemm/metrics/image_quality/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,22 @@ def compute_metric(

@weave.op()
def evaluate(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
self,
prompt: str,
ground_truth_image: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Union[float, Dict[str, float]]:
_ = "SSIMMetric"
return super().evaluate(prompt, ground_truth_image, model_output)
return super().evaluate(prompt, ground_truth_image, model_output, metadata)

@weave.op()
async def evaluate_async(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
self,
prompt: str,
ground_truth_image: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Union[float, Dict[str, float]]:
_ = "SSIMMetric"
return self.evaluate(prompt, ground_truth_image, model_output)
return self.evaluate(prompt, ground_truth_image, model_output, metadata)
5 changes: 4 additions & 1 deletion hemm/metrics/prompt_alignment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from io import BytesIO
from typing import Any, Dict, Union

import weave
from PIL import Image

from ..base import BaseMetric
Expand Down Expand Up @@ -37,7 +38,9 @@ def compute_metric(
"""
pass

def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
def evaluate(
self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model
) -> Dict[str, float]:
"""Compute the metric for the given image. This method is used as the scorer
function for `weave.Evaluation` in the evaluation pipelines.
Expand Down
10 changes: 6 additions & 4 deletions hemm/metrics/prompt_alignment/blip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@ def compute_metric(
)

@weave.op()
def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
def evaluate(
self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model
) -> Dict[str, float]:
_ = "BLIPScoreMertric"
return super().evaluate(prompt, model_output)
return super().evaluate(prompt, model_output, metadata)

@weave.op()
async def evaluate_async(
self, prompt: str, model_output: Dict[str, Any]
self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model
) -> Dict[str, float]:
_ = "BLIPScoreMertric"
return self.evaluate(prompt, model_output)
return self.evaluate(prompt, model_output, metadata)
10 changes: 6 additions & 4 deletions hemm/metrics/prompt_alignment/clip_iqa_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,15 @@ def compute_metric(
return score_dict

@weave.op()
def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
def evaluate(
self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model
) -> Dict[str, float]:
_ = "CLIPImageQualityScoreMetric"
return super().evaluate(prompt, model_output)
return super().evaluate(prompt, model_output, metadata)

@weave.op()
async def evaluate_async(
self, prompt: str, model_output: Dict[str, Any]
self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model
) -> Dict[str, float]:
_ = "CLIPImageQualityScoreMetric"
return self.evaluate(prompt, model_output)
return self.evaluate(prompt, model_output, metadata)
10 changes: 6 additions & 4 deletions hemm/metrics/prompt_alignment/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ def compute_metric(
)

@weave.op()
def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
def evaluate(
self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model
) -> Dict[str, float]:
_ = "CLIPScoreMetric"
return super().evaluate(prompt, model_output)
return super().evaluate(prompt, model_output, metadata)

@weave.op()
async def evaluate_async(
self, prompt: str, model_output: Dict[str, Any]
self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model
) -> Dict[str, float]:
_ = "CLIPScoreMetric"
return self.evaluate(prompt, model_output)
return self.evaluate(prompt, model_output, metadata)
7 changes: 6 additions & 1 deletion hemm/metrics/spatial_relationship/spatial_relationship_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def evaluate(
entity_2: str,
relationship: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Dict[str, Union[bool, float, int]]:
"""Calculate the spatial relationship score for the given prompt and model output.
Expand All @@ -235,6 +236,7 @@ def evaluate(
Dict[str, Union[bool, float, int]]: The comprehensive spatial relationship judgement.
"""
_ = prompt
_ = metadata

image = model_output["image"]
boxes: List[BoundingBox] = self.judge.predict(image)
Expand All @@ -251,5 +253,8 @@ async def evaluate_async(
entity_2: str,
relationship: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Dict[str, Union[bool, float, int]]:
return self.evaluate(prompt, entity_1, entity_2, relationship, model_output)
return self.evaluate(
prompt, entity_1, entity_2, relationship, model_output, metadata
)
7 changes: 6 additions & 1 deletion hemm/metrics/vqa/disentangled_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def evaluate(
adj_2: str,
noun_2: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Dict[str, Any]:
"""Evaluate the attribute-binding capability of the model.
Expand All @@ -78,6 +79,7 @@ def evaluate(
Dict[str, Any]: The evaluation result.
"""
_ = prompt
_ = metadata
judgement = self.judge.predict(
adj_1, noun_1, adj_2, noun_2, model_output["image"]
)
Expand All @@ -93,5 +95,8 @@ async def evaluate_async(
adj_2: str,
noun_2: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Dict[str, Any]:
return self.evaluate(prompt, adj_1, noun_1, adj_2, noun_2, model_output)
return self.evaluate(
prompt, adj_1, noun_1, adj_2, noun_2, model_output, metadata
)
8 changes: 5 additions & 3 deletions hemm/metrics/vqa/multi_modal_llm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def __init__(
self.name = name

@weave.op()
def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, Any]:
def evaluate(
self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model
) -> Dict[str, Any]:
"""Evaluate the generated image using the judge LLM model.
Args:
Expand All @@ -50,6 +52,6 @@ def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, Any]:

@weave.op()
async def evaluate_async(
self, prompt: str, model_output: Dict[str, Any]
self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model
) -> Dict[str, Any]:
return self.evaluate(prompt, model_output)
return self.evaluate(prompt, model_output, metadata)

0 comments on commit fed51c3

Please sign in to comment.