Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions openrag/components/indexer/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ class TaskInfo:
class TaskStateManager:
def __init__(self):
self.tasks: Dict[str, TaskInfo] = {}
self.user_index: Dict[int, set[str]] = {}
self.lock = asyncio.Lock()

async def _ensure_task(self, task_id: str) -> TaskInfo:
Expand Down Expand Up @@ -346,6 +347,7 @@ async def set_details(
"metadata": metadata,
"user_id": user_id,
}
self.user_index.setdefault(user_id, set()).add(task_id)

@ray.method(concurrency_group="set")
async def set_object_ref(self, task_id: str, object_ref: dict):
Expand Down Expand Up @@ -394,6 +396,20 @@ async def get_all_info(self) -> Dict[str, dict]:
for task_id, info in self.tasks.items()
}

@ray.method(concurrency_group="queue_info")
async def get_all_user_info(self, user_id: int) -> Dict[str, dict]:
async with self.lock:
task_ids = self.user_index.get(user_id, set())
return {
tid: {
"state": self.tasks[tid].state,
"error": self.tasks[tid].error,
"details": self.tasks[tid].details,
}
for tid in task_ids
if tid in self.tasks
}

@ray.method(concurrency_group="queue_info")
async def get_pool_info(self) -> Dict[str, int]:
return {
Expand Down
18 changes: 13 additions & 5 deletions openrag/routers/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from fastapi.responses import JSONResponse
from utils.dependencies import get_task_state_manager

from .utils import require_admin
from .utils import current_user, require_admin

# load config
config = load_config()

# Create an APIRouter instance
router = APIRouter(dependencies=[Depends(require_admin)])
router = APIRouter()


def _format_pool_info(worker_info: dict[str, int]) -> dict[str, int]:
Expand All @@ -26,7 +26,9 @@ def _format_pool_info(worker_info: dict[str, int]) -> dict[str, int]:


@router.get("/info")
async def get_queue_info(task_state_manager=Depends(get_task_state_manager)):
async def get_queue_info(
admin=Depends(require_admin), task_state_manager=Depends(get_task_state_manager)
):
all_states: dict = await task_state_manager.get_all_states.remote()
status_counts = Counter(all_states.values())

Expand All @@ -51,14 +53,20 @@ async def list_tasks(
request: Request,
task_status: str | None = None,
task_state_manager=Depends(get_task_state_manager),
user=Depends(current_user),
):
"""
- ?task_status=active → QUEUED | SERIALIZING | CHUNKING | INSERTING
- ?task_status=<exact> → exact match (case-insensitive)
- (none) → all tasks
"""
# fetck task info
all_info: dict[str, dict] = await task_state_manager.get_all_info.remote()
# fetch task info
if user.get("is_admin"):
all_info: dict[str, dict] = await task_state_manager.get_all_info.remote()
else:
all_info: dict[str, dict] = await task_state_manager.get_all_user_info.remote(
user.get("id")
)

if task_status is None:
filtered = all_info.items()
Expand Down
Loading