diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py
index c8c89ed31cb..3152dc9f3ba 100644
--- a/invokeai/app/invocations/fields.py
+++ b/invokeai/app/invocations/fields.py
@@ -42,6 +42,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
MainModel = "MainModelField"
CogView4MainModel = "CogView4MainModelField"
FluxMainModel = "FluxMainModelField"
+ BriaMainModel = "BriaMainModelField"
SD3MainModel = "SD3MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py
index 8a0e770d037..caff085b373 100644
--- a/invokeai/backend/model_manager/legacy_probe.py
+++ b/invokeai/backend/model_manager/legacy_probe.py
@@ -125,6 +125,7 @@ class ModelProbe(object):
}
CLASS2TYPE = {
+ "BriaPipeline": ModelType.Main,
"FluxPipeline": ModelType.Main,
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
@@ -861,6 +862,8 @@ def get_base_type(self) -> BaseModelType:
return BaseModelType.StableDiffusion3
elif transformer_conf["_class_name"] == "CogView4Transformer2DModel":
return BaseModelType.CogView4
+ elif transformer_conf["_class_name"] == "BriaTransformer2DModel":
+ return BaseModelType.Bria
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
diff --git a/invokeai/backend/model_manager/load/model_loaders/bria.py b/invokeai/backend/model_manager/load/model_loaders/bria.py
new file mode 100644
index 00000000000..6712e13896e
--- /dev/null
+++ b/invokeai/backend/model_manager/load/model_loaders/bria.py
@@ -0,0 +1,56 @@
+from pathlib import Path
+from typing import Optional
+
+from invokeai.backend.model_manager.config import (
+ AnyModelConfig,
+ CheckpointConfigBase,
+ DiffusersConfigBase,
+)
+from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
+from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
+from invokeai.backend.model_manager.taxonomy import (
+ AnyModel,
+ BaseModelType,
+ ModelFormat,
+ ModelType,
+ SubModelType,
+)
+
+
+@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.Main, format=ModelFormat.Diffusers)
+class BriaDiffusersModel(GenericDiffusersLoader):
+ """Class to load Bria main models."""
+
+ def _load_model(
+ self,
+ config: AnyModelConfig,
+ submodel_type: Optional[SubModelType] = None,
+ ) -> AnyModel:
+ if isinstance(config, CheckpointConfigBase):
+ raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.")
+
+ if submodel_type is None:
+ raise Exception("A submodel type must be provided when loading main pipelines.")
+
+ model_path = Path(config.path)
+ load_class = self.get_hf_load_class(model_path, submodel_type)
+ repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
+ variant = repo_variant.value if repo_variant else None
+ model_path = model_path / submodel_type.value
+
+ dtype = self._torch_dtype
+ try:
+ result: AnyModel = load_class.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ variant=variant,
+ )
+ except OSError as e:
+ if variant and "no file named" in str(
+ e
+ ): # try without the variant, just in case user's preferences changed
+ result = load_class.from_pretrained(model_path, torch_dtype=dtype)
+ else:
+ raise e
+
+ return result
diff --git a/invokeai/backend/model_manager/taxonomy.py b/invokeai/backend/model_manager/taxonomy.py
index d77f5fc10ff..76c0ffe4f60 100644
--- a/invokeai/backend/model_manager/taxonomy.py
+++ b/invokeai/backend/model_manager/taxonomy.py
@@ -28,6 +28,7 @@ class BaseModelType(str, Enum):
CogView4 = "cogview4"
Imagen3 = "imagen3"
ChatGPT4o = "chatgpt-4o"
+ Bria = "bria"
class ModelType(str, Enum):
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx
index a4dc5414953..20fb19c1d31 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx
@@ -23,6 +23,8 @@ import {
isBoardFieldInputTemplate,
isBooleanFieldInputInstance,
isBooleanFieldInputTemplate,
+ isBriaMainModelFieldInputInstance,
+ isBriaMainModelFieldInputTemplate,
isCLIPEmbedModelFieldInputInstance,
isCLIPEmbedModelFieldInputTemplate,
isCLIPGEmbedModelFieldInputInstance,
@@ -105,6 +107,7 @@ import { assert } from 'tsafe';
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
+import BriaMainModelFieldInputComponent from './inputs/BriaMainModelFieldInputComponent';
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent';
import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent';
@@ -408,6 +411,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return ;
}
+ if (isBriaMainModelFieldInputTemplate(template)) {
+ if (!isBriaMainModelFieldInputInstance(field)) {
+ return null;
+ }
+ return ;
+ }
+
if (isSD3MainModelFieldInputTemplate(template)) {
if (!isSD3MainModelFieldInputInstance(field)) {
return null;
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaMainModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaMainModelFieldInputComponent.tsx
new file mode 100644
index 00000000000..8d8af426a56
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaMainModelFieldInputComponent.tsx
@@ -0,0 +1,44 @@
+import { useAppDispatch } from 'app/store/storeHooks';
+import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
+import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
+import type { BriaMainModelFieldInputInstance, BriaMainModelFieldInputTemplate } from 'features/nodes/types/field';
+import { memo, useCallback } from 'react';
+import { useBriaModels } from 'services/api/hooks/modelsByType';
+import type { MainModelConfig } from 'services/api/types';
+
+import type { FieldComponentProps } from './types';
+
+type Props = FieldComponentProps;
+
+const BriaMainModelFieldInputComponent = (props: Props) => {
+ const { nodeId, field } = props;
+ const dispatch = useAppDispatch();
+ const [modelConfigs, { isLoading }] = useBriaModels();
+ const onChange = useCallback(
+ (value: MainModelConfig | null) => {
+ if (!value) {
+ return;
+ }
+ dispatch(
+ fieldMainModelValueChanged({
+ nodeId,
+ fieldName: field.name,
+ value,
+ })
+ );
+ },
+ [dispatch, field.name, nodeId]
+ );
+
+ return (
+
+ );
+};
+
+export default memo(BriaMainModelFieldInputComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts
index b1b17de0726..b1098e41c2b 100644
--- a/invokeai/frontend/web/src/features/nodes/types/common.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/common.ts
@@ -77,9 +77,10 @@ const zBaseModel = z.enum([
'cogview4',
'imagen3',
'chatgpt-4o',
+ 'bria',
]);
export type BaseModelType = z.infer;
-export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4', 'imagen3', 'chatgpt-4o']);
+export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4', 'imagen3', 'chatgpt-4o', 'bria']);
export type MainModelBase = z.infer;
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
const zModelType = z.enum([
diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts
index a8ab6d231e3..0e6131e4882 100644
--- a/invokeai/frontend/web/src/features/nodes/types/constants.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts
@@ -52,6 +52,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
LoRAModelField: 'teal.500',
MainModelField: 'teal.500',
FluxMainModelField: 'teal.500',
+ BriaMainModelField: 'teal.500',
SD3MainModelField: 'teal.500',
CogView4MainModelField: 'teal.500',
SDXLMainModelField: 'teal.500',
diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts
index 8bc856f8bd7..0cc42e1bc0b 100644
--- a/invokeai/frontend/web/src/features/nodes/types/field.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/field.ts
@@ -184,6 +184,10 @@ const zFluxMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxMainModelField'),
originalType: zStatelessFieldType.optional(),
});
+const zBriaMainModelFieldType = zFieldTypeBase.extend({
+ name: z.literal('BriaMainModelField'),
+ originalType: zStatelessFieldType.optional(),
+});
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'),
originalType: zStatelessFieldType.optional(),
@@ -304,6 +308,7 @@ const zStatefulFieldType = z.union([
zIntegerGeneratorFieldType,
zStringGeneratorFieldType,
zImageGeneratorFieldType,
+ zBriaMainModelFieldType,
]);
export type StatefulFieldType = z.infer;
const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value);
@@ -320,6 +325,7 @@ const modelFieldTypeNames = [
zSD3MainModelFieldType.shape.name.value,
zCogView4MainModelFieldType.shape.name.value,
zFluxMainModelFieldType.shape.name.value,
+ zBriaMainModelFieldType.shape.name.value,
zSDXLRefinerModelFieldType.shape.name.value,
zVAEModelFieldType.shape.name.value,
zLoRAModelFieldType.shape.name.value,
@@ -863,6 +869,26 @@ export const isFluxMainModelFieldInputTemplate =
buildTemplateTypeGuard('FluxMainModelField');
// #endregion
+// #region BriaMainModelField
+const zBriaMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
+const zBriaMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
+ value: zBriaMainModelFieldValue,
+});
+const zBriaMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
+ type: zBriaMainModelFieldType,
+ originalType: zFieldType.optional(),
+ default: zBriaMainModelFieldValue,
+});
+const zBriaMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
+ type: zBriaMainModelFieldType,
+});
+export type BriaMainModelFieldInputInstance = z.infer;
+export type BriaMainModelFieldInputTemplate = z.infer;
+export const isBriaMainModelFieldInputInstance = buildInstanceTypeGuard(zBriaMainModelFieldInputInstance);
+export const isBriaMainModelFieldInputTemplate =
+ buildTemplateTypeGuard('BriaMainModelField');
+// #endregion
+
// #region SDXLRefinerModelField
/** @alias */ // tells knip to ignore this duplicate export
export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only.
@@ -1790,6 +1816,7 @@ export const zStatefulFieldValue = z.union([
zMainModelFieldValue,
zSDXLMainModelFieldValue,
zFluxMainModelFieldValue,
+ zBriaMainModelFieldValue,
zSD3MainModelFieldValue,
zCogView4MainModelFieldValue,
zSDXLRefinerModelFieldValue,
@@ -1837,6 +1864,7 @@ const zStatefulFieldInputInstance = z.union([
zModelIdentifierFieldInputInstance,
zMainModelFieldInputInstance,
zFluxMainModelFieldInputInstance,
+ zBriaMainModelFieldInputInstance,
zSD3MainModelFieldInputInstance,
zCogView4MainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
@@ -1879,6 +1907,7 @@ const zStatefulFieldInputTemplate = z.union([
zModelIdentifierFieldInputTemplate,
zMainModelFieldInputTemplate,
zFluxMainModelFieldInputTemplate,
+ zBriaMainModelFieldInputTemplate,
zSD3MainModelFieldInputTemplate,
zCogView4MainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate,
@@ -1927,6 +1956,7 @@ const zStatefulFieldOutputTemplate = z.union([
zModelIdentifierFieldOutputTemplate,
zMainModelFieldOutputTemplate,
zFluxMainModelFieldOutputTemplate,
+ zBriaMainModelFieldOutputTemplate,
zSD3MainModelFieldOutputTemplate,
zCogView4MainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate,
diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts
index edae5a78f13..9ff5f537de6 100644
--- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts
@@ -17,6 +17,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record =
SchedulerField: 'dpmpp_3m_k',
SDXLMainModelField: undefined,
FluxMainModelField: undefined,
+ BriaMainModelField: undefined,
SD3MainModelField: undefined,
CogView4MainModelField: undefined,
SDXLRefinerModelField: undefined,
diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts
index b458ea180cc..cda0effa945 100644
--- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts
@@ -2,6 +2,7 @@ import { FieldParseError } from 'features/nodes/types/error';
import type {
BoardFieldInputTemplate,
BooleanFieldInputTemplate,
+ BriaMainModelFieldInputTemplate,
CLIPEmbedModelFieldInputTemplate,
CLIPGEmbedModelFieldInputTemplate,
CLIPLEmbedModelFieldInputTemplate,
@@ -338,6 +339,20 @@ const buildFluxMainModelFieldInputTemplate: FieldInputTemplateBuilder = ({
+ schemaObject,
+ baseField,
+ fieldType,
+}) => {
+ const template: BriaMainModelFieldInputTemplate = {
+ ...baseField,
+ type: fieldType,
+ default: schemaObject.default ?? undefined,
+ };
+
+ return template;
+};
+
const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder = ({
schemaObject,
baseField,
@@ -778,6 +793,7 @@ export const TEMPLATE_BUILDER_MAP: Record {
+ return config.type === 'main' && config.base === 'bria';
+};
+
export const isFluxFillMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'flux' && config.variant === 'inpaint';
};