-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
65 lines (55 loc) · 1.94 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import asyncio
from sentence_transformers import SentenceTransformer
import threading
from discover import Discover
from fastapi import FastAPI, Request
import uvicorn
from contextlib import asynccontextmanager
print("Creating app")
config_path = os.getenv("DISCOVER_CONFIG_PATH")
if config_path is None:
config_path = "ord.yaml"
discover = Discover()
model = SentenceTransformer('clip-ViT-L-14')
index = discover.get_index(768)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Setup database & api pool
await discover.setup_db(config_path)
# Start update thread
index_thread = threading.Thread(target=asyncio.run, args=(discover.update_index(model, config_path),))
index_thread.start()
yield ## api server can run now
# Stop index thread
print("cleaning up")
discover.stop_indexer = True
print("FastAPI exited, waiting on index thread to finish..")
index_thread.join()
app = FastAPI(lifespan=lifespan)
@app.get("/")
async def hello_world():
return "Hello, World!"
@app.get("/ntotal")
def ntotal():
return discover.index.ntotal
@app.get("/search/{search_term}")
async def search(search_term, n: int = 9):
response = await discover.get_text_to_inscription_numbers(model, search_term, min(n, 50))
return response
@app.post("/search_by_image")
async def search_by_image(request: Request, n: int = 9):
image_binary = await request.body()
response = await discover.get_image_to_inscription_numbers(model, image_binary, min(n, 50))
return response
@app.get("/similar/{sha256}")
async def similar(sha256, n: int = 9):
response = await discover.get_similar_images(model, sha256, min(n, 50))
return response
@app.get("/get_class/{dbclass}")
async def get_class(dbclass, n: int = 9):
response = await discover.get_dbclass_to_inscription_numbers(dbclass, n)
return response
if __name__ == "__main__":
print("main hit")
uvicorn.run("main:app", port=4080, log_level="info")