Skip to content

Commit

Permalink
Adds multi-language support for VITS onnx, fixes onnx inference error…
Browse files Browse the repository at this point in the history
… when speaker_id is None or not passed, fixes onnx exporting for models with init_discriminator=false (coqui-ai#2816)
  • Loading branch information
SystemPanic authored and Tindell committed Aug 14, 2023
1 parent 00e449d commit e645035
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,14 +1814,16 @@ def export_onnx(self, output_path: str = "coqui_vits.onnx", verbose: bool = True

# rollback values
_forward = self.forward
disc = self.disc
disc = None
if hasattr(self, 'disc'):
disc = self.disc
training = self.training

# set export mode
self.disc = None
self.eval()

def onnx_inference(text, text_lengths, scales, sid=None):
def onnx_inference(text, text_lengths, scales, sid=None, langid=None):
noise_scale = scales[0]
length_scale = scales[1]
noise_scale_dp = scales[2]
Expand All @@ -1834,7 +1836,7 @@ def onnx_inference(text, text_lengths, scales, sid=None):
"x_lengths": text_lengths,
"d_vectors": None,
"speaker_ids": sid,
"language_ids": None,
"language_ids": langid,
"durations": None,
},
)["model_outputs"]
Expand All @@ -1845,11 +1847,14 @@ def onnx_inference(text, text_lengths, scales, sid=None):
dummy_input_length = 100
sequences = torch.randint(low=0, high=self.args.num_chars, size=(1, dummy_input_length), dtype=torch.long)
sequence_lengths = torch.LongTensor([sequences.size(1)])
sepaker_id = None
speaker_id = None
language_id = None
if self.num_speakers > 1:
sepaker_id = torch.LongTensor([0])
speaker_id = torch.LongTensor([0])
if self.num_languages > 0 and self.embedded_language_dim > 0:
language_id = torch.LongTensor([0])
scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp])
dummy_input = (sequences, sequence_lengths, scales, sepaker_id)
dummy_input = (sequences, sequence_lengths, scales, speaker_id, language_id)

# export to ONNX
torch.onnx.export(
Expand All @@ -1858,7 +1863,7 @@ def onnx_inference(text, text_lengths, scales, sid=None):
opset_version=15,
f=output_path,
verbose=verbose,
input_names=["input", "input_lengths", "scales", "sid"],
input_names=["input", "input_lengths", "scales", "sid", "langid"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size", 1: "phonemes"},
Expand All @@ -1871,7 +1876,8 @@ def onnx_inference(text, text_lengths, scales, sid=None):
self.forward = _forward
if training:
self.train()
self.disc = disc
if not disc is None:
self.disc = disc

def load_onnx(self, model_path: str, cuda=False):
import onnxruntime as ort
Expand All @@ -1888,7 +1894,7 @@ def load_onnx(self, model_path: str, cuda=False):
providers=providers,
)

def inference_onnx(self, x, x_lengths=None, speaker_id=None):
def inference_onnx(self, x, x_lengths=None, speaker_id=None, language_id=None):
"""ONNX inference"""

if isinstance(x, torch.Tensor):
Expand All @@ -1903,13 +1909,15 @@ def inference_onnx(self, x, x_lengths=None, speaker_id=None):
[self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp],
dtype=np.float32,
)

audio = self.onnx_sess.run(
["output"],
{
"input": x,
"input_lengths": x_lengths,
"scales": scales,
"sid": torch.tensor([speaker_id]).cpu().numpy(),
"sid": None if speaker_id is None else torch.tensor([speaker_id]).cpu().numpy(),
"langid": None if language_id is None else torch.tensor([language_id]).cpu().numpy()
},
)
return audio[0][0]
Expand Down

0 comments on commit e645035

Please sign in to comment.