Skip to content

Commit

Permalink
Merge pull request #47 from hayden-fr/feature-multi-user
Browse files Browse the repository at this point in the history
feat: adapt to multi user
  • Loading branch information
hayden-fr authored Nov 11, 2024
2 parents ae518b5 + a1e5761 commit f2e1774
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 257 deletions.
60 changes: 54 additions & 6 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,61 @@
routes = config.routes


@routes.get("/model-manager/ws")
async def socket_handler(request):
@routes.get("/model-manager/download/task")
async def scan_download_tasks(request):
"""
Handle websocket connection.
Read download task list.
"""
ws = await services.connect_websocket(request)
return ws
try:
result = await services.scan_model_download_task_list()
return web.json_response({"success": True, "data": result})
except Exception as e:
error_msg = f"Read download task list failed: {e}"
logging.error(error_msg)
logging.debug(traceback.format_exc())
return web.json_response({"success": False, "error": error_msg})


@routes.put("/model-manager/download/{task_id}")
async def resume_download_task(request):
"""
Toggle download task status.
"""
try:
task_id = request.match_info.get("task_id", None)
if task_id is None:
raise web.HTTPBadRequest(reason="Invalid task id")
json_data = await request.json()
status = json_data.get("status", None)
if status == "pause":
await services.pause_model_download_task(task_id)
elif status == "resume":
await services.resume_model_download_task(task_id, request)
else:
raise web.HTTPBadRequest(reason="Invalid status")

return web.json_response({"success": True})
except Exception as e:
error_msg = f"Resume download task failed: {str(e)}"
logging.error(error_msg)
logging.debug(traceback.format_exc())
return web.json_response({"success": False, "error": error_msg})


@routes.delete("/model-manager/download/{task_id}")
async def delete_model_download_task(request):
"""
Delete download task.
"""
task_id = request.match_info.get("task_id", None)
try:
await services.delete_model_download_task(task_id)
return web.json_response({"success": True})
except Exception as e:
error_msg = f"Delete download task failed: {str(e)}"
logging.error(error_msg)
logging.debug(traceback.format_exc())
return web.json_response({"success": False, "error": error_msg})


