diff --git a/TTS/demos/xtts_ft_demo/xtts_demo.py b/TTS/demos/xtts_ft_demo/xtts_demo.py index ebb11f29d1..f4fe00aa11 100644 --- a/TTS/demos/xtts_ft_demo/xtts_demo.py +++ b/TTS/demos/xtts_ft_demo/xtts_demo.py @@ -33,7 +33,8 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab): config.load_json(xtts_config) XTTS_MODEL = Xtts.init_from_config(config) print("Loading XTTS model! ") - XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False) + xtts_checkpoint_dir = os.path.dirname(xtts_checkpoint) + XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, checkpoint_dir=xtts_checkpoint_dir, vocab_path=xtts_vocab, use_deepspeed=False) if torch.cuda.is_available(): XTTS_MODEL.cuda()