Skip to content

Commit

Permalink
[Python] Openai Wrapper improvements
Browse files Browse the repository at this point in the history
## What

1. Save outputs
2. Allow passing a file path
3. Allow passing an AIConfig object
4. Persist output with streaming enabled (pass through streaming)

Notes:
- With streaming, we can either capture the response and not return it to the user or use pass through streaming, but, if the stream completion is not fully iterated through, ie the user doesn't touch the completion, completion args doesn't get serialized. Can't do both due to the nature of what it entails
- - chose to go with pass through streaming

- If the completion has a prompt that is already in the config (same input, settings, etc) but different outputs, the outputs are overriden with the new one

- if one completion is streamed, and another is not, those are considered different prompts and get serialized as such


## Why

Wrapper needs to be customizable and flexible
  • Loading branch information
Ankush Pala [email protected] committed Nov 10, 2023
1 parent d6a2314 commit 6806c16
Showing 1 changed file with 152 additions and 9 deletions.
161 changes: 152 additions & 9 deletions python/src/aiconfig/ChatCompletion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,115 @@
import openai
from aiconfig.ChatCompletion import create_and_save_to_config
openai.ChatCompletion.create = create_and_save_to_config
openai.ChatCompletion.create = create_and_save_to_config('my-aiconfig.json')
```
"""

from aiconfig.schema import ExecuteResult
import copy
import types
from typing import Dict, List, Optional
from aiconfig.default_parsers.openai import multi_choice_message_reducer
from aiconfig.schema import ExecuteResult, Output, Prompt
from aiconfig.Config import AIConfigRuntime
import openai
import asyncio

aiconfig = AIConfigRuntime.create("")
import nest_asyncio

nest_asyncio.apply()

openai_chat_completion_create = openai.ChatCompletion.create


def create_and_save_to_config(*args, **kwargs):
response = openai_chat_completion_create(*args, **kwargs)
def create_and_save_to_config(
config_file_path: Optional[str] = None, aiconfig: Optional[AIConfigRuntime] = None
):
"""
Overrides OpenAI's ChatCompletion.create method to serialize prompts and save them along with their outputs to a configuration file.
Args:
file_path (str, optional): Path to the configuration file.
aiconfig (AIConfigRuntime, optional): An instance of AIConfigRuntime to be used.
Returns:
A modified version of the OpenAI ChatCompletion.create function, with additional logic to handle prompt serialization and configuration saving.
"""
if aiconfig is None:
try:
aiconfig = AIConfigRuntime.load(config_file_path)
except:
aiconfig = AIConfigRuntime.create()

def _create_chat_completion_with_config_saving(*args, **kwargs):
response = openai_chat_completion_create(*args, **kwargs)

serialized_prompts = async_run_serialize_helper(aiconfig, kwargs)

# serialize output from response
outputs = []

# Check if response is a stream
stream = kwargs.get("stream", False) is True and isinstance(response, types.GeneratorType)

# Convert Response to output for last prompt
if not stream:
outputs = extract_outputs_from_response(response)

# Add outputs to last prompt
serialized_prompts[-1].outputs = outputs

validate_and_add_prompts_to_config(serialized_prompts, aiconfig)

# Save config to file
aiconfig.save(config_file_path, include_outputs=True)

# Return original response
return response
else:
# If response is a stream, build the output as the stream iterated through. do_logic() becomes a generator.

def generate_streamed_response():
stream_outputs = {}
messages = {}
for chunk in response:
# streaming only returns one chunk, one choice at a time. The order in which the choices are returned is not guaranteed.
messages = multi_choice_message_reducer(messages, chunk)

for i, choice in enumerate(chunk["choices"]):
index = choice.get("index")
accumulated_message_for_choice = messages.get(index, {})
output = ExecuteResult(
**{
"output_type": "execute_result",
"data": copy.deepcopy(accumulated_message_for_choice),
"execution_count": index,
"metadata": {"finish_reason": choice.get("finish_reason")},
}
)
stream_outputs[index] = output
yield chunk
stream_outputs = [stream_outputs[i] for i in sorted(list(stream_outputs.keys()))]

# Add outputs to last prompt
serialized_prompts[-1].outputs = stream_outputs

validate_and_add_prompts_to_config(serialized_prompts, aiconfig)

prompts = asyncio.run(aiconfig.serialize(kwargs.get("model"), kwargs))
# Save config to file
aiconfig.save(config_file_path, include_outputs=True)

return generate_streamed_response()

return _create_chat_completion_with_config_saving


def validate_and_add_prompts_to_config(prompts: List[Prompt], aiconfig) -> None:
"""
Validates and adds new prompts to the AI configuration, ensuring no duplicates and updating outputs if necessary.
Args:
prompts (List[Prompt]): List of prompts to be validated and added.
aiconfig (AIConfigRuntime): Configuration runtime instance to which the prompts are to be added.
"""
for i, new_prompt in enumerate(prompts):
in_config = False
for config_prompt in aiconfig.prompts:
Expand All @@ -38,11 +129,63 @@ def create_and_save_to_config(*args, **kwargs):
and new_prompt.metadata == config_prompt.metadata
):
in_config = True
# update outputs if different
if config_prompt.outputs != new_prompt.outputs:
config_prompt.outputs = new_prompt.outputs
break
if not in_config:
new_prompt_name = "prompt_{}".format(str(len(aiconfig.prompts)))
new_prompt.name = new_prompt_name
aiconfig.add_prompt(new_prompt.name, new_prompt)
aiconfig.save(include_outputs=False)

return response

def extract_outputs_from_response(response) -> List[Output]:
"""
Extracts outputs from the OpenAI ChatCompletion response and transforms them into a structured format.
Args:
response (dict): The response dictionary received from OpenAI's ChatCompletion.
Returns:
List[Output]: A list of outputs extracted and formatted from the response.
"""
outputs = []

response_without_choices = {
key: copy.deepcopy(value) for key, value in response.items() if key != "choices"
}
for i, choice in enumerate(response.get("choices")):
response_without_choices.update({"finish_reason": choice.get("finish_reason")})
output = ExecuteResult(
**{
"output_type": "execute_result",
"data": choice["message"],
"execution_count": i,
"metadata": response_without_choices,
}
)

outputs.append(output)
return outputs


def async_run_serialize_helper(aiconfig: AIConfigRuntime, request_kwargs: Dict):
"""
Method serialize() of AIConfig is an async method. If not, create a new one and await serialize().
"""
in_event_loop = asyncio.get_event_loop().is_running()

serialized_prompts = None

async def run_and_await_serialize():
result = await aiconfig.serialize(request_kwargs.get("model"), request_kwargs, "prompt")
return result

# serialize prompts from ChatCompletion kwargs
if in_event_loop:
event_loop = asyncio.get_event_loop()
serialized_prompts = event_loop.run_until_complete(run_and_await_serialize())

else:
serialized_prompts = asyncio.run(run_and_await_serialize())
return serialized_prompts

0 comments on commit 6806c16

Please sign in to comment.