diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 0836870e29..53a63a9185 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -1,6 +1,8 @@ import os from contextlib import contextmanager from dataclasses import dataclass +import time +import librosa import torch import torch.nn.functional as F @@ -404,6 +406,36 @@ def get_conditioning_latents( ) return gpt_cond_latents.to(self.device), diffusion_cond_latents.to(self.device) + # When working in an environment where your reference wavs are going to be consistent, there's no reason we need to compute the conditioning latents every time, + # especially when they're the bulk of the processing time during inference. When you can cache it, always cache it. + def precompute_conditioning_latents( + self, + audio_folder_path + ): + print("Beginning latent precomputation.") + for file in os.listdir(audio_folder_path): + if os.path.isdir(audio_folder_path + "/" + file): + continue + if file.endswith(".gpt") or file.endswith(".diffusion"): + continue + duration = librosa.get_duration(filename=audio_folder_path + "/" + file) # probably a faster/cheaper way to do this + if duration < 3: + print("Skipping " + str(file) + "; duration is " + str(int(duration))) + continue + print("Computing " + str(file)) + if os.path.exists(audio_folder_path + "/" + file + ".gpt"): + print("Skipping " + str(file) + ".gpt; already exists.") + else: + gpt_cond_latents = self.get_gpt_cond_latents(audio_folder_path + "/" + file, length=int(duration)) # [1, 1024, T] + torch.save(gpt_cond_latents, audio_folder_path + "/" + file + ".gpt") + if os.path.exists(audio_folder_path + "/" + file + ".diffusion"): + print("Skipping " + str(file) + ".diffusion; already exists.") + else: + diffusion_cond_latents = self.get_diffusion_cond_latents( + audio_folder_path + "/" + file, + ) + torch.save(diffusion_cond_latents, audio_folder_path + "/" + file + ".diffusion") + print("Latent precomputation complete, files saved adjacent to the input files with .gpt and .diffusion extensions.") def synthesize(self, text, config, speaker_wav, language, **kwargs): """Synthesize speech with the given input text. @@ -463,6 +495,7 @@ def inference( top_p=0.85, gpt_cond_len=4, do_sample=True, + precomputed_latents=False, # Decoder inference decoder_iterations=100, cond_free=True, @@ -480,6 +513,9 @@ def inference( ref_audio_path: (str) Path to a reference audio file to be used for cloning. This audio file should be >3 seconds long. + precomputed_latents: (bool) True or false, determines if precomputed latents are used. If enabled, make + sure you ran precompute_conditioning_latents on the model first. + language: (str) Language of the voice to be generated. temperature: (float) The softmax temperature of the autoregressive model. Defaults to 0.65. @@ -531,19 +567,22 @@ def inference( assert ( text_tokens.shape[-1] < self.args.gpt_max_text_tokens ), " ❗ XTTS can only generate text with a maximum of 400 tokens." - - ( - gpt_cond_latent, - diffusion_conditioning, - ) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len) - + if precomputed_latents: + print("Using precomputed latents for " + ref_audio_path) + gpt_cond_latent = torch.load(ref_audio_path + ".gpt").to(self.device) + diffusion_conditioning = torch.load(ref_audio_path + ".diffusion").to(self.device) + else: + print("Using non-precomputed latents for " + ref_audio_path + "; if repeatedly inferring on the same reference speaker, consider precomputation to save inference time.") + ( + gpt_cond_latent, + diffusion_conditioning, + ) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len) diffuser = load_discrete_vocoder_diffuser( desired_diffusion_steps=decoder_iterations, cond_free=cond_free, cond_free_k=cond_free_k, sampler=decoder_sampler, ) - with torch.no_grad(): self.gpt = self.gpt.to(self.device) with self.lazy_load_model(self.gpt) as gpt: @@ -561,7 +600,6 @@ def inference( output_attentions=False, **hf_generate_kwargs, ) - with self.lazy_load_model(self.gpt) as gpt: expected_output_len = torch.tensor( [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device @@ -642,7 +680,7 @@ def load_checkpoint( self.init_models() if eval: self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) - self.load_state_dict(load_fsspec(model_path)["model"], strict=strict) + self.load_state_dict(load_fsspec(model_path,self.device)["model"], strict=strict) if eval: self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)