Skip to content

Commit

Permalink
feat: Reflect feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine committed Feb 20, 2025
1 parent b5ceb60 commit aa2079b
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 10 deletions.
1 change: 1 addition & 0 deletions changes/3712.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `AuditLog` table to record `ImageAuditLogOperationType` events.
5 changes: 5 additions & 0 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ async def _pull(reporter: ProgressReporter, *, img_conf: ImageConfig) -> None:
await self.agent.produce_event(
ImagePullStartedEvent(
image=str(img_ref),
image_ref=img_ref,
agent_id=self.agent.id,
timestamp=datetime.now(timezone.utc).timestamp(),
)
Expand All @@ -544,6 +545,7 @@ async def _pull(reporter: ProgressReporter, *, img_conf: ImageConfig) -> None:
await self.agent.produce_event(
ImagePullFailedEvent(
image=str(img_ref),
image_ref=img_ref,
agent_id=self.agent.id,
msg=f"timeout (s:{image_pull_timeout})",
)
Expand All @@ -553,6 +555,7 @@ async def _pull(reporter: ProgressReporter, *, img_conf: ImageConfig) -> None:
await self.agent.produce_event(
ImagePullFailedEvent(
image=str(img_ref),
image_ref=img_ref,
agent_id=self.agent.id,
msg=repr(e),
)
Expand All @@ -562,6 +565,7 @@ async def _pull(reporter: ProgressReporter, *, img_conf: ImageConfig) -> None:
await self.agent.produce_event(
ImagePullFinishedEvent(
image=str(img_ref),
image_ref=img_ref,
agent_id=self.agent.id,
timestamp=datetime.now(timezone.utc).timestamp(),
)
Expand All @@ -571,6 +575,7 @@ async def _pull(reporter: ProgressReporter, *, img_conf: ImageConfig) -> None:
await self.agent.produce_event(
ImagePullFinishedEvent(
image=str(img_ref),
image_ref=img_ref,
agent_id=self.agent.id,
timestamp=datetime.now(timezone.utc).timestamp(),
msg="Image already exists",
Expand Down
47 changes: 43 additions & 4 deletions src/ai/backend/common/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from aiotools.taskgroup.types import AsyncExceptionHandler
from redis.asyncio import ConnectionPool

from ai.backend.common.docker import ImageRef
from ai.backend.logging import BraceStyleAdapter, LogLevel

from . import msgpack, redis_helper
Expand Down Expand Up @@ -221,70 +222,108 @@ def deserialize(cls, value: tuple):
class ImagePullStartedEvent(AbstractEvent):
name = "image_pull_started"

image: str = attrs.field()
image: str = attrs.field() # deprecated, use image_ref
agent_id: AgentId = attrs.field()
timestamp: float = attrs.field()
image_ref: Optional[ImageRef] = attrs.field(default=None)

def serialize(self) -> tuple:
if self.image_ref is None:
return (self.image, str(self.agent_id), self.timestamp)

return (
self.image,
str(self.agent_id),
self.timestamp,
self.image_ref,
)

@classmethod
def deserialize(cls, value: tuple):
# Backward compatibility
if len(value) <= 3:
return cls(
image=value[0],
agent_id=AgentId(value[1]),
timestamp=value[2],
)

return cls(
image=value[0],
agent_id=AgentId(value[1]),
timestamp=value[2],
image_ref=value[3],
)


@attrs.define(slots=True, frozen=True)
class ImagePullFinishedEvent(AbstractEvent):
name = "image_pull_finished"

image: str = attrs.field()
image: str = attrs.field() # deprecated, use image_ref
agent_id: AgentId = attrs.field()
timestamp: float = attrs.field()
msg: Optional[str] = attrs.field(default=None)
image_ref: Optional[ImageRef] = attrs.field(default=None)

def serialize(self) -> tuple:
return (
self.image,
str(self.agent_id),
self.timestamp,
self.msg,
self.image_ref,
)

@classmethod
def deserialize(cls, value: tuple):
# Backward compatibility
if len(value) <= 4:
return cls(
image=value[0],
agent_id=AgentId(value[1]),
timestamp=value[2],
msg=value[3],
)

return cls(
image=value[0],
agent_id=AgentId(value[1]),
timestamp=value[2],
msg=value[3],
image_ref=value[4],
)


@attrs.define(slots=True, frozen=True)
class ImagePullFailedEvent(AbstractEvent):
name = "image_pull_failed"

image: str = attrs.field()
image: str = attrs.field() # deprecated, use image_ref
agent_id: AgentId = attrs.field()
msg: str = attrs.field()
image_ref: Optional[ImageRef] = attrs.field(default=None)

def serialize(self) -> tuple:
return (self.image, str(self.agent_id), self.msg)
if self.image_ref is None:
return (self.image, str(self.agent_id), self.msg)
return (self.image, str(self.agent_id), self.msg, self.image_ref)

@classmethod
def deserialize(cls, value: tuple) -> ImagePullFailedEvent:
# Backward compatibility
if len(value) <= 3:
return cls(
image=value[0],
agent_id=AgentId(value[1]),
msg=value[2],
)

return cls(
image=value[0],
agent_id=AgentId(value[1]),
msg=value[2],
image_ref=value[3],
)


Expand Down
13 changes: 13 additions & 0 deletions src/ai/backend/manager/container_registry/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
from ai.backend.common.types import SlotName, SSLContextType
from ai.backend.common.utils import join_non_empty
from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.models.audit_log import (
AuditLogEntityType,
AuditLogRow,
ImageAuditLogOperationType,
)

from ..defs import INTRINSIC_SLOTS_MIN
from ..models.image import ImageIdentifier, ImageRow, ImageType
Expand Down Expand Up @@ -185,6 +190,14 @@ async def commit_rescan_result(self) -> None:
progress_msg = f"Updated image - {parsed_img.canonical}/{image_identifier.architecture} ({update['config_digest']})"
log.info(progress_msg)

session.add(
AuditLogRow(
entity_type=AuditLogEntityType.IMAGE,
operation=ImageAuditLogOperationType.SESSION_CREATE,
entity_id=image_row.id,
)
)

if (reporter := progress_reporter.get()) is not None:
await reporter.update(1, message=progress_msg)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Add AuditLog table
Revision ID: 683ca0a32f41
Revises: 8f85e9d0bd4e
Create Date: 2025-02-14 10:56:10.191119
"""

import sqlalchemy as sa
from alembic import op

from ai.backend.manager.models.base import GUID

# revision identifiers, used by Alembic.
revision = "683ca0a32f41"
down_revision = "8f85e9d0bd4e"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table(
"audit_logs",
sa.Column("id", GUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False),
sa.Column("entity_type", sa.String, nullable=False),
sa.Column("operation", sa.String, nullable=False),
sa.Column("entity_id", sa.String, nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_audit_logs")),
)
op.create_index(op.f("ix_audit_logs_created_at"), "audit_logs", ["created_at"], unique=False)
op.create_index(op.f("ix_audit_logs_entity_type"), "audit_logs", ["entity_type"], unique=False)
op.create_index(op.f("ix_audit_logs_operation"), "audit_logs", ["operation"], unique=False)
op.create_index(op.f("ix_audit_logs_entity_id"), "audit_logs", ["entity_id"], unique=False)


def downgrade() -> None:
op.drop_index(op.f("ix_audit_logs_created_at"), table_name="audit_logs")
op.drop_index(op.f("ix_audit_logs_entity_type"), table_name="audit_logs")
op.drop_index(op.f("ix_audit_logs_operation"), table_name="audit_logs")
op.drop_index(op.f("ix_audit_logs_entity_id"), table_name="audit_logs")
op.drop_table("audit_logs")
89 changes: 89 additions & 0 deletions src/ai/backend/manager/models/audit_log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from __future__ import annotations

import enum
import logging
import uuid

import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncSession

from ai.backend.logging import BraceStyleAdapter

from .base import (
Base,
IDColumn,
)

log = BraceStyleAdapter(logging.getLogger(__spec__.name))

__all__ = ("AuditLogRow", "ImageAuditLogOperationType")


class AuditLogEntityType(enum.StrEnum):
IMAGE = "image"


class ImageAuditLogOperationType(enum.StrEnum):
SESSION_CREATE = "session_create"
UPDATE = "update" # Rescan and update image metadata
PULL = "pull" # Pull image from the registry


AuditLogOperationType = ImageAuditLogOperationType


class AuditLogRow(Base):
__tablename__ = "audit_logs"

id = IDColumn("id")

entity_type = sa.Column("entity_type", sa.String, index=True, nullable=False)
operation = sa.Column("operation", sa.String, index=True, nullable=False)

entity_id = sa.Column(
"entity_id",
sa.String,
nullable=False,
index=True,
)

created_at = sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
index=True,
)

def __init__(
self,
entity_type: AuditLogEntityType,
operation: AuditLogOperationType,
entity_id: str | uuid.UUID,
):
self.entity_type = entity_type.value
self.operation = operation.value
self.entity_id = str(entity_id) if isinstance(entity_id, uuid.UUID) else entity_id

def __str__(self) -> str:
return (
f"AuditLogRow("
f"entity_type: {self.entity_type}, "
f"operation: {self.operation}, "
f"created_at: {self.created_at}, "
f"entity_id: {self.entity_id}"
f")"
)

def __repr__(self) -> str:
return self.__str__()

@classmethod
async def report_image(
cls,
db_session: AsyncSession,
entity_type: AuditLogEntityType,
operation: ImageAuditLogOperationType,
entity_id: str | uuid.UUID,
) -> None:
db_session.add(cls(entity_type, operation, entity_id))
await db_session.flush()
Loading

0 comments on commit aa2079b

Please sign in to comment.