From 81b9c0569476e6314f9a415d389994ba84da8d42 Mon Sep 17 00:00:00 2001 From: mrhan1993 <50648276+mrhan1993@users.noreply.github.com> Date: Mon, 8 Apr 2024 14:05:17 +0800 Subject: [PATCH] lora url support for replicate --- fooocusapi/utils/lora_manager.py | 56 ++++++++++++++++++++++++++++++++ predict.py | 25 +++++++++++++- 2 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 fooocusapi/utils/lora_manager.py diff --git a/fooocusapi/utils/lora_manager.py b/fooocusapi/utils/lora_manager.py new file mode 100644 index 0000000..f6e72b1 --- /dev/null +++ b/fooocusapi/utils/lora_manager.py @@ -0,0 +1,56 @@ +""" +Manager loras from url + +@author: TechnikMax +@github: https://github.com/TechnikMax +""" +import hashlib +import os +import requests + + +def _hash_url(url): + """Generates a hash value for a given URL.""" + return hashlib.md5(url.encode('utf-8')).hexdigest() + + +class LoraManager: + """ + Manager loras from url + """ + def __init__(self): + self.cache_dir = "/models/loras/" + + def _download_lora(self, url): + """ + Downloads a LoRa from a URL and saves it in the cache. + """ + url_hash = _hash_url(url) + filepath = os.path.join(self.cache_dir, f"{url_hash}.safetensors") + file_name = f"{url_hash}.safetensors" + + if not os.path.exists(filepath): + print(f"start download for: {url}") + + try: + response = requests.get(url, timeout=10, stream=True) + response.raise_for_status() + with open(filepath, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + print(f"Download successfully, saved as {file_name}") + + except Exception as e: + raise Exception(f"error downloading {url}: {e}") from e + + else: + print(f"LoRa already downloaded {url}") + return file_name + + def check(self, urls): + """Manages the specified LoRAs: downloads missing ones and returns their file names.""" + paths = [] + for url in urls: + path = self._download_lora(url) + paths.append(path) + return paths diff --git a/predict.py b/predict.py index 81cc89b..3a2ec9c 100644 --- a/predict.py +++ b/predict.py @@ -10,6 +10,7 @@ from PIL import Image from cog import BasePredictor, BaseModel, Input, Path +from fooocusapi.utils.lora_manager import LoraManager from fooocusapi.utils.file_utils import output_dir from fooocusapi.models.common.task import GenerationFinishReason from fooocusapi.parameters import ( @@ -72,6 +73,12 @@ def predict( image_seed: int = Input( default=-1, description="Seed to generate image, -1 for random"), + use_default_loras: bool = Input( + default=True, + description="Use default LoRAs"), + loras_custom_urls: str = Input( + default="", + description="Custom LoRAs URLs in the format 'url,weight' provide multiple seperated by ; (example 'url1,0.3;url2,0.1')"), sharpness: float = Input( default=2.0, ge=0.0, le=30.0), @@ -182,7 +189,23 @@ def predict( base_model_name = default_base_model_name refiner_model_name = default_refiner_model_name - loras = copy.copy(default_loras) + + lora_manager = LoraManager() + + # Use default loras if selected + loras = copy.copy(default_loras) if use_default_loras else [] + + # add custom user loras if provided + if loras_custom_urls: + urls = [url.strip() for url in loras_custom_urls.split(';')] + + loras_with_weights = [url.split(',') for url in urls] + + custom_lora_paths = lora_manager.check([lw[0] for lw in loras_with_weights]) + custom_loras = [[path, float(lw[1]) if len(lw) > 1 else 1.0] for path, lw in + zip(custom_lora_paths, loras_with_weights)] + + loras.extend(custom_loras) style_selections_arr = [] for s in style_selections.strip().split(','):