Skip to content

Commit

Permalink
first try
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCuadron committed Nov 11, 2024
1 parent dbd7ad4 commit a9e346a
Show file tree
Hide file tree
Showing 4 changed files with 370 additions and 428 deletions.
111 changes: 96 additions & 15 deletions openhands/agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import os
from collections import deque
from itertools import islice
Expand All @@ -13,6 +12,7 @@
from openhands.core.config.llm_config import LLMConfig
from openhands.core.logger import openhands_logger as logger
from openhands.core.message import ImageContent, Message, TextContent
from openhands.core.utils import json
from openhands.events.action import (
Action,
AgentDelegateAction,
Expand Down Expand Up @@ -73,6 +73,8 @@ class CodeActAgent(Agent):
JupyterRequirement(),
]
obs_prefix = 'OBSERVATION:\n'
when_to_stop = 6
number_of_events = -1

def __init__(
self,
Expand All @@ -85,6 +87,7 @@ def __init__(
- llm (LLM): The llm to be used by this agent
"""

# import pdb; pdb.set_trace()
llm_config = LLMConfig(
model='litellm_proxy/claude-3-5-sonnet-20241022',
api_key='REDACTED',
Expand All @@ -93,10 +96,9 @@ def __init__(
)
llm = LLM(llm_config)
# TODO: Remove this once we have a real AgentConfig
config = AgentConfig(llm_config='o1-mini')
config = AgentConfig()
super().__init__(llm, config)
self.reset()

self.micro_agent = (
MicroAgent(
os.path.join(
Expand Down Expand Up @@ -343,14 +345,33 @@ def step(self, state: State) -> Action:
- MessageAction(content) - Message action to run (e.g. ask for clarification)
- AgentFinishAction() - end the interaction
"""

# If this agent has a supervisor, we need to get the time to stop from the supervisor
if self.when_to_stop < 0 and state.inputs.get('when_to_stop', None):
self.when_to_stop = state.inputs['when_to_stop']

# Continue with pending actions if any
if self.pending_actions:
return self.pending_actions.popleft()

# if we're done, go back
last_user_message = state.get_last_user_message()
if last_user_message and last_user_message.strip() == '/exit':
return AgentFinishAction()
messages = self._get_messages(state)
serialized_messages = [msg.model_dump() for msg in messages]
return AgentFinishAction(
outputs={'fixed': True, 'trayectory': serialized_messages}
)

# if we've reached the max number of iterations, go back for an evaluation on the approach
if self.when_to_stop > 0 and state.local_iteration % self.when_to_stop == 0:
messages = self._get_messages(state)
serialized_messages = [
msg.model_dump() for msg in messages
] # Serialize each Message object
return AgentFinishAction(
outputs={'trayectory': serialized_messages, 'fixed': False}
)

# prepare what we want to send to the LLM
messages = self._get_messages(state)
Expand Down Expand Up @@ -409,17 +430,60 @@ def _get_messages(self, state: State) -> list[Message]:
- Messages from the same role are combined to prevent consecutive same-role messages
- For Anthropic models, specific messages are cached according to their documentation
"""
messages: list[Message] = [
Message(
role='system',
content=[
TextContent(
text=self.system_prompt,
cache_prompt=self.llm.is_caching_prompt_active(), # Cache system prompt
)
],
# import pdb; pdb.set_trace()
messages: list[Message] = []
trayectory = state.inputs.get('trayectory', '')
# If there is no trayectory, its the first time we are seeing the task
if not trayectory:
messages.append(
Message(
role='system',
content=[
TextContent(
text=self.system_prompt,
cache_prompt=self.llm.is_caching_prompt_active(), # Cache system prompt
)
],
)
)
]
if state.inputs.get('task', '') != '':
# During AgentDelegation the history is empty, so we add the task as the user message.
messages.append(
Message(
role='user',
content=[TextContent(text=state.inputs['task'])],
)
)

if state.inputs.get('augmented_task', ''):
messages.append(
Message(
role='user',
content=[TextContent(text=state.inputs['augmented_task'])],
)
)
else:
# If there is a previous trayectory, we restore it.
deserialized_trajectory = [
Message(
role='user',
content=[
TextContent(text=content_text)
for content_text in [
msg_dict['content'][0]['text']
if isinstance(msg_dict['content'], list)
else msg_dict['content']
]
if content_text # Skip empty content
],
tool_call_id=msg_dict.get('tool_call_id'),
name=msg_dict.get('name'),
)
for msg_dict in trayectory
if msg_dict.get('content') # Skip messages with no content
]
messages.extend(deserialized_trajectory)

if self.initial_user_message:
messages.append(
Message(
Expand All @@ -431,7 +495,9 @@ def _get_messages(self, state: State) -> list[Message]:
pending_tool_call_action_messages: dict[str, Message] = {}
tool_call_id_to_message: dict[str, Message] = {}
events = list(state.history)
for event in events:
if self.number_of_events < 0:
self.number_of_events = len(events)
for i, event in enumerate(events):
# create a regular message from an event
if isinstance(event, Action):
messages_to_add = self.get_action_message(
Expand All @@ -446,6 +512,14 @@ def _get_messages(self, state: State) -> list[Message]:
else:
raise ValueError(f'Unknown event type: {type(event)}')

if i == self.number_of_events and state.inputs.get('next_step', ''):
messages_to_add = [
Message(
role='user',
content=[TextContent(text=state.inputs['next_step'])],
)
]

# Check pending tool call action messages and see if they are complete
_response_ids_to_remove = []
for (
Expand Down Expand Up @@ -488,6 +562,13 @@ def _get_messages(self, state: State) -> list[Message]:
else:
messages.append(message)

if self.number_of_events == len(events) and state.inputs.get('next_step', ''):
messages.append(
Message(
role='user', content=[TextContent(text=state.inputs['next_step'])]
)
)

if self.llm.is_caching_prompt_active():
# NOTE: this is only needed for anthropic
# following logic here:
Expand Down
11 changes: 9 additions & 2 deletions openhands/agenthub/codeact_agent/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)

from openhands.core.logger import openhands_logger as logger
from openhands.core.message import Message
from openhands.events.action import (
Action,
AgentDelegateAction,
Expand Down Expand Up @@ -448,7 +449,11 @@ def combine_thought(action: Action, thought: str) -> Action:
return action


def response_to_actions(response: ModelResponse) -> list[Action]:
def response_to_actions(
response: ModelResponse, messages: list[Message] | None = None
) -> list[Action]:
if messages is None:
messages = []
actions: list[Action] = []
assert len(response.choices) == 1, 'Only one choice is supported for now'
assistant_msg = response.choices[0].message
Expand Down Expand Up @@ -481,7 +486,9 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
inputs=arguments,
)
elif tool_call.function.name == 'finish':
action = AgentFinishAction()
action = AgentFinishAction(
outputs={'fixed': True, 'trayectory': messages}
)
elif tool_call.function.name == 'edit_file':
action = FileEditAction(**arguments)
elif tool_call.function.name == 'str_replace_editor':
Expand Down
Loading

0 comments on commit a9e346a

Please sign in to comment.