From cc38d1012407206c868c437673de00f17a076436 Mon Sep 17 00:00:00 2001 From: Sanjay Nadhavajhala Date: Mon, 12 Feb 2024 01:48:31 -0800 Subject: [PATCH] init task executor service --- services/backend/task_executor/Dockerfile | 24 + services/backend/task_executor/__init__.py | 0 .../backend/task_executor/app/__init__.py | 0 .../task_executor/app/backend_models.py | 410 +++ .../task_executor/app/celery_config.py | 6 + .../backend/task_executor/app/local_model.py | 395 +++ services/backend/task_executor/app/models.py | 2874 +++++++++++++++++ services/backend/task_executor/app/tasks.py | 505 +++ .../app/tools/file_knowledge_tool.py | 50 + .../task_executor/app/tools/get_date_tool.py | 35 + .../app/tools/web_browse_tool/browser.py | 147 + .../tools/web_browse_tool/web_browse_tool.py | 94 + .../app/vector_db/milvus/CustomEmbeddings.py | 55 + .../app/vector_db/milvus/main.py | 119 + .../app/vector_db/milvus/query_milvus.py | 825 +++++ .../backend/task_executor/requirements.txt | 20 + 16 files changed, 5559 insertions(+) create mode 100644 services/backend/task_executor/Dockerfile create mode 100644 services/backend/task_executor/__init__.py create mode 100644 services/backend/task_executor/app/__init__.py create mode 100644 services/backend/task_executor/app/backend_models.py create mode 100644 services/backend/task_executor/app/celery_config.py create mode 100644 services/backend/task_executor/app/local_model.py create mode 100644 services/backend/task_executor/app/models.py create mode 100644 services/backend/task_executor/app/tasks.py create mode 100644 services/backend/task_executor/app/tools/file_knowledge_tool.py create mode 100644 services/backend/task_executor/app/tools/get_date_tool.py create mode 100644 services/backend/task_executor/app/tools/web_browse_tool/browser.py create mode 100644 services/backend/task_executor/app/tools/web_browse_tool/web_browse_tool.py create mode 100644 services/backend/task_executor/app/vector_db/milvus/CustomEmbeddings.py create mode 100644 services/backend/task_executor/app/vector_db/milvus/main.py create mode 100644 services/backend/task_executor/app/vector_db/milvus/query_milvus.py create mode 100644 services/backend/task_executor/requirements.txt diff --git a/services/backend/task_executor/Dockerfile b/services/backend/task_executor/Dockerfile new file mode 100644 index 0000000..c1e9efe --- /dev/null +++ b/services/backend/task_executor/Dockerfile @@ -0,0 +1,24 @@ +FROM python:3.10.7-slim + +# Set the working directory in the container to /app +WORKDIR /app + +# First, add only the requirements file (to leverage Docker cache) +ADD requirements.txt /app/ + +# Install any needed packages specified in requirements.txt +RUN pip install --no-cache-dir -r requirements.txt +RUN spacy download en_core_web_sm +RUN playwright install +RUN playwright install-deps + + +# Add the current directory contents into the container at /app +# This is done after installing dependencies to leverage Docker's cache +ADD . /app + +# Set environment variables +ENV OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES + +# Run app.py when the container launches +CMD ["celery", "-A", "app.tasks", "worker", "--loglevel=info"] diff --git a/services/backend/task_executor/__init__.py b/services/backend/task_executor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/services/backend/task_executor/app/__init__.py b/services/backend/task_executor/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/services/backend/task_executor/app/backend_models.py b/services/backend/task_executor/app/backend_models.py new file mode 100644 index 0000000..ab92e0c --- /dev/null +++ b/services/backend/task_executor/app/backend_models.py @@ -0,0 +1,410 @@ +# backend_models.py +from beanie import Document +from pydantic import ( + AnyUrl, + BaseModel, + Extra, + Field, + PositiveFloat, + confloat, + conint, + constr, +) +from typing import Any, Dict, List, Optional, Union +from .models import ( + AssistantFileObject, + AssistantToolsFunction, + CreateAssistantRequest, + FineTuningJob, + ImagesResponse, + LastError, + ListAssistantFilesResponse, + ListFilesResponse, + ListFineTuneEventsResponse, + Object7, + Object14, + Purpose1, + Status, + Object20, + Object21, + Object22, + Object23, + Object24, + Object25, + Object27, + Object28, + RequiredAction, + Role7, + Role8, + Type16, + Status2, + Status3, + LastError1, + MessageContentTextObject, + MessageContentImageFileObject, + RunStepDetailsMessageCreationObject, + RunStepDetailsToolCallsObject, + AssistantToolsCode, + AssistantToolsRetrieval, + AssistantToolsFunction, + AssistantToolsBrowser +) + +class AssistantObject(Document): + assistant_id: str = Field( + ..., description="The identifier, which can be referenced in API endpoints.", alias="id" + ) + object: Object20 = Field( + ..., description="The object type, which is always `assistant`." + ) + created_at: int = Field( + ..., + description="The Unix timestamp (in seconds) for when the assistant was created.", + ) + name: constr(max_length=256) = Field( + ..., + description="The name of the assistant. The maximum length is 256 characters.\n", + ) + description: constr(max_length=512) = Field( + ..., + description="The description of the assistant. The maximum length is 512 characters.\n", + ) + model: str = Field( + ..., + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", + ) + instructions: constr(max_length=32768) = Field( + ..., + description="The system instructions that the assistant uses. The maximum length is 32768 characters.\n", + ) + tools: List[ + Union[AssistantToolsCode, AssistantToolsRetrieval, AssistantToolsFunction, AssistantToolsBrowser] + ] = Field( + ..., + description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `retrieval`, or `function`.\n", + max_items=128, + ) + file_ids: List[str] = Field( + ..., + description="A list of [file](/docs/api-reference/files) IDs attached to this assistant. There can be a maximum of 20 files attached to the assistant. Files are ordered by their creation date in ascending order.\n", + max_items=20, + ) + metadata: Dict[str, Any] = Field( + ..., + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ) + + class Settings: + name = "assistants" + +class ListAssistantsResponse(BaseModel): + object: str = Field(..., example="list") + data: List[AssistantObject] + first_id: str = Field(..., example="asst_hLBK7PXBv5Lr2NQT7KLY0ag1") + last_id: str = Field(..., example="asst_QLoItBbqwyAJEzlTy4y9kOMM") + has_more: bool = Field(..., example=False) + +class ThreadObject(Document): + thread_id: str = Field( + ..., description="The identifier, which can be referenced in API endpoints.", alias="id" + ) + object: Object23 = Field( + ..., description="The object type, which is always `thread`." + ) + created_at: int = Field( + ..., + description="The Unix timestamp (in seconds) for when the thread was created.", + ) + metadata: Dict[str, Any] = Field( + ..., + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long." + ) + + class Settings: + name = "threads" + +class MessageObject(Document): + message_id: str = Field( + ..., description="The identifier, which can be referenced in API endpoints.", alias="id" + ) + object: Object25 = Field( + ..., description="The object type, which is always `thread.message`." + ) + created_at: int = Field( + ..., + description="The Unix timestamp (in seconds) for when the message was created.", + ) + thread_id: str = Field( + ..., + description="The [thread](/docs/api-reference/threads) ID that this message belongs to.", + ) + role: Role7 = Field( + ..., + description="The entity that produced the message. One of `user` or `assistant`.", + ) + content: List[ + Union[MessageContentImageFileObject, MessageContentTextObject] + ] = Field( + ..., description="The content of the message in array of text and/or images." + ) + assistant_id: str = Field( + ..., + description="If applicable, the ID of the [assistant](/docs/api-reference/assistants) that authored this message.", + ) + run_id: str = Field( + ..., + description="If applicable, the ID of the [run](/docs/api-reference/runs) associated with the authoring of this message.", + ) + file_ids: List[str] = Field( + ..., + description="A list of [file](/docs/api-reference/files) IDs that the assistant should use. Useful for tools like retrieval and code_interpreter that can access files. A maximum of 10 files can be attached to a message.", + max_items=10, + ) + metadata: Dict[str, Any] = Field( + ..., + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ) + + class Settings: + name = "messages" + +class ListMessagesResponse(BaseModel): + object: str = Field(..., example="list") + data: List[MessageObject] + first_id: str = Field(..., example="msg_hLBK7PXBv5Lr2NQT7KLY0ag1") + last_id: str = Field(..., example="msg_QLoItBbqwyAJEzlTy4y9kOMM") + has_more: bool = Field(..., example=False) + +class RunObject(Document): + run_id: str = Field( + ..., description="The identifier, which can be referenced in API endpoints.", alias="id" + ) + object: Object22 = Field( + ..., description="The object type, which is always `thread.run`." + ) + created_at: int = Field( + ..., description="The Unix timestamp (in seconds) for when the run was created." + ) + thread_id: str = Field( + ..., + description="The ID of the [thread](/docs/api-reference/threads) that was executed on as a part of this run.", + ) + assistant_id: str = Field( + ..., + description="The ID of the [assistant](/docs/api-reference/assistants) used for execution of this run.", + ) + status: Status2 = Field( + ..., + description="The status of the run, which can be either `queued`, `in_progress`, `requires_action`, `cancelling`, `cancelled`, `failed`, `completed`, or `expired`.", + ) + required_action: RequiredAction = Field( + None, + description="Details on the action required to continue the run. Will be `null` if no action is required.", + ) + last_error: LastError = Field( + None, + description="The last error associated with this run. Will be `null` if there are no errors.", + ) + expires_at: int = Field( + None, description="The Unix timestamp (in seconds) for when the run will expire." + ) + started_at: int = Field( + None, description="The Unix timestamp (in seconds) for when the run was started." + ) + cancelled_at: int = Field( + None, + description="The Unix timestamp (in seconds) for when the run was cancelled.", + ) + failed_at: int = Field( + None, description="The Unix timestamp (in seconds) for when the run failed." + ) + completed_at: int = Field( + None, + description="The Unix timestamp (in seconds) for when the run was completed.", + ) + model: str = Field( + ..., + description="The model that the [assistant](/docs/api-reference/assistants) used for this run.", + ) + instructions: str = Field( + ..., + description="The instructions that the [assistant](/docs/api-reference/assistants) used for this run.", + ) + tools: List[ + Union[AssistantToolsCode, AssistantToolsRetrieval, AssistantToolsFunction, AssistantToolsBrowser] + ] = Field( + ..., + description="The list of tools that the [assistant](/docs/api-reference/assistants) used for this run.", + max_items=20, + ) + file_ids: List[str] = Field( + ..., + description="The list of [File](/docs/api-reference/files) IDs the [assistant](/docs/api-reference/assistants) used for this run.", + ) + metadata: Dict[str, Any] = Field( + ..., + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ) + + class Settings: + name = "runs" + + + +class OpenAIFile(Document): + file_id: str = Field( + ..., + description="The file identifier, which can be referenced in the API endpoints.", alias="id" + ) + bytes: int = Field(..., description="The size of the file, in bytes.") + created_at: int = Field( + ..., + description="The Unix timestamp (in seconds) for when the file was created.", + ) + filename: str = Field(..., description="The name of the file.") + object: Object14 = Field( + ..., description="The object type, which is always `file`." + ) + purpose: Purpose1 = Field( + ..., + description="The intended purpose of the file. Supported values are `fine-tune`, `fine-tune-results`, `assistants`, and `assistants_output`.", + ) + status: Status = Field( + ..., + description="Deprecated. The current status of the file, which can be either `uploaded`, `processed`, or `error`.", + ) + status_details: Optional[str] = Field( + None, + description="Deprecated. For details on why a fine-tuning training file failed validation, see the `error` field on `fine_tuning.job`.", + ) + + class Settings: + name = "files" + +class ListFilesResponse(BaseModel): + data: List[OpenAIFile] + object: Object7 + +class FilesStorageObject(Document): + file_id: str = Field( + ..., + description="The file identifier, which can be referenced in the API endpoints.", alias="id" + ) + content: bytes = Field(..., description="The file content") + content_type: str = Field(..., description="The file content type") + + class Settings: + name = "files_storage" + + +class AssistantFileObject(Document): + file_id: str = Field( + ..., description="The identifier, which can be referenced in API endpoints." , alias="id" + ) + object: Object28 = Field( + ..., description="The object type, which is always `assistant.file`." + ) + created_at: int = Field( + ..., + description="The Unix timestamp (in seconds) for when the assistant file was created.", + ) + assistant_id: str = Field( + ..., description="The assistant ID that the file is attached to." + ) + + class Settings: + name = "assistant_files" + + +class ListAssistantFilesResponse(BaseModel): + object: str = Field(..., example="list") + data: List[AssistantFileObject] + first_id: str = Field(..., example="file-hLBK7PXBv5Lr2NQT7KLY0ag1") + last_id: str = Field(..., example="file-QLoItBbqwyAJEzlTy4y9kOMM") + has_more: bool = Field(..., example=False) + + +class RunStepObject(Document): + run_step_id: str = Field( + ..., + description="The identifier of the run step, which can be referenced in API endpoints.", alias="id" + ) + object: Object27 = Field( + ..., description="The object type, which is always `thread.run.step``." + ) + created_at: int = Field( + ..., + description="The Unix timestamp (in seconds) for when the run step was created.", + ) + assistant_id: str = Field( + ..., + description="The ID of the [assistant](/docs/api-reference/assistants) associated with the run step.", + ) + thread_id: str = Field( + ..., + description="The ID of the [thread](/docs/api-reference/threads) that was run.", + ) + run_id: str = Field( + ..., + description="The ID of the [run](/docs/api-reference/runs) that this run step is a part of.", + ) + type: Type16 = Field( + ..., + description="The type of run step, which can be either `message_creation` or `tool_calls`.", + ) + status: Status3 = Field( + ..., + description="The status of the run step, which can be either `in_progress`, `cancelled`, `failed`, `completed`, or `expired`.", + ) + step_details: Union[ + RunStepDetailsMessageCreationObject, RunStepDetailsToolCallsObject + ] = Field(..., description="The details of the run step.") + last_error: LastError1 = Field( + ..., + description="The last error associated with this run step. Will be `null` if there are no errors.", + ) + expired_at: int = Field( + ..., + description="The Unix timestamp (in seconds) for when the run step expired. A step is considered expired if the parent run is expired.", + ) + cancelled_at: int = Field( + ..., + description="The Unix timestamp (in seconds) for when the run step was cancelled.", + ) + failed_at: int = Field( + ..., description="The Unix timestamp (in seconds) for when the run step failed." + ) + completed_at: int = Field( + ..., + description="The Unix timestamp (in seconds) for when the run step completed.", + ) + metadata: Dict[str, Any] = Field( + ..., + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ) + class Settings: + name = "run_steps" + + +class ListRunStepsResponse(BaseModel): + object: str = Field(..., example="list") + data: List[RunStepObject] + first_id: str = Field(..., example="step_hLBK7PXBv5Lr2NQT7KLY0ag1") + last_id: str = Field(..., example="step_QLoItBbqwyAJEzlTy4y9kOMM") + has_more: bool = Field(..., example=False) + + +class ListRunsResponse(BaseModel): + object: str = Field(..., example="list") + data: List[RunObject] + first_id: str = Field(..., example="run_hLBK7PXBv5Lr2NQT7KLY0ag1") + last_id: str = Field(..., example="run_QLoItBbqwyAJEzlTy4y9kOMM") + has_more: bool = Field(..., example=False) + + +class FileUpload(BaseModel): + purpose: str + +class ApiKeysUpdateModel(BaseModel): + OPENAI_API_KEY: Optional[bool] = None + ANTHROPIC_API_KEY: Optional[bool] = None diff --git a/services/backend/task_executor/app/celery_config.py b/services/backend/task_executor/app/celery_config.py new file mode 100644 index 0000000..d0deef8 --- /dev/null +++ b/services/backend/task_executor/app/celery_config.py @@ -0,0 +1,6 @@ +import os + +CELERY_REDIS_HOST = os.getenv("REDIS_HOST", "localhost") +BROKER_URL = f'redis://{CELERY_REDIS_HOST}:6379/0' # Redis configuration +CELERY_RESULT_BACKEND = f'redis://{CELERY_REDIS_HOST}:6379/0' +CELERY_IMPORTS = ("app.tasks", ) diff --git a/services/backend/task_executor/app/local_model.py b/services/backend/task_executor/app/local_model.py new file mode 100644 index 0000000..9ed8241 --- /dev/null +++ b/services/backend/task_executor/app/local_model.py @@ -0,0 +1,395 @@ +# Standard Library +import json +import logging +import os +import random + +# Third Party +import spacy +from openai import OpenAI + +# Local +from app.models import Role7, Type8, Type824 +from app.tools.file_knowledge_tool import FileKnowledgeTool +from app.tools.get_date_tool import get_date +from app.tools.web_browse_tool.web_browse_tool import ( + WebBrowseTool, + create_google_search_url, + parse_url, +) + +ner = spacy.load("en_core_web_sm") + +pattern = r">(.*?) bool: + """Detect if the query contains a date entity""" + # Process the text + doc = ner(query) + + # Search for entities in the text + for ent in doc.ents: + if ent.label_ == "DATE": + return True + + return False + + +def query_enhancement_relative_date(query: str) -> str: + PROMPT = f""" +today's date is : {get_date()} +Given the QUESTION, generate a new query that can be used to google search. +Find out the indent of the QUESTION, if it's looking for information with relative date, improve it with fixed date, for example: +EXAMPLE1: "who won the US election yesterday" -> "US election result {get_date(day_delta=-1)}" +EXAMPLE2: "tomorrow's weather in SF" -> "weather in SF on {get_date(day_delta=1)}" + +Your response should be in the following json format: +{QUERY_FORMAT} + +QUESTION: {query} +""" + messages = [{"role": "user", "content": PROMPT}] + response = oai_client.chat.completions.create( + model="openai/custom", messages=messages, stream=False + ) + res = response.choices[0].message.content + try: + res = json.loads(res)["query"] + except Exception as e: + logging.error( + "LLM response is not in valid json format, using the original query." + ) + res = query + return res + + +def query_enhancement_no_date(query: str) -> str: + PROMPT = f""" +Today's date is : {get_date()} +Given the QUESTION, generate a new query that can be used to google search. +Find out the intent of the QUESTION, if it is looking for latest information, then add date properly for example: +EXAMPLE1: "who won the Taiwan election" -> "Taiwan election result {get_date(year_only=True)}" +EXAMPLE2: "the result of IOWA vote" -> "the result of IOWA vote {get_date(year_only=True)}" + +Your response should be in the following json format: +{QUERY_FORMAT} + +QUESTION: {query} +""" + messages = [{"role": "user", "content": PROMPT}] + response = oai_client.chat.completions.create( + model="openai/custom", messages=messages, stream=False + ) + res = response.choices[0].message.content + try: + res = json.loads(res)["query"] + except Exception as e: + logging.error( + "LLM response is not in valid json format, using the original query." + ) + res = query + return res + + +def simple_summarize(query: str, context: str) -> str: + if len(context.strip()) == 0: + return "Sorry, I can't find any relevant information." + PROMPT = f""" +Do NOT infer or assume anything, generate an answer to the QUESTION only based on the search results you got, include as much information as possible. +If the search results is irrelevant, politely express that you can't help. +-------------------- +SEARCH RESULTS: +{context} +-------------------- +QUESTION: {query} +""" + messages = [{"role": "user", "content": PROMPT}] + response = oai_client.chat.completions.create( + model="openai/custom", temperature=0.1, messages=messages, stream=False + ) + return response.choices[0].message.content + + +QA_FORMAT_FALSE = "Not enough information" + + +def simple_qa(query: str, context: str) -> str: + PROMPT = f""" +Do NOT infer or assume anything, only answer the QUESTION based on the search results. +When the search results do not directly provide enough information for answering the qestion, response with txt format: {QA_FORMAT_FALSE} +Otherwise, simply response with you answer that only focus on answer the question, and do not infer or assume anything. +-------------------- +SEARCH RESULTS: +{context} +-------------------- +QUESTION: {query} +""" + messages = [{"role": "user", "content": PROMPT}] + response = oai_client.chat.completions.create( + model="openai/custom", + temperature=0.1, + messages=messages, + stream=False, + response_format="web", + ) + return response.choices[0].message.content + + +def simple_judge(query: str, answer: str) -> str: + PROMPT = f""" +Given the QUESTION and the ANSWER, determine if the ANSWER is sufficient to the QUESTION. +if the ANSWER does answer the QUESTION completely, response with word "Yes". +Otherwise, if the answer indicates that there's not enough information to fully answer the question, response with word "No". + +QUESTION: {query}? ANSWER: {answer} +""" + messages = [{"role": "user", "content": PROMPT}] + response = oai_client.chat.completions.create( + model="openai/custom", temperature=0.0, messages=messages, stream=False + ) + return response.choices[0].message.content + + +def multi_step_summarize(query: str, context: str) -> str: + # Third Party + from langchain.text_splitter import RecursiveCharacterTextSplitter + + text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=300) + + page_res_list = [] + for i, c in enumerate(context): + ctx = parse_url(c["url"], query) + logging.info(f"=========visited url:========== {c['url']}") + if len(ctx) == 0: + continue + texts = text_splitter.split_text(ctx) + + useful_chunk = 0 + for i, t in enumerate(texts): + if i >= 10: # arbitrary limit + break + chunk_qa = simple_qa(query=query, context=t) + logging.debug(f"CHUNK {i} : {chunk_qa}") + if not chunk_qa.startswith(QA_FORMAT_FALSE): + page_res_list.append(chunk_qa) + useful_chunk += 1 + + if useful_chunk > 0: + page_qa = simple_summarize(query=query, context="".join(page_res_list)) + logging.info(f"======PAGE_QA:======\n {page_qa}") + is_sufficient = simple_judge(query=query, answer=page_qa) + logging.info(f"======IS_SUFFICIENT:======\n {is_sufficient}") + if is_sufficient.strip() == "Yes": + return page_qa + return page_qa + + +class RubraLocalAgent: + def __init__(self, assistant_id, tools): + self.assistant_id = assistant_id + self.setup_tools(tools) + + def setup_tools(self, tools): + self.available_tools = {} + self.tools = [] + for t in tools: + if t["type"] == Type8.retrieval.value: + file_search_tool = FileKnowledgeTool() + + def post_processed_file_search(query: str): + context = file_search_tool._run( + query=query, assistant_id=self.assistant_id + ) + summarized_ans = simple_summarize(query=query, context=context) + return summarized_ans + + self.tools.append( + { + "name": file_search_tool.name, + "description": file_search_tool.description, + "parameters": file_search_tool.parameters, + } + ) + self.available_tools[file_search_tool.name] = post_processed_file_search + elif t["type"] == Type824.retrieval.value: + search_tool = WebBrowseTool() + + def post_processed_google_search(query: str): + date_detected = ner_date_detection(query) + if date_detected: + logging.info( + "Date detected in query, using relative date enhancement." + ) + # new_query = query_enhancement_relative_date(query) # TODO: need improvement + new_query = query + else: + logging.info( + "No date detected in query, using no date enhancement." + ) + new_query = query_enhancement_no_date(query) + logging.debug(f"enhanced search query : {new_query}") + context = search_tool._run( + new_query, web_browse=False, concat_text=False + ) + # summarized_ans = simple_summarize(new_query, context=context) + summarized_ans = multi_step_summarize(query, context=context) + + # Add search reference for web-browse/google-search + google_search_url = create_google_search_url(new_query) + word_pool1 = ["I did", "After", "With"] + word_pool2 = ["I found", "I have", "I got", "I discovered"] + search_res_prefix = f"{random.choice(word_pool1)} a [quick search]({google_search_url}), here's what {random.choice(word_pool2)}:\n\n" + final_ans = search_res_prefix + summarized_ans + return final_ans + + self.tools.append( + { + "name": search_tool.name, + "description": search_tool.description, + "parameters": search_tool.parameters, + } + ) + self.available_tools[search_tool.name] = post_processed_google_search + + def validate_function_call(self, msg: str) -> (bool, str): + try: + funtion_call_json = json.loads(msg) + except Exception as e: + logging.warning("invalid json format") + logging.warning(e) + try: + funtion_call_json = json.loads( + msg + "}" + ) # sometimes the msg is not complete with a closing bracket + except Exception as e: + logging.error(e) + return False, msg + + if "function" in funtion_call_json: + return True, json.dumps(funtion_call_json) + else: + return False, funtion_call_json["content"] + + def chat( + self, + msgs: list, + sys_instruction: str = "", + stream: bool = True, + ): + messages = [] + + if sys_instruction is None or sys_instruction == "": + system_instruction = "You are a helpful assistant." + else: + system_instruction = sys_instruction + + if len(self.tools) > 0: + response_format = {"type": "json_object"} + system_instruction += f""" +You have access to the following tool: +``` +{self.tools[0]} +``` +To use a tool, response strictly with the following json format: +{FUNCTION_CALL_FORMAT} + +To chat with user, response strictly with the following json format: +{CHAT_FORMAT} + +You MUST only answer user's question based on the output from tools, include as much information as possible. +If there is no tools or no relevant information that matches user's request, you should response that you can't help. +""" + else: + response_format = None + + messages.append({"role": "system", "content": system_instruction}) + for msg in msgs: + if ( + msg["role"] == "user" + or msg["role"] == "assistant" + or msg["role"] == "tool_output" + ): + messages.append(msg) + + response = oai_client.chat.completions.create( + model="openai/custom", + messages=messages, + stream=stream, + temperature=0.1, + response_format=response_format, + ) + + return response + + def get_function_response(self, function_call_json): + try: + function_name = function_call_json["function"] + function_to_call = self.available_tools[function_name] + function_args = function_call_json["args"] + logging.info( + f"Calling function : {function_name} with args: {function_args}" + ) + function_response = function_to_call(**function_args) + return function_response + except Exception as e: + logging.error(e) + return "Rubra Backend Error: Failed to process function." + + def conversation( + self, + msgs: list, + sys_instruction: str = "", + stream: bool = True, + ): + messages = msgs.copy() + + response = self.chat(msgs=msgs, sys_instruction=sys_instruction, stream=stream) + + msg = "" + for r in response: + if r.choices[0].delta.content != None: + msg += r.choices[0].delta.content + + if len(self.tools) == 0: + msgs.append({"role": "assistant", "content": msg}) + return response, msgs + + is_function_call, parsed_msg = self.validate_function_call(msg) + print(parsed_msg) + + if is_function_call: + print("=====function call========") + msgs.append({"role": "assistant", "content": parsed_msg}) + + function_response = self.get_function_response( + function_call_json=json.loads(parsed_msg) + ) + + msg = function_response + print(f"\n MESSAGE: {msg}\n") + parsed_msg = msg + + msgs.append({"role": "assistant", "content": parsed_msg}) + messages = msgs + return response, messages diff --git a/services/backend/task_executor/app/models.py b/services/backend/task_executor/app/models.py new file mode 100644 index 0000000..17a4501 --- /dev/null +++ b/services/backend/task_executor/app/models.py @@ -0,0 +1,2874 @@ +# generated by fastapi-codegen: +# filename: /Users/sanjay_acorn/Downloads/openapi.yaml +# timestamp: 2023-11-18T07:40:51+00:00 + +from __future__ import annotations + +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +from pydantic import ( + AnyUrl, + BaseModel, + Extra, + Field, + PositiveFloat, + confloat, + conint, + constr, +) + + +class Error(BaseModel): + code: str + message: str + param: str + type: str + + +class ErrorResponse(BaseModel): + error: Error + + +class Object(Enum): + list = "list" + + +class DeleteModelResponse(BaseModel): + id: str + deleted: bool + object: str + + +class ModelEnum(Enum): + babbage_002 = "babbage-002" + davinci_002 = "davinci-002" + gpt_3_5_turbo_instruct = "gpt-3.5-turbo-instruct" + text_davinci_003 = "text-davinci-003" + text_davinci_002 = "text-davinci-002" + text_davinci_001 = "text-davinci-001" + code_davinci_002 = "code-davinci-002" + text_curie_001 = "text-curie-001" + text_babbage_001 = "text-babbage-001" + text_ada_001 = "text-ada-001" + + +class PromptItem(BaseModel): + __root__: List[Any] + + +class CreateCompletionRequest(BaseModel): + model: Union[str, ModelEnum] = Field( + ..., + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", + ) + prompt: Union[str, List[str], List[int], List[PromptItem]] = Field( + ..., + description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n", + ) + best_of: Optional[conint(ge=0, le=20)] = Field( + 1, + description='Generates `best_of` completions server-side and returns the "best" (the one with the highest log probability per token). Results cannot be streamed.\n\nWhen used with `n`, `best_of` controls the number of candidate completions and `n` specifies how many to return – `best_of` must be greater than `n`.\n\n**Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.\n', + ) + echo: Optional[bool] = Field( + False, description="Echo back the prompt in addition to the completion\n" + ) + frequency_penalty: Optional[confloat(ge=-2.0, le=2.0)] = Field( + 0, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/guides/gpt/parameter-details)\n", + ) + logit_bias: Optional[Dict[str, int]] = Field( + None, + description='Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. You can use this [tokenizer tool](/tokenizer?view=bpe) (which works for both GPT-2 and GPT-3) to convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n\nAs an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated.\n', + ) + logprobs: Optional[conint(ge=0, le=5)] = Field( + None, + description="Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response.\n\nThe maximum value for `logprobs` is 5.\n", + ) + max_tokens: Optional[conint(ge=0)] = Field( + 16, + description="The maximum number of [tokens](/tokenizer) to generate in the completion.\n\nThe token count of your prompt plus `max_tokens` cannot exceed the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", + example=16, + ) + n: Optional[conint(ge=1, le=128)] = Field( + 1, + description="How many completions to generate for each prompt.\n\n**Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.\n", + example=1, + ) + presence_penalty: Optional[confloat(ge=-2.0, le=2.0)] = Field( + 0, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/guides/gpt/parameter-details)\n", + ) + seed: Optional[conint(ge=-9223372036854775808, le=9223372036854775808)] = Field( + None, + description="If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.\n\nDeterminism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.\n", + ) + stop: Optional[Union[str, List[str]]] = Field( + None, + description="Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.\n", + ) + stream: Optional[bool] = Field( + False, + description="Whether to stream back partial progress. If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ) + suffix: Optional[str] = Field( + None, + description="The suffix that comes after a completion of inserted text.", + example="test.", + ) + temperature: Optional[confloat(ge=0.0, le=2.0)] = Field( + 1, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n", + example=1, + ) + top_p: Optional[confloat(ge=0.0, le=1.0)] = Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n", + example=1, + ) + user: Optional[str] = Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + example="user-1234", + ) + + +class FinishReason(Enum): + stop = "stop" + length = "length" + content_filter = "content_filter" + + +class Logprobs(BaseModel): + text_offset: Optional[List[int]] = None + token_logprobs: Optional[List[float]] = None + tokens: Optional[List[str]] = None + top_logprobs: Optional[List[Dict[str, float]]] = None + + +class Choice(BaseModel): + finish_reason: FinishReason = Field( + ..., + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\nor `content_filter` if content was omitted due to a flag from our content filters.\n", + ) + index: int + logprobs: Logprobs + text: str + + +class Object1(Enum): + text_completion = "text_completion" + + +class Type(Enum): + image_url = "image_url" + + +class Detail(Enum): + auto = "auto" + low = "low" + high = "high" + + +class ImageUrl(BaseModel): + url: AnyUrl = Field( + ..., description="Either a URL of the image or the base64 encoded image data." + ) + detail: Optional[Detail] = Field( + "auto", description="Specifies the detail level of the image." + ) + + +class ChatCompletionRequestMessageContentPartImage(BaseModel): + type: Type = Field(..., description="The type of the content part.") + image_url: ImageUrl + + +class Type1(Enum): + text = "text" + + +class ChatCompletionRequestMessageContentPartText(BaseModel): + type: Type1 = Field(..., description="The type of the content part.") + text: str = Field(..., description="The text content.") + + +class Role(Enum): + system = "system" + + +class ChatCompletionRequestSystemMessage(BaseModel): + content: str = Field(..., description="The contents of the system message.") + role: Role = Field( + ..., description="The role of the messages author, in this case `system`." + ) + + +class Role1(Enum): + user = "user" + + +class Role2(Enum): + assistant = "assistant" + + +class FunctionCall(BaseModel): + arguments: str = Field( + ..., + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.", + ) + name: str = Field(..., description="The name of the function to call.") + + +class Role3(Enum): + tool = "tool" + + +class ChatCompletionRequestToolMessage(BaseModel): + role: Role3 = Field( + ..., description="The role of the messages author, in this case `tool`." + ) + content: str = Field(..., description="The contents of the tool message.") + tool_call_id: str = Field( + ..., description="Tool call that this message is responding to." + ) + + +class Role4(Enum): + function = "function" + + +class ChatCompletionRequestFunctionMessage(BaseModel): + role: Role4 = Field( + ..., description="The role of the messages author, in this case `function`." + ) + content: str = Field( + ..., + description="The return value from the function call, to return to the model.", + ) + name: str = Field(..., description="The name of the function to call.") + + +class FunctionParameters(BaseModel): + pass + + class Config: + extra = Extra.allow + + +class ChatCompletionFunctions(BaseModel): + description: Optional[str] = Field( + None, + description="A description of what the function does, used by the model to choose when and how to call the function.", + ) + name: str = Field( + ..., + description="The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64.", + ) + parameters: FunctionParameters + + +class ChatCompletionFunctionCallOption(BaseModel): + name: str = Field(..., description="The name of the function to call.") + + +class Type2(Enum): + function = "function" + + +class FunctionObject(BaseModel): + description: Optional[str] = Field( + None, + description="A description of what the function does, used by the model to choose when and how to call the function.", + ) + name: str = Field( + ..., + description="The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64.", + ) + parameters: FunctionParameters + + +class ChatCompletionToolChoiceOptionEnum(Enum): + none = "none" + auto = "auto" + + +class Type3(Enum): + function = "function" + + +class Function(BaseModel): + name: str = Field(..., description="The name of the function to call.") + + +class ChatCompletionNamedToolChoice(BaseModel): + type: Optional[Type3] = Field( + None, + description="The type of the tool. Currently, only `function` is supported.", + ) + function: Optional[Function] = None + + +class Type4(Enum): + function = "function" + + +class Function1(BaseModel): + name: str = Field(..., description="The name of the function to call.") + arguments: str = Field( + ..., + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.", + ) + + +class ChatCompletionMessageToolCall(BaseModel): + id: str = Field(..., description="The ID of the tool call.") + type: Type4 = Field( + ..., + description="The type of the tool. Currently, only `function` is supported.", + ) + function: Function1 = Field(..., description="The function that the model called.") + + +class Type5(Enum): + function = "function" + + +class Function2(BaseModel): + name: Optional[str] = Field(None, description="The name of the function to call.") + arguments: Optional[str] = Field( + None, + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.", + ) + + +class ChatCompletionMessageToolCallChunk(BaseModel): + index: int + id: Optional[str] = Field(None, description="The ID of the tool call.") + type: Optional[Type5] = Field( + None, + description="The type of the tool. Currently, only `function` is supported.", + ) + function: Optional[Function2] = None + + +class ChatCompletionRole(Enum): + system = "system" + user = "user" + assistant = "assistant" + tool = "tool" + function = "function" + + +class Role5(Enum): + assistant = "assistant" + + +class FunctionCall1(BaseModel): + arguments: str = Field( + ..., + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.", + ) + name: str = Field(..., description="The name of the function to call.") + + +class FunctionCall2(BaseModel): + arguments: Optional[str] = Field( + None, + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.", + ) + name: Optional[str] = Field(None, description="The name of the function to call.") + + +class Role6(Enum): + system = "system" + user = "user" + assistant = "assistant" + tool = "tool" + + +class ChatCompletionStreamResponseDelta(BaseModel): + content: Optional[str] = Field( + None, description="The contents of the chunk message." + ) + function_call: Optional[FunctionCall2] = Field( + None, + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.", + ) + tool_calls: Optional[List[ChatCompletionMessageToolCallChunk]] = None + role: Optional[Role6] = Field( + None, description="The role of the author of this message." + ) + + +class ModelEnum1(Enum): + gpt_4_1106_preview = "gpt-4-1106-preview" + gpt_4_vision_preview = "gpt-4-vision-preview" + gpt_4 = "gpt-4" + gpt_4_0314 = "gpt-4-0314" + gpt_4_0613 = "gpt-4-0613" + gpt_4_32k = "gpt-4-32k" + gpt_4_32k_0314 = "gpt-4-32k-0314" + gpt_4_32k_0613 = "gpt-4-32k-0613" + gpt_3_5_turbo_1106 = "gpt-3.5-turbo-1106" + gpt_3_5_turbo = "gpt-3.5-turbo" + gpt_3_5_turbo_16k = "gpt-3.5-turbo-16k" + gpt_3_5_turbo_0301 = "gpt-3.5-turbo-0301" + gpt_3_5_turbo_0613 = "gpt-3.5-turbo-0613" + gpt_3_5_turbo_16k_0613 = "gpt-3.5-turbo-16k-0613" + + +class Type6(Enum): + text = "text" + json_object = "json_object" + + +class ResponseFormat(BaseModel): + type: Optional[Type6] = Field( + "text", + description="Must be one of `text` or `json_object`.", + example="json_object", + ) + + +class FunctionCallEnum(Enum): + none = "none" + auto = "auto" + + +class FinishReason1(Enum): + stop = "stop" + length = "length" + tool_calls = "tool_calls" + content_filter = "content_filter" + function_call = "function_call" + + +class Object2(Enum): + chat_completion = "chat.completion" + + +class FinishReason2(Enum): + stop = "stop" + length = "length" + function_call = "function_call" + content_filter = "content_filter" + + +class Object3(Enum): + chat_completion = "chat.completion" + + +class Object4(Enum): + list = "list" + + +class FinishReason3(Enum): + stop = "stop" + length = "length" + tool_calls = "tool_calls" + content_filter = "content_filter" + function_call = "function_call" + + +class Choice3(BaseModel): + delta: ChatCompletionStreamResponseDelta + finish_reason: FinishReason3 = Field( + ..., + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\n`content_filter` if content was omitted due to a flag from our content filters,\n`tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.\n", + ) + index: int = Field( + ..., description="The index of the choice in the list of choices." + ) + + +class Object5(Enum): + chat_completion_chunk = "chat.completion.chunk" + + +class CreateChatCompletionStreamResponse(BaseModel): + id: str = Field( + ..., + description="A unique identifier for the chat completion. Each chunk has the same ID.", + ) + choices: List[Choice3] = Field( + ..., + description="A list of chat completion choices. Can be more than one if `n` is greater than 1.", + ) + created: int = Field( + ..., + description="The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp.", + ) + model: str = Field(..., description="The model to generate the completion.") + system_fingerprint: Optional[str] = Field( + None, + description="This fingerprint represents the backend configuration that the model runs with.\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + ) + object: Object5 = Field( + ..., description="The object type, which is always `chat.completion.chunk`." + ) + + +class CreateChatCompletionImageResponse(BaseModel): + pass + + +class ModelEnum2(Enum): + text_davinci_edit_001 = "text-davinci-edit-001" + code_davinci_edit_001 = "code-davinci-edit-001" + + +class CreateEditRequest(BaseModel): + instruction: str = Field( + ..., + description="The instruction that tells the model how to edit the prompt.", + example="Fix the spelling mistakes.", + ) + model: Union[str, ModelEnum2] = Field( + ..., + description="ID of the model to use. You can use the `text-davinci-edit-001` or `code-davinci-edit-001` model with this endpoint.", + example="text-davinci-edit-001", + ) + input: Optional[str] = Field( + "", + description="The input text to use as a starting point for the edit.", + example="What day of the wek is it?", + ) + n: Optional[conint(ge=1, le=20)] = Field( + 1, + description="How many edits to generate for the input and instruction.", + example=1, + ) + temperature: Optional[confloat(ge=0.0, le=2.0)] = Field( + 1, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n", + example=1, + ) + top_p: Optional[confloat(ge=0.0, le=1.0)] = Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n", + example=1, + ) + + +class FinishReason4(Enum): + stop = "stop" + length = "length" + + +class Choice4(BaseModel): + finish_reason: FinishReason4 = Field( + ..., + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\nor `content_filter` if content was omitted due to a flag from our content filters.\n", + ) + index: int = Field( + ..., description="The index of the choice in the list of choices." + ) + text: str = Field(..., description="The edited result.") + + +class Object6(Enum): + edit = "edit" + + +class ModelEnum3(Enum): + dall_e_2 = "dall-e-2" + dall_e_3 = "dall-e-3" + + +class Quality(Enum): + standard = "standard" + hd = "hd" + + +class ResponseFormat1(Enum): + url = "url" + b64_json = "b64_json" + + +class Size(Enum): + field_256x256 = "256x256" + field_512x512 = "512x512" + field_1024x1024 = "1024x1024" + field_1792x1024 = "1792x1024" + field_1024x1792 = "1024x1792" + + +class Style(Enum): + vivid = "vivid" + natural = "natural" + + +class CreateImageRequest(BaseModel): + prompt: str = Field( + ..., + description="A text description of the desired image(s). The maximum length is 1000 characters for `dall-e-2` and 4000 characters for `dall-e-3`.", + example="A cute baby sea otter", + ) + model: Optional[Union[str, ModelEnum3]] = Field( + "dall-e-2", + description="The model to use for image generation.", + example="dall-e-3", + ) + n: Optional[conint(ge=1, le=10)] = Field( + 1, + description="The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported.", + example=1, + ) + quality: Optional[Quality] = Field( + "standard", + description="The quality of the image that will be generated. `hd` creates images with finer details and greater consistency across the image. This param is only supported for `dall-e-3`.", + example="standard", + ) + response_format: Optional[ResponseFormat1] = Field( + "url", + description="The format in which the generated images are returned. Must be one of `url` or `b64_json`.", + example="url", + ) + size: Optional[Size] = Field( + "1024x1024", + description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3` models.", + example="1024x1024", + ) + style: Optional[Style] = Field( + "vivid", + description="The style of the generated images. Must be one of `vivid` or `natural`. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. This param is only supported for `dall-e-3`.", + example="vivid", + ) + user: Optional[str] = Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + example="user-1234", + ) + + +class Image(BaseModel): + b64_json: Optional[str] = Field( + None, + description="The base64-encoded JSON of the generated image, if `response_format` is `b64_json`.", + ) + url: Optional[str] = Field( + None, + description="The URL of the generated image, if `response_format` is `url` (default).", + ) + revised_prompt: Optional[str] = Field( + None, + description="The prompt that was used to generate the image, if there was any revision to the prompt.", + ) + + +class ModelEnum4(Enum): + dall_e_2 = "dall-e-2" + + +class Size1(Enum): + field_256x256 = "256x256" + field_512x512 = "512x512" + field_1024x1024 = "1024x1024" + + +class ResponseFormat2(Enum): + url = "url" + b64_json = "b64_json" + + +class CreateImageEditRequest(BaseModel): + image: bytes = Field( + ..., + description="The image to edit. Must be a valid PNG file, less than 4MB, and square. If mask is not provided, image must have transparency, which will be used as the mask.", + ) + prompt: str = Field( + ..., + description="A text description of the desired image(s). The maximum length is 1000 characters.", + example="A cute baby sea otter wearing a beret", + ) + mask: Optional[bytes] = Field( + None, + description="An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where `image` should be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as `image`.", + ) + model: Optional[Union[str, ModelEnum4]] = Field( + "dall-e-2", + description="The model to use for image generation. Only `dall-e-2` is supported at this time.", + example="dall-e-2", + ) + n: Optional[conint(ge=1, le=10)] = Field( + 1, + description="The number of images to generate. Must be between 1 and 10.", + example=1, + ) + size: Optional[Size1] = Field( + "1024x1024", + description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`.", + example="1024x1024", + ) + response_format: Optional[ResponseFormat2] = Field( + "url", + description="The format in which the generated images are returned. Must be one of `url` or `b64_json`.", + example="url", + ) + user: Optional[str] = Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + example="user-1234", + ) + + +class ModelEnum5(Enum): + dall_e_2 = "dall-e-2" + + +class ResponseFormat3(Enum): + url = "url" + b64_json = "b64_json" + + +class Size2(Enum): + field_256x256 = "256x256" + field_512x512 = "512x512" + field_1024x1024 = "1024x1024" + + +class CreateImageVariationRequest(BaseModel): + image: bytes = Field( + ..., + description="The image to use as the basis for the variation(s). Must be a valid PNG file, less than 4MB, and square.", + ) + model: Optional[Union[str, ModelEnum5]] = Field( + "dall-e-2", + description="The model to use for image generation. Only `dall-e-2` is supported at this time.", + example="dall-e-2", + ) + n: Optional[conint(ge=1, le=10)] = Field( + 1, + description="The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported.", + example=1, + ) + response_format: Optional[ResponseFormat3] = Field( + "url", + description="The format in which the generated images are returned. Must be one of `url` or `b64_json`.", + example="url", + ) + size: Optional[Size2] = Field( + "1024x1024", + description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`.", + example="1024x1024", + ) + user: Optional[str] = Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + example="user-1234", + ) + + +class ModelEnum6(Enum): + text_moderation_latest = "text-moderation-latest" + text_moderation_stable = "text-moderation-stable" + + +class CreateModerationRequest(BaseModel): + input: Union[str, List[str]] = Field(..., description="The input text to classify") + model: Optional[Union[str, ModelEnum6]] = Field( + "text-moderation-latest", + description="Two content moderations models are available: `text-moderation-stable` and `text-moderation-latest`.\n\nThe default is `text-moderation-latest` which will be automatically upgraded over time. This ensures you are always using our most accurate model. If you use `text-moderation-stable`, we will provide advanced notice before updating the model. Accuracy of `text-moderation-stable` may be slightly lower than for `text-moderation-latest`.\n", + example="text-moderation-stable", + ) + + +class Categories(BaseModel): + hate: bool = Field( + ..., + description="Content that expresses, incites, or promotes hate based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste. Hateful content aimed at non-protected groups (e.g., chess players) is harrassment.", + ) + hate_threatening: bool = Field( + ..., + alias="hate/threatening", + description="Hateful content that also includes violence or serious harm towards the targeted group based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste.", + ) + harassment: bool = Field( + ..., + description="Content that expresses, incites, or promotes harassing language towards any target.", + ) + harassment_threatening: bool = Field( + ..., + alias="harassment/threatening", + description="Harassment content that also includes violence or serious harm towards any target.", + ) + self_harm: bool = Field( + ..., + alias="self-harm", + description="Content that promotes, encourages, or depicts acts of self-harm, such as suicide, cutting, and eating disorders.", + ) + self_harm_intent: bool = Field( + ..., + alias="self-harm/intent", + description="Content where the speaker expresses that they are engaging or intend to engage in acts of self-harm, such as suicide, cutting, and eating disorders.", + ) + self_harm_instructions: bool = Field( + ..., + alias="self-harm/instructions", + description="Content that encourages performing acts of self-harm, such as suicide, cutting, and eating disorders, or that gives instructions or advice on how to commit such acts.", + ) + sexual: bool = Field( + ..., + description="Content meant to arouse sexual excitement, such as the description of sexual activity, or that promotes sexual services (excluding sex education and wellness).", + ) + sexual_minors: bool = Field( + ..., + alias="sexual/minors", + description="Sexual content that includes an individual who is under 18 years old.", + ) + violence: bool = Field( + ..., description="Content that depicts death, violence, or physical injury." + ) + violence_graphic: bool = Field( + ..., + alias="violence/graphic", + description="Content that depicts death, violence, or physical injury in graphic detail.", + ) + + +class CategoryScores(BaseModel): + hate: float = Field(..., description="The score for the category 'hate'.") + hate_threatening: float = Field( + ..., + alias="hate/threatening", + description="The score for the category 'hate/threatening'.", + ) + harassment: float = Field( + ..., description="The score for the category 'harassment'." + ) + harassment_threatening: float = Field( + ..., + alias="harassment/threatening", + description="The score for the category 'harassment/threatening'.", + ) + self_harm: float = Field( + ..., alias="self-harm", description="The score for the category 'self-harm'." + ) + self_harm_intent: float = Field( + ..., + alias="self-harm/intent", + description="The score for the category 'self-harm/intent'.", + ) + self_harm_instructions: float = Field( + ..., + alias="self-harm/instructions", + description="The score for the category 'self-harm/instructions'.", + ) + sexual: float = Field(..., description="The score for the category 'sexual'.") + sexual_minors: float = Field( + ..., + alias="sexual/minors", + description="The score for the category 'sexual/minors'.", + ) + violence: float = Field(..., description="The score for the category 'violence'.") + violence_graphic: float = Field( + ..., + alias="violence/graphic", + description="The score for the category 'violence/graphic'.", + ) + + +class Result(BaseModel): + flagged: bool = Field( + ..., + description="Whether the content violates [OpenAI's usage policies](/policies/usage-policies).", + ) + categories: Categories = Field( + ..., + description="A list of the categories, and whether they are flagged or not.", + ) + category_scores: CategoryScores = Field( + ..., + description="A list of the categories along with their scores as predicted by model.", + ) + + +class CreateModerationResponse(BaseModel): + id: str = Field( + ..., description="The unique identifier for the moderation request." + ) + model: str = Field( + ..., description="The model used to generate the moderation results." + ) + results: List[Result] = Field(..., description="A list of moderation objects.") + + +class Object7(Enum): + list = "list" + + +class Purpose(Enum): + fine_tune = "fine-tune" + assistants = "assistants" + + +class CreateFileRequest(BaseModel): + class Config: + extra = Extra.forbid + + file: bytes = Field( + ..., description="The File object (not file name) to be uploaded.\n" + ) + purpose: Purpose = Field( + ..., + description='The intended purpose of the uploaded file.\n\nUse "fine-tune" for [Fine-tuning](/docs/api-reference/fine-tuning) and "assistants" for [Assistants](/docs/api-reference/assistants) and [Messages](/docs/api-reference/messages). This allows us to validate the format of the uploaded file is correct for fine-tuning.\n', + ) + + +class Object8(Enum): + file = "file" + + +class DeleteFileResponse(BaseModel): + id: str + object: Object8 + deleted: bool + + +class ModelEnum7(Enum): + babbage_002 = "babbage-002" + davinci_002 = "davinci-002" + gpt_3_5_turbo = "gpt-3.5-turbo" + + +class BatchSizeEnum(Enum): + auto = "auto" + + +class LearningRateMultiplierEnum(Enum): + auto = "auto" + + +class NEpoch(Enum): + auto = "auto" + + +class Hyperparameters(BaseModel): + batch_size: Optional[Union[BatchSizeEnum, conint(ge=1, le=256)]] = Field( + "auto", + description="Number of examples in each batch. A larger batch size means that model parameters\nare updated less frequently, but with lower variance.\n", + ) + learning_rate_multiplier: Optional[ + Union[LearningRateMultiplierEnum, PositiveFloat] + ] = Field( + "auto", + description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid\noverfitting.\n", + ) + n_epochs: Optional[Union[NEpoch, conint(ge=1, le=50)]] = Field( + "auto", + description="The number of epochs to train the model for. An epoch refers to one full cycle \nthrough the training dataset.\n", + ) + + +class CreateFineTuningJobRequest(BaseModel): + model: Union[str, ModelEnum7] = Field( + ..., + description="The name of the model to fine-tune. You can select one of the\n[supported models](/docs/guides/fine-tuning/what-models-can-be-fine-tuned).\n", + example="gpt-3.5-turbo", + ) + training_file: str = Field( + ..., + description="The ID of an uploaded file that contains training data.\n\nSee [upload file](/docs/api-reference/files/upload) for how to upload a file.\n\nYour dataset must be formatted as a JSONL file. Additionally, you must upload your file with the purpose `fine-tune`.\n\nSee the [fine-tuning guide](/docs/guides/fine-tuning) for more details.\n", + example="file-abc123", + ) + hyperparameters: Optional[Hyperparameters] = Field( + None, description="The hyperparameters used for the fine-tuning job." + ) + suffix: Optional[constr(min_length=1, max_length=40)] = Field( + None, + description='A string of up to 18 characters that will be added to your fine-tuned model name.\n\nFor example, a `suffix` of "custom-model-name" would produce a model name like `ft:gpt-3.5-turbo:openai:custom-model-name:7p4lURel`.\n', + ) + validation_file: Optional[str] = Field( + None, + description="The ID of an uploaded file that contains validation data.\n\nIf you provide this file, the data is used to generate validation\nmetrics periodically during fine-tuning. These metrics can be viewed in\nthe fine-tuning results file.\nThe same data should not be present in both train and validation files.\n\nYour dataset must be formatted as a JSONL file. You must upload your file with the purpose `fine-tune`.\n\nSee the [fine-tuning guide](/docs/guides/fine-tuning) for more details.\n", + example="file-abc123", + ) + + +class Object9(Enum): + list = "list" + + +class NEpoch1(Enum): + auto = "auto" + + +class Hyperparameters1(BaseModel): + n_epochs: Optional[Union[NEpoch1, conint(ge=1, le=50)]] = Field( + "auto", + description="The number of epochs to train the model for. An epoch refers to one\nfull cycle through the training dataset.\n", + ) + + +class ModelEnum8(Enum): + ada = "ada" + babbage = "babbage" + curie = "curie" + davinci = "davinci" + + +class CreateFineTuneRequest(BaseModel): + training_file: str = Field( + ..., + description='The ID of an uploaded file that contains training data.\n\nSee [upload file](/docs/api-reference/files/upload) for how to upload a file.\n\nYour dataset must be formatted as a JSONL file, where each training\nexample is a JSON object with the keys "prompt" and "completion".\nAdditionally, you must upload your file with the purpose `fine-tune`.\n\nSee the [fine-tuning guide](/docs/guides/legacy-fine-tuning/creating-training-data) for more details.\n', + example="file-abc123", + ) + batch_size: Optional[int] = Field( + None, + description="The batch size to use for training. The batch size is the number of\ntraining examples used to train a single forward and backward pass.\n\nBy default, the batch size will be dynamically configured to be\n~0.2% of the number of examples in the training set, capped at 256 -\nin general, we've found that larger batch sizes tend to work better\nfor larger datasets.\n", + ) + classification_betas: Optional[List[float]] = Field( + None, + description="If this is provided, we calculate F-beta scores at the specified\nbeta values. The F-beta score is a generalization of F-1 score.\nThis is only used for binary classification.\n\nWith a beta of 1 (i.e. the F-1 score), precision and recall are\ngiven the same weight. A larger beta score puts more weight on\nrecall and less on precision. A smaller beta score puts more weight\non precision and less on recall.\n", + example=[0.6, 1, 1.5, 2], + ) + classification_n_classes: Optional[int] = Field( + None, + description="The number of classes in a classification task.\n\nThis parameter is required for multiclass classification.\n", + ) + classification_positive_class: Optional[str] = Field( + None, + description="The positive class in binary classification.\n\nThis parameter is needed to generate precision, recall, and F1\nmetrics when doing binary classification.\n", + ) + compute_classification_metrics: Optional[bool] = Field( + False, + description="If set, we calculate classification-specific metrics such as accuracy\nand F-1 score using the validation set at the end of every epoch.\nThese metrics can be viewed in the [results file](/docs/guides/legacy-fine-tuning/analyzing-your-fine-tuned-model).\n\nIn order to compute classification metrics, you must provide a\n`validation_file`. Additionally, you must\nspecify `classification_n_classes` for multiclass classification or\n`classification_positive_class` for binary classification.\n", + ) + hyperparameters: Optional[Hyperparameters1] = Field( + None, description="The hyperparameters used for the fine-tuning job." + ) + learning_rate_multiplier: Optional[float] = Field( + None, + description="The learning rate multiplier to use for training.\nThe fine-tuning learning rate is the original learning rate used for\npretraining multiplied by this value.\n\nBy default, the learning rate multiplier is the 0.05, 0.1, or 0.2\ndepending on final `batch_size` (larger learning rates tend to\nperform better with larger batch sizes). We recommend experimenting\nwith values in the range 0.02 to 0.2 to see what produces the best\nresults.\n", + ) + model: Optional[Union[str, ModelEnum8]] = Field( + "curie", + description='The name of the base model to fine-tune. You can select one of "ada",\n"babbage", "curie", "davinci", or a fine-tuned model created after 2022-04-21 and before 2023-08-22.\nTo learn more about these models, see the\n[Models](/docs/models) documentation.\n', + example="curie", + ) + prompt_loss_weight: Optional[float] = Field( + 0.01, + description="The weight to use for loss on the prompt tokens. This controls how\nmuch the model tries to learn to generate the prompt (as compared\nto the completion which always has a weight of 1.0), and can add\na stabilizing effect to training when completions are short.\n\nIf prompts are extremely long (relative to completions), it may make\nsense to reduce this weight so as to avoid over-prioritizing\nlearning the prompt.\n", + ) + suffix: Optional[constr(min_length=1, max_length=40)] = Field( + None, + description='A string of up to 40 characters that will be added to your fine-tuned model name.\n\nFor example, a `suffix` of "custom-model-name" would produce a model name like `ada:ft-your-org:custom-model-name-2022-02-15-04-21-04`.\n', + ) + validation_file: Optional[str] = Field( + None, + description='The ID of an uploaded file that contains validation data.\n\nIf you provide this file, the data is used to generate validation\nmetrics periodically during fine-tuning. These metrics can be viewed in\nthe [fine-tuning results file](/docs/guides/legacy-fine-tuning/analyzing-your-fine-tuned-model).\nYour train and validation data should be mutually exclusive.\n\nYour dataset must be formatted as a JSONL file, where each validation\nexample is a JSON object with the keys "prompt" and "completion".\nAdditionally, you must upload your file with the purpose `fine-tune`.\n\nSee the [fine-tuning guide](/docs/guides/legacy-fine-tuning/creating-training-data) for more details.\n', + example="file-abc123", + ) + + +class Object10(Enum): + list = "list" + + +class Object11(Enum): + list = "list" + + +class InputItem(BaseModel): + __root__: List[Any] + + +class ModelEnum9(Enum): + text_embedding_ada_002 = "text-embedding-ada-002" + + +class EncodingFormat(Enum): + float = "float" + base64 = "base64" + + +class CreateEmbeddingRequest(BaseModel): + class Config: + extra = Extra.forbid + + input: Union[str, List[str], List[int], List[InputItem]] = Field( + ..., + description="Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`) and cannot be an empty string. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", + example="The quick brown fox jumped over the lazy dog", + ) + model: Union[str, ModelEnum9] = Field( + ..., + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", + example="text-embedding-ada-002", + ) + encoding_format: Optional[EncodingFormat] = Field( + "float", + description="The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).", + example="float", + ) + user: Optional[str] = Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + example="user-1234", + ) + + +class Object12(Enum): + list = "list" + + +class Usage(BaseModel): + prompt_tokens: int = Field( + ..., description="The number of tokens used by the prompt." + ) + total_tokens: int = Field( + ..., description="The total number of tokens used by the request." + ) + + +class ModelEnum10(Enum): + whisper_1 = "whisper-1" + + +class ResponseFormat4(Enum): + json = "json" + text = "text" + srt = "srt" + verbose_json = "verbose_json" + vtt = "vtt" + + +class CreateTranscriptionRequest(BaseModel): + class Config: + extra = Extra.forbid + + file: bytes = Field( + ..., + description="The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.\n", + ) + model: Union[str, ModelEnum10] = Field( + ..., + description="ID of the model to use. Only `whisper-1` is currently available.\n", + example="whisper-1", + ) + language: Optional[str] = Field( + None, + description="The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency.\n", + ) + prompt: Optional[str] = Field( + None, + description="An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should match the audio language.\n", + ) + response_format: Optional[ResponseFormat4] = Field( + "json", + description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n", + ) + temperature: Optional[float] = Field( + 0, + description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n", + ) + + +class CreateTranscriptionResponse(BaseModel): + text: str + + +class ModelEnum11(Enum): + whisper_1 = "whisper-1" + + +class CreateTranslationRequest(BaseModel): + class Config: + extra = Extra.forbid + + file: bytes = Field( + ..., + description="The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.\n", + ) + model: Union[str, ModelEnum11] = Field( + ..., + description="ID of the model to use. Only `whisper-1` is currently available.\n", + example="whisper-1", + ) + prompt: Optional[str] = Field( + None, + description="An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should be in English.\n", + ) + response_format: Optional[str] = Field( + "json", + description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n", + ) + temperature: Optional[float] = Field( + 0, + description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n", + ) + + +class CreateTranslationResponse(BaseModel): + text: str + + +class ModelEnum12(Enum): + tts_1 = "tts-1" + tts_1_hd = "tts-1-hd" + + +class Voice(Enum): + alloy = "alloy" + echo = "echo" + fable = "fable" + onyx = "onyx" + nova = "nova" + shimmer = "shimmer" + + +class ResponseFormat5(Enum): + mp3 = "mp3" + opus = "opus" + aac = "aac" + flac = "flac" + + +class CreateSpeechRequest(BaseModel): + class Config: + extra = Extra.forbid + + model: Union[str, ModelEnum12] = Field( + ..., + description="One of the available [TTS models](/docs/models/tts): `tts-1` or `tts-1-hd`\n", + ) + input: constr(max_length=4096) = Field( + ..., + description="The text to generate audio for. The maximum length is 4096 characters.", + ) + voice: Voice = Field( + ..., + description="The voice to use when generating the audio. Supported voices are `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer`.", + ) + response_format: Optional[ResponseFormat5] = Field( + "mp3", + description="The format to audio in. Supported formats are `mp3`, `opus`, `aac`, and `flac`.", + ) + speed: Optional[confloat(ge=0.25, le=4.0)] = Field( + 1.0, + description="The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default.", + ) + + +class Object13(Enum): + model = "model" + + +class Model(BaseModel): + id: str = Field( + ..., + description="The model identifier, which can be referenced in the API endpoints.", + ) + created: int = Field( + ..., description="The Unix timestamp (in seconds) when the model was created." + ) + object: Object13 = Field( + ..., description='The object type, which is always "model".' + ) + owned_by: str = Field(..., description="The organization that owns the model.") + + +class Object14(Enum): + file = "file" + + +class Purpose1(Enum): + fine_tune = "fine_tune" + fine_tune_results = "fine_tune_results" + assistants = "assistants" + assistants_output = "assistants_output" + + +class Status(Enum): + uploaded = "uploaded" + processed = "processed" + error = "error" + + +class OpenAIFile(BaseModel): + id: str = Field( + ..., + description="The file identifier, which can be referenced in the API endpoints.", + ) + bytes: int = Field(..., description="The size of the file, in bytes.") + created_at: int = Field( + ..., + description="The Unix timestamp (in seconds) for when the file was created.", + ) + filename: str = Field(..., description="The name of the file.") + object: Object14 = Field( + ..., description="The object type, which is always `file`." + ) + purpose: Purpose1 = Field( + ..., + description="The intended purpose of the file. Supported values are `fine-tune`, `fine-tune-results`, `assistants`, and `assistants_output`.", + ) + status: Status = Field( + ..., + description="Deprecated. The current status of the file, which can be either `uploaded`, `processed`, or `error`.", + ) + status_details: Optional[str] = Field( + None, + description="Deprecated. For details on why a fine-tuning training file failed validation, see the `error` field on `fine_tuning.job`.", + ) + + +class Object15(Enum): + embedding = "embedding" + + +class Embedding(BaseModel): + index: int = Field( + ..., description="The index of the embedding in the list of embeddings." + ) + embedding: List[float] = Field( + ..., + description="The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](/docs/guides/embeddings).\n", + ) + object: Object15 = Field( + ..., description='The object type, which is always "embedding".' + ) + + +class Error1(BaseModel): + code: str = Field(..., description="A machine-readable error code.") + message: str = Field(..., description="A human-readable error message.") + param: str = Field( + ..., + description="The parameter that was invalid, usually `training_file` or `validation_file`. This field will be null if the failure was not parameter-specific.", + ) + + +class NEpoch2(Enum): + auto = "auto" + + +class Hyperparameters2(BaseModel): + n_epochs: Union[NEpoch2, conint(ge=1, le=50)] = Field( + ..., + description='The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.\n"auto" decides the optimal number of epochs based on the size of the dataset. If setting the number manually, we support any number between 1 and 50 epochs.', + ) + + +class Object16(Enum): + fine_tuning_job = "fine_tuning.job" + + +class Status1(Enum): + validating_files = "validating_files" + queued = "queued" + running = "running" + succeeded = "succeeded" + failed = "failed" + cancelled = "cancelled" + + +class FineTuningJob(BaseModel): + id: str = Field( + ..., + description="The object identifier, which can be referenced in the API endpoints.", + ) + created_at: int = Field( + ..., + description="The Unix timestamp (in seconds) for when the fine-tuning job was created.", + ) + error: Error1 = Field( + ..., + description="For fine-tuning jobs that have `failed`, this will contain more information on the cause of the failure.", + ) + fine_tuned_model: str = Field( + ..., + description="The name of the fine-tuned model that is being created. The value will be null if the fine-tuning job is still running.", + ) + finished_at: int = Field( + ..., + description="The Unix timestamp (in seconds) for when the fine-tuning job was finished. The value will be null if the fine-tuning job is still running.", + ) + hyperparameters: Hyperparameters2 = Field( + ..., + description="The hyperparameters used for the fine-tuning job. See the [fine-tuning guide](/docs/guides/fine-tuning) for more details.", + ) + model: str = Field(..., description="The base model that is being fine-tuned.") + object: Object16 = Field( + ..., description='The object type, which is always "fine_tuning.job".' + ) + organization_id: str = Field( + ..., description="The organization that owns the fine-tuning job." + ) + result_files: List[str] = Field( + ..., + description="The compiled results file ID(s) for the fine-tuning job. You can retrieve the results with the [Files API](/docs/api-reference/files/retrieve-contents).", + ) + status: Status1 = Field( + ..., + description="The current status of the fine-tuning job, which can be either `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`.", + ) + trained_tokens: int = Field( + ..., + description="The total number of billable tokens processed by this fine-tuning job. The value will be null if the fine-tuning job is still running.", + ) + training_file: str = Field( + ..., + description="The file ID used for training. You can retrieve the training data with the [Files API](/docs/api-reference/files/retrieve-contents).", + ) + validation_file: str = Field( + ..., + description="The file ID used for validation. You can retrieve the validation results with the [Files API](/docs/api-reference/files/retrieve-contents).", + ) + + +class Level(Enum): + info = "info" + warn = "warn" + error = "error" + + +class Object17(Enum): + fine_tuning_job_event = "fine_tuning.job.event" + + +class FineTuningJobEvent(BaseModel): + id: str + created_at: int + level: Level + message: str + object: Object17 + + +class Hyperparams(BaseModel): + batch_size: int = Field( + ..., + description="The batch size to use for training. The batch size is the number of\ntraining examples used to train a single forward and backward pass.\n", + ) + classification_n_classes: Optional[int] = Field( + None, + description="The number of classes to use for computing classification metrics.\n", + ) + classification_positive_class: Optional[str] = Field( + None, + description="The positive class to use for computing classification metrics.\n", + ) + compute_classification_metrics: Optional[bool] = Field( + None, + description="The classification metrics to compute using the validation dataset at the end of every epoch.\n", + ) + learning_rate_multiplier: float = Field( + ..., description="The learning rate multiplier to use for training.\n" + ) + n_epochs: int = Field( + ..., + description="The number of epochs to train the model for. An epoch refers to one\nfull cycle through the training dataset.\n", + ) + prompt_loss_weight: float = Field( + ..., description="The weight to use for loss on the prompt tokens.\n" + ) + + +class Object18(Enum): + fine_tune = "fine-tune" + + +class Object19(Enum): + fine_tune_event = "fine-tune-event" + + +class FineTuneEvent(BaseModel): + created_at: int + level: str + message: str + object: Object19 + + +class CompletionUsage(BaseModel): + completion_tokens: int = Field( + ..., description="Number of tokens in the generated completion." + ) + prompt_tokens: int = Field(..., description="Number of tokens in the prompt.") + total_tokens: int = Field( + ..., + description="Total number of tokens used in the request (prompt + completion).", + ) + + +class Object20(Enum): + assistant = "assistant" + + +class Object21(Enum): + assistant_deleted = "assistant.deleted" + + +class DeleteAssistantResponse(BaseModel): + id: str + deleted: bool + object: Object21 + + +class Type7(Enum): + code_interpreter = "code_interpreter" + + +class AssistantToolsCode(BaseModel): + type: Type7 = Field( + ..., description="The type of tool being defined: `code_interpreter`" + ) + + +class Type8(Enum): + retrieval = "retrieval" + + +class AssistantToolsRetrieval(BaseModel): + type: Type8 = Field(..., description="The type of tool being defined: `retrieval`") + + +class Type824(Enum): + retrieval = "browser" + + +class AssistantToolsBrowser(BaseModel): + type: Type824 = Field(..., description="The type of tool being defined: `browser`") + + +class Type9(Enum): + function = "function" + + +class AssistantToolsFunction(BaseModel): + type: Type9 = Field(..., description="The type of tool being defined: `function`") + function: FunctionObject + + +class Object22(Enum): + thread_run = "thread.run" + + +class Status2(Enum): + queued = "queued" + in_progress = "in_progress" + requires_action = "requires_action" + cancelling = "cancelling" + cancelled = "cancelled" + failed = "failed" + completed = "completed" + expired = "expired" + + +class Type10(Enum): + submit_tool_outputs = "submit_tool_outputs" + + +class Code(Enum): + server_error = "server_error" + rate_limit_exceeded = "rate_limit_exceeded" + + +class LastError(BaseModel): + code: Code = Field( + ..., description="One of `server_error` or `rate_limit_exceeded`." + ) + message: str = Field(..., description="A human-readable description of the error.") + + +class CreateRunRequest(BaseModel): + class Config: + extra = Extra.forbid + + assistant_id: str = Field( + ..., + description="The ID of the [assistant](/docs/api-reference/assistants) to use to execute this run.", + ) + model: Optional[str] = Field( + None, + description="The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used.", + ) + instructions: Optional[str] = Field( + None, + description="Override the default system message of the assistant. This is useful for modifying the behavior on a per-run basis.", + ) + tools: Optional[ + List[ + Union[ + AssistantToolsCode, + AssistantToolsRetrieval, + AssistantToolsFunction, + AssistantToolsBrowser, + ] + ] + ] = Field( + None, + description="Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis.", + max_items=20, + ) + metadata: Optional[Dict[str, Any]] = Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ) + + +class ModifyRunRequest(BaseModel): + class Config: + extra = Extra.forbid + + metadata: Optional[Dict[str, Any]] = Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ) + + +class ToolOutput(BaseModel): + tool_call_id: Optional[str] = Field( + None, + description="The ID of the tool call in the `required_action` object within the run object the output is being submitted for.", + ) + output: Optional[str] = Field( + None, + description="The output of the tool call to be submitted to continue the run.", + ) + + +class SubmitToolOutputsRunRequest(BaseModel): + class Config: + extra = Extra.forbid + + tool_outputs: List[ToolOutput] = Field( + ..., description="A list of tools for which the outputs are being submitted." + ) + + +class Type11(Enum): + function = "function" + + +class Function3(BaseModel): + name: str = Field(..., description="The name of the function.") + arguments: str = Field( + ..., + description="The arguments that the model expects you to pass to the function.", + ) + + +class RunToolCallObject(BaseModel): + id: str = Field( + ..., + description="The ID of the tool call. This ID must be referenced when you submit the tool outputs in using the [Submit tool outputs to run](/docs/api-reference/runs/submitToolOutputs) endpoint.", + ) + type: Type11 = Field( + ..., + description="The type of tool call the output is required for. For now, this is always `function`.", + ) + function: Function3 = Field(..., description="The function definition.") + + +class Object23(Enum): + thread = "thread" + + +# class ThreadObject(BaseModel): +# id: str = Field( +# ..., description="The identifier, which can be referenced in API endpoints." +# ) +# object: Object23 = Field( +# ..., description="The object type, which is always `thread`." +# ) +# created_at: int = Field( +# ..., +# description="The Unix timestamp (in seconds) for when the thread was created.", +# ) +# metadata: Dict[str, Any] = Field( +# ..., +# description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", +# ) + + +class ModifyThreadRequest(BaseModel): + class Config: + extra = Extra.forbid + + metadata: Optional[Dict[str, Any]] = Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ) + + +class Object24(Enum): + thread_deleted = "thread.deleted" + + +class DeleteThreadResponse(BaseModel): + id: str + deleted: bool + object: Object24 + + +class ListThreadsResponse(BaseModel): + object: str = Field(..., example="list") + data: List[ThreadObject] + first_id: str = Field(..., example="asst_hLBK7PXBv5Lr2NQT7KLY0ag1") + last_id: str = Field(..., example="asst_QLoItBbqwyAJEzlTy4y9kOMM") + has_more: bool = Field(..., example=False) + + +class Object25(Enum): + thread_message = "thread.message" + + +class Role7(Enum): + user = "user" + assistant = "assistant" + tool_call = "tool_call" + tool_output = "tool_output" + + +class Role8(Enum): + user = "user" + + +class CreateMessageRequest(BaseModel): + class Config: + extra = Extra.forbid + + role: Role8 = Field( + ..., + description="The role of the entity that is creating the message. Currently only `user` is supported.", + ) + content: constr(min_length=1, max_length=32768) = Field( + ..., description="The content of the message." + ) + file_ids: Optional[List[str]] = Field( + [], + description="A list of [File](/docs/api-reference/files) IDs that the message should use. There can be a maximum of 10 files attached to a message. Useful for tools like `retrieval` and `code_interpreter` that can access and use files.", + max_items=10, + min_items=1, + ) + metadata: Optional[Dict[str, Any]] = Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ) + + +class ModifyMessageRequest(BaseModel): + class Config: + extra = Extra.forbid + + metadata: Optional[Dict[str, Any]] = Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ) + + +class Object26(Enum): + thread_message_deleted = "thread.message.deleted" + + +class DeleteMessageResponse(BaseModel): + id: str + deleted: bool + object: Object26 + + +class Type12(Enum): + image_file = "image_file" + + +class ImageFile(BaseModel): + file_id: str = Field( + ..., + description="The [File](/docs/api-reference/files) ID of the image in the message content.", + ) + + +class MessageContentImageFileObject(BaseModel): + type: Type12 = Field(..., description="Always `image_file`.") + image_file: ImageFile + + +class Type13(Enum): + text = "text" + + +class Type14(Enum): + file_citation = "file_citation" + + +class FileCitation(BaseModel): + file_id: str = Field( + ..., description="The ID of the specific File the citation is from." + ) + quote: str = Field(..., description="The specific quote in the file.") + + +class MessageContentTextAnnotationsFileCitationObject(BaseModel): + type: Type14 = Field(..., description="Always `file_citation`.") + text: str = Field( + ..., description="The text in the message content that needs to be replaced." + ) + file_citation: FileCitation + start_index: conint(ge=0) + end_index: conint(ge=0) + + +class Type15(Enum): + file_path = "file_path" + + +class FilePath(BaseModel): + file_id: str = Field(..., description="The ID of the file that was generated.") + + +class MessageContentTextAnnotationsFilePathObject(BaseModel): + type: Type15 = Field(..., description="Always `file_path`.") + text: str = Field( + ..., description="The text in the message content that needs to be replaced." + ) + file_path: FilePath + start_index: conint(ge=0) + end_index: conint(ge=0) + + +class Object27(Enum): + thread_run_step = "thread.run.step" + + +class Type16(Enum): + message_creation = "message_creation" + tool_calls = "tool_calls" + + +class Status3(Enum): + in_progress = "in_progress" + cancelled = "cancelled" + failed = "failed" + completed = "completed" + expired = "expired" + + +class Code1(Enum): + server_error = "server_error" + rate_limit_exceeded = "rate_limit_exceeded" + + +class LastError1(BaseModel): + code: Code1 = Field( + ..., description="One of `server_error` or `rate_limit_exceeded`." + ) + message: str = Field(..., description="A human-readable description of the error.") + + +class Type17(Enum): + message_creation = "message_creation" + + +class MessageCreation(BaseModel): + message_id: str = Field( + ..., description="The ID of the message that was created by this run step." + ) + + +class RunStepDetailsMessageCreationObject(BaseModel): + type: Type17 = Field(..., description="Always `message_creation``.") + message_creation: MessageCreation + + +class Type18(Enum): + tool_calls = "tool_calls" + + +class Type19(Enum): + code_interpreter = "code_interpreter" + + +class Type20(Enum): + logs = "logs" + + +class RunStepDetailsToolCallsCodeOutputLogsObject(BaseModel): + type: Type20 = Field(..., description="Always `logs`.") + logs: str = Field( + ..., description="The text output from the Code Interpreter tool call." + ) + + +class Type21(Enum): + image = "image" + + +class Image1(BaseModel): + file_id: str = Field( + ..., description="The [file](/docs/api-reference/files) ID of the image." + ) + + +class RunStepDetailsToolCallsCodeOutputImageObject(BaseModel): + type: Type21 = Field(..., description="Always `image`.") + image: Image1 + + +class Type22(Enum): + retrieval = "retrieval" + + +class RunStepDetailsToolCallsRetrievalObject(BaseModel): + id: str = Field(..., description="The ID of the tool call object.") + type: Type22 = Field( + ..., + description="The type of tool call. This is always going to be `retrieval` for this type of tool call.", + ) + retrieval: Dict[str, Any] = Field( + ..., description="For now, this is always going to be an empty object." + ) + + +class Type23(Enum): + function = "function" + + +class Function4(BaseModel): + name: str = Field(..., description="The name of the function.") + arguments: str = Field(..., description="The arguments passed to the function.") + output: str = Field( + ..., + description="The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet.", + ) + + +class RunStepDetailsToolCallsFunctionObject(BaseModel): + id: str = Field(..., description="The ID of the tool call object.") + type: Type23 = Field( + ..., + description="The type of tool call. This is always going to be `function` for this type of tool call.", + ) + function: Function4 = Field( + ..., description="The definition of the function that was called." + ) + + +class Object28(Enum): + assistant_file = "assistant.file" + + +class AssistantFileObject(BaseModel): + id: str = Field( + ..., description="The identifier, which can be referenced in API endpoints." + ) + object: Object28 = Field( + ..., description="The object type, which is always `assistant.file`." + ) + created_at: int = Field( + ..., + description="The Unix timestamp (in seconds) for when the assistant file was created.", + ) + assistant_id: str = Field( + ..., description="The assistant ID that the file is attached to." + ) + + +class CreateAssistantFileRequest(BaseModel): + class Config: + extra = Extra.forbid + + file_id: str = Field( + ..., + description='A [File](/docs/api-reference/files) ID (with `purpose="assistants"`) that the assistant should use. Useful for tools like `retrieval` and `code_interpreter` that can access files.', + ) + + +class Object29(Enum): + assistant_file_deleted = "assistant.file.deleted" + + +class DeleteAssistantFileResponse(BaseModel): + id: str + deleted: bool + object: Object29 + + +class ListAssistantFilesResponse(BaseModel): + object: str = Field(..., example="list") + data: List[AssistantFileObject] + first_id: str = Field(..., example="file-hLBK7PXBv5Lr2NQT7KLY0ag1") + last_id: str = Field(..., example="file-QLoItBbqwyAJEzlTy4y9kOMM") + has_more: bool = Field(..., example=False) + + +class Object30(Enum): + thread_message_file = "thread.message.file" + + +class MessageFileObject(BaseModel): + id: str = Field( + ..., description="The identifier, which can be referenced in API endpoints." + ) + object: Object30 = Field( + ..., description="The object type, which is always `thread.message.file`." + ) + created_at: int = Field( + ..., + description="The Unix timestamp (in seconds) for when the message file was created.", + ) + message_id: str = Field( + ..., + description="The ID of the [message](/docs/api-reference/messages) that the [File](/docs/api-reference/files) is attached to.", + ) + + +class ListMessageFilesResponse(BaseModel): + object: str = Field(..., example="list") + data: List[MessageFileObject] + first_id: str = Field(..., example="file-hLBK7PXBv5Lr2NQT7KLY0ag1") + last_id: str = Field(..., example="file-QLoItBbqwyAJEzlTy4y9kOMM") + has_more: bool = Field(..., example=False) + + +class Order(Enum): + asc = "asc" + desc = "desc" + + +class Order1(Enum): + asc = "asc" + desc = "desc" + + +class Order2(Enum): + asc = "asc" + desc = "desc" + + +class Order3(Enum): + asc = "asc" + desc = "desc" + + +class Order4(Enum): + asc = "asc" + desc = "desc" + + +class Order5(Enum): + asc = "asc" + desc = "desc" + + +class Order6(Enum): + asc = "asc" + desc = "desc" + + +class Order7(Enum): + asc = "asc" + desc = "desc" + + +class Order8(Enum): + asc = "asc" + desc = "desc" + + +class Order9(Enum): + asc = "asc" + desc = "desc" + + +class Order10(Enum): + asc = "asc" + desc = "desc" + + +class Order11(Enum): + asc = "asc" + desc = "desc" + + +class ListModelsResponse(BaseModel): + object: Object + data: List[Model] + + +class CreateCompletionResponse(BaseModel): + id: str = Field(..., description="A unique identifier for the completion.") + choices: List[Choice] = Field( + ..., + description="The list of completion choices the model generated for the input prompt.", + ) + created: int = Field( + ..., + description="The Unix timestamp (in seconds) of when the completion was created.", + ) + model: str = Field(..., description="The model used for completion.") + system_fingerprint: Optional[str] = Field( + None, + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + ) + object: Object1 = Field( + ..., description='The object type, which is always "text_completion"' + ) + usage: Optional[CompletionUsage] = None + + +class ChatCompletionRequestMessageContentPart(BaseModel): + __root__: Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartImage, + ] + + +class ChatCompletionRequestUserMessage(BaseModel): + content: Union[str, List[ChatCompletionRequestMessageContentPart]] = Field( + ..., description="The contents of the user message.\n" + ) + role: Role1 = Field( + ..., description="The role of the messages author, in this case `user`." + ) + + +class ChatCompletionTool(BaseModel): + type: Type2 = Field( + ..., + description="The type of the tool. Currently, only `function` is supported.", + ) + function: FunctionObject + + +class ChatCompletionToolChoiceOption(BaseModel): + __root__: Union[ + ChatCompletionToolChoiceOptionEnum, ChatCompletionNamedToolChoice + ] = Field( + ..., + description='Controls which (if any) function is called by the model.\n`none` means the model will not call a function and instead generates a message.\n`auto` means the model can pick between generating a message or calling a function.\nSpecifying a particular function via `{"type: "function", "function": {"name": "my_function"}}` forces the model to call that function.\n\n`none` is the default when no functions are present. `auto` is the default if functions are present.\n', + ) + + +class ChatCompletionMessageToolCalls(BaseModel): + __root__: List[ChatCompletionMessageToolCall] = Field( + ..., + description="The tool calls generated by the model, such as function calls.", + ) + + +class ChatCompletionResponseMessage(BaseModel): + content: str = Field(..., description="The contents of the message.") + tool_calls: Optional[ChatCompletionMessageToolCalls] = None + role: Role5 = Field(..., description="The role of the author of this message.") + function_call: Optional[FunctionCall1] = Field( + None, + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.", + ) + + +class Choice1(BaseModel): + finish_reason: FinishReason1 = Field( + ..., + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\n`content_filter` if content was omitted due to a flag from our content filters,\n`tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.\n", + ) + index: int = Field( + ..., description="The index of the choice in the list of choices." + ) + message: ChatCompletionResponseMessage + + +class CreateChatCompletionResponse(BaseModel): + id: str = Field(..., description="A unique identifier for the chat completion.") + choices: List[Choice1] = Field( + ..., + description="A list of chat completion choices. Can be more than one if `n` is greater than 1.", + ) + created: int = Field( + ..., + description="The Unix timestamp (in seconds) of when the chat completion was created.", + ) + model: str = Field(..., description="The model used for the chat completion.") + system_fingerprint: Optional[str] = Field( + None, + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + ) + object: Object2 = Field( + ..., description="The object type, which is always `chat.completion`." + ) + usage: Optional[CompletionUsage] = None + + +class Choice2(BaseModel): + finish_reason: FinishReason2 = Field( + ..., + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, `content_filter` if content was omitted due to a flag from our content filters, or `function_call` if the model called a function.\n", + ) + index: int = Field( + ..., description="The index of the choice in the list of choices." + ) + message: ChatCompletionResponseMessage + + +class CreateChatCompletionFunctionResponse(BaseModel): + id: str = Field(..., description="A unique identifier for the chat completion.") + choices: List[Choice2] = Field( + ..., + description="A list of chat completion choices. Can be more than one if `n` is greater than 1.", + ) + created: int = Field( + ..., + description="The Unix timestamp (in seconds) of when the chat completion was created.", + ) + model: str = Field(..., description="The model used for the chat completion.") + system_fingerprint: Optional[str] = Field( + None, + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + ) + object: Object3 = Field( + ..., description="The object type, which is always `chat.completion`." + ) + usage: Optional[CompletionUsage] = None + + +class ListPaginatedFineTuningJobsResponse(BaseModel): + data: List[FineTuningJob] + has_more: bool + object: Object4 + + +class CreateEditResponse(BaseModel): + choices: List[Choice4] = Field( + ..., + description="A list of edit choices. Can be more than one if `n` is greater than 1.", + ) + object: Object6 = Field(..., description="The object type, which is always `edit`.") + created: int = Field( + ..., description="The Unix timestamp (in seconds) of when the edit was created." + ) + usage: CompletionUsage + + +class ImagesResponse(BaseModel): + created: int + data: List[Image] + + +class ListFilesResponse(BaseModel): + data: List[OpenAIFile] + object: Object7 + + +class ListFineTuningJobEventsResponse(BaseModel): + data: List[FineTuningJobEvent] + object: Object9 + + +class ListFineTuneEventsResponse(BaseModel): + data: List[FineTuneEvent] + object: Object11 + + +class CreateEmbeddingResponse(BaseModel): + data: List[Embedding] = Field( + ..., description="The list of embeddings generated by the model." + ) + model: str = Field( + ..., description="The name of the model used to generate the embedding." + ) + object: Object12 = Field( + ..., description='The object type, which is always "list".' + ) + usage: Usage = Field(..., description="The usage information for the request.") + + +class FineTune(BaseModel): + id: str = Field( + ..., + description="The object identifier, which can be referenced in the API endpoints.", + ) + created_at: int = Field( + ..., + description="The Unix timestamp (in seconds) for when the fine-tuning job was created.", + ) + events: Optional[List[FineTuneEvent]] = Field( + None, + description="The list of events that have been observed in the lifecycle of the FineTune job.", + ) + fine_tuned_model: str = Field( + ..., description="The name of the fine-tuned model that is being created." + ) + hyperparams: Hyperparams = Field( + ..., + description="The hyperparameters used for the fine-tuning job. See the [fine-tuning guide](/docs/guides/legacy-fine-tuning/hyperparameters) for more details.", + ) + model: str = Field(..., description="The base model that is being fine-tuned.") + object: Object18 = Field( + ..., description='The object type, which is always "fine-tune".' + ) + organization_id: str = Field( + ..., description="The organization that owns the fine-tuning job." + ) + result_files: List[OpenAIFile] = Field( + ..., description="The compiled results files for the fine-tuning job." + ) + status: str = Field( + ..., + description="The current status of the fine-tuning job, which can be either `created`, `running`, `succeeded`, `failed`, or `cancelled`.", + ) + training_files: List[OpenAIFile] = Field( + ..., description="The list of files used for training." + ) + updated_at: int = Field( + ..., + description="The Unix timestamp (in seconds) for when the fine-tuning job was last updated.", + ) + validation_files: List[OpenAIFile] = Field( + ..., description="The list of files used for validation." + ) + + +# class AssistantObject(BaseModel): +# id: str = Field( +# ..., description="The identifier, which can be referenced in API endpoints." +# ) +# object: Object20 = Field( +# ..., description="The object type, which is always `assistant`." +# ) +# created_at: int = Field( +# ..., +# description="The Unix timestamp (in seconds) for when the assistant was created.", +# ) +# name: constr(max_length=256) = Field( +# ..., +# description="The name of the assistant. The maximum length is 256 characters.\n", +# ) +# description: constr(max_length=512) = Field( +# ..., +# description="The description of the assistant. The maximum length is 512 characters.\n", +# ) +# model: str = Field( +# ..., +# description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", +# ) +# instructions: constr(max_length=32768) = Field( +# ..., +# description="The system instructions that the assistant uses. The maximum length is 32768 characters.\n", +# ) +# tools: List[ +# Union[AssistantToolsCode, AssistantToolsRetrieval, AssistantToolsFunction] +# ] = Field( +# ..., +# description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `retrieval`, or `function`.\n", +# max_items=128, +# ) +# file_ids: List[str] = Field( +# ..., +# description="A list of [file](/docs/api-reference/files) IDs attached to this assistant. There can be a maximum of 20 files attached to the assistant. Files are ordered by their creation date in ascending order.\n", +# max_items=20, +# ) +# metadata: Dict[str, Any] = Field( +# ..., +# description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", +# ) + + +class CreateAssistantRequest(BaseModel): + class Config: + extra = Extra.forbid + + model: str = Field( + ..., + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", + ) + name: Optional[constr(max_length=256)] = Field( + None, + description="The name of the assistant. The maximum length is 256 characters.\n", + ) + description: Optional[constr(max_length=512)] = Field( + None, + description="The description of the assistant. The maximum length is 512 characters.\n", + ) + instructions: Optional[constr(max_length=32768)] = Field( + None, + description="The system instructions that the assistant uses. The maximum length is 32768 characters.\n", + ) + tools: Optional[ + List[ + Union[ + AssistantToolsCode, + AssistantToolsRetrieval, + AssistantToolsFunction, + AssistantToolsBrowser, + ] + ] + ] = Field( + [], + description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `retrieval`, or `function`.\n", + max_items=128, + ) + file_ids: Optional[List[str]] = Field( + [], + description="A list of [file](/docs/api-reference/files) IDs attached to this assistant. There can be a maximum of 20 files attached to the assistant. Files are ordered by their creation date in ascending order.\n", + max_items=20, + ) + metadata: Optional[Dict[str, Any]] = Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ) + + +class ModifyAssistantRequest(BaseModel): + class Config: + extra = Extra.forbid + + model: Optional[str] = Field( + None, + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", + ) + name: Optional[constr(max_length=256)] = Field( + None, + description="The name of the assistant. The maximum length is 256 characters.\n", + ) + description: Optional[constr(max_length=512)] = Field( + None, + description="The description of the assistant. The maximum length is 512 characters.\n", + ) + instructions: Optional[constr(max_length=32768)] = Field( + None, + description="The system instructions that the assistant uses. The maximum length is 32768 characters.\n", + ) + tools: Optional[ + List[ + Union[ + AssistantToolsCode, + AssistantToolsRetrieval, + AssistantToolsFunction, + AssistantToolsBrowser, + ] + ] + ] = Field( + [], + description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `retrieval`, or `function`.\n", + max_items=128, + ) + file_ids: Optional[List[str]] = Field( + [], + description="A list of [File](/docs/api-reference/files) IDs attached to this assistant. There can be a maximum of 20 files attached to the assistant. Files are ordered by their creation date in ascending order. If a file was previosuly attached to the list but does not show up in the list, it will be deleted from the assistant.\n", + max_items=20, + ) + metadata: Optional[Dict[str, Any]] = Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ) + + +class SubmitToolOutputs(BaseModel): + tool_calls: List[RunToolCallObject] = Field( + ..., description="A list of the relevant tool calls." + ) + + +class RequiredAction(BaseModel): + type: Type10 = Field( + ..., description="For now, this is always `submit_tool_outputs`." + ) + submit_tool_outputs: SubmitToolOutputs = Field( + ..., description="Details on the tool outputs needed for this run to continue." + ) + + +# class RunObject(BaseModel): +# id: str = Field( +# ..., description="The identifier, which can be referenced in API endpoints." +# ) +# object: Object22 = Field( +# ..., description="The object type, which is always `thread.run`." +# ) +# created_at: int = Field( +# ..., description="The Unix timestamp (in seconds) for when the run was created." +# ) +# thread_id: str = Field( +# ..., +# description="The ID of the [thread](/docs/api-reference/threads) that was executed on as a part of this run.", +# ) +# assistant_id: str = Field( +# ..., +# description="The ID of the [assistant](/docs/api-reference/assistants) used for execution of this run.", +# ) +# status: Status2 = Field( +# ..., +# description="The status of the run, which can be either `queued`, `in_progress`, `requires_action`, `cancelling`, `cancelled`, `failed`, `completed`, or `expired`.", +# ) +# required_action: RequiredAction = Field( +# ..., +# description="Details on the action required to continue the run. Will be `null` if no action is required.", +# ) +# last_error: LastError = Field( +# ..., +# description="The last error associated with this run. Will be `null` if there are no errors.", +# ) +# expires_at: int = Field( +# ..., description="The Unix timestamp (in seconds) for when the run will expire." +# ) +# started_at: int = Field( +# ..., description="The Unix timestamp (in seconds) for when the run was started." +# ) +# cancelled_at: int = Field( +# ..., +# description="The Unix timestamp (in seconds) for when the run was cancelled.", +# ) +# failed_at: int = Field( +# ..., description="The Unix timestamp (in seconds) for when the run failed." +# ) +# completed_at: int = Field( +# ..., +# description="The Unix timestamp (in seconds) for when the run was completed.", +# ) +# model: str = Field( +# ..., +# description="The model that the [assistant](/docs/api-reference/assistants) used for this run.", +# ) +# instructions: str = Field( +# ..., +# description="The instructions that the [assistant](/docs/api-reference/assistants) used for this run.", +# ) +# tools: List[ +# Union[AssistantToolsCode, AssistantToolsRetrieval, AssistantToolsFunction] +# ] = Field( +# ..., +# description="The list of tools that the [assistant](/docs/api-reference/assistants) used for this run.", +# max_items=20, +# ) +# file_ids: List[str] = Field( +# ..., +# description="The list of [File](/docs/api-reference/files) IDs the [assistant](/docs/api-reference/assistants) used for this run.", +# ) +# metadata: Dict[str, Any] = Field( +# ..., +# description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", +# ) + + +# class ListRunsResponse(BaseModel): +# object: str = Field(..., example="list") +# data: List[RunObject] +# first_id: str = Field(..., example="run_hLBK7PXBv5Lr2NQT7KLY0ag1") +# last_id: str = Field(..., example="run_QLoItBbqwyAJEzlTy4y9kOMM") +# has_more: bool = Field(..., example=False) + + +class CreateThreadRequest(BaseModel): + class Config: + extra = Extra.forbid + + messages: Optional[List[CreateMessageRequest]] = Field( + None, + description="A list of [messages](/docs/api-reference/messages) to start the thread with.", + ) + metadata: Optional[Dict[str, Any]] = Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ) + + +class Text(BaseModel): + value: str = Field(..., description="The data that makes up the text.") + annotations: List[ + Union[ + MessageContentTextAnnotationsFileCitationObject, + MessageContentTextAnnotationsFilePathObject, + ] + ] + + +class MessageContentTextObject(BaseModel): + type: Type13 = Field(..., description="Always `text`.") + text: Text + + +class CodeInterpreter(BaseModel): + input: str = Field(..., description="The input to the Code Interpreter tool call.") + outputs: List[ + Union[ + RunStepDetailsToolCallsCodeOutputLogsObject, + RunStepDetailsToolCallsCodeOutputImageObject, + ] + ] = Field( + ..., + description="The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type.", + ) + + +class RunStepDetailsToolCallsCodeObject(BaseModel): + id: str = Field(..., description="The ID of the tool call.") + type: Type19 = Field( + ..., + description="The type of tool call. This is always going to be `code_interpreter` for this type of tool call.", + ) + code_interpreter: CodeInterpreter = Field( + ..., description="The Code Interpreter tool call definition." + ) + + +class ChatCompletionRequestAssistantMessage(BaseModel): + content: str = Field(..., description="The contents of the assistant message.\n") + role: Role2 = Field( + ..., description="The role of the messages author, in this case `assistant`." + ) + tool_calls: Optional[ChatCompletionMessageToolCalls] = None + function_call: Optional[FunctionCall] = Field( + None, + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.", + ) + + +class ListFineTunesResponse(BaseModel): + data: List[FineTune] + object: Object10 + + +class CreateThreadAndRunRequest(BaseModel): + class Config: + extra = Extra.forbid + + assistant_id: str = Field( + ..., + description="The ID of the [assistant](/docs/api-reference/assistants) to use to execute this run.", + ) + thread: Optional[CreateThreadRequest] = Field( + None, description="If no thread is provided, an empty thread will be created." + ) + model: Optional[str] = Field( + None, + description="The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used.", + ) + instructions: Optional[str] = Field( + None, + description="Override the default system message of the assistant. This is useful for modifying the behavior on a per-run basis.", + ) + tools: Optional[ + List[ + Union[ + AssistantToolsCode, + AssistantToolsRetrieval, + AssistantToolsFunction, + AssistantToolsBrowser, + ] + ] + ] = Field( + None, + description="Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis.", + max_items=20, + ) + metadata: Optional[Dict[str, Any]] = Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ) + + +# class MessageObject(BaseModel): +# id: str = Field( +# ..., description="The identifier, which can be referenced in API endpoints." +# ) +# object: Object25 = Field( +# ..., description="The object type, which is always `thread.message`." +# ) +# created_at: int = Field( +# ..., +# description="The Unix timestamp (in seconds) for when the message was created.", +# ) +# thread_id: str = Field( +# ..., +# description="The [thread](/docs/api-reference/threads) ID that this message belongs to.", +# ) +# role: Role7 = Field( +# ..., +# description="The entity that produced the message. One of `user` or `assistant`.", +# ) +# content: List[ +# Union[MessageContentImageFileObject, MessageContentTextObject] +# ] = Field( +# ..., description="The content of the message in array of text and/or images." +# ) +# assistant_id: str = Field( +# ..., +# description="If applicable, the ID of the [assistant](/docs/api-reference/assistants) that authored this message.", +# ) +# run_id: str = Field( +# ..., +# description="If applicable, the ID of the [run](/docs/api-reference/runs) associated with the authoring of this message.", +# ) +# file_ids: List[str] = Field( +# ..., +# description="A list of [file](/docs/api-reference/files) IDs that the assistant should use. Useful for tools like retrieval and code_interpreter that can access files. A maximum of 10 files can be attached to a message.", +# max_items=10, +# ) +# metadata: Dict[str, Any] = Field( +# ..., +# description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", +# ) + + +class RunStepDetailsToolCallsObject(BaseModel): + type: Type18 = Field(..., description="Always `tool_calls`.") + tool_calls: List[ + Union[ + RunStepDetailsToolCallsCodeObject, + RunStepDetailsToolCallsRetrievalObject, + RunStepDetailsToolCallsFunctionObject, + ] + ] = Field( + ..., + description="An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `retrieval`, or `function`.\n", + ) + + +class ChatCompletionRequestMessage(BaseModel): + __root__: Union[ + ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, + ChatCompletionRequestAssistantMessage, + ChatCompletionRequestToolMessage, + ChatCompletionRequestFunctionMessage, + ] + + +class CreateChatCompletionRequest(BaseModel): + messages: List[ChatCompletionRequestMessage] = Field( + ..., + description="A list of messages comprising the conversation so far. [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).", + min_items=1, + ) + model: Union[str, ModelEnum1] = Field( + ..., + description="ID of the model to use. See the [model endpoint compatibility](/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.", + example="gpt-3.5-turbo", + ) + frequency_penalty: Optional[confloat(ge=-2.0, le=2.0)] = Field( + 0, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/guides/gpt/parameter-details)\n", + ) + logit_bias: Optional[Dict[str, int]] = Field( + None, + description="Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n", + ) + max_tokens: Optional[int] = Field( + "inf", + description="The maximum number of [tokens](/tokenizer) to generate in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", + ) + n: Optional[conint(ge=1, le=128)] = Field( + 1, + description="How many chat completion choices to generate for each input message.", + example=1, + ) + presence_penalty: Optional[confloat(ge=-2.0, le=2.0)] = Field( + 0, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/guides/gpt/parameter-details)\n", + ) + response_format: Optional[ResponseFormat] = Field( + None, + description='An object specifying the format that the model must output. \n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in increased latency and appearance of a "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n', + ) + seed: Optional[conint(ge=-9223372036854775808, le=9223372036854775808)] = Field( + None, + description="This feature is in Beta. \nIf specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.\nDeterminism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.\n", + ) + stop: Optional[Union[str, List[str]]] = Field( + None, + description="Up to 4 sequences where the API will stop generating further tokens.\n", + ) + stream: Optional[bool] = Field( + False, + description="If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ) + temperature: Optional[confloat(ge=0.0, le=2.0)] = Field( + 1, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n", + example=1, + ) + top_p: Optional[confloat(ge=0.0, le=1.0)] = Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n", + example=1, + ) + tools: Optional[List[ChatCompletionTool]] = Field( + None, + description="A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for.\n", + ) + tool_choice: Optional[ChatCompletionToolChoiceOption] = None + user: Optional[str] = Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + example="user-1234", + ) + function_call: Optional[ + Union[FunctionCallEnum, ChatCompletionFunctionCallOption] + ] = Field( + None, + description='Deprecated in favor of `tool_choice`.\n\nControls which (if any) function is called by the model.\n`none` means the model will not call a function and instead generates a message.\n`auto` means the model can pick between generating a message or calling a function.\nSpecifying a particular function via `{"name": "my_function"}` forces the model to call that function.\n\n`none` is the default when no functions are present. `auto`` is the default if functions are present.\n', + ) + functions: Optional[List[ChatCompletionFunctions]] = Field( + None, + description="Deprecated in favor of `tools`.\n\nA list of functions the model may generate JSON inputs for.\n", + max_items=128, + min_items=1, + ) + + +# class RunStepObject(BaseModel): +# id: str = Field( +# ..., +# description="The identifier of the run step, which can be referenced in API endpoints.", +# ) +# object: Object27 = Field( +# ..., description="The object type, which is always `thread.run.step``." +# ) +# created_at: int = Field( +# ..., +# description="The Unix timestamp (in seconds) for when the run step was created.", +# ) +# assistant_id: str = Field( +# ..., +# description="The ID of the [assistant](/docs/api-reference/assistants) associated with the run step.", +# ) +# thread_id: str = Field( +# ..., +# description="The ID of the [thread](/docs/api-reference/threads) that was run.", +# ) +# run_id: str = Field( +# ..., +# description="The ID of the [run](/docs/api-reference/runs) that this run step is a part of.", +# ) +# type: Type16 = Field( +# ..., +# description="The type of run step, which can be either `message_creation` or `tool_calls`.", +# ) +# status: Status3 = Field( +# ..., +# description="The status of the run step, which can be either `in_progress`, `cancelled`, `failed`, `completed`, or `expired`.", +# ) +# step_details: Union[ +# RunStepDetailsMessageCreationObject, RunStepDetailsToolCallsObject +# ] = Field(..., description="The details of the run step.") +# last_error: LastError1 = Field( +# ..., +# description="The last error associated with this run step. Will be `null` if there are no errors.", +# ) +# expired_at: int = Field( +# ..., +# description="The Unix timestamp (in seconds) for when the run step expired. A step is considered expired if the parent run is expired.", +# ) +# cancelled_at: int = Field( +# ..., +# description="The Unix timestamp (in seconds) for when the run step was cancelled.", +# ) +# failed_at: int = Field( +# ..., description="The Unix timestamp (in seconds) for when the run step failed." +# ) +# completed_at: int = Field( +# ..., +# description="The Unix timestamp (in seconds) for when the run step completed.", +# ) +# metadata: Dict[str, Any] = Field( +# ..., +# description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", +# ) + + +# class ListRunStepsResponse(BaseModel): +# object: str = Field(..., example="list") +# data: List[RunStepObject] +# first_id: str = Field(..., example="step_hLBK7PXBv5Lr2NQT7KLY0ag1") +# last_id: str = Field(..., example="step_QLoItBbqwyAJEzlTy4y9kOMM") +# has_more: bool = Field(..., example=False) diff --git a/services/backend/task_executor/app/tasks.py b/services/backend/task_executor/app/tasks.py new file mode 100644 index 0000000..a807306 --- /dev/null +++ b/services/backend/task_executor/app/tasks.py @@ -0,0 +1,505 @@ +# tasks.py +# Standard Library +import json +import os +import sys +from functools import partial + +# Get the current working directory +current_directory = os.getcwd() + +# Add the current directory to sys.path +sys.path.append(current_directory) + +# Standard Library +import io +import logging +import uuid +from datetime import datetime + +# Third Party +import redis +from celery import Celery, shared_task, signals +from openai import OpenAI +from pymongo import MongoClient + +# Local +from app.backend_models import ( + AssistantObject, + FilesStorageObject, + MessageObject, + RunStepObject, +) +from app.local_model import RubraLocalAgent +from app.models import Role7, Status2, Type8, Type824 + +litellm_host = os.getenv("LITELLM_HOST", "localhost") +redis_host = os.getenv("REDIS_HOST", "localhost") +mongodb_host = os.getenv("MONGODB_HOST", "localhost") + +redis_client = redis.Redis(host=redis_host, port=6379, db=0) +app = Celery("tasks", broker=f"redis://{redis_host}:6379/0") +app.config_from_object("app.celery_config") +app.autodiscover_tasks(["app"]) # Explicitly discover tasks in 'app' package + +# MongoDB Configuration +MONGODB_URL = f"mongodb://{mongodb_host}:27017" +DATABASE_NAME = "rubra_db" + +# Global MongoDB client +mongo_client = None + + +@signals.worker_process_init.connect +def setup_mongo_connection(*args, **kwargs): + global mongo_client + mongo_client = MongoClient(f"mongodb://{mongodb_host}:27017") + + +def create_assistant_message( + thread_id, assistant_id, run_id, content_text, role=Role7.assistant.value +): + db = mongo_client[DATABASE_NAME] + + # Generate a unique ID for the message + message_id = f"msg_{uuid.uuid4().hex[:6]}" + + # Create the message object + message = { + "id": message_id, + "object": "thread.message", + "created_at": int(datetime.now().timestamp()), + "thread_id": thread_id, + "role": role, + "content": [ + {"type": "text", "text": {"value": content_text, "annotations": []}} + ], + "file_ids": [], + "assistant_id": assistant_id, + "run_id": run_id, + "metadata": {}, + } + + # Insert the message into MongoDB + db.messages.insert_one(message) + + +def rubra_local_agent_chat_completion( + chat_agent: RubraLocalAgent, + redis_channel, + chat_messages: list, + sys_instruction: str, + thread_id: str, + assistant_id: str, + run_id: str, +): + response = chat_agent.chat( + msgs=chat_messages, sys_instruction=sys_instruction, stream=True + ) + + msg = "" + if ( + len(chat_agent.tools) == 0 + ): # if no tool, stream the response and return msg directly + for r in response: + if r.choices and r.choices[0].delta and r.choices[0].delta.content: + msg += r.choices[0].delta.content + + redis_client.publish(redis_channel, str(r)) + return msg + + msg_state = "start" # possible states: start, chat, function + for r in response: + if r.choices and r.choices[0].delta and r.choices[0].delta.content: + msg += r.choices[0].delta.content + logging.debug(f" msg : {msg} msg_state: {msg_state}") + if msg_state == "function": + r.choices[0].delta.function_call = {"type": "function"} + + if '"function":' in msg or '"content": "' in msg: + if msg_state == "start": + if '"content": "' in msg: + msg_state = "chat" + r.choices[0].delta.content = msg.split('"content": "')[1] + else: + msg_state = "function" + r.choices[0].delta.content = msg + r.choices[0].delta.function_call = {"type": "function"} + elif msg_state == "chat": + if r.choices[0].delta.content.endswith('"'): + try: + json.loads(msg + "}") + r.choices[0].delta.content = r.choices[ + 0 + ].delta.content.split('"')[0] + except: + pass + if r.choices[0].delta.content.endswith("}"): + try: + json.loads(msg) + r.choices[0].delta.content = r.choices[ + 0 + ].delta.content.split("}")[0] + except: + pass + + if msg_state != "start": + redis_client.publish(redis_channel, str(r)) + + is_function_call, parsed_msg = chat_agent.validate_function_call(msg) + ## TODO: create run_step object. + + if is_function_call: + logging.info("=====function call========") + function_call_content = parsed_msg + chat_messages.append( + {"role": Role7.assistant.value, "content": function_call_content} + ) + create_assistant_message( + thread_id, assistant_id, run_id, content_text=function_call_content + ) + + function_response = chat_agent.get_function_response( + function_call_json=json.loads(parsed_msg) + ) + + parsed_msg = function_response + + last_chunk = r + last_chunk.choices[0].delta.function_call = None + last_chunk.choices[0].delta.content = parsed_msg + redis_client.publish(redis_channel, str(last_chunk)) + last_chunk.choices[0].delta.content = "" + redis_client.publish(redis_channel, str(last_chunk)) + print(f"message: {parsed_msg}") + + return parsed_msg + + +def form_openai_tools(tools, assistant_id: str): + # Local + from app.tools.file_knowledge_tool import FileKnowledgeTool + from app.tools.web_browse_tool.web_browse_tool import WebBrowseTool + + retrieval = FileKnowledgeTool() + googlesearch = WebBrowseTool() + res_tools = [] + available_function = {} + for t in tools: + if t["type"] == Type8.retrieval.value: + retrieval_tool = { + "type": "function", + "function": { + "name": retrieval.name, + "description": retrieval.description, + "parameters": retrieval.parameters, + }, + } + res_tools.append(retrieval_tool) + retrieval_func = partial(retrieval._run, assistant_id=assistant_id) + available_function[retrieval.name] = retrieval_func + elif t["type"] == Type824.retrieval.value: + gs_tool = { + "type": "function", + "function": { + "name": googlesearch.name, + "description": googlesearch.description, + "parameters": googlesearch.parameters, + }, + } + res_tools.append(gs_tool) + available_function[googlesearch.name] = googlesearch._run + else: + res_tools.append(t) + return res_tools, available_function + + +@shared_task +def execute_chat_completion(assistant_id, thread_id, redis_channel, run_id): + try: + oai_client = OpenAI( + base_url=f"http://{litellm_host}:8002/v1/", + api_key="abc", # point to litellm server + ) + db = mongo_client[DATABASE_NAME] + + # Fetch assistant and thread messages synchronously + assistant = db.assistants.find_one({"id": assistant_id}) + thread_messages = list(db.messages.find({"thread_id": thread_id})) + + if not assistant or not thread_messages: + raise ValueError("Assistant or Thread Messages not found") + + # Update the run status to in_progress and set the started_at timestamp + started_at = int(datetime.now().timestamp()) + db.runs.update_one( + {"id": run_id}, + {"$set": {"status": Status2.in_progress.value, "started_at": started_at}}, + ) + + print("Calling model:", assistant["model"]) + + if assistant["model"].startswith("claude-") or assistant["model"].startswith( + "gpt-" + ): + # Prepare the chat messages for OpenAI + chat_messages = [{"role": "system", "content": assistant["instructions"]}] + for msg in thread_messages: + content_text = ( + msg["content"][0]["text"]["value"] + if msg["content"] + and isinstance(msg["content"], list) + and "text" in msg["content"][0] + else "" + ) + chat_messages.append({"role": msg["role"], "content": content_text}) + + print("Chat Messages:", chat_messages) + + # Call OpenAI for chat completion + oai_tools, available_function = form_openai_tools( + assistant["tools"], assistant_id=assistant_id + ) + print("oai_tools", oai_tools) + # filter out code interpreter and browser for now + oai_tools = [ + tool for tool in oai_tools if tool["type"] != "code_interpreter" + ] + + if oai_tools: + response = oai_client.chat.completions.create( + model=assistant["model"], + messages=chat_messages, + tools=oai_tools, + tool_choice="auto", + temperature=0.1, + top_p=0.95, + stream=True, + ) + else: + response = oai_client.chat.completions.create( + model=assistant["model"], + messages=chat_messages, + temperature=0.1, + top_p=0.95, + stream=True, + ) + + # Iterate over the response chunks and construct the assistant's response + assistant_response = "" + function_call_list_dict = [] + for i, chunk in enumerate(response): + if chunk.choices and chunk.choices[0].delta: + if chunk.choices[ + 0 + ].delta.tool_calls: # openai can do multi-tool-call + print(chunk.choices[0].delta) + for tc in chunk.choices[0].delta.tool_calls: + if tc.function.name: # the first chunk + function_call_list_dict.append( + { + "name": tc.function.name, + "id": tc.id, + "argument": tc.function.arguments or "", + } + ) + + else: + function_call_list_dict[-1][ + "argument" + ] += tc.function.arguments + elif chunk.choices[0].delta.content: + assistant_message = chunk.choices[0].delta.content + assistant_response += assistant_message + redis_client.publish(redis_channel, str(chunk)) + + while function_call_list_dict: + print(f"called function : {function_call_list_dict}") + print(f"num of functions: {len(function_call_list_dict)}") + + function_call_msg = { + "content": "", + "role": "assistant", + "tool_calls": [], + } + for j, fc in enumerate(function_call_list_dict): + this_fc = { + "index": 0, + "id": fc["id"], + "function": {"arguments": fc["argument"], "name": fc["name"]}, + "type": "function", + } + function_call_msg["tool_calls"].append(this_fc) + print(function_call_msg) + chat_messages.append(function_call_msg) + for ftc in function_call_list_dict: + if ftc["name"] not in available_function: + # TODO: in this case, add a runstep object. And the user is responsible to submit the tool output + pass + else: + function_args = json.loads(ftc["argument"]) + function_response = available_function[ftc["name"]]( + **function_args + ) + function_response_content = json.dumps(function_response) + + chat_messages.append( + { + "tool_call_id": ftc["id"], + "role": "tool", + "name": ftc["name"], + "content": function_response_content, + } + ) + + response = oai_client.chat.completions.create( + model=assistant["model"], + messages=chat_messages, + tools=oai_tools, + tool_choice="auto", + temperature=0.1, + top_p=0.95, + stream=True, + ) + # Iterate over the response chunks and construct the assistant's response + assistant_response = "" + function_call_list_dict = [] + for i, chunk in enumerate(response): + if chunk.choices and chunk.choices[0].delta: + if chunk.choices[ + 0 + ].delta.tool_calls: # openai can do multi-tool-call + print(chunk.choices[0].delta.tool_calls) + for i, tc in enumerate(chunk.choices[0].delta.tool_calls): + if tc.function.name: # the first chunk + function_call_list_dict.append( + { + "name": tc.function.name, + "id": tc.id, + "argument": tc.function.arguments or "", + } + ) + + else: + function_call_list_dict[-1][ + "argument" + ] += tc.function.arguments + elif chunk.choices[0].delta.content: + assistant_message = chunk.choices[0].delta.content + assistant_response += assistant_message + # redis_client.publish(redis_channel, str(assistant_message)) + redis_client.publish(redis_channel, str(chunk)) + + else: # assume local model + chat_messages = [] + for msg in thread_messages: + content_text = ( + msg["content"][0]["text"]["value"] + if msg["content"] + and isinstance(msg["content"], list) + and "text" in msg["content"][0] + else "" + ) + chat_messages.append({"role": msg["role"], "content": content_text}) + + chat_agent = RubraLocalAgent( + assistant_id=assistant_id, tools=assistant["tools"] + ) + sys_instruction = assistant["instructions"] + assistant_response = rubra_local_agent_chat_completion( + chat_agent, + redis_channel, + chat_messages, + sys_instruction, + thread_id, + assistant_id, + run_id, + ) + + # Check if there's a valid response to add as an assistant message + if assistant_response.strip(): + # Create a new message from the assistant + create_assistant_message( + thread_id, assistant_id, run_id, assistant_response + ) + print("Assistant message created:", assistant_response) + + # Update the run status to completed and set the completed_at timestamp after successful completion + completed_at = int(datetime.now().timestamp()) + db.runs.update_one( + {"id": run_id}, + {"$set": {"status": Status2.completed.value, "completed_at": completed_at}}, + ) + redis_client.publish( + f"task_status_{thread_id}", + json.dumps( + {"thread_id": thread_id, "run_id": run_id, "status": "completed"} + ), + ) + + except Exception as e: + print(f"Error in execute_chat_completion: {str(e)}") + # Update the run status to failed and set the failed_at timestamp in case of an exception + failed_at = int(datetime.now().timestamp()) + db.runs.update_one( + {"id": run_id}, + {"$set": {"status": Status2.failed.value, "failed_at": failed_at}}, + ) + redis_client.publish( + f"task_status_{thread_id}", + json.dumps({"thread_id": thread_id, "run_id": run_id, "status": "failed"}), + ) + + raise + + +@app.task +def execute_asst_file_create(file_id: str, assistant_id: str): + # Standard Library + import json + + # Third Party + from langchain.text_splitter import RecursiveCharacterTextSplitter + + # Local + from app.vector_db.milvus.main import add_texts + + try: + db = mongo_client[DATABASE_NAME] + collection_name = assistant_id + text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) + parsed_text = "" + + file_object = db.files_storage.find_one({"id": file_id}) + logging.info(f"processing file id : {file_id}") + + if file_object["content_type"] == "application/pdf": + # Third Party + from PyPDF2 import PdfReader + + reader = PdfReader(io.BytesIO(file_object["content"])) + text = "" + for page in reader.pages: + text += page.extract_text() + "\n" + + parsed_text = text + + elif file_object["content_type"] == "application/json": + res = json.loads(file_object["content"]) + parsed_text = res + else: ## try to read plain text + try: + parsed_text = file_object["content"].decode() + + except Exception as e: + print(f"Load Error: {e}") + + if parsed_text != "": + # Split docs and add to milvus vector DB + texts = text_splitter.split_text(parsed_text) + metadatas = [{"file_id": file_id} for t in texts] + pks = add_texts(collection_name, texts=texts, metadatas=metadatas) + + logging.info(f"file {file_id} processing completed") + except Exception as e: + print(f"Error in execute_asst_file_create: {str(e)}") diff --git a/services/backend/task_executor/app/tools/file_knowledge_tool.py b/services/backend/task_executor/app/tools/file_knowledge_tool.py new file mode 100644 index 0000000..fc62637 --- /dev/null +++ b/services/backend/task_executor/app/tools/file_knowledge_tool.py @@ -0,0 +1,50 @@ +# Standard Library +import json +import os + +# Third Party +import requests + +VECTOR_DB_HOST = os.getenv("VECTOR_DB_HOST", "localhost") +VECTOR_DB_MATCH_URL = f"http://{VECTOR_DB_HOST}:8010/similarity_match" + + +class FileKnowledgeTool: + name = "FileKnowledge" + description = "Useful for search knowledge or information from user's file" + parameters = { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "the completed question to search", + }, + }, + "required": ["query"], + } + + async def _arun(self, query: str, assistant_id): + return await file_knowledge_search_api(query, assistant_id) + + def _run(self, query: str, assistant_id): + return file_knowledge_search_api(query, assistant_id) + + +def file_knowledge_search_api(query: str, assistant_id: str): + headers = {"accept": "application/json", "Content-Type": "application/json"} + data = json.dumps( + { + "text": query, + "collection_name": assistant_id, + "topk": 10, + "rerank": False, + "topr": 5, + } + ) + + response = requests.post(VECTOR_DB_MATCH_URL, headers=headers, data=data) + res = response.json()["response"] + txt = "" + for r in res: + txt += r["text"] + "\n\n" + return txt diff --git a/services/backend/task_executor/app/tools/get_date_tool.py b/services/backend/task_executor/app/tools/get_date_tool.py new file mode 100644 index 0000000..3493953 --- /dev/null +++ b/services/backend/task_executor/app/tools/get_date_tool.py @@ -0,0 +1,35 @@ +# Standard Library +import datetime + +# Third Party +from dateutil.relativedelta import relativedelta + + +class GetNowTool: + name = "GetNowTool" + description = "Get date of the current day. Returns year, month, day" + parameters = { + "type": "object", + "properties": {}, + "required": [], + } + + async def _arun(self, time=None): + return get_date() + + def _run(self): + return get_date() + + +def get_date(day_delta=0, month_delta=0, year_delta=0, year_only=False): + res_date = datetime.date.today() + if day_delta != 0: + res_date += datetime.timedelta(days=day_delta) + if month_delta != 0: + res_date += relativedelta(months=month_delta) + if year_delta != 0: + res_date += relativedelta(years=year_delta) + if year_only: + res_date = res_date.year + res = f"{str(res_date)}" + return res diff --git a/services/backend/task_executor/app/tools/web_browse_tool/browser.py b/services/backend/task_executor/app/tools/web_browse_tool/browser.py new file mode 100644 index 0000000..c747d71 --- /dev/null +++ b/services/backend/task_executor/app/tools/web_browse_tool/browser.py @@ -0,0 +1,147 @@ +# Standard Library +import asyncio +import os +import re +from tempfile import TemporaryDirectory +from urllib.parse import urljoin + +# Third Party +from bs4 import BeautifulSoup +from langchain.docstore.document import Document +from langchain.text_splitter import TokenTextSplitter +from markdownify import markdownify as md +from playwright.async_api import async_playwright + + +class WebPageBrowser: + def __init__(self): + self.browser = WebBrowser() + self.tmpdir = TemporaryDirectory() + self.text_splitter = TokenTextSplitter() + self.url_map = {} + + async def initialize(self): + await self.browser.initialize() + + async def abrowse(self, url: str) -> str: + page_source = await self.browser.goto(url) + + # Filter and Parse HTML + soup = BeautifulSoup(page_source, "html.parser") + # Update relative URLs to absolute URLs + for a_tag in soup.find_all("a", href=True): + a_tag["href"] = urljoin(url, a_tag["href"]) + # Remove unwanted tags like 'script', 'style', etc. + for script in soup(["script", "style", "noscript"]): + script.extract() + filtered_html = str(soup) + + # Convert HTML to Markdown + results = md(filtered_html) + + # Remove consecutive newlines + results = re.sub(r"\n+", "\n", results).strip() + + return results + + async def close(self): + await self.browser.close() + + def browse(self, url: str) -> str: + return asyncio.get_event_loop().run_until_complete(self.abrowse(url)) + + async def async_browse_and_save(self, url: str) -> str: + content = await self.async_browse(url) + filepath = self.save(url) + return filepath + + def browse_and_save(self, url: str) -> str: + return asyncio.get_event_loop().run_until_complete( + self.async_browse_and_save(url) + ) + + def save(self, url: str) -> str: + content = self.browse(url) + + # Generate filename based on the URL + filename = f"{url.replace('https://', '').replace('/', '_')}.md" + filepath = os.path.join(self.tmpdir.name, filename) + + # Save the content to a file in the temporary directory + with open(filepath, "w") as f: + f.write(content) + + self.url_map[filename] = url + + self.filename = filename + + return filepath + + def read(self, filename: str) -> str: + filepath = os.path.join(self.tmpdir.name, filename) + with open(filepath, "r") as f: + content = f.read() + return content + + def explore_website(self, query: str) -> str: + """Assumes you have already fetched the contents of a webpage that is saved to `filename`. Using the query, it will try to respond to the query using the contents of the webpage. If a response cannot be generated, but you think there is a link that may lead you to the answer respond with the links""" + + if self.filename is None: + raise ValueError("Filename is not set. Cannot explore website.") + + file_contents = self.read(self.filename) + filename_only = os.path.basename(self.filename) + # Extract the URL from filename for metadata + url = self.url_map.get(filename_only, filename_only) + # Prepare documents + docs = [Document(page_content=file_contents, metadata={"source": url})] + web_docs = self.text_splitter.split_documents(docs) + results = [] + for i in range(0, len(web_docs), 4): + input_docs = web_docs[i : i + 4] + window_result = self.qa_chain( + {"input_documents": input_docs, "question": query}, + return_only_outputs=True, + ) + results.append(f"Response from window {i} - {window_result}") + results_docs = [ + Document(page_content="\n".join(results), metadata={"source": url}) + ] + res = self.qa_chain( + {"input_documents": results_docs, "question": query}, + return_only_outputs=True, + ) + return res + + +class WebBrowser: + def __init__(self): + self.browser = None + self.playwright = None + + async def initialize(self): + self.playwright = await async_playwright().start() + self.browser = await self.playwright.chromium.launch(headless=True) + + async def goto(self, url: str): + # Create a new browser context + if self.browser is None: + await self.initialize() + context = await self.browser.new_context( + java_script_enabled=False + ) # Turn off JavaScript + page = await context.new_page() + try: + await page.goto(url, wait_until="domcontentloaded") + except Exception: + pass # Ignore timeout or other exceptions + content = await page.content() + await page.close() + await context.close() + return content + + async def close(self): + if self.browser: + await self.browser.close() + if self.playwright: + await self.playwright.stop() diff --git a/services/backend/task_executor/app/tools/web_browse_tool/web_browse_tool.py b/services/backend/task_executor/app/tools/web_browse_tool/web_browse_tool.py new file mode 100644 index 0000000..468fff8 --- /dev/null +++ b/services/backend/task_executor/app/tools/web_browse_tool/web_browse_tool.py @@ -0,0 +1,94 @@ +# Standard Library +import logging +import urllib.parse + +# Third Party +from googlesearch import search as search_api + +from .browser import WebPageBrowser + + +def web_scraper(url: str): + wb = WebPageBrowser() + result = wb.browse(url) + return result + + +class WebBrowseTool: + name = "GoogleSearchTool" + description = "Useful for search information on internet." + parameters = { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "the completed question to search", + }, + }, + "required": ["query"], + } + + async def _arun(self, query: str): + return await google_search_api(query) + + def _run(self, query: str, web_browse: bool = False, concat_text: bool = True): + res = google_search_api(query) + summaries = [] + if web_browse: + for i, r in enumerate(res): + context = parse_url(r["url"], query) + res[i]["text"] += "\n" + context + + if concat_text: + for i, r in enumerate(res): + formatted_summary = f"{i+1}.\nTITLE:{r['title']}\nTEXT:{r['text']}\nSOURCE_URL:{r['url']}" + summaries.append(formatted_summary) + + joint_res = "\n\n".join(summaries) + return joint_res + else: + return res + + +def google_search_api(query: str) -> str: + max_results = 5 + res = [ + {"title": r.title, "url": r.url, "text": r.description} + for i, r in enumerate(search_api(query, advanced=True)) + if i < max_results + ] + return res + + +def create_google_search_url(query: str) -> str: + encoded_query = urllib.parse.quote_plus(query) + url = f"https://www.google.com/search?q={encoded_query}" + return url + + +def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int: + """Returns the number of tokens in a text string.""" + # Third Party + import tiktoken + + encoding = tiktoken.get_encoding(encoding_name) + num_tokens = len(encoding.encode(string)) + return num_tokens + + +def parse_url(url: str, query: str, stream: bool = False) -> str: + """ + Args: + query (str): _description_ + stream (bool, optional): _description_. Defaults to False. + summarize (bool, optional): _description_. Defaults to False. + + Returns: + _type_: _description_ + """ + try: + context = web_scraper(url) + except Exception as e: + logging.error(f"Error in web_scraper: {e}") + context = "" + return context diff --git a/services/backend/task_executor/app/vector_db/milvus/CustomEmbeddings.py b/services/backend/task_executor/app/vector_db/milvus/CustomEmbeddings.py new file mode 100644 index 0000000..085f281 --- /dev/null +++ b/services/backend/task_executor/app/vector_db/milvus/CustomEmbeddings.py @@ -0,0 +1,55 @@ +# Standard Library +import json +import os +from typing import List + +# Third Party +import requests +from langchain.embeddings.base import Embeddings + +HOST = os.getenv("EMBEDDING_HOST", "localhost") +EMBEDDING_URL = f"http://{HOST}:8020/embed_multiple" + + +def embed_text(texts: List[str]) -> List[List[float]]: + """Embed a list of texts using a remote service. + + Args: + texts (List[str]): List of texts to be embedded. + + Returns: + List[List[float]]: List of embedded texts. + """ + headers = {"accept": "application/json", "Content-Type": "application/json"} + data = json.dumps(texts) + + response = requests.post(EMBEDDING_URL, headers=headers, data=data) + response = response.json() + + return response["embeddings"] + + +class CustomEmbeddings(Embeddings): + """Custom embeddings class that uses a remote service for embedding.""" + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed a list of documents. + + Args: + texts (List[str]): List of documents to be embedded. + + Returns: + List[List[float]]: List of embedded documents. + """ + return embed_text(texts) + + def embed_query(self, text: str) -> List[float]: + """Embed a single query. + + Args: + text (str): Query to be embedded. + + Returns: + List[float]: Embedded query. + """ + return embed_text([text])[0] diff --git a/services/backend/task_executor/app/vector_db/milvus/main.py b/services/backend/task_executor/app/vector_db/milvus/main.py new file mode 100644 index 0000000..e3922d5 --- /dev/null +++ b/services/backend/task_executor/app/vector_db/milvus/main.py @@ -0,0 +1,119 @@ +# Standard library imports +# Standard Library +import os +from typing import List, Optional + +# Third Party +# Third party imports +from fastapi import FastAPI +from pydantic import BaseModel + +# Local application imports +from .CustomEmbeddings import CustomEmbeddings +from .query_milvus import Milvus + +MILVUS_HOST = os.getenv("MILVUS_HOST", "localhost") + +model = {} +top_re_rank = 5 +top_k_match = 10 +app = FastAPI() + + +class Query(BaseModel): + text: str + collection_name: str + topk: int = top_k_match + rerank: bool = False + topr: int = top_re_rank + + +@app.on_event("startup") +async def app_startup(): + pass + + +def drop_collection(collection_name: str): + load_collection(collection_name).drop_collection() + + +def load_collection(collection_name: str) -> Milvus: + return Milvus( + embedding_function=CustomEmbeddings(), + collection_name=collection_name, + connection_args={ + "host": MILVUS_HOST, + "port": "19530", + "user": "username", + "password": "password", + }, + index_params={ + "metric_type": "IP", + "index_type": "FLAT", + "params": {"nlist": 16384}, + }, + search_params={"metric_type": "IP", "params": {"nprobe": 32}}, + ) + + +@app.post("/add_texts") +async def add_texts_embeddings( + collection_name: str, + texts: List[str], + metadatas: Optional[List[dict]] = None, +): + """_summary_ + + Args: + texts (List[str]): _description_ + connlection_name (str): this should reflect user's random id + the assistant_id they created. + """ + pks = add_texts(collection_name, texts, metadatas) + + +def add_texts( + collection_name: str, + texts: List[str], + metadatas: Optional[List[dict]] = None, +): + c = load_collection(collection_name) + pks = c.add_texts(texts=texts, metadatas=metadatas) + print(pks) + return pks + + +@app.delete("/delete_docs") +async def delete_docs_api(collection_name: str, expr: str): + delete_docs(collection_name, expr) + + +def delete_docs(collection_name: str, expr: str): + c = load_collection(collection_name) + c.delete_entities(expr=expr) + + +def get_top_k_biencoder_match_milvus(query: Query): + c = load_collection(query.collection_name) + + docs = c.similarity_search(query.text, k=query.topk) + res = [] + for i, d in enumerate(docs): + thisd = {"id": i, "metadata": d.metadata, "text": d.page_content} + res.append(thisd) + return res + + +def get_similar_match(query, biencoder_match_method: str, rerank: bool = False): + query_biencoder_matches = get_top_k_biencoder_match_milvus(query) + return query_biencoder_matches[: query.topr] + + +@app.post("/similarity_match") +def text_similarity_match(query: Query): + res = get_similar_match(query, biencoder_match_method="milvus", rerank=query.rerank) + return {"response": res} + + +@app.get("/ping") +def ping(): + return {"response": "Pong!"} diff --git a/services/backend/task_executor/app/vector_db/milvus/query_milvus.py b/services/backend/task_executor/app/vector_db/milvus/query_milvus.py new file mode 100644 index 0000000..4e7779c --- /dev/null +++ b/services/backend/task_executor/app/vector_db/milvus/query_milvus.py @@ -0,0 +1,825 @@ +"""Wrapper around the Milvus vector database.""" + +# Standard Library +import logging +import os +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from uuid import uuid4 + +# Third Party +import numpy as np +from langchain.docstore.document import Document +from langchain.embeddings.base import Embeddings +from langchain.vectorstores.base import VectorStore +from langchain.vectorstores.utils import maximal_marginal_relevance + +logger = logging.getLogger(__name__) + +DEFAULT_MILVUS_CONNECTION = { + "host": os.getenv("MILVUS_HOST", "localhost"), + "port": "19530", + "user": "", + "password": "", + "secure": False, +} + + +class Milvus(VectorStore): + """Wrapper around the Milvus vector database.""" + + def __init__( + self, + embedding_function: Embeddings, + collection_name: str = "DefaultCollection", + connection_args: Optional[Dict[str, Any]] = None, + consistency_level: str = "Session", + index_params: Optional[Dict[str, Any]] = None, + search_params: Optional[Dict[str, Any]] = None, + drop_old: Optional[bool] = False, + ): + """Initialize wrapper around the milvus vector database. + + In order to use this you need to have `pymilvus` installed and a + running Milvus/Zilliz Cloud instance. + + See the following documentation for how to run a Milvus instance: + https://milvus.io/docs/install_standalone-docker.md + + If looking for a hosted Milvus, take a looka this documentation: + https://zilliz.com/cloud + + IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA. + + The connection args used for this class comes in the form of a dict, + here are a few of the options: + address (str): The actual address of Milvus + instance. Example address: "localhost:19530" + uri (str): The uri of Milvus instance. Example uri: + "http://randomwebsite:19530", + "tcp:foobarsite:19530", + "https://ok.s3.south.com:19530". + host (str): The host of Milvus instance. Default at "localhost", + PyMilvus will fill in the default host if only port is provided. + port (str/int): The port of Milvus instance. Default at 19530, PyMilvus + will fill in the default port if only host is provided. + user (str): Use which user to connect to Milvus instance. If user and + password are provided, we will add related header in every RPC call. + password (str): Required when user is provided. The password + corresponding to the user. + secure (bool): Default is false. If set to true, tls will be enabled. + client_key_path (str): If use tls two-way authentication, need to + write the client.key path. + client_pem_path (str): If use tls two-way authentication, need to + write the client.pem path. + ca_pem_path (str): If use tls two-way authentication, need to write + the ca.pem path. + server_pem_path (str): If use tls one-way authentication, need to + write the server.pem path. + server_name (str): If use tls, need to write the common name. + + Args: + embedding_function (Embeddings): Function used to embed the text. + collection_name (str): Which Milvus collection to use. Defaults to + "LangChainCollection". + connection_args (Optional[dict[str, any]]): The arguments for connection to + Milvus/Zilliz instance. Defaults to DEFAULT_MILVUS_CONNECTION. + consistency_level (str): The consistency level to use for a collection. + Defaults to "Session". + index_params (Optional[dict]): Which index params to use. Defaults to + HNSW/AUTOINDEX depending on service. + search_params (Optional[dict]): Which search params to use. Defaults to + default of index. + drop_old (Optional[bool]): Whether to drop the current collection. Defaults + to False. + """ + try: + # Third Party + from pymilvus import Collection, utility + except ImportError: + raise ValueError( + "Could not import pymilvus python package. " + "Please install it with `pip install pymilvus`." + ) + + # Default search params when one is not provided. + self.default_search_params = { + "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, + "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}}, + "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, + "HNSW": {"metric_type": "L2", "params": {"ef": 10}}, + "RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}}, + "RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}}, + "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}}, + "IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}}, + "ANNOY": {"metric_type": "L2", "params": {"search_k": 10}}, + "AUTOINDEX": {"metric_type": "L2", "params": {}}, + } + + self.embedding_func = embedding_function + self.collection_name = collection_name + self.index_params = index_params + self.search_params = search_params + self.consistency_level = consistency_level + + # In order for a collection to be compatible, pk needs to be auto'id and int + self._primary_field = "pk" + # In order for compatiblility, the text field will need to be called "text" + self._text_field = "text" + # In order for compatbility, the vector field needs to be called "vector" + self._vector_field = "vector" + self.fields: list[str] = [] + # Create the connection to the server + if connection_args is None: + connection_args = DEFAULT_MILVUS_CONNECTION + self.alias = self._create_connection_alias(connection_args) + self.col: Optional[Collection] = None + + # Grab the existing colection if it exists + if utility.has_collection(self.collection_name, using=self.alias): + self.col = Collection( + self.collection_name, + using=self.alias, + ) + # If need to drop old, drop it + if drop_old and isinstance(self.col, Collection): + self.col.drop() + self.col = None + + # Initialize the vector store + self._init() + + def drop_collection(self): + # Third Party + from pymilvus import utility + + utility.drop_collection(collection_name=self.collection_name, using=self.alias) + + def _create_connection_alias(self, connection_args: dict) -> str: + """Create the connection to the Milvus server.""" + # Third Party + from pymilvus import MilvusException, connections + + # Grab the connection arguments that are used for checking existing connection + host: str = connection_args.get("host", None) + port: Union[str, int] = connection_args.get("port", None) + address: str = connection_args.get("address", None) + uri: str = connection_args.get("uri", None) + user = connection_args.get("user", None) + + # Order of use is host/port, uri, address + if host is not None and port is not None: + given_address = str(host) + ":" + str(port) + elif uri is not None: + given_address = uri.split("https://")[1] + elif address is not None: + given_address = address + else: + given_address = None + logger.debug("Missing standard address type for reuse atttempt") + + # User defaults to empty string when getting connection info + if user is not None: + tmp_user = user + else: + tmp_user = "" + + # If a valid address was given, then check if a connection exists + if given_address is not None: + for con in connections.list_connections(): + addr = connections.get_connection_addr(con[0]) + if ( + con[1] + and ("address" in addr) + and (addr["address"] == given_address) + and ("user" in addr) + and (addr["user"] == tmp_user) + ): + logger.debug("Using previous connection: %s", con[0]) + return con[0] + + # Generate a new connection if one doesnt exist + alias = uuid4().hex + try: + connections.connect(alias=alias, **connection_args) + logger.debug("Created new connection using: %s", alias) + return alias + except MilvusException as e: + logger.error("Failed to create new connection using: %s", alias) + raise e + + def _init( + self, embeddings: Optional[list] = None, metadatas: Optional[List[dict]] = None + ) -> None: + if embeddings is not None: + self._create_collection(embeddings, metadatas) + self._extract_fields() + self._create_index() + self._create_search_params() + self._load() + + def _create_collection( + self, embeddings: list, metadatas: Optional[List[dict]] = None + ) -> None: + # Third Party + from pymilvus import ( + Collection, + CollectionSchema, + DataType, + FieldSchema, + MilvusException, + ) + from pymilvus.orm.types import infer_dtype_bydata + + # Determine embedding dim + dim = len(embeddings[0]) + fields = [] + # Determine metadata schema + if metadatas: + # Create FieldSchema for each entry in metadata. + for key, value in metadatas[0].items(): + # Infer the corresponding datatype of the metadata + dtype = infer_dtype_bydata(value) + # Datatype isnt compatible + if dtype == DataType.UNKNOWN or dtype == DataType.NONE: + logger.error( + "Failure to create collection, unrecognized dtype for key: %s", + key, + ) + raise ValueError(f"Unrecognized datatype for {key}.") + # Dataype is a string/varchar equivalent + elif dtype == DataType.VARCHAR: + fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535)) + else: + fields.append(FieldSchema(key, dtype)) + + # Create the text field + fields.append( + FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535) + ) + # Create the primary key field + fields.append( + FieldSchema( + self._primary_field, DataType.INT64, is_primary=True, auto_id=True + ) + ) + # Create the vector field, supports binary or float vectors + fields.append( + FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim) + ) + + # Create the schema for the collection + schema = CollectionSchema(fields) + + # Create the collection + try: + self.col = Collection( + name=self.collection_name, + schema=schema, + consistency_level=self.consistency_level, + using=self.alias, + ) + except MilvusException as e: + logger.error( + "Failed to create collection: %s error: %s", self.collection_name, e + ) + raise e + + def _extract_fields(self) -> None: + """Grab the existing fields from the Collection""" + # Third Party + from pymilvus import Collection + + if isinstance(self.col, Collection): + schema = self.col.schema + for x in schema.fields: + self.fields.append(x.name) + # Since primary field is auto-id, no need to track it + self.fields.remove(self._primary_field) + + def _get_index(self) -> Optional[Dict[str, Any]]: + """Return the vector index information if it exists""" + # Third Party + from pymilvus import Collection + + if isinstance(self.col, Collection): + for x in self.col.indexes: + if x.field_name == self._vector_field: + return x.to_dict() + return None + + def _create_index(self) -> None: + """Create a index on the collection""" + # Third Party + from pymilvus import Collection, MilvusException + + if isinstance(self.col, Collection) and self._get_index() is None: + try: + # If no index params, use a default HNSW based one + if self.index_params is None: + self.index_params = { + "metric_type": "L2", + "index_type": "HNSW", + "params": {"M": 8, "efConstruction": 64}, + } + + try: + self.col.create_index( + self._vector_field, + index_params=self.index_params, + using=self.alias, + ) + + # If default did not work, most likely on Zilliz Cloud + except MilvusException: + # Use AUTOINDEX based index + self.index_params = { + "metric_type": "L2", + "index_type": "AUTOINDEX", + "params": {}, + } + self.col.create_index( + self._vector_field, + index_params=self.index_params, + using=self.alias, + ) + logger.debug( + "Successfully created an index on collection: %s", + self.collection_name, + ) + + except MilvusException as e: + logger.error( + "Failed to create an index on collection: %s", self.collection_name + ) + raise e + + def _create_search_params(self) -> None: + """Generate search params based on the current index type""" + # Third Party + from pymilvus import Collection + + if isinstance(self.col, Collection) and self.search_params is None: + index = self._get_index() + if index is not None: + index_type: str = index["index_param"]["index_type"] + metric_type: str = index["index_param"]["metric_type"] + self.search_params = self.default_search_params[index_type] + self.search_params["metric_type"] = metric_type + + def _load(self) -> None: + """Load the collection if available.""" + # Third Party + from pymilvus import Collection + + if isinstance(self.col, Collection) and self._get_index() is not None: + self.col.load() + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + timeout: Optional[int] = None, + batch_size: int = 1000, + **kwargs: Any, + ) -> List[str]: + """Insert text data into Milvus. + + Inserting data when the collection has not be made yet will result + in creating a new Collection. The data of the first entity decides + the schema of the new collection, the dim is extracted from the first + embedding and the columns are decided by the first metadata dict. + Metada keys will need to be present for all inserted values. At + the moment there is no None equivalent in Milvus. + + Args: + texts (Iterable[str]): The texts to embed, it is assumed + that they all fit in memory. + metadatas (Optional[List[dict]]): Metadata dicts attached to each of + the texts. Defaults to None. + timeout (Optional[int]): Timeout for each batch insert. Defaults + to None. + batch_size (int, optional): Batch size to use for insertion. + Defaults to 1000. + + Raises: + MilvusException: Failure to add texts + + Returns: + List[str]: The resulting keys for each inserted element. + """ + # Third Party + from pymilvus import Collection, MilvusException + + texts = list(texts) + + try: + embeddings = self.embedding_func.embed_documents(texts) + except NotImplementedError: + embeddings = [self.embedding_func.embed_query(x) for x in texts] + + if len(embeddings) == 0: + logger.debug("Nothing to insert, skipping.") + return [] + # If the collection hasnt been initialized yet, perform all steps to do so + if not isinstance(self.col, Collection): + self._init(embeddings, metadatas) + + # Dict to hold all insert columns + insert_dict: dict[str, list] = { + self._text_field: texts, + self._vector_field: embeddings, + } + + # Collect the metadata into the insert dict. + if metadatas is not None: + for d in metadatas: + for key in self.fields: + if key not in [self._text_field, self._vector_field]: + value = d.get( + key, "" + ) # Use a default value (e.g., None) if the key is not present in the metadata + insert_dict.setdefault(key, []).append(value) + + # Total insert count + vectors: list = insert_dict[self._vector_field] + total_count = len(vectors) + + pks: list[str] = [] + + assert isinstance(self.col, Collection) + for i in range(0, total_count, batch_size): + # Grab end index + end = min(i + batch_size, total_count) + # Convert dict to list of lists batch for insertion + insert_list = [insert_dict[x][i:end] for x in self.fields] + # Insert into the collection. + try: + res: Collection + res = self.col.insert(insert_list, timeout=timeout, **kwargs) + pks.extend(res.primary_keys) + except MilvusException as e: + logger.error( + "Failed to insert batch starting at entity: %s/%s", i, total_count + ) + raise e + return pks + + def delete_entities(self, expr: str): + if self.col is None: + logger.debug("No existing collection to search.") + + self.col.delete(expr=expr) + + def similarity_search( + self, + query: str, + k: int = 5, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Document]: + """Perform a similarity search against the query string. + + Args: + query (str): The text to search. + k (int, optional): How many results to return. Defaults to 4. + param (dict, optional): The search params for the index type. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[Document]: Document results for search. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + res = self.similarity_search_with_score( + query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs + ) + return [doc for doc, _ in res] + + def similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Document]: + """Perform a similarity search against the query string. + + Args: + embedding (List[float]): The embedding vector to search. + k (int, optional): How many results to return. Defaults to 4. + param (dict, optional): The search params for the index type. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[Document]: Document results for search. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + res = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs + ) + return [doc for doc, _ in res] + + def similarity_search_with_score( + self, + query: str, + k: int = 4, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Perform a search on a query string and return results with score. + + For more information about the search parameters, take a look at the pymilvus + documentation found here: + https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md + + Args: + query (str): The text being searched. + k (int, optional): The amount of results ot return. Defaults to 4. + param (dict): The search params for the specified index. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[float], List[Tuple[Document, any, any]]: + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + + # Embed the query text. + embedding = self.embedding_func.embed_query(query) + + res = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs + ) + return res + + def similarity_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Perform a search on a query string and return results with score. + + For more information about the search parameters, take a look at the pymilvus + documentation found here: + https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md + + Args: + embedding (List[float]): The embedding vector being searched. + k (int, optional): The amount of results ot return. Defaults to 4. + param (dict): The search params for the specified index. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[Tuple[Document, float]]: Result doc and score. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + + if param is None: + param = self.search_params + + # Determine result metadata fields. + output_fields = self.fields[:] + output_fields.remove(self._vector_field) + + # Perform the search. + res = self.col.search( + data=[embedding], + anns_field=self._vector_field, + param=param, + limit=k, + expr=expr, + output_fields=output_fields, + timeout=timeout, + **kwargs, + ) + # Organize results. + ret = [] + for result in res[0]: + meta = {x: result.entity.get(x) for x in output_fields} + doc = Document(page_content=meta.pop(self._text_field), metadata=meta) + pair = (doc, result.score) + ret.append(pair) + + return ret + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Document]: + """Perform a search and return results that are reordered by MMR. + + Args: + query (str): The text being searched. + k (int, optional): How many results to give. Defaults to 4. + fetch_k (int, optional): Total results to select k from. + Defaults to 20. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5 + param (dict, optional): The search params for the specified index. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + + Returns: + List[Document]: Document results for search. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + + embedding = self.embedding_func.embed_query(query) + + return self.max_marginal_relevance_search_by_vector( + embedding=embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + param=param, + expr=expr, + timeout=timeout, + **kwargs, + ) + + def max_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Document]: + """Perform a search and return results that are reordered by MMR. + + Args: + embedding (str): The embedding vector being searched. + k (int, optional): How many results to give. Defaults to 4. + fetch_k (int, optional): Total results to select k from. + Defaults to 20. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5 + param (dict, optional): The search params for the specified index. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[Document]: Document results for search. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + + if param is None: + param = self.search_params + + # Determine result metadata fields. + output_fields = self.fields[:] + output_fields.remove(self._vector_field) + + # Perform the search. + res = self.col.search( + data=[embedding], + anns_field=self._vector_field, + param=param, + limit=fetch_k, + expr=expr, + output_fields=output_fields, + timeout=timeout, + **kwargs, + ) + # Organize results. + ids = [] + documents = [] + scores = [] + for result in res[0]: + meta = {x: result.entity.get(x) for x in output_fields} + doc = Document(page_content=meta.pop(self._text_field), metadata=meta) + documents.append(doc) + scores.append(result.score) + ids.append(result.id) + + vectors = self.col.query( + expr=f"{self._primary_field} in {ids}", + output_fields=[self._primary_field, self._vector_field], + timeout=timeout, + ) + # Reorganize the results from query to match search order. + vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors} + + ordered_result_embeddings = [vectors[x] for x in ids] + + # Get the new order of results. + new_ordering = maximal_marginal_relevance( + np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult + ) + + # Reorder the values and return. + ret = [] + for x in new_ordering: + # Function can return -1 index + if x == -1: + break + else: + ret.append(documents[x]) + return ret + + @classmethod + def from_texts( + cls, + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = "LangChainCollection", + connection_args: Dict[str, Any] = DEFAULT_MILVUS_CONNECTION, + consistency_level: str = "Session", + index_params: Optional[dict] = None, + search_params: Optional[dict] = None, + drop_old: bool = False, + **kwargs: Any, + ): + """Create a Milvus collection, indexes it with HNSW, and insert data. + + Args: + texts (List[str]): Text data. + embedding (Embeddings): Embedding function. + metadatas (Optional[List[dict]]): Metadata for each text if it exists. + Defaults to None. + collection_name (str, optional): Collection name to use. Defaults to + "LangChainCollection". + connection_args (dict[str, Any], optional): Connection args to use. Defaults + to DEFAULT_MILVUS_CONNECTION. + consistency_level (str, optional): Which consistency level to use. Defaults + to "Session". + index_params (Optional[dict], optional): Which index_params to use. Defaults + to None. + search_params (Optional[dict], optional): Which search params to use. + Defaults to None. + drop_old (Optional[bool], optional): Whether to drop the collection with + that name if it exists. Defaults to False. + + Returns: + Milvus: Milvus Vector Store + """ + vector_db = cls( + embedding_function=embedding, + collection_name=collection_name, + connection_args=connection_args, + consistency_level=consistency_level, + index_params=index_params, + search_params=search_params, + drop_old=drop_old, + **kwargs, + ) + vector_db.add_texts(texts=texts, metadatas=metadatas) + return vector_db diff --git a/services/backend/task_executor/requirements.txt b/services/backend/task_executor/requirements.txt new file mode 100644 index 0000000..24d4970 --- /dev/null +++ b/services/backend/task_executor/requirements.txt @@ -0,0 +1,20 @@ +aioredis==2.0.1 +beanie==1.23.6 +celery==5.3.6 +fastapi==0.105.0 +googlesearch-python==1.2.3 +langchain==0.0.351 +motor==3.3.2 +openai==1.6.1 +pydantic==1.10.9 +python-multipart==0.0.6 +redis==5.0.1 +requests==2.31.0 +uvicorn==0.25.0 +websockets==12.0 +pymilvus==2.3.4 +pypdf2==3.0.1 +spacy==3.7.2 +markdownify==0.11.6 +playwright==1.39.0 +tiktoken==0.5.2