Skip to content

Commit

Permalink
initial prototype for vits2
Browse files Browse the repository at this point in the history
  • Loading branch information
p0p4k committed Aug 4, 2023
1 parent 4e7f8cd commit e65cab7
Show file tree
Hide file tree
Showing 6 changed files with 3,059 additions and 0 deletions.
130 changes: 130 additions & 0 deletions TTS/tts/layers/glow_tts/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,133 @@ def forward(self, x, x_mask):
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x

class ConditionalRelativePositionTransformer(nn.Module):
"""Transformer with Relative Potional Encoding and conditioned on external embeddings at cond_layer_idx'th layer.
https://arxiv.org/abs/2307.16430
Args:
in_channels (int): number of channels of the input tensor.
out_chanels (int): number of channels of the output tensor.
hidden_channels (int): model hidden channels.
hidden_channels_ffn (int): hidden channels of FeedForwardNetwork.
num_heads (int): number of attention heads.
num_layers (int): number of transformer layers.
kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1.
dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0.
rel_attn_window_size (int, optional): relation attention window size.
If 4, for each time step next and previous 4 time steps are attended.
If default, relative encoding is disabled and it is a regular transformer.
Defaults to None.
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
layer_norm_type (str, optional): type "1" uses torch tensor operations and type "2" uses torch layer_norm
primitive. Use type "2", type "1: is for backward compat. Defaults to "1".
cond_channels (int): number of channels of the external embeddings.
cond_layer_idx (int): layer index to condition at. (using 3rd layer by default as in the paper)
"""

def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int,
hidden_channels_ffn: int,
num_heads: int,
num_layers: int,
kernel_size=1,
dropout_p=0.0,
rel_attn_window_size: int = None,
input_length: int = None,
layer_norm_type: str = "1",
cond_channels: int = 0,
cond_layer_idx: int = 2,
):
super().__init__()
self.cond_channels = cond_channels
if cond_layer_idx < 0 or cond_layer_idx >= num_layers:
raise ValueError(" [!] cond_layer_idx should be in [0, num_layers)")
self.cond_layer_idx = cond_layer_idx
self.cond_proj = None
if self.cond_channels:
self.cond_proj = nn.Linear(cond_channels, hidden_channels)

self.hidden_channels = hidden_channels
self.hidden_channels_ffn = hidden_channels_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.kernel_size = kernel_size
self.dropout_p = dropout_p
self.rel_attn_window_size = rel_attn_window_size

self.dropout = nn.Dropout(dropout_p)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()

for idx in range(self.num_layers):
self.attn_layers.append(
RelativePositionMultiHeadAttention(
hidden_channels if idx != 0 else in_channels,
hidden_channels,
num_heads,
rel_attn_window_size=rel_attn_window_size,
dropout_p=dropout_p,
input_length=input_length,
)
)
if layer_norm_type == "1":
self.norm_layers_1.append(LayerNorm(hidden_channels))
elif layer_norm_type == "2":
self.norm_layers_1.append(LayerNorm2(hidden_channels))
else:
raise ValueError(" [!] Unknown layer norm type")

if hidden_channels != out_channels and (idx + 1) == self.num_layers:
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)

self.ffn_layers.append(
FeedForwardNetwork(
hidden_channels,
hidden_channels if (idx + 1) != self.num_layers else out_channels,
hidden_channels_ffn,
kernel_size,
dropout_p=dropout_p,
)
)

if layer_norm_type == "1":
self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels))
elif layer_norm_type == "2":
self.norm_layers_2.append(LayerNorm2(hidden_channels if (idx + 1) != self.num_layers else out_channels))
else:
raise ValueError(" [!] Unknown layer norm type")

def forward(self, x, x_mask, g=None):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
- g: :math:`[B, C, T]`
"""
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
for i in range(self.num_layers):
if i == self.cond_layer_idx and self.cond_proj is not None:
g = self.cond_proj(g.transpose(1, 2))
g = g.transpose(1, 2)
x = x + g
x = x * x_mask
y = self.attn_layers[i](x, x, attn_mask)
y = self.dropout(y)
x = self.norm_layers_1[i](x + y)

y = self.ffn_layers[i](x, x_mask)
y = self.dropout(y)

if (i + 1) == self.num_layers and hasattr(self, "proj"):
x = self.proj(x)

x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
89 changes: 89 additions & 0 deletions TTS/tts/layers/vits2/discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch
from torch import nn
from torch.nn.modules.conv import Conv1d

from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator


class DiscriminatorS(torch.nn.Module):
"""HiFiGAN Scale Discriminator. Channel sizes are different from the original HiFiGAN.
Args:
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
"""

def __init__(self, use_spectral_norm=False):
super().__init__()
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))

def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
Tensor: discriminator scores.
List[Tensor]: list of features from the convolutiona layers.
"""
feat = []
for l in self.convs:
x = l(x)
x = torch.nn.functional.leaky_relu(x, 0.1)
feat.append(x)
x = self.conv_post(x)
feat.append(x)
x = torch.flatten(x, 1, -1)
return x, feat


class VitsDiscriminator(nn.Module):
"""VITS discriminator wrapping one Scale Discriminator and a stack of Period Discriminator.
::
waveform -> ScaleDiscriminator() -> scores_sd, feats_sd --> append() -> scores, feats
|--> MultiPeriodDiscriminator() -> scores_mpd, feats_mpd ^
Args:
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
"""

def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False):
super().__init__()
self.nets = nn.ModuleList()
self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm))
self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods])

def forward(self, x, x_hat=None):
"""
Args:
x (Tensor): ground truth waveform.
x_hat (Tensor): predicted waveform.
Returns:
List[Tensor]: discriminator scores.
List[List[Tensor]]: list of list of features from each layers of each discriminator.
"""
x_scores = []
x_hat_scores = [] if x_hat is not None else None
x_feats = []
x_hat_feats = [] if x_hat is not None else None
for net in self.nets:
x_score, x_feat = net(x)
x_scores.append(x_score)
x_feats.append(x_feat)
if x_hat is not None:
x_hat_score, x_hat_feat = net(x_hat)
x_hat_scores.append(x_hat_score)
x_hat_feats.append(x_hat_feat)
return x_scores, x_feats, x_hat_scores, x_hat_feats
Loading

0 comments on commit e65cab7

Please sign in to comment.