Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(product-assistant): a human in the loop for the taxonomy agent #26767

Open
wants to merge 66 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
dcb1bca
feat: in-memory human in the loop
skoob13 Dec 4, 2024
7ef5260
feat: force the agent to ask for clarification
skoob13 Dec 4, 2024
d26ad87
chore: bump deps
skoob13 Dec 5, 2024
6c686f5
feat: checkpointer v1
skoob13 Dec 5, 2024
ac9491b
fix: tests
skoob13 Dec 5, 2024
80d39e5
test: more tests
skoob13 Dec 5, 2024
1b80601
feat: thread retrieval
skoob13 Dec 6, 2024
80222a7
fix: debug
skoob13 Dec 6, 2024
a38229a
fix: lock avoiding race conditions
skoob13 Dec 6, 2024
c2a3c62
fix: lock order
skoob13 Dec 6, 2024
88614f9
fix: refactor locking
skoob13 Dec 9, 2024
bad9d66
test: test interruptions and time travelling
skoob13 Dec 9, 2024
288fea1
refactor: move some methods on the model level
skoob13 Dec 9, 2024
207b832
test: make sure a checkpoint is copied
skoob13 Dec 9, 2024
cb45539
Merge branch 'master' of github.com:PostHog/posthog into feat/agent-h…
skoob13 Dec 9, 2024
8ebab16
chore: migrations
skoob13 Dec 9, 2024
6c40f48
fix: mypy issues
skoob13 Dec 9, 2024
c21d4e8
refactor: restoring state
skoob13 Dec 9, 2024
8f2c583
feat: split human in the loop prompt
skoob13 Dec 9, 2024
f6a60dd
test: writes
skoob13 Dec 9, 2024
221308a
refactor: message history reconstruction
skoob13 Dec 9, 2024
273c0c9
fix: message restoration for taxonomy agent
skoob13 Dec 10, 2024
8bf9c19
feat: track initiators
skoob13 Dec 10, 2024
1dafce5
Merge branch 'master' of github.com:PostHog/posthog into feat/agent-h…
skoob13 Dec 10, 2024
1b5695f
fix: rename filter conversation
skoob13 Dec 10, 2024
32b783d
fix: test failures
skoob13 Dec 10, 2024
a5d018d
fix: utils test
skoob13 Dec 10, 2024
ce766d4
fix: assistant tests
skoob13 Dec 10, 2024
6ee0a5c
fix: tests
skoob13 Dec 10, 2024
55a1244
feat: assign id to human messages
skoob13 Dec 10, 2024
4207e79
feat: use ids everywhere
skoob13 Dec 10, 2024
2734866
fix: taxonomy_agent conversation filtering
skoob13 Dec 10, 2024
fdc5d92
fix: more tests
skoob13 Dec 10, 2024
0b1839c
fix: taxonomy agent tests
skoob13 Dec 11, 2024
7b75257
fix: rest of tests
skoob13 Dec 11, 2024
a5b2d18
Update UI snapshots for `chromium` (1)
github-actions[bot] Dec 11, 2024
43e92c7
test: node interruptions
skoob13 Dec 11, 2024
10f1f62
Merge branch 'feat/agent-human-in-the-loop' of github.com:PostHog/pos…
skoob13 Dec 11, 2024
7bbed0d
Merge branch 'master' of github.com:PostHog/posthog into feat/agent-h…
skoob13 Dec 11, 2024
f2d6260
chore: rename models
skoob13 Dec 11, 2024
195004b
fix: use new message instead of a full conversation
skoob13 Dec 11, 2024
d9b42de
feat: conversation endpoint
skoob13 Dec 11, 2024
15f42df
fix: tests
skoob13 Dec 11, 2024
560f409
fix: tests
skoob13 Dec 12, 2024
792b77f
test: more tests
skoob13 Dec 12, 2024
11d4122
fix: mypy
skoob13 Dec 12, 2024
b39d894
test: fix tests after changing the way we send conversations
skoob13 Dec 12, 2024
4ececa5
chore: readability
skoob13 Dec 12, 2024
3cc78e2
feat: frontend
skoob13 Dec 12, 2024
ab2e8b1
Update UI snapshots for `chromium` (1)
github-actions[bot] Dec 12, 2024
bffa82a
Update query snapshots
github-actions[bot] Dec 12, 2024
73d9e7c
Update query snapshots
github-actions[bot] Dec 12, 2024
c88ad92
feat: pydantic schemas in state
skoob13 Dec 12, 2024
075bc84
fix: update tests
skoob13 Dec 12, 2024
f08e615
fix: deal with reserved langgraph's values
skoob13 Dec 12, 2024
6f4ee5e
fix: tests
skoob13 Dec 12, 2024
fc39eb1
fix: imports
skoob13 Dec 12, 2024
4f6108a
fix: mypy
skoob13 Dec 12, 2024
967482d
fix: github warning
skoob13 Dec 12, 2024
250e3e4
Merge branch 'master' of github.com:PostHog/posthog into feat/agent-h…
skoob13 Dec 12, 2024
b8dc2c8
fix: loop
skoob13 Dec 12, 2024
a37b1d2
fix: test
skoob13 Dec 12, 2024
6f63035
fix: storybook
skoob13 Dec 12, 2024
1e055d1
Merge branch 'feat/agent-human-in-the-loop' of github.com:PostHog/pos…
skoob13 Dec 12, 2024
582f175
Update UI snapshots for `chromium` (1)
github-actions[bot] Dec 12, 2024
1ef8667
Update query snapshots
github-actions[bot] Dec 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions ee/api/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import cast

