Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/kernelbot/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ def init_background_submission_manager(_manager: BackgroundSubmissionManager):
return background_submission_manager


async def get_queue_snapshot(submission_id: int | None = None) -> dict[str, Any] | None:
if background_submission_manager is None:
return None
return await background_submission_manager.queue_snapshot(submission_id)


@app.exception_handler(KernelBotError)
async def kernel_bot_error_handler(req: Request, exc: KernelBotError):
return JSONResponse(status_code=exc.http_code, content={"message": str(exc)})
Expand Down Expand Up @@ -511,7 +517,8 @@ async def enqueue_background_job(
job_id = db.upsert_submission_job_status(sub_id, "initial", None)
# put submission request in queue
await manager.enqueue(req, mode, sub_id)
return sub_id, job_id
queue = await get_queue_snapshot(sub_id)
return sub_id, job_id, queue


@app.post("/submission/{leaderboard_name}/{gpu_type}/{submission_mode}")
Expand Down Expand Up @@ -564,14 +571,15 @@ async def run_submission_async(
raise HTTPException(status_code=400, detail="Invalid GPU type")

# put submission request to background manager to run in background
sub_id, job_status_id = await enqueue_background_job(
sub_id, job_status_id, queue = await enqueue_background_job(
req, submission_mode_enum, backend_instance, background_submission_manager
)

return JSONResponse(
status_code=202,
content={
"details": {"id": sub_id, "job_status_id": job_status_id},
"queue": queue,
"status": "accepted",
},
)
Expand Down Expand Up @@ -1000,6 +1008,7 @@ async def get_user_submission(
"error": submission.get("job_error"),
"last_heartbeat": submission.get("job_last_heartbeat"),
},
"queue": await get_queue_snapshot(submission_id),
}
except HTTPException:
raise
Expand Down
22 changes: 22 additions & 0 deletions src/libkernelbot/background_submission_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import datetime as dt
import logging
from dataclasses import dataclass
from typing import Any

from libkernelbot.backend import KernelBackend
from libkernelbot.consts import SubmissionMode
Expand Down Expand Up @@ -140,6 +141,27 @@ async def enqueue(
await self._autoscale_up()
return job_id, sub_id

async def queue_snapshot(self, sub_id: int | None = None) -> dict[str, Any]:
async with self._state_lock:
queued = list(self.queue._queue) # noqa: SLF001 - asyncio.Queue has no public snapshot API.
queued_ids = [item.sub_id for item in queued]
position = None
if sub_id in queued_ids:
position = queued_ids.index(sub_id) + 1
stage = "queued" if position is not None else "dispatched"
message = (
"In KernelBot queue"
if stage == "queued"
else "Job dispatched to Modal/GitHub runner"
)

return {
"stage": stage,
"message": message,
"position": position,
"jobs_ahead": None if position is None else position - 1,
}

async def _worker_loop(self):
"""
A worker will keep listening to the queue, and process the job in the queue.
Expand Down
29 changes: 29 additions & 0 deletions tests/test_background_submission_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,35 @@ async def fake_submit_full(req, mode, reporter, sub_id, skip_precheck=False):
await manager.stop()


@pytest.mark.asyncio
async def test_queue_snapshot_reports_position(mock_backend):
db_context = mock_backend.db
db_context.upsert_submission_job_status = mock.Mock(
side_effect=lambda *a, **k: a[0]
)

manager = BackgroundSubmissionManager(
mock_backend, min_workers=0, max_workers=0, idle_seconds=0.1
)
await manager.start()

await manager.enqueue(get_req(1), SubmissionMode.TEST, sub_id=41)
await manager.enqueue(get_req(2), SubmissionMode.TEST, sub_id=42)

snapshot = await manager.queue_snapshot(42)

assert snapshot["stage"] == "queued"
assert snapshot["message"] == "In KernelBot queue"
assert snapshot["position"] == 2
assert snapshot["jobs_ahead"] == 1

manager.queue.get_nowait()
manager.queue.task_done()
manager.queue.get_nowait()
manager.queue.task_done()
await manager.stop()


@pytest.mark.asyncio
async def test_hacked_submission_sets_hacked_status(mock_backend):
db_context = mock_backend.db
Expand Down
Loading