diff --git a/server/main.py b/server/main.py index ddef7b1..6dbe082 100644 --- a/server/main.py +++ b/server/main.py @@ -112,12 +112,8 @@ class StreamingInputs(BaseModel): def predict_streaming_generator(parsed_input: dict = Body(...)): - speaker_embedding = ( - torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1) - ) - gpt_cond_latent = ( - torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0) - ) + speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1) + gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0) text = parsed_input.text language = parsed_input.language @@ -158,12 +154,8 @@ class TTSInputs(BaseModel): @app.post("/tts") def predict_speech(parsed_input: TTSInputs): - speaker_embedding = ( - torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1) - ).cuda() - gpt_cond_latent = ( - torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0) - ).cuda() + speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1) + gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0) text = parsed_input.text language = parsed_input.language