Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

增设前端推理界面,并修改api_v2以进行适配。 #1681

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions api_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,12 @@
import signal
import numpy as np
import soundfile as sf
import shutil
from fastapi import FastAPI, Request, HTTPException, Response
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
import uvicorn
from io import BytesIO
from tools.i18n.i18n import I18nAuto
Expand All @@ -139,6 +142,7 @@
config_path = "GPT-SoVITS/configs/tts_infer.yaml"

tts_config = TTS_Config(config_path)
print("以下为TTS_CONFIG配置, 如需修改请查看/GPT_SoVITS/configs/tts_infer.yaml")
print(tts_config)
tts_pipeline = TTS(tts_config)

Expand Down Expand Up @@ -447,7 +451,84 @@ async def set_sovits_weights(weights_path: str = None):
return JSONResponse(status_code=400, content={"message": f"change sovits weight failed", "Exception": str(e)})
return JSONResponse(status_code=200, content={"message": "success"})

APP.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有域名的请求
allow_credentials=True,
allow_methods=["*"], # 允许所有方法
allow_headers=["*"], # 允许所有请求头
)

@APP.get("/info")
async def get_info():
try:
gpt_weights_dir_v2 = 'GPT_weights_v2'
sovits_weights_dir_v2 = 'SoVITS_weights_v2'
gpt_weights_dir = 'GPT_weights'
sovits_weights_dir = 'SoVITS_weights'

gpt_filenames = []
sovits_filenames = []

for dir in [gpt_weights_dir_v2, gpt_weights_dir]:
if os.path.exists(dir):
gpt_filenames.extend([f"{dir}/{f}" for f in os.listdir(dir) if os.path.isfile(os.path.join(dir, f))])

for dir in [sovits_weights_dir_v2, sovits_weights_dir]:
if os.path.exists(dir):
sovits_filenames.extend([f"{dir}/{f}" for f in os.listdir(dir) if os.path.isfile(os.path.join(dir, f))])

if not gpt_filenames:
return JSONResponse(status_code=404, content={"message": "No GPT weights files found"})
if not sovits_filenames:
return JSONResponse(status_code=404, content={"message": "No SoVITS weights files found"})

return JSONResponse(status_code=200, content={
"gpt_weights_files": gpt_filenames,
"sovits_weights_files": sovits_filenames,
"server_port": port
})
except Exception as e:
return JSONResponse(status_code=500, content={"message": f"Error retrieving weights info", "error": str(e)})

@APP.post("/tts")
async def tts_post_endpoint(request: TTS_Request):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面有个同名函数,是不是应该把那个删了?

req = request.model_dump()
print("\nProcessed request (req):")
print(f"Type: {type(req)}")
print("Content:")
for key, value in req.items():
print(f" {key}: {value}")

return await tts_handle(req)

@APP.post("/upload_file")
async def upload_file(file: UploadFile = File(...)):
try:
# Create a temporary directory if it doesn't exist
temp_dir = "temp_files"
os.makedirs(temp_dir, exist_ok=True)

# Define the path to save the uploaded file
file_path = os.path.join(temp_dir, file.filename)

# Save the uploaded file to the temporary directory
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)

return JSONResponse(status_code=200, content={"message": "File uploaded successfully", "file_path": file_path})
except Exception as e:
return JSONResponse(status_code=500, content={"message": "File upload failed", "error": str(e)})

APP.mount("/", StaticFiles(directory="dist", html=True), name="static")
print("--------------------------------")
print(f"前端界面已在 http://{host}:{port} 开启。")
print("目前的前端版本只适配默认端口9880, 更改api端口会导致前端页面无法工作, 但不影响后端api运行。")
print("在前端界面中上传的音频文件将会保存在 ./temp_files 目录下,如有需要请手动删除。")
print("请至少运行一遍webui.py, 放好模型, 再运行本API, 以确保存放模型的文件夹SoVITS_weights和GPT_weights存在。")
print("如遇配置错误,请检查命令行上方输出的配置详情,并修改文件/GPT_SoVITS/configs/tts_infer.yaml")
print("如果运行环境是mac, 请将tts_infer.yaml内custom条目下的device改为cpu, is_half改为false")
print("--------------------------------")

if __name__ == "__main__":
try:
Expand Down
Loading