diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 1f4a628c..da531083 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -172,7 +172,7 @@ def get_completion_stream(self, if not inspect.isclass(event_handler): raise Exception("Event handler must not be an instance.") - return self.main_thread.get_completion_stream(message=message, + res = self.main_thread.get_completion_stream(message=message, message_files=message_files, event_handler=event_handler, attachments=attachments, @@ -181,6 +181,10 @@ def get_completion_stream(self, tool_choice=tool_choice ) + event_handler.on_all_streams_end() + + return res + def demo_gradio(self, height=450, dark_mode=True, **kwargs): """ Launches a Gradio-based demo interface for the agency chatbot. diff --git a/agency_swarm/threads/thread.py b/agency_swarm/threads/thread.py index 2f9b0c5e..94af6ea3 100644 --- a/agency_swarm/threads/thread.py +++ b/agency_swarm/threads/thread.py @@ -208,9 +208,6 @@ def get_completion(self, continue - if event_handler: - event_handler.on_all_streams_end() - return full_message def _create_run(self, recipient_agent, additional_instructions, event_handler, tool_choice): diff --git a/tests/test_agency.py b/tests/test_agency.py index 67a56270..6fe01e90 100644 --- a/tests/test_agency.py +++ b/tests/test_agency.py @@ -259,6 +259,7 @@ def test_5_agent_communication_stream(self): test_tool_used = False test_agent2_used = False + num_on_all_streams_end_calls = 0 class EventHandler(AgencyEventHandler): @override @@ -273,6 +274,12 @@ def on_tool_call_done(self, tool_call: ToolCall) -> None: nonlocal test_tool_used test_tool_used = True + @override + @classmethod + def on_all_streams_end(cls): + nonlocal num_on_all_streams_end_calls + num_on_all_streams_end_calls += 1 + message = self.__class__.agency.get_completion_stream( "Please tell TestAgent1 to tell TestAgent 2 to use test tool.", event_handler=EventHandler, @@ -284,6 +291,7 @@ def on_tool_call_done(self, tool_call: ToolCall) -> None: self.assertTrue(test_tool_used) self.assertTrue(test_agent2_used) + self.assertTrue(num_on_all_streams_end_calls == 1) self.assertTrue(self.__class__.TestTool.shared_state.get("test_tool_used"))