Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vits2 prototype #3355

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions TTS/tts/configs/vits2_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from dataclasses import dataclass, field
from typing import List

from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.vits2 import Vits2Args, Vits2AudioConfig


@dataclass
class Vits2Config(BaseTTSConfig):
"""Defines parameters for VITS2 End2End TTS model.

Args:
model (str):
Model name. Do not change unless you know what you are doing.

model_args (Vits2Args):
Model architecture arguments. Defaults to `Vits2Args()`.

audio (Vits2AudioConfig):
Audio processing configuration. Defaults to `Vits2AudioConfig()`.

grad_clip (List):
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.

lr_gen (float):
Initial learning rate for the generator. Defaults to 0.0002.

lr_disc (float):
Initial learning rate for the discriminator. Defaults to 0.0002.

lr_scheduler_gen (str):
Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
`ExponentialLR`.

lr_scheduler_gen_params (dict):
Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.

lr_scheduler_disc (str):
Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
`ExponentialLR`.

lr_scheduler_disc_params (dict):
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.

scheduler_after_epoch (bool):
If true, step the schedulers after each epoch else after each step. Defaults to `False`.

optimizer (str):
Name of the optimizer to use with both the generator and the discriminator networks. One of the
`torch.optim.*`. Defaults to `AdamW`.

kl_loss_alpha (float):
Loss weight for KL loss. Defaults to 1.0.

disc_loss_alpha (float):
Loss weight for the discriminator loss. Defaults to 1.0.

gen_loss_alpha (float):
Loss weight for the generator loss. Defaults to 1.0.

feat_loss_alpha (float):
Loss weight for the feature matching loss. Defaults to 1.0.

mel_loss_alpha (float):
Loss weight for the mel loss. Defaults to 45.0.

return_wav (bool):
If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.

compute_linear_spec (bool):
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.

use_weighted_sampler (bool):
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.

weighted_sampler_attrs (dict):
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
by overweighting `root_path` by 2.0. Defaults to `{}`.

weighted_sampler_multipliers (dict):
Weight each unique value of a key returned by the formatter for weighted sampling.
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.

r (int):
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.

add_blank (bool):
If true, a blank token is added in between every character. Defaults to `True`.

test_sentences (List[List]):
List of sentences with speaker and language information to be used for testing.

language_ids_file (str):
Path to the language ids file.

use_language_embedding (bool):
If true, language embedding is used. Defaults to `False`.

Note:
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.

Example:

>>> from TTS.tts.configs.vits2_config import Vits2Config
>>> config = Vits2Config()
"""

model: str = "vits2"
# model specific params
model_args: Vits2Args = field(default_factory=Vits2Args)
audio: Vits2AudioConfig = field(default_factory=Vits2AudioConfig)

# optimizer
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
lr_gen: float = 0.0002
lr_disc: float = 0.0002
lr_dur: float = 0.0002

lr_scheduler_gen: str = "ExponentialLR"
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
lr_scheduler_disc: str = "ExponentialLR"
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
lr_scheduler_dur: str = "ExponentialLR"
lr_scheduler_dur_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})

scheduler_after_epoch: bool = True
optimizer: str = "AdamW"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})

# loss params
kl_loss_alpha: float = 1.0
disc_loss_alpha: float = 1.0
gen_loss_alpha: float = 1.0
feat_loss_alpha: float = 1.0
mel_loss_alpha: float = 45.0
dur_loss_alpha: float = 1.0
speaker_encoder_loss_alpha: float = 1.0

# data loader params
return_wav: bool = True
compute_linear_spec: bool = True

# sampler params
use_weighted_sampler: bool = False # TODO: move it to the base config
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})

# overrides
r: int = 1 # DO NOT CHANGE
add_blank: bool = True

# testing
test_sentences: List[List] = field(
default_factory=lambda: [
["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."],
["Be a voice, not an echo."],
["I'm sorry Dave. I'm afraid I can't do that."],
["This cake is great. It's so delicious and moist."],
["Prior to November 22, 1963."],
]
)

# multi-speaker settings
# use speaker embedding layer
num_speakers: int = 0
use_speaker_embedding: bool = False
speakers_file: str = None
speaker_embedding_channels: int = 256
language_ids_file: str = None
use_language_embedding: bool = False

# use d-vectors
use_d_vector_file: bool = False
d_vector_file: List[str] = None
d_vector_dim: int = None

def __post_init__(self):
for key, val in self.model_args.items():
if hasattr(self, key):
self[key] = val
130 changes: 130 additions & 0 deletions TTS/tts/layers/glow_tts/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,133 @@ def forward(self, x, x_mask):
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x

