Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

toolkit: add pinned chats #774

Merged
merged 14 commits into from
Sep 24, 2024
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""add is pinned flag

Revision ID: 20b03fd331e8
Revises: ac3933258035
Create Date: 2024-09-16 20:16:55.080572

"""
from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = '20b03fd331e8'
down_revision: Union[str, None] = 'ac3933258035'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('conversations', sa.Column('is_pinned', sa.Boolean(), nullable=False, server_default=sa.false()))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('conversations', 'is_pinned')
# ### end Alembic commands ###
35 changes: 34 additions & 1 deletion src/backend/crud/conversation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from sqlalchemy import desc
from sqlalchemy.orm import Session

from backend.database_models.conversation import (
Conversation,
ConversationFileAssociation,
)
from backend.schemas.conversation import UpdateConversationRequest
from backend.schemas.conversation import (
ToggleConversationPinRequest,
UpdateConversationRequest,
)
from backend.services.transaction import validate_transaction


Expand Down Expand Up @@ -54,6 +58,7 @@ def get_conversations(
user_id: str,
offset: int = 0,
limit: int = 100,
order_by: str | None = None,
agent_id: str | None = None,
organization_id: str | None = None,
) -> list[Conversation]:
Expand All @@ -67,6 +72,7 @@ def get_conversations(
agent_id (str): Agent ID.
offset (int): Offset to start the list.
limit (int): Limit of conversations to be listed.
order_by (str): A field by which to order the conversations.

Returns:
list[Conversation]: List of conversations.
Expand All @@ -76,6 +82,9 @@ def get_conversations(
query = query.filter(Conversation.agent_id == agent_id)
if organization_id is not None:
query = query.filter(Conversation.organization_id == organization_id)
if order_by is not None:
order_column = getattr(Conversation, order_by)
query = query.order_by(desc(order_column))
query = query.order_by(Conversation.updated_at.desc()).offset(offset).limit(limit)

return query.all()
Expand Down Expand Up @@ -104,6 +113,30 @@ def update_conversation(
return conversation


@validate_transaction
def toggle_conversation_pin(
db: Session, conversation: Conversation, new_conversation_pin: ToggleConversationPinRequest
) -> Conversation:
"""
Update conversation pin by conversation ID.

Args:
db (Session): Database session.
conversation (Conversation): Conversation to be updated.
new_conversation_pin (ToggleConversationPinRequest): New conversation pin data.

Returns:
Conversation: Updated conversation.
"""
db.query(Conversation).filter(Conversation.id == conversation.id).update({
Conversation.is_pinned: new_conversation_pin.is_pinned,
Conversation.updated_at: conversation.updated_at
})
db.commit()
db.refresh(conversation)
return conversation


@validate_transaction
def delete_conversation(db: Session, conversation_id: str, user_id: str) -> None:
"""
Expand Down
10 changes: 9 additions & 1 deletion src/backend/database_models/conversation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from typing import List, Optional
from uuid import uuid4

from sqlalchemy import ForeignKey, Index, PrimaryKeyConstraint, String, UniqueConstraint
from sqlalchemy import (
Boolean,
ForeignKey,
Index,
PrimaryKeyConstraint,
String,
UniqueConstraint,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship

from backend.database_models.base import Base
Expand Down Expand Up @@ -46,6 +53,7 @@ class Conversation(Base):
ondelete="CASCADE",
)
)
is_pinned: Mapped[bool] = mapped_column(Boolean, default=False)

