diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index d4350cd5e8..5ded306743 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -169,6 +169,7 @@ def main(): help="Output wav file path.", ) parser.add_argument("--use_cuda", type=bool, help="Run model on CUDA.", default=False) + parser.add_argument("--device", type=str, help="Device to run model on.", default="cpu") parser.add_argument( "--vocoder_path", type=str, @@ -391,6 +392,10 @@ def main(): if args.encoder_path is not None: encoder_path = args.encoder_path encoder_config_path = args.encoder_config_path + + device = args.device + if args.use_cuda: + device = "cuda" # load models synthesizer = Synthesizer( @@ -406,8 +411,7 @@ def main(): vc_config_path, model_dir, args.voice_dir, - args.use_cuda, - ) + ).to(device) # query speaker ids of a multi-speaker model. if args.list_speaker_idxs: