diff --git a/vlmeval/config.py b/vlmeval/config.py index 80dc810a..c0245cf6 100644 --- a/vlmeval/config.py +++ b/vlmeval/config.py @@ -292,6 +292,11 @@ 'Llama-3.2-90B-Vision-Instruct': partial(llama_vision, model_path='meta-llama/Llama-3.2-90B-Vision-Instruct'), } +molmo_series={ + 'molmo-7B-D-0924': partial(molmo, model_path='allenai/Molmo-7B-D-0924'), + 'molmo-7B-O-0924': partial(molmo, model_path='allenai/Molmo-7B-O-0924'), +} + kosmos_series={ 'Kosmos2': partial(Kosmos2, model_path='/root/kosmos-2-patch14-224') } @@ -309,8 +314,8 @@ deepseekvl_series, minicpm_series, cogvlm_series, wemm_series, cambrian_series, chameleon_series, video_models, ovis_series, vila_series, mantis_series, mmalaya_series, phi3_series, xgen_mm_series, qwen2vl_series, - slime_series, eagle_series, moondream_series, llama_series, kosmos_series, - points_series + slime_series, eagle_series, moondream_series, llama_series, molmo_series, + kosmos_series, points_series ] for grp in model_groups: diff --git a/vlmeval/vlm/__init__.py b/vlmeval/vlm/__init__.py index fe271001..2e46d49b 100644 --- a/vlmeval/vlm/__init__.py +++ b/vlmeval/vlm/__init__.py @@ -50,4 +50,5 @@ from .mplug_owl3 import mPLUG_Owl3 from .pixtral import Pixtral from .llama_vision import llama_vision +from .molmo import molmo from .points import POINTS diff --git a/vlmeval/vlm/molmo.py b/vlmeval/vlm/molmo.py new file mode 100644 index 00000000..711bcf2b --- /dev/null +++ b/vlmeval/vlm/molmo.py @@ -0,0 +1,67 @@ +import torch +from PIL import Image +import os.path as osp +import sys +from .base import BaseModel +from ..smp import * +from ..dataset import DATASET_TYPE + + +class molmo(BaseModel): + + INSTALL_REQ = False + INTERLEAVE = False + + def __init__(self, model_path='allenai/Molmo-7B-D-0924', **kwargs): + try: + from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig + import einops + except: + warnings.warn('Please install transformer and einops before using molmo.') + sys.exit(-1) + + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + torch_dtype='auto', + device_map='auto' + ) + self.processor = AutoProcessor.from_pretrained( + model_path, + trust_remote_code=True, + torch_dtype='auto', + device_map='auto' + ) + self.kwargs = kwargs + self.model_name = model_path + + def generate_inner(self, message, dataset=None): + from transformers import GenerationConfig + prompt, image_path = self.message_to_promptimg(message, dataset=dataset) + + image = Image.open(image_path) + if image.mode != "RGB": + image = image.convert("RGB") + # process the image and text + inputs = self.processor.process( + images=[image], + text=prompt + ) + + # move inputs to the correct device and make a batch of size 1 + inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()} + + # generate output; maximum 200 new tokens; stop generation when <|endoftext|> is generated + with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16): + output = self.model.generate_from_batch( + inputs, + GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"), + tokenizer=self.processor.tokenizer + ) + + # only get generated tokens; decode them to text + generated_tokens = output[0, inputs['input_ids'].size(1):] + generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) + + # print the generated text + return generated_text