Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion api/staleness_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from app.logging import get_logger
from app.models import Staleness
from app.models import db
from app.models.utils import StalenessCache
from app.staleness_serialization import AttrDict
from app.staleness_serialization import build_serialized_acc_staleness_obj
from app.staleness_serialization import build_staleness_sys_default
Expand All @@ -11,13 +12,17 @@


def get_staleness_obj(org_id: str) -> AttrDict:
cached = StalenessCache.get(org_id)
if cached is not None:
return cached

try:
staleness = db.session.query(Staleness).filter(Staleness.org_id == org_id).one()
logger.info(f"Using custom staleness for org {org_id}.")
staleness = build_serialized_acc_staleness_obj(staleness)
except NoResultFound:
logger.debug(f"No custom staleness data found for org {org_id}, using system default values instead.")
staleness = build_staleness_sys_default(org_id)
return staleness

StalenessCache.put(org_id, staleness)
return staleness
11 changes: 10 additions & 1 deletion app/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,20 @@
from app.staleness_serialization import build_serialized_acc_staleness_obj
from app.staleness_serialization import build_staleness_sys_default
from app.staleness_serialization import get_staleness_timestamps
from lib.batch_cache import ThreadLocalBatchCache

logger = get_logger(__name__)


class StalenessCache(ThreadLocalBatchCache):
"""Batch-scoped cache for staleness config lookups, keyed by org_id."""


def _get_staleness_obj(org_id):
cached = StalenessCache.get(org_id)
if cached is not None:
return cached

try:
from app.models.staleness import Staleness

Expand All @@ -24,8 +33,8 @@ def _get_staleness_obj(org_id):
except NoResultFound:
logger.debug(f"No staleness data found for org {org_id}, using system default values for model")
staleness = build_staleness_sys_default(org_id)
return staleness

StalenessCache.put(org_id, staleness)
return staleness


Expand Down
4 changes: 3 additions & 1 deletion app/queue/host_mq.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from app.models import host_app_data
from app.models import schemas as model_schemas
from app.models.system_profile_static import HostStaticSystemProfile
from app.models.utils import StalenessCache
from app.payload_tracker import PayloadTrackerContext
from app.payload_tracker import PayloadTrackerProcessingContext
from app.payload_tracker import get_payload_tracker
Expand Down Expand Up @@ -83,6 +84,7 @@
from lib.db import session_guard
from lib.feature_flags import FLAG_INVENTORY_REJECT_RHSM_PAYLOADS
from lib.feature_flags import get_flag_value
from lib.group_repository import UngroupedGroupCache
from utils.system_profile_log import extract_host_dict_sp_to_log

logger = get_logger(__name__)
Expand Down Expand Up @@ -235,7 +237,7 @@ def _process_batch(self) -> None:
InvalidRequestError: When the database session is in an invalid state
StaleDataError: When trying to update data modified by another transaction
"""
with session_guard(db.session, close=False), db.session.no_autoflush:
with session_guard(db.session, close=False), db.session.no_autoflush, StalenessCache(), UngroupedGroupCache():
messages = self.consumer.consume(
num_messages=inventory_config().mq_db_batch_max_messages,
timeout=inventory_config().mq_db_batch_max_seconds,
Expand Down
54 changes: 54 additions & 0 deletions lib/batch_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import threading


class ThreadLocalBatchCache:
"""Base class for batch-scoped, thread-local caches.

Subclasses are used as context managers around batch processing.
The cache is stored in thread-local storage and cleared on exit,
ensuring no data leaks between batches or threads.

Usage:
class MyCache(ThreadLocalBatchCache):
pass

