Skip to content

Setup Probe and UI to accept bria main models #7973

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
MainModel = "MainModelField"
CogView4MainModel = "CogView4MainModelField"
FluxMainModel = "FluxMainModelField"
BriaMainModel = "BriaMainModelField"
SD3MainModel = "SD3MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
Expand Down
3 changes: 3 additions & 0 deletions invokeai/backend/model_manager/legacy_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class ModelProbe(object):
}

CLASS2TYPE = {
"BriaPipeline": ModelType.Main,
"FluxPipeline": ModelType.Main,
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
Expand Down Expand Up @@ -861,6 +862,8 @@ def get_base_type(self) -> BaseModelType:
return BaseModelType.StableDiffusion3
elif transformer_conf["_class_name"] == "CogView4Transformer2DModel":
return BaseModelType.CogView4
elif transformer_conf["_class_name"] == "BriaTransformer2DModel":
return BaseModelType.Bria
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")

Expand Down
56 changes: 56 additions & 0 deletions invokeai/backend/model_manager/load/model_loaders/bria.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from pathlib import Path
from typing import Optional

from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
DiffusersConfigBase,
)
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)


@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.Main, format=ModelFormat.Diffusers)
class BriaDiffusersModel(GenericDiffusersLoader):
"""Class to load Bria main models."""

def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, CheckpointConfigBase):
raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.")

if submodel_type is None:
raise Exception("A submodel type must be provided when loading main pipelines.")

model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path, submodel_type)
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
variant = repo_variant.value if repo_variant else None
model_path = model_path / submodel_type.value

dtype = self._torch_dtype
try:
result: AnyModel = load_class.from_pretrained(
model_path,
torch_dtype=dtype,
variant=variant,
)
except OSError as e:
if variant and "no file named" in str(
e
): # try without the variant, just in case user's preferences changed
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
else:
raise e

return result
1 change: 1 addition & 0 deletions invokeai/backend/model_manager/taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class BaseModelType(str, Enum):
CogView4 = "cogview4"
Imagen3 = "imagen3"
ChatGPT4o = "chatgpt-4o"
Bria = "bria"


