diff --git a/src/python/piper_train/vits/lightning.py b/src/python/piper_train/vits/lightning.py index a70eab46..5dc1c96c 100644 --- a/src/python/piper_train/vits/lightning.py +++ b/src/python/piper_train/vits/lightning.py @@ -368,8 +368,8 @@ def training_step_g(self, batch: Batch): self._y = y _y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.model_d(y, y_hat) - if net_dur_disc is not None: - y_dur_hat_r, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw_, logw) + if self.net_dur_disc is not None: + y_dur_hat_r, y_dur_hat_g = self.net_dur_disc(hidden_x, x_mask, logw_, logw) with autocast(self.device.type, enabled=False): # Generator loss loss_dur = torch.sum(l_length.float()) @@ -414,7 +414,7 @@ def training_step_d(self, batch: Batch): def training_step_dur(self, batch: Batch): if self.net_dur_disc is not None: - y_dur_hat_r, y_dur_hat_g = net_dur_disc( + y_dur_hat_r, y_dur_hat_g = self.net_dur_disc( self.hidden_x.detach(), self.x_mask.detach(), self.logw_.detach(), self.logw.detach() ) # logw is predicted duration, logw_ is real duration with autocast(self.device.type, enabled=False): @@ -422,6 +422,8 @@ def training_step_dur(self, batch: Batch): loss_dur_disc_all = loss_dur_disc self.log("loss_dur_disc_all", loss_dur_disc_all) + return loss_dur_disc_all + def validation_step(self, batch: Batch, batch_idx: int): val_loss = self.training_step_g(batch) + self.training_step_d(batch) + self.training_step_dur(batch) self.log("val_loss", val_loss) diff --git a/src/python/piper_train/vits/models.py b/src/python/piper_train/vits/models.py index a27c0da7..c85185f9 100644 --- a/src/python/piper_train/vits/models.py +++ b/src/python/piper_train/vits/models.py @@ -1675,7 +1675,7 @@ def __init__( self.emb_g = nn.Embedding(n_speakers, gin_channels) def forward(self, x, x_lengths, y, y_lengths, sid=None): - if self.n_speakers > 0: + if self.n_speakers > 1: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] else: g = None @@ -1754,7 +1754,7 @@ def infer( noise_scale_w=1., max_len=None, ): - if self.n_speakers > 0: + if self.n_speakers > 1: assert sid is not None, "Missing speaker id" g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] else: