diff --git a/.dockerignore b/.dockerignore index 0733e81..e1e4222 100644 --- a/.dockerignore +++ b/.dockerignore @@ -20,7 +20,6 @@ user_path_config-deprecated.txt build_chb.py experiment.py /modules/*.png -/repositories /venv /tmp /ui-config.json diff --git a/fooocusapi/parameters.py b/fooocusapi/parameters.py index c61bef6..0b6ef6c 100644 --- a/fooocusapi/parameters.py +++ b/fooocusapi/parameters.py @@ -35,7 +35,7 @@ default_base_model_name = "juggernautXL_v8Rundiffusion.safetensors" default_refiner_model_name = "None" default_refiner_switch = 0.5 -default_loras = [[True, "sd_xl_offset_example-lora_1.0.safetensors", 0.1]] +default_loras = [["sd_xl_offset_example-lora_1.0.safetensors", 0.1]] default_cfg_scale = 7.0 default_prompt_negative = "" default_aspect_ratio = "1152*896" diff --git a/fooocusapi/utils/lora_manager.py b/fooocusapi/utils/lora_manager.py new file mode 100644 index 0000000..fcff29a --- /dev/null +++ b/fooocusapi/utils/lora_manager.py @@ -0,0 +1,59 @@ +""" +Manager loras from url + +@author: TechnikMax +@github: https://github.com/TechnikMax +""" +import hashlib +import os +import requests + + +def _hash_url(url): + """Generates a hash value for a given URL.""" + return hashlib.md5(url.encode('utf-8')).hexdigest() + + +class LoraManager: + """ + Manager loras from url + """ + def __init__(self): + self.cache_dir = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + '../../', + 'repositories/Fooocus/models/loras') + + def _download_lora(self, url): + """ + Downloads a LoRa from a URL and saves it in the cache. + """ + url_hash = _hash_url(url) + filepath = os.path.join(self.cache_dir, f"{url_hash}.safetensors") + file_name = f"{url_hash}.safetensors" + + if not os.path.exists(filepath): + print(f"start download for: {url}") + + try: + response = requests.get(url, timeout=10, stream=True) + response.raise_for_status() + with open(filepath, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + print(f"Download successfully, saved as {file_name}") + + except Exception as e: + raise Exception(f"error downloading {url}: {e}") from e + + else: + print(f"LoRa already downloaded {url}") + return file_name + + def check(self, urls): + """Manages the specified LoRAs: downloads missing ones and returns their file names.""" + paths = [] + for url in urls: + path = self._download_lora(url) + paths.append(path) + return paths diff --git a/fooocusapi/worker.py b/fooocusapi/worker.py index 3d1034f..f0f4404 100644 --- a/fooocusapi/worker.py +++ b/fooocusapi/worker.py @@ -464,7 +464,8 @@ def yield_result(_, imgs, tasks, extension='png'): pipeline.refresh_everything( refiner_model_name=refiner_model_name, base_model_name=base_model_name, - loras=loras, base_model_additional_loras=base_model_additional_loras, + loras=loras, + base_model_additional_loras=base_model_additional_loras, use_synthetic_refiner=use_synthetic_refiner) progressbar(async_task, 3, 'Processing prompts ...') diff --git a/predict.py b/predict.py index 81cc89b..5ac6624 100644 --- a/predict.py +++ b/predict.py @@ -10,6 +10,7 @@ from PIL import Image from cog import BasePredictor, BaseModel, Input, Path +from fooocusapi.utils.lora_manager import LoraManager from fooocusapi.utils.file_utils import output_dir from fooocusapi.models.common.task import GenerationFinishReason from fooocusapi.parameters import ( @@ -59,7 +60,7 @@ def predict( description="Fooocus styles applied for image generation, separated by comma"), performance_selection: str = Input( default='Speed', - choices=['Speed', 'Quality', 'Extreme Speed'], + choices=['Speed', 'Quality', 'Extreme Speed', 'Lightning'], description="Performance selection"), aspect_ratios_selection: str = Input( default='1152*896', @@ -72,6 +73,12 @@ def predict( image_seed: int = Input( default=-1, description="Seed to generate image, -1 for random"), + use_default_loras: bool = Input( + default=True, + description="Use default LoRAs"), + loras_custom_urls: str = Input( + default="", + description="Custom LoRAs URLs in the format 'url,weight' provide multiple seperated by ; (example 'url1,0.3;url2,0.1')"), sharpness: float = Input( default=2.0, ge=0.0, le=30.0), @@ -182,7 +189,23 @@ def predict( base_model_name = default_base_model_name refiner_model_name = default_refiner_model_name - loras = copy.copy(default_loras) + + lora_manager = LoraManager() + + # Use default loras if selected + loras = copy.copy(default_loras) if use_default_loras else [] + + # add custom user loras if provided + if loras_custom_urls: + urls = [url.strip() for url in loras_custom_urls.split(';')] + + loras_with_weights = [url.split(',') for url in urls] + + custom_lora_paths = lora_manager.check([lw[0] for lw in loras_with_weights]) + custom_loras = [[path, float(lw[1]) if len(lw) > 1 else 1.0] for path, lw in + zip(custom_lora_paths, loras_with_weights)] + + loras.extend(custom_loras) style_selections_arr = [] for s in style_selections.strip().split(','):