From 9fd9381716e8e0ca2e6d87bee495b86c1914e28a Mon Sep 17 00:00:00 2001 From: Marc-Antoine Parent Date: Tue, 3 Oct 2023 11:06:17 -0400 Subject: [PATCH] async version of main functions with to_thread. Factor out common code --- easycompletion/__init__.py | 5 +- easycompletion/model.py | 481 ++++++++++++++++++++++++++----------- easycompletion/prompt.py | 6 + 3 files changed, 354 insertions(+), 138 deletions(-) diff --git a/easycompletion/__init__.py b/easycompletion/__init__.py index c4a47da..c7e01f8 100644 --- a/easycompletion/__init__.py +++ b/easycompletion/__init__.py @@ -1,7 +1,10 @@ from .model import ( function_completion, + function_completion_async, text_completion, - chat_completion + text_completion_async, + chat_completion, + chat_completion_async ) openai_function_call = function_completion diff --git a/easycompletion/model.py b/easycompletion/model.py index 0689929..009a9c1 100644 --- a/easycompletion/model.py +++ b/easycompletion/model.py @@ -4,6 +4,7 @@ import re import json import ast +import asyncio from dotenv import load_dotenv @@ -153,43 +154,15 @@ def validate_functions(response, functions, function_call, debug=DEBUG): log("Function call is valid", type="success", log=debug) return True -def chat_completion( - messages, - model_failure_retries=5, - model=None, - chunk_length=DEFAULT_CHUNK_LENGTH, - api_key=EASYCOMPLETION_API_KEY, - debug=DEBUG, - temperature=0.0, -): - """ - Function for sending chat messages and returning a chat response. - - Parameters: - messages (str): Messages to send to the model. In the form {: string, : string} - roles are "user" and "assistant" - model_failure_retries (int, optional): Number of retries if the request fails. Default is 5. - model (str, optional): The model to use. Default is the TEXT_MODEL defined in constants.py. - chunk_length (int, optional): Maximum length of text chunk to process. Default is defined in constants.py. - api_key (str, optional): OpenAI API key. If not provided, it uses the one defined in constants.py. - - Returns: - str: The response content from the model. - - Example: - >>> text_completion("Hello, how are you?", model_failure_retries=3, model='gpt-3.5-turbo', chunk_length=1024, api_key='your_openai_api_key') - """ +def sanity_check(prompt, model=None, chunk_length=DEFAULT_CHUNK_LENGTH, api_key=EASYCOMPLETION_API_KEY, debug=DEBUG): # Validate the API key if not api_key.strip(): return {"error": "Invalid OpenAI API key"} openai.api_key = api_key - # Use the default model if no model is specified - if model == None: - model = TEXT_MODEL - # Count tokens in the input text - total_tokens = count_tokens(messages, model=model) + total_tokens = count_tokens(prompt, model=model) # If text is longer than chunk_length and model is not for long texts, switch to the long text model if total_tokens > chunk_length and "16k" not in model: @@ -209,19 +182,27 @@ def chat_completion( "error": "Message too long", } - log(f"Prompt:\n{str(messages)}", type="prompt", log=debug) + if isinstance(prompt, dict): + for key, value in prompt.items(): + if value: + log(f"Prompt {key} ({count_tokens(value)} tokens):\n{str(value)}", type="prompt", log=debug) + else: + log(f"Prompt ({total_tokens} tokens):\n{str(prompt)}", type="prompt", log=debug) + +def do_chat_completion( + messages, model=TEXT_MODEL, temperature=0.8, functions=None, function_call=None, model_failure_retries=5, debug=DEBUG): # Try to make a request for a specified number of times response = None for i in range(model_failure_retries): try: response = openai.ChatCompletion.create( - model=model, messages=messages, temperature=temperature + model=model, messages=messages, temperature=temperature, + functions=functions, function_call=function_call, ) break except Exception as e: log(f"OpenAI Error: {e}", type="error", log=debug) - continue # If response is not valid, print an error message and return None if ( @@ -229,12 +210,105 @@ def chat_completion( or response["choices"] is None or response["choices"][0] is None ): - return { + return None, { "text": None, "usage": None, "finish_reason": None, "error": "Error: Could not get a successful response from OpenAI API", } + return response, None + +def chat_completion( + messages, + model_failure_retries=5, + model=None, + chunk_length=DEFAULT_CHUNK_LENGTH, + api_key=EASYCOMPLETION_API_KEY, + debug=DEBUG, + temperature=0.0, +): + """ + Function for sending chat messages and returning a chat response. + + Parameters: + messages (str): Messages to send to the model. In the form {: string, : string} - roles are "user" and "assistant" + model_failure_retries (int, optional): Number of retries if the request fails. Default is 5. + model (str, optional): The model to use. Default is the TEXT_MODEL defined in constants.py. + chunk_length (int, optional): Maximum length of text chunk to process. Default is defined in constants.py. + api_key (str, optional): OpenAI API key. If not provided, it uses the one defined in constants.py. + + Returns: + str: The response content from the model. + + Example: + >>> text_completion("Hello, how are you?", model_failure_retries=3, model='gpt-3.5-turbo', chunk_length=1024, api_key='your_openai_api_key') + """ + openai.api_key = api_key + + # Use the default model if no model is specified + model = model or TEXT_MODEL + error = sanity_check(messages, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) + if error: + return error + + # Try to make a request for a specified number of times + response, error = do_chat_completion( + model=model, messages=messages, temperature=temperature, model_failure_retries=model_failure_retries, debug=debug) + + if error: + return error + + # Extract content from the response + text = response["choices"][0]["message"]["content"] + finish_reason = response["choices"][0]["finish_reason"] + usage = response["usage"] + + return { + "text": text, + "usage": usage, + "finish_reason": finish_reason, + "error": None, + } + + +async def chat_completion_async( + messages, + model_failure_retries=5, + model=None, + chunk_length=DEFAULT_CHUNK_LENGTH, + api_key=EASYCOMPLETION_API_KEY, + debug=DEBUG, + temperature=0.0, +): + """ + Function for sending chat messages and returning a chat response. + + Parameters: + messages (str): Messages to send to the model. In the form {: string, : string} - roles are "user" and "assistant" + model_failure_retries (int, optional): Number of retries if the request fails. Default is 5. + model (str, optional): The model to use. Default is the TEXT_MODEL defined in constants.py. + chunk_length (int, optional): Maximum length of text chunk to process. Default is defined in constants.py. + api_key (str, optional): OpenAI API key. If not provided, it uses the one defined in constants.py. + + Returns: + str: The response content from the model. + + Example: + >>> text_completion("Hello, how are you?", model_failure_retries=3, model='gpt-3.5-turbo', chunk_length=1024, api_key='your_openai_api_key') + """ + + # Use the default model if no model is specified + model = model or TEXT_MODEL + error = sanity_check(messages, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) + if error: + return error + + # Try to make a request for a specified number of times + response, error = await asyncio.to_thread(lambda: do_chat_completion( + model=model, messages=messages, temperature=temperature, model_failure_retries=model_failure_retries, debug=debug)) + + if error: + return error # Extract content from the response text = response["choices"][0]["message"]["content"] @@ -275,66 +349,74 @@ def text_completion( >>> text_completion("Hello, how are you?", model_failure_retries=3, model='gpt-3.5-turbo', chunk_length=1024, api_key='your_openai_api_key') """ - # Override the API key if provided as parameter - if api_key is not None: - openai.api_key = api_key - # Use the default model if no model is specified - if model == None: - model = TEXT_MODEL + model = model or TEXT_MODEL + error = sanity_check(text, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) + if error: + return error - # Count tokens in the input text - total_tokens = count_tokens(text, model=model) + # Prepare messages for the API call + messages = [{"role": "user", "content": text}] - # If text is longer than chunk_length and model is not for long texts, switch to the long text model - if total_tokens > chunk_length and "16k" not in model: - model = LONG_TEXT_MODEL - if not os.environ.get("SUPPRESS_WARNINGS"): - print( - "Warning: Message is long. Using 16k model (to hide this message, set SUPPRESS_WARNINGS=1)" - ) + # Try to make a request for a specified number of times + response, error = do_chat_completion( + model=model, messages=messages, temperature=temperature, model_failure_retries=model_failure_retries, debug=debug) + if error: + return error - # If text is too long even for long text model, return None - if total_tokens > (16384 - chunk_length): - print("Error: Message too long") - return { - "text": None, - "usage": None, - "finish_reason": None, - "error": "Message too long", - } + # Extract content from the response + text = response["choices"][0]["message"]["content"] + finish_reason = response["choices"][0]["finish_reason"] + usage = response["usage"] + + return { + "text": text, + "usage": usage, + "finish_reason": finish_reason, + "error": None, + } + +async def text_completion_async( + text, + model_failure_retries=5, + model=None, + chunk_length=DEFAULT_CHUNK_LENGTH, + api_key=EASYCOMPLETION_API_KEY, + debug=DEBUG, + temperature=0.0, +): + """ + Function for sending text and returning a text completion response. + + Parameters: + text (str): Text to send to the model. + model_failure_retries (int, optional): Number of retries if the request fails. Default is 5. + model (str, optional): The model to use. Default is the TEXT_MODEL defined in constants.py. + chunk_length (int, optional): Maximum length of text chunk to process. Default is defined in constants.py. + api_key (str, optional): OpenAI API key. If not provided, it uses the one defined in constants.py. + + Returns: + str: The response content from the model. + + Example: + >>> text_completion("Hello, how are you?", model_failure_retries=3, model='gpt-3.5-turbo', chunk_length=1024, api_key='your_openai_api_key') + """ + + # Use the default model if no model is specified + model = model or TEXT_MODEL + error = sanity_check(text, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) + if error: + return error # Prepare messages for the API call messages = [{"role": "user", "content": text}] - log(f"Prompt:\n{text}", type="prompt", log=debug) - # Try to make a request for a specified number of times - response = None - for i in range(model_failure_retries): - try: - response = openai.ChatCompletion.create( - model=model, messages=messages, temperature=temperature - ) - break - except Exception as e: - log(f"OpenAI Error: {e}", type="error", log=debug) - continue - # wait 1 second - time.sleep(1) + response, error = await asyncio.to_thread(lambda: do_chat_completion( + model=model, messages=messages, temperature=temperature, model_failure_retries=model_failure_retries, debug=debug)) - # If response is not valid, print an error message and return None - if ( - response is None - or response["choices"] is None - or response["choices"][0] is None - ): - return { - "text": None, - "usage": None, - "finish_reason": None, - "error": "Error: Could not get a successful response from OpenAI API", - } + if error: + return error # Extract content from the response text = response["choices"][0]["message"]["content"] @@ -387,9 +469,8 @@ def function_completion( >>> function_completion("Call the function.", function) """ - # Check if the user provided an API key - if api_key is not None: - openai.api_key = api_key + # Use the default model if no model is specified + model = model or TEXT_MODEL # Ensure that functions are provided if functions is None: @@ -445,9 +526,11 @@ def function_completion( "error": "function_call had an invalid name. Should be a string of the function name or an object with a name property" } - # Use the default text model if no model is specified - if model is None: - model = TEXT_MODEL + error = sanity_check(dict( + text=text, functions=functions, messages=messages, system_message=system_message + ), model=model, chunk_length=chunk_length, api_key=api_key) + if error: + return error # Count the number of tokens in the message message_tokens = count_tokens(text, model=model) @@ -456,23 +539,174 @@ def function_completion( function_call_tokens = count_tokens(functions, model=model) total_tokens += function_call_tokens + 3 # Additional tokens for the user + all_messages = [] + + if system_message is not None: + all_messages.append({"role": "system", "content": system_message}) + + if messages is not None: + all_messages += messages + + # Prepare the messages to be sent to the API + if text is not None and text != "": + all_messages.append({"role": "user", "content": text}) + + # Retry function call and model calls according to the specified retry counts + response = None + for _ in range(function_failure_retries): + # Try to make a request for a specified number of times + response, error = do_chat_completion( + model=model, messages=all_messages, temperature=temperature, function_call=function_call, + functions=functions, model_failure_retries=model_failure_retries, debug=debug) + if error: + time.sleep(1) + continue + print('***** response') + print(response) + if validate_functions(response, functions, function_call): + break + time.sleep(1) + + # Check if we have a valid response from the model + if not response: + return error + + # Extracting the content and function call response from API response + response_data = response["choices"][0]["message"] + finish_reason = response["choices"][0]["finish_reason"] + usage = response["usage"] + + text = response_data["content"] + function_call_response = response_data.get("function_call", None) + + # If no function call in response, return an error + if function_call_response is None: + log(f"No function call in response\n{response}", type="error", log=debug) + return {"error": "No function call in response"} + function_name = function_call_response["name"] + arguments = parse_arguments(function_call_response["arguments"]) log( - f"Message tokens: {str(message_tokens)}" - + f"\nFunction call tokens: {str(function_call_tokens)}" - + f"\nTotal tokens: {str(total_tokens)}", - type="info", + f"Response\n\nFunction Name: {function_name}\n\nArguments:\n{arguments}\n\nText:\n{text}\n\nFinish Reason: {finish_reason}\n\nUsage:\n{usage}", + type="response", log=debug, ) + # Return the final result with the text response, function name, arguments and no error + return { + "text": text, + "function_name": function_name, + "arguments": arguments, + "usage": usage, + "finish_reason": finish_reason, + "error": None, + } - # Switch to a larger model if the message is too long for the default model - if total_tokens > chunk_length and "16k" not in model: - model = LONG_TEXT_MODEL - log("Warning: Message is long. Using 16k model", type="warning", log=debug) +async def function_completion_async( + text=None, + messages=None, + system_message=None, + functions=None, + model_failure_retries=5, + function_call=None, + function_failure_retries=10, + chunk_length=DEFAULT_CHUNK_LENGTH, + model=None, + api_key=EASYCOMPLETION_API_KEY, + debug=DEBUG, + temperature=0.0, +): + """ + Send text and a list of functions to the model and return optional text and a function call. + The function call is validated against the functions array. + The input text is sent to the chat model and is treated as a user message. - # Check if the total number of tokens exceeds the maximum allowable tokens for the model - if total_tokens > (16384 - chunk_length): - log("Error: Message too long", type="error", log=debug) - return {"error": "Message too long"} + Args: + text (str): Text that will be sent as the user message to the model. + functions (list[dict] | dict | None): List of functions or a single function dictionary to be sent to the model. + model_failure_retries (int): Number of times to retry the request if it fails (default is 5). + function_call (str | dict | None): 'auto' to let the model decide, or a function name or a dictionary containing the function name (default is "auto"). + function_failure_retries (int): Number of times to retry the request if the function call is invalid (default is 10). + chunk_length (int): The length of each chunk to be processed. + model (str | None): The model to use (default is the TEXT_MODEL, i.e. gpt-3.5-turbo). + api_key (str | None): If you'd like to pass in a key to override the environment variable EASYCOMPLETION_API_KEY. + + Returns: + dict: On most errors, returns a dictionary with an "error" key. On success, returns a dictionary containing + "text" (response from the model), "function_name" (name of the function called), "arguments" (arguments for the function), "error" (None). + + Example: + >>> function = {'name': 'function1', 'parameters': {'param1': 'value1'}} + >>> function_completion("Call the function.", function) + """ + + # Use the default model if no model is specified + model = model or TEXT_MODEL + + # Ensure that functions are provided + if functions is None: + return {"error": "functions is required"} + + # Check if a list of functions is provided + if not isinstance(functions, list): + if ( + isinstance(functions, dict) + and "name" in functions + and "parameters" in functions + ): + # A single function is provided as a dictionary, convert it to a list + functions = [functions] + else: + # Functions must be either a list of dictionaries or a single dictionary + return { + "error": "functions must be a list of functions or a single function" + } + + # Set the function call to the name of the function if only one function is provided + # If there are multiple functions, use "auto" + if function_call is None: + function_call = functions[0]["name"] if len(functions) == 1 else "auto" + + # Make sure text is provided + if text is None: + log("Text is required", type="error", log=debug) + return {"error": "text is required"} + + function_call_names = [function["name"] for function in functions] + # check that all function_call_names are unique and in the text + if len(function_call_names) != len(set(function_call_names)): + log("Function names must be unique", type="error", log=debug) + return {"error": "Function names must be unique"} + + if len(function_call_names) > 1 and not any( + function_call_name in text for function_call_name in function_call_names + ): + log( + "WARNING: Function and argument names should be in the text", + type="warning", + log=debug, + ) + + # Check if the function call is valid + if function_call != "auto": + if isinstance(function_call, str): + function_call = {"name": function_call} + elif "name" not in function_call: + log("function_call must have a name property", type="error", log=debug) + return { + "error": "function_call had an invalid name. Should be a string of the function name or an object with a name property" + } + + error = sanity_check(dict( + text=text, functions=functions, messages=messages, system_message=system_message + ), model=model, chunk_length=chunk_length, api_key=api_key) + if error: + return error + + # Count the number of tokens in the message + message_tokens = count_tokens(text, model=model) + total_tokens = message_tokens + + function_call_tokens = count_tokens(functions, model=model) + total_tokens += function_call_tokens + 3 # Additional tokens for the user all_messages = [] @@ -486,52 +720,25 @@ def function_completion( if text is not None and text != "": all_messages.append({"role": "user", "content": text}) - log( - f"Prompt:\n{text}\n\nFunctions:\n{json.dumps(functions, indent=4)}", - type="prompt", - log=debug, - ) - # Retry function call and model calls according to the specified retry counts response = None for _ in range(function_failure_retries): - for _ in range(model_failure_retries): - try: - # If there are function(s) to call - response = openai.ChatCompletion.create( - model=model, - messages=all_messages, - functions=functions, - function_call=function_call, - temperature=temperature, - ) - print('***** openai response') - print(response) - if not response.get("choices") or response["choices"][0] is None: - log("No choices in response", type="error", log=debug) - continue - break - except Exception as e: - print('**** ERROR') - print(e) - log(f"OpenAI Error: {e}", type="error", log=debug) + # Try to make a request for a specified number of times + response, error = await asyncio.to_thread(lambda: do_chat_completion( + model=model, messages=all_messages, temperature=temperature, function_call=function_call, + functions=functions, model_failure_retries=model_failure_retries, debug=debug)) + if error: time.sleep(1) - # Check if we have a valid response from the model + continue print('***** response') print(response) - print('***** functions') - print(functions) - print('***** function_call') - print(function_call) if validate_functions(response, functions, function_call): break time.sleep(1) # Check if we have a valid response from the model - if response is None: - error = "Could not get a successful response from the model. Check your API key and arguments." - log(error, type="error", log=True) - return {"error": error} + if not response: + return error # Extracting the content and function call response from API response response_data = response["choices"][0]["message"] diff --git a/easycompletion/prompt.py b/easycompletion/prompt.py index 8254b9e..0569142 100644 --- a/easycompletion/prompt.py +++ b/easycompletion/prompt.py @@ -108,6 +108,12 @@ def count_tokens(prompt: str, model=TEXT_MODEL) -> int: count_tokens("This is a test.") Output: 5 """ + if not prompt: + return 0 + if isinstance(prompt, (list, tuple)): + return sum(count_tokens(p, model) for p in prompt) + if isinstance(prompt, dict): + return sum(count_tokens(v) for v in prompt.values()) if not isinstance(prompt, str): prompt = str(prompt)