Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds multi-language support for VITS onnx, fixes onnx exporting and inference errors #2816

Merged
merged 1 commit into from
Jul 31, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,14 +1813,16 @@ def export_onnx(self, output_path: str = "coqui_vits.onnx", verbose: bool = True

# rollback values
_forward = self.forward
disc = self.disc
disc = None
if hasattr(self, 'disc'):
disc = self.disc
training = self.training

# set export mode
self.disc = None
self.eval()

def onnx_inference(text, text_lengths, scales, sid=None):
def onnx_inference(text, text_lengths, scales, sid=None, langid=None):
noise_scale = scales[0]
length_scale = scales[1]
noise_scale_dp = scales[2]
Expand All @@ -1833,7 +1835,7 @@ def onnx_inference(text, text_lengths, scales, sid=None):
"x_lengths": text_lengths,
"d_vectors": None,
"speaker_ids": sid,
"language_ids": None,
"language_ids": langid,
"durations": None,
},
)["model_outputs"]
Expand All @@ -1844,11 +1846,14 @@ def onnx_inference(text, text_lengths, scales, sid=None):
dummy_input_length = 100
sequences = torch.randint(low=0, high=self.args.num_chars, size=(1, dummy_input_length), dtype=torch.long)
sequence_lengths = torch.LongTensor([sequences.size(1)])
sepaker_id = None
speaker_id = None
language_id = None
if self.num_speakers > 1:
sepaker_id = torch.LongTensor([0])
speaker_id = torch.LongTensor([0])
if self.num_languages > 0 and self.embedded_language_dim > 0:
language_id = torch.LongTensor([0])
scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp])
dummy_input = (sequences, sequence_lengths, scales, sepaker_id)
dummy_input = (sequences, sequence_lengths, scales, speaker_id, language_id)

# export to ONNX
torch.onnx.export(
Expand All @@ -1857,7 +1862,7 @@ def onnx_inference(text, text_lengths, scales, sid=None):
opset_version=15,
f=output_path,
verbose=verbose,
input_names=["input", "input_lengths", "scales", "sid"],
input_names=["input", "input_lengths", "scales", "sid", "langid"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size", 1: "phonemes"},
Expand All @@ -1870,7 +1875,8 @@ def onnx_inference(text, text_lengths, scales, sid=None):
self.forward = _forward
if training:
self.train()
self.disc = disc
if not disc is None:
self.disc = disc

def load_onnx(self, model_path: str, cuda=False):
import onnxruntime as ort
Expand All @@ -1887,7 +1893,7 @@ def load_onnx(self, model_path: str, cuda=False):
providers=providers,
)

def inference_onnx(self, x, x_lengths=None, speaker_id=None):
def inference_onnx(self, x, x_lengths=None, speaker_id=None, language_id=None):
"""ONNX inference"""

if isinstance(x, torch.Tensor):
Expand All @@ -1902,13 +1908,15 @@ def inference_onnx(self, x, x_lengths=None, speaker_id=None):
[self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp],
dtype=np.float32,
)

audio = self.onnx_sess.run(
["output"],
{
"input": x,
"input_lengths": x_lengths,
"scales": scales,
"sid": torch.tensor([speaker_id]).cpu().numpy(),
"sid": None if speaker_id is None else torch.tensor([speaker_id]).cpu().numpy(),
"langid": None if language_id is None else torch.tensor([language_id]).cpu().numpy()
},
)
return audio[0][0]
Expand Down
Loading