Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AnimateDiff + Faster-Diffusion & DeepCache #179

Draft
wants to merge 174 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
174 commits
Select commit Hold shift + click to select a range
bfc2db9
feat: SDXL
gabe56f Jul 21, 2023
f1fb2ae
feat: (mostly) working LWP with SDXL
gabe56f Jul 21, 2023
142336f
fix: SDXL loading
gabe56f Jul 21, 2023
49b66e1
fix: Fix generation on normal (non-xl) models
gabe56f Jul 22, 2023
fff3f17
feat: temporarily list models that contain "xl" as SDXL
gabe56f Jul 22, 2023
ce74ed8
fix: loading issue
gabe56f Jul 22, 2023
ab3e2c9
feat: refiner support
gabe56f Jul 22, 2023
e6b22e8
feat: SDXL Refiner frontend
gabe56f Jul 22, 2023
d21849a
Remove bislerp-original
gabe56f Jul 22, 2023
0435434
Merge branch 'experimental' into feature/sdxl
gabe56f Aug 2, 2023
adab052
Cleanup
gabe56f Aug 2, 2023
8c2cb20
Further cleanup
gabe56f Aug 2, 2023
e896ad5
Merge branch 'experimental' into feature/sdxl
gabe56f Aug 2, 2023
b925c67
Add SDXL to model selector dropdown
gabe56f Aug 2, 2023
52b5eea
Merge branch 'experimental' into feature/sdxl
gabe56f Aug 2, 2023
43e0219
Purge old downloader, bump deps, fix typing errors
Stax124 Aug 3, 2023
9313571
Remove module offload & fix offload & add VAE switching
gabe56f Aug 4, 2023
2e80767
Remove "module" option from config
gabe56f Aug 4, 2023
c922408
Build frontend
gabe56f Aug 4, 2023
ee976e5
Fix offload vae switching
gabe56f Aug 4, 2023
ae813f0
Fix offload
gabe56f Aug 4, 2023
628e4ed
Fixed offload for real this time
gabe56f Aug 5, 2023
ce7f092
Add "flashattention" to supported attention processors
gabe56f Sep 20, 2023
a55a7e5
Make flash-attn optional
gabe56f Oct 10, 2023
68ac5fb
Work on model detection
gabe56f Oct 18, 2023
0d2d110
Merge remote-tracking branch 'origin/experimental' into feature/sdxl
gabe56f Oct 18, 2023
a477829
fix more issues
gabe56f Oct 18, 2023
74cc7e2
...
gabe56f Oct 18, 2023
15a317d
Merge remote-tracking branch 'origin/experimental' into feature/sdxl
gabe56f Oct 18, 2023
5929044
fix prompt-expansion download
gabe56f Nov 4, 2023
ec1f5a3
Merge branch 'experimental' into feature/sdxl
gabe56f Nov 7, 2023
96d3401
Merge branch 'experimental' into feature/sdxl
gabe56f Nov 7, 2023
029c61b
Working SDXL loras, non-working SD1.5 loras
gabe56f Nov 8, 2023
aa3ac16
Fix non-SDXL loras
gabe56f Nov 8, 2023
a45aa32
Fix frontend
gabe56f Nov 8, 2023
e822f21
Reload-less more feature-proof clip skip
gabe56f Nov 9, 2023
270caa3
T2I
gabe56f Nov 9, 2023
2df7159
Merge branch 'feature/sdxl' of https://github.com/voltaML/voltaML-fas…
gabe56f Nov 9, 2023
b2e4900
Merge branch 'experimental' into feature/sdxl
Stax124 Nov 9, 2023
7a523e1
Fix loading
gabe56f Nov 9, 2023
bb43bc0
Merge branch 'feature/sdxl' of https://github.com/voltaML/voltaML-fas…
gabe56f Nov 9, 2023
7f85229
Merge SDXL and SD model loader tabs, few config typing changes, rebui…
Stax124 Nov 9, 2023
66b350d
Completely broken module offload
gabe56f Nov 9, 2023
661fbeb
Merge branch 'feature/sdxl' of https://github.com/voltaML/voltaML-fas…
gabe56f Nov 9, 2023
6582e73
Tag expansion, improved offload
gabe56f Nov 10, 2023
cad5969
Hardcode gpu load model for now, better error handling when searching…
Stax124 Nov 10, 2023
1d63a68
add few cases to determine_model_type, hook it up to autoloader
Stax124 Nov 10, 2023
26e2b7a
Update docstring for CachedModelList
Stax124 Nov 10, 2023
3a8c8f8
Fix Diffusers model loading
Stax124 Nov 10, 2023
5d0a804
implement cfg-related things, start work on sag rework
gabe56f Nov 11, 2023
eae53e6
Merge branch 'feature/sdxl' of https://github.com/voltaML/voltaML-fas…
gabe56f Nov 11, 2023
196df8f
Temporary ResizeFromDimensionsInput
Stax124 Nov 11, 2023
729eb6b
frontend
gabe56f Nov 11, 2023
1a3ff04
Merge branch 'feature/sdxl' of https://github.com/voltaML/voltaML-fas…
gabe56f Nov 11, 2023
db9bda5
rebuild
gabe56f Nov 11, 2023
6b5bae4
Resize from frontend and config
Stax124 Nov 11, 2023
256a605
Self-Attention Guidance on SDXL and K-Diff
gabe56f Nov 12, 2023
baa3f6a
Merge branch 'feature/sdxl' of https://github.com/voltaML/voltaML-fas…
gabe56f Nov 12, 2023
c9b82fa
Rebuild frontend
gabe56f Nov 12, 2023
d63661b
Fix ControlNet SAG
gabe56f Nov 12, 2023
6483bd8
Set default refiner method to separate, 'cause I know that works
gabe56f Nov 12, 2023
d3f8138
asdfjh
gabe56f Nov 14, 2023
614de26
Merge branch 'experimental' into feature/sdxl
Stax124 Nov 16, 2023
26c7fe3
Fix tags inside TopBar
Stax124 Nov 16, 2023
745efcc
SDXL Textual inversions
Stax124 Nov 16, 2023
fe9dbec
Fix ruff error
Stax124 Nov 16, 2023
4473a48
Send model type with loaded models
Stax124 Nov 17, 2023
5c19e2b
Fix Config, set VAE tiling to be enabled by default
Stax124 Nov 17, 2023
ec6f14e
Settings Diff Resolver, update debugger settings
Stax124 Nov 17, 2023
cec74ea
Maybe fix FP32 VAE once and for all??
gabe56f Nov 17, 2023
f29fb0e
Merge branch 'feature/sdxl' of https://github.com/voltaML/voltaML-fas…
gabe56f Nov 17, 2023
6c12830
Script for parsing keys from models
Stax124 Nov 17, 2023
2ff52f6
Fix
gabe56f Nov 17, 2023
72011da
Merge branch 'feature/sdxl' of https://github.com/voltaML/voltaML-fas…
gabe56f Nov 17, 2023
24a3549
Fixed hires image upscale, sdxl VAE loading of custom models
Stax124 Nov 19, 2023
04c0fb8
Fix config, fix some type errors
Stax124 Nov 19, 2023
4025d62
Start work on LCM
gabe56f Nov 19, 2023
1178d14
Revert "Start work on LCM"
gabe56f Nov 19, 2023
e29b9b1
Redo checkpoint loading, add new entries to package.json
Stax124 Nov 20, 2023
4ff1e9d
Merge branch 'feature/sdxl' of https://github.com/voltaML/voltaML-fas…
gabe56f Nov 20, 2023
364eb35
Merge branch 'experimental' into feature/sdxl
Stax124 Nov 21, 2023
d50ad2c
float8 support
gabe56f Nov 21, 2023
951ee80
Merge branch 'feature/sdxl' of https://github.com/voltaML/voltaML-fas…
gabe56f Nov 21, 2023
4beeab7
Heun++ & maybe working fp16 loras on fp8
gabe56f Nov 22, 2023
98ff461
maybe fixed 1.5 loras?
gabe56f Nov 22, 2023
eeecf3e
Revert "Revert "Start work on LCM"" & fix LoRAs
gabe56f Nov 24, 2023
b0411f6
Add sdxl tests, fix inpainting
Stax124 Nov 24, 2023
cf0e845
fix non inpaint pipelines
Stax124 Nov 24, 2023
02e42f1
Fix few sdxl inference errors
Stax124 Nov 24, 2023
f43bc33
update pytorch tests
Stax124 Nov 26, 2023
aa428db
Fix diffusers samplers with ControlNet
Stax124 Nov 26, 2023
a3b3923
fix: controlnets
gabe56f Nov 26, 2023
e72b860
Fix imports for older pytorch
tomas-novak-olc Nov 27, 2023
4f792fb
remove unused import
tomas-novak-olc Nov 27, 2023
9d6a1f6
update AIT tests with correct model
Stax124 Nov 27, 2023
a68feff
Merge pull request #127 from VoltaML/feature/sdxl
gabe56f Nov 27, 2023
ed40b23
Merge pull request #171 from VoltaML/fix/prompt-expansion
gabe56f Nov 27, 2023
95fe095
wtf is wrong
gabe56f Nov 27, 2023
dadf3d0
Kohya deepshrink working
gabe56f Nov 27, 2023
c81832e
fix
gabe56f Nov 27, 2023
855ac3f
Fix type error inside config
Stax124 Nov 28, 2023
ad12a52
update CI, handle missing .env file
Stax124 Nov 28, 2023
0ba1cd1
Create new CI for docker build on new tag
Stax124 Nov 28, 2023
5bfaf28
CODEOWNERS
gabe56f Nov 29, 2023
eb0660f
Better image grid for CivitAI browser, nsfw toggle per image
Stax124 Nov 30, 2023
7b5713b
Fix CivitAI model image grouping
Stax124 Dec 1, 2023
bc518a0
Extra data displayed in CivitAI model browser
Stax124 Dec 1, 2023
bdd599b
More CODEOWNERS stuff
gabe56f Dec 2, 2023
c3e4cf2
Update Docker compose templates, yarn commands for docker
Stax124 Dec 3, 2023
707065e
Try to fix Intel Arc install issues, fix types
Stax124 Dec 3, 2023
960523c
Fix dependencies for intel part 2
Stax124 Dec 3, 2023
e1b5809
Fix Intel Arc install part 3
Stax124 Dec 3, 2023
ea65063
Fix get_capabilities for non-CUDA users
Stax124 Dec 3, 2023
39b078b
Fix get_capabilities for Arc users pt2
Stax124 Dec 3, 2023
d6daeb5
Experimental custom TI impl.
gabe56f Dec 6, 2023
49f350c
Merge remote-tracking branch 'origin/experimental' into feature/deep-…
gabe56f Dec 6, 2023
c0a319f
Enable Upscale everywhere, prep for highres everywhere
Stax124 Dec 6, 2023
415f060
SASolver
gabe56f Dec 6, 2023
4c9f211
Scalecrafter
gabe56f Dec 9, 2023
97c9609
format
gabe56f Dec 9, 2023
c3c62e1
Forgot something
gabe56f Dec 9, 2023
e0b1f1d
Global postprocessing (Hires, Upscale, more to come)
Stax124 Dec 9, 2023
10d5c69
Merge remote-tracking branch 'origin/experimental' into feature/deep-…
gabe56f Dec 9, 2023
c5288d1
merge
gabe56f Dec 9, 2023
c5152cf
WIP: Frontend for deepshrink and scalecrafter
Stax124 Dec 10, 2023
4c0c1ed
Animatediff
gabe56f Dec 11, 2023
0f2b7f4
AnimateDiff v2
gabe56f Dec 11, 2023
d1a9b42
channels last 3d
gabe56f Dec 12, 2023
7cbacae
fix circular import
gabe56f Dec 12, 2023
84fa9e9
SAG with SDPA
gabe56f Dec 12, 2023
982bfb6
format
gabe56f Dec 12, 2023
c514e62
DeepShrink and scalecrafter default settings, better UI
Stax124 Dec 12, 2023
d9251b0
Update HighResFix visuals for other tabs
Stax124 Dec 12, 2023
37d0c86
Less VRAM usage
gabe56f Dec 13, 2023
4a8d189
WIP: Volta for mobile devices
Stax124 Dec 13, 2023
dc9c940
todo list
gabe56f Dec 14, 2023
acb108a
Merge branch 'feature/animate' of https://github.com/voltaML/voltaML-…
gabe56f Dec 14, 2023
dc5fbe3
initial freeinit
gabe56f Dec 14, 2023
356fb85
16 frames
gabe56f Dec 14, 2023
7371437
maybe fix euler
gabe56f Dec 14, 2023
e44af82
Merge branch 'experimental' into feature/deep-shrink
Stax124 Dec 15, 2023
0eb89e8
Fix icon bug, rebuild frontend
Stax124 Dec 15, 2023
86d1a82
Merge pull request #172 from VoltaML/feature/deep-shrink
Stax124 Dec 15, 2023
c2bd605
Free U v2, better FreeU settings, diffusers bump to v0.24.0
Stax124 Dec 15, 2023
ccfa548
Add GIF as valid extensions to API ouput paths
Stax124 Dec 15, 2023
4be30e7
AnimateDiff v3
gabe56f Dec 15, 2023
a8b970b
Prepare class for Adetailer
Stax124 Dec 17, 2023
d32c30a
Update defaults, fix path model loading issues when name was differen…
Stax124 Dec 19, 2023
988524a
Merge branch 'experimental' into adetailer
Stax124 Dec 19, 2023
ac7b50b
wrap up backend, redo frontend settings types
Stax124 Dec 20, 2023
c31274b
Faster-Diffusion + DeepCache (fast-diff is working)
gabe56f Dec 20, 2023
7f5b410
Merge remote-tracking branch 'origin/experimental' into feature/animate
gabe56f Dec 21, 2023
1e9a20d
some fixes
gabe56f Dec 21, 2023
c34f58e
fix all launch problems?
gabe56f Dec 21, 2023
9f8caf4
Format, bring SDXL up-to-date, documentation
gabe56f Dec 21, 2023
9a079b4
ADetailer frontend progress
Stax124 Dec 24, 2023
3ceacc8
ADetailer somewhat working
Stax124 Dec 26, 2023
de521e8
PIA support + fix loads of things
gabe56f Dec 29, 2023
3f7fbc1
Improve loading times from single files
gabe56f Dec 29, 2023
c218fdf
Further loading improvements & fixes
gabe56f Dec 29, 2023
eb984ce
Fix control net not returning preprocessed images
Stax124 Dec 29, 2023
68e1a1b
Update default sampler settings
Stax124 Dec 29, 2023
a6ea7ab
Improve loading times
gabe56f Dec 29, 2023
8cc0123
Add upscale and iterations settings
Stax124 Dec 30, 2023
37ab58b
Update websocket target for inpainting, remove extra comma from defau…
Stax124 Dec 30, 2023
8ed6df9
Disable iterations for production env as they are currently broken
Stax124 Dec 30, 2023
fbe4d06
Fix SDXL loading & other refactors
gabe56f Dec 31, 2023
2b9b0eb
j is == i.
gabe56f Dec 31, 2023
1211e5f
Implement SDXL support in the backend
gabe56f Dec 31, 2023
d0cd806
Fix a small oopsie
gabe56f Dec 31, 2023
b18ca52
Merge pull request #180 from VoltaML/adetailer
Stax124 Dec 31, 2023
35ecc60
Merge remote-tracking branch 'origin/experimental' into feature/animate
gabe56f Dec 31, 2023
e42b247
Daily pydantic hate
gabe56f Dec 31, 2023
2d6fea3
fix sdxl
gabe56f Jan 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Most of the stuff is Staxs
# core is somewhat general
* @Stax124
/core/ @Stax124 @gabe56f
/frontend/ @Stax124 @gabe56f

