From e024afc81ca3656ad176740ba2dcf6e6d409b30d Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 2 Jun 2024 16:56:35 -0700 Subject: [PATCH] feat: openai manifold --- .../providers/ollama_manifold_pipeline.py | 2 +- .../providers/openai_manifold_pipeline.py | 123 ++++++++++++++++++ main.py | 14 ++ 3 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 examples/pipelines/providers/openai_manifold_pipeline.py diff --git a/examples/pipelines/providers/ollama_manifold_pipeline.py b/examples/pipelines/providers/ollama_manifold_pipeline.py index d2e9fcec..4393b3de 100644 --- a/examples/pipelines/providers/ollama_manifold_pipeline.py +++ b/examples/pipelines/providers/ollama_manifold_pipeline.py @@ -58,7 +58,7 @@ def get_ollama_models(self): print(f"Error: {e}") return [ { - "id": self.id, + "id": "error", "name": "Could not fetch models from Ollama, please update the URL in the valves.", }, ] diff --git a/examples/pipelines/providers/openai_manifold_pipeline.py b/examples/pipelines/providers/openai_manifold_pipeline.py new file mode 100644 index 00000000..0d667f56 --- /dev/null +++ b/examples/pipelines/providers/openai_manifold_pipeline.py @@ -0,0 +1,123 @@ +from typing import List, Union, Generator, Iterator +from schemas import OpenAIChatMessage +from pydantic import BaseModel + +import os +import requests + + +class Pipeline: + class Valves(BaseModel): + OPENAI_API_BASE_URL: str = "https://api.openai.com/v1" + OPENAI_API_KEY: str = "" + pass + + def __init__(self): + self.type = "manifold" + # Optionally, you can set the id and name of the pipeline. + # Best practice is to not specify the id so that it can be automatically inferred from the filename, so that users can install multiple versions of the same pipeline. + # The identifier must be unique across all pipelines. + # The identifier must be an alphanumeric string that can include underscores or hyphens. It cannot contain spaces, special characters, slashes, or backslashes. + # self.id = "openai_pipeline" + self.name = "OpenAI: " + + self.valves = self.Valves( + **{ + "OPENAI_API_KEY": os.getenv( + "OPENAI_API_KEY", "your-openai-api-key-here" + ) + } + ) + + self.pipelines = self.get_openai_models() + pass + + async def on_startup(self): + # This function is called when the server is started. + print(f"on_startup:{__name__}") + pass + + async def on_shutdown(self): + # This function is called when the server is stopped. + print(f"on_shutdown:{__name__}") + pass + + async def on_valves_updated(self): + # This function is called when the valves are updated. + print(f"on_valves_updated:{__name__}") + self.pipelines = self.get_openai_models() + pass + + def get_openai_models(self): + if self.valves.OPENAI_API_KEY: + try: + headers = {} + headers["Authorization"] = f"Bearer {self.valves.OPENAI_API_KEY}" + headers["Content-Type"] = "application/json" + + r = requests.get( + f"{self.valves.OPENAI_API_BASE_URL}/models", headers=headers + ) + + models = r.json() + return [ + { + "id": model["id"], + "name": model["name"] if "name" in model else model["id"], + } + for model in models["data"] + if "gpt" in model["id"] + ] + + except Exception as e: + + print(f"Error: {e}") + return [ + { + "id": "error", + "name": "Could not fetch models from OpenAI, please update the API Key in the valves.", + }, + ] + else: + return [] + + def pipe( + self, user_message: str, model_id: str, messages: List[dict], body: dict + ) -> Union[str, Generator, Iterator]: + # This is where you can add your custom pipelines like RAG. + print(f"pipe:{__name__}") + + print(messages) + print(user_message) + + headers = {} + headers["Authorization"] = f"Bearer {self.valves.OPENAI_API_KEY}" + headers["Content-Type"] = "application/json" + + payload = {**body, "model": model_id} + + if "user" in payload: + del payload["user"] + if "chat_id" in payload: + del payload["chat_id"] + if "title" in payload: + del payload["title"] + + print(payload) + + try: + r = requests.post( + url=f"{self.valves.OPENAI_API_BASE_URL}/chat/completions", + json=payload, + headers=headers, + stream=True, + ) + + r.raise_for_status() + + if body["stream"]: + return r.iter_lines() + else: + return r.json() + except Exception as e: + return f"Error: {e}" diff --git a/main.py b/main.py index 6e5c23bc..0bb917ff 100644 --- a/main.py +++ b/main.py @@ -506,6 +506,13 @@ async def filter_inlet(pipeline_id: str, form_data: FilterForm): detail=f"Filter {pipeline_id} not found", ) + try: + pipeline = app.state.PIPELINES[form_data.body["model"]] + if pipeline["type"] == "manifold": + pipeline_id = pipeline_id.split(".")[0] + except: + pass + pipeline = PIPELINE_MODULES[pipeline_id] try: @@ -531,6 +538,13 @@ async def filter_outlet(pipeline_id: str, form_data: FilterForm): detail=f"Filter {pipeline_id} not found", ) + try: + pipeline = app.state.PIPELINES[form_data.body["model"]] + if pipeline["type"] == "manifold": + pipeline_id = pipeline_id.split(".")[0] + except: + pass + pipeline = PIPELINE_MODULES[pipeline_id] try: