Skip to content

Commit

Permalink
prompt node (#132)
Browse files Browse the repository at this point in the history
* prompt node

* skip extra networks for now

* rename

* category

* category

* cleanup

* dont need input
  • Loading branch information
ljleb authored Aug 13, 2023
1 parent 0f92ca3 commit aaf9e96
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -1,33 +1,59 @@
from lib_comfyui.webui import proxies


class WebuiCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"void": ("VOID", ),
},
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"

CATEGORY = "loaders"

def load_checkpoint(self, void):
config = proxies.get_comfy_model_config()
proxies.raise_on_unsupported_model_type(config)
return (
proxies.ModelPatcher(proxies.Model()),
proxies.ClipWrapper(proxies.Clip()),
proxies.VaeWrapper(proxies.Vae()),
)


NODE_CLASS_MAPPINGS = {
"WebuiCheckpointLoader": WebuiCheckpointLoader,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"WebuiCheckpointLoader": 'Webui Checkpoint',
}
from lib_comfyui.webui import proxies
from lib_comfyui import global_state


class WebuiCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"void": ("VOID", ),
},
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"

CATEGORY = "loaders"

def load_checkpoint(self, void):
config = proxies.get_comfy_model_config()
proxies.raise_on_unsupported_model_type(config)
return (
proxies.ModelPatcher(proxies.Model()),
proxies.ClipWrapper(proxies.Clip()),
proxies.VaeWrapper(proxies.Vae()),
)


class WebuiPrompts:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"void": ("VOID", ),
},
}
RETURN_TYPES = ("STRING", "STRING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "get_prompts"

CATEGORY = "text"

def get_prompts(self, void):
positive_prompts, _extra_networks = proxies.extra_networks_parse_prompts([getattr(global_state, 'last_positive_prompt', '')])

return (
positive_prompts[0],
getattr(global_state, 'last_negative_prompt', ''),
)


NODE_CLASS_MAPPINGS = {
"WebuiCheckpointLoader": WebuiCheckpointLoader,
"WebuiPrompts": WebuiPrompts,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"WebuiCheckpointLoader": 'Webui Checkpoint',
"WebuiPrompts": "Webui Prompts",
}
6 changes: 6 additions & 0 deletions lib_comfyui/webui/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def register_callbacks():
from modules import script_callbacks
script_callbacks.on_ui_tabs(on_ui_tabs)
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_after_component(on_after_component)
script_callbacks.on_app_started(on_app_started)
script_callbacks.on_script_unloaded(on_script_unloaded)

Expand All @@ -21,6 +22,11 @@ def on_ui_settings():
return settings.create_section()


@ipc.restrict_to_process('webui')
def on_after_component(*args, **kwargs):
return workflow_patcher.watch_prompts(*args, **kwargs)


@ipc.restrict_to_process('webui')
def on_app_started(_gr_root, _fast_api):
comfyui_process.start()
Expand Down
6 changes: 6 additions & 0 deletions lib_comfyui/webui/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,5 +373,11 @@ def sd_model_get_config():
return sd_models_config.find_checkpoint_config(shared.sd_model.state_dict(), sd_models.select_checkpoint())


@ipc.run_in_process('webui')
def extra_networks_parse_prompts(prompts):
from modules import extra_networks
return extra_networks.parse_prompts(prompts)


def soft_raise(message):
print(f'[sd-webui-comfyui] {message}', file=sys.stderr)
19 changes: 19 additions & 0 deletions lib_comfyui/webui/workflow_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,25 @@ def apply_patches():
sd_samplers.create_sampler = functools.partial(create_sampler_hijack, original_function=sd_samplers.create_sampler)


@ipc.restrict_to_process('webui')
def watch_prompts(component, **kwargs):
possible_elem_ids = {
f'{tab}{negative}_prompt': bool(negative)
for tab in ('txt2img', 'img2img')
for negative in ('', '_neg')
}
event_listeners = ('change', 'blur')

elem_id = getattr(component, 'elem_id', None)
if elem_id in possible_elem_ids:
attribute = f'last_{"negative" if possible_elem_ids[elem_id] else "positive"}_prompt'
for event_listener in event_listeners:
getattr(component, event_listener)(
fn = lambda p: setattr(global_state, attribute, p),
inputs=[component]
)


@ipc.restrict_to_process('webui')
def clear_patches():
from modules import sd_samplers
Expand Down

0 comments on commit aaf9e96

Please sign in to comment.