Skip to content

Commit bd5a95d

Browse files
authored
Merge pull request #112 from linagora/fix-tasks-permissions
fixed user access to /tasks
2 parents ddb9284 + 0e3e0fb commit bd5a95d

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

openrag/components/indexer/indexer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ class TaskInfo:
308308
class TaskStateManager:
309309
def __init__(self):
310310
self.tasks: Dict[str, TaskInfo] = {}
311+
self.user_index: Dict[int, set[str]] = {}
311312
self.lock = asyncio.Lock()
312313

313314
async def _ensure_task(self, task_id: str) -> TaskInfo:
@@ -346,6 +347,7 @@ async def set_details(
346347
"metadata": metadata,
347348
"user_id": user_id,
348349
}
350+
self.user_index.setdefault(user_id, set()).add(task_id)
349351

350352
@ray.method(concurrency_group="set")
351353
async def set_object_ref(self, task_id: str, object_ref: dict):
@@ -394,6 +396,20 @@ async def get_all_info(self) -> Dict[str, dict]:
394396
for task_id, info in self.tasks.items()
395397
}
396398

399+
@ray.method(concurrency_group="queue_info")
400+
async def get_all_user_info(self, user_id: int) -> Dict[str, dict]:
401+
async with self.lock:
402+
task_ids = self.user_index.get(user_id, set())
403+
return {
404+
tid: {
405+
"state": self.tasks[tid].state,
406+
"error": self.tasks[tid].error,
407+
"details": self.tasks[tid].details,
408+
}
409+
for tid in task_ids
410+
if tid in self.tasks
411+
}
412+
397413
@ray.method(concurrency_group="queue_info")
398414
async def get_pool_info(self) -> Dict[str, int]:
399415
return {

openrag/routers/queue.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
from fastapi.responses import JSONResponse
66
from utils.dependencies import get_task_state_manager
77

8-
from .utils import require_admin
8+
from .utils import current_user, require_admin
99

1010
# load config
1111
config = load_config()
1212

1313
# Create an APIRouter instance
14-
router = APIRouter(dependencies=[Depends(require_admin)])
14+
router = APIRouter()
1515

1616

1717
def _format_pool_info(worker_info: dict[str, int]) -> dict[str, int]:
@@ -26,7 +26,9 @@ def _format_pool_info(worker_info: dict[str, int]) -> dict[str, int]:
2626

2727

2828
@router.get("/info")
29-
async def get_queue_info(task_state_manager=Depends(get_task_state_manager)):
29+
async def get_queue_info(
30+
admin=Depends(require_admin), task_state_manager=Depends(get_task_state_manager)
31+
):
3032
all_states: dict = await task_state_manager.get_all_states.remote()
3133
status_counts = Counter(all_states.values())
3234

@@ -51,14 +53,20 @@ async def list_tasks(
5153
request: Request,
5254
task_status: str | None = None,
5355
task_state_manager=Depends(get_task_state_manager),
56+
user=Depends(current_user),
5457
):
5558
"""
5659
- ?task_status=active → QUEUED | SERIALIZING | CHUNKING | INSERTING
5760
- ?task_status=<exact> → exact match (case-insensitive)
5861
- (none) → all tasks
5962
"""
60-
# fetck task info
61-
all_info: dict[str, dict] = await task_state_manager.get_all_info.remote()
63+
# fetch task info
64+
if user.get("is_admin"):
65+
all_info: dict[str, dict] = await task_state_manager.get_all_info.remote()
66+
else:
67+
all_info: dict[str, dict] = await task_state_manager.get_all_user_info.remote(
68+
user.get("id")
69+
)
6270

6371
if task_status is None:
6472
filtered = all_info.items()

0 commit comments

Comments
 (0)