Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use weave.Model for metadata management #14

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions examples/2d_spatial_eval/evaluate_spatial_relationship.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
from typing import Optional, Tuple

import fire
import jsonlines
import wandb
import weave

Expand Down
29 changes: 23 additions & 6 deletions hemm/eval_pipelines/eval_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from abc import ABC
from typing import Dict, List, Union
from functools import partial
from typing import Any, Dict, List, Optional, Union

import wandb
import weave
Expand All @@ -18,7 +19,12 @@ class EvaluationPipeline(ABC):
seed (int): Seed value for the random number generator.
"""

def __init__(self, model: BaseDiffusionModel, seed: int = 42) -> None:
def __init__(
self,
model: BaseDiffusionModel,
seed: int = 42,
configs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()
self.model = model
self.model.initialize()
Expand All @@ -42,6 +48,7 @@ def __init__(self, model: BaseDiffusionModel, seed: int = 42) -> None:
},
"seed": seed,
"diffusion_pipeline": dict(self.model._pipeline.config),
**configs,
}

def add_metric(self, metric_fn: BaseMetric):
Expand Down Expand Up @@ -105,10 +112,14 @@ def log_summary(self, summary: Dict[str, float]) -> None:
{
f"evalution/{self.model.diffusion_model_name_or_path}": self.evaluation_table,
f"summary/{self.model.diffusion_model_name_or_path}": summary_table,
"summary": summary,
}
)

def __call__(self, dataset: Union[List[Dict], str]) -> Dict[str, float]:
@weave.op()
def evaluate(
self, dataset: Union[List[Dict], str], name: Optional[str] = None
) -> Dict[str, float]:
"""Evaluate the Stable Diffusion model on the given dataset.

