Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 19 additions & 29 deletions label_studio/projects/functions/next_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from core.utils.common import conditional_atomic, db_is_not_sqlite, load_func
from core.utils.db import fast_first
from django.conf import settings
from django.db.models import BooleanField, Case, Count, Exists, F, Max, OuterRef, Q, QuerySet, Value, When
from django.db.models import Case, Count, Exists, F, Max, OuterRef, Q, QuerySet, When
from django.db.models.fields import DecimalField
from projects.functions.stream_history import add_stream_history
from projects.models import Project
Expand Down Expand Up @@ -75,37 +75,27 @@ def _try_tasks_with_overlap(tasks: QuerySet[Task]) -> Tuple[Union[Task, None], Q
return None, tasks.filter(overlap=1)


def _try_breadth_first(
tasks: QuerySet[Task], user: User, project: Project, attempt_gt_first: bool = False
) -> Union[Task, None]:
def _try_breadth_first(tasks: QuerySet[Task], user: User, project: Project) -> Union[Task, None]:
"""Try to find tasks with maximum amount of annotations, since we are trying to label tasks as fast as possible"""

# Exclude ground truth annotations from the count when not in onboarding window
# to prevent GT tasks from being prioritized via breadth-first logic
annotation_filter = ~Q(annotations__completed_by=user)
if not attempt_gt_first:
annotation_filter &= ~Q(annotations__ground_truth=True)
if project.annotator_evaluation_enabled:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave comment here explaining why we're doing this , how it applies to onboarding/ongoing etc

tasks = _annotate_has_ground_truths(tasks)
tasks = tasks.filter(has_ground_truths=False)

tasks = tasks.annotate(annotations_count=Count('annotations', filter=annotation_filter))
tasks = tasks.annotate(annotations_count=Count('annotations', filter=~Q(annotations__completed_by=user)))
max_annotations_count = tasks.aggregate(Max('annotations_count'))['annotations_count__max']
if max_annotations_count == 0:
# there is no any labeled tasks found
return

# find any task with maximal amount of created annotations
not_solved_tasks_labeling_started = tasks.annotate(
reach_max_annotations_count=Case(
When(annotations_count=max_annotations_count, then=Value(True)),
default=Value(False),
output_field=BooleanField(),
)
)
not_solved_tasks_labeling_with_max_annotations = not_solved_tasks_labeling_started.filter(
reach_max_annotations_count=True
)
if not_solved_tasks_labeling_with_max_annotations.exists():
# try to complete tasks that are already in progress
return _get_random_unlocked(not_solved_tasks_labeling_with_max_annotations, user)

if max_annotations_count == 0 or max_annotations_count is None:
# No tasks with annotations, let the next step in the pipeline handle it
return None

# Find tasks at the maximum amount of annotations
candidates = tasks.filter(annotations_count=max_annotations_count)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good simplification :D

if candidates.exists():
# Select randomly from candidates
result = _get_random_unlocked(candidates, user)
return result
return None


def _try_uncertainty_sampling(
Expand Down Expand Up @@ -289,7 +279,7 @@ def get_next_task_without_dm_queue(
if not next_task and project.maximum_annotations > 1:
# if there are already labeled tasks, but task.overlap still < project.maximum_annotations, randomly sampling from them
logger.debug(f'User={user} tries depth first from prepared tasks')
next_task = _try_breadth_first(not_solved_tasks, user, project, attempt_gt_first)
next_task = _try_breadth_first(not_solved_tasks, user, project)
if next_task:
queue_info += (' & ' if queue_info else '') + 'Breadth first queue'

Expand Down
1 change: 1 addition & 0 deletions label_studio/tasks/tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class TaskFactory(factory.django.DjangoModelFactory):
}
)
project = factory.SubFactory(load_func(settings.PROJECT_FACTORY))
overlap = factory.LazyAttribute(lambda obj: obj.project.maximum_annotations)

class Meta:
model = Task
Expand Down
95 changes: 95 additions & 0 deletions label_studio/tests/test_next_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@
from core.redis import redis_healthcheck
from django.apps import apps
from django.db.models import Q
from django.test import TestCase
from projects.functions.next_task import _try_breadth_first
from projects.models import Project
from projects.tests.factories import ProjectFactory
from tasks.models import Annotation, Prediction, Task
from tasks.tests.factories import AnnotationFactory, TaskFactory
from users.tests.factories import UserFactory

from .utils import (
_client_is_annotator,
Expand Down Expand Up @@ -1327,3 +1332,93 @@ def complete_task(annotator):
else:
assert not all_tasks_with_overlap_are_labeled
assert not all_tasks_without_overlap_are_not_labeled


class TestTryBreadthFirst(TestCase):
@classmethod
def setUpTestData(cls):
cls.user = UserFactory()
cls.other_user = UserFactory()

# Project with evaluation enabled
cls.project_with_eval = ProjectFactory(
maximum_annotations=3,
annotator_evaluation_enabled=True,
)

# Project without evaluation
cls.project_without_eval = ProjectFactory(
maximum_annotations=3,
annotator_evaluation_enabled=False,
)

def test_excludes_ground_truth_tasks_when_evaluation_enabled(self):
"""
Test that _try_breadth_first excludes GT tasks when annotator_evaluation_enabled=True.
"""
# Create tasks with varying annotation counts
task_1 = TaskFactory(project=self.project_with_eval) # 2 regular annotations (max)
task_2 = TaskFactory(project=self.project_with_eval) # 1 regular annotation
task_3_gt = TaskFactory(project=self.project_with_eval) # 3 annotations BUT has GT

# Add regular annotations to task_1 (should be selected)
AnnotationFactory.create_batch(2, task=task_1, ground_truth=False)

# Add regular annotation to task_2
AnnotationFactory(task=task_2, ground_truth=False)

# Add GT annotation to task_3_gt plus a regular one
AnnotationFactory(task=task_3_gt, ground_truth=True)
AnnotationFactory(task=task_3_gt, ground_truth=False)
AnnotationFactory(task=task_3_gt, ground_truth=False)

# Get all tasks
tasks = Task.objects.filter(project=self.project_with_eval)

# Execute
result = _try_breadth_first(tasks, self.user, self.project_with_eval)

# Assert: should return task_1 (max annotations, not GT), not task_3_gt
assert result == task_1

def test_includes_ground_truth_tasks_when_evaluation_disabled(self):
"""
Test that _try_breadth_first includes GT tasks when annotator_evaluation_enabled=False.
"""
# Create tasks with varying annotation counts
task_1 = TaskFactory(project=self.project_without_eval) # 2 regular annotations (max)
task_2 = TaskFactory(project=self.project_without_eval) # 1 regular annotation
task_3_gt = TaskFactory(project=self.project_without_eval) # 3 annotations BUT has GT

# Add regular annotations to task_1 (should be selected)
AnnotationFactory.create_batch(2, task=task_1, ground_truth=False)

# Add regular annotation to task_2
AnnotationFactory(task=task_2, ground_truth=False)

# Add GT annotation to task_3_gt plus a regular one
AnnotationFactory(task=task_3_gt, ground_truth=True)
AnnotationFactory(task=task_3_gt, ground_truth=False)
AnnotationFactory(task=task_3_gt, ground_truth=False)

# Get all tasks
tasks = Task.objects.filter(project=self.project_without_eval)

# Execute
result = _try_breadth_first(tasks, self.user, self.project_without_eval)

# Assert: should return task_3_gt (max annotations, GT), not task_1 or task_2
assert result == task_3_gt

def test_returns_none_when_no_tasks_with_annotations_and_evaluation_enabled(self):

task_gt = TaskFactory(project=self.project_with_eval)
AnnotationFactory(task=task_gt, ground_truth=True)

tasks = Task.objects.filter(project=self.project_with_eval)

# Execute
result = _try_breadth_first(tasks, self.user, self.project_with_eval)

# Assert: should return None
assert result is None
Loading