Skip to content

Commit

Permalink
Merge pull request #698 from dimagi/cs/chat_to_version
Browse files Browse the repository at this point in the history
Ability to chat to a specific version
  • Loading branch information
SmittieC authored Oct 9, 2024
2 parents df2890c + d932292 commit 6f32256
Show file tree
Hide file tree
Showing 15 changed files with 196 additions and 50 deletions.
20 changes: 20 additions & 0 deletions apps/channels/tests/test_web_channel.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import pytest
from django.http import Http404
from django.test import override_settings
from mock.mock import Mock, patch

from apps.channels.models import ChannelPlatform
from apps.chat.channels import WebChannel
from apps.chat.models import Chat


@pytest.mark.django_db()
Expand Down Expand Up @@ -40,6 +42,12 @@ def test_start_new_session(new_user_message, get_ai_message_id, with_seed_messag
assert message.attachments == []


@pytest.mark.django_db()
def test_404_raised_when_version_is_not_found(experiment):
with pytest.raises(Http404):
WebChannel.start_new_session(experiment, "[email protected]", version=20)


@pytest.mark.django_db()
class TestVersioning:
@patch("apps.events.tasks.enqueue_static_triggers", Mock())
Expand All @@ -54,3 +62,15 @@ def test_start_new_session_uses_default_version(self, check_and_process_seed_mes
_session_used, experiment_used = check_and_process_seed_message.call_args[0]
assert experiment_used == new_version
assert session.experiment == experiment
assert session.chat.metadata.get(Chat.MetadataKeys.EXPERIMENT_VERSION) == experiment.DEFAULT_VERSION_NUMBER

@patch("apps.events.tasks.enqueue_static_triggers", Mock())
@patch("apps.chat.channels.WebChannel.check_and_process_seed_message")
def test_start_new_session_uses_specified_version(self, check_and_process_seed_message, experiment):
new_version = experiment.create_new_version()
session = WebChannel.start_new_session(experiment, "[email protected]", version=1)

_session_used, experiment_used = check_and_process_seed_message.call_args[0]
assert experiment_used == new_version
assert session.experiment == experiment
assert session.chat.metadata.get(Chat.MetadataKeys.EXPERIMENT_VERSION) == 1
5 changes: 2 additions & 3 deletions apps/chat/agent/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,9 @@ def create_schedule_message(
}


def get_tools(experiment_session, for_assistant=False) -> list[BaseTool]:
def get_tools(experiment_session, experiment) -> list[BaseTool]:
tools = []
experiment = experiment_session.experiment
tool_names = experiment.assistant.tools if for_assistant else experiment.tools
tool_names = experiment.assistant.tools if experiment.assistant else experiment.tools
for tool_name in tool_names:
tool_cls = TOOL_CLASS_MAP[tool_name]
tools.append(tool_cls(experiment_session=experiment_session))
Expand Down
13 changes: 11 additions & 2 deletions apps/chat/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import emoji
import requests
from django.db import transaction
from django.http import Http404
from telebot import TeleBot
from telebot.util import antiflood, smart_split

Expand All @@ -21,7 +22,7 @@
ParticipantNotAllowedException,
VersionedExperimentSessionsNotAllowedException,
)
from apps.chat.models import ChatMessage, ChatMessageType
from apps.chat.models import Chat, ChatMessage, ChatMessageType
from apps.events.models import StaticTriggerType
from apps.events.tasks import enqueue_static_triggers
from apps.experiments.models import (
Expand Down Expand Up @@ -542,12 +543,20 @@ def start_new_session(
participant_user: CustomUser | None = None,
session_status: SessionStatus = SessionStatus.ACTIVE,
timezone: str | None = None,
version: int = Experiment.DEFAULT_VERSION_NUMBER,
):
experiment_channel = ExperimentChannel.objects.get_team_web_channel(working_experiment.team)
session = super().start_new_session(
working_experiment, experiment_channel, participant_identifier, participant_user, session_status, timezone
)
WebChannel.check_and_process_seed_message(session, working_experiment.default_version)

try:
experiment_version = working_experiment.get_version(version)
session.chat.set_metadata(Chat.MetadataKeys.EXPERIMENT_VERSION, version)
except Experiment.DoesNotExist:
raise Http404(f"Experiment with version {version} not found")

WebChannel.check_and_process_seed_message(session, experiment_version)
return session

@classmethod
Expand Down
1 change: 1 addition & 0 deletions apps/chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Chat(BaseTeamModel, TaggedModelMixin, UserCommentsMixin):

class MetadataKeys(StrEnum):
OPENAI_THREAD_ID = "openai_thread_id"
EXPERIMENT_VERSION = "experiment_version"

# must match or be greater than experiment name field
name = models.CharField(max_length=128, default="Unnamed Chat")
Expand Down
27 changes: 27 additions & 0 deletions apps/experiments/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,9 @@ class Experiment(BaseTeamModel, VersionsMixin):
Each experiment can be run as a chatbot.
"""

# 0 is a reserved version number, meaning the default version
DEFAULT_VERSION_NUMBER = 0

owner = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE)
name = models.CharField(max_length=128)
description = models.TextField(null=True, default="", verbose_name="A longer description of the experiment.") # noqa DJ001
Expand Down Expand Up @@ -596,6 +599,18 @@ def save(self, *args, **kwargs):
def get_absolute_url(self):
return reverse("experiments:single_experiment_home", args=[self.team.slug, self.id])

def get_version(self, version: int) -> "Experiment":
"""
Returns the version of this experiment family matching `version`. If `version` is 0, the default version is
returned.
"""
working_version = self.get_working_version()
if version == self.DEFAULT_VERSION_NUMBER:
return working_version.default_version
elif working_version.version_number == version:
return working_version
return working_version.versions.get(version_number=version)

@property
def tools_enabled(self):
return len(self.tools) > 0
Expand Down Expand Up @@ -1264,6 +1279,11 @@ def working_experiment(self) -> Experiment:
"""Returns the default experiment, or if there is none, the working experiment"""
return self.experiment.get_working_version()

@property
def experiment_version_for_display(self):
version_number = self.get_experiment_version_number()
return "Default version" if version_number == Experiment.DEFAULT_VERSION_NUMBER else f"v{version_number}"

def get_participant_timezone(self):
participant_data = self.participant_data_from_experiment
return participant_data.get("timezone")
Expand All @@ -1280,3 +1300,10 @@ def get_participant_data(self, use_participant_tz=False):
if scheduled_messages:
participant_data = {**participant_data, "scheduled_messages": scheduled_messages}
return self.participant.global_data | participant_data

def get_experiment_version_number(self) -> int:
"""
Returns the version that is being chatted to. If it's the default version, return 0 which is the default
experiment's version number
"""
return self.chat.metadata.get(Chat.MetadataKeys.EXPERIMENT_VERSION, Experiment.DEFAULT_VERSION_NUMBER)
3 changes: 2 additions & 1 deletion apps/experiments/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class ExperimentSessionsTable(tables.Table):
verbose_name="Tags",
template_name="annotations/tag_ui.html",
)
version = columns.Column(verbose_name="Version", accessor="experiment_version_for_display")
actions = actions.chip_column(label="Session Details", align="center", verbose_name="")

def render_tags(self, record, bound_column):
Expand All @@ -156,7 +157,7 @@ class ExperimentVersionsTable(tables.Table):
template_code="""{% if record.is_default_version %}
<span aria-label="true">✓</span>
{% endif %}""",
verbose_name="Default Version",
verbose_name="Deployed",
)
details = columns.TemplateColumn(
template_name="experiments/components/experiment_version_details_button.html",
Expand Down
40 changes: 32 additions & 8 deletions apps/experiments/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from django.utils import timezone
from freezegun import freeze_time

from apps.chat.models import Chat
from apps.events.actions import ScheduleTriggerAction
from apps.events.models import EventActionType, ScheduledMessage, TimePeriod
from apps.experiments.models import (
Expand Down Expand Up @@ -105,6 +106,7 @@ def test_get_for_team_do_not_include_other_team_exclusive_voices(self):
assert voice1 not in voices_queryset


@pytest.mark.django_db()
class TestExperimentSession:
def _construct_event_action(self, time_period: TimePeriod, experiment_id: int, frequency=1, repetitions=1) -> tuple:
params = self._get_params(experiment_id, time_period, frequency, repetitions)
Expand All @@ -120,7 +122,6 @@ def _get_params(self, experiment_id: int, time_period: TimePeriod = TimePeriod.D
"experiment_id": experiment_id,
}

@pytest.mark.django_db()
@freeze_time("2024-01-01")
def test_get_participant_scheduled_messages_custom_params(self):
session = ExperimentSessionFactory()
Expand Down Expand Up @@ -175,7 +176,6 @@ def _make_expected_dict(external_id):
]
assert participant.get_schedules_for_experiment(experiment, as_dict=True) == expected_dict_version

@pytest.mark.django_db()
@pytest.mark.parametrize(
("repetitions", "total_triggers", "expected_triggers_remaining"),
[
Expand Down Expand Up @@ -209,7 +209,6 @@ def test_get_schedules_for_experiment_as_dict(self, repetitions, total_triggers,
assert schedule["total_triggers"] == total_triggers
assert schedule["triggers_remaining"] == expected_triggers_remaining

@pytest.mark.django_db()
@freeze_time("2024-01-01")
@pytest.mark.parametrize(
("time_period", "repetitions", "total_triggers", "expected"),
Expand Down Expand Up @@ -288,7 +287,6 @@ def test_get_schedules_for_experiment_as_string(self, time_period, repetitions,
next_trigger = "Next trigger is at Monday, 01 January 2024 00:00:00 UTC."
assert schedule == expected.format(message=message, next_trigger=next_trigger)

@pytest.mark.django_db()
def test_get_participant_scheduled_messages_includes_child_experiments(self):
session = ExperimentSessionFactory()
team = session.team
Expand All @@ -304,7 +302,6 @@ def test_get_participant_scheduled_messages_includes_child_experiments(self):
assert len(participant.get_schedules_for_experiment(session2.experiment)) == 1
assert len(participant.get_schedules_for_experiment(session.experiment)) == 2

@pytest.mark.django_db()
@pytest.mark.parametrize("use_custom_experiment", [False, True])
def test_scheduled_message_experiment(self, use_custom_experiment):
"""ScheduledMessages should use the experiment specified in the linked action's params"""
Expand Down Expand Up @@ -350,7 +347,6 @@ def test_should_mark_complete(self, repetitions, total_triggers, end_date, expec
)
assert scheduled_message._should_mark_complete() == expected

@pytest.mark.django_db()
def test_get_participant_data_name(self):
participant = ParticipantFactory()
session = ExperimentSessionFactory(participant=participant, team=participant.team)
Expand All @@ -376,7 +372,6 @@ def test_get_participant_data_name(self):
"first_name": "Jimmy",
}

@pytest.mark.django_db()
@freeze_time("2022-01-01 08:00:00")
@pytest.mark.parametrize("use_participant_tz", [False, True])
def test_get_participant_data_timezone(self, use_participant_tz):
Expand Down Expand Up @@ -407,7 +402,6 @@ def test_get_participant_data_timezone(self, use_participant_tz):
participant_data.pop("scheduled_messages")
assert participant_data == expected_data

@pytest.mark.django_db()
@pytest.mark.parametrize("fail_silently", [True, False])
@patch("apps.chat.channels.ChannelBase.from_experiment_session")
@patch("apps.chat.bots.TopicBot.process_input")
Expand All @@ -432,6 +426,17 @@ def _test():
else:
_test()

@pytest.mark.parametrize(
("chat_metadata_version", "expected_display_val"),
[
(Experiment.DEFAULT_VERSION_NUMBER, "Default version"),
("1", "v1"),
],
)
def test_experiment_version_for_display(self, chat_metadata_version, expected_display_val, experiment_session):
experiment_session.chat.set_metadata(Chat.MetadataKeys.EXPERIMENT_VERSION, chat_metadata_version)
assert experiment_session.experiment_version_for_display == expected_display_val


class TestParticipant:
@pytest.mark.django_db()
Expand Down Expand Up @@ -723,6 +728,25 @@ def _assert_attribute_duplicated(self, attr_name, original_experiment, new_versi
expected_changed_fields=["id", "working_version_id"],
)

def test_get_version(self, experiment):
"""Test that we are able to find a specific experiment version using any experiment in the version family"""
first_version = experiment.create_new_version()
second_version = experiment.create_new_version(make_default=True)
experiment.refresh_from_db()
first_version.refresh_from_db()

# Get the default version
assert experiment.get_version(Experiment.DEFAULT_VERSION_NUMBER) == second_version
assert first_version.get_version(Experiment.DEFAULT_VERSION_NUMBER) == second_version

# Get the working version. Its version number should be 3
assert experiment.get_version(3) == experiment
assert first_version.get_version(3) == experiment

# Get a specific version
assert experiment.get_version(1) == first_version
assert first_version.get_version(1) == first_version


@pytest.mark.django_db()
class TestExperimentObjectManager:
Expand Down
39 changes: 34 additions & 5 deletions apps/experiments/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from waffle.testutils import override_flag

from apps.chat.channels import WebChannel
from apps.chat.models import Chat
from apps.experiments.models import (
AgentTools,
Experiment,
Expand Down Expand Up @@ -354,17 +355,21 @@ def test_timezone_saved_in_participant_data(_trigger_mock):


@pytest.mark.django_db()
@pytest.mark.parametrize("version", [Experiment.DEFAULT_VERSION_NUMBER, 1])
@mock.patch("apps.chat.channels.enqueue_static_triggers", mock.Mock())
@mock.patch("apps.experiments.views.experiment.get_response_for_webchat_task.delay")
def test_experiment_session_message_view_creates_files(delay_mock, experiment, client):
def test_experiment_session_message_view_creates_files(delay_mock, version, experiment, client):
task = mock.Mock()
task.task_id = 1
delay_mock.return_value = task
session = ExperimentSessionFactory(experiment=experiment, participant=ParticipantFactory(user=experiment.owner))
url = reverse(
"experiments:experiment_session_message",
kwargs={"team_slug": experiment.team.slug, "experiment_id": experiment.id, "session_id": session.id},
)
url_kwargs = {
"team_slug": experiment.team.slug,
"experiment_id": experiment.id,
"session_id": session.id,
"version_number": version,
}
url = reverse("experiments:experiment_session_message", kwargs=url_kwargs)

client.force_login(experiment.owner)
file_search_file = BytesIO(b"some content")
Expand Down Expand Up @@ -396,3 +401,27 @@ def test_get_queryset(self, experiment):
view = ExperimentTableView()
view.request = request
assert list(view.get_queryset().all()) == [experiment]


@pytest.mark.django_db()
@pytest.mark.parametrize("version_number", [1, 2])
def test_start_authed_web_session_with_version(version_number, client):
team = TeamWithUsersFactory()
working_experiment = ExperimentFactory(team=team)
working_experiment.create_new_version()

client.force_login(working_experiment.team.members.first())
url = reverse(
"experiments:start_authed_web_session",
kwargs={
"team_slug": working_experiment.team.slug,
"experiment_id": working_experiment.id,
"version_number": version_number,
},
)

response = client.post(url, data={})
assert response.status_code == 302
assert working_experiment.sessions.count() == 1
expected_chat_metadata = {Chat.MetadataKeys.EXPERIMENT_VERSION: version_number}
assert working_experiment.sessions.first().chat.metadata == expected_chat_metadata
8 changes: 5 additions & 3 deletions apps/experiments/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,19 @@
path("e/<int:pk>/add_file/", views.AddFileToExperiment.as_view(), name="add_file"),
path("e/<int:pk>/delete_file/<int:file_id>/", views.DeleteFileFromExperiment.as_view(), name="remove_file"),
path(
"e/<int:experiment_id>/start_authed_web_session/",
"e/<int:experiment_id>/v/<int:version_number>/start_authed_web_session/",
views.start_authed_web_session,
name="start_authed_web_session",
),
path("e/<int:experiment_id>/create_channel/", views.create_channel, name="create_channel"),
path("e/<int:experiment_id>/update_channel/<int:channel_id>/", views.update_delete_channel, name="update_channel"),
path(
"e/<int:experiment_id>/session/<int:session_id>/", views.experiment_chat_session, name="experiment_chat_session"
"e/<int:experiment_id>/v/<int:version_number>/session/<int:session_id>/",
views.experiment_chat_session,
name="experiment_chat_session",
),
path(
"e/<int:experiment_id>/session/<int:session_id>/message/",
"e/<int:experiment_id>/v/<int:version_number>/session/<int:session_id>/message/",
views.experiment_session_message,
name="experiment_session_message",
),
Expand Down
Loading

0 comments on commit 6f32256

Please sign in to comment.