-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add common llm wrapper, rename openai module to openai_api
- Loading branch information
1 parent
ec94f4d
commit c960c71
Showing
23 changed files
with
310 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Common Chat Completion\n", | ||
"\n", | ||
"This notebook covers not-again-ai's common abstraction around multiple chat completion model providers. \n", | ||
"\n", | ||
"Currently the supported providers are the [OpenAI API](https://openai.com/api) and [Ollama](https://github.com/ollama/ollama)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Choosing a Client" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from not_again_ai.llm.ollama.ollama_client import ollama_client\n", | ||
"from not_again_ai.llm.openai_api.openai_client import openai_client\n", | ||
"\n", | ||
"client_openai = openai_client()\n", | ||
"client_ollama = ollama_client()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Define common variables we can try sending to different provider/model combinations." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"messages = [\n", | ||
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", | ||
" {\n", | ||
" \"role\": \"user\",\n", | ||
" \"content\": \"Generate a random number between 0 and 100 and structure the response in using JSON.\",\n", | ||
" },\n", | ||
"]\n", | ||
"\n", | ||
"max_tokens = 200\n", | ||
"temperature = 2\n", | ||
"json_mode = True" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Use the OpenAI Client" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'message': {'randomNumber': 57},\n", | ||
" 'completion_tokens': 10,\n", | ||
" 'extras': {'prompt_tokens': 35, 'finish_reason': 'stop'}}" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"from not_again_ai.llm.chat_completion import chat_completion\n", | ||
"\n", | ||
"chat_completion(\n", | ||
" messages=messages,\n", | ||
" model=\"gpt-3.5-turbo\",\n", | ||
" client=client_openai,\n", | ||
" max_tokens=max_tokens,\n", | ||
" temperature=temperature,\n", | ||
" json_mode=json_mode,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Use the Ollama Client" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'message': {'random_number': 47},\n", | ||
" 'completion_tokens': None,\n", | ||
" 'extras': {'response_duration': None}}" | ||
] | ||
}, | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"chat_completion(\n", | ||
" messages=messages,\n", | ||
" model=\"phi3\",\n", | ||
" client=client_ollama,\n", | ||
" max_tokens=max_tokens,\n", | ||
" temperature=temperature,\n", | ||
" json_mode=json_mode,\n", | ||
")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[tool.poetry] | ||
name = "not-again-ai" | ||
version = "0.7.0" | ||
version = "0.8.0" | ||
description = "Designed to once and for all collect all the little things that come up over and over again in AI projects and put them in one place." | ||
authors = ["DaveCoDev <[email protected]>"] | ||
license = "MIT" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from typing import Any | ||
|
||
from ollama import Client | ||
from openai import OpenAI | ||
|
||
from not_again_ai.llm.ollama import chat_completion as chat_completion_ollama | ||
from not_again_ai.llm.openai_api import chat_completion as chat_completion_openai | ||
|
||
|
||
def chat_completion( | ||
messages: list[dict[str, Any]], | ||
model: str, | ||
client: OpenAI | Client, | ||
max_tokens: int | None = None, | ||
temperature: float = 0.7, | ||
json_mode: bool = False, | ||
seed: int | None = None, | ||
**kwargs: Any, | ||
) -> dict[str, Any]: | ||
"""Creates a common wrapper around chat completion models from different providers. | ||
Currently supports the OpenAI API and Ollama local models. | ||
All input parameters are supported by all providers in similar ways and the output is standardized. | ||
Args: | ||
messages (list[dict[str, Any]]): A list of messages to send to the model. | ||
model (str): The model name to use. | ||
client (OpenAI | Client): The client object to use for chat completion. | ||
max_tokens (int, optional): The maximum number of tokens to generate. | ||
temperature (float, optional): The temperature of the model. Increasing the temperature will make the model answer more creatively. | ||
json_mode (bool, optional): This will structure the response as a valid JSON object. | ||
seed (int, optional): The seed to use for the model for reproducible outputs. | ||
Returns: | ||
dict[str, Any]: A dictionary with the following keys | ||
message (str | dict): The content of the generated assistant message. | ||
If json_mode is True, this will be a dictionary. | ||
completion_tokens (int): The number of tokens used by the model to generate the completion. | ||
extras (dict): This will contain any additional fields returned by corresponding provider. | ||
""" | ||
# Determine which chat_completion function to call based on the client type | ||
if isinstance(client, OpenAI): | ||
response = chat_completion_openai.chat_completion( | ||
messages=messages, | ||
model=model, | ||
client=client, | ||
max_tokens=max_tokens, | ||
temperature=temperature, | ||
json_mode=json_mode, | ||
seed=seed, | ||
**kwargs, | ||
) | ||
elif isinstance(client, Client): | ||
response = chat_completion_ollama.chat_completion( | ||
messages=messages, | ||
model=model, | ||
client=client, | ||
max_tokens=max_tokens, | ||
temperature=temperature, | ||
json_mode=json_mode, | ||
seed=seed, | ||
**kwargs, | ||
) | ||
else: | ||
raise ValueError("Invalid client type") | ||
|
||
# Parse the responses to be consistent | ||
response_data = {} | ||
response_data["message"] = response.get("message", None) | ||
response_data["completion_tokens"] = response.get("completion_tokens", None) | ||
|
||
# Return any additional fields from the response in an "extras" dictionary | ||
extras = {k: v for k, v in response.items() if k not in response_data} | ||
if extras: | ||
response_data["extras"] = extras | ||
|
||
return response_data |
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion
2
...again_ai/llm/openai/context_management.py → ...n_ai/llm/openai_api/context_management.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Oops, something went wrong.