@routes.get("/model-manager/base-folders")
Expand Down Expand Up @@ -56,7 +104,7 @@ async def create_model(request):
"""
post = await request.post()
try:
task_id = await services.create_model_download_task(post)
task_id = await services.create_model_download_task(post, request)
return web.json_response({"success": True, "data": {"taskId": task_id}})
except Exception as e:
error_msg = f"Create model download task failed: {str(e)}"
Expand Down
12 changes: 0 additions & 12 deletions py/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,3 @@

serverInstance = PromptServer.instance
routes = serverInstance.routes


class FakeRequest:
def __init__(self):
self.headers = {}


class CustomException(BaseException):
def __init__(self, type: str, message: str = None) -> None:
self.type = type
self.message = message
super().__init__(message)
68 changes: 31 additions & 37 deletions py/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from dataclasses import dataclass
from . import config
from . import utils
from . import socket
from . import thread


Expand Down Expand Up @@ -93,33 +92,28 @@ def delete_task_status(task_id: str):
download_model_task_status.pop(task_id, None)


async def scan_model_download_task_list(sid: str):
async def scan_model_download_task_list():
"""
Scan the download directory and send the task list to the client.
"""
try:
download_dir = utils.get_download_path()
task_files = utils.search_files(download_dir)
task_files = folder_paths.filter_files_extensions(task_files, [".task"])
task_files = sorted(
task_files,
key=lambda x: os.stat(utils.join_path(download_dir, x)).st_ctime,
reverse=True,
)
task_list: list[dict] = []
for task_file in task_files:
task_id = task_file.replace(".task", "")
task_status = get_task_status(task_id)
task_list.append(task_status)
download_dir = utils.get_download_path()
task_files = utils.search_files(download_dir)
task_files = folder_paths.filter_files_extensions(task_files, [".task"])
task_files = sorted(
task_files,
key=lambda x: os.stat(utils.join_path(download_dir, x)).st_ctime,
reverse=True,
)
task_list: list[dict] = []
for task_file in task_files:
task_id = task_file.replace(".task", "")
task_status = get_task_status(task_id)
task_list.append(task_status)

await socket.send_json("downloadTaskList", task_list, sid)
except Exception as e:
error_msg = f"Refresh task list failed: {e}"
await socket.send_json("error", error_msg, sid)
logging.error(error_msg)
return utils.unpack_dataclass(task_list)


async def create_model_download_task(post: dict):
async def create_model_download_task(post: dict, request):
"""
Creates a download task for the given post.
"""
Expand Down Expand Up @@ -152,12 +146,12 @@ async def create_model_download_task(post: dict):
totalSize=float(post.get("sizeBytes", 0)),
)
download_model_task_status[task_id] = task_status
await socket.send_json("createDownloadTask", task_status)
await utils.send_json("create_download_task", task_status)
except Exception as e:
await delete_model_download_task(task_id)
raise RuntimeError(str(e)) from e

await download_model(task_id)
await download_model(task_id, request)
return task_id


Expand All @@ -170,7 +164,7 @@ async def delete_model_download_task(task_id: str):
task_status = get_task_status(task_id)
is_running = task_status.status == "doing"
task_status.status = "waiting"
await socket.send_json("deleteDownloadTask", task_id)
await utils.send_json("delete_download_task", task_id)

# Pause the task
if is_running:
Expand All @@ -185,13 +179,13 @@ async def delete_model_download_task(task_id: str):
delete_task_status(task_id)
os.remove(utils.join_path(download_dir, task_file))

await socket.send_json("deleteDownloadTask", task_id)
await utils.send_json("delete_download_task", task_id)


async def download_model(task_id: str):
async def download_model(task_id: str, request):
async def download_task(task_id: str):
async def report_progress(task_status: TaskStatus):
await socket.send_json("updateDownloadTask", task_status)
await utils.send_json("update_download_task", task_status)

try:
# When starting a task from the queue, the task may not exist
Expand All @@ -201,7 +195,7 @@ async def report_progress(task_status: TaskStatus):

# Update task status
task_status.status = "doing"
await socket.send_json("updateDownloadTask", task_status)
await utils.send_json("update_download_task", task_status)

try:

Expand All @@ -210,12 +204,12 @@ async def report_progress(task_status: TaskStatus):

download_platform = task_status.platform
if download_platform == "civitai":
api_key = utils.get_setting_value("api_key.civitai")
api_key = utils.get_setting_value(request, "api_key.civitai")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"

elif download_platform == "huggingface":
api_key = utils.get_setting_value("api_key.huggingface")
api_key = utils.get_setting_value(request, "api_key.huggingface")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"

Expand All @@ -229,7 +223,7 @@ async def report_progress(task_status: TaskStatus):
except Exception as e:
task_status.status = "pause"
task_status.error = str(e)
await socket.send_json("updateDownloadTask", task_status)
await utils.send_json("update_download_task", task_status)
task_status.error = None
logging.error(str(e))

Expand All @@ -238,11 +232,11 @@ async def report_progress(task_status: TaskStatus):
if status == "Waiting":
task_status = get_task_status(task_id)
task_status.status = "waiting"
await socket.send_json("updateDownloadTask", task_status)
await utils.send_json("update_download_task", task_status)
except Exception as e:
task_status.status = "pause"
task_status.error = str(e)
await socket.send_json("updateDownloadTask", task_status)
await utils.send_json("update_download_task", task_status)
task_status.error = None
logging.error(traceback.format_exc())

Expand Down Expand Up @@ -275,7 +269,7 @@ async def download_complete():
time.sleep(1)
task_file = utils.join_path(download_path, f"{task_id}.task")
os.remove(task_file)
await socket.send_json("completeDownloadTask", task_id)
await utils.send_json("complete_download_task", task_id)

async def update_progress():
nonlocal last_update_time
Expand Down Expand Up @@ -347,7 +341,7 @@ async def update_progress():
task_content.sizeBytes = total_size
task_status.totalSize = total_size
set_task_content(task_id, task_content)
await socket.send_json("updateDownloadTask", task_content)
await utils.send_json("update_download_task", task_content)

with open(download_tmp_file, "ab") as f:
for chunk in response.iter_content(chunk_size=8192):
Expand All @@ -366,4 +360,4 @@ async def update_progress():
await download_complete()
else:
task_status.status = "pause"
await socket.send_json("updateDownloadTask", task_status)
await utils.send_json("update_download_task", task_status)
46 changes: 19 additions & 27 deletions py/services.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,13 @@
import os
import logging
import traceback

import folder_paths

from typing import Any
from multidict import MultiDictProxy
from . import config
from . import utils
from . import socket
from . import download


async def connect_websocket(request):
async def message_handler(event_type: str, detail: Any, sid: str):
try:
if event_type == "downloadTaskList":
await download.scan_model_download_task_list(sid=sid)

if event_type == "resumeDownloadTask":
await download.download_model(task_id=detail)

if event_type == "pauseDownloadTask":
await download.pause_model_download_task(task_id=detail)

if event_type == "deleteDownloadTask":
await download.delete_model_download_task(task_id=detail)
except Exception:
logging.error(traceback.format_exc())

ws = await socket.create_websocket_handler(request, handler=message_handler)
return ws


def scan_models():
result = []
model_base_paths = config.model_base_paths
Expand Down Expand Up @@ -135,6 +111,22 @@ def remove_model(model_path: str):
os.remove(utils.join_path(model_dirname, description))


async def create_model_download_task(post):
async def create_model_download_task(post, request):
dict_post = dict(post)
return await download.create_model_download_task(dict_post)
return await download.create_model_download_task(dict_post, request)


async def scan_model_download_task_list():
return await download.scan_model_download_task_list()


async def pause_model_download_task(task_id):
return await download.pause_model_download_task(task_id)


async def resume_model_download_task(task_id, request):
return await download.download_model(task_id, request)


async def delete_model_download_task(task_id):
return await download.delete_model_download_task(task_id)
63 changes: 0 additions & 63 deletions py/socket.py

This file was deleted.

Loading

0 comments on commit f2e1774

Please sign in to comment.