From 409db505d24debc744e9bee277753b85bd2b53bf Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Mon, 14 Aug 2023 15:04:44 -0400 Subject: [PATCH] Add device support in TTS and Synthesizer (#2855) * fix: resolve merge conflicts * fix: retain backwards compatability in functions * feature: utilize device for voice transfer * feature: use device for vocoder * chore: cleanup vocoder cpu logic * fix: add necessary vocoder output device check * fix: add necessary vocoder output device check * fix: indentation * fix: check if waveform is pt tensor before cpu conversion --------- Co-authored-by: Jake Tae --- TTS/api.py | 8 ++++- TTS/tts/utils/synthesis.py | 66 ++++++++++++++++++++++---------------- TTS/utils/synthesizer.py | 19 +++++++---- 3 files changed, 58 insertions(+), 35 deletions(-) diff --git a/TTS/api.py b/TTS/api.py index 5bb91362a3..2ee108ba70 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -1,8 +1,10 @@ import tempfile +import warnings from pathlib import Path from typing import Union import numpy as np +from torch import nn from TTS.cs_api import CS_API from TTS.utils.audio.numpy_transforms import save_wav @@ -10,7 +12,7 @@ from TTS.utils.synthesizer import Synthesizer -class TTS: +class TTS(nn.Module): """TODO: Add voice conversion and Capacitron support.""" def __init__( @@ -62,6 +64,7 @@ def __init__( Defaults to "XTTS". gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. """ + super().__init__() self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False) self.synthesizer = None @@ -70,6 +73,9 @@ def __init__( self.cs_api_model = cs_api_model self.model_name = None + if gpu: + warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.") + if model_name is not None: if "tts_models" in model_name or "coqui_studio" in model_name: self.load_tts_model_by_name(model_name, gpu) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 039816db1f..d3c29f9346 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -5,19 +5,21 @@ from torch import nn -def numpy_to_torch(np_array, dtype, cuda=False): +def numpy_to_torch(np_array, dtype, cuda=False, device="cpu"): + if cuda: + device = "cuda" if np_array is None: return None - tensor = torch.as_tensor(np_array, dtype=dtype) - if cuda: - return tensor.cuda() + tensor = torch.as_tensor(np_array, dtype=dtype, device=device) return tensor -def compute_style_mel(style_wav, ap, cuda=False): - style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0) +def compute_style_mel(style_wav, ap, cuda=False, device="cpu"): if cuda: - return style_mel.cuda() + device = "cuda" + style_mel = torch.FloatTensor( + ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate)), device=device, + ).unsqueeze(0) return style_mel @@ -73,22 +75,22 @@ def inv_spectrogram(postnet_output, ap, CONFIG): return wav -def id_to_torch(aux_id, cuda=False): +def id_to_torch(aux_id, cuda=False, device="cpu"): + if cuda: + device = "cuda" if aux_id is not None: aux_id = np.asarray(aux_id) - aux_id = torch.from_numpy(aux_id) - if cuda: - return aux_id.cuda() + aux_id = torch.from_numpy(aux_id).to(device) return aux_id -def embedding_to_torch(d_vector, cuda=False): +def embedding_to_torch(d_vector, cuda=False, device="cpu"): + if cuda: + device = "cuda" if d_vector is not None: d_vector = np.asarray(d_vector) d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor) - d_vector = d_vector.squeeze().unsqueeze(0) - if cuda: - return d_vector.cuda() + d_vector = d_vector.squeeze().unsqueeze(0).to(device) return d_vector @@ -162,6 +164,11 @@ def synthesis( language_id (int): Language ID passed to the language embedding layer in multi-langual model. Defaults to None. """ + # device + device = next(model.parameters()).device + if use_cuda: + device = "cuda" + # GST or Capacitron processing # TODO: need to handle the case of setting both gst and capacitron to true somewhere style_mel = None @@ -169,10 +176,10 @@ def synthesis( if isinstance(style_wav, dict): style_mel = style_wav else: - style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda) + style_mel = compute_style_mel(style_wav, model.ap, device=device) if CONFIG.has("capacitron_vae") and CONFIG.use_capacitron_vae and style_wav is not None: - style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda) + style_mel = compute_style_mel(style_wav, model.ap, device=device) style_mel = style_mel.transpose(1, 2) # [1, time, depth] language_name = None @@ -188,26 +195,26 @@ def synthesis( ) # pass tensors to backend if speaker_id is not None: - speaker_id = id_to_torch(speaker_id, cuda=use_cuda) + speaker_id = id_to_torch(speaker_id, device=device) if d_vector is not None: - d_vector = embedding_to_torch(d_vector, cuda=use_cuda) + d_vector = embedding_to_torch(d_vector, device=device) if language_id is not None: - language_id = id_to_torch(language_id, cuda=use_cuda) + language_id = id_to_torch(language_id, device=device) if not isinstance(style_mel, dict): # GST or Capacitron style mel - style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) + style_mel = numpy_to_torch(style_mel, torch.float, device=device) if style_text is not None: style_text = np.asarray( model.tokenizer.text_to_ids(style_text, language=language_id), dtype=np.int32, ) - style_text = numpy_to_torch(style_text, torch.long, cuda=use_cuda) + style_text = numpy_to_torch(style_text, torch.long, device=device) style_text = style_text.unsqueeze(0) - text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda) + text_inputs = numpy_to_torch(text_inputs, torch.long, device=device) text_inputs = text_inputs.unsqueeze(0) # synthesize voice outputs = run_model_torch( @@ -290,22 +297,27 @@ def transfer_voice( do_trim_silence (bool): trim silence after synthesis. Defaults to False. """ + # device + device = next(model.parameters()).device + if use_cuda: + device = "cuda" + # pass tensors to backend if speaker_id is not None: - speaker_id = id_to_torch(speaker_id, cuda=use_cuda) + speaker_id = id_to_torch(speaker_id, device=device) if d_vector is not None: - d_vector = embedding_to_torch(d_vector, cuda=use_cuda) + d_vector = embedding_to_torch(d_vector, device=device) if reference_d_vector is not None: - reference_d_vector = embedding_to_torch(reference_d_vector, cuda=use_cuda) + reference_d_vector = embedding_to_torch(reference_d_vector, device=device) # load reference_wav audio reference_wav = embedding_to_torch( model.ap.load_wav( reference_wav, sr=model.args.encoder_sample_rate if model.args.encoder_sample_rate else model.ap.sample_rate ), - cuda=use_cuda, + device=device, ) if hasattr(model, "module"): diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index bc0e231df0..fbae32162d 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -5,6 +5,7 @@ import numpy as np import pysbd import torch +from torch import nn from TTS.config import load_config from TTS.tts.configs.vits_config import VitsConfig @@ -21,7 +22,7 @@ from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input -class Synthesizer(object): +class Synthesizer(nn.Module): def __init__( self, tts_checkpoint: str = "", @@ -60,6 +61,7 @@ def __init__( vc_config (str, optional): path to the voice conversion config file. Defaults to `""`, use_cuda (bool, optional): enable/disable cuda. Defaults to False. """ + super().__init__() self.tts_checkpoint = tts_checkpoint self.tts_config_path = tts_config_path self.tts_speakers_file = tts_speakers_file @@ -356,7 +358,12 @@ def tts( if speaker_wav is not None and self.tts_model.speaker_manager is not None: speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav) + vocoder_device = "cpu" use_gl = self.vocoder_model is None + if not use_gl: + vocoder_device = next(self.vocoder_model.parameters()).device + if self.use_cuda: + vocoder_device = "cuda" if not reference_wav: # not voice conversion for sen in sens: @@ -388,7 +395,6 @@ def tts( mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy() # denormalize tts output based on tts audio config mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T - device_type = "cuda" if self.use_cuda else "cpu" # renormalize spectrogram based on vocoder config vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) # compute scale factor for possible sample rate mismatch @@ -403,8 +409,8 @@ def tts( vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable # run vocoder model # [1, T, C] - waveform = self.vocoder_model.inference(vocoder_input.to(device_type)) - if self.use_cuda and not use_gl: + waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device)) + if torch.is_tensor(waveform) and waveform.device != torch.device("cpu") and not use_gl: waveform = waveform.cpu() if not use_gl: waveform = waveform.numpy() @@ -453,7 +459,6 @@ def tts( mel_postnet_spec = outputs[0].detach().cpu().numpy() # denormalize tts output based on tts audio config mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T - device_type = "cuda" if self.use_cuda else "cpu" # renormalize spectrogram based on vocoder config vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) # compute scale factor for possible sample rate mismatch @@ -468,8 +473,8 @@ def tts( vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable # run vocoder model # [1, T, C] - waveform = self.vocoder_model.inference(vocoder_input.to(device_type)) - if self.use_cuda: + waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device)) + if torch.is_tensor(waveform) and waveform.device != torch.device("cpu"): waveform = waveform.cpu() if not use_gl: waveform = waveform.numpy()