diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index f4f4c6391f..58769ddeff 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1813,14 +1813,16 @@ def export_onnx(self, output_path: str = "coqui_vits.onnx", verbose: bool = True # rollback values _forward = self.forward - disc = self.disc + disc = None + if hasattr(self, 'disc'): + disc = self.disc training = self.training # set export mode self.disc = None self.eval() - def onnx_inference(text, text_lengths, scales, sid=None): + def onnx_inference(text, text_lengths, scales, sid=None, langid=None): noise_scale = scales[0] length_scale = scales[1] noise_scale_dp = scales[2] @@ -1833,7 +1835,7 @@ def onnx_inference(text, text_lengths, scales, sid=None): "x_lengths": text_lengths, "d_vectors": None, "speaker_ids": sid, - "language_ids": None, + "language_ids": langid, "durations": None, }, )["model_outputs"] @@ -1844,11 +1846,14 @@ def onnx_inference(text, text_lengths, scales, sid=None): dummy_input_length = 100 sequences = torch.randint(low=0, high=self.args.num_chars, size=(1, dummy_input_length), dtype=torch.long) sequence_lengths = torch.LongTensor([sequences.size(1)]) - sepaker_id = None + speaker_id = None + language_id = None if self.num_speakers > 1: - sepaker_id = torch.LongTensor([0]) + speaker_id = torch.LongTensor([0]) + if self.num_languages > 0 and self.embedded_language_dim > 0: + language_id = torch.LongTensor([0]) scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp]) - dummy_input = (sequences, sequence_lengths, scales, sepaker_id) + dummy_input = (sequences, sequence_lengths, scales, speaker_id, language_id) # export to ONNX torch.onnx.export( @@ -1857,7 +1862,7 @@ def onnx_inference(text, text_lengths, scales, sid=None): opset_version=15, f=output_path, verbose=verbose, - input_names=["input", "input_lengths", "scales", "sid"], + input_names=["input", "input_lengths", "scales", "sid", "langid"], output_names=["output"], dynamic_axes={ "input": {0: "batch_size", 1: "phonemes"}, @@ -1870,7 +1875,8 @@ def onnx_inference(text, text_lengths, scales, sid=None): self.forward = _forward if training: self.train() - self.disc = disc + if not disc is None: + self.disc = disc def load_onnx(self, model_path: str, cuda=False): import onnxruntime as ort @@ -1887,7 +1893,7 @@ def load_onnx(self, model_path: str, cuda=False): providers=providers, ) - def inference_onnx(self, x, x_lengths=None, speaker_id=None): + def inference_onnx(self, x, x_lengths=None, speaker_id=None, language_id=None): """ONNX inference""" if isinstance(x, torch.Tensor): @@ -1902,13 +1908,15 @@ def inference_onnx(self, x, x_lengths=None, speaker_id=None): [self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp], dtype=np.float32, ) + audio = self.onnx_sess.run( ["output"], { "input": x, "input_lengths": x_lengths, "scales": scales, - "sid": torch.tensor([speaker_id]).cpu().numpy(), + "sid": None if speaker_id is None else torch.tensor([speaker_id]).cpu().numpy(), + "langid": None if language_id is None else torch.tensor([language_id]).cpu().numpy() }, ) return audio[0][0]