Skip to content

Commit

Permalink
Common LLM Wrapper (#7)
Browse files Browse the repository at this point in the history
* add common llm wrapper, rename openai module to openai_api
  • Loading branch information
DavidKoleczek authored May 8, 2024
1 parent ec94f4d commit c960c71
Show file tree
Hide file tree
Showing 23 changed files with 310 additions and 46 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
OPENAI_ORG_ID: ${{ secrets.OPENAI_ORG_ID }}
SKIP_TESTS_NAAI: "tests/llm/ollama"
SKIP_TESTS_NAAI: "tests/llm/ollama tests/llm/test_chat_completion.py"
run: poetry run nox -s test-${{ matrix.python-version }}
quality:
runs-on: ubuntu-22.04
Expand Down
156 changes: 156 additions & 0 deletions notebooks/llm/common_chat_completion.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Common Chat Completion\n",
"\n",
"This notebook covers not-again-ai's common abstraction around multiple chat completion model providers. \n",
"\n",
"Currently the supported providers are the [OpenAI API](https://openai.com/api) and [Ollama](https://github.com/ollama/ollama)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Choosing a Client"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from not_again_ai.llm.ollama.ollama_client import ollama_client\n",
"from not_again_ai.llm.openai_api.openai_client import openai_client\n",
"\n",
"client_openai = openai_client()\n",
"client_ollama = ollama_client()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define common variables we can try sending to different provider/model combinations."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"messages = [\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Generate a random number between 0 and 100 and structure the response in using JSON.\",\n",
" },\n",
"]\n",
"\n",
"max_tokens = 200\n",
"temperature = 2\n",
"json_mode = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use the OpenAI Client"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'message': {'randomNumber': 57},\n",
" 'completion_tokens': 10,\n",
" 'extras': {'prompt_tokens': 35, 'finish_reason': 'stop'}}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from not_again_ai.llm.chat_completion import chat_completion\n",
"\n",
"chat_completion(\n",
" messages=messages,\n",
" model=\"gpt-3.5-turbo\",\n",
" client=client_openai,\n",
" max_tokens=max_tokens,\n",
" temperature=temperature,\n",
" json_mode=json_mode,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Use the Ollama Client"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'message': {'random_number': 47},\n",
" 'completion_tokens': None,\n",
" 'extras': {'response_duration': None}}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat_completion(\n",
" messages=messages,\n",
" model=\"phi3\",\n",
" client=client_ollama,\n",
" max_tokens=max_tokens,\n",
" temperature=temperature,\n",
" json_mode=json_mode,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
16 changes: 8 additions & 8 deletions notebooks/llm/gpt-4-v.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
"source": [
"from pathlib import Path\n",
"\n",
"from not_again_ai.llm.openai.prompts import chat_prompt\n",
"from not_again_ai.llm.openai_api.prompts import chat_prompt\n",
"\n",
"sk_infographic = Path.cwd().parent.parent / \"tests\" / \"llm\" / \"sample_images\" / \"SKInfographic.png\"\n",
"sk_diagram = Path.cwd().parent.parent / \"tests\" / \"llm\" / \"sample_images\" / \"SKDiagram.png\"\n",
Expand Down Expand Up @@ -164,7 +164,7 @@
{
"data": {
"text/plain": [
"ChatCompletion(id='chatcmpl-9LYbe6AAOTg0B1EPYQZjwPV0lYDuE', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Semantic Kernel integrates various AI services and plugins to process and contextualize inputs, enhancing the precision of its output by selecting optimal models and templates based on the contextualized information it gathers.', role='assistant', function_call=None, tool_calls=None))], created=1714924942, model='gpt-4-turbo-2024-04-09', object='chat.completion', system_fingerprint='fp_5d12056990', usage=CompletionUsage(completion_tokens=36, prompt_tokens=1565, total_tokens=1601))"
"ChatCompletion(id='chatcmpl-9MOpZiTg5uUclrUH9EbvY9xGVcCm0', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Semantic Kernel works by integrating multiple plugins and APIs to recall memory, create plans, and perform custom tasks while maintaining context, ultimately providing a tailored output based on specific AI services and monitored interactions.', role='assistant', function_call=None, tool_calls=None))], created=1715125693, model='gpt-4-turbo-2024-04-09', object='chat.completion', system_fingerprint='fp_5d12056990', usage=CompletionUsage(completion_tokens=38, prompt_tokens=1565, total_tokens=1603))"
]
},
"execution_count": 2,
Expand Down Expand Up @@ -207,10 +207,10 @@
"data": {
"text/plain": [
"{'choices': [{'finish_reason': 'stop',\n",
" 'message': 'The Semantic Kernel works by selecting and invoking AI services based on a rendered prompt, which it uses to parse responses and create function results, integrating these processes with telemetry, monitoring, and event notifications for reliable and responsible AI applications.'},\n",
" 'message': 'The Semantic Kernel functions by selecting and invoking AI services based on a model, rendering prompts, and parsing responses, while managing reliability and telemetry through a centralized kernel that integrates with applications via event notifications and result creation.'},\n",
" {'finish_reason': 'stop',\n",
" 'message': 'Semantic Kernel operates by selecting and invoking AI services based on rendered prompts, managing reliability and model selection, and parsing responses to create functional results, all coordinated through a central kernel that also handles event notifications and telemetry monitoring.'}],\n",
" 'completion_tokens': 88,\n",
" 'message': \"The Semantic Kernel functions by selecting and invoking AI services based on rendered prompts, managing responses through parsing and templating, and integrating reliability, telemetry, and monitoring within the kernel's architecture to ensure responsible AI usage.\"}],\n",
" 'completion_tokens': 84,\n",
" 'prompt_tokens': 1565,\n",
" 'system_fingerprint': 'fp_5d12056990'}"
]
Expand All @@ -221,8 +221,8 @@
}
],
"source": [
"from not_again_ai.llm.openai.chat_completion import chat_completion\n",
"from not_again_ai.llm.openai.openai_client import openai_client\n",
"from not_again_ai.llm.openai_api.chat_completion import chat_completion\n",
"from not_again_ai.llm.openai_api.openai_client import openai_client\n",
"\n",
"client = openai_client()\n",
"\n",
Expand All @@ -241,7 +241,7 @@
"data": {
"text/plain": [
"{'finish_reason': 'stop',\n",
" 'message': 'The Semantic Kernel works by selecting and invoking AI services based on a rendered prompt, which it uses to parse responses and create function results, integrating these processes with telemetry, monitoring, and event notifications for reliable and responsible AI applications.'}"
" 'message': 'The Semantic Kernel functions by selecting and invoking AI services based on a model, rendering prompts, and parsing responses, while managing reliability and telemetry through a centralized kernel that integrates with applications via event notifications and result creation.'}"
]
},
"execution_count": 4,
Expand Down
10 changes: 5 additions & 5 deletions notebooks/llm/ollama_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@
{
"data": {
"text/plain": [
"{'message': \"Hello there! How can I assist you today? If you have any questions or need information on various topics, feel free to ask. I'm here to help!\",\n",
" 'completion_tokens': 35,\n",
" 'response_duration': 2.74299}"
"{'message': 'Hello there! How can I assist you today?\\n',\n",
" 'completion_tokens': 12,\n",
" 'response_duration': 2.48722}"
]
},
"execution_count": 2,
Expand Down Expand Up @@ -150,7 +150,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[{'name': 'phi3:latest', 'model': 'phi3:latest', 'modified_at': '2024-05-05T18:06:07.598394632Z', 'size': 2318920898, 'size_readable': '2.16 GB', 'details': {'parent_model': '', 'format': 'gguf', 'family': 'llama', 'families': ['llama'], 'parameter_size': '4B', 'quantization_level': 'Q4_K_M'}}]\n"
"[{'name': 'llama3:8b', 'model': 'llama3:8b', 'modified_at': '2024-05-07T12:50:01.227242442Z', 'size': 4661224578, 'size_readable': '4.34 GB', 'details': {'parent_model': '', 'format': 'gguf', 'family': 'llama', 'families': ['llama'], 'parameter_size': '8B', 'quantization_level': 'Q4_0'}}, {'name': 'phi3:latest', 'model': 'phi3:latest', 'modified_at': '2024-05-05T18:06:59.995187156Z', 'size': 2318920898, 'size_readable': '2.16 GB', 'details': {'parent_model': '', 'format': 'gguf', 'family': 'llama', 'families': ['llama'], 'parameter_size': '4B', 'quantization_level': 'Q4_K_M'}}]\n"
]
}
],
Expand Down Expand Up @@ -192,7 +192,7 @@
"data": {
"text/plain": [
"{'modelfile': '# Modelfile generated by \"ollama show\"\\n# To build a new Modelfile based on this one, replace the FROM line with:\\n# FROM phi3:latest\\n\\nFROM /usr/share/ollama/.ollama/models/blobs/sha256-4fed7364ee3e0c7cb4fe0880148bfdfcd1b630981efa0802a6b62ee52e7da97e\\nTEMPLATE \"\"\"{{ if .System }}<|system|>\\n{{ .System }}<|end|>\\n{{ end }}{{ if .Prompt }}<|user|>\\n{{ .Prompt }}<|end|>\\n{{ end }}<|assistant|>\\n{{ .Response }}<|end|>\\n\"\"\"\\nPARAMETER num_keep 4\\nPARAMETER stop \"<|user|>\"\\nPARAMETER stop \"<|assistant|>\"\\nPARAMETER stop \"<|system|>\"\\nPARAMETER stop \"<|end|>\"\\nPARAMETER stop \"<|endoftext|>\"',\n",
" 'parameters': 'num_keep 4\\nstop \"<|user|>\"\\nstop \"<|assistant|>\"\\nstop \"<|system|>\"\\nstop \"<|end|>\"\\nstop \"<|endoftext|>\"',\n",
" 'parameters': 'stop \"<|user|>\"\\nstop \"<|assistant|>\"\\nstop \"<|system|>\"\\nstop \"<|end|>\"\\nstop \"<|endoftext|>\"\\nnum_keep 4',\n",
" 'template': '{{ if .System }}<|system|>\\n{{ .System }}<|end|>\\n{{ end }}{{ if .Prompt }}<|user|>\\n{{ .Prompt }}<|end|>\\n{{ end }}<|assistant|>\\n{{ .Response }}<|end|>\\n',\n",
" 'details': {'parent_model': '',\n",
" 'format': 'gguf',\n",
Expand Down
11 changes: 5 additions & 6 deletions notebooks/llm/openai_chat_completion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"metadata": {},
"outputs": [],
"source": [
"from not_again_ai.llm.openai.openai_client import openai_client\n",
"from not_again_ai.llm.openai_api.openai_client import openai_client\n",
"\n",
"client = openai_client()"
]
Expand Down Expand Up @@ -59,7 +59,7 @@
}
],
"source": [
"from not_again_ai.llm.openai.chat_completion import chat_completion\n",
"from not_again_ai.llm.openai_api.chat_completion import chat_completion\n",
"\n",
"messages = [{\"role\": \"system\", \"content\": \"You are a helpful assistant.\"}, {\"role\": \"user\", \"content\": \"Hello!\"}]\n",
"response = chat_completion(messages=messages, model=\"gpt-3.5-turbo\", max_tokens=100, client=client)\n",
Expand Down Expand Up @@ -98,7 +98,7 @@
}
],
"source": [
"from not_again_ai.llm.openai.prompts import chat_prompt\n",
"from not_again_ai.llm.openai_api.prompts import chat_prompt\n",
"\n",
"place_extraction_prompt = [\n",
" {\n",
Expand Down Expand Up @@ -147,7 +147,7 @@
}
],
"source": [
"from not_again_ai.llm.openai.tokens import num_tokens_from_messages\n",
"from not_again_ai.llm.openai_api.tokens import num_tokens_from_messages\n",
"\n",
"num_tokens = num_tokens_from_messages(messages=messages, model=\"gpt-3.5-turbo\")\n",
"print(num_tokens)"
Expand Down Expand Up @@ -180,8 +180,7 @@
" 'tool_names': ['get_current_weather'],\n",
" 'tool_args_list': [{'location': 'Boston, MA', 'format': 'celsius'}]}],\n",
" 'completion_tokens': 40,\n",
" 'prompt_tokens': 105,\n",
" 'system_fingerprint': 'fp_3b956da36b'}"
" 'prompt_tokens': 105}"
]
},
"execution_count": 5,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "not-again-ai"
version = "0.7.0"
version = "0.8.0"
description = "Designed to once and for all collect all the little things that come up over and over again in AI projects and put them in one place."
authors = ["DaveCoDev <[email protected]>"]
license = "MIT"
Expand Down
76 changes: 76 additions & 0 deletions src/not_again_ai/llm/chat_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Any

