Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft Bark finetuning #2846

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 67 additions & 4 deletions TTS/tts/configs/bark_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from dataclasses import dataclass, field
from typing import Dict
from typing import Dict, List

from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.layers.bark.model import GPTConfig
Expand Down Expand Up @@ -48,9 +48,46 @@ class BarkConfig(BaseTTSConfig):
model: str = "bark"
audio: BarkAudioConfig = field(default_factory=BarkAudioConfig)
num_chars: int = 0
semantic_config: GPTConfig = field(default_factory=GPTConfig)
fine_config: FineGPTConfig = field(default_factory=FineGPTConfig)
coarse_config: GPTConfig = field(default_factory=GPTConfig)

semantic_gpt_config: GPTConfig = field(
default_factory=lambda: GPTConfig(
block_size=1024,
input_vocab_size=129600,
output_vocab_size=10048,
n_layer=24,
n_head=16,
n_embd=1024,
dropout=0.0,
bias=False,
)
)
coarse_gpt_config: GPTConfig = field(
default_factory=lambda: GPTConfig(
block_size=1024,
input_vocab_size=12096,
output_vocab_size=12096,
n_layer=24,
n_head=16,
n_embd=1024,
dropout=0.0,
bias=False,
)
)
fine_gpt_config: FineGPTConfig = field(
default_factory=lambda: FineGPTConfig(
block_size=1024,
input_vocab_size=1056,
output_vocab_size=1056,
n_layer=24,
n_head=16,
n_embd=1024,
dropout=0.0,
bias=False,
n_codes_total=8,
n_codes_given=1,
)
)

CONTEXT_WINDOW_SIZE: int = 1024
SEMANTIC_RATE_HZ: float = 49.9
SEMANTIC_VOCAB_SIZE: int = 10_000
Expand All @@ -75,6 +112,32 @@ class BarkConfig(BaseTTSConfig):
CACHE_DIR: str = str(get_user_data_dir("tts/suno/bark_v0"))
DEF_SPEAKER_DIR: str = str(get_user_data_dir("tts/bark_v0/speakers"))

# training parameters
training_mode: str = None
batch_size: int = 32
num_workers: int = 4
epochs: int = 1000

# data parameters
train_semantic_data_settings: dict = field(
default_factory=lambda: {"max_semantic_tokens_len": 511, "max_text_tokens_len": 256}
)
train_coarse_data_settings: dict = field(
default_factory=lambda: {"max_semantic_tokens_len": 256, "max_coarse_tokens_len": 768}
)
train_fine_data_settings: dict = field(
default_factory=lambda: {"max_semantic_tokens_len": 256, "max_fine_tokens_len": 512}
)

# optimizer
grad_clip: float = 1000
lr: float = 0.0002
lr_scheduler: str = "ExponentialLR"
lr_scheduler_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
scheduler_after_epoch: bool = True
optimizer: str = "AdamW"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})

