From 3dbf916650e147b9aff5bc811ea93115fc111c83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Sequeira?= Date: Sat, 18 Oct 2025 17:50:30 +0200 Subject: [PATCH] fix(cohorts): Fix open transaction with cohort cursor --- posthog/models/cohort/cohort.py | 57 +++++++------ posthog/models/person/person.py | 140 +++++++++++++++++++++++++++++++- 2 files changed, 170 insertions(+), 27 deletions(-) diff --git a/posthog/models/cohort/cohort.py b/posthog/models/cohort/cohort.py index 5ec22fe31cce2..fd4dae29c3548 100644 --- a/posthog/models/cohort/cohort.py +++ b/posthog/models/cohort/cohort.py @@ -522,34 +522,43 @@ def _insert_users_list_with_batching( from django.db import connections persons_connection = connections[READ_DB_FOR_PERSONS] - cursor = persons_connection.cursor() - for batch_index, batch in batch_iterator: - current_batch_index = batch_index - persons_query = ( - Person.objects.db_manager(READ_DB_FOR_PERSONS) - .filter(team_id=team_id) - .filter(uuid__in=batch) - .exclude(cohort__id=self.id) - ) - if insert_in_clickhouse: - insert_static_cohort( - list(persons_query.values_list("uuid", flat=True)), - self.pk, - team_id=team_id, + with persons_connection.cursor() as cursor: + for batch_index, batch in batch_iterator: + current_batch_index = batch_index + persons_query = ( + Person.objects.db_manager(READ_DB_FOR_PERSONS) + .filter(team_id=team_id) + .filter(uuid__in=batch) + .exclude(cohort__id=self.id) ) - sql, params = persons_query.distinct("pk").only("pk").query.sql_with_params() - query = UPDATE_QUERY.format( - cohort_id=self.pk, - values_query=sql.replace( - 'FROM "posthog_person"', - f', {self.pk}, {self.version or "NULL"} FROM "posthog_person"', - 1, - ), - ) - cursor.execute(query, params) + if insert_in_clickhouse: + insert_static_cohort( + list(persons_query.values_list("uuid", flat=True)), + self.pk, + team_id=team_id, + ) + sql, params = persons_query.distinct("pk").only("pk").query.sql_with_params() + query = UPDATE_QUERY.format( + cohort_id=self.pk, + values_query=sql.replace( + 'FROM "posthog_person"', + f', {self.pk}, {self.version or "NULL"} FROM "posthog_person"', + 1, + ), + ) + cursor.execute(query, params) + + # Commit the transaction after all batches are processed + persons_connection.commit() except Exception as err: processing_error = err + # Rollback the transaction on error + try: + persons_connection.rollback() + except Exception: + # Ignore rollback errors, focus on the original error + pass if settings.DEBUG: raise # Add batch index context to the exception diff --git a/posthog/models/person/person.py b/posthog/models/person/person.py index 0dbdac1c55ef1..8ba42ed37cbb4 100644 --- a/posthog/models/person/person.py +++ b/posthog/models/person/person.py @@ -1,3 +1,4 @@ +import logging from typing import Any, Optional from django.db import connections, models, router, transaction @@ -8,6 +9,8 @@ from ..team import Team from .missing_person import uuidFromDistinctId +logger = logging.getLogger(__name__) + MAX_LIMIT_DISTINCT_IDS = 2500 if "persons_db_reader" in connections: @@ -87,23 +90,116 @@ def _add_distinct_ids(self, distinct_ids: list[str]) -> None: def split_person(self, main_distinct_id: Optional[str], max_splits: Optional[int] = None): original_person = Person.objects.get(pk=self.pk) + + # Log which database we're reading distinct_ids from + logger.info( + "split_person: Fetching distinct_ids", + extra={ + "person_id": self.pk, + "person_uuid": str(self.uuid), + "team_id": self.team_id, + "read_db": READ_DB_FOR_PERSONS, + }, + ) + distinct_ids = original_person.distinct_ids original_person_version = original_person.version or 0 + + # Also fetch from the write database to compare + db_alias_write = router.db_for_write(PersonDistinctId) or "default" + distinct_ids_from_write_db = [ + id[0] + for id in PersonDistinctId.objects.db_manager(db_alias_write) + .filter(person=self, team_id=self.team_id) + .order_by("id") + .values_list("distinct_id") + ] + + # Check for discrepancies between read and write databases + missing_from_read = set(distinct_ids_from_write_db) - set(distinct_ids) + extra_in_read = set(distinct_ids) - set(distinct_ids_from_write_db) + + if missing_from_read or extra_in_read: + logger.warning( + "split_person: Discrepancy between read and write databases", + extra={ + "person_id": self.pk, + "read_db": READ_DB_FOR_PERSONS, + "write_db": db_alias_write, + "distinct_ids_from_read": distinct_ids, + "distinct_ids_from_write": distinct_ids_from_write_db, + "missing_from_read": list(missing_from_read), + "extra_in_read": list(extra_in_read), + }, + ) + + logger.info( + "split_person: Starting person split", + extra={ + "person_id": self.pk, + "person_uuid": str(self.uuid), + "team_id": self.team_id, + "total_distinct_ids": len(distinct_ids), + "distinct_ids": distinct_ids, + "distinct_ids_from_write_db": distinct_ids_from_write_db, + "read_db": READ_DB_FOR_PERSONS, + "write_db": db_alias_write, + "max_splits": max_splits, + "original_person_version": original_person_version, + }, + ) + if not main_distinct_id: self.properties = {} self.save() main_distinct_id = distinct_ids[0] + logger.info( + "split_person: No main_distinct_id provided, using first distinct_id", + extra={"person_id": self.pk, "main_distinct_id": main_distinct_id}, + ) + else: + logger.info( + "split_person: Using provided main_distinct_id", + extra={"person_id": self.pk, "main_distinct_id": main_distinct_id}, + ) if max_splits is not None and len(distinct_ids) > max_splits: - # Split the last N distinct_ids of the list + original_count = len(distinct_ids) distinct_ids = distinct_ids[-1 * max_splits :] + logger.info( + "split_person: Limiting splits due to max_splits", + extra={ + "person_id": self.pk, + "original_count": original_count, + "limited_count": len(distinct_ids), + "max_splits": max_splits, + }, + ) + + split_count = 0 + failed_count = 0 + skipped_count = 0 for distinct_id in distinct_ids: - if not distinct_id == main_distinct_id: + if distinct_id == main_distinct_id: + logger.info( + "split_person: Skipping main_distinct_id", + extra={"person_id": self.pk, "distinct_id": distinct_id}, + ) + skipped_count += 1 + continue + + try: + logger.info( + "split_person: Processing distinct_id", + extra={"person_id": self.pk, "distinct_id": distinct_id}, + ) + db_alias = router.db_for_write(PersonDistinctId) or "default" + with transaction.atomic(using=db_alias): pdi = PersonDistinctId.objects.select_for_update().get(person=self, distinct_id=distinct_id) - person, _ = Person.objects.get_or_create( + person, created = Person.objects.get_or_create( uuid=uuidFromDistinctId(self.team_id, distinct_id), team_id=self.team_id, defaults={ @@ -132,6 +228,44 @@ def split_person(self, main_distinct_id: Optional[str], max_splits: Optional[int team_id=self.team_id, uuid=str(person.uuid), version=person.version, created_at=person.created_at ) + split_count += 1 + logger.info( + "split_person: Successfully split distinct_id", + extra={ + "person_id": self.pk, + "distinct_id": distinct_id, + "new_person_id": person.id, + "new_person_uuid": str(person.uuid), + "person_created": created, + "pdi_version": pdi.version, + }, + ) + except Exception as e: + failed_count += 1 + logger.error( + "split_person: Failed to split distinct_id", + extra={ + "person_id": self.pk, + "distinct_id": distinct_id, + "error": str(e), + "error_type": type(e).__name__, + }, + exc_info=True, + ) + + logger.info( + "split_person: Completed person split", + extra={ + "person_id": self.pk, + "person_uuid": str(self.uuid), + "team_id": self.team_id, + "total_processed": len(distinct_ids), + "split_count": split_count, + "failed_count": failed_count, + "skipped_count": skipped_count, + }, + ) + class PersonDistinctId(models.Model): id = models.BigAutoField(primary_key=True)