diff --git a/label_studio/projects/functions/next_task.py b/label_studio/projects/functions/next_task.py index 6d5163cf3e36..e640648ef47b 100644 --- a/label_studio/projects/functions/next_task.py +++ b/label_studio/projects/functions/next_task.py @@ -6,7 +6,7 @@ from core.utils.common import conditional_atomic, db_is_not_sqlite, load_func from core.utils.db import fast_first from django.conf import settings -from django.db.models import BooleanField, Case, Count, Exists, F, Max, OuterRef, Q, QuerySet, Value, When +from django.db.models import Case, Count, Exists, F, Max, OuterRef, Q, QuerySet, When from django.db.models.fields import DecimalField from projects.functions.stream_history import add_stream_history from projects.models import Project @@ -75,37 +75,32 @@ def _try_tasks_with_overlap(tasks: QuerySet[Task]) -> Tuple[Union[Task, None], Q return None, tasks.filter(overlap=1) -def _try_breadth_first( - tasks: QuerySet[Task], user: User, project: Project, attempt_gt_first: bool = False -) -> Union[Task, None]: +def _try_breadth_first(tasks: QuerySet[Task], user: User, project: Project) -> Union[Task, None]: """Try to find tasks with maximum amount of annotations, since we are trying to label tasks as fast as possible""" - # Exclude ground truth annotations from the count when not in onboarding window - # to prevent GT tasks from being prioritized via breadth-first logic - annotation_filter = ~Q(annotations__completed_by=user) - if not attempt_gt_first: - annotation_filter &= ~Q(annotations__ground_truth=True) + if project.annotator_evaluation_enabled: + # When annotator evaluation is enabled, ground truth tasks accumulate overlap regardless of the maximum annotations setting. + # If we include them, they will eventually be front-loaded by the breadth first logic. + # So we exclude them from the candidates. + # Onboarding tasks are served by _try_ground_truth. + # When no in progress tasks are found by breadth first, the next step in the pipeline will serve the remaining GT tasks. + tasks = _annotate_has_ground_truths(tasks) + tasks = tasks.filter(has_ground_truths=False) - tasks = tasks.annotate(annotations_count=Count('annotations', filter=annotation_filter)) + tasks = tasks.annotate(annotations_count=Count('annotations', filter=~Q(annotations__completed_by=user))) max_annotations_count = tasks.aggregate(Max('annotations_count'))['annotations_count__max'] - if max_annotations_count == 0: - # there is no any labeled tasks found - return - - # find any task with maximal amount of created annotations - not_solved_tasks_labeling_started = tasks.annotate( - reach_max_annotations_count=Case( - When(annotations_count=max_annotations_count, then=Value(True)), - default=Value(False), - output_field=BooleanField(), - ) - ) - not_solved_tasks_labeling_with_max_annotations = not_solved_tasks_labeling_started.filter( - reach_max_annotations_count=True - ) - if not_solved_tasks_labeling_with_max_annotations.exists(): - # try to complete tasks that are already in progress - return _get_random_unlocked(not_solved_tasks_labeling_with_max_annotations, user) + + if max_annotations_count == 0 or max_annotations_count is None: + # No tasks with annotations, let the next step in the pipeline handle it + return None + + # Find tasks at the maximum amount of annotations + candidates = tasks.filter(annotations_count=max_annotations_count) + if candidates.exists(): + # Select randomly from candidates + result = _get_random_unlocked(candidates, user) + return result + return None def _try_uncertainty_sampling( @@ -289,7 +284,7 @@ def get_next_task_without_dm_queue( if not next_task and project.maximum_annotations > 1: # if there are already labeled tasks, but task.overlap still < project.maximum_annotations, randomly sampling from them logger.debug(f'User={user} tries depth first from prepared tasks') - next_task = _try_breadth_first(not_solved_tasks, user, project, attempt_gt_first) + next_task = _try_breadth_first(not_solved_tasks, user, project) if next_task: queue_info += (' & ' if queue_info else '') + 'Breadth first queue' diff --git a/label_studio/tasks/tests/factories.py b/label_studio/tasks/tests/factories.py index 3a85864e40ad..74d69aa2f8e8 100644 --- a/label_studio/tasks/tests/factories.py +++ b/label_studio/tasks/tests/factories.py @@ -15,6 +15,7 @@ class TaskFactory(factory.django.DjangoModelFactory): } ) project = factory.SubFactory(load_func(settings.PROJECT_FACTORY)) + overlap = factory.LazyAttribute(lambda obj: obj.project.maximum_annotations if obj.project else 1) class Meta: model = Task diff --git a/label_studio/tests/test_next_task.py b/label_studio/tests/test_next_task.py index bb2bae010b45..77fbab108e79 100644 --- a/label_studio/tests/test_next_task.py +++ b/label_studio/tests/test_next_task.py @@ -8,8 +8,13 @@ from core.redis import redis_healthcheck from django.apps import apps from django.db.models import Q +from django.test import TestCase +from projects.functions.next_task import _try_breadth_first from projects.models import Project +from projects.tests.factories import ProjectFactory from tasks.models import Annotation, Prediction, Task +from tasks.tests.factories import AnnotationFactory, TaskFactory +from users.tests.factories import UserFactory from .utils import ( _client_is_annotator, @@ -1327,3 +1332,93 @@ def complete_task(annotator): else: assert not all_tasks_with_overlap_are_labeled assert not all_tasks_without_overlap_are_not_labeled + + +class TestTryBreadthFirst(TestCase): + @classmethod + def setUpTestData(cls): + cls.user = UserFactory() + cls.other_user = UserFactory() + + # Project with evaluation enabled + cls.project_with_eval = ProjectFactory( + maximum_annotations=3, + annotator_evaluation_enabled=True, + ) + + # Project without evaluation + cls.project_without_eval = ProjectFactory( + maximum_annotations=3, + annotator_evaluation_enabled=False, + ) + + def test_excludes_ground_truth_tasks_when_evaluation_enabled(self): + """ + Test that _try_breadth_first excludes GT tasks when annotator_evaluation_enabled=True. + """ + # Create tasks with varying annotation counts + task_1 = TaskFactory(project=self.project_with_eval) # 2 regular annotations (max) + task_2 = TaskFactory(project=self.project_with_eval) # 1 regular annotation + task_3_gt = TaskFactory(project=self.project_with_eval) # 3 annotations BUT has GT + + # Add regular annotations to task_1 (should be selected) + AnnotationFactory.create_batch(2, task=task_1, ground_truth=False) + + # Add regular annotation to task_2 + AnnotationFactory(task=task_2, ground_truth=False) + + # Add GT annotation to task_3_gt plus a regular one + AnnotationFactory(task=task_3_gt, ground_truth=True) + AnnotationFactory(task=task_3_gt, ground_truth=False) + AnnotationFactory(task=task_3_gt, ground_truth=False) + + # Get all tasks + tasks = Task.objects.filter(project=self.project_with_eval) + + # Execute + result = _try_breadth_first(tasks, self.user, self.project_with_eval) + + # Assert: should return task_1 (max annotations, not GT), not task_3_gt + assert result == task_1 + + def test_includes_ground_truth_tasks_when_evaluation_disabled(self): + """ + Test that _try_breadth_first includes GT tasks when annotator_evaluation_enabled=False. + """ + # Create tasks with varying annotation counts + task_1 = TaskFactory(project=self.project_without_eval) # 2 regular annotations (max) + task_2 = TaskFactory(project=self.project_without_eval) # 1 regular annotation + task_3_gt = TaskFactory(project=self.project_without_eval) # 3 annotations BUT has GT + + # Add regular annotations to task_1 (should be selected) + AnnotationFactory.create_batch(2, task=task_1, ground_truth=False) + + # Add regular annotation to task_2 + AnnotationFactory(task=task_2, ground_truth=False) + + # Add GT annotation to task_3_gt plus a regular one + AnnotationFactory(task=task_3_gt, ground_truth=True) + AnnotationFactory(task=task_3_gt, ground_truth=False) + AnnotationFactory(task=task_3_gt, ground_truth=False) + + # Get all tasks + tasks = Task.objects.filter(project=self.project_without_eval) + + # Execute + result = _try_breadth_first(tasks, self.user, self.project_without_eval) + + # Assert: should return task_3_gt (max annotations, GT), not task_1 or task_2 + assert result == task_3_gt + + def test_returns_none_when_no_tasks_with_annotations_and_evaluation_enabled(self): + + task_gt = TaskFactory(project=self.project_with_eval) + AnnotationFactory(task=task_gt, ground_truth=True) + + tasks = Task.objects.filter(project=self.project_with_eval) + + # Execute + result = _try_breadth_first(tasks, self.user, self.project_with_eval) + + # Assert: should return None + assert result is None