diff --git a/TTS/tts/configs/bark_config.py b/TTS/tts/configs/bark_config.py index 4d1cd1374a..d4849c5d8c 100644 --- a/TTS/tts/configs/bark_config.py +++ b/TTS/tts/configs/bark_config.py @@ -1,6 +1,6 @@ import os from dataclasses import dataclass, field -from typing import Dict +from typing import Dict, List from TTS.tts.configs.shared_configs import BaseTTSConfig from TTS.tts.layers.bark.model import GPTConfig @@ -48,9 +48,46 @@ class BarkConfig(BaseTTSConfig): model: str = "bark" audio: BarkAudioConfig = field(default_factory=BarkAudioConfig) num_chars: int = 0 - semantic_config: GPTConfig = field(default_factory=GPTConfig) - fine_config: FineGPTConfig = field(default_factory=FineGPTConfig) - coarse_config: GPTConfig = field(default_factory=GPTConfig) + + semantic_gpt_config: GPTConfig = field( + default_factory=lambda: GPTConfig( + block_size=1024, + input_vocab_size=129600, + output_vocab_size=10048, + n_layer=24, + n_head=16, + n_embd=1024, + dropout=0.0, + bias=False, + ) + ) + coarse_gpt_config: GPTConfig = field( + default_factory=lambda: GPTConfig( + block_size=1024, + input_vocab_size=12096, + output_vocab_size=12096, + n_layer=24, + n_head=16, + n_embd=1024, + dropout=0.0, + bias=False, + ) + ) + fine_gpt_config: FineGPTConfig = field( + default_factory=lambda: FineGPTConfig( + block_size=1024, + input_vocab_size=1056, + output_vocab_size=1056, + n_layer=24, + n_head=16, + n_embd=1024, + dropout=0.0, + bias=False, + n_codes_total=8, + n_codes_given=1, + ) + ) + CONTEXT_WINDOW_SIZE: int = 1024 SEMANTIC_RATE_HZ: float = 49.9 SEMANTIC_VOCAB_SIZE: int = 10_000 @@ -75,6 +112,32 @@ class BarkConfig(BaseTTSConfig): CACHE_DIR: str = str(get_user_data_dir("tts/suno/bark_v0")) DEF_SPEAKER_DIR: str = str(get_user_data_dir("tts/bark_v0/speakers")) + # training parameters + training_mode: str = None + batch_size: int = 32 + num_workers: int = 4 + epochs: int = 1000 + + # data parameters + train_semantic_data_settings: dict = field( + default_factory=lambda: {"max_semantic_tokens_len": 511, "max_text_tokens_len": 256} + ) + train_coarse_data_settings: dict = field( + default_factory=lambda: {"max_semantic_tokens_len": 256, "max_coarse_tokens_len": 768} + ) + train_fine_data_settings: dict = field( + default_factory=lambda: {"max_semantic_tokens_len": 256, "max_fine_tokens_len": 512} + ) + + # optimizer + grad_clip: float = 1000 + lr: float = 0.0002 + lr_scheduler: str = "ExponentialLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + scheduler_after_epoch: bool = True + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01}) + def __post_init__(self): self.REMOTE_MODEL_PATHS = { "text": { diff --git a/TTS/tts/layers/bark/hubert/kmeans_hubert.py b/TTS/tts/layers/bark/hubert/kmeans_hubert.py index a6a3b9aeb1..c57ec50a30 100644 --- a/TTS/tts/layers/bark/hubert/kmeans_hubert.py +++ b/TTS/tts/layers/bark/hubert/kmeans_hubert.py @@ -47,12 +47,11 @@ def __init__(self, checkpoint_path, target_sample_hz=16000, seq_len_multiple_of= self.target_sample_hz = target_sample_hz self.seq_len_multiple_of = seq_len_multiple_of self.output_layer = output_layer - if device is not None: - self.to(device) + self.model = HubertModel.from_pretrained("facebook/hubert-base-ls960") - if device is not None: - self.model.to(device) self.model.eval() + if device is not None: + self.to(device) @property def groups(self): @@ -73,10 +72,9 @@ def forward(self, wav_input, flatten=True, input_sample_hz=None): output_hidden_states=True, ) embed = outputs["hidden_states"][self.output_layer] - embed, packed_shape = pack([embed], "* d") - codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) if flatten: - return codebook_indices + embed, packed_shape = pack([embed], "* d") + embed = torch.from_numpy(embed.cpu().detach().numpy()).to(device) + return embed - (codebook_indices,) = unpack(codebook_indices, packed_shape, "*") - return codebook_indices + return embed diff --git a/TTS/tts/layers/bark/inference_funcs.py b/TTS/tts/layers/bark/inference_funcs.py index d7f3f79345..c0248a24cb 100644 --- a/TTS/tts/layers/bark/inference_funcs.py +++ b/TTS/tts/layers/bark/inference_funcs.py @@ -11,6 +11,7 @@ import tqdm from encodec.utils import convert_audio from scipy.special import softmax +from torch import nn from torch.nn import functional as F from TTS.tts.layers.bark.hubert.hubert_manager import HubertManager @@ -102,51 +103,82 @@ def compute_average_bass_energy(audio_data, sample_rate, max_bass_freq=250): return bass_energy +class BarkHubertAudioTokenizer(nn.Module): + def __init__(self, config, lazy_load=True) -> None: + super().__init__() + self.__device_param = nn.Parameter(torch.empty(0)) + self.config = config + self.lazy_load = lazy_load + if lazy_load: + self.load_hubert(config, self.device) + + @property + def device(self): + return self.__device_param.device + + def load_hubert(self, config, device): + hubert_manager = HubertManager() + hubert_manager.make_sure_tokenizer_installed(model_path=self.config.LOCAL_MODEL_PATHS["hubert_tokenizer"]) + self.hubert_model = CustomHubert(checkpoint_path=self.config.LOCAL_MODEL_PATHS["hubert"]).to(device) + self.tokenizer = HubertTokenizer.load_from_checkpoint( + config.LOCAL_MODEL_PATHS["hubert_tokenizer"], map_location=device + ) + + def encode(self, audio, device): + """Encode an audio file into a sequence of tokens. + + Args: + audio (str or Tensor): The audio to encode. In shape (B, T). + device (str): The device to use for encoding. + + Returns: + Tensor: The encoded tokens. + """ + if isinstance(audio, str): + audio, sr = torchaudio.load(audio) + audio = convert_audio(audio, sr, self.config.sample_rate, 1) + audio = audio.to(device) + + if not self.lazy_load: + self.load_hubert(self.config, self.device) + + semantic_vectors = self.hubert_model.forward(audio, flatten=False, input_sample_hz=self.config.sample_rate) + semantic_tokens = self.tokenizer.get_token(semantic_vectors) + semantic_tokens = semantic_tokens + return semantic_tokens + + def generate_voice( audio, model, - output_path, + output_path=None, ): - """Generate a new voice from a given audio and text prompt. + """Generate a new voice from a given audioZ. Args: audio (np.ndarray): The audio to use as a base for the new voice. - text (str): Transcription of the audio you are clonning. model (BarkModel): The BarkModel to use for generating the new voice. - output_path (str): The path to save the generated voice to. + output_path (str): The path to save the generated voice to. If None, return computed tokens. """ if isinstance(audio, str): audio, sr = torchaudio.load(audio) audio = convert_audio(audio, sr, model.config.sample_rate, model.encodec.channels) audio = audio.unsqueeze(0).to(model.device) - with torch.no_grad(): - encoded_frames = model.encodec.encode(audio) - codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T] - - # move codes to cpu - codes = codes.cpu().numpy() + # Coarse and fine tokens + fine_tokens, coarse_tokens = model.generate_coarse_fine_tokens(audio) + fine_tokens = fine_tokens.cpu().numpy() + coarse_tokens = coarse_tokens.cpu().numpy() - # generate semantic tokens - # Load the HuBERT model - hubert_manager = HubertManager() - # hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"]) - hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"]) + # Semantic tokens + semantic_tokens = model.generate_semantic_tokens(audio).cpu().numpy() - hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device) - - # Load the CustomTokenizer model - tokenizer = HubertTokenizer.load_from_checkpoint( - model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"], map_location=model.device - ) - # semantic_tokens = model.text_to_semantic( - # text, max_gen_duration_s=seconds, top_k=50, top_p=0.95, temp=0.7 - # ) # not 100% - semantic_vectors = hubert_model.forward(audio[0], input_sample_hz=model.config.sample_rate) - semantic_tokens = tokenizer.get_token(semantic_vectors) - semantic_tokens = semantic_tokens.cpu().numpy() - - np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens) + if output_path is not None: + np.savez( + output_path, fine_prompt=fine_tokens, coarse_prompt=coarse_tokens[:2, :], semantic_prompt=semantic_tokens + ) + else: + return {"fine_prompt": fine_tokens, "coarse_prompt": coarse_tokens, "semantic_prompt": semantic_tokens} def generate_text_semantic( @@ -162,7 +194,7 @@ def generate_text_semantic( allow_early_stop=True, base=None, use_kv_caching=True, - **kwargs, # pylint: disable=unused-argument + **kwargs, # pylint: disable=unused-argument ): """Generate semantic tokens from text. @@ -242,7 +274,7 @@ def generate_text_semantic( x_input = x[:, [-1]] else: x_input = x - logits, kv_cache = model.semantic_model( + logits, kv_cache = model.semantic_model.inference( x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache ) relevant_logits = logits[0, 0, : model.config.SEMANTIC_VOCAB_SIZE] @@ -296,7 +328,7 @@ def generate_text_semantic( def _flatten_codebooks(arr, offset_size): - assert len(arr.shape) == 2 + assert len(arr.shape) == 2, f" ❗ Codebooks must be 2D, got {len(arr.shape)}D" arr = arr.copy() if offset_size is not None: for n in range(1, arr.shape[0]): diff --git a/TTS/tts/layers/bark/model.py b/TTS/tts/layers/bark/model.py index c84022bd08..c31c4d08b9 100644 --- a/TTS/tts/layers/bark/model.py +++ b/TTS/tts/layers/bark/model.py @@ -175,7 +175,37 @@ def get_num_params(self, non_embedding=True): n_params -= self.transformer.wpe.weight.numel() return n_params - def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False): + def forward(self, idx): + device = idx.device + _, t = idx.size() + assert ( + t <= self.config.block_size + ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + + # forward the GPT model itself + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + + past_length = 0 + past_kv = tuple([None] * len(self.transformer.h)) + + position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) # shape (1, t) + assert position_ids.shape == (1, t) + pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd) + + x = self.transformer.drop(tok_emb + pos_emb) + + for _, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)): + x, kv = block(x, past_kv=past_layer_kv, use_cache=False) + + x = self.transformer.ln_f(x) + + # inference-time mini-optimization: only forward the lm_head on the very last position + logits = self.lm_head(x) # note: using list [-1] to preserve the time dim + + return logits + + def inference(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False): device = idx.device _, t = idx.size() if past_kv is not None: diff --git a/TTS/tts/models/bark.py b/TTS/tts/models/bark.py index f198c3d58a..fa07ff35a2 100644 --- a/TTS/tts/models/bark.py +++ b/TTS/tts/models/bark.py @@ -1,14 +1,25 @@ import os from dataclasses import dataclass -from typing import Optional +from typing import Dict, List, Optional, Tuple, Union import numpy as np +import torch +import torch.distributed as dist +import torchaudio from coqpit import Coqpit from encodec import EncodecModel +from torch.nn import functional as F +from torch.utils.data import DataLoader, Dataset +from trainer.trainer_utils import get_optimizer, get_scheduler from transformers import BertTokenizer +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.datasets.dataset import _parse_sample from TTS.tts.layers.bark.inference_funcs import ( + BarkHubertAudioTokenizer, codec_decode, + convert_audio, generate_coarse, generate_fine, generate_text_semantic, @@ -21,6 +32,104 @@ from TTS.tts.models.base_tts import BaseTTS +def load_audio(file_path, sr): + """Load the audio file normalized in [-1, 1] + + Return Shapes: + - x: :math:`[1, T]` + """ + x, _sr = torchaudio.load(file_path) + + # resample if needed + if sr != _sr: + x = torchaudio.transforms.Resample(_sr, sr)(x) + + assert (x > 1).sum() + (x < -1).sum() == 0 + return x, sr + + +class BarkDataset(Dataset): + def __init__(self, config, samples): + super().__init__() + self.samples = samples + self.config = config + + def __getitem__(self, idx): + item = self.samples[idx] + raw_text = item["text"] + + wav, _ = load_audio(item["audio_file"], self.config.sample_rate) + wav_filename = os.path.basename(item["audio_file"]) + + return { + "raw_text": raw_text, + "text_len": len(raw_text), + "wav": wav, + "wav_len": wav.shape[1], + "wav_file": wav_filename, + "speaker_name": item["speaker_name"], + "language_name": item["language"], + "audio_unique_name": item["audio_unique_name"], + } + + def __len__(self): + return len(self.samples) + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + lens.append(audio_len) + return lens + + def collate_fn(self, batch): + """ + Return Shapes: + - tokens: :math:`[B, T]` + - token_lens :math:`[B]` + - token_rel_lens :math:`[B]` + - waveform: :math:`[B, 1, T]` + - waveform_lens: :math:`[B]` + - waveform_rel_lens: :math:`[B]` + - speaker_names: :math:`[B]` + - language_names: :math:`[B]` + - audiofile_paths: :math:`[B]` + - raw_texts: :math:`[B]` + - audio_unique_names: :math:`[B]` + """ + # convert list of dicts to dict of lists + B = len(batch) + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x.size(1) for x in batch["wav"]]), dim=0, descending=True + ) + + wav_lens = [w.shape[1] for w in batch["wav"]] + wav_lens = torch.LongTensor(wav_lens) + wav_lens_max = torch.max(wav_lens) + wav_rel_lens = wav_lens / wav_lens_max + + wav_padded = torch.FloatTensor(B, 1, wav_lens_max) + wav_padded = wav_padded.zero_() + for i in range(len(ids_sorted_decreasing)): + wav = batch["wav"][i] + wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) + + return { + "waveform": wav_padded, # (B x T) + "waveform_lens": wav_lens, # (B) + "waveform_rel_lens": wav_rel_lens, + "speaker_names": batch["speaker_name"], + "language_names": batch["language_name"], + "audio_files": batch["wav_file"], + "raw_text": batch["raw_text"], + "audio_unique_names": batch["audio_unique_name"], + } + + @dataclass class BarkAudioConfig(Coqpit): sample_rate: int = 24000 @@ -36,16 +145,26 @@ def __init__( super().__init__(config=config, ap=None, tokenizer=None, speaker_manager=None, language_manager=None) self.config.num_chars = len(tokenizer) self.tokenizer = tokenizer - self.semantic_model = GPT(config.semantic_config) - self.coarse_model = GPT(config.coarse_config) - self.fine_model = FineGPT(config.fine_config) + self.semantic_model = GPT(config.semantic_gpt_config) + self.coarse_model = GPT(config.coarse_gpt_config) + self.fine_model = FineGPT(config.fine_gpt_config) self.encodec = EncodecModel.encodec_model_24khz() self.encodec.set_target_bandwidth(6.0) + self.semantic_tokenizer = BarkHubertAudioTokenizer(self.config, lazy_load=self.config.training_mode) @property def device(self): return next(self.parameters()).device + @property + def pad_token(self): + if self.config.training_mode == "semantic": + return self.config.SEMANTIC_PAD_TOKEN + elif self.config.training_mode in ["coarse", "fine"]: + return self.config.COARSE_SEMANTIC_PAD_TOKEN + else: + raise ValueError("Invalid training mode: {}".format(self.config.training_mode)) + def load_bark_models(self): self.semantic_model, self.config = load_model( ckpt_path=self.config.LOCAL_MODEL_PATHS["text"], device=self.device, config=self.config, model_type="text" @@ -60,10 +179,23 @@ def load_bark_models(self): ckpt_path=self.config.LOCAL_MODEL_PATHS["fine"], device=self.device, config=self.config, model_type="fine" ) - def train_step( + def generate_coarse_fine_tokens( self, + audio, ): - pass + if isinstance(audio, str): + audio, sr = torchaudio.load(audio) + audio = convert_audio(audio, sr, self.config.sample_rate, self.encodec.channels) + audio = audio.unsqueeze(0).to(self.device) + + # Coarse and fine tokens + with torch.no_grad(): + encoded_frames = self.encodec.encode(audio) + codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T] + return codes, codes[:2, :] # fine, corse + + def generate_semantic_tokens(self, audio): + return self.semantic_tokenizer.encode(audio, self.device) def text_to_semantic( self, @@ -225,7 +357,117 @@ def synthesize( return return_dict - def eval_step(self): + def format_batch(self, batch): + """Tokenize input text. + + Args: + batch (dict): batch of data to format + + Returns: + formatted batch + """ + tokenss = [] + max_len = 0 + for i, text in enumerate(batch["raw_text"]): + tokens = np.array(self.tokenizer.encode(text, add_special_tokens=False)) + self.config.TEXT_ENCODING_OFFSET + tokens = torch.from_numpy(tokens).long() + tokenss.append(tokens) + max_len = max(max_len, len(tokens)) + + if self.config.training_mode == "semantic": + # pad and collate into batch + for i, tokens in enumerate(tokenss): + tokenss[i] = torch.nn.functional.pad(tokens, (0, max_len - len(tokens)), value=self.pad_token) + tokens = torch.stack(tokenss, dim=0) + batch["input_ids"] = tokens[:, : self.config.train_semantic_data_settings["max_text_tokens_len"]] + return batch + + def format_batch_on_device(self, batch): + """Tokenize input audio. + + Args: + batch (dict): Batch of input data. + + Returns: + dict: Formatted batch. + """ + # TODO: Make padding and truncation based on exact length of the waveforms + if self.config.training_mode == "semantic": + batch["semantic_tokens"] = self.generate_semantic_tokens(batch["waveform"][:, 0])[ + :, : self.config.max_semantic_tokens_len + ] + elif self.config.training_mode == "coarse": + semantic_to_coarse_ratio = ( + self.config.COARSE_RATE_HZ / self.config.SEMANTIC_RATE_HZ * self.config.N_COARSE_CODEBOOKS + ) + + batch["semantic_tokens"] = self.generate_semantic_tokens(batch["waveform"][:, 0])[ + :, : self.config.train_coarse_data_settings["max_semantic_tokens_len"] + ] + batch["semantic_tokens"] = torch.nn.functional.pad( + batch["semantic_tokens"], (0, 1), value=self.config.COARSE_INFER_TOKEN + ) + + batch["coarse_tokens"] = self.generate_coarse_fine_tokens(batch["waveform"])[1] + batch["coarse_tokens"] = ( + batch["coarse_tokens"].flatten(start_dim=1) + + self.config.CODEBOOK_SIZE + + self.config.SEMANTIC_VOCAB_SIZE + ) + batch["coarse_tokens"] = batch["coarse_tokens"][ + :, : self.config.train_coarse_data_settings["max_coarse_tokens_len"] + ] + elif self.config.training_mode == "fine": + batch["coarse_tokens"], batch["fine_tokens"] = self.generate_coarse_fine_tokens(batch["waveform"])[ + :, : self.config.max_coarse_tokens_len + ] + return batch + + def train_step_semantic(self, batch: dict, criterion: torch.nn.Module) -> Tuple[Dict, Dict]: + """Train semantic encoder""" + tokens = batch["semantic_tokens"] + target_tokens = tokens[:, 1:].contiguous() + input_tokens = tokens[:, :-1].contiguous() + + inputs = torch.cat([batch["input_ids"], input_tokens], dim=1) + logits = self.semantic_model(inputs) + + logits = logits[:, batch["input_ids"].size(1) :].contiguous() + + loss = criterion(logits.view(-1, self.config.semantic_gpt_config.output_vocab_size), target_tokens.view(-1)) + loss_dict = {"loss": loss} + return {}, loss_dict + + def train_step_coarse(self, batch: dict, criterion: torch.nn.Module) -> Tuple[Dict, Dict]: + """Train coarse encoder""" + tokens = batch["coarse_tokens"] + target_tokens = tokens[:, 1:].contiguous() + input_tokens = tokens[:, :-1].contiguous() + + inputs = torch.cat([batch["semantic_tokens"], input_tokens], dim=1) + logits = self.coarse_model(inputs) + + logits = logits[:, batch["semantic_tokens"].size(1) :].contiguous() + + loss = criterion(logits.view(-1, self.config.coarse_gpt_config.output_vocab_size), target_tokens.view(-1)) + loss_dict = {"loss": loss} + return {}, loss_dict + + def train_step_fine(self): + ... + + def train_step(self, *args, **kwargs): + if self.config.training_mode == "semantic": + return self.train_step_semantic(*args, **kwargs) + elif self.config.training_mode == "coarse": + return self.train_step_coarse(*args, **kwargs) + elif self.config.training_mode == "fine": + raise NotImplemented() + + def eval_step(self, *args, **kwargs): + self.train_step(*args, **kwargs) + + def test_run(self, *args, **kwargs): ... def forward(self): @@ -234,6 +476,72 @@ def forward(self): def inference(self): ... + def _get_test_aux_inputs(self): + return None + + def get_criterion(self): + return torch.nn.CrossEntropyLoss(ignore_index=self.pad_token) + + def get_optimizer(self): + if self.config.training_mode == "semantic": + optimizer = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr, self.semantic_model + ) + elif self.config.training_mode == "coarse": + optimizer = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr, self.coarse_model + ) + elif self.config.training_mode == "fine": + optimizer = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr, self.fine_model + ) + else: + raise ValueError(" ❗ Invalid training mode: {}".format(self.config.training_mode)) + return optimizer + + def get_scheduler(self, optimizer): + scheduler = get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer) + return scheduler + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + from trainer.torch import DistributedSampler + + if is_eval and not config.run_eval: + loader = None + else: + # init dataloader + dataset = BarkDataset( + config=self.config, + samples=samples, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # init data loader + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, # shuffle is done in the dataset. + collate_fn=dataset.collate_fn, + drop_last=True, # setting this False might cause issues in AMP training. + sampler=sampler, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=True, + ) + return loader + @staticmethod def init_from_config(config: "BarkConfig", **kwargs): # pylint: disable=unused-argument return Bark(config) @@ -276,3 +584,81 @@ def load_checkpoint( if eval: self.eval() + + +if __name__ == "__main__": + # from TTS.tts.configs.bark_config import BarkConfig + + # bark_config = BarkConfig() + + # bark_config.training_mode = "semantic" + # bark_config.batch_size = 2 + + # bark = Bark.init_from_config(bark_config) + + # # batch = {"waveform": torch.randn(2, 48000), "raw_text": ["hello world", "how are you"]} + # # batch = bark.format_batch(batch) + # # batch = bark.format_batch_on_device(batch) + + # from trainer import Trainer, TrainerArgs + + # dataset_config = BaseDatasetConfig( + # formatter="ljspeech", meta_file_train="metadata.csv", path="/data/TTS-public/tests/data/ljspeech/" + # ) + + # train_samples, eval_samples = load_tts_samples( + # dataset_config, + # eval_split=True, + # eval_split_max_size=4, + # eval_split_size=4, + # ) + + # trainer = Trainer( + # model=bark, + # config=bark_config, + # output_path="./", + # args=TrainerArgs(), + # train_samples=train_samples, + # eval_samples=eval_samples, + # ) + # trainer.fit() + + from TTS.tts.configs.bark_config import BarkConfig + + bark_config = BarkConfig() + + bark_config.training_mode = "coarse" + bark_config.batch_size = 2 + bark_config.run_eval = False + bark_config.save_checkpoints = False + bark_config.save_best_after = 100000 + bark_config.print_step = 1 + + bark = Bark.init_from_config(bark_config) + + # batch = {"waveform": torch.randn(2, 48000), "raw_text": ["hello world", "how are you"]} + # batch = bark.format_batch(batch) + # batch = bark.format_batch_on_device(batch) + + from trainer import Trainer, TrainerArgs + + dataset_config = BaseDatasetConfig( + formatter="ljspeech", meta_file_train="metadata.csv", path="/data/TTS-public/tests/data/ljspeech/" + ) + + train_samples, eval_samples = load_tts_samples( + dataset_config, + eval_split=True, + eval_split_max_size=4, + eval_split_size=4, + ) + + trainer = Trainer( + model=bark, + config=bark_config, + output_path="./", + args=TrainerArgs(), + train_samples=train_samples, + eval_samples=eval_samples, + ) + trainer.fit()