class ConditionalRelativePositionTransformer(nn.Module):
"""Transformer with Relative Potional Encoding and conditioned on external embeddings at cond_layer_idx'th layer.

https://arxiv.org/abs/2307.16430

Args:
in_channels (int): number of channels of the input tensor.
out_chanels (int): number of channels of the output tensor.
hidden_channels (int): model hidden channels.
hidden_channels_ffn (int): hidden channels of FeedForwardNetwork.
num_heads (int): number of attention heads.
num_layers (int): number of transformer layers.
kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1.
dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0.
rel_attn_window_size (int, optional): relation attention window size.
If 4, for each time step next and previous 4 time steps are attended.
If default, relative encoding is disabled and it is a regular transformer.
Defaults to None.
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
layer_norm_type (str, optional): type "1" uses torch tensor operations and type "2" uses torch layer_norm
primitive. Use type "2", type "1: is for backward compat. Defaults to "1".
cond_channels (int): number of channels of the external embeddings.
cond_layer_idx (int): layer index to condition at. (using 3rd layer by default as in the paper)
"""

def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int,
hidden_channels_ffn: int,
num_heads: int,
num_layers: int,
kernel_size=1,
dropout_p=0.0,
rel_attn_window_size: int = None,
input_length: int = None,
layer_norm_type: str = "1",
cond_channels: int = 0,
cond_layer_idx: int = 2,
):
super().__init__()
self.cond_channels = cond_channels
if cond_layer_idx < 0 or cond_layer_idx >= num_layers:
raise ValueError(" [!] cond_layer_idx should be in [0, num_layers)")
self.cond_layer_idx = cond_layer_idx
self.cond_proj = None
if self.cond_channels:
self.cond_proj = nn.Linear(cond_channels, hidden_channels)

self.hidden_channels = hidden_channels
self.hidden_channels_ffn = hidden_channels_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.kernel_size = kernel_size
self.dropout_p = dropout_p
self.rel_attn_window_size = rel_attn_window_size

self.dropout = nn.Dropout(dropout_p)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()

for idx in range(self.num_layers):
self.attn_layers.append(
RelativePositionMultiHeadAttention(
hidden_channels if idx != 0 else in_channels,
hidden_channels,
num_heads,
rel_attn_window_size=rel_attn_window_size,
dropout_p=dropout_p,
input_length=input_length,
)
)
if layer_norm_type == "1":
self.norm_layers_1.append(LayerNorm(hidden_channels))
elif layer_norm_type == "2":
self.norm_layers_1.append(LayerNorm2(hidden_channels))
else:
raise ValueError(" [!] Unknown layer norm type")

if hidden_channels != out_channels and (idx + 1) == self.num_layers:
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)

self.ffn_layers.append(
FeedForwardNetwork(
hidden_channels,
hidden_channels if (idx + 1) != self.num_layers else out_channels,
hidden_channels_ffn,
kernel_size,
dropout_p=dropout_p,
)
)

if layer_norm_type == "1":
self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels))
elif layer_norm_type == "2":
self.norm_layers_2.append(LayerNorm2(hidden_channels if (idx + 1) != self.num_layers else out_channels))
else:
raise ValueError(" [!] Unknown layer norm type")

def forward(self, x, x_mask, g=None):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
- g: :math:`[B, C, T]`
"""
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
for i in range(self.num_layers):
if i == self.cond_layer_idx and self.cond_proj is not None:
g = self.cond_proj(g.transpose(1, 2))
g = g.transpose(1, 2)
x = x + g
x = x * x_mask
y = self.attn_layers[i](x, x, attn_mask)
y = self.dropout(y)
x = self.norm_layers_1[i](x + y)

y = self.ffn_layers[i](x, x_mask)
y = self.dropout(y)

if (i + 1) == self.num_layers and hasattr(self, "proj"):
x = self.proj(x)

x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
33 changes: 33 additions & 0 deletions TTS/tts/layers/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,39 @@ def forward(
return_dict["loss"] = loss
return return_dict

class Vits2DurationLoss(nn.Module):
def __init__(self, c: Coqpit):
super().__init__()
self.disc_loss_alpha = c.disc_loss_alpha

@staticmethod
def discriminator_loss(scores_real, scores_fake):
loss = 0
real_losses = []
fake_losses = []
for dr, dg in zip(scores_real, scores_fake):
dr = dr.float()
dg = dg.float()
real_loss = torch.mean((1 - dr) ** 2)
fake_loss = torch.mean(dg**2)
loss += real_loss + fake_loss
real_losses.append(real_loss.item())
fake_losses.append(fake_loss.item())
return loss, real_losses, fake_losses

def forward(self, scores_disc_real, scores_disc_fake):
loss = 0.0
return_dict = {}
loss_disc, loss_disc_real, _ = self.discriminator_loss(
scores_real=scores_disc_real, scores_fake=scores_disc_fake
)
return_dict["loss_dur_disc"] = loss_disc * self.disc_loss_alpha
loss = loss + return_dict["loss_dur_disc"]
return_dict["loss"] = loss

for i, ldr in enumerate(loss_disc_real):
return_dict[f"loss_dur_disc_real_{i}"] = ldr
return return_dict

class VitsDiscriminatorLoss(nn.Module):
def __init__(self, c: Coqpit):
Expand Down
Loading
Loading