diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 166376305a..af0a8ed0ee 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -3,10 +3,10 @@ updates: - package-ecosystem: pip directory: '/backend' schedule: - interval: weekly + interval: monthly target-branch: 'dev' - package-ecosystem: 'github-actions' directory: '/' schedule: # Check for updates to GitHub Actions every week - interval: 'weekly' + interval: monthly diff --git a/.github/workflows/build-release.yml b/.github/workflows/build-release.yml index cae363f42a..443d904199 100644 --- a/.github/workflows/build-release.yml +++ b/.github/workflows/build-release.yml @@ -10,61 +10,63 @@ jobs: runs-on: ubuntu-latest steps: - - name: Checkout repository - uses: actions/checkout@v4 + - name: Checkout repository + uses: actions/checkout@v4 - - name: Check for changes in package.json - run: | - git diff --cached --diff-filter=d package.json || { - echo "No changes to package.json" - exit 1 - } - - - name: Get version number from package.json - id: get_version - run: | - VERSION=$(jq -r '.version' package.json) - echo "::set-output name=version::$VERSION" + - name: Check for changes in package.json + run: | + git diff --cached --diff-filter=d package.json || { + echo "No changes to package.json" + exit 1 + } - - name: Extract latest CHANGELOG entry - id: changelog - run: | - CHANGELOG_CONTENT=$(awk 'BEGIN {print_section=0;} /^## \[/ {if (print_section == 0) {print_section=1;} else {exit;}} print_section {print;}' CHANGELOG.md) - CHANGELOG_ESCAPED=$(echo "$CHANGELOG_CONTENT" | sed ':a;N;$!ba;s/\n/%0A/g') - echo "Extracted latest release notes from CHANGELOG.md:" - echo -e "$CHANGELOG_CONTENT" - echo "::set-output name=content::$CHANGELOG_ESCAPED" + - name: Get version number from package.json + id: get_version + run: | + VERSION=$(jq -r '.version' package.json) + echo "::set-output name=version::$VERSION" - - name: Create GitHub release - uses: actions/github-script@v7 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - const changelog = `${{ steps.changelog.outputs.content }}`; - const release = await github.rest.repos.createRelease({ - owner: context.repo.owner, - repo: context.repo.repo, - tag_name: `v${{ steps.get_version.outputs.version }}`, - name: `v${{ steps.get_version.outputs.version }}`, - body: changelog, - }) - console.log(`Created release ${release.data.html_url}`) + - name: Extract latest CHANGELOG entry + id: changelog + run: | + CHANGELOG_CONTENT=$(awk 'BEGIN {print_section=0;} /^## \[/ {if (print_section == 0) {print_section=1;} else {exit;}} print_section {print;}' CHANGELOG.md) + CHANGELOG_ESCAPED=$(echo "$CHANGELOG_CONTENT" | sed ':a;N;$!ba;s/\n/%0A/g') + echo "Extracted latest release notes from CHANGELOG.md:" + echo -e "$CHANGELOG_CONTENT" + echo "::set-output name=content::$CHANGELOG_ESCAPED" - - name: Upload package to GitHub release - uses: actions/upload-artifact@v4 - with: - name: package - path: . - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Create GitHub release + uses: actions/github-script@v7 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const changelog = `${{ steps.changelog.outputs.content }}`; + const release = await github.rest.repos.createRelease({ + owner: context.repo.owner, + repo: context.repo.repo, + tag_name: `v${{ steps.get_version.outputs.version }}`, + name: `v${{ steps.get_version.outputs.version }}`, + body: changelog, + }) + console.log(`Created release ${release.data.html_url}`) - - name: Trigger Docker build workflow - uses: actions/github-script@v7 - with: - script: | - github.rest.actions.createWorkflowDispatch({ - owner: context.repo.owner, - repo: context.repo.repo, - workflow_id: 'docker-build.yaml', - ref: 'v${{ steps.get_version.outputs.version }}', - }) + - name: Upload package to GitHub release + uses: actions/upload-artifact@v4 + with: + name: package + path: | + . + !.git + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Trigger Docker build workflow + uses: actions/github-script@v7 + with: + script: | + github.rest.actions.createWorkflowDispatch({ + owner: context.repo.owner, + repo: context.repo.repo, + workflow_id: 'docker-build.yaml', + ref: 'v${{ steps.get_version.outputs.version }}', + }) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b6bdd98fb..6b86b26001 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,50 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.3.15] - 2024-08-21 + +### Added + +- **🔗 Temporary Chat Activation**: Integrated a new URL parameter 'temporary-chat=true' to enable temporary chat sessions directly through the URL. +- **🌄 ComfyUI Seed Node Support**: Introduced seed node support in ComfyUI for image generation, allowing users to specify node IDs for randomized seed assignment. + +### Fixed + +- **🛠️ Tools and Functions**: Resolved a critical issue where Tools and Functions were not properly functioning, restoring full capability and reliability to these essential features. +- **🔘 Chat Action Button in Many Model Chat**: Fixed the malfunctioning of chat action buttons in many model chat environments, ensuring a smoother and more responsive user interaction. +- **⏪ Many Model Chat Compatibility**: Restored backward compatibility for many model chats. + +## [0.3.14] - 2024-08-21 + +### Added + +- **🛠️ Custom ComfyUI Workflow**: Deprecating several older environment variables, this enhancement introduces a new, customizable workflow for a more tailored user experience. +- **🔀 Merge Responses in Many Model Chat**: Enhances the dialogue by merging responses from multiple models into a single, coherent reply, improving the interaction quality in many model chats. +- **✅ Multiple Instances of Same Model in Chats**: Enhanced many model chat to support adding multiple instances of the same model. +- **🔧 Quick Actions in Model Workspace**: Enhanced Shift key quick actions for hiding/unhiding and deleting models, facilitating a smoother workflow. +- **🗨️ Markdown Rendering in User Messages**: User messages are now rendered in Markdown, enhancing readability and interaction. +- **💬 Temporary Chat Feature**: Introduced a temporary chat feature, deprecating the old chat history setting to enhance user interaction flexibility. +- **🖋️ User Message Editing**: Enhanced the user chat editing feature to allow saving changes without sending, providing more flexibility in message management. +- **🛡️ Security Enhancements**: Various security improvements implemented across the platform to ensure safer user experiences. +- **🌍 Updated Translations**: Enhanced translations for Chinese, Ukrainian, and Bahasa Malaysia, improving localization and user comprehension. + +### Fixed + +- **📑 Mermaid Rendering Issue**: Addressed issues with Mermaid chart rendering to ensure clean and clear visual data representation. +- **🎭 PWA Icon Maskability**: Fixed the Progressive Web App icon to be maskable, ensuring proper display on various device home screens. +- **🔀 Cloned Model Chat Freezing Issue**: Fixed a bug where cloning many model chats would cause freezing, enhancing stability and responsiveness. +- **🔍 Generic Error Handling and Refinements**: Various minor fixes and refinements to address previously untracked issues, ensuring smoother operations. + +### Changed + +- **🖼️ Image Generation Refactor**: Overhauled image generation processes for improved efficiency and quality. +- **🔨 Refactor Tool and Function Calling**: Refactored tool and function calling mechanisms for improved clarity and maintainability. +- **🌐 Backend Library Updates**: Updated critical backend libraries including SQLAlchemy, uvicorn[standard], faster-whisper, bcrypt, and boto3 for enhanced performance and security. + +### Removed + +- **🚫 Deprecated ComfyUI Environment Variables**: Removed several outdated environment variables related to ComfyUI settings, simplifying configuration management. + ## [0.3.13] - 2024-08-14 ### Added diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 20519b59b1..d66a9fa11e 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -1,5 +1,12 @@ -import os +import hashlib +import json import logging +import os +import uuid +from functools import lru_cache +from pathlib import Path + +import requests from fastapi import ( FastAPI, Request, @@ -8,34 +15,14 @@ status, UploadFile, File, - Form, ) -from fastapi.responses import StreamingResponse, JSONResponse, FileResponse - from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse from pydantic import BaseModel - -import uuid -import requests -import hashlib -from pathlib import Path -import json - -from constants import ERROR_MESSAGES -from utils.utils import ( - decode_token, - get_current_user, - get_verified_user, - get_admin_user, -) -from utils.misc import calculate_sha256 - - from config import ( SRC_LOG_LEVELS, CACHE_DIR, - UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR, WHISPER_MODEL_AUTO_UPDATE, @@ -51,6 +38,13 @@ AUDIO_TTS_MODEL, AUDIO_TTS_VOICE, AppConfig, + CORS_ALLOW_ORIGIN, +) +from constants import ERROR_MESSAGES +from utils.utils import ( + get_current_user, + get_verified_user, + get_admin_user, ) log = logging.getLogger(__name__) @@ -59,7 +53,7 @@ app = FastAPI() app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -261,6 +255,13 @@ async def speech(request: Request, user=Depends(get_verified_user)): raise HTTPException(status_code=400, detail="Invalid JSON payload") voice_id = payload.get("voice", "") + + if voice_id not in get_available_voices(): + raise HTTPException( + status_code=400, + detail="Invalid voice id", + ) + url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}" headers = { @@ -466,39 +467,58 @@ async def get_models(user=Depends(get_verified_user)): return {"models": get_available_models()} -def get_available_voices() -> list[dict]: +def get_available_voices() -> dict: + """Returns {voice_id: voice_name} dict""" + ret = {} if app.state.config.TTS_ENGINE == "openai": - return [ - {"name": "alloy", "id": "alloy"}, - {"name": "echo", "id": "echo"}, - {"name": "fable", "id": "fable"}, - {"name": "onyx", "id": "onyx"}, - {"name": "nova", "id": "nova"}, - {"name": "shimmer", "id": "shimmer"}, - ] - elif app.state.config.TTS_ENGINE == "elevenlabs": - headers = { - "xi-api-key": app.state.config.TTS_API_KEY, - "Content-Type": "application/json", + ret = { + "alloy": "alloy", + "echo": "echo", + "fable": "fable", + "onyx": "onyx", + "nova": "nova", + "shimmer": "shimmer", } - + elif app.state.config.TTS_ENGINE == "elevenlabs": try: - response = requests.get( - "https://api.elevenlabs.io/v1/voices", headers=headers - ) - response.raise_for_status() - voices_data = response.json() + ret = get_elevenlabs_voices() + except Exception as e: + # Avoided @lru_cache with exception + pass - voices = [] - for voice in voices_data.get("voices", []): - voices.append({"name": voice["name"], "id": voice["voice_id"]}) - return voices - except requests.RequestException as e: - log.error(f"Error fetching voices: {str(e)}") + return ret + + +@lru_cache +def get_elevenlabs_voices() -> dict: + """ + Note, set the following in your .env file to use Elevenlabs: + AUDIO_TTS_ENGINE=elevenlabs + AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key + AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices + AUDIO_TTS_MODEL=eleven_multilingual_v2 + """ + headers = { + "xi-api-key": app.state.config.TTS_API_KEY, + "Content-Type": "application/json", + } + try: + # TODO: Add retries + response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers) + response.raise_for_status() + voices_data = response.json() - return [] + voices = {} + for voice in voices_data.get("voices", []): + voices[voice["voice_id"]] = voice["name"] + except requests.RequestException as e: + # Avoid @lru_cache with exception + log.error(f"Error fetching voices: {str(e)}") + raise RuntimeError(f"Error fetching voices: {str(e)}") + + return voices @app.get("/voices") async def get_voices(user=Depends(get_verified_user)): - return {"voices": get_available_voices()} + return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]} diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index d2f5ddd5d6..25ed2c5176 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -1,26 +1,10 @@ -import re -import requests -import base64 from fastapi import ( FastAPI, Request, Depends, HTTPException, - status, - UploadFile, - File, - Form, ) from fastapi.middleware.cors import CORSMiddleware - -from constants import ERROR_MESSAGES -from utils.utils import ( - get_verified_user, - get_admin_user, -) - -from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image -from utils.misc import calculate_sha256 from typing import Optional from pydantic import BaseModel from pathlib import Path @@ -29,7 +13,21 @@ import base64 import json import logging +import re +import requests + +from utils.utils import ( + get_verified_user, + get_admin_user, +) + +from apps.images.utils.comfyui import ( + ComfyUIWorkflow, + ComfyUIGenerateImageForm, + comfyui_generate_image, +) +from constants import ERROR_MESSAGES from config import ( SRC_LOG_LEVELS, CACHE_DIR, @@ -38,18 +36,14 @@ AUTOMATIC1111_BASE_URL, AUTOMATIC1111_API_AUTH, COMFYUI_BASE_URL, - COMFYUI_CFG_SCALE, - COMFYUI_SAMPLER, - COMFYUI_SCHEDULER, - COMFYUI_SD3, - COMFYUI_FLUX, - COMFYUI_FLUX_WEIGHT_DTYPE, - COMFYUI_FLUX_FP8_CLIP, + COMFYUI_WORKFLOW, + COMFYUI_WORKFLOW_NODES, IMAGES_OPENAI_API_BASE_URL, IMAGES_OPENAI_API_KEY, IMAGE_GENERATION_MODEL, IMAGE_SIZE, IMAGE_STEPS, + CORS_ALLOW_ORIGIN, AppConfig, ) @@ -62,7 +56,7 @@ app = FastAPI() app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -81,188 +75,210 @@ app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL +app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW +app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES app.state.config.IMAGE_SIZE = IMAGE_SIZE app.state.config.IMAGE_STEPS = IMAGE_STEPS -app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE -app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER -app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER -app.state.config.COMFYUI_SD3 = COMFYUI_SD3 -app.state.config.COMFYUI_FLUX = COMFYUI_FLUX -app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE = COMFYUI_FLUX_WEIGHT_DTYPE -app.state.config.COMFYUI_FLUX_FP8_CLIP = COMFYUI_FLUX_FP8_CLIP - - -def get_automatic1111_api_auth(): - if app.state.config.AUTOMATIC1111_API_AUTH is None: - return "" - else: - auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8") - auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string) - auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8") - return f"Basic {auth1111_base64_encoded_string}" @app.get("/config") async def get_config(request: Request, user=Depends(get_admin_user)): return { - "engine": app.state.config.ENGINE, "enabled": app.state.config.ENABLED, + "engine": app.state.config.ENGINE, + "openai": { + "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + }, + "automatic1111": { + "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, + "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, + }, + "comfyui": { + "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, + "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW, + "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES, + }, } -class ConfigUpdateForm(BaseModel): - engine: str +class OpenAIConfigForm(BaseModel): + OPENAI_API_BASE_URL: str + OPENAI_API_KEY: str + + +class Automatic1111ConfigForm(BaseModel): + AUTOMATIC1111_BASE_URL: str + AUTOMATIC1111_API_AUTH: str + + +class ComfyUIConfigForm(BaseModel): + COMFYUI_BASE_URL: str + COMFYUI_WORKFLOW: str + COMFYUI_WORKFLOW_NODES: list[dict] + + +class ConfigForm(BaseModel): enabled: bool + engine: str + openai: OpenAIConfigForm + automatic1111: Automatic1111ConfigForm + comfyui: ComfyUIConfigForm @app.post("/config/update") -async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): +async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)): app.state.config.ENGINE = form_data.engine app.state.config.ENABLED = form_data.enabled - return { - "engine": app.state.config.ENGINE, - "enabled": app.state.config.ENABLED, - } + app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL + app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY -class EngineUrlUpdateForm(BaseModel): - AUTOMATIC1111_BASE_URL: Optional[str] = None - AUTOMATIC1111_API_AUTH: Optional[str] = None - COMFYUI_BASE_URL: Optional[str] = None + app.state.config.AUTOMATIC1111_BASE_URL = ( + form_data.automatic1111.AUTOMATIC1111_BASE_URL + ) + app.state.config.AUTOMATIC1111_API_AUTH = ( + form_data.automatic1111.AUTOMATIC1111_API_AUTH + ) + app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL + app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW + app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES -@app.get("/url") -async def get_engine_url(user=Depends(get_admin_user)): return { - "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, - "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, - "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, + "enabled": app.state.config.ENABLED, + "engine": app.state.config.ENGINE, + "openai": { + "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + }, + "automatic1111": { + "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, + "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, + }, + "comfyui": { + "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, + "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW, + "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES, + }, } -@app.post("/url/update") -async def update_engine_url( - form_data: EngineUrlUpdateForm, user=Depends(get_admin_user) -): - if form_data.AUTOMATIC1111_BASE_URL is None: - app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL +def get_automatic1111_api_auth(): + if app.state.config.AUTOMATIC1111_API_AUTH is None: + return "" else: - url = form_data.AUTOMATIC1111_BASE_URL.strip("/") + auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8") + auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string) + auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8") + return f"Basic {auth1111_base64_encoded_string}" + + +@app.get("/config/url/verify") +async def verify_url(user=Depends(get_admin_user)): + if app.state.config.ENGINE == "automatic1111": try: - r = requests.head(url) + r = requests.get( + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + headers={"authorization": get_automatic1111_api_auth()}, + ) r.raise_for_status() - app.state.config.AUTOMATIC1111_BASE_URL = url + return True except Exception as e: + app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) - - if form_data.COMFYUI_BASE_URL is None: - app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL - else: - url = form_data.COMFYUI_BASE_URL.strip("/") - + elif app.state.config.ENGINE == "comfyui": try: - r = requests.head(url) + r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") r.raise_for_status() - app.state.config.COMFYUI_BASE_URL = url + return True except Exception as e: + app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) - - if form_data.AUTOMATIC1111_API_AUTH is None: - app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH else: - app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH - - return { - "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, - "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, - "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, - "status": True, - } + return True -class OpenAIConfigUpdateForm(BaseModel): - url: str - key: str +def set_image_model(model: str): + app.state.config.MODEL = model + if app.state.config.ENGINE in ["", "automatic1111"]: + api_auth = get_automatic1111_api_auth() + r = requests.get( + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + headers={"authorization": api_auth}, + ) + options = r.json() + if model != options["sd_model_checkpoint"]: + options["sd_model_checkpoint"] = model + r = requests.post( + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + json=options, + headers={"authorization": api_auth}, + ) + return app.state.config.MODEL -@app.get("/openai/config") -async def get_openai_config(user=Depends(get_admin_user)): - return { - "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, - } +def get_image_model(): + if app.state.config.ENGINE == "openai": + return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2" + elif app.state.config.ENGINE == "comfyui": + return app.state.config.MODEL if app.state.config.MODEL else "" + elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "": + try: + r = requests.get( + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + headers={"authorization": get_automatic1111_api_auth()}, + ) + options = r.json() + return options["sd_model_checkpoint"] + except Exception as e: + app.state.config.ENABLED = False + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) -@app.post("/openai/config/update") -async def update_openai_config( - form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user) -): - if form_data.key == "": - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) +class ImageConfigForm(BaseModel): + MODEL: str + IMAGE_SIZE: str + IMAGE_STEPS: int - app.state.config.OPENAI_API_BASE_URL = form_data.url - app.state.config.OPENAI_API_KEY = form_data.key +@app.get("/image/config") +async def get_image_config(user=Depends(get_admin_user)): return { - "status": True, - "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + "MODEL": app.state.config.MODEL, + "IMAGE_SIZE": app.state.config.IMAGE_SIZE, + "IMAGE_STEPS": app.state.config.IMAGE_STEPS, } -class ImageSizeUpdateForm(BaseModel): - size: str - - -@app.get("/size") -async def get_image_size(user=Depends(get_admin_user)): - return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE} - +@app.post("/image/config/update") +async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)): + app.state.config.MODEL = form_data.MODEL -@app.post("/size/update") -async def update_image_size( - form_data: ImageSizeUpdateForm, user=Depends(get_admin_user) -): - pattern = r"^\d+x\d+$" # Regular expression pattern - if re.match(pattern, form_data.size): - app.state.config.IMAGE_SIZE = form_data.size - return { - "IMAGE_SIZE": app.state.config.IMAGE_SIZE, - "status": True, - } + pattern = r"^\d+x\d+$" + if re.match(pattern, form_data.IMAGE_SIZE): + app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE else: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."), ) - -class ImageStepsUpdateForm(BaseModel): - steps: int - - -@app.get("/steps") -async def get_image_size(user=Depends(get_admin_user)): - return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS} - - -@app.post("/steps/update") -async def update_image_size( - form_data: ImageStepsUpdateForm, user=Depends(get_admin_user) -): - if form_data.steps >= 0: - app.state.config.IMAGE_STEPS = form_data.steps - return { - "IMAGE_STEPS": app.state.config.IMAGE_STEPS, - "status": True, - } + if form_data.IMAGE_STEPS >= 0: + app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS else: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."), ) + return { + "MODEL": app.state.config.MODEL, + "IMAGE_SIZE": app.state.config.IMAGE_SIZE, + "IMAGE_STEPS": app.state.config.IMAGE_STEPS, + } + @app.get("/models") def get_models(user=Depends(get_verified_user)): @@ -273,18 +289,51 @@ def get_models(user=Depends(get_verified_user)): {"id": "dall-e-3", "name": "DALL·E 3"}, ] elif app.state.config.ENGINE == "comfyui": - + # TODO - get models from comfyui r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") info = r.json() - return list( - map( - lambda model: {"id": model, "name": model}, - info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0], + workflow = json.loads(app.state.config.COMFYUI_WORKFLOW) + model_node_id = None + + for node in app.state.config.COMFYUI_WORKFLOW_NODES: + if node["type"] == "model": + if node["node_ids"]: + model_node_id = node["node_ids"][0] + break + + if model_node_id: + model_list_key = None + + print(workflow[model_node_id]["class_type"]) + for key in info[workflow[model_node_id]["class_type"]]["input"][ + "required" + ]: + if "_name" in key: + model_list_key = key + break + + if model_list_key: + return list( + map( + lambda model: {"id": model, "name": model}, + info[workflow[model_node_id]["class_type"]]["input"][ + "required" + ][model_list_key][0], + ) + ) + else: + return list( + map( + lambda model: {"id": model, "name": model}, + info["CheckpointLoaderSimple"]["input"]["required"][ + "ckpt_name" + ][0], + ) ) - ) - - else: + elif ( + app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "" + ): r = requests.get( url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models", headers={"authorization": get_automatic1111_api_auth()}, @@ -301,69 +350,11 @@ def get_models(user=Depends(get_verified_user)): raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) -@app.get("/models/default") -async def get_default_model(user=Depends(get_admin_user)): - try: - if app.state.config.ENGINE == "openai": - return { - "model": ( - app.state.config.MODEL if app.state.config.MODEL else "dall-e-2" - ) - } - elif app.state.config.ENGINE == "comfyui": - return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")} - else: - r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", - headers={"authorization": get_automatic1111_api_auth()}, - ) - options = r.json() - return {"model": options["sd_model_checkpoint"]} - except Exception as e: - app.state.config.ENABLED = False - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) - - -class UpdateModelForm(BaseModel): - model: str - - -def set_model_handler(model: str): - if app.state.config.ENGINE in ["openai", "comfyui"]: - app.state.config.MODEL = model - return app.state.config.MODEL - else: - api_auth = get_automatic1111_api_auth() - r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", - headers={"authorization": api_auth}, - ) - options = r.json() - - if model != options["sd_model_checkpoint"]: - options["sd_model_checkpoint"] = model - r = requests.post( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", - json=options, - headers={"authorization": api_auth}, - ) - - return options - - -@app.post("/models/default/update") -def update_default_model( - form_data: UpdateModelForm, - user=Depends(get_verified_user), -): - return set_model_handler(form_data.model) - - class GenerateImageForm(BaseModel): model: Optional[str] = None prompt: str - n: int = 1 size: Optional[str] = None + n: int = 1 negative_prompt: Optional[str] = None @@ -479,7 +470,6 @@ async def image_generations( return images elif app.state.config.ENGINE == "comfyui": - data = { "prompt": form_data.prompt, "width": width, @@ -493,32 +483,20 @@ async def image_generations( if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt - if app.state.config.COMFYUI_CFG_SCALE: - data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE - - if app.state.config.COMFYUI_SAMPLER is not None: - data["sampler"] = app.state.config.COMFYUI_SAMPLER - - if app.state.config.COMFYUI_SCHEDULER is not None: - data["scheduler"] = app.state.config.COMFYUI_SCHEDULER - - if app.state.config.COMFYUI_SD3 is not None: - data["sd3"] = app.state.config.COMFYUI_SD3 - - if app.state.config.COMFYUI_FLUX is not None: - data["flux"] = app.state.config.COMFYUI_FLUX - - if app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE is not None: - data["flux_weight_dtype"] = app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE - - if app.state.config.COMFYUI_FLUX_FP8_CLIP is not None: - data["flux_fp8_clip"] = app.state.config.COMFYUI_FLUX_FP8_CLIP - - data = ImageGenerationPayload(**data) - + form_data = ComfyUIGenerateImageForm( + **{ + "workflow": ComfyUIWorkflow( + **{ + "workflow": app.state.config.COMFYUI_WORKFLOW, + "nodes": app.state.config.COMFYUI_WORKFLOW_NODES, + } + ), + **data, + } + ) res = await comfyui_generate_image( app.state.config.MODEL, - data, + form_data, user.id, app.state.config.COMFYUI_BASE_URL, ) @@ -532,13 +510,15 @@ async def image_generations( file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json") with open(file_body_path, "w") as f: - json.dump(data.model_dump(exclude_none=True), f) + json.dump(form_data.model_dump(exclude_none=True), f) log.debug(f"images: {images}") return images - else: + elif ( + app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "" + ): if form_data.model: - set_model_handler(form_data.model) + set_image_model(form_data.model) data = { "prompt": form_data.prompt, @@ -560,7 +540,6 @@ async def image_generations( ) res = r.json() - log.debug(f"res: {res}") images = [] @@ -577,7 +556,6 @@ async def image_generations( except Exception as e: error = e - if r != None: data = r.json() if "error" in data: diff --git a/backend/apps/images/utils/comfyui.py b/backend/apps/images/utils/comfyui.py index f11dca57c5..1584842236 100644 --- a/backend/apps/images/utils/comfyui.py +++ b/backend/apps/images/utils/comfyui.py @@ -15,245 +15,6 @@ from typing import Optional -COMFYUI_DEFAULT_PROMPT = """ -{ - "3": { - "inputs": { - "seed": 0, - "steps": 20, - "cfg": 8, - "sampler_name": "euler", - "scheduler": "normal", - "denoise": 1, - "model": [ - "4", - 0 - ], - "positive": [ - "6", - 0 - ], - "negative": [ - "7", - 0 - ], - "latent_image": [ - "5", - 0 - ] - }, - "class_type": "KSampler", - "_meta": { - "title": "KSampler" - } - }, - "4": { - "inputs": { - "ckpt_name": "model.safetensors" - }, - "class_type": "CheckpointLoaderSimple", - "_meta": { - "title": "Load Checkpoint" - } - }, - "5": { - "inputs": { - "width": 512, - "height": 512, - "batch_size": 1 - }, - "class_type": "EmptyLatentImage", - "_meta": { - "title": "Empty Latent Image" - } - }, - "6": { - "inputs": { - "text": "Prompt", - "clip": [ - "4", - 1 - ] - }, - "class_type": "CLIPTextEncode", - "_meta": { - "title": "CLIP Text Encode (Prompt)" - } - }, - "7": { - "inputs": { - "text": "Negative Prompt", - "clip": [ - "4", - 1 - ] - }, - "class_type": "CLIPTextEncode", - "_meta": { - "title": "CLIP Text Encode (Prompt)" - } - }, - "8": { - "inputs": { - "samples": [ - "3", - 0 - ], - "vae": [ - "4", - 2 - ] - }, - "class_type": "VAEDecode", - "_meta": { - "title": "VAE Decode" - } - }, - "9": { - "inputs": { - "filename_prefix": "ComfyUI", - "images": [ - "8", - 0 - ] - }, - "class_type": "SaveImage", - "_meta": { - "title": "Save Image" - } - } -} -""" - -FLUX_DEFAULT_PROMPT = """ -{ - "5": { - "inputs": { - "width": 1024, - "height": 1024, - "batch_size": 1 - }, - "class_type": "EmptyLatentImage" - }, - "6": { - "inputs": { - "text": "Input Text Here", - "clip": [ - "11", - 0 - ] - }, - "class_type": "CLIPTextEncode" - }, - "8": { - "inputs": { - "samples": [ - "13", - 0 - ], - "vae": [ - "10", - 0 - ] - }, - "class_type": "VAEDecode" - }, - "9": { - "inputs": { - "filename_prefix": "ComfyUI", - "images": [ - "8", - 0 - ] - }, - "class_type": "SaveImage" - }, - "10": { - "inputs": { - "vae_name": "ae.safetensors" - }, - "class_type": "VAELoader" - }, - "11": { - "inputs": { - "clip_name1": "clip_l.safetensors", - "clip_name2": "t5xxl_fp16.safetensors", - "type": "flux" - }, - "class_type": "DualCLIPLoader" - }, - "12": { - "inputs": { - "unet_name": "flux1-dev.safetensors", - "weight_dtype": "default" - }, - "class_type": "UNETLoader" - }, - "13": { - "inputs": { - "noise": [ - "25", - 0 - ], - "guider": [ - "22", - 0 - ], - "sampler": [ - "16", - 0 - ], - "sigmas": [ - "17", - 0 - ], - "latent_image": [ - "5", - 0 - ] - }, - "class_type": "SamplerCustomAdvanced" - }, - "16": { - "inputs": { - "sampler_name": "euler" - }, - "class_type": "KSamplerSelect" - }, - "17": { - "inputs": { - "scheduler": "simple", - "steps": 20, - "denoise": 1, - "model": [ - "12", - 0 - ] - }, - "class_type": "BasicScheduler" - }, - "22": { - "inputs": { - "model": [ - "12", - 0 - ], - "conditioning": [ - "6", - 0 - ] - }, - "class_type": "BasicGuider" - }, - "25": { - "inputs": { - "noise_seed": 778937779713005 - }, - "class_type": "RandomNoise" - } -} -""" - def queue_prompt(prompt, client_id, base_url): log.info("queue_prompt") @@ -311,82 +72,71 @@ def get_images(ws, prompt, client_id, base_url): return {"data": output_images} -class ImageGenerationPayload(BaseModel): +class ComfyUINodeInput(BaseModel): + type: Optional[str] = None + node_ids: list[str] = [] + key: Optional[str] = "text" + value: Optional[str] = None + + +class ComfyUIWorkflow(BaseModel): + workflow: str + nodes: list[ComfyUINodeInput] + + +class ComfyUIGenerateImageForm(BaseModel): + workflow: ComfyUIWorkflow + prompt: str - negative_prompt: Optional[str] = "" - steps: Optional[int] = None - seed: Optional[int] = None + negative_prompt: Optional[str] = None width: int height: int n: int = 1 - cfg_scale: Optional[float] = None - sampler: Optional[str] = None - scheduler: Optional[str] = None - sd3: Optional[bool] = None - flux: Optional[bool] = None - flux_weight_dtype: Optional[str] = None - flux_fp8_clip: Optional[bool] = None + + steps: Optional[int] = None + seed: Optional[int] = None async def comfyui_generate_image( - model: str, payload: ImageGenerationPayload, client_id, base_url + model: str, payload: ComfyUIGenerateImageForm, client_id, base_url ): ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://") - - comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT) - - if payload.cfg_scale: - comfyui_prompt["3"]["inputs"]["cfg"] = payload.cfg_scale - - if payload.sampler: - comfyui_prompt["3"]["inputs"]["sampler"] = payload.sampler - - if payload.scheduler: - comfyui_prompt["3"]["inputs"]["scheduler"] = payload.scheduler - - if payload.sd3: - comfyui_prompt["5"]["class_type"] = "EmptySD3LatentImage" - - if payload.steps: - comfyui_prompt["3"]["inputs"]["steps"] = payload.steps - - comfyui_prompt["4"]["inputs"]["ckpt_name"] = model - comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt - comfyui_prompt["3"]["inputs"]["seed"] = ( - payload.seed if payload.seed else random.randint(0, 18446744073709551614) - ) - - # as Flux uses a completely different workflow, we must treat it specially - if payload.flux: - comfyui_prompt = json.loads(FLUX_DEFAULT_PROMPT) - comfyui_prompt["12"]["inputs"]["unet_name"] = model - comfyui_prompt["25"]["inputs"]["noise_seed"] = ( - payload.seed if payload.seed else random.randint(0, 18446744073709551614) - ) - - if payload.sampler: - comfyui_prompt["16"]["inputs"]["sampler_name"] = payload.sampler - - if payload.steps: - comfyui_prompt["17"]["inputs"]["steps"] = payload.steps - - if payload.scheduler: - comfyui_prompt["17"]["inputs"]["scheduler"] = payload.scheduler - - if payload.flux_weight_dtype: - comfyui_prompt["12"]["inputs"]["weight_dtype"] = payload.flux_weight_dtype - - if payload.flux_fp8_clip: - comfyui_prompt["11"]["inputs"][ - "clip_name2" - ] = "t5xxl_fp8_e4m3fn.safetensors" - - comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n - comfyui_prompt["5"]["inputs"]["width"] = payload.width - comfyui_prompt["5"]["inputs"]["height"] = payload.height - - # set the text prompt for our positive CLIPTextEncode - comfyui_prompt["6"]["inputs"]["text"] = payload.prompt + workflow = json.loads(payload.workflow.workflow) + + for node in payload.workflow.nodes: + if node.type: + if node.type == "model": + for node_id in node.node_ids: + workflow[node_id]["inputs"][node.key] = model + elif node.type == "prompt": + for node_id in node.node_ids: + workflow[node_id]["inputs"]["text"] = payload.prompt + elif node.type == "negative_prompt": + for node_id in node.node_ids: + workflow[node_id]["inputs"]["text"] = payload.negative_prompt + elif node.type == "width": + for node_id in node.node_ids: + workflow[node_id]["inputs"]["width"] = payload.width + elif node.type == "height": + for node_id in node.node_ids: + workflow[node_id]["inputs"]["height"] = payload.height + elif node.type == "n": + for node_id in node.node_ids: + workflow[node_id]["inputs"]["batch_size"] = payload.n + elif node.type == "steps": + for node_id in node.node_ids: + workflow[node_id]["inputs"]["steps"] = payload.steps + elif node.type == "seed": + seed = ( + payload.seed + if payload.seed + else random.randint(0, 18446744073709551614) + ) + for node_id in node.node_ids: + workflow[node_id]["inputs"][node.key] = seed + else: + for node_id in node.node_ids: + workflow[node_id]["inputs"][node.key] = node.value try: ws = websocket.WebSocket() @@ -397,9 +147,9 @@ async def comfyui_generate_image( return None try: - images = await asyncio.to_thread( - get_images, ws, comfyui_prompt, client_id, base_url - ) + log.info("Sending workflow to WebSocket server.") + log.info(f"Workflow: {workflow}") + images = await asyncio.to_thread(get_images, ws, workflow, client_id, base_url) except Exception as e: log.exception(f"Error while receiving images: {e}") images = None diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 03a8e198ee..d3931b1ab9 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -41,6 +41,7 @@ MODEL_FILTER_LIST, UPLOAD_DIR, AppConfig, + CORS_ALLOW_ORIGIN, ) from utils.misc import ( calculate_sha256, @@ -55,7 +56,7 @@ app = FastAPI() app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -147,13 +148,17 @@ async def cleanup_response( await session.close() -async def post_streaming_url(url: str, payload: str, stream: bool = True): +async def post_streaming_url(url: str, payload: Union[str, bytes], stream: bool = True): r = None try: session = aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) - r = await session.post(url, data=payload) + r = await session.post( + url, + data=payload, + headers={"Content-Type": "application/json"}, + ) r.raise_for_status() if stream: @@ -422,6 +427,7 @@ async def copy_model( r = requests.request( method="POST", url=f"{url}/api/copy", + headers={"Content-Type": "application/json"}, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -470,6 +476,7 @@ async def delete_model( r = requests.request( method="DELETE", url=f"{url}/api/delete", + headers={"Content-Type": "application/json"}, data=form_data.model_dump_json(exclude_none=True).encode(), ) try: @@ -510,6 +517,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us r = requests.request( method="POST", url=f"{url}/api/show", + headers={"Content-Type": "application/json"}, data=form_data.model_dump_json(exclude_none=True).encode(), ) try: @@ -567,6 +575,7 @@ async def generate_embeddings( r = requests.request( method="POST", url=f"{url}/api/embeddings", + headers={"Content-Type": "application/json"}, data=form_data.model_dump_json(exclude_none=True).encode(), ) try: @@ -616,6 +625,7 @@ def generate_ollama_embeddings( r = requests.request( method="POST", url=f"{url}/api/embeddings", + headers={"Content-Type": "application/json"}, data=form_data.model_dump_json(exclude_none=True).encode(), ) try: @@ -721,11 +731,8 @@ async def generate_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=") - - payload = { - **form_data.model_dump(exclude_none=True, exclude=["metadata"]), - } + payload = {**form_data.model_dump(exclude_none=True)} + log.debug(f"{payload = }") if "metadata" in payload: del payload["metadata"] diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index d344c66222..9ad67c40c7 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -32,6 +32,7 @@ ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, AppConfig, + CORS_ALLOW_ORIGIN, ) from typing import Optional, Literal, overload @@ -45,7 +46,7 @@ app = FastAPI() app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index f9788556bc..7b2fbc6794 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -129,6 +129,7 @@ RAG_WEB_SEARCH_RESULT_COUNT, RAG_WEB_SEARCH_CONCURRENT_REQUESTS, RAG_EMBEDDING_OPENAI_BATCH_SIZE, + CORS_ALLOW_ORIGIN, ) from constants import ERROR_MESSAGES @@ -240,12 +241,9 @@ def update_reranking_model( app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, ) -origins = ["*"] - - app.add_middleware( CORSMiddleware, - allow_origins=origins, + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index dddf3fbb2a..bede4b4f83 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -26,6 +26,7 @@ apply_model_system_prompt_to_body, ) +from utils.tools import get_tools from config import ( SHOW_ADMIN_DETAILS, @@ -43,10 +44,12 @@ JWT_EXPIRES_IN, WEBUI_BANNERS, ENABLE_COMMUNITY_SHARING, + ENABLE_MESSAGE_RATING, AppConfig, OAUTH_USERNAME_CLAIM, OAUTH_PICTURE_CLAIM, OAUTH_EMAIL_CLAIM, + CORS_ALLOW_ORIGIN, ) from apps.socket.main import get_event_call, get_event_emitter @@ -59,8 +62,6 @@ app = FastAPI() -origins = ["*"] - app.state.config = AppConfig() app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP @@ -82,6 +83,7 @@ app.state.config.BANNERS = WEBUI_BANNERS app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING +app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM @@ -93,7 +95,7 @@ app.add_middleware( CORSMiddleware, - allow_origins=origins, + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -274,7 +276,13 @@ def get_function_params(function_module, form_data, user, extra_params={}): async def generate_function_chat_completion(form_data, user): model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) - metadata = form_data.pop("metadata", None) + metadata = form_data.pop("metadata", {}) + files = metadata.get("files", []) + tool_ids = metadata.get("tool_ids", []) + + # Check if tool_ids is None + if tool_ids is None: + tool_ids = [] __event_emitter__ = None __event_call__ = None @@ -286,6 +294,21 @@ async def generate_function_chat_completion(form_data, user): __event_call__ = get_event_call(metadata) __task__ = metadata.get("task", None) + extra_params = { + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__task__": __task__, + } + tools_params = { + **extra_params, + "__model__": app.state.MODELS[form_data["model"]], + "__messages__": form_data["messages"], + "__files__": files, + } + + tools = get_tools(app, tool_ids, user, tools_params) + extra_params["__tools__"] = tools + if model_info: if model_info.base_model_id: form_data["model"] = model_info.base_model_id @@ -298,16 +321,7 @@ async def generate_function_chat_completion(form_data, user): function_module = get_function_module(pipe_id) pipe = function_module.pipe - params = get_function_params( - function_module, - form_data, - user, - { - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__task__": __task__, - }, - ) + params = get_function_params(function_module, form_data, user, extra_params) if form_data["stream"]: diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 36dfa4f855..b6e85e2ca2 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -1,12 +1,10 @@ -from pydantic import BaseModel, ConfigDict, parse_obj_as -from typing import Union, Optional +from pydantic import BaseModel, ConfigDict +from typing import Optional import time from sqlalchemy import String, Column, BigInteger, Text -from utils.misc import get_gravatar_url - -from apps.webui.internal.db import Base, JSONField, Session, get_db +from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.models.chats import Chats #################### @@ -78,7 +76,6 @@ class UserUpdateForm(BaseModel): class UsersTable: - def insert_new_user( self, id: str, @@ -122,7 +119,6 @@ def get_user_by_id(self, id: str) -> Optional[UserModel]: def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: try: with get_db() as db: - user = db.query(User).filter_by(api_key=api_key).first() return UserModel.model_validate(user) except Exception: @@ -131,7 +127,6 @@ def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: def get_user_by_email(self, email: str) -> Optional[UserModel]: try: with get_db() as db: - user = db.query(User).filter_by(email=email).first() return UserModel.model_validate(user) except Exception: @@ -140,7 +135,6 @@ def get_user_by_email(self, email: str) -> Optional[UserModel]: def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: try: with get_db() as db: - user = db.query(User).filter_by(oauth_sub=sub).first() return UserModel.model_validate(user) except Exception: @@ -195,7 +189,6 @@ def update_user_profile_image_url_by_id( def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: try: with get_db() as db: - db.query(User).filter_by(id=id).update( {"last_active_at": int(time.time())} ) diff --git a/backend/apps/webui/routers/auths.py b/backend/apps/webui/routers/auths.py index e68cad7f00..91414bc643 100644 --- a/backend/apps/webui/routers/auths.py +++ b/backend/apps/webui/routers/auths.py @@ -356,6 +356,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)): "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, + "ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING, } @@ -365,6 +366,7 @@ class AdminConfig(BaseModel): DEFAULT_USER_ROLE: str JWT_EXPIRES_IN: str ENABLE_COMMUNITY_SHARING: bool + ENABLE_MESSAGE_RATING: bool @router.post("/admin/config") @@ -386,6 +388,7 @@ async def update_admin_config( request.app.state.config.ENABLE_COMMUNITY_SHARING = ( form_data.ENABLE_COMMUNITY_SHARING ) + request.app.state.config.ENABLE_MESSAGE_RATING = form_data.ENABLE_MESSAGE_RATING return { "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS, @@ -393,6 +396,7 @@ async def update_admin_config( "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, + "ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING, } diff --git a/backend/apps/webui/routers/utils.py b/backend/apps/webui/routers/utils.py index 7a3c339324..8bf8267da1 100644 --- a/backend/apps/webui/routers/utils.py +++ b/backend/apps/webui/routers/utils.py @@ -85,9 +85,10 @@ async def download_chat_as_pdf( pdf.add_font("NotoSans", "i", f"{FONTS_DIR}/NotoSans-Italic.ttf") pdf.add_font("NotoSansKR", "", f"{FONTS_DIR}/NotoSansKR-Regular.ttf") pdf.add_font("NotoSansJP", "", f"{FONTS_DIR}/NotoSansJP-Regular.ttf") + pdf.add_font("NotoSansSC", "", f"{FONTS_DIR}/NotoSansSC-Regular.ttf") pdf.set_font("NotoSans", size=12) - pdf.set_fallback_fonts(["NotoSansKR", "NotoSansJP"]) + pdf.set_fallback_fonts(["NotoSansKR", "NotoSansJP", "NotoSansSC"]) pdf.set_auto_page_break(auto=True, margin=15) diff --git a/backend/config.py b/backend/config.py index d564aed2ce..05913065bc 100644 --- a/backend/config.py +++ b/backend/config.py @@ -3,6 +3,8 @@ import logging import importlib.metadata import pkgutil +from urllib.parse import urlparse + import chromadb from chromadb import Settings from bs4 import BeautifulSoup @@ -174,7 +176,6 @@ def parse_section(section): CHANGELOG = changelog_json - #################################### # SAFE_MODE #################################### @@ -806,10 +807,24 @@ def create_config_file(file_path): os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true" ) +USER_PERMISSIONS_CHAT_EDITING = ( + os.environ.get("USER_PERMISSIONS_CHAT_EDITING", "True").lower() == "true" +) + +USER_PERMISSIONS_CHAT_TEMPORARY = ( + os.environ.get("USER_PERMISSIONS_CHAT_TEMPORARY", "True").lower() == "true" +) + USER_PERMISSIONS = PersistentConfig( "USER_PERMISSIONS", "ui.user_permissions", - {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}, + { + "chat": { + "deletion": USER_PERMISSIONS_CHAT_DELETION, + "editing": USER_PERMISSIONS_CHAT_EDITING, + "temporary": USER_PERMISSIONS_CHAT_TEMPORARY, + } + }, ) ENABLE_MODEL_FILTER = PersistentConfig( @@ -840,6 +855,47 @@ def create_config_file(file_path): os.environ.get("ENABLE_COMMUNITY_SHARING", "True").lower() == "true", ) +ENABLE_MESSAGE_RATING = PersistentConfig( + "ENABLE_MESSAGE_RATING", + "ui.enable_message_rating", + os.environ.get("ENABLE_MESSAGE_RATING", "True").lower() == "true", +) + + +def validate_cors_origins(origins): + for origin in origins: + if origin != "*": + validate_cors_origin(origin) + + +def validate_cors_origin(origin): + parsed_url = urlparse(origin) + + # Check if the scheme is either http or https + if parsed_url.scheme not in ["http", "https"]: + raise ValueError( + f"Invalid scheme in CORS_ALLOW_ORIGIN: '{origin}'. Only 'http' and 'https' are allowed." + ) + + # Ensure that the netloc (domain + port) is present, indicating it's a valid URL + if not parsed_url.netloc: + raise ValueError(f"Invalid URL structure in CORS_ALLOW_ORIGIN: '{origin}'.") + + +# For production, you should only need one host as +# fastapi serves the svelte-kit built frontend and backend from the same host and port. +# To test CORS_ALLOW_ORIGIN locally, you can set something like +# CORS_ALLOW_ORIGIN=http://localhost:5173;http://localhost:8080 +# in your .env file depending on your frontend port, 5173 in this case. +CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";") + +if "*" in CORS_ALLOW_ORIGIN: + log.warning( + "\n\nWARNING: CORS_ALLOW_ORIGIN IS SET TO '*' - NOT RECOMMENDED FOR PRODUCTION DEPLOYMENTS.\n" + ) + +validate_cors_origins(CORS_ALLOW_ORIGIN) + class BannerModel(BaseModel): id: str @@ -895,10 +951,7 @@ class BannerModel(BaseModel): "task.title.prompt_template", os.environ.get( "TITLE_GENERATION_PROMPT_TEMPLATE", - """Here is the query: -{{prompt:middletruncate:8000}} - -Create a concise, 3-5 word phrase with an emoji as a title for the previous query. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. + """Create a concise, 3-5 word title with an emoji as a title for the prompt in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. Examples of titles: 📉 Stock Market Trends @@ -906,7 +959,9 @@ class BannerModel(BaseModel): Evolution of Music Streaming Remote Work Productivity Tips Artificial Intelligence in Healthcare -🎮 Video Game Development Insights""", +🎮 Video Game Development Insights + +Prompt: {{prompt:middletruncate:8000}}""", ), ) @@ -939,8 +994,7 @@ class BannerModel(BaseModel): "task.tools.prompt_template", os.environ.get( "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", - """Tools: {{TOOLS}} -If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks. Only return the object. Do not return any other text.""", + """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text.""", ), ) @@ -1056,7 +1110,7 @@ class BannerModel(BaseModel): RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig( "RAG_EMBEDDING_OPENAI_BATCH_SIZE", "rag.embedding_openai_batch_size", - os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", 1), + int(os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")), ) RAG_RERANKING_MODEL = PersistentConfig( @@ -1263,7 +1317,7 @@ class BannerModel(BaseModel): IMAGE_GENERATION_ENGINE = PersistentConfig( "IMAGE_GENERATION_ENGINE", "image_generation.engine", - os.getenv("IMAGE_GENERATION_ENGINE", ""), + os.getenv("IMAGE_GENERATION_ENGINE", "openai"), ) ENABLE_IMAGE_GENERATION = PersistentConfig( @@ -1288,46 +1342,127 @@ class BannerModel(BaseModel): os.getenv("COMFYUI_BASE_URL", ""), ) -COMFYUI_CFG_SCALE = PersistentConfig( - "COMFYUI_CFG_SCALE", - "image_generation.comfyui.cfg_scale", - os.getenv("COMFYUI_CFG_SCALE", ""), -) - -COMFYUI_SAMPLER = PersistentConfig( - "COMFYUI_SAMPLER", - "image_generation.comfyui.sampler", - os.getenv("COMFYUI_SAMPLER", ""), -) - -COMFYUI_SCHEDULER = PersistentConfig( - "COMFYUI_SCHEDULER", - "image_generation.comfyui.scheduler", - os.getenv("COMFYUI_SCHEDULER", ""), -) - -COMFYUI_SD3 = PersistentConfig( - "COMFYUI_SD3", - "image_generation.comfyui.sd3", - os.environ.get("COMFYUI_SD3", "").lower() == "true", -) +COMFYUI_DEFAULT_WORKFLOW = """ +{ + "3": { + "inputs": { + "seed": 0, + "steps": 20, + "cfg": 8, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 1, + "model": [ + "4", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "4": { + "inputs": { + "ckpt_name": "model.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "5": { + "inputs": { + "width": 512, + "height": 512, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "6": { + "inputs": { + "text": "Prompt", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "7": { + "inputs": { + "text": "", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "9": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + } +} +""" -COMFYUI_FLUX = PersistentConfig( - "COMFYUI_FLUX", - "image_generation.comfyui.flux", - os.environ.get("COMFYUI_FLUX", "").lower() == "true", -) -COMFYUI_FLUX_WEIGHT_DTYPE = PersistentConfig( - "COMFYUI_FLUX_WEIGHT_DTYPE", - "image_generation.comfyui.flux_weight_dtype", - os.getenv("COMFYUI_FLUX_WEIGHT_DTYPE", ""), +COMFYUI_WORKFLOW = PersistentConfig( + "COMFYUI_WORKFLOW", + "image_generation.comfyui.workflow", + os.getenv("COMFYUI_WORKFLOW", COMFYUI_DEFAULT_WORKFLOW), ) -COMFYUI_FLUX_FP8_CLIP = PersistentConfig( - "COMFYUI_FLUX_FP8_CLIP", - "image_generation.comfyui.flux_fp8_clip", - os.environ.get("COMFYUI_FLUX_FP8_CLIP", "").lower() == "true", +COMFYUI_WORKFLOW_NODES = PersistentConfig( + "COMFYUI_WORKFLOW", + "image_generation.comfyui.nodes", + [], ) IMAGES_OPENAI_API_BASE_URL = PersistentConfig( @@ -1410,13 +1545,13 @@ class BannerModel(BaseModel): AUDIO_TTS_MODEL = PersistentConfig( "AUDIO_TTS_MODEL", "audio.tts.model", - os.getenv("AUDIO_TTS_MODEL", "tts-1"), + os.getenv("AUDIO_TTS_MODEL", "tts-1"), # OpenAI default model ) AUDIO_TTS_VOICE = PersistentConfig( "AUDIO_TTS_VOICE", "audio.tts.voice", - os.getenv("AUDIO_TTS_VOICE", "alloy"), + os.getenv("AUDIO_TTS_VOICE", "alloy"), # OpenAI default voice ) diff --git a/backend/constants.py b/backend/constants.py index b9c7fc430d..d55216bb5d 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -100,3 +100,4 @@ def __str__(self) -> str: EMOJI_GENERATION = "emoji_generation" QUERY_GENERATION = "query_generation" FUNCTION_CALLING = "function_calling" + MOA_RESPONSE_GENERATION = "moa_response_generation" diff --git a/backend/main.py b/backend/main.py index d8ce5f5d78..c59f631b3c 100644 --- a/backend/main.py +++ b/backend/main.py @@ -14,6 +14,7 @@ import mimetypes import shutil import inspect +from typing import Optional from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi.staticfiles import StaticFiles @@ -51,15 +52,13 @@ from pydantic import BaseModel -from typing import Optional from apps.webui.models.auths import Auths from apps.webui.models.models import Models -from apps.webui.models.tools import Tools from apps.webui.models.functions import Functions -from apps.webui.models.users import Users +from apps.webui.models.users import Users, UserModel -from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id +from apps.webui.utils import load_function_module_by_id from utils.utils import ( get_admin_user, @@ -68,12 +67,16 @@ get_http_authorization_cred, get_password_hash, create_token, + decode_token, ) from utils.task import ( title_generation_template, search_query_generation_template, tools_function_calling_generation_template, + moa_response_generation_template, ) + +from utils.tools import get_tools from utils.misc import ( get_last_user_message, add_or_update_system_message, @@ -118,6 +121,7 @@ WEBUI_SESSION_COOKIE_SECURE, ENABLE_ADMIN_CHAT_ACCESS, AppConfig, + CORS_ALLOW_ORIGIN, ) from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS @@ -208,8 +212,6 @@ async def lifespan(app: FastAPI): app.state.MODELS = {} -origins = ["*"] - ################################## # @@ -218,25 +220,6 @@ async def lifespan(app: FastAPI): ################################## -async def get_body_and_model_and_user(request): - # Read the original request body - body = await request.body() - body_str = body.decode("utf-8") - body = json.loads(body_str) if body_str else {} - - model_id = body["model"] - if model_id not in app.state.MODELS: - raise Exception("Model not found") - model = app.state.MODELS[model_id] - - user = get_current_user( - request, - get_http_authorization_cred(request.headers.get("Authorization")), - ) - - return body, model, user - - def get_task_model_id(default_model_id): # Set the task model task_model_id = default_model_id @@ -261,6 +244,7 @@ def get_filter_function_ids(model): def get_priority(function_id): function = Functions.get_function_by_id(function_id) if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel return (function.valves if function.valves else {}).get("priority", 0) return 0 @@ -282,164 +266,7 @@ def get_priority(function_id): return filter_ids -async def get_function_call_response( - messages, - files, - tool_id, - template, - task_model_id, - user, - __event_emitter__=None, - __event_call__=None, -): - tool = Tools.get_tool_by_id(tool_id) - tools_specs = json.dumps(tool.specs, indent=2) - content = tools_function_calling_generation_template(template, tools_specs) - - user_message = get_last_user_message(messages) - prompt = ( - "History:\n" - + "\n".join( - [ - f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" - for message in messages[::-1][:4] - ] - ) - + f"\nQuery: {user_message}" - ) - - print(prompt) - - payload = { - "model": task_model_id, - "messages": [ - {"role": "system", "content": content}, - {"role": "user", "content": f"Query: {prompt}"}, - ], - "stream": False, - "task": str(TASKS.FUNCTION_CALLING), - } - - try: - payload = filter_pipeline(payload, user) - except Exception as e: - raise e - - model = app.state.MODELS[task_model_id] - - response = None - try: - response = await generate_chat_completions(form_data=payload, user=user) - content = None - - if hasattr(response, "body_iterator"): - async for chunk in response.body_iterator: - data = json.loads(chunk.decode("utf-8")) - content = data["choices"][0]["message"]["content"] - - # Cleanup any remaining background tasks if necessary - if response.background is not None: - await response.background() - else: - content = response["choices"][0]["message"]["content"] - - if content is None: - return None, None, False - - # Parse the function response - print(f"content: {content}") - result = json.loads(content) - print(result) - - citation = None - - if "name" not in result: - return None, None, False - - # Call the function - if tool_id in webui_app.state.TOOLS: - toolkit_module = webui_app.state.TOOLS[tool_id] - else: - toolkit_module, _ = load_toolkit_module_by_id(tool_id) - webui_app.state.TOOLS[tool_id] = toolkit_module - - file_handler = False - # check if toolkit_module has file_handler self variable - if hasattr(toolkit_module, "file_handler"): - file_handler = True - print("file_handler: ", file_handler) - - if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"): - valves = Tools.get_tool_valves_by_id(tool_id) - toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {})) - - function = getattr(toolkit_module, result["name"]) - function_result = None - try: - # Get the signature of the function - sig = inspect.signature(function) - params = result["parameters"] - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": tool_id, - "__messages__": messages, - "__files__": files, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - # Call the function with the '__user__' parameter included - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(toolkit_module, "UserValves"): - __user__["valves"] = toolkit_module.UserValves( - **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(function): - function_result = await function(**params) - else: - function_result = function(**params) - - if hasattr(toolkit_module, "citation") and toolkit_module.citation: - citation = { - "source": {"name": f"TOOL:{tool.name}/{result['name']}"}, - "document": [function_result], - "metadata": [{"source": result["name"]}], - } - except Exception as e: - print(e) - - # Add the function result to the system prompt - if function_result is not None: - return function_result, citation, file_handler - except Exception as e: - print(f"Error: {e}") - - return None, None, False - - -async def chat_completion_functions_handler( - body, model, user, __event_emitter__, __event_call__ -): +async def chat_completion_filter_functions_handler(body, model, extra_params): skip_files = None filter_ids = get_filter_function_ids(model) @@ -475,37 +302,20 @@ async def chat_completion_functions_handler( params = {"body": body} # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": filter_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - + custom_params = {**extra_params, "__model__": model, "__id__": filter_id} + if hasattr(function_module, "UserValves") and "__user__" in sig.parameters: try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) - ) + uid = custom_params["__user__"]["id"] + custom_params["__user__"]["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id(filter_id, uid) + ) except Exception as e: print(e) - params = {**params, "__user__": __user__} + # Add extra params in contained in function signature + for key, value in custom_params.items(): + if key in sig.parameters: + params[key] = value if inspect.iscoroutinefunction(inlet): body = await inlet(**params) @@ -516,74 +326,146 @@ async def chat_completion_functions_handler( print(f"Error: {e}") raise e - if skip_files: - if "files" in body: - del body["files"] + if skip_files and "files" in body.get("metadata", {}): + del body["metadata"]["files"] return body, {} -async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__): - skip_files = None +def get_tools_function_calling_payload(messages, task_model_id, content): + user_message = get_last_user_message(messages) + history = "\n".join( + f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" + for message in messages[::-1][:4] + ) + + prompt = f"History:\n{history}\nQuery: {user_message}" + + return { + "model": task_model_id, + "messages": [ + {"role": "system", "content": content}, + {"role": "user", "content": f"Query: {prompt}"}, + ], + "stream": False, + "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, + } + +async def get_content_from_response(response) -> Optional[str]: + content = None + if hasattr(response, "body_iterator"): + async for chunk in response.body_iterator: + data = json.loads(chunk.decode("utf-8")) + content = data["choices"][0]["message"]["content"] + + # Cleanup any remaining background tasks if necessary + if response.background is not None: + await response.background() + else: + content = response["choices"][0]["message"]["content"] + return content + + +async def chat_completion_tools_handler( + body: dict, user: UserModel, extra_params: dict +) -> tuple[dict, dict]: + # If tool_ids field is present, call the functions + metadata = body.get("metadata", {}) + tool_ids = metadata.get("tool_ids", None) + if not tool_ids: + return body, {} + + skip_files = False contexts = [] - citations = None + citations = [] task_model_id = get_task_model_id(body["model"]) - # If tool_ids field is present, call the functions - if "tool_ids" in body: - print(body["tool_ids"]) - for tool_id in body["tool_ids"]: - print(tool_id) - try: - response, citation, file_handler = await get_function_call_response( - messages=body["messages"], - files=body.get("files", []), - tool_id=tool_id, - template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - task_model_id=task_model_id, - user=user, - __event_emitter__=__event_emitter__, - __event_call__=__event_call__, - ) + log.debug(f"{tool_ids=}") - print(file_handler) - if isinstance(response, str): - contexts.append(response) + custom_params = { + **extra_params, + "__model__": app.state.MODELS[task_model_id], + "__messages__": body["messages"], + "__files__": metadata.get("files", []), + } + tools = get_tools(webui_app, tool_ids, user, custom_params) + log.info(f"{tools=}") - if citation: - if citations is None: - citations = [citation] - else: - citations.append(citation) + specs = [tool["spec"] for tool in tools.values()] + tools_specs = json.dumps(specs) - if file_handler: - skip_files = True + tools_function_calling_prompt = tools_function_calling_generation_template( + app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, tools_specs + ) + log.info(f"{tools_function_calling_prompt=}") + payload = get_tools_function_calling_payload( + body["messages"], task_model_id, tools_function_calling_prompt + ) - except Exception as e: - print(f"Error: {e}") - del body["tool_ids"] - print(f"tool_contexts: {contexts}") + try: + payload = filter_pipeline(payload, user) + except Exception as e: + raise e - if skip_files: - if "files" in body: - del body["files"] + try: + response = await generate_chat_completions(form_data=payload, user=user) + log.debug(f"{response=}") + content = await get_content_from_response(response) + log.debug(f"{content=}") - return body, { - **({"contexts": contexts} if contexts is not None else {}), - **({"citations": citations} if citations is not None else {}), - } + if not content: + return body, {} + result = json.loads(content) -async def chat_completion_files_handler(body): - contexts = [] - citations = None + tool_function_name = result.get("name", None) + if tool_function_name not in tools: + return body, {} - if "files" in body: - files = body["files"] - del body["files"] + tool_function_params = result.get("parameters", {}) + try: + tool_output = await tools[tool_function_name]["callable"]( + **tool_function_params + ) + except Exception as e: + tool_output = str(e) + + if tools[tool_function_name]["citation"]: + citations.append( + { + "source": { + "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" + }, + "document": [tool_output], + "metadata": [{"source": tool_function_name}], + } + ) + if tools[tool_function_name]["file_handler"]: + skip_files = True + + if isinstance(tool_output, str): + contexts.append(tool_output) + + except Exception as e: + log.exception(f"Error: {e}") + content = None + + log.debug(f"tool_contexts: {contexts}") + + if skip_files and "files" in body.get("metadata", {}): + del body["metadata"]["files"] + + return body, {"contexts": contexts, "citations": citations} + + +async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]: + contexts = [] + citations = [] + + if files := body.get("metadata", {}).get("files", None): contexts, citations = get_rag_context( files=files, messages=body["messages"], @@ -596,152 +478,168 @@ async def chat_completion_files_handler(body): log.debug(f"rag_contexts: {contexts}, citations: {citations}") - return body, { - **({"contexts": contexts} if contexts is not None else {}), - **({"citations": citations} if citations is not None else {}), - } + return body, {"contexts": contexts, "citations": citations} -class ChatCompletionMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - if request.method == "POST" and any( - endpoint in request.url.path - for endpoint in ["/ollama/api/chat", "/chat/completions"] - ): - log.debug(f"request.url.path: {request.url.path}") +def is_chat_completion_request(request): + return request.method == "POST" and any( + endpoint in request.url.path + for endpoint in ["/ollama/api/chat", "/chat/completions"] + ) - try: - body, model, user = await get_body_and_model_and_user(request) - except Exception as e: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - metadata = { - "chat_id": body.pop("chat_id", None), - "message_id": body.pop("id", None), - "session_id": body.pop("session_id", None), - "valves": body.pop("valves", None), - } +async def get_body_and_model_and_user(request): + # Read the original request body + body = await request.body() + body_str = body.decode("utf-8") + body = json.loads(body_str) if body_str else {} - __event_emitter__ = get_event_emitter(metadata) - __event_call__ = get_event_call(metadata) + model_id = body["model"] + if model_id not in app.state.MODELS: + raise Exception("Model not found") + model = app.state.MODELS[model_id] - # Initialize data_items to store additional data to be sent to the client - data_items = [] + user = get_current_user( + request, + get_http_authorization_cred(request.headers.get("Authorization")), + ) - # Initialize context, and citations - contexts = [] - citations = [] + return body, model, user - try: - body, flags = await chat_completion_functions_handler( - body, model, user, __event_emitter__, __event_call__ - ) - except Exception as e: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - try: - body, flags = await chat_completion_tools_handler( - body, user, __event_emitter__, __event_call__ - ) +class ChatCompletionMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + if not is_chat_completion_request(request): + return await call_next(request) + log.debug(f"request.url.path: {request.url.path}") - contexts.extend(flags.get("contexts", [])) - citations.extend(flags.get("citations", [])) - except Exception as e: - print(e) - pass + try: + body, model, user = await get_body_and_model_and_user(request) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) - try: - body, flags = await chat_completion_files_handler(body) + metadata = { + "chat_id": body.pop("chat_id", None), + "message_id": body.pop("id", None), + "session_id": body.pop("session_id", None), + "tool_ids": body.get("tool_ids", None), + "files": body.get("files", None), + } + body["metadata"] = metadata - contexts.extend(flags.get("contexts", [])) - citations.extend(flags.get("citations", [])) - except Exception as e: - print(e) - pass + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } - # If context is not empty, insert it into the messages - if len(contexts) > 0: - context_string = "/n".join(contexts).strip() - prompt = get_last_user_message(body["messages"]) - - # Workaround for Ollama 2.0+ system prompt issue - # TODO: replace with add_or_update_system_message - if model["owned_by"] == "ollama": - body["messages"] = prepend_to_first_user_message_content( - rag_template( - rag_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], - ) - else: - body["messages"] = add_or_update_system_message( - rag_template( - rag_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], - ) + extra_params = { + "__user__": __user__, + "__event_emitter__": get_event_emitter(metadata), + "__event_call__": get_event_call(metadata), + } - # If there are citations, add them to the data_items - if len(citations) > 0: - data_items.append({"citations": citations}) - - body["metadata"] = metadata - modified_body_bytes = json.dumps(body).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[ - (k, v) - for k, v in request.headers.raw - if k.lower() != b"content-length" - ], - ] - - response = await call_next(request) - if isinstance(response, StreamingResponse): - # If it's a streaming response, inject it as SSE event or NDJSON line - content_type = response.headers.get("Content-Type") - if "text/event-stream" in content_type: - return StreamingResponse( - self.openai_stream_wrapper(response.body_iterator, data_items), - ) - if "application/x-ndjson" in content_type: - return StreamingResponse( - self.ollama_stream_wrapper(response.body_iterator, data_items), - ) + # Initialize data_items to store additional data to be sent to the client + # Initalize contexts and citation + data_items = [] + contexts = [] + citations = [] - return response + try: + body, flags = await chat_completion_filter_functions_handler( + body, model, extra_params + ) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + metadata = { + **metadata, + "tool_ids": body.pop("tool_ids", None), + "files": body.pop("files", None), + } + body["metadata"] = metadata + + try: + body, flags = await chat_completion_tools_handler(body, user, extra_params) + contexts.extend(flags.get("contexts", [])) + citations.extend(flags.get("citations", [])) + except Exception as e: + log.exception(e) + + try: + body, flags = await chat_completion_files_handler(body) + contexts.extend(flags.get("contexts", [])) + citations.extend(flags.get("citations", [])) + except Exception as e: + log.exception(e) + + # If context is not empty, insert it into the messages + if len(contexts) > 0: + context_string = "/n".join(contexts).strip() + prompt = get_last_user_message(body["messages"]) + if prompt is None: + raise Exception("No user message found") + # Workaround for Ollama 2.0+ system prompt issue + # TODO: replace with add_or_update_system_message + if model["owned_by"] == "ollama": + body["messages"] = prepend_to_first_user_message_content( + rag_template( + rag_app.state.config.RAG_TEMPLATE, context_string, prompt + ), + body["messages"], + ) else: - return response + body["messages"] = add_or_update_system_message( + rag_template( + rag_app.state.config.RAG_TEMPLATE, context_string, prompt + ), + body["messages"], + ) + + # If there are citations, add them to the data_items + if len(citations) > 0: + data_items.append({"citations": citations}) + + modified_body_bytes = json.dumps(body).encode("utf-8") + # Replace the request body with the modified one + request._body = modified_body_bytes + # Set custom header to ensure content-length matches new body length + request.headers.__dict__["_list"] = [ + (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), + *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], + ] - # If it's not a chat completion request, just pass it through response = await call_next(request) - return response + if not isinstance(response, StreamingResponse): + return response - async def _receive(self, body: bytes): - return {"type": "http.request", "body": body, "more_body": False} + content_type = response.headers["Content-Type"] + is_openai = "text/event-stream" in content_type + is_ollama = "application/x-ndjson" in content_type + if not is_openai and not is_ollama: + return response - async def openai_stream_wrapper(self, original_generator, data_items): - for item in data_items: - yield f"data: {json.dumps(item)}\n\n" + def wrap_item(item): + return f"data: {item}\n\n" if is_openai else f"{item}\n" - async for data in original_generator: - yield data + async def stream_wrapper(original_generator, data_items): + for item in data_items: + yield wrap_item(json.dumps(item)) - async def ollama_stream_wrapper(self, original_generator, data_items): - for item in data_items: - yield f"{json.dumps(item)}\n" + async for data in original_generator: + yield data - async for data in original_generator: - yield data + return StreamingResponse(stream_wrapper(response.body_iterator, data_items)) + + async def _receive(self, body: bytes): + return {"type": "http.request", "body": body, "more_body": False} app.add_middleware(ChatCompletionMiddleware) @@ -790,19 +688,21 @@ def filter_pipeline(payload, user): url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/inlet", - headers=headers, - json={ - "user": user, - "body": payload, - }, - ) + if key == "": + continue + + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/inlet", + headers=headers, + json={ + "user": user, + "body": payload, + }, + ) - r.raise_for_status() - payload = r.json() + r.raise_for_status() + payload = r.json() except Exception as e: # Handle connection error here print(f"Connection error: {e}") @@ -817,44 +717,39 @@ def filter_pipeline(payload, user): class PipelineMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - if request.method == "POST" and ( - "/ollama/api/chat" in request.url.path - or "/chat/completions" in request.url.path - ): - log.debug(f"request.url.path: {request.url.path}") - - # Read the original request body - body = await request.body() - # Decode body to string - body_str = body.decode("utf-8") - # Parse string to JSON - data = json.loads(body_str) if body_str else {} - - user = get_current_user( - request, - get_http_authorization_cred(request.headers.get("Authorization")), - ) + if not is_chat_completion_request(request): + return await call_next(request) - try: - data = filter_pipeline(data, user) - except Exception as e: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) + log.debug(f"request.url.path: {request.url.path}") + + # Read the original request body + body = await request.body() + # Decode body to string + body_str = body.decode("utf-8") + # Parse string to JSON + data = json.loads(body_str) if body_str else {} + + user = get_current_user( + request, + get_http_authorization_cred(request.headers["Authorization"]), + ) - modified_body_bytes = json.dumps(data).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[ - (k, v) - for k, v in request.headers.raw - if k.lower() != b"content-length" - ], - ] + try: + data = filter_pipeline(data, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + + modified_body_bytes = json.dumps(data).encode("utf-8") + # Replace the request body with the modified one + request._body = modified_body_bytes + # Set custom header to ensure content-length matches new body length + request.headers.__dict__["_list"] = [ + (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), + *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], + ] response = await call_next(request) return response @@ -868,7 +763,7 @@ async def _receive(self, body: bytes): app.add_middleware( CORSMiddleware, - allow_origins=origins, + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -1019,6 +914,8 @@ async def get_all_models(): model["actions"] = [] for action_id in action_ids: action = Functions.get_function_by_id(action_id) + if action is None: + raise Exception(f"Action not found: {action_id}") if action_id in webui_app.state.FUNCTIONS: function_module = webui_app.state.FUNCTIONS[action_id] @@ -1026,6 +923,10 @@ async def get_all_models(): function_module, _, _ = load_function_module_by_id(action_id) webui_app.state.FUNCTIONS[action_id] = function_module + __webui__ = False + if hasattr(function_module, "__webui__"): + __webui__ = function_module.__webui__ + if hasattr(function_module, "actions"): actions = function_module.actions model["actions"].extend( @@ -1039,6 +940,7 @@ async def get_all_models(): "icon_url": _action.get( "icon_url", action.meta.manifest.get("icon_url", None) ), + **({"__webui__": __webui__} if __webui__ else {}), } for _action in actions ] @@ -1050,6 +952,7 @@ async def get_all_models(): "name": action.name, "description": action.meta.description, "icon_url": action.meta.manifest.get("icon_url", None), + **({"__webui__": __webui__} if __webui__ else {}), } ) @@ -1092,23 +995,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u detail="Model not found", ) model = app.state.MODELS[model_id] - - # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc. - task = None - if "task" in form_data: - task = form_data["task"] - del form_data["task"] - - if task: - if "metadata" in form_data: - form_data["metadata"]["task"] = task - else: - form_data["metadata"] = {"task": task} - if model.get("pipe"): return await generate_function_chat_completion(form_data, user=user) if model["owned_by"] == "ollama": - print("generate_ollama_chat_completion") return await generate_ollama_chat_completion(form_data, user=user) else: return await generate_openai_chat_completion(form_data, user=user) @@ -1192,6 +1081,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): def get_priority(function_id): function = Functions.get_function_by_id(function_id) if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel to include vavles return (function.valves if function.valves else {}).get("priority", 0) return 0 @@ -1481,7 +1371,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): "stream": False, "max_tokens": 50, "chat_id": form_data.get("chat_id", None), - "task": str(TASKS.TITLE_GENERATION), + "metadata": {"task": str(TASKS.TITLE_GENERATION)}, } log.debug(payload) @@ -1534,7 +1424,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) "messages": [{"role": "user", "content": content}], "stream": False, "max_tokens": 30, - "task": str(TASKS.QUERY_GENERATION), + "metadata": {"task": str(TASKS.QUERY_GENERATION)}, } print(payload) @@ -1591,7 +1481,7 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): "stream": False, "max_tokens": 4, "chat_id": form_data.get("chat_id", None), - "task": str(TASKS.EMOJI_GENERATION), + "metadata": {"task": str(TASKS.EMOJI_GENERATION)}, } log.debug(payload) @@ -1610,9 +1500,9 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): return await generate_chat_completions(form_data=payload, user=user) -@app.post("/api/task/tools/completions") -async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)): - print("get_tools_function_calling") +@app.post("/api/task/moa/completions") +async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)): + print("generate_moa_response") model_id = form_data["model"] if model_id not in app.state.MODELS: @@ -1624,26 +1514,43 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ # Check if the user has a custom task model # If the user has a custom task model, use that model model_id = get_task_model_id(model_id) - print(model_id) - template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + + template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" + +Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. + +Responses from models: {{responses}}""" + + content = moa_response_generation_template( + template, + form_data["prompt"], + form_data["responses"], + ) + + payload = { + "model": model_id, + "messages": [{"role": "user", "content": content}], + "stream": form_data.get("stream", False), + "chat_id": form_data.get("chat_id", None), + "metadata": {"task": str(TASKS.MOA_RESPONSE_GENERATION)}, + } + + log.debug(payload) try: - context, _, _ = await get_function_call_response( - form_data["messages"], - form_data.get("files", []), - form_data["tool_id"], - template, - model_id, - user, - ) - return context + payload = filter_pipeline(payload, user) except Exception as e: return JSONResponse( status_code=e.args[0], content={"detail": e.args[1]}, ) + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) + ################################## # @@ -1683,7 +1590,7 @@ async def upload_pipeline( ): print("upload_pipeline", urlIdx, file.filename) # Check if the uploaded file is a python file - if not file.filename.endswith(".py"): + if not (file.filename and file.filename.endswith(".py")): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Only Python (.py) files are allowed.", @@ -1980,40 +1887,61 @@ async def update_pipeline_valves( @app.get("/api/config") -async def get_app_config(): +async def get_app_config(request: Request): + user = None + if "token" in request.cookies: + token = request.cookies.get("token") + data = decode_token(token) + if data is not None and "id" in data: + user = Users.get_user_by_id(data["id"]) + return { "status": True, "name": WEBUI_NAME, "version": VERSION, "default_locale": str(DEFAULT_LOCALE), - "default_models": webui_app.state.config.DEFAULT_MODELS, - "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, + "oauth": { + "providers": { + name: config.get("name", name) + for name, config in OAUTH_PROVIDERS.items() + } + }, "features": { "auth": WEBUI_AUTH, "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), "enable_signup": webui_app.state.config.ENABLE_SIGNUP, "enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM, - "enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH, - "enable_image_generation": images_app.state.config.ENABLED, - "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING, - "enable_admin_export": ENABLE_ADMIN_EXPORT, - "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, - }, - "audio": { - "tts": { - "engine": audio_app.state.config.TTS_ENGINE, - "voice": audio_app.state.config.TTS_VOICE, - }, - "stt": { - "engine": audio_app.state.config.STT_ENGINE, - }, + **( + { + "enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH, + "enable_image_generation": images_app.state.config.ENABLED, + "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING, + "enable_message_rating": webui_app.state.config.ENABLE_MESSAGE_RATING, + "enable_admin_export": ENABLE_ADMIN_EXPORT, + "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, + } + if user is not None + else {} + ), }, - "oauth": { - "providers": { - name: config.get("name", name) - for name, config in OAUTH_PROVIDERS.items() + **( + { + "default_models": webui_app.state.config.DEFAULT_MODELS, + "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, + "audio": { + "tts": { + "engine": audio_app.state.config.TTS_ENGINE, + "voice": audio_app.state.config.TTS_VOICE, + }, + "stt": { + "engine": audio_app.state.config.STT_ENGINE, + }, + }, + "permissions": {**webui_app.state.config.USER_PERMISSIONS}, } - }, + if user is not None + else {} + ), } @@ -2132,7 +2060,10 @@ async def oauth_login(provider: str, request: Request): redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( "oauth_callback", provider=provider ) - return await oauth.create_client(provider).authorize_redirect(request, redirect_uri) + client = oauth.create_client(provider) + if client is None: + raise HTTPException(404) + return await client.authorize_redirect(request, redirect_uri) # OAuth login logic is as follows: @@ -2264,7 +2195,20 @@ async def get_manifest_json(): "display": "standalone", "background_color": "#343541", "orientation": "portrait-primary", - "icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}], + "icons": [ + { + "src": "/static/logo.png", + "type": "image/png", + "sizes": "500x500", + "purpose": "any", + }, + { + "src": "/static/logo.png", + "type": "image/png", + "sizes": "500x500", + "purpose": "maskable", + }, + ], } diff --git a/backend/requirements.txt b/backend/requirements.txt index 6ef299b5fa..04b3261916 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,5 +1,5 @@ fastapi==0.111.0 -uvicorn[standard]==0.22.0 +uvicorn[standard]==0.30.6 pydantic==2.8.2 python-multipart==0.0.9 @@ -13,17 +13,17 @@ passlib[bcrypt]==1.7.4 requests==2.32.3 aiohttp==3.10.2 -sqlalchemy==2.0.31 +sqlalchemy==2.0.32 alembic==1.13.2 peewee==3.17.6 peewee-migrate==1.12.2 psycopg2-binary==2.9.9 PyMySQL==1.1.1 -bcrypt==4.1.3 +bcrypt==4.2.0 pymongo redis -boto3==1.34.153 +boto3==1.35.0 argon2-cffi==23.1.0 APScheduler==3.10.4 @@ -44,7 +44,7 @@ sentence-transformers==3.0.1 pypdf==4.3.1 docx2txt==0.8 python-pptx==1.0.0 -unstructured==0.15.0 +unstructured==0.15.5 Markdown==3.6 pypandoc==1.13 pandas==2.2.2 @@ -60,7 +60,7 @@ rapidocr-onnxruntime==1.3.24 fpdf2==2.7.9 rank-bm25==0.2.2 -faster-whisper==1.0.2 +faster-whisper==1.0.3 PyJWT[crypto]==2.9.0 authlib==1.3.1 diff --git a/backend/static/fonts/NotoSansSC-Regular.ttf b/backend/static/fonts/NotoSansSC-Regular.ttf new file mode 100644 index 0000000000..7056f5e97a Binary files /dev/null and b/backend/static/fonts/NotoSansSC-Regular.ttf differ diff --git a/backend/utils/schemas.py b/backend/utils/schemas.py new file mode 100644 index 0000000000..09b24897b9 --- /dev/null +++ b/backend/utils/schemas.py @@ -0,0 +1,104 @@ +from pydantic import BaseModel, Field, create_model +from typing import Any, Optional, Type + + +def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]: + """ + Converts a JSON schema to a Pydantic BaseModel class. + + Args: + json_schema: The JSON schema to convert. + + Returns: + A Pydantic BaseModel class. + """ + + # Extract the model name from the schema title. + model_name = tool_dict["name"] + schema = tool_dict["parameters"] + + # Extract the field definitions from the schema properties. + field_definitions = { + name: json_schema_to_pydantic_field(name, prop, schema.get("required", [])) + for name, prop in schema.get("properties", {}).items() + } + + # Create the BaseModel class using create_model(). + return create_model(model_name, **field_definitions) + + +def json_schema_to_pydantic_field( + name: str, json_schema: dict[str, Any], required: list[str] +) -> Any: + """ + Converts a JSON schema property to a Pydantic field definition. + + Args: + name: The field name. + json_schema: The JSON schema property. + + Returns: + A Pydantic field definition. + """ + + # Get the field type. + type_ = json_schema_to_pydantic_type(json_schema) + + # Get the field description. + description = json_schema.get("description") + + # Get the field examples. + examples = json_schema.get("examples") + + # Create a Field object with the type, description, and examples. + # The 'required' flag will be set later when creating the model. + return ( + type_, + Field( + description=description, + examples=examples, + default=... if name in required else None, + ), + ) + + +def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any: + """ + Converts a JSON schema type to a Pydantic type. + + Args: + json_schema: The JSON schema to convert. + + Returns: + A Pydantic type. + """ + + type_ = json_schema.get("type") + + if type_ == "string" or type_ == "str": + return str + elif type_ == "integer" or type_ == "int": + return int + elif type_ == "number" or type_ == "float": + return float + elif type_ == "boolean" or type_ == "bool": + return bool + elif type_ == "array": + items_schema = json_schema.get("items") + if items_schema: + item_type = json_schema_to_pydantic_type(items_schema) + return list[item_type] + else: + return list + elif type_ == "object": + # Handle nested models. + properties = json_schema.get("properties") + if properties: + nested_model = json_schema_to_model(json_schema) + return nested_model + else: + return dict + elif type_ == "null": + return Optional[Any] # Use Optional[Any] for nullable fields + else: + raise ValueError(f"Unsupported JSON schema type: {type_}") diff --git a/backend/utils/task.py b/backend/utils/task.py index 1b2276c9c5..ea9254c4f7 100644 --- a/backend/utils/task.py +++ b/backend/utils/task.py @@ -121,6 +121,43 @@ def replacement_function(match): return template +def moa_response_generation_template( + template: str, prompt: str, responses: list[str] +) -> str: + def replacement_function(match): + full_match = match.group(0) + start_length = match.group(1) + end_length = match.group(2) + middle_length = match.group(3) + + if full_match == "{{prompt}}": + return prompt + elif start_length is not None: + return prompt[: int(start_length)] + elif end_length is not None: + return prompt[-int(end_length) :] + elif middle_length is not None: + middle_length = int(middle_length) + if len(prompt) <= middle_length: + return prompt + start = prompt[: math.ceil(middle_length / 2)] + end = prompt[-math.floor(middle_length / 2) :] + return f"{start}...{end}" + return "" + + template = re.sub( + r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}", + replacement_function, + template, + ) + + responses = [f'"""{response}"""' for response in responses] + responses = "\n\n".join(responses) + + template = template.replace("{{responses}}", responses) + return template + + def tools_function_calling_generation_template(template: str, tools_specs: str) -> str: template = template.replace("{{TOOLS}}", tools_specs) return template diff --git a/backend/utils/tools.py b/backend/utils/tools.py index eac36b5d90..1a2fea32b0 100644 --- a/backend/utils/tools.py +++ b/backend/utils/tools.py @@ -1,5 +1,90 @@ import inspect -from typing import get_type_hints +import logging +from typing import Awaitable, Callable, get_type_hints + +from apps.webui.models.tools import Tools +from apps.webui.models.users import UserModel +from apps.webui.utils import load_toolkit_module_by_id + +from utils.schemas import json_schema_to_model + +log = logging.getLogger(__name__) + + +def apply_extra_params_to_tool_function( + function: Callable, extra_params: dict +) -> Callable[..., Awaitable]: + sig = inspect.signature(function) + extra_params = { + key: value for key, value in extra_params.items() if key in sig.parameters + } + is_coroutine = inspect.iscoroutinefunction(function) + + async def new_function(**kwargs): + extra_kwargs = kwargs | extra_params + if is_coroutine: + return await function(**extra_kwargs) + return function(**extra_kwargs) + + return new_function + + +# Mutation on extra_params +def get_tools( + webui_app, tool_ids: list[str], user: UserModel, extra_params: dict +) -> dict[str, dict]: + tools = {} + for tool_id in tool_ids: + toolkit = Tools.get_tool_by_id(tool_id) + if toolkit is None: + continue + + module = webui_app.state.TOOLS.get(tool_id, None) + if module is None: + module, _ = load_toolkit_module_by_id(tool_id) + webui_app.state.TOOLS[tool_id] = module + + extra_params["__id__"] = tool_id + if hasattr(module, "valves") and hasattr(module, "Valves"): + valves = Tools.get_tool_valves_by_id(tool_id) or {} + module.valves = module.Valves(**valves) + + if hasattr(module, "UserValves"): + extra_params["__user__"]["valves"] = module.UserValves( # type: ignore + **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) + ) + + for spec in toolkit.specs: + # TODO: Fix hack for OpenAI API + for val in spec.get("parameters", {}).get("properties", {}).values(): + if val["type"] == "str": + val["type"] = "string" + function_name = spec["name"] + + # convert to function that takes only model params and inserts custom params + original_func = getattr(module, function_name) + callable = apply_extra_params_to_tool_function(original_func, extra_params) + if hasattr(original_func, "__doc__"): + callable.__doc__ = original_func.__doc__ + + # TODO: This needs to be a pydantic model + tool_dict = { + "toolkit_id": tool_id, + "callable": callable, + "spec": spec, + "pydantic_model": json_schema_to_model(spec), + "file_handler": hasattr(module, "file_handler") and module.file_handler, + "citation": hasattr(module, "citation") and module.citation, + } + + # TODO: if collision, prepend toolkit name + if function_name in tools: + log.warning(f"Tool {function_name} already exists in another toolkit!") + log.warning(f"Collision between {toolkit} and {tool_id}.") + log.warning(f"Discarding {toolkit}.{function_name}") + else: + tools[function_name] = tool_dict + return tools def doc_to_dict(docstring): diff --git a/docs/SECURITY.md b/docs/SECURITY.md index 1cf539b3e7..507e3c6069 100644 --- a/docs/SECURITY.md +++ b/docs/SECURITY.md @@ -9,10 +9,18 @@ Our primary goal is to ensure the protection and confidentiality of sensitive da | main | :white_check_mark: | | others | :x: | +## Zero Tolerance for External Platforms + +Based on a precedent of an unacceptable degree of spamming and unsolicited communications from third-party platforms, we forcefully reaffirm our stance. **We refuse to engage with, join, or monitor any platforms outside of GitHub for vulnerability reporting.** Our reasons are not just procedural but are deep-seated in the ethos of our project, which champions transparency and direct community interaction inherent in the open-source culture. Any attempts to divert our processes to external platforms will be met with outright rejection. This policy is non-negotiable and understands no exceptions. + +Any reports or solicitations arriving from sources other than our designated GitHub repository will be dismissed without consideration. We’ve seen how external engagements can dilute and compromise the integrity of community-driven projects, and we’re not here to gamble with the security and privacy of our user community. + ## Reporting a Vulnerability We appreciate the community's interest in identifying potential vulnerabilities. However, effective immediately, we will **not** accept low-effort vulnerability reports. To ensure that submissions are constructive and actionable, please adhere to the following guidelines: +Reports not submitted through our designated GitHub repository will be disregarded, and we will categorically reject invitations to collaborate on external platforms. Our aggressive stance on this matter underscores our commitment to a secure, transparent, and open community where all operations are visible and contributors are accountable. + 1. **No Vague Reports**: Submissions such as "I found a vulnerability" without any details will be treated as spam and will not be accepted. 2. **In-Depth Understanding Required**: Reports must reflect a clear understanding of the codebase and provide specific details about the vulnerability, including the affected components and potential impacts. @@ -23,7 +31,7 @@ We appreciate the community's interest in identifying potential vulnerabilities. 5. **Streamlined Merging Process**: When vulnerability reports meet the above criteria, we can consider them for immediate merging, similar to regular pull requests. Well-structured and thorough submissions will expedite the process of enhancing our security. -Submissions that do not meet these criteria will be closed, and repeat offenders may face a ban from future submissions. We aim to create a respectful and constructive reporting environment, where high-quality submissions foster better security for everyone. +**Non-compliant submissions will be closed, and repeat violators may be banned.** Our goal is to foster a constructive reporting environment where quality submissions promote better security for all users. ## Product Security @@ -33,4 +41,4 @@ For immediate concerns or detailed reports that meet our guidelines, please crea --- -_Last updated on **2024-08-06**._ +_Last updated on **2024-08-19**._ diff --git a/package-lock.json b/package-lock.json index 52c5f89335..d612dddcb8 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.3.13", + "version": "0.3.15", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.3.13", + "version": "0.3.15", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-python": "^6.1.6", @@ -18,6 +18,7 @@ "codemirror": "^6.0.1", "crc-32": "^1.2.2", "dayjs": "^1.11.10", + "dompurify": "^3.1.6", "eventsource-parser": "^1.1.2", "file-saver": "^2.0.5", "fuse.js": "^7.0.0", @@ -29,7 +30,6 @@ "js-sha256": "^0.10.1", "katex": "^0.16.9", "marked": "^9.1.0", - "marked-katex-extension": "^5.1.1", "mermaid": "^10.9.1", "pyodide": "^0.26.1", "socket.io-client": "^4.2.0", @@ -1545,11 +1545,6 @@ "integrity": "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==", "dev": true }, - "node_modules/@types/katex": { - "version": "0.16.7", - "resolved": "https://registry.npmjs.org/@types/katex/-/katex-0.16.7.tgz", - "integrity": "sha512-HMwFiRujE5PjrgwHQ25+bsLJgowjGjm5Z8FVSf0N6PwgJrwxH0QxzHYDcKsTfV3wva0vzrpqMTJS2jXPr5BMEQ==" - }, "node_modules/@types/mdast": { "version": "3.0.15", "resolved": "https://registry.npmjs.org/@types/mdast/-/mdast-3.0.15.tgz", @@ -3918,9 +3913,9 @@ } }, "node_modules/dompurify": { - "version": "3.1.5", - "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.1.5.tgz", - "integrity": "sha512-lwG+n5h8QNpxtyrJW/gJWckL+1/DQiYMX8f7t8Z2AZTPw1esVrqjI63i7Zc2Gz0aKzLVMYC1V1PL/ky+aY/NgA==" + "version": "3.1.6", + "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.1.6.tgz", + "integrity": "sha512-cTOAhc36AalkjtBpfG6O8JimdTMWNXjiePT2xQH/ppBGi/4uIpmj8eKyIkMJErXWARyINV/sB38yf8JCLF5pbQ==" }, "node_modules/domutils": { "version": "3.1.0", @@ -6042,18 +6037,6 @@ "node": ">= 16" } }, - "node_modules/marked-katex-extension": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/marked-katex-extension/-/marked-katex-extension-5.1.1.tgz", - "integrity": "sha512-piquiCyZpZ1aiocoJlJkRXr+hkk5UI4xw9GhRZiIAAgvX5rhzUDSJ0seup1JcsgueC8MLNDuqe5cRcAzkFE42Q==", - "dependencies": { - "@types/katex": "^0.16.7" - }, - "peerDependencies": { - "katex": ">=0.16 <0.17", - "marked": ">=4 <15" - } - }, "node_modules/matcher-collection": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/matcher-collection/-/matcher-collection-2.0.1.tgz", diff --git a/package.json b/package.json index 2d32422d1b..7252d8829a 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.3.13", + "version": "0.3.15", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", @@ -59,6 +59,7 @@ "codemirror": "^6.0.1", "crc-32": "^1.2.2", "dayjs": "^1.11.10", + "dompurify": "^3.1.6", "eventsource-parser": "^1.1.2", "file-saver": "^2.0.5", "fuse.js": "^7.0.0", @@ -70,7 +71,6 @@ "js-sha256": "^0.10.1", "katex": "^0.16.9", "marked": "^9.1.0", - "marked-katex-extension": "^5.1.1", "mermaid": "^10.9.1", "pyodide": "^0.26.1", "socket.io-client": "^4.2.0", diff --git a/pyproject.toml b/pyproject.toml index 159bce0727..61c3e5417c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ license = { file = "LICENSE" } dependencies = [ "fastapi==0.111.0", - "uvicorn[standard]==0.22.0", + "uvicorn[standard]==0.30.6", "pydantic==2.8.2", "python-multipart==0.0.9", @@ -21,17 +21,17 @@ dependencies = [ "requests==2.32.3", "aiohttp==3.10.2", - "sqlalchemy==2.0.31", + "sqlalchemy==2.0.32", "alembic==1.13.2", "peewee==3.17.6", "peewee-migrate==1.12.2", "psycopg2-binary==2.9.9", "PyMySQL==1.1.1", - "bcrypt==4.1.3", + "bcrypt==4.2.0", "pymongo", "redis", - "boto3==1.34.153", + "boto3==1.35.0", "argon2-cffi==23.1.0", "APScheduler==3.10.4", @@ -51,7 +51,7 @@ dependencies = [ "pypdf==4.3.1", "docx2txt==0.8", "python-pptx==1.0.0", - "unstructured==0.15.0", + "unstructured==0.15.5", "Markdown==3.6", "pypandoc==1.13", "pandas==2.2.2", @@ -67,7 +67,7 @@ dependencies = [ "fpdf2==2.7.9", "rank-bm25==0.2.2", - "faster-whisper==1.0.2", + "faster-whisper==1.0.3", "PyJWT[crypto]==2.9.0", "authlib==1.3.1", diff --git a/requirements-dev.lock b/requirements-dev.lock index 6b3f518512..01dcaa2c3c 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -50,7 +50,7 @@ backoff==2.2.1 # via langfuse # via posthog # via unstructured -bcrypt==4.1.3 +bcrypt==4.2.0 # via chromadb # via open-webui # via passlib @@ -63,9 +63,9 @@ black==24.8.0 # via open-webui blinker==1.8.2 # via flask -boto3==1.34.153 +boto3==1.35.0 # via open-webui -botocore==1.34.155 +botocore==1.35.2 # via boto3 # via s3transfer build==1.2.1 @@ -156,7 +156,7 @@ fastapi==0.111.0 # via open-webui fastapi-cli==0.0.4 # via fastapi -faster-whisper==1.0.2 +faster-whisper==1.0.3 # via open-webui filelock==3.14.0 # via huggingface-hub @@ -632,7 +632,7 @@ sniffio==1.3.1 # via openai soupsieve==2.5 # via beautifulsoup4 -sqlalchemy==2.0.31 +sqlalchemy==2.0.32 # via alembic # via langchain # via langchain-community @@ -703,7 +703,7 @@ tzlocal==5.2 # via extract-msg ujson==5.10.0 # via fastapi -unstructured==0.15.0 +unstructured==0.15.5 # via open-webui unstructured-client==0.22.0 # via unstructured @@ -715,7 +715,7 @@ urllib3==2.2.1 # via kubernetes # via requests # via unstructured-client -uvicorn==0.22.0 +uvicorn==0.30.6 # via chromadb # via fastapi # via open-webui diff --git a/requirements.lock b/requirements.lock index 6b3f518512..01dcaa2c3c 100644 --- a/requirements.lock +++ b/requirements.lock @@ -50,7 +50,7 @@ backoff==2.2.1 # via langfuse # via posthog # via unstructured -bcrypt==4.1.3 +bcrypt==4.2.0 # via chromadb # via open-webui # via passlib @@ -63,9 +63,9 @@ black==24.8.0 # via open-webui blinker==1.8.2 # via flask -boto3==1.34.153 +boto3==1.35.0 # via open-webui -botocore==1.34.155 +botocore==1.35.2 # via boto3 # via s3transfer build==1.2.1 @@ -156,7 +156,7 @@ fastapi==0.111.0 # via open-webui fastapi-cli==0.0.4 # via fastapi -faster-whisper==1.0.2 +faster-whisper==1.0.3 # via open-webui filelock==3.14.0 # via huggingface-hub @@ -632,7 +632,7 @@ sniffio==1.3.1 # via openai soupsieve==2.5 # via beautifulsoup4 -sqlalchemy==2.0.31 +sqlalchemy==2.0.32 # via alembic # via langchain # via langchain-community @@ -703,7 +703,7 @@ tzlocal==5.2 # via extract-msg ujson==5.10.0 # via fastapi -unstructured==0.15.0 +unstructured==0.15.5 # via open-webui unstructured-client==0.22.0 # via unstructured @@ -715,7 +715,7 @@ urllib3==2.2.1 # via kubernetes # via requests # via unstructured-client -uvicorn==0.22.0 +uvicorn==0.30.6 # via chromadb # via fastapi # via open-webui diff --git a/src/app.css b/src/app.css index 4345bb3777..a421d90ae4 100644 --- a/src/app.css +++ b/src/app.css @@ -34,6 +34,10 @@ math { @apply rounded-lg; } +.markdown-prose { + @apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; +} + .markdown a { @apply underline; } diff --git a/src/lib/apis/images/index.ts b/src/lib/apis/images/index.ts index 3f624704eb..2e6510437b 100644 --- a/src/lib/apis/images/index.ts +++ b/src/lib/apis/images/index.ts @@ -1,6 +1,6 @@ import { IMAGES_API_BASE_URL } from '$lib/constants'; -export const getImageGenerationConfig = async (token: string = '') => { +export const getConfig = async (token: string = '') => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/config`, { @@ -32,11 +32,7 @@ export const getImageGenerationConfig = async (token: string = '') => { return res; }; -export const updateImageGenerationConfig = async ( - token: string = '', - engine: string, - enabled: boolean -) => { +export const updateConfig = async (token: string = '', config: object) => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/config/update`, { @@ -47,8 +43,7 @@ export const updateImageGenerationConfig = async ( ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - engine, - enabled + ...config }) }) .then(async (res) => { @@ -72,10 +67,10 @@ export const updateImageGenerationConfig = async ( return res; }; -export const getOpenAIConfig = async (token: string = '') => { +export const verifyConfigUrl = async (token: string = '') => { let error = null; - const res = await fetch(`${IMAGES_API_BASE_URL}/openai/config`, { + const res = await fetch(`${IMAGES_API_BASE_URL}/config/url/verify`, { method: 'GET', headers: { Accept: 'application/json', @@ -104,46 +99,10 @@ export const getOpenAIConfig = async (token: string = '') => { return res; }; -export const updateOpenAIConfig = async (token: string = '', url: string, key: string) => { - let error = null; - - const res = await fetch(`${IMAGES_API_BASE_URL}/openai/config/update`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - }, - body: JSON.stringify({ - url: url, - key: key - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - if ('detail' in err) { - error = err.detail; - } else { - error = 'Server connection failed'; - } - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const getImageGenerationEngineUrls = async (token: string = '') => { +export const getImageGenerationConfig = async (token: string = '') => { let error = null; - const res = await fetch(`${IMAGES_API_BASE_URL}/url`, { + const res = await fetch(`${IMAGES_API_BASE_URL}/image/config`, { method: 'GET', headers: { Accept: 'application/json', @@ -172,19 +131,17 @@ export const getImageGenerationEngineUrls = async (token: string = '') => { return res; }; -export const updateImageGenerationEngineUrls = async (token: string = '', urls: object = {}) => { +export const updateImageGenerationConfig = async (token: string = '', config: object) => { let error = null; - const res = await fetch(`${IMAGES_API_BASE_URL}/url/update`, { + const res = await fetch(`${IMAGES_API_BASE_URL}/image/config/update`, { method: 'POST', headers: { Accept: 'application/json', 'Content-Type': 'application/json', ...(token && { authorization: `Bearer ${token}` }) }, - body: JSON.stringify({ - ...urls - }) + body: JSON.stringify({ ...config }) }) .then(async (res) => { if (!res.ok) throw await res.json(); @@ -207,138 +164,6 @@ export const updateImageGenerationEngineUrls = async (token: string = '', urls: return res; }; -export const getImageSize = async (token: string = '') => { - let error = null; - - const res = await fetch(`${IMAGES_API_BASE_URL}/size`, { - method: 'GET', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - if ('detail' in err) { - error = err.detail; - } else { - error = 'Server connection failed'; - } - return null; - }); - - if (error) { - throw error; - } - - return res.IMAGE_SIZE; -}; - -export const updateImageSize = async (token: string = '', size: string) => { - let error = null; - - const res = await fetch(`${IMAGES_API_BASE_URL}/size/update`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - }, - body: JSON.stringify({ - size: size - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - if ('detail' in err) { - error = err.detail; - } else { - error = 'Server connection failed'; - } - return null; - }); - - if (error) { - throw error; - } - - return res.IMAGE_SIZE; -}; - -export const getImageSteps = async (token: string = '') => { - let error = null; - - const res = await fetch(`${IMAGES_API_BASE_URL}/steps`, { - method: 'GET', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - if ('detail' in err) { - error = err.detail; - } else { - error = 'Server connection failed'; - } - return null; - }); - - if (error) { - throw error; - } - - return res.IMAGE_STEPS; -}; - -export const updateImageSteps = async (token: string = '', steps: number) => { - let error = null; - - const res = await fetch(`${IMAGES_API_BASE_URL}/steps/update`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - }, - body: JSON.stringify({ steps }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - if ('detail' in err) { - error = err.detail; - } else { - error = 'Server connection failed'; - } - return null; - }); - - if (error) { - throw error; - } - - return res.IMAGE_STEPS; -}; - export const getImageGenerationModels = async (token: string = '') => { let error = null; @@ -371,73 +196,6 @@ export const getImageGenerationModels = async (token: string = '') => { return res; }; -export const getDefaultImageGenerationModel = async (token: string = '') => { - let error = null; - - const res = await fetch(`${IMAGES_API_BASE_URL}/models/default`, { - method: 'GET', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - if ('detail' in err) { - error = err.detail; - } else { - error = 'Server connection failed'; - } - return null; - }); - - if (error) { - throw error; - } - - return res.model; -}; - -export const updateDefaultImageGenerationModel = async (token: string = '', model: string) => { - let error = null; - - const res = await fetch(`${IMAGES_API_BASE_URL}/models/default/update`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - }, - body: JSON.stringify({ - model: model - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - if ('detail' in err) { - error = err.detail; - } else { - error = 'Server connection failed'; - } - return null; - }); - - if (error) { - throw error; - } - - return res.model; -}; - export const imageGenerations = async (token: string = '', prompt: string) => { let error = null; diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index c4778cadbd..8432554785 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -333,6 +333,42 @@ export const generateSearchQuery = async ( return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt; }; +export const generateMoACompletion = async ( + token: string = '', + model: string, + prompt: string, + responses: string[] +) => { + const controller = new AbortController(); + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/task/moa/completions`, { + signal: controller.signal, + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + prompt: prompt, + responses: responses, + stream: true + }) + }).catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return [res, controller]; +}; + export const getPipelinesList = async (token: string = '') => { let error = null; @@ -629,6 +665,7 @@ export const getBackendConfig = async () => { const res = await fetch(`${WEBUI_BASE_URL}/api/config`, { method: 'GET', + credentials: 'include', headers: { 'Content-Type': 'application/json' } @@ -913,6 +950,7 @@ export interface ModelConfig { export interface ModelMeta { description?: string; capabilities?: object; + profile_image_url?: string; } export interface ModelParams {} diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index c4c449156c..d4e994312e 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -396,7 +396,7 @@ export const deleteModel = async (token: string, tagName: string, urlIdx: string return res; }; -export const pullModel = async (token: string, tagName: string, urlIdx: string | null = null) => { +export const pullModel = async (token: string, tagName: string, urlIdx: number | null = null) => { let error = null; const controller = new AbortController(); diff --git a/src/lib/components/admin/Settings.svelte b/src/lib/components/admin/Settings.svelte index afb8736ea1..e242ab632a 100644 --- a/src/lib/components/admin/Settings.svelte +++ b/src/lib/components/admin/Settings.svelte @@ -336,8 +336,11 @@
{code}+
{code}{/if} {:else}