Skip to content

Commit

Permalink
🌡️ fix: Delete Endpoint Validation, List Index Errors, use PyJWT (#42)
Browse files Browse the repository at this point in the history
* 🌡️ refactor: add mongo health check and use enums for constants

* fix(delete_documents): use correct pydantic class

Changed the delete_documents endpoint to use a Pydantic model for better request body validation. This fixes the issue where the endpoint was incorrectly using query parameters.

* fix: list index out of range in query when documents are empty

* fix: ensure documents list is not empty in GET /documents

* fix: ensure documents list is not empty in GET /documents/{id}/context

* fix: ensure documents list is not empty in POST /query_multiple

* fix: delete route expected body

* chore: swap python-jose for PyJWT due to security advisories

* chore: add ATLAS warning as is not fully compatible
  • Loading branch information
danny-avila authored May 22, 2024
1 parent e8d52dc commit a3bee40
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 84 deletions.
57 changes: 34 additions & 23 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# config.py
import json
import os
import json
import logging
from enum import Enum
from datetime import datetime

from dotenv import find_dotenv, load_dotenv
from langchain_community.embeddings import (
HuggingFaceEmbeddings,
Expand All @@ -12,12 +12,24 @@
)
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from starlette.middleware.base import BaseHTTPMiddleware

from store_factory import get_vector_store

load_dotenv(find_dotenv())


class VectorDBType(Enum):
PGVECTOR = "pgvector"
ATLAS_MONGO = "atlas-mongo"


class EmbeddingsProvider(Enum):
OPENAI = "openai"
AZURE = "azure"
HUGGINGFACE = "huggingface"
HUGGINGFACETEI = "huggingfacetei"
OLLAMA = "ollama"


def get_env_variable(
var_name: str, default_value: str = None, required: bool = False
) -> str:
Expand All @@ -36,7 +48,9 @@ def get_env_variable(
if not os.path.exists(RAG_UPLOAD_DIR):
os.makedirs(RAG_UPLOAD_DIR, exist_ok=True)

VECTOR_DB_TYPE = get_env_variable("VECTOR_DB_TYPE", "pgvector")
VECTOR_DB_TYPE = VectorDBType(
get_env_variable("VECTOR_DB_TYPE", VectorDBType.PGVECTOR.value)
)
POSTGRES_DB = get_env_variable("POSTGRES_DB", "mydatabase")
POSTGRES_USER = get_env_variable("POSTGRES_USER", "myuser")
POSTGRES_PASSWORD = get_env_variable("POSTGRES_PASSWORD", "mypassword")
Expand Down Expand Up @@ -140,7 +154,6 @@ async def dispatch(self, request, call_next):

logging.getLogger("uvicorn.access").disabled = True


## Credentials

OPENAI_API_KEY = get_env_variable("OPENAI_API_KEY", "")
Expand All @@ -163,51 +176,49 @@ async def dispatch(self, request, call_next):


def init_embeddings(provider, model):
if provider == "openai":
if provider == EmbeddingsProvider.OPENAI:
return OpenAIEmbeddings(
model=model,
api_key=RAG_OPENAI_API_KEY,
openai_api_base=RAG_OPENAI_BASEURL,
openai_proxy=RAG_OPENAI_PROXY,
)
elif provider == "azure":
elif provider == EmbeddingsProvider.AZURE:
return AzureOpenAIEmbeddings(
azure_deployment=model,
api_key=RAG_AZURE_OPENAI_API_KEY,
azure_endpoint=RAG_AZURE_OPENAI_ENDPOINT,
api_version=RAG_AZURE_OPENAI_API_VERSION,
)
elif provider == "huggingface":
elif provider == EmbeddingsProvider.HUGGINGFACE:
return HuggingFaceEmbeddings(
model_name=model, encode_kwargs={"normalize_embeddings": True}
)
elif provider == "huggingfacetei":
elif provider == EmbeddingsProvider.HUGGINGFACETEI:
return HuggingFaceHubEmbeddings(model=model)
elif provider == "ollama":
elif provider == EmbeddingsProvider.OLLAMA:
return OllamaEmbeddings(model=model, base_url=OLLAMA_BASE_URL)
else:
raise ValueError(f"Unsupported embeddings provider: {provider}")


EMBEDDINGS_PROVIDER = get_env_variable("EMBEDDINGS_PROVIDER", "openai").lower()
EMBEDDINGS_PROVIDER = EmbeddingsProvider(
get_env_variable("EMBEDDINGS_PROVIDER", EmbeddingsProvider.OPENAI.value).lower()
)

if EMBEDDINGS_PROVIDER == "openai":
if EMBEDDINGS_PROVIDER == EmbeddingsProvider.OPENAI:
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-small")

elif EMBEDDINGS_PROVIDER == "azure":
elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.AZURE:
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-small")

elif EMBEDDINGS_PROVIDER == "huggingface":
elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.HUGGINGFACE:
EMBEDDINGS_MODEL = get_env_variable(
"EMBEDDINGS_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
)

elif EMBEDDINGS_PROVIDER == "huggingfacetei":
elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.HUGGINGFACETEI:
EMBEDDINGS_MODEL = get_env_variable(
"EMBEDDINGS_MODEL", "http://huggingfacetei:3000"
)

elif EMBEDDINGS_PROVIDER == "ollama":
elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.OLLAMA:
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "nomic-embed-text")
else:
raise ValueError(f"Unsupported embeddings provider: {EMBEDDINGS_PROVIDER}")
Expand All @@ -217,15 +228,15 @@ def init_embeddings(provider, model):
logger.info(f"Initialized embeddings of type: {type(embeddings)}")

# Vector store
if VECTOR_DB_TYPE == "pgvector":
if VECTOR_DB_TYPE == VectorDBType.PGVECTOR:
vector_store = get_vector_store(
connection_string=CONNECTION_STRING,
embeddings=embeddings,
collection_name=COLLECTION_NAME,
mode="async",
)
elif VECTOR_DB_TYPE == "atlas-mongo":
# atlas-mongo vector:
elif VECTOR_DB_TYPE == VectorDBType.ATLAS_MONGO:
logger.warning("Using Atlas MongoDB as vector store is not fully supported yet.")
vector_store = get_vector_store(
connection_string=ATLAS_MONGO_DB_URI,
embeddings=embeddings,
Expand Down
67 changes: 50 additions & 17 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import hashlib
import aiofiles
import aiofiles.os
from typing import Iterable
from typing import Iterable, List
from shutil import copyfileobj

import uvicorn
Expand All @@ -13,14 +13,15 @@
from langchain_core.runnables.config import run_in_executor
from langchain.text_splitter import RecursiveCharacterTextSplitter
from fastapi import (
FastAPI,
File,
Form,
Body,
Query,
UploadFile,
HTTPException,
status,
FastAPI,
Request,
UploadFile,
HTTPException,
)
from langchain_community.document_loaders import (
WebBaseLoader,
Expand All @@ -35,11 +36,17 @@
UnstructuredExcelLoader,
)

from models import DocumentResponse, StoreDocument, QueryRequestBody, QueryMultipleBody
from models import (
StoreDocument,
QueryRequestBody,
DocumentResponse,
QueryMultipleBody,
)
from psql import PSQLDatabase, ensure_custom_id_index_on_embedding, pg_health_check
from middleware import security_middleware
from pgvector_routes import router as pgvector_router
from parsers import process_documents, clean_text
from middleware import security_middleware
from mongo import mongo_health_check
from constants import ERROR_MESSAGES
from store import AsyncPgVector

Expand All @@ -57,6 +64,7 @@
LogMiddleware,
RAG_HOST,
RAG_PORT,
VectorDBType,
# RAG_EMBEDDING_MODEL,
# RAG_EMBEDDING_MODEL_DEVICE_TYPE,
# RAG_TEMPLATE,
Expand Down Expand Up @@ -107,8 +115,10 @@ async def get_all_ids():


def isHealthOK():
if VECTOR_DB_TYPE == "pgvector":
if VECTOR_DB_TYPE == VectorDBType.PGVECTOR:
return pg_health_check()
if VECTOR_DB_TYPE == VectorDBType.ATLAS_MONGO:
return mongo_health_check()
else:
return True

Expand All @@ -131,9 +141,16 @@ async def get_documents_by_ids(ids: list[str] = Query(...)):
existing_ids = vector_store.get_all_ids()
documents = vector_store.get_documents_by_ids(ids)

# Ensure all requested ids exist
if not all(id in existing_ids for id in ids):
raise HTTPException(status_code=404, detail="One or more IDs not found")

# Ensure documents list is not empty
if not documents:
raise HTTPException(
status_code=404, detail="No documents found for the given IDs"
)

return documents
except HTTPException as http_exc:
raise http_exc
Expand All @@ -142,19 +159,19 @@ async def get_documents_by_ids(ids: list[str] = Query(...)):


@app.delete("/documents")
async def delete_documents(ids: list[str] = Query(...)):
async def delete_documents(document_ids: List[str] = Body(...)):
try:
if isinstance(vector_store, AsyncPgVector):
existing_ids = await vector_store.get_all_ids()
await vector_store.delete(ids=ids)
await vector_store.delete(ids=document_ids)
else:
existing_ids = vector_store.get_all_ids()
vector_store.delete(ids=ids)
vector_store.delete(ids=document_ids)

if not all(id in existing_ids for id in ids):
if not all(id in existing_ids for id in document_ids):
raise HTTPException(status_code=404, detail="One or more IDs not found")

file_count = len(ids)
file_count = len(document_ids)
return {
"message": f"Documents for {file_count} file{'s' if file_count > 1 else ''} deleted successfully"
}
Expand All @@ -164,12 +181,11 @@ async def delete_documents(ids: list[str] = Query(...)):

@app.post("/query")
async def query_embeddings_by_file_id(body: QueryRequestBody, request: Request):
if not hasattr(request.state, "user"):
user_authorized = "public"
else:
user_authorized = request.state.user.get("id")

user_authorized = (
"public" if not hasattr(request.state, "user") else request.state.user.get("id")
)
authorized_documents = []

try:
embedding = vector_store.embedding_function.embed_query(body.query)

Expand All @@ -186,6 +202,9 @@ async def query_embeddings_by_file_id(body: QueryRequestBody, request: Request):
embedding, k=body.k, filter={"file_id": body.file_id}
)

if not documents:
return authorized_documents

document, score = documents[0]
doc_metadata = document.metadata
doc_user_id = doc_metadata.get("user_id")
Expand All @@ -198,6 +217,7 @@ async def query_embeddings_by_file_id(body: QueryRequestBody, request: Request):
)

return authorized_documents

except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))
Expand Down Expand Up @@ -427,11 +447,18 @@ async def load_document_context(id: str):
existing_ids = vector_store.get_all_ids()
documents = vector_store.get_documents_by_ids(ids)

# Ensure the requested id exists
if not all(id in existing_ids for id in ids):
raise HTTPException(
status_code=404, detail="The specified file_id was not found"
)

# Ensure documents list is not empty
if not documents:
raise HTTPException(
status_code=404, detail="No document found for the given ID"
)

return process_documents(documents)
except Exception as e:
logger.error(e)
Expand Down Expand Up @@ -511,6 +538,12 @@ async def query_embeddings_by_file_ids(body: QueryMultipleBody):
embedding, k=body.k, filter={"file_id": {"$in": body.file_ids}}
)

# Ensure documents list is not empty
if not documents:
raise HTTPException(
status_code=404, detail="No documents found for the given query"
)

return documents
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
Expand Down
Loading

0 comments on commit a3bee40

Please sign in to comment.