Skip to content

Commit

Permalink
feature: cursor pagination of get_all_users in /admin/users route
Browse files Browse the repository at this point in the history
  • Loading branch information
ajanitshimanga committed May 27, 2024
1 parent ec894cd commit 0c12004
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
14 changes: 10 additions & 4 deletions memgpt/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,11 +655,17 @@ def get_user(self, user_id: uuid.UUID) -> Optional[User]:
return results[0].to_record()

@enforce_types
def get_all_users(self) -> List[User]:
# TODO make paginated
def get_all_users(self, cursor: Optional[uuid.UUID], limit: int) -> (Optional[uuid.UUID], List[User]):
with self.session_maker() as session:
results = session.query(UserModel).all()
return [r.to_record() for r in results]
query = session.query(UserModel).order_by(UserModel.id)
if cursor:
query = query.filter(UserModel.id > cursor)
results = query.limit(limit).all()
user_records = [r.to_record() for r in results]
next_cursor = user_records[-1].id
assert isinstance(next_cursor, uuid.UUID)

return next_cursor, user_records

@enforce_types
def get_source(
Expand Down
18 changes: 15 additions & 3 deletions memgpt/server/rest_api/admin/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
router = APIRouter()


class GetAllUsersRequest(BaseModel):
cursor: Optional[uuid.UUID] = Field(None, description="Cursor to which to start the paginated request.")
limit: int = Field(..., description="Maximum number of users to retrieve per page.")


class GetAllUsersResponse(BaseModel):
cursor: Optional[uuid.UUID] = Field(None, description="Cursor for the next page in the response.")
user_list: List[dict] = Field(..., description="A list of users.")


Expand Down Expand Up @@ -54,18 +60,24 @@ class DeleteUserResponse(BaseModel):

def setup_admin_router(server: SyncServer, interface: QueuingInterface):
@router.get("/users", tags=["admin"], response_model=GetAllUsersResponse)
def get_all_users():
def get_all_users(
cursor: Optional[uuid.UUID] = Query(None, description="Cursor to which to start the paginated request."),
limit: int = Query(100, description="Maximum number of users to retrieve per page."),
):
"""
Get a list of all users in the database
"""
try:
users = server.ms.get_all_users()
# Validate with the Pydantic model
request = GetAllUsersRequest(cursor=cursor, limit=limit)

next_cursor, users = server.ms.get_all_users(request.cursor, request.limit)
processed_users = [{"user_id": user.id} for user in users]
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return GetAllUsersResponse(user_list=processed_users)
return GetAllUsersResponse(cursor=next_cursor, user_list=processed_users)

@router.post("/users", tags=["admin"], response_model=CreateUserResponse)
def create_user(request: Optional[CreateUserRequest] = Body(None)):
Expand Down

0 comments on commit 0c12004

Please sign in to comment.