diff --git a/src/database/setups.py b/src/database/setups.py index 1c959b3..ce0f0cc 100644 --- a/src/database/setups.py +++ b/src/database/setups.py @@ -3,7 +3,14 @@ from typing import TYPE_CHECKING from sqlalchemy import text +from sqlalchemy.exc import IntegrityError +from database.exceptions import ( + _DUPLICATE_ENTRY, + _FOREIGN_KEY_CONSTRAINT_FAILED, + DuplicatePrimaryKeyError, + ForeignKeyConstraintError, +) from routers.types import Identifier, TagString if TYPE_CHECKING: @@ -88,12 +95,20 @@ async def tag( connection: AsyncConnection, ) -> None: """Add tag `tag` to setup with id `setup_id`.""" - await connection.execute( - text( - """ - INSERT INTO setup_tag (id, tag, uploader) - VALUES (:setup_id, :tag, :user_id) - """, - ), - parameters={"setup_id": setup_id, "tag": tag, "user_id": user_id}, - ) + try: + await connection.execute( + text( + """ + INSERT INTO setup_tag (id, tag, uploader) + VALUES (:setup_id, :tag, :user_id) + """, + ), + parameters={"setup_id": setup_id, "tag": tag, "user_id": user_id}, + ) + except IntegrityError as e: + code, msg = e.orig.args + if code == _FOREIGN_KEY_CONSTRAINT_FAILED: + raise ForeignKeyConstraintError(msg) from e + if code == _DUPLICATE_ENTRY: + raise DuplicatePrimaryKeyError(msg) from e + raise diff --git a/src/routers/openml/setups.py b/src/routers/openml/setups.py index ef71ce6..34b28e0 100644 --- a/src/routers/openml/setups.py +++ b/src/routers/openml/setups.py @@ -13,6 +13,7 @@ TagNotFoundError, TagNotOwnedError, ) +from database.exceptions import DuplicatePrimaryKeyError, ForeignKeyConstraintError from database.users import User from routers.dependencies import expdb_connection, fetch_user_or_raise from routers.types import Identifier, TagString @@ -54,22 +55,19 @@ async def tag_setup( expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[str, dict[str, str | list[str]]]: """Add tag `tag` to setup with id `setup_id`.""" - setup, setup_tags = await asyncio.gather( - database.setups.get(setup_id, expdb_db), - database.setups.get_tags(setup_id, expdb_db), - ) - if not setup: + try: + await database.setups.tag(setup_id, tag, user.user_id, expdb_db) + except ForeignKeyConstraintError: msg = f"Setup {setup_id} not found." - raise SetupNotFoundError(msg) - matched_tag_row = next((t for t in setup_tags if t.tag.casefold() == tag.casefold()), None) + raise SetupNotFoundError(msg, code=472) from None + except DuplicatePrimaryKeyError: + msg = f"Setup {setup_id} already tagged with {tag!r}." + raise TagAlreadyExistsError(msg) from None - if matched_tag_row: - msg = f"Setup {setup_id} already has tag {tag!r}." - raise TagAlreadyExistsError(msg) - - await database.setups.tag(setup_id, tag, user.user_id, expdb_db) logger.info("Setup {setup_id} tagged '{tag}'.", setup_id=setup_id, tag=tag) - all_tags = [t.tag for t in setup_tags] + [tag] + all_tag_rows = await database.setups.get_tags(setup_id, expdb_db) + all_tags = [t.tag for t in all_tag_rows] + return {"setup_tag": {"id": str(setup_id), "tag": all_tags}} diff --git a/tests/routers/openml/setups_tag_test.py b/tests/routers/openml/setups_tag_test.py index c674ca2..32c85ef 100644 --- a/tests/routers/openml/setups_tag_test.py +++ b/tests/routers/openml/setups_tag_test.py @@ -65,7 +65,7 @@ async def test_setup_tag_already_exists(expdb_test: AsyncConnection) -> None: text("INSERT INTO setup_tag (id, tag, uploader) VALUES (1, :tag, 2);"), parameters={"tag": tag}, ) - with pytest.raises(TagAlreadyExistsError, match=rf"Setup 1 already has tag '{tag}'\."): + with pytest.raises(TagAlreadyExistsError, match=rf"Setup 1 already tagged with '{tag}'\."): await tag_setup( setup_id=1, tag=tag, @@ -203,4 +203,4 @@ async def test_setup_tag_response_is_identical_tag_already_exists( assert php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR assert py_response.status_code == HTTPStatus.CONFLICT assert php_response.json()["error"]["message"] == "Entity already tagged by this tag." - assert py_response.json()["detail"] == f"Setup {setup_id} already has tag {tag!r}." + assert py_response.json()["detail"] == f"Setup {setup_id} already tagged with {tag!r}."