diff --git a/README.md b/README.md index 32e86665..02cd97a3 100644 --- a/README.md +++ b/README.md @@ -259,7 +259,7 @@ Bug finding and pull requests are also highly appreciated to keep this project g * [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation) -* [ ] Allow silero-vad as alternative VAD option +* [x] Allow silero-vad as alternative VAD option * [ ] Improve diarization (word level). *Harder than first thought...* @@ -281,7 +281,9 @@ Borrows important alignment code from [PyTorch tutorial on forced alignment](htt And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio -Valuable VAD & Diarization Models from [pyannote audio][https://github.com/pyannote/pyannote-audio] +Valuable VAD & Diarization Models from: +- [pyannote audio][https://github.com/pyannote/pyannote-audio] +- [silero vad][https://github.com/snakers4/silero-vad] Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2) diff --git a/whisperx/__init__.py b/whisperx/__init__.py index 20abaaed..92e6d424 100644 --- a/whisperx/__init__.py +++ b/whisperx/__init__.py @@ -1,4 +1,4 @@ -from .transcribe import load_model from .alignment import load_align_model, align from .audio import load_audio -from .diarize import assign_word_speakers, DiarizationPipeline \ No newline at end of file +from .diarize import assign_word_speakers, DiarizationPipeline +from .asr import load_model diff --git a/whisperx/asr.py b/whisperx/asr.py index 0ccaf92b..4cec0f70 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -1,5 +1,4 @@ import os -import warnings from typing import List, Union, Optional, NamedTuple import ctranslate2 @@ -10,7 +9,7 @@ from transformers.pipelines.pt_utils import PipelineIterator from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram -from .vad import load_vad_model, merge_chunks +import whisperx.vads from .types import TranscriptionResult, SingleSegment def find_numeral_symbol_tokens(tokenizer): @@ -183,7 +182,16 @@ def data(audio, segments): # print(f2-f1) yield {'inputs': audio[f1:f2]} - vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) + # Pre-process audio and merge chunks as defined by the respective VAD child class + # In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit + if issubclass(type(self.vad_model), whisperx.vads.Vad): + waveform = self.vad_model.preprocess_audio(audio) + merge_chunks = self.vad_model.merge_chunks + else: + waveform = whisperx.vads.Pyannote.preprocess_audio(audio) + merge_chunks = whisperx.vads.Pyannote.merge_chunks + + vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE}) vad_segments = merge_chunks( vad_segments, chunk_size, @@ -263,6 +271,7 @@ def load_model(whisper_arch, asr_options=None, language : Optional[str] = None, vad_model=None, + vad_method=None, vad_options=None, model : Optional[WhisperModel] = None, task="transcribe", @@ -273,6 +282,7 @@ def load_model(whisper_arch, whisper_arch: str - The name of the Whisper model to load. device: str - The device to load the model on. compute_type: str - The compute type to use for the model. + vad_method: str - The vad method to use. vad_model has higher priority if is not None. options: dict - A dictionary of options to use for the model. language: str - The language of the model. (use English for now) model: Optional[WhisperModel] - The WhisperModel instance to use. @@ -334,6 +344,7 @@ def load_model(whisper_arch, default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options) default_vad_options = { + "chunk_size": 30, # needed by silero since binarization happens before merge_chunks "vad_onset": 0.500, "vad_offset": 0.363 } @@ -341,10 +352,16 @@ def load_model(whisper_arch, if vad_options is not None: default_vad_options.update(vad_options) + # Note: manually assigned vad_model has higher priority than vad_method! if vad_model is not None: + print("Use manually assigned vad_model. vad_method is ignored.") vad_model = vad_model else: - vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options) + match vad_method: + case "silero": + vad_model = whisperx.vads.Silero(**default_vad_options) + case "pyannote" | _: + vad_model = whisperx.vads.Pyannote(torch.device(device), use_auth_token=None, **default_vad_options) return FasterWhisperPipeline( model=model, diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index edd27648..de008dfa 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -39,6 +39,7 @@ def cli(): parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file") # vad params + parser.add_argument("--vad_method", type=str, default="pyannote", choices=["pyannote", "silero"], help="VAD method to be used") parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected") parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.") parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.") @@ -102,6 +103,7 @@ def cli(): return_char_alignments: bool = args.pop("return_char_alignments") hf_token: str = args.pop("hf_token") + vad_method: str = args.pop("vad_method") vad_onset: float = args.pop("vad_onset") vad_offset: float = args.pop("vad_offset") @@ -167,7 +169,7 @@ def cli(): results = [] tmp_results = [] # model = load_model(model_name, device=device, download_root=model_dir) - model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads) + model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_method=vad_method, vad_options={"chunk_size":chunk_size, "vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads) for audio_path in args.pop("audio"): audio = load_audio(audio_path) diff --git a/whisperx/vads/__init__.py b/whisperx/vads/__init__.py new file mode 100644 index 00000000..9dd82bf7 --- /dev/null +++ b/whisperx/vads/__init__.py @@ -0,0 +1,3 @@ +from whisperx.vads.pyannote import Pyannote +from whisperx.vads.silero import Silero +from whisperx.vads.vad import Vad \ No newline at end of file diff --git a/whisperx/vad.py b/whisperx/vads/pyannote.py similarity index 59% rename from whisperx/vad.py rename to whisperx/vads/pyannote.py index ab2c7bbf..9f20dfd5 100644 --- a/whisperx/vad.py +++ b/whisperx/vads/pyannote.py @@ -1,62 +1,24 @@ import hashlib import os import urllib -from typing import Callable, Optional, Text, Union +from typing import Callable, Text, Union +from typing import Optional import numpy as np -import pandas as pd import torch from pyannote.audio import Model from pyannote.audio.core.io import AudioFile from pyannote.audio.pipelines import VoiceActivityDetection from pyannote.audio.pipelines.utils import PipelineModel -from pyannote.core import Annotation, Segment, SlidingWindowFeature +from pyannote.core import Annotation, SlidingWindowFeature +from pyannote.core import Segment from tqdm import tqdm -from .diarize import Segment as SegmentX +from whisperx.diarize import Segment as SegmentX +from whisperx.vads.vad import Vad VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin" -def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None): - model_dir = torch.hub._get_torch_home() - os.makedirs(model_dir, exist_ok = True) - if model_fp is None: - model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin") - if os.path.exists(model_fp) and not os.path.isfile(model_fp): - raise RuntimeError(f"{model_fp} exists and is not a regular file") - - if not os.path.isfile(model_fp): - with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output: - with tqdm( - total=int(source.info().get("Content-Length")), - ncols=80, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as loop: - while True: - buffer = source.read(8192) - if not buffer: - break - - output.write(buffer) - loop.update(len(buffer)) - - model_bytes = open(model_fp, "rb").read() - if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]: - raise RuntimeError( - "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." - ) - - vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token) - hyperparameters = {"onset": vad_onset, - "offset": vad_offset, - "min_duration_on": 0.1, - "min_duration_off": 0.1} - vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device)) - vad_pipeline.instantiate(hyperparameters) - - return vad_pipeline class Binarize: """Binarize detection scores using hysteresis thresholding, with min-cut operation @@ -85,21 +47,21 @@ class Binarize: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. - Modified by Max Bain to include WhisperX's min-cut operation + Modified by Max Bain to include WhisperX's min-cut operation https://arxiv.org/abs/2303.00747 - + Pyannote-audio """ def __init__( - self, - onset: float = 0.5, - offset: Optional[float] = None, - min_duration_on: float = 0.0, - min_duration_off: float = 0.0, - pad_onset: float = 0.0, - pad_offset: float = 0.0, - max_duration: float = float('inf') + self, + onset: float = 0.5, + offset: Optional[float] = None, + min_duration_on: float = 0.0, + min_duration_off: float = 0.0, + pad_onset: float = 0.0, + pad_offset: float = 0.0, + max_duration: float = float('inf') ): super().__init__() @@ -145,7 +107,7 @@ def __call__(self, scores: SlidingWindowFeature) -> Annotation: t = start for t, y in zip(timestamps[1:], k_scores[1:]): # currently active - if is_active: + if is_active: curr_duration = t - start if curr_duration > self.max_duration: search_after = len(curr_scores) // 2 @@ -155,8 +117,8 @@ def __call__(self, scores: SlidingWindowFeature) -> Annotation: region = Segment(start - self.pad_onset, min_score_t + self.pad_offset) active[region, k] = label start = curr_timestamps[min_score_div_idx] - curr_scores = curr_scores[min_score_div_idx+1:] - curr_timestamps = curr_timestamps[min_score_div_idx+1:] + curr_scores = curr_scores[min_score_div_idx + 1:] + curr_timestamps = curr_timestamps[min_score_div_idx + 1:] # switching from active to inactive elif y < self.offset: region = Segment(start - self.pad_onset, t + self.pad_offset) @@ -197,11 +159,11 @@ def __call__(self, scores: SlidingWindowFeature) -> Annotation: class VoiceActivitySegmentation(VoiceActivityDetection): def __init__( - self, - segmentation: PipelineModel = "pyannote/segmentation", - fscore: bool = False, - use_auth_token: Union[Text, None] = None, - **inference_kwargs, + self, + segmentation: PipelineModel = "pyannote/segmentation", + fscore: bool = False, + use_auth_token: Union[Text, None] = None, + **inference_kwargs, ): super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs) @@ -240,72 +202,72 @@ def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation: return segmentations -def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0): - - active = Annotation() - for k, vad_t in enumerate(vad_arr): - region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset) - active[region, k] = 1 - - - if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0: - active = active.support(collar=min_duration_off) - - # remove tracks shorter than min_duration_on - if min_duration_on > 0: - for segment, track in list(active.itertracks()): - if segment.duration < min_duration_on: - del active[segment, track] - - active = active.for_json() - active_segs = pd.DataFrame([x['segment'] for x in active['content']]) - return active_segs - -def merge_chunks( - segments, - chunk_size, - onset: float = 0.5, - offset: Optional[float] = None, -): - """ - Merge operation described in paper - """ - curr_end = 0 - merged_segments = [] - seg_idxs = [] - speaker_idxs = [] - - assert chunk_size > 0 - binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset) - segments = binarize(segments) - segments_list = [] - for speech_turn in segments.get_timeline(): - segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN")) - - if len(segments_list) == 0: - print("No active speech found in audio") - return [] - # assert segments_list, "segments_list is empty." - # Make sur the starting point is the start of the segment. - curr_start = segments_list[0].start - - for seg in segments_list: - if seg.end - curr_start > chunk_size and curr_end-curr_start > 0: - merged_segments.append({ - "start": curr_start, - "end": curr_end, - "segments": seg_idxs, - }) - curr_start = seg.start - seg_idxs = [] - speaker_idxs = [] - curr_end = seg.end - seg_idxs.append((seg.start, seg.end)) - speaker_idxs.append(seg.speaker) - # add final - merged_segments.append({ - "start": curr_start, - "end": curr_end, - "segments": seg_idxs, - }) - return merged_segments +class Pyannote(Vad): + + def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs): + print(">>Performing voice activity detection using Pyannote...") + super().__init__(kwargs['vad_onset']) + + model_dir = torch.hub._get_torch_home() + os.makedirs(model_dir, exist_ok=True) + if model_fp is None: + model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin") + if os.path.exists(model_fp) and not os.path.isfile(model_fp): + raise RuntimeError(f"{model_fp} exists and is not a regular file") + + if not os.path.isfile(model_fp): + with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output: + with tqdm( + total=int(source.info().get("Content-Length")), + ncols=80, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + model_bytes = open(model_fp, "rb").read() + if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]: + raise RuntimeError( + "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." + ) + + vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token) + hyperparameters = {"onset": kwargs['vad_onset'], + "offset": kwargs['vad_offset'], + "min_duration_on": 0.1, + "min_duration_off": 0.1} + self.vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device)) + self.vad_pipeline.instantiate(hyperparameters) + + def __call__(self, audio: AudioFile, **kwargs): + return self.vad_pipeline(audio) + + @staticmethod + def preprocess_audio(audio): + return torch.from_numpy(audio).unsqueeze(0) + + @staticmethod + def merge_chunks(segments, + chunk_size, + onset: float = 0.5, + offset: Optional[float] = None, + ): + assert chunk_size > 0 + binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset) + segments = binarize(segments) + segments_list = [] + for speech_turn in segments.get_timeline(): + segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN")) + + if len(segments_list) == 0: + print("No active speech found in audio") + return [] + assert segments_list, "segments_list is empty." + return Vad.merge_chunks(segments_list, chunk_size, onset, offset) diff --git a/whisperx/vads/silero.py b/whisperx/vads/silero.py new file mode 100644 index 00000000..e7b44cc4 --- /dev/null +++ b/whisperx/vads/silero.py @@ -0,0 +1,62 @@ +from io import IOBase +from pathlib import Path +from typing import Mapping, Text +from typing import Optional +from typing import Union + +import torch + +from whisperx.diarize import Segment as SegmentX +from whisperx.vads.vad import Vad + +AudioFile = Union[Text, Path, IOBase, Mapping] + + +class Silero(Vad): + # check again default values + def __init__(self, **kwargs): + print(">>Performing voice activity detection using Silero...") + super().__init__(kwargs['vad_onset']) + + self.vad_onset = kwargs['vad_onset'] + self.chunk_size = kwargs['chunk_size'] + self.vad_pipeline, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', + model='silero_vad', + force_reload=False, + onnx=False, + trust_repo=True) + (self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils + + def __call__(self, audio: AudioFile, **kwargs): + """use silero to get segments of speech""" + # Only accept 16000 Hz for now. + # Note: Silero models support both 8000 and 16000 Hz. Although other values are not directly supported, + # multiples of 16000 (e.g. 32000 or 48000) are cast to 16000 inside of the JIT model! + sample_rate = audio["sample_rate"] + if sample_rate != 16000: + raise ValueError("Only 16000Hz sample rate is allowed") + + timestamps = self.get_speech_timestamps(audio["waveform"], + model=self.vad_pipeline, + sampling_rate=sample_rate, + max_speech_duration_s=self.chunk_size, + threshold=self.vad_onset + # min_silence_duration_ms = self.min_duration_off/1000 + # min_speech_duration_ms = self.min_duration_on/1000 + # ... + # See silero documentation for full option list + ) + return [SegmentX(i['start'] / sample_rate, i['end'] / sample_rate, "UNKNOWN") for i in timestamps] + + @staticmethod + def preprocess_audio(audio): + return audio + + @staticmethod + def merge_chunks(segments, + chunk_size, + onset: float = 0.5, + offset: Optional[float] = None, + ): + assert chunk_size > 0 + return Vad.merge_chunks(segments, chunk_size, onset, offset) diff --git a/whisperx/vads/vad.py b/whisperx/vads/vad.py new file mode 100644 index 00000000..d96184c5 --- /dev/null +++ b/whisperx/vads/vad.py @@ -0,0 +1,74 @@ +from typing import Optional + +import pandas as pd +from pyannote.core import Annotation, Segment + + +class Vad: + def __init__(self, vad_onset): + if not (0 < vad_onset < 1): + raise ValueError( + "vad_onset is a decimal value between 0 and 1." + ) + + @staticmethod + def preprocess_audio(audio): + pass + + # keep merge_chunks as static so it can be also used by manually assigned vad_model (see 'load_model') + @staticmethod + def merge_chunks(segments, + chunk_size, + onset: float, + offset: Optional[float]): + """ + Merge operation described in paper + """ + curr_end = 0 + merged_segments = [] + seg_idxs = [] + speaker_idxs = [] + + curr_start = segments[0].start + for seg in segments: + if seg.end - curr_start > chunk_size and curr_end - curr_start > 0: + merged_segments.append({ + "start": curr_start, + "end": curr_end, + "segments": seg_idxs, + }) + curr_start = seg.start + seg_idxs = [] + speaker_idxs = [] + curr_end = seg.end + seg_idxs.append((seg.start, seg.end)) + speaker_idxs.append(seg.speaker) + # add final + merged_segments.append({ + "start": curr_start, + "end": curr_end, + "segments": seg_idxs, + }) + + return merged_segments + + # Unused function + @staticmethod + def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0): + active = Annotation() + for k, vad_t in enumerate(vad_arr): + region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset) + active[region, k] = 1 + + if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0: + active = active.support(collar=min_duration_off) + + # remove tracks shorter than min_duration_on + if min_duration_on > 0: + for segment, track in list(active.itertracks()): + if segment.duration < min_duration_on: + del active[segment, track] + + active = active.for_json() + active_segs = pd.DataFrame([x['segment'] for x in active['content']]) + return active_segs