Skip to content

Commit

Permalink
feat: model deletion
Browse files Browse the repository at this point in the history
* Add support for the deletion of fine-tuned models

---------

Co-authored-by: arresejo <[email protected]>
  • Loading branch information
jean-malo and arresejo authored Jun 19, 2024
1 parent cd72257 commit af48070
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 4 deletions.
3 changes: 3 additions & 0 deletions examples/async_jobs_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ async def main():
await client.files.delete(training_file.id)
await client.files.delete(validation_file.id)

# Delete fine-tuned model
await client.delete_model(created_job.fine_tuned_model)


if __name__ == "__main__":
asyncio.run(main())
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mistralai"
version = "0.4.0"
version = "0.4.1"
description = ""
authors = ["Bam4d <[email protected]>"]
readme = "README.md"
Expand Down
10 changes: 9 additions & 1 deletion src/mistralai/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ToolChoice,
)
from mistralai.models.embeddings import EmbeddingResponse
from mistralai.models.models import ModelList
from mistralai.models.models import ModelDeleted, ModelList


class MistralAsyncClient(ClientBase):
Expand Down Expand Up @@ -304,6 +304,14 @@ async def list_models(self) -> ModelList:

raise MistralException("No response received")

async def delete_model(self, model_id: str) -> ModelDeleted:
single_response = self._request("delete", {}, f"v1/models/{model_id}")

async for response in single_response:
return ModelDeleted(**response)

raise MistralException("No response received")

async def completion(
self,
model: str,
Expand Down
10 changes: 9 additions & 1 deletion src/mistralai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
ToolChoice,
)
from mistralai.models.embeddings import EmbeddingResponse
from mistralai.models.models import ModelList
from mistralai.models.models import ModelDeleted, ModelList


class MistralClient(ClientBase):
Expand Down Expand Up @@ -298,6 +298,14 @@ def list_models(self) -> ModelList:

raise MistralException("No response received")

def delete_model(self, model_id: str) -> ModelDeleted:
single_response = self._request("delete", {}, f"v1/models/{model_id}")

for response in single_response:
return ModelDeleted(**response)

raise MistralException("No response received")

def completion(
self,
model: str,
Expand Down
2 changes: 1 addition & 1 deletion src/mistralai/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice

CLIENT_VERSION = "0.4.0"
CLIENT_VERSION = "0.4.1"


class ClientBase(ABC):
Expand Down
6 changes: 6 additions & 0 deletions src/mistralai/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,9 @@ class ModelCard(BaseModel):
class ModelList(BaseModel):
object: str
data: List[ModelCard]


class ModelDeleted(BaseModel):
id: str
object: str
deleted: bool
26 changes: 26 additions & 0 deletions tests/test_delete_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from mistralai.models.models import ModelDeleted

from .utils import mock_model_deleted_response_payload, mock_response


class TestDeleteModel:
def test_delete_model(self, client):
expected_response_model = ModelDeleted.model_validate_json(mock_model_deleted_response_payload())
client._client.request.return_value = mock_response(200, expected_response_model.json())

response_model = client.delete_model("model_id")

client._client.request.assert_called_once_with(
"delete",
"https://api.mistral.ai/v1/models/model_id",
headers={
"User-Agent": f"mistral-client-python/{client._version}",
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": "Bearer test_api_key",
},
json={},
data=None,
)

assert response_model == expected_response_model
28 changes: 28 additions & 0 deletions tests/test_delete_model_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
from mistralai.models.models import ModelDeleted

from .utils import mock_model_deleted_response_payload, mock_response


class TestAsyncDeleteModel:
@pytest.mark.asyncio
async def test_delete_model(self, async_client):
expected_response_model = ModelDeleted.model_validate_json(mock_model_deleted_response_payload())
async_client._client.request.return_value = mock_response(200, expected_response_model.json())

response_model = await async_client.delete_model("model_id")

async_client._client.request.assert_called_once_with(
"delete",
"https://api.mistral.ai/v1/models/model_id",
headers={
"User-Agent": f"mistral-client-python/{async_client._version}",
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": "Bearer test_api_key",
},
json={},
data=None,
)

assert response_model == expected_response_model
10 changes: 10 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,13 @@ def mock_file_deleted_response_payload() -> str:
"deleted": True,
}
)


def mock_model_deleted_response_payload() -> str:
return orjson.dumps(
{
"id": "model_id",
"object": "model",
"deleted": True,
}
)

0 comments on commit af48070

Please sign in to comment.