From 76ca31022b94315fe7db217684de82e21f34dfe8 Mon Sep 17 00:00:00 2001 From: soumik12345 <19soumik.rakshit96@gmail.com> Date: Sat, 29 Jun 2024 20:21:20 +0000 Subject: [PATCH] add: docstrings --- .../attribute_binding/dataset_generator.py | 49 +++++++++++++++++++ .../attribute_binding/disentangled_vqa.py | 44 +++++++++++++++++ .../attribute_binding/judges/blip_vqa.py | 20 ++++++++ .../spatial_relationship_2d.py | 2 +- 4 files changed, 114 insertions(+), 1 deletion(-) diff --git a/hemm/metrics/attribute_binding/dataset_generator.py b/hemm/metrics/attribute_binding/dataset_generator.py index 8028042..b15d82f 100644 --- a/hemm/metrics/attribute_binding/dataset_generator.py +++ b/hemm/metrics/attribute_binding/dataset_generator.py @@ -25,6 +25,14 @@ class AttributeBindingEvaluationResponse(BaseModel): class AttributeBindingModel(weave.Model): + """Weave Model to generate prompts for evaluation of attribute binding capability of + image-generation models using an OpenAI model. + + Args: + openai_model (Optional[str]): The OpenAI model to use for generating prompts. + num_prompts (Optional[int]): Number of prompts to generate. + """ + openai_model: Optional[str] = "gpt-3.5-turbo" num_prompts: Optional[int] = 20 _openai_client: Optional[OpenAI] = None @@ -85,6 +93,12 @@ def _initialize(self): @weave.op() def predict(self, seed: int) -> Dict[str, str]: + """Generate prompts and corresponding metadata for evaluation of attribute binding + capability of image-generation models. + + Args: + seed (int): OpenAI seed to use for generating prompts. + """ return { "response": self._openai_client.chat.completions.create( model=self.openai_model, @@ -107,6 +121,36 @@ def predict(self, seed: int) -> Dict[str, str]: class AttributeBindingDatasetGenerator: + """Dataset generator for evaluation of attribute binding capability of image-generation models. + This class enables us to generate the dataset consisting of prompts in the format + `“a {adj_1} {noun_1} and a {adj_2} {noun_2}”` and the corresponding metadata using an LLM capable + of generating json objects like GPT4-O. The dataset is then published both as a + [W&B dataset artifact](https://docs.wandb.ai/guides/artifacts) and as a + [weave dataset](https://wandb.github.io/weave/guides/core-types/datasets). + + ??? example "Sample usage" + ```python + from hemm.metrics.attribute_binding import AttributeBindingDatasetGenerator + + dataset_generator = AttributeBindingDatasetGenerator( + openai_model="gpt-4o", + openai_seed=42, + num_prompts_in_single_call=20, + num_api_calls=50, + project_name="disentangled_vqa", + ) + + dataset_generator(dump_dir="./dump") + ``` + + Args: + openai_model (Optional[str]): The OpenAI model to use for generating prompts. + openai_seed (Optional[Union[int, List[int]]]): Seed to use for generating prompts. + If not provided, seeds will be auto-generated. + num_prompts_in_single_call (Optional[int]): Number of prompts to generate in a single API call. + num_api_calls (Optional[int]): Number of API calls to make. + project_name (Optional[str]): Name of the Weave project to use for logging the dataset. + """ def __init__( self, @@ -189,6 +233,11 @@ async def evaluate_generated_response( return eval_response.model_dump() def __call__(self, dump_dir: Optional[str] = "./dump") -> None: + """Generate the dataset and publish it to Weave. + + Args: + dump_dir (Optional[str]): Directory to dump the dataset. + """ wandb.init( project=self.project_name, job_type="attribute_binding_dataset", diff --git a/hemm/metrics/attribute_binding/disentangled_vqa.py b/hemm/metrics/attribute_binding/disentangled_vqa.py index eb4fe06..58ba772 100644 --- a/hemm/metrics/attribute_binding/disentangled_vqa.py +++ b/hemm/metrics/attribute_binding/disentangled_vqa.py @@ -7,6 +7,37 @@ class DisentangledVQAMetric(BaseMetric): + """Disentangled VQA metric to evaluate the attribute-binding capability + for image generation models as proposed in Section 4.1 from the paper + [T2I-CompBench: A Comprehensive Benchmark for Open-world Compositional Text-to-image Generation](https://arxiv.org/pdf/2307.06350). + + ??? example "Sample usage" + ```python + import wandb + import weave + + wandb.init(project=project, entity=entity, job_type="evaluation") + weave.init(project_name=project) + + diffusion_model = BaseDiffusionModel( + diffusion_model_name_or_path=diffusion_model_address, + enable_cpu_offfload=diffusion_model_enable_cpu_offfload, + image_height=image_size[0], + image_width=image_size[1], + ) + evaluation_pipeline = EvaluationPipeline(model=diffusion_model) + + judge = BlipVQAJudge() + metric = DisentangledVQAMetric(judge=judge, name="disentangled_blip_metric") + evaluation_pipeline.add_metric(metric) + + evaluation_pipeline(dataset=dataset) + ``` + + Args: + judge (Union[weave.Model, BlipVQAJudge]): The judge model to evaluate the attribute-binding capability. + name (Optional[str]): The name of the metric. Defaults to "disentangled_vlm_metric". + """ def __init__( self, @@ -30,6 +61,19 @@ def evaluate( noun_2: str, model_output: Dict[str, Any], ) -> Dict[str, Any]: + """Evaluate the attribute-binding capability of the model. + + Args: + prompt (str): The prompt for the model. + adj_1 (str): The first adjective. + noun_1 (str): The first noun. + adj_2 (str): The second adjective. + noun_2 (str): The second noun. + model_output (Dict[str, Any]): The model output. + + Returns: + Dict[str, Any]: The evaluation result. + """ _ = prompt judgement = self.judge.predict( adj_1, noun_1, adj_2, noun_2, model_output["image"] diff --git a/hemm/metrics/attribute_binding/judges/blip_vqa.py b/hemm/metrics/attribute_binding/judges/blip_vqa.py index 78b6cba..ca452e3 100644 --- a/hemm/metrics/attribute_binding/judges/blip_vqa.py +++ b/hemm/metrics/attribute_binding/judges/blip_vqa.py @@ -9,6 +9,14 @@ class BlipVQAJudge(weave.Model): + """Weave Model to judge the presence of entities in an image using the + [Blip-VQA model](https://huggingface.co/Salesforce/blip-vqa-base). + + Args: + blip_processor_address (str): The address of the BlipProcessor model. + blip_vqa_address (str): The address of the BlipForQuestionAnswering model. + device (str): The device to use for inference + """ blip_processor_address: str = "Salesforce/blip-vqa-base" blip_vqa_address: str = "Salesforce/blip-vqa-base" @@ -57,6 +65,18 @@ def get_target_token_probability( def predict( self, adj_1: str, noun_1: str, adj_2: str, noun_2: str, image: str ) -> Dict: + """Predict the probabilities presence of entities in an image using the Blip-VQA model. + + Args: + adj_1 (str): The adjective of the first entity. + noun_1 (str): The noun of the first entity. + adj_2 (str): The adjective of the second entity. + noun_2 (str): The noun of the second entity. + image (str): The base64 encoded image. + + Returns: + Dict: The probabilities of the presence of the entities. + """ question_1 = f"is {adj_1} {noun_1} present in the picture?" question_2 = f"is {adj_2} {noun_2} present in the picture?" return { diff --git a/hemm/metrics/spatial_relationship/spatial_relationship_2d.py b/hemm/metrics/spatial_relationship/spatial_relationship_2d.py index cbc2a56..1925f4b 100644 --- a/hemm/metrics/spatial_relationship/spatial_relationship_2d.py +++ b/hemm/metrics/spatial_relationship/spatial_relationship_2d.py @@ -13,7 +13,7 @@ class SpatialRelationshipMetric2D(BaseMetric): - """Spatial relationship metric for 2D images as proposed by Section 4.2 from the paper + """Spatial relationship metric for image generation as proposed in Section 4.2 from the paper [T2I-CompBench: A Comprehensive Benchmark for Open-world Compositional Text-to-image Generation](https://arxiv.org/pdf/2307.06350). ??? example "Sample usage"