diff --git a/__init__.py b/__init__.py index 704aadd..b791144 100644 --- a/__init__.py +++ b/__init__.py @@ -8,6 +8,8 @@ import importlib import re import base64 +import hashlib +import markdownify from aiohttp import web import server @@ -187,6 +189,7 @@ def ui_rules(): Rule("model-show-add-button", True, bool), Rule("model-show-copy-button", True, bool), Rule("model-show-load-workflow-button", True, bool), + Rule("model-show-open-model-url-button", False, bool), Rule("model-info-button-on-left", False, bool), Rule("model-add-embedding-extension", False, bool), @@ -240,6 +243,93 @@ def get_def_headers(url=""): return def_headers +def civitai_get_model_info(sha256_hash: str): + url_api_hash = r"https://civitai.com/api/v1/model-versions/by-hash/" + sha256_hash + hash_response = requests.get(url_api_hash) + if hash_response.status_code != 200: + return {} + hash_info = hash_response.json() + if len(hash_info) == 0: + return {} + model_id = hash_info["modelId"] + + url_api_model = r"https://civitai.com/api/v1/models/" + str(model_id) + model_response = requests.get(url_api_model) + if model_response.status_code != 200: + return {} + return model_response.json() + + +def search_web_for_model_url(sha256_hash): + model_info = civitai_get_model_info(sha256_hash) + if len(model_info) > 0: + model_id = model_info["id"] + version_id = None + for model_version in model_info["modelVersions"]: + for files in model_version["files"]: + if files["hashes"]["SHA256"].lower() == sha256_hash.lower(): + version_id = model_version["id"] + break + if version_id is not None: break + return f"https://civitai.com/models/{model_id}?modelVersionId={version_id}" + + # TODO: search other websites + + return "" + + +def search_web_for_model_notes(sha256_hash): + model_info = civitai_get_model_info(sha256_hash) + if len(model_info) > 0: + model_description = model_info.get("description", "") + model_version_description = "" + model_trigger_words = [] + for model_version in model_info["modelVersions"]: + for files in model_version["files"]: + if files["hashes"]["SHA256"].lower() == sha256_hash.lower(): + model_version_description = model_version.get("description", "") + model_trigger_words = model_version.get("trainedWords", "") + break + if model_version_description != "": break + + notes = "" + if len(model_trigger_words) > 0: + notes += "# Trigger Words\n\n" + model_trigger_words = [re.sub(",$", "", s.strip()) for s in model_trigger_words] + join_separator = ', ' + for s in model_trigger_words: + if ',' in s: + join_separator = '\n' + break + if join_separator == '\n' and len(model_trigger_words) > 1: + model_trigger_words = ["* " + s for s in model_trigger_words] + notes += join_separator.join(model_trigger_words) + if model_version_description != "": + if len(notes) > 0: notes += "\n\n" + notes += "# About this version\n\n" + notes += markdownify.markdownify(model_version_description) + if model_description != "": + if len(notes) > 0: notes += "\n\n" + notes += "# " + model_info.get("name", str(model_info["id"])) + "\n\n" + notes += markdownify.markdownify(model_description) + notes = notes.strip() + return notes + + # TODO: search other websites + + return "" + + +def hash_file(path, buffer_size=1024*1024): + sha256 = hashlib.sha256() + with open(path, 'rb') as f: + while True: + data = f.read(buffer_size) + if not data: break + sha256.update(data) + return sha256.hexdigest() + + @server.PromptServer.instance.routes.get("/model-manager/timestamp") async def get_timestamp(request): return web.json_response({ "timestamp": datetime.now().timestamp() }) @@ -1004,6 +1094,23 @@ async def get_model_info(request): return web.json_response(result) +@server.PromptServer.instance.routes.get("/model-manager/model/info/web-url") +async def get_model_web_url(request): + model_path = request.query.get("path", None) + if model_path is None: + web.json_response({ "url": "" }) + model_path = urllib.parse.unquote(model_path) + + abs_path, model_type = search_path_to_system_path(model_path) + if abs_path is None: + web.json_response({ "url": "" }) + + sha256_hash = hash_file(abs_path) + url = search_web_for_model_url(sha256_hash) + + return web.json_response({ "url": url }) + + @server.PromptServer.instance.routes.get("/model-manager/system-separator") async def get_system_separator(request): return web.json_response(os.path.sep) @@ -1218,12 +1325,50 @@ async def set_notes(request): except ValueError as e: print(e, file=sys.stderr, flush=True) result["alert"] = "Failed to save notes!\n\n" + str(e) - web.json_response(result) + return web.json_response(result) result["success"] = True return web.json_response(result) +@server.PromptServer.instance.routes.post("/model-manager/notes/download") +async def try_download_notes(request): + result = { "success": False } + + model_path = request.query.get("path", None) + if model_path is None: + result["alert"] = "Missing model path!" + return web.json_response(result) + model_path = urllib.parse.unquote(model_path) + + abs_path, model_type = search_path_to_system_path(model_path) + if abs_path is None: + result["alert"] = "Invalid model path!" + return web.json_response(result) + + overwrite = request.query.get("overwrite", None) + overwrite = not (overwrite == "False" or overwrite == "false" or overwrite == None) + notes_path = os.path.splitext(abs_path)[0] + ".txt" + if not overwrite and os.path.isfile(notes_path): + result["alert"] = "Notes already exist!" + return web.json_response(result) + + sha256_hash = hash_file(abs_path) + notes = search_web_for_model_notes(sha256_hash) + if not notes.isspace() and notes != "": + try: + with open(notes_path, "w", encoding="utf-8") as f: + f.write(notes) + result["success"] = True + except ValueError as e: + print(e, file=sys.stderr, flush=True) + result["alert"] = "Failed to save notes!\n\n" + str(e) + return web.json_response(result) + + result["notes"] = notes + return web.json_response(result) + + WEB_DIRECTORY = "web" NODE_CLASS_MAPPINGS = {} __all__ = ["NODE_CLASS_MAPPINGS"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..06a83f1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +markdownify \ No newline at end of file diff --git a/web/model-manager.js b/web/model-manager.js index a6cedd1..ef310d2 100644 --- a/web/model-manager.js +++ b/web/model-manager.js @@ -139,6 +139,22 @@ async function loadWorkflow(url) { app.handleFile(file); } +/** + * @param {string} modelPath + * @returns {Promise} + */ +async function tryOpenModelUrl(modelPath) { + const webUrlResponse = await comfyRequest(`/model-manager/model/info/web-url?path=${modelPath}`); + try { + const modelUrl = new URL(webUrlResponse["url"]); + window.open(modelUrl, '_blank').focus(); + } + catch (exception) { + return false; + } + return true; +} + const modelNodeType = { "checkpoints": "CheckpointLoaderSimple", "clip": "CLIPLoader", @@ -1885,6 +1901,7 @@ class ModelGrid { const showAddButton = canShowButtons && settingsElements["model-show-add-button"].checked; const showCopyButton = canShowButtons && settingsElements["model-show-copy-button"].checked; const showLoadWorkflowButton = canShowButtons && settingsElements["model-show-load-workflow-button"].checked; + const shouldShowTryOpenModelUrl = canShowButtons && settingsElements["model-show-open-model-url-button"].checked; const strictDragToAdd = settingsElements["model-add-drag-strict-on-field"].checked; const addOffset = parseInt(settingsElements["model-add-offset"].value); const showModelExtension = settingsElements["model-show-label-extensions"].checked; @@ -1957,6 +1974,22 @@ class ModelGrid { }).element, ); } + if (shouldShowTryOpenModelUrl) { + actionButtons.push( + new ComfyButton({ + icon: "open-in-new", + tooltip: "Attempt to open model url page in a new tab.", + classList: "comfyui-button icon-button model-button", + action: async (e) => { + const [button, icon, span] = comfyButtonDisambiguate(e.target); + button.disabled = true; + const success = await tryOpenModelUrl(searchPath); + comfyButtonAlert(e.target, success, "mdi-check-bold", "mdi-close-thick"); + button.disabled = false; + }, + }).element + ); + } const infoButtons = [ new ComfyButton({ icon: "information-outline", @@ -2494,17 +2527,30 @@ class ModelInfo { innerHtml.push($el("div", [ previewSelect.elements.previews, + $el("div.row.tab-header", { style: { "flex-direction": "row" } }, [ + new ComfyButton({ + icon: "arrow-bottom-left-bold-box-outline", + tooltip: "Attempt to load preview image workflow", + classList: "comfyui-button icon-button", + action: async () => { + const urlString = previewSelect.elements.defaultPreviews.children[0].src; + await loadWorkflow(urlString); + }, + }).element, + new ComfyButton({ + icon: "open-in-new", + tooltip: "Attempt to open model url page in a new tab.", + classList: "comfyui-button icon-button", + action: async (e) => { + const [button, icon, span] = comfyButtonDisambiguate(e.target); + button.disabled = true; + const success = await tryOpenModelUrl(path); + comfyButtonAlert(e.target, success, "mdi-check-bold", "mdi-close-thick"); + button.disabled = false; + }, + }).element, + ]), $el("div.row.tab-header", [ - $el("div", [ - new ComfyButton({ - content: "Load Workflow", - tooltip: "Attempt to load preview image workflow", - action: async () => { - const urlString = previewSelect.elements.defaultPreviews.children[0].src; - await loadWorkflow(urlString); - }, - }).element, - ]), $el("div.row.tab-header-flex-block", [ previewSelect.elements.radioGroup, ]), @@ -2691,6 +2737,46 @@ class ModelInfo { }, }).element; + const downloadNotesButton = new ComfyButton({ + icon: "earth-arrow-down", + tooltip: "Attempt to download model info from the internet.", + classList: "comfyui-button icon-button", + action: async (e) => { + if (this.#savedNotesValue !== "") { + const overwriteNoteConfirmation = window.confirm("Overwrite note?"); + if (!overwriteNoteConfirmation) { + comfyButtonAlert(e.target, false, "mdi-check-bold", "mdi-close-thick"); + return; + } + } + + const [button, icon, span] = comfyButtonDisambiguate(e.target); + button.disabled = true; + const [success, downloadedNotesValue] = await comfyRequest( + `/model-manager/notes/download?path=${path}&overwrite=True`, + { + method: "POST", + body: {}, + } + ).then((data) => { + const success = data["success"]; + const message = data["alert"]; + if (message !== undefined) { + window.alert(message); + } + return [success, data["notes"]]; + }).catch((err) => { + return [false, ""]; + }); + if (success) { + this.#savedNotesValue = downloadedNotesValue; + this.elements.notes.value = downloadedNotesValue; + } + comfyButtonAlert(e.target, success, "mdi-check-bold", "mdi-close-thick"); + button.disabled = false; + }, + }).element; + const saveDebounce = debounce(async() => { const saveIconClass = "mdi-" + saveIcon; const savingIconClass = "mdi-" + savingIcon; @@ -2750,6 +2836,7 @@ class ModelInfo { }, [ $el("h1", ["Notes"]), saveNotesButton, + downloadNotesButton, ]), $el("div", { style: { "display": "flex", "height": "100%", "min-height": "60px" }, @@ -3162,7 +3249,7 @@ async function getModelInfos(urlText) { const description = [ tags !== undefined ? "# Trigger Words" : undefined, tags?.join(tags.some((tag) => { return tag.includes(","); }) ? "\n" : ", "), - version["description"] !== undefined ? "# About this version " : undefined, + version["description"] !== undefined ? "# About this version" : undefined, version["description"], civitaiInfo["description"] !== undefined ? "# " + name : undefined, civitaiInfo["description"], @@ -3882,6 +3969,7 @@ class SettingsView { /** @type {HTMLInputElement} */ "model-show-add-button": null, /** @type {HTMLInputElement} */ "model-show-copy-button": null, /** @type {HTMLInputElement} */ "model-show-load-workflow-button": null, + /** @type {HTMLInputElement} */ "model-show-open-model-url-button": null, /** @type {HTMLInputElement} */ "model-info-button-on-left": null, /** @type {HTMLInputElement} */ "model-add-embedding-extension": null, @@ -4111,6 +4199,9 @@ class SettingsView { $checkbox({ $: (el) => (settings["model-show-load-workflow-button"] = el), textContent: "Show \"Load Workflow\" button", + }),$checkbox({ + $: (el) => (settings["model-show-open-model-url-button"] = el), + textContent: "Show \"Open Model Url\" button", }), $checkbox({ $: (el) => (settings["model-info-button-on-left"] = el),