Skip to content

Commit

Permalink
update: examples
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Oct 18, 2024
1 parent d8a9b91 commit 0d085d6
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions examples/multimodal_llm_eval/evaluate_mllm_metric_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Optional

import fire
import weave

import wandb
from hemm.eval_pipelines import EvaluationPipeline
from hemm.metrics.vqa import MultiModalLLMEvaluationMetric
from hemm.metrics.vqa.judges.mmllm_judges import OpenAIJudge, PromptCategory
from hemm.models import BaseDiffusionModel


def main(
project="mllm-eval",
entity="hemm-eval",
dataset_ref: Optional[str] = "attribute_binding_dataset:v1",
dataset_limit: Optional[int] = None,
diffusion_model_address: str = "stabilityai/stable-diffusion-2-1",
diffusion_model_enable_cpu_offfload: bool = False,
openai_judge_model: str = "gpt-4o",
image_height: int = 1024,
image_width: int = 1024,
num_inference_steps: int = 50,
mock_inference_dataset_address: Optional[str] = None,
save_inference_dataset_name: Optional[str] = None,
):
wandb.init(project=project, entity=entity, job_type="evaluation")
weave.init(project_name=f"{entity}/{project}")

dataset = weave.ref(dataset_ref).get()
dataset = dataset.rows[:dataset_limit] if dataset_limit else dataset

diffusion_model = BaseDiffusionModel(
diffusion_model_name_or_path=diffusion_model_address,
enable_cpu_offfload=diffusion_model_enable_cpu_offfload,
image_height=image_height,
image_width=image_width,
num_inference_steps=num_inference_steps,
)
diffusion_model._pipeline.set_progress_bar_config(disable=True)
evaluation_pipeline = EvaluationPipeline(
model=diffusion_model,
mock_inference_dataset_address=mock_inference_dataset_address,
save_inference_dataset_name=save_inference_dataset_name,
)

judge = OpenAIJudge(
prompt_property=PromptCategory.action, openai_model=openai_judge_model
)
metric = MultiModalLLMEvaluationMetric(judge=judge)
evaluation_pipeline.add_metric(metric)

evaluation_pipeline(dataset=dataset)
wandb.finish()
evaluation_pipeline.cleanup()


if __name__ == "__main__":
fire.Fire(main)

0 comments on commit 0d085d6

Please sign in to comment.