Skip to content

Commit

Permalink
fixes (#147)
Browse files Browse the repository at this point in the history
* fixes

* rm console log

* dirty patch regex

* refact

* detail

* fix tests
  • Loading branch information
ljleb authored Aug 18, 2023
1 parent 0749265 commit a332cfa
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 70 deletions.
23 changes: 11 additions & 12 deletions comfyui_custom_nodes/webui_io.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from lib_comfyui import global_state


class AnyType(str):
def __ne__(self, _) -> bool:
return False
class StaticProperty(object):
def __init__(self, f):
self.f = f


class AnyReturnTypes(tuple):
def __init__(self):
super().__init__()

def __getitem__(self, _):
return AnyType()
def __get__(self, *args):
return self.f()


class FromWebui:
Expand All @@ -22,15 +17,19 @@ def INPUT_TYPES(cls):
"void": ("VOID", ),
},
}
RETURN_TYPES = AnyReturnTypes()

@StaticProperty
def RETURN_TYPES():
return getattr(global_state, "current_workflow_input_types", ())

RETURN_NAMES = ()
FUNCTION = "get_node_inputs"

CATEGORY = "webui"

@staticmethod
def get_node_inputs(void):
return global_state.node_input_args
return global_state.node_inputs


class ToWebui:
Expand Down
17 changes: 11 additions & 6 deletions comfyui_custom_scripts/extensions/webuiPatches.js
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,11 @@ async function patchDefaultGraph(iframeInfo) {
}