from ollama import Client
from openai import OpenAI

from not_again_ai.llm.ollama import chat_completion as chat_completion_ollama
from not_again_ai.llm.openai_api import chat_completion as chat_completion_openai


def chat_completion(
messages: list[dict[str, Any]],
model: str,
client: OpenAI | Client,
max_tokens: int | None = None,
temperature: float = 0.7,
json_mode: bool = False,
seed: int | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Creates a common wrapper around chat completion models from different providers.
Currently supports the OpenAI API and Ollama local models.
All input parameters are supported by all providers in similar ways and the output is standardized.
Args:
messages (list[dict[str, Any]]): A list of messages to send to the model.
model (str): The model name to use.
client (OpenAI | Client): The client object to use for chat completion.
max_tokens (int, optional): The maximum number of tokens to generate.
temperature (float, optional): The temperature of the model. Increasing the temperature will make the model answer more creatively.
json_mode (bool, optional): This will structure the response as a valid JSON object.
seed (int, optional): The seed to use for the model for reproducible outputs.
Returns:
dict[str, Any]: A dictionary with the following keys
message (str | dict): The content of the generated assistant message.
If json_mode is True, this will be a dictionary.
completion_tokens (int): The number of tokens used by the model to generate the completion.
extras (dict): This will contain any additional fields returned by corresponding provider.
"""
# Determine which chat_completion function to call based on the client type
if isinstance(client, OpenAI):
response = chat_completion_openai.chat_completion(
messages=messages,
model=model,
client=client,
max_tokens=max_tokens,
temperature=temperature,
json_mode=json_mode,
seed=seed,
**kwargs,
)
elif isinstance(client, Client):
response = chat_completion_ollama.chat_completion(
messages=messages,
model=model,
client=client,
max_tokens=max_tokens,
temperature=temperature,
json_mode=json_mode,
seed=seed,
**kwargs,
)
else:
raise ValueError("Invalid client type")

# Parse the responses to be consistent
response_data = {}
response_data["message"] = response.get("message", None)
response_data["completion_tokens"] = response.get("completion_tokens", None)

# Return any additional fields from the response in an "extras" dictionary
extras = {k: v for k, v in response.items() if k not in response_data}
if extras:
response_data["extras"] = extras

return response_data
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy

from not_again_ai.llm.openai.tokens import num_tokens_from_messages, truncate_str
from not_again_ai.llm.openai_api.tokens import num_tokens_from_messages, truncate_str


def _inject_variable(
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit c960c71

Please sign in to comment.