diff --git a/ee/clickhouse/models/test/test_cohort.py b/ee/clickhouse/models/test/test_cohort.py index 7929a1d62e7f6..c8b52e18a74c7 100644 --- a/ee/clickhouse/models/test/test_cohort.py +++ b/ee/clickhouse/models/test/test_cohort.py @@ -1500,37 +1500,6 @@ def test_calculate_people_ch_in_multiteam_project(self): self.assertCountEqual([r[0] for r in results_team1], [person2_team1.uuid]) self.assertCountEqual([r[0] for r in results_team2], [person1_team2.uuid]) - def test_static_cohort_size_in_multiteam_project(self): - from posthog.models.cohort.util import get_static_cohort_size - - # Create another team in the same project - team2 = Team.objects.create(organization=self.organization, project=self.team.project) - - # Create people in team 1 - _create_person(team_id=self.team.pk, distinct_ids=["person1_team1"]) - _create_person(team_id=self.team.pk, distinct_ids=["person2_team1"]) - - # Create people in team 2 - _create_person(team_id=team2.pk, distinct_ids=["person1_team2"]) - _create_person(team_id=team2.pk, distinct_ids=["person2_team2"]) - _create_person(team_id=team2.pk, distinct_ids=["person3_team2"]) - - # Create a static cohort in team2 - static_cohort = Cohort.objects.create(team=team2, is_static=True, name="static cohort") - - # Add people from both teams to the cohort - static_cohort.insert_users_by_list(["person1_team1", "person2_team1"]) # Team 1 people - static_cohort.insert_users_by_list(["person1_team2", "person2_team2"]) # Team 2 people - - # Verify get_static_cohort_size correctly filters by team_id - # The cohort belongs to team2, so filtering by team2 should count team2 people - count_team2 = get_static_cohort_size(cohort_id=static_cohort.pk, team_id=team2.pk) - self.assertEqual(count_team2, 2) - - # Filtering by team1 should return 0 since cohort belongs to team2 - count_team1 = get_static_cohort_size(cohort_id=static_cohort.pk, team_id=self.team.pk) - self.assertEqual(count_team1, 0) - def test_cohortpeople_action_all_events(self): # Create an action that matches all events (no specific event defined) action = Action.objects.create(team=self.team, name="all events", steps_json=[{"event": None}]) diff --git a/posthog/api/test/__snapshots__/test_feature_flag.ambr b/posthog/api/test/__snapshots__/test_feature_flag.ambr index cd078d0c324e4..632ebe2243078 100644 --- a/posthog/api/test/__snapshots__/test_feature_flag.ambr +++ b/posthog/api/test/__snapshots__/test_feature_flag.ambr @@ -817,9 +817,9 @@ ''' SELECT COUNT(*) AS "__count" FROM "posthog_cohortpeople" - INNER JOIN "posthog_cohort" ON ("posthog_cohortpeople"."cohort_id" = "posthog_cohort"."id") - WHERE ("posthog_cohort"."team_id" = 99999 - AND "posthog_cohortpeople"."cohort_id" = 99999) + INNER JOIN "posthog_person" ON ("posthog_cohortpeople"."person_id" = "posthog_person"."id") + WHERE ("posthog_cohortpeople"."cohort_id" = 99999 + AND "posthog_person"."team_id" = 99999) ''' # --- # name: TestCohortGenerationForFeatureFlag.test_creating_static_cohort_iterator.24 @@ -1185,9 +1185,9 @@ ''' SELECT COUNT(*) AS "__count" FROM "posthog_cohortpeople" - INNER JOIN "posthog_cohort" ON ("posthog_cohortpeople"."cohort_id" = "posthog_cohort"."id") - WHERE ("posthog_cohort"."team_id" = 99999 - AND "posthog_cohortpeople"."cohort_id" = 99999) + INNER JOIN "posthog_person" ON ("posthog_cohortpeople"."person_id" = "posthog_person"."id") + WHERE ("posthog_cohortpeople"."cohort_id" = 99999 + AND "posthog_person"."team_id" = 99999) ''' # --- # name: TestCohortGenerationForFeatureFlag.test_creating_static_cohort_iterator.7 @@ -1423,9 +1423,9 @@ ''' SELECT COUNT(*) AS "__count" FROM "posthog_cohortpeople" - INNER JOIN "posthog_cohort" ON ("posthog_cohortpeople"."cohort_id" = "posthog_cohort"."id") - WHERE ("posthog_cohort"."team_id" = 99999 - AND "posthog_cohortpeople"."cohort_id" = 99999) + INNER JOIN "posthog_person" ON ("posthog_cohortpeople"."person_id" = "posthog_person"."id") + WHERE ("posthog_cohortpeople"."cohort_id" = 99999 + AND "posthog_person"."team_id" = 99999) ''' # --- # name: TestCohortGenerationForFeatureFlag.test_creating_static_cohort_with_cohort_flag_adds_cohort_props_as_default_too.13 @@ -2535,9 +2535,9 @@ ''' SELECT COUNT(*) AS "__count" FROM "posthog_cohortpeople" - INNER JOIN "posthog_cohort" ON ("posthog_cohortpeople"."cohort_id" = "posthog_cohort"."id") - WHERE ("posthog_cohort"."team_id" = 99999 - AND "posthog_cohortpeople"."cohort_id" = 99999) + INNER JOIN "posthog_person" ON ("posthog_cohortpeople"."person_id" = "posthog_person"."id") + WHERE ("posthog_cohortpeople"."cohort_id" = 99999 + AND "posthog_person"."team_id" = 99999) ''' # --- # name: TestCohortGenerationForFeatureFlag.test_creating_static_cohort_with_default_person_properties_adjustment.3 @@ -3043,9 +3043,9 @@ ''' SELECT COUNT(*) AS "__count" FROM "posthog_cohortpeople" - INNER JOIN "posthog_cohort" ON ("posthog_cohortpeople"."cohort_id" = "posthog_cohort"."id") - WHERE ("posthog_cohort"."team_id" = 99999 - AND "posthog_cohortpeople"."cohort_id" = 99999) + INNER JOIN "posthog_person" ON ("posthog_cohortpeople"."person_id" = "posthog_person"."id") + WHERE ("posthog_cohortpeople"."cohort_id" = 99999 + AND "posthog_person"."team_id" = 99999) ''' # --- # name: TestCohortGenerationForFeatureFlag.test_creating_static_cohort_with_default_person_properties_adjustment.8 @@ -3205,9 +3205,9 @@ ''' SELECT COUNT(*) AS "__count" FROM "posthog_cohortpeople" - INNER JOIN "posthog_cohort" ON ("posthog_cohortpeople"."cohort_id" = "posthog_cohort"."id") - WHERE ("posthog_cohort"."team_id" = 99999 - AND "posthog_cohortpeople"."cohort_id" = 99999) + INNER JOIN "posthog_person" ON ("posthog_cohortpeople"."person_id" = "posthog_person"."id") + WHERE ("posthog_cohortpeople"."cohort_id" = 99999 + AND "posthog_person"."team_id" = 99999) ''' # --- # name: TestCohortGenerationForFeatureFlag.test_creating_static_cohort_with_experience_continuity_flag.11 @@ -4101,9 +4101,9 @@ ''' SELECT COUNT(*) AS "__count" FROM "posthog_cohortpeople" - INNER JOIN "posthog_cohort" ON ("posthog_cohortpeople"."cohort_id" = "posthog_cohort"."id") - WHERE ("posthog_cohort"."team_id" = 99999 - AND "posthog_cohortpeople"."cohort_id" = 99999) + INNER JOIN "posthog_person" ON ("posthog_cohortpeople"."person_id" = "posthog_person"."id") + WHERE ("posthog_cohortpeople"."cohort_id" = 99999 + AND "posthog_person"."team_id" = 99999) ''' # --- # name: TestFeatureFlag.test_creating_static_cohort.2 diff --git a/posthog/models/cohort/util.py b/posthog/models/cohort/util.py index 6903dd8d466d5..b7eac8f519bb4 100644 --- a/posthog/models/cohort/util.py +++ b/posthog/models/cohort/util.py @@ -420,7 +420,12 @@ def remove_person_from_static_cohort(person_uuid: uuid.UUID, cohort_id: int, *, def get_static_cohort_size(*, cohort_id: int, team_id: int) -> int: - count = CohortPeople.objects.filter(cohort_id=cohort_id, cohort__team_id=team_id).count() + # First check if cohort belongs to the team (cohort is in default DB) + if not Cohort.objects.filter(id=cohort_id, team_id=team_id).exists(): + return 0 + + # Then count cohortpeople (in persons DB, no cross-DB join) + count = CohortPeople.objects.filter(cohort_id=cohort_id).count() return count diff --git a/posthog/test/test_cohort_model.py b/posthog/test/test_cohort_model.py index 28931df2021c6..29ab872b1930e 100644 --- a/posthog/test/test_cohort_model.py +++ b/posthog/test/test_cohort_model.py @@ -339,6 +339,29 @@ def test_insert_users_list_by_uuid(self): assert cohort_person_uuids == set(uuids) assert cohort.is_calculating is False + def test_static_cohort_size_validates_team(self): + from posthog.models.cohort.util import get_static_cohort_size + + # Create another team in the same organization (different project) + team2 = Team.objects.create(organization=self.organization) + + # Create people in both teams + Person.objects.create(team=self.team, distinct_ids=["person1_team1"]) + Person.objects.create(team=self.team, distinct_ids=["person2_team1"]) + Person.objects.create(team=team2, distinct_ids=["person1_team2"]) + + # Create a static cohort in team1 + cohort = Cohort.objects.create(team=self.team, is_static=True, name="test cohort") + cohort.insert_users_by_list(["person1_team1", "person2_team1"]) + + # Count should work for the correct team + count_correct = get_static_cohort_size(cohort_id=cohort.pk, team_id=self.team.pk) + assert count_correct == 2 + + # Count should be 0 for a different team (validates team ownership) + count_wrong_team = get_static_cohort_size(cohort_id=cohort.pk, team_id=team2.pk) + assert count_wrong_team == 0 + def test_insert_users_by_list_avoids_duplicates_with_batching(self): """Test that batching with duplicates works correctly - people already in cohort are not re-inserted.""" # Create people with distinct IDs