From fee51eb0ad0f081a192edba8b2305c0e8bc52860 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Mon, 14 Aug 2023 17:09:03 -0400 Subject: [PATCH] feature: add device flag to tts cli --- TTS/bin/synthesize.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index d4350cd5e8..5ded306743 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -169,6 +169,7 @@ def main(): help="Output wav file path.", ) parser.add_argument("--use_cuda", type=bool, help="Run model on CUDA.", default=False) + parser.add_argument("--device", type=str, help="Device to run model on.", default="cpu") parser.add_argument( "--vocoder_path", type=str, @@ -391,6 +392,10 @@ def main(): if args.encoder_path is not None: encoder_path = args.encoder_path encoder_config_path = args.encoder_config_path + + device = args.device + if args.use_cuda: + device = "cuda" # load models synthesizer = Synthesizer( @@ -406,8 +411,7 @@ def main(): vc_config_path, model_dir, args.voice_dir, - args.use_cuda, - ) + ).to(device) # query speaker ids of a multi-speaker model. if args.list_speaker_idxs: