Skip to content

Commit

Permalink
[Fix] Fix Qwen2-VL
Browse files Browse the repository at this point in the history
  • Loading branch information
kennymckormick committed Sep 24, 2024
1 parent 063a020 commit c83968e
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions vlmeval/vlm/qwen2_vl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,17 @@ def __init__(
assert model_path is not None
self.model_path = model_path
self.processor = Qwen2VLProcessor.from_pretrained(model_path)
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path, torch_dtype='auto', device='cpu', device_map='cpu', attn_implementation='flash_attention_2'
)
self.model.cuda().eval()
if '72b' not in self.model_path.lower():
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path, torch_dtype='auto', device_map='cpu', attn_implementation='flash_attention_2'
)
self.model.cuda().eval()
else:
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path, torch_dtype='auto', device_map='auto', attn_implementation='flash_attention_2'
)
self.model.cuda().eval()

torch.cuda.empty_cache()

def _prepare_content(self, inputs: list[dict[str, str]], dataset: str | None = None) -> list[dict[str, str]]:
Expand Down

0 comments on commit c83968e

Please sign in to comment.