Skip to content

Commit

Permalink
refact
Browse files Browse the repository at this point in the history
  • Loading branch information
ljleb committed Aug 16, 2023
1 parent be023b8 commit d7082bb
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 41 deletions.
85 changes: 47 additions & 38 deletions lib_comfyui/webui/reverse_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,61 +5,73 @@ def register_comfyui(fast_api):
if not settings.is_reverse_proxy_enabled():
return

from starlette.requests import Request
from starlette.responses import StreamingResponse
from starlette.background import BackgroundTask
from fastapi import WebSocket
import httpx
import websockets
import asyncio
from starlette.websockets import WebSocketDisconnect
from websockets.exceptions import ConnectionClosedOK

comfyui_url = settings.get_comfyui_server_url()

proxy_route = settings.get_comfyui_reverse_proxy_route()
proxy_route_bytes = bytes(proxy_route, "utf-8")

async def async_iter_raw_patched(response):
async for chunk in response.aiter_raw():
replacements = [
(b'/favicon', proxy_route_bytes + b'/favicon'),
(b'from "/scripts/', b'from "' + proxy_route_bytes + b'/scripts/'),
(b'from "/extensions/', b'from "' + proxy_route_bytes + b'/extensions/'),
(b'from "/webui_scripts/', b'from "' + proxy_route_bytes + b'/webui_scripts/'),
(proxy_route_bytes * 2, proxy_route_bytes),
]
for substring, replacement in replacements:
chunk = chunk.replace(substring, replacement)
yield chunk
create_http_reverse_proxy(fast_api, comfyui_url, proxy_route)
create_ws_reverse_proxy(fast_api, comfyui_url, proxy_route)
print("[sd-webui-comfyui]", f"Created a reverse proxy route to ComfyUI: {proxy_route}")


def create_http_reverse_proxy(fast_api, comfyui_url, proxy_route):
from starlette.requests import Request
from starlette.responses import StreamingResponse, Response
from starlette.background import BackgroundTask
import httpx

web_client = httpx.AsyncClient(base_url=comfyui_url)

async def reverse_proxy(request: Request):
"""Proxy incoming requests to another server."""
base_path = request.url.path.replace(proxy_route, "", 1)
url = httpx.URL(path=base_path, query=request.url.query.encode("utf-8"))
rp_req = web_client.build_request(request.method, url, headers=request.headers.raw, content=await request.body())
rp_resp = await web_client.send(rp_req, stream=True)
return StreamingResponse(
async_iter_raw_patched(rp_resp),
status_code=rp_resp.status_code,
headers=rp_resp.headers,
background=BackgroundTask(rp_resp.aclose),
)
try:
rp_resp = await web_client.send(rp_req, stream=True)
except httpx.ConnectError:
return Response(status_code=404)
else:
return StreamingResponse(
async_iter_raw_patched(rp_resp, proxy_route),
status_code=rp_resp.status_code,
headers=rp_resp.headers,
background=BackgroundTask(rp_resp.aclose),
)

fast_api.add_route(f"{proxy_route}/{{path:path}}", reverse_proxy, ["GET", "POST", "PUT", "DELETE"])


async def async_iter_raw_patched(response, proxy_route):
proxy_route_bytes = bytes(proxy_route, "utf-8")

async for chunk in response.aiter_raw():
paths_to_replace = [b"/scripts/", b"/extensions/", b"/webui_scripts/"]
replacements = [
(b'/favicon', proxy_route_bytes + b'/favicon'),
*(
(b'from "' + path, b'from "' + proxy_route_bytes + path)
for path in paths_to_replace
),
]
for substring, replacement in replacements:
chunk = chunk.replace(substring, replacement)
yield chunk


def create_ws_reverse_proxy(fast_api, comfyui_url, proxy_route):
from fastapi import WebSocket
import websockets
import asyncio
from starlette.websockets import WebSocketDisconnect
from websockets.exceptions import ConnectionClosedOK

ws_comfyui_url = http_to_ws(comfyui_url)

@fast_api.websocket(f"{proxy_route}/ws")
async def websocket_endpoint(ws_client: WebSocket):
"""Websocket endpoint to proxy incoming WS requests."""
await ws_client.accept()
async with websockets.connect(ws_comfyui_url) as ws_server:

async def listen_to_client():
"""Forward messages from client to server."""
try:
while True:
data = await ws_client.receive_text()
Expand All @@ -68,7 +80,6 @@ async def listen_to_client():
await ws_server.close()

async def listen_to_server():
"""Forward messages from server to client."""
try:
while True:
data = await ws_server.recv()
Expand All @@ -78,8 +89,6 @@ async def listen_to_server():

await asyncio.gather(listen_to_client(), listen_to_server())

print("[sd-webui-comfyui]", f"Created a reverse proxy route to ComfyUI: {proxy_route}")


def http_to_ws(url: str) -> str:
"""Convert http or https URL to its websocket equivalent."""
Expand All @@ -88,4 +97,4 @@ def http_to_ws(url: str) -> str:
parsed_url = urlparse(url)
ws_scheme = 'wss' if parsed_url.scheme == 'https' else 'ws'
ws_url = parsed_url._replace(scheme=ws_scheme)
return urlunparse(ws_url) + "ws"
return f"{urlunparse(ws_url)}/ws"
6 changes: 3 additions & 3 deletions lib_comfyui/webui/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_comfyui_reverse_proxy_url():
"""
comfyui reverse proxy url, as seen from the browser
"""
return f"{get_comfyui_reverse_proxy_route()}/"
return get_comfyui_reverse_proxy_route()


def get_comfyui_reverse_proxy_route():
Expand All @@ -120,7 +120,7 @@ def get_comfyui_client_url():
"""
from modules import shared
loopback_address = '127.0.0.1'
server_url = "http://" + (get_setting_value('--listen') or getattr(shared.cmd_opts, 'comfyui_listen', loopback_address)) + ":" + str(get_port()) + "/"
server_url = "http://" + (get_setting_value('--listen') or getattr(shared.cmd_opts, 'comfyui_listen', loopback_address)) + ":" + str(get_port())
client_url = shared.opts.data.get('comfyui_client_address', None) or getattr(shared.cmd_opts, 'webui_comfyui_client_address', None) or server_url
if client_url.startswith(('http://0.0.0.0', 'https://0.0.0.0')):
print(textwrap.dedent(f"""
Expand All @@ -138,7 +138,7 @@ def get_comfyui_server_url():
"""
comfyui server url, as seen from the webui server
"""
return f"http://localhost:{get_port()}/"
return f"http://localhost:{get_port()}"


@ipc.restrict_to_process('webui')
Expand Down

0 comments on commit d7082bb

Please sign in to comment.