Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
20ef2d8
tidy(nodes): move batch nodes to own file
psychedelicious Jan 14, 2025
55b5231
tidy(nodes): code dedupe for batch node init errors
psychedelicious Jan 14, 2025
cddfd95
feat(nodes): generators as nodes
psychedelicious Jan 16, 2025
000e444
chore(ui): typegen
psychedelicious Jan 16, 2025
53250be
feat(ui): support generator nodes (wip)
psychedelicious Jan 16, 2025
1f63472
feat(ui): don't show generator preview for random generators
psychedelicious Jan 16, 2025
a835348
chore(ui): lint
psychedelicious Jan 16, 2025
e98613c
feat(nodes): add integer generator nodes
psychedelicious Jan 16, 2025
0c67f9a
feat(nodes): remove default values for generator; let UI handle it
psychedelicious Jan 16, 2025
c49a800
fix(ui): use utils to get default float generator values
psychedelicious Jan 16, 2025
a0e1999
feat(ui): support integer generators
psychedelicious Jan 16, 2025
2de67d7
fix(ui): translation for generators
psychedelicious Jan 16, 2025
a45b3c6
feat(ui): rip out generator modal functionality
psychedelicious Jan 16, 2025
8cf3dac
fix(ui): batch size calculations
psychedelicious Jan 16, 2025
f254662
fix(ui): remove nonfunctional button
psychedelicious Jan 16, 2025
5eed9ef
feat(ui): reworked float/int generators (arithmetic sequence, linear …
psychedelicious Jan 16, 2025
808de6b
feat(ui): add integer & float parse string generators
psychedelicious Jan 16, 2025
6261685
tidy(ui): remove extraneous reset button on generators
psychedelicious Jan 16, 2025
1e426f3
feat(ui): better preview for generators
psychedelicious Jan 17, 2025
a35b47a
feat(ui): improved generator text area styling
psychedelicious Jan 17, 2025
cdb2db8
chore(ui): lint
psychedelicious Jan 17, 2025
245a792
chore(ui): typegen
psychedelicious Jan 17, 2025
bbac75a
chore(ui): lint
psychedelicious Jan 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions invokeai/app/invocations/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from typing import Literal

from pydantic import BaseModel

from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
ImageField,
Input,
InputField,
OutputField,
)
from invokeai.app.invocations.primitives import (
FloatOutput,
ImageOutput,
IntegerOutput,
StringOutput,
)
from invokeai.app.services.shared.invocation_context import InvocationContext

BATCH_GROUP_IDS = Literal[
"None",
"Group 1",
"Group 2",
"Group 3",
"Group 4",
"Group 5",
]


class NotExecutableNodeError(Exception):
def __init__(self, message: str = "This class should never be executed or instantiated directly."):
super().__init__(message)

pass


class BaseBatchInvocation(BaseInvocation):
batch_group_id: BATCH_GROUP_IDS = InputField(
default="None",
description="The ID of this batch node's group. If provided, all batch nodes in with the same ID will be 'zipped' before execution, and all nodes' collections must be of the same size.",
input=Input.Direct,
title="Batch Group",
)

def __init__(self):
raise NotExecutableNodeError()


