From 0a7916913ba45138b73b2438498a7fceaba51eee Mon Sep 17 00:00:00 2001 From: FangXinyu-0913 Date: Sun, 11 Aug 2024 13:14:37 +0800 Subject: [PATCH 1/8] add support for mmbench-video api packed fps sample --- run.py | 8 +++- vlmeval/api/gpt.py | 4 +- vlmeval/dataset/mmbench_video.py | 10 ++--- vlmeval/dataset/video_base.py | 66 +++++++++++++++++++++++++------- vlmeval/inference_video.py | 25 ++++++++---- 5 files changed, 84 insertions(+), 29 deletions(-) diff --git a/run.py b/run.py index ad6539bd..1410bc1b 100644 --- a/run.py +++ b/run.py @@ -19,6 +19,7 @@ def parse_args(): parser.add_argument('--nframe', type=int, default=8) parser.add_argument('--pack', action='store_true') parser.add_argument('--use-subtitle', action='store_true') + parser.add_argument('--fps', type=int, default=-1) # Work Dir parser.add_argument('--work-dir', type=str, default='.', help='select the output directory') # Infer + Eval or Infer Only @@ -35,6 +36,7 @@ def parse_args(): parser.add_argument('--ignore', action='store_true', help='Ignore failed indices. ') # Rerun: will remove all evaluation temp files parser.add_argument('--rerun', action='store_true') + args = parser.parse_args() return args @@ -100,6 +102,9 @@ def main(): subtitlestr = 'subs' if args.use_subtitle else 'nosubs' result_file = f'{pred_root}/{model_name}_{dataset_name}_{args.nframe}frame_{packstr}_{subtitlestr}.xlsx' + if args.fps > 0: + result_file = result_file.replace('.xlsx', f'_fps{args.fps}.xlsx') + if dataset.TYPE == 'MT': result_file = result_file.replace('.xlsx', '.tsv') @@ -121,7 +126,8 @@ def main(): pack=args.pack, verbose=args.verbose, subtitle=args.use_subtitle, - api_nproc=args.nproc) + api_nproc=args.nproc, + fps=args.fps) elif dataset.TYPE == 'MT': model = infer_data_job_mt( model, diff --git a/vlmeval/api/gpt.py b/vlmeval/api/gpt.py index f308e7c1..0fd9423a 100644 --- a/vlmeval/api/gpt.py +++ b/vlmeval/api/gpt.py @@ -41,9 +41,9 @@ def __init__(self, verbose: bool = True, system_prompt: str = None, temperature: float = 0, - timeout: int = 60, + timeout: int = 6000, api_base: str = None, - max_tokens: int = 1024, + max_tokens: int = 409600, img_size: int = 512, img_detail: str = 'low', use_azure: bool = False, diff --git a/vlmeval/dataset/mmbench_video.py b/vlmeval/dataset/mmbench_video.py index cded905c..0fd63926 100644 --- a/vlmeval/dataset/mmbench_video.py +++ b/vlmeval/dataset/mmbench_video.py @@ -88,7 +88,7 @@ def check_integrity(pth): return dict(data_file=data_file, root=osp.join(dataset_path, 'video')) - def build_prompt_pack(self, line, num_frames): + def build_prompt_pack(self, line, num_frames, fps=-1): if isinstance(line, int): assert line < len(self) video = self.videos[line] @@ -97,9 +97,9 @@ def build_prompt_pack(self, line, num_frames): elif isinstance(line, str): video = line - frames = self.save_video_frames(video, num_frames) + frames = self.save_video_frames(video, num_frames, fps) sub = self.data[self.data['video'] == video] - sys_prompt = self.SYS + self.FRAMES_TMPL_PACK.format(num_frames) + sys_prompt = self.SYS + self.FRAMES_TMPL_PACK.format(len(frames)) message = [dict(type='text', value=sys_prompt)] for im in frames: message.append(dict(type='image', value=im)) @@ -130,9 +130,9 @@ def build_prompt_nopack(self, line, num_frames, video_llm): message.append(dict(type='text', value=prompt)) return message - def build_prompt(self, line, num_frames, video_llm): + def build_prompt(self, line, num_frames, video_llm, fps): if self.pack and not video_llm: - return self.build_prompt_pack(line, num_frames) + return self.build_prompt_pack(line, num_frames, fps) else: return self.build_prompt_nopack(line, num_frames, video_llm) diff --git a/vlmeval/dataset/video_base.py b/vlmeval/dataset/video_base.py index d2e0b4a5..678d8ef3 100644 --- a/vlmeval/dataset/video_base.py +++ b/vlmeval/dataset/video_base.py @@ -21,6 +21,7 @@ def __init__(self, self.frame_root = osp.join(lmu_root, 'images', dataset) os.makedirs(self.frame_root, exist_ok=True) self.frame_tmpl = 'frame-{}-of-{}.jpg' + self.frame_tmpl_fps = 'frame-{}-of-{}-{}fps.jpg' self.data_root = ret['root'] self.data_file = ret['data_file'] @@ -49,21 +50,58 @@ def frame_paths(self, video, num_frames=8): os.makedirs(frame_root, exist_ok=True) return [osp.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)] - def save_video_frames(self, video, num_frames=8): - frame_paths = self.frame_paths(video, num_frames) - flag = np.all([osp.exists(p) for p in frame_paths]) - if flag: + def frame_paths_fps(self, video, num_frames=8, fps=-1): + frame_root = osp.join(self.frame_root, video) + os.makedirs(frame_root, exist_ok=True) + return [osp.join(frame_root, self.frame_tmpl_fps.format(i, num_frames, fps)) for i in range(1, num_frames + 1)] + + def save_video_frames(self, video, num_frames=8, fps=-1): + if fps > 0: + vid_path = osp.join(self.data_root, video + '.mp4') + vid = decord.VideoReader(vid_path) + + # 计算视频的总帧数和总时长 + total_frames = len(vid) + video_fps = vid.get_avg_fps() + total_duration = total_frames / video_fps + + # 计算需要提取的总帧数 + required_frames = int(total_duration * fps) + + # 计算提取帧的间隔 + step_size = video_fps / fps + + # 计算提取帧的索引 + indices = [int(i * step_size) for i in range(required_frames)] + + # 提取帧并保存 + frame_paths = self.frame_paths_fps(video, len(indices), fps) + flag = np.all([osp.exists(p) for p in frame_paths]) + if flag: + return frame_paths + + images = [vid[i].asnumpy() for i in indices] + images = [Image.fromarray(arr) for arr in images] + for im, pth in zip(images, frame_paths): + if not osp.exists(pth): + im.save(pth) + return frame_paths + + else: + frame_paths = self.frame_paths(video, num_frames) + flag = np.all([osp.exists(p) for p in frame_paths]) + if flag: + return frame_paths + vid_path = osp.join(self.data_root, video + '.mp4') + vid = decord.VideoReader(vid_path) + step_size = len(vid) / (num_frames + 1) + indices = [int(i * step_size) for i in range(1, num_frames + 1)] + images = [vid[i].asnumpy() for i in indices] + images = [Image.fromarray(arr) for arr in images] + for im, pth in zip(images, frame_paths): + if not osp.exists(pth): + im.save(pth) return frame_paths - vid_path = osp.join(self.data_root, video + '.mp4') - vid = decord.VideoReader(vid_path) - step_size = len(vid) / (num_frames + 1) - indices = [int(i * step_size) for i in range(1, num_frames + 1)] - images = [vid[i].asnumpy() for i in indices] - images = [Image.fromarray(arr) for arr in images] - for im, pth in zip(images, frame_paths): - if not osp.exists(pth): - im.save(pth) - return frame_paths # Return a list of dataset names that are supported by this class, can override @classmethod diff --git a/vlmeval/inference_video.py b/vlmeval/inference_video.py index 825b69a9..4a06b509 100644 --- a/vlmeval/inference_video.py +++ b/vlmeval/inference_video.py @@ -18,7 +18,7 @@ def parse_args(): # Only API model is accepted -def infer_data_api(work_dir, model_name, dataset, nframe=8, pack=False, samples_dict={}, api_nproc=4): +def infer_data_api(work_dir, model_name, dataset, nframe=8, pack=False, samples_dict={}, api_nproc=4, fps=-1): rank, world_size = get_rank_and_world_size() assert rank == 0 and world_size == 1 dataset_name = dataset.dataset_name @@ -26,10 +26,12 @@ def infer_data_api(work_dir, model_name, dataset, nframe=8, pack=False, samples_ assert getattr(model, 'is_api', False) indices = list(samples_dict.keys()) - structs = [dataset.build_prompt(samples_dict[idx], num_frames=nframe) for idx in indices] + structs = [dataset.build_prompt(samples_dict[idx], num_frames=nframe, video_llm=getattr(model, 'VIDEO_LLM', False), fps=fps) for idx in indices] packstr = 'pack' if pack else 'nopack' out_file = f'{work_dir}/{model_name}_{dataset_name}_{nframe}frame_{packstr}_supp.pkl' + if fps > 0: + out_file = out_file.replace('.pkl', f'_fps{fps}.pkl') res = load(out_file) if osp.exists(out_file) else {} structs = [s for i, s in zip(indices, structs) if i not in res] @@ -45,7 +47,7 @@ def infer_data_api(work_dir, model_name, dataset, nframe=8, pack=False, samples_ return res -def infer_data(model_name, work_dir, dataset, out_file, nframe=8, pack=False, verbose=False, api_nproc=4): +def infer_data(model_name, work_dir, dataset, out_file, nframe=8, pack=False, verbose=False, api_nproc=4, fps=-1): res = load(out_file) if osp.exists(out_file) else {} rank, world_size = get_rank_and_world_size() dataset_name = dataset.dataset_name @@ -71,13 +73,16 @@ def infer_data(model_name, work_dir, dataset, out_file, nframe=8, pack=False, ve nframe=nframe, pack=pack, samples_dict={k: sample_map[k] for k in sample_indices_subrem}, - api_nproc=api_nproc) + api_nproc=api_nproc, + fps=fps) for k in sample_indices_subrem: assert k in supp res.update(supp) dump(res, out_file) return model_name + #TODO: not fix below with fps + for i, idx in tqdm(enumerate(sample_indices_subrem)): if idx in res: continue @@ -109,7 +114,8 @@ def infer_data_job_video( pack=False, verbose=False, subtitle=False, - api_nproc=4): + api_nproc=4, + fps=-1): dataset_name = dataset.dataset_name packstr = 'pack' if pack else 'nopack' @@ -118,7 +124,8 @@ def infer_data_job_video( if dataset_name == 'Video-MME': subtitle_str = 'subs' if subtitle else 'nosubs' result_file = result_file.replace('.xlsx', f'_{subtitle_str}.xlsx') - + if fps > 0: + result_file = result_file.replace('.xlsx', f'_fps{fps}.xlsx') # Dump Predictions to Prev File if result file exists if osp.exists(result_file): return model_name @@ -127,6 +134,9 @@ def infer_data_job_video( if dataset_name == 'Video-MME': subtitle_str = 'subs' if subtitle else 'nosubs' tmpl = tmpl.replace('.pkl', f'_{subtitle_str}.pkl') + + if fps > 0: + tmpl = tmpl.replace('.pkl', f'_fps{fps}.pkl') out_file = tmpl.format(rank) model = infer_data( @@ -137,7 +147,8 @@ def infer_data_job_video( pack=pack, out_file=out_file, verbose=verbose, - api_nproc=api_nproc) + api_nproc=api_nproc, + fps=fps) if world_size > 1: dist.barrier() From a20061b5aeb640f1c845021452405733106378d2 Mon Sep 17 00:00:00 2001 From: FangXinyu-0913 Date: Mon, 12 Aug 2024 15:57:26 +0800 Subject: [PATCH 2/8] add support for gpt-4o-low and modify for above 250 pictures --- vlmeval/api/gpt.py | 8 +++++++- vlmeval/config.py | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/vlmeval/api/gpt.py b/vlmeval/api/gpt.py index 0fd9423a..5c03e465 100644 --- a/vlmeval/api/gpt.py +++ b/vlmeval/api/gpt.py @@ -43,7 +43,7 @@ def __init__(self, temperature: float = 0, timeout: int = 6000, api_base: str = None, - max_tokens: int = 409600, + max_tokens: int = 1024, img_size: int = 512, img_detail: str = 'low', use_azure: bool = False, @@ -133,17 +133,22 @@ def __init__(self, def prepare_itlist(self, inputs): assert np.all([isinstance(x, dict) for x in inputs]) has_images = np.sum([x['type'] == 'image' for x in inputs]) + img_counts = 0 if has_images: content_list = [] for msg in inputs: if msg['type'] == 'text': content_list.append(dict(type='text', text=msg['value'])) elif msg['type'] == 'image': + if img_counts >= 250: # for gpt-4o-mini + continue + from PIL import Image img = Image.open(msg['value']) b64 = encode_image_to_base64(img, target_size=self.img_size) img_struct = dict(url=f'data:image/jpeg;base64,{b64}', detail=self.img_detail) content_list.append(dict(type='image_url', image_url=img_struct)) + img_counts += 1 else: assert all([x['type'] == 'text' for x in inputs]) text = '\n'.join([x['value'] for x in inputs]) @@ -171,6 +176,7 @@ def generate_inner(self, inputs, **kwargs) -> str: context_window = GPT_context_window(self.model) max_tokens = min(max_tokens, context_window - self.get_token_len(inputs)) + print(f'Context Window: {context_window}, Tokens: {max_tokens}, existing token length: {self.get_token_len(inputs)}') if 0 < max_tokens <= 100: self.logger.warning( 'Less than 100 tokens left, ' diff --git a/vlmeval/config.py b/vlmeval/config.py index 9f046e31..f6a069e3 100644 --- a/vlmeval/config.py +++ b/vlmeval/config.py @@ -44,6 +44,7 @@ 'GPT4o_HIGH': partial(GPT4V, model='gpt-4o-2024-05-13', temperature=0, img_size=-1, img_detail='high', retry=10), 'GPT4o_20240806': partial(GPT4V, model='gpt-4o-2024-08-06', temperature=0, img_size=-1, img_detail='high', retry=10), 'GPT4o_MINI': partial(GPT4V, model='gpt-4o-mini-2024-07-18', temperature=0, img_size=-1, img_detail='high', retry=10), + 'GPT4o_MINI_LOW': partial(GPT4V, model='gpt-4o-mini-2024-07-18', temperature=0, img_size=512, img_detail='low', retry=10), # Gemini 'GeminiProVision': partial(GeminiProVision, model='gemini-1.0-pro', temperature=0, retry=10), 'GeminiPro1-5': partial(GeminiProVision, model='gemini-1.5-pro', temperature=0, retry=10), From 76ce43ee6e835884820ff77aa9d3ab243019e9f3 Mon Sep 17 00:00:00 2001 From: FangXinyu-0913 Date: Mon, 19 Aug 2024 15:08:53 +0800 Subject: [PATCH 3/8] update mmbench video fps sample --- vlmeval/dataset/mmbench_video.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vlmeval/dataset/mmbench_video.py b/vlmeval/dataset/mmbench_video.py index 0fd63926..04521313 100644 --- a/vlmeval/dataset/mmbench_video.py +++ b/vlmeval/dataset/mmbench_video.py @@ -110,7 +110,7 @@ def build_prompt_pack(self, line, num_frames, fps=-1): message.append(dict(type='text', value=prompt)) return message - def build_prompt_nopack(self, line, num_frames, video_llm): + def build_prompt_nopack(self, line, num_frames, video_llm, fps): if isinstance(line, int): assert line < len(self) line = self.data.iloc[line] @@ -121,8 +121,8 @@ def build_prompt_nopack(self, line, num_frames, video_llm): message.append(dict(type='video', value=os.path.join(self.video_path, video_idx_path))) return message else: - frames = self.save_video_frames(line['video'], num_frames) - sys_prompt = self.FRAMES_TMPL_NOPACK.format(num_frames) + frames = self.save_video_frames(line['video'], num_frames, fps) + sys_prompt = self.FRAMES_TMPL_NOPACK.format(len(frames)) message = [dict(type='text', value=sys_prompt)] for im in frames: message.append(dict(type='image', value=im)) @@ -134,7 +134,7 @@ def build_prompt(self, line, num_frames, video_llm, fps): if self.pack and not video_llm: return self.build_prompt_pack(line, num_frames, fps) else: - return self.build_prompt_nopack(line, num_frames, video_llm) + return self.build_prompt_nopack(line, num_frames, video_llm, fps) @staticmethod def remove_side_quote(s, syms=[',', '"', "'"]): From ddb1e95d25f776f57767553180e92bd8f26eddca Mon Sep 17 00:00:00 2001 From: FangXinyu-0913 Date: Mon, 19 Aug 2024 15:17:56 +0800 Subject: [PATCH 4/8] update with lint --- vlmeval/api/gpt.py | 1 - vlmeval/inference_video.py | 7 +++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/vlmeval/api/gpt.py b/vlmeval/api/gpt.py index 5c03e465..0cef2adf 100644 --- a/vlmeval/api/gpt.py +++ b/vlmeval/api/gpt.py @@ -176,7 +176,6 @@ def generate_inner(self, inputs, **kwargs) -> str: context_window = GPT_context_window(self.model) max_tokens = min(max_tokens, context_window - self.get_token_len(inputs)) - print(f'Context Window: {context_window}, Tokens: {max_tokens}, existing token length: {self.get_token_len(inputs)}') if 0 < max_tokens <= 100: self.logger.warning( 'Less than 100 tokens left, ' diff --git a/vlmeval/inference_video.py b/vlmeval/inference_video.py index 4a06b509..bc63043e 100644 --- a/vlmeval/inference_video.py +++ b/vlmeval/inference_video.py @@ -26,7 +26,8 @@ def infer_data_api(work_dir, model_name, dataset, nframe=8, pack=False, samples_ assert getattr(model, 'is_api', False) indices = list(samples_dict.keys()) - structs = [dataset.build_prompt(samples_dict[idx], num_frames=nframe, video_llm=getattr(model, 'VIDEO_LLM', False), fps=fps) for idx in indices] + structs = [dataset.build_prompt(samples_dict[idx], num_frames=nframe, + video_llm=getattr(model, 'VIDEO_LLM', False), fps=fps) for idx in indices] packstr = 'pack' if pack else 'nopack' out_file = f'{work_dir}/{model_name}_{dataset_name}_{nframe}frame_{packstr}_supp.pkl' @@ -81,14 +82,12 @@ def infer_data(model_name, work_dir, dataset, out_file, nframe=8, pack=False, ve dump(res, out_file) return model_name - #TODO: not fix below with fps - for i, idx in tqdm(enumerate(sample_indices_subrem)): if idx in res: continue nframe = getattr(model, 'nframe', 0) if getattr(model, 'nframe', 0) > 0 else nframe # when using video-llm, build prompt returns video+question; otherwise, several frames+question - struct = dataset.build_prompt(sample_map[idx], num_frames=nframe, video_llm=getattr(model, 'VIDEO_LLM', False)) + struct = dataset.build_prompt(sample_map[idx], num_frames=nframe, video_llm=getattr(model, 'VIDEO_LLM', False), fps=fps) response = model.generate(message=struct, dataset=dataset_name) torch.cuda.empty_cache() From 9d69e276a06df9b5bccb618ba56001c44586b16e Mon Sep 17 00:00:00 2001 From: FangXinyu-0913 Date: Mon, 19 Aug 2024 15:20:09 +0800 Subject: [PATCH 5/8] update lint --- vlmeval/inference_video.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vlmeval/inference_video.py b/vlmeval/inference_video.py index bc63043e..5d45d84e 100644 --- a/vlmeval/inference_video.py +++ b/vlmeval/inference_video.py @@ -87,7 +87,8 @@ def infer_data(model_name, work_dir, dataset, out_file, nframe=8, pack=False, ve continue nframe = getattr(model, 'nframe', 0) if getattr(model, 'nframe', 0) > 0 else nframe # when using video-llm, build prompt returns video+question; otherwise, several frames+question - struct = dataset.build_prompt(sample_map[idx], num_frames=nframe, video_llm=getattr(model, 'VIDEO_LLM', False), fps=fps) + struct = dataset.build_prompt(sample_map[idx], num_frames=nframe, + video_llm=getattr(model, 'VIDEO_LLM', False), fps=fps) response = model.generate(message=struct, dataset=dataset_name) torch.cuda.empty_cache() From 052f58b84c9a86a4fb356b90c7aaa3b02a6adb87 Mon Sep 17 00:00:00 2001 From: FangXinyu-0913 Date: Sat, 12 Oct 2024 21:47:56 +0800 Subject: [PATCH 6/8] update with video-mme and mvbench_mp4 --- run.py | 19 ++++++++-- vlmeval/dataset/mvbench.py | 51 ++++++++++++++++++-------- vlmeval/dataset/videomme.py | 23 ++++++++---- vlmeval/inference_video.py | 35 ++++++++++++++---- vlmeval/vlm/video_llm/chat_uni_vi.py | 6 ++- vlmeval/vlm/video_llm/llama_vid.py | 9 +++-- vlmeval/vlm/video_llm/video_chatgpt.py | 2 +- vlmeval/vlm/video_llm/video_llava.py | 4 +- 8 files changed, 104 insertions(+), 45 deletions(-) diff --git a/run.py b/run.py index 36f3f832..70e77abb 100644 --- a/run.py +++ b/run.py @@ -19,7 +19,7 @@ def parse_args(): parser.add_argument('--nframe', type=int, default=8) parser.add_argument('--pack', action='store_true') parser.add_argument('--use-subtitle', action='store_true') - parser.add_argument('--fps', type=int, default=-1) + parser.add_argument('--fps', type=float, default=-1) # Work Dir parser.add_argument('--work-dir', type=str, default='./outputs', help='select the output directory') # Infer + Eval or Infer Only @@ -95,15 +95,25 @@ def main(): continue result_file = f'{pred_root}/{model_name}_{dataset_name}.xlsx' + if args.fps > 0: # For Video Dataset, set the fps for priority + if dataset_name == 'MVBench': + raise ValueError('MVBench does not support fps setting, please transfer to MVBench_MP4!') + args.nframe = 0 if dataset_name in ['MMBench-Video']: packstr = 'pack' if args.pack else 'nopack' - result_file = f'{pred_root}/{model_name}_{dataset_name}_{args.nframe}frame_{packstr}.xlsx' + if args.nframe > 0: + result_file = f'{pred_root}/{model_name}_{dataset_name}_{args.nframe}frame_{packstr}.xlsx' + else: + result_file = f'{pred_root}/{model_name}_{dataset_name}_{args.fps}fps_{packstr}.xlsx' elif dataset.MODALITY == 'VIDEO': if args.pack: logger.info(f'{dataset_name} not support Pack Mode, directly change to unpack') args.pack = False packstr = 'pack' if args.pack else 'nopack' - result_file = f'{pred_root}/{model_name}_{dataset_name}_{args.nframe}frame_{packstr}.xlsx' + if args.nframe > 0: + result_file = f'{pred_root}/{model_name}_{dataset_name}_{args.nframe}frame_{packstr}.xlsx' + else: + result_file = f'{pred_root}/{model_name}_{dataset_name}_{args.fps}fps_{packstr}.xlsx' if dataset_name in ['Video-MME']: subtitlestr = 'subs' if args.use_subtitle else 'nosubs' result_file = result_file.replace('.xlsx', f'_{subtitlestr}.xlsx') @@ -129,7 +139,8 @@ def main(): pack=args.pack, verbose=args.verbose, subtitle=args.use_subtitle, - api_nproc=args.nproc) + api_nproc=args.nproc, + fps=args.fps) elif dataset.TYPE == 'MT': model = infer_data_job_mt( model, diff --git a/vlmeval/dataset/mvbench.py b/vlmeval/dataset/mvbench.py index ad9b59bd..8f022c57 100644 --- a/vlmeval/dataset/mvbench.py +++ b/vlmeval/dataset/mvbench.py @@ -23,7 +23,7 @@ class MVBench(VideoBaseDataset): - MD5 = 'ae2a2607e2f8618155709220c6e927a6' + MD5 = 'fd21d36522cdedd46d84dc46715ad832' SYS = """Carefully watch the video and pay attention to the cause and sequence of events, \ the detail and movement of objects, and the action and pose of persons. \ Based on your observations, select the best option that accurately addresses the question. @@ -123,7 +123,7 @@ def generate_tsv(pth): for data in json_data: self.data_list.append({ 'task_type': k, - 'prefix': v[1].replace('your_data_path', os.path.join(dataset_path, 'video')), + 'prefix': v[1].replace('your_data_path', 'video'), 'data_type': v[2], 'bound': v[3], 'start': data['start'] if 'start' in data.keys() else None, @@ -274,7 +274,7 @@ def qa_template(self, data): return question, answer def load_into_video_and_process(self, line): - video_path = os.path.join(line['prefix'], line['video']) + video_path = os.path.join(self.data_root, line['prefix'], line['video']) if line['data_type'] in ['gif'] or os.path.splitext(video_path)[1] in ['.webm']: processed_video_path = video_path.replace(os.path.splitext(video_path)[1], '.mp4') @@ -315,7 +315,7 @@ def save_video_into_images(self, line, num_frames): line['start'], line['end'], ) - video_path = os.path.join(line['prefix'], line['video']) + video_path = os.path.join(self.data_root, line['prefix'], line['video']) decord_method = self.decord_method[line['data_type']] self.num_segments = num_frames if num_frames > 0 else self.nframe torch_imgs = decord_method(video_path, bound) @@ -510,7 +510,7 @@ def qa_template(self, data): answer = f"({chr(ord('A') + answer_idx)}) {answer}" return question, answer - def get_index(self, max_frame): + def get_index_by_frame(self, max_frame): seg_size = float(max_frame) / self.num_segments frame_indices = np.array([ int((seg_size / 2) + np.round(seg_size * idx)) @@ -518,21 +518,37 @@ def get_index(self, max_frame): ]) return frame_indices - def read_video(self, video_path, bound=None): + def get_index_by_fps(self, vid, fps): + total_frames = len(vid) + video_fps = vid.get_avg_fps() + total_duration = total_frames / video_fps + required_frames = int(total_duration * fps) + step_size = video_fps / fps + frame_indices = np.array([int(i * step_size) for i in range(required_frames)]) + self.num_segments = len(frame_indices) + return frame_indices + + def read_video(self, video_path, fps=-1): vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) max_frame = len(vr) - 1 images_group = list() - frame_indices = self.get_index(max_frame) + if fps < 0: + frame_indices = self.get_index_by_frame(max_frame) + else: + frame_indices = self.get_index_by_fps(vr, fps) + for frame_index in frame_indices: img = Image.fromarray(vr[frame_index].asnumpy()) images_group.append(img) torch_imgs = self.transform(images_group) return torch_imgs - def save_video_frames(self, imgs, video_name, frames): - - frame_paths = self.frame_paths(video_name, frames) + def save_video_frames(self, imgs, video_name, frames, fps): + if fps > 0: + frame_paths = self.frame_paths_fps(video_name, frames, fps) + else: + frame_paths = self.frame_paths(video_name, frames) flag = np.all([osp.exists(p) for p in frame_paths]) if not flag: @@ -546,14 +562,17 @@ def save_video_frames(self, imgs, video_name, frames): return frame_paths - def save_video_into_images(self, line, num_frames): + def save_video_into_images(self, line, num_frames, fps=-1): video_path = os.path.join(self.data_root, line['prefix'], line['video']) - self.num_segments = num_frames if num_frames > 0 else self.nframe - torch_imgs = self.read_video(video_path) - img_frame_paths = self.save_video_frames(torch_imgs, line['video'], self.num_segments) + if fps <= 0: + self.num_segments = num_frames if num_frames > 0 else self.nframe + else: + self.num_segments = 0 + torch_imgs = self.read_video(video_path, fps) + img_frame_paths = self.save_video_frames(torch_imgs, line['video'], self.num_segments, fps) return img_frame_paths - def build_prompt(self, line, num_frames, video_llm): + def build_prompt(self, line, num_frames, video_llm, fps): if isinstance(line, int): assert line < len(self) line = self.data.iloc[line] @@ -565,7 +584,7 @@ def build_prompt(self, line, num_frames, video_llm): if video_llm: message.append(dict(type='video', value=video_path)) else: - img_frame_paths = self.save_video_into_images(line, num_frames) + img_frame_paths = self.save_video_into_images(line, num_frames, fps) for im in img_frame_paths: message.append(dict(type='image', value=im)) message.append(dict(type='text', value='\nOnly give the best option.')) diff --git a/vlmeval/dataset/videomme.py b/vlmeval/dataset/videomme.py index b8a29afe..985b03b0 100644 --- a/vlmeval/dataset/videomme.py +++ b/vlmeval/dataset/videomme.py @@ -148,26 +148,33 @@ def generate_tsv(pth): return dict(data_file=data_file, root=dataset_path) - def save_video_frames(self, video, num_frames=8): + def save_video_frames(self, video, num_frames=8, fps=-1, video_llm=False): vid_path = osp.join(self.data_root, 'video', video + '.mp4') vid = decord.VideoReader(vid_path) - step_size = len(vid) / (num_frames + 1) - indices = [int(i * step_size) for i in range(1, num_frames + 1)] - video_info = { 'fps': vid.get_avg_fps(), 'n_frames': len(vid), } + if num_frames > 0 and fps < 0: + step_size = len(vid) / (num_frames + 1) + indices = [int(i * step_size) for i in range(1, num_frames + 1)] + frame_paths = self.frame_paths(video, num_frames) + elif fps > 0: + # not constrained by num_frames, get frames by fps + total_duration = video_info['n_frames'] / video_info['fps'] + required_frames = int(total_duration * fps) + step_size = video_info['fps'] / fps + indices = [int(i * step_size) for i in range(required_frames)] + frame_paths = self.frame_paths_fps(video, len(indices), fps) - frame_paths = self.frame_paths(video, num_frames) flag = np.all([osp.exists(p) for p in frame_paths]) if not flag: images = [vid[i].asnumpy() for i in indices] images = [Image.fromarray(arr) for arr in images] for im, pth in zip(images, frame_paths): - if not osp.exists(pth): + if not osp.exists(pth) and not video_llm: im.save(pth) return frame_paths, indices, video_info @@ -176,12 +183,12 @@ def save_video_into_images(self, line, num_frames=8): frame_paths, indices, video_info = self.save_video_frames(line['video'], num_frames) return frame_paths - def build_prompt(self, line, num_frames, video_llm): + def build_prompt(self, line, num_frames, video_llm, fps): if isinstance(line, int): assert line < len(self) line = self.data.iloc[line] - frames, indices, video_info = self.save_video_frames(line['video'], num_frames) + frames, indices, video_info = self.save_video_frames(line['video'], num_frames, fps, video_llm) if self.use_subtitle and os.path.exists(osp.join(self.data_root, line['subtitle_path'])): import pysubs2 diff --git a/vlmeval/inference_video.py b/vlmeval/inference_video.py index a94037db..ff697c19 100644 --- a/vlmeval/inference_video.py +++ b/vlmeval/inference_video.py @@ -30,9 +30,10 @@ def infer_data_api(work_dir, model_name, dataset, nframe=8, pack=False, samples_ video_llm=getattr(model, 'VIDEO_LLM', False), fps=fps) for idx in indices] packstr = 'pack' if pack else 'nopack' - out_file = f'{work_dir}/{model_name}_{dataset_name}_{nframe}frame_{packstr}_supp.pkl' - if fps > 0: - out_file = out_file.replace('.pkl', f'_fps{fps}.pkl') + if nframe > 0: + out_file = f'{work_dir}/{model_name}_{dataset_name}_{nframe}frame_{packstr}_supp.pkl' + else: + out_file = f'{work_dir}/{model_name}_{dataset_name}_{fps}fps_{packstr}_supp.pkl' res = load(out_file) if osp.exists(out_file) else {} structs = [s for i, s in zip(indices, structs) if i not in res or res[i] == FAIL_MSG] @@ -85,10 +86,22 @@ def infer_data(model_name, work_dir, dataset, out_file, nframe=8, pack=False, ve for i, idx in tqdm(enumerate(sample_indices_subrem)): if idx in res: continue - # adapt to model frame sample number first - nframe = getattr(model, 'nframe', 0) if getattr(model, 'nframe', 0) > 0 else nframe - # when using video-llm, build prompt returns video+question; otherwise, several frames+question + if getattr(model, 'nframe', 0) > 0: + if nframe > 0: + print(f'{model_name} is a video-llm model, nframe is set to {nframe}, not using default') + setattr(model, 'nframe', nframe) + else: + raise ValueError(f'nframe is not suitable for {model_name}') + if getattr(model, 'fps', 0) > 0: + if fps > 0: + print(f'{model_name} is a video-llm model, fps is set to {fps}, not using default') + setattr(model, 'fps', fps) + else: + raise ValueError(f'fps is not suitable for {model_name}') + if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name): + if nframe == 0: + raise ValueError(f'nframe must be set for custom prompt, fps is not suitable for {model_name}') struct = model.build_prompt( dataset.data.iloc[sample_map[idx]], dataset=dataset, num_frames=nframe, video_llm=getattr(model, 'VIDEO_LLM', False) @@ -129,7 +142,10 @@ def infer_data_job_video( dataset_name = dataset.dataset_name packstr = 'pack' if pack else 'nopack' rank, world_size = get_rank_and_world_size() - result_file = osp.join(work_dir, f'{model_name}_{dataset_name}_{nframe}frame_{packstr}.xlsx') + if nframe > 0: + result_file = osp.join(work_dir, f'{model_name}_{dataset_name}_{nframe}frame_{packstr}.xlsx') + else: + result_file = osp.join(work_dir, f'{model_name}_{dataset_name}_{fps}fps_{packstr}.xlsx') if dataset_name == 'Video-MME': subtitle_str = 'subs' if subtitle else 'nosubs' result_file = result_file.replace('.xlsx', f'_{subtitle_str}.xlsx') @@ -139,7 +155,10 @@ def infer_data_job_video( if osp.exists(result_file): return model_name - tmpl = osp.join(work_dir, '{}' + f'{world_size}_{dataset_name}_{nframe}frame_{packstr}.pkl') + if nframe > 0: + tmpl = osp.join(work_dir, '{}' + f'{world_size}_{dataset_name}_{nframe}frame_{packstr}.pkl') + else: + tmpl = osp.join(work_dir, '{}' + f'{world_size}_{dataset_name}_{fps}fps_{packstr}.pkl') if dataset_name == 'Video-MME': subtitle_str = 'subs' if subtitle else 'nosubs' tmpl = tmpl.replace('.pkl', f'_{subtitle_str}.pkl') diff --git a/vlmeval/vlm/video_llm/chat_uni_vi.py b/vlmeval/vlm/video_llm/chat_uni_vi.py index 48a1c856..cd6b92ef 100644 --- a/vlmeval/vlm/video_llm/chat_uni_vi.py +++ b/vlmeval/vlm/video_llm/chat_uni_vi.py @@ -86,6 +86,7 @@ class Chatunivi(BaseModel): INSTALL_REQ = True INTERLEAVE = False VIDEO_LLM = True + # sample 1 fps (maximum 64 frames) from the video def __init__(self, model_path='Chat-UniVi/Chat-UniVi', **kwargs): assert model_path is not None @@ -105,7 +106,7 @@ def __init__(self, model_path='Chat-UniVi/Chat-UniVi', **kwargs): self.processor = image_processor self.context_len = context_len self.kwargs = kwargs - self.nframe = 64 + self.fps = 1 self.resolution = 224 if 'v1.5' in model_path: self.resolution = 336 @@ -138,7 +139,8 @@ def get_model_output(self, model, video_processor, tokenizer, video, qs): m = m.to(dtype=torch.bfloat16) video_frames, slice_len = _get_rawvideo_dec( - video, video_processor, max_frames=MAX_IMAGE_LENGTH, image_resolution=self.resolution + video, video_processor, max_frames=MAX_IMAGE_LENGTH, + image_resolution=self.resolution, video_framerate=self.fps ) if model.config.mm_use_im_start_end: diff --git a/vlmeval/vlm/video_llm/llama_vid.py b/vlmeval/vlm/video_llm/llama_vid.py index dc7a01b1..d171d8df 100644 --- a/vlmeval/vlm/video_llm/llama_vid.py +++ b/vlmeval/vlm/video_llm/llama_vid.py @@ -11,11 +11,11 @@ from huggingface_hub import snapshot_download -def load_video(video_path): +def load_video(video_path, setting_fps): vr = VideoReader(video_path, ctx=cpu(0)) total_frame_num = len(vr) fps = round(vr.get_avg_fps()) - frame_idx = [i for i in range(0, total_frame_num, fps)] + frame_idx = [i for i in range(0, total_frame_num, int(fps / setting_fps))] spare_frames = vr.get_batch(frame_idx).asnumpy() return spare_frames @@ -31,6 +31,7 @@ class LLaMAVID(BaseModel): INSTALL_REQ = True INTERLEAVE = False VIDEO_LLM = True + # sample 1 fps from the video def __init__(self, model_path='YanweiLi/llama-vid-7b-full-224-video-fps-1', **kwargs): assert model_path is not None @@ -61,7 +62,7 @@ def __init__(self, model_path='YanweiLi/llama-vid-7b-full-224-video-fps-1', **kw self.processor = image_processor self.context_len = context_len self.kwargs = kwargs - self.nframe = 8 + self.fps = 1 def get_model_output(self, model, video_processor, tokenizer, video, qs): from llamavid.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN @@ -83,7 +84,7 @@ def get_model_output(self, model, video_processor, tokenizer, video, qs): # Check if the video exists if os.path.exists(video): - video = load_video(video) + video = load_video(video, self.fps) video = video_processor.preprocess(video, return_tensors='pt')['pixel_values'].half().cuda() video = [video] diff --git a/vlmeval/vlm/video_llm/video_chatgpt.py b/vlmeval/vlm/video_llm/video_chatgpt.py index 2829bcee..3cf6a39f 100644 --- a/vlmeval/vlm/video_llm/video_chatgpt.py +++ b/vlmeval/vlm/video_llm/video_chatgpt.py @@ -14,6 +14,7 @@ class VideoChatGPT(BaseModel): INSTALL_REQ = True INTERLEAVE = False VIDEO_LLM = True + # sample a video in 100 frames def __init__(self, model_path='MBZUAI/Video-ChatGPT-7B', dir_root=None, **kwargs): assert model_path is not None @@ -40,7 +41,6 @@ def __init__(self, model_path='MBZUAI/Video-ChatGPT-7B', dir_root=None, **kwargs self.context_len = video_token_len self.kwargs = kwargs self.vision_tower = vision_tower - self.nframe = 8 def get_model_output(self, model, video_processor, tokenizer, video, qs): from video_chatgpt.eval.model_utils import load_video diff --git a/vlmeval/vlm/video_llm/video_llava.py b/vlmeval/vlm/video_llm/video_llava.py index 19dfe18d..72a82ea9 100644 --- a/vlmeval/vlm/video_llm/video_llava.py +++ b/vlmeval/vlm/video_llm/video_llava.py @@ -26,6 +26,7 @@ class VideoLLaVA_HF(BaseModel): INSTALL_REQ = False INTERLEAVE = False VIDEO_LLM = True + # sample a video in 8 frames def __init__(self, model_path='LanguageBind/Video-LLaVA-7B-hf', **kwargs): try: @@ -42,7 +43,6 @@ def __init__(self, model_path='LanguageBind/Video-LLaVA-7B-hf', **kwargs): self.model.eval().cuda() self.processor = VideoLlavaProcessor.from_pretrained(model_path) self.kwargs = kwargs - self.nframe = 8 torch.cuda.empty_cache() def generate_inner(self, message, dataset=None): @@ -81,6 +81,7 @@ class VideoLLaVA(BaseModel): INSTALL_REQ = True INTERLEAVE = False VIDEO_LLM = True + # sample a video in 8 frames def __init__(self, model_path='LanguageBind/Video-LLaVA-7B', **kwargs): assert model_path is not None @@ -104,7 +105,6 @@ def __init__(self, model_path='LanguageBind/Video-LLaVA-7B', **kwargs): self.processor = processor self.context_len = context_len self.kwargs = kwargs - self.nframe = 8 def get_model_output(self, model, video_processor, tokenizer, video, qs): from videollava.conversation import conv_templates, SeparatorStyle From 0d1dccb474a2dad813f7d9c81d3118272e5bc57f Mon Sep 17 00:00:00 2001 From: root Date: Sat, 12 Oct 2024 22:03:19 +0800 Subject: [PATCH 7/8] fix back gpt --- vlmeval/api/gpt.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/vlmeval/api/gpt.py b/vlmeval/api/gpt.py index 0cef2adf..f308e7c1 100644 --- a/vlmeval/api/gpt.py +++ b/vlmeval/api/gpt.py @@ -41,7 +41,7 @@ def __init__(self, verbose: bool = True, system_prompt: str = None, temperature: float = 0, - timeout: int = 6000, + timeout: int = 60, api_base: str = None, max_tokens: int = 1024, img_size: int = 512, @@ -133,22 +133,17 @@ def __init__(self, def prepare_itlist(self, inputs): assert np.all([isinstance(x, dict) for x in inputs]) has_images = np.sum([x['type'] == 'image' for x in inputs]) - img_counts = 0 if has_images: content_list = [] for msg in inputs: if msg['type'] == 'text': content_list.append(dict(type='text', text=msg['value'])) elif msg['type'] == 'image': - if img_counts >= 250: # for gpt-4o-mini - continue - from PIL import Image img = Image.open(msg['value']) b64 = encode_image_to_base64(img, target_size=self.img_size) img_struct = dict(url=f'data:image/jpeg;base64,{b64}', detail=self.img_detail) content_list.append(dict(type='image_url', image_url=img_struct)) - img_counts += 1 else: assert all([x['type'] == 'text' for x in inputs]) text = '\n'.join([x['value'] for x in inputs]) From 5f154da5f692606920c411d0da4cecc7446ef237 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 12 Oct 2024 22:05:09 +0800 Subject: [PATCH 8/8] fix back config --- vlmeval/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vlmeval/config.py b/vlmeval/config.py index d7c4c2f7..c0245cf6 100644 --- a/vlmeval/config.py +++ b/vlmeval/config.py @@ -59,7 +59,6 @@ 'GPT4o_HIGH': partial(GPT4V, model='gpt-4o-2024-05-13', temperature=0, img_size=-1, img_detail='high', retry=10), 'GPT4o_20240806': partial(GPT4V, model='gpt-4o-2024-08-06', temperature=0, img_size=-1, img_detail='high', retry=10), 'GPT4o_MINI': partial(GPT4V, model='gpt-4o-mini-2024-07-18', temperature=0, img_size=-1, img_detail='high', retry=10), - 'GPT4o_MINI_LOW': partial(GPT4V, model='gpt-4o-mini-2024-07-18', temperature=0, img_size=512, img_detail='low', retry=10), # Gemini 'GeminiPro1-0': partial(GeminiProVision, model='gemini-1.0-pro', temperature=0, retry=10), # now GeminiPro1-0 is only supported by vertex backend 'GeminiPro1-5': partial(GeminiProVision, model='gemini-1.5-pro', temperature=0, retry=10),