Skip to content

Commit

Permalink
Ready for testing.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmcpantoja committed Apr 25, 2024
1 parent 9e3a65d commit 0ecfc14
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 5 additions & 3 deletions src/python/piper_train/vits/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ def training_step_g(self, batch: Batch):
self._y = y

_y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.model_d(y, y_hat)
if net_dur_disc is not None:
y_dur_hat_r, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw_, logw)
if self.net_dur_disc is not None:
y_dur_hat_r, y_dur_hat_g = self.net_dur_disc(hidden_x, x_mask, logw_, logw)
with autocast(self.device.type, enabled=False):
# Generator loss
loss_dur = torch.sum(l_length.float())
Expand Down Expand Up @@ -414,14 +414,16 @@ def training_step_d(self, batch: Batch):

def training_step_dur(self, batch: Batch):
if self.net_dur_disc is not None:
y_dur_hat_r, y_dur_hat_g = net_dur_disc(
y_dur_hat_r, y_dur_hat_g = self.net_dur_disc(
self.hidden_x.detach(), self.x_mask.detach(), self.logw_.detach(), self.logw.detach()
) # logw is predicted duration, logw_ is real duration
with autocast(self.device.type, enabled=False):
loss_dur_disc, losses_dur_disc_r, losses_dur_disc_g = discriminator_loss(y_dur_hat_r, y_dur_hat_g)
loss_dur_disc_all = loss_dur_disc
self.log("loss_dur_disc_all", loss_dur_disc_all)

return loss_dur_disc_all

def validation_step(self, batch: Batch, batch_idx: int):
val_loss = self.training_step_g(batch) + self.training_step_d(batch) + self.training_step_dur(batch)
self.log("val_loss", val_loss)
Expand Down
4 changes: 2 additions & 2 deletions src/python/piper_train/vits/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,7 +1675,7 @@ def __init__(
self.emb_g = nn.Embedding(n_speakers, gin_channels)

def forward(self, x, x_lengths, y, y_lengths, sid=None):
if self.n_speakers > 0:
if self.n_speakers > 1:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
g = None
Expand Down Expand Up @@ -1754,7 +1754,7 @@ def infer(
noise_scale_w=1.,
max_len=None,
):
if self.n_speakers > 0:
if self.n_speakers > 1:
assert sid is not None, "Missing speaker id"
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
Expand Down

0 comments on commit 0ecfc14

Please sign in to comment.