Skip to content

Commit

Permalink
fix null bug (#2326)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sea-Snell authored Jan 3, 2025
1 parent 00425a3 commit 4de10b1
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 12 deletions.
32 changes: 25 additions & 7 deletions fact_edit_experiments/run_gsm8k.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def trigger_rethink_memory(agent_state: "AgentState", message: Optional[str]) ->
)
'''
OPENAI_CONFIG = LLMConfig.default_config("gpt-4o-mini")
# NOTE: if we start using this finction, we might need to change the model here

client.set_default_llm_config(OPENAI_CONFIG)
client.set_default_embedding_config(EmbeddingConfig.default_config(model_name="letta"))
Expand Down Expand Up @@ -94,13 +95,26 @@ def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_labe
def run_memory_edits(gsm8k_input_file: str,
output_file: str,
human_block_filename: str = "human_accurate",
persona_block_filename: str = "persona_block_verbose",
persona_block_filename: str = "persona_verbose",
system_block_filename: str = "convo_base",
offline_system_block_filename: str = "offline_base",
random_example: bool = False,
few_shot: bool = True,
limit: int = None,
skip_first: int = None) -> None:
skip_first: int = None,
offline_memory_model: Optional[str] = None,
conversation_model: Optional[str] = None) -> None:

if offline_memory_model is None:
offline_openai_config = OPENAI_CONFIG
else:
offline_openai_config = LLMConfig.default_config(offline_memory_model)

if conversation_model is None:
conversation_openai_config = OPENAI_CONFIG
else:
conversation_openai_config = LLMConfig.default_config(conversation_model)

if few_shot:
with open("gsm8k_experiments/gsm8k-cot.yaml", "r") as f:
test_yaml = f.read()
Expand Down Expand Up @@ -170,7 +184,7 @@ def run_memory_edits(gsm8k_input_file: str,
name="conversation_agent",
agent_type=AgentType.memgpt_agent,
system=get_system_text(system_block_filename),
llm_config=OPENAI_CONFIG,
llm_config=conversation_openai_config,
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
# tools=["send_message", trigger_rethink_memory_tool.name],
tools=["send_message"],
Expand All @@ -185,7 +199,7 @@ def run_memory_edits(gsm8k_input_file: str,
agent_type=AgentType.offline_memory_agent,
system=get_system_text(offline_system_block_filename),
memory=offline_memory,
llm_config=OPENAI_CONFIG,
llm_config=offline_openai_config,
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
tools = ["rethink_memory", "finish_rethinking_memory"],
# tool_ids=[rethink_memory_tool.id, finish_rethinking_memory_tool.id],
Expand Down Expand Up @@ -219,7 +233,7 @@ def run_memory_edits(gsm8k_input_file: str,
offline_message.updated_at = offline_message.updated_at.isoformat()
'''

