diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index edca023..272dcdf 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -29,7 +29,7 @@ jobs: env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} OPENAI_ORG_ID: ${{ secrets.OPENAI_ORG_ID }} - SKIP_TESTS_NAAI: "tests/llm/ollama" + SKIP_TESTS_NAAI: "tests/llm/ollama tests/llm/test_chat_completion.py" run: poetry run nox -s test-${{ matrix.python-version }} quality: runs-on: ubuntu-22.04 diff --git a/notebooks/llm/common_chat_completion.ipynb b/notebooks/llm/common_chat_completion.ipynb new file mode 100644 index 0000000..5ea6447 --- /dev/null +++ b/notebooks/llm/common_chat_completion.ipynb @@ -0,0 +1,156 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Common Chat Completion\n", + "\n", + "This notebook covers not-again-ai's common abstraction around multiple chat completion model providers. \n", + "\n", + "Currently the supported providers are the [OpenAI API](https://openai.com/api) and [Ollama](https://github.com/ollama/ollama)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Choosing a Client" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from not_again_ai.llm.ollama.ollama_client import ollama_client\n", + "from not_again_ai.llm.openai_api.openai_client import openai_client\n", + "\n", + "client_openai = openai_client()\n", + "client_ollama = ollama_client()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define common variables we can try sending to different provider/model combinations." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "messages = [\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"Generate a random number between 0 and 100 and structure the response in using JSON.\",\n", + " },\n", + "]\n", + "\n", + "max_tokens = 200\n", + "temperature = 2\n", + "json_mode = True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use the OpenAI Client" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'message': {'randomNumber': 57},\n", + " 'completion_tokens': 10,\n", + " 'extras': {'prompt_tokens': 35, 'finish_reason': 'stop'}}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from not_again_ai.llm.chat_completion import chat_completion\n", + "\n", + "chat_completion(\n", + " messages=messages,\n", + " model=\"gpt-3.5-turbo\",\n", + " client=client_openai,\n", + " max_tokens=max_tokens,\n", + " temperature=temperature,\n", + " json_mode=json_mode,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Use the Ollama Client" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'message': {'random_number': 47},\n", + " 'completion_tokens': None,\n", + " 'extras': {'response_duration': None}}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_completion(\n", + " messages=messages,\n", + " model=\"phi3\",\n", + " client=client_ollama,\n", + " max_tokens=max_tokens,\n", + " temperature=temperature,\n", + " json_mode=json_mode,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/llm/gpt-4-v.ipynb b/notebooks/llm/gpt-4-v.ipynb index 9a07ed0..12c816d 100644 --- a/notebooks/llm/gpt-4-v.ipynb +++ b/notebooks/llm/gpt-4-v.ipynb @@ -109,7 +109,7 @@ "source": [ "from pathlib import Path\n", "\n", - "from not_again_ai.llm.openai.prompts import chat_prompt\n", + "from not_again_ai.llm.openai_api.prompts import chat_prompt\n", "\n", "sk_infographic = Path.cwd().parent.parent / \"tests\" / \"llm\" / \"sample_images\" / \"SKInfographic.png\"\n", "sk_diagram = Path.cwd().parent.parent / \"tests\" / \"llm\" / \"sample_images\" / \"SKDiagram.png\"\n", @@ -164,7 +164,7 @@ { "data": { "text/plain": [ - "ChatCompletion(id='chatcmpl-9LYbe6AAOTg0B1EPYQZjwPV0lYDuE', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Semantic Kernel integrates various AI services and plugins to process and contextualize inputs, enhancing the precision of its output by selecting optimal models and templates based on the contextualized information it gathers.', role='assistant', function_call=None, tool_calls=None))], created=1714924942, model='gpt-4-turbo-2024-04-09', object='chat.completion', system_fingerprint='fp_5d12056990', usage=CompletionUsage(completion_tokens=36, prompt_tokens=1565, total_tokens=1601))" + "ChatCompletion(id='chatcmpl-9MOpZiTg5uUclrUH9EbvY9xGVcCm0', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Semantic Kernel works by integrating multiple plugins and APIs to recall memory, create plans, and perform custom tasks while maintaining context, ultimately providing a tailored output based on specific AI services and monitored interactions.', role='assistant', function_call=None, tool_calls=None))], created=1715125693, model='gpt-4-turbo-2024-04-09', object='chat.completion', system_fingerprint='fp_5d12056990', usage=CompletionUsage(completion_tokens=38, prompt_tokens=1565, total_tokens=1603))" ] }, "execution_count": 2, @@ -207,10 +207,10 @@ "data": { "text/plain": [ "{'choices': [{'finish_reason': 'stop',\n", - " 'message': 'The Semantic Kernel works by selecting and invoking AI services based on a rendered prompt, which it uses to parse responses and create function results, integrating these processes with telemetry, monitoring, and event notifications for reliable and responsible AI applications.'},\n", + " 'message': 'The Semantic Kernel functions by selecting and invoking AI services based on a model, rendering prompts, and parsing responses, while managing reliability and telemetry through a centralized kernel that integrates with applications via event notifications and result creation.'},\n", " {'finish_reason': 'stop',\n", - " 'message': 'Semantic Kernel operates by selecting and invoking AI services based on rendered prompts, managing reliability and model selection, and parsing responses to create functional results, all coordinated through a central kernel that also handles event notifications and telemetry monitoring.'}],\n", - " 'completion_tokens': 88,\n", + " 'message': \"The Semantic Kernel functions by selecting and invoking AI services based on rendered prompts, managing responses through parsing and templating, and integrating reliability, telemetry, and monitoring within the kernel's architecture to ensure responsible AI usage.\"}],\n", + " 'completion_tokens': 84,\n", " 'prompt_tokens': 1565,\n", " 'system_fingerprint': 'fp_5d12056990'}" ] @@ -221,8 +221,8 @@ } ], "source": [ - "from not_again_ai.llm.openai.chat_completion import chat_completion\n", - "from not_again_ai.llm.openai.openai_client import openai_client\n", + "from not_again_ai.llm.openai_api.chat_completion import chat_completion\n", + "from not_again_ai.llm.openai_api.openai_client import openai_client\n", "\n", "client = openai_client()\n", "\n", @@ -241,7 +241,7 @@ "data": { "text/plain": [ "{'finish_reason': 'stop',\n", - " 'message': 'The Semantic Kernel works by selecting and invoking AI services based on a rendered prompt, which it uses to parse responses and create function results, integrating these processes with telemetry, monitoring, and event notifications for reliable and responsible AI applications.'}" + " 'message': 'The Semantic Kernel functions by selecting and invoking AI services based on a model, rendering prompts, and parsing responses, while managing reliability and telemetry through a centralized kernel that integrates with applications via event notifications and result creation.'}" ] }, "execution_count": 4, diff --git a/notebooks/llm/ollama_intro.ipynb b/notebooks/llm/ollama_intro.ipynb index 1a68262..0d24a6d 100644 --- a/notebooks/llm/ollama_intro.ipynb +++ b/notebooks/llm/ollama_intro.ipynb @@ -61,9 +61,9 @@ { "data": { "text/plain": [ - "{'message': \"Hello there! How can I assist you today? If you have any questions or need information on various topics, feel free to ask. I'm here to help!\",\n", - " 'completion_tokens': 35,\n", - " 'response_duration': 2.74299}" + "{'message': 'Hello there! How can I assist you today?\\n',\n", + " 'completion_tokens': 12,\n", + " 'response_duration': 2.48722}" ] }, "execution_count": 2, @@ -150,7 +150,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "[{'name': 'phi3:latest', 'model': 'phi3:latest', 'modified_at': '2024-05-05T18:06:07.598394632Z', 'size': 2318920898, 'size_readable': '2.16 GB', 'details': {'parent_model': '', 'format': 'gguf', 'family': 'llama', 'families': ['llama'], 'parameter_size': '4B', 'quantization_level': 'Q4_K_M'}}]\n" + "[{'name': 'llama3:8b', 'model': 'llama3:8b', 'modified_at': '2024-05-07T12:50:01.227242442Z', 'size': 4661224578, 'size_readable': '4.34 GB', 'details': {'parent_model': '', 'format': 'gguf', 'family': 'llama', 'families': ['llama'], 'parameter_size': '8B', 'quantization_level': 'Q4_0'}}, {'name': 'phi3:latest', 'model': 'phi3:latest', 'modified_at': '2024-05-05T18:06:59.995187156Z', 'size': 2318920898, 'size_readable': '2.16 GB', 'details': {'parent_model': '', 'format': 'gguf', 'family': 'llama', 'families': ['llama'], 'parameter_size': '4B', 'quantization_level': 'Q4_K_M'}}]\n" ] } ], @@ -192,7 +192,7 @@ "data": { "text/plain": [ "{'modelfile': '# Modelfile generated by \"ollama show\"\\n# To build a new Modelfile based on this one, replace the FROM line with:\\n# FROM phi3:latest\\n\\nFROM /usr/share/ollama/.ollama/models/blobs/sha256-4fed7364ee3e0c7cb4fe0880148bfdfcd1b630981efa0802a6b62ee52e7da97e\\nTEMPLATE \"\"\"{{ if .System }}<|system|>\\n{{ .System }}<|end|>\\n{{ end }}{{ if .Prompt }}<|user|>\\n{{ .Prompt }}<|end|>\\n{{ end }}<|assistant|>\\n{{ .Response }}<|end|>\\n\"\"\"\\nPARAMETER num_keep 4\\nPARAMETER stop \"<|user|>\"\\nPARAMETER stop \"<|assistant|>\"\\nPARAMETER stop \"<|system|>\"\\nPARAMETER stop \"<|end|>\"\\nPARAMETER stop \"<|endoftext|>\"',\n", - " 'parameters': 'num_keep 4\\nstop \"<|user|>\"\\nstop \"<|assistant|>\"\\nstop \"<|system|>\"\\nstop \"<|end|>\"\\nstop \"<|endoftext|>\"',\n", + " 'parameters': 'stop \"<|user|>\"\\nstop \"<|assistant|>\"\\nstop \"<|system|>\"\\nstop \"<|end|>\"\\nstop \"<|endoftext|>\"\\nnum_keep 4',\n", " 'template': '{{ if .System }}<|system|>\\n{{ .System }}<|end|>\\n{{ end }}{{ if .Prompt }}<|user|>\\n{{ .Prompt }}<|end|>\\n{{ end }}<|assistant|>\\n{{ .Response }}<|end|>\\n',\n", " 'details': {'parent_model': '',\n", " 'format': 'gguf',\n", diff --git a/notebooks/llm/openai_chat_completion.ipynb b/notebooks/llm/openai_chat_completion.ipynb index 82a5043..1d56190 100644 --- a/notebooks/llm/openai_chat_completion.ipynb +++ b/notebooks/llm/openai_chat_completion.ipynb @@ -26,7 +26,7 @@ "metadata": {}, "outputs": [], "source": [ - "from not_again_ai.llm.openai.openai_client import openai_client\n", + "from not_again_ai.llm.openai_api.openai_client import openai_client\n", "\n", "client = openai_client()" ] @@ -59,7 +59,7 @@ } ], "source": [ - "from not_again_ai.llm.openai.chat_completion import chat_completion\n", + "from not_again_ai.llm.openai_api.chat_completion import chat_completion\n", "\n", "messages = [{\"role\": \"system\", \"content\": \"You are a helpful assistant.\"}, {\"role\": \"user\", \"content\": \"Hello!\"}]\n", "response = chat_completion(messages=messages, model=\"gpt-3.5-turbo\", max_tokens=100, client=client)\n", @@ -98,7 +98,7 @@ } ], "source": [ - "from not_again_ai.llm.openai.prompts import chat_prompt\n", + "from not_again_ai.llm.openai_api.prompts import chat_prompt\n", "\n", "place_extraction_prompt = [\n", " {\n", @@ -147,7 +147,7 @@ } ], "source": [ - "from not_again_ai.llm.openai.tokens import num_tokens_from_messages\n", + "from not_again_ai.llm.openai_api.tokens import num_tokens_from_messages\n", "\n", "num_tokens = num_tokens_from_messages(messages=messages, model=\"gpt-3.5-turbo\")\n", "print(num_tokens)" @@ -180,8 +180,7 @@ " 'tool_names': ['get_current_weather'],\n", " 'tool_args_list': [{'location': 'Boston, MA', 'format': 'celsius'}]}],\n", " 'completion_tokens': 40,\n", - " 'prompt_tokens': 105,\n", - " 'system_fingerprint': 'fp_3b956da36b'}" + " 'prompt_tokens': 105}" ] }, "execution_count": 5, diff --git a/pyproject.toml b/pyproject.toml index 50a2b5f..640420f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "not-again-ai" -version = "0.7.0" +version = "0.8.0" description = "Designed to once and for all collect all the little things that come up over and over again in AI projects and put them in one place." authors = ["DaveCoDev "] license = "MIT" diff --git a/src/not_again_ai/llm/chat_completion.py b/src/not_again_ai/llm/chat_completion.py new file mode 100644 index 0000000..d2b1711 --- /dev/null +++ b/src/not_again_ai/llm/chat_completion.py @@ -0,0 +1,76 @@ +from typing import Any + +from ollama import Client +from openai import OpenAI + +from not_again_ai.llm.ollama import chat_completion as chat_completion_ollama +from not_again_ai.llm.openai_api import chat_completion as chat_completion_openai + + +def chat_completion( + messages: list[dict[str, Any]], + model: str, + client: OpenAI | Client, + max_tokens: int | None = None, + temperature: float = 0.7, + json_mode: bool = False, + seed: int | None = None, + **kwargs: Any, +) -> dict[str, Any]: + """Creates a common wrapper around chat completion models from different providers. + Currently supports the OpenAI API and Ollama local models. + All input parameters are supported by all providers in similar ways and the output is standardized. + + Args: + messages (list[dict[str, Any]]): A list of messages to send to the model. + model (str): The model name to use. + client (OpenAI | Client): The client object to use for chat completion. + max_tokens (int, optional): The maximum number of tokens to generate. + temperature (float, optional): The temperature of the model. Increasing the temperature will make the model answer more creatively. + json_mode (bool, optional): This will structure the response as a valid JSON object. + seed (int, optional): The seed to use for the model for reproducible outputs. + + Returns: + dict[str, Any]: A dictionary with the following keys + message (str | dict): The content of the generated assistant message. + If json_mode is True, this will be a dictionary. + completion_tokens (int): The number of tokens used by the model to generate the completion. + extras (dict): This will contain any additional fields returned by corresponding provider. + """ + # Determine which chat_completion function to call based on the client type + if isinstance(client, OpenAI): + response = chat_completion_openai.chat_completion( + messages=messages, + model=model, + client=client, + max_tokens=max_tokens, + temperature=temperature, + json_mode=json_mode, + seed=seed, + **kwargs, + ) + elif isinstance(client, Client): + response = chat_completion_ollama.chat_completion( + messages=messages, + model=model, + client=client, + max_tokens=max_tokens, + temperature=temperature, + json_mode=json_mode, + seed=seed, + **kwargs, + ) + else: + raise ValueError("Invalid client type") + + # Parse the responses to be consistent + response_data = {} + response_data["message"] = response.get("message", None) + response_data["completion_tokens"] = response.get("completion_tokens", None) + + # Return any additional fields from the response in an "extras" dictionary + extras = {k: v for k, v in response.items() if k not in response_data} + if extras: + response_data["extras"] = extras + + return response_data diff --git a/src/not_again_ai/llm/openai/__init__.py b/src/not_again_ai/llm/openai_api/__init__.py similarity index 100% rename from src/not_again_ai/llm/openai/__init__.py rename to src/not_again_ai/llm/openai_api/__init__.py diff --git a/src/not_again_ai/llm/openai/chat_completion.py b/src/not_again_ai/llm/openai_api/chat_completion.py similarity index 100% rename from src/not_again_ai/llm/openai/chat_completion.py rename to src/not_again_ai/llm/openai_api/chat_completion.py diff --git a/src/not_again_ai/llm/openai/context_management.py b/src/not_again_ai/llm/openai_api/context_management.py similarity index 97% rename from src/not_again_ai/llm/openai/context_management.py rename to src/not_again_ai/llm/openai_api/context_management.py index ce8a418..12c7efb 100644 --- a/src/not_again_ai/llm/openai/context_management.py +++ b/src/not_again_ai/llm/openai_api/context_management.py @@ -1,6 +1,6 @@ import copy -from not_again_ai.llm.openai.tokens import num_tokens_from_messages, truncate_str +from not_again_ai.llm.openai_api.tokens import num_tokens_from_messages, truncate_str def _inject_variable( diff --git a/src/not_again_ai/llm/openai/embeddings.py b/src/not_again_ai/llm/openai_api/embeddings.py similarity index 100% rename from src/not_again_ai/llm/openai/embeddings.py rename to src/not_again_ai/llm/openai_api/embeddings.py diff --git a/src/not_again_ai/llm/openai/openai_client.py b/src/not_again_ai/llm/openai_api/openai_client.py similarity index 100% rename from src/not_again_ai/llm/openai/openai_client.py rename to src/not_again_ai/llm/openai_api/openai_client.py diff --git a/src/not_again_ai/llm/openai/prompts.py b/src/not_again_ai/llm/openai_api/prompts.py similarity index 100% rename from src/not_again_ai/llm/openai/prompts.py rename to src/not_again_ai/llm/openai_api/prompts.py diff --git a/src/not_again_ai/llm/openai/tokens.py b/src/not_again_ai/llm/openai_api/tokens.py similarity index 100% rename from src/not_again_ai/llm/openai/tokens.py rename to src/not_again_ai/llm/openai_api/tokens.py diff --git a/tests/llm/ollama/test_chat_completion.py b/tests/llm/ollama/test_chat_completion.py index b6d16e3..41f16ba 100644 --- a/tests/llm/ollama/test_chat_completion.py +++ b/tests/llm/ollama/test_chat_completion.py @@ -4,42 +4,47 @@ from not_again_ai.llm.ollama.chat_completion import chat_completion from not_again_ai.llm.ollama.ollama_client import ollama_client -MODEL = "phi3" +MODELS = ["phi3", "llama3:8b"] -def test_chat_completion() -> None: +@pytest.fixture(params=MODELS) +def model(request): # type: ignore + return request.param + + +def test_chat_completion(model: str) -> None: client = ollama_client() messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello!"}, ] - response = chat_completion(messages, model=MODEL, client=client) + response = chat_completion(messages, model=model, client=client) print(response) -def test_chat_completion_max_tokens() -> None: +def test_chat_completion_max_tokens(model: str) -> None: client = ollama_client() messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello!"}, ] - response = chat_completion(messages, model=MODEL, client=client, max_tokens=2) + response = chat_completion(messages, model=model, client=client, max_tokens=2) print(response) -def test_chat_completion_context_window() -> None: +def test_chat_completion_context_window(model: str) -> None: client = ollama_client() messages = [ {"role": "user", "content": "Orange, kiwi, watermelon. List the three fruits I just named."}, ] - response = chat_completion(messages, model=MODEL, client=client, context_window=1, max_tokens=200) + response = chat_completion(messages, model=model, client=client, context_window=1, max_tokens=200) print(response) -def test_chat_completion_json_mode() -> None: +def test_chat_completion_json_mode(model: str) -> None: client = ollama_client() messages = [ { @@ -55,24 +60,27 @@ def test_chat_completion_json_mode() -> None: }, ] - response = chat_completion(messages, model=MODEL, client=client, json_mode=True, max_tokens=200) + response = chat_completion(messages, model=model, client=client, json_mode=True, max_tokens=200) print(response) -def test_chat_completion_seed() -> None: +test_chat_completion_json_mode(MODELS[1]) + + +def test_chat_completion_seed(model: str) -> None: client = ollama_client() messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Generate a random number between 0 and 100."}, ] - response1 = chat_completion(messages, model=MODEL, client=client, seed=6, temperature=2) - response2 = chat_completion(messages, model=MODEL, client=client, seed=6, temperature=2) + response1 = chat_completion(messages, model=model, client=client, seed=6, temperature=2) + response2 = chat_completion(messages, model=model, client=client, seed=6, temperature=2) assert response1["message"] == response2["message"] -def test_chat_completion_all() -> None: +def test_chat_completion_all(model: str) -> None: client = ollama_client() messages = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -84,7 +92,7 @@ def test_chat_completion_all() -> None: response = chat_completion( messages, - model=MODEL, + model=model, client=client, max_tokens=300, context_window=1000, diff --git a/tests/llm/openai/__init__.py b/tests/llm/open_ai/__init__.py similarity index 100% rename from tests/llm/openai/__init__.py rename to tests/llm/open_ai/__init__.py diff --git a/tests/llm/openai/test_chat_completion.py b/tests/llm/open_ai/test_chat_completion.py similarity index 99% rename from tests/llm/openai/test_chat_completion.py rename to tests/llm/open_ai/test_chat_completion.py index 3675ac5..51e0b5d 100644 --- a/tests/llm/openai/test_chat_completion.py +++ b/tests/llm/open_ai/test_chat_completion.py @@ -3,9 +3,9 @@ import pytest -from not_again_ai.llm.openai.chat_completion import chat_completion -from not_again_ai.llm.openai.openai_client import openai_client -from not_again_ai.llm.openai.prompts import encode_image +from not_again_ai.llm.openai_api.chat_completion import chat_completion +from not_again_ai.llm.openai_api.openai_client import openai_client +from not_again_ai.llm.openai_api.prompts import encode_image image_dir = Path(__file__).parent.parent / "sample_images" cat_image = image_dir / "cat.jpg" diff --git a/tests/llm/openai/test_context_management.py b/tests/llm/open_ai/test_context_management.py similarity index 96% rename from tests/llm/openai/test_context_management.py rename to tests/llm/open_ai/test_context_management.py index 4477f68..21babf0 100644 --- a/tests/llm/openai/test_context_management.py +++ b/tests/llm/open_ai/test_context_management.py @@ -1,5 +1,5 @@ -from not_again_ai.llm.openai.context_management import priority_truncation -from not_again_ai.llm.openai.tokens import num_tokens_from_messages +from not_again_ai.llm.openai_api.context_management import priority_truncation +from not_again_ai.llm.openai_api.tokens import num_tokens_from_messages def test_priority_truncation_simple() -> None: diff --git a/tests/llm/openai/test_embeddings.py b/tests/llm/open_ai/test_embeddings.py similarity index 92% rename from tests/llm/openai/test_embeddings.py rename to tests/llm/open_ai/test_embeddings.py index 82189a4..47c41df 100644 --- a/tests/llm/openai/test_embeddings.py +++ b/tests/llm/open_ai/test_embeddings.py @@ -1,5 +1,5 @@ -from not_again_ai.llm.openai.embeddings import embed_text -from not_again_ai.llm.openai.openai_client import openai_client +from not_again_ai.llm.openai_api.embeddings import embed_text +from not_again_ai.llm.openai_api.openai_client import openai_client def test_embeddings_basic() -> None: diff --git a/tests/llm/openai/test_openai_client.py b/tests/llm/open_ai/test_openai_client.py similarity index 92% rename from tests/llm/openai/test_openai_client.py rename to tests/llm/open_ai/test_openai_client.py index a37c3a9..2e1caa6 100644 --- a/tests/llm/openai/test_openai_client.py +++ b/tests/llm/open_ai/test_openai_client.py @@ -3,7 +3,7 @@ from openai import OpenAI import pytest -from not_again_ai.llm.openai.openai_client import InvalidOAIAPITypeError, openai_client +from not_again_ai.llm.openai_api.openai_client import InvalidOAIAPITypeError, openai_client def test_openai_client_client_default_type() -> None: diff --git a/tests/llm/openai/test_prompts.py b/tests/llm/open_ai/test_prompts.py similarity index 98% rename from tests/llm/openai/test_prompts.py rename to tests/llm/open_ai/test_prompts.py index b63c02e..07cb58f 100644 --- a/tests/llm/openai/test_prompts.py +++ b/tests/llm/open_ai/test_prompts.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Any -from not_again_ai.llm.openai.prompts import chat_prompt, encode_image +from not_again_ai.llm.openai_api.prompts import chat_prompt, encode_image image_dir = Path(__file__).parent.parent / "sample_images" sk_infographic = image_dir / "SKInfographic.png" diff --git a/tests/llm/openai/test_tokens.py b/tests/llm/open_ai/test_tokens.py similarity index 95% rename from tests/llm/openai/test_tokens.py rename to tests/llm/open_ai/test_tokens.py index ff2f867..099632a 100644 --- a/tests/llm/openai/test_tokens.py +++ b/tests/llm/open_ai/test_tokens.py @@ -1,6 +1,6 @@ import pytest -from not_again_ai.llm.openai.tokens import num_tokens_from_messages, num_tokens_in_string, truncate_str +from not_again_ai.llm.openai_api.tokens import num_tokens_from_messages, num_tokens_in_string, truncate_str # Tests for truncate_str function diff --git a/tests/llm/test_chat_completion.py b/tests/llm/test_chat_completion.py new file mode 100644 index 0000000..6242101 --- /dev/null +++ b/tests/llm/test_chat_completion.py @@ -0,0 +1,25 @@ +from not_again_ai.llm.chat_completion import chat_completion +from not_again_ai.llm.ollama.ollama_client import ollama_client +from not_again_ai.llm.openai_api.openai_client import openai_client + + +def test_chat_completion_ollama() -> None: + client = ollama_client() + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ] + + response = chat_completion(messages, model="phi3", client=client) + print(response) + + +def test_chat_completion_openai() -> None: + client = openai_client() + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ] + + response = chat_completion(messages, model="gpt-3.5-turbo", client=client) + print(response)