diff --git a/invokeai/app/invocations/batch.py b/invokeai/app/invocations/batch.py
new file mode 100644
index 00000000000..14b90338646
--- /dev/null
+++ b/invokeai/app/invocations/batch.py
@@ -0,0 +1,118 @@
+from typing import Literal
+
+from invokeai.app.invocations.baseinvocation import (
+ BaseInvocation,
+ Classification,
+ invocation,
+)
+from invokeai.app.invocations.fields import (
+ ImageField,
+ Input,
+ InputField,
+)
+from invokeai.app.invocations.primitives import FloatOutput, ImageOutput, IntegerOutput, StringOutput
+from invokeai.app.services.shared.invocation_context import InvocationContext
+
+BATCH_GROUP_IDS = Literal[
+ "None",
+ "Group 1",
+ "Group 2",
+ "Group 3",
+ "Group 4",
+ "Group 5",
+]
+
+
+class NotExecutableNodeError(Exception):
+ def __init__(self, message: str = "This class should never be executed or instantiated directly."):
+ super().__init__(message)
+
+ pass
+
+
+class BaseBatchInvocation(BaseInvocation):
+ batch_group_id: BATCH_GROUP_IDS = InputField(
+ default="None",
+ description="The ID of this batch node's group. If provided, all batch nodes in with the same ID will be 'zipped' before execution, and all nodes' collections must be of the same size.",
+ input=Input.Direct,
+ title="Batch Group",
+ )
+
+ def __init__(self):
+ raise NotExecutableNodeError()
+
+
+@invocation(
+ "image_batch",
+ title="Image Batch",
+ tags=["primitives", "image", "batch", "special"],
+ category="primitives",
+ version="1.0.0",
+ classification=Classification.Special,
+)
+class ImageBatchInvocation(BaseBatchInvocation):
+ """Create a batched generation, where the workflow is executed once for each image in the batch."""
+
+ images: list[ImageField] = InputField(
+ default=[], min_length=1, description="The images to batch over", input=Input.Direct
+ )
+
+ def invoke(self, context: InvocationContext) -> ImageOutput:
+ raise NotExecutableNodeError()
+
+
+@invocation(
+ "string_batch",
+ title="String Batch",
+ tags=["primitives", "string", "batch", "special"],
+ category="primitives",
+ version="1.0.0",
+ classification=Classification.Special,
+)
+class StringBatchInvocation(BaseBatchInvocation):
+ """Create a batched generation, where the workflow is executed once for each string in the batch."""
+
+ strings: list[str] = InputField(
+ default=[], min_length=1, description="The strings to batch over", input=Input.Direct
+ )
+
+ def invoke(self, context: InvocationContext) -> StringOutput:
+ raise NotExecutableNodeError()
+
+
+@invocation(
+ "integer_batch",
+ title="Integer Batch",
+ tags=["primitives", "integer", "number", "batch", "special"],
+ category="primitives",
+ version="1.0.0",
+ classification=Classification.Special,
+)
+class IntegerBatchInvocation(BaseBatchInvocation):
+ """Create a batched generation, where the workflow is executed once for each integer in the batch."""
+
+ integers: list[int] = InputField(
+ default=[], min_length=1, description="The integers to batch over", input=Input.Direct
+ )
+
+ def invoke(self, context: InvocationContext) -> IntegerOutput:
+ raise NotExecutableNodeError()
+
+
+@invocation(
+ "float_batch",
+ title="Float Batch",
+ tags=["primitives", "float", "number", "batch", "special"],
+ category="primitives",
+ version="1.0.0",
+ classification=Classification.Special,
+)
+class FloatBatchInvocation(BaseBatchInvocation):
+ """Create a batched generation, where the workflow is executed once for each float in the batch."""
+
+ floats: list[float] = InputField(
+ default=[], min_length=1, description="The floats to batch over", input=Input.Direct
+ )
+
+ def invoke(self, context: InvocationContext) -> FloatOutput:
+ raise NotExecutableNodeError()
diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py
index 3d0ac3d87c4..97de3eb8981 100644
--- a/invokeai/app/invocations/primitives.py
+++ b/invokeai/app/invocations/primitives.py
@@ -1,13 +1,12 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
-from typing import Literal, Optional
+from typing import Optional
import torch
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
- Classification,
invocation,
invocation_output,
)
@@ -539,100 +538,3 @@ def invoke(self, context: InvocationContext) -> BoundingBoxOutput:
# endregion
-
-BATCH_GROUP_IDS = Literal[
- "None",
- "Group 1",
- "Group 2",
- "Group 3",
- "Group 4",
- "Group 5",
-]
-
-
-class BaseBatchInvocation(BaseInvocation):
- batch_group_id: BATCH_GROUP_IDS = InputField(
- default="None",
- description="The ID of this batch node's group. If provided, all batch nodes in with the same ID will be 'zipped' before execution, and all nodes' collections must be of the same size.",
- input=Input.Direct,
- title="Batch Group",
- )
-
- def __init__(self):
- raise NotImplementedError("This class should never be executed or instantiated directly.")
-
-
-@invocation(
- "image_batch",
- title="Image Batch",
- tags=["primitives", "image", "batch", "special"],
- category="primitives",
- version="1.0.0",
- classification=Classification.Special,
-)
-class ImageBatchInvocation(BaseBatchInvocation):
- """Create a batched generation, where the workflow is executed once for each image in the batch."""
-
- images: list[ImageField] = InputField(
- default=[], min_length=1, description="The images to batch over", input=Input.Direct
- )
-
- def invoke(self, context: InvocationContext) -> ImageOutput:
- raise NotImplementedError("This class should never be executed or instantiated directly.")
-
-
-@invocation(
- "string_batch",
- title="String Batch",
- tags=["primitives", "string", "batch", "special"],
- category="primitives",
- version="1.0.0",
- classification=Classification.Special,
-)
-class StringBatchInvocation(BaseBatchInvocation):
- """Create a batched generation, where the workflow is executed once for each string in the batch."""
-
- strings: list[str] = InputField(
- default=[], min_length=1, description="The strings to batch over", input=Input.Direct
- )
-
- def invoke(self, context: InvocationContext) -> StringOutput:
- raise NotImplementedError("This class should never be executed or instantiated directly.")
-
-
-@invocation(
- "integer_batch",
- title="Integer Batch",
- tags=["primitives", "integer", "number", "batch", "special"],
- category="primitives",
- version="1.0.0",
- classification=Classification.Special,
-)
-class IntegerBatchInvocation(BaseBatchInvocation):
- """Create a batched generation, where the workflow is executed once for each integer in the batch."""
-
- integers: list[int] = InputField(
- default=[], min_length=1, description="The integers to batch over", input=Input.Direct
- )
-
- def invoke(self, context: InvocationContext) -> IntegerOutput:
- raise NotImplementedError("This class should never be executed or instantiated directly.")
-
-
-@invocation(
- "float_batch",
- title="Float Batch",
- tags=["primitives", "float", "number", "batch", "special"],
- category="primitives",
- version="1.0.0",
- classification=Classification.Special,
-)
-class FloatBatchInvocation(BaseBatchInvocation):
- """Create a batched generation, where the workflow is executed once for each float in the batch."""
-
- floats: list[float] = InputField(
- default=[], min_length=1, description="The floats to batch over", input=Input.Direct
- )
-
- def invoke(self, context: InvocationContext) -> FloatOutput:
- raise NotImplementedError("This class should never be executed or instantiated directly.")
diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index 799707e93b0..5253e3418df 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -855,7 +855,13 @@
},
"nodes": {
"noBatchGroup": "no group",
+ "generator": "Generator",
+ "generatedValues": "Generated Values",
+ "commitValues": "Commit Values",
+ "addValue": "Add Value",
"addNode": "Add Node",
+ "lockLinearView": "Lock Linear View",
+ "unlockLinearView": "Unlock Linear View",
"addNodeToolTip": "Add Node (Shift+A, Space)",
"addLinearView": "Add to Linear View",
"animatedEdges": "Animated Edges",
@@ -994,11 +1000,7 @@
"imageAccessError": "Unable to find image {{image_name}}, resetting to default",
"boardAccessError": "Unable to find board {{board_id}}, resetting to default",
"modelAccessError": "Unable to find model {{key}}, resetting to default",
- "saveToGallery": "Save To Gallery",
- "addItem": "Add Item",
- "generateValues": "Generate Values",
- "floatRangeGenerator": "Float Range Generator",
- "integerRangeGenerator": "Integer Range Generator"
+ "saveToGallery": "Save To Gallery"
},
"parameters": {
"aspect": "Aspect",
diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx
index a0d7f481e26..2902c344adb 100644
--- a/invokeai/frontend/web/src/app/components/App.tsx
+++ b/invokeai/frontend/web/src/app/components/App.tsx
@@ -22,8 +22,6 @@ import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicP
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
import { ImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
-import { FloatRangeGeneratorModal } from 'features/nodes/components/FloatRangeGeneratorModal';
-import { IntegerRangeGeneratorModal } from 'features/nodes/components/IntegerRangeGeneratorModal';
import { ShareWorkflowModal } from 'features/nodes/components/sidePanel/WorkflowListMenu/ShareWorkflowModal';
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
import { DeleteStylePresetDialog } from 'features/stylePresets/components/DeleteStylePresetDialog';
@@ -112,8 +110,6 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
-
-
);
};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts
index 2e87e925d6e..06ab99a0fd9 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts
@@ -9,6 +9,7 @@ import {
isIntegerFieldCollectionInputInstance,
isStringFieldCollectionInputInstance,
} from 'features/nodes/types/field';
+import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
import type { InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
@@ -140,10 +141,11 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
+ const resolvedValue = resolveNumberFieldCollectionValue(integers);
if (batchGroupId !== 'None') {
- addZippedBatchDataCollectionItem(edgesFromStringBatch, integers.value);
+ addZippedBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
} else {
- addProductBatchDataCollectionItem(edgesFromStringBatch, integers.value);
+ addProductBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
}
}
@@ -163,10 +165,11 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
+ const resolvedValue = resolveNumberFieldCollectionValue(floats);
if (batchGroupId !== 'None') {
- addZippedBatchDataCollectionItem(edgesFromStringBatch, floats.value);
+ addZippedBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
} else {
- addProductBatchDataCollectionItem(edgesFromStringBatch, floats.value);
+ addProductBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
}
}
diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts
index a36300cca98..83676c6b6cb 100644
--- a/invokeai/frontend/web/src/app/store/store.ts
+++ b/invokeai/frontend/web/src/app/store/store.ts
@@ -166,8 +166,10 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
reducer: rememberedRootReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
- serializableCheck: import.meta.env.MODE === 'development',
- immutableCheck: import.meta.env.MODE === 'development',
+ serializableCheck: false,
+ immutableCheck: false,
+ // serializableCheck: import.meta.env.MODE === 'development',
+ // immutableCheck: import.meta.env.MODE === 'development',
})
.concat(api.middleware)
.concat(dynamicMiddlewares)
diff --git a/invokeai/frontend/web/src/features/nodes/components/FloatRangeGeneratorModal.tsx b/invokeai/frontend/web/src/features/nodes/components/FloatRangeGeneratorModal.tsx
deleted file mode 100644
index 6ef8f72001f..00000000000
--- a/invokeai/frontend/web/src/features/nodes/components/FloatRangeGeneratorModal.tsx
+++ /dev/null
@@ -1,105 +0,0 @@
-import {
- Button,
- CompositeNumberInput,
- Flex,
- FormControl,
- FormLabel,
- IconButton,
- Modal,
- ModalBody,
- ModalCloseButton,
- ModalContent,
- ModalFooter,
- ModalHeader,
- ModalOverlay,
- Text,
-} from '@invoke-ai/ui-library';
-import { useStore } from '@nanostores/react';
-import { round } from 'lodash-es';
-import { atom } from 'nanostores';
-import { memo, useCallback, useMemo, useState } from 'react';
-import { useTranslation } from 'react-i18next';
-import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
-
-type FloatRangeGeneratorModalState = {
- isOpen: boolean;
- onSave: (values: number[]) => void;
-};
-
-const $floatRangeGeneratorModal = atom({
- isOpen: false,
- onSave: () => {},
-});
-
-export const openFloatRangeGeneratorModal = (onSave: (values: number[]) => void) => {
- $floatRangeGeneratorModal.set({ ...$floatRangeGeneratorModal.get(), isOpen: true, onSave });
-};
-
-const onClose = () => {
- $floatRangeGeneratorModal.set({ ...$floatRangeGeneratorModal.get(), isOpen: false });
-};
-
-export const FloatRangeGeneratorModal = memo(() => {
- const { isOpen, onSave } = useStore($floatRangeGeneratorModal);
- const { t } = useTranslation();
-
- const [start, setStart] = useState(0);
- const [step, setStep] = useState(1);
- const [count, setCount] = useState(1);
-
- const values = useMemo(() => Array.from({ length: count }, (_, i) => start + i * step), [start, step, count]);
-
- const onReset = useCallback(() => {
- setStart(0);
- setStep(1);
- setCount(1);
- }, []);
-
- const onClickSave = useCallback(() => {
- onSave(values);
- onClose();
- }, [onSave, values]);
-
- return (
-
-
-
- {t('nodes.floatRangeGenerator')}
-
-
-
-
- {t('common.start')}
-
-
-
- {t('common.count')}
-
-
-
- {t('common.step')}
-
-
- } onClick={onReset} variant="ghost" />
-
-
- {t('common.values')}
-
-
- {values.map((val) => round(val, 2)).join(', ')}
-
-
-
-
-
-
-
-
-
-
- );
-});
-
-FloatRangeGeneratorModal.displayName = 'FloatRangeGeneratorModal';
diff --git a/invokeai/frontend/web/src/features/nodes/components/IntegerRangeGeneratorModal.tsx b/invokeai/frontend/web/src/features/nodes/components/IntegerRangeGeneratorModal.tsx
deleted file mode 100644
index aff6fd8dcba..00000000000
--- a/invokeai/frontend/web/src/features/nodes/components/IntegerRangeGeneratorModal.tsx
+++ /dev/null
@@ -1,103 +0,0 @@
-import {
- Button,
- CompositeNumberInput,
- Flex,
- FormControl,
- FormLabel,
- IconButton,
- Modal,
- ModalBody,
- ModalCloseButton,
- ModalContent,
- ModalFooter,
- ModalHeader,
- ModalOverlay,
- Text,
-} from '@invoke-ai/ui-library';
-import { useStore } from '@nanostores/react';
-import { atom } from 'nanostores';
-import { memo, useCallback, useMemo, useState } from 'react';
-import { useTranslation } from 'react-i18next';
-import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
-
-type IntegerRangeGeneratorModalState = {
- isOpen: boolean;
- onSave: (values: number[]) => void;
-};
-
-const $integerRangeGeneratorModal = atom({
- isOpen: false,
- onSave: () => {},
-});
-
-export const openIntegerRangeGeneratorModal = (onSave: (values: number[]) => void) => {
- $integerRangeGeneratorModal.set({ ...$integerRangeGeneratorModal.get(), isOpen: true, onSave });
-};
-
-const onClose = () => {
- $integerRangeGeneratorModal.set({ ...$integerRangeGeneratorModal.get(), isOpen: false });
-};
-
-export const IntegerRangeGeneratorModal = memo(() => {
- const { isOpen, onSave } = useStore($integerRangeGeneratorModal);
- const { t } = useTranslation();
- const [start, setStart] = useState(0);
- const [step, setStep] = useState(1);
- const [count, setCount] = useState(1);
-
- const values = useMemo(() => Array.from({ length: count }, (_, i) => start + i * step), [start, step, count]);
-
- const onReset = useCallback(() => {
- setStart(0);
- setStep(1);
- setCount(1);
- }, []);
-
- const onClickSave = useCallback(() => {
- onSave(values);
- onClose();
- }, [onSave, values]);
-
- return (
-
-
-
- {t('nodes.integerRangeGenerator')}
-
-
-
-
- {t('common.start')}
-
-
-
- {t('common.count')}
-
-
-
- {t('common.step')}
-
-
- } onClick={onReset} variant="ghost" />
-
-
- {t('common.values')}
-
-
- {values.join(', ')}
-
-
-
-
-
-
-
-
-
-
- );
-});
-
-IntegerRangeGeneratorModal.displayName = 'IntegerRangeGeneratorModal';
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx
index baa7fc262a8..46fb70b6598 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx
@@ -43,7 +43,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
{fieldNames.connectionFields.map((fieldName, i) => (
-
+
))}
@@ -59,7 +59,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
nodeId={nodeId}
fieldName={fieldName}
>
-
+
))}
{fieldNames.missingFields.map((fieldName) => (
@@ -68,7 +68,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
nodeId={nodeId}
fieldName={fieldName}
>
-
+
))}
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldLinearViewToggle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldLinearViewToggle.tsx
index ff59e029167..ab1e0acccca 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldLinearViewToggle.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldLinearViewToggle.tsx
@@ -1,7 +1,7 @@
import { IconButton } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
-import { useFieldValue } from 'features/nodes/hooks/useFieldValue';
+import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import {
selectWorkflowSlice,
workflowExposedFieldAdded,
@@ -19,7 +19,7 @@ type Props = {
const FieldLinearViewToggle = ({ nodeId, fieldName }: Props) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
- const value = useFieldValue(nodeId, fieldName);
+ const field = useFieldInputInstance(nodeId, fieldName);
const selectIsExposed = useMemo(
() =>
createSelector(selectWorkflowSlice, (workflow) => {
@@ -31,8 +31,11 @@ const FieldLinearViewToggle = ({ nodeId, fieldName }: Props) => {
const isExposed = useAppSelector(selectIsExposed);
const handleExposeField = useCallback(() => {
- dispatch(workflowExposedFieldAdded({ nodeId, fieldName, value }));
- }, [dispatch, fieldName, nodeId, value]);
+ if (!field) {
+ return;
+ }
+ dispatch(workflowExposedFieldAdded({ nodeId, fieldName, field }));
+ }, [dispatch, field, fieldName, nodeId]);
const handleUnexposeField = useCallback(() => {
dispatch(workflowExposedFieldRemoved({ nodeId, fieldName }));
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx
index 9a05a1c6b53..2597f29fd60 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx
@@ -14,9 +14,10 @@ import { InputFieldWrapper } from './InputFieldWrapper';
interface Props {
nodeId: string;
fieldName: string;
+ isLinearView: boolean;
}
-const InputField = ({ nodeId, fieldName }: Props) => {
+const InputField = ({ nodeId, fieldName, isLinearView }: Props) => {
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
const [isHovered, setIsHovered] = useState(false);
const isInvalid = useFieldIsInvalid(nodeId, fieldName);
@@ -69,12 +70,12 @@ const InputField = ({ nodeId, fieldName }: Props) => {
px={2}
>
-
+
{isHovered && }
{isHovered && }
-
+
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx
index 4095648e370..fe8500814df 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx
@@ -99,109 +99,285 @@ import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
type InputFieldProps = {
nodeId: string;
fieldName: string;
+ isLinearView: boolean;
};
-const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
+const InputFieldRenderer = ({ nodeId, fieldName, isLinearView }: InputFieldProps) => {
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
if (isStringFieldCollectionInputInstance(fieldInstance) && isStringFieldCollectionInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isBooleanFieldInputInstance(fieldInstance) && isBooleanFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isIntegerFieldCollectionInputInstance(fieldInstance) && isIntegerFieldCollectionInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isFloatFieldCollectionInputInstance(fieldInstance) && isFloatFieldCollectionInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isEnumFieldInputInstance(fieldInstance) && isEnumFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isImageFieldCollectionInputInstance(fieldInstance) && isImageFieldCollectionInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isImageFieldInputInstance(fieldInstance) && isImageFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isBoardFieldInputInstance(fieldInstance) && isBoardFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isMainModelFieldInputInstance(fieldInstance) && isMainModelFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isModelIdentifierFieldInputInstance(fieldInstance) && isModelIdentifierFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isVAEModelFieldInputInstance(fieldInstance) && isVAEModelFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isCLIPEmbedModelFieldInputInstance(fieldInstance) && isCLIPEmbedModelFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isCLIPLEmbedModelFieldInputInstance(fieldInstance) && isCLIPLEmbedModelFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isCLIPGEmbedModelFieldInputInstance(fieldInstance) && isCLIPGEmbedModelFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isControlLoRAModelFieldInputInstance(fieldInstance) && isControlLoRAModelFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isControlNetModelFieldInputInstance(fieldInstance) && isControlNetModelFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isIPAdapterModelFieldInputInstance(fieldInstance) && isIPAdapterModelFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (
@@ -213,28 +389,64 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
+ isLinearView={isLinearView}
/>
);
}
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) {
- return ;
+ return (
+
+ );
}
if (fieldTemplate) {
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx
index a6d36c00389..859659ad235 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx
@@ -97,7 +97,11 @@ const LinearViewFieldInternal = ({ fieldIdentifier }: Props) => {
icon={}
/>
-
+
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatRangeGenerator.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatRangeGenerator.tsx
new file mode 100644
index 00000000000..8e4750f0d74
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatRangeGenerator.tsx
@@ -0,0 +1,65 @@
+import { CompositeNumberInput, Flex, FormControl, FormLabel, IconButton } from '@invoke-ai/ui-library';
+import {
+ type FloatRangeStartStepCountGenerator,
+ getDefaultFloatRangeStartStepCountGenerator,
+} from 'features/nodes/types/generators';
+import { memo, useCallback } from 'react';
+import { useTranslation } from 'react-i18next';
+import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
+
+type FloatRangeGeneratorProps = {
+ state: FloatRangeStartStepCountGenerator;
+ onChange: (state: FloatRangeStartStepCountGenerator) => void;
+};
+
+export const FloatRangeGenerator = memo(({ state, onChange }: FloatRangeGeneratorProps) => {
+ const { t } = useTranslation();
+
+ const onChangeStart = useCallback(
+ (start: number) => {
+ onChange({ ...state, start });
+ },
+ [onChange, state]
+ );
+ const onChangeStep = useCallback(
+ (step: number) => {
+ onChange({ ...state, step });
+ },
+ [onChange, state]
+ );
+ const onChangeCount = useCallback(
+ (count: number) => {
+ onChange({ ...state, count });
+ },
+ [onChange, state]
+ );
+
+ const onReset = useCallback(() => {
+ onChange(getDefaultFloatRangeStartStepCountGenerator());
+ }, [onChange]);
+
+ return (
+
+
+ {t('common.start')}
+
+
+
+ {t('common.count')}
+
+
+
+ {t('common.step')}
+
+
+ }
+ variant="ghost"
+ />
+
+ );
+});
+
+FloatRangeGenerator.displayName = 'FloatRangeGenerator';
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldCollectionInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldCollectionInputComponent.tsx
index fc0d4eefbdb..c0db9184203 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldCollectionInputComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldCollectionInputComponent.tsx
@@ -1,33 +1,45 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import {
Button,
- ButtonGroup,
CompositeNumberInput,
Divider,
Flex,
+ FormControl,
FormLabel,
Grid,
GridItem,
IconButton,
+ Switch,
+ Text,
} from '@invoke-ai/ui-library';
import { NUMPY_RAND_MAX } from 'app/constants';
import { useAppStore } from 'app/store/nanostores/store';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
-import { openFloatRangeGeneratorModal } from 'features/nodes/components/FloatRangeGeneratorModal';
-import { openIntegerRangeGeneratorModal } from 'features/nodes/components/IntegerRangeGeneratorModal';
+import { FloatRangeGenerator } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatRangeGenerator';
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
-import { fieldNumberCollectionValueChanged } from 'features/nodes/store/nodesSlice';
+import {
+ fieldNumberCollectionGeneratorCommitted,
+ fieldNumberCollectionGeneratorStateChanged,
+ fieldNumberCollectionGeneratorToggled,
+ fieldNumberCollectionLockLinearViewToggled,
+ fieldNumberCollectionValueChanged,
+} from 'features/nodes/store/nodesSlice';
import type {
FloatFieldCollectionInputInstance,
FloatFieldCollectionInputTemplate,
IntegerFieldCollectionInputInstance,
IntegerFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
-import { isNil } from 'lodash-es';
+import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
+import type {
+ FloatRangeStartStepCountGenerator,
+ IntegerRangeStartStepCountGenerator,
+} from 'features/nodes/types/generators';
+import { isNil, round } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
-import { PiXBold } from 'react-icons/pi';
+import { PiLockSimpleFill, PiLockSimpleOpenFill, PiXBold } from 'react-icons/pi';
import type { FieldComponentProps } from './types';
@@ -47,7 +59,7 @@ export const NumberFieldCollectionInputComponent = memo(
| FieldComponentProps
| FieldComponentProps
) => {
- const { nodeId, field, fieldTemplate } = props;
+ const { nodeId, field, fieldTemplate, isLinearView } = props;
const store = useAppStore();
const { t } = useTranslation();
@@ -77,17 +89,6 @@ export const NumberFieldCollectionInputComponent = memo(
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
}, [field.name, field.value, nodeId, store]);
- const onOpenGenerator = useCallback(() => {
- const onSave = (values: number[]) => {
- store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: values }));
- };
- if (isIntegerField) {
- openIntegerRangeGeneratorModal(onSave);
- } else {
- openFloatRangeGeneratorModal(onSave);
- }
- }, [field.name, isIntegerField, nodeId, store]);
-
const min = useMemo(() => {
let min = -NUMPY_RAND_MAX;
if (!isNil(fieldTemplate.minimum)) {
@@ -124,6 +125,32 @@ export const NumberFieldCollectionInputComponent = memo(
return fieldTemplate.multipleOf;
}, [fieldTemplate.multipleOf, isIntegerField]);
+ const toggleGenerator = useCallback(() => {
+ store.dispatch(fieldNumberCollectionGeneratorToggled({ nodeId, fieldName: field.name }));
+ }, [field.name, nodeId, store]);
+
+ const onChangeGenerator = useCallback(
+ (generatorState: FloatRangeStartStepCountGenerator | IntegerRangeStartStepCountGenerator) => {
+ store.dispatch(fieldNumberCollectionGeneratorStateChanged({ nodeId, fieldName: field.name, generatorState }));
+ },
+ [field.name, nodeId, store]
+ );
+
+ const onCommitGenerator = useCallback(() => {
+ store.dispatch(fieldNumberCollectionGeneratorCommitted({ nodeId, fieldName: field.name }));
+ }, [field.name, nodeId, store]);
+
+ const onToggleLockLinearView = useCallback(() => {
+ store.dispatch(fieldNumberCollectionLockLinearViewToggled({ nodeId, fieldName: field.name }));
+ }, [field.name, nodeId, store]);
+
+ const valuesAsString = useMemo(() => {
+ const resolvedValue = resolveNumberFieldCollectionValue(field);
+ return resolvedValue ? resolvedValue.map((val) => round(val, 2)).join(', ') : '';
+ }, [field]);
+
+ const isLockedOnLinearView = !(field.lockLinearView && isLinearView);
+
return (
-
-
-
-
- {field.value && field.value.length > 0 && (
+
+ {!field.generator && (
+
+ )}
+ {field.generator && isLockedOnLinearView && (
+
+ }
+ onClick={onCommitGenerator}
+ variant="ghost"
+ flexGrow={1}
+ size="sm"
+ >
+ {t('nodes.commitValues')}
+
+ )}
+ {isLockedOnLinearView && (
+
+ {t('nodes.generator')}
+
+
+ )}
+ {!isLinearView && (
+ : }
+ variant="ghost"
+ size="sm"
+ />
+ )}
+
+ {!field.generator && field.value && field.value.length > 0 && (
<>
-
+ {!(field.lockLinearView && isLinearView) && }
>
)}
+ {field.generator && field.generator.type === 'float-range-generator-start-step-count' && (
+ <>
+ {!(field.lockLinearView && isLinearView) && }
+
+ >
+ )}
);
}
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/types.ts b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/types.ts
index fea63960752..153a24ea548 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/types.ts
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/types.ts
@@ -4,4 +4,5 @@ export type FieldComponentProps {
-
+
);
};
diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldIsInvalid.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldIsInvalid.ts
index 2f97296087a..d97484cd180 100644
--- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldIsInvalid.ts
+++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldIsInvalid.ts
@@ -63,13 +63,13 @@ export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
}
if (isIntegerFieldCollectionInputInstance(field) && isIntegerFieldCollectionInputTemplate(template)) {
- if (validateNumberFieldCollectionValue(field.value, template).length > 0) {
+ if (validateNumberFieldCollectionValue(field, template).length > 0) {
return true;
}
}
if (isFloatFieldCollectionInputInstance(field) && isFloatFieldCollectionInputTemplate(template)) {
- if (validateNumberFieldCollectionValue(field.value, template).length > 0) {
+ if (validateNumberFieldCollectionValue(field, template).length > 0) {
return true;
}
}
diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOriginalValue.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOriginalValue.ts
index a9ebc991e22..11f2c47acfe 100644
--- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOriginalValue.ts
+++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOriginalValue.ts
@@ -1,8 +1,9 @@
-import { createSelector } from '@reduxjs/toolkit';
+import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
-import { useFieldValue } from 'features/nodes/hooks/useFieldValue';
+import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import { fieldValueReset } from 'features/nodes/store/nodesSlice';
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
+import { isFloatFieldCollectionInputInstance, isIntegerFieldCollectionInputInstance } from 'features/nodes/types/field';
import { isEqual } from 'lodash-es';
import { useCallback, useMemo } from 'react';
@@ -10,19 +11,38 @@ export const useFieldOriginalValue = (nodeId: string, fieldName: string) => {
const dispatch = useAppDispatch();
const selectOriginalExposedFieldValues = useMemo(
() =>
- createSelector(
- selectWorkflowSlice,
- (workflow) =>
- workflow.originalExposedFieldValues.find((v) => v.nodeId === nodeId && v.fieldName === fieldName)?.value
+ createMemoizedSelector(selectWorkflowSlice, (workflow) =>
+ workflow.originalExposedFieldValues.find((v) => v.nodeId === nodeId && v.fieldName === fieldName)
),
[nodeId, fieldName]
);
- const originalValue = useAppSelector(selectOriginalExposedFieldValues);
- const value = useFieldValue(nodeId, fieldName);
- const isValueChanged = useMemo(() => !isEqual(value, originalValue), [value, originalValue]);
+ const exposedField = useAppSelector(selectOriginalExposedFieldValues);
+ const field = useFieldInputInstance(nodeId, fieldName);
+ const isValueChanged = useMemo(() => {
+ if (!field) {
+ // Field is not found, so it is not changed
+ return false;
+ }
+ if (isFloatFieldCollectionInputInstance(field) && isFloatFieldCollectionInputInstance(exposedField?.field)) {
+ return !isEqual(field.generator, exposedField.field.generator);
+ }
+ if (isIntegerFieldCollectionInputInstance(field) && isIntegerFieldCollectionInputInstance(exposedField?.field)) {
+ return !isEqual(field.generator, exposedField.field.generator);
+ }
+ return !isEqual(field.value, exposedField?.field.value);
+ }, [field, exposedField]);
const onReset = useCallback(() => {
- dispatch(fieldValueReset({ nodeId, fieldName, value: originalValue }));
- }, [dispatch, fieldName, nodeId, originalValue]);
+ if (!exposedField) {
+ return;
+ }
+ const { value } = exposedField.field;
+ const generator =
+ isIntegerFieldCollectionInputInstance(exposedField.field) ||
+ isFloatFieldCollectionInputInstance(exposedField.field)
+ ? exposedField.field.generator
+ : undefined;
+ dispatch(fieldValueReset({ nodeId, fieldName, value, generator }));
+ }, [dispatch, fieldName, nodeId, exposedField]);
- return { originalValue, isValueChanged, onReset };
+ return { originalValue: exposedField, isValueChanged, onReset };
};
diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
index 5b031dc0727..72b329c0f22 100644
--- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
@@ -36,6 +36,8 @@ import type {
VAEModelFieldValue,
} from 'features/nodes/types/field';
import {
+ isFloatFieldCollectionInputInstance,
+ isIntegerFieldCollectionInputInstance,
zBoardFieldValue,
zBooleanFieldValue,
zCLIPEmbedModelFieldValue,
@@ -66,6 +68,16 @@ import {
zT5EncoderModelFieldValue,
zVAEModelFieldValue,
} from 'features/nodes/types/field';
+import type {
+ FloatRangeStartStepCountGenerator,
+ IntegerRangeStartStepCountGenerator,
+} from 'features/nodes/types/generators';
+import {
+ floatRangeStartStepCountGenerator,
+ getDefaultFloatRangeStartStepCountGenerator,
+ getDefaultIntegerRangeStartStepCountGenerator,
+ integerRangeStartStepCountGenerator,
+} from 'features/nodes/types/generators';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import { atom, computed } from 'nanostores';
@@ -83,11 +95,22 @@ const initialNodesState: NodesState = {
edges: [],
};
-type FieldValueAction = PayloadAction<{
- nodeId: string;
- fieldName: string;
- value: T;
-}>;
+type FieldValueAction = PayloadAction<
+ {
+ nodeId: string;
+ fieldName: string;
+ value: T;
+ } & U
+>;
+
+const selectField = (state: NodesState, nodeId: string, fieldName: string) => {
+ const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
+ const node = state.nodes?.[nodeIndex];
+ if (!isInvocationNode(node)) {
+ return;
+ }
+ return node.data?.inputs[fieldName];
+};
const fieldValueReducer = (
state: NodesState,
@@ -95,17 +118,24 @@ const fieldValueReducer = (
schema: z.ZodTypeAny
) => {
const { nodeId, fieldName, value } = action.payload;
- const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
- const node = state.nodes?.[nodeIndex];
- if (!isInvocationNode(node)) {
- return;
- }
- const input = node.data?.inputs[fieldName];
+ const field = selectField(state, nodeId, fieldName);
const result = schema.safeParse(value);
- if (!input || nodeIndex < 0 || !result.success) {
+ if (!field || !result.success) {
return;
}
- input.value = result.data;
+ field.value = result.data;
+ // Special handling if the field value is being reset
+ if (result.data === undefined) {
+ if (isFloatFieldCollectionInputInstance(field)) {
+ if (field.lockLinearView && field.generator) {
+ field.generator = getDefaultFloatRangeStartStepCountGenerator();
+ }
+ } else if (isIntegerFieldCollectionInputInstance(field)) {
+ if (field.lockLinearView && field.generator) {
+ field.generator = getDefaultIntegerRangeStartStepCountGenerator();
+ }
+ }
+ }
};
export const nodesSlice = createSlice({
@@ -310,8 +340,31 @@ export const nodesSlice = createSlice({
}
node.data.notes = notes;
},
- fieldValueReset: (state, action: FieldValueAction) => {
- fieldValueReducer(state, action, zStatefulFieldValue);
+ fieldValueReset: (
+ state,
+ action: FieldValueAction<
+ StatefulFieldValue,
+ { generator?: IntegerRangeStartStepCountGenerator | FloatRangeStartStepCountGenerator }
+ >
+ ) => {
+ const { nodeId, fieldName, value, generator } = action.payload;
+ const field = selectField(state, nodeId, fieldName);
+ const result = zStatefulFieldValue.safeParse(value);
+
+ if (!field || !result.success) {
+ return;
+ }
+
+ field.value = result.data;
+
+ if (isFloatFieldCollectionInputInstance(field) && generator?.type === 'float-range-generator-start-step-count') {
+ field.generator = generator;
+ } else if (
+ isIntegerFieldCollectionInputInstance(field) &&
+ generator?.type === 'integer-range-generator-start-step-count'
+ ) {
+ field.generator = generator;
+ }
},
fieldStringValueChanged: (state, action: FieldValueAction) => {
fieldValueReducer(state, action, zStringFieldValue);
@@ -325,6 +378,85 @@ export const nodesSlice = createSlice({
fieldNumberCollectionValueChanged: (state, action: FieldValueAction) => {
fieldValueReducer(state, action, zIntegerFieldCollectionValue.or(zFloatFieldCollectionValue));
},
+ fieldNumberCollectionGeneratorToggled: (state, action: PayloadAction<{ nodeId: string; fieldName: string }>) => {
+ const { nodeId, fieldName } = action.payload;
+ const field = selectField(state, nodeId, fieldName);
+ if (!field) {
+ return;
+ }
+ if (isFloatFieldCollectionInputInstance(field)) {
+ field.generator = field.generator ? undefined : getDefaultFloatRangeStartStepCountGenerator();
+ } else if (isIntegerFieldCollectionInputInstance(field)) {
+ field.generator = field.generator ? undefined : getDefaultIntegerRangeStartStepCountGenerator();
+ } else {
+ // This should never happen
+ }
+ },
+ fieldNumberCollectionGeneratorStateChanged: (
+ state,
+ action: PayloadAction<{
+ nodeId: string;
+ fieldName: string;
+ generatorState: FloatRangeStartStepCountGenerator | IntegerRangeStartStepCountGenerator;
+ }>
+ ) => {
+ const { nodeId, fieldName, generatorState } = action.payload;
+ const field = selectField(state, nodeId, fieldName);
+ if (!field) {
+ return;
+ }
+ if (
+ isFloatFieldCollectionInputInstance(field) &&
+ generatorState.type === 'float-range-generator-start-step-count'
+ ) {
+ field.generator = generatorState;
+ } else if (
+ isIntegerFieldCollectionInputInstance(field) &&
+ generatorState.type === 'integer-range-generator-start-step-count'
+ ) {
+ field.generator = generatorState;
+ } else {
+ // This should never happen
+ }
+ },
+ fieldNumberCollectionGeneratorCommitted: (state, action: PayloadAction<{ nodeId: string; fieldName: string }>) => {
+ const { nodeId, fieldName } = action.payload;
+ const field = selectField(state, nodeId, fieldName);
+ if (!field) {
+ return;
+ }
+ if (
+ isFloatFieldCollectionInputInstance(field) &&
+ field.generator &&
+ field.generator.type === 'float-range-generator-start-step-count'
+ ) {
+ field.value = floatRangeStartStepCountGenerator(field.generator);
+ field.generator = undefined;
+ } else if (
+ isIntegerFieldCollectionInputInstance(field) &&
+ field.generator &&
+ field.generator.type === 'integer-range-generator-start-step-count'
+ ) {
+ field.value = integerRangeStartStepCountGenerator(field.generator);
+ field.generator = undefined;
+ } else {
+ // This should never happen
+ }
+ },
+ fieldNumberCollectionLockLinearViewToggled: (
+ state,
+ action: PayloadAction<{ nodeId: string; fieldName: string }>
+ ) => {
+ const { nodeId, fieldName } = action.payload;
+ const field = selectField(state, nodeId, fieldName);
+ if (!field) {
+ return;
+ }
+ if (!isFloatFieldCollectionInputInstance(field) && !isIntegerFieldCollectionInputInstance(field)) {
+ return;
+ }
+ field.lockLinearView = !field.lockLinearView;
+ },
fieldBooleanValueChanged: (state, action: FieldValueAction) => {
fieldValueReducer(state, action, zBooleanFieldValue);
},
@@ -447,6 +579,10 @@ export const {
fieldMainModelValueChanged,
fieldNumberValueChanged,
fieldNumberCollectionValueChanged,
+ fieldNumberCollectionGeneratorToggled,
+ fieldNumberCollectionGeneratorStateChanged,
+ fieldNumberCollectionGeneratorCommitted,
+ fieldNumberCollectionLockLinearViewToggled,
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
fieldStringValueChanged,
diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts
index 5684d0f2b33..36f51485b12 100644
--- a/invokeai/frontend/web/src/features/nodes/store/types.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/types.ts
@@ -1,8 +1,8 @@
import type {
FieldIdentifier,
+ FieldInputInstance,
FieldInputTemplate,
FieldOutputTemplate,
- StatefulFieldValue,
} from 'features/nodes/types/field';
import type {
AnyNode,
@@ -31,15 +31,15 @@ export type NodesState = {
};
export type WorkflowMode = 'edit' | 'view';
-export type FieldIdentifierWithValue = FieldIdentifier & {
- value: StatefulFieldValue;
+export type FieldIdentifierWithInstance = FieldIdentifier & {
+ field: FieldInputInstance;
};
export type WorkflowsState = Omit & {
- _version: 1;
+ _version: 2;
isTouched: boolean;
mode: WorkflowMode;
- originalExposedFieldValues: FieldIdentifierWithValue[];
+ originalExposedFieldValues: FieldIdentifierWithInstance[];
searchTerm: string;
orderBy?: WorkflowRecordOrderBy;
orderDirection: SQLiteDirection;
diff --git a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts
index b1663b2a11d..bf6a49bb9d3 100644
--- a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts
@@ -5,7 +5,7 @@ import { deepClone } from 'common/util/deepClone';
import { workflowLoaded } from 'features/nodes/store/actions';
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged } from 'features/nodes/store/nodesSlice';
import type {
- FieldIdentifierWithValue,
+ FieldIdentifierWithInstance,
WorkflowMode,
WorkflowsState as WorkflowState,
} from 'features/nodes/store/types';
@@ -31,7 +31,7 @@ const blankWorkflow: Omit = {
};
const initialWorkflowState: WorkflowState = {
- _version: 1,
+ _version: 2,
isTouched: false,
mode: 'view',
originalExposedFieldValues: [],
@@ -62,7 +62,7 @@ export const workflowSlice = createSlice({
const { id, isOpen } = action.payload;
state.categorySections[id] = isOpen;
},
- workflowExposedFieldAdded: (state, action: PayloadAction) => {
+ workflowExposedFieldAdded: (state, action: PayloadAction) => {
state.exposedFields = uniqBy(
state.exposedFields.concat(omit(action.payload, 'value')),
(field) => `${field.nodeId}-${field.fieldName}`
@@ -128,25 +128,25 @@ export const workflowSlice = createSlice({
builder.addCase(workflowLoaded, (state, action) => {
const { nodes, edges: _edges, ...workflowExtra } = action.payload;
- const originalExposedFieldValues: FieldIdentifierWithValue[] = [];
+ const originalExposedFieldValues: FieldIdentifierWithInstance[] = [];
- workflowExtra.exposedFields.forEach((field) => {
- const node = nodes.find((n) => n.id === field.nodeId);
+ workflowExtra.exposedFields.forEach(({ nodeId, fieldName }) => {
+ const node = nodes.find((n) => n.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
- const input = node.data.inputs[field.fieldName];
+ const field = node.data.inputs[fieldName];
- if (!input) {
+ if (!field) {
return;
}
const originalExposedFieldValue = {
- nodeId: field.nodeId,
- fieldName: field.fieldName,
- value: input.value,
+ nodeId,
+ fieldName,
+ field,
};
originalExposedFieldValues.push(originalExposedFieldValue);
});
@@ -243,6 +243,9 @@ const migrateWorkflowState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
+ if (state._version === 1) {
+ return deepClone(initialWorkflowState);
+ }
return state;
};
diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts
index 1ccc6ec9bbe..f7b599a62ff 100644
--- a/invokeai/frontend/web/src/features/nodes/types/field.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/field.ts
@@ -1,3 +1,7 @@
+import {
+ zFloatRangeStartStepCountGenerator,
+ zIntegerRangeStartStepCountGenerator,
+} from 'features/nodes/types/generators';
import { buildTypeGuard } from 'features/parameters/types/parameterSchemas';
import { z } from 'zod';
@@ -282,6 +286,8 @@ export const isIntegerFieldInputTemplate = buildTypeGuard(zIntegerFieldInputTemp
export const zIntegerFieldCollectionValue = z.array(zIntegerFieldValue).optional();
const zIntegerFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
value: zIntegerFieldCollectionValue,
+ generator: zIntegerRangeStartStepCountGenerator.optional(),
+ lockLinearView: z.boolean().default(false),
});
const zIntegerFieldCollectionInputTemplate = zFieldInputTemplateBase
.extend({
@@ -343,9 +349,12 @@ export const isFloatFieldInputTemplate = buildTypeGuard(zFloatFieldInputTemplate
// #endregion
// #region FloatField Collection
+
export const zFloatFieldCollectionValue = z.array(zFloatFieldValue).optional();
const zFloatFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
value: zFloatFieldCollectionValue,
+ generator: zFloatRangeStartStepCountGenerator.optional(),
+ lockLinearView: z.boolean().default(false),
});
const zFloatFieldCollectionInputTemplate = zFieldInputTemplateBase
.extend({
@@ -373,7 +382,6 @@ const zFloatFieldCollectionInputTemplate = zFieldInputTemplateBase
const zFloatFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
type: zFloatCollectionFieldType,
});
-export type FloatFieldCollectionValue = z.infer;
export type FloatFieldCollectionInputInstance = z.infer;
export type FloatFieldCollectionInputTemplate = z.infer;
export const isFloatFieldCollectionInputInstance = buildTypeGuard(zFloatFieldCollectionInputInstance);
diff --git a/invokeai/frontend/web/src/features/nodes/types/fieldValidators.ts b/invokeai/frontend/web/src/features/nodes/types/fieldValidators.ts
index 4dbfc588c0f..8dd19b09556 100644
--- a/invokeai/frontend/web/src/features/nodes/types/fieldValidators.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/fieldValidators.ts
@@ -1,13 +1,17 @@
import type {
+ FloatFieldCollectionInputInstance,
FloatFieldCollectionInputTemplate,
- FloatFieldCollectionValue,
ImageFieldCollectionInputTemplate,
ImageFieldCollectionValue,
+ IntegerFieldCollectionInputInstance,
IntegerFieldCollectionInputTemplate,
- IntegerFieldCollectionValue,
StringFieldCollectionInputTemplate,
StringFieldCollectionValue,
} from 'features/nodes/types/field';
+import {
+ floatRangeStartStepCountGenerator,
+ integerRangeStartStepCountGenerator,
+} from 'features/nodes/types/generators';
import { t } from 'i18next';
export const validateImageFieldCollectionValue = (
@@ -67,12 +71,31 @@ export const validateStringFieldCollectionValue = (
return reasons;
};
+export const resolveNumberFieldCollectionValue = (
+ field: IntegerFieldCollectionInputInstance | FloatFieldCollectionInputInstance
+): number[] | undefined => {
+ if (field.generator?.type === 'float-range-generator-start-step-count') {
+ return floatRangeStartStepCountGenerator(field.generator);
+ } else if (field.generator?.type === 'integer-range-generator-start-step-count') {
+ return integerRangeStartStepCountGenerator(field.generator);
+ } else {
+ return field.value;
+ }
+};
+
export const validateNumberFieldCollectionValue = (
- value: NonNullable | NonNullable,
+ field: IntegerFieldCollectionInputInstance | FloatFieldCollectionInputInstance,
template: IntegerFieldCollectionInputTemplate | FloatFieldCollectionInputTemplate
): string[] => {
const reasons: string[] = [];
const { minItems, maxItems, minimum, maximum, exclusiveMinimum, exclusiveMaximum, multipleOf } = template;
+ const value = resolveNumberFieldCollectionValue(field);
+
+ if (value === undefined) {
+ reasons.push(t('parameters.invoke.collectionEmpty'));
+ return reasons;
+ }
+
const count = value.length;
// Image collections may have min or max items to validate
diff --git a/invokeai/frontend/web/src/features/nodes/types/generators.ts b/invokeai/frontend/web/src/features/nodes/types/generators.ts
new file mode 100644
index 00000000000..da3848a1d2e
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/types/generators.ts
@@ -0,0 +1,29 @@
+import { z } from 'zod';
+
+export const zFloatRangeStartStepCountGenerator = z.object({
+ type: z.literal('float-range-generator-start-step-count').default('float-range-generator-start-step-count'),
+ start: z.number().default(0),
+ step: z.number().default(1),
+ count: z.number().int().default(10),
+});
+export type FloatRangeStartStepCountGenerator = z.infer;
+export const floatRangeStartStepCountGenerator = (generator: FloatRangeStartStepCountGenerator): number[] => {
+ const { start, step, count } = generator;
+ return Array.from({ length: count }, (_, i) => start + i * step);
+};
+export const getDefaultFloatRangeStartStepCountGenerator = (): FloatRangeStartStepCountGenerator =>
+ zFloatRangeStartStepCountGenerator.parse({});
+
+export const zIntegerRangeStartStepCountGenerator = z.object({
+ type: z.literal('integer-range-generator-start-step-count').default('integer-range-generator-start-step-count'),
+ start: z.number().int().default(0),
+ step: z.number().int().default(1),
+ count: z.number().int().default(10),
+});
+export type IntegerRangeStartStepCountGenerator = z.infer;
+export const integerRangeStartStepCountGenerator = (generator: IntegerRangeStartStepCountGenerator): number[] => {
+ const { start, step, count } = generator;
+ return Array.from({ length: count }, (_, i) => start + i * step);
+};
+export const getDefaultIntegerRangeStartStepCountGenerator = (): IntegerRangeStartStepCountGenerator =>
+ zIntegerRangeStartStepCountGenerator.parse({});
diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts
index 4cc977593c6..4aff783912b 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts
@@ -1,5 +1,7 @@
import { logger } from 'app/logging/logger';
import type { NodesState } from 'features/nodes/store/types';
+import { isFloatFieldCollectionInputInstance, isIntegerFieldCollectionInputInstance } from 'features/nodes/types/field';
+import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
import { negate, omit, reduce } from 'lodash-es';
import type { AnyInvocation, Graph } from 'services/api/types';
@@ -25,7 +27,11 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
const transformedInputs = reduce(
inputs,
(inputsAccumulator, input, name) => {
- inputsAccumulator[name] = input.value;
+ if (isFloatFieldCollectionInputInstance(input) || isIntegerFieldCollectionInputInstance(input)) {
+ inputsAccumulator[name] = resolveNumberFieldCollectionValue(input);
+ } else {
+ inputsAccumulator[name] = input.value;
+ }
return inputsAccumulator;
},
diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.ts b/invokeai/frontend/web/src/features/queue/store/readiness.ts
index e652602153a..13e56551ccd 100644
--- a/invokeai/frontend/web/src/features/queue/store/readiness.ts
+++ b/invokeai/frontend/web/src/features/queue/store/readiness.ts
@@ -30,6 +30,7 @@ import {
isStringFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import {
+ resolveNumberFieldCollectionValue,
validateImageFieldCollectionValue,
validateNumberFieldCollectionValue,
validateStringFieldCollectionValue,
@@ -176,14 +177,14 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
isIntegerFieldCollectionInputInstance(field) &&
isIntegerFieldCollectionInputTemplate(fieldTemplate)
) {
- const errors = validateNumberFieldCollectionValue(field.value, fieldTemplate);
+ const errors = validateNumberFieldCollectionValue(field, fieldTemplate);
reasons.push(...errors.map((error) => ({ prefix, content: error })));
} else if (
field.value &&
isFloatFieldCollectionInputInstance(field) &&
isFloatFieldCollectionInputTemplate(fieldTemplate)
) {
- const errors = validateNumberFieldCollectionValue(field.value, fieldTemplate);
+ const errors = validateNumberFieldCollectionValue(field, fieldTemplate);
reasons.push(...errors.map((error) => ({ prefix, content: error })));
}
});
@@ -555,10 +556,10 @@ const getBatchCollectionSize = (batchNode: InvocationNode) => {
return batchNode.data.inputs.strings.value?.length ?? 0;
} else if (batchNode.data.type === 'float_batch') {
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));
- return batchNode.data.inputs.floats.value?.length ?? 0;
+ return resolveNumberFieldCollectionValue(batchNode.data.inputs.floats)?.length ?? 0;
} else if (batchNode.data.type === 'integer_batch') {
assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers));
- return batchNode.data.inputs.integers.value?.length ?? 0;
+ return resolveNumberFieldCollectionValue(batchNode.data.inputs.integers)?.length ?? 0;
}
return 0;
};