Skip to content

Commit

Permalink
Fix CLI and headless after changes to eventstream (#5949)
Browse files Browse the repository at this point in the history
Co-authored-by: Engel Nyst <[email protected]>
  • Loading branch information
rbren and enyst authored Jan 1, 2025
1 parent 2ec2f25 commit f3885ca
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 14 deletions.
13 changes: 7 additions & 6 deletions openhands/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def display_event(event: Event, config: AppConfig):
display_confirmation(event.confirmation_state)


async def main():
async def main(loop):
"""Runs the agent in CLI mode"""

parser = get_parser()
Expand All @@ -112,7 +112,7 @@ async def main():

logger.setLevel(logging.WARNING)
config = load_app_config(config_file=args.config_file)
sid = 'cli'
sid = str(uuid4())

agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
agent_config = config.get_agent_config(config.default_agent)
Expand Down Expand Up @@ -150,7 +150,6 @@ async def main():

async def prompt_for_next_task():
# Run input() in a thread pool to avoid blocking the event loop
loop = asyncio.get_event_loop()
next_message = await loop.run_in_executor(
None, lambda: input('How can I help? >> ')
)
Expand All @@ -165,13 +164,12 @@ async def prompt_for_next_task():
event_stream.add_event(action, EventSource.USER)

async def prompt_for_user_confirmation():
loop = asyncio.get_event_loop()
user_confirmation = await loop.run_in_executor(
None, lambda: input('Confirm action (possible security risk)? (y/n) >> ')
)
return user_confirmation.lower() == 'y'

async def on_event(event: Event):
async def on_event_async(event: Event):
display_event(event, config)
if isinstance(event, AgentStateChangedObservation):
if event.agent_state in [
Expand All @@ -193,6 +191,9 @@ async def on_event(event: Event):
ChangeAgentStateAction(AgentState.USER_REJECTED), EventSource.USER
)

def on_event(event: Event) -> None:
loop.create_task(on_event_async(event))

event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))

await runtime.connect()
Expand All @@ -208,7 +209,7 @@ async def on_event(event: Event):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
loop.run_until_complete(main(loop))
except KeyboardInterrupt:
print('Received keyboard interrupt, shutting down...')
except ConnectionRefusedError as e:
Expand Down
2 changes: 1 addition & 1 deletion openhands/core/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def get_parser() -> argparse.ArgumentParser:
parser.add_argument(
'-n',
'--name',
default='default',
default='',
type=str,
help='Name for the session',
)
Expand Down
2 changes: 1 addition & 1 deletion openhands/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ async def run_controller(
# init with the provided actions
event_stream.add_event(initial_user_action, EventSource.USER)

async def on_event(event: Event):
def on_event(event: Event):
if isinstance(event, AgentStateChangedObservation):
if event.agent_state == AgentState.AWAITING_USER_INPUT:
if exit_on_message:
Expand Down
22 changes: 18 additions & 4 deletions openhands/events/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,24 @@ async def _process_queue(self):
for callback_id in callbacks:
callback = callbacks[callback_id]
pool = self._thread_pools[key][callback_id]
pool.submit(callback, event)

def _callback(self, callback: Callable, event: Event):
asyncio.run(callback(event))
future = pool.submit(callback, event)
future.add_done_callback(self._make_error_handler(callback_id, key))

def _make_error_handler(self, callback_id: str, subscriber_id: str):
def _handle_callback_error(fut):
try:
# This will raise any exception that occurred during callback execution
fut.result()
except Exception as e:
logger.error(
f'Error in event callback {callback_id} for subscriber {subscriber_id}: {str(e)}',
exc_info=True,
stack_info=True,
)
# Re-raise in the main thread so the error is not swallowed
raise e

return _handle_callback_error

def filtered_events_by_source(self, source: EventSource):
for event in self.get_events():
Expand Down
2 changes: 1 addition & 1 deletion openhands/resolver/resolve_issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async def process_issue(
runtime = create_runtime(config)
await runtime.connect()

async def on_event(evt):
def on_event(evt):
logger.info(evt)

runtime.event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_parser_default_values():
assert args.eval_num_workers == 4
assert args.eval_note is None
assert args.llm_config is None
assert args.name == 'default'
assert args.name == ''
assert not args.no_auto_continue


Expand Down

0 comments on commit f3885ca

Please sign in to comment.