Skip to content

Commit

Permalink
prompt node
Browse files Browse the repository at this point in the history
  • Loading branch information
ljleb committed Aug 13, 2023
1 parent 3741422 commit fbdf0f0
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -1,33 +1,57 @@
from lib_comfyui.webui import proxies


class WebuiCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"void": ("VOID", ),
},
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"

CATEGORY = "loaders"

def load_checkpoint(self, void):
config = proxies.get_comfy_model_config()
proxies.raise_on_unsupported_model_type(config)
return (
proxies.ModelPatcher(proxies.Model()),
proxies.ClipWrapper(proxies.Clip()),
proxies.VaeWrapper(proxies.Vae()),
)


NODE_CLASS_MAPPINGS = {
"WebuiCheckpointLoader": WebuiCheckpointLoader,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"WebuiCheckpointLoader": 'Webui Checkpoint',
}
from lib_comfyui.webui import proxies
from lib_comfyui import global_state


class WebuiCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"void": ("VOID", ),
},
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"

CATEGORY = "loaders"

def load_checkpoint(self, void):
config = proxies.get_comfy_model_config()
proxies.raise_on_unsupported_model_type(config)
return (
proxies.ModelPatcher(proxies.Model()),
proxies.ClipWrapper(proxies.Clip()),
proxies.VaeWrapper(proxies.Vae()),
)


class WebuiPrompts:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"void": ("VOID", ),
},
}
RETURN_TYPES = ("STRING", "STRING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "get_prompts"

CATEGORY = "loaders"

def get_prompts(self, void):
return (
getattr(global_state, 'last_positive_prompt', ''),
getattr(global_state, 'last_negative_prompt', ''),
)


NODE_CLASS_MAPPINGS = {
"WebuiCheckpointLoader": WebuiCheckpointLoader,
"WebuiPrompts": WebuiPrompts,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"WebuiCheckpointLoader": 'Webui Checkpoint',
"WebuiPrompts": "Webui Prompts",
}
6 changes: 6 additions & 0 deletions lib_comfyui/webui/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def register_callbacks():
from modules import script_callbacks
script_callbacks.on_ui_tabs(on_ui_tabs)
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_after_component(on_after_component)
script_callbacks.on_app_started(on_app_started)
script_callbacks.on_script_unloaded(on_script_unloaded)

Expand All @@ -21,6 +22,11 @@ def on_ui_settings():
return settings.create_section()


@ipc.restrict_to_process('webui')
def on_after_component(*args, **kwargs):
return workflow_patcher.watch_prompts(*args, **kwargs)


@ipc.restrict_to_process('webui')
def on_app_started(_gr_root, _fast_api):
comfyui_process.start()
Expand Down
18 changes: 18 additions & 0 deletions lib_comfyui/webui/workflow_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@ def apply_patches():
sd_samplers.create_sampler = functools.partial(create_sampler_hijack, original_function=sd_samplers.create_sampler)


@ipc.restrict_to_process('webui')
def watch_prompts(component, **kwargs):
possible_ids = {
f'{tab}{negative}_prompt': bool(negative)
for tab in ('txt2img', 'img2img')
for negative in ('', '_neg')
}
event_listeners = ('change', 'input', 'blur')

if (elem_id := getattr(component, 'elem_id', None)) in possible_ids:
attribute = f'last_{"negative" if possible_ids[elem_id] else "positive"}_prompt'
for event_listener in event_listeners:
getattr(component, event_listener)(
fn = lambda p: setattr(global_state, attribute, p),
inputs=[component]
)


@ipc.restrict_to_process('webui')
def clear_patches():
from modules import sd_samplers
Expand Down

0 comments on commit fbdf0f0

Please sign in to comment.