Args:
Expand All @@ -117,10 +128,16 @@ def __call__(self, dataset: Union[List[Dict], str]) -> Dict[str, float]:
"""
dataset = weave.ref(dataset).get() if isinstance(dataset, str) else dataset
evaluation = weave.Evaluation(
name=name,
dataset=dataset,
scorers=[metric_fn.evaluate_async for metric_fn in self.metric_functions],
scorers=[
partial(
metric_fn.evaluate_async,
metadata=weave.Model.model_validate(self.evaluation_configs),
)
for metric_fn in self.metric_functions
],
)
with weave.attributes(self.evaluation_configs):
summary = asyncio.run(evaluation.evaluate(self.infer_async))
summary = asyncio.run(evaluation.evaluate(self.infer_async))
self.log_summary(summary)
return summary
6 changes: 4 additions & 2 deletions hemm/metrics/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from abc import ABC, abstractmethod
from typing import Any, Dict

import weave


class BaseMetric(ABC):

def __init__(self) -> None:
super().__init__()

@abstractmethod
def evaluate(self) -> Dict[str, Any]:
def evaluate(self, metadata: weave.Model) -> Dict[str, Any]:
pass

@abstractmethod
def evaluate_async(self) -> Dict[str, Any]:
def evaluate_async(self, metadata: weave.Model) -> Dict[str, Any]:
pass
7 changes: 6 additions & 1 deletion hemm/metrics/image_quality/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from io import BytesIO
from typing import Any, Dict, Union

import weave
from PIL import Image
from pydantic import BaseModel

Expand Down Expand Up @@ -50,7 +51,11 @@ def compute_metric(
pass

def evaluate(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
self,
prompt: str,
ground_truth_image: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Dict[str, float]:
"""Compute the metric for the given images. This method is used as the scorer
function for `weave.Evaluation` in the evaluation pipelines.
Expand Down
16 changes: 12 additions & 4 deletions hemm/metrics/image_quality/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,22 @@ def compute_metric(

@weave.op()
def evaluate(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
self,
prompt: str,
ground_truth_image: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Union[float, Dict[str, float]]:
_ = "LPIPSMetric"
return super().evaluate(prompt, ground_truth_image, model_output)
return super().evaluate(prompt, ground_truth_image, model_output, metadata)

@weave.op()
async def evaluate_async(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
self,
prompt: str,
ground_truth_image: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Union[float, Dict[str, float]]:
_ = "LPIPSMetric"
return self.evaluate(prompt, ground_truth_image, model_output)
return self.evaluate(prompt, ground_truth_image, model_output, metadata)
16 changes: 12 additions & 4 deletions hemm/metrics/image_quality/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,22 @@ def compute_metric(

@weave.op()
def evaluate(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
self,
prompt: str,
ground_truth_image: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Union[float, Dict[str, float]]:
_ = "PSNRMetric"
return super().evaluate(prompt, ground_truth_image, model_output)
return super().evaluate(prompt, ground_truth_image, model_output, metadata)

@weave.op()
async def evaluate_async(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
self,
prompt: str,
ground_truth_image: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Union[float, Dict[str, float]]:
_ = "PSNRMetric"
return self.evaluate(prompt, ground_truth_image, model_output)
return self.evaluate(prompt, ground_truth_image, model_output, metadata)
16 changes: 12 additions & 4 deletions hemm/metrics/image_quality/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,22 @@ def compute_metric(

@weave.op()
def evaluate(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
self,
prompt: str,
ground_truth_image: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Union[float, Dict[str, float]]:
_ = "SSIMMetric"
return super().evaluate(prompt, ground_truth_image, model_output)
return super().evaluate(prompt, ground_truth_image, model_output, metadata)

@weave.op()
async def evaluate_async(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
self,
prompt: str,
ground_truth_image: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Union[float, Dict[str, float]]:
_ = "SSIMMetric"
return self.evaluate(prompt, ground_truth_image, model_output)
return self.evaluate(prompt, ground_truth_image, model_output, metadata)
5 changes: 4 additions & 1 deletion hemm/metrics/prompt_alignment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from io import BytesIO
from typing import Any, Dict, Union

import weave
from PIL import Image

from ..base import BaseMetric
Expand Down Expand Up @@ -37,7 +38,9 @@ def compute_metric(
"""
pass

def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
def evaluate(
self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model
) -> Dict[str, float]:
"""Compute the metric for the given image. This method is used as the scorer
function for `weave.Evaluation` in the evaluation pipelines.

Expand Down
10 changes: 6 additions & 4 deletions hemm/metrics/prompt_alignment/blip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@ def compute_metric(
)

@weave.op()
def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
def evaluate(
self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model
) -> Dict[str, float]:
_ = "BLIPScoreMertric"
return super().evaluate(prompt, model_output)
return super().evaluate(prompt, model_output, metadata)

@weave.op()
async def evaluate_async(
self, prompt: str, model_output: Dict[str, Any]
self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model
) -> Dict[str, float]:
_ = "BLIPScoreMertric"
return self.evaluate(prompt, model_output)
return self.evaluate(prompt, model_output, metadata)
10 changes: 6 additions & 4 deletions hemm/metrics/prompt_alignment/clip_iqa_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,15 @@ def compute_metric(
return score_dict

@weave.op()
def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
def evaluate(
self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model
) -> Dict[str, float]:
_ = "CLIPImageQualityScoreMetric"
return super().evaluate(prompt, model_output)
return super().evaluate(prompt, model_output, metadata)

@weave.op()
async def evaluate_async(
self, prompt: str, model_output: Dict[str, Any]
self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model
) -> Dict[str, float]:
_ = "CLIPImageQualityScoreMetric"
return self.evaluate(prompt, model_output)
return self.evaluate(prompt, model_output, metadata)
10 changes: 6 additions & 4 deletions hemm/metrics/prompt_alignment/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ def compute_metric(
)

@weave.op()
def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
def evaluate(
self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model
) -> Dict[str, float]:
_ = "CLIPScoreMetric"
return super().evaluate(prompt, model_output)
return super().evaluate(prompt, model_output, metadata)

@weave.op()
async def evaluate_async(
self, prompt: str, model_output: Dict[str, Any]
self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model
) -> Dict[str, float]:
_ = "CLIPScoreMetric"
return self.evaluate(prompt, model_output)
return self.evaluate(prompt, model_output, metadata)
7 changes: 6 additions & 1 deletion hemm/metrics/spatial_relationship/spatial_relationship_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def evaluate(
entity_2: str,
relationship: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Dict[str, Union[bool, float, int]]:
"""Calculate the spatial relationship score for the given prompt and model output.

Expand All @@ -235,6 +236,7 @@ def evaluate(
Dict[str, Union[bool, float, int]]: The comprehensive spatial relationship judgement.
"""
_ = prompt
_ = metadata

image = model_output["image"]
boxes: List[BoundingBox] = self.judge.predict(image)
Expand All @@ -251,5 +253,8 @@ async def evaluate_async(
entity_2: str,
relationship: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Dict[str, Union[bool, float, int]]:
return self.evaluate(prompt, entity_1, entity_2, relationship, model_output)
return self.evaluate(
prompt, entity_1, entity_2, relationship, model_output, metadata
)
7 changes: 6 additions & 1 deletion hemm/metrics/vqa/disentangled_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def evaluate(
adj_2: str,
noun_2: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Dict[str, Any]:
"""Evaluate the attribute-binding capability of the model.

Expand All @@ -78,6 +79,7 @@ def evaluate(
Dict[str, Any]: The evaluation result.
"""
_ = prompt
_ = metadata
judgement = self.judge.predict(
adj_1, noun_1, adj_2, noun_2, model_output["image"]
)
Expand All @@ -93,5 +95,8 @@ async def evaluate_async(
adj_2: str,
noun_2: str,
model_output: Dict[str, Any],
metadata: weave.Model,
) -> Dict[str, Any]:
return self.evaluate(prompt, adj_1, noun_1, adj_2, noun_2, model_output)
return self.evaluate(
prompt, adj_1, noun_1, adj_2, noun_2, model_output, metadata
)
8 changes: 5 additions & 3 deletions hemm/metrics/vqa/multi_modal_llm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def __init__(
self.name = name

@weave.op()
def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, Any]:
def evaluate(
self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model
) -> Dict[str, Any]:
"""Evaluate the generated image using the judge LLM model.

Args:
Expand All @@ -50,6 +52,6 @@ def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, Any]:

@weave.op()
async def evaluate_async(
self, prompt: str, model_output: Dict[str, Any]
self, prompt: str, model_output: Dict[str, Any], metadata: weave.Model
) -> Dict[str, Any]:
return self.evaluate(prompt, model_output)
return self.evaluate(prompt, model_output, metadata)
Loading