Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Phase 2 | Versioning code changes #687

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions api-schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ paths:
type: string
format: binary
description: ''
/api/openai/{experiment_id}/chat/completions:
/api/openai/{experiment_id}/{version}/chat/completions:
post:
operationId: openai_chat_completions
description: |2
Expand All @@ -91,7 +91,7 @@ paths:

client = OpenAI(
api_key="your API key",
base_url=f"https://chatbots.dimagi.com/api/openai/{experiment_id}",
base_url=f"https://chatbots.dimagi.com/api/openai/{experiment_id}/default",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this break existing links?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. There's 4 API channels, so we'll just have to tell 4 people to change the endpoint

)

completion = client.chat.completions.create(
Expand All @@ -112,6 +112,13 @@ paths:
type: string
description: Experiment ID
required: true
- in: path
name: version
schema:
type: string
description: The experiment version. This can be either 'default' or the version
number
required: true
tags:
- OpenAI
requestBody:
Expand Down
21 changes: 18 additions & 3 deletions apps/api/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from apps.api.serializers import ExperimentSessionCreateSerializer, MessageSerializer
from apps.channels.tasks import handle_api_message
from apps.experiments.models import Experiment


@extend_schema(
Expand All @@ -30,7 +31,7 @@

client = OpenAI(
api_key="your API key",
base_url=f"https://chatbots.dimagi.com/api/openai/{experiment_id}",
base_url=f"https://chatbots.dimagi.com/api/openai/{experiment_id}/default",
)

completion = client.chat.completions.create(
Expand Down Expand Up @@ -83,10 +84,16 @@
location=OpenApiParameter.PATH,
description="Experiment ID",
),
OpenApiParameter(
name="version",
type=OpenApiTypes.STR,
location=OpenApiParameter.PATH,
description=("The experiment version. This can be either 'default' or the version number"),
),
],
)
@api_view(["POST"])
def chat_completions(request, experiment_id: str):
def chat_completions(request, experiment_id: str, version: str):
messages = request.data.get("messages", [])
try:
last_message = messages.pop()
Expand All @@ -107,9 +114,17 @@ def chat_completions(request, experiment_id: str):
return _make_error_response(400, str(e))

session = serializer.save()
if version == "default":
experiment_version = session.default_experiment_version
else:
try:
experiment_version = session.experiment_version(version)
except Experiment.DoesNotExist:
return _make_error_response(404, f"Version {version} was not found for this experiment")

response_message = handle_api_message(
request.user,
session.experiment_version,
experiment_version,
session.experiment_channel,
last_message.get("content"),
session.participant.identifier,
Expand Down
47 changes: 32 additions & 15 deletions apps/api/tests/test_openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from unittest.mock import patch

import pytest
from openai import OpenAI
from openai import NotFoundError, OpenAI
from pytest_django.fixtures import live_server_helper

from apps.api.models import UserAPIKey
Expand Down Expand Up @@ -48,26 +48,43 @@ def api_key(team_with_users):
available_apps=["apps.api", "apps.experiments", "apps.teams", "apps.users"],
serialized_rollback=True,
)
@pytest.mark.parametrize(
("version", "version_exists"),
[
("default", True),
(2, False),
],
)
@patch("apps.chat.channels.ApiChannel._get_bot_response")
def test_chat_completion(mock_experiment_response, experiment, api_key, live_server):
def test_chat_completion(mock_experiment_response, version, version_exists, experiment, api_key, live_server):
mock_experiment_response.return_value = "I am fine, thank you."

base_url = f"{live_server.url}/api/openai/{experiment.public_id}"
base_url = f"{live_server.url}/api/openai/{experiment.public_id}/{version}"

client = OpenAI(
api_key=api_key,
base_url=base_url,
)

completion = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hi, how are you?"},
],
)

assert ExperimentSession.objects.count() == 1
assert completion.id == ExperimentSession.objects.first().external_id
assert completion.model == experiment.llm
assert completion.choices[0].message.content == "I am fine, thank you."
if version_exists:
completion = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hi, how are you?"},
],
)

