Skip to content

Commit

Permalink
Include improvements of rhasspy/piper/rhasspy#476.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmcpantoja committed Apr 25, 2024
1 parent 0ecfc14 commit 445af0a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 28 deletions.
22 changes: 15 additions & 7 deletions src/python/piper_train/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

from .vits.lightning import VitsModel

Expand All @@ -27,6 +27,11 @@ def main():
type=int,
help="Save checkpoint every N epochs (default: 1)",
)
parser.add_argument(
"--patience",
type=int,
help="Number of validation cycles to allow to pass without improvement before stopping training"
)
parser.add_argument(
"--quality",
default="medium",
Expand Down Expand Up @@ -76,19 +81,22 @@ def main():
num_speakers = int(config["num_speakers"])
sample_rate = int(config["audio"]["sample_rate"])

trainer = Trainer.from_argparse_args(args)
callbacks = []
if args.checkpoint_epochs is not None:
trainer.callbacks = [ModelCheckpoint(
every_n_epochs=args.checkpoint_epochs,
save_top_k=args.num_ckpt,
save_last=args.save_last
)]
callbacks.append(
ModelCheckpoint(every_n_epochs=args.checkpoint_epochs, monitor="val_loss", save_top_k=args.num_ckpt, save_last=args.save_last, mode="min")
)
_LOGGER.debug(
"Checkpoints will be saved every %s epoch(s)", args.checkpoint_epochs
)
_LOGGER.debug(
"%s Checkpoints will be saved", args.num_ckpt
)
if args.patience is not None:
callbacks.append(
EarlyStopping(monitor="val_loss", min_delta=0.00, patience=args.patience, verbose=False, mode="min")
)
trainer = Trainer.from_argparse_args(args, callbacks=callbacks)

dict_args = vars(args)
if args.quality == "x-low":
Expand Down
46 changes: 25 additions & 21 deletions src/python/piper_train/vits/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,30 +428,34 @@ def validation_step(self, batch: Batch, batch_idx: int):
val_loss = self.training_step_g(batch) + self.training_step_d(batch) + self.training_step_dur(batch)
self.log("val_loss", val_loss)
print(f"Epoch: {self.current_epoch}. Steps: {self.global_step}. Validation loss: {val_loss}")
# Generate audio examples
for utt_idx, test_utt in enumerate(self._test_dataset):
text = test_utt.phoneme_ids.unsqueeze(0).to(self.device)
text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device)
scales = [1.0, 1.0, 1.0]
sid = (
test_utt.speaker_id.to(self.device)
if test_utt.speaker_id is not None
else None
)
test_audio = self(text, text_lengths, scales, sid=sid).detach()
return val_loss

# Scale to make louder in [-1, 1]
test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max())))
def on_validation_end(self) -> None:
# Generate audio examples after validation, but not during sanity check
if not self.trainer.sanity_checking:
for utt_idx, test_utt in enumerate(self._test_dataset):
text = test_utt.phoneme_ids.unsqueeze(0).to(self.device)
text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device)
scales = [1.0, 1.0, 1.0]
sid = (
test_utt.speaker_id.to(self.device)
if test_utt.speaker_id is not None
else None
)
test_audio = self(text, text_lengths, scales, sid=sid).detach()

tag = test_utt.text or str(utt_idx)
self.logger.experiment.add_audio(
tag,
test_audio,
self.global_step,
sample_rate=self.hparams.sample_rate
)
# Scale to make louder in [-1, 1]
test_audio = test_audio * (1.0 / max(0.01, abs(test_audio).max()))

return val_loss
tag = test_utt.text or str(utt_idx)
self.logger.experiment.add_audio(
tag,
test_audio,
self.global_step,
sample_rate=self.hparams.sample_rate
)

return super().on_validation_end()

def configure_optimizers(self):
optimizers = [
Expand Down

0 comments on commit 445af0a

Please sign in to comment.