Skip to content

Commit

Permalink
Merge pull request #995 from dimagi/cs/scheduled_message_updates
Browse files Browse the repository at this point in the history
Fix: Use correct versioned bot for scheduled messages
  • Loading branch information
SmittieC authored Dec 18, 2024
2 parents f4a633e + 1b2891c commit 7b85ecc
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 6 deletions.
24 changes: 21 additions & 3 deletions apps/events/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,14 +352,13 @@ def safe_trigger(self):
logger.exception(f"An error occured while trying to send scheduled messsage {self.id}. Error: {e}")

def _trigger(self):
experiment_id = self.params.get("experiment_id", self.experiment.id)
experiment_session = self.participant.get_latest_session(experiment=self.experiment)
if not experiment_session:
# Schedules probably created by the API
return
experiment_to_use = Experiment.objects.get(id=experiment_id)

experiment_session.ad_hoc_bot_message(
self.params["prompt_text"], fail_silently=False, use_experiment=experiment_to_use
self.params["prompt_text"], fail_silently=False, use_experiment=self._get_experiment_to_generate_response()
)

utc_now = timezone.now()
Expand All @@ -373,6 +372,25 @@ def _trigger(self):

self.save()

def _get_experiment_to_generate_response(self) -> Experiment:
"""
- If no child bot was specified to generate the response, use the default experiment version
- If a child bot was specified to generate the response, we must find the version of the child bot that is
linked to the default router.
"""
default_router_experiment = self.experiment.default_version
experiment_id = self.params.get("experiment_id")
if experiment_id and int(experiment_id) != self.experiment.id:
if default_router_experiment.is_a_version and default_router_experiment.child_links.count() > 0:
# Find the child of this version that has the specified experiment as its working version
return (
default_router_experiment.child_links.filter(child__working_version_id=experiment_id).first().child
)

return Experiment.objects.get(id=experiment_id)

return default_router_experiment

def _should_mark_complete(self):
return bool(not self.remaining_triggers or (self.end_date and self.end_date <= timezone.now()))

Expand Down
51 changes: 50 additions & 1 deletion apps/events/tests/test_scheduled_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

from apps.events.models import EventActionType, ScheduledMessage, TimePeriod
from apps.events.tasks import _get_messages_to_fire, poll_scheduled_messages
from apps.experiments.models import ExperimentRoute
from apps.utils.factories.events import EventActionFactory, ScheduledMessageFactory
from apps.utils.factories.experiment import ExperimentSessionFactory
from apps.utils.factories.experiment import ExperimentFactory, ExperimentSessionFactory
from apps.utils.time import timedelta_to_relative_delta


Expand Down Expand Up @@ -260,3 +261,51 @@ def test_schedule_update():
_assert_next_trigger_date(message3, message3_next_trigger_data)
assert message1.next_trigger_date < message1_prev_trigger_date
assert message2.next_trigger_date < message2_prev_trigger_date


@pytest.mark.django_db()
def test_schedule_trigger_for_versioned_routes():
router = ExperimentFactory()
child = ExperimentFactory(team=router.team)
session = ExperimentSessionFactory(experiment=router)

ExperimentRoute.objects.create(
team=router.team,
parent=router,
child=child,
keyword="keyword1",
)

event_action, params = _construct_event_action(frequency=1, time_period=TimePeriod.DAYS, experiment_id=None)
sm = ScheduledMessageFactory(
team=router.team, action=event_action, experiment=router, participant=session.participant
)
# No experiment specified, so the router should be used (no version yet, so router == default version)
assert sm._get_experiment_to_generate_response() == router

new_params = sm.action.params
new_params["experiment_id"] = child.id
sm.action.params = new_params
sm.action.save()

# No versions yet, so the working version of the child should be used
assert sm._get_experiment_to_generate_response() == child

default_router = router.create_new_version(make_default=True)
router.refresh_from_db()
del router.default_version
sm.refresh_from_db()
child_version = default_router.child_links.first().child
assert new_params["experiment_id"] == child_version.working_version_id
# The router is versioned and the deployed version is not the working version, so the child of the deployed version
# should be used
assert sm._get_experiment_to_generate_response() == child_version

new_params = sm.action.params
new_params["experiment_id"] = None
sm.action.params = new_params
sm.action.save()
sm.refresh_from_db()
# Clear the `params` cached property on ScheduledMessage
del sm.params
assert sm._get_experiment_to_generate_response() == router.default_version
2 changes: 1 addition & 1 deletion apps/experiments/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,7 @@ def eligible_children(cls, team: Team, parent: Experiment | None = None):
else:
eligible_experiments = Experiment.objects.filter(team=team).exclude(id__in=parent_ids)

return eligible_experiments
return eligible_experiments.filter(working_version_id=None)

@transaction.atomic()
def create_new_version(self, new_parent: Experiment) -> "ExperimentRoute":
Expand Down
2 changes: 1 addition & 1 deletion apps/experiments/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def test_eligible_children(self):
assert len(queryset) == 2

queryset = ExperimentRoute.eligible_children(team=parent.team)
assert len(queryset) == 4
assert len(queryset) == 3

def test_compare_with_model_testcase_1(self):
"""
Expand Down

0 comments on commit 7b85ecc

Please sign in to comment.