From 92d8edec7647d8f5f6aa00bbdf75e7f579c4f477 Mon Sep 17 00:00:00 2001 From: marcvergees Date: Fri, 3 Jul 2026 14:04:52 +0200 Subject: [PATCH] feat: :sparkles: Allow users to choose AI model --- app/api/routes/forms.py | 11 +++++++++++ app/api/schemas/forms.py | 6 +++++- docs/1. SETUP.md | 13 +++++++++++++ tests/test_api.py | 20 ++++++++++++++++++++ 4 files changed, 49 insertions(+), 1 deletion(-) diff --git a/app/api/routes/forms.py b/app/api/routes/forms.py index be16029..607dd60 100644 --- a/app/api/routes/forms.py +++ b/app/api/routes/forms.py @@ -10,6 +10,7 @@ FormFillResponse, ModelsResponse, TranscriptionResponse, + ModelPullRequest, ) from app.core.config import OLLAMA_HOST, OLLAMA_MODEL, BASE_DIR, RETENTION_PERIOD_DAYS from app.services.whisper import call_whisper_asr @@ -86,6 +87,16 @@ def list_models(): return ModelsResponse(models=models, default=default_model) +@router.post("/pull") +def pull_model(req: ModelPullRequest): + try: + resp = requests.post(f"{OLLAMA_HOST}/api/pull", json={"name": req.model, "stream": False}, timeout=600) + resp.raise_for_status() + return {"status": "success", "message": f"Model {req.model} pulled successfully"} + except requests.exceptions.RequestException as e: + raise AppError(f"Failed to pull model: {e}", status_code=500, error_code="MODEL_PULL_ERROR") + + @router.post("/transcribe", response_model=TranscriptionResponse) def transcribe(audio: UploadFile = File(...)): """Forward recorded audio to the local Whisper ASR sidecar and return text. diff --git a/app/api/schemas/forms.py b/app/api/schemas/forms.py index 155b14f..5acdadc 100644 --- a/app/api/schemas/forms.py +++ b/app/api/schemas/forms.py @@ -74,4 +74,8 @@ class Config: class AsyncFormFillResponse(BaseModel): - jobs: list[AsyncJobSubmitResponse] \ No newline at end of file + jobs: list[AsyncJobSubmitResponse] + + +class ModelPullRequest(BaseModel): + model: str \ No newline at end of file diff --git a/docs/1. SETUP.md b/docs/1. SETUP.md index 30d2aee..23d3ecc 100644 --- a/docs/1. SETUP.md +++ b/docs/1. SETUP.md @@ -91,6 +91,19 @@ Check `make logs-app` for the actual error. The entrypoint runs database migrati **Want a clean slate** `make super-clean` stops everything and **deletes all volumes** database, uploads, and downloaded model weights. Only use it when you intend to wipe all local data. +## AI Model Selection + +FireForm allows you to choose which AI model to run during form extraction directly from the dropdown in the frontend "Fill Form" UI. + +The supported recommended models are: +- `qwen2.5:1.5b` (default, lightweight) +- `qwen2.5:3b` +- `qwen2.5:7b` +- `llama3.2:3b` +- `mistral:7b` + +If you select a model that is not yet installed (pulled) in your local Ollama instance, the app will automatically request the backend to download it via the `POST /api/v1/forms/pull` endpoint. During installation, form submission is temporarily disabled, and status progress is displayed. Once downloaded, the model is cached in Ollama's Docker volume for future use. + ## Where to go next - **Join our [Discord](https://discord.gg/nBv5b6kF68)** — ask questions and coordinate with other contributors diff --git a/tests/test_api.py b/tests/test_api.py index 32104ae..f56dce2 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -283,6 +283,26 @@ def boom(*a, **k): assert resp.status_code == 200 assert resp.json()["models"] == ["qwen2.5:1.5b"] + def test_pull_model_success(self, client, monkeypatch): + from unittest.mock import MagicMock + fake_response = MagicMock() + fake_response.raise_for_status.return_value = None + monkeypatch.setattr("app.api.routes.forms.requests.post", lambda *a, **k: fake_response) + + resp = client.post(f"{API_PREFIX}/forms/pull", json={"model": "llama3.2:3b"}) + assert resp.status_code == 200 + assert resp.json()["status"] == "success" + + def test_pull_model_failure(self, client, monkeypatch): + import requests + def boom(*a, **k): + raise requests.exceptions.RequestException("pull failed") + monkeypatch.setattr("app.api.routes.forms.requests.post", boom) + + resp = client.post(f"{API_PREFIX}/forms/pull", json={"model": "llama3.2:3b"}) + assert resp.status_code == 500 + assert resp.json()["error_code"] == "MODEL_PULL_ERROR" + def test_fill_form_passes_model_override(self, client, mock_controller): """A `model` in the request reaches Controller.fill_form but isn't persisted.""" tpl_id = self._seed_template(client, mock_controller)