-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
3,059 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import torch | ||
from torch import nn | ||
from torch.nn.modules.conv import Conv1d | ||
|
||
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator | ||
|
||
|
||
class DiscriminatorS(torch.nn.Module): | ||
"""HiFiGAN Scale Discriminator. Channel sizes are different from the original HiFiGAN. | ||
Args: | ||
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. | ||
""" | ||
|
||
def __init__(self, use_spectral_norm=False): | ||
super().__init__() | ||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm | ||
self.convs = nn.ModuleList( | ||
[ | ||
norm_f(Conv1d(1, 16, 15, 1, padding=7)), | ||
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), | ||
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), | ||
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), | ||
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), | ||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), | ||
] | ||
) | ||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) | ||
|
||
def forward(self, x): | ||
""" | ||
Args: | ||
x (Tensor): input waveform. | ||
Returns: | ||
Tensor: discriminator scores. | ||
List[Tensor]: list of features from the convolutiona layers. | ||
""" | ||
feat = [] | ||
for l in self.convs: | ||
x = l(x) | ||
x = torch.nn.functional.leaky_relu(x, 0.1) | ||
feat.append(x) | ||
x = self.conv_post(x) | ||
feat.append(x) | ||
x = torch.flatten(x, 1, -1) | ||
return x, feat | ||
|
||
|
||
class VitsDiscriminator(nn.Module): | ||
"""VITS discriminator wrapping one Scale Discriminator and a stack of Period Discriminator. | ||
:: | ||
waveform -> ScaleDiscriminator() -> scores_sd, feats_sd --> append() -> scores, feats | ||
|--> MultiPeriodDiscriminator() -> scores_mpd, feats_mpd ^ | ||
Args: | ||
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. | ||
""" | ||
|
||
def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False): | ||
super().__init__() | ||
self.nets = nn.ModuleList() | ||
self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm)) | ||
self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]) | ||
|
||
def forward(self, x, x_hat=None): | ||
""" | ||
Args: | ||
x (Tensor): ground truth waveform. | ||
x_hat (Tensor): predicted waveform. | ||
Returns: | ||
List[Tensor]: discriminator scores. | ||
List[List[Tensor]]: list of list of features from each layers of each discriminator. | ||
""" | ||
x_scores = [] | ||
x_hat_scores = [] if x_hat is not None else None | ||
x_feats = [] | ||
x_hat_feats = [] if x_hat is not None else None | ||
for net in self.nets: | ||
x_score, x_feat = net(x) | ||
x_scores.append(x_score) | ||
x_feats.append(x_feat) | ||
if x_hat is not None: | ||
x_hat_score, x_hat_feat = net(x_hat) | ||
x_hat_scores.append(x_hat_score) | ||
x_hat_feats.append(x_hat_feat) | ||
return x_scores, x_feats, x_hat_scores, x_hat_feats |
Oops, something went wrong.