Skip to content

Commit

Permalink
Merge pull request #153 from VoltaML/feature/kdiffusion
Browse files Browse the repository at this point in the history
Major refactoring & K-Diffusion
  • Loading branch information
Stax124 authored Oct 15, 2023
2 parents 6b7fcad + 44f4748 commit 7f57d0d
Show file tree
Hide file tree
Showing 119 changed files with 8,189 additions and 5,163 deletions.
16 changes: 9 additions & 7 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
from pathlib import Path

from api_analytics.fastapi import Analytics
from fastapi import Depends, FastAPI, Request
from fastapi import Depends, FastAPI, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi_simple_cachecontrol.middleware import CacheControlMiddleware
from fastapi_simple_cachecontrol.types import CacheControl
from huggingface_hub.hf_api import LocalTokenNotFoundError
from starlette import status
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import JSONResponse

from api import websocket_manager
from api.routes import static, ws
Expand Down Expand Up @@ -188,16 +186,20 @@ async def shutdown_event():
static_app.mount("/", StaticFiles(directory="frontend/dist/assets"), name="assets")
app.mount("/assets", static_app)

origins = ["*"]

# Allow CORS for specified origins
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
)
static_app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
)
12 changes: 9 additions & 3 deletions api/routes/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
valid_extensions = ["png", "jpeg", "webp"]


def sort_images(images: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"Sort images by time"

return sorted(images, key=lambda x: x["time"], reverse=True)


@router.get("/txt2img")
def txt2img() -> List[Dict[str, Any]]:
"List all generated images"
Expand All @@ -31,7 +37,7 @@ def txt2img() -> List[Dict[str, Any]]:
{"path": i.as_posix(), "time": os.path.getmtime(i), "id": Path(i).stem}
)

return data
return sort_images(data)


@router.get("/img2img")
Expand All @@ -50,7 +56,7 @@ def img2img() -> List[Dict[str, Any]]:
{"path": i.as_posix(), "time": os.path.getmtime(i), "id": Path(i).stem}
)

return data
return sort_images(data)


@router.get("/extra")
Expand All @@ -69,7 +75,7 @@ def extra() -> List[Dict[str, Any]]:
{"path": i.as_posix(), "time": os.path.getmtime(i), "id": Path(i).stem}
)

return data
return sort_images(data)


@router.get("/data")
Expand Down
7 changes: 2 additions & 5 deletions api/routes/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@ async def save_configuration(settings: config.Configuration):
"Receive settings from the frontend and save them to the config file"

reload_required = False
if config.config.api.device_id != settings.api.device_id:
logger.info(f"Device ID was changed to {settings.api.device_id}")
reload_required = True
if config.config.api.device_type != settings.api.device_type:
logger.info(f"Device type was changed to {settings.api.device_type}")
if config.config.api.device != settings.api.device:
logger.info(f"Device was changed to {settings.api.device}")
reload_required = True
if config.config.api.data_type != settings.api.data_type:
logger.info(f"Precision changed to {settings.api.data_type}")
Expand Down
3 changes: 1 addition & 2 deletions api/routes/ws.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from fastapi import APIRouter
from fastapi.websockets import WebSocket
from starlette.websockets import WebSocketDisconnect
from fastapi.websockets import WebSocket, WebSocketDisconnect

from api import websocket_manager
from api.websockets.data import Data
Expand Down
8 changes: 4 additions & 4 deletions api/websockets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ def disconnect(self, websocket: WebSocket):
and config.api.clear_memory_policy == "after_disconnect"
):
if torch.cuda.is_available():
logger.debug(f"Cleaning up GPU memory: {config.api.device_id}")
logger.debug(f"Cleaning up GPU memory: {config.api.device}")

with torch.cuda.device(config.api.device_id):
with torch.device(config.api.device):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

Expand Down Expand Up @@ -177,9 +177,9 @@ async def close_all(self):

if config.api.clear_memory_policy == "after_disconnect":
if torch.cuda.is_available():
logger.debug(f"Cleaning up GPU memory: {config.api.device_id}")
logger.debug(f"Cleaning up GPU memory: {config.api.device}")

with torch.cuda.device(config.api.device_id):
with torch.cuda.device(config.api.device):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

Expand Down
123 changes: 49 additions & 74 deletions core/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,29 @@
from dataclasses_json import CatchAll, DataClassJsonMixin, Undefined, dataclass_json
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers

from core.config.samplers.sampler_config import SamplerConfig
from core.types import SigmaScheduler

logger = logging.getLogger(__name__)


@dataclass
class BaseDiffusionMixin:
width: int = 512
height: int = 512
batch_count: int = 1
batch_size: int = 1
seed: int = -1
cfg_scale: int = 7
steps: int = 40
prompt: str = ""
negative_prompt: str = ""
sampler: Union[
int, str
] = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value
sigmas: SigmaScheduler = "automatic"


@dataclass
class QuantDict:
vae_decoder: Optional[bool] = None
Expand All @@ -19,72 +39,32 @@ class QuantDict:


@dataclass
class Txt2ImgConfig:
class Txt2ImgConfig(BaseDiffusionMixin):
"Configuration for the text to image pipeline"

width: int = 512
height: int = 512
seed: int = -1
cfg_scale: int = 7
sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value
prompt: str = ""
negative_prompt: str = ""
steps: int = 40
batch_count: int = 1
batch_size: int = 1
self_attention_scale: float = 0.0


