diff --git a/lib_comfyui/webui/reverse_proxy.py b/lib_comfyui/webui/reverse_proxy.py index 40e894c..e1d3813 100644 --- a/lib_comfyui/webui/reverse_proxy.py +++ b/lib_comfyui/webui/reverse_proxy.py @@ -5,61 +5,73 @@ def register_comfyui(fast_api): if not settings.is_reverse_proxy_enabled(): return - from starlette.requests import Request - from starlette.responses import StreamingResponse - from starlette.background import BackgroundTask - from fastapi import WebSocket - import httpx - import websockets - import asyncio - from starlette.websockets import WebSocketDisconnect - from websockets.exceptions import ConnectionClosedOK - comfyui_url = settings.get_comfyui_server_url() - proxy_route = settings.get_comfyui_reverse_proxy_route() - proxy_route_bytes = bytes(proxy_route, "utf-8") - async def async_iter_raw_patched(response): - async for chunk in response.aiter_raw(): - replacements = [ - (b'/favicon', proxy_route_bytes + b'/favicon'), - (b'from "/scripts/', b'from "' + proxy_route_bytes + b'/scripts/'), - (b'from "/extensions/', b'from "' + proxy_route_bytes + b'/extensions/'), - (b'from "/webui_scripts/', b'from "' + proxy_route_bytes + b'/webui_scripts/'), - (proxy_route_bytes * 2, proxy_route_bytes), - ] - for substring, replacement in replacements: - chunk = chunk.replace(substring, replacement) - yield chunk + create_http_reverse_proxy(fast_api, comfyui_url, proxy_route) + create_ws_reverse_proxy(fast_api, comfyui_url, proxy_route) + print("[sd-webui-comfyui]", f"Created a reverse proxy route to ComfyUI: {proxy_route}") + + +def create_http_reverse_proxy(fast_api, comfyui_url, proxy_route): + from starlette.requests import Request + from starlette.responses import StreamingResponse, Response + from starlette.background import BackgroundTask + import httpx web_client = httpx.AsyncClient(base_url=comfyui_url) async def reverse_proxy(request: Request): - """Proxy incoming requests to another server.""" base_path = request.url.path.replace(proxy_route, "", 1) url = httpx.URL(path=base_path, query=request.url.query.encode("utf-8")) rp_req = web_client.build_request(request.method, url, headers=request.headers.raw, content=await request.body()) - rp_resp = await web_client.send(rp_req, stream=True) - return StreamingResponse( - async_iter_raw_patched(rp_resp), - status_code=rp_resp.status_code, - headers=rp_resp.headers, - background=BackgroundTask(rp_resp.aclose), - ) + try: + rp_resp = await web_client.send(rp_req, stream=True) + except httpx.ConnectError: + return Response(status_code=404) + else: + return StreamingResponse( + async_iter_raw_patched(rp_resp, proxy_route), + status_code=rp_resp.status_code, + headers=rp_resp.headers, + background=BackgroundTask(rp_resp.aclose), + ) fast_api.add_route(f"{proxy_route}/{{path:path}}", reverse_proxy, ["GET", "POST", "PUT", "DELETE"]) + +async def async_iter_raw_patched(response, proxy_route): + proxy_route_bytes = bytes(proxy_route, "utf-8") + + async for chunk in response.aiter_raw(): + paths_to_replace = [b"/scripts/", b"/extensions/", b"/webui_scripts/"] + replacements = [ + (b'/favicon', proxy_route_bytes + b'/favicon'), + *( + (b'from "' + path, b'from "' + proxy_route_bytes + path) + for path in paths_to_replace + ), + ] + for substring, replacement in replacements: + chunk = chunk.replace(substring, replacement) + yield chunk + + +def create_ws_reverse_proxy(fast_api, comfyui_url, proxy_route): + from fastapi import WebSocket + import websockets + import asyncio + from starlette.websockets import WebSocketDisconnect + from websockets.exceptions import ConnectionClosedOK + ws_comfyui_url = http_to_ws(comfyui_url) @fast_api.websocket(f"{proxy_route}/ws") async def websocket_endpoint(ws_client: WebSocket): - """Websocket endpoint to proxy incoming WS requests.""" await ws_client.accept() async with websockets.connect(ws_comfyui_url) as ws_server: async def listen_to_client(): - """Forward messages from client to server.""" try: while True: data = await ws_client.receive_text() @@ -68,7 +80,6 @@ async def listen_to_client(): await ws_server.close() async def listen_to_server(): - """Forward messages from server to client.""" try: while True: data = await ws_server.recv() @@ -78,8 +89,6 @@ async def listen_to_server(): await asyncio.gather(listen_to_client(), listen_to_server()) - print("[sd-webui-comfyui]", f"Created a reverse proxy route to ComfyUI: {proxy_route}") - def http_to_ws(url: str) -> str: """Convert http or https URL to its websocket equivalent.""" @@ -88,4 +97,4 @@ def http_to_ws(url: str) -> str: parsed_url = urlparse(url) ws_scheme = 'wss' if parsed_url.scheme == 'https' else 'ws' ws_url = parsed_url._replace(scheme=ws_scheme) - return urlunparse(ws_url) + "ws" + return f"{urlunparse(ws_url)}/ws" diff --git a/lib_comfyui/webui/settings.py b/lib_comfyui/webui/settings.py index 5e7d636..4f64c20 100644 --- a/lib_comfyui/webui/settings.py +++ b/lib_comfyui/webui/settings.py @@ -106,7 +106,7 @@ def get_comfyui_reverse_proxy_url(): """ comfyui reverse proxy url, as seen from the browser """ - return f"{get_comfyui_reverse_proxy_route()}/" + return get_comfyui_reverse_proxy_route() def get_comfyui_reverse_proxy_route(): @@ -120,7 +120,7 @@ def get_comfyui_client_url(): """ from modules import shared loopback_address = '127.0.0.1' - server_url = "http://" + (get_setting_value('--listen') or getattr(shared.cmd_opts, 'comfyui_listen', loopback_address)) + ":" + str(get_port()) + "/" + server_url = "http://" + (get_setting_value('--listen') or getattr(shared.cmd_opts, 'comfyui_listen', loopback_address)) + ":" + str(get_port()) client_url = shared.opts.data.get('comfyui_client_address', None) or getattr(shared.cmd_opts, 'webui_comfyui_client_address', None) or server_url if client_url.startswith(('http://0.0.0.0', 'https://0.0.0.0')): print(textwrap.dedent(f""" @@ -138,7 +138,7 @@ def get_comfyui_server_url(): """ comfyui server url, as seen from the webui server """ - return f"http://localhost:{get_port()}/" + return f"http://localhost:{get_port()}" @ipc.restrict_to_process('webui')