Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ui): image batching in workflows #7343

Merged
merged 18 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
8c6ffa0
feat(app): add `ImageField` as an allowed batching data type
psychedelicious Nov 16, 2024
8a99354
feat(app): add `Classification.Special`, used for batch nodes
psychedelicious Nov 16, 2024
effb70b
feat(nodes): add `ImageBatchInvocation`
psychedelicious Nov 16, 2024
dd1fc1d
chore(ui): typegen
psychedelicious Nov 16, 2024
0e8120f
feat(ui): image batching in workflows
psychedelicious Nov 16, 2024
966562a
feat(nodes): add minimum image count to ImageBatchInvocation
psychedelicious Nov 16, 2024
ebdb48b
chore(ui): typegen
psychedelicious Nov 16, 2024
79e4087
feat(ui): support min and max length for image collections
psychedelicious Nov 16, 2024
442e6df
fix(ui): zod schema refiners must return boolean
psychedelicious Nov 16, 2024
a8159b0
fix(ui): image field collection dnd adds instead of replaces
psychedelicious Nov 16, 2024
7a21597
feat(ui): autosize image collection field grid
psychedelicious Nov 16, 2024
c080a62
fix(ui): do not allow invoking when canvas is selectig object
psychedelicious Nov 18, 2024
7fc9fd8
feat(ui): add graph validation for image collection size
psychedelicious Nov 18, 2024
8fd9a63
feat(ui): update field validation logic to handle collection sizes
psychedelicious Nov 18, 2024
fbf1355
feat(ui): add reset to default value button to field title
psychedelicious Nov 18, 2024
ecc5f29
feat(ui): allow removing individual images from batch
psychedelicious Nov 18, 2024
1f86188
feat(ui): make image field collection scrollable
psychedelicious Nov 18, 2024
3ce766d
fix(ui): reactflow drag interactions with custom scrollbar
psychedelicious Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions invokeai/app/invocations/baseinvocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,15 @@ class Classification(str, Enum, metaclass=MetaEnum):
- `Prototype`: The invocation is not yet stable and may be removed from the application at any time. Workflows built around this invocation may break, and we are *not* committed to supporting this invocation.
- `Deprecated`: The invocation is deprecated and may be removed in a future version.
- `Internal`: The invocation is not intended for use by end-users. It may be changed or removed at any time, but is exposed for users to play with.
- `Special`: The invocation is a special case and does not fit into any of the other classifications.
"""

Stable = "stable"
Beta = "beta"
Prototype = "prototype"
Deprecated = "deprecated"
Internal = "internal"
Special = "special"


class UIConfigBase(BaseModel):
Expand Down
28 changes: 27 additions & 1 deletion invokeai/app/invocations/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@

import torch

from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
BoundingBoxField,
Expand Down Expand Up @@ -533,3 +539,23 @@ def invoke(self, context: InvocationContext) -> BoundingBoxOutput:


# endregion


@invocation(
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "internal"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class ImageBatchInvocation(BaseInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""

images: list[ImageField] = InputField(min_length=1, description="The images to batch over", input=Input.Direct)

def __init__(self):
raise NotImplementedError("This class should never be executed or instantiated directly.")

def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")
7 changes: 2 additions & 5 deletions invokeai/app/services/session_queue/session_queue_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pydantic_core import to_jsonable_python

from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.fields import ImageField
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
from invokeai.app.services.workflow_records.workflow_records_common import (
WorkflowWithoutID,
Expand Down Expand Up @@ -51,11 +52,7 @@ class SessionQueueItemNotFoundError(ValueError):

# region Batch

BatchDataType = Union[
StrictStr,
float,
int,
]
BatchDataType = Union[StrictStr, float, int, ImageField]


class NodeFieldValue(BaseModel):
Expand Down
14 changes: 9 additions & 5 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -1015,8 +1015,11 @@
"addingImagesTo": "Adding images to",
"invoke": "Invoke",
"missingFieldTemplate": "Missing field template",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} missing input",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}}: missing input",
"missingNodeTemplate": "Missing node template",
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} empty collection",
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: too few items, minimum {{minItems}}",
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: too many items, maximum {{maxItems}}",
"noModelSelected": "No model selected",
"noT5EncoderModelSelected": "No T5 Encoder model selected for FLUX generation",
"noFLUXVAEModelSelected": "No VAE model selected for FLUX generation",
Expand All @@ -1025,10 +1028,11 @@
"fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), bbox height is {{height}}",
"fluxModelIncompatibleScaledBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), scaled bbox width is {{width}}",
"fluxModelIncompatibleScaledBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), scaled bbox height is {{height}}",
"canvasIsFiltering": "Canvas is filtering",
"canvasIsTransforming": "Canvas is transforming",
"canvasIsRasterizing": "Canvas is rasterizing",
"canvasIsCompositing": "Canvas is compositing",
"canvasIsFiltering": "Canvas is busy (filtering)",
"canvasIsTransforming": "Canvas is busy (transforming)",
"canvasIsRasterizing": "Canvas is busy (rasterizing)",
"canvasIsCompositing": "Canvas is busy (compositing)",
"canvasIsSelectingObject": "Canvas is busy (selecting object)",
"noPrompts": "No prompts generated",
"noNodesInGraph": "No nodes in graph",
"systemDisconnected": "System disconnected",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig } from 'services/api/types';
import type { Batch, BatchConfig } from 'services/api/types';

const log = logger('workflows');

export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) => {
startAppListening({
Expand All @@ -26,13 +31,41 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
delete builtWorkflow.id;
}

const data: Batch['data'] = [];

// Skip edges from batch nodes - these should not be in the graph, they exist only in the UI
const imageBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'image_batch');
for (const node of imageBatchNodes) {
const images = node.data.inputs['images'];
if (!isImageFieldCollectionInputInstance(images)) {
log.warn({ nodeId: node.id }, 'Image batch images field is not an image collection');
break;
}
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
const batchDataCollectionItem: NonNullable<Batch['data']>[number] = [];
for (const edge of edgesFromImageBatch) {
if (!edge.targetHandle) {
break;
}
batchDataCollectionItem.push({
node_path: edge.target,
field_name: edge.targetHandle,
items: images.value,
});
}
if (batchDataCollectionItem.length > 0) {
data.push(batchDataCollectionItem);
}
}

const batchConfig: BatchConfig = {
batch: {
graph,
workflow: builtWorkflow,
runs: state.params.iterations,
origin: 'workflows',
destination: 'gallery',
data,
},
prepend: action.payload.prepend,
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { deepClone } from 'common/util/deepClone';
import { merge } from 'lodash-es';
import { ClickScrollPlugin, OverlayScrollbars } from 'overlayscrollbars';
import type { UseOverlayScrollbarsParams } from 'overlayscrollbars-react';
import type { CSSProperties } from 'react';

OverlayScrollbars.plugin(ClickScrollPlugin);

Expand All @@ -27,3 +28,8 @@ export const getOverlayScrollbarsParams = (
merge(params, { options: { overflow: { y: overflowY, x: overflowX } } });
return params;
};

export const overlayScrollbarsStyles: CSSProperties = {
height: '100%',
width: '100%',
};
35 changes: 30 additions & 5 deletions invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,9 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
};

const sx = {
borderColor: 'error.500',
borderStyle: 'solid',
borderWidth: 0,
borderRadius: 'base',
'&[data-error=true]': {
borderColor: 'error.500',
borderStyle: 'solid',
borderWidth: 1,
},
} satisfies SystemStyleObject;
Expand All @@ -164,7 +162,34 @@ export const UploadImageButton = ({
<>
<IconButton
aria-label="Upload image"
variant="ghost"
variant="outline"
sx={sx}
data-error={isError}
icon={<PiUploadBold />}
isLoading={uploadApi.request.isLoading}
{...rest}
{...uploadApi.getUploadButtonProps()}
/>
<input {...uploadApi.getUploadInputProps()} />
</>
);
};

export const UploadMultipleImageButton = ({
isDisabled = false,
onUpload,
isError = false,
...rest
}: {
onUpload?: (imageDTOs: ImageDTO[]) => void;
isError?: boolean;
} & SetOptional<IconButtonProps, 'aria-label'>) => {
const uploadApi = useImageUploadButton({ isDisabled, allowMultiple: true, onUpload });
return (
<>
<IconButton
aria-label="Upload image"
variant="outline"
sx={sx}
data-error={isError}
icon={<PiUploadBold />}
Expand Down
94 changes: 75 additions & 19 deletions invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { $templates } from 'features/nodes/store/nodesSlice';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import type { Templates } from 'features/nodes/store/types';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { isImageFieldCollectionInputInstance, isImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
import { selectConfigSlice } from 'features/system/store/configSlice';
Expand All @@ -30,15 +31,25 @@ const LAYER_TYPE_TO_TKEY = {
control_layer: 'controlLayers.controlLayer',
} as const;

const createSelector = (
templates: Templates,
isConnected: boolean,
canvasIsFiltering: boolean,
canvasIsTransforming: boolean,
canvasIsRasterizing: boolean,
canvasIsCompositing: boolean
) =>
createMemoizedSelector(
const createSelector = (arg: {
templates: Templates;
isConnected: boolean;
canvasIsFiltering: boolean;
canvasIsTransforming: boolean;
canvasIsRasterizing: boolean;
canvasIsCompositing: boolean;
canvasIsSelectingObject: boolean;
}) => {
const {
templates,
isConnected,
canvasIsFiltering,
canvasIsTransforming,
canvasIsRasterizing,
canvasIsCompositing,
canvasIsSelectingObject,
} = arg;
return createMemoizedSelector(
[
selectSystemSlice,
selectNodesSlice,
Expand Down Expand Up @@ -93,14 +104,45 @@ const createSelector = (
return;
}

const baseTKeyOptions = {
nodeLabel: node.data.label || nodeTemplate.title,
fieldLabel: field.label || fieldTemplate.title,
};

if (fieldTemplate.required && field.value === undefined && !hasConnection) {
reasons.push({
content: i18n.t('parameters.invoke.missingInputForField', {
nodeLabel: node.data.label || nodeTemplate.title,
fieldLabel: field.label || fieldTemplate.title,
}),
});
reasons.push({ content: i18n.t('parameters.invoke.missingInputForField', baseTKeyOptions) });
return;
} else if (
field.value &&
isImageFieldCollectionInputInstance(field) &&
isImageFieldCollectionInputTemplate(fieldTemplate)
) {
// Image collections may have min or max items to validate
// TODO(psyche): generalize this to other collection types
if (fieldTemplate.minItems !== undefined && fieldTemplate.minItems > 0 && field.value.length === 0) {
reasons.push({ content: i18n.t('parameters.invoke.collectionEmpty', baseTKeyOptions) });
return;
}
if (fieldTemplate.minItems !== undefined && field.value.length < fieldTemplate.minItems) {
reasons.push({
content: i18n.t('parameters.invoke.collectionTooFewItems', {
...baseTKeyOptions,
size: field.value.length,
minItems: fieldTemplate.minItems,
}),
});
return;
}
if (fieldTemplate.maxItems !== undefined && field.value.length > fieldTemplate.maxItems) {
reasons.push({
content: i18n.t('parameters.invoke.collectionTooManyItems', {
...baseTKeyOptions,
size: field.value.length,
maxItems: fieldTemplate.maxItems,
}),
});
return;
}
}
});
});
Expand Down Expand Up @@ -147,6 +189,9 @@ const createSelector = (
if (canvasIsCompositing) {
reasons.push({ content: i18n.t('parameters.invoke.canvasIsCompositing') });
}
if (canvasIsSelectingObject) {
reasons.push({ content: i18n.t('parameters.invoke.canvasIsSelectingObject') });
}

if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) {
reasons.push({ content: i18n.t('parameters.invoke.noPrompts') });
Expand Down Expand Up @@ -305,6 +350,7 @@ const createSelector = (
return { isReady: !reasons.length, reasons };
}
);
};

export const useIsReadyToEnqueue = () => {
const templates = useStore($templates);
Expand All @@ -313,18 +359,28 @@ export const useIsReadyToEnqueue = () => {
const canvasIsFiltering = useStore(canvasManager?.stateApi.$isFiltering ?? $true);
const canvasIsTransforming = useStore(canvasManager?.stateApi.$isTransforming ?? $true);
const canvasIsRasterizing = useStore(canvasManager?.stateApi.$isRasterizing ?? $true);
const canvasIsSelectingObject = useStore(canvasManager?.stateApi.$isSegmenting ?? $true);
const canvasIsCompositing = useStore(canvasManager?.compositor.$isBusy ?? $true);
const selector = useMemo(
() =>
createSelector(
createSelector({
templates,
isConnected,
canvasIsFiltering,
canvasIsTransforming,
canvasIsRasterizing,
canvasIsCompositing
),
[templates, isConnected, canvasIsFiltering, canvasIsTransforming, canvasIsRasterizing, canvasIsCompositing]
canvasIsCompositing,
canvasIsSelectingObject,
}),
[
templates,
isConnected,
canvasIsFiltering,
canvasIsTransforming,
canvasIsRasterizing,
canvasIsCompositing,
canvasIsSelectingObject,
]
);
const value = useAppSelector(selector);
return value;
Expand Down
13 changes: 8 additions & 5 deletions invokeai/frontend/web/src/features/dnd/DndImage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@ const sx = {
},
} satisfies SystemStyleObject;

type Props = ImageProps & {
imageDTO: ImageDTO;
asThumbnail?: boolean;
};
/* eslint-disable-next-line @typescript-eslint/no-namespace */
export namespace DndImage {
export interface Props extends ImageProps {
imageDTO: ImageDTO;
asThumbnail?: boolean;
}
}

export const DndImage = memo(({ imageDTO, asThumbnail, ...rest }: Props) => {
export const DndImage = memo(({ imageDTO, asThumbnail, ...rest }: DndImage.Props) => {
const store = useAppStore();
const [isDragging, setIsDragging] = useState(false);
const [element, ref] = useState<HTMLImageElement | null>(null);
Expand Down
Loading
Loading