Skip to content

Commit

Permalink
Fix CPU inference
Browse files Browse the repository at this point in the history
  • Loading branch information
WeberJulian committed Dec 13, 2023
1 parent da7af7a commit 281b1f4
Showing 1 changed file with 4 additions and 12 deletions.
16 changes: 4 additions & 12 deletions server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,8 @@ class StreamingInputs(BaseModel):


def predict_streaming_generator(parsed_input: dict = Body(...)):
speaker_embedding = (
torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1)
)
gpt_cond_latent = (
torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0)
)
speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1)
gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0)
text = parsed_input.text
language = parsed_input.language

Expand Down Expand Up @@ -158,12 +154,8 @@ class TTSInputs(BaseModel):

@app.post("/tts")
def predict_speech(parsed_input: TTSInputs):
speaker_embedding = (
torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1)
).cuda()
gpt_cond_latent = (
torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0)
).cuda()
speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1)
gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0)
text = parsed_input.text
language = parsed_input.language

Expand Down

0 comments on commit 281b1f4

Please sign in to comment.