class ModelType(str, Enum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import {
isBoardFieldInputTemplate,
isBooleanFieldInputInstance,
isBooleanFieldInputTemplate,
isBriaMainModelFieldInputInstance,
isBriaMainModelFieldInputTemplate,
isCLIPEmbedModelFieldInputInstance,
isCLIPEmbedModelFieldInputTemplate,
isCLIPGEmbedModelFieldInputInstance,
Expand Down Expand Up @@ -105,6 +107,7 @@ import { assert } from 'tsafe';

import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
import BriaMainModelFieldInputComponent from './inputs/BriaMainModelFieldInputComponent';
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent';
import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent';
Expand Down Expand Up @@ -408,6 +411,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}

if (isBriaMainModelFieldInputTemplate(template)) {
if (!isBriaMainModelFieldInputInstance(field)) {
return null;
}
return <BriaMainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}

if (isSD3MainModelFieldInputTemplate(template)) {
if (!isSD3MainModelFieldInputInstance(field)) {
return null;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { BriaMainModelFieldInputInstance, BriaMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useBriaModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';

import type { FieldComponentProps } from './types';

type Props = FieldComponentProps<BriaMainModelFieldInputInstance, BriaMainModelFieldInputTemplate>;

const BriaMainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useBriaModels();
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);

return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

export default memo(BriaMainModelFieldInputComponent);
3 changes: 2 additions & 1 deletion invokeai/frontend/web/src/features/nodes/types/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,10 @@ const zBaseModel = z.enum([
'cogview4',
'imagen3',
'chatgpt-4o',
'bria',
]);
export type BaseModelType = z.infer<typeof zBaseModel>;
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4', 'imagen3', 'chatgpt-4o']);
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4', 'imagen3', 'chatgpt-4o', 'bria']);
export type MainModelBase = z.infer<typeof zMainModelBase>;
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
const zModelType = z.enum([
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
LoRAModelField: 'teal.500',
MainModelField: 'teal.500',
FluxMainModelField: 'teal.500',
BriaMainModelField: 'teal.500',
SD3MainModelField: 'teal.500',
CogView4MainModelField: 'teal.500',
SDXLMainModelField: 'teal.500',
Expand Down
30 changes: 30 additions & 0 deletions invokeai/frontend/web/src/features/nodes/types/field.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ const zFluxMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxMainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zBriaMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('BriaMainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'),
originalType: zStatelessFieldType.optional(),
Expand Down Expand Up @@ -304,6 +308,7 @@ const zStatefulFieldType = z.union([
zIntegerGeneratorFieldType,
zStringGeneratorFieldType,
zImageGeneratorFieldType,
zBriaMainModelFieldType,
]);
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value);
Expand All @@ -320,6 +325,7 @@ const modelFieldTypeNames = [
zSD3MainModelFieldType.shape.name.value,
zCogView4MainModelFieldType.shape.name.value,
zFluxMainModelFieldType.shape.name.value,
zBriaMainModelFieldType.shape.name.value,
zSDXLRefinerModelFieldType.shape.name.value,
zVAEModelFieldType.shape.name.value,
zLoRAModelFieldType.shape.name.value,
Expand Down Expand Up @@ -863,6 +869,26 @@ export const isFluxMainModelFieldInputTemplate =
buildTemplateTypeGuard<FluxMainModelFieldInputTemplate>('FluxMainModelField');
// #endregion

// #region BriaMainModelField
const zBriaMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
const zBriaMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zBriaMainModelFieldValue,
});
const zBriaMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zBriaMainModelFieldType,
originalType: zFieldType.optional(),
default: zBriaMainModelFieldValue,
});
const zBriaMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zBriaMainModelFieldType,
});
export type BriaMainModelFieldInputInstance = z.infer<typeof zBriaMainModelFieldInputInstance>;
export type BriaMainModelFieldInputTemplate = z.infer<typeof zBriaMainModelFieldInputTemplate>;
export const isBriaMainModelFieldInputInstance = buildInstanceTypeGuard(zBriaMainModelFieldInputInstance);
export const isBriaMainModelFieldInputTemplate =
buildTemplateTypeGuard<BriaMainModelFieldInputTemplate>('BriaMainModelField');
// #endregion

