Skip to content

Commit

Permalink
Merge pull request #2894 from coqui-ai/dev
Browse files Browse the repository at this point in the history
v0.16.5
  • Loading branch information
erogol authored Aug 26, 2023
2 parents c4e5eff + c0b5e61 commit 530a893
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 41 deletions.
10 changes: 6 additions & 4 deletions TTS/.models.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
"hf_url": [
"https://coqui.gateway.scarf.sh/hf/bark/coarse_2.pt",
"https://coqui.gateway.scarf.sh/hf/bark/fine_2.pt",
"https://coqui.gateway.scarf.sh/hf/bark/text_2.pt",
"https://coqui.gateway.scarf.sh/hf/bark/config.json"
"https://app.coqui.ai/tts_model/text_2.pt",
"https://coqui.gateway.scarf.sh/hf/bark/config.json",
"https://coqui.gateway.scarf.sh/hf/bark/hubert.pt",
"https://coqui.gateway.scarf.sh/hf/bark/tokenizer.pth"
],
"default_vocoder": null,
"commit": "e9a1953e",
Expand Down Expand Up @@ -238,7 +240,7 @@
"tortoise-v2": {
"description": "Tortoise tts model https://github.com/neonbjb/tortoise-tts",
"github_rls_url": [
"https://coqui.gateway.scarf.sh/v0.14.1_models/autoregressive.pth",
"https://app.coqui.ai/tts_model/autoregressive.pth",
"https://coqui.gateway.scarf.sh/v0.14.1_models/clvp2.pth",
"https://coqui.gateway.scarf.sh/v0.14.1_models/cvvp.pth",
"https://coqui.gateway.scarf.sh/v0.14.1_models/diffusion_decoder.pth",
Expand Down Expand Up @@ -879,4 +881,4 @@
}
}
}
}
}
2 changes: 1 addition & 1 deletion TTS/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.16.3
0.16.5
8 changes: 7 additions & 1 deletion TTS/api.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import tempfile
import warnings
from pathlib import Path
from typing import Union

import numpy as np
from torch import nn

from TTS.cs_api import CS_API
from TTS.utils.audio.numpy_transforms import save_wav
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer


class TTS:
class TTS(nn.Module):
"""TODO: Add voice conversion and Capacitron support."""