import ipdb; ipdb.set_trace()
# import ipdb; ipdb.set_trace()
writer.write(
{
"question": example["question"],
Expand Down Expand Up @@ -252,13 +266,15 @@ def run_memory_edits(gsm8k_input_file: str,
parser.add_argument("--input_file", type=str, default="./GSM8K_p2.jsonl", required=False)
parser.add_argument("--output_file", default="./predictions-GSM8k_p2.jsonl", required=False)
parser.add_argument("--human_block_filename", default="human_accurate", required=False)
parser.add_argument("--persona_block_filename", default="persona_block_verbose", required=False)
parser.add_argument("--persona_block_filename", default="persona_verbose", required=False)
parser.add_argument("--system_block_filename", default="convo_base", required=False)
parser.add_argument("--offline_system_block_filename", default="offline_base", required=False)
parser.add_argument("--random_example", action="store_true") # debug by using a random example
parser.add_argument("--few_shot", default=8, required=False, type=int)
parser.add_argument("--limit", default=None, required=False, type=int)
parser.add_argument("--skip_first", default=0, required=False, type=int)
parser.add_argument("--offline_memory_model", default="gpt-4o-mini", required=False)
parser.add_argument("--conversation_model", default="gpt-4o-mini", required=False)
args = parser.parse_args()

run_memory_edits(args.input_file,
Expand All @@ -270,4 +286,6 @@ def run_memory_edits(gsm8k_input_file: str,
args.random_example,
args.few_shot,
args.limit,
args.skip_first)
args.skip_first,
args.offline_memory_model,
args.conversation_model)
5 changes: 3 additions & 2 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,8 @@ def execute_tool_and_persist_state(self, function_name, function_to_call, functi
)
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
self.update_memory_if_change(updated_agent_state.memory)
if updated_agent_state is not None:
self.update_memory_if_change(updated_agent_state.memory)
except Exception as e:
# Need to catch error here, or else trunction wont happen
# TODO: modify to function execution error
Expand Down Expand Up @@ -584,7 +585,7 @@ def append_to_messages(self, added_messages: List[dict]):
def _get_ai_reply(
self,
message_sequence: List[Message],
function_call: str = "auto",
function_call: Optional[str] = None,
first_message: bool = False,
stream: bool = False, # TODO move to config?
empty_response_retry_limit: int = 3,
Expand Down
9 changes: 7 additions & 2 deletions letta/llm_api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def unpack_all_inner_thoughts_from_kwargs(

def unpack_inner_thoughts_from_kwargs(choice: Choice, inner_thoughts_key: str) -> Choice:
message = choice.message
rewritten_choice = choice # inner thoughts unpacked out of the function
if message.role == "assistant" and message.tool_calls and len(message.tool_calls) >= 1:
if len(message.tool_calls) > 1:
warnings.warn(f"Unpacking inner thoughts from more than one tool call ({len(message.tool_calls)}) is not supported")
Expand All @@ -271,14 +272,18 @@ def unpack_inner_thoughts_from_kwargs(choice: Choice, inner_thoughts_key: str) -
warnings.warn(f"Overwriting existing inner monologue ({new_choice.message.content}) with kwarg ({inner_thoughts})")
new_choice.message.content = inner_thoughts

return new_choice
# update the choice object
rewritten_choice = new_choice
else:
warnings.warn(f"Did not find inner thoughts in tool call: {str(tool_call)}")
return choice

except json.JSONDecodeError as e:
warnings.warn(f"Failed to strip inner thoughts from kwargs: {e}")
raise e
else:
warnings.warn(f"Did not find tool call in message: {str(message)}")

return rewritten_choice


def is_context_overflow_error(exception: Union[requests.exceptions.RequestException, Exception]) -> bool:
Expand Down
11 changes: 10 additions & 1 deletion letta/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def create(
user_id: Optional[str] = None, # option UUID to associate request with
functions: Optional[list] = None,
functions_python: Optional[dict] = None,
function_call: str = "auto",
function_call: Optional[str] = None,
# hint
first_message: bool = False,
# use tool naming?
Expand Down Expand Up @@ -147,10 +147,19 @@ def create(

# openai
if llm_config.model_endpoint_type == "openai":

if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
# only is a problem if we are *not* using an openai proxy
raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"])

if function_call is None and functions is not None and len(functions) > 0:
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# TODO(matt) move into LLMConfig
if llm_config.model_endpoint == "https://inference.memgpt.ai":
function_call = "auto" # TODO change to "required" once proxy supports it
else:
function_call = "required"

data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens)
if stream: # Client requested token streaming
data.stream = True
Expand Down
8 changes: 8 additions & 0 deletions letta/schemas/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ def default_config(cls, model_name: str):
model_wrapper=None,
context_window=128000,
)
elif model_name == "gpt-4o":
return cls(
model="gpt-4o",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
model_wrapper=None,
context_window=128000,
)
elif model_name == "letta":
return cls(
model="memgpt-openai",
Expand Down

0 comments on commit 4de10b1

Please sign in to comment.