// #region SDXLRefinerModelField
/** @alias */ // tells knip to ignore this duplicate export
export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only.
Expand Down Expand Up @@ -1790,6 +1816,7 @@ export const zStatefulFieldValue = z.union([
zMainModelFieldValue,
zSDXLMainModelFieldValue,
zFluxMainModelFieldValue,
zBriaMainModelFieldValue,
zSD3MainModelFieldValue,
zCogView4MainModelFieldValue,
zSDXLRefinerModelFieldValue,
Expand Down Expand Up @@ -1837,6 +1864,7 @@ const zStatefulFieldInputInstance = z.union([
zModelIdentifierFieldInputInstance,
zMainModelFieldInputInstance,
zFluxMainModelFieldInputInstance,
zBriaMainModelFieldInputInstance,
zSD3MainModelFieldInputInstance,
zCogView4MainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
Expand Down Expand Up @@ -1879,6 +1907,7 @@ const zStatefulFieldInputTemplate = z.union([
zModelIdentifierFieldInputTemplate,
zMainModelFieldInputTemplate,
zFluxMainModelFieldInputTemplate,
zBriaMainModelFieldInputTemplate,
zSD3MainModelFieldInputTemplate,
zCogView4MainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate,
Expand Down Expand Up @@ -1927,6 +1956,7 @@ const zStatefulFieldOutputTemplate = z.union([
zModelIdentifierFieldOutputTemplate,
zMainModelFieldOutputTemplate,
zFluxMainModelFieldOutputTemplate,
zBriaMainModelFieldOutputTemplate,
zSD3MainModelFieldOutputTemplate,
zCogView4MainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
SchedulerField: 'dpmpp_3m_k',
SDXLMainModelField: undefined,
FluxMainModelField: undefined,
BriaMainModelField: undefined,
SD3MainModelField: undefined,
CogView4MainModelField: undefined,
SDXLRefinerModelField: undefined,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { FieldParseError } from 'features/nodes/types/error';
import type {
BoardFieldInputTemplate,
BooleanFieldInputTemplate,
BriaMainModelFieldInputTemplate,
CLIPEmbedModelFieldInputTemplate,
CLIPGEmbedModelFieldInputTemplate,
CLIPLEmbedModelFieldInputTemplate,
Expand Down Expand Up @@ -338,6 +339,20 @@ const buildFluxMainModelFieldInputTemplate: FieldInputTemplateBuilder<FluxMainMo
return template;
};

const buildBriaMainModelFieldInputTemplate: FieldInputTemplateBuilder<BriaMainModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: BriaMainModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};

return template;
};

const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainModelFieldInputTemplate> = ({
schemaObject,
baseField,
Expand Down Expand Up @@ -778,6 +793,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
SD3MainModelField: buildSD3MainModelFieldInputTemplate,
CogView4MainModelField: buildCogView4MainModelFieldInputTemplate,
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
BriaMainModelField: buildBriaMainModelFieldInputTemplate,
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
StringField: buildStringFieldInputTemplate,
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
} from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import {
isBriaMainModelModelConfig,
isCLIPEmbedModelConfig,
isCLIPVisionModelConfig,
isCogView4MainModelModelConfig,
Expand Down Expand Up @@ -60,6 +61,7 @@ export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
export const useBriaModels = buildModelsHook(isBriaMainModelModelConfig);
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
export const useCogView4Models = buildModelsHook(isCogView4MainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
Expand Down
4 changes: 2 additions & 2 deletions invokeai/frontend/web/src/services/api/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2062,7 +2062,7 @@ export type components = {
* @description Base model type.
* @enum {string}
*/
BaseModelType: "any" | "sd-1" | "sd-2" | "sd-3" | "sdxl" | "sdxl-refiner" | "flux" | "cogview4" | "imagen3" | "chatgpt-4o";
BaseModelType: "any" | "sd-1" | "sd-2" | "sd-3" | "sdxl" | "sdxl-refiner" | "flux" | "cogview4" | "imagen3" | "chatgpt-4o" | "bria";
/** Batch */
Batch: {
/**
Expand Down Expand Up @@ -21258,7 +21258,7 @@ export type components = {
* used, and the type will be ignored. They are included here for backwards compatibility.
* @enum {string}
*/
UIType: "MainModelField" | "CogView4MainModelField" | "FluxMainModelField" | "SD3MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "CLIPLEmbedModelField" | "CLIPGEmbedModelField" | "SpandrelImageToImageModelField" | "ControlLoRAModelField" | "SigLipModelField" | "FluxReduxModelField" | "LLaVAModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
UIType: "MainModelField" | "CogView4MainModelField" | "FluxMainModelField" | "BriaMainModelField" | "SD3MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "CLIPLEmbedModelField" | "CLIPGEmbedModelField" | "SpandrelImageToImageModelField" | "ControlLoRAModelField" | "SigLipModelField" | "FluxReduxModelField" | "LLaVAModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
/** UNetField */
UNetField: {
/** @description Info to load unet submodel */
Expand Down
4 changes: 4 additions & 0 deletions invokeai/frontend/web/src/services/api/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ export const isFluxMainModelModelConfig = (config: AnyModelConfig): config is Ma
return config.type === 'main' && config.base === 'flux';
};

export const isBriaMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'bria';
};

export const isFluxFillMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'flux' && config.variant === 'inpaint';
};
Expand Down
Loading