Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: cursor pagination of get_all_users in /admin/users route #1424

14 changes: 10 additions & 4 deletions memgpt/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,11 +655,17 @@ def get_user(self, user_id: uuid.UUID) -> Optional[User]:
return results[0].to_record()

@enforce_types
def get_all_users(self) -> List[User]:
# TODO make paginated
def get_all_users(self, cursor: Optional[uuid.UUID], limit: int) -> (Optional[uuid.UUID], List[User]):
with self.session_maker() as session:
results = session.query(UserModel).all()
return [r.to_record() for r in results]
query = session.query(UserModel).order_by(UserModel.id)
if cursor:
query = query.filter(UserModel.id > cursor)
results = query.limit(limit).all()
user_records = [r.to_record() for r in results]
next_cursor = user_records[-1].id
assert isinstance(next_cursor, uuid.UUID)

return next_cursor, user_records

@enforce_types
def get_source(
Expand Down
18 changes: 15 additions & 3 deletions memgpt/server/rest_api/admin/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
router = APIRouter()


class GetAllUsersRequest(BaseModel):
cursor: Optional[uuid.UUID] = Field(None, description="Cursor to which to start the paginated request.")
limit: int = Field(..., description="Maximum number of users to retrieve per page.")


class GetAllUsersResponse(BaseModel):
cursor: Optional[uuid.UUID] = Field(None, description="Cursor for the next page in the response.")
user_list: List[dict] = Field(..., description="A list of users.")


Expand Down Expand Up @@ -54,18 +60,24 @@ class DeleteUserResponse(BaseModel):

def setup_admin_router(server: SyncServer, interface: QueuingInterface):
@router.get("/users", tags=["admin"], response_model=GetAllUsersResponse)
def get_all_users():
def get_all_users(
cursor: Optional[uuid.UUID] = Query(None, description="Cursor to which to start the paginated request."),
limit: int = Query(100, description="Maximum number of users to retrieve per page."),
):
"""
Get a list of all users in the database
"""
try:
users = server.ms.get_all_users()
# Validate with the Pydantic model
request = GetAllUsersRequest(cursor=cursor, limit=limit)

next_cursor, users = server.ms.get_all_users(request.cursor, request.limit)
processed_users = [{"user_id": user.id} for user in users]
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return GetAllUsersResponse(user_list=processed_users)
return GetAllUsersResponse(cursor=next_cursor, user_list=processed_users)

@router.post("/users", tags=["admin"], response_model=CreateUserResponse)
def create_user(request: Optional[CreateUserRequest] = Body(None)):
Expand Down
Loading