Skip to content

Commit

Permalink
Fix duplicate state initialization (#6089)
Browse files Browse the repository at this point in the history
  • Loading branch information
enyst authored Jan 7, 2025
1 parent 9747c9e commit aabbbb6
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 18 deletions.
10 changes: 6 additions & 4 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,10 +993,12 @@ def _is_stuck(self) -> bool:

def __repr__(self):
return (
f'AgentController(id={self.id}, agent={self.agent!r}, '
f'event_stream={self.event_stream!r}, '
f'state={self.state!r}, '
f'delegate={self.delegate!r}, _pending_action={self._pending_action!r})'
f'AgentController(id={getattr(self, "id", "<uninitialized>")}, '
f'agent={getattr(self, "agent", "<uninitialized>")!r}, '
f'event_stream={getattr(self, "event_stream", "<uninitialized>")!r}, '
f'state={getattr(self, "state", "<uninitialized>")!r}, '
f'delegate={getattr(self, "delegate", "<uninitialized>")!r}, '
f'_pending_action={getattr(self, "_pending_action", "<uninitialized>")!r})'
)

def _is_awaiting_observation(self):
Expand Down
26 changes: 12 additions & 14 deletions openhands/server/session/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,27 +274,25 @@ def _create_controller(
confirmation_mode=confirmation_mode,
headless_mode=False,
status_callback=self._status_callback,
initial_state=self._maybe_restore_state(),
)

# Note: We now attempt to restore the state from session here,
# but if it fails, we fall back to None and still initialize the controller
# with a fresh state. That way, the controller will always load events from the event stream
# even if the state file was corrupt.
return controller

def _maybe_restore_state(self) -> State | None:
"""Helper method to handle state restore logic."""
restored_state = None

# Attempt to restore the state from session.
# Use a heuristic to figure out if we should have a state:
# if we have events in the stream.
try:
restored_state = State.restore_from_session(self.sid, self.file_store)
logger.debug(f'Restored state from session, sid: {self.sid}')
except Exception as e:
if self.event_stream.get_latest_event_id() > 0:
# if we have events, we should have a state
logger.warning(f'State could not be restored: {e}')

# Set the initial state through the controller.
controller.set_initial_state(restored_state, max_iterations, confirmation_mode)
if restored_state:
logger.debug(f'Restored agent state from session, sid: {self.sid}')
else:
logger.debug('New session state created.')

logger.debug('Agent controller initialized.')
return controller
else:
logger.debug('No events found, no state to restore')
return restored_state
186 changes: 186 additions & 0 deletions tests/unit/test_agent_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
from openhands.controller.state.state import State
from openhands.core.config import AppConfig, LLMConfig
from openhands.events import EventStream, EventStreamSubscriber
from openhands.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.runtime.base import Runtime
from openhands.server.session.agent_session import AgentSession
from openhands.storage.memory import InMemoryFileStore


@pytest.fixture
def mock_agent():
"""Create a properly configured mock agent with all required nested attributes"""
# Create the base mocks
agent = MagicMock(spec=Agent)
llm = MagicMock(spec=LLM)
metrics = MagicMock(spec=Metrics)
llm_config = MagicMock(spec=LLMConfig)

# Configure the LLM config
llm_config.model = 'test-model'
llm_config.base_url = 'http://test'
llm_config.draft_editor = None
llm_config.max_message_chars = 1000

# Set up the chain of mocks
llm.metrics = metrics
llm.config = llm_config
agent.llm = llm
agent.name = 'test-agent'
agent.sandbox_plugins = []

return agent


@pytest.mark.asyncio
async def test_agent_session_start_with_no_state(mock_agent):
"""Test that AgentSession.start() works correctly when there's no state to restore"""

# Setup
file_store = InMemoryFileStore({})
session = AgentSession(sid='test-session', file_store=file_store)

# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=Runtime)

# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
session.runtime = mock_runtime

session._create_runtime = AsyncMock(side_effect=mock_create_runtime)

# Create a mock EventStream with no events
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.get_events.return_value = []
mock_event_stream.subscribe = MagicMock()
mock_event_stream.get_latest_event_id.return_value = 0

# Inject the mock event stream into the session
session.event_stream = mock_event_stream

# Create a spy on set_initial_state
class SpyAgentController(AgentController):
set_initial_state_call_count = 0
test_initial_state = None

def set_initial_state(self, *args, state=None, **kwargs):
self.set_initial_state_call_count += 1
self.test_initial_state = state
super().set_initial_state(*args, state=state, **kwargs)

# Patch AgentController and State.restore_from_session to fail
with patch(
'openhands.server.session.agent_session.AgentController', SpyAgentController
), patch(
'openhands.server.session.agent_session.EventStream',
return_value=mock_event_stream,
), patch(
'openhands.controller.state.state.State.restore_from_session',
side_effect=Exception('No state found'),
):
await session.start(
runtime_name='test-runtime',
config=AppConfig(),
agent=mock_agent,
max_iterations=10,
)

# Verify EventStream.subscribe was called with correct parameters
mock_event_stream.subscribe.assert_called_with(
EventStreamSubscriber.AGENT_CONTROLLER,
session.controller.on_event,
session.controller.id,
)

# Verify set_initial_state was called once with None as state
assert session.controller.set_initial_state_call_count == 1
assert session.controller.test_initial_state is None
assert session.controller.state.max_iterations == 10
assert session.controller.agent.name == 'test-agent'
assert session.controller.state.start_id == 0
assert session.controller.state.end_id == -1
assert session.controller.state.truncation_id == -1


@pytest.mark.asyncio
async def test_agent_session_start_with_restored_state(mock_agent):
"""Test that AgentSession.start() works correctly when there's a state to restore"""

# Setup
file_store = InMemoryFileStore({})
session = AgentSession(sid='test-session', file_store=file_store)

# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=Runtime)

# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
session.runtime = mock_runtime

session._create_runtime = AsyncMock(side_effect=mock_create_runtime)

# Create a mock EventStream with some events
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.get_events.return_value = []
mock_event_stream.subscribe = MagicMock()
mock_event_stream.get_latest_event_id.return_value = 5 # Indicate some events exist

# Inject the mock event stream into the session
session.event_stream = mock_event_stream

# Create a mock restored state
mock_restored_state = MagicMock(spec=State)
mock_restored_state.start_id = -1
mock_restored_state.end_id = -1
mock_restored_state.truncation_id = -1
mock_restored_state.max_iterations = 5

# Create a spy on set_initial_state by subclassing AgentController
class SpyAgentController(AgentController):
set_initial_state_call_count = 0
test_initial_state = None

def set_initial_state(self, *args, state=None, **kwargs):
self.set_initial_state_call_count += 1
self.test_initial_state = state
super().set_initial_state(*args, state=state, **kwargs)

# Patch AgentController and State.restore_from_session to succeed
with patch(
'openhands.server.session.agent_session.AgentController', SpyAgentController
), patch(
'openhands.server.session.agent_session.EventStream',
return_value=mock_event_stream,
), patch(
'openhands.controller.state.state.State.restore_from_session',
return_value=mock_restored_state,
):
await session.start(
runtime_name='test-runtime',
config=AppConfig(),
agent=mock_agent,
max_iterations=10,
)

# Verify set_initial_state was called once with the restored state
assert session.controller.set_initial_state_call_count == 1

# Verify EventStream.subscribe was called with correct parameters
mock_event_stream.subscribe.assert_called_with(
EventStreamSubscriber.AGENT_CONTROLLER,
session.controller.on_event,
session.controller.id,
)
assert session.controller.test_initial_state is mock_restored_state
assert session.controller.state is mock_restored_state
assert session.controller.state.max_iterations == 5
assert session.controller.state.start_id == 0
assert session.controller.state.end_id == -1
assert session.controller.state.truncation_id == -1

0 comments on commit aabbbb6

Please sign in to comment.