with MyCache():
MyCache.put("key", value)
cached = MyCache.get("key")
# cache is cleared here
"""

_local = threading.local()

@classmethod
def _cache_attr(cls):
return f"_cache_{cls.__name__}"

@classmethod
def get(cls, key):
cache = getattr(cls._local, cls._cache_attr(), None)
if cache is None:
return None
return cache.get(key)

@classmethod
def put(cls, key, value):
cache = getattr(cls._local, cls._cache_attr(), None)
if cache is not None:
cache[key] = value

@classmethod
def _create(cls):
setattr(cls._local, cls._cache_attr(), {})

@classmethod
def _clear(cls):
setattr(cls._local, cls._cache_attr(), None)

def __enter__(self):
self._create()
return self

def __exit__(self, *exc):
self._clear()
return False
16 changes: 13 additions & 3 deletions lib/group_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from app.serialization import serialize_rbac_workspace_with_host_count
from app.serialization import serialize_uuid
from app.staleness_serialization import AttrDict
from lib.batch_cache import ThreadLocalBatchCache
from lib.db import raw_db_connection
from lib.db import session_guard
from lib.host_repository import get_host_counts_batch
Expand All @@ -51,6 +52,10 @@
logger = get_logger(__name__)


class UngroupedGroupCache(ThreadLocalBatchCache):
"""Batch-scoped cache for ungrouped group lookups, keyed by org_id."""


def _update_hosts_for_group_changes(host_id_list: list[str], group_id_list: list[str], identity: Identity):
if not host_id_list:
return [], []
Expand Down Expand Up @@ -552,23 +557,28 @@ def get_group_using_host_id(host_id: str, org_id: str):


def get_or_create_ungrouped_hosts_group_for_identity(identity: Identity) -> Group:
cached = UngroupedGroupCache.get(identity.org_id)
if cached is not None:
return cached

group = get_ungrouped_group(identity)

# If the "ungrouped" Group exists, return it.
if group is not None:
UngroupedGroupCache.put(identity.org_id, group)
return group

# Otherwise, create the workspace
workspace_id = rbac_create_ungrouped_hosts_workspace(identity)

# Create "ungrouped" group for this org using group ID == workspace ID
return add_group(
group = add_group(
group_name="Ungrouped Hosts",
org_id=identity.org_id,
account=getattr(identity, "account_number", None),
group_id=workspace_id,
ungrouped=True,
)
UngroupedGroupCache.put(identity.org_id, group)
return group


def get_ungrouped_group(identity: Identity) -> Group:
Expand Down
47 changes: 47 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2838,3 +2838,50 @@ def test_update_display_name_writes_when_reporter_changed(db_create_host, models

assert existing_host.display_name == "my-host"
assert existing_host.display_name_reporter == "yupana"


def test_staleness_cache_context_manager():
"""StalenessCache context manager should create and clear cache."""
from app.models.utils import StalenessCache

assert StalenessCache.get("org1") is None

with StalenessCache():
StalenessCache.put("org1", {"test": True})
assert StalenessCache.get("org1") == {"test": True}

assert StalenessCache.get("org1") is None


def test_staleness_cache_eliminates_redundant_queries(flask_app, mocker): # noqa: ARG001
"""StalenessCache should prevent duplicate Staleness DB queries for the same org_id."""
from app.models.utils import StalenessCache
from app.models.utils import _get_staleness_obj

with StalenessCache():
first_result = _get_staleness_obj(USER_IDENTITY["org_id"])
assert first_result is not None

mocker.patch("app.models.staleness.Staleness.query")
second_result = _get_staleness_obj(USER_IDENTITY["org_id"])

assert second_result is first_result
from app.models.staleness import Staleness

Staleness.query.filter.assert_not_called()

assert StalenessCache.get(USER_IDENTITY["org_id"]) is None


def test_ungrouped_group_cache_context_manager(flask_app, mocker): # noqa: ARG001
"""UngroupedGroupCache should prevent duplicate Group DB queries."""
from lib.group_repository import UngroupedGroupCache

assert UngroupedGroupCache.get("org1") is None

with UngroupedGroupCache():
mock_group = mocker.Mock()
UngroupedGroupCache.put("org1", mock_group)
assert UngroupedGroupCache.get("org1") is mock_group

assert UngroupedGroupCache.get("org1") is None
Loading