@dataclass
class Img2ImgConfig:
class Img2ImgConfig(BaseDiffusionMixin):
"Configuration for the image to image pipeline"

width: int = 512
height: int = 512
seed: int = -1
cfg_scale: int = 7
sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value
prompt: str = ""
negative_prompt: str = ""
steps: int = 40
batch_count: int = 1
batch_size: int = 1
resize_method: int = 0
denoising_strength: float = 0.6
self_attention_scale: float = 0.0


@dataclass
class InpaintingConfig:
class InpaintingConfig(BaseDiffusionMixin):
"Configuration for the inpainting pipeline"

prompt: str = ""
negative_prompt: str = ""
width: int = 512
height: int = 512
steps: int = 40
cfg_scale: int = 7
seed: int = -1
batch_count: int = 1
batch_size: int = 1
sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value
self_attention_scale: float = 0.0


@dataclass
class ControlNetConfig:
class ControlNetConfig(BaseDiffusionMixin):
"Configuration for the inpainting pipeline"

prompt: str = ""
negative_prompt: str = ""
width: int = 512
height: int = 512
seed: int = -1
cfg_scale: int = 7
steps: int = 40
batch_count: int = 1
batch_size: int = 1
sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value
controlnet: str = "lllyasviel/sd-controlnet-canny"
controlnet_conditioning_scale: float = 1.0
detection_resolution: int = 512
Expand All @@ -107,6 +87,11 @@ class UpscaleConfig:
class APIConfig:
"Configuration for the API"

# Autoload
autoloaded_textual_inversions: List[str] = field(default_factory=list)
autoloaded_models: List[str] = field(default_factory=list)
autoloaded_vae: Dict[str, str] = field(default_factory=dict)

# Websockets and intervals
websocket_sync_interval: float = 0.02
websocket_perf_interval: float = 1.0
Expand All @@ -116,8 +101,6 @@ class APIConfig:
tomesd_ratio: float = 0.25 # had to tone this down, 0.4 is too big of a context loss even on short prompts
tomesd_downsample_layers: Literal[1, 2, 4, 8] = 1

image_preview_delay: float = 2.0

# General optimizations
autocast: bool = False
attention_processor: Literal[
Expand All @@ -126,8 +109,6 @@ class APIConfig:
subquadratic_size: int = 512
attention_slicing: Union[int, Literal["auto", "disabled"]] = "disabled"
channels_last: bool = True
vae_slicing: bool = True
vae_tiling: bool = False
trace_model: bool = False
clear_memory_policy: Literal["always", "after_disconnect", "never"] = "always"
offload: Literal["module", "model", "disabled"] = "disabled"
Expand All @@ -139,8 +120,7 @@ class APIConfig:
deterministic_generation: bool = False

# Device settings
device_id: int = 0
device_type: Literal["cpu", "cuda", "mps", "directml", "intel", "vulkan"] = "cuda"
device: str = "cuda:0"

# Critical
enable_shutdown: bool = True
Expand All @@ -149,11 +129,6 @@ class APIConfig:
clip_skip: int = 1
clip_quantization: Literal["full", "int8", "int4"] = "full"

# Autoload
autoloaded_textual_inversions: List[str] = field(default_factory=list)
autoloaded_models: List[str] = field(default_factory=list)
autoloaded_vae: Dict[str, str] = field(default_factory=dict)

huggingface_style_parsing: bool = False

# Saving
Expand All @@ -176,6 +151,19 @@ class APIConfig:
"max-autotune",
] = "reduce-overhead"

# K_Diffusion
sgm_noise_multiplier: bool = False # also known as "alternate DDIM ODE"
kdiffusers_quantization: bool = True # improves sampling quality

# "philox" is what a "cuda" generator would be, except, it's on cpu
generator: Literal["device", "cpu", "philox"] = "device"

# VAE
live_preview_method: Literal["disabled", "approximation", "taesd"] = "approximation"
live_preview_delay: float = 2.0
vae_slicing: bool = True
vae_tiling: bool = False

@property
def dtype(self):
"Return selected data type"
Expand All @@ -186,26 +174,12 @@ def dtype(self):
return torch.float32

@property
def device(self):
"Return the device"

if self.device_type == "intel":
from core.inference.functions import is_ipex_available

return torch.device("xpu" if is_ipex_available() else "cpu")

if self.device_type in ["cpu", "mps"]:
return torch.device(self.device_type)

if self.device_type in ["vulkan", "cuda"]:
return torch.device(f"{self.device_type}:{self.device_id}")

if self.device_type == "directml":
import torch_directml # pylint: disable=import-error
def overwrite_generator(self) -> bool:
"Whether the generator needs to be overwritten with 'cpu.'"

return torch_directml.device()
else:
raise ValueError(f"Device type {self.device_type} not supported")
return any(
map(lambda x: x in self.device, ["mps", "directml", "vulkan", "intel"])
)


@dataclass
Expand Down Expand Up @@ -279,6 +253,7 @@ class Configuration(DataClassJsonMixin):
onnx: ONNXConfig = field(default_factory=ONNXConfig)
bot: BotConfig = field(default_factory=BotConfig)
frontend: FrontendConfig = field(default_factory=FrontendConfig)
sampler_config: SamplerConfig = field(default_factory=SamplerConfig)
extra: CatchAll = field(default_factory=dict)


Expand Down
Loading

0 comments on commit 7f57d0d

Please sign in to comment.