Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for precomputing conditioning latents to xtts for repeated inference on the same reference wavs for significant performance gains. #2956

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from contextlib import contextmanager
from dataclasses import dataclass
import time
import librosa

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -404,6 +406,36 @@ def get_conditioning_latents(
)
return gpt_cond_latents.to(self.device), diffusion_cond_latents.to(self.device)

# When working in an environment where your reference wavs are going to be consistent, there's no reason we need to compute the conditioning latents every time,
# especially when they're the bulk of the processing time during inference. When you can cache it, always cache it.
def precompute_conditioning_latents(
self,
audio_folder_path
):
print("Beginning latent precomputation.")
for file in os.listdir(audio_folder_path):
if os.path.isdir(audio_folder_path + "/" + file):
continue
if file.endswith(".gpt") or file.endswith(".diffusion"):
continue
duration = librosa.get_duration(filename=audio_folder_path + "/" + file) # probably a faster/cheaper way to do this
if duration < 3:
print("Skipping " + str(file) + "; duration is " + str(int(duration)))
continue
print("Computing " + str(file))
if os.path.exists(audio_folder_path + "/" + file + ".gpt"):
print("Skipping " + str(file) + ".gpt; already exists.")
else:
gpt_cond_latents = self.get_gpt_cond_latents(audio_folder_path + "/" + file, length=int(duration)) # [1, 1024, T]
torch.save(gpt_cond_latents, audio_folder_path + "/" + file + ".gpt")
if os.path.exists(audio_folder_path + "/" + file + ".diffusion"):
print("Skipping " + str(file) + ".diffusion; already exists.")
else:
diffusion_cond_latents = self.get_diffusion_cond_latents(
audio_folder_path + "/" + file,
)
torch.save(diffusion_cond_latents, audio_folder_path + "/" + file + ".diffusion")
print("Latent precomputation complete, files saved adjacent to the input files with .gpt and .diffusion extensions.")
def synthesize(self, text, config, speaker_wav, language, **kwargs):
"""Synthesize speech with the given input text.

Expand Down Expand Up @@ -463,6 +495,7 @@ def inference(
top_p=0.85,
gpt_cond_len=4,
do_sample=True,
precomputed_latents=False,
# Decoder inference
decoder_iterations=100,
cond_free=True,
Expand All @@ -480,6 +513,9 @@ def inference(
ref_audio_path: (str) Path to a reference audio file to be used for cloning. This audio file should be >3
seconds long.

precomputed_latents: (bool) True or false, determines if precomputed latents are used. If enabled, make
sure you ran precompute_conditioning_latents on the model first.

language: (str) Language of the voice to be generated.

temperature: (float) The softmax temperature of the autoregressive model. Defaults to 0.65.
Expand Down Expand Up @@ -531,19 +567,22 @@ def inference(
assert (
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens."

(
gpt_cond_latent,
diffusion_conditioning,
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)

if precomputed_latents:
print("Using precomputed latents for " + ref_audio_path)
gpt_cond_latent = torch.load(ref_audio_path + ".gpt").to(self.device)
diffusion_conditioning = torch.load(ref_audio_path + ".diffusion").to(self.device)
else:
print("Using non-precomputed latents for " + ref_audio_path + "; if repeatedly inferring on the same reference speaker, consider precomputation to save inference time.")
(
gpt_cond_latent,
diffusion_conditioning,
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
diffuser = load_discrete_vocoder_diffuser(
desired_diffusion_steps=decoder_iterations,
cond_free=cond_free,
cond_free_k=cond_free_k,
sampler=decoder_sampler,
)

with torch.no_grad():
self.gpt = self.gpt.to(self.device)
with self.lazy_load_model(self.gpt) as gpt:
Expand All @@ -561,7 +600,6 @@ def inference(
output_attentions=False,
**hf_generate_kwargs,
)

with self.lazy_load_model(self.gpt) as gpt:
expected_output_len = torch.tensor(
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
Expand Down Expand Up @@ -642,7 +680,7 @@ def load_checkpoint(
self.init_models()
if eval:
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
self.load_state_dict(load_fsspec(model_path)["model"], strict=strict)
self.load_state_dict(load_fsspec(model_path,self.device)["model"], strict=strict)

if eval:
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
Expand Down
Loading