Skip to content

Commit

Permalink
Merge pull request #5 from DaveCoDev/DavidKoleczke/gpt-4-v
Browse files Browse the repository at this point in the history
GPT-4-V Support
  • Loading branch information
DavidKoleczek authored Apr 6, 2024
2 parents 8027c09 + 24604ea commit ab9f00c
Show file tree
Hide file tree
Showing 12 changed files with 828 additions and 291 deletions.
509 changes: 255 additions & 254 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "not-again-ai"
version = "0.4.5"
version = "0.5.0"
description = "Designed to once and for all collect all the little things that come up over and over again in AI projects and put them in one place."
authors = ["DaveCoDev <[email protected]>"]
license = "MIT"
Expand Down Expand Up @@ -28,10 +28,10 @@ python = "^3.11, <3.13"

# Optional dependencies are defined here, and groupings are defined below.
numpy = { version = "^1.26.4", optional = true }
openai = { version = "^1.14.3", optional = true }
openai = { version = "^1.16.2", optional = true }
pandas = { version = "^2.2.1", optional = true }
python-liquid = { version = "^1.12.1", optional = true }
scipy = { version = "^1.12.0", optional = true }
scipy = { version = "^1.13.0", optional = true }
scikit-learn = { version = "^1.4.1.post1", optional = true }
seaborn = { version = "^0.13.2", optional = true }
tiktoken = { version = "^0.6.0", optional = true }
Expand Down
4 changes: 1 addition & 3 deletions src/not_again_ai/llm/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def chat_completion(
finish_reason = response_choice.finish_reason
response_data_curr["finish_reason"] = finish_reason

# Not checking finish_reason=="tool_calls" here because when a user provides function name as tool_choice,
# the finish reason is "stop", not "tool_calls"
# We first check for tool calls because even if the finish_reason is stop, the model may have called a tool
tool_calls = response_choice.message.tool_calls
if tool_calls:
tool_names = []
Expand Down Expand Up @@ -159,7 +158,6 @@ def chat_completion(
response_data["system_fingerprint"] = response.system_fingerprint

if len(response_data["choices"]) == 1:
# Add all the fields in the first choice dict to the response_data dict
response_data.update(response_data["choices"][0])
del response_data["choices"]

Expand Down
88 changes: 88 additions & 0 deletions src/not_again_ai/llm/chat_completion_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Any

from openai import OpenAI


def chat_completion_vision(
messages: list[dict[str, Any]],
model: str,
client: OpenAI,
max_tokens: int | None = None,
temperature: float = 0.7,
seed: int | None = None,
n: int = 1,
**kwargs: Any,
) -> dict[str, Any]:
"""Get an OpenAI chat completion response for vision models only: https://platform.openai.com/docs/guides/vision
Args:
messages (list): A list of messages comprising the conversation so far.
See https://platform.openai.com/docs/api-reference/chat/create for details on the format
model (str): ID of the model to use for generating chat completions. Refer to OpenAI's documentation
for details on available models.
client (OpenAI): An instance of the OpenAI client, used to make requests to the API.
max_tokens (int | None, optional): The maximum number of tokens to generate in the chat completion.
If None, defaults to the model's maximum context length. Defaults to None.
temperature (float, optional): Controls the randomness of the output. A higher temperature produces
more varied results, whereas a lower temperature results in more deterministic and predictable text.
Must be between 0 and 2. Defaults to 0.7.
seed (int | None, optional): A seed used for deterministic generation. Providing a seed ensures that
the same input will produce the same output across different runs. Defaults to None.
n (int, optional): The number of chat completion choices to generate for each input message.
Defaults to 1.
**kwargs (Any): Additional keyword arguments to pass to the OpenAI client chat completion method.
Returns:
dict[str, Any]: A dictionary containing the generated responses and metadata. Key components include:
'finish_reason' (str): The reason the model stopped generating further tokens.
Can be 'stop' or 'length'
'tool_names' (list[str], optional): The names of the tools called by the model.
'tool_args_list' (list[dict], optional): The arguments of the tools called by the model.
'message' (str | dict): The content of the generated assistant message.
'choices' (list[dict], optional): A list of chat completion choices if n > 1 where each dict contains the above fields.
'completion_tokens' (int): The number of tokens used by the model to generate the completion.
NOTE: If n > 1 this is the sum of all completions and thus will be same value in each dict.
'prompt_tokens' (int): The number of tokens in the generated response.
NOTE: If n > 1 this is the sum of all completions and thus will be same value in each dict.
'system_fingerprint' (str, optional): If seed is set, a unique identifier for the model used to generate the response.
"""
kwargs.update(
{
"messages": messages,
"model": model,
"max_tokens": max_tokens,
"temperature": temperature,
"n": n,
}
)

if seed is not None:
kwargs["seed"] = seed

response = client.chat.completions.create(**kwargs)

response_data: dict[str, Any] = {"choices": []}
for response_choice in response.choices:
response_data_curr = {}
finish_reason = response_choice.finish_reason
response_data_curr["finish_reason"] = finish_reason

if finish_reason == "stop" or finish_reason == "length":
message = response_choice.message.content
response_data_curr["message"] = message

response_data["choices"].append(response_data_curr)

usage = response.usage
if usage is not None:
response_data["completion_tokens"] = usage.completion_tokens
response_data["prompt_tokens"] = usage.prompt_tokens

if seed is not None and response.system_fingerprint is not None:
response_data["system_fingerprint"] = response.system_fingerprint

if len(response_data["choices"]) == 1:
response_data.update(response_data["choices"][0])
del response_data["choices"]

return response_data
171 changes: 167 additions & 4 deletions src/not_again_ai/llm/prompts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import base64
from copy import deepcopy
import mimetypes
from pathlib import Path
from typing import Any

from liquid import Template


Expand All @@ -8,12 +14,12 @@ def _validate_message(message: dict[str, str]) -> bool:
valid_fields = ["role", "content", "name", "tool_call_id", "tool_calls"]
# Check if the only keys in the message are in valid_fields
if not all(key in valid_fields for key in message):
return False
raise ValueError(f"Message contains invalid fields: {message.keys()}")

# Check if the only roles in the message are in valid_fields
valid_roles = ["system", "user", "assistant", "tool"]
if message["role"] not in valid_roles:
return False
raise ValueError(f"Message contains invalid role: {message['role']}")

return True

Expand Down Expand Up @@ -46,12 +52,169 @@ def chat_prompt(messages_unformatted: list[dict[str, str]], variables: dict[str,
]
"""

messages_formatted = messages_unformatted.copy()
messages_formatted = deepcopy(messages_unformatted)
for message in messages_formatted:
if not _validate_message(message):
raise ValueError(f"Invalid message: {message}")
raise ValueError()

liquid_template = Template(message["content"])
message["content"] = liquid_template.render(**variables)

return messages_formatted


def _validate_message_vision(message: dict[str, list[dict[str, Path | str]] | str]) -> bool:
"""Validates that a message for a vision model is valid"""
valid_fields = ["role", "content"]
if not all(key in valid_fields for key in message):
raise ValueError(f"Message contains invalid fields: {message.keys()}")

valid_roles = ["system", "user", "assistant"]
if message["role"] not in valid_roles:
raise ValueError(f"Message contains invalid role: {message['role']}")

if not isinstance(message["content"], list) and not isinstance(message["content"], str):
raise ValueError(f"content must be a list of dictionaries or a string: {message['content']}")

if isinstance(message["content"], list):
for content_part in message["content"]:
if isinstance(content_part, dict):
if "image" not in content_part:
raise ValueError(f"Dictionary content part must contain 'image' key: {content_part}")
if "detail" in content_part and content_part["detail"] not in ["low", "high"]:
raise ValueError(f"Optional 'detail' key must be 'low' or 'high': {content_part['detail']}")
elif not isinstance(content_part, str):
raise ValueError(f"content_part must be a dictionary or a string: {content_part}")

return True


def encode_image(image_path: Path) -> str:
"""Encodes an image file at the given Path to base64.
Args:
image_path: The path to the image file to encode.
Returns:
The base64 encoded image as a string.
"""
with Path.open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")


def create_image_url(image_path: Path) -> str:
"""Creates a data URL for an image file at the given Path.
Args:
image_path: The path to the image file to encode.
Returns:
The data URL for the image.
"""
image_data = encode_image(image_path)

valid_mime_types = ["image/jpeg", "image/png", "image/webp", "image/gif"]

# Get the MIME type from the image file extension
mime_type = mimetypes.guess_type(image_path)[0]

# Check if the MIME type is valid
# List of valid types is here: https://platform.openai.com/docs/guides/vision/what-type-of-files-can-i-upload
if mime_type not in valid_mime_types:
raise ValueError(f"Invalid MIME type for image: {mime_type}")

return f"data:{mime_type};base64,{image_data}"


def chat_prompt_vision(messages_unformatted: list[dict[str, Any]], variables: dict[str, str]) -> list[dict[str, Any]]:
"""Formats a list of messages for OpenAI's chat completion API for vision models only using Liquid templating.
Args:
messages_unformatted (list[dict[str, list[dict[str, Path | str]] | str]]):
A list of dictionaries where each dictionary represents a message.
Each message must have 'role' and 'content' keys. `role` must be 'system', 'user', or 'assistant'.
`content` can be a Liquid template string or a list of dictionaries where each dictionary
represents a content part. Each content part can be a string or a dictionary with 'image' and 'detail' keys.
The 'image' key must be a Path or a string representing a URL. The 'detail' key is optional and must be 'low' or 'high'.
variables: A dictionary where each key-value pair represents a variable
name and its value for template rendering.
Returns:
A list which represents messages in the format that OpenAI expects for its chat completions API.
See here for details: https://platform.openai.com/docs/api-reference/chat/create
Examples:
>>> # Assume cat_image and dog_image are Path objects to image files
>>> messages = [
... {"role": "system", "content": "You are a helpful assistant."},
... {
... "role": "user",
... "content": ["Describe the animal in the image in one word.", {"image": cat_image, "detail": "low"}],
... }
... {"role": "assistant", "content": "{{ answer }}"},
... {
... "role": "user",
... "content": ["What about this animal?", {"image": dog_image, "detail": "high"}],
... }
... ]
>>> vars = {"answer": "Cat"}
>>> chat_prompt(messages, vars)
[
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [
{"type": "text", "text": "Describe the animal in the image in one word."},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,<encoding>", "detail": "low"},
},
],
},
{"role": "assistant", "content": "Cat"},
{
"role": "user",
"content": [
{"type": "text", "text": "What about this animal?"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,<encoding>", "detail": "high"},
},
],
},
]
"""
messages_formatted = deepcopy(messages_unformatted)
for message in messages_formatted:
if not _validate_message_vision(message):
raise ValueError()

if isinstance(message["content"], list):
for i in range(len(message["content"])):
content_part = message["content"][i]
if isinstance(content_part, dict):
image_path = content_part["image"]
if isinstance(image_path, Path):
temp_content_part: dict[str, Any] = {
"type": "image_url",
"image_url": {
"url": create_image_url(image_path),
},
}
if "detail" in content_part:
temp_content_part["image_url"]["detail"] = content_part["detail"]
elif isinstance(image_path, str):
# Assume its a valid URL
pass
else:
raise ValueError(f"Image path must be a Path or str: {image_path}")
message["content"][i] = temp_content_part
elif isinstance(content_part, str):
message["content"][i] = {
"type": "text",
"text": Template(content_part).render(**variables),
}
elif isinstance(message["content"], str):
message["content"] = Template(message["content"]).render(**variables)

return messages_formatted
Binary file added tests/llm/sample_images/SKDiagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/llm/sample_images/SKInfographic.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/llm/sample_images/cat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/llm/sample_images/dog.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/llm/sample_images/numbers.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit ab9f00c

Please sign in to comment.