@invocation(
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class ImageBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""

images: list[ImageField] = InputField(
default=[], min_length=1, description="The images to batch over", input=Input.Direct
)

def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotExecutableNodeError()


@invocation(
"string_batch",
title="String Batch",
tags=["primitives", "string", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class StringBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each string in the batch."""

strings: list[str] = InputField(
default=[], min_length=1, description="The strings to batch over", input=Input.Direct
)

def invoke(self, context: InvocationContext) -> StringOutput:
raise NotExecutableNodeError()


@invocation(
"integer_batch",
title="Integer Batch",
tags=["primitives", "integer", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class IntegerBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""

integers: list[int] = InputField(
default=[],
min_length=1,
description="The integers to batch over",
)

def invoke(self, context: InvocationContext) -> IntegerOutput:
raise NotExecutableNodeError()


@invocation_output("integer_generator_output")
class IntegerGeneratorOutput(BaseInvocationOutput):
integers: list[int] = OutputField(description="The generated integers")


class IntegerGeneratorField(BaseModel):
pass


@invocation(
"integer_generator",
title="Integer Generator",
tags=["primitives", "int", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class IntegerGenerator(BaseInvocation):
"""Generated a range of integers for use in a batched generation"""

generator: IntegerGeneratorField = InputField(
description="The integer generator.",
input=Input.Direct,
title="Generator Type",
)

def __init__(self):
raise NotExecutableNodeError()

def invoke(self, context: InvocationContext) -> IntegerGeneratorOutput:
raise NotExecutableNodeError()


@invocation(
"float_batch",
title="Float Batch",
tags=["primitives", "float", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class FloatBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each float in the batch."""

floats: list[float] = InputField(
default=[],
min_length=1,
description="The floats to batch over",
)

def invoke(self, context: InvocationContext) -> FloatOutput:
raise NotExecutableNodeError()


@invocation_output("float_generator_output")
class FloatGeneratorOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of floats"""

floats: list[float] = OutputField(description="The generated floats")


class FloatGeneratorField(BaseModel):
pass


@invocation(
"float_generator",
title="Float Generator",
tags=["primitives", "float", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class FloatGenerator(BaseInvocation):
"""Generated a range of floats for use in a batched generation"""

generator: FloatGeneratorField = InputField(
description="The float generator.",
input=Input.Direct,
title="Generator Type",
)

def __init__(self):
raise NotExecutableNodeError()

def invoke(self, context: InvocationContext) -> FloatGeneratorOutput:
raise NotExecutableNodeError()
100 changes: 1 addition & 99 deletions invokeai/app/invocations/primitives.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)

from typing import Literal, Optional
from typing import Optional

import torch

from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
Expand Down Expand Up @@ -539,100 +538,3 @@ def invoke(self, context: InvocationContext) -> BoundingBoxOutput:


# endregion

BATCH_GROUP_IDS = Literal[
"None",
"Group 1",
"Group 2",
"Group 3",
"Group 4",
"Group 5",
]


class BaseBatchInvocation(BaseInvocation):
batch_group_id: BATCH_GROUP_IDS = InputField(
default="None",
description="The ID of this batch node's group. If provided, all batch nodes in with the same ID will be 'zipped' before execution, and all nodes' collections must be of the same size.",
input=Input.Direct,
title="Batch Group",
)

def __init__(self):
raise NotImplementedError("This class should never be executed or instantiated directly.")


@invocation(
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class ImageBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""

images: list[ImageField] = InputField(
default=[], min_length=1, description="The images to batch over", input=Input.Direct
)

def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")


@invocation(
"string_batch",
title="String Batch",
tags=["primitives", "string", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class StringBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each string in the batch."""

strings: list[str] = InputField(
default=[], min_length=1, description="The strings to batch over", input=Input.Direct
)

def invoke(self, context: InvocationContext) -> StringOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")


@invocation(
"integer_batch",
title="Integer Batch",
tags=["primitives", "integer", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class IntegerBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""

integers: list[int] = InputField(
default=[], min_length=1, description="The integers to batch over", input=Input.Direct
)

def invoke(self, context: InvocationContext) -> IntegerOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")


@invocation(
"float_batch",
title="Float Batch",
tags=["primitives", "float", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class FloatBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each float in the batch."""

floats: list[float] = InputField(
default=[], min_length=1, description="The floats to batch over", input=Input.Direct
)

def invoke(self, context: InvocationContext) -> FloatOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")
19 changes: 16 additions & 3 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,11 @@
"start": "Start",
"count": "Count",
"step": "Step",
"values": "Values"
"end": "End",
"min": "Min",
"max": "Max",
"values": "Values",
"resetToDefaults": "Reset to Defaults"
},
"hrf": {
"hrf": "High Resolution Fix",
Expand Down Expand Up @@ -854,7 +858,15 @@
"defaultVAE": "Default VAE"
},
"nodes": {
"arithmeticSequence": "Arithmetic Sequence",
"linearDistribution": "Linear Distribution",
"uniformRandomDistribution": "Uniform Random Distribution",
"parseString": "Parse String",
"splitOn": "Split On",
"noBatchGroup": "no group",
"generatorNRandomValues_one": "{{count}} random value",
"generatorNRandomValues_other": "{{count}} random values",
"generatorNoValues": "empty",
"addNode": "Add Node",
"addNodeToolTip": "Add Node (Shift+A, Space)",
"addLinearView": "Add to Linear View",
Expand Down Expand Up @@ -1035,9 +1047,10 @@
"missingFieldTemplate": "Missing field template",
"missingInputForField": "missing input",
"missingNodeTemplate": "Missing node template",
"collectionEmpty": "empty collection",
"invalidBatchConfiguration": "Invalid batch configuration",
"emptyBatches": "empty batches",
"batchNodeNotConnected": "Batch node not connected: {{label}}",
"batchNodeEmptyCollection": "Some batch nodes have empty collections",
"invalidBatchConfigurationCannotCalculate": "Invalid batch configuration; cannot calculate",
"collectionTooFewItems": "too few items, minimum {{minItems}}",
"collectionTooManyItems": "too many items, maximum {{maxItems}}",
"collectionStringTooLong": "too long, max {{maxLength}}",
Expand Down
4 changes: 0 additions & 4 deletions invokeai/frontend/web/src/app/components/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicP
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
import { ImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { FloatRangeGeneratorModal } from 'features/nodes/components/FloatRangeGeneratorModal';
import { IntegerRangeGeneratorModal } from 'features/nodes/components/IntegerRangeGeneratorModal';
import { ShareWorkflowModal } from 'features/nodes/components/sidePanel/WorkflowListMenu/ShareWorkflowModal';
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
import { DeleteStylePresetDialog } from 'features/stylePresets/components/DeleteStylePresetDialog';
Expand Down Expand Up @@ -112,8 +110,6 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
<ImageContextMenu />
<FullscreenDropzone />
<VideosModal />
<FloatRangeGeneratorModal />
<IntegerRangeGeneratorModal />
</ErrorBoundary>
);
};
Expand Down
Loading
Loading