# Stax-specific
/core/* @Stax124
/core/inference/esrgan/ @Stax124

# Gabe-specific stuff
/libs/ @gabe56f
/core/scheduling/ @gabe56f
/core/optimizations/ @gabe56f
/core/inference/injectables/ @gabe56f
/core/inference/utilities/kohya_hires.py @gabe56f
/core/inference/utilities/anisotropic.py @gabe56f
/core/inference/utilities/cfg.py @gabe56f
/core/inference/utilities/sag/ @gabe56f
/core/inference/utilities/prompt_expansion/ @gabe56f
/core/inference/onnx/ @gabe56f
27 changes: 27 additions & 0 deletions .github/workflows/docker_build_tag.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Docker Build on Tag Push

on:
push:
tags:
- "v*"

jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build and push
uses: docker/build-push-action@v5
with:
context: .
file: docker/cuda/dockerfile
push: true
tags: ${{ secrets.DOCKERHUB_USERNAME }}/volta:${{ github.ref_name }}-cuda
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ cover/
# Diffusers convert files
out
traced_unet/
onnx
/onnx
converted

# Docker
test.docker-compose.yml
test-no-mount.docker-compose.yml
*test.docker-compose.yml
*test-no-mount.docker-compose.yml

# Docs
node_modules/
Expand Down
20 changes: 10 additions & 10 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
{
"configurations": [
{
"name": "Python: File",
"type": "python",
"request": "launch",
"program": "main.py",
"args": ["--log-level=DEBUG"],
"justMyCode": false
}
]
"configurations": [
{
"name": "VoltaML API Debug",
"type": "python",
"request": "launch",
"program": "main.py",
"args": ["--log-level=DEBUG"],
"justMyCode": false
}
]
}
7 changes: 2 additions & 5 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
"python.testing.pytestEnabled": true,
"python.analysis.typeCheckingMode": "basic",
"python.languageServer": "Pylance",
"rust-analyzer.linkedProjects": [
"./manager/Cargo.toml"
],
"rust-analyzer.linkedProjects": ["./manager/Cargo.toml"],
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none"
}
}
33 changes: 29 additions & 4 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from api.websockets.data import Data
from api.websockets.notification import Notification
from core import shared
from core.files import get_full_model_path
from core.types import InferenceBackend
from core.utils import determine_model_type

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -47,12 +49,33 @@ async def validation_exception_handler(_request: Request, exc: RequestValidation

logger.debug(exc)

if exc._error_cache is not None and exc._error_cache[0]["loc"][0] == "body":
from core.config._config import Configuration

default_value = Configuration()
keys = [str(i) for i in exc._error_cache[0]["loc"][1:]] # type: ignore
current_value = exc._error_cache[0]["ctx"]["given"] # type: ignore

# Traverse the config object to find the correct value
for key in keys:
default_value = getattr(default_value, key)

websocket_manager.broadcast_sync(
data=Data(
data={
"default_value": default_value,
"key": keys,
"current_value": current_value,
},
data_type="incorrect_settings_value",
)
)

try:
why = str(exc).split(":")[1].strip()
await websocket_manager.broadcast(
websocket_manager.broadcast_sync(
data=Notification(
severity="error",
message=f"Validation error: {why}",
message="Validation error",
title="Validation Error",
)
)
Expand Down Expand Up @@ -130,7 +153,9 @@ async def startup_event():
for model in config.api.autoloaded_models:
if model in [i.path for i in all_models]:
backend: InferenceBackend = [i.backend for i in all_models if i.path == model][0] # type: ignore
gpu.load_model(model, backend)
model_type = determine_model_type(get_full_model_path(model))[1]

gpu.load_model(model, backend, type=model_type)
else:
logger.warning(f"Autoloaded model {model} not found, skipping")

Expand Down
31 changes: 21 additions & 10 deletions api/routes/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
DeleteModelRequest,
InferenceBackend,
ModelResponse,
PyTorchModelBase,
TextualInversionLoadRequest,
VaeLoadRequest,
)
from core.utils import download_file
from core.utils import determine_model_type, download_file

router = APIRouter(tags=["models"])
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -62,9 +63,11 @@ def list_loaded_models() -> List[ModelResponse]:

loaded_models = []
for model_id in gpu.loaded_models:
name, type_, stage = determine_model_type(get_full_model_path(model_id))

loaded_models.append(
ModelResponse(
name=Path(model_id).name
name=name
if (".ckpt" in model_id) or (".safetensors" in model_id)
else model_id,
backend=gpu.loaded_models[model_id].backend,
Expand All @@ -75,6 +78,8 @@ def list_loaded_models() -> List[ModelResponse]:
"textual_inversions", []
),
valid=True,
stage=stage,
type=type_,
)
)

Expand All @@ -92,11 +97,12 @@ def list_available_models() -> List[ModelResponse]:
def load_model(
model: str,
backend: InferenceBackend,
type: PyTorchModelBase,
):
"Loads a model into memory"

try:
gpu.load_model(model, backend)
gpu.load_model(model, backend, type)

websocket_manager.broadcast_sync(data=Data(data_type="refresh_models", data={}))
except torch.cuda.OutOfMemoryError: # type: ignore
Expand All @@ -106,7 +112,7 @@ def load_model(


@router.post("/unload")
async def unload_model(model: str):
def unload_model(model: str):
"Unloads a model from memory"

gpu.unload(model)
Expand All @@ -125,7 +131,7 @@ def unload_all_models():


@router.post("/load-vae")
async def load_vae(req: VaeLoadRequest):
def load_vae(req: VaeLoadRequest):
"Load a VAE into a model"

gpu.load_vae(req)
Expand All @@ -134,7 +140,7 @@ async def load_vae(req: VaeLoadRequest):


@router.post("/load-textual-inversion")
async def load_textual_inversion(req: TextualInversionLoadRequest):
def load_textual_inversion(req: TextualInversionLoadRequest):
"Load a LoRA model into a model"

gpu.load_textual_inversion(req)
Expand All @@ -143,15 +149,15 @@ async def load_textual_inversion(req: TextualInversionLoadRequest):


@router.post("/memory-cleanup")
async def cleanup():
def cleanup():
"Free up memory manually"

gpu.memory_cleanup()
return {"message": "Memory cleaned up"}


@router.post("/download")
async def download_model(model: str):
def download_model(model: str):
"Download a model to the cache"

gpu.download_huggingface_model(model)
Expand Down Expand Up @@ -243,7 +249,7 @@ def delete_model(req: DeleteModelRequest):

@router.post("/download-model")
def download_checkpoint(
link: str, model_type: Literal["Checkpoint", "TextualInversion", "LORA"]
link: str, model_type: Literal["Checkpoint", "TextualInversion", "LORA", "VAE"]
) -> str:
"Download a model from a link and return the path to the downloaded file."

Expand All @@ -254,7 +260,12 @@ def download_checkpoint(
folder = "textual-inversion"
elif mtype == "lora":
folder = "lora"
elif mtype == "vae":
folder = "vae"
else:
raise ValueError(f"Unknown model type {mtype}")

return download_file(link, Path("data") / folder, True).as_posix()
saved_path = download_file(link, Path("data") / folder, True).as_posix()
websocket_manager.broadcast_sync(Data(data_type="refresh_models", data={}))

return saved_path
2 changes: 1 addition & 1 deletion api/routes/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
thread_pool = ThreadPoolExecutor()
logger = logging.getLogger(__name__)

valid_extensions = ["png", "jpeg", "webp"]
valid_extensions = ["png", "jpeg", "webp", "gif"]


def sort_images(images: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
Expand Down
2 changes: 1 addition & 1 deletion api/routes/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi import APIRouter

from core import config
from core.config.config import update_config
from core.config._config import update_config

router = APIRouter(tags=["settings"])

Expand Down
10 changes: 6 additions & 4 deletions core/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from pathlib import Path

from diffusers.utils.constants import DIFFUSERS_CACHE
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE as DIFFUSERS_CACHE

from .config import (
from ._config import (
Configuration,
Img2ImgConfig,
Txt2ImgConfig,
load_config,
save_config,
)
from .default_settings import (
Txt2ImgConfig,
Img2ImgConfig,
)

config = load_config()

Expand Down
80 changes: 80 additions & 0 deletions core/config/_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import logging
from dataclasses import Field, dataclass, field, fields

from dataclasses_json import CatchAll, DataClassJsonMixin, Undefined, dataclass_json

from core.config.samplers.sampler_config import SamplerConfig

from .api_settings import APIConfig
from .bot_settings import BotConfig
from .default_settings import (
AITemplateConfig,
ControlNetConfig,
Img2ImgConfig,
InpaintingConfig,
ONNXConfig,
Txt2ImgConfig,
UpscaleConfig,
)
from .flags_settings import FlagsConfig
from .frontend_settings import FrontendConfig
from .interrogator_settings import InterrogatorConfig

logger = logging.getLogger(__name__)


@dataclass_json(undefined=Undefined.INCLUDE)
@dataclass
class Configuration(DataClassJsonMixin):
"Main configuration class for the application"

txt2img: Txt2ImgConfig = field(default_factory=Txt2ImgConfig)
img2img: Img2ImgConfig = field(default_factory=Img2ImgConfig)
inpainting: InpaintingConfig = field(default_factory=InpaintingConfig)
controlnet: ControlNetConfig = field(default_factory=ControlNetConfig)
upscale: UpscaleConfig = field(default_factory=UpscaleConfig)
api: APIConfig = field(default_factory=APIConfig)
interrogator: InterrogatorConfig = field(default_factory=InterrogatorConfig)
aitemplate: AITemplateConfig = field(default_factory=AITemplateConfig)
onnx: ONNXConfig = field(default_factory=ONNXConfig)
bot: BotConfig = field(default_factory=BotConfig)
frontend: FrontendConfig = field(default_factory=FrontendConfig)
sampler_config: SamplerConfig = field(default_factory=SamplerConfig)
flags: FlagsConfig = field(default_factory=FlagsConfig)
extra: CatchAll = field(default_factory=dict)


def save_config(config: Configuration):
"Save the configuration to a file"

logger.info("Saving configuration to data/settings.json")

with open("data/settings.json", "w", encoding="utf-8") as f:
f.write(config.to_json(ensure_ascii=False, indent=4))


def update_config(config: Configuration, new_config: Configuration):
"Update the configuration with new values instead of overwriting the pointer"

for cls_field in fields(new_config):
assert isinstance(cls_field, Field)
setattr(config, cls_field.name, getattr(new_config, cls_field.name))


def load_config():
"Load the configuration from a file"

logger.info("Loading configuration from data/settings.json")

try:
with open("data/settings.json", "r", encoding="utf-8") as f:
config = Configuration.from_json(f.read())
logger.info("Configuration loaded from data/settings.json")
return config

except FileNotFoundError:
logger.info("data/settings.json not found, creating a new one")
config = Configuration()
save_config(config)
logger.info("Configuration saved to data/settings.json")
return config
Loading
Loading