From 4de10b1f176fb634cd49fb770769ef6e68669bf7 Mon Sep 17 00:00:00 2001 From: Sea-Snell Date: Fri, 3 Jan 2025 13:12:30 -0800 Subject: [PATCH] fix null bug (#2326) --- fact_edit_experiments/run_gsm8k.py | 32 +++++++++++++++++++++++------- letta/agent.py | 5 +++-- letta/llm_api/helpers.py | 9 +++++++-- letta/llm_api/llm_api_tools.py | 11 +++++++++- letta/schemas/llm_config.py | 8 ++++++++ 5 files changed, 53 insertions(+), 12 deletions(-) diff --git a/fact_edit_experiments/run_gsm8k.py b/fact_edit_experiments/run_gsm8k.py index b47e4f4e0a..a6a1de5d46 100644 --- a/fact_edit_experiments/run_gsm8k.py +++ b/fact_edit_experiments/run_gsm8k.py @@ -49,6 +49,7 @@ def trigger_rethink_memory(agent_state: "AgentState", message: Optional[str]) -> ) ''' OPENAI_CONFIG = LLMConfig.default_config("gpt-4o-mini") + # NOTE: if we start using this finction, we might need to change the model here client.set_default_llm_config(OPENAI_CONFIG) client.set_default_embedding_config(EmbeddingConfig.default_config(model_name="letta")) @@ -94,13 +95,26 @@ def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_labe def run_memory_edits(gsm8k_input_file: str, output_file: str, human_block_filename: str = "human_accurate", - persona_block_filename: str = "persona_block_verbose", + persona_block_filename: str = "persona_verbose", system_block_filename: str = "convo_base", offline_system_block_filename: str = "offline_base", random_example: bool = False, few_shot: bool = True, limit: int = None, - skip_first: int = None) -> None: + skip_first: int = None, + offline_memory_model: Optional[str] = None, + conversation_model: Optional[str] = None) -> None: + + if offline_memory_model is None: + offline_openai_config = OPENAI_CONFIG + else: + offline_openai_config = LLMConfig.default_config(offline_memory_model) + + if conversation_model is None: + conversation_openai_config = OPENAI_CONFIG + else: + conversation_openai_config = LLMConfig.default_config(conversation_model) + if few_shot: with open("gsm8k_experiments/gsm8k-cot.yaml", "r") as f: test_yaml = f.read() @@ -170,7 +184,7 @@ def run_memory_edits(gsm8k_input_file: str, name="conversation_agent", agent_type=AgentType.memgpt_agent, system=get_system_text(system_block_filename), - llm_config=OPENAI_CONFIG, + llm_config=conversation_openai_config, embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"), # tools=["send_message", trigger_rethink_memory_tool.name], tools=["send_message"], @@ -185,7 +199,7 @@ def run_memory_edits(gsm8k_input_file: str, agent_type=AgentType.offline_memory_agent, system=get_system_text(offline_system_block_filename), memory=offline_memory, - llm_config=OPENAI_CONFIG, + llm_config=offline_openai_config, embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"), tools = ["rethink_memory", "finish_rethinking_memory"], # tool_ids=[rethink_memory_tool.id, finish_rethinking_memory_tool.id], @@ -219,7 +233,7 @@ def run_memory_edits(gsm8k_input_file: str, offline_message.updated_at = offline_message.updated_at.isoformat() ''' - import ipdb; ipdb.set_trace() + # import ipdb; ipdb.set_trace() writer.write( { "question": example["question"], @@ -252,13 +266,15 @@ def run_memory_edits(gsm8k_input_file: str, parser.add_argument("--input_file", type=str, default="./GSM8K_p2.jsonl", required=False) parser.add_argument("--output_file", default="./predictions-GSM8k_p2.jsonl", required=False) parser.add_argument("--human_block_filename", default="human_accurate", required=False) - parser.add_argument("--persona_block_filename", default="persona_block_verbose", required=False) + parser.add_argument("--persona_block_filename", default="persona_verbose", required=False) parser.add_argument("--system_block_filename", default="convo_base", required=False) parser.add_argument("--offline_system_block_filename", default="offline_base", required=False) parser.add_argument("--random_example", action="store_true") # debug by using a random example parser.add_argument("--few_shot", default=8, required=False, type=int) parser.add_argument("--limit", default=None, required=False, type=int) parser.add_argument("--skip_first", default=0, required=False, type=int) + parser.add_argument("--offline_memory_model", default="gpt-4o-mini", required=False) + parser.add_argument("--conversation_model", default="gpt-4o-mini", required=False) args = parser.parse_args() run_memory_edits(args.input_file, @@ -270,4 +286,6 @@ def run_memory_edits(gsm8k_input_file: str, args.random_example, args.few_shot, args.limit, - args.skip_first) + args.skip_first, + args.offline_memory_model, + args.conversation_model) diff --git a/letta/agent.py b/letta/agent.py index 98e6f9ff1e..29eada35a1 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -444,7 +444,8 @@ def execute_tool_and_persist_state(self, function_name, function_to_call, functi ) function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool" - self.update_memory_if_change(updated_agent_state.memory) + if updated_agent_state is not None: + self.update_memory_if_change(updated_agent_state.memory) except Exception as e: # Need to catch error here, or else trunction wont happen # TODO: modify to function execution error @@ -584,7 +585,7 @@ def append_to_messages(self, added_messages: List[dict]): def _get_ai_reply( self, message_sequence: List[Message], - function_call: str = "auto", + function_call: Optional[str] = None, first_message: bool = False, stream: bool = False, # TODO move to config? empty_response_retry_limit: int = 3, diff --git a/letta/llm_api/helpers.py b/letta/llm_api/helpers.py index 1244b6ffe7..b96d825c58 100644 --- a/letta/llm_api/helpers.py +++ b/letta/llm_api/helpers.py @@ -250,6 +250,7 @@ def unpack_all_inner_thoughts_from_kwargs( def unpack_inner_thoughts_from_kwargs(choice: Choice, inner_thoughts_key: str) -> Choice: message = choice.message + rewritten_choice = choice # inner thoughts unpacked out of the function if message.role == "assistant" and message.tool_calls and len(message.tool_calls) >= 1: if len(message.tool_calls) > 1: warnings.warn(f"Unpacking inner thoughts from more than one tool call ({len(message.tool_calls)}) is not supported") @@ -271,14 +272,18 @@ def unpack_inner_thoughts_from_kwargs(choice: Choice, inner_thoughts_key: str) - warnings.warn(f"Overwriting existing inner monologue ({new_choice.message.content}) with kwarg ({inner_thoughts})") new_choice.message.content = inner_thoughts - return new_choice + # update the choice object + rewritten_choice = new_choice else: warnings.warn(f"Did not find inner thoughts in tool call: {str(tool_call)}") - return choice except json.JSONDecodeError as e: warnings.warn(f"Failed to strip inner thoughts from kwargs: {e}") raise e + else: + warnings.warn(f"Did not find tool call in message: {str(message)}") + + return rewritten_choice def is_context_overflow_error(exception: Union[requests.exceptions.RequestException, Exception]) -> bool: diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 163c4e1868..fdba3e169d 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -110,7 +110,7 @@ def create( user_id: Optional[str] = None, # option UUID to associate request with functions: Optional[list] = None, functions_python: Optional[dict] = None, - function_call: str = "auto", + function_call: Optional[str] = None, # hint first_message: bool = False, # use tool naming? @@ -147,10 +147,19 @@ def create( # openai if llm_config.model_endpoint_type == "openai": + if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1": # only is a problem if we are *not* using an openai proxy raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"]) + if function_call is None and functions is not None and len(functions) > 0: + # force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice + # TODO(matt) move into LLMConfig + if llm_config.model_endpoint == "https://inference.memgpt.ai": + function_call = "auto" # TODO change to "required" once proxy supports it + else: + function_call = "required" + data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens) if stream: # Client requested token streaming data.stream = True diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index ed63e766b5..2544f7083c 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -90,6 +90,14 @@ def default_config(cls, model_name: str): model_wrapper=None, context_window=128000, ) + elif model_name == "gpt-4o": + return cls( + model="gpt-4o", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + model_wrapper=None, + context_window=128000, + ) elif model_name == "letta": return cls( model="memgpt-openai",