diff --git a/TTS/tts/configs/vits2_config.py b/TTS/tts/configs/vits2_config.py new file mode 100644 index 0000000000..e398dd2c28 --- /dev/null +++ b/TTS/tts/configs/vits2_config.py @@ -0,0 +1,181 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.vits2 import Vits2Args, Vits2AudioConfig + + +@dataclass +class Vits2Config(BaseTTSConfig): + """Defines parameters for VITS2 End2End TTS model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (Vits2Args): + Model architecture arguments. Defaults to `Vits2Args()`. + + audio (Vits2AudioConfig): + Audio processing configuration. Defaults to `Vits2AudioConfig()`. + + grad_clip (List): + Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`. + + lr_gen (float): + Initial learning rate for the generator. Defaults to 0.0002. + + lr_disc (float): + Initial learning rate for the discriminator. Defaults to 0.0002. + + lr_scheduler_gen (str): + Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to + `ExponentialLR`. + + lr_scheduler_gen_params (dict): + Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`. + + lr_scheduler_disc (str): + Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to + `ExponentialLR`. + + lr_scheduler_disc_params (dict): + Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`. + + scheduler_after_epoch (bool): + If true, step the schedulers after each epoch else after each step. Defaults to `False`. + + optimizer (str): + Name of the optimizer to use with both the generator and the discriminator networks. One of the + `torch.optim.*`. Defaults to `AdamW`. + + kl_loss_alpha (float): + Loss weight for KL loss. Defaults to 1.0. + + disc_loss_alpha (float): + Loss weight for the discriminator loss. Defaults to 1.0. + + gen_loss_alpha (float): + Loss weight for the generator loss. Defaults to 1.0. + + feat_loss_alpha (float): + Loss weight for the feature matching loss. Defaults to 1.0. + + mel_loss_alpha (float): + Loss weight for the mel loss. Defaults to 45.0. + + return_wav (bool): + If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`. + + compute_linear_spec (bool): + If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. + + use_weighted_sampler (bool): + If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`. + + weighted_sampler_attrs (dict): + Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities + by overweighting `root_path` by 2.0. Defaults to `{}`. + + weighted_sampler_multipliers (dict): + Weight each unique value of a key returned by the formatter for weighted sampling. + For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`. + It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`. + + r (int): + Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. + + add_blank (bool): + If true, a blank token is added in between every character. Defaults to `True`. + + test_sentences (List[List]): + List of sentences with speaker and language information to be used for testing. + + language_ids_file (str): + Path to the language ids file. + + use_language_embedding (bool): + If true, language embedding is used. Defaults to `False`. + + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + + Example: + + >>> from TTS.tts.configs.vits2_config import Vits2Config + >>> config = Vits2Config() + """ + + model: str = "vits2" + # model specific params + model_args: Vits2Args = field(default_factory=Vits2Args) + audio: Vits2AudioConfig = field(default_factory=Vits2AudioConfig) + + # optimizer + grad_clip: List[float] = field(default_factory=lambda: [1000, 1000]) + lr_gen: float = 0.0002 + lr_disc: float = 0.0002 + lr_dur: float = 0.0002 + + lr_scheduler_gen: str = "ExponentialLR" + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + lr_scheduler_disc: str = "ExponentialLR" + lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + lr_scheduler_dur: str = "ExponentialLR" + lr_scheduler_dur_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + + scheduler_after_epoch: bool = True + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01}) + + # loss params + kl_loss_alpha: float = 1.0 + disc_loss_alpha: float = 1.0 + gen_loss_alpha: float = 1.0 + feat_loss_alpha: float = 1.0 + mel_loss_alpha: float = 45.0 + dur_loss_alpha: float = 1.0 + speaker_encoder_loss_alpha: float = 1.0 + + # data loader params + return_wav: bool = True + compute_linear_spec: bool = True + + # sampler params + use_weighted_sampler: bool = False # TODO: move it to the base config + weighted_sampler_attrs: dict = field(default_factory=lambda: {}) + weighted_sampler_multipliers: dict = field(default_factory=lambda: {}) + + # overrides + r: int = 1 # DO NOT CHANGE + add_blank: bool = True + + # testing + test_sentences: List[List] = field( + default_factory=lambda: [ + ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."], + ["Be a voice, not an echo."], + ["I'm sorry Dave. I'm afraid I can't do that."], + ["This cake is great. It's so delicious and moist."], + ["Prior to November 22, 1963."], + ] + ) + + # multi-speaker settings + # use speaker embedding layer + num_speakers: int = 0 + use_speaker_embedding: bool = False + speakers_file: str = None + speaker_embedding_channels: int = 256 + language_ids_file: str = None + use_language_embedding: bool = False + + # use d-vectors + use_d_vector_file: bool = False + d_vector_file: List[str] = None + d_vector_dim: int = None + + def __post_init__(self): + for key, val in self.model_args.items(): + if hasattr(self, key): + self[key] = val diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py index 02688d611f..2324c6253f 100644 --- a/TTS/tts/layers/glow_tts/transformer.py +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -430,3 +430,133 @@ def forward(self, x, x_mask): x = self.norm_layers_2[i](x + y) x = x * x_mask return x + +class ConditionalRelativePositionTransformer(nn.Module): + """Transformer with Relative Potional Encoding and conditioned on external embeddings at cond_layer_idx'th layer. + + https://arxiv.org/abs/2307.16430 + + Args: + in_channels (int): number of channels of the input tensor. + out_chanels (int): number of channels of the output tensor. + hidden_channels (int): model hidden channels. + hidden_channels_ffn (int): hidden channels of FeedForwardNetwork. + num_heads (int): number of attention heads. + num_layers (int): number of transformer layers. + kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1. + dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0. + rel_attn_window_size (int, optional): relation attention window size. + If 4, for each time step next and previous 4 time steps are attended. + If default, relative encoding is disabled and it is a regular transformer. + Defaults to None. + input_length (int, optional): input lenght to limit position encoding. Defaults to None. + layer_norm_type (str, optional): type "1" uses torch tensor operations and type "2" uses torch layer_norm + primitive. Use type "2", type "1: is for backward compat. Defaults to "1". + cond_channels (int): number of channels of the external embeddings. + cond_layer_idx (int): layer index to condition at. (using 3rd layer by default as in the paper) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int, + hidden_channels_ffn: int, + num_heads: int, + num_layers: int, + kernel_size=1, + dropout_p=0.0, + rel_attn_window_size: int = None, + input_length: int = None, + layer_norm_type: str = "1", + cond_channels: int = 0, + cond_layer_idx: int = 2, + ): + super().__init__() + self.cond_channels = cond_channels + if cond_layer_idx < 0 or cond_layer_idx >= num_layers: + raise ValueError(" [!] cond_layer_idx should be in [0, num_layers)") + self.cond_layer_idx = cond_layer_idx + self.cond_proj = None + if self.cond_channels: + self.cond_proj = nn.Linear(cond_channels, hidden_channels) + + self.hidden_channels = hidden_channels + self.hidden_channels_ffn = hidden_channels_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.kernel_size = kernel_size + self.dropout_p = dropout_p + self.rel_attn_window_size = rel_attn_window_size + + self.dropout = nn.Dropout(dropout_p) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + + for idx in range(self.num_layers): + self.attn_layers.append( + RelativePositionMultiHeadAttention( + hidden_channels if idx != 0 else in_channels, + hidden_channels, + num_heads, + rel_attn_window_size=rel_attn_window_size, + dropout_p=dropout_p, + input_length=input_length, + ) + ) + if layer_norm_type == "1": + self.norm_layers_1.append(LayerNorm(hidden_channels)) + elif layer_norm_type == "2": + self.norm_layers_1.append(LayerNorm2(hidden_channels)) + else: + raise ValueError(" [!] Unknown layer norm type") + + if hidden_channels != out_channels and (idx + 1) == self.num_layers: + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + + self.ffn_layers.append( + FeedForwardNetwork( + hidden_channels, + hidden_channels if (idx + 1) != self.num_layers else out_channels, + hidden_channels_ffn, + kernel_size, + dropout_p=dropout_p, + ) + ) + + if layer_norm_type == "1": + self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels)) + elif layer_norm_type == "2": + self.norm_layers_2.append(LayerNorm2(hidden_channels if (idx + 1) != self.num_layers else out_channels)) + else: + raise ValueError(" [!] Unknown layer norm type") + + def forward(self, x, x_mask, g=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, T]` + """ + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.num_layers): + if i == self.cond_layer_idx and self.cond_proj is not None: + g = self.cond_proj(g.transpose(1, 2)) + g = g.transpose(1, 2) + x = x + g + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.dropout(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.dropout(y) + + if (i + 1) == self.num_layers and hasattr(self, "proj"): + x = self.proj(x) + + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x \ No newline at end of file diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index de5f408c48..75880d4745 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -740,6 +740,39 @@ def forward( return_dict["loss"] = loss return return_dict +class Vits2DurationLoss(nn.Module): + def __init__(self, c: Coqpit): + super().__init__() + self.disc_loss_alpha = c.disc_loss_alpha + + @staticmethod + def discriminator_loss(scores_real, scores_fake): + loss = 0 + real_losses = [] + fake_losses = [] + for dr, dg in zip(scores_real, scores_fake): + dr = dr.float() + dg = dg.float() + real_loss = torch.mean((1 - dr) ** 2) + fake_loss = torch.mean(dg**2) + loss += real_loss + fake_loss + real_losses.append(real_loss.item()) + fake_losses.append(fake_loss.item()) + return loss, real_losses, fake_losses + + def forward(self, scores_disc_real, scores_disc_fake): + loss = 0.0 + return_dict = {} + loss_disc, loss_disc_real, _ = self.discriminator_loss( + scores_real=scores_disc_real, scores_fake=scores_disc_fake + ) + return_dict["loss_dur_disc"] = loss_disc * self.disc_loss_alpha + loss = loss + return_dict["loss_dur_disc"] + return_dict["loss"] = loss + + for i, ldr in enumerate(loss_disc_real): + return_dict[f"loss_dur_disc_real_{i}"] = ldr + return return_dict class VitsDiscriminatorLoss(nn.Module): def __init__(self, c: Coqpit): diff --git a/TTS/tts/layers/vits2/discriminator.py b/TTS/tts/layers/vits2/discriminator.py new file mode 100644 index 0000000000..148f283c90 --- /dev/null +++ b/TTS/tts/layers/vits2/discriminator.py @@ -0,0 +1,89 @@ +import torch +from torch import nn +from torch.nn.modules.conv import Conv1d + +from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator + + +class DiscriminatorS(torch.nn.Module): + """HiFiGAN Scale Discriminator. Channel sizes are different from the original HiFiGAN. + + Args: + use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. + """ + + def __init__(self, use_spectral_norm=False): + super().__init__() + norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + Tensor: discriminator scores. + List[Tensor]: list of features from the convolutiona layers. + """ + feat = [] + for l in self.convs: + x = l(x) + x = torch.nn.functional.leaky_relu(x, 0.1) + feat.append(x) + x = self.conv_post(x) + feat.append(x) + x = torch.flatten(x, 1, -1) + return x, feat + + +class VitsDiscriminator(nn.Module): + """VITS discriminator wrapping one Scale Discriminator and a stack of Period Discriminator. + + :: + waveform -> ScaleDiscriminator() -> scores_sd, feats_sd --> append() -> scores, feats + |--> MultiPeriodDiscriminator() -> scores_mpd, feats_mpd ^ + + Args: + use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. + """ + + def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False): + super().__init__() + self.nets = nn.ModuleList() + self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm)) + self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]) + + def forward(self, x, x_hat=None): + """ + Args: + x (Tensor): ground truth waveform. + x_hat (Tensor): predicted waveform. + + Returns: + List[Tensor]: discriminator scores. + List[List[Tensor]]: list of list of features from each layers of each discriminator. + """ + x_scores = [] + x_hat_scores = [] if x_hat is not None else None + x_feats = [] + x_hat_feats = [] if x_hat is not None else None + for net in self.nets: + x_score, x_feat = net(x) + x_scores.append(x_score) + x_feats.append(x_feat) + if x_hat is not None: + x_hat_score, x_hat_feat = net(x_hat) + x_hat_scores.append(x_hat_score) + x_hat_feats.append(x_hat_feat) + return x_scores, x_feats, x_hat_scores, x_hat_feats diff --git a/TTS/tts/layers/vits2/duration_discriminator.py b/TTS/tts/layers/vits2/duration_discriminator.py new file mode 100644 index 0000000000..c060a0f248 --- /dev/null +++ b/TTS/tts/layers/vits2/duration_discriminator.py @@ -0,0 +1,84 @@ +import torch +from torch import nn +from TTS.tts.layers.generic.normalization import LayerNorm2 + + +class DurationDiscriminator(nn.Module): #vits2 + """VITS-2 Duration Discriminator. + + :: + dur_r, dur_hat -> DurationDiscriminator() -> output_probs + + Args: + in_channels (int): number of input channels. + filter_channels (int): number of filter channels. + kernel_size (int): kernel size of the convolutional layers. + p_dropout (float): dropout probability. + gin_channels (int): number of global conditioning channels. + Unused for now. + + Returns: + List[Tensor]: list of discriminator scores. Real, Predicted/Generated. + """ + # TODO : not using "spk conditioning" for now according to the paper. + # Can be a better discriminator if we use it. + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + p_dropout, + gin_channels=0 + ): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2) + self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2) + self.dur_proj = nn.Conv1d(1, filter_channels, 1) + + self.pre_out_conv_1 = nn.Conv1d(2*filter_channels, filter_channels, kernel_size, padding=kernel_size//2) + self.pre_out_norm_1 = LayerNorm2(filter_channels) + self.pre_out_conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2) + self.pre_out_norm_2 = LayerNorm2(filter_channels) + + # if gin_channels != 0: + # self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + self.output_layer = nn.Sequential( + nn.Linear(filter_channels, 1), + nn.Sigmoid() + ) + + def forward_probability(self, x, x_mask, dur, g=None): + dur = self.dur_proj(dur) + x = torch.cat([x, dur], dim=1) + x = self.pre_out_conv_1(x * x_mask) + x = self.pre_out_conv_2(x * x_mask) + x = x * x_mask + x = x.transpose(1, 2) + output_prob = self.output_layer(x) + return output_prob + + def forward(self, x, x_mask, dur_r, dur_hat, g=None): + x = torch.detach(x) + # if g is not None: + # g = torch.detach(g) + # x = x + self.cond(g) + x = self.conv_1(x * x_mask) + # x = self.drop(x) + x = self.conv_2(x * x_mask) + # x = self.drop(x) + + output_probs = [] + for dur in [dur_r, dur_hat]: + output_prob = self.forward_probability(x, x_mask, dur, g) + output_probs.append(output_prob) + + return output_probs \ No newline at end of file diff --git a/TTS/tts/layers/vits2/networks.py b/TTS/tts/layers/vits2/networks.py new file mode 100644 index 0000000000..9c6bb3dd7c --- /dev/null +++ b/TTS/tts/layers/vits2/networks.py @@ -0,0 +1,382 @@ +import math + +import torch +from torch import nn + +from TTS.tts.layers.generic.wavenet import WN +from TTS.tts.layers.vits2.transformer import ConditionalRelativePositionTransformer, RelativePositionTransformer +from TTS.tts.utils.helpers import sequence_mask + +LRELU_SLOPE = 0.1 + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class TextEncoder(nn.Module): + def __init__( + self, + n_vocab: int, + out_channels: int, + hidden_channels: int, + hidden_channels_ffn: int, + num_heads: int, + num_layers: int, + kernel_size: int, + dropout_p: float, + language_emb_dim: int = None, + speaker_emb_dim: int = None, + speaker_emb_layer_idx: int = None + ): + """Text Encoder for VITS-2 model. + + Args: + n_vocab (int): Number of characters for the embedding layer. + out_channels (int): Number of channels for the output. + hidden_channels (int): Number of channels for the hidden layers. + hidden_channels_ffn (int): Number of channels for the convolutional layers. + num_heads (int): Number of attention heads for the Transformer layers. + num_layers (int): Number of Transformer layers. + kernel_size (int): Kernel size for the FFN layers in Transformer network. + dropout_p (float): Dropout rate for the Transformer layers. + """ + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + + self.emb = nn.Embedding(n_vocab, hidden_channels) + + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + if language_emb_dim: + hidden_channels += language_emb_dim + + self.speaker_conditioning = False + if speaker_emb_dim: + speaker_emb_layer_idx = speaker_emb_layer_idx if speaker_emb_layer_idx is not None else 2 + assert speaker_emb_layer_idx < num_layers, "speaker_emb_layer_idx should be less than num_layers" + assert speaker_emb_dim > 0, "speaker_emb_dim should be greater than 0" + self.speaker_conditioning = True + + self.encoder = ConditionalRelativePositionTransformer( + in_channels=hidden_channels, + out_channels=hidden_channels, + hidden_channels=hidden_channels, + hidden_channels_ffn=hidden_channels_ffn, + num_heads=num_heads, + num_layers=num_layers, + kernel_size=kernel_size, + dropout_p=dropout_p, + layer_norm_type="2", + rel_attn_window_size=4, + cond_channels=speaker_emb_dim, + cond_layer_idx=speaker_emb_layer_idx + ) + + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, lang_emb=None, speaker_emb=None): + """ + Shapes: + - x: :math:`[B, T]` + - x_length: :math:`[B]` + """ + assert x.shape[0] == x_lengths.shape[0] + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + + # concat the lang emb in embedding chars + if lang_emb is not None: + x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) + + if speaker_emb is None and self.speaker_conditioning: + raise ValueError("speaker_emb is None but speaker conditioning is enabled") + if speaker_emb is not None and not self.speaker_conditioning: + raise ValueError("speaker_emb is not None but speaker conditioning is disabled") + + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t] + + x = self.encoder(x * x_mask, x_mask, g=speaker_emb) # [b, h, t] + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class ResidualCouplingTransformerLayer(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size=3, + num_layers=2, + dropout_p=0.1, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.half_channels = channels // 2 + self.mean_only = mean_only + # input layer + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + # coupling layers + self.pre_transformer = RelativePositionTransformer( + in_channels=self.half_channels, + out_channels=self.half_channels, + hidden_channels=hidden_channels, + hidden_channels_ffn=768, + num_heads=2, + num_layers=num_layers, + kernel_size=kernel_size, + dropout_p=dropout_p, + layer_norm_type="2", + rel_attn_window_size=None + ) + # output layer + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + """ + Note: + Set `reverse` to True for inference. + + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, 1]` + """ + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + x0_ = self.pre_transformer(x0 * x_mask, x_mask) + x0_ = x0_ + x0 + h = self.pre(x0_) * x_mask + stats = self.post(h) * x_mask + if not self.mean_only: + m, log_scale = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + log_scale = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(log_scale) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(log_scale, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-log_scale) * x_mask + x = torch.cat([x0, x1], 1) + return x + +class ResidualCouplingBlock(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + dropout_p=0, + cond_channels=0, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.half_channels = channels // 2 + self.mean_only = mean_only + # input layer + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + # coupling layers + self.enc = WN( + hidden_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + dropout_p=dropout_p, + c_in_channels=cond_channels, + ) + # output layer + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + """ + Note: + Set `reverse` to True for inference. + + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, 1]` + """ + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, log_scale = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + log_scale = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(log_scale) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(log_scale, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-log_scale) * x_mask + x = torch.cat([x0, x1], 1) + return x + +class ResidualCouplingBlocks(nn.Module): + def __init__( + self, + channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + num_layers: int, + num_flows=4, + cond_channels=0, + use_transformer_flow_layer=True, + ): + """Redisual Coupling blocks for VITS-2 flow layers. + + Args: + channels (int): Number of input and output tensor channels. + hidden_channels (int): Number of hidden network channels. + kernel_size (int): Kernel size of the WaveNet layers. + dilation_rate (int): Dilation rate of the WaveNet layers. + num_layers (int): Number of the WaveNet layers. + num_flows (int, optional): Number of Residual Coupling blocks. Defaults to 4. + cond_channels (int, optional): Number of channels of the conditioning tensor. Defaults to 0. + use_transformer_flow_layer (bool, optional): Use Transformer flow layer. Defaults to True. + """ + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_layers = num_layers + self.num_flows = num_flows + self.cond_channels = cond_channels + + self.flows = nn.ModuleList() + for _ in range(num_flows): + self.flows.append( + ResidualCouplingBlock( + channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + cond_channels=cond_channels, + mean_only=True, + ) + ) + self.use_transformer_flow_layer = use_transformer_flow_layer + if self.use_transformer_flow_layer: + self.TransformerFlowLayer = ResidualCouplingTransformerLayer( + channels=channels, + hidden_channels=hidden_channels, + kernel_size=kernel_size, + num_layers=num_layers, + mean_only=True, + ) + self.flows.append(self.TransformerFlowLayer) + + def forward(self, x, x_mask, g=None, reverse=False): + """ + Note: + Set `reverse` to True for inference. + + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, 1]` + """ + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + x = torch.flip(x, [1]) + else: + for flow in reversed(self.flows): + x = torch.flip(x, [1]) + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class PosteriorEncoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + num_layers: int, + cond_channels=0, + ): + """Posterior Encoder of VITS-2 model. + + :: + x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z + + Args: + in_channels (int): Number of input tensor channels. + out_channels (int): Number of output tensor channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size of the WaveNet convolution layers. + dilation_rate (int): Dilation rate of the WaveNet layers. + num_layers (int): Number of the WaveNet layers. + cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_layers = num_layers + self.cond_channels = cond_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels=cond_channels + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_lengths: :math:`[B, 1]` + - g: :math:`[B, C, 1]` + """ + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + mean, log_scale = torch.split(stats, self.out_channels, dim=1) + z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask + return z, mean, log_scale, x_mask diff --git a/TTS/tts/layers/vits2/stochastic_duration_predictor.py b/TTS/tts/layers/vits2/stochastic_duration_predictor.py new file mode 100644 index 0000000000..98dbf0935c --- /dev/null +++ b/TTS/tts/layers/vits2/stochastic_duration_predictor.py @@ -0,0 +1,294 @@ +import math + +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.generic.normalization import LayerNorm2 +from TTS.tts.layers.vits.transforms import piecewise_rational_quadratic_transform + + +class DilatedDepthSeparableConv(nn.Module): + def __init__(self, channels, kernel_size, num_layers, dropout_p=0.0) -> torch.tensor: + """Dilated Depth-wise Separable Convolution module. + + :: + x |-> DDSConv(x) -> LayerNorm(x) -> GeLU(x) -> Conv1x1(x) -> LayerNorm(x) -> GeLU(x) -> + -> o + |-------------------------------------------------------------------------------------^ + + Args: + channels ([type]): [description] + kernel_size ([type]): [description] + num_layers ([type]): [description] + dropout_p (float, optional): [description]. Defaults to 0.0. + + Returns: + torch.tensor: Network output masked by the input sequence mask. + """ + super().__init__() + self.num_layers = num_layers + + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(num_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm2(channels)) + self.norms_2.append(LayerNorm2(channels)) + self.dropout = nn.Dropout(dropout_p) + + def forward(self, x, x_mask, g=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + """ + if g is not None: + x = x + g + for i in range(self.num_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.dropout(y) + x = x + y + return x * x_mask + + +class ElementwiseAffine(nn.Module): + """Element-wise affine transform like no-population stats BatchNorm alternative. + + Args: + channels (int): Number of input tensor channels. + """ + + def __init__(self, channels): + super().__init__() + self.translation = nn.Parameter(torch.zeros(channels, 1)) + self.log_scale = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): # pylint: disable=unused-argument + if not reverse: + y = (x * torch.exp(self.log_scale) + self.translation) * x_mask + logdet = torch.sum(self.log_scale * x_mask, [1, 2]) + return y, logdet + x = (x - self.translation) * torch.exp(-self.log_scale) * x_mask + return x + + +class ConvFlow(nn.Module): + """Dilated depth separable convolutional based spline flow. + + Args: + in_channels (int): Number of input tensor channels. + hidden_channels (int): Number of in network channels. + kernel_size (int): Convolutional kernel size. + num_layers (int): Number of convolutional layers. + num_bins (int, optional): Number of spline bins. Defaults to 10. + tail_bound (float, optional): Tail bound for PRQT. Defaults to 5.0. + """ + + def __init__( + self, + in_channels: int, + hidden_channels: int, + kernel_size: int, + num_layers: int, + num_bins=10, + tail_bound=5.0, + ): + super().__init__() + self.num_bins = num_bins + self.tail_bound = tail_bound + self.hidden_channels = hidden_channels + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers, dropout_p=0.0) + self.proj = nn.Conv1d(hidden_channels, self.half_channels * (num_bins * 3 - 1), 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.hidden_channels) + unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.hidden_channels) + unnormalized_derivatives = h[..., 2 * self.num_bins :] + + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + return x + + +class StochasticDurationPredictor(nn.Module): + """Stochastic duration predictor with Spline Flows. + + It applies Variational Dequantization and Variational Data Augmentation. + + Paper: + SDP: https://arxiv.org/pdf/2106.06103.pdf + Spline Flow: https://arxiv.org/abs/1906.04032 + + :: + ## Inference + + x -> TextCondEncoder() -> Flow() -> dr_hat + noise ----------------------^ + + ## Training + |---------------------| + x -> TextCondEncoder() -> + -> PosteriorEncoder() -> split() -> z_u, z_v -> (d - z_u) -> concat() -> Flow() -> noise + d -> DurCondEncoder() -> ^ | + |------------------------------------------------------------------------------| + + Args: + in_channels (int): Number of input tensor channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size of convolutional layers. + dropout_p (float): Dropout rate. + num_flows (int, optional): Number of flow blocks. Defaults to 4. + cond_channels (int, optional): Number of channels of conditioning tensor. Defaults to 0. + """ + + def __init__( + self, + in_channels: int, + hidden_channels: int, + kernel_size: int, + dropout_p: float, + num_flows=4, + cond_channels=0, + language_emb_dim=0, + ): + super().__init__() + + # add language embedding dim in the input + if language_emb_dim: + in_channels += language_emb_dim + + # condition encoder text + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p) + self.proj = nn.Conv1d(hidden_channels, hidden_channels, 1) + + # posterior encoder + self.flows = nn.ModuleList() + self.flows.append(ElementwiseAffine(2)) + self.flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)] + + # condition encoder duration + self.post_pre = nn.Conv1d(1, hidden_channels, 1) + self.post_convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p) + self.post_proj = nn.Conv1d(hidden_channels, hidden_channels, 1) + + # flow layers + self.post_flows = nn.ModuleList() + self.post_flows.append(ElementwiseAffine(2)) + self.post_flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)] + + if cond_channels != 0 and cond_channels is not None: + self.cond = nn.Conv1d(cond_channels, hidden_channels, 1) + + if language_emb_dim != 0 and language_emb_dim is not None: + self.cond_lang = nn.Conv1d(language_emb_dim, hidden_channels, 1) + + def forward(self, x, x_mask, dr=None, g=None, lang_emb=None, reverse=False, noise_scale=1.0): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - dr: :math:`[B, 1, T]` + - g: :math:`[B, C]` + """ + # condition encoder text + x = self.pre(x) + if g is not None: + x = x + self.cond(g) + + if lang_emb is not None: + x = x + self.cond_lang(lang_emb) + + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert dr is not None + + # condition encoder duration + h = self.post_pre(dr) + h = self.post_convs(h, x_mask) + h = self.post_proj(h) * x_mask + noise = torch.randn(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + z_q = noise + + # posterior encoder + logdet_tot_q = 0.0 + for idx, flow in enumerate(self.post_flows): + z_q, logdet_q = flow(z_q, x_mask, g=(x + h)) + logdet_tot_q = logdet_tot_q + logdet_q + if idx > 0: + z_q = torch.flip(z_q, [1]) + + z_u, z_v = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (dr - u) * x_mask + + # posterior encoder - neg log likelihood + logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) + nll_posterior_encoder = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + (noise**2)) * x_mask, [1, 2]) - logdet_tot_q + ) + + z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask + logdet_tot = torch.sum(-z0, [1, 2]) + z = torch.cat([z0, z_v], 1) + + # flow layers + for idx, flow in enumerate(flows): + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + if idx > 0: + z = torch.flip(z, [1]) + + # flow layers - neg log likelihood + nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot + return nll_flow_layers + nll_posterior_encoder + + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale + for flow in flows: + z = torch.flip(z, [1]) + z = flow(z, x_mask, g=x, reverse=reverse) + + z0, _ = torch.split(z, [1, 1], 1) + logw = z0 + return logw diff --git a/TTS/tts/layers/vits2/transformer.py b/TTS/tts/layers/vits2/transformer.py new file mode 100644 index 0000000000..2324c6253f --- /dev/null +++ b/TTS/tts/layers/vits2/transformer.py @@ -0,0 +1,562 @@ +import math + +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2 + + +class RelativePositionMultiHeadAttention(nn.Module): + """Multi-head attention with Relative Positional embedding. + https://arxiv.org/pdf/1809.04281.pdf + + It learns positional embeddings for a window of neighbours. For keys and values, + it learns different set of embeddings. Key embeddings are agregated with the attention + scores and value embeddings are aggregated with the output. + + Note: + Example with relative attention window size 2 + + - input = [a, b, c, d, e] + - rel_attn_embeddings = [e(t-2), e(t-1), e(t+1), e(t+2)] + + So it learns 4 embedding vectors (in total 8) separately for key and value vectors. + + Considering the input c + + - e(t-2) corresponds to c -> a + - e(t-2) corresponds to c -> b + - e(t-2) corresponds to c -> d + - e(t-2) corresponds to c -> e + + These embeddings are shared among different time steps. So input a, b, d and e also uses + the same embeddings. + + Embeddings are ignored when the relative window is out of limit for the first and the last + n items. + + Args: + channels (int): input and inner layer channels. + out_channels (int): output channels. + num_heads (int): number of attention heads. + rel_attn_window_size (int, optional): relation attention window size. + If 4, for each time step next and previous 4 time steps are attended. + If default, relative encoding is disabled and it is a regular transformer. + Defaults to None. + heads_share (bool, optional): [description]. Defaults to True. + dropout_p (float, optional): dropout rate. Defaults to 0.. + input_length (int, optional): intput length for positional encoding. Defaults to None. + proximal_bias (bool, optional): enable/disable proximal bias as in the paper. Defaults to False. + proximal_init (bool, optional): enable/disable poximal init as in the paper. + Init key and query layer weights the same. Defaults to False. + """ + + def __init__( + self, + channels, + out_channels, + num_heads, + rel_attn_window_size=None, + heads_share=True, + dropout_p=0.0, + input_length=None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % num_heads == 0, " [!] channels should be divisible by num_heads." + # class attributes + self.channels = channels + self.out_channels = out_channels + self.num_heads = num_heads + self.rel_attn_window_size = rel_attn_window_size + self.heads_share = heads_share + self.input_length = input_length + self.proximal_bias = proximal_bias + self.dropout_p = dropout_p + self.attn = None + # query, key, value layers + self.k_channels = channels // num_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + # output layers + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.dropout = nn.Dropout(dropout_p) + # relative positional encoding layers + if rel_attn_window_size is not None: + n_heads_rel = 1 if heads_share else num_heads + rel_stddev = self.k_channels**-0.5 + emb_rel_k = nn.Parameter( + torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev + ) + emb_rel_v = nn.Parameter( + torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev + ) + self.register_parameter("emb_rel_k", emb_rel_k) + self.register_parameter("emb_rel_v", emb_rel_v) + + # init layers + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + # proximal bias + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - c: :math:`[B, C, T]` + - attn_mask: :math:`[B, 1, T, T]` + """ + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + x, self.attn = self.attention(q, k, v, mask=attn_mask) + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.num_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.num_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.num_heads, self.k_channels, t_s).transpose(2, 3) + # compute raw attention scores + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + # relative positional encoding for scores + if self.rel_attn_window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + # get relative key embeddings + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) + rel_logits = self._relative_position_to_absolute_position(rel_logits) + scores_local = rel_logits / math.sqrt(self.k_channels) + scores = scores + scores_local + # proximan bias + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attn_proximity_bias(t_s).to(device=scores.device, dtype=scores.dtype) + # attention score masking + if mask is not None: + # add small value to prevent oor error. + scores = scores.masked_fill(mask == 0, -1e4) + if self.input_length is not None: + block_mask = torch.ones_like(scores).triu(-1 * self.input_length).tril(self.input_length) + scores = scores * block_mask + -1e4 * (1 - block_mask) + # attention score normalization + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + # apply dropout to attention weights + p_attn = self.dropout(p_attn) + # compute output + output = torch.matmul(p_attn, value) + # relative positional encoding for values + if self.rel_attn_window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + @staticmethod + def _matmul_with_relative_values(p_attn, re): + """ + Args: + p_attn (Tensor): attention weights. + re (Tensor): relative value embedding vector. (a_(i,j)^V) + + Shapes: + -p_attn: :math:`[B, H, T, V]` + -re: :math:`[H or 1, V, D]` + -logits: :math:`[B, H, T, D]` + """ + logits = torch.matmul(p_attn, re.unsqueeze(0)) + return logits + + @staticmethod + def _matmul_with_relative_keys(query, re): + """ + Args: + query (Tensor): batch of query vectors. (x*W^Q) + re (Tensor): relative key embedding vector. (a_(i,j)^K) + + Shapes: + - query: :math:`[B, H, T, D]` + - re: :math:`[H or 1, V, D]` + - logits: :math:`[B, H, T, V]` + """ + # logits = torch.einsum('bhld, kmd -> bhlm', [query, re.to(query.dtype)]) + logits = torch.matmul(query, re.unsqueeze(0).transpose(-2, -1)) + return logits + + def _get_relative_embeddings(self, relative_embeddings, length): + """Convert embedding vestors to a tensor of embeddings""" + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.rel_attn_window_size + 1), 0) + slice_start_position = max((self.rel_attn_window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] + return used_relative_embeddings + + @staticmethod + def _relative_position_to_absolute_position(x): + """Converts tensor from relative to absolute indexing for local attention. + Shapes: + x: :math:`[B, C, T, 2 * T - 1]` + Returns: + A Tensor of shape :math:`[B, C, T, T]` + """ + batch, heads, length, _ = x.size() + # Pad to shift from relative to absolute indexing. + x = F.pad(x, [0, 1, 0, 0, 0, 0, 0, 0]) + # Pad extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad(x_flat, [0, length - 1, 0, 0, 0, 0]) + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :] + return x_final + + @staticmethod + def _absolute_position_to_relative_position(x): + """ + Shapes: + - x: :math:`[B, C, T, T]` + - ret: :math:`[B, C, T, 2*T-1]` + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0]) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0]) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + @staticmethod + def _attn_proximity_bias(length): + """Produce an attention mask that discourages distant + attention values. + Args: + length (int): an integer scalar. + Returns: + a Tensor with shape :math:`[1, 1, T, T]` + """ + # L + r = torch.arange(length, dtype=torch.float32) + # L x L + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + # scale mask values + diff = -torch.log1p(torch.abs(diff)) + # 1 x 1 x L x L + return diff.unsqueeze(0).unsqueeze(0) + + +class FeedForwardNetwork(nn.Module): + """Feed Forward Inner layers for Transformer. + + Args: + in_channels (int): input tensor channels. + out_channels (int): output tensor channels. + hidden_channels (int): inner layers hidden channels. + kernel_size (int): conv1d filter kernel size. + dropout_p (float, optional): dropout rate. Defaults to 0. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0, causal=False): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dropout_p = dropout_p + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size) + self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size) + self.dropout = nn.Dropout(dropout_p) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + x = torch.relu(x) + x = self.dropout(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, self._pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, self._pad_shape(padding)) + return x + + @staticmethod + def _pad_shape(padding): + l = padding[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +class RelativePositionTransformer(nn.Module): + """Transformer with Relative Potional Encoding. + https://arxiv.org/abs/1803.02155 + + Args: + in_channels (int): number of channels of the input tensor. + out_chanels (int): number of channels of the output tensor. + hidden_channels (int): model hidden channels. + hidden_channels_ffn (int): hidden channels of FeedForwardNetwork. + num_heads (int): number of attention heads. + num_layers (int): number of transformer layers. + kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1. + dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0. + rel_attn_window_size (int, optional): relation attention window size. + If 4, for each time step next and previous 4 time steps are attended. + If default, relative encoding is disabled and it is a regular transformer. + Defaults to None. + input_length (int, optional): input lenght to limit position encoding. Defaults to None. + layer_norm_type (str, optional): type "1" uses torch tensor operations and type "2" uses torch layer_norm + primitive. Use type "2", type "1: is for backward compat. Defaults to "1". + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int, + hidden_channels_ffn: int, + num_heads: int, + num_layers: int, + kernel_size=1, + dropout_p=0.0, + rel_attn_window_size: int = None, + input_length: int = None, + layer_norm_type: str = "1", + ): + super().__init__() + self.hidden_channels = hidden_channels + self.hidden_channels_ffn = hidden_channels_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.kernel_size = kernel_size + self.dropout_p = dropout_p + self.rel_attn_window_size = rel_attn_window_size + + self.dropout = nn.Dropout(dropout_p) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + + for idx in range(self.num_layers): + self.attn_layers.append( + RelativePositionMultiHeadAttention( + hidden_channels if idx != 0 else in_channels, + hidden_channels, + num_heads, + rel_attn_window_size=rel_attn_window_size, + dropout_p=dropout_p, + input_length=input_length, + ) + ) + if layer_norm_type == "1": + self.norm_layers_1.append(LayerNorm(hidden_channels)) + elif layer_norm_type == "2": + self.norm_layers_1.append(LayerNorm2(hidden_channels)) + else: + raise ValueError(" [!] Unknown layer norm type") + + if hidden_channels != out_channels and (idx + 1) == self.num_layers: + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + + self.ffn_layers.append( + FeedForwardNetwork( + hidden_channels, + hidden_channels if (idx + 1) != self.num_layers else out_channels, + hidden_channels_ffn, + kernel_size, + dropout_p=dropout_p, + ) + ) + + if layer_norm_type == "1": + self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels)) + elif layer_norm_type == "2": + self.norm_layers_2.append(LayerNorm2(hidden_channels if (idx + 1) != self.num_layers else out_channels)) + else: + raise ValueError(" [!] Unknown layer norm type") + + def forward(self, x, x_mask): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + """ + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.num_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.dropout(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.dropout(y) + + if (i + 1) == self.num_layers and hasattr(self, "proj"): + x = self.proj(x) + + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + +class ConditionalRelativePositionTransformer(nn.Module): + """Transformer with Relative Potional Encoding and conditioned on external embeddings at cond_layer_idx'th layer. + + https://arxiv.org/abs/2307.16430 + + Args: + in_channels (int): number of channels of the input tensor. + out_chanels (int): number of channels of the output tensor. + hidden_channels (int): model hidden channels. + hidden_channels_ffn (int): hidden channels of FeedForwardNetwork. + num_heads (int): number of attention heads. + num_layers (int): number of transformer layers. + kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1. + dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0. + rel_attn_window_size (int, optional): relation attention window size. + If 4, for each time step next and previous 4 time steps are attended. + If default, relative encoding is disabled and it is a regular transformer. + Defaults to None. + input_length (int, optional): input lenght to limit position encoding. Defaults to None. + layer_norm_type (str, optional): type "1" uses torch tensor operations and type "2" uses torch layer_norm + primitive. Use type "2", type "1: is for backward compat. Defaults to "1". + cond_channels (int): number of channels of the external embeddings. + cond_layer_idx (int): layer index to condition at. (using 3rd layer by default as in the paper) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int, + hidden_channels_ffn: int, + num_heads: int, + num_layers: int, + kernel_size=1, + dropout_p=0.0, + rel_attn_window_size: int = None, + input_length: int = None, + layer_norm_type: str = "1", + cond_channels: int = 0, + cond_layer_idx: int = 2, + ): + super().__init__() + self.cond_channels = cond_channels + if cond_layer_idx < 0 or cond_layer_idx >= num_layers: + raise ValueError(" [!] cond_layer_idx should be in [0, num_layers)") + self.cond_layer_idx = cond_layer_idx + self.cond_proj = None + if self.cond_channels: + self.cond_proj = nn.Linear(cond_channels, hidden_channels) + + self.hidden_channels = hidden_channels + self.hidden_channels_ffn = hidden_channels_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.kernel_size = kernel_size + self.dropout_p = dropout_p + self.rel_attn_window_size = rel_attn_window_size + + self.dropout = nn.Dropout(dropout_p) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + + for idx in range(self.num_layers): + self.attn_layers.append( + RelativePositionMultiHeadAttention( + hidden_channels if idx != 0 else in_channels, + hidden_channels, + num_heads, + rel_attn_window_size=rel_attn_window_size, + dropout_p=dropout_p, + input_length=input_length, + ) + ) + if layer_norm_type == "1": + self.norm_layers_1.append(LayerNorm(hidden_channels)) + elif layer_norm_type == "2": + self.norm_layers_1.append(LayerNorm2(hidden_channels)) + else: + raise ValueError(" [!] Unknown layer norm type") + + if hidden_channels != out_channels and (idx + 1) == self.num_layers: + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + + self.ffn_layers.append( + FeedForwardNetwork( + hidden_channels, + hidden_channels if (idx + 1) != self.num_layers else out_channels, + hidden_channels_ffn, + kernel_size, + dropout_p=dropout_p, + ) + ) + + if layer_norm_type == "1": + self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels)) + elif layer_norm_type == "2": + self.norm_layers_2.append(LayerNorm2(hidden_channels if (idx + 1) != self.num_layers else out_channels)) + else: + raise ValueError(" [!] Unknown layer norm type") + + def forward(self, x, x_mask, g=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, T]` + """ + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.num_layers): + if i == self.cond_layer_idx and self.cond_proj is not None: + g = self.cond_proj(g.transpose(1, 2)) + g = g.transpose(1, 2) + x = x + g + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.dropout(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.dropout(y) + + if (i + 1) == self.num_layers and hasattr(self, "proj"): + x = self.proj(x) + + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x \ No newline at end of file diff --git a/TTS/tts/layers/vits2/transforms.py b/TTS/tts/layers/vits2/transforms.py new file mode 100644 index 0000000000..3cac1b8d6d --- /dev/null +++ b/TTS/tts/layers/vits2/transforms.py @@ -0,0 +1,202 @@ +# adopted from https://github.com/bayesiains/nflows + +import numpy as np +import torch +from torch.nn import functional as F + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs, + ) + return outputs, logabsdet + + +def searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 + + +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == "linear": + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet diff --git a/TTS/tts/models/vits2.py b/TTS/tts/models/vits2.py new file mode 100644 index 0000000000..0342c1ea9f --- /dev/null +++ b/TTS/tts/models/vits2.py @@ -0,0 +1,2061 @@ +import math +import os +from dataclasses import dataclass, field, replace +from itertools import chain +from typing import Dict, List, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torchaudio +from coqpit import Coqpit +from librosa.filters import mel as librosa_mel_fn +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +from trainer.torch import DistributedSampler, DistributedSamplerWrapper +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.configs.shared_configs import CharactersConfig +from TTS.tts.datasets.dataset import TTSDataset, _parse_sample +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.layers.vits2.duration_discriminator import DurationDiscriminator +from TTS.tts.layers.vits2.discriminator import VitsDiscriminator +from TTS.tts.layers.vits2.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder +from TTS.tts.layers.vits2.stochastic_duration_predictor import StochasticDurationPredictor +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.fairseq import rehash_fairseq_vits_checkpoint +from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask +from TTS.tts.utils.languages import LanguageManager +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment +from TTS.utils.io import load_fsspec +from TTS.utils.samplers import BucketBatchSampler +from TTS.vocoder.models.hifigan_generator import HifiganGenerator +from TTS.vocoder.utils.generic_utils import plot_results + +############################## +# IO / Feature extraction +############################## + +# pylint: disable=global-statement +hann_window = {} +mel_basis = {} + + +@torch.no_grad() +def weights_reset(m: nn.Module): + # check if the current module has reset_parameters and if it is reset the weight + reset_parameters = getattr(m, "reset_parameters", None) + if callable(reset_parameters): + m.reset_parameters() + + +def get_module_weights_sum(mdl: nn.Module): + dict_sums = {} + for name, w in mdl.named_parameters(): + if "weight" in name: + value = w.data.sum().item() + dict_sums[name] = value + return dict_sums + + +def load_audio(file_path): + """Load the audio file normalized in [-1, 1] + + Return Shapes: + - x: :math:`[1, T]` + """ + x, sr = torchaudio.load(file_path) + assert (x > 1).sum() + (x < -1).sum() == 0 + return x, sr + + +def _amp_to_db(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def _db_to_amp(x, C=1): + return torch.exp(x) / C + + +def amp_to_db(magnitudes): + output = _amp_to_db(magnitudes) + return output + + +def db_to_amp(magnitudes): + output = _db_to_amp(magnitudes) + return output + + +def wav_to_spec(y, n_fft, hop_length, win_length, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): + """ + Args Shapes: + - spec : :math:`[B,C,T]` + + Return Shapes: + - mel : :math:`[B,C,T]` + """ + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + mel = torch.matmul(mel_basis[fmax_dtype_device], spec) + mel = amp_to_db(mel) + return mel + + +def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = amp_to_db(spec) + return spec + + +############################# +# CONFIGS +############################# + + +@dataclass +class Vits2AudioConfig(Coqpit): + fft_size: int = 1024 + sample_rate: int = 22050 + win_length: int = 1024 + hop_length: int = 256 + num_mels: int = 80 + mel_fmin: int = 0 + mel_fmax: int = None + + +############################## +# DATASET +############################## + + +def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None): + """Create inverse frequency weights for balancing the dataset. + Use `multi_dict` to scale relative weights.""" + attr_names_samples = np.array([item[attr_name] for item in items]) + unique_attr_names = np.unique(attr_names_samples).tolist() + attr_idx = [unique_attr_names.index(l) for l in attr_names_samples] + attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names]) + weight_attr = 1.0 / attr_count + dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx]) + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + if multi_dict is not None: + # check if all keys are in the multi_dict + for k in multi_dict: + assert k in unique_attr_names, f"{k} not in {unique_attr_names}" + # scale weights + multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items]) + dataset_samples_weight *= multiplier_samples + return ( + torch.from_numpy(dataset_samples_weight).float(), + unique_attr_names, + np.unique(dataset_samples_weight).tolist(), + ) + + +class Vits2Dataset(TTSDataset): + def __init__(self, model_args, *args, **kwargs): + super().__init__(*args, **kwargs) + self.pad_id = self.tokenizer.characters.pad_id + self.model_args = model_args + + def __getitem__(self, idx): + item = self.samples[idx] + raw_text = item["text"] + + wav, _ = load_audio(item["audio_file"]) + if self.model_args.encoder_sample_rate is not None: + if wav.size(1) % self.model_args.encoder_sample_rate != 0: + wav = wav[:, : -int(wav.size(1) % self.model_args.encoder_sample_rate)] + + wav_filename = os.path.basename(item["audio_file"]) + + token_ids = self.get_token_ids(idx, item["text"]) + + # after phonemization the text length may change + # this is a shameful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len: + self.rescue_item_idx += 1 + return self.__getitem__(self.rescue_item_idx) + + return { + "raw_text": raw_text, + "token_ids": token_ids, + "token_len": len(token_ids), + "wav": wav, + "wav_file": wav_filename, + "speaker_name": item["speaker_name"], + "language_name": item["language"], + "audio_unique_name": item["audio_unique_name"], + } + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + lens.append(audio_len) + return lens + + def collate_fn(self, batch): + """ + Return Shapes: + - tokens: :math:`[B, T]` + - token_lens :math:`[B]` + - token_rel_lens :math:`[B]` + - waveform: :math:`[B, 1, T]` + - waveform_lens: :math:`[B]` + - waveform_rel_lens: :math:`[B]` + - speaker_names: :math:`[B]` + - language_names: :math:`[B]` + - audiofile_paths: :math:`[B]` + - raw_texts: :math:`[B]` + - audio_unique_names: :math:`[B]` + """ + # convert list of dicts to dict of lists + B = len(batch) + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x.size(1) for x in batch["wav"]]), dim=0, descending=True + ) + + max_text_len = max([len(x) for x in batch["token_ids"]]) + token_lens = torch.LongTensor(batch["token_len"]) + token_rel_lens = token_lens / token_lens.max() + + wav_lens = [w.shape[1] for w in batch["wav"]] + wav_lens = torch.LongTensor(wav_lens) + wav_lens_max = torch.max(wav_lens) + wav_rel_lens = wav_lens / wav_lens_max + + token_padded = torch.LongTensor(B, max_text_len) + wav_padded = torch.FloatTensor(B, 1, wav_lens_max) + token_padded = token_padded.zero_() + self.pad_id + wav_padded = wav_padded.zero_() + self.pad_id + for i in range(len(ids_sorted_decreasing)): + token_ids = batch["token_ids"][i] + token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids) + + wav = batch["wav"][i] + wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) + + return { + "tokens": token_padded, + "token_lens": token_lens, + "token_rel_lens": token_rel_lens, + "waveform": wav_padded, # (B x T) + "waveform_lens": wav_lens, # (B) + "waveform_rel_lens": wav_rel_lens, + "speaker_names": batch["speaker_name"], + "language_names": batch["language_name"], + "audio_files": batch["wav_file"], + "raw_text": batch["raw_text"], + "audio_unique_names": batch["audio_unique_name"], + } + + +############################## +# MODEL DEFINITION +############################## + + +@dataclass +class Vits2Args(Coqpit): + """VITS2 model arguments. + + Args: + + num_chars (int): + Number of characters in the vocabulary. Defaults to 100. + + out_channels (int): + Number of output channels of the decoder. Defaults to 80 since we use mel-spec in vits2. + + spec_segment_size (int): + Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`. + + hidden_channels (int): + Number of hidden channels of the model. Defaults to 192. + + hidden_channels_ffn_text_encoder (int): + Number of hidden channels of the feed-forward layers of the text encoder transformer. Defaults to 256. + + num_heads_text_encoder (int): + Number of attention heads of the text encoder transformer. Defaults to 2. + + num_layers_text_encoder (int): + Number of transformer layers in the text encoder. Defaults to 6. + + kernel_size_text_encoder (int): + Kernel size of the text encoder transformer FFN layers. Defaults to 3. + + dropout_p_text_encoder (float): + Dropout rate of the text encoder. Defaults to 0.1. + + dropout_p_duration_predictor (float): + Dropout rate of the duration predictor. Defaults to 0.1. + + kernel_size_posterior_encoder (int): + Kernel size of the posterior encoder's WaveNet layers. Defaults to 5. + + dilatation_posterior_encoder (int): + Dilation rate of the posterior encoder's WaveNet layers. Defaults to 1. + + num_layers_posterior_encoder (int): + Number of posterior encoder's WaveNet layers. Defaults to 16. + + kernel_size_flow (int): + Kernel size of the Residual Coupling layers of the flow network. Defaults to 5. + + dilatation_flow (int): + Dilation rate of the Residual Coupling WaveNet layers of the flow network. Defaults to 1. + + num_layers_flow (int): + Number of Residual Coupling WaveNet layers of the flow network. Defaults to 6. + + resblock_type_decoder (str): + Type of the residual block in the decoder network. Defaults to "1". + + resblock_kernel_sizes_decoder (List[int]): + Kernel sizes of the residual blocks in the decoder network. Defaults to `[3, 7, 11]`. + + resblock_dilation_sizes_decoder (List[List[int]]): + Dilation sizes of the residual blocks in the decoder network. Defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`. + + upsample_rates_decoder (List[int]): + Upsampling rates for each concecutive upsampling layer in the decoder network. The multiply of these + values must be equal to the kop length used for computing spectrograms. Defaults to `[8, 8, 2, 2]`. + + upsample_initial_channel_decoder (int): + Number of hidden channels of the first upsampling convolution layer of the decoder network. Defaults to 512. + + upsample_kernel_sizes_decoder (List[int]): + Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`. + + periods_multi_period_discriminator (List[int]): + Periods values for Vits Multi-Period Discriminator. Defaults to `[2, 3, 5, 7, 11]`. + + use_sdp (bool): + Use Stochastic Duration Predictor. Defaults to True. + + noise_scale (float): + Noise scale used for the sample noise tensor in training. Defaults to 1.0. + + inference_noise_scale (float): + Noise scale used for the sample noise tensor in inference. Defaults to 0.667. + + length_scale (float): + Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1. + + noise_scale_dp (float): + Noise scale used by the Stochastic Duration Predictor sample noise in training. Defaults to 1.0. + + inference_noise_scale_dp (float): + Noise scale for the Stochastic Duration Predictor in inference. Defaults to 0.8. + + max_inference_len (int): + Maximum inference length to limit the memory use. Defaults to None. + + init_discriminator (bool): + Initialize the disciminator network if set True. Set False for inference. Defaults to True. + + use_spectral_norm_disriminator (bool): + Use spectral normalization over weight norm in the discriminator. Defaults to False. + + use_speaker_embedding (bool): + Enable/Disable speaker embedding for multi-speaker models. Defaults to False. + + num_speakers (int): + Number of speakers for the speaker embedding layer. Defaults to 0. + + speakers_file (str): + Path to the speaker mapping file for the Speaker Manager. Defaults to None. + + speaker_embedding_channels (int): + Number of speaker embedding channels. Defaults to 256. + + use_d_vector_file (bool): + Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False. + + d_vector_file (List[str]): + List of paths to the files including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Number of d-vector channels. Defaults to 0. + + detach_dp_input (bool): + Detach duration predictor's input from the network for stopping the gradients. Defaults to True. + + use_language_embedding (bool): + Enable/Disable language embedding for multilingual models. Defaults to False. + + embedded_language_dim (int): + Number of language embedding channels. Defaults to 4. + + num_languages (int): + Number of languages for the language embedding layer. Defaults to 0. + + language_ids_file (str): + Path to the language mapping file for the Language Manager. Defaults to None. + + use_speaker_encoder_as_loss (bool): + Enable/Disable Speaker Consistency Loss (SCL). Defaults to False. + + speaker_encoder_config_path (str): + Path to the file speaker encoder config file, to use for SCL. Defaults to "". + + speaker_encoder_model_path (str): + Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". + + condition_dp_on_speaker (bool): + Condition the duration predictor on the speaker embedding. Defaults to True. + + freeze_encoder (bool): + Freeze the encoder weigths during training. Defaults to False. + + freeze_DP (bool): + Freeze the duration predictor weigths during training. Defaults to False. + + freeze_PE (bool): + Freeze the posterior encoder weigths during training. Defaults to False. + + freeze_flow_encoder (bool): + Freeze the flow encoder weigths during training. Defaults to False. + + freeze_waveform_decoder (bool): + Freeze the waveform decoder weigths during training. Defaults to False. + + encoder_sample_rate (int): + If not None this sample rate will be used for training the Posterior Encoder, + flow, text_encoder and duration predictor. The decoder part (vocoder) will be + trained with the `config.audio.sample_rate`. Defaults to None. + + interpolate_z (bool): + If `encoder_sample_rate` not None and this parameter True the nearest interpolation + will be used to upsampling the latent variable z with the sampling rate `encoder_sample_rate` + to the `config.audio.sample_rate`. If it is False you will need to add extra + `upsample_rates_decoder` to match the shape. Defaults to True. + + use_transformer_flow_layer (bool): + Use Transformer Flow layer inside of Residual Coupling Blocks. Defaults to True. + """ + + num_chars: int = 100 + out_channels: int = 80 + spec_segment_size: int = 32 + hidden_channels: int = 192 + hidden_channels_ffn_text_encoder: int = 768 + num_heads_text_encoder: int = 2 + num_layers_text_encoder: int = 6 + kernel_size_text_encoder: int = 3 + dropout_p_text_encoder: float = 0.1 + dropout_p_duration_predictor: float = 0.5 + kernel_size_posterior_encoder: int = 5 + dilation_rate_posterior_encoder: int = 1 + num_layers_posterior_encoder: int = 16 + kernel_size_flow: int = 5 + dilation_rate_flow: int = 1 + num_layers_flow: int = 4 + resblock_type_decoder: str = "1" + resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11]) + resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) + upsample_initial_channel_decoder: int = 512 + upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) + periods_multi_period_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11]) + use_sdp: bool = True + noise_scale: float = 1.0 + inference_noise_scale: float = 0.667 + length_scale: float = 1 + noise_scale_dp: float = 1.0 + inference_noise_scale_dp: float = 1.0 + max_inference_len: int = None + init_discriminator: bool = True + use_spectral_norm_disriminator: bool = False + use_speaker_embedding: bool = False + num_speakers: int = 0 + speakers_file: str = None + d_vector_file: List[str] = None + speaker_embedding_channels: int = 256 + use_d_vector_file: bool = False + d_vector_dim: int = 0 + detach_dp_input: bool = True + use_language_embedding: bool = False + embedded_language_dim: int = 4 + num_languages: int = 0 + language_ids_file: str = None + use_speaker_encoder_as_loss: bool = False + speaker_encoder_config_path: str = "" + speaker_encoder_model_path: str = "" + condition_dp_on_speaker: bool = True + freeze_encoder: bool = False + freeze_DP: bool = False + freeze_PE: bool = False + freeze_flow_decoder: bool = False + freeze_waveform_decoder: bool = False + encoder_sample_rate: int = None + interpolate_z: bool = True + reinit_DP: bool = False + reinit_text_encoder: bool = False + use_transformer_flow: bool = True + use_noise_scaled_MAS: bool = True + init_dur_discriminator: bool = True + mas_noise_scale_initial: float = 2e-6 + noise_scale_delta: float = 0.01 + +class Vits2(BaseTTS): + """VITS2 TTS model + + Paper:: + https://arxiv.org/pdf/2307.16430.pdf + + Paper Abstract:: + ## TODO : add paper abstract + + Check :class:`TTS.tts.configs.vits2_config.Vits2Config` for class arguments. + + Examples: + >>> from TTS.tts.configs.vits_config import Vits2Config + >>> from TTS.tts.models.vits import Vits2 + >>> config = Vits2Config() + >>> model = Vits2(config) + """ + + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager, language_manager) + + self.init_multispeaker(config) + self.init_multilingual(config) + self.init_upsampling() + + self.length_scale = self.args.length_scale + self.noise_scale = self.args.noise_scale + self.inference_noise_scale = self.args.inference_noise_scale + self.inference_noise_scale_dp = self.args.inference_noise_scale_dp + self.noise_scale_dp = self.args.noise_scale_dp + self.mas_noise_scale_initial=self.args.mas_noise_scale_initial + self.noise_scale_delta=self.args.noise_scale_delta + + self.max_inference_len = self.args.max_inference_len + self.spec_segment_size = self.args.spec_segment_size + self.use_transformer_flow_layer = self.args.use_transformer_flow_layer + self.text_encoder = TextEncoder( + self.args.num_chars, + self.args.hidden_channels, + self.args.hidden_channels, + self.args.hidden_channels_ffn_text_encoder, + self.args.num_heads_text_encoder, + self.args.num_layers_text_encoder, + self.args.kernel_size_text_encoder, + self.args.dropout_p_text_encoder, + language_emb_dim=self.embedded_language_dim, + ) + + self.posterior_encoder = PosteriorEncoder( + self.args.out_channels, + self.args.hidden_channels, + self.args.hidden_channels, + kernel_size=self.args.kernel_size_posterior_encoder, + dilation_rate=self.args.dilation_rate_posterior_encoder, + num_layers=self.args.num_layers_posterior_encoder, + cond_channels=self.embedded_speaker_dim, + ) + + self.flow = ResidualCouplingBlocks( + self.args.hidden_channels, + self.args.hidden_channels, + kernel_size=self.args.kernel_size_flow, + dilation_rate=self.args.dilation_rate_flow, + num_layers=self.args.num_layers_flow, + cond_channels=self.embedded_speaker_dim, + use_transformer_flow_layer=self.use_transformer_flow_layer, + ) + + if self.args.use_sdp: + self.duration_predictor = StochasticDurationPredictor( + self.args.hidden_channels, + 192, + 3, + self.args.dropout_p_duration_predictor, + 4, + cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, + language_emb_dim=self.embedded_language_dim, + ) + else: + self.duration_predictor = DurationPredictor( + self.args.hidden_channels, + 256, + 3, + self.args.dropout_p_duration_predictor, + cond_channels=self.embedded_speaker_dim, + language_emb_dim=self.embedded_language_dim, + ) + + self.waveform_decoder = HifiganGenerator( + self.args.hidden_channels, + 1, + self.args.resblock_type_decoder, + self.args.resblock_dilation_sizes_decoder, + self.args.resblock_kernel_sizes_decoder, + self.args.upsample_kernel_sizes_decoder, + self.args.upsample_initial_channel_decoder, + self.args.upsample_rates_decoder, + inference_padding=0, + cond_channels=self.embedded_speaker_dim, + conv_pre_weight_norm=False, + conv_post_weight_norm=False, + conv_post_bias=False, + ) + + if self.args.init_discriminator: + self.disc = VitsDiscriminator( + periods=self.args.periods_multi_period_discriminator, + use_spectral_norm=self.args.use_spectral_norm_disriminator, + ) + + if self.args.init_dur_discriminator: + self.dur_disc = DurationDiscriminator( + in_channels=self.args.hidden_channels, + filter_channels=self.args.hidden_channels, + kernel_size=3, + p_dropout=0.1, + gin_channels=0 #chnage this if condition on spekaer in future + ) + + + @property + def device(self): + return next(self.parameters()).device + + def init_multispeaker(self, config: Coqpit): + """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer + or with external `d_vectors` computed from a speaker encoder model. + + You must provide a `speaker_manager` at initialization to set up the multi-speaker modules. + + Args: + config (Coqpit): Model configuration. + data (List, optional): Dataset items to infer number of speakers. Defaults to None. + """ + self.embedded_speaker_dim = 0 + self.num_speakers = self.args.num_speakers + self.audio_transform = None + + if self.speaker_manager: + self.num_speakers = self.speaker_manager.num_speakers + + if self.args.use_speaker_embedding: + self._init_speaker_embedding() + + if self.args.use_d_vector_file: + self._init_d_vector() + + # TODO: make this a function + if self.args.use_speaker_encoder_as_loss: + if self.speaker_manager.encoder is None and ( + not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path + ): + raise RuntimeError( + " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" + ) + + self.speaker_manager.encoder.eval() + print(" > External Speaker Encoder Loaded !!") + + if ( + hasattr(self.speaker_manager.encoder, "audio_config") + and self.config.audio.sample_rate != self.speaker_manager.encoder.audio_config["sample_rate"] + ): + self.audio_transform = torchaudio.transforms.Resample( + orig_freq=self.config.audio.sample_rate, + new_freq=self.speaker_manager.encoder.audio_config["sample_rate"], + ) + + def _init_speaker_embedding(self): + # pylint: disable=attribute-defined-outside-init + if self.num_speakers > 0: + print(" > initialization of speaker-embedding layers.") + self.embedded_speaker_dim = self.args.speaker_embedding_channels + self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + + def _init_d_vector(self): + # pylint: disable=attribute-defined-outside-init + if hasattr(self, "emb_g"): + raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") + self.embedded_speaker_dim = self.args.d_vector_dim + + def init_multilingual(self, config: Coqpit): + """Initialize multilingual modules of a model. + + Args: + config (Coqpit): Model configuration. + """ + if self.args.language_ids_file is not None: + self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) + + if self.args.use_language_embedding and self.language_manager: + print(" > initialization of language-embedding layers.") + self.num_languages = self.language_manager.num_languages + self.embedded_language_dim = self.args.embedded_language_dim + self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim) + torch.nn.init.xavier_uniform_(self.emb_l.weight) + else: + self.embedded_language_dim = 0 + + def init_upsampling(self): + """ + Initialize upsampling modules of a model. + """ + if self.args.encoder_sample_rate: + self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate + self.audio_resampler = torchaudio.transforms.Resample( + orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate + ) # pylint: disable=W0201 + + def on_epoch_start(self, trainer): # pylint: disable=W0613 + """Freeze layers at the beginning of an epoch""" + self._freeze_layers() + # set the device of speaker encoder + if self.args.use_speaker_encoder_as_loss: + self.speaker_manager.encoder = self.speaker_manager.encoder.to(self.device) + + def on_init_end(self, trainer): # pylint: disable=W0613 + """Reinit layes if needed""" + if self.args.reinit_DP: + before_dict = get_module_weights_sum(self.duration_predictor) + # Applies weights_reset recursively to every submodule of the duration predictor + self.duration_predictor.apply(fn=weights_reset) + after_dict = get_module_weights_sum(self.duration_predictor) + for key, value in after_dict.items(): + if value == before_dict[key]: + raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !") + print(" > Duration Predictor was reinit.") + + if self.args.reinit_text_encoder: + before_dict = get_module_weights_sum(self.text_encoder) + # Applies weights_reset recursively to every submodule of the duration predictor + self.text_encoder.apply(fn=weights_reset) + after_dict = get_module_weights_sum(self.text_encoder) + for key, value in after_dict.items(): + if value == before_dict[key]: + raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !") + print(" > Text Encoder was reinit.") + + def get_aux_input(self, aux_input: Dict): + sid, g, lid, _ = self._set_cond_input(aux_input) + return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} + + def _freeze_layers(self): + if self.args.freeze_encoder: + for param in self.text_encoder.parameters(): + param.requires_grad = False + + if hasattr(self, "emb_l"): + for param in self.emb_l.parameters(): + param.requires_grad = False + + if self.args.freeze_PE: + for param in self.posterior_encoder.parameters(): + param.requires_grad = False + + if self.args.freeze_DP: + for param in self.duration_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_flow_decoder: + for param in self.flow.parameters(): + param.requires_grad = False + + if self.args.freeze_waveform_decoder: + for param in self.waveform_decoder.parameters(): + param.requires_grad = False + + @staticmethod + def _set_cond_input(aux_input: Dict): + """Set the speaker conditioning input based on the multi-speaker mode.""" + sid, g, lid, durations = None, None, None, None + if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: + sid = aux_input["speaker_ids"] + if sid.ndim == 0: + sid = sid.unsqueeze_(0) + if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: + g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1) + if g.ndim == 2: + g = g.unsqueeze_(0) + + if "language_ids" in aux_input and aux_input["language_ids"] is not None: + lid = aux_input["language_ids"] + if lid.ndim == 0: + lid = lid.unsqueeze_(0) + + if "durations" in aux_input and aux_input["durations"] is not None: + durations = aux_input["durations"] + + return sid, g, lid, durations + + def _set_speaker_input(self, aux_input: Dict): + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): + # find the alignment path + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + #do not confuse logp with logs_p; + #logp is the probability likelihood; + #it is the likelihood that z_p is from distribution (m_p, logs_p). + with torch.no_grad(): + o_scale = torch.exp(-2 * logs_p) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)]) + logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) + logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp2 + logp3 + logp1 + logp4 + if self.use_noise_scaled_mas: + epsilon = torch.std(logs_p) * torch.randn_like(logs_p) * self.current_mas_noise_scale + logp = logp + epsilon + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t'] + + # duration predictor + attn_durations = attn.sum(3) + if self.args.use_sdp: + loss_duration = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + attn_durations, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, + ) + loss_duration = loss_duration / torch.sum(x_mask) + else: + attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask + log_durations = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, + ) + loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) + outputs["loss_duration"] = loss_duration + return outputs, attn, attn_log_durations, log_durations + + def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None): + spec_segment_size = self.spec_segment_size + if self.args.encoder_sample_rate: + # recompute the slices and spec_segment_size if needed + slice_ids = slice_ids * int(self.interpolate_factor) if slice_ids is not None else slice_ids + spec_segment_size = spec_segment_size * int(self.interpolate_factor) + # interpolate z if needed + if self.args.interpolate_z: + z = torch.nn.functional.interpolate(z, scale_factor=[self.interpolate_factor], mode="linear").squeeze(0) + # recompute the mask if needed + if y_lengths is not None and y_mask is not None: + y_mask = ( + sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1) + ) # [B, 1, T_dec_resampled] + + return z, spec_segment_size, slice_ids, y_mask + + def forward( # pylint: disable=dangerous-default-value + self, + x: torch.tensor, + x_lengths: torch.tensor, + y: torch.tensor, + y_lengths: torch.tensor, + waveform: torch.tensor, + current_mas_noise_scale, + aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}, + ) -> Dict: + """Forward pass of the model. + + Args: + x (torch.tensor): Batch of input character sequence IDs. + x_lengths (torch.tensor): Batch of input character sequence lengths. + y (torch.tensor): Batch of input spectrograms. + y_lengths (torch.tensor): Batch of input spectrogram lengths. + waveform (torch.tensor): Batch of ground truth waveforms per sample. + current_mas_noise_scale (float): Current MAS noise scale. + aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training. + Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}. + + Returns: + Dict: model outputs keyed by the output name. + + Shapes: + - x: :math:`[B, T_seq]` + - x_lengths: :math:`[B]` + - y: :math:`[B, C, T_spec]` + - y_lengths: :math:`[B]` + - waveform: :math:`[B, 1, T_wav]` + - d_vectors: :math:`[B, C, 1]` + - speaker_ids: :math:`[B]` + - language_ids: :math:`[B]` + + Return Shapes: + - model_outputs: :math:`[B, 1, T_wav]` + - alignments: :math:`[B, T_seq, T_dec]` + - z: :math:`[B, C, T_dec]` + - z_p: :math:`[B, C, T_dec]` + - m_p: :math:`[B, C, T_dec]` + - logs_p: :math:`[B, C, T_dec]` + - m_q: :math:`[B, C, T_dec]` + - logs_q: :math:`[B, C, T_dec]` + - waveform_seg: :math:`[B, 1, spec_seg_size * hop_length]` + - gt_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` + - syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` + - hidden_encoded_text: :math:`[B, T_seq, hidden_channels]` + - hidden_encoded_text_mask: :math:`[B, 1, T_seq]` + """ + outputs = {} + sid, g, lid, _ = self._set_cond_input(aux_input) + # speaker embedding + if self.args.use_speaker_embedding and sid is not None: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + + # language embedding + lang_emb = None + if self.args.use_language_embedding and lid is not None: + lang_emb = self.emb_l(lid).unsqueeze(-1) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + + # posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) + + # flow layers + z_p = self.flow(z, y_mask, g=g) + + # duration predictor + self.current_mas_noise_scale = current_mas_noise_scale + outputs, attn, attn_log_durations, log_durations = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) + + # expand prior + m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) + logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) + + # select a random feature segment for the waveform decoder + z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) + + # interpolate z if needed + z_slice, spec_segment_size, slice_ids, _ = self.upsampling_z(z_slice, slice_ids=slice_ids) + + o = self.waveform_decoder(z_slice, g=g) + + wav_seg = segment( + waveform, + slice_ids * self.config.audio.hop_length, + spec_segment_size * self.config.audio.hop_length, + pad_short=True, + ) + + if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None: + # concate generated and GT waveforms + wavs_batch = torch.cat((wav_seg, o), dim=0) + + # resample audio to speaker encoder sample_rate + # pylint: disable=W0105 + if self.audio_transform is not None: + wavs_batch = self.audio_transform(wavs_batch) + + pred_embs = self.speaker_manager.encoder.forward(wavs_batch, l2_norm=True) + + # split generated and GT speaker embeddings + gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) + else: + gt_spk_emb, syn_spk_emb = None, None + + outputs.update( + { + "model_outputs": o, + "alignments": attn.squeeze(1), + "m_p": m_p, + "logs_p": logs_p, + "z": z, + "z_p": z_p, + "m_q": m_q, + "logs_q": logs_q, + "waveform_seg": wav_seg, + "gt_spk_emb": gt_spk_emb, + "syn_spk_emb": syn_spk_emb, + "slice_ids": slice_ids, + "hidden_encoded_text": x, + "hidden_encoded_text_mask": x_mask, + "real_durations": attn_log_durations, + "predicted_durations": log_durations, + } + ) + return outputs + + @staticmethod + def _set_x_lengths(x, aux_input): + if "x_lengths" in aux_input and aux_input["x_lengths"] is not None: + return aux_input["x_lengths"] + return torch.tensor(x.shape[1:2]).to(x.device) + + @torch.no_grad() + def inference( + self, + x, + aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None}, + ): # pylint: disable=dangerous-default-value + """ + Note: + To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1. + + Shapes: + - x: :math:`[B, T_seq]` + - x_lengths: :math:`[B]` + - d_vectors: :math:`[B, C]` + - speaker_ids: :math:`[B]` + + Return Shapes: + - model_outputs: :math:`[B, 1, T_wav]` + - alignments: :math:`[B, T_seq, T_dec]` + - z: :math:`[B, C, T_dec]` + - z_p: :math:`[B, C, T_dec]` + - m_p: :math:`[B, C, T_dec]` + - logs_p: :math:`[B, C, T_dec]` + """ + sid, g, lid, durations = self._set_cond_input(aux_input) + x_lengths = self._set_x_lengths(x, aux_input) + + # speaker embedding + if self.args.use_speaker_embedding and sid is not None: + g = self.emb_g(sid).unsqueeze(-1) + + # language embedding + lang_emb = None + if self.args.use_language_embedding and lid is not None: + lang_emb = self.emb_l(lid).unsqueeze(-1) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + + if durations is None: + if self.args.use_sdp: + logw = self.duration_predictor( + x, + x_mask, + g=g if self.args.condition_dp_on_speaker else None, + reverse=True, + noise_scale=self.inference_noise_scale_dp, + lang_emb=lang_emb, + ) + else: + logw = self.duration_predictor( + x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb + ) + w = torch.exp(logw) * x_mask * self.length_scale + else: + assert durations.shape[-1] == x.shape[-1] + w = durations.unsqueeze(0) + + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec] + + attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1] + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) + + m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) + logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + + # upsampling if needed + z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask) + + o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g) + + outputs = { + "model_outputs": o, + "alignments": attn.squeeze(1), + "durations": w_ceil, + "z": z, + "z_p": z_p, + "m_p": m_p, + "logs_p": logs_p, + "y_mask": y_mask, + } + return outputs + + @torch.no_grad() + def inference_voice_conversion( + self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None + ): + """Inference for voice conversion + + Args: + reference_wav (Tensor): Reference wavform. Tensor of shape [B, T] + speaker_id (Tensor): speaker_id of the target speaker. Tensor of shape [B] + d_vector (Tensor): d_vector embedding of target speaker. Tensor of shape `[B, C]` + reference_speaker_id (Tensor): speaker_id of the reference_wav speaker. Tensor of shape [B] + reference_d_vector (Tensor): d_vector embedding of the reference_wav speaker. Tensor of shape `[B, C]` + """ + # compute spectrograms + y = wav_to_spec( + reference_wav, + self.config.audio.fft_size, + self.config.audio.hop_length, + self.config.audio.win_length, + center=False, + ) + y_lengths = torch.tensor([y.size(-1)]).to(y.device) + speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector + speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector + wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt) + return wav + + def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt): + """Forward pass for voice conversion + + TODO: create an end-point for voice conversion + + Args: + y (Tensor): Reference spectrograms. Tensor of shape [B, T, C] + y_lengths (Tensor): Length of each reference spectrogram. Tensor of shape [B] + speaker_cond_src (Tensor): Reference speaker ID. Tensor of shape [B,] + speaker_cond_tgt (Tensor): Target speaker ID. Tensor of shape [B,] + """ + assert self.num_speakers > 0, "num_speakers have to be larger than 0." + # speaker embedding + if self.args.use_speaker_embedding and not self.args.use_d_vector_file: + g_src = self.emb_g(torch.from_numpy((np.array(speaker_cond_src))).unsqueeze(0)).unsqueeze(-1) + g_tgt = self.emb_g(torch.from_numpy((np.array(speaker_cond_tgt))).unsqueeze(0)).unsqueeze(-1) + elif not self.args.use_speaker_embedding and self.args.use_d_vector_file: + g_src = F.normalize(speaker_cond_src).unsqueeze(-1) + g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1) + else: + raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.") + + z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=g_src) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) + return o_hat, y_mask, (z, z_p, z_hat) + + def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: + """Perform a single training step. Run the model forward pass and compute losses. + + Args: + batch (Dict): Input tensors. + criterion (nn.Module): Loss layer designed for the model. + optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks. + + Returns: + Tuple[Dict, Dict]: Model ouputs and computed losses. + """ + + # spec_lens = batch["spec_lens"] + spec_lens = batch["mel_lens"] #vits2 + + if optimizer_idx == 0: + tokens = batch["tokens"] + token_lenghts = batch["token_lens"] + # spec = batch["spec"] + spec = batch["mel"] #vits2 + + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + language_ids = batch["language_ids"] + waveform = batch["waveform"] + + # generator pass + outputs = self.forward( + tokens, + token_lenghts, + spec, + spec_lens, + waveform, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, + ) + + # cache tensors for the generator pass + self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init + + # compute scores and features + scores_disc_fake, _, scores_disc_real, _ = self.disc( + outputs["model_outputs"].detach(), outputs["waveform_seg"] + ) + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + scores_disc_real, + scores_disc_fake, + ) + return outputs, loss_dict + + if optimizer_idx == 1: + mel = batch["mel"] + + # compute melspec segment + with autocast(enabled=False): + if self.args.encoder_sample_rate: + spec_segment_size = self.spec_segment_size * int(self.interpolate_factor) + else: + spec_segment_size = self.spec_segment_size + + mel_slice = segment( + mel.float(), self.model_outputs_cache["slice_ids"], spec_segment_size, pad_short=True + ) + mel_slice_hat = wav_to_mel( + y=self.model_outputs_cache["model_outputs"].float(), + n_fft=self.config.audio.fft_size, + sample_rate=self.config.audio.sample_rate, + num_mels=self.config.audio.num_mels, + hop_length=self.config.audio.hop_length, + win_length=self.config.audio.win_length, + fmin=self.config.audio.mel_fmin, + fmax=self.config.audio.mel_fmax, + center=False, + ) + + # compute discriminator scores and features + scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc( + self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] + ) + + # compute losses + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + mel_slice_hat=mel_slice.float(), + mel_slice=mel_slice_hat.float(), + z_p=self.model_outputs_cache["z_p"].float(), + logs_q=self.model_outputs_cache["logs_q"].float(), + m_p=self.model_outputs_cache["m_p"].float(), + logs_p=self.model_outputs_cache["logs_p"].float(), + z_len=spec_lens, + scores_disc_fake=scores_disc_fake, + feats_disc_fake=feats_disc_fake, + feats_disc_real=feats_disc_real, + loss_duration=self.model_outputs_cache["loss_duration"], + use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, + gt_spk_emb=self.model_outputs_cache["gt_spk_emb"], + syn_spk_emb=self.model_outputs_cache["syn_spk_emb"], + ) + + return self.model_outputs_cache, loss_dict + + if optimizer_idx == 2: + output_prob_for_real, output_probs_for_pred = self.dur_disc( + self.model_outputs_cache['hidden_encoded_text'], + self.model_outputs_cache['hidden_encoded_text_mask'], + self.model_outputs_cache['real_durations'], #logscaled + self.model_outputs_cache['predicted_durations'] #logscaled + ) + + outputs = { + "hidden_encoded_text" : self.model_outputs_cache['hidden_encoded_text'], + "hidden_encoded_text_mask" : self.model_outputs_cache['hidden_encoded_text_mask'], + "real_durations" : self.model_outputs_cache['real_durations'], #logscaled + "predicted_durations" : self.model_outputs_cache['predicted_durations'] #logscaled + } + with autocast(enabled=False): + loss_dict = criterion[optimizer_idx]( + output_prob_for_real, + output_probs_for_pred, + ) + return outputs, loss_dict + + raise ValueError(" [!] Unexpected `optimizer_idx`.") + + def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use + y_hat = outputs[1]["model_outputs"] + y = outputs[1]["waveform_seg"] + figures = plot_results(y_hat, y, ap, name_prefix) + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + audios = {f"{name_prefix}/audio": sample_voice} + + alignments = outputs[1]["alignments"] + align_img = alignments[0].data.cpu().numpy().T + + figures.update( + { + "alignment": plot_alignment(align_img, output_fig=False), + } + ) + return figures, audios + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=no-self-use + """Create visualizations and waveform examples. + + For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to + be projected onto Tensorboard. + + Args: + ap (AudioProcessor): audio processor used at training. + batch (Dict): Model inputs used at the previous training step. + outputs (Dict): Model outputs generated at the previoud training step. + + Returns: + Tuple[Dict, np.ndarray]: training plots and output waveform. + """ + figures, audios = self._log(self.ap, batch, outputs, "train") + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + @torch.no_grad() + def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): + return self.train_step(batch, criterion, optimizer_idx) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._log(self.ap, batch, outputs, "eval") + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if hasattr(self, "speaker_manager"): + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_embedding() + else: + d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_id() + else: + speaker_id = self.speaker_manager.name_to_id[speaker_name] + + # get language id + if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.name_to_id[language_name] + + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + "language_name": language_name, + } + + @torch.no_grad() + def test_run(self, assets) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + for idx, s_info in enumerate(test_sentences): + aux_inputs = self.get_aux_input_from_test_sentences(s_info) + wav, alignment, _, _ = synthesis( + self, + aux_inputs["text"], + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + language_id=aux_inputs["language_id"], + use_griffin_lim=True, + do_trim_silence=False, + ).values() + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) + return {"figures": test_figures, "audios": test_audios} + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) + logger.test_figures(steps, outputs["figures"]) + + def format_batch(self, batch: Dict) -> Dict: + """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" + speaker_ids = None + language_ids = None + d_vectors = None + + # get numerical speaker ids from speaker names + if self.speaker_manager is not None and self.speaker_manager.name_to_id and self.args.use_speaker_embedding: + speaker_ids = [self.speaker_manager.name_to_id[sn] for sn in batch["speaker_names"]] + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) + + # get d_vectors from audio file names + if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file: + d_vector_mapping = self.speaker_manager.embeddings + d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_unique_names"]] + d_vectors = torch.FloatTensor(d_vectors) + + # get language ids from language names + if self.language_manager is not None and self.language_manager.name_to_id and self.args.use_language_embedding: + language_ids = [self.language_manager.name_to_id[ln] for ln in batch["language_names"]] + + if language_ids is not None: + language_ids = torch.LongTensor(language_ids) + + batch["language_ids"] = language_ids + batch["d_vectors"] = d_vectors + batch["speaker_ids"] = speaker_ids + return batch + + def format_batch_on_device(self, batch): + """Compute spectrograms on the device.""" + ac = self.config.audio + + if self.args.encoder_sample_rate: + wav = self.audio_resampler(batch["waveform"]) + else: + wav = batch["waveform"] + + # compute spectrograms + batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False) + + if self.args.encoder_sample_rate: + # recompute spec with high sampling rate to the loss + spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False) + # remove extra stft frames if needed + if spec_mel.size(2) > int(batch["spec"].size(2) * self.interpolate_factor): + spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)] + else: + batch["spec"] = batch["spec"][:, :, : int(spec_mel.size(2) / self.interpolate_factor)] + else: + spec_mel = batch["spec"] + + batch["mel"] = spec_to_mel( + spec=spec_mel, + n_fft=ac.fft_size, + num_mels=ac.num_mels, + sample_rate=ac.sample_rate, + fmin=ac.mel_fmin, + fmax=ac.mel_fmax, + ) + + if self.args.encoder_sample_rate: + assert batch["spec"].shape[2] == int( + batch["mel"].shape[2] / self.interpolate_factor + ), f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" + else: + assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" + + # compute spectrogram frame lengths + batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int() + batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int() + + if self.args.encoder_sample_rate: + assert (batch["spec_lens"] - (batch["mel_lens"] / self.interpolate_factor).int()).sum() == 0 + else: + assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0 + + # zero the padding frames + batch["spec"] = batch["spec"] * sequence_mask(batch["spec_lens"]).unsqueeze(1) + batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1) + return batch + + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=False): + weights = None + data_items = dataset.samples + if getattr(config, "use_weighted_sampler", False): + for attr_name, alpha in config.weighted_sampler_attrs.items(): + print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'") + multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) + print(multi_dict) + weights, attr_names, attr_weights = get_attribute_balancer_weights( + attr_name=attr_name, items=data_items, multi_dict=multi_dict + ) + weights = weights * alpha + print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}") + + # input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items] + + if weights is not None: + w_sampler = WeightedRandomSampler(weights, len(weights)) + batch_sampler = BucketBatchSampler( + w_sampler, + data=data_items, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + sort_key=lambda x: os.path.getsize(x["audio_file"]), + drop_last=True, + ) + else: + batch_sampler = None + # sampler for DDP + if batch_sampler is None: + batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + batch_sampler = ( + DistributedSamplerWrapper(batch_sampler) if num_gpus > 1 else batch_sampler + ) # TODO: check batch_sampler with multi-gpu + return batch_sampler + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # init dataloader + dataset = Vits2Dataset( + model_args=self.args, + samples=samples, + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + verbose=verbose, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences from short to long + dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) + if sampler is None: + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, # shuffle is done in the dataset. + collate_fn=dataset.collate_fn, + drop_last=False, # setting this False might cause issues in AMP training. + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + else: + if num_gpus > 1: + loader = DataLoader( + dataset, + sampler=sampler, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + else: + loader = DataLoader( + dataset, + batch_sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def get_optimizer(self) -> List: + """Initiate and return the GAN optimizers based on the config parameters. + It returnes 3 optimizers in a list. First one is for the generator and the second one is for the discriminator. + Returns: + List: optimizers. + """ + # select generator parameters + optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) + + gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) + optimizer1 = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters + ) + optimizer2 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_dur, self.dur_disc) + return [optimizer0, optimizer1, optimizer2] + + def get_lr(self) -> List: + """Set the initial learning rates for each optimizer. + + Returns: + List: learning rates for each optimizer. + """ + return [self.config.lr_disc, self.config.lr_gen, self.config.lr_dur] + + def get_scheduler(self, optimizer) -> List: + """Set the schedulers for each optimizer. + + Args: + optimizer (List[`torch.optim.Optimizer`]): List of optimizers. + + Returns: + List: Schedulers, one for each optimizer. + """ + scheduler_D = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[0]) + scheduler_G = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[1]) + scheduler_DUR = get_scheduler(self.config.lr_scheduler_dur, self.config.lr_scheduler_dur_params, optimizer[2]) + return [scheduler_D, scheduler_G] + + def get_criterion(self): + """Get criterions for each optimizer. The index in the output list matches the optimizer idx used in + `train_step()`""" + from TTS.tts.layers.losses import ( # pylint: disable=import-outside-toplevel + VitsDiscriminatorLoss, + VitsGeneratorLoss, + Vits2DurationLoss + ) + + return [VitsDiscriminatorLoss(self.config), VitsGeneratorLoss(self.config), Vits2DurationLoss(self.config)] + + def on_train_step_start(self, trainer): + """MAS noise scale.""" + current_mas_noise_scale = self.mas_noise_scale_initial - self.noise_scale_delta * trainer.epochs_done + #TODO is steps = epochs done * batch size * deivces? + self.current_mas_noise_scale = max(current_mas_noise_scale, 0.0) + + def load_checkpoint( + self, config, checkpoint_path, eval=False, strict=True, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + """Load the model checkpoint and setup for training or inference""" + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + # compat band-aid for the pre-trained models to not use the encoder baked into the model + # TODO: consider baking the speaker encoder into the model and call it from there. + # as it is probably easier for model distribution. + state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} + + if self.args.encoder_sample_rate is not None and eval: + # audio resampler is not used in inference time + self.audio_resampler = None + + # handle fine-tuning from a checkpoint with additional speakers + if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: + num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] + print(f" > Loading checkpoint with {num_new_speakers} additional speakers.") + emb_g = state["model"]["emb_g.weight"] + new_row = torch.randn(num_new_speakers, emb_g.shape[1]) + emb_g = torch.cat([emb_g, new_row], axis=0) + state["model"]["emb_g.weight"] = emb_g + # load the model weights + self.load_state_dict(state["model"], strict=strict) + + if eval: + self.eval() + assert not self.training + + # def load_fairseq_checkpoint( + # self, config, checkpoint_dir, eval=False, strict=True + # ): # pylint: disable=unused-argument, redefined-builtin + # """Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms + # Performs some changes for compatibility. + + # Args: + # config (Coqpit): 🐸TTS model config. + # checkpoint_dir (str): Path to the checkpoint directory. + # eval (bool, optional): Set to True for evaluation. Defaults to False. + # """ + # import json + + # from TTS.tts.utils.text.cleaners import basic_cleaners + + # self.disc = None + # # set paths + # config_file = os.path.join(checkpoint_dir, "config.json") + # checkpoint_file = os.path.join(checkpoint_dir, "G_100000.pth") + # vocab_file = os.path.join(checkpoint_dir, "vocab.txt") + # # set config params + # with open(config_file, "r", encoding="utf-8") as file: + # # Load the JSON data as a dictionary + # config_org = json.load(file) + # self.config.audio.sample_rate = config_org["data"]["sampling_rate"] + # # self.config.add_blank = config['add_blank'] + # # set tokenizer + # vocab = FairseqVocab(vocab_file) + # self.text_encoder.emb = nn.Embedding(vocab.num_chars, config.model_args.hidden_channels) + # self.tokenizer = TTSTokenizer( + # use_phonemes=False, + # text_cleaner=basic_cleaners, + # characters=vocab, + # phonemizer=None, + # add_blank=config_org["data"]["add_blank"], + # use_eos_bos=False, + # ) + # # load fairseq checkpoint + # new_chk = rehash_fairseq_vits_checkpoint(checkpoint_file) + # self.load_state_dict(new_chk, strict=strict) + # if eval: + # self.eval() + # assert not self.training + + @staticmethod + def init_from_config(config: "Vits2Config", samples: Union[List[List], List[Dict]] = None, verbose=True): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item() + + if not config.model_args.encoder_sample_rate: + assert ( + upsample_rate == config.audio.hop_length + ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}" + else: + encoder_to_vocoder_upsampling_factor = config.audio.sample_rate / config.model_args.encoder_sample_rate + effective_hop_length = config.audio.hop_length * encoder_to_vocoder_upsampling_factor + assert ( + upsample_rate == effective_hop_length + ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {effective_hop_length}" + + ap = AudioProcessor.init_from_config(config, verbose=verbose) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + language_manager = LanguageManager.init_from_config(config) + + if config.model_args.speaker_encoder_model_path: + speaker_manager.init_encoder( + config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path + ) + return Vits2(new_config, ap, tokenizer, speaker_manager, language_manager) + + def export_onnx(self, output_path: str = "coqui_vits.onnx", verbose: bool = True): + """Export model to ONNX format for inference + + Args: + output_path (str): Path to save the exported model. + verbose (bool): Print verbose information. Defaults to True. + """ + + # rollback values + _forward = self.forward + disc = None + if hasattr(self, "disc"): + disc = self.disc + training = self.training + + # set export mode + self.disc = None + self.eval() + + def onnx_inference(text, text_lengths, scales, sid=None, langid=None): + noise_scale = scales[0] + length_scale = scales[1] + noise_scale_dp = scales[2] + self.noise_scale = noise_scale + self.length_scale = length_scale + self.noise_scale_dp = noise_scale_dp + return self.inference( + text, + aux_input={ + "x_lengths": text_lengths, + "d_vectors": None, + "speaker_ids": sid, + "language_ids": langid, + "durations": None, + }, + )["model_outputs"] + + self.forward = onnx_inference + + # set dummy inputs + dummy_input_length = 100 + sequences = torch.randint(low=0, high=2, size=(1, dummy_input_length), dtype=torch.long) + sequence_lengths = torch.LongTensor([sequences.size(1)]) + scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp]) + dummy_input = (sequences, sequence_lengths, scales) + input_names = ["input", "input_lengths", "scales"] + + if self.num_speakers > 0: + speaker_id = torch.LongTensor([0]) + dummy_input += (speaker_id, ) + input_names.append("sid") + + if hasattr(self, 'num_languages') and self.num_languages > 0 and self.embedded_language_dim > 0: + language_id = torch.LongTensor([0]) + dummy_input += (language_id, ) + input_names.append("langid") + + # export to ONNX + torch.onnx.export( + model=self, + args=dummy_input, + opset_version=15, + f=output_path, + verbose=verbose, + input_names=input_names, + output_names=["output"], + dynamic_axes={ + "input": {0: "batch_size", 1: "phonemes"}, + "input_lengths": {0: "batch_size"}, + "output": {0: "batch_size", 1: "time1", 2: "time2"}, + }, + ) + + # rollback + self.forward = _forward + if training: + self.train() + if not disc is None: + self.disc = disc + + def load_onnx(self, model_path: str, cuda=False): + import onnxruntime as ort + + providers = [ + "CPUExecutionProvider" + if cuda is False + else ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}) + ] + sess_options = ort.SessionOptions() + self.onnx_sess = ort.InferenceSession( + model_path, + sess_options=sess_options, + providers=providers, + ) + + def inference_onnx(self, x, x_lengths=None, speaker_id=None, language_id=None): + """ONNX inference""" + + if isinstance(x, torch.Tensor): + x = x.cpu().numpy() + + if x_lengths is None: + x_lengths = np.array([x.shape[1]], dtype=np.int64) + + if isinstance(x_lengths, torch.Tensor): + x_lengths = x_lengths.cpu().numpy() + scales = np.array( + [self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp], + dtype=np.float32, + ) + input_params = { + "input": x, + "input_lengths": x_lengths, + "scales": scales + } + if not speaker_id is None: + input_params["sid"] = torch.tensor([speaker_id]).cpu().numpy() + if not language_id is None: + input_params["langid"] = torch.tensor([language_id]).cpu().numpy() + + audio = self.onnx_sess.run( + ["output"], + input_params, + ) + return audio[0][0] + + +################################## +# VITS-2 CHARACTERS +################################## + + +class Vits2Characters(BaseCharacters): + """Characters class for VITs model for compatibility with pre-trained models""" + + def __init__( + self, + graphemes: str = _characters, + punctuations: str = _punctuations, + pad: str = _pad, + ipa_characters: str = _phonemes, + ) -> None: + if ipa_characters is not None: + graphemes += ipa_characters + super().__init__(graphemes, punctuations, pad, None, None, "", is_unique=False, is_sorted=True) + + def _create_vocab(self): + self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank] + self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} + # pylint: disable=unnecessary-comprehension + self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)} + + @staticmethod + def init_from_config(config: Coqpit): + if config.characters is not None: + _pad = config.characters["pad"] + _punctuations = config.characters["punctuations"] + _letters = config.characters["characters"] + _letters_ipa = config.characters["phonemes"] + return ( + Vits2Characters(graphemes=_letters, ipa_characters=_letters_ipa, punctuations=_punctuations, pad=_pad), + config, + ) + characters = Vits2Characters() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + def to_config(self) -> "CharactersConfig": + return CharactersConfig( + characters=self._characters, + punctuations=self._punctuations, + pad=self._pad, + eos=None, + bos=None, + blank=self._blank, + is_unique=False, + is_sorted=True, + ) + + +class FairseqVocab(BaseVocabulary): + def __init__(self, vocab: str): + super(FairseqVocab).__init__() + self.vocab = vocab + + @property + def vocab(self): + """Return the vocabulary dictionary.""" + return self._vocab + + @vocab.setter + def vocab(self, vocab_file): + with open(vocab_file, encoding="utf-8") as f: + self._vocab = [x.replace("\n", "") for x in f.readlines()] + self.blank = self._vocab[0] + self.pad = " " + self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension + self._id_to_char = {i: s for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension