diff --git a/README.md b/README.md index ee81c9c..2d4e3c1 100644 --- a/README.md +++ b/README.md @@ -61,3 +61,10 @@ npm run build - Read, edit and save notes. (Saved as a `.md` file beside the model). - Change or remove a model's preview image. - View training tags and use the random tag generator to generate prompt ideas. (Inspired by the one in A1111.) + +### Scan Model Information + +Model Manager Demo Screenshot + +- Scan models and try to download information & preview. +- Support migration from `cdb-boop/ComfyUI-Model-Manager/main` diff --git a/__init__.py b/__init__.py index 31e6d87..c5a0770 100644 --- a/__init__.py +++ b/__init__.py @@ -3,9 +3,26 @@ from .py import config from .py import utils +extension_uri = utils.normalize_path(os.path.dirname(__file__)) + +requirements_path = utils.join_path(extension_uri, "requirements.txt") + +with open(requirements_path, "r", encoding="utf-8") as f: + requirements = f.readlines() + +requirements = [x.strip() for x in requirements] +requirements = [x for x in requirements if not x.startswith("#")] + +uninstalled_package = [p for p in requirements if not utils.is_installed(p)] + +if len(uninstalled_package) > 0: + utils.print_info(f"Install dependencies...") + for p in uninstalled_package: + utils.pip_install(p) + # Init config settings -config.extension_uri = utils.normalize_path(os.path.dirname(__file__)) +config.extension_uri = extension_uri utils.resolve_model_base_paths() version = utils.get_current_version() @@ -97,9 +114,8 @@ async def create_model(request): - downloadUrl: download url. - hash: a JSON string containing the hash value of the downloaded model. """ - post = await request.post() + task_data = await request.json() try: - task_data = dict(post) task_id = await services.create_model_download_task(task_data, request) return web.json_response({"success": True, "data": {"taskId": task_id}}) except Exception as e: @@ -158,13 +174,12 @@ async def update_model(request): index = int(request.match_info.get("index", None)) filename = request.match_info.get("filename", None) - post: dict = await request.post() + model_data: dict = await request.json() try: model_path = utils.get_valid_full_path(model_type, index, filename) if model_path is None: raise RuntimeError(f"File {filename} not found") - model_data = dict(post) services.update_model(model_path, model_data) return web.json_response({"success": True}) except Exception as e: @@ -194,6 +209,37 @@ async def delete_model(request): return web.json_response({"success": False, "error": error_msg}) +@routes.get("/model-manager/model-info") +async def fetch_model_info(request): + """ + Fetch model information from network with model page. + """ + try: + model_page = request.query.get("model-page", None) + result = services.fetch_model_info(model_page) + return web.json_response({"success": True, "data": result}) + except Exception as e: + error_msg = f"Fetch model info failed: {str(e)}" + utils.print_error(error_msg) + return web.json_response({"success": False, "error": error_msg}) + + +@routes.post("/model-manager/model-info/scan") +async def download_model_info(request): + """ + Create a task to download model information. + """ + post = await utils.get_request_body(request) + try: + scan_mode = post.get("scanMode", "diff") + await services.download_model_info(scan_mode) + return web.json_response({"success": True}) + except Exception as e: + error_msg = f"Download model info failed: {str(e)}" + utils.print_error(error_msg) + return web.json_response({"success": False, "error": error_msg}) + + @routes.get("/model-manager/preview/{type}/{index}/{filename:.*}") async def read_model_preview(request): """ @@ -236,6 +282,20 @@ async def read_download_preview(request): return web.FileResponse(preview_path) +@routes.post("/model-manager/migrate") +async def migrate_legacy_information(request): + """ + Migrate legacy information. + """ + try: + await services.migrate_legacy_information() + return web.json_response({"success": True}) + except Exception as e: + error_msg = f"Download model info failed: {str(e)}" + utils.print_error(error_msg) + return web.json_response({"success": False, "error": error_msg}) + + WEB_DIRECTORY = "web" NODE_CLASS_MAPPINGS = {} __all__ = ["WEB_DIRECTORY", "NODE_CLASS_MAPPINGS"] diff --git a/demo/scan-model-info.png b/demo/scan-model-info.png new file mode 100755 index 0000000..7a26a3e Binary files /dev/null and b/demo/scan-model-info.png differ diff --git a/package.json b/package.json index 3706b2e..50e0865 100644 --- a/package.json +++ b/package.json @@ -14,7 +14,6 @@ "@types/lodash": "^4.17.9", "@types/markdown-it": "^14.1.2", "@types/node": "^22.5.5", - "@types/turndown": "^5.0.5", "@vitejs/plugin-vue": "^5.1.4", "autoprefixer": "^10.4.20", "eslint": "^9.10.0", @@ -40,15 +39,13 @@ "markdown-it": "^14.1.0", "markdown-it-metadata-block": "^1.0.6", "primevue": "^4.0.7", - "turndown": "^7.2.0", "vue": "^3.4.31", "vue-i18n": "^9.13.1", "yaml": "^2.6.0" }, "lint-staged": { "./**/*.{js,ts,tsx,vue}": [ - "prettier --write", - "git add" + "prettier --write" ] } } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 3fc0dc2..2b59fef 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -26,9 +26,6 @@ importers: primevue: specifier: ^4.0.7 version: 4.0.7(vue@3.5.6(typescript@5.6.2)) - turndown: - specifier: ^7.2.0 - version: 7.2.0 vue: specifier: ^3.4.31 version: 3.5.6(typescript@5.6.2) @@ -51,9 +48,6 @@ importers: '@types/node': specifier: ^22.5.5 version: 22.5.5 - '@types/turndown': - specifier: ^5.0.5 - version: 5.0.5 '@vitejs/plugin-vue': specifier: ^5.1.4 version: 5.1.4(vite@5.4.6(@types/node@22.5.5)(less@4.2.0))(vue@3.5.6(typescript@5.6.2)) @@ -349,9 +343,6 @@ packages: '@jridgewell/trace-mapping@0.3.25': resolution: {integrity: sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==} - '@mixmark-io/domino@2.2.0': - resolution: {integrity: sha512-Y28PR25bHXUg88kCV7nivXrP2Nj2RueZ3/l/jdx6J9f8J4nsEGcgX0Qe6lt7Pa+J79+kPiJU3LguR6O/6zrLOw==} - '@nodelib/fs.scandir@2.1.5': resolution: {integrity: sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==} engines: {node: '>= 8'} @@ -502,9 +493,6 @@ packages: '@types/node@22.5.5': resolution: {integrity: sha512-Xjs4y5UPO/CLdzpgR6GirZJx36yScjh73+2NlLlkFRSoQN8B0DpfXPdZGnvVmLRLOsqDpOfTNv7D9trgGhmOIA==} - '@types/turndown@5.0.5': - resolution: {integrity: sha512-TL2IgGgc7B5j78rIccBtlYAnkuv8nUQqhQc+DSYV5j9Be9XOcm/SKOVRuA47xAVI3680Tk9B1d8flK2GWT2+4w==} - '@typescript-eslint/eslint-plugin@8.13.0': resolution: {integrity: sha512-nQtBLiZYMUPkclSeC3id+x4uVd1SGtHuElTxL++SfP47jR0zfkZBJHc+gL4qPsgTuypz0k8Y2GheaDYn6Gy3rg==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} @@ -1614,9 +1602,6 @@ packages: tslib@2.7.0: resolution: {integrity: sha512-gLXCKdN1/j47AiHiOkJN69hJmcbGTHI0ImLmbYLHykhgeN0jVGola9yVjFgzCUklsZQMW55o+dW7IXv3RCXDzA==} - turndown@7.2.0: - resolution: {integrity: sha512-eCZGBN4nNNqM9Owkv9HAtWRYfLA4h909E/WGAWWBpmB275ehNhZyk87/Tpvjbp0jjNl9XwCsbe6bm6CqFsgD+A==} - type-check@0.4.0: resolution: {integrity: sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==} engines: {node: '>= 0.8.0'} @@ -1929,8 +1914,6 @@ snapshots: '@jridgewell/resolve-uri': 3.1.2 '@jridgewell/sourcemap-codec': 1.5.0 - '@mixmark-io/domino@2.2.0': {} - '@nodelib/fs.scandir@2.1.5': dependencies: '@nodelib/fs.stat': 2.0.5 @@ -2038,8 +2021,6 @@ snapshots: dependencies: undici-types: 6.19.8 - '@types/turndown@5.0.5': {} - '@typescript-eslint/eslint-plugin@8.13.0(@typescript-eslint/parser@8.13.0(eslint@9.10.0(jiti@1.21.6))(typescript@5.6.2))(eslint@9.10.0(jiti@1.21.6))(typescript@5.6.2)': dependencies: '@eslint-community/regexpp': 4.12.1 @@ -3164,10 +3145,6 @@ snapshots: tslib@2.7.0: {} - turndown@7.2.0: - dependencies: - '@mixmark-io/domino': 2.2.0 - type-check@0.4.0: dependencies: prelude-ls: 1.2.1 diff --git a/py/download.py b/py/download.py index 0fee52b..84f61d2 100644 --- a/py/download.py +++ b/py/download.py @@ -24,6 +24,34 @@ class TaskStatus: bps: float = 0 error: Optional[str] = None + def __init__(self, **kwargs): + self.taskId = kwargs.get("taskId", None) + self.type = kwargs.get("type", None) + self.fullname = kwargs.get("fullname", None) + self.preview = kwargs.get("preview", None) + self.status = kwargs.get("status", "pause") + self.platform = kwargs.get("platform", None) + self.downloadedSize = kwargs.get("downloadedSize", 0) + self.totalSize = kwargs.get("totalSize", 0) + self.progress = kwargs.get("progress", 0) + self.bps = kwargs.get("bps", 0) + self.error = kwargs.get("error", None) + + def to_dict(self): + return { + "taskId": self.taskId, + "type": self.type, + "fullname": self.fullname, + "preview": self.preview, + "status": self.status, + "platform": self.platform, + "downloadedSize": self.downloadedSize, + "totalSize": self.totalSize, + "progress": self.progress, + "bps": self.bps, + "error": self.error, + } + @dataclass class TaskContent: @@ -33,9 +61,31 @@ class TaskContent: description: str downloadPlatform: str downloadUrl: str - sizeBytes: float + sizeBytes: int hashes: Optional[dict[str, str]] = None + def __init__(self, **kwargs): + self.type = kwargs.get("type", None) + self.pathIndex = int(kwargs.get("pathIndex", 0)) + self.fullname = kwargs.get("fullname", None) + self.description = kwargs.get("description", None) + self.downloadPlatform = kwargs.get("downloadPlatform", None) + self.downloadUrl = kwargs.get("downloadUrl", None) + self.sizeBytes = int(kwargs.get("sizeBytes", 0)) + self.hashes = kwargs.get("hashes", None) + + def to_dict(self): + return { + "type": self.type, + "pathIndex": self.pathIndex, + "fullname": self.fullname, + "description": self.description, + "downloadPlatform": self.downloadPlatform, + "downloadUrl": self.downloadUrl, + "sizeBytes": self.sizeBytes, + "hashes": self.hashes, + } + download_model_task_status: dict[str, TaskStatus] = {} download_thread_pool = thread.DownloadThreadPool() @@ -44,7 +94,7 @@ class TaskContent: def set_task_content(task_id: str, task_content: Union[TaskContent, dict]): download_path = utils.get_download_path() task_file_path = utils.join_path(download_path, f"{task_id}.task") - utils.save_dict_pickle_file(task_file_path, utils.unpack_dataclass(task_content)) + utils.save_dict_pickle_file(task_file_path, task_content) def get_task_content(task_id: str): @@ -53,8 +103,6 @@ def get_task_content(task_id: str): if not os.path.isfile(task_file): raise RuntimeError(f"Task {task_id} not found") task_content = utils.load_dict_pickle_file(task_file) - task_content["pathIndex"] = int(task_content.get("pathIndex", 0)) - task_content["sizeBytes"] = float(task_content.get("sizeBytes", 0)) return TaskContent(**task_content) @@ -106,14 +154,14 @@ async def scan_model_download_task_list(): for task_file in task_files: task_id = task_file.replace(".task", "") task_status = get_task_status(task_id) - task_list.append(task_status) + task_list.append(task_status.to_dict()) - return utils.unpack_dataclass(task_list) + return task_list async def create_model_download_task(task_data: dict, request): """ - Creates a download task for the given post. + Creates a download task for the given data. """ model_type = task_data.get("type", None) path_index = int(task_data.get("pathIndex", None)) @@ -132,8 +180,8 @@ async def create_model_download_task(task_data: dict, request): raise RuntimeError(f"Task {task_id} already exists") try: - previewFile = task_data.pop("previewFile", None) - utils.save_model_preview_image(task_path, previewFile) + preview_url = task_data.pop("preview", None) + utils.save_model_preview_image(task_path, preview_url) set_task_content(task_id, task_data) task_status = TaskStatus( taskId=task_id, @@ -144,7 +192,7 @@ async def create_model_download_task(task_data: dict, request): totalSize=float(task_data.get("sizeBytes", 0)), ) download_model_task_status[task_id] = task_status - await utils.send_json("create_download_task", task_status) + await utils.send_json("create_download_task", task_status.to_dict()) except Exception as e: await delete_model_download_task(task_id) raise RuntimeError(str(e)) from e @@ -183,7 +231,7 @@ async def delete_model_download_task(task_id: str): async def download_model(task_id: str, request): async def download_task(task_id: str): async def report_progress(task_status: TaskStatus): - await utils.send_json("update_download_task", task_status) + await utils.send_json("update_download_task", task_status.to_dict()) try: # When starting a task from the queue, the task may not exist @@ -193,7 +241,7 @@ async def report_progress(task_status: TaskStatus): # Update task status task_status.status = "doing" - await utils.send_json("update_download_task", task_status) + await utils.send_json("update_download_task", task_status.to_dict()) try: @@ -221,7 +269,7 @@ async def report_progress(task_status: TaskStatus): except Exception as e: task_status.status = "pause" task_status.error = str(e) - await utils.send_json("update_download_task", task_status) + await utils.send_json("update_download_task", task_status.to_dict()) task_status.error = None utils.print_error(str(e)) @@ -230,11 +278,11 @@ async def report_progress(task_status: TaskStatus): if status == "Waiting": task_status = get_task_status(task_id) task_status.status = "waiting" - await utils.send_json("update_download_task", task_status) + await utils.send_json("update_download_task", task_status.to_dict()) except Exception as e: task_status.status = "pause" task_status.error = str(e) - await utils.send_json("update_download_task", task_status) + await utils.send_json("update_download_task", task_status.to_dict()) task_status.error = None utils.print_error(str(e)) @@ -339,7 +387,7 @@ async def update_progress(): task_content.sizeBytes = total_size task_status.totalSize = total_size set_task_content(task_id, task_content) - await utils.send_json("update_download_task", task_content) + await utils.send_json("update_download_task", task_content.to_dict()) with open(download_tmp_file, "ab") as f: for chunk in response.iter_content(chunk_size=8192): @@ -358,4 +406,4 @@ async def update_progress(): await download_complete() else: task_status.status = "pause" - await utils.send_json("update_download_task", task_status) + await utils.send_json("update_download_task", task_status.to_dict()) diff --git a/py/searcher.py b/py/searcher.py new file mode 100644 index 0000000..a2a1926 --- /dev/null +++ b/py/searcher.py @@ -0,0 +1,317 @@ +import os +import re +import yaml +import requests +import markdownify + + +from abc import ABC, abstractmethod +from urllib.parse import urlparse, parse_qs + +from . import utils + + +class ModelSearcher(ABC): + """ + Abstract class for model searcher. + """ + + @abstractmethod + def search_by_url(self, url: str) -> list[dict]: + pass + + @abstractmethod + def search_by_hash(self, hash: str) -> dict: + pass + + +class UnknownWebsiteSearcher(ModelSearcher): + def search_by_url(self, url: str): + raise RuntimeError( + f"Unknown Website, please input a URL from huggingface.co or civitai.com." + ) + + def search_by_hash(self, hash: str): + raise RuntimeError(f"Unknown Website, unable to search with hash value.") + + +class CivitaiModelSearcher(ModelSearcher): + def search_by_url(self, url: str): + parsed_url = urlparse(url) + + pathname = parsed_url.path + match = re.match(r"^/models/(\d*)", pathname) + model_id = match.group(1) if match else None + + query_params = parse_qs(parsed_url.query) + version_id = query_params.get("modelVersionId", [None])[0] + + if not model_id: + return [] + + response = requests.get(f"https://civitai.com/api/v1/models/{model_id}") + response.raise_for_status() + res_data: dict = response.json() + + model_versions: list[dict] = res_data["modelVersions"] + if version_id: + model_versions = utils.filter_with(model_versions, {"id": int(version_id)}) + + models: list[dict] = [] + + for version in model_versions: + model_files: list[dict] = version.get("files", []) + model_files = utils.filter_with(model_files, {"type": "Model"}) + + shortname = version.get("name", None) if len(model_files) > 0 else None + + for file in model_files: + fullname = file.get("name", None) + extension = os.path.splitext(fullname)[1] + basename = os.path.splitext(fullname)[0] + + metadata_info = { + "website": "Civitai", + "modelPage": f"https://civitai.com/models/{model_id}?modelVersionId={version.get('id')}", + "author": res_data.get("creator", {}).get("username", None), + "baseModel": version.get("baseModel"), + "hashes": file.get("hashes"), + "metadata": file.get("metadata"), + "preview": [i["url"] for i in version["images"]], + } + + description_parts: list[str] = [] + description_parts.append("---") + description_parts.append(yaml.dump(metadata_info).strip()) + description_parts.append("---") + description_parts.append("") + description_parts.append(f"# Trigger Words") + description_parts.append("") + description_parts.append( + ", ".join(version.get("trainedWords", ["No trigger words"])) + ) + description_parts.append("") + description_parts.append(f"# About this version") + description_parts.append("") + description_parts.append( + markdownify.markdownify( + version.get( + "description", "

No description about this version

" + ) + ).strip() + ) + description_parts.append("") + description_parts.append(f"# {res_data.get('name')}") + description_parts.append("") + description_parts.append( + markdownify.markdownify( + res_data.get( + "description", "

No description about this model

" + ) + ).strip() + ) + description_parts.append("") + + model = { + "id": file.get("id"), + "shortname": shortname or basename, + "fullname": fullname, + "basename": basename, + "extension": extension, + "preview": metadata_info.get("preview"), + "sizeBytes": file.get("sizeKB", 0) * 1024, + "type": self._resolve_model_type(res_data.get("type", "unknown")), + "pathIndex": 0, + "description": "\n".join(description_parts), + "metadata": file.get("metadata"), + "downloadPlatform": "civitai", + "downloadUrl": file.get("downloadUrl"), + "hashes": file.get("hashes"), + } + models.append(model) + + return models + + def search_by_hash(self, hash: str): + if not hash: + raise RuntimeError(f"Hash value is empty.") + + response = requests.get( + f"https://civitai.com/api/v1/model-versions/by-hash/{hash}" + ) + response.raise_for_status() + version: dict = response.json() + + model_id = version.get("modelId") + version_id = version.get("id") + + model_page = ( + f"https://civitai.com/models/{model_id}?modelVersionId={version_id}" + ) + + models = self.search_by_url(model_page) + + for model in models: + sha256 = model.get("hashes", {}).get("SHA256") + if sha256 == hash: + return model + + return models[0] + + def _resolve_model_type(self, model_type: str): + map_legacy = { + "TextualInversion": "embeddings", + "LoCon": "loras", + "DoRA": "loras", + "Controlnet": "controlnet", + "Upscaler": "upscale_models", + "VAE": "vae", + "unknown": "unknown", + } + return map_legacy.get(model_type, f"{model_type.lower()}s") + + +class HuggingfaceModelSearcher(ModelSearcher): + def search_by_url(self, url: str): + parsed_url = urlparse(url) + + pathname = parsed_url.path + + space, name, *rest_paths = pathname.strip("/").split("/") + + model_id = f"{space}/{name}" + rest_pathname = "/".join(rest_paths) + + response = requests.get(f"https://huggingface.co/api/models/{model_id}") + response.raise_for_status() + res_data: dict = response.json() + + sibling_files: list[str] = [ + x.get("rfilename") for x in res_data.get("siblings", []) + ] + + model_files = utils.filter_with( + utils.filter_with(sibling_files, self._match_model_files()), + self._match_tree_files(rest_pathname), + ) + + image_files = utils.filter_with( + utils.filter_with(sibling_files, self._match_image_files()), + self._match_tree_files(rest_pathname), + ) + image_files = [ + f"https://huggingface.co/{model_id}/resolve/main/{filename}" + for filename in image_files + ] + + models: list[dict] = [] + + for filename in model_files: + fullname = os.path.basename(filename) + extension = os.path.splitext(fullname)[1] + basename = os.path.splitext(fullname)[0] + + description_parts: list[str] = [] + + metadata_info = { + "website": "HuggingFace", + "modelPage": f"https://huggingface.co/{model_id}", + "author": res_data.get("author", None), + "preview": image_files, + } + + description_parts: list[str] = [] + description_parts.append("---") + description_parts.append(yaml.dump(metadata_info).strip()) + description_parts.append("---") + description_parts.append("") + description_parts.append(f"# Trigger Words") + description_parts.append("") + description_parts.append("No trigger words") + description_parts.append("") + description_parts.append(f"# About this version") + description_parts.append("") + description_parts.append("No description about this version") + description_parts.append("") + description_parts.append(f"# {res_data.get('name')}") + description_parts.append("") + description_parts.append("No description about this model") + description_parts.append("") + + model = { + "id": filename, + "shortname": filename, + "fullname": fullname, + "basename": basename, + "extension": extension, + "preview": image_files, + "sizeBytes": 0, + "type": "unknown", + "pathIndex": 0, + "description": "\n".join(description_parts), + "metadata": {}, + "downloadPlatform": "", + "downloadUrl": f"https://huggingface.co/{model_id}/resolve/main/{filename}?download=true", + } + models.append(model) + + return models + + def search_by_hash(self, hash: str): + raise RuntimeError("Hash search is not supported by Huggingface.") + + def _match_model_files(self): + extension = [ + ".bin", + ".ckpt", + ".gguf", + ".onnx", + ".pt", + ".pth", + ".safetensors", + ] + + def _filter_model_files(file: str): + return any(file.endswith(ext) for ext in extension) + + return _filter_model_files + + def _match_image_files(self): + extension = [ + ".png", + ".webp", + ".jpeg", + ".jpg", + ".jfif", + ".gif", + ".apng", + ] + + def _filter_image_files(file: str): + return any(file.endswith(ext) for ext in extension) + + return _filter_image_files + + def _match_tree_files(self, pathname: str): + target, *paths = pathname.split("/") + + def _filter_tree_files(file: str): + if not target: + return True + if target != "tree" and target != "blob": + return True + + prefix_path = "/".join(paths) + return file.startswith(prefix_path) + + return _filter_tree_files + + +def get_model_searcher_by_url(url: str) -> ModelSearcher: + parsed_url = urlparse(url) + host_name = parsed_url.hostname + if host_name == "civitai.com": + return CivitaiModelSearcher() + elif host_name == "huggingface.co": + return HuggingfaceModelSearcher() + return UnknownWebsiteSearcher() diff --git a/py/services.py b/py/services.py index 9a6c7e5..6729987 100644 --- a/py/services.py +++ b/py/services.py @@ -5,6 +5,7 @@ from . import config from . import utils from . import download +from . import searcher def scan_models(): @@ -128,3 +129,180 @@ async def resume_model_download_task(task_id, request): async def delete_model_download_task(task_id): return await download.delete_model_download_task(task_id) + + +def fetch_model_info(model_page: str): + if not model_page: + return [] + + model_searcher = searcher.get_model_searcher_by_url(model_page) + result = model_searcher.search_by_url(model_page) + return result + + +async def download_model_info(scan_mode: str): + utils.print_info(f"Download model info for {scan_mode}") + model_base_paths = config.model_base_paths + for model_type in model_base_paths: + + folders, extensions = folder_paths.folder_names_and_paths[model_type] + for path_index, base_path in enumerate(folders): + files = utils.recursive_search_files(base_path) + + models = folder_paths.filter_files_extensions(files, extensions) + images = folder_paths.filter_files_content_types(files, ["image"]) + image_dict = utils.file_list_to_name_dict(images) + descriptions = folder_paths.filter_files_extensions(files, [".md"]) + description_dict = utils.file_list_to_name_dict(descriptions) + + for fullname in models: + fullname = utils.normalize_path(fullname) + basename = os.path.splitext(fullname)[0] + + abs_model_path = utils.join_path(base_path, fullname) + + image_name = image_dict.get(basename, "no-preview.png") + abs_image_path = utils.join_path(base_path, image_name) + + has_preview = os.path.isfile(abs_image_path) + + description_name = description_dict.get(basename, None) + abs_description_path = ( + utils.join_path(base_path, description_name) + if description_name + else None + ) + has_description = ( + os.path.isfile(abs_description_path) + if abs_description_path + else False + ) + + try: + + utils.print_info(f"Checking model {abs_model_path}") + utils.print_debug(f"Scan mode: {scan_mode}") + utils.print_debug(f"Has preview: {has_preview}") + utils.print_debug(f"Has description: {has_description}") + + if scan_mode != "full" and (has_preview and has_description): + continue + + utils.print_debug(f"Calculate sha256 for {abs_model_path}") + hash_value = utils.calculate_sha256(abs_model_path) + utils.print_info(f"Searching model info by hash {hash_value}") + model_info = searcher.CivitaiModelSearcher().search_by_hash( + hash_value + ) + + preview_url_list = model_info.get("preview", []) + preview_image_url = ( + preview_url_list[0] if preview_url_list else None + ) + if preview_image_url: + utils.print_debug(f"Save preview image to {abs_image_path}") + utils.save_model_preview_image( + abs_model_path, preview_image_url + ) + + description = model_info.get("description", None) + if description: + utils.save_model_description(abs_model_path, description) + except Exception as e: + utils.print_error( + f"Failed to download model info for {abs_model_path}: {e}" + ) + + utils.print_debug("Completed scan model information.") + + +async def migrate_legacy_information(): + import json + import yaml + from PIL import Image + + utils.print_info(f"Migrating legacy information...") + + model_base_paths = config.model_base_paths + for model_type in model_base_paths: + + folders, extensions = folder_paths.folder_names_and_paths[model_type] + for path_index, base_path in enumerate(folders): + files = utils.recursive_search_files(base_path) + + models = folder_paths.filter_files_extensions(files, extensions) + + for fullname in models: + fullname = utils.normalize_path(fullname) + + abs_model_path = utils.join_path(base_path, fullname) + + base_file_name = os.path.splitext(abs_model_path)[0] + + utils.print_debug(f"Try to migrate legacy info for {abs_model_path}") + + preview_path = utils.join_path( + os.path.dirname(abs_model_path), + utils.get_model_preview_name(abs_model_path), + ) + new_preview_path = f"{base_file_name}.webp" + + if os.path.isfile(preview_path) and preview_path != new_preview_path: + utils.print_info(f"Migrate preview image from {fullname}") + with Image.open(preview_path) as image: + image.save(new_preview_path, format="WEBP") + os.remove(preview_path) + + description_path = f"{base_file_name}.md" + + metadata_info = { + "website": "Civitai", + } + + url_info_path = f"{base_file_name}.url" + if os.path.isfile(url_info_path): + with open(url_info_path, "r", encoding="utf-8") as f: + for line in f: + if line.startswith("URL="): + model_page_url = line[len("URL=") :].strip() + metadata_info.update({"modelPage": model_page_url}) + + json_info_path = f"{base_file_name}.json" + if os.path.isfile(json_info_path): + with open(json_info_path, "r", encoding="utf-8") as f: + version = json.load(f) + metadata_info.update( + { + "baseModel": version.get("baseModel"), + "preview": [i["url"] for i in version["images"]], + } + ) + + description_parts: list[str] = [ + "---", + yaml.dump(metadata_info).strip(), + "---", + "", + ] + + text_info_path = f"{base_file_name}.txt" + if os.path.isfile(text_info_path): + with open(text_info_path, "r", encoding="utf-8") as f: + description_parts.append(f.read()) + + description_path = f"{base_file_name}.md" + + if os.path.isfile(text_info_path): + utils.print_info(f"Migrate description from {fullname}") + with open(description_path, "w", encoding="utf-8", newline="") as f: + f.write("\n".join(description_parts)) + + def try_to_remove_file(file_path): + if os.path.isfile(file_path): + os.remove(file_path) + + try_to_remove_file(url_info_path) + try_to_remove_file(text_info_path) + try_to_remove_file(json_info_path) + + utils.print_debug("Completed migrate model information.") diff --git a/py/utils.py b/py/utils.py index a3efa54..5e53db0 100644 --- a/py/utils.py +++ b/py/utils.py @@ -29,6 +29,27 @@ def print_debug(msg, *args, **kwargs): logging.debug(f"[{config.extension_tag}] {msg}", *args, **kwargs) +def _matches(predicate: dict): + def _filter(obj: dict): + return all(obj.get(key, None) == value for key, value in predicate.items()) + + return _filter + + +def filter_with(list: list, predicate): + if isinstance(predicate, dict): + predicate = _matches(predicate) + + return [item for item in list if predicate(item)] + + +async def get_request_body(request) -> dict: + try: + return await request.json() + except: + return {} + + def normalize_path(path: str): normpath = os.path.normpath(path) return normpath.replace(os.path.sep, "/") @@ -202,41 +223,22 @@ def get_model_preview_name(model_path: str): return images[0] if len(images) > 0 else "no-preview.png" -def save_model_preview_image(model_path: str, image_file: Any): - if not isinstance(image_file, web.FileField): - raise RuntimeError("Invalid image file") - - content_type: str = image_file.content_type - if not content_type.startswith("image/"): - raise RuntimeError(f"FileTypeError: expected image, got {content_type}") - - base_dirname = os.path.dirname(model_path) +from PIL import Image +from io import BytesIO - # remove old preview images - old_preview_images = get_model_all_images(model_path) - a1111_civitai_helper_image = False - for image in old_preview_images: - if os.path.splitext(image)[1].endswith(".preview"): - a1111_civitai_helper_image = True - image_path = join_path(base_dirname, image) - os.remove(image_path) - # save new preview image - basename = os.path.splitext(os.path.basename(model_path))[0] - extension = f".{content_type.split('/')[1]}" - new_preview_path = join_path(base_dirname, f"{basename}{extension}") +def save_model_preview_image(model_path: str, image_url: str): + try: + image_response = requests.get(image_url) + image_response.raise_for_status() - with open(new_preview_path, "wb") as f: - f.write(image_file.file.read()) + basename = os.path.splitext(model_path)[0] + preview_path = f"{basename}.webp" + image = Image.open(BytesIO(image_response.content)) + image.save(preview_path, "WEBP") - # TODO Is it possible to abandon the current rules and adopt the rules of a1111 civitai_helper? - if a1111_civitai_helper_image: - """ - Keep preview image of a1111_civitai_helper - """ - new_preview_path = join_path(base_dirname, f"{basename}.preview{extension}") - with open(new_preview_path, "wb") as f: - f.write(image_file.file.read()) + except Exception as e: + print_error(f"Failed to download image: {e}") def get_model_all_descriptions(model_path: str): @@ -361,20 +363,43 @@ def get_setting_value(request: web.Request, key: str, default: Any = None) -> An return settings.get(setting_id, default) -from dataclasses import asdict, is_dataclass +async def send_json(event: str, data: Any, sid: str = None): + await config.serverInstance.send_json(event, data, sid) -def unpack_dataclass(data: Any): - if isinstance(data, dict): - return {key: unpack_dataclass(value) for key, value in data.items()} - elif isinstance(data, list): - return [unpack_dataclass(x) for x in data] - elif is_dataclass(data): - return asdict(data) - else: - return data +import sys +import subprocess +import importlib.util +import importlib.metadata -async def send_json(event: str, data: Any, sid: str = None): - detail = unpack_dataclass(data) - await config.serverInstance.send_json(event, detail, sid) +def is_installed(package_name: str): + try: + dist = importlib.metadata.distribution(package_name) + except importlib.metadata.PackageNotFoundError: + try: + spec = importlib.util.find_spec(package_name) + except ModuleNotFoundError: + return False + + return spec is not None + + return dist is not None + + +def pip_install(package_name: str): + subprocess.run([sys.executable, "-m", "pip", "install", package_name], check=True) + + +import hashlib + + +def calculate_sha256(path, buffer_size=1024 * 1024): + sha256 = hashlib.sha256() + with open(path, "rb") as f: + while True: + data = f.read(buffer_size) + if not data: + break + sha256.update(data) + return sha256.hexdigest() diff --git a/pyproject.toml b/pyproject.toml index 95fc4b0..3a398ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,7 @@ name = "comfyui-model-manager" description = "Manage models: browsing, download and delete." version = "2.0.3" license = "LICENSE" +dependencies = ["markdownify"] [project.urls] Repository = "https://github.com/hayden-fr/ComfyUI-Model-Manager" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..06a83f1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +markdownify \ No newline at end of file diff --git a/src/components/DialogCreateTask.vue b/src/components/DialogCreateTask.vue index e7dcef7..089449d 100644 --- a/src/components/DialogCreateTask.vue +++ b/src/components/DialogCreateTask.vue @@ -70,7 +70,6 @@ import { request } from 'hooks/request' import { useToast } from 'hooks/toast' import Button from 'primevue/button' import { VersionModel } from 'types/typings' -import { previewUrlToFile } from 'utils/common' import { ref } from 'vue' const { isMobile } = useConfig() @@ -89,38 +88,11 @@ const searchModelsByUrl = async () => { } const createDownTask = async (data: VersionModel) => { - const formData = new FormData() - loading.show() - // set base info - formData.append('type', data.type) - formData.append('pathIndex', data.pathIndex.toString()) - formData.append('fullname', data.fullname) - // set preview - const previewFile = await previewUrlToFile(data.preview as string).catch( - () => { - loading.hide() - toast.add({ - severity: 'error', - summary: 'Error', - detail: 'Failed to download preview', - life: 15000, - }) - throw new Error('Failed to download preview') - }, - ) - formData.append('previewFile', previewFile) - // set description - formData.append('description', data.description) - // set model download info - formData.append('downloadPlatform', data.downloadPlatform) - formData.append('downloadUrl', data.downloadUrl) - formData.append('sizeBytes', data.sizeBytes.toString()) - formData.append('hashes', JSON.stringify(data.hashes)) await request('/model', { method: 'POST', - body: formData, + body: JSON.stringify(data), }) .then(() => { dialog.close({ key: 'model-manager-create-task' }) diff --git a/src/hooks/config.ts b/src/hooks/config.ts index 358b148..0674d3d 100644 --- a/src/hooks/config.ts +++ b/src/hooks/config.ts @@ -1,9 +1,10 @@ -import { useRequest } from 'hooks/request' +import { request, useRequest } from 'hooks/request' import { defineStore } from 'hooks/store' -import { app } from 'scripts/comfyAPI' +import { $el, app, ComfyDialog } from 'scripts/comfyAPI' import { onMounted, onUnmounted, ref } from 'vue' +import { useToast } from './toast' -export const useConfig = defineStore('config', () => { +export const useConfig = defineStore('config', (store) => { const mobileDeviceBreakPoint = 759 const isMobile = ref(window.innerWidth < mobileDeviceBreakPoint) @@ -36,7 +37,7 @@ export const useConfig = defineStore('config', () => { refresh, } - useAddConfigSettings() + useAddConfigSettings(store) return config }) @@ -49,7 +50,41 @@ declare module 'hooks/store' { } } -function useAddConfigSettings() { +function useAddConfigSettings(store: import('hooks/store').StoreProvider) { + const { toast } = useToast() + + const confirm = (opts: { + message?: string + accept?: () => void + reject?: () => void + }) => { + const dialog = new ComfyDialog('div', []) + + dialog.show( + $el('div', [ + $el('p', { textContent: opts.message }), + $el('div.flex.gap-4', [ + $el('button.flex-1', { + textContent: 'Cancel', + onclick: () => { + opts.reject?.() + dialog.close() + document.body.removeChild(dialog.element) + }, + }), + $el('button.flex-1', { + textContent: 'Continue', + onclick: () => { + opts.accept?.() + dialog.close() + document.body.removeChild(dialog.element) + }, + }), + ]), + ]), + ) + } + onMounted(() => { // API keys app.ui?.settings.addSetting({ @@ -65,5 +100,144 @@ function useAddConfigSettings() { type: 'text', defaultValue: undefined, }) + + // Migrate + app.ui?.settings.addSetting({ + id: 'ModelManager.Migrate.Migrate', + name: 'Migrate information from cdb-boop/main', + defaultValue: '', + type: () => { + return $el('button.p-button.p-component.p-button-secondary', { + textContent: 'Migrate', + onclick: () => { + confirm({ + message: [ + 'This operation will delete old files and override current files if it exists.', + // 'This may take a while and generate MANY server requests!', + 'Continue?', + ].join('\n'), + accept: () => { + store.loading.loading.value = true + request('/migrate', { + method: 'POST', + }) + .then(() => { + toast.add({ + severity: 'success', + summary: 'Complete migration', + life: 2000, + }) + store.models.refresh() + }) + .catch((err) => { + toast.add({ + severity: 'error', + summary: 'Error', + detail: err.message ?? 'Failed to migrate information', + life: 15000, + }) + }) + .finally(() => { + store.loading.loading.value = false + }) + }, + }) + }, + }) + }, + }) + + // Scan information + app.ui?.settings.addSetting({ + id: 'ModelManager.ScanFiles.Full', + name: "Override all models' information and preview", + defaultValue: '', + type: () => { + return $el('button.p-button.p-component.p-button-secondary', { + textContent: 'Full Scan', + onclick: () => { + confirm({ + message: [ + 'This operation will override current files.', + 'This may take a while and generate MANY server requests!', + 'USE AT YOUR OWN RISK! Continue?', + ].join('\n'), + accept: () => { + store.loading.loading.value = true + request('/model-info/scan', { + method: 'POST', + body: JSON.stringify({ scanMode: 'full' }), + }) + .then(() => { + toast.add({ + severity: 'success', + summary: 'Complete download information', + life: 2000, + }) + store.models.refresh() + }) + .catch((err) => { + toast.add({ + severity: 'error', + summary: 'Error', + detail: err.message ?? 'Failed to download information', + life: 15000, + }) + }) + .finally(() => { + store.loading.loading.value = false + }) + }, + }) + }, + }) + }, + }) + + app.ui?.settings.addSetting({ + id: 'ModelManager.ScanFiles.Incremental', + name: 'Download missing information or preview', + defaultValue: '', + type: () => { + return $el('button.p-button.p-component.p-button-secondary', { + textContent: 'Diff Scan', + onclick: () => { + confirm({ + message: [ + 'Download missing information or preview.', + 'This may take a while and generate MANY server requests!', + 'USE AT YOUR OWN RISK! Continue?', + ].join('\n'), + accept: () => { + store.loading.loading.value = true + request('/model-info/scan', { + method: 'POST', + body: JSON.stringify({ scanMode: 'diff' }), + }) + .then(() => { + toast.add({ + severity: 'success', + summary: 'Complete download information', + life: 2000, + }) + store.models.refresh() + }) + .catch((err) => { + toast.add({ + severity: 'error', + summary: 'Error', + detail: err.message ?? 'Failed to download information', + life: 15000, + }) + }) + .finally(() => { + store.loading.loading.value = false + }) + }, + }) + }, + }) + }, + }) }) } diff --git a/src/hooks/download.ts b/src/hooks/download.ts index b4d3c1a..9945540 100644 --- a/src/hooks/download.ts +++ b/src/hooks/download.ts @@ -1,5 +1,4 @@ import { useLoading } from 'hooks/loading' -import { MarkdownTool, useMarkdown } from 'hooks/markdown' import { request } from 'hooks/request' import { defineStore } from 'hooks/store' import { useToast } from 'hooks/toast' @@ -157,253 +156,8 @@ declare module 'hooks/store' { } } -abstract class ModelSearch { - constructor(readonly md: MarkdownTool) {} - - abstract search(pathname: string): Promise -} - -class Civitai extends ModelSearch { - async search(searchUrl: string): Promise { - const { pathname, searchParams } = new URL(searchUrl) - - const [, modelId] = pathname.match(/^\/models\/(\d*)/) ?? [] - const versionId = searchParams.get('modelVersionId') - - if (!modelId) { - return Promise.resolve([]) - } - - return fetch(`https://civitai.com/api/v1/models/${modelId}`) - .then((response) => response.json()) - .then((resData) => { - const modelVersions: any[] = resData.modelVersions.filter( - (version: any) => { - if (versionId) { - return version.id == versionId - } - return true - }, - ) - - const models: VersionModel[] = [] - - for (const version of modelVersions) { - const modelFiles: any[] = version.files.filter( - (file: any) => file.type === 'Model', - ) - - const shortname = modelFiles.length > 0 ? version.name : undefined - - for (const file of modelFiles) { - const fullname = file.name - const extension = `.${fullname.split('.').pop()}` - const basename = fullname.replace(extension, '') - - models.push({ - id: file.id, - shortname: shortname ?? basename, - fullname: fullname, - basename: basename, - extension: extension, - preview: version.images.map((i: any) => i.url), - sizeBytes: file.sizeKB * 1024, - type: this.resolveType(resData.type), - pathIndex: 0, - description: [ - '---', - ...[ - `website: Civitai`, - `modelPage: https://civitai.com/models/${modelId}?modelVersionId=${version.id}`, - `author: ${resData.creator?.username}`, - version.baseModel && `baseModel: ${version.baseModel}`, - file.hashes && `hashes:`, - ...Object.entries(file.hashes ?? {}).map( - ([key, value]) => ` ${key}: ${value}`, - ), - file.metadata && `metadata:`, - ...Object.entries(file.metadata ?? {}).map( - ([key, value]) => ` ${key}: ${value}`, - ), - ].filter(Boolean), - '---', - '', - '# Trigger Words', - `\n${(version.trainedWords ?? ['No trigger words']).join(', ')}\n`, - '# About this version', - this.resolveDescription( - version.description, - '\nNo description about this version\n', - ), - `# ${resData.name}`, - this.resolveDescription( - resData.description, - 'No description about this model', - ), - ].join('\n'), - metadata: file.metadata, - downloadPlatform: 'civitai', - downloadUrl: file.downloadUrl, - hashes: file.hashes, - }) - } - } - - return models - }) - } - - private resolveType(type: string) { - const mapLegacy = { - TextualInversion: 'embeddings', - LoCon: 'loras', - DoRA: 'loras', - Controlnet: 'controlnet', - Upscaler: 'upscale_models', - VAE: 'vae', - } - return mapLegacy[type] ?? `${type.toLowerCase()}s` - } - - private resolveDescription(content: string, defaultContent: string) { - const mdContent = this.md.parse(content ?? '').trim() - return mdContent || defaultContent - } -} - -class Huggingface extends ModelSearch { - async search(searchUrl: string): Promise { - const { pathname } = new URL(searchUrl) - const [, space, name, ...restPaths] = pathname.split('/') - - if (!space || !name) { - return Promise.resolve([]) - } - - const modelId = `${space}/${name}` - const restPathname = restPaths.join('/') - - return fetch(`https://huggingface.co/api/models/${modelId}`) - .then((response) => response.json()) - .then((resData) => { - const siblingFiles: string[] = resData.siblings.map( - (item: any) => item.rfilename, - ) - - const modelFiles: string[] = this.filterTreeFiles( - this.filterModelFiles(siblingFiles), - restPathname, - ) - const images: string[] = this.filterTreeFiles( - this.filterImageFiles(siblingFiles), - restPathname, - ).map((filename) => { - return `https://huggingface.co/${modelId}/resolve/main/${filename}` - }) - - const models: VersionModel[] = [] - - for (const filename of modelFiles) { - const fullname = filename.split('/').pop()! - const extension = `.${fullname.split('.').pop()}` - const basename = fullname.replace(extension, '') - - models.push({ - id: filename, - shortname: filename, - fullname: fullname, - basename: basename, - extension: extension, - preview: images, - sizeBytes: 0, - type: 'unknown', - pathIndex: 0, - description: [ - '---', - ...[ - `website: HuggingFace`, - `modelPage: https://huggingface.co/${modelId}`, - `author: ${resData.author}`, - ].filter(Boolean), - '---', - '', - '# Trigger Words', - '\nNo trigger words\n', - '# About this version', - '\nNo description about this version\n', - `# ${resData.modelId}`, - '\nNo description about this model\n', - ].join('\n'), - metadata: {}, - downloadPlatform: 'huggingface', - downloadUrl: `https://huggingface.co/${modelId}/resolve/main/${filename}?download=true`, - }) - } - - return models - }) - } - - private filterTreeFiles(files: string[], pathname: string) { - const [target, , ...paths] = pathname.split('/') - - if (!target) return files - - if (target !== 'tree' && target !== 'blob') return files - - const pathPrefix = paths.join('/') - return files.filter((file) => { - return file.startsWith(pathPrefix) - }) - } - - private filterModelFiles(files: string[]) { - const extension = [ - '.bin', - '.ckpt', - '.gguf', - '.onnx', - '.pt', - '.pth', - '.safetensors', - ] - return files.filter((file) => { - const ext = file.split('.').pop() - return ext ? extension.includes(`.${ext}`) : false - }) - } - - private filterImageFiles(files: string[]) { - const extension = [ - '.png', - '.webp', - '.jpeg', - '.jpg', - '.jfif', - '.gif', - '.apng', - ] - - return files.filter((file) => { - const ext = file.split('.').pop() - return ext ? extension.includes(`.${ext}`) : false - }) - } -} - -class UnknownWebsite extends ModelSearch { - async search(): Promise { - return Promise.reject( - new Error( - 'Unknown Website, please input a URL from huggingface.co or civitai.com.', - ), - ) - } -} - export const useModelSearch = () => { const loading = useLoading() - const md = useMarkdown() const { toast } = useToast() const data = ref<(SelectOptions & { item: VersionModel })[]>([]) const current = ref() @@ -414,22 +168,9 @@ export const useModelSearch = () => { return Promise.resolve([]) } - let instance: ModelSearch = new UnknownWebsite(md) - - const { hostname } = new URL(url ?? '') - - if (hostname === 'civitai.com') { - instance = new Civitai(md) - } - - if (hostname === 'huggingface.co') { - instance = new Huggingface(md) - } - loading.show() - return instance - .search(url) - .then((resData) => { + return request(`/model-info?model-page=${encodeURIComponent(url)}`, {}) + .then((resData: VersionModel[]) => { data.value = resData.map((item) => ({ label: item.shortname, value: item.id, diff --git a/src/hooks/loading.ts b/src/hooks/loading.ts index 349b7d2..2a66005 100644 --- a/src/hooks/loading.ts +++ b/src/hooks/loading.ts @@ -31,6 +31,12 @@ export const useGlobalLoading = defineStore('loading', () => { return { loading } }) +declare module 'hooks/store' { + interface StoreProvider { + loading: ReturnType + } +} + export const useLoading = () => { const timer = ref() diff --git a/src/hooks/markdown.ts b/src/hooks/markdown.ts index a358d1e..6fd005c 100644 --- a/src/hooks/markdown.ts +++ b/src/hooks/markdown.ts @@ -1,6 +1,5 @@ import MarkdownIt from 'markdown-it' import metadata_block from 'markdown-it-metadata-block' -import TurndownService from 'turndown' import yaml from 'yaml' interface MarkdownOptions { @@ -31,19 +30,7 @@ export const useMarkdown = (opts?: MarkdownOptions) => { return self.renderToken(tokens, idx, options) } - const turndown = new TurndownService({ - headingStyle: 'atx', - bulletListMarker: '-', - }) - - turndown.addRule('paragraph', { - filter: 'p', - replacement: function (content) { - return `\n\n${content}` - }, - }) - - return { render: md.render.bind(md), parse: turndown.turndown.bind(turndown) } + return { render: md.render.bind(md) } } export type MarkdownTool = ReturnType diff --git a/src/hooks/model.ts b/src/hooks/model.ts index 268a057..3a652ea 100644 --- a/src/hooks/model.ts +++ b/src/hooks/model.ts @@ -7,7 +7,7 @@ import { useToast } from 'hooks/toast' import { cloneDeep } from 'lodash' import { app } from 'scripts/comfyAPI' import { BaseModel, Model, SelectEvent } from 'types/typings' -import { bytesToSize, formatDate, previewUrlToFile } from 'utils/common' +import { bytesToSize, formatDate } from 'utils/common' import { ModelGrid } from 'utils/legacy' import { genModelKey, resolveModelTypeLoader } from 'utils/model' import { @@ -29,18 +29,17 @@ export const useModels = defineStore('models', (store) => { const loading = useLoading() const updateModel = async (model: BaseModel, data: BaseModel) => { - const formData = new FormData() + const updateData = new Map() let oldKey: string | null = null // Check current preview if (model.preview !== data.preview) { - const previewFile = await previewUrlToFile(data.preview as string) - formData.append('previewFile', previewFile) + updateData.set('previewFile', data.preview) } // Check current description if (model.description !== data.description) { - formData.append('description', data.description) + updateData.set('description', data.description) } // Check current name and pathIndex @@ -49,19 +48,19 @@ export const useModels = defineStore('models', (store) => { model.pathIndex !== data.pathIndex ) { oldKey = genModelKey(model) - formData.append('type', data.type) - formData.append('pathIndex', data.pathIndex.toString()) - formData.append('fullname', data.fullname) + updateData.set('type', data.type) + updateData.set('pathIndex', data.pathIndex.toString()) + updateData.set('fullname', data.fullname) } - if (formData.keys().next().done) { + if (updateData.size === 0) { return } loading.show() await request(`/model/${model.type}/${model.pathIndex}/${model.fullname}`, { method: 'PUT', - body: formData, + body: JSON.stringify(Object.fromEntries(updateData.entries())), }) .catch((err) => { const error_message = err.message ?? err.error @@ -246,14 +245,17 @@ export const useModelBaseInfoEditor = (formInstance: ModelFormInstance) => { interface FieldsItem { key: keyof Model - formatter: (val: any) => string + formatter: (val: any) => string | undefined | null } const baseInfo = computed(() => { const fields: FieldsItem[] = [ { key: 'type', - formatter: () => modelData.value.type, + formatter: () => + modelData.value.type in modelFolders.value + ? modelData.value.type + : undefined, }, { key: 'pathIndex', diff --git a/src/scripts/comfyAPI.ts b/src/scripts/comfyAPI.ts index 61394ff..1f5445b 100644 --- a/src/scripts/comfyAPI.ts +++ b/src/scripts/comfyAPI.ts @@ -5,3 +5,4 @@ export const $el = window.comfyAPI.ui.$el export const ComfyApp = window.comfyAPI.app.ComfyApp export const ComfyButton = window.comfyAPI.button.ComfyButton +export const ComfyDialog = window.comfyAPI.dialog.ComfyDialog diff --git a/src/types/global.d.ts b/src/types/global.d.ts index 9a4e612..f3fb9ee 100644 --- a/src/types/global.d.ts +++ b/src/types/global.d.ts @@ -112,6 +112,7 @@ declare namespace ComfyAPI { settings: ComfySettingsDialog menuHamburger?: HTMLDivElement menuContainer?: HTMLDivElement + dialog: dialog.ComfyDialog } type SettingInputType = @@ -197,6 +198,15 @@ declare namespace ComfyAPI { constructor(...buttons: (HTMLElement | ComfyButton)[]): ComfyButtonGroup } } + + namespace dialog { + class ComfyDialog { + constructor(type = 'div', buttons: HTMLElement[] = null) + element: HTMLElement + close(): void + show(html: string | HTMLElement | HTMLElement[]): void + } + } } declare namespace lightGraph { diff --git a/src/utils/common.ts b/src/utils/common.ts index 5d6a259..fd1c0f7 100644 --- a/src/utils/common.ts +++ b/src/utils/common.ts @@ -26,14 +26,3 @@ export const bytesToSize = ( export const formatDate = (date: number | string | Date) => { return dayjs(date).format('YYYY-MM-DD HH:mm:ss') } - -export const previewUrlToFile = async (url: string) => { - return fetch(url) - .then((res) => res.blob()) - .then((blob) => { - const type = blob.type - const extension = type.split('/')[1] - const file = new File([blob], `preview.${extension}`, { type }) - return file - }) -}