diff --git a/changes/3647.feature.md b/changes/3647.feature.md new file mode 100644 index 00000000000..734aae3b05e --- /dev/null +++ b/changes/3647.feature.md @@ -0,0 +1 @@ +Implement `Image` status filtering logics. (e.g. adding an optional argument to the `Image`, `ImageNode` GQL resolvers to enable querying deleted images as well.) diff --git a/docs/manager/graphql-reference/schema.graphql b/docs/manager/graphql-reference/schema.graphql index 37ff52ea46e..22ed5411e02 100644 --- a/docs/manager/graphql-reference/schema.graphql +++ b/docs/manager/graphql-reference/schema.graphql @@ -92,6 +92,9 @@ type Queries { is_installed: Boolean is_operation: Boolean @deprecated(reason: "Deprecated since 24.03.4. This field is ignored if `load_filters` is specified and is not null.") + """Added in 25.4.0.""" + filter_by_statuses: [ImageStatus] = [ALIVE] + """ Added in 24.03.8. Allowed values are: [general, operational, customized]. When superuser queries with `customized` option set the resolver will return every customized images (including those not owned by callee). To resolve images owned by user only call `customized_images`. """ @@ -121,6 +124,9 @@ type Queries { """Default is read_attribute.""" permission: ImagePermissionValueField = "read_attribute" + + """Added in 25.4.0.""" + filter_by_statuses: [ImageStatus] = [ALIVE] filter: String order: String offset: Int @@ -905,6 +911,12 @@ type Image { hash: String } +"""Added in 25.4.0.""" +enum ImageStatus { + ALIVE + DELETED +} + """Added in 25.3.0.""" type ImageConnection { """Pagination data for this connection.""" diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index 80b31ffbaaa..469a8b78fcd 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -51,7 +51,7 @@ from ai.backend.common.bgtask import ProgressReporter from ai.backend.common.docker import ImageRef from ai.backend.manager.models.group import GroupRow -from ai.backend.manager.models.image import ImageIdentifier, rescan_images +from ai.backend.manager.models.image import ImageIdentifier, ImageStatus, rescan_images if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection @@ -1261,6 +1261,7 @@ async def _commit_and_upload(reporter: ProgressReporter) -> None: == f"{params.image_visibility.value}:{image_owner_id}" ) ) + .where(ImageRow.status == ImageStatus.ALIVE) ) existing_image_count = await sess.scalar(query) @@ -1278,14 +1279,13 @@ async def _commit_and_upload(reporter: ProgressReporter) -> None: # check if image with same name exists and reuse ID it if is query = sa.select(ImageRow).where( - ImageRow.name.like(f"{new_canonical}%") - & ( + sa.and_( + ImageRow.name.like(f"{new_canonical}%"), ImageRow.labels["ai.backend.customized-image.owner"].as_string() - == f"{params.image_visibility.value}:{image_owner_id}" - ) - & ( + == f"{params.image_visibility.value}:{image_owner_id}", ImageRow.labels["ai.backend.customized-image.name"].as_string() - == params.image_name + == params.image_name, + ImageRow.status == ImageStatus.ALIVE, ) ) existing_row = await sess.scalar(query) diff --git a/src/ai/backend/manager/cli/image_impl.py b/src/ai/backend/manager/cli/image_impl.py index 553e93b50d8..29db7ef8633 100644 --- a/src/ai/backend/manager/cli/image_impl.py +++ b/src/ai/backend/manager/cli/image_impl.py @@ -16,7 +16,7 @@ from ai.backend.common.types import ImageAlias from ai.backend.logging import BraceStyleAdapter -from ..models.image import ImageAliasRow, ImageIdentifier, ImageRow +from ..models.image import ImageAliasRow, ImageIdentifier, ImageRow, ImageStatus from ..models.image import rescan_images as rescan_images_func from ..models.utils import connect_database from .context import CLIContext, redis_ctx @@ -33,6 +33,7 @@ async def list_images(cli_ctx, short, installed_only): ): displayed_items = [] try: + # Idea: Add `--include-deleted` option to include deleted images? items = await ImageRow.list(session) # NOTE: installed/installed_agents fields are no longer provided in CLI, # until we finish the epic refactoring of image metadata db. @@ -137,7 +138,7 @@ async def purge_image(cli_ctx, canonical_or_alias, architecture): ImageIdentifier(canonical_or_alias, architecture), ImageAlias(canonical_or_alias), ], - load_only_active=False, + filter_by_statuses=None, ) await session.delete(image_row) except UnknownImageReference: @@ -238,6 +239,15 @@ async def validate_image_alias(cli_ctx, alias: str) -> None: log.exception(f"An error occurred. Error: {e}") +def _resolve_architecture(current: bool, architecture: Optional[str]) -> str: + if architecture is not None: + return architecture + if current: + return CURRENT_ARCH + + raise ValueError("Unreachable code!") + + async def validate_image_canonical( cli_ctx, canonical: str, current: bool, architecture: Optional[str] = None ) -> None: @@ -247,22 +257,23 @@ async def validate_image_canonical( ): try: if current or architecture is not None: - if current: - architecture = architecture or CURRENT_ARCH - image_row = await session.scalar( - sa.select(ImageRow).where( - (ImageRow.name == canonical) & (ImageRow.architecture == architecture) - ) + resolved_arch = _resolve_architecture(current, architecture) + image_row = await ImageRow.resolve( + session, [ImageIdentifier(canonical, resolved_arch)] ) - if image_row is None: - raise UnknownImageReference(f"{canonical}/{architecture}") + + print(f"{'architecture':<40}: {resolved_arch}") for key, value in validate_image_labels(image_row.labels).items(): print(f"{key:<40}: ", end="") if isinstance(value, list): value = f"{', '.join(value)}" print(value) else: - rows = await session.scalars(sa.select(ImageRow).where(ImageRow.name == canonical)) + rows = await session.scalars( + sa.select(ImageRow).where( + sa.and_(ImageRow.name == canonical, ImageRow.status == ImageStatus.ALIVE) + ) + ) image_rows = rows.fetchall() if not image_rows: raise UnknownImageReference(f"{canonical}") diff --git a/src/ai/backend/manager/container_registry/base.py b/src/ai/backend/manager/container_registry/base.py index b2b1bbef626..993d6905115 100644 --- a/src/ai/backend/manager/container_registry/base.py +++ b/src/ai/backend/manager/container_registry/base.py @@ -41,7 +41,7 @@ from ai.backend.logging import BraceStyleAdapter from ..defs import INTRINSIC_SLOTS_MIN -from ..models.image import ImageIdentifier, ImageRow, ImageType +from ..models.image import ImageIdentifier, ImageRow, ImageStatus, ImageType from ..models.utils import ExtendedAsyncSAEngine log = BraceStyleAdapter(logging.getLogger(__spec__.name)) @@ -153,7 +153,7 @@ async def commit_rescan_result(self) -> list[ImageRow]: existing_images = await session.scalars( sa.select(ImageRow).where( sa.func.ROW(ImageRow.name, ImageRow.architecture).in_(image_identifiers), - ), + ) ) is_local = self.registry_name == "local" @@ -169,6 +169,15 @@ async def commit_rescan_result(self) -> list[ImageRow]: image_row.is_local = is_local scanned_images.append(image_row) + if image_row.status == ImageStatus.DELETED: + image_row.status = ImageStatus.ALIVE + + progress_msg = f"Restored deleted image - {image_ref.canonical}/{image_ref.architecture} ({update['config_digest']})" + log.info(progress_msg) + + if (reporter := progress_reporter.get()) is not None: + await reporter.update(1, message=progress_msg) + for image_identifier, update in _all_updates.items(): try: parsed_img = ImageRef.from_image_str( @@ -200,6 +209,7 @@ async def commit_rescan_result(self) -> list[ImageRow]: accelerators=update.get("accels"), labels=update["labels"], resources=update["resources"], + status=ImageStatus.ALIVE, ) session.add(image_row) scanned_images.append(image_row) diff --git a/src/ai/backend/manager/container_registry/local.py b/src/ai/backend/manager/container_registry/local.py index 0507013af96..ed95394412d 100644 --- a/src/ai/backend/manager/container_registry/local.py +++ b/src/ai/backend/manager/container_registry/local.py @@ -12,7 +12,7 @@ from ai.backend.common.docker import arch_name_aliases, get_docker_connector from ai.backend.logging import BraceStyleAdapter -from ..models.image import ImageRow +from ..models.image import ImageRow, ImageStatus from .base import ( BaseContainerRegistry, concurrency_sema, @@ -83,9 +83,12 @@ async def _read_image_info( async with self.db.begin_readonly_session() as db_session: already_exists = await db_session.scalar( sa.select([sa.func.count(ImageRow.id)]).where( - ImageRow.config_digest == config_digest, - ImageRow.is_local == sa.false(), - ) + sa.and_( + ImageRow.config_digest == config_digest, + ImageRow.is_local == sa.false(), + ImageRow.status == ImageStatus.ALIVE, + ) + ), ) if already_exists > 0: return {}, "already synchronized from a remote registry" diff --git a/src/ai/backend/manager/models/gql.py b/src/ai/backend/manager/models/gql.py index 6878ed31d72..cb28639efd3 100644 --- a/src/ai/backend/manager/models/gql.py +++ b/src/ai/backend/manager/models/gql.py @@ -117,6 +117,7 @@ ImageConnection, ImageNode, ImagePermissionValueField, + ImageStatusType, ModifyImage, PreloadImage, PurgeImageById, @@ -149,6 +150,7 @@ ) from .image import ( ImageLoadFilter, + ImageStatus, PublicImageLoadFilter, ) from .kernel import ( @@ -581,6 +583,11 @@ class Queries(graphene.ObjectType): is_operation=graphene.Boolean( deprecation_reason="Deprecated since 24.03.4. This field is ignored if `load_filters` is specified and is not null." ), + filter_by_statuses=graphene.List( + ImageStatusType, + default_value=[ImageStatus.ALIVE], + description="Added in 25.4.0.", + ), load_filters=graphene.List( graphene.String, default_value=None, @@ -614,6 +621,11 @@ class Queries(graphene.ObjectType): default_value=ImagePermission.READ_ATTRIBUTE, description=f"Default is {ImagePermission.READ_ATTRIBUTE.value}.", ), + filter_by_statuses=graphene.List( + ImageStatusType, + default_value=[ImageStatus.ALIVE], + description="Added in 25.4.0.", + ), ) user = graphene.Field( @@ -1473,13 +1485,15 @@ async def resolve_image( client_role = ctx.user["role"] client_domain = ctx.user["domain_name"] if id: - item = await Image.load_item_by_id(info.context, uuid.UUID(id)) + item = await Image.load_item_by_id(info.context, uuid.UUID(id), filter_by_statuses=None) else: if not (reference and architecture): raise InvalidAPIParameters( "reference/architecture and id can't be omitted at the same time!" ) - item = await Image.load_item(info.context, reference, architecture) + item = await Image.load_item( + info.context, reference, architecture, filter_by_statuses=None + ) if client_role == UserRole.SUPERADMIN: pass elif client_role in (UserRole.ADMIN, UserRole.USER): @@ -1528,6 +1542,7 @@ async def resolve_images( *, is_installed: bool | None = None, is_operation=False, + filter_by_statuses: Optional[list[ImageStatus]] = [ImageStatus.ALIVE], load_filters: list[str] | None = None, image_filters: list[str] | None = None, ) -> Sequence[Image]: @@ -1559,7 +1574,9 @@ async def resolve_images( # but to conform with previous implementation... image_load_types.add(ImageLoadFilter.OPERATIONAL) - items = await Image.load_all(ctx, types=image_load_types) + items = await Image.load_all( + ctx, types=image_load_types, filter_by_statuses=filter_by_statuses + ) if client_role == UserRole.SUPERADMIN: pass elif client_role in (UserRole.ADMIN, UserRole.USER): @@ -1747,6 +1764,7 @@ async def resolve_image_nodes( info: graphene.ResolveInfo, *, scope_id: ScopeType, + filter_by_statuses: Optional[list[ImageStatus]] = [ImageStatus.ALIVE], permission: ImagePermission = ImagePermission.READ_ATTRIBUTE, filter: Optional[str] = None, order: Optional[str] = None, @@ -1760,6 +1778,7 @@ async def resolve_image_nodes( info, scope_id, permission, + filter_by_statuses, filter_expr=filter, order_expr=order, offset=offset, diff --git a/src/ai/backend/manager/models/gql_models/image.py b/src/ai/backend/manager/models/gql_models/image.py index 294d17a66c0..77fba33e31f 100644 --- a/src/ai/backend/manager/models/gql_models/image.py +++ b/src/ai/backend/manager/models/gql_models/image.py @@ -124,6 +124,8 @@ "accelerators": ("accelerators", None), } +ImageStatusType = graphene.Enum.from_enum(ImageStatus, description="Added in 25.4.0.") + class Image(graphene.ObjectType): id = graphene.UUID() @@ -246,12 +248,15 @@ async def batch_load_by_canonical( cls, graph_ctx: GraphQueryContext, image_names: Sequence[str], + filter_by_statuses: Optional[list[ImageStatus]] = [ImageStatus.ALIVE], ) -> Sequence[Optional[Image]]: query = ( sa.select(ImageRow) .where(ImageRow.name.in_(image_names)) .options(selectinload(ImageRow.aliases)) ) + if filter_by_statuses: + query = query.where(ImageRow.status.in_(filter_by_statuses)) async with graph_ctx.db.begin_readonly_session() as session: result = await session.execute(query) return [await Image.from_row(graph_ctx, row) for row in result.scalars().all()] @@ -261,18 +266,22 @@ async def batch_load_by_image_ref( cls, graph_ctx: GraphQueryContext, image_refs: Sequence[ImageRef], + filter_by_statuses: Optional[list[ImageStatus]] = [ImageStatus.ALIVE], ) -> Sequence[Optional[Image]]: image_names = [x.canonical for x in image_refs] - return await cls.batch_load_by_canonical(graph_ctx, image_names) + return await cls.batch_load_by_canonical(graph_ctx, image_names, filter_by_statuses) @classmethod async def load_item_by_id( cls, ctx: GraphQueryContext, id: UUID, + filter_by_statuses: Optional[list[ImageStatus]] = [ImageStatus.ALIVE], ) -> Image: async with ctx.db.begin_readonly_session() as session: - row = await ImageRow.get(session, id, load_aliases=True) + row = await ImageRow.get( + session, id, load_aliases=True, filter_by_statuses=filter_by_statuses + ) if not row: raise ImageNotFound @@ -284,6 +293,7 @@ async def load_item( ctx: GraphQueryContext, reference: str, architecture: str, + filter_by_statuses: Optional[list[ImageStatus]] = [ImageStatus.ALIVE], ) -> Image: try: async with ctx.db.begin_readonly_session() as session: @@ -293,6 +303,7 @@ async def load_item( ImageIdentifier(reference, architecture), ImageAlias(reference), ], + filter_by_statuses=filter_by_statuses, ) except UnknownImageReference: raise ImageNotFound @@ -304,9 +315,12 @@ async def load_all( ctx: GraphQueryContext, *, types: set[ImageLoadFilter] = set(), + filter_by_statuses: Optional[list[ImageStatus]] = [ImageStatus.ALIVE], ) -> Sequence[Image]: async with ctx.db.begin_readonly_session() as session: - rows = await ImageRow.list(session, load_aliases=True) + rows = await ImageRow.list( + session, load_aliases=True, filter_by_statuses=filter_by_statuses + ) items: list[Image] = [ item async for item in cls.bulk_load(ctx, rows) if item.matches_filter(ctx, types) ] @@ -429,12 +443,16 @@ async def batch_load_by_name_and_arch( cls, graph_ctx: GraphQueryContext, name_and_arch: Sequence[tuple[str, str]], + filter_by_statuses: Optional[list[ImageStatus]] = [ImageStatus.ALIVE], ) -> Sequence[Sequence[ImageNode]]: query = ( sa.select(ImageRow) .where(sa.tuple_(ImageRow.name, ImageRow.architecture).in_(name_and_arch)) .options(selectinload(ImageRow.aliases)) ) + if filter_by_statuses: + query = query.where(ImageRow.status.in_(filter_by_statuses)) + async with graph_ctx.db.begin_readonly_session() as db_session: return await batch_multiresult_in_scalar_stream( graph_ctx, @@ -450,9 +468,12 @@ async def batch_load_by_image_identifier( cls, graph_ctx: GraphQueryContext, image_ids: Sequence[ImageIdentifier], + filter_by_statuses: Optional[list[ImageStatus]] = [ImageStatus.ALIVE], ) -> Sequence[Sequence[ImageNode]]: name_and_arch_tuples = [(img.canonical, img.architecture) for img in image_ids] - return await cls.batch_load_by_name_and_arch(graph_ctx, name_and_arch_tuples) + return await cls.batch_load_by_name_and_arch( + graph_ctx, name_and_arch_tuples, filter_by_statuses + ) @overload @classmethod @@ -511,6 +532,7 @@ def from_row( supported_accelerators=(row.accelerators or "").split(","), aliases=[alias_row.alias for alias_row in row.aliases], permissions=[] if permissions is None else permissions, + status=row.status, ) return result @@ -540,6 +562,7 @@ def from_legacy_image( supported_accelerators=row.supported_accelerators, aliases=row.aliases, permissions=[] if permissions is None else permissions, + status=row.status, ) return result @@ -587,6 +610,7 @@ async def get_connection( info: graphene.ResolveInfo, scope_id: ScopeType, permission: ImagePermission, + filter_by_statuses: Optional[list[ImageStatus]] = [ImageStatus.ALIVE], filter_expr: Optional[str] = None, order_expr: Optional[str] = None, offset: Optional[int] = None, @@ -636,6 +660,11 @@ async def get_connection( return ConnectionResolverResult([], cursor, pagination_order, page_size, 0) query = query.where(cond).options(selectinload(ImageRow.aliases)) cnt_query = cnt_query.where(cond) + + if filter_by_statuses: + query = query.where(ImageRow.status.in_(filter_by_statuses)) + cnt_query = cnt_query.where(ImageRow.status.in_(filter_by_statuses)) + async with graph_ctx.db.begin_readonly_session(db_conn) as db_session: image_rows = (await db_session.scalars(query)).all() total_cnt = await db_session.scalar(cnt_query) diff --git a/src/ai/backend/manager/models/image.py b/src/ai/backend/manager/models/image.py index c4ba016bf3c..6a2c3896714 100644 --- a/src/ai/backend/manager/models/image.py +++ b/src/ai/backend/manager/models/image.py @@ -404,6 +404,7 @@ def __init__( accelerators=None, labels=None, resources=None, + status=ImageStatus.ALIVE, ) -> None: self.name = name self.project = project @@ -419,6 +420,7 @@ def __init__( self.accelerators = accelerators self.labels = labels self.resources = resources + self.status = status @property def trimmed_digest(self) -> str: @@ -445,6 +447,7 @@ async def from_alias( session: AsyncSession, alias: str, load_aliases: bool = False, + filter_by_statuses: Optional[list[ImageStatus]] = [ImageStatus.ALIVE], *, loading_options: Iterable[RelationLoadingOption] = tuple(), ) -> ImageRow: @@ -455,6 +458,9 @@ async def from_alias( ) if load_aliases: query = query.options(selectinload(ImageRow.aliases)) + if filter_by_statuses: + query = query.where(ImageRow.status.in_(filter_by_statuses)) + query = _apply_loading_option(query, loading_options) result = await session.scalar(query) if result is not None: @@ -468,6 +474,7 @@ async def from_image_identifier( session: AsyncSession, identifier: ImageIdentifier, load_aliases: bool = True, + filter_by_statuses: Optional[list[ImageStatus]] = [ImageStatus.ALIVE], *, loading_options: Iterable[RelationLoadingOption] = tuple(), ) -> ImageRow: @@ -478,6 +485,9 @@ async def from_image_identifier( if load_aliases: query = query.options(selectinload(ImageRow.aliases)) + if filter_by_statuses: + query = query.where(ImageRow.status.in_(filter_by_statuses)) + query = _apply_loading_option(query, loading_options) result = await session.execute(query) @@ -496,6 +506,7 @@ async def from_image_ref( *, strict_arch: bool = False, load_aliases: bool = False, + filter_by_statuses: Optional[list[ImageStatus]] = [ImageStatus.ALIVE], loading_options: Iterable[RelationLoadingOption] = tuple(), ) -> ImageRow: """ @@ -508,6 +519,9 @@ async def from_image_ref( query = sa.select(ImageRow).where(ImageRow.name == ref.canonical) if load_aliases: query = query.options(selectinload(ImageRow.aliases)) + if filter_by_statuses: + query = query.where(ImageRow.status.in_(filter_by_statuses)) + query = _apply_loading_option(query, loading_options) result = await session.execute(query) @@ -529,6 +543,7 @@ async def resolve( reference_candidates: list[ImageAlias | ImageRef | ImageIdentifier], *, strict_arch: bool = False, + filter_by_statuses: Optional[list[ImageStatus]] = [ImageStatus.ALIVE], load_aliases: bool = True, loading_options: Iterable[RelationLoadingOption] = tuple(), ) -> ImageRow: @@ -579,7 +594,11 @@ async def resolve( searched_refs.append(f"identifier:{reference!r}") try: if row := await resolver_func( - session, reference, load_aliases=load_aliases, loading_options=loading_options + session, + reference, + load_aliases=load_aliases, + filter_by_statuses=filter_by_statuses, + loading_options=loading_options, ): return row except UnknownImageReference: @@ -588,19 +607,34 @@ async def resolve( @classmethod async def get( - cls, session: AsyncSession, image_id: UUID, load_aliases=False + cls, + session: AsyncSession, + image_id: UUID, + filter_by_statuses: Optional[list[ImageStatus]] = [ImageStatus.ALIVE], + load_aliases: bool = False, ) -> ImageRow | None: query = sa.select(ImageRow).where(ImageRow.id == image_id) if load_aliases: query = query.options(selectinload(ImageRow.aliases)) + if filter_by_statuses: + query = query.where(ImageRow.status.in_(filter_by_statuses)) + result = await session.execute(query) return result.scalar() @classmethod - async def list(cls, session: AsyncSession, load_aliases=False) -> List[ImageRow]: + async def list( + cls, + session: AsyncSession, + filter_by_statuses: Optional[list[ImageStatus]] = [ImageStatus.ALIVE], + load_aliases: bool = False, + ) -> List[ImageRow]: query = sa.select(ImageRow) if load_aliases: query = query.options(selectinload(ImageRow.aliases)) + if filter_by_statuses: + query = query.where(ImageRow.status.in_(filter_by_statuses)) + result = await session.execute(query) return result.scalars().all()