assert ExperimentSession.objects.count() == 1
assert completion.id == ExperimentSession.objects.first().external_id
assert completion.model == experiment.llm
assert completion.choices[0].message.content == "I am fine, thank you."
else:
with pytest.raises(NotFoundError, match=f"Version {version} was not found for this experiment"):
client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hi, how are you?"},
],
)
6 changes: 5 additions & 1 deletion apps/api/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@

urlpatterns = [
path("participants/", views.update_participant_data, name="update-participant-data"),
path("openai/<str:experiment_id>/chat/completions", openai.chat_completions, name="openai-chat-completions"),
path(
"openai/<str:experiment_id>/<str:version>/chat/completions",
openai.chat_completions,
name="openai-chat-completions",
),
path("files/<int:pk>/content", views.file_content_view, name="file-content"),
path("", include(router.urls)),
]
4 changes: 2 additions & 2 deletions apps/chat/bots.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class TopicBot:
"""

def __init__(self, session: ExperimentSession, experiment: Experiment | None = None, disable_tools: bool = False):
self.experiment = experiment or session.experiment_version
self.experiment = experiment or session.default_experiment_version
self.disable_tools = disable_tools
self.prompt = self.experiment.prompt_text
self.input_formatter = self.experiment.input_formatter
Expand Down Expand Up @@ -254,7 +254,7 @@ def filter_ai_messages(self) -> bool:

class PipelineBot:
def __init__(self, session: ExperimentSession):
self.experiment = session.experiment_version
self.experiment = session.default_experiment_version
self.session = session

def process_input(self, user_input: str, save_input_to_history=True, attachments: list["Attachment"] | None = None):
Expand Down
6 changes: 3 additions & 3 deletions apps/events/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def invoke(self, session: ExperimentSession, action) -> str:
history = session.chat.get_langchain_messages_until_summary()
current_summary = history.pop(0).content if history[0].type == ChatMessageType.SYSTEM else ""
messages = session.chat.get_langchain_messages()
summary = SummarizerMixin(llm=session.experiment.get_chat_model(), prompt=prompt).predict_new_summary(
messages, current_summary
)
summary = SummarizerMixin(
llm=session.default_experiment_version.get_chat_model(), prompt=prompt
).predict_new_summary(messages, current_summary)

return summary

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Generated by Django 5.1 on 2024-09-17 08:00

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('events', '0017_statictrigger_working_version_and_more'),
]

operations = [
migrations.AddField(
model_name='statictrigger',
name='is_archived',
field=models.BooleanField(default=False),
),
migrations.AddField(
model_name='timeouttrigger',
name='is_archived',
field=models.BooleanField(default=False),
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def update_periodic_task(apps, schema_editor):
class Migration(migrations.Migration):

dependencies = [
('events', '0017_statictrigger_working_version_and_more'),
('events', '0018_statictrigger_is_archived_timeouttrigger_is_archived'),
]

operations = [migrations.RunPython(update_periodic_task, migrations.RunPython.noop)]
37 changes: 33 additions & 4 deletions apps/events/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from apps.events.const import TOTAL_FAILURES
from apps.experiments.models import Experiment, ExperimentSession, VersionsMixin
from apps.teams.models import BaseTeamModel
from apps.teams.utils import current_team
from apps.utils.models import BaseModel
from apps.utils.slug import get_next_unique_id
from apps.utils.time import pretty_date
Expand Down Expand Up @@ -89,6 +90,17 @@ class StaticTriggerType(models.TextChoices):
PARTICIPANT_JOINED_EXPERIMENT = ("participant_joined", "A new participant joined the experiment")


class StaticTriggerObjectManager(models.Manager):
def get_queryset(self):
return super().get_queryset().filter(is_archived=False)


# TODO: Can we have a versions Manager mixin?
class TimeoutTriggerObjectManager(models.Manager):
def get_queryset(self):
return super().get_queryset().filter(is_archived=False)


class StaticTrigger(BaseModel, VersionsMixin):
action = models.OneToOneField(EventAction, on_delete=models.CASCADE, related_name="static_trigger")
experiment = models.ForeignKey(Experiment, on_delete=models.CASCADE, related_name="static_triggers")
Expand All @@ -101,14 +113,17 @@ class StaticTrigger(BaseModel, VersionsMixin):
blank=True,
related_name="versions",
)
is_archived = models.BooleanField(default=False)
objects = StaticTriggerObjectManager()

@property
def trigger_type(self):
return "StaticTrigger"

def fire(self, session):
try:
result = ACTION_HANDLERS[self.action.action_type]().invoke(session, self.action)
with current_team(session.team):
result = ACTION_HANDLERS[self.action.action_type]().invoke(session, self.action)
self.event_logs.create(session=session, status=EventLogStatusChoices.SUCCESS, log=result)
return result
except Exception as e:
Expand All @@ -121,6 +136,10 @@ def delete(self, *args, **kwargs):
self.action.delete(*args, **kwargs)
return result

def archive(self):
self.is_archived = True
self.save()

@transaction.atomic()
def create_new_version(self, new_experiment: Experiment):
"""Create a duplicate and assign the `new_experiment` to it. Also duplicate all EventActions"""
Expand Down Expand Up @@ -149,6 +168,8 @@ class TimeoutTrigger(BaseModel, VersionsMixin):
blank=True,
related_name="versions",
)
is_archived = models.BooleanField(default=False)
objects = TimeoutTriggerObjectManager()

@transaction.atomic()
def create_new_version(self, new_experiment: Experiment):
Expand Down Expand Up @@ -279,6 +300,10 @@ def _has_triggers_left(self, session, message):

return not (has_succeeded or failed)

def archive(self):
self.is_archived = True
self.save()


class TimePeriod(models.TextChoices):
HOURS = ("hours", "Hours")
Expand Down Expand Up @@ -333,12 +358,16 @@ def safe_trigger(self):
logger.exception(f"An error occured while trying to send scheduled messsage {self.id}. Error: {e}")

def _trigger(self):
experiment_id = self.params.get("experiment_id", self.experiment.id)
experiment_session = self.participant.get_latest_session(experiment=self.experiment)
if experiment_id := self.params.get("experiment_id"):
experiment_to_use = Experiment.objects.get(id=experiment_id)
else:
experiment_to_use = self.experiment.default_version

experiment_session = self.participant.get_latest_session(experiment=self.experiment.get_working_version())
if not experiment_session:
# Schedules probably created by the API
return
experiment_to_use = Experiment.objects.get(id=experiment_id)

experiment_session.ad_hoc_bot_message(
self.params["prompt_text"], fail_silently=False, use_experiment=experiment_to_use
)
Expand Down
15 changes: 9 additions & 6 deletions apps/events/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from celery.app import shared_task
from django.db.models import functions

from apps.events.models import ScheduledMessage, StaticTrigger, TimeoutTrigger
from apps.events.models import ScheduledMessage, StaticTrigger, StaticTriggerType, TimeoutTrigger
from apps.experiments.models import ExperimentSession

logger = logging.getLogger(__name__)
Expand All @@ -12,14 +12,17 @@
@shared_task(ignore_result=True)
def enqueue_static_triggers(session_id, trigger_type):
session = ExperimentSession.objects.get(id=session_id)

trigger_ids = StaticTrigger.objects.filter(experiment_id=session.experiment_id, type=trigger_type).values_list(
"id", flat=True
)
for trigger_id in trigger_ids:
for trigger_id in _get_triggers_to_fire(session, trigger_type):
fire_static_trigger.delay(trigger_id, session_id)


def _get_triggers_to_fire(session: ExperimentSession, trigger_type: StaticTriggerType) -> list[int]:
trigger_ids = StaticTrigger.objects.filter(
experiment=session.default_experiment_version, type=trigger_type
).values_list("id", flat=True)
return trigger_ids


@shared_task(ignore_result=True)
def fire_static_trigger(trigger_id, session_id):
trigger = StaticTrigger.objects.get(id=trigger_id)
Expand Down
Loading
Loading