Skip to content

Commit

Permalink
tdd_repro finishes a given agent run and saves the conversation histo…
Browse files Browse the repository at this point in the history
…ry + patch to a directory. add flags to run.py to continue a conversation with manual input
  • Loading branch information
toshok committed Sep 26, 2024
1 parent 3aa0068 commit 11c2881
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 14 deletions.
3 changes: 2 additions & 1 deletion config/default_with_tools.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ instance_template: |-
{issue}
</ISSUE_DESCRIPTION>
{tdd_results}
# INSTRUCTIONS
* Solve this issue on your own. Your terminal session has started and you're in the repository's root directory. Edit files and run any checks or tests as needed.
* YOU CAN ONLY MAKE ONE TOOL CALL (RUN A COMMAND) AT A TIME. You should always wait for feedback after every command.
Expand All @@ -29,6 +28,8 @@ instance_template: |-
3. CWD is the directory of the repo you are supposed to edit. Only modify files inside this directory. Always provide absolute file paths, prefixed with $PWD.
4. When editing files, it is easy to accidentally specify a wrong line number or to write code with incorrect indentation. Always check the code after you issue an edit to make sure that it reflects what you wanted to accomplish. If it didn't, issue another command to fix it.
# HOW TO BEGIN
To start things off, run the 'tdd_repro' command to reproduce the issue.
(Open file: {open_file})
(Current directory: {working_dir})
Expand Down
23 changes: 20 additions & 3 deletions run_instance.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
set -euo pipefail

case $# in
1)
MANUAL_INPUT_ARGS=""
;;

3)
MANUAL_INPUT_ARGS="--manual_input_conversation_path $2 --manual_input_continuation_label $3"
;;

*)
echo "Usage: $0 <instance_id> [<manual_input_conversation_path> <manual_input_continuation_label>]"
exit 1
;;
esac

set -x

# This runs the instance from the official SWE-agent demo video.
# See: https://www.youtube.com/watch?v=CeMtJ4XObAM
echo $MANUAL_INPUT_ARGS

python3 run.py \
--model_name "claude-sonnet-3.5" \
--data_path "princeton-nlp/SWE-bench_Verified" \
Expand All @@ -12,5 +28,6 @@ python3 run.py \
--instance_filter "$1" \
--skip_existing False \
--cache_task_images \
--tdd
--tdd \
$MANUAL_INPUT_ARGS