app.original_loadGraphData = app.loadGraphData;
app.loadGraphData = (graphData) => {
if (graphData) {
const doLoadGraphData = graphData => {
if (graphData !== "auto") {
return app.original_loadGraphData(graphData);
}

if (iframeInfo.defaultWorkflow !== "auto") {
return app.original_loadGraphData(iframeInfo.defaultWorkflow);
}

app.graph.clear();

const from_webui = LiteGraph.createNode("FromWebui");
Expand All @@ -73,6 +69,15 @@ async function patchDefaultGraph(iframeInfo) {
app.graph.arrange();
};

app.loadGraphData = (graphData) => {
if (graphData) {
return doLoadGraphData(graphData);
}
else {
return doLoadGraphData(iframeInfo.defaultWorkflow);
}
};

app.loadGraphData();
}

Expand Down
4 changes: 2 additions & 2 deletions comfyui_custom_scripts/extensions/webuiRequests.js
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ async function registerClientToWebui(workflowTypeId, webuiClientId, sid) {

const webuiRequests = new Map([
["queue_prompt", async (json) => {
await app.queuePrompt(json.queueFront ? -1 : 0, 1);
await app.queuePrompt(json.detail.queueFront ? -1 : 0, 1);
}],
["serialize_graph", (json) => {
return app.graph.original_serialize();
}],
["set_workflow", (json) => {
app.loadGraphData(json.workflow);
app.loadGraphData(json.detail.workflow);
}],
]);

Expand Down
43 changes: 25 additions & 18 deletions lib_comfyui/comfyui/iframe_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,40 @@ def send(request, workflow_type, data=None):
@staticmethod
@ipc.restrict_to_process('webui')
def start_workflow_sync(
batch_input_args: Tuple[Any],
batch_input_args: Tuple[Any, ...],
workflow_type_id: str,
workflow_input_types: List[str],
queue_front: bool,
) -> List[Dict[str, Any]]:
from modules import shared
if shared.state.interrupted:
raise RuntimeError('The workflow was not started because the webui has been interrupted')

global_state.node_input_args = batch_input_args
global_state.node_inputs = batch_input_args
global_state.node_outputs = []
global_state.current_workflow_input_types = workflow_input_types

queue_tracker.setup_tracker_id()

# unsafe queue tracking
ComfyuiIFrameRequests.send(
request='webui_queue_prompt',
workflow_type=workflow_type_id,
data={
'requiredNodeTypes': [],
'queueFront': queue_front,
}
)

if not queue_tracker.wait_until_done():
raise RuntimeError('The workflow has not returned normally')

return global_state.node_outputs
try:
queue_tracker.setup_tracker_id()

# unsafe queue tracking
ComfyuiIFrameRequests.send(
request='webui_queue_prompt',
workflow_type=workflow_type_id,
data={
'requiredNodeTypes': [],
'queueFront': queue_front,
}
)

if not queue_tracker.wait_until_done():
raise RuntimeError('The workflow has not returned normally')

return global_state.node_outputs
finally:
global_state.current_workflow_input_types = ()
global_state.node_outputs = []
global_state.node_inputs = None

@staticmethod
@ipc.restrict_to_process('comfyui')
Expand Down
36 changes: 20 additions & 16 deletions lib_comfyui/external_code/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def run_workflow(
if queue_front is None:
queue_front = getattr(global_state, 'queue_front', True)

batch_input_args = _normalize_batch_input_to_tuple(batch_input, workflow_type)
batch_input_args, input_types = _normalize_to_tuple(batch_input, workflow_type.input_types)

if not candidate_ids:
raise ValueError(f'The workflow type {workflow_type.pretty_str()} does not exist on tab {tab}. Valid tabs for the given workflow type are: {workflow_type.tabs}')
Expand All @@ -260,6 +260,7 @@ def run_workflow(
batch_output_params = ComfyuiIFrameRequests.start_workflow_sync(
batch_input_args=batch_input_args,
workflow_type_id=workflow_type_id,
workflow_input_types=input_types,
queue_front=queue_front,
)
except RuntimeError as e:
Expand Down Expand Up @@ -294,28 +295,31 @@ class WorkflowTypeDisabled(RuntimeError):
pass


def _normalize_batch_input_to_tuple(batch_input, workflow_type):
if isinstance(workflow_type.input_types, dict):
def _normalize_to_tuple(batch_input, input_types):
if isinstance(input_types, str):
return (batch_input,), (input_types,)
elif isinstance(input_types, tuple):
if not isinstance(batch_input, tuple):
raise TypeError(f'batch_input should be tuple but is instead {type(batch_input)}')

if len(batch_input) != len(input_types):
raise TypeError(
f'batch_input received {len(batch_input)} values instead of {len(input_types)} (signature is {input_types})')

return batch_input, input_types
elif isinstance(input_types, dict):
if not isinstance(batch_input, dict):
raise TypeError(f'batch_input should be dict but is instead {type(batch_input)}')

expected_keys = set(workflow_type.input_types.keys())
expected_keys = set(input_types.keys())
actual_keys = set(batch_input.keys())
if expected_keys - actual_keys:
raise TypeError(f'batch_input is missing keys: {expected_keys - actual_keys}')

# convert to tuple in the same order as the items in input_types
return tuple(batch_input[k] for k in workflow_type.input_types.keys())
elif isinstance(workflow_type.input_types, str):
return batch_input,
elif isinstance(workflow_type.input_types, tuple):
if not isinstance(batch_input, tuple):
raise TypeError(f'batch_input should be tuple but is instead {type(batch_input)}')

if len(batch_input) != len(workflow_type.input_types):
raise TypeError(
f'batch_input received {len(batch_input)} values instead of {len(workflow_type.input_types)} (signature is {workflow_type.input_types})')

return batch_input
return (
tuple(batch_input[k] for k in input_types.keys()),
tuple(input_types.values())
)
else:
raise TypeError(f'batch_input should be str, tuple or dict but is instead {type(batch_input)}')
9 changes: 5 additions & 4 deletions lib_comfyui/global_state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys
from types import ModuleType
from typing import List, Tuple, Dict
from typing import List, Tuple, Dict, Any
from lib_comfyui import ipc


Expand All @@ -9,9 +9,10 @@
queue_front: bool

workflow_types: List
enabled_workflow_type_ids: Dict
batch_input_args: Tuple
batch_output_args: List[Dict]
enabled_workflow_type_ids: Dict[str, bool]
node_inputs: Tuple[Any, ...]
node_outputs: List[Dict[str, Any]]
current_workflow_input_types: Tuple[str, ...]

ipc_strategy_class: type
ipc_strategy_class_name: str
Expand Down
6 changes: 3 additions & 3 deletions lib_comfyui/webui/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from lib_comfyui import comfyui_process, ipc, global_state, external_code
from lib_comfyui.webui import tab, settings, workflow_patcher, reverse_proxy
from lib_comfyui.webui import tab, settings, patches, reverse_proxy


@ipc.restrict_to_process('webui')
Expand All @@ -24,7 +24,7 @@ def on_ui_settings():

@ipc.restrict_to_process('webui')
def on_after_component(*args, **kwargs):
return workflow_patcher.watch_prompts(*args, **kwargs)
return patches.watch_prompts(*args, **kwargs)


@ipc.restrict_to_process('webui')
Expand All @@ -36,6 +36,6 @@ def on_app_started(_gr_root, fast_api):
@ipc.restrict_to_process('webui')
def on_script_unloaded():
comfyui_process.stop()
workflow_patcher.clear_patches()
patches.clear_patches()
global_state.is_ui_instantiated = False
external_code.clear_workflow_types()
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
import functools
import re
import sys

import torch
from lib_comfyui import ipc, global_state, default_workflow_types, external_code
from lib_comfyui.comfyui import type_conversion


__original_create_sampler = None
__original_re_param_code = None


@ipc.restrict_to_process('webui')
def apply_patches():
from modules import sd_samplers
global __original_create_sampler
from modules import sd_samplers, generation_parameters_copypaste
global __original_create_sampler, __original_re_param_code

__original_create_sampler = sd_samplers.create_sampler
sd_samplers.create_sampler = functools.partial(create_sampler_hijack, original_function=sd_samplers.create_sampler)

__original_re_param_code = generation_parameters_copypaste.re_param_code
generation_parameters_copypaste.re_param_code = r'\s*([\w ]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
generation_parameters_copypaste.re_param = re.compile(generation_parameters_copypaste.re_param_code)


@ipc.restrict_to_process('webui')
def watch_prompts(component, **kwargs):
Expand All @@ -39,10 +44,15 @@ def watch_prompts(component, **kwargs):

@ipc.restrict_to_process('webui')
def clear_patches():
from modules import sd_samplers
global __original_create_sampler
from modules import sd_samplers, generation_parameters_copypaste
global __original_create_sampler, __original_re_param_code

if __original_create_sampler is not None:
sd_samplers.create_sampler = __original_create_sampler

sd_samplers.create_sampler = __original_create_sampler
if __original_re_param_code is not None:
generation_parameters_copypaste.re_param_code = __original_re_param_code
generation_parameters_copypaste.re_param = re.compile(generation_parameters_copypaste.re_param_code)


@ipc.restrict_to_process('webui')
Expand Down
4 changes: 4 additions & 0 deletions lib_comfyui/webui/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def model_patches_to(self, device):
def model_dtype(self):
return self.model.dtype

@property
def current_device(self):
return self.model.device

def add_patches(self, *args, **kwargs):
soft_raise('patching a webui resource is not yet supported')
return []
Expand Down
6 changes: 3 additions & 3 deletions scripts/comfyui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from modules import scripts
from lib_comfyui import global_state, platform_utils, external_code, default_workflow_types, comfyui_process
from lib_comfyui.webui import callbacks, settings, workflow_patcher, gradio_utils, accordion
from lib_comfyui.webui import callbacks, settings, patches, gradio_utils, accordion
from lib_comfyui.comfyui import iframe_requests, type_conversion


Expand Down Expand Up @@ -51,7 +51,7 @@ def process(self, p, queue_front, enabled_workflow_type_ids, **kwargs):
global_state.enabled_workflow_type_ids.update(enabled_workflow_type_ids)

global_state.queue_front = queue_front
workflow_patcher.patch_processing(p)
patches.patch_processing(p)

def postprocess_batch_list(self, p, pp, *args, **kwargs):
if not getattr(global_state, 'enabled', True):
Expand Down Expand Up @@ -103,4 +103,4 @@ def extract_contiguous_buckets(images, batch_size):
callbacks.register_callbacks()
default_workflow_types.add_default_workflow_types()
settings.init_extension_base_dir()
workflow_patcher.apply_patches()
patches.apply_patches()
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_valid_workflow_with_dict_types(self, mock_start_workflow):
mock_start_workflow.assert_called_once_with(
batch_input_args=("value", 123),
workflow_type_id="test_tab",
workflow_input_types=('IMAGE', 'LATENT'),
queue_front=True,
)
self.assertEqual(result, mock_start_workflow.return_value)
Expand All @@ -48,6 +49,7 @@ def test_valid_workflow_with_tuple_types(self, mock_start_workflow):
mock_start_workflow.assert_called_once_with(
batch_input_args=("value", 123),
workflow_type_id="test_tab",
workflow_input_types=('IMAGE', 'LATENT'),
queue_front=True,
)
self.assertEqual(result, [tuple(batch.values()) for batch in mock_start_workflow.return_value])
Expand All @@ -68,6 +70,7 @@ def test_valid_workflow_with_str_types(self, mock_start_workflow):
mock_start_workflow.assert_called_once_with(
batch_input_args=("value",),
workflow_type_id="test_tab",
workflow_input_types=('IMAGE',),
queue_front=True,
)
self.assertEqual(result, [next(iter(batch.values())) for batch in mock_start_workflow.return_value])
Expand Down Expand Up @@ -179,6 +182,7 @@ def test_identity_on_error(self, mock_start_workflow):
mock_start_workflow.assert_called_with(
batch_input_args=("value",),
workflow_type_id="test_tab",
workflow_input_types=('IMAGE',),
queue_front=True,
)

Expand Down

0 comments on commit a332cfa

Please sign in to comment.