Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fps sample in mmbench video #381

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def parse_args():
parser.add_argument('--nframe', type=int, default=8)
parser.add_argument('--pack', action='store_true')
parser.add_argument('--use-subtitle', action='store_true')
parser.add_argument('--fps', type=int, default=-1)
# Work Dir
parser.add_argument('--work-dir', type=str, default='.', help='select the output directory')
# Infer + Eval or Infer Only
Expand All @@ -35,6 +36,7 @@ def parse_args():
parser.add_argument('--ignore', action='store_true', help='Ignore failed indices. ')
# Rerun: will remove all evaluation temp files
parser.add_argument('--rerun', action='store_true')

args = parser.parse_args()
return args

Expand Down Expand Up @@ -100,6 +102,9 @@ def main():
subtitlestr = 'subs' if args.use_subtitle else 'nosubs'
result_file = f'{pred_root}/{model_name}_{dataset_name}_{args.nframe}frame_{packstr}_{subtitlestr}.xlsx'

if args.fps > 0:
result_file = result_file.replace('.xlsx', f'_fps{args.fps}.xlsx')

if dataset.TYPE == 'MT':
result_file = result_file.replace('.xlsx', '.tsv')

Expand All @@ -121,7 +126,8 @@ def main():
pack=args.pack,
verbose=args.verbose,
subtitle=args.use_subtitle,
api_nproc=args.nproc)
api_nproc=args.nproc,
fps=args.fps)
elif dataset.TYPE == 'MT':
model = infer_data_job_mt(
model,
Expand Down
7 changes: 6 additions & 1 deletion vlmeval/api/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self,
verbose: bool = True,
system_prompt: str = None,
temperature: float = 0,
timeout: int = 60,
timeout: int = 6000,
api_base: str = None,
max_tokens: int = 1024,
img_size: int = 512,
Expand Down Expand Up @@ -133,17 +133,22 @@ def __init__(self,
def prepare_itlist(self, inputs):
assert np.all([isinstance(x, dict) for x in inputs])
has_images = np.sum([x['type'] == 'image' for x in inputs])
img_counts = 0
if has_images:
content_list = []
for msg in inputs:
if msg['type'] == 'text':
content_list.append(dict(type='text', text=msg['value']))
elif msg['type'] == 'image':
if img_counts >= 250: # for gpt-4o-mini
continue

from PIL import Image
img = Image.open(msg['value'])
b64 = encode_image_to_base64(img, target_size=self.img_size)
img_struct = dict(url=f'data:image/jpeg;base64,{b64}', detail=self.img_detail)
content_list.append(dict(type='image_url', image_url=img_struct))
img_counts += 1
else:
assert all([x['type'] == 'text' for x in inputs])
text = '\n'.join([x['value'] for x in inputs])
Expand Down
1 change: 1 addition & 0 deletions vlmeval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
'GPT4o_HIGH': partial(GPT4V, model='gpt-4o-2024-05-13', temperature=0, img_size=-1, img_detail='high', retry=10),
'GPT4o_20240806': partial(GPT4V, model='gpt-4o-2024-08-06', temperature=0, img_size=-1, img_detail='high', retry=10),
'GPT4o_MINI': partial(GPT4V, model='gpt-4o-mini-2024-07-18', temperature=0, img_size=-1, img_detail='high', retry=10),
'GPT4o_MINI_LOW': partial(GPT4V, model='gpt-4o-mini-2024-07-18', temperature=0, img_size=512, img_detail='low', retry=10),
# Gemini
'GeminiProVision': partial(GeminiProVision, model='gemini-1.0-pro', temperature=0, retry=10),
'GeminiPro1-5': partial(GeminiProVision, model='gemini-1.5-pro', temperature=0, retry=10),
Expand Down
18 changes: 9 additions & 9 deletions vlmeval/dataset/mmbench_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def check_integrity(pth):

return dict(data_file=data_file, root=osp.join(dataset_path, 'video'))

def build_prompt_pack(self, line, num_frames):
def build_prompt_pack(self, line, num_frames, fps=-1):
if isinstance(line, int):
assert line < len(self)
video = self.videos[line]
Expand All @@ -97,9 +97,9 @@ def build_prompt_pack(self, line, num_frames):
elif isinstance(line, str):
video = line

frames = self.save_video_frames(video, num_frames)
frames = self.save_video_frames(video, num_frames, fps)
sub = self.data[self.data['video'] == video]
sys_prompt = self.SYS + self.FRAMES_TMPL_PACK.format(num_frames)
sys_prompt = self.SYS + self.FRAMES_TMPL_PACK.format(len(frames))
message = [dict(type='text', value=sys_prompt)]
for im in frames:
message.append(dict(type='image', value=im))
Expand All @@ -110,7 +110,7 @@ def build_prompt_pack(self, line, num_frames):
message.append(dict(type='text', value=prompt))
return message

def build_prompt_nopack(self, line, num_frames, video_llm):
def build_prompt_nopack(self, line, num_frames, video_llm, fps):
if isinstance(line, int):
assert line < len(self)
line = self.data.iloc[line]
Expand All @@ -121,20 +121,20 @@ def build_prompt_nopack(self, line, num_frames, video_llm):
message.append(dict(type='video', value=os.path.join(self.video_path, video_idx_path)))
return message
else:
frames = self.save_video_frames(line['video'], num_frames)
sys_prompt = self.FRAMES_TMPL_NOPACK.format(num_frames)
frames = self.save_video_frames(line['video'], num_frames, fps)
sys_prompt = self.FRAMES_TMPL_NOPACK.format(len(frames))
message = [dict(type='text', value=sys_prompt)]
for im in frames:
message.append(dict(type='image', value=im))
prompt = 'Question: {}\nAnswer: '.format(line['question'])
message.append(dict(type='text', value=prompt))
return message

def build_prompt(self, line, num_frames, video_llm):
def build_prompt(self, line, num_frames, video_llm, fps):
if self.pack and not video_llm:
return self.build_prompt_pack(line, num_frames)
return self.build_prompt_pack(line, num_frames, fps)
else:
return self.build_prompt_nopack(line, num_frames, video_llm)
return self.build_prompt_nopack(line, num_frames, video_llm, fps)

@staticmethod
def remove_side_quote(s, syms=[',', '"', "'"]):
Expand Down
66 changes: 52 additions & 14 deletions vlmeval/dataset/video_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self,
self.frame_root = osp.join(lmu_root, 'images', dataset)
os.makedirs(self.frame_root, exist_ok=True)
self.frame_tmpl = 'frame-{}-of-{}.jpg'
self.frame_tmpl_fps = 'frame-{}-of-{}-{}fps.jpg'

self.data_root = ret['root']
self.data_file = ret['data_file']
Expand Down Expand Up @@ -49,21 +50,58 @@ def frame_paths(self, video, num_frames=8):
os.makedirs(frame_root, exist_ok=True)
return [osp.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]

def save_video_frames(self, video, num_frames=8):
frame_paths = self.frame_paths(video, num_frames)
flag = np.all([osp.exists(p) for p in frame_paths])
if flag:
def frame_paths_fps(self, video, num_frames=8, fps=-1):
frame_root = osp.join(self.frame_root, video)
os.makedirs(frame_root, exist_ok=True)
return [osp.join(frame_root, self.frame_tmpl_fps.format(i, num_frames, fps)) for i in range(1, num_frames + 1)]

def save_video_frames(self, video, num_frames=8, fps=-1):
if fps > 0:
vid_path = osp.join(self.data_root, video + '.mp4')
vid = decord.VideoReader(vid_path)

# 计算视频的总帧数和总时长
total_frames = len(vid)
video_fps = vid.get_avg_fps()
total_duration = total_frames / video_fps

# 计算需要提取的总帧数
required_frames = int(total_duration * fps)

# 计算提取帧的间隔
step_size = video_fps / fps

# 计算提取帧的索引
indices = [int(i * step_size) for i in range(required_frames)]

# 提取帧并保存
frame_paths = self.frame_paths_fps(video, len(indices), fps)
flag = np.all([osp.exists(p) for p in frame_paths])
if flag:
return frame_paths

images = [vid[i].asnumpy() for i in indices]
images = [Image.fromarray(arr) for arr in images]
for im, pth in zip(images, frame_paths):
if not osp.exists(pth):
im.save(pth)
return frame_paths

else:
frame_paths = self.frame_paths(video, num_frames)
flag = np.all([osp.exists(p) for p in frame_paths])
if flag:
return frame_paths
vid_path = osp.join(self.data_root, video + '.mp4')
vid = decord.VideoReader(vid_path)
step_size = len(vid) / (num_frames + 1)
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
images = [vid[i].asnumpy() for i in indices]
images = [Image.fromarray(arr) for arr in images]
for im, pth in zip(images, frame_paths):
if not osp.exists(pth):
im.save(pth)
return frame_paths
vid_path = osp.join(self.data_root, video + '.mp4')
vid = decord.VideoReader(vid_path)
step_size = len(vid) / (num_frames + 1)
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
images = [vid[i].asnumpy() for i in indices]
images = [Image.fromarray(arr) for arr in images]
for im, pth in zip(images, frame_paths):
if not osp.exists(pth):
im.save(pth)
return frame_paths

# Return a list of dataset names that are supported by this class, can override
@classmethod
Expand Down
27 changes: 19 additions & 8 deletions vlmeval/inference_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,21 @@ def parse_args():


# Only API model is accepted
def infer_data_api(work_dir, model_name, dataset, nframe=8, pack=False, samples_dict={}, api_nproc=4):
def infer_data_api(work_dir, model_name, dataset, nframe=8, pack=False, samples_dict={}, api_nproc=4, fps=-1):
rank, world_size = get_rank_and_world_size()
assert rank == 0 and world_size == 1
dataset_name = dataset.dataset_name
model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
assert getattr(model, 'is_api', False)

indices = list(samples_dict.keys())
structs = [dataset.build_prompt(samples_dict[idx], num_frames=nframe) for idx in indices]
structs = [dataset.build_prompt(samples_dict[idx], num_frames=nframe,
video_llm=getattr(model, 'VIDEO_LLM', False), fps=fps) for idx in indices]

packstr = 'pack' if pack else 'nopack'
out_file = f'{work_dir}/{model_name}_{dataset_name}_{nframe}frame_{packstr}_supp.pkl'
if fps > 0:
out_file = out_file.replace('.pkl', f'_fps{fps}.pkl')
res = load(out_file) if osp.exists(out_file) else {}

structs = [s for i, s in zip(indices, structs) if i not in res]
Expand All @@ -45,7 +48,7 @@ def infer_data_api(work_dir, model_name, dataset, nframe=8, pack=False, samples_
return res


def infer_data(model_name, work_dir, dataset, out_file, nframe=8, pack=False, verbose=False, api_nproc=4):
def infer_data(model_name, work_dir, dataset, out_file, nframe=8, pack=False, verbose=False, api_nproc=4, fps=-1):
res = load(out_file) if osp.exists(out_file) else {}
rank, world_size = get_rank_and_world_size()
dataset_name = dataset.dataset_name
Expand All @@ -71,7 +74,8 @@ def infer_data(model_name, work_dir, dataset, out_file, nframe=8, pack=False, ve
nframe=nframe,
pack=pack,
samples_dict={k: sample_map[k] for k in sample_indices_subrem},
api_nproc=api_nproc)
api_nproc=api_nproc,
fps=fps)
for k in sample_indices_subrem:
assert k in supp
res.update(supp)
Expand All @@ -83,7 +87,8 @@ def infer_data(model_name, work_dir, dataset, out_file, nframe=8, pack=False, ve
continue
nframe = getattr(model, 'nframe', 0) if getattr(model, 'nframe', 0) > 0 else nframe
# when using video-llm, build prompt returns video+question; otherwise, several frames+question
struct = dataset.build_prompt(sample_map[idx], num_frames=nframe, video_llm=getattr(model, 'VIDEO_LLM', False))
struct = dataset.build_prompt(sample_map[idx], num_frames=nframe,
video_llm=getattr(model, 'VIDEO_LLM', False), fps=fps)
response = model.generate(message=struct, dataset=dataset_name)
torch.cuda.empty_cache()

Expand All @@ -109,7 +114,8 @@ def infer_data_job_video(
pack=False,
verbose=False,
subtitle=False,
api_nproc=4):
api_nproc=4,
fps=-1):

dataset_name = dataset.dataset_name
packstr = 'pack' if pack else 'nopack'
Expand All @@ -118,7 +124,8 @@ def infer_data_job_video(
if dataset_name == 'Video-MME':
subtitle_str = 'subs' if subtitle else 'nosubs'
result_file = result_file.replace('.xlsx', f'_{subtitle_str}.xlsx')

if fps > 0:
result_file = result_file.replace('.xlsx', f'_fps{fps}.xlsx')
# Dump Predictions to Prev File if result file exists
if osp.exists(result_file):
return model_name
Expand All @@ -127,6 +134,9 @@ def infer_data_job_video(
if dataset_name == 'Video-MME':
subtitle_str = 'subs' if subtitle else 'nosubs'
tmpl = tmpl.replace('.pkl', f'_{subtitle_str}.pkl')

if fps > 0:
tmpl = tmpl.replace('.pkl', f'_fps{fps}.pkl')
out_file = tmpl.format(rank)

model = infer_data(
Expand All @@ -137,7 +147,8 @@ def infer_data_job_video(
pack=pack,
out_file=out_file,
verbose=verbose,
api_nproc=api_nproc)
api_nproc=api_nproc,
fps=fps)

if world_size > 1:
dist.barrier()
Expand Down
Loading