53 changes: 46 additions & 7 deletions sweagent/agent/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from sweagent.agent.commands import Command, ParseCommand
from sweagent.agent.history_processors import HistoryProcessor
from sweagent.agent.manual_input import ManualInput
from sweagent.agent.model_result import AnthropicModelResult
from sweagent.agent.models import (
APIStats,
ContextWindowExceededError,
Expand Down Expand Up @@ -194,6 +196,9 @@ class AgentArguments(FlattenedAccess, FrozenSerializable):
# We put tdd on the agent args because it needs it in post_init.
tdd: bool = False

manual_input_conversation_path: Path | None = None
manual_input_continuation_label: str | None = None

def __post_init__(self):
if self.config is None and self.config_file is not None:
# If unassigned, we load the config from the file to store its contents with the overall arguments
Expand Down Expand Up @@ -284,6 +289,8 @@ def __init__(self, name: str, args: AgentArguments, env: SWEEnv):
self.last_container_id = None
self.hooks = []
self.logger = get_logger(f"Agent[{name}]")
self.manual_input = ManualInput(args.manual_input_conversation_path, args.manual_input_continuation_label)
self.initial_model_response = None

def add_hook(self, hook: AgentHook):
"""Add hook to agent"""
Expand Down Expand Up @@ -318,7 +325,7 @@ def _system_repro_prompt(self):

# Unflag all tdd entries from history, so it can be compressed.
def _unflag_tdd_history(self):
self._assert_tdd_history_entries()
# self._assert_tdd_history_entries()

for entry in self.history:
if "tdd" in entry:
Expand Down Expand Up @@ -349,6 +356,21 @@ def setup(self, env: SWEEnv, instance_args, init_model_stats=None) -> None:
self.model.setup(init_model_stats)
self.instance_args = instance_args

if self.manual_input.enabled():
history_and_patch = self.manual_input.load_conversation()
if history_and_patch is not None:
history, patch = history_and_patch
self.initial_model_response = history[-1]
assert self.initial_model_response["role"] == "assistant"
assert self.initial_model_response["action"] == "tdd_repro"
self.history = history[:-1]
self.made_initial_prompt = True
if patch is not None:
logger.info(f"Applying patch:\n>>{patch}\n<<")
env._apply_patch(patch)
return


# Compose system prompt.
system_msg = self.config.system_template.format(**self.system_args)
system_msg = f"{system_msg}\n\n{self._system_repro_prompt()}"
Expand Down Expand Up @@ -571,7 +593,16 @@ def forward(self, observation: str, available_actions: list[str], state: str) ->
action: action that the model proposes
output: raw model output (not output of the action)
"""
thought, action, output = self.forward_with_error_check(observation, state)
if self.initial_model_response is not None:
thought, action, content = self.initial_model_response["thought"], self.initial_model_response["action"], self.initial_model_response["content"]
if isinstance(content, str):
output = content
else:
output = AnthropicModelResult(blocks=content)
self.initial_model_response = None
else:
thought, action, output = self.forward_with_error_check(observation, state)

last_tool_name = get_last_valid_tool_use_name(output)
last_command = self.get_command(last_tool_name)
ran_tdd_action = last_command.tdd if last_command else False
Expand Down Expand Up @@ -620,7 +651,7 @@ def forward_model(self, observation: str, state: str) -> ModelQueryResult:
if self.config.strategy_template is not None:
templates.append(self.config.strategy_template)
# Get tdd_results, to be rendered into the initial_prompt template.
state_vars["tdd_results"] = self._make_initial_tdd_result()
# state_vars["tdd_results"] = self._make_initial_tdd_result()
elif observation is None or observation.strip() == "":
# Show no output template if observation content was empty
templates = [self.config.next_step_no_output_template]
Expand All @@ -646,7 +677,7 @@ def forward_model(self, observation: str, state: str) -> ModelQueryResult:
self._append_history(
{
"role": "user",
"content": make_user_reply_content(message, None, self.history, False),
"content": make_user_reply_content(message, None, self.history, False, self.manual_input.load_continuation_file()),
"agent": self.name,
"tdd": self.env.tdd and is_init,
}
Expand All @@ -665,7 +696,7 @@ def retry_after_format_fail(self, output: ModelQueryResult) -> ModelQueryResult:

temp_history = self.local_history + [
{"role": "assistant", "content": make_assistant_content(output), "agent": self.name},
{"role": "user", "content": make_user_reply_content(format_error_template, output, self.history, True), "agent": self.name},
{"role": "user", "content": make_user_reply_content(format_error_template, output, self.history, True, None), "agent": self.name},
]
return self.model.query(temp_history)

Expand All @@ -679,7 +710,7 @@ def retry_after_blocklist_fail(self, output: ModelQueryResult, action: str) -> M

temp_history = self.local_history + [
{"role": "assistant", "content": make_assistant_content(output), "agent": self.name},
{"role": "user", "content": make_user_reply_content(blocklist_error_message, output, self.history, True), "agent": self.name},
{"role": "user", "content": make_user_reply_content(blocklist_error_message, output, self.history, True, None), "agent": self.name},
]
return self.model.query(temp_history)

Expand Down Expand Up @@ -896,13 +927,15 @@ def run(
If return_type is "info_trajectory", returns a tuple of
the info dictionary and the trajectory (list of dictionaries).
"""
self.made_initial_prompt = False
done = False
# mypy checks
assert env.container_obj is not None
assert env.record is not None
assert self.config is not None

self.made_initial_prompt = False
self.manual_input.set_instance_id(env.record["instance_id"])

if env.container_obj.id != self.last_container_id:
self.logger.info(f"Initializing agent settings for container {env.container_obj.id}")
self.init_environment_vars(env)
Expand All @@ -922,13 +955,15 @@ def run(
for hook in self.hooks:
hook.on_step_start()
state = env.communicate(self.state_command) if self.state_command else None
stop_on_tdd_repro = self.manual_input.enabled() and self.initial_model_response is None
thought, action, output = self.forward(observation, env.get_available_actions(), state)
for hook in self.hooks:
hook.on_actions_generated(thought=thought, action=action, output=repr(output))
observations = list()
run_action = self._guard_multiline_input(action)
for sub_action in self.split_actions(run_action):
if sub_action["agent"] == self.name or sub_action["cmd_name"] == self.config.submit_command:
logger.warning(f"ACTION: {sub_action['action']}")
for hook in self.hooks:
hook.on_sub_action_started(sub_action=sub_action)
obs, _, done, info = env.step(sub_action["action"])
Expand All @@ -937,6 +972,10 @@ def run(
observations.append(obs)
if sub_action["cmd_name"] == self.config.submit_command:
done = True
if sub_action["action"] == "tdd_repro" and stop_on_tdd_repro:
done = True
patch = env.communicate("git add -A && git diff --cached")
self.manual_input.save_conversation(self.history, patch)
if done:
break
else:
Expand Down
94 changes: 94 additions & 0 deletions sweagent/agent/manual_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from __future__ import annotations
import os
from typing import Tuple

from sweagent.utils.log import get_logger
from sweagent.agent.model_cache import json_serialize_file, json_deserialize_file, json_serialize_str, json_deserialize_str, hash_string

ManualTDDInputEnvVar = "MANUAL_TDD_INPUT_DIRECTORY"

logger = get_logger("manual_tdd_input")

class ManualInput:
base_dir: str | None
instance_id: str | None
conversation_path: str | None
continuation_label: str | None

def __init__(self, conversation_path: str | None, continuation_label: str | None):
self.conversation_path = conversation_path
self.continuation_label = continuation_label
self.base_dir = None
if ManualTDDInputEnvVar in os.environ:
logger.warning("⚠ ManualInput is enabled")
self.base_dir = os.environ[ManualTDDInputEnvVar]

def enabled(self) -> bool:
return self.base_dir is not None

def set_instance_id(self, instance_id: str) -> None:
self.instance_id = instance_id

def _get_conversation_dir(self) -> str:
if self.conversation_path is None:
return os.path.join(self.base_dir, self.instance_id)
return os.path.join(self.base_dir, self.instance_id, self.conversation_path)

def load_continuation_file(self) -> str | None:
if not self.enabled():
return None

try:
with open(os.path.join(self._get_conversation_dir(), f"{self.continuation_label}.md"), "r") as f:
return f.read()
except FileNotFoundError:
return None

def save_conversation(self, conversation: list[dict[str, str]], patch: str | None) -> None:
if not self.enabled():
return None

parent_dir = self._get_conversation_dir()

# if continuation_label is left off, we're storing the root conversation
if self.continuation_label is None:
self.continuation_label = "root"

content = json_serialize_str(conversation)
hash = hash_string(content)

new_subdir = os.path.join(parent_dir, f"{self.continuation_label}-{hash}")
os.makedirs(new_subdir, exist_ok=True)

with open(os.path.join(new_subdir, "conversation.json"), "w") as f:
f.write(content)

if patch is not None:
with open(os.path.join(new_subdir, "patch.diff"), "w") as f:
f.write(patch)

def load_conversation(self) -> Tuple[list[dict[str, str]], str] | None:
if not self.enabled():
return None

dir = self._get_conversation_dir()
if not os.path.exists(dir):
return None

conversation_file_path = os.path.join(dir, "conversation.json")
if not os.path.exists(conversation_file_path):
return None

with open(conversation_file_path, "r") as f:
conversation = json_deserialize_str(f.read())

patch_file_path = os.path.join(dir, "patch.diff")

# a missing patch isn't an error (will happen with the first tdd_repro call)
if os.path.exists(patch_file_path):
with open(patch_file_path, "r") as f:
patch = f.read()
else:
patch = None

return conversation, patch
2 changes: 1 addition & 1 deletion sweagent/agent/model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(self):

def _get_file(self, history: list[dict[str, str]]) -> str:
hash_input = json_serialize_str(history)
print(f"HASH_INPUT\n{hash_input}\nEND_OF_HASH_INPUT")
# print(f"HASH_INPUT\n{hash_input}\nEND_OF_HASH_INPUT")
hash = hash_string(hash_input)
return f"{self.directory}/model-query-{hash}.json"

Expand Down
7 changes: 5 additions & 2 deletions sweagent/agent/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_last_valid_tool_use_block(model_result: ModelQueryResult | None, history
tool_use = last_assistant_block
return tool_use

def make_user_reply_content(action_result: str, model_result: ModelQueryResult | None, history: list[dict[str, any]], is_error: bool):
def make_user_reply_content(action_result: str, model_result: ModelQueryResult | None, history: list[dict[str, any]], is_error: bool, tool_extra_content: str | None):
"""
Create a tool_result block from the action_result, model_result, and history.
action_result: The return value of the action from the environment.
Expand All @@ -65,10 +65,13 @@ def make_user_reply_content(action_result: str, model_result: ModelQueryResult |
tool_use = get_last_valid_tool_use_block(model_result, history)
if tool_use:
# This is a reply to a tool_use:
if tool_use.name == "tdd_repro" and tool_extra_content is not None:
action_result += f"\n\n{tool_extra_content}"

result = {
"type": "tool_result",
"tool_use_id": tool_use.id,
"content": action_result
"content": action_result,
}
if is_error:
result["is_error"] = True
Expand Down
12 changes: 12 additions & 0 deletions sweagent/environment/swe_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,18 @@ def _apply_test_patch(self):
)
self.logger.debug(f"[TDD] Applied test patch - output:\n{res}")

def _apply_patch(self, patch: str) -> None:
"""
Apply patch to source in repo
"""
patch_path = "/root/model.patch"
self.copy_string_to_container_file(patch, patch_path)
self.communicate_with_handling(
f"cd /{self._repo_name} && cat {patch_path} && git apply -v {patch_path} && rm {patch_path}",
error_msg="Failed to apply patch",
)
self.logger.debug(f"[TDD] Applied previous changes - output:\n{res}")

def step(self, action: str) -> tuple[str | None, int, bool, dict]:
"""
Runs an action proposed by the agent in the environment and returns the corresponding output.
Expand Down

0 comments on commit 11c2881

Please sign in to comment.