Skip to content

Commit

Permalink
dont need to revalidate Message.role with different allowed values
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Woodson committed Sep 3, 2024
1 parent a3103e1 commit 7f5f560
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
7 changes: 0 additions & 7 deletions memgpt/schemas/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,6 @@ class Message(BaseMessage):
id: str = BaseMessage.generate_id_field()
created_at: datetime = Field(default_factory=get_utc_time, description="The time the message was created.")

@field_validator("role")
@classmethod
def validate_role(cls, v: str) -> str:
roles = ["system", "assistant", "user", "tool"]
assert v in roles, f"Role must be one of {roles}"
return v

def to_json(self):
json_message = vars(self)
if json_message["tool_calls"] is not None:
Expand Down
9 changes: 7 additions & 2 deletions tests/unit/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _storage_data(self, type: TableType, user_id: str = None, agent_id: str = No
case TableType.RECALL_MEMORY:
data = [MessageCreate(
agent_id=agent_id,
user_id=user_id,
role=random.choice(list(MessageRole)),
text=faker.text(),
name=faker.name(),
Expand All @@ -59,8 +60,12 @@ def test_insert_storage(self, storage, db_session):

storage.insert_many(records=_data)
storage_records = storage.get_all()

item_ids = [item.id for item in storage_records]
with db_session as session:
records = storage.SQLModel.list(db_session=session)

assert len(records) == len(_data) == len(storage_records)
assert len(records) == len(_data) == len(storage_records)

random_item = storage.get(random.choice(item_ids))
assert random_item is not None
assert random_item.text in [item.text for item in _data]

0 comments on commit 7f5f560

Please sign in to comment.