@property
def messages(self):
Expand Down
42 changes: 41 additions & 1 deletion src/backend/routers/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ConversationWithoutMessages,
DeleteConversationResponse,
GenerateTitleResponse,
ToggleConversationPinRequest,
UpdateConversationRequest,
)
from backend.schemas.file import (
Expand Down Expand Up @@ -98,6 +99,7 @@ async def get_conversation(
description=conversation.description,
agent_id=conversation.agent_id,
organization_id=conversation.organization_id,
is_pinned=conversation.is_pinned,
)

_ = validate_conversation(session, conversation_id, user_id)
Expand All @@ -109,6 +111,7 @@ async def list_conversations(
*,
offset: int = 0,
limit: int = 100,
order_by: str = None,
agent_id: str = None,
session: DBSessionDep,
request: Request,
Expand All @@ -120,6 +123,7 @@ async def list_conversations(
Args:
offset (int): Offset to start the list.
limit (int): Limit of conversations to be listed.
order_by (str): A field by which to order the conversations.
agent_id (str): Query parameter for agent ID to optionally filter conversations by agent.
session (DBSessionDep): Database session.
request (Request): Request object.
Expand All @@ -130,7 +134,7 @@ async def list_conversations(
user_id = ctx.get_user_id()

conversations = conversation_crud.get_conversations(
session, offset=offset, limit=limit, user_id=user_id, agent_id=agent_id
session, offset=offset, limit=limit, order_by=order_by, user_id=user_id, agent_id=agent_id
)

results = []
Expand All @@ -153,6 +157,7 @@ async def list_conversations(
agent_id=conversation.agent_id,
messages=[],
organization_id=conversation.organization_id,
is_pinned=conversation.is_pinned,
)
)

Expand Down Expand Up @@ -205,6 +210,40 @@ async def update_conversation(
description=conversation.description,
agent_id=conversation.agent_id,
organization_id=conversation.organization_id,
is_pinned=conversation.is_pinned,
)


@router.put("/{conversation_id}/toggle-pin", response_model=ConversationWithoutMessages)
async def toggle_conversation_pin(
conversation_id: str,
new_conversation_pin: ToggleConversationPinRequest,
session: DBSessionDep,
ctx: Context = Depends(get_context),
) -> ConversationWithoutMessages:
user_id = ctx.get_user_id()
conversation = validate_conversation(session, conversation_id, user_id)
conversation = conversation_crud.toggle_conversation_pin(
session, conversation, new_conversation_pin
)
files = get_file_service().get_files_by_conversation_id(
session, user_id, conversation.id, ctx
)
files_with_conversation_id = attach_conversation_id_to_files(
conversation.id, files
)
return ConversationWithoutMessages(
id=conversation.id,
user_id=user_id,
created_at=conversation.created_at,
updated_at=conversation.updated_at,
title=conversation.title,
files=files_with_conversation_id,
description=conversation.description,
agent_id=conversation.agent_id,
messages=[],
organization_id=conversation.organization_id,
is_pinned=conversation.is_pinned,
)


Expand Down Expand Up @@ -313,6 +352,7 @@ async def search_conversations(
agent_id=conversation.agent_id,
messages=[],
organization_id=conversation.organization_id,
is_pinned=conversation.is_pinned,
)
)
return results
Expand Down
5 changes: 5 additions & 0 deletions src/backend/schemas/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Conversation(ConversationBase):
files: List[ConversationFilePublic]
description: Optional[str]
agent_id: Optional[str]
is_pinned: bool

@computed_field(return_type=int)
def total_file_size(self):
Expand All @@ -48,6 +49,10 @@ class Config:
from_attributes = True


class ToggleConversationPinRequest(BaseModel):
is_pinned: bool


class DeleteConversationResponse(BaseModel):
pass

Expand Down
1 change: 1 addition & 0 deletions src/backend/tests/unit/factories/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Meta:
text_messages = []
agent_id = None
organization_id = None
is_pinned = False


class ConversationFileAssociationFactory(BaseFactory):
Expand Down
30 changes: 30 additions & 0 deletions src/backend/tests/unit/routers/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,36 @@ def test_update_conversations_missing_user_id(
assert results == {"detail": "User-Id required in request headers."}


def test_toggle_conversation_pin(
session_client: TestClient,
session: Session,
user: User,
) -> None:
conversation = get_factory("Conversation", session).create(
is_pinned=False, user_id=user.id
)
response = session_client.put(
f"/v1/conversations/{conversation.id}/toggle-pin",
json={"is_pinned": True},
headers={"User-Id": user.id},
)
response_conversation = response.json()

assert response.status_code == 200
assert response_conversation["is_pinned"]
assert response_conversation["updated_at"] == conversation.updated_at.isoformat()

# Check if the conversation was updated
updated_conversation = (
session.query(Conversation)
.filter_by(id=conversation.id, user_id=conversation.user_id)
.first()
)
assert updated_conversation is not None
assert updated_conversation.is_pinned
assert updated_conversation.updated_at == conversation.updated_at


def test_delete_conversation(
session_client: TestClient,
session: Session,
Expand Down
18 changes: 18 additions & 0 deletions src/interfaces/assistants_web/src/assets/icons/Pin.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import * as React from 'react';
import { SVGProps } from 'react';

import { cn } from '@/utils';

export const Pin: React.FC<SVGProps<SVGSVGElement>> = ({ className, ...props }) => (
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
className={cn('h-full w-full fill-inherit', className)}
{...props}
>
<path
d="M14.5 5.09375L10.703125 1.296875C10.453125 1.046875 10.152344 0.921875 9.804688 0.921875C9.457031 0.921875 9.15625 1.046875 8.90625 1.296875L5.203125 5.09375C5.152344 5.144531 5.082031 5.171875 5 5.171875C4.917969 5.171875 4.847656 5.144531 4.796875 5.09375C4.597656 5 4.386719 4.917969 4.164062 4.851562C3.941406 4.785156 3.652344 4.800781 3.296875 4.90625C2.996094 5 2.695312 5.136719 2.398438 5.320312C2.101562 5.503906 1.800781 5.699219 1.5 5.90625C1.355469 6 1.234375 6.136719 1.140625 6.3125C1.046875 6.488281 1 6.652344 1 6.796875C1 7.152344 1.027344 7.371094 1.078125 7.453125C1.128906 7.535156 1.203125 7.652344 1.296875 7.796875L4.203125 10.703125C4.253906 10.753906 4.28125 10.824219 4.28125 10.90625C4.28125 10.988281 4.253906 11.050781 4.203125 11.09375L1.09375 14.203125C1 14.296875 0.953125 14.410156 0.953125 14.546875C0.953125 14.683594 1 14.800781 1.09375 14.90625C1.199219 15 1.332031 15.03125 1.5 15C1.667969 14.96875 1.800781 14.902344 1.90625 14.796875L4.90625 11.59375C4.949219 11.550781 5.011719 11.53125 5.09375 11.53125C5.175781 11.53125 5.246094 11.550781 5.296875 11.59375L8.203125 14.5C8.296875 14.605469 8.40625 14.683594 8.53125 14.734375C8.65625 14.785156 8.84375 14.8125 9.09375 14.8125C9.34375 14.8125 9.53125 14.765625 9.65625 14.671875C9.78125 14.578125 9.894531 14.453125 10 14.296875C10.199219 13.996094 10.386719 13.660156 10.570312 13.296875C10.753906 12.933594 10.894531 12.597656 11 12.296875C11.050781 12.046875 11.058594 11.800781 11.015625 11.5625C10.972656 11.324219 10.902344 11.105469 10.796875 10.90625C10.746094 10.855469 10.71875 10.800781 10.71875 10.75C10.71875 10.699219 10.746094 10.644531 10.796875 10.59375L14.5 6.90625C14.75 6.65625 14.875 6.355469 14.875 6C14.875 5.644531 14.75 5.34375 14.5 5.09375ZM9.90625 12.09375C9.855469 12.300781 9.746094 12.5625 9.578125 12.875C9.410156 13.1875 9.25 13.496094 9.09375 13.796875C9.050781 13.847656 9.003906 13.890625 8.953125 13.921875C8.902344 13.953125 8.847656 13.949219 8.796875 13.90625C8.777344 13.875 8.417969 13.511719 7.71875 12.8125C7.03125 12.105469 6.277344 11.332031 5.453125 10.5C4.628906 9.667969 3.871094 8.894531 3.171875 8.1875C2.484375 7.488281 2.125 7.125 2.09375 7.09375C2.050781 7.050781 2.023438 6.988281 2.007812 6.898438C1.992188 6.808594 2.019531 6.746094 2.09375 6.703125C2.34375 6.546875 2.59375 6.398438 2.84375 6.257812C3.09375 6.117188 3.34375 6 3.59375 5.90625C3.699219 5.855469 3.847656 5.847656 4.046875 5.890625C4.246094 5.933594 4.394531 6 4.5 6.09375L9.703125 11.296875C9.796875 11.402344 9.863281 11.523438 9.90625 11.664062C9.949219 11.804688 9.949219 11.949219 9.90625 12.09375ZM9.703125 9.90625L5.90625 6.09375C5.855469 6.050781 5.828125 5.988281 5.828125 5.90625C5.828125 5.824219 5.855469 5.753906 5.90625 5.703125L9.59375 2C9.644531 1.949219 9.714844 1.921875 9.796875 1.921875C9.878906 1.921875 9.949219 1.949219 10 2L13.796875 5.796875C13.847656 5.847656 13.875 5.917969 13.875 6C13.875 6.082031 13.847656 6.152344 13.796875 6.203125L10.09375 9.90625C10.050781 9.949219 9.988281 9.96875 9.90625 9.96875C9.824219 9.96875 9.753906 9.949219 9.703125 9.90625ZM9.703125 9.90625"
fill="inherit"
/>
</svg>
);
1 change: 1 addition & 0 deletions src/interfaces/assistants_web/src/assets/icons/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export * from './Moon';
export * from './NewMessage';
export * from './OneDrive';
export * from './Paperclip';
export * from './Pin';
export * from './Profile';
export * from './Regenerate';
export * from './Search';
Expand Down
17 changes: 16 additions & 1 deletion src/interfaces/assistants_web/src/cohere-client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
CreateSnapshotRequest,
CreateUser,
Fetch,
ToggleConversationPinRequest,
UpdateAgentRequest,
UpdateConversationRequest,
UpdateDeploymentEnv,
Expand Down Expand Up @@ -151,7 +152,12 @@ export class CohereClient {
});
}

public listConversations(params: { offset?: number; limit?: number; agentId?: string }) {
public listConversations(params: {
offset?: number;
limit?: number;
orderBy?: string;
agentId?: string;
}) {
return this.cohereService.default.listConversationsV1ConversationsGet(params);
}

Expand All @@ -174,6 +180,15 @@ export class CohereClient {
});
}

public toggleConversationPin(requestBody: ToggleConversationPinRequest, conversationId: string) {
return this.cohereService.default.toggleConversationPinV1ConversationsConversationIdTogglePinPut(
{
conversationId: conversationId,
requestBody,
}
);
}

public listTools({ agentId }: { agentId?: string | null }) {
return this.cohereService.default.listToolsV1ToolsGet({ agentId });
}
Expand Down
Loading
Loading