def __init__(
Expand Down Expand Up @@ -62,6 +64,7 @@ def __init__(
Defaults to "XTTS".
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
super().__init__()
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False)

self.synthesizer = None
Expand All @@ -70,6 +73,9 @@ def __init__(
self.cs_api_model = cs_api_model
self.model_name = None

if gpu:
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")

if model_name is not None:
if "tts_models" in model_name or "coqui_studio" in model_name:
self.load_tts_model_by_name(model_name, gpu)
Expand Down
6 changes: 6 additions & 0 deletions TTS/tts/models/bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ def load_checkpoint(
text_model_path=None,
coarse_model_path=None,
fine_model_path=None,
hubert_model_path=None,
hubert_tokenizer_path=None,
eval=False,
strict=True,
**kwargs,
Expand All @@ -267,10 +269,14 @@ def load_checkpoint(
text_model_path = text_model_path or os.path.join(checkpoint_dir, "text_2.pt")
coarse_model_path = coarse_model_path or os.path.join(checkpoint_dir, "coarse_2.pt")
fine_model_path = fine_model_path or os.path.join(checkpoint_dir, "fine_2.pt")
hubert_model_path = hubert_model_path or os.path.join(checkpoint_dir, "hubert.pt")
hubert_tokenizer_path = hubert_tokenizer_path or os.path.join(checkpoint_dir, "tokenizer.pth")

self.config.LOCAL_MODEL_PATHS["text"] = text_model_path
self.config.LOCAL_MODEL_PATHS["coarse"] = coarse_model_path
self.config.LOCAL_MODEL_PATHS["fine"] = fine_model_path
self.config.LOCAL_MODEL_PATHS["hubert"] = hubert_model_path
self.config.LOCAL_MODEL_PATHS["hubert_tokenizer"] = hubert_tokenizer_path

self.load_bark_models()

Expand Down
67 changes: 40 additions & 27 deletions TTS/tts/utils/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@
from torch import nn


def numpy_to_torch(np_array, dtype, cuda=False):
def numpy_to_torch(np_array, dtype, cuda=False, device="cpu"):
if cuda:
device = "cuda"
if np_array is None:
return None
tensor = torch.as_tensor(np_array, dtype=dtype)
if cuda:
return tensor.cuda()
tensor = torch.as_tensor(np_array, dtype=dtype, device=device)
return tensor


def compute_style_mel(style_wav, ap, cuda=False):
style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
def compute_style_mel(style_wav, ap, cuda=False, device="cpu"):
if cuda:
return style_mel.cuda()
device = "cuda"
style_mel = torch.FloatTensor(
ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate)),
device=device,
).unsqueeze(0)
return style_mel


Expand Down Expand Up @@ -73,22 +76,22 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
return wav


def id_to_torch(aux_id, cuda=False):
def id_to_torch(aux_id, cuda=False, device="cpu"):
if cuda:
device = "cuda"
if aux_id is not None:
aux_id = np.asarray(aux_id)
aux_id = torch.from_numpy(aux_id)
if cuda:
return aux_id.cuda()
aux_id = torch.from_numpy(aux_id).to(device)
return aux_id


def embedding_to_torch(d_vector, cuda=False):
def embedding_to_torch(d_vector, cuda=False, device="cpu"):
if cuda:
device = "cuda"
if d_vector is not None:
d_vector = np.asarray(d_vector)
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
d_vector = d_vector.squeeze().unsqueeze(0)
if cuda:
return d_vector.cuda()
d_vector = d_vector.squeeze().unsqueeze(0).to(device)
return d_vector


Expand Down Expand Up @@ -162,17 +165,22 @@ def synthesis(
language_id (int):
Language ID passed to the language embedding layer in multi-langual model. Defaults to None.
"""
# device
device = next(model.parameters()).device
if use_cuda:
device = "cuda"

# GST or Capacitron processing
# TODO: need to handle the case of setting both gst and capacitron to true somewhere
style_mel = None
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
if isinstance(style_wav, dict):
style_mel = style_wav
else:
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
style_mel = compute_style_mel(style_wav, model.ap, device=device)

if CONFIG.has("capacitron_vae") and CONFIG.use_capacitron_vae and style_wav is not None:
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
style_mel = compute_style_mel(style_wav, model.ap, device=device)
style_mel = style_mel.transpose(1, 2) # [1, time, depth]

language_name = None
Expand All @@ -188,26 +196,26 @@ def synthesis(
)
# pass tensors to backend
if speaker_id is not None:
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
speaker_id = id_to_torch(speaker_id, device=device)

if d_vector is not None:
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
d_vector = embedding_to_torch(d_vector, device=device)

if language_id is not None:
language_id = id_to_torch(language_id, cuda=use_cuda)
language_id = id_to_torch(language_id, device=device)

if not isinstance(style_mel, dict):
# GST or Capacitron style mel
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
style_mel = numpy_to_torch(style_mel, torch.float, device=device)
if style_text is not None:
style_text = np.asarray(
model.tokenizer.text_to_ids(style_text, language=language_id),
dtype=np.int32,
)
style_text = numpy_to_torch(style_text, torch.long, cuda=use_cuda)
style_text = numpy_to_torch(style_text, torch.long, device=device)
style_text = style_text.unsqueeze(0)

text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
text_inputs = numpy_to_torch(text_inputs, torch.long, device=device)
text_inputs = text_inputs.unsqueeze(0)
# synthesize voice
outputs = run_model_torch(
Expand Down Expand Up @@ -290,22 +298,27 @@ def transfer_voice(
do_trim_silence (bool):
trim silence after synthesis. Defaults to False.
"""
# device
device = next(model.parameters()).device
if use_cuda:
device = "cuda"

# pass tensors to backend
if speaker_id is not None:
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
speaker_id = id_to_torch(speaker_id, device=device)

if d_vector is not None:
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
d_vector = embedding_to_torch(d_vector, device=device)

if reference_d_vector is not None:
reference_d_vector = embedding_to_torch(reference_d_vector, cuda=use_cuda)
reference_d_vector = embedding_to_torch(reference_d_vector, device=device)

# load reference_wav audio
reference_wav = embedding_to_torch(
model.ap.load_wav(
reference_wav, sr=model.args.encoder_sample_rate if model.args.encoder_sample_rate else model.ap.sample_rate
),
cuda=use_cuda,
device=device,
)

if hasattr(model, "module"):
Expand Down
8 changes: 7 additions & 1 deletion TTS/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,13 @@ def get_import_path(obj: object) -> str:


def get_user_data_dir(appname):
if sys.platform == "win32":
TTS_HOME = os.environ.get("TTS_HOME")
XDG_DATA_HOME = os.environ.get("XDG_DATA_HOME")
if TTS_HOME is not None:
ans = Path(TTS_HOME).expanduser().resolve(strict=False)
elif XDG_DATA_HOME is not None:
ans = Path(XDG_DATA_HOME).expanduser().resolve(strict=False)
elif sys.platform == "win32":
import winreg # pylint: disable=import-outside-toplevel

key = winreg.OpenKey(
Expand Down
19 changes: 12 additions & 7 deletions TTS/utils/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pysbd
import torch
from torch import nn

from TTS.config import load_config
from TTS.tts.configs.vits_config import VitsConfig
Expand All @@ -21,7 +22,7 @@
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input


class Synthesizer(object):
class Synthesizer(nn.Module):
def __init__(
self,
tts_checkpoint: str = "",
Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(
vc_config (str, optional): path to the voice conversion config file. Defaults to `""`,
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
"""
super().__init__()
self.tts_checkpoint = tts_checkpoint
self.tts_config_path = tts_config_path
self.tts_speakers_file = tts_speakers_file
Expand Down Expand Up @@ -356,7 +358,12 @@ def tts(
if speaker_wav is not None and self.tts_model.speaker_manager is not None:
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)

vocoder_device = "cpu"
use_gl = self.vocoder_model is None
if not use_gl:
vocoder_device = next(self.vocoder_model.parameters()).device
if self.use_cuda:
vocoder_device = "cuda"

if not reference_wav: # not voice conversion
for sen in sens:
Expand Down Expand Up @@ -388,7 +395,6 @@ def tts(
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
# denormalize tts output based on tts audio config
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
device_type = "cuda" if self.use_cuda else "cpu"
# renormalize spectrogram based on vocoder config
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
# compute scale factor for possible sample rate mismatch
Expand All @@ -403,8 +409,8 @@ def tts(
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
# run vocoder model
# [1, T, C]
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
if self.use_cuda and not use_gl:
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device))
if torch.is_tensor(waveform) and waveform.device != torch.device("cpu") and not use_gl:
waveform = waveform.cpu()
if not use_gl:
waveform = waveform.numpy()
Expand Down Expand Up @@ -453,7 +459,6 @@ def tts(
mel_postnet_spec = outputs[0].detach().cpu().numpy()
# denormalize tts output based on tts audio config
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
device_type = "cuda" if self.use_cuda else "cpu"
# renormalize spectrogram based on vocoder config
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
# compute scale factor for possible sample rate mismatch
Expand All @@ -468,8 +473,8 @@ def tts(
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
# run vocoder model
# [1, T, C]
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
if self.use_cuda:
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device))
if torch.is_tensor(waveform) and waveform.device != torch.device("cpu"):
waveform = waveform.cpu()
if not use_gl:
waveform = waveform.numpy()
Expand Down

0 comments on commit 530a893

Please sign in to comment.