def __post_init__(self):
self.REMOTE_MODEL_PATHS = {
"text": {
Expand Down
16 changes: 7 additions & 9 deletions TTS/tts/layers/bark/hubert/kmeans_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,11 @@ def __init__(self, checkpoint_path, target_sample_hz=16000, seq_len_multiple_of=
self.target_sample_hz = target_sample_hz
self.seq_len_multiple_of = seq_len_multiple_of
self.output_layer = output_layer
if device is not None:
self.to(device)

self.model = HubertModel.from_pretrained("facebook/hubert-base-ls960")
if device is not None:
self.model.to(device)
self.model.eval()
if device is not None:
self.to(device)

@property
def groups(self):
Expand All @@ -73,10 +72,9 @@ def forward(self, wav_input, flatten=True, input_sample_hz=None):
output_hidden_states=True,
)
embed = outputs["hidden_states"][self.output_layer]
embed, packed_shape = pack([embed], "* d")
codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device)
if flatten:
return codebook_indices
embed, packed_shape = pack([embed], "* d")
embed = torch.from_numpy(embed.cpu().detach().numpy()).to(device)
return embed

(codebook_indices,) = unpack(codebook_indices, packed_shape, "*")
return codebook_indices
return embed
96 changes: 64 additions & 32 deletions TTS/tts/layers/bark/inference_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tqdm
from encodec.utils import convert_audio
from scipy.special import softmax
from torch import nn
from torch.nn import functional as F

from TTS.tts.layers.bark.hubert.hubert_manager import HubertManager
Expand Down Expand Up @@ -102,51 +103,82 @@ def compute_average_bass_energy(audio_data, sample_rate, max_bass_freq=250):
return bass_energy


class BarkHubertAudioTokenizer(nn.Module):
def __init__(self, config, lazy_load=True) -> None:
super().__init__()
self.__device_param = nn.Parameter(torch.empty(0))
self.config = config
self.lazy_load = lazy_load
if lazy_load:
self.load_hubert(config, self.device)

@property
def device(self):
return self.__device_param.device

def load_hubert(self, config, device):
hubert_manager = HubertManager()
hubert_manager.make_sure_tokenizer_installed(model_path=self.config.LOCAL_MODEL_PATHS["hubert_tokenizer"])
self.hubert_model = CustomHubert(checkpoint_path=self.config.LOCAL_MODEL_PATHS["hubert"]).to(device)
self.tokenizer = HubertTokenizer.load_from_checkpoint(
config.LOCAL_MODEL_PATHS["hubert_tokenizer"], map_location=device
)

def encode(self, audio, device):
"""Encode an audio file into a sequence of tokens.

Args:
audio (str or Tensor): The audio to encode. In shape (B, T).
device (str): The device to use for encoding.

Returns:
Tensor: The encoded tokens.
"""
if isinstance(audio, str):
audio, sr = torchaudio.load(audio)
audio = convert_audio(audio, sr, self.config.sample_rate, 1)
audio = audio.to(device)

if not self.lazy_load:
self.load_hubert(self.config, self.device)

semantic_vectors = self.hubert_model.forward(audio, flatten=False, input_sample_hz=self.config.sample_rate)
semantic_tokens = self.tokenizer.get_token(semantic_vectors)
semantic_tokens = semantic_tokens
return semantic_tokens


def generate_voice(
audio,
model,
output_path,
output_path=None,
):
"""Generate a new voice from a given audio and text prompt.
"""Generate a new voice from a given audioZ.

Args:
audio (np.ndarray): The audio to use as a base for the new voice.
text (str): Transcription of the audio you are clonning.
model (BarkModel): The BarkModel to use for generating the new voice.
output_path (str): The path to save the generated voice to.
output_path (str): The path to save the generated voice to. If None, return computed tokens.
"""
if isinstance(audio, str):
audio, sr = torchaudio.load(audio)
audio = convert_audio(audio, sr, model.config.sample_rate, model.encodec.channels)
audio = audio.unsqueeze(0).to(model.device)

with torch.no_grad():
encoded_frames = model.encodec.encode(audio)
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]

# move codes to cpu
codes = codes.cpu().numpy()
# Coarse and fine tokens
fine_tokens, coarse_tokens = model.generate_coarse_fine_tokens(audio)
fine_tokens = fine_tokens.cpu().numpy()
coarse_tokens = coarse_tokens.cpu().numpy()

# generate semantic tokens
# Load the HuBERT model
hubert_manager = HubertManager()
# hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"])
hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"])
# Semantic tokens
semantic_tokens = model.generate_semantic_tokens(audio).cpu().numpy()

hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device)

# Load the CustomTokenizer model
tokenizer = HubertTokenizer.load_from_checkpoint(
model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"], map_location=model.device
)
# semantic_tokens = model.text_to_semantic(
# text, max_gen_duration_s=seconds, top_k=50, top_p=0.95, temp=0.7
# ) # not 100%
semantic_vectors = hubert_model.forward(audio[0], input_sample_hz=model.config.sample_rate)
semantic_tokens = tokenizer.get_token(semantic_vectors)
semantic_tokens = semantic_tokens.cpu().numpy()

np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)
if output_path is not None:
np.savez(
output_path, fine_prompt=fine_tokens, coarse_prompt=coarse_tokens[:2, :], semantic_prompt=semantic_tokens
)
else:
return {"fine_prompt": fine_tokens, "coarse_prompt": coarse_tokens, "semantic_prompt": semantic_tokens}


def generate_text_semantic(
Expand All @@ -162,7 +194,7 @@ def generate_text_semantic(
allow_early_stop=True,
base=None,
use_kv_caching=True,
**kwargs, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
"""Generate semantic tokens from text.

Expand Down Expand Up @@ -242,7 +274,7 @@ def generate_text_semantic(
x_input = x[:, [-1]]
else:
x_input = x
logits, kv_cache = model.semantic_model(
logits, kv_cache = model.semantic_model.inference(
x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache
)
relevant_logits = logits[0, 0, : model.config.SEMANTIC_VOCAB_SIZE]
Expand Down Expand Up @@ -296,7 +328,7 @@ def generate_text_semantic(


def _flatten_codebooks(arr, offset_size):
assert len(arr.shape) == 2
assert len(arr.shape) == 2, f" ❗ Codebooks must be 2D, got {len(arr.shape)}D"
arr = arr.copy()
if offset_size is not None:
for n in range(1, arr.shape[0]):
Expand Down
32 changes: 31 additions & 1 deletion TTS/tts/layers/bark/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,37 @@ def get_num_params(self, non_embedding=True):
n_params -= self.transformer.wpe.weight.numel()
return n_params

def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
def forward(self, idx):
device = idx.device
_, t = idx.size()
assert (
t <= self.config.block_size
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"

# forward the GPT model itself
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)

past_length = 0
past_kv = tuple([None] * len(self.transformer.h))

position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0) # shape (1, t)
assert position_ids.shape == (1, t)
pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)

x = self.transformer.drop(tok_emb + pos_emb)

for _, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)):
x, kv = block(x, past_kv=past_layer_kv, use_cache=False)

x = self.transformer.ln_f(x)

# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x) # note: using list [-1] to preserve the time dim

return logits

def inference(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
device = idx.device
_, t = idx.size()
if past_kv is not None:
Expand Down
Loading
Loading