diff --git a/.env.example b/.env.example index 158b471..3653517 100644 --- a/.env.example +++ b/.env.example @@ -21,6 +21,10 @@ LLM_PROVIDER='ollama' OLLAMA_MODEL='dolphin-llama3:8b-v2.9-q4_K_M' # Smaller option # OLLAMA_MODEL='tinydolphin:1.1b-v2.8-q4_K_M' +GROQ_API_KEY= +GROQ_MODEL='llama3-8b-8192' +OPENAI_API_KEY= +OPENAI_MODEL='gpt-4o-2024-05-13' LLM_TEMPERATURE=0 JWT_SECRET= SENTRY_DSN= diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 201f383..1c4a839 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -72,11 +72,11 @@ jobs: docker rmi -f $(docker images -f "dangling=true" -q) docker volume rm -f $(docker volume ls -f "dangling=true" -q) # Update the service - docker compose pull backend gradio - docker compose stop backend gradio && docker compose up -d --wait + docker compose pull backend chat + docker compose stop backend chat && docker compose up -d --wait # Check update docker inspect -f '{{ .Created }}' $(docker compose images -q backend) - docker inspect -f '{{ .Created }}' $(docker compose images -q gradio) + docker inspect -f '{{ .Created }}' $(docker compose images -q chat) # Clean up docker rm -fv $(docker ps -aq) docker rmi -f $(docker images -f "dangling=true" -q) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 913766e..e62515d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -77,10 +77,12 @@ If you are wondering how to do something with Companion API, or a more general q - [Git](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git) - [Docker](https://docs.docker.com/engine/install/) - [Docker compose](https://docs.docker.com/compose/) -- [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) and a GPU (>= 6 Gb VRAM for good performance/latency balance) +- [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) and a GPU (>= 6 Gb VRAM for good performance/latency balance)* - [Poetry](https://python-poetry.org/docs/) - [Make](https://www.gnu.org/software/make/) (optional) +_*If you don't have a GPU, you can use alternative LLM providers (currently supported: Groq, OpenAI)_ + ### Configure your fork @@ -130,8 +132,12 @@ This file contains all the information to run the project. - `SUPERADMIN_PWD`: the password of the initial admin user #### Other optional values -- `SECRET_KEY`: if set, tokens can be reused between sessions. All instances sharing the same secret key can use the same token. +- `JWT_SECRET`: if set, tokens can be reused between sessions. All instances sharing the same secret key can use the same token. - `OLLAMA_MODEL`: the model tag in [Ollama library](https://ollama.com/library) that will be used for the API. +- `GROQ_API_KEY`: your [Groq API KEY](https://console.groq.com/keys), required if you select `groq` as `LLM_PROVIDER`. +- `GROQ_MODEL`: the model tag in [Groq supported models](https://console.groq.com/docs/models) that will be used for the API. +- `OPENAI_API_KEY`: your [OpenAI API KEY](https://platform.openai.com/api-keys), required if you select `openai` as `LLM_PROVIDER`. +- `OPENAI_MODEL`: the model tag in [OpenAI supported models](https://platform.openai.com/docs/models) that will be used for the API. - `SENTRY_DSN`: the DSN for your [Sentry](https://sentry.io/) project, which monitors back-end errors and report them back. - `SERVER_NAME`: the server tag that will be used to report events to Sentry. - `POSTHOG_HOST`: the host for PostHog [PostHog](https://eu.posthog.com/settings/project-details). diff --git a/docker-compose.prod.yml b/docker-compose.prod.yml index f57284c..98db081 100644 --- a/docker-compose.prod.yml +++ b/docker-compose.prod.yml @@ -72,7 +72,7 @@ services: - POSTGRES_URL=postgresql+asyncpg://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB} - SUPERADMIN_LOGIN=${SUPERADMIN_LOGIN} - SUPERADMIN_PWD=${SUPERADMIN_PWD} - - SECRET_KEY=${SECRET_KEY} + - JWT_SECRET=${JWT_SECRET} - OLLAMA_ENDPOINT=http://ollama:11434 - OLLAMA_MODEL=${OLLAMA_MODEL} - OLLAMA_TIMEOUT=${OLLAMA_TIMEOUT:-60} diff --git a/docker-compose.yml b/docker-compose.yml index 0d9598c..28158ff 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -56,6 +56,8 @@ services: - OLLAMA_MODEL=${OLLAMA_MODEL} - GROQ_API_KEY=${GROQ_API_KEY} - GROQ_MODEL=${GROQ_MODEL} + - OPENAI_API_KEY=${OPENAI_API_KEY} + - OPENAI_MODEL=${OPENAI_MODEL} - OLLAMA_TIMEOUT=${OLLAMA_TIMEOUT:-60} - SUPPORT_EMAIL=${SUPPORT_EMAIL} - DEBUG=true diff --git a/docs/developers/self-hosting.mdx b/docs/developers/self-hosting.mdx index 0699c3f..aa03dc0 100644 --- a/docs/developers/self-hosting.mdx +++ b/docs/developers/self-hosting.mdx @@ -10,7 +10,7 @@ Whatever your installation method, you'll need at least the following to be inst 1. [Docker](https://docs.docker.com/engine/install/) (and [Docker compose](https://docs.docker.com/compose/) if you're using an old version) 2. [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) and a GPU -_We recommend min 5Gb of VRAM on your GPU for good performance/latency balance._ +_We recommend min 5Gb of VRAM on your GPU for good performance/latency balance. Please note that by default, this will run your LLM locally (available offline) but if you don't have a GPU, you can use online LLM providers (currently supported: Groq, OpenAI)_ ### 60 seconds setup ⏱️ diff --git a/poetry.lock b/poetry.lock index 84d506c..02050eb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1552,6 +1552,29 @@ files = [ [package.dependencies] httpx = ">=0.27.0,<0.28.0" +[[package]] +name = "openai" +version = "1.30.1" +description = "The official Python library for the openai API" +optional = false +python-versions = ">=3.7.1" +files = [ + {file = "openai-1.30.1-py3-none-any.whl", hash = "sha256:c9fb3c3545c118bbce8deb824397b9433a66d0d0ede6a96f7009c95b76de4a46"}, + {file = "openai-1.30.1.tar.gz", hash = "sha256:4f85190e577cba0b066e1950b8eb9b11d25bc7ebcc43a86b326ce1bfa564ec74"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.7,<5" + +[package.extras] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] + [[package]] name = "orjson" version = "3.9.15" @@ -2803,6 +2826,17 @@ files = [ [package.dependencies] urllib3 = ">=2" +[[package]] +name = "types-urllib3" +version = "1.26.25.14" +description = "Typing stubs for urllib3" +optional = false +python-versions = "*" +files = [ + {file = "types-urllib3-1.26.25.14.tar.gz", hash = "sha256:229b7f577c951b8c1b92c1bc2b2fdb0b49847bd2af6d1cc2a2e3dd340f3bda8f"}, + {file = "types_urllib3-1.26.25.14-py3-none-any.whl", hash = "sha256:9683bbb7fb72e32bfe9d2be6e04875fbe1b3eeec3cbb4ea231435aa7fd6b4f0e"}, +] + [[package]] name = "typing-extensions" version = "4.8.0" @@ -3022,4 +3056,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "5e3fc4ebd14893da2f70f02dea074e976425f3eda4634643859a8a8b32d48348" +content-hash = "e534d24c22548f9876c27ffaa4dbf5afdb4123971a1edfbfb2fb3ff8b482b3a1" diff --git a/pyproject.toml b/pyproject.toml index e8a2b9b..ee9f586 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ posthog = "^3.0.0" prometheus-fastapi-instrumentator = "^6.1.0" groq = "^0.5.0" ollama = "^0.1.9" +openai = "^1.29.0" uvloop = "^0.19.0" httptools = "^0.6.1" @@ -40,6 +41,7 @@ optional = true ruff = "==0.4.4" mypy = "==1.10.0" types-requests = ">=2.0.0" +types-urllib3 = ">=1.26.25" types-passlib = ">=1.7.0" pre-commit = "^3.6.0" diff --git a/src/app/api/api_v1/endpoints/code.py b/src/app/api/api_v1/endpoints/code.py index 5df78de..41f23d5 100644 --- a/src/app/api/api_v1/endpoints/code.py +++ b/src/app/api/api_v1/endpoints/code.py @@ -29,7 +29,7 @@ async def chat( guidelines: GuidelineCRUD = Depends(get_guideline_crud), token_payload: TokenPayload = Security(get_quack_jwt, scopes=[UserScope.ADMIN, UserScope.USER]), ) -> StreamingResponse: - telemetry_client.capture(token_payload.sub, event="compute-chat") + telemetry_client.capture(token_payload.sub, event="code-chat") # Validate payload if len(payload.messages) == 0: raise HTTPException( diff --git a/src/app/core/config.py b/src/app/core/config.py index 714a490..2369b11 100644 --- a/src/app/core/config.py +++ b/src/app/core/config.py @@ -53,7 +53,7 @@ def sqlachmey_uri(cls, v: str) -> str: GROQ_API_KEY: Union[str, None] = os.environ.get("GROQ_API_KEY") GROQ_MODEL: str = os.environ.get("GROQ_MODEL", "llama3-8b-8192") OPENAI_API_KEY: Union[str, None] = os.environ.get("OPENAI_API_KEY") - OPENAI_MODEL: str = os.environ.get("OPENAI_MODEL", "gpt-4-turbo-2024-04-09") + OPENAI_MODEL: str = os.environ.get("OPENAI_MODEL", "gpt-4o-2024-05-13") # Error monitoring SENTRY_DSN: Union[str, None] = os.environ.get("SENTRY_DSN") diff --git a/src/app/schemas/services.py b/src/app/schemas/services.py index 62919c3..37b8ee0 100644 --- a/src/app/schemas/services.py +++ b/src/app/schemas/services.py @@ -4,65 +4,21 @@ # See LICENSE or go to for full license details. from enum import Enum -from typing import Any, Dict, List, Union +from typing import List, Union from pydantic import BaseModel, HttpUrl __all__ = ["ChatCompletion"] -class OpenAIModel(str, Enum): - # https://platform.openai.com/docs/models/overview - GPT3_5_TURBO: str = "gpt-3.5-turbo-0125" - GPT3_5_TURBO_LEGACY: str = "gpt-3.5-turbo-1106" - GPT4_TURBO: str = "gpt-4-0125-preview" - GPT4_TURBO_LEGACY: str = "gpt-4-1106-preview" - - -class OpenAIChatRole(str, Enum): +class ChatRole(str, Enum): SYSTEM: str = "system" USER: str = "user" ASSISTANT: str = "assistant" -class FieldSchema(BaseModel): - type: str - description: str - - -class ObjectSchema(BaseModel): - type: str = "object" - properties: Dict[str, Any] - required: List[str] - - -class ArraySchema(BaseModel): - type: str = "array" - items: ObjectSchema - - -class OpenAIFunction(BaseModel): - name: str - description: str - parameters: ObjectSchema - - -class OpenAITool(BaseModel): - type: str = "function" - function: OpenAIFunction - - -class _FunctionName(BaseModel): - name: str - - -class _OpenAIToolChoice(BaseModel): - type: str = "function" - function: _FunctionName - - -class OpenAIMessage(BaseModel): - role: OpenAIChatRole +class ChatMessage(BaseModel): + role: ChatRole content: str @@ -72,9 +28,7 @@ class _ResponseFormat(BaseModel): class ChatCompletion(BaseModel): model: str - messages: List[OpenAIMessage] - functions: List[OpenAIFunction] - function_call: Dict[str, str] + messages: List[ChatMessage] temperature: float = 0.0 frequency_penalty: float = 1.0 response_format: _ResponseFormat = _ResponseFormat(type="json_object") diff --git a/src/app/services/llm/groq.py b/src/app/services/llm/groq.py index 15fdffc..983207a 100644 --- a/src/app/services/llm/groq.py +++ b/src/app/services/llm/groq.py @@ -11,6 +11,8 @@ from groq import Groq, Stream from groq.lib.chat_completion_chunk import ChatCompletionChunk +from .utils import CHAT_PROMPT + logger = logging.getLogger("uvicorn.error") @@ -20,12 +22,6 @@ class GroqModel(str, Enum): MIXTRAL_8X7b: str = "mixtral-8x7b-32768" -CHAT_PROMPT = ( - "You are an AI programming assistant, developed by the company Quack AI, and you only answer questions related to computer science " - "(refuse to answer for the rest)." -) - - class GroqClient: def __init__( self, diff --git a/src/app/services/llm/llm.py b/src/app/services/llm/llm.py index 3497975..43c4330 100644 --- a/src/app/services/llm/llm.py +++ b/src/app/services/llm/llm.py @@ -3,16 +3,40 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. +import re from enum import Enum -from typing import Union +from typing import Dict, Union + +from fastapi import HTTPException, status from app.core.config import settings from .groq import GroqClient from .ollama import OllamaClient +from .openai import OpenAIClient __all__ = ["llm_client"] +EXAMPLE_PROMPT = ( + "You are responsible for producing concise illustrations of the company coding guidelines. " + "This will be used to teach new developers our way of engineering software. " + "Make sure your code is in the specified programming language and functional, don't add extra comments or explanations.\n" + # Format + "You should output two code blocks: " + "a minimal code snippet where the instruction was correctly followed, " + "and the same snippet with minimal modifications that invalidates the instruction." +) +# Strangely, this doesn't work when compiled +EXAMPLE_PATTERN = r"```[a-zA-Z]*\n(?P.*?)```\n.*```[a-zA-Z]*\n(?P.*?)```" + + +def validate_example_response(response: str) -> Dict[str, str]: + matches = re.search(EXAMPLE_PATTERN, response.strip(), re.DOTALL) + if matches is None: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed output schema validation") + + return matches.groupdict() + class LLMProvider(str, Enum): OLLAMA: str = "ollama" @@ -20,7 +44,7 @@ class LLMProvider(str, Enum): GROQ: str = "groq" -llm_client: Union[OllamaClient, GroqClient] +llm_client: Union[OllamaClient, GroqClient, OpenAIClient] if settings.LLM_PROVIDER == LLMProvider.OLLAMA: if not settings.OLLAMA_ENDPOINT: raise ValueError("Please provide a value for `OLLAMA_ENDPOINT`") @@ -29,5 +53,9 @@ class LLMProvider(str, Enum): if not settings.GROQ_API_KEY: raise ValueError("Please provide a value for `GROQ_API_KEY`") llm_client = GroqClient(settings.GROQ_API_KEY, settings.GROQ_MODEL, settings.LLM_TEMPERATURE) # type: ignore[arg-type] +elif settings.LLM_PROVIDER == LLMProvider.OPENAI: + if not settings.OPENAI_API_KEY: + raise ValueError("Please provide a value for `OPENAI_API_KEY`") + llm_client = OpenAIClient(settings.OPENAI_API_KEY, settings.OPENAI_MODEL, settings.LLM_TEMPERATURE) # type: ignore[arg-type] else: raise NotImplementedError("LLM provider is not implemented") diff --git a/src/app/services/llm/ollama.py b/src/app/services/llm/ollama.py index 3e42f1b..7f6660b 100644 --- a/src/app/services/llm/ollama.py +++ b/src/app/services/llm/ollama.py @@ -3,73 +3,16 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -import json import logging -import re -from typing import Dict, Generator, List, TypeVar, Union +from typing import Dict, Generator, List, Union -from fastapi import HTTPException, status from ollama import Client -logger = logging.getLogger("uvicorn.error") - -ValidationOut = TypeVar("ValidationOut") - - -EXAMPLE_PROMPT = ( - "You are responsible for producing concise illustrations of the company coding guidelines. " - "This will be used to teach new developers our way of engineering software. " - "Make sure your code is in the specified programming language and functional, don't add extra comments or explanations.\n" - # Format - "You should output two code blocks: " - "a minimal code snippet where the instruction was correctly followed, " - "and the same snippet with minimal modifications that invalidates the instruction." -) -# Strangely, this doesn't work when compiled -EXAMPLE_PATTERN = r"```[a-zA-Z]*\n(?P.*?)```\n.*```[a-zA-Z]*\n(?P.*?)```" - - -def validate_example_response(response: str) -> Dict[str, str]: - matches = re.search(EXAMPLE_PATTERN, response.strip(), re.DOTALL) - if matches is None: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed output schema validation") - - return matches.groupdict() +from .utils import CHAT_PROMPT +__all__ = ["OllamaClient"] -PARSING_PROMPT = ( - "You are responsible for summarizing the list of distinct coding guidelines for the company, by going through documentation. " - "This list will be used by developers to avoid hesitations in code reviews and to onboard new members. " - "Consider only guidelines that can be verified for a specific snippet of code (nothing about git, commits or community interactions) " - "by a human developer without running additional commands or tools, it should only relate to the code within each file. " - "Only include guidelines for which you could generate positive and negative code snippets, " - "don't invent anything that isn't present in the input text.\n" - # Format - "You should answer with a list of JSON dictionaries, one dictionary per guideline, where each dictionary has two keys with string values:\n" - "- title: a short summary title of the guideline\n" - "- details: a descriptive, comprehensive and inambiguous explanation of the guideline." -) -PARSING_PATTERN = r"\{\s*\"title\":\s+\"(?P.*?)\",\s+\"details\":\s+\"(?P<details>.*?)\"\s*\}" - - -CHAT_PROMPT = ( - "You are an AI programming assistant, developed by the company Quack AI, and you only answer questions related to computer science " - "(refuse to answer for the rest)." -) - -GUIDELINE_PROMPT = ( - "When answering user requests, you should at all times keep in mind the following software development guidelines:" -) - - -def validate_parsing_response(response: str) -> List[Dict[str, str]]: - guideline_list = json.loads(response.strip()) - if not isinstance(guideline_list, list) or any( - not isinstance(val, str) for guideline in guideline_list for val in guideline.values() - ): - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed output schema validation") - - return json.loads(response.strip()) +logger = logging.getLogger("uvicorn.error") class OllamaClient: diff --git a/src/app/services/llm/openai.py b/src/app/services/llm/openai.py index 07eedf0..0b7726f 100644 --- a/src/app/services/llm/openai.py +++ b/src/app/services/llm/openai.py @@ -3,386 +3,69 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details. -import json import logging -from concurrent.futures import ThreadPoolExecutor from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Type, TypeVar, Union +from typing import Dict, Generator, List, Union, cast -import requests -from fastapi import HTTPException, status -from pydantic import ValidationError +from openai import OpenAI, Stream +from openai.types.chat import ChatCompletionChunk -from app.models import Guideline -from app.schemas.code import ComplianceResult -from app.schemas.guidelines import GuidelineContent, GuidelineExample -from app.schemas.services import ( - ArraySchema, - ChatCompletion, - FieldSchema, - ObjectSchema, - OpenAIChatRole, - OpenAIFunction, - OpenAIMessage, - OpenAIModel, -) +from .utils import CHAT_PROMPT logger = logging.getLogger("uvicorn.error") -# __all__ = ["openai_client"] - -class ExecutionMode(str, Enum): - SINGLE = "single-thread" - MULTI = "multi-thread" - - -MONO_SCHEMA = ObjectSchema( - type="object", - properties={ - "is_compliant": FieldSchema( - type="boolean", - description="whether the guideline has been followed in the code snippet", - ), - "comment": FieldSchema( - type="string", - description="short instruction to make the snippet compliant, addressed to the developer who wrote the code. Should be empty if the snippet is compliant", - ), - # "suggestion": FieldSchema(type="string", description="the modified code snippet that meets the guideline, with minimal modifications. Should be empty if the snippet is compliant"), - }, - required=["is_compliant", "comment"], -) - -MONO_PROMPT = ( - "As a code compliance agent, you are going to receive requests from user with two elements: a code snippet, and a guideline. " - "You should answer in JSON format with an analysis result. The comment should be an empty string if the code is compliant with the guideline." -) - - -MULTI_SCHEMA = ObjectSchema( - type="object", - properties={ - "result": ArraySchema( - type="array", - items=ObjectSchema( - type="object", - properties={ - "is_compliant": FieldSchema( - type="boolean", - description="whether the guideline has been followed in the code snippet", - ), - "comment": FieldSchema( - type="string", - description="short instruction to make the snippet compliant, addressed to the developer who wrote the code. Should be empty if the snippet is compliant", - ), - # "suggestion": FieldSchema(type="string", description="the modified code snippet that meets the guideline, with minimal modifications. Should be empty if the snippet is compliant"), - }, - required=["is_compliant", "comment"], - ), - ), - }, - required=["result"], -) - -MULTI_PROMPT = ( - "As a code compliance agent, you are going to receive requests from user with two elements: a code snippet, and a list of guidelines. " - "You should answer in JSON format with a list of compliance results, one for each guideline, no more, no less (in the same order). " - "For a given compliance results, the comment should be an empty string if the code is compliant with the corresponding guideline." -) - -PARSING_SCHEMA = ObjectSchema( - type="object", - properties={ - "result": ArraySchema( - type="array", - items=ObjectSchema( - type="object", - properties={ - "source": FieldSchema( - type="string", - description="the text section the guideline was extracted from", - ), - "category": { - "type": "string", - "description": "the high-level category of the guideline", - "enum": [ - "naming", - "error handling", - "syntax", - "comments", - "docstring", - "documentation", - "testing", - "signature", - "type hint", - "formatting", - "other", - ], - }, - "details": FieldSchema( - type="string", - description="a descriptive, comprehensive, inambiguous and specific explanation of the guideline.", - ), - "title": FieldSchema(type="string", description="a summary title of the guideline"), - "examples": ObjectSchema( - type="object", - properties={ - "positive": FieldSchema( - type="string", - description="a short code snippet where the guideline was correctly followed.", - ), - "negative": FieldSchema( - type="string", - description="the same snippet with minimal modification that invalidate the instruction.", - ), - }, - required=["positive", "negative"], - ), - }, - required=["category", "details", "title", "examples"], - ), - ), - }, - required=["result"], -) - -PARSING_PROMPT = ( - "You are responsible for summarizing the list of distinct coding guidelines for the company, by going through documentation. " - "This list will be used by developers to avoid hesitations in code reviews and to onboard new members. " - "Consider only guidelines that can be verified for a specific snippet of code (nothing about git, commits or community interactions) " - "by a human developer without running additional commands or tools, it should only relate to the code within each file. " - "Only include guidelines for which you could generate positive and negative code snippets, " - "don't invent anything that isn't present in the input text or someone will die. " - "You should answer in JSON format with only the list of string guidelines." -) - -EXAMPLE_SCHEMA = ObjectSchema( - type="object", - properties={ - "positive": FieldSchema( - type="string", - description="a short code snippet where the instruction was correctly followed.", - ), - "negative": FieldSchema( - type="string", - description="the same snippet with minimal modification that invalidate the instruction.", - ), - }, - required=["positive", "negative"], -) - -EXAMPLE_PROMPT = ( - "You are responsible for producing concise code snippets to illustrate the company coding guidelines. " - "This will be used to teach new developers our way of engineering software. " - "You should answer in JSON format with only two short code snippets in the specified programming language: one that follows the rule correctly, " - "and a similar version with minimal modifications that violates the rule. " - "Make sure your code is functional, don't extra comments or explanation, or someone will die." -) - -ModelInp = TypeVar("ModelInp") - - -def validate_model(model: Type[ModelInp], data: Dict[str, Any]) -> Union[ModelInp, None]: - try: - return model(**data) - except (ValidationError, TypeError): - return None +class OpenAIModel(str, Enum): + # https://platform.openai.com/docs/models/overview + GPT4o: str = "gpt-4o-2024-05-13" + GPT3_5: str = "gpt-3.5-turbo-0125" class OpenAIClient: - ENDPOINT: str = "https://api.openai.com/v1/chat/completions" - def __init__( self, api_key: str, model: OpenAIModel, temperature: float = 0.0, - frequency_penalty: float = 1.0, ) -> None: - self.headers = self._get_headers(api_key) + self._client = OpenAI(api_key=api_key) # Validate model - model_card = requests.get(f"https://api.openai.com/v1/models/{model}", headers=self.headers, timeout=5) - if model_card.status_code != 200: - raise HTTPException(status_code=model_card.status_code, detail=model_card.json()["error"]["message"]) + model_card = self._client.models.retrieve(model) self.model = model self.temperature = temperature - self.frequency_penalty = frequency_penalty logger.info( - f"Using OpenAI model: {self.model} (created at {datetime.fromtimestamp(model_card.json()['created']).isoformat()})", + f"Using OpenAI w/ {self.model} (created at {datetime.fromtimestamp(model_card.created).isoformat()})", ) - @staticmethod - def _get_headers(api_key: str) -> Dict[str, str]: - return {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} - - def check_code_against_guidelines( - self, - code: str, - guidelines: List[Guideline], - mode: ExecutionMode = ExecutionMode.SINGLE, - **kwargs, - ) -> List[ComplianceResult]: - # Check args before sending a request - if len(code) == 0 or len(guidelines) == 0 or any(len(guideline.content) == 0 for guideline in guidelines): - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="No code or guideline provided for analysis.", - ) - # Ideas: check which programming language & whether it's correct code - if mode == ExecutionMode.SINGLE: - parsed_response = self._analyze( - MULTI_PROMPT, - {"code": code, "guidelines": [guideline.content for guideline in guidelines]}, - MULTI_SCHEMA, - **kwargs, - )["result"] - if len(parsed_response) != len(guidelines): - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Invalid model response") - elif mode == ExecutionMode.MULTI: - with ThreadPoolExecutor() as executor: - tasks = [ - executor.submit( - self._analyze, - MONO_PROMPT, - {"code": code, "guideline": guideline.content}, - MONO_SCHEMA, - **kwargs, - ) - for guideline in guidelines - ] - # Collect results - parsed_response = [task.result() for task in tasks] - else: - raise ValueError("invalid value for argument `mode`") - - # Return with pydantic validation - return [ - # ComplianceResult(is_compliant=res["is_compliant"], comment="" if res["is_compliant"] else res["comment"]) - ComplianceResult(guideline_id=guideline.id, **res) - for guideline, res in zip(guidelines, parsed_response) - ] - - def check_code(self, code: str, guideline: Guideline, **kwargs) -> ComplianceResult: - # Check args before sending a request - if len(code) == 0 or len(guideline.content) == 0: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="No code or guideline provided for analysis.", - ) - res = self._analyze(MONO_PROMPT, {"code": code, "guideline": guideline.content}, MONO_SCHEMA, **kwargs) - # Return with pydantic validation - return ComplianceResult(guideline_id=guideline.id, **res) - - def _request( + def chat( self, - system_prompt: str, - openai_fn: OpenAIFunction, - message: str, - timeout: int = 20, - model: Union[OpenAIModel, None] = None, - user_id: Union[str, None] = None, - ) -> Dict[str, Any]: + messages: List[Dict[str, str]], + system: Union[str, None] = None, + ) -> Generator[str, None, None]: # Prepare the request - _payload = ChatCompletion( - model=model or self.model, - messages=[ - OpenAIMessage( - role=OpenAIChatRole.SYSTEM, - content=system_prompt, + _system = CHAT_PROMPT if not system else f"{CHAT_PROMPT} {system}" + stream = cast( + Stream[ChatCompletionChunk], + self._client.chat.completions.create( # type: ignore[call-overload] + messages=( + {"role": "system", "content": _system}, + *messages, ), - OpenAIMessage( - role=OpenAIChatRole.USER, - content=message, - ), - ], - functions=[openai_fn], - function_call={"name": openai_fn.name}, - temperature=self.temperature, - frequency_penalty=self.frequency_penalty, - user=user_id, - ) - # Send the request - response = requests.post(self.ENDPOINT, json=_payload.model_dump(), headers=self.headers, timeout=timeout) - - # Check status - if response.status_code != 200: - raise HTTPException(status_code=response.status_code, detail=response.json()["error"]["message"]) - - return json.loads(response.json()["choices"][0]["message"]["function_call"]["arguments"]) - - def _analyze( - self, - prompt: str, - payload: Dict[str, Any], - schema: ObjectSchema, - **kwargs, - ) -> Dict[str, Any]: - return self._request( - prompt, - OpenAIFunction( - name="analyze_code", - description="Analyze code", - parameters=schema, + model=self.model, + # Optional + temperature=self.temperature, + max_tokens=2048, + top_p=1, + stop=None, + stream=True, + stream_options={"include_usage": True}, ), - json.dumps(payload), - **kwargs, - ) - - def parse_guidelines_from_text( - self, - corpus: str, - **kwargs, - ) -> List[GuidelineContent]: - if not isinstance(corpus, str): - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="The input corpus needs to be a string.", - ) - if len(corpus) == 0: - return [] - - response = self._request( - PARSING_PROMPT, - OpenAIFunction( - name="validate_guidelines_from_text", - description="Validate extracted coding guidelines from corpus", - parameters=PARSING_SCHEMA, - ), - json.dumps(corpus), - **kwargs, - ) - guidelines = [validate_model(GuidelineContent, elt) for elt in response["result"]] - if any(guideline is None for guideline in guidelines): - logger.info("Validation errors on some guidelines") - return [guideline for guideline in guidelines if guideline is not None] - - def generate_examples_for_instruction( - self, - instruction: str, - language: str, - **kwargs, - ) -> GuidelineExample: - if ( - not isinstance(instruction, str) - or len(instruction) == 0 - or not isinstance(language, str) - or len(language) == 0 - ): - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="The instruction and language need to be non-empty strings.", - ) - - return GuidelineExample( - **self._request( - EXAMPLE_PROMPT, - OpenAIFunction( - name="generate_examples_from_instruction", - description="Generate code examples for a coding instruction", - parameters=EXAMPLE_SCHEMA, - ), - json.dumps({"instruction": instruction, "language": language}), - **kwargs, - ) ) + for chunk in stream: + if len(chunk.choices) > 0 and isinstance(chunk.choices[0].delta.content, str): + yield chunk.choices[0].delta.content + if chunk.usage: + logger.info( + f"OpenAI ({self.model}): {chunk.usage.prompt_tokens} prompt tokens | {chunk.usage.completion_tokens} completion tokens", + ) diff --git a/src/app/services/llm/utils.py b/src/app/services/llm/utils.py new file mode 100644 index 0000000..d7e6fe0 --- /dev/null +++ b/src/app/services/llm/utils.py @@ -0,0 +1,65 @@ +# Copyright (C) 2024, Quack AI. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details. + +import json +import re +from typing import Dict, List + +from fastapi import HTTPException, status + +__all__ = ["CHAT_PROMPT"] + +EXAMPLE_PROMPT = ( + "You are responsible for producing concise illustrations of the company coding guidelines. " + "This will be used to teach new developers our way of engineering software. " + "Make sure your code is in the specified programming language and functional, don't add extra comments or explanations.\n" + # Format + "You should output two code blocks: " + "a minimal code snippet where the instruction was correctly followed, " + "and the same snippet with minimal modifications that invalidates the instruction." +) +# Strangely, this doesn't work when compiled +EXAMPLE_PATTERN = r"```[a-zA-Z]*\n(?P<positive>.*?)```\n.*```[a-zA-Z]*\n(?P<negative>.*?)```" + +PARSING_PROMPT = ( + "You are responsible for summarizing the list of distinct coding guidelines for the company, by going through documentation. " + "This list will be used by developers to avoid hesitations in code reviews and to onboard new members. " + "Consider only guidelines that can be verified for a specific snippet of code (nothing about git, commits or community interactions) " + "by a human developer without running additional commands or tools, it should only relate to the code within each file. " + "Only include guidelines for which you could generate positive and negative code snippets, " + "don't invent anything that isn't present in the input text.\n" + # Format + "You should answer with a list of JSON dictionaries, one dictionary per guideline, where each dictionary has two keys with string values:\n" + "- title: a short summary title of the guideline\n" + "- details: a descriptive, comprehensive and inambiguous explanation of the guideline." +) +PARSING_PATTERN = r"\{\s*\"title\":\s+\"(?P<title>.*?)\",\s+\"details\":\s+\"(?P<details>.*?)\"\s*\}" + +CHAT_PROMPT = ( + "You are an AI programming assistant, developed by the company Quack AI, and you only answer questions related to computer science " + "(refuse to answer for the rest)." +) + +GUIDELINE_PROMPT = ( + "When answering user requests, you should at all times keep in mind the following software development guidelines:" +) + + +def validate_example_response(response: str) -> Dict[str, str]: + matches = re.search(EXAMPLE_PATTERN, response.strip(), re.DOTALL) + if matches is None: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed output schema validation") + + return matches.groupdict() + + +def validate_parsing_response(response: str) -> List[Dict[str, str]]: + guideline_list = json.loads(response.strip()) + if not isinstance(guideline_list, list) or any( + not isinstance(val, str) for guideline in guideline_list for val in guideline.values() + ): + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed output schema validation") + + return json.loads(response.strip()) diff --git a/src/tests/services/test_llm.py b/src/tests/services/test_llm.py index a5b1c92..8704cdd 100644 --- a/src/tests/services/test_llm.py +++ b/src/tests/services/test_llm.py @@ -1,13 +1,17 @@ import types import pytest -from groq import AuthenticationError, NotFoundError +from groq import AuthenticationError as GAuthError +from groq import NotFoundError as GNotFoundError from httpx import ConnectError from ollama import ResponseError +from openai import AuthenticationError as OAIAuthError +from openai import NotFoundError as OAINotFounderError from app.core.config import settings from app.services.llm.groq import GroqClient from app.services.llm.ollama import OllamaClient +from app.services.llm.openai import OpenAIClient @pytest.mark.parametrize( @@ -34,10 +38,10 @@ def test_ollamaclient_chat(): def test_groqclient_constructor(): - with pytest.raises(AuthenticationError): + with pytest.raises(GAuthError): GroqClient("api_key", settings.GROQ_MODEL) if isinstance(settings.GROQ_API_KEY, str): - with pytest.raises(NotFoundError): + with pytest.raises(GNotFoundError): GroqClient(settings.GROQ_API_KEY, "quack") GroqClient(settings.GROQ_API_KEY, settings.GROQ_MODEL) @@ -49,3 +53,21 @@ def test_groqclient_chat(): assert isinstance(stream, types.GeneratorType) for chunk in stream: assert isinstance(chunk, str) + + +def test_openaiclient_constructor(): + with pytest.raises(OAIAuthError): + OpenAIClient("api_key", settings.OPENAI_MODEL) + if isinstance(settings.OPENAI_API_KEY, str): + with pytest.raises(OAINotFounderError): + OpenAIClient(settings.OPENAI_API_KEY, "quack") + OpenAIClient(settings.OPENAI_API_KEY, settings.OPENAI_MODEL) + + +@pytest.mark.skipif("settings.OPENAI_API_KEY is None") +def test_openaiclient_chat(): + llm_client = OpenAIClient(settings.OPENAI_API_KEY, settings.OPENAI_MODEL) + stream = llm_client.chat([{"role": "user", "content": "hello"}]) + assert isinstance(stream, types.GeneratorType) + for chunk in stream: + assert isinstance(chunk, str)