diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index c8c89ed31cb..3152dc9f3ba 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -42,6 +42,7 @@ class UIType(str, Enum, metaclass=MetaEnum): MainModel = "MainModelField" CogView4MainModel = "CogView4MainModelField" FluxMainModel = "FluxMainModelField" + BriaMainModel = "BriaMainModelField" SD3MainModel = "SD3MainModelField" SDXLMainModel = "SDXLMainModelField" SDXLRefinerModel = "SDXLRefinerModelField" diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py index 8a0e770d037..caff085b373 100644 --- a/invokeai/backend/model_manager/legacy_probe.py +++ b/invokeai/backend/model_manager/legacy_probe.py @@ -125,6 +125,7 @@ class ModelProbe(object): } CLASS2TYPE = { + "BriaPipeline": ModelType.Main, "FluxPipeline": ModelType.Main, "StableDiffusionPipeline": ModelType.Main, "StableDiffusionInpaintPipeline": ModelType.Main, @@ -861,6 +862,8 @@ def get_base_type(self) -> BaseModelType: return BaseModelType.StableDiffusion3 elif transformer_conf["_class_name"] == "CogView4Transformer2DModel": return BaseModelType.CogView4 + elif transformer_conf["_class_name"] == "BriaTransformer2DModel": + return BaseModelType.Bria else: raise InvalidModelConfigException(f"Unknown base model for {self.model_path}") diff --git a/invokeai/backend/model_manager/load/model_loaders/bria.py b/invokeai/backend/model_manager/load/model_loaders/bria.py new file mode 100644 index 00000000000..6712e13896e --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/bria.py @@ -0,0 +1,56 @@ +from pathlib import Path +from typing import Optional + +from invokeai.backend.model_manager.config import ( + AnyModelConfig, + CheckpointConfigBase, + DiffusersConfigBase, +) +from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry +from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader +from invokeai.backend.model_manager.taxonomy import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelType, + SubModelType, +) + + +@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.Main, format=ModelFormat.Diffusers) +class BriaDiffusersModel(GenericDiffusersLoader): + """Class to load Bria main models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if isinstance(config, CheckpointConfigBase): + raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.") + + if submodel_type is None: + raise Exception("A submodel type must be provided when loading main pipelines.") + + model_path = Path(config.path) + load_class = self.get_hf_load_class(model_path, submodel_type) + repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None + variant = repo_variant.value if repo_variant else None + model_path = model_path / submodel_type.value + + dtype = self._torch_dtype + try: + result: AnyModel = load_class.from_pretrained( + model_path, + torch_dtype=dtype, + variant=variant, + ) + except OSError as e: + if variant and "no file named" in str( + e + ): # try without the variant, just in case user's preferences changed + result = load_class.from_pretrained(model_path, torch_dtype=dtype) + else: + raise e + + return result diff --git a/invokeai/backend/model_manager/taxonomy.py b/invokeai/backend/model_manager/taxonomy.py index d77f5fc10ff..76c0ffe4f60 100644 --- a/invokeai/backend/model_manager/taxonomy.py +++ b/invokeai/backend/model_manager/taxonomy.py @@ -28,6 +28,7 @@ class BaseModelType(str, Enum): CogView4 = "cogview4" Imagen3 = "imagen3" ChatGPT4o = "chatgpt-4o" + Bria = "bria" class ModelType(str, Enum): diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index a4dc5414953..20fb19c1d31 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -23,6 +23,8 @@ import { isBoardFieldInputTemplate, isBooleanFieldInputInstance, isBooleanFieldInputTemplate, + isBriaMainModelFieldInputInstance, + isBriaMainModelFieldInputTemplate, isCLIPEmbedModelFieldInputInstance, isCLIPEmbedModelFieldInputTemplate, isCLIPGEmbedModelFieldInputInstance, @@ -105,6 +107,7 @@ import { assert } from 'tsafe'; import BoardFieldInputComponent from './inputs/BoardFieldInputComponent'; import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent'; +import BriaMainModelFieldInputComponent from './inputs/BriaMainModelFieldInputComponent'; import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent'; import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent'; import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent'; @@ -408,6 +411,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props) return ; } + if (isBriaMainModelFieldInputTemplate(template)) { + if (!isBriaMainModelFieldInputInstance(field)) { + return null; + } + return ; + } + if (isSD3MainModelFieldInputTemplate(template)) { if (!isSD3MainModelFieldInputInstance(field)) { return null; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaMainModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaMainModelFieldInputComponent.tsx new file mode 100644 index 00000000000..8d8af426a56 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaMainModelFieldInputComponent.tsx @@ -0,0 +1,44 @@ +import { useAppDispatch } from 'app/store/storeHooks'; +import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox'; +import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; +import type { BriaMainModelFieldInputInstance, BriaMainModelFieldInputTemplate } from 'features/nodes/types/field'; +import { memo, useCallback } from 'react'; +import { useBriaModels } from 'services/api/hooks/modelsByType'; +import type { MainModelConfig } from 'services/api/types'; + +import type { FieldComponentProps } from './types'; + +type Props = FieldComponentProps; + +const BriaMainModelFieldInputComponent = (props: Props) => { + const { nodeId, field } = props; + const dispatch = useAppDispatch(); + const [modelConfigs, { isLoading }] = useBriaModels(); + const onChange = useCallback( + (value: MainModelConfig | null) => { + if (!value) { + return; + } + dispatch( + fieldMainModelValueChanged({ + nodeId, + fieldName: field.name, + value, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + return ( + + ); +}; + +export default memo(BriaMainModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index b1b17de0726..b1098e41c2b 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -77,9 +77,10 @@ const zBaseModel = z.enum([ 'cogview4', 'imagen3', 'chatgpt-4o', + 'bria', ]); export type BaseModelType = z.infer; -export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4', 'imagen3', 'chatgpt-4o']); +export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4', 'imagen3', 'chatgpt-4o', 'bria']); export type MainModelBase = z.infer; export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success; const zModelType = z.enum([ diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index a8ab6d231e3..0e6131e4882 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -52,6 +52,7 @@ export const FIELD_COLORS: { [key: string]: string } = { LoRAModelField: 'teal.500', MainModelField: 'teal.500', FluxMainModelField: 'teal.500', + BriaMainModelField: 'teal.500', SD3MainModelField: 'teal.500', CogView4MainModelField: 'teal.500', SDXLMainModelField: 'teal.500', diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index 8bc856f8bd7..0cc42e1bc0b 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -184,6 +184,10 @@ const zFluxMainModelFieldType = zFieldTypeBase.extend({ name: z.literal('FluxMainModelField'), originalType: zStatelessFieldType.optional(), }); +const zBriaMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('BriaMainModelField'), + originalType: zStatelessFieldType.optional(), +}); const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ name: z.literal('SDXLRefinerModelField'), originalType: zStatelessFieldType.optional(), @@ -304,6 +308,7 @@ const zStatefulFieldType = z.union([ zIntegerGeneratorFieldType, zStringGeneratorFieldType, zImageGeneratorFieldType, + zBriaMainModelFieldType, ]); export type StatefulFieldType = z.infer; const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value); @@ -320,6 +325,7 @@ const modelFieldTypeNames = [ zSD3MainModelFieldType.shape.name.value, zCogView4MainModelFieldType.shape.name.value, zFluxMainModelFieldType.shape.name.value, + zBriaMainModelFieldType.shape.name.value, zSDXLRefinerModelFieldType.shape.name.value, zVAEModelFieldType.shape.name.value, zLoRAModelFieldType.shape.name.value, @@ -863,6 +869,26 @@ export const isFluxMainModelFieldInputTemplate = buildTemplateTypeGuard('FluxMainModelField'); // #endregion +// #region BriaMainModelField +const zBriaMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. +const zBriaMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + value: zBriaMainModelFieldValue, +}); +const zBriaMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zBriaMainModelFieldType, + originalType: zFieldType.optional(), + default: zBriaMainModelFieldValue, +}); +const zBriaMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zBriaMainModelFieldType, +}); +export type BriaMainModelFieldInputInstance = z.infer; +export type BriaMainModelFieldInputTemplate = z.infer; +export const isBriaMainModelFieldInputInstance = buildInstanceTypeGuard(zBriaMainModelFieldInputInstance); +export const isBriaMainModelFieldInputTemplate = + buildTemplateTypeGuard('BriaMainModelField'); +// #endregion + // #region SDXLRefinerModelField /** @alias */ // tells knip to ignore this duplicate export export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only. @@ -1790,6 +1816,7 @@ export const zStatefulFieldValue = z.union([ zMainModelFieldValue, zSDXLMainModelFieldValue, zFluxMainModelFieldValue, + zBriaMainModelFieldValue, zSD3MainModelFieldValue, zCogView4MainModelFieldValue, zSDXLRefinerModelFieldValue, @@ -1837,6 +1864,7 @@ const zStatefulFieldInputInstance = z.union([ zModelIdentifierFieldInputInstance, zMainModelFieldInputInstance, zFluxMainModelFieldInputInstance, + zBriaMainModelFieldInputInstance, zSD3MainModelFieldInputInstance, zCogView4MainModelFieldInputInstance, zSDXLMainModelFieldInputInstance, @@ -1879,6 +1907,7 @@ const zStatefulFieldInputTemplate = z.union([ zModelIdentifierFieldInputTemplate, zMainModelFieldInputTemplate, zFluxMainModelFieldInputTemplate, + zBriaMainModelFieldInputTemplate, zSD3MainModelFieldInputTemplate, zCogView4MainModelFieldInputTemplate, zSDXLMainModelFieldInputTemplate, @@ -1927,6 +1956,7 @@ const zStatefulFieldOutputTemplate = z.union([ zModelIdentifierFieldOutputTemplate, zMainModelFieldOutputTemplate, zFluxMainModelFieldOutputTemplate, + zBriaMainModelFieldOutputTemplate, zSD3MainModelFieldOutputTemplate, zCogView4MainModelFieldOutputTemplate, zSDXLMainModelFieldOutputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index edae5a78f13..9ff5f537de6 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -17,6 +17,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record = SchedulerField: 'dpmpp_3m_k', SDXLMainModelField: undefined, FluxMainModelField: undefined, + BriaMainModelField: undefined, SD3MainModelField: undefined, CogView4MainModelField: undefined, SDXLRefinerModelField: undefined, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts index b458ea180cc..cda0effa945 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts @@ -2,6 +2,7 @@ import { FieldParseError } from 'features/nodes/types/error'; import type { BoardFieldInputTemplate, BooleanFieldInputTemplate, + BriaMainModelFieldInputTemplate, CLIPEmbedModelFieldInputTemplate, CLIPGEmbedModelFieldInputTemplate, CLIPLEmbedModelFieldInputTemplate, @@ -338,6 +339,20 @@ const buildFluxMainModelFieldInputTemplate: FieldInputTemplateBuilder = ({ + schemaObject, + baseField, + fieldType, +}) => { + const template: BriaMainModelFieldInputTemplate = { + ...baseField, + type: fieldType, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, @@ -778,6 +793,7 @@ export const TEMPLATE_BUILDER_MAP: Record { + return config.type === 'main' && config.base === 'bria'; +}; + export const isFluxFillMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => { return config.type === 'main' && config.base === 'flux' && config.variant === 'inpaint'; };