from django.http import StreamingHttpResponse
from pydantic import ValidationError
from rest_framework import serializers
from rest_framework.renderers import BaseRenderer
from rest_framework.request import Request
from rest_framework.viewsets import GenericViewSet

from ee.hogai.assistant import Assistant
from ee.models.assistant import Conversation
from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.models.user import User
from posthog.rate_limit import AIBurstRateThrottle, AISustainedRateThrottle
from posthog.schema import HumanMessage


class MessageSerializer(serializers.Serializer):
content = serializers.CharField(required=True, max_length=1000)
conversation = serializers.UUIDField(required=False)

def validate(self, data):
try:
message = HumanMessage(content=data["content"])
data["message"] = message
except ValidationError:
raise serializers.ValidationError("Invalid message content.")
return data


class ServerSentEventRenderer(BaseRenderer):
media_type = "text/event-stream"
format = "txt"

def render(self, data, accepted_media_type=None, renderer_context=None):
return data


class ConversationViewSet(TeamAndOrgViewSetMixin, GenericViewSet):
scope_object = "INTERNAL"
serializer_class = MessageSerializer
renderer_classes = [ServerSentEventRenderer]
queryset = Conversation.objects.all()
lookup_url_kwarg = "conversation"

def safely_get_queryset(self, queryset):
# Only allow access to conversations created by the current user
return queryset.filter(user=self.request.user)

def get_throttles(self):
return [AIBurstRateThrottle(), AISustainedRateThrottle()]

