Skip to content

Commit

Permalink
add: docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Jun 29, 2024
1 parent 0979344 commit 76ca310
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 1 deletion.
49 changes: 49 additions & 0 deletions hemm/metrics/attribute_binding/dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ class AttributeBindingEvaluationResponse(BaseModel):


class AttributeBindingModel(weave.Model):
"""Weave Model to generate prompts for evaluation of attribute binding capability of
image-generation models using an OpenAI model.
Args:
openai_model (Optional[str]): The OpenAI model to use for generating prompts.
num_prompts (Optional[int]): Number of prompts to generate.
"""

openai_model: Optional[str] = "gpt-3.5-turbo"
num_prompts: Optional[int] = 20
_openai_client: Optional[OpenAI] = None
Expand Down Expand Up @@ -85,6 +93,12 @@ def _initialize(self):

@weave.op()
def predict(self, seed: int) -> Dict[str, str]:
"""Generate prompts and corresponding metadata for evaluation of attribute binding
capability of image-generation models.
Args:
seed (int): OpenAI seed to use for generating prompts.
"""
return {
"response": self._openai_client.chat.completions.create(
model=self.openai_model,
Expand All @@ -107,6 +121,36 @@ def predict(self, seed: int) -> Dict[str, str]:


class AttributeBindingDatasetGenerator:
"""Dataset generator for evaluation of attribute binding capability of image-generation models.
This class enables us to generate the dataset consisting of prompts in the format
`“a {adj_1} {noun_1} and a {adj_2} {noun_2}”` and the corresponding metadata using an LLM capable
of generating json objects like GPT4-O. The dataset is then published both as a
[W&B dataset artifact](https://docs.wandb.ai/guides/artifacts) and as a
[weave dataset](https://wandb.github.io/weave/guides/core-types/datasets).
??? example "Sample usage"
```python
from hemm.metrics.attribute_binding import AttributeBindingDatasetGenerator
dataset_generator = AttributeBindingDatasetGenerator(
openai_model="gpt-4o",
openai_seed=42,
num_prompts_in_single_call=20,
num_api_calls=50,
project_name="disentangled_vqa",
)
dataset_generator(dump_dir="./dump")
```
Args:
openai_model (Optional[str]): The OpenAI model to use for generating prompts.
openai_seed (Optional[Union[int, List[int]]]): Seed to use for generating prompts.
If not provided, seeds will be auto-generated.
num_prompts_in_single_call (Optional[int]): Number of prompts to generate in a single API call.
num_api_calls (Optional[int]): Number of API calls to make.
project_name (Optional[str]): Name of the Weave project to use for logging the dataset.
"""

def __init__(
self,
Expand Down Expand Up @@ -189,6 +233,11 @@ async def evaluate_generated_response(
return eval_response.model_dump()

def __call__(self, dump_dir: Optional[str] = "./dump") -> None:
"""Generate the dataset and publish it to Weave.
Args:
dump_dir (Optional[str]): Directory to dump the dataset.
"""
wandb.init(
project=self.project_name,
job_type="attribute_binding_dataset",
Expand Down
44 changes: 44 additions & 0 deletions hemm/metrics/attribute_binding/disentangled_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,37 @@


class DisentangledVQAMetric(BaseMetric):
"""Disentangled VQA metric to evaluate the attribute-binding capability
for image generation models as proposed in Section 4.1 from the paper
[T2I-CompBench: A Comprehensive Benchmark for Open-world Compositional Text-to-image Generation](https://arxiv.org/pdf/2307.06350).
??? example "Sample usage"
```python
import wandb
import weave
wandb.init(project=project, entity=entity, job_type="evaluation")
weave.init(project_name=project)
diffusion_model = BaseDiffusionModel(
diffusion_model_name_or_path=diffusion_model_address,
enable_cpu_offfload=diffusion_model_enable_cpu_offfload,
image_height=image_size[0],
image_width=image_size[1],
)
evaluation_pipeline = EvaluationPipeline(model=diffusion_model)
judge = BlipVQAJudge()
metric = DisentangledVQAMetric(judge=judge, name="disentangled_blip_metric")
evaluation_pipeline.add_metric(metric)
evaluation_pipeline(dataset=dataset)
```
Args:
judge (Union[weave.Model, BlipVQAJudge]): The judge model to evaluate the attribute-binding capability.
name (Optional[str]): The name of the metric. Defaults to "disentangled_vlm_metric".
"""

def __init__(
self,
Expand All @@ -30,6 +61,19 @@ def evaluate(
noun_2: str,
model_output: Dict[str, Any],
) -> Dict[str, Any]:
"""Evaluate the attribute-binding capability of the model.
Args:
prompt (str): The prompt for the model.
adj_1 (str): The first adjective.
noun_1 (str): The first noun.
adj_2 (str): The second adjective.
noun_2 (str): The second noun.
model_output (Dict[str, Any]): The model output.
Returns:
Dict[str, Any]: The evaluation result.
"""
_ = prompt
judgement = self.judge.predict(
adj_1, noun_1, adj_2, noun_2, model_output["image"]
Expand Down
20 changes: 20 additions & 0 deletions hemm/metrics/attribute_binding/judges/blip_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@


class BlipVQAJudge(weave.Model):
"""Weave Model to judge the presence of entities in an image using the
[Blip-VQA model](https://huggingface.co/Salesforce/blip-vqa-base).
Args:
blip_processor_address (str): The address of the BlipProcessor model.
blip_vqa_address (str): The address of the BlipForQuestionAnswering model.
device (str): The device to use for inference
"""

blip_processor_address: str = "Salesforce/blip-vqa-base"
blip_vqa_address: str = "Salesforce/blip-vqa-base"
Expand Down Expand Up @@ -57,6 +65,18 @@ def get_target_token_probability(
def predict(
self, adj_1: str, noun_1: str, adj_2: str, noun_2: str, image: str
) -> Dict:
"""Predict the probabilities presence of entities in an image using the Blip-VQA model.
Args:
adj_1 (str): The adjective of the first entity.
noun_1 (str): The noun of the first entity.
adj_2 (str): The adjective of the second entity.
noun_2 (str): The noun of the second entity.
image (str): The base64 encoded image.
Returns:
Dict: The probabilities of the presence of the entities.
"""
question_1 = f"is {adj_1} {noun_1} present in the picture?"
question_2 = f"is {adj_2} {noun_2} present in the picture?"
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


class SpatialRelationshipMetric2D(BaseMetric):
"""Spatial relationship metric for 2D images as proposed by Section 4.2 from the paper
"""Spatial relationship metric for image generation as proposed in Section 4.2 from the paper
[T2I-CompBench: A Comprehensive Benchmark for Open-world Compositional Text-to-image Generation](https://arxiv.org/pdf/2307.06350).
??? example "Sample usage"
Expand Down

0 comments on commit 76ca310

Please sign in to comment.