Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tiktoken cache #962

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions src/marvin/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ class MarvinSettings(BaseSettings):

def __setattr__(self, name: str, value: Any) -> None:
# wrap bare strings in SecretStr if the field is annotated with SecretStr
field = self.model_fields.get(name)
if field:
if field := self.model_fields.get(name):
annotation = field.annotation
base_types = (
getattr(annotation, "__args__", None)
Expand All @@ -32,6 +31,18 @@ def __setattr__(self, name: str, value: Any) -> None:
super().__setattr__(name, value)


class TiktokenSettings(MarvinSettings):
model_config = SettingsConfigDict(env_prefix="marvin_tiktoken_", extra="ignore")

cache_dir: Optional[str] = Field(
default=None, description="Directory to store cached tiktoken encoding files."
)
verify_ssl: bool = Field(
default=True,
description="Whether to verify SSL certificates for tiktoken requests.",
)


class ChatCompletionSettings(MarvinSettings):
model_config = SettingsConfigDict(
env_prefix="marvin_chat_completions_", extra="ignore"
Expand All @@ -40,10 +51,19 @@ class ChatCompletionSettings(MarvinSettings):

temperature: float = Field(description="The default temperature to use.", default=1)

tiktoken: TiktokenSettings = Field(default_factory=TiktokenSettings)

@property
def encoder(self):
import tiktoken

if self.tiktoken.cache_dir:
os.environ["TIKTOKEN_CACHE_DIR"] = self.tiktoken.cache_dir
if not self.tiktoken.verify_ssl:
import ssl

ssl._create_default_https_context = ssl._create_unverified_context

try:
encoding = tiktoken.encoding_for_model(self.model)
except KeyError:
Expand Down
35 changes: 35 additions & 0 deletions tests/utilities/test_tiktoken.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os
from unittest.mock import MagicMock, patch

from marvin.settings import ChatCompletionSettings, settings, temporary_settings


def test_tiktoken_cache_dir_setting(tmp_path):
with temporary_settings(
openai__chat__completions__tiktoken__cache_dir=str(tmp_path)
):
_ = settings.openai.chat.completions.encoder
assert os.environ.get("TIKTOKEN_CACHE_DIR") == str(tmp_path)

# Check that the environment is cleaned up after the test
assert "TIKTOKEN_CACHE_DIR" not in os.environ


def test_tiktoken_default_behavior():
# Test with default settings (no cache dir, SSL verification enabled)
with patch("tiktoken.encoding_for_model") as mock_encoding:
mock_encoder = MagicMock()
mock_encoding.return_value = mock_encoder

chat_settings = ChatCompletionSettings()
_ = chat_settings.encoder

# Check that TIKTOKEN_CACHE_DIR is not set
assert "TIKTOKEN_CACHE_DIR" not in os.environ

# Check that SSL verification is not modified
import ssl

assert ssl._create_default_https_context != ssl._create_unverified_context

mock_encoding.assert_called_once_with(chat_settings.model)