diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1601d30..4d9aefd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,6 +24,7 @@ repos: additional_dependencies: - fastapi - pytest + - sqlalchemy - repo: https://github.com/astral-sh/ruff-pre-commit rev: 'v0.15.18' diff --git a/src/core/formatting.py b/src/core/formatting.py index faf6d42..a4b09e8 100644 --- a/src/core/formatting.py +++ b/src/core/formatting.py @@ -1,12 +1,9 @@ import html -from typing import TYPE_CHECKING from config import get_config +from database.schema.base import UntypedRow from schemas.datasets.openml import DatasetFileFormat -if TYPE_CHECKING: - from sqlalchemy.engine import Row - def _str_to_bool(string: str) -> bool: if string.casefold() in ["true", "1", "yes", "y"]: @@ -17,7 +14,7 @@ def _str_to_bool(string: str) -> bool: raise ValueError(msg) -def _format_parquet_url(dataset: Row) -> str | None: +def _format_parquet_url(dataset: UntypedRow) -> str | None: if dataset.format.lower() != DatasetFileFormat.ARFF: return None @@ -27,7 +24,7 @@ def _format_parquet_url(dataset: Row) -> str | None: return f"{minio_base_url}datasets/{ten_thousands_prefix}/{padded_id}/dataset_{dataset.did}.pq" -def _format_dataset_url(dataset: Row) -> str: +def _format_dataset_url(dataset: UntypedRow) -> str: base_url = get_config().routing.server_url filename = f"{html.escape(dataset.name)}.{dataset.format.lower()}" return f"{base_url}data/v1/download/{dataset.file_id}/{filename}" diff --git a/src/database/datasets.py b/src/database/datasets.py index 5f9c0e5..5cb99d0 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -13,15 +13,15 @@ DuplicatePrimaryKeyError, ForeignKeyConstraintError, ) +from database.schema.base import UntypedRow from routers.types import Identifier, TagString from schemas.datasets.openml import DatasetStatus, Feature if TYPE_CHECKING: - from sqlalchemy.engine import Row from sqlalchemy.ext.asyncio import AsyncConnection -async def get(id_: Identifier, connection: AsyncConnection) -> Row | None: +async def get(id_: Identifier, connection: AsyncConnection) -> UntypedRow | None: row = await connection.execute( text( """ @@ -35,7 +35,7 @@ async def get(id_: Identifier, connection: AsyncConnection) -> Row | None: return row.one_or_none() -async def get_file(*, file_id: Identifier, connection: AsyncConnection) -> Row | None: +async def get_file(*, file_id: Identifier, connection: AsyncConnection) -> UntypedRow | None: row = await connection.execute( text( """ @@ -53,7 +53,7 @@ async def get_tag( dataset_id: Identifier, tag: TagString, connection: AsyncConnection, -) -> Row | None: +) -> UntypedRow | None: return ( await connection.execute( text( @@ -111,6 +111,8 @@ async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection) }, ) except IntegrityError as e: + if e.orig is None: + raise code, msg = e.orig.args if code == _FOREIGN_KEY_CONSTRAINT_FAILED: raise ForeignKeyConstraintError(msg) from e @@ -122,7 +124,7 @@ async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection) async def get_description( id_: Identifier, connection: AsyncConnection, -) -> Row | None: +) -> UntypedRow | None: """Get the most recent description for the dataset.""" row = await connection.execute( text( @@ -160,7 +162,7 @@ async def get_status(id_: Identifier, connection: AsyncConnection) -> DatasetSta async def get_latest_processing_update( dataset_id: Identifier, connection: AsyncConnection, -) -> Row | None: +) -> UntypedRow | None: row = await connection.execute( text( """ diff --git a/src/database/evaluations.py b/src/database/evaluations.py index 382653f..14bec4d 100644 --- a/src/database/evaluations.py +++ b/src/database/evaluations.py @@ -1,16 +1,20 @@ from collections.abc import Sequence -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING -from sqlalchemy import Row, text +from sqlalchemy import text from core.formatting import _str_to_bool +from database.schema.base import UntypedRow from schemas.datasets.openml import EstimationProcedure if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncConnection -async def get_math_functions(function_type: str, connection: AsyncConnection) -> Sequence[Row]: +async def get_math_functions( + function_type: str, + connection: AsyncConnection, +) -> Sequence[UntypedRow]: rows = await connection.execute( text( """ @@ -21,10 +25,7 @@ async def get_math_functions(function_type: str, connection: AsyncConnection) -> ), parameters={"function_type": function_type}, ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() async def get_estimation_procedures(connection: AsyncConnection) -> list[EstimationProcedure]: diff --git a/src/database/flows.py b/src/database/flows.py index 7b1da5e..830a386 100644 --- a/src/database/flows.py +++ b/src/database/flows.py @@ -1,15 +1,16 @@ from collections.abc import Sequence -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING -from sqlalchemy import Row, text +from sqlalchemy import text +from database.schema.base import UntypedRow from routers.types import Identifier if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncConnection -async def get_subflows(for_flow: Identifier, expdb: AsyncConnection) -> Sequence[Row]: +async def get_subflows(for_flow: Identifier, expdb: AsyncConnection) -> Sequence[UntypedRow]: rows = await expdb.execute( text( """ @@ -20,10 +21,7 @@ async def get_subflows(for_flow: Identifier, expdb: AsyncConnection) -> Sequence ), parameters={"flow_id": for_flow}, ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() async def get_tags(flow_id: Identifier, expdb: AsyncConnection) -> list[str]: @@ -41,7 +39,7 @@ async def get_tags(flow_id: Identifier, expdb: AsyncConnection) -> list[str]: return [tag.tag for tag in tag_rows] -async def get_parameters(flow_id: Identifier, expdb: AsyncConnection) -> Sequence[Row]: +async def get_parameters(flow_id: Identifier, expdb: AsyncConnection) -> Sequence[UntypedRow]: rows = await expdb.execute( text( """ @@ -52,13 +50,14 @@ async def get_parameters(flow_id: Identifier, expdb: AsyncConnection) -> Sequenc ), parameters={"flow_id": flow_id}, ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() -async def get_by_name(name: str, external_version: str, expdb: AsyncConnection) -> Row | None: +async def get_by_name( + name: str, + external_version: str, + expdb: AsyncConnection, +) -> UntypedRow | None: """Get flow by name and external version.""" row = await expdb.execute( text( @@ -73,7 +72,7 @@ async def get_by_name(name: str, external_version: str, expdb: AsyncConnection) return row.one_or_none() -async def get(id_: Identifier, expdb: AsyncConnection) -> Row | None: +async def get(id_: Identifier, expdb: AsyncConnection) -> UntypedRow | None: row = await expdb.execute( text( """ diff --git a/src/database/runs.py b/src/database/runs.py index f0bc419..0c5c93c 100644 --- a/src/database/runs.py +++ b/src/database/runs.py @@ -3,8 +3,9 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, cast -from sqlalchemy import Row, bindparam, text +from sqlalchemy import bindparam, text +from database.schema.base import UntypedRow from routers.types import Identifier if TYPE_CHECKING: @@ -26,7 +27,7 @@ async def exist(id_: Identifier, expdb: AsyncConnection) -> bool: return bool(row.one_or_none()) -async def get(run_id: Identifier, expdb: AsyncConnection) -> Row | None: +async def get(run_id: Identifier, expdb: AsyncConnection) -> UntypedRow | None: """Fetch the core run row from the `run` table. Returns the row if found, or None if no run with `run_id` exists. @@ -63,7 +64,7 @@ async def get_tags(run_id: int, expdb: AsyncConnection) -> list[str]: return [row.tag for row in rows.all()] -async def get_input_data(run_id: int, expdb: AsyncConnection) -> list[Row]: +async def get_input_data(run_id: int, expdb: AsyncConnection) -> list[UntypedRow]: """Fetch the dataset(s) used as input for a run, with name and url. Joins `input_data` with `dataset` to include the dataset name and ARFF URL. @@ -79,10 +80,10 @@ async def get_input_data(run_id: int, expdb: AsyncConnection) -> list[Row]: ), parameters={"run_id": run_id}, ) - return cast("list[Row]", rows.all()) + return cast("list[UntypedRow]", rows.all()) -async def get_output_files(run_id: int, expdb: AsyncConnection) -> list[Row]: +async def get_output_files(run_id: int, expdb: AsyncConnection) -> list[UntypedRow]: """Fetch output files attached to a run from the `runfile` table. Typical entries include the description XML and predictions ARFF. @@ -98,7 +99,7 @@ async def get_output_files(run_id: int, expdb: AsyncConnection) -> list[Row]: ), parameters={"run_id": run_id}, ) - return cast("list[Row]", rows.all()) + return cast("list[UntypedRow]", rows.all()) async def get_evaluations( @@ -106,7 +107,7 @@ async def get_evaluations( expdb: AsyncConnection, *, evaluation_engine_ids: list[int], -) -> list[Row]: +) -> list[UntypedRow]: """Fetch evaluation metric results for a run. Joins `evaluation` with `math_function` to resolve the metric name @@ -138,10 +139,10 @@ async def get_evaluations( query, parameters={"run_id": run_id, "engine_ids": evaluation_engine_ids}, ) - return cast("list[Row]", rows.all()) + return cast("list[UntypedRow]", rows.all()) -async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]: +async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[UntypedRow]: """Get trace rows for a run from the trace table.""" rows = await expdb.execute( text( @@ -153,7 +154,4 @@ async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]: ), parameters={"run_id": run_id}, ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() diff --git a/src/database/schema/__init__.py b/src/database/schema/__init__.py new file mode 100644 index 0000000..6e6260c --- /dev/null +++ b/src/database/schema/__init__.py @@ -0,0 +1 @@ +"""Defines Object-Relational Mappings (ORM).""" diff --git a/src/database/schema/base.py b/src/database/schema/base.py new file mode 100644 index 0000000..0a75055 --- /dev/null +++ b/src/database/schema/base.py @@ -0,0 +1,52 @@ +"""Base classes for all ORM classes. + +When defining a new ORM class, use both `Base` and one of the `DeferredReflection` subclasses to +make sure that the class is populated with attributes that may not be defined explicitly. +For example, when creating a new mapping for a table from the `openml_expdb` database, use: + +class ClassName(ExpDBReflected, Base): + __tablename__ = "class_names" + + # any columns you wanted mapped explicitly + ... + +""" + +from typing import Any + +from sqlalchemy import Row +from sqlalchemy.ext.declarative import DeferredReflection +from sqlalchemy.orm import DeclarativeBase + +from database.setup import expdb_database, user_database + +UntypedRow = Row[Any] + + +class Base(DeclarativeBase): + """Base class for all ORM classes.""" + + +class ExpDBReflected(DeferredReflection): + """Base class for ORM classes to map onto a table in the `openml_expdb` database.""" + + __abstract__ = True + + +class UserDBReflected(DeferredReflection): + """Base class for ORM classes to map onto a table in the `openml` database.""" + + __abstract__ = True + + +async def reflect_db_schemas() -> None: + """Populate defined ORM classes with attributes defined from columns in the database. + + For example, the `dataset` class would automatically get a `collection_date` attribute, + even if it wasn't explicitly declared in the class definition, + because the `openml_expdb.dataset` table has a column `collection_date`. + """ + async with user_database().connect() as connection: + await connection.run_sync(UserDBReflected.prepare) # type: ignore[arg-type] # run_sync expects positional-only arg but `prepare` does not have it. + async with expdb_database().connect() as connection: + await connection.run_sync(ExpDBReflected.prepare) # type: ignore[arg-type] # as above. diff --git a/src/database/schema/tags.py b/src/database/schema/tags.py new file mode 100644 index 0000000..a6b6a78 --- /dev/null +++ b/src/database/schema/tags.py @@ -0,0 +1,30 @@ +"""ORM classes for the *_tag tables (task_tag, ...).""" + +from datetime import datetime + +from sqlalchemy import FetchedValue +from sqlalchemy.orm import Mapped, mapped_column + +from database.schema.base import Base, ExpDBReflected +from routers.types import Identifier, TagString + + +class Tag: + """Base class for all of the *_tag tables.""" + + # The identifier of the entity that is tagged (e.g., dataset id, task id) + entity_id: Mapped[Identifier] = mapped_column("id", primary_key=True) + tag: Mapped[TagString] = mapped_column(primary_key=True) + uploader_id: Mapped[Identifier] = mapped_column("uploader") + creation_date: Mapped[datetime] = mapped_column("date", server_default=FetchedValue()) + + +class TaskTag(ExpDBReflected, Tag, Base): + """Tags belonging to a task.""" + + __tablename__ = "task_tag" + + @property + def task_id(self) -> Identifier: + """Identifier of the task which is tagged by this tag.""" + return self.entity_id diff --git a/src/database/setups.py b/src/database/setups.py index ce0f0cc..c78986c 100644 --- a/src/database/setups.py +++ b/src/database/setups.py @@ -11,14 +11,15 @@ DuplicatePrimaryKeyError, ForeignKeyConstraintError, ) +from database.schema.base import UntypedRow from routers.types import Identifier, TagString if TYPE_CHECKING: - from sqlalchemy.engine import Row, RowMapping + from sqlalchemy.engine import RowMapping from sqlalchemy.ext.asyncio import AsyncConnection -async def get(setup_id: Identifier, connection: AsyncConnection) -> Row | None: +async def get(setup_id: Identifier, connection: AsyncConnection) -> UntypedRow | None: """Get the setup with id `setup_id` from the database.""" row = await connection.execute( text( @@ -60,7 +61,7 @@ async def get_parameters(setup_id: Identifier, connection: AsyncConnection) -> l return list(rows.mappings().all()) -async def get_tags(setup_id: Identifier, connection: AsyncConnection) -> list[Row]: +async def get_tags(setup_id: Identifier, connection: AsyncConnection) -> list[UntypedRow]: """Get all tags for setup with `setup_id` from the database.""" rows = await connection.execute( text( @@ -106,6 +107,8 @@ async def tag( parameters={"setup_id": setup_id, "tag": tag, "user_id": user_id}, ) except IntegrityError as e: + if e.orig is None: + raise code, msg = e.orig.args if code == _FOREIGN_KEY_CONSTRAINT_FAILED: raise ForeignKeyConstraintError(msg) from e diff --git a/src/database/studies.py b/src/database/studies.py index 1939452..b4a6936 100644 --- a/src/database/studies.py +++ b/src/database/studies.py @@ -3,8 +3,9 @@ from datetime import UTC, datetime from typing import TYPE_CHECKING, cast -from sqlalchemy import Row, text +from sqlalchemy import text +from database.schema.base import UntypedRow from database.users import User from routers.types import Identifier from schemas.study import CreateStudy, StudyType @@ -13,7 +14,7 @@ from sqlalchemy.ext.asyncio import AsyncConnection -async def get_by_id(id_: Identifier, connection: AsyncConnection) -> Row | None: +async def get_by_id(id_: Identifier, connection: AsyncConnection) -> UntypedRow | None: row = await connection.execute( text( """ @@ -27,7 +28,7 @@ async def get_by_id(id_: Identifier, connection: AsyncConnection) -> Row | None: return row.one_or_none() -async def get_by_alias(alias: str, connection: AsyncConnection) -> Row | None: +async def get_by_alias(alias: str, connection: AsyncConnection) -> UntypedRow | None: row = await connection.execute( text( """ @@ -41,7 +42,7 @@ async def get_by_alias(alias: str, connection: AsyncConnection) -> Row | None: return row.one_or_none() -async def get_study_data(study: Row, expdb: AsyncConnection) -> Sequence[Row]: +async def get_study_data(study: UntypedRow, expdb: AsyncConnection) -> Sequence[UntypedRow]: """Return data related to the study, content depends on the study type. For task studies: (task id, dataset id) @@ -58,10 +59,8 @@ async def get_study_data(study: Row, expdb: AsyncConnection) -> Sequence[Row]: ), parameters={"study_id": study.id}, ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() + rows = await expdb.execute( text( """ @@ -80,10 +79,7 @@ async def get_study_data(study: Row, expdb: AsyncConnection) -> Sequence[Row]: ), parameters={"study_id": study.id}, ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() async def create(study: CreateStudy, user: User, expdb: AsyncConnection) -> int: diff --git a/src/database/tasks.py b/src/database/tasks.py index b0f6010..9d1d120 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -1,7 +1,7 @@ from collections.abc import Sequence -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING -from sqlalchemy import Row, text +from sqlalchemy import select, text from sqlalchemy.exc import IntegrityError from database.exceptions import ( @@ -10,13 +10,15 @@ DuplicatePrimaryKeyError, ForeignKeyConstraintError, ) +from database.schema.base import UntypedRow +from database.schema.tags import TaskTag from routers.types import Identifier, TagString if TYPE_CHECKING: - from sqlalchemy.ext.asyncio import AsyncConnection + from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession -async def get(id_: Identifier, expdb: AsyncConnection) -> Row | None: +async def get(id_: Identifier, expdb: AsyncConnection) -> UntypedRow | None: row = await expdb.execute( text( """ @@ -30,7 +32,7 @@ async def get(id_: Identifier, expdb: AsyncConnection) -> Row | None: return row.one_or_none() -async def get_task_types(expdb: AsyncConnection) -> Sequence[Row]: +async def get_task_types(expdb: AsyncConnection) -> Sequence[UntypedRow]: rows = await expdb.execute( text( """ @@ -39,13 +41,10 @@ async def get_task_types(expdb: AsyncConnection) -> Sequence[Row]: """, ), ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() -async def get_task_type(task_type_id: Identifier, expdb: AsyncConnection) -> Row | None: +async def get_task_type(task_type_id: Identifier, expdb: AsyncConnection) -> UntypedRow | None: row = await expdb.execute( text( """ @@ -102,7 +101,10 @@ async def get_task_evaluation_measure(task_id: int, expdb: AsyncConnection) -> s return result.value if result else None -async def get_input_for_task_type(task_type_id: int, expdb: AsyncConnection) -> Sequence[Row]: +async def get_input_for_task_type( + task_type_id: int, + expdb: AsyncConnection, +) -> Sequence[UntypedRow]: rows = await expdb.execute( text( """ @@ -113,13 +115,10 @@ async def get_input_for_task_type(task_type_id: int, expdb: AsyncConnection) -> ), parameters={"ttid": task_type_id}, ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() -async def get_input_for_task(id_: Identifier, expdb: AsyncConnection) -> Sequence[Row]: +async def get_input_for_task(id_: Identifier, expdb: AsyncConnection) -> Sequence[UntypedRow]: rows = await expdb.execute( text( """ @@ -130,16 +129,13 @@ async def get_input_for_task(id_: Identifier, expdb: AsyncConnection) -> Sequenc ), parameters={"task_id": id_}, ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() async def get_task_type_inout_with_template( task_type: Identifier, expdb: AsyncConnection, -) -> Sequence[Row]: +) -> Sequence[UntypedRow]: rows = await expdb.execute( text( """ @@ -150,25 +146,12 @@ async def get_task_type_inout_with_template( ), parameters={"ttid": task_type}, ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() -async def get_tags(id_: Identifier, connection: AsyncConnection) -> list[str]: - rows = await connection.execute( - text( - """ - SELECT `tag` - FROM task_tag - WHERE `id` = :task_id - """, - ), - parameters={"task_id": id_}, - ) - tag_rows = rows.all() - return [row.tag for row in tag_rows] +async def get_tags(task_id: Identifier, session: AsyncSession) -> Sequence[TaskTag]: + stmt = select(TaskTag).where(TaskTag.entity_id == task_id) + return (await session.scalars(stmt)).all() async def tag( @@ -176,23 +159,15 @@ async def tag( tag_: TagString, *, user_id: Identifier, - connection: AsyncConnection, + session: AsyncSession, ) -> None: try: - await connection.execute( - text( - """ - INSERT INTO task_tag(`id`, `tag`, `uploader`) - VALUES (:task_id, :tag, :user_id) - """, - ), - parameters={ - "task_id": id_, - "user_id": user_id, - "tag": tag_, - }, - ) + tag = TaskTag(entity_id=id_, uploader_id=user_id, tag=tag_) + session.add(tag) + await session.flush() except IntegrityError as e: + if e.orig is None: + raise code, msg = e.orig.args if code == _FOREIGN_KEY_CONSTRAINT_FAILED: raise ForeignKeyConstraintError(msg) from e diff --git a/src/database/users.py b/src/database/users.py index 381f01d..379c6f8 100644 --- a/src/database/users.py +++ b/src/database/users.py @@ -106,9 +106,9 @@ async def get_user_groups_for( @dataclasses.dataclass class User: user_id: Identifier - _database: AsyncConnection first_name: str = "" last_name: str = "" + _database: AsyncConnection | None = None _groups: list[UserGroup] | None = None def __post_init__(self) -> None: @@ -136,11 +136,17 @@ async def fetch(cls, api_key: APIKey, user_db: AsyncConnection) -> Self | None: return None async def get_groups(self) -> list[UserGroup]: - if self._groups is None: - self._groups = await get_user_groups_for( - user_id=self.user_id, - connection=self._database, - ) + if self._groups: + return self._groups + + if self._database is None: + msg = "`get_groups` can only be used when `connection` is provided on instantiation." + raise RuntimeError(msg) + + self._groups = await get_user_groups_for( + user_id=self.user_id, + connection=self._database, + ) return self._groups async def is_admin(self) -> bool: diff --git a/src/main.py b/src/main.py index b4f7dc8..6207b15 100644 --- a/src/main.py +++ b/src/main.py @@ -26,6 +26,7 @@ request_response_logger, setup_log_sinks, ) +from database.schema.base import reflect_db_schemas from database.setup import close_databases from routers.openml.datasets import router as datasets_router from routers.openml.estimation_procedure import router as estimationprocedure_router @@ -45,6 +46,8 @@ async def lifespan( app: FastAPI | None, # noqa: ARG001 # parameter required by FastAPI/Starlette ) -> AsyncIterator[None]: """Manage application lifespan - startup and shutdown events.""" + logger.info("Reflecting database schemas") + await reflect_db_schemas() yield await asyncio.gather( logger.complete(), diff --git a/src/routers/dependencies.py b/src/routers/dependencies.py index ca4a965..2ee8548 100644 --- a/src/routers/dependencies.py +++ b/src/routers/dependencies.py @@ -4,6 +4,7 @@ from fastapi import Depends from loguru import logger from pydantic import BaseModel, Field +from sqlalchemy.ext.asyncio import AsyncSession from core.errors import AuthenticationFailedError, AuthenticationRequiredError from database.setup import expdb_database, user_database @@ -25,6 +26,13 @@ async def userdb_connection() -> AsyncIterator[AsyncConnection]: yield connection +async def expdb_session( + connection: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> AsyncIterator[AsyncSession]: + async with AsyncSession(connection) as session, session.begin(): + yield session + + async def fetch_user( api_key: APIKey | None = None, user_data: Annotated[AsyncConnection | None, Depends(userdb_connection)] = None, diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index 270c71b..72c94de 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -156,6 +156,7 @@ def _quality_clause(quality: str, range_: str | None) -> str: @router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.") @router.get(path="/list") async def list_datasets( # noqa: PLR0913, C901 + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], pagination: Annotated[Pagination, Body(default_factory=Pagination)], data_name: Annotated[CasualString128 | None, Body()] = None, tag: Annotated[TagString | None, Body()] = None, @@ -180,9 +181,7 @@ async def list_datasets( # noqa: PLR0913, C901 number_missing_values: Annotated[IntegerRange | None, Body()] = None, status: Annotated[DatasetStatusFilter, Body()] = DatasetStatusFilter.ACTIVE, user: Annotated[User | None, Depends(fetch_user)] = None, - expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> list[dict[str, Any]]: - assert expdb_db is not None # noqa: S101 status_subquery = text( """ SELECT ds1.`did`, ds1.`status` @@ -356,8 +355,8 @@ async def _get_dataset_raise_otherwise( @router.get("/features/{dataset_id}", response_model_exclude_none=True) async def get_dataset_features( dataset_id: Identifier, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> list[Feature]: assert expdb is not None # noqa: S101 await _get_dataset_raise_otherwise(dataset_id, user, expdb) @@ -453,9 +452,9 @@ async def update_dataset_status( ) async def get_dataset( dataset_id: Identifier, + user_db: Annotated[AsyncConnection, Depends(userdb_connection)], + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], user: Annotated[User | None, Depends(fetch_user)] = None, - user_db: Annotated[AsyncConnection, Depends(userdb_connection)] = None, - expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> DatasetMetadata: assert user_db is not None # noqa: S101 assert expdb_db is not None # noqa: S101 diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py index 0abf40e..88296cf 100644 --- a/src/routers/openml/runs.py +++ b/src/routers/openml/runs.py @@ -6,9 +6,6 @@ from fastapi import APIRouter, Depends -if TYPE_CHECKING: - from sqlalchemy import Row - import config import database.flows import database.runs @@ -16,6 +13,7 @@ import database.tasks import database.users from core.errors import RunNotFoundError, RunTraceNotFoundError +from database.schema.base import UntypedRow from routers.dependencies import expdb_connection, userdb_connection from routers.types import Identifier from schemas.runs import ( @@ -72,17 +70,17 @@ class RunContext: uploader_name: str | None tags: list[str] - input_data_rows: list[Row] - output_file_rows: list[Row] - evaluation_rows: list[Row] + input_data_rows: list[UntypedRow] + output_file_rows: list[UntypedRow] + evaluation_rows: list[UntypedRow] task_type: str | None task_evaluation_measure: str | None - setup: Row | None - parameter_rows: list[Row] + setup: UntypedRow | None + parameter_rows: list[UntypedRow] async def _load_run_context( - run: Row, + run: UntypedRow, run_id: int, expdb: AsyncConnection, userdb: AsyncConnection, @@ -99,8 +97,7 @@ async def _load_run_context( setup, parameter_rows, ) = cast( - "tuple[Any, list[str], list[Row], list[Row], list[Row], str | None, str |" - "None, Row | None, list[Row]]", + "tuple[Any, list[str], list[UntypedRow], list[UntypedRow], list[UntypedRow], str | None, str | None, UntypedRow | None, list[UntypedRow]]", # noqa: E501 await asyncio.gather( database.users.get_user(user_id=run.uploader, connection=userdb), database.runs.get_tags(run_id, expdb), @@ -126,7 +123,7 @@ async def _load_run_context( ) -def _build_evaluations(rows: list[Row]) -> list[EvaluationScore]: +def _build_evaluations(rows: list[UntypedRow]) -> list[EvaluationScore]: def _normalise_value(v: object) -> object: if isinstance(v, (int, float)): return int(v) if float(v).is_integer() else float(v) @@ -181,7 +178,7 @@ async def get_run( setup_id=run.setup, setup_string=ctx.setup.setup_string if ctx.setup else None, parameter_setting=[ - ParameterSetting(name=p["name"], value=p["value"], component=p["flow_id"]) + ParameterSetting(name=p.name, value=p.value, component=p.flow_id) for p in ctx.parameter_rows ], error_message=error_messages, diff --git a/src/routers/openml/study.py b/src/routers/openml/study.py index ccdf9ff..dbb6581 100644 --- a/src/routers/openml/study.py +++ b/src/routers/openml/study.py @@ -16,6 +16,7 @@ StudyPrivateError, ) from core.formatting import _str_to_bool +from database.schema.base import UntypedRow from database.users import User from routers.dependencies import expdb_connection, fetch_user, fetch_user_or_raise from routers.types import Identifier @@ -23,7 +24,6 @@ from schemas.study import CreateStudy, Study, StudyStatus, StudyType if TYPE_CHECKING: - from sqlalchemy.engine import Row from sqlalchemy.ext.asyncio import AsyncConnection router = APIRouter(prefix="/studies", tags=["studies"]) @@ -33,7 +33,7 @@ async def _get_study_raise_otherwise( id_or_alias: Identifier | str, user: User | None, expdb: AsyncConnection, -) -> Row: +) -> UntypedRow: search_by_id = isinstance(id_or_alias, int) or id_or_alias.isdigit() if search_by_id: study = await database.studies.get_by_id(int(id_or_alias), expdb) @@ -67,7 +67,7 @@ async def attach_to_study( study_id: Annotated[Identifier, Body()], entity_ids: Annotated[list[Identifier], Body()], user: Annotated[User, Depends(fetch_user_or_raise)], - expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> AttachDetachResponse: assert expdb is not None # noqa: S101 if user is None: @@ -90,16 +90,21 @@ async def attach_to_study( # We let the database handle the constraints on whether # the entity is already attached or if it even exists. - attach_kwargs = { - "study_id": study_id, - "user": user, - "connection": expdb, - } try: if study.type_ == StudyType.TASK: - await database.studies.attach_tasks(task_ids=entity_ids, **attach_kwargs) + await database.studies.attach_tasks( + task_ids=entity_ids, + study_id=study_id, + user=user, + connection=expdb, + ) else: - await database.studies.attach_runs(run_ids=entity_ids, **attach_kwargs) + await database.studies.attach_runs( + run_ids=entity_ids, + study_id=study_id, + user=user, + connection=expdb, + ) except ValueError as e: msg = str(e) raise StudyConflictError(msg) from e @@ -116,7 +121,7 @@ async def attach_to_study( async def create_study( study: CreateStudy, user: Annotated[User, Depends(fetch_user_or_raise)], - expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[Literal["study_id"], int]: assert expdb is not None # noqa: S101 if study.main_entity_type == StudyType.RUN and study.tasks: @@ -152,8 +157,8 @@ async def create_study( @router.get("/{alias_or_id}") async def get_study( alias_or_id: Identifier | str, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> Study: assert expdb is not None # noqa: S101 study = await _get_study_raise_otherwise(alias_or_id, user, expdb) diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index f68c4b8..55df619 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -14,8 +14,9 @@ from config import get_config from core.errors import InternalError, NoResultsError, TagAlreadyExistsError, TaskNotFoundError from database.exceptions import DuplicatePrimaryKeyError, ForeignKeyConstraintError +from database.schema.base import UntypedRow from database.users import User -from routers.dependencies import Pagination, expdb_connection, fetch_user_or_raise +from routers.dependencies import Pagination, expdb_connection, expdb_session, fetch_user_or_raise from routers.types import ( CasualString128, Identifier, @@ -26,8 +27,7 @@ from schemas.datasets.openml import Task if TYPE_CHECKING: - from sqlalchemy.engine import RowMapping - from sqlalchemy.ext.asyncio import AsyncConnection + from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession router = APIRouter(prefix="/tasks", tags=["tasks"]) @@ -39,10 +39,10 @@ async def tag_task( task_id: Annotated[Identifier, Body()], tag: Annotated[TagString, Body()], user: Annotated[User, Depends(fetch_user_or_raise)], - expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], + expdb_session: Annotated[AsyncSession, Depends(expdb_session)], ) -> dict[str, dict[str, Any]]: try: - await database.tasks.tag(task_id, tag, user_id=user.user_id, connection=expdb_db) + await database.tasks.tag(task_id, tag, user_id=user.user_id, session=expdb_session) except ForeignKeyConstraintError: msg = f"Task {task_id} not found." raise TaskNotFoundError(msg, code=472) from None @@ -52,10 +52,10 @@ async def tag_task( logger.info("Task {task_id} tagged '{tag}'.", task_id=task_id, tag=tag) - tags = await database.tasks.get_tags(task_id, expdb_db) + tags = await database.tasks.get_tags(task_id, expdb_session) return { - "task_tag": {"id": str(task_id), "tag": tags}, + "task_tag": {"id": str(task_id), "tag": [t.tag for t in tags]}, } @@ -70,7 +70,7 @@ def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]: async def fill_template( template: str, - task: RowMapping, + task: UntypedRow, task_inputs: dict[str, str | int], connection: AsyncConnection, ) -> dict[str, JSON]: @@ -137,7 +137,7 @@ async def fill_template( async def _fill_json_template( # noqa: C901 template: JSON, - task: RowMapping, + task: UntypedRow, task_inputs: dict[str, str | int], fetched_data: dict[str, str], connection: AsyncConnection, @@ -257,6 +257,7 @@ def _quality_clause(quality: str, range_: str | None) -> str: @router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.") @router.get(path="/list") async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915 + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], pagination: Annotated[Pagination, Body(default_factory=Pagination)], task_type_id: Annotated[Identifier | None, Body(description="Filter by task type id.")] = None, tag: Annotated[TagString | None, Body()] = None, @@ -275,7 +276,6 @@ async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915 number_features: Annotated[IntegerRange | None, Body()] = None, number_classes: Annotated[IntegerRange | None, Body()] = None, number_missing_values: Annotated[IntegerRange | None, Body()] = None, - expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> list[dict[str, Any]]: """List tasks, optionally filtered by type, tag, status, dataset properties, and more.""" assert expdb is not None # noqa: S101 @@ -451,6 +451,7 @@ async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915 async def get_task( task_id: int, expdb: Annotated[AsyncConnection, Depends(expdb_connection)], + expdb_session: Annotated[AsyncSession, Depends(expdb_session)], ) -> Task: if not (task := await database.tasks.get(task_id, expdb)): msg = f"Task {task_id} not found." @@ -462,7 +463,7 @@ async def get_task( task_input_rows, ttios, tags = await asyncio.gather( database.tasks.get_input_for_task(task_id, expdb), database.tasks.get_task_type_inout_with_template(task_type.ttid, expdb), - database.tasks.get_tags(task_id, expdb), + database.tasks.get_tags(task_id, expdb_session), ) task_inputs = { row.input: int(row.value) if row.value.isdigit() else row.value for row in task_input_rows @@ -495,5 +496,5 @@ async def get_task( task_type=task_type.name, input_=inputs, output=outputs, - tags=tags, + tags=[t.tag for t in tags], ) diff --git a/tests/conftest.py b/tests/conftest.py index 4191c47..68a8e1c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,7 @@ from _pytest.nodes import Item # noqa: TC002 used during collection by Pytest from asgi_lifespan import LifespanManager from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession from config import ( Configuration, @@ -71,6 +72,16 @@ async def expdb_test() -> AsyncIterator[AsyncConnection]: yield connection +@pytest.fixture +async def expdb_session(expdb_test: AsyncConnection) -> AsyncIterator[AsyncSession]: + # It is possible that this session `commits`, does the connection + # rollback then still take effect? Probably not. + async with AsyncSession(expdb_test) as session: + yield session + # Do we here again need to do some check on whether there is an active transation? + await session.rollback() + + @pytest.fixture async def user_test() -> AsyncIterator[AsyncConnection]: async with automatic_rollback(user_database()) as connection: @@ -85,7 +96,7 @@ async def php_api() -> AsyncIterator[httpx.AsyncClient]: yield client -@pytest.fixture(scope="session") +@pytest.fixture(scope="session", autouse=True) async def app() -> AsyncIterator[FastAPI]: config = Configuration( openml_database=DatabaseConfiguration(database="openml"), @@ -103,7 +114,9 @@ async def app() -> AsyncIterator[FastAPI]: @pytest.fixture async def py_api( - expdb_test: AsyncConnection, user_test: AsyncConnection, app: FastAPI + expdb_test: AsyncConnection, + user_test: AsyncConnection, + app: FastAPI, ) -> AsyncIterator[httpx.AsyncClient]: """Create test client which automatically rolls back database updates on teardown.""" # Using the function-scoped database fixtures automatically benefits the diff --git a/tests/routers/openml/runs_get_test.py b/tests/routers/openml/runs_get_test.py index d566403..af841f6 100644 --- a/tests/routers/openml/runs_get_test.py +++ b/tests/routers/openml/runs_get_test.py @@ -299,7 +299,7 @@ def __init__( self.fold = fold rows = [MockRow("test_metric", input_value, repeat=repeat, fold=fold)] - evals = _build_evaluations(rows) + evals = _build_evaluations(rows) # type: ignore[arg-type] assert len(evals) == 1 assert evals[0].value == expected_value diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py index 087f76e..f6e04b1 100644 --- a/tests/routers/openml/task_tag_test.py +++ b/tests/routers/openml/task_tag_test.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: import httpx - from sqlalchemy.ext.asyncio import AsyncConnection + from sqlalchemy.ext.asyncio import AsyncSession @pytest.mark.parametrize( @@ -41,44 +41,46 @@ async def test_task_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.AsyncCli [ADMIN_USER, SOME_USER, OWNER_USER], ids=["administrator", "non-owner", "owner"], ) -async def test_task_tag(user: User, expdb_test: AsyncConnection, task_factory: TaskFactory) -> None: +async def test_task_tag(user: User, expdb_session: AsyncSession, task_factory: TaskFactory) -> None: tag = "test_task_tag" task = await task_factory() - result = await tag_task(task_id=task.id, tag=tag, user=user, expdb_db=expdb_test) + result = await tag_task(task_id=task.id, tag=tag, user=user, expdb_session=expdb_session) assert result == {"task_tag": {"id": str(task.id), "tag": [tag]}} - tags = await get_tags(id_=task.id, connection=expdb_test) - assert tag in tags + tags = await get_tags(task_id=task.id, session=expdb_session) + assert tag in [t.tag for t in tags] @pytest.mark.mut async def test_task_tag_returns_existing_tags( - task_factory: TaskFactory, expdb_test: AsyncConnection + task_factory: TaskFactory, expdb_session: AsyncSession ) -> None: task = await task_factory() - await tag_task(task_id=task.id, tag="first", user=ADMIN_USER, expdb_db=expdb_test) - result = await tag_task(task_id=task.id, tag="second", user=ADMIN_USER, expdb_db=expdb_test) + await tag_task(task_id=task.id, tag="first", user=ADMIN_USER, expdb_session=expdb_session) + result = await tag_task( + task_id=task.id, tag="second", user=ADMIN_USER, expdb_session=expdb_session + ) assert result == {"task_tag": {"id": str(task.id), "tag": ["first", "second"]}} @pytest.mark.mut async def test_task_tag_fails_if_tag_exists( - expdb_test: AsyncConnection, task_factory: TaskFactory + expdb_session: AsyncSession, task_factory: TaskFactory ) -> None: tag = "fails_if_exist" task = await task_factory() - await tag_task(task_id=task.id, tag=tag, user=ADMIN_USER, expdb_db=expdb_test) + await tag_task(task_id=task.id, tag=tag, user=ADMIN_USER, expdb_session=expdb_session) with pytest.raises(TagAlreadyExistsError) as e: - await tag_task(task_id=task.id, tag=tag, user=ADMIN_USER, expdb_db=expdb_test) + await tag_task(task_id=task.id, tag=tag, user=ADMIN_USER, expdb_session=expdb_session) assert str(task.id) in e.value.detail assert tag in e.value.detail -async def test_task_tag_fails_if_task_does_not_exist(expdb_test: AsyncConnection) -> None: +async def test_task_tag_fails_if_task_does_not_exist(expdb_session: AsyncSession) -> None: task_id = 1_000_000 with pytest.raises(TaskNotFoundError) as e: - await tag_task(task_id=task_id, tag="foo", user=ADMIN_USER, expdb_db=expdb_test) + await tag_task(task_id=task_id, tag="foo", user=ADMIN_USER, expdb_session=expdb_session) assert str(task_id) in e.value.detail task_not_found_in_tag_endpoint = TASK_NOT_FOUND_DURING_TAG assert e.value.code == task_not_found_in_tag_endpoint diff --git a/tests/users.py b/tests/users.py index c98ffb0..07d3c66 100644 --- a/tests/users.py +++ b/tests/users.py @@ -3,10 +3,10 @@ from database.users import User, UserGroup NO_USER = None -SOME_USER = User(user_id=2, _database=None, _groups=[UserGroup.READ_WRITE]) -OWNER_USER = User(user_id=3229, _database=None, _groups=[UserGroup.READ_WRITE]) -DATASET_130_OWNER = User(user_id=16, _database=None, _groups=[UserGroup.READ_WRITE]) -ADMIN_USER = User(user_id=1159, _database=None, _groups=[UserGroup.ADMIN, UserGroup.READ_WRITE]) +SOME_USER = User(user_id=2, _groups=[UserGroup.READ_WRITE]) +OWNER_USER = User(user_id=3229, _groups=[UserGroup.READ_WRITE]) +DATASET_130_OWNER = User(user_id=16, _groups=[UserGroup.READ_WRITE]) +ADMIN_USER = User(user_id=1159, _groups=[UserGroup.ADMIN, UserGroup.READ_WRITE]) class ApiKey(StrEnum):