def create(self, request: Request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
conversation_id = serializer.validated_data.get("conversation")
if conversation_id:
self.kwargs[self.lookup_url_kwarg] = conversation_id
conversation = self.get_object()
else:
conversation = self.get_queryset().create(user=request.user, team=self.team)
assistant = Assistant(
self.team,
conversation,
serializer.validated_data["message"],
user=cast(User, request.user),
is_new_conversation=not conversation_id,
)
return StreamingHttpResponse(assistant.stream(), content_type=ServerSentEventRenderer.media_type)
157 changes: 157 additions & 0 deletions ee/api/test/test_conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from unittest.mock import patch

from rest_framework import status

from ee.hogai.assistant import Assistant
from ee.models.assistant import Conversation
from posthog.models.team.team import Team
from posthog.models.user import User
from posthog.test.base import APIBaseTest


class TestConversation(APIBaseTest):
def setUp(self):
super().setUp()
self.other_team = Team.objects.create(organization=self.organization, name="other team")
self.other_user = User.objects.create_and_join(
organization=self.organization,
email="[email protected]",
password="password",
first_name="Other",
)

def _get_streaming_content(self, response):
return b"".join(response.streaming_content)

def test_create_conversation(self):
with patch.object(Assistant, "_stream", return_value=["test response"]) as stream_mock:
response = self.client.post(
"/api/projects/@current/conversations/",
{"content": "test query"},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(self._get_streaming_content(response), b"test response")
self.assertEqual(Conversation.objects.count(), 1)
conversation: Conversation = Conversation.objects.first()
self.assertEqual(conversation.user, self.user)
self.assertEqual(conversation.team, self.team)
stream_mock.assert_called_once()

def test_add_message_to_existing_conversation(self):
with patch.object(Assistant, "_stream", return_value=["test response"]) as stream_mock:
conversation = Conversation.objects.create(user=self.user, team=self.team)
response = self.client.post(
"/api/projects/@current/conversations/",
{
"conversation": str(conversation.id),
"content": "test query",
},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(self._get_streaming_content(response), b"test response")
self.assertEqual(Conversation.objects.count(), 1)
stream_mock.assert_called_once()

def test_cant_access_other_users_conversation(self):
conversation = Conversation.objects.create(user=self.other_user, team=self.team)

self.client.force_login(self.user)
response = self.client.post(
"/api/projects/@current/conversations/",
{"conversation": conversation.id, "content": "test query"},
)

self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

def test_cant_access_other_teams_conversation(self):
conversation = Conversation.objects.create(user=self.user, team=self.other_team)
response = self.client.post(
"/api/projects/@current/conversations/",
{"conversation": conversation.id, "content": "test query"},
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

def test_invalid_message_format(self):
response = self.client.post("/api/projects/@current/conversations/")
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_rate_limit_burst(self):
# Create multiple requests to trigger burst rate limit
with patch.object(Assistant, "_stream", return_value=["test response"]):
for _ in range(11): # Assuming burst limit is less than this
response = self.client.post(
"/api/projects/@current/conversations/",
{"content": "test query"},
)
self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)

def test_empty_content(self):
response = self.client.post(
"/api/projects/@current/conversations/",
{"content": ""},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_content_too_long(self):
response = self.client.post(
"/api/projects/@current/conversations/",
{"content": "x" * 1001}, # Very long message
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_invalid_conversation_id(self):
response = self.client.post(
"/api/projects/@current/conversations/",
{
"conversation": "not-a-valid-uuid",
"content": "test query",
},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_nonexistent_conversation(self):
response = self.client.post(
"/api/projects/@current/conversations/",
{
"conversation": "12345678-1234-5678-1234-567812345678",
"content": "test query",
},
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

def test_deleted_conversation(self):
# Create and then delete a conversation
conversation = Conversation.objects.create(user=self.user, team=self.team)
conversation_id = conversation.id
conversation.delete()

response = self.client.post(
"/api/projects/@current/conversations/",
{
"conversation": str(conversation_id),
"content": "test query",
},
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

def test_unauthenticated_request(self):
self.client.logout()
response = self.client.post(
"/api/projects/@current/conversations/",
{"content": "test query"},
)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

def test_streaming_error_handling(self):
def raise_error():
yield "some content"
raise Exception("Streaming error")

with patch.object(Assistant, "_stream", side_effect=raise_error):
response = self.client.post(
"/api/projects/@current/conversations/",
{"content": "test query"},
)
with self.assertRaises(Exception) as context:
b"".join(response.streaming_content)
self.assertTrue("Streaming error" in str(context.exception))
Loading
Loading