Skip to content

Commit

Permalink
feat: openai manifold
Browse files Browse the repository at this point in the history
  • Loading branch information
tjbck committed Jun 2, 2024
1 parent c19f52b commit e024afc
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 1 deletion.
2 changes: 1 addition & 1 deletion examples/pipelines/providers/ollama_manifold_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_ollama_models(self):
print(f"Error: {e}")
return [
{
"id": self.id,
"id": "error",
"name": "Could not fetch models from Ollama, please update the URL in the valves.",
},
]
Expand Down
123 changes: 123 additions & 0 deletions examples/pipelines/providers/openai_manifold_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import List, Union, Generator, Iterator
from schemas import OpenAIChatMessage
from pydantic import BaseModel

import os
import requests


class Pipeline:
class Valves(BaseModel):
OPENAI_API_BASE_URL: str = "https://api.openai.com/v1"
OPENAI_API_KEY: str = ""
pass

def __init__(self):
self.type = "manifold"
# Optionally, you can set the id and name of the pipeline.
# Best practice is to not specify the id so that it can be automatically inferred from the filename, so that users can install multiple versions of the same pipeline.
# The identifier must be unique across all pipelines.
# The identifier must be an alphanumeric string that can include underscores or hyphens. It cannot contain spaces, special characters, slashes, or backslashes.
# self.id = "openai_pipeline"
self.name = "OpenAI: "

self.valves = self.Valves(
**{
"OPENAI_API_KEY": os.getenv(
"OPENAI_API_KEY", "your-openai-api-key-here"
)
}
)

self.pipelines = self.get_openai_models()
pass

async def on_startup(self):
# This function is called when the server is started.
print(f"on_startup:{__name__}")
pass

async def on_shutdown(self):
# This function is called when the server is stopped.
print(f"on_shutdown:{__name__}")
pass

async def on_valves_updated(self):
# This function is called when the valves are updated.
print(f"on_valves_updated:{__name__}")
self.pipelines = self.get_openai_models()
pass

def get_openai_models(self):
if self.valves.OPENAI_API_KEY:
try:
headers = {}
headers["Authorization"] = f"Bearer {self.valves.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"

r = requests.get(
f"{self.valves.OPENAI_API_BASE_URL}/models", headers=headers
)

models = r.json()
return [
{
"id": model["id"],
"name": model["name"] if "name" in model else model["id"],
}
for model in models["data"]
if "gpt" in model["id"]
]

except Exception as e:

print(f"Error: {e}")
return [
{
"id": "error",
"name": "Could not fetch models from OpenAI, please update the API Key in the valves.",
},
]
else:
return []

def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.
print(f"pipe:{__name__}")

print(messages)
print(user_message)

headers = {}
headers["Authorization"] = f"Bearer {self.valves.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"

payload = {**body, "model": model_id}

if "user" in payload:
del payload["user"]
if "chat_id" in payload:
del payload["chat_id"]
if "title" in payload:
del payload["title"]

print(payload)

try:
r = requests.post(
url=f"{self.valves.OPENAI_API_BASE_URL}/chat/completions",
json=payload,
headers=headers,
stream=True,
)

r.raise_for_status()

if body["stream"]:
return r.iter_lines()
else:
return r.json()
except Exception as e:
return f"Error: {e}"
14 changes: 14 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,13 @@ async def filter_inlet(pipeline_id: str, form_data: FilterForm):
detail=f"Filter {pipeline_id} not found",
)

try:
pipeline = app.state.PIPELINES[form_data.body["model"]]
if pipeline["type"] == "manifold":
pipeline_id = pipeline_id.split(".")[0]
except:
pass

pipeline = PIPELINE_MODULES[pipeline_id]

try:
Expand All @@ -531,6 +538,13 @@ async def filter_outlet(pipeline_id: str, form_data: FilterForm):
detail=f"Filter {pipeline_id} not found",
)

try:
pipeline = app.state.PIPELINES[form_data.body["model"]]
if pipeline["type"] == "manifold":
pipeline_id = pipeline_id.split(".")[0]
except:
pass

pipeline = PIPELINE_MODULES[pipeline_id]

try:
Expand Down

0 comments on commit e024afc

Please sign in to comment.