From f5bc6414f9c149a606fbb6c720becae28b484df7 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Thu, 25 Jun 2026 08:54:08 +0200 Subject: [PATCH 01/11] Add tag endpoint --- src/database/tasks.py | 39 ++++++++++++++++++++++++++++++++++++- src/routers/openml/tasks.py | 32 ++++++++++++++++++++++++++++-- 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/src/database/tasks.py b/src/database/tasks.py index 788fc93..cdd9248 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -2,8 +2,15 @@ from typing import TYPE_CHECKING, cast from sqlalchemy import Row, text +from sqlalchemy.exc import IntegrityError -from routers.types import Identifier +from database.exceptions import ( + _DUPLICATE_ENTRY, + _FOREIGN_KEY_CONSTRAINT_FAILED, + DuplicatePrimaryKeyError, + ForeignKeyConstraintError, +) +from routers.types import Identifier, TagString if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncConnection @@ -162,3 +169,33 @@ async def get_tags(id_: Identifier, expdb: AsyncConnection) -> list[str]: ) tag_rows = rows.all() return [row.tag for row in tag_rows] + + +async def tag( + id_: Identifier, + tag_: TagString, + *, + user_id: Identifier, + connection: AsyncConnection, +) -> None: + try: + await connection.execute( + text( + """ + INSERT INTO task_tag(`id`, `tag`, `uploader`) + VALUES (:task_id, :tag, :user_id) + """, + ), + parameters={ + "task_id": id_, + "user_id": user_id, + "tag": tag_, + }, + ) + except IntegrityError as e: + code, msg = e.orig.args + if code == _FOREIGN_KEY_CONSTRAINT_FAILED: + raise ForeignKeyConstraintError(msg) from e + if code == _DUPLICATE_ENTRY: + raise DuplicatePrimaryKeyError(msg) from e + raise diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 8faa898..f68c4b8 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -6,13 +6,16 @@ import xmltodict from fastapi import APIRouter, Body, Depends +from loguru import logger from sqlalchemy import bindparam, text import database.datasets import database.tasks from config import get_config -from core.errors import InternalError, NoResultsError, TaskNotFoundError -from routers.dependencies import Pagination, expdb_connection +from core.errors import InternalError, NoResultsError, TagAlreadyExistsError, TaskNotFoundError +from database.exceptions import DuplicatePrimaryKeyError, ForeignKeyConstraintError +from database.users import User +from routers.dependencies import Pagination, expdb_connection, fetch_user_or_raise from routers.types import ( CasualString128, Identifier, @@ -31,6 +34,31 @@ type JSON = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None +@router.post(path="/tag") +async def tag_task( + task_id: Annotated[Identifier, Body()], + tag: Annotated[TagString, Body()], + user: Annotated[User, Depends(fetch_user_or_raise)], + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> dict[str, dict[str, Any]]: + try: + await database.tasks.tag(task_id, tag, user_id=user.user_id, connection=expdb_db) + except ForeignKeyConstraintError: + msg = f"Task {task_id} not found." + raise TaskNotFoundError(msg, code=472) from None + except DuplicatePrimaryKeyError: + msg = f"Task {task_id} already tagged with {tag!r}." + raise TagAlreadyExistsError(msg) from None + + logger.info("Task {task_id} tagged '{tag}'.", task_id=task_id, tag=tag) + + tags = await database.tasks.get_tags(task_id, expdb_db) + + return { + "task_tag": {"id": str(task_id), "tag": tags}, + } + + def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]: json_template = xmltodict.parse(xml_template.replace("oml:", "")) json_str = json.dumps(json_template) From 2af428f6082602926dd113b59dd5a818cc230340 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Thu, 25 Jun 2026 09:24:48 +0200 Subject: [PATCH 02/11] Add tests --- src/database/tasks.py | 4 +- tests/routers/openml/task_tag_test.py | 183 ++++++++++++++++++++++++++ 2 files changed, 185 insertions(+), 2 deletions(-) create mode 100644 tests/routers/openml/task_tag_test.py diff --git a/src/database/tasks.py b/src/database/tasks.py index cdd9248..b0f6010 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -156,8 +156,8 @@ async def get_task_type_inout_with_template( ) -async def get_tags(id_: Identifier, expdb: AsyncConnection) -> list[str]: - rows = await expdb.execute( +async def get_tags(id_: Identifier, connection: AsyncConnection) -> list[str]: + rows = await connection.execute( text( """ SELECT `tag` diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py new file mode 100644 index 0000000..b42bb89 --- /dev/null +++ b/tests/routers/openml/task_tag_test.py @@ -0,0 +1,183 @@ +import re +from http import HTTPStatus +from typing import TYPE_CHECKING + +import pytest + +from core.conversions import nested_remove_single_element_list +from core.errors import TagAlreadyExistsError, TaskNotFoundError +from database.tasks import get_tags +from database.users import User +from routers.openml.tasks import tag_task +from tests import constants +from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey + +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + + +@pytest.mark.parametrize( + "key", + [None, ApiKey.INVALID], + ids=["no authentication", "invalid key"], +) +async def test_task_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.AsyncClient) -> None: + apikey = "" if key is None else f"?api_key={key}" + response = await py_api.post( + f"/tasks/tag{apikey}", + json={"task_id": next(iter(constants.PRIVATE_DATASET_ID)), "tag": "test"}, + ) + assert response.status_code == HTTPStatus.UNAUTHORIZED + + +# ── Direct call tests: tag_task ── + + +@pytest.mark.mut +@pytest.mark.parametrize( + "user", + [ADMIN_USER, SOME_USER, OWNER_USER], + ids=["administrator", "non-owner", "owner"], +) +async def test_task_tag(user: User, expdb_test: AsyncConnection) -> None: + task_id, tag = 2, "test" + result = await tag_task( + task_id=task_id, + tag=tag, + user=user, + expdb_db=expdb_test, + ) + assert result == {"task_tag": {"id": str(task_id), "tag": [tag]}} + + tags = await get_tags(id_=task_id, connection=expdb_test) + assert tag in tags + + +@pytest.mark.mut +async def test_task_tag_returns_existing_tags(expdb_test: AsyncConnection) -> None: + task_id, tag = 1, "test" # Task 1 already is tagged with 'OpenML100' + result = await tag_task( + task_id=task_id, + tag=tag, + user=ADMIN_USER, + expdb_db=expdb_test, + ) + assert result == {"task_tag": {"id": str(task_id), "tag": ["OpenML100", tag]}} + + +@pytest.mark.mut +async def test_task_tag_fails_if_tag_exists(expdb_test: AsyncConnection) -> None: + task_id, tag = 1, "OpenML100" # Task 1 already is tagged with 'OpenML100' + with pytest.raises(TagAlreadyExistsError) as e: + await tag_task( + task_id=task_id, + tag=tag, + user=ADMIN_USER, + expdb_db=expdb_test, + ) + assert str(task_id) in e.value.detail + assert tag in e.value.detail + + +async def test_task_tag_fails_if_task_does_not_exist(expdb_test: AsyncConnection) -> None: + task_id = 1_000_000 + with pytest.raises(TaskNotFoundError) as e: + await tag_task( + task_id=task_id, + tag="foo", + user=ADMIN_USER, + expdb_db=expdb_test, + ) + assert str(task_id) in e.value.detail + task_not_found_in_tag_endpoint = 472 + assert e.value.code == task_not_found_in_tag_endpoint + + +# -- migration tests -- + + +@pytest.mark.mut +@pytest.mark.parametrize( + "task_id", + [ + *range(1, 10), + 101, + constants.SOME_DEACTIVATED_DATASET_ID, + constants.DATASET_ID_THAT_DOES_NOT_EXIST, + ], +) +@pytest.mark.parametrize( + "api_key", + [ApiKey.ADMIN, ApiKey.SOME_USER, ApiKey.OWNER_USER], + ids=["Administrator", "regular user", "possible owner"], +) +@pytest.mark.parametrize( + "tag", + ["study_14", "totally_new_tag_for_migration_testing"], + ids=["typically existing tag", "new tag"], +) +async def test_task_tag_response_is_identical( + task_id: int, + tag: str, + api_key: str, + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + # PHP request must happen first to check state, can't parallelize + php_response = await php_api.post( + "/task/tag", + data={"api_key": api_key, "tag": tag, "task_id": task_id}, + ) + already_tagged = ( + php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + and "already tagged" in php_response.json()["error"]["message"] + ) + if not already_tagged: + # undo the tag, because we don't want to persist this change to the taskbase + # Sometimes a change is already committed to the taskbase even if an error occurs. + await php_api.post( + "/task/untag", + data={"api_key": api_key, "tag": tag, "task_id": task_id}, + ) + if ( + php_response.status_code != HTTPStatus.OK + and php_response.json()["error"]["message"] == "An Elastic Search Exception occured." + ): + pytest.skip("Encountered Elastic Search error.") + + py_response = await py_api.post( + f"/tasks/tag?api_key={api_key}", + json={"task_id": task_id, "tag": tag}, + ) + + # RFC 9457: Tag conflict now returns 409 instead of 500 + if php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR and already_tagged: + assert py_response.status_code == HTTPStatus.CONFLICT + assert py_response.json()["code"] == php_response.json()["error"]["code"] + assert php_response.json()["error"]["message"] == "Entity already tagged by this tag." + assert re.match( + pattern=r"Task \d+ already tagged with " + f"'{tag}'.", + string=py_response.json()["detail"], + ) + return + + if py_response.status_code == HTTPStatus.NOT_FOUND: + assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED + py_error = py_response.json() + php_error = php_response.json()["error"] + assert py_error["code"] == php_error["code"] + assert php_error["message"] == "Entity not found." + assert re.match(r"Task \d+ not found.", py_error["detail"]) + return + + assert py_response.status_code == php_response.status_code, php_response.json() + if py_response.status_code != HTTPStatus.OK: + assert py_response.json()["code"] == php_response.json()["error"]["code"] + assert py_response.json()["detail"] == php_response.json()["error"]["message"] + return + + php_json = php_response.json() + py_json = py_response.json() + py_json = nested_remove_single_element_list(py_json) + assert py_json == php_json From 3db1974e48053efb4a7289bcf0ebf1121da34a32 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Thu, 25 Jun 2026 11:07:18 +0200 Subject: [PATCH 03/11] Toward unifying tag tests --- tests/routers/openml/dataset_tag_test.py | 61 +----------------- tests/routers/openml/tag_test_helper.py | 79 ++++++++++++++++++++++++ tests/routers/openml/task_tag_test.py | 63 +------------------ 3 files changed, 84 insertions(+), 119 deletions(-) create mode 100644 tests/routers/openml/tag_test_helper.py diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index bb12a74..16f719a 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -1,15 +1,14 @@ -import re from http import HTTPStatus from typing import TYPE_CHECKING import pytest -from core.conversions import nested_remove_single_element_list from core.errors import DatasetNotFoundError, TagAlreadyExistsError from database.datasets import get_tags_for from database.users import User from routers.openml.datasets import tag_dataset from tests import constants +from tests.routers.openml.tag_test_helper import assert_tag_response_is_identical from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey if TYPE_CHECKING: @@ -124,60 +123,4 @@ async def test_dataset_tag_response_is_identical( py_api: httpx.AsyncClient, php_api: httpx.AsyncClient, ) -> None: - # PHP request must happen first to check state, can't parallelize - php_response = await php_api.post( - "/data/tag", - data={"api_key": api_key, "tag": tag, "data_id": dataset_id}, - ) - already_tagged = ( - php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR - and "already tagged" in php_response.json()["error"]["message"] - ) - if not already_tagged: - # undo the tag, because we don't want to persist this change to the database - # Sometimes a change is already committed to the database even if an error occurs. - await php_api.post( - "/data/untag", - data={"api_key": api_key, "tag": tag, "data_id": dataset_id}, - ) - if ( - php_response.status_code != HTTPStatus.OK - and php_response.json()["error"]["message"] == "An Elastic Search Exception occured." - ): - pytest.skip("Encountered Elastic Search error.") - - py_response = await py_api.post( - f"/datasets/tag?api_key={api_key}", - json={"data_id": dataset_id, "tag": tag}, - ) - - # RFC 9457: Tag conflict now returns 409 instead of 500 - if php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR and already_tagged: - assert py_response.status_code == HTTPStatus.CONFLICT - assert py_response.json()["code"] == php_response.json()["error"]["code"] - assert php_response.json()["error"]["message"] == "Entity already tagged by this tag." - assert re.match( - pattern=r"Dataset \d+ already tagged with " + f"'{tag}'.", - string=py_response.json()["detail"], - ) - return - - if py_response.status_code == HTTPStatus.NOT_FOUND: - assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED - py_error = py_response.json() - php_error = php_response.json()["error"] - assert py_error["code"] == php_error["code"] - assert php_error["message"] == "Entity not found." - assert re.match(r"Dataset \d+ not found.", py_error["detail"]) - return - - assert py_response.status_code == php_response.status_code, php_response.json() - if py_response.status_code != HTTPStatus.OK: - assert py_response.json()["code"] == php_response.json()["error"]["code"] - assert py_response.json()["detail"] == php_response.json()["error"]["message"] - return - - php_json = php_response.json() - py_json = py_response.json() - py_json = nested_remove_single_element_list(py_json) - assert py_json == php_json + await assert_tag_response_is_identical(dataset_id, tag, api_key, "dataset", py_api, php_api) diff --git a/tests/routers/openml/tag_test_helper.py b/tests/routers/openml/tag_test_helper.py new file mode 100644 index 0000000..4c4a99d --- /dev/null +++ b/tests/routers/openml/tag_test_helper.py @@ -0,0 +1,79 @@ +import re +from http import HTTPStatus +from typing import TYPE_CHECKING + +import pytest + +from core.conversions import nested_remove_single_element_list + +if TYPE_CHECKING: + import httpx + + +async def assert_tag_response_is_identical( # noqa: PLR0913 + identifier: int, + tag: str, + api_key: str, + entity: str, + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + php_alias = "data" if entity == "dataset" else entity + # PHP request must happen first to check state, can't parallelize + php_response = await php_api.post( + f"/{php_alias}/tag", + data={"api_key": api_key, "tag": tag, f"{php_alias}_id": identifier}, + ) + already_tagged = ( + php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + and "already tagged" in php_response.json()["error"]["message"] + ) + if not already_tagged: + # undo the tag, because we don't want to persist this change to the taskbase + # Sometimes a change is already committed to the taskbase even if an error occurs. + await php_api.post( + f"/{php_alias}/untag", + data={"api_key": api_key, "tag": tag, f"{php_alias}_id": identifier}, + ) + if ( + php_response.status_code != HTTPStatus.OK + and php_response.json()["error"]["message"] == "An Elastic Search Exception occured." + ): + pytest.skip("Encountered Elastic Search error.") + + entity_plural = f"{entity}s" + py_response = await py_api.post( + f"/{entity_plural}/tag?api_key={api_key}", + json={f"{php_alias}_id": identifier, "tag": tag}, + ) + + # RFC 9457: Tag conflict now returns 409 instead of 500 + if php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR and already_tagged: + assert py_response.status_code == HTTPStatus.CONFLICT + assert py_response.json()["code"] == php_response.json()["error"]["code"] + assert php_response.json()["error"]["message"] == "Entity already tagged by this tag." + assert re.match( + pattern=rf"{entity.capitalize()} \d+ already tagged with " + f"'{tag}'.", + string=py_response.json()["detail"], + ) + return + + if py_response.status_code == HTTPStatus.NOT_FOUND: + assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED + py_error = py_response.json() + php_error = php_response.json()["error"] + assert py_error["code"] == php_error["code"] + assert php_error["message"] == "Entity not found." + assert re.match(rf"{entity.capitalize()} \d+ not found.", py_error["detail"]) + return + + assert py_response.status_code == php_response.status_code, php_response.json() + if py_response.status_code != HTTPStatus.OK: + assert py_response.json()["code"] == php_response.json()["error"]["code"] + assert py_response.json()["detail"] == php_response.json()["error"]["message"] + return + + php_json = php_response.json() + py_json = py_response.json() + py_json = nested_remove_single_element_list(py_json) + assert py_json == php_json diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py index b42bb89..98bf457 100644 --- a/tests/routers/openml/task_tag_test.py +++ b/tests/routers/openml/task_tag_test.py @@ -1,15 +1,14 @@ -import re from http import HTTPStatus from typing import TYPE_CHECKING import pytest -from core.conversions import nested_remove_single_element_list from core.errors import TagAlreadyExistsError, TaskNotFoundError from database.tasks import get_tags from database.users import User from routers.openml.tasks import tag_task from tests import constants +from tests.routers.openml.tag_test_helper import assert_tag_response_is_identical from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey if TYPE_CHECKING: @@ -114,7 +113,7 @@ async def test_task_tag_fails_if_task_does_not_exist(expdb_test: AsyncConnection ) @pytest.mark.parametrize( "tag", - ["study_14", "totally_new_tag_for_migration_testing"], + ["OpenML100", "totally_new_tag_for_migration_testing"], ids=["typically existing tag", "new tag"], ) async def test_task_tag_response_is_identical( @@ -124,60 +123,4 @@ async def test_task_tag_response_is_identical( py_api: httpx.AsyncClient, php_api: httpx.AsyncClient, ) -> None: - # PHP request must happen first to check state, can't parallelize - php_response = await php_api.post( - "/task/tag", - data={"api_key": api_key, "tag": tag, "task_id": task_id}, - ) - already_tagged = ( - php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR - and "already tagged" in php_response.json()["error"]["message"] - ) - if not already_tagged: - # undo the tag, because we don't want to persist this change to the taskbase - # Sometimes a change is already committed to the taskbase even if an error occurs. - await php_api.post( - "/task/untag", - data={"api_key": api_key, "tag": tag, "task_id": task_id}, - ) - if ( - php_response.status_code != HTTPStatus.OK - and php_response.json()["error"]["message"] == "An Elastic Search Exception occured." - ): - pytest.skip("Encountered Elastic Search error.") - - py_response = await py_api.post( - f"/tasks/tag?api_key={api_key}", - json={"task_id": task_id, "tag": tag}, - ) - - # RFC 9457: Tag conflict now returns 409 instead of 500 - if php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR and already_tagged: - assert py_response.status_code == HTTPStatus.CONFLICT - assert py_response.json()["code"] == php_response.json()["error"]["code"] - assert php_response.json()["error"]["message"] == "Entity already tagged by this tag." - assert re.match( - pattern=r"Task \d+ already tagged with " + f"'{tag}'.", - string=py_response.json()["detail"], - ) - return - - if py_response.status_code == HTTPStatus.NOT_FOUND: - assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED - py_error = py_response.json() - php_error = php_response.json()["error"] - assert py_error["code"] == php_error["code"] - assert php_error["message"] == "Entity not found." - assert re.match(r"Task \d+ not found.", py_error["detail"]) - return - - assert py_response.status_code == php_response.status_code, php_response.json() - if py_response.status_code != HTTPStatus.OK: - assert py_response.json()["code"] == php_response.json()["error"]["code"] - assert py_response.json()["detail"] == php_response.json()["error"]["message"] - return - - php_json = php_response.json() - py_json = py_response.json() - py_json = nested_remove_single_element_list(py_json) - assert py_json == php_json + await assert_tag_response_is_identical(task_id, tag, api_key, "task", py_api, php_api) From e43b177c629c6189d6ba98fb815be969fa1ab0e4 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Thu, 25 Jun 2026 15:57:18 +0200 Subject: [PATCH 04/11] Refactor the tag task tests --- tests/conftest.py | 37 ++++++++++++++++- tests/routers/openml/task_tag_test.py | 58 +++++++++++---------------- 2 files changed, 59 insertions(+), 36 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 23a5295..bd8cb5b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,8 @@ import contextlib import json -from collections.abc import AsyncIterator, Callable, Iterable, Iterator +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator from pathlib import Path -from typing import TYPE_CHECKING, Any, NamedTuple +from typing import TYPE_CHECKING, Any, NamedTuple, Protocol import _pytest.mark import httpx @@ -22,6 +22,7 @@ from database.setup import expdb_database, user_database from main import create_api from routers.dependencies import expdb_connection, userdb_connection +from routers.types import Identifier from tests.users import OWNER_USER if TYPE_CHECKING: @@ -139,6 +140,38 @@ def dataset_130() -> Iterator[dict[str, Any]]: yield json.load(dataset_file) +class Task(NamedTuple): + """To be replaced by an actual ORM class.""" + + id: Identifier + task_type: Identifier + creator: Identifier + + +class TaskFactory(Protocol): + def __call__( + self, *, task_id: Identifier = 42_000, task_type: Identifier = 1, creator: Identifier = 1 + ) -> Awaitable[Task]: ... + + +@pytest.fixture +async def task_factory( + expdb_test: AsyncConnection, +) -> TaskFactory: + async def create_task( + *, task_id: Identifier = 42_000, task_type: Identifier = 1, creator: Identifier = 1 + ) -> Task: + await expdb_test.execute( + text(""" + INSERT INTO task (task_id, ttid, creator) VALUES (:task_id, :ttid, :creator); + """), + parameters={"task_id": task_id, "ttid": task_type, "creator": creator}, + ) + return Task(task_id, task_type, creator) + + return create_task + + class Flow(NamedTuple): """To be replaced by an actual ORM class.""" diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py index 98bf457..35ab120 100644 --- a/tests/routers/openml/task_tag_test.py +++ b/tests/routers/openml/task_tag_test.py @@ -8,6 +8,7 @@ from database.users import User from routers.openml.tasks import tag_task from tests import constants +from tests.conftest import TaskFactory from tests.routers.openml.tag_test_helper import assert_tag_response_is_identical from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey @@ -39,55 +40,44 @@ async def test_task_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.AsyncCli [ADMIN_USER, SOME_USER, OWNER_USER], ids=["administrator", "non-owner", "owner"], ) -async def test_task_tag(user: User, expdb_test: AsyncConnection) -> None: - task_id, tag = 2, "test" - result = await tag_task( - task_id=task_id, - tag=tag, - user=user, - expdb_db=expdb_test, - ) - assert result == {"task_tag": {"id": str(task_id), "tag": [tag]}} +async def test_task_tag(user: User, expdb_test: AsyncConnection, task_factory: TaskFactory) -> None: + tag = "test_task_tag" + task = await task_factory() + result = await tag_task(task_id=task.id, tag=tag, user=user, expdb_db=expdb_test) + assert result == {"task_tag": {"id": str(task.id), "tag": [tag]}} - tags = await get_tags(id_=task_id, connection=expdb_test) + tags = await get_tags(id_=task.id, connection=expdb_test) assert tag in tags @pytest.mark.mut -async def test_task_tag_returns_existing_tags(expdb_test: AsyncConnection) -> None: - task_id, tag = 1, "test" # Task 1 already is tagged with 'OpenML100' - result = await tag_task( - task_id=task_id, - tag=tag, - user=ADMIN_USER, - expdb_db=expdb_test, - ) - assert result == {"task_tag": {"id": str(task_id), "tag": ["OpenML100", tag]}} +async def test_task_tag_returns_existing_tags( + task_factory: TaskFactory, expdb_test: AsyncConnection +) -> None: + task = await task_factory() + await tag_task(task_id=task.id, tag="first", user=ADMIN_USER, expdb_db=expdb_test) + result = await tag_task(task_id=task.id, tag="second", user=ADMIN_USER, expdb_db=expdb_test) + assert result == {"task_tag": {"id": str(task.id), "tag": ["first", "second"]}} @pytest.mark.mut -async def test_task_tag_fails_if_tag_exists(expdb_test: AsyncConnection) -> None: - task_id, tag = 1, "OpenML100" # Task 1 already is tagged with 'OpenML100' +async def test_task_tag_fails_if_tag_exists( + expdb_test: AsyncConnection, task_factory: TaskFactory +) -> None: + tag = "fails_if_exist" + task = await task_factory() + await tag_task(task_id=task.id, tag=tag, user=ADMIN_USER, expdb_db=expdb_test) + with pytest.raises(TagAlreadyExistsError) as e: - await tag_task( - task_id=task_id, - tag=tag, - user=ADMIN_USER, - expdb_db=expdb_test, - ) - assert str(task_id) in e.value.detail + await tag_task(task_id=task.id, tag=tag, user=ADMIN_USER, expdb_db=expdb_test) + assert str(task.id) in e.value.detail assert tag in e.value.detail async def test_task_tag_fails_if_task_does_not_exist(expdb_test: AsyncConnection) -> None: task_id = 1_000_000 with pytest.raises(TaskNotFoundError) as e: - await tag_task( - task_id=task_id, - tag="foo", - user=ADMIN_USER, - expdb_db=expdb_test, - ) + await tag_task(task_id=task_id, tag="foo", user=ADMIN_USER, expdb_db=expdb_test) assert str(task_id) in e.value.detail task_not_found_in_tag_endpoint = 472 assert e.value.code == task_not_found_in_tag_endpoint From 25992b2320befa2e8aed61a6d9ce2f7138891db7 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Thu, 25 Jun 2026 16:19:28 +0200 Subject: [PATCH 05/11] Refactor dataset tag tests to not rely on database state --- tests/conftest.py | 44 +++++++++++++++++++++-- tests/routers/openml/dataset_tag_test.py | 45 ++++++++++++------------ 2 files changed, 64 insertions(+), 25 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index bd8cb5b..e604f0f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import contextlib +import datetime import json from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator from pathlib import Path @@ -150,7 +151,11 @@ class Task(NamedTuple): class TaskFactory(Protocol): def __call__( - self, *, task_id: Identifier = 42_000, task_type: Identifier = 1, creator: Identifier = 1 + self, + *, + task_id: Identifier = 42_000, + task_type: Identifier = 1, + creator: Identifier = OWNER_USER.user_id, ) -> Awaitable[Task]: ... @@ -159,7 +164,10 @@ async def task_factory( expdb_test: AsyncConnection, ) -> TaskFactory: async def create_task( - *, task_id: Identifier = 42_000, task_type: Identifier = 1, creator: Identifier = 1 + *, + task_id: Identifier = 42_000, + task_type: Identifier = 1, + creator: Identifier = OWNER_USER.user_id, ) -> Task: await expdb_test.execute( text(""" @@ -172,6 +180,38 @@ async def create_task( return create_task +class DatasetFactory(Protocol): + def __call__( + self, *, dataset_id: Identifier = 42_000, creator: Identifier = OWNER_USER.user_id + ) -> Awaitable[Identifier]: ... + + +@pytest.fixture +async def dataset_factory( + expdb_test: AsyncConnection, +) -> DatasetFactory: + async def create_dataset( + *, dataset_id: Identifier = 42_000, creator: Identifier = OWNER_USER.user_id + ) -> Identifier: + await expdb_test.execute( + text(""" + INSERT INTO dataset + (did, uploader, name, version, format, upload_date, licence, url, visibility) + VALUES + (:dataset_id, :creator, 'dataset-name', 'dataset-version', 'dataset-format', + :now, 'public', 'dataset-url', 'public'); + """), + parameters={ + "dataset_id": dataset_id, + "creator": creator, + "now": datetime.datetime.now(tz=datetime.UTC), + }, + ) + return dataset_id + + return create_dataset + + class Flow(NamedTuple): """To be replaced by an actual ORM class.""" diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index 16f719a..307927e 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -8,6 +8,7 @@ from database.users import User from routers.openml.datasets import tag_dataset from tests import constants +from tests.conftest import DatasetFactory from tests.routers.openml.tag_test_helper import assert_tag_response_is_identical from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey @@ -39,14 +40,12 @@ async def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.Async [ADMIN_USER, SOME_USER, OWNER_USER], ids=["administrator", "non-owner", "owner"], ) -async def test_dataset_tag(user: User, expdb_test: AsyncConnection) -> None: - dataset_id, tag = next(iter(constants.PRIVATE_DATASET_ID)), "test" - result = await tag_dataset( - data_id=dataset_id, - tag=tag, - user=user, - expdb_db=expdb_test, - ) +async def test_dataset_tag( + user: User, expdb_test: AsyncConnection, dataset_factory: DatasetFactory +) -> None: + dataset_id = await dataset_factory() + tag = "test_tag" + result = await tag_dataset(data_id=dataset_id, tag=tag, user=user, expdb_db=expdb_test) assert result == {"data_tag": {"id": str(dataset_id), "tag": [tag]}} tags = await get_tags_for(id_=dataset_id, connection=expdb_test) @@ -54,27 +53,27 @@ async def test_dataset_tag(user: User, expdb_test: AsyncConnection) -> None: @pytest.mark.mut -async def test_dataset_tag_returns_existing_tags(expdb_test: AsyncConnection) -> None: - dataset_id, tag = 1, "test" # Dataset 1 already is tagged with 'study_14' +async def test_dataset_tag_returns_existing_tags( + expdb_test: AsyncConnection, dataset_factory: DatasetFactory +) -> None: + dataset_id = await dataset_factory() + await tag_dataset(data_id=dataset_id, tag="first", user=OWNER_USER, expdb_db=expdb_test) result = await tag_dataset( - data_id=dataset_id, - tag=tag, - user=ADMIN_USER, - expdb_db=expdb_test, + data_id=dataset_id, tag="second", user=ADMIN_USER, expdb_db=expdb_test ) - assert result == {"data_tag": {"id": str(dataset_id), "tag": ["study_14", tag]}} + assert result == {"data_tag": {"id": str(dataset_id), "tag": ["first", "second"]}} @pytest.mark.mut -async def test_dataset_tag_fails_if_tag_exists(expdb_test: AsyncConnection) -> None: - dataset_id, tag = 1, "study_14" # Dataset 1 already is tagged with 'study_14' +async def test_dataset_tag_fails_if_tag_exists( + expdb_test: AsyncConnection, dataset_factory: DatasetFactory +) -> None: + tag = "repeated_tag" + dataset_id = await dataset_factory() + await tag_dataset(data_id=dataset_id, tag=tag, user=OWNER_USER, expdb_db=expdb_test) + with pytest.raises(TagAlreadyExistsError) as e: - await tag_dataset( - data_id=dataset_id, - tag=tag, - user=ADMIN_USER, - expdb_db=expdb_test, - ) + await tag_dataset(data_id=dataset_id, tag=tag, user=ADMIN_USER, expdb_db=expdb_test) assert str(dataset_id) in e.value.detail assert tag in e.value.detail From 3078682fda368bfa6b37061e226282b1c4ea0436 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Thu, 25 Jun 2026 16:28:38 +0200 Subject: [PATCH 06/11] indicate that the identifier does not matter to the test --- tests/routers/openml/dataset_tag_test.py | 3 ++- tests/routers/openml/task_tag_test.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index 307927e..42ef658 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -24,9 +24,10 @@ ) async def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.AsyncClient) -> None: apikey = "" if key is None else f"?api_key={key}" + any_dataset_identifier = 1 response = await py_api.post( f"/datasets/tag{apikey}", - json={"data_id": next(iter(constants.PRIVATE_DATASET_ID)), "tag": "test"}, + json={"data_id": any_dataset_identifier, "tag": "test"}, ) assert response.status_code == HTTPStatus.UNAUTHORIZED diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py index 35ab120..ba76b61 100644 --- a/tests/routers/openml/task_tag_test.py +++ b/tests/routers/openml/task_tag_test.py @@ -24,9 +24,10 @@ ) async def test_task_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.AsyncClient) -> None: apikey = "" if key is None else f"?api_key={key}" + any_task_id = 1 response = await py_api.post( f"/tasks/tag{apikey}", - json={"task_id": next(iter(constants.PRIVATE_DATASET_ID)), "tag": "test"}, + json={"task_id": any_task_id, "tag": "test"}, ) assert response.status_code == HTTPStatus.UNAUTHORIZED From 0396bca80e9cff37e1e00321adfc7a34fb878e93 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Thu, 25 Jun 2026 16:30:32 +0200 Subject: [PATCH 07/11] generalize the name since it should be valid for all entities --- tests/constants.py | 2 +- tests/routers/openml/dataset_tag_test.py | 2 +- tests/routers/openml/task_tag_test.py | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/constants.py b/tests/constants.py index 5563dec..991c2e9 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -1,4 +1,4 @@ -DATASET_ID_THAT_DOES_NOT_EXIST = 9_9999_999 +ENTITY_ID_THAT_DOES_NOT_EXIST = 9_9999_999 SOME_PRIVATE_DATASET_ID = 130 PRIVATE_DATASET_ID = {130} IN_PREPARATION_ID = {33, 161, 162, 163} diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index 42ef658..dd80966 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -103,7 +103,7 @@ async def test_dataset_tag_fails_if_dataset_does_not_exist(expdb_test: AsyncConn *range(1, 10), 101, constants.SOME_DEACTIVATED_DATASET_ID, - constants.DATASET_ID_THAT_DOES_NOT_EXIST, + constants.ENTITY_ID_THAT_DOES_NOT_EXIST, ], ) @pytest.mark.parametrize( diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py index ba76b61..a2d9fc9 100644 --- a/tests/routers/openml/task_tag_test.py +++ b/tests/routers/openml/task_tag_test.py @@ -93,8 +93,7 @@ async def test_task_tag_fails_if_task_does_not_exist(expdb_test: AsyncConnection [ *range(1, 10), 101, - constants.SOME_DEACTIVATED_DATASET_ID, - constants.DATASET_ID_THAT_DOES_NOT_EXIST, + constants.ENTITY_ID_THAT_DOES_NOT_EXIST, ], ) @pytest.mark.parametrize( From d7a406a3b39633dd8e2637a2bf5e8e1810aadf0e Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Fri, 26 Jun 2026 09:12:21 +0200 Subject: [PATCH 08/11] remove dead code --- tests/routers/openml/tag_test_helper.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/routers/openml/tag_test_helper.py b/tests/routers/openml/tag_test_helper.py index 4c4a99d..ba3560d 100644 --- a/tests/routers/openml/tag_test_helper.py +++ b/tests/routers/openml/tag_test_helper.py @@ -68,10 +68,6 @@ async def assert_tag_response_is_identical( # noqa: PLR0913 return assert py_response.status_code == php_response.status_code, php_response.json() - if py_response.status_code != HTTPStatus.OK: - assert py_response.json()["code"] == php_response.json()["error"]["code"] - assert py_response.json()["detail"] == php_response.json()["error"]["message"] - return php_json = php_response.json() py_json = py_response.json() From c11d5aece94b799d9d06800fff477c5a69689b89 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Fri, 26 Jun 2026 09:14:39 +0200 Subject: [PATCH 09/11] Make task not found in tag error code a constant --- src/core/errors.py | 2 ++ tests/routers/openml/task_tag_test.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/core/errors.py b/src/core/errors.py index f8de607..2ae6b6f 100644 --- a/src/core/errors.py +++ b/src/core/errors.py @@ -251,6 +251,8 @@ class AccountHasResourcesError(ProblemDetailError): # Tag Errors # ============================================================================= +TASK_NOT_FOUND_DURING_TAG = 472 + class TagAlreadyExistsError(ProblemDetailError): """Raised when trying to add a tag that already exists.""" diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py index a2d9fc9..087f76e 100644 --- a/tests/routers/openml/task_tag_test.py +++ b/tests/routers/openml/task_tag_test.py @@ -3,7 +3,7 @@ import pytest -from core.errors import TagAlreadyExistsError, TaskNotFoundError +from core.errors import TASK_NOT_FOUND_DURING_TAG, TagAlreadyExistsError, TaskNotFoundError from database.tasks import get_tags from database.users import User from routers.openml.tasks import tag_task @@ -80,7 +80,7 @@ async def test_task_tag_fails_if_task_does_not_exist(expdb_test: AsyncConnection with pytest.raises(TaskNotFoundError) as e: await tag_task(task_id=task_id, tag="foo", user=ADMIN_USER, expdb_db=expdb_test) assert str(task_id) in e.value.detail - task_not_found_in_tag_endpoint = 472 + task_not_found_in_tag_endpoint = TASK_NOT_FOUND_DURING_TAG assert e.value.code == task_not_found_in_tag_endpoint From e19bd0d77874e34b18ee0f35ad988b3c3bfe1714 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Fri, 26 Jun 2026 09:27:12 +0200 Subject: [PATCH 10/11] Make the dataset and task factories callable multiple times --- tests/conftest.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e604f0f..4191c47 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -159,16 +159,32 @@ def __call__( ) -> Awaitable[Task]: ... +def _create_identifier_factory() -> Callable[[], Identifier]: + _identifier_counter: Identifier = 10_000_000 + + def _get() -> Identifier: + nonlocal _identifier_counter + _identifier_counter += 1 + return _identifier_counter + + return _get + + +_identifier_factory = _create_identifier_factory() + + @pytest.fixture async def task_factory( expdb_test: AsyncConnection, ) -> TaskFactory: async def create_task( *, - task_id: Identifier = 42_000, + task_id: Identifier | None = None, task_type: Identifier = 1, creator: Identifier = OWNER_USER.user_id, ) -> Task: + task_id = task_id or _identifier_factory() + await expdb_test.execute( text(""" INSERT INTO task (task_id, ttid, creator) VALUES (:task_id, :ttid, :creator); @@ -191,20 +207,22 @@ async def dataset_factory( expdb_test: AsyncConnection, ) -> DatasetFactory: async def create_dataset( - *, dataset_id: Identifier = 42_000, creator: Identifier = OWNER_USER.user_id + *, dataset_id: Identifier | None = None, creator: Identifier = OWNER_USER.user_id ) -> Identifier: + dataset_id = dataset_id or _identifier_factory() await expdb_test.execute( text(""" INSERT INTO dataset (did, uploader, name, version, format, upload_date, licence, url, visibility) VALUES - (:dataset_id, :creator, 'dataset-name', 'dataset-version', 'dataset-format', + (:dataset_id, :creator, :name, 'dataset-version', 'dataset-format', :now, 'public', 'dataset-url', 'public'); """), parameters={ "dataset_id": dataset_id, "creator": creator, "now": datetime.datetime.now(tz=datetime.UTC), + "name": f"dataset-name-{dataset_id}", }, ) return dataset_id From 46ae998450480fe50e4d7d3910762a22f5fa153f Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Fri, 26 Jun 2026 09:35:10 +0200 Subject: [PATCH 11/11] Simplify control flow, fix comments --- tests/routers/openml/tag_test_helper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/routers/openml/tag_test_helper.py b/tests/routers/openml/tag_test_helper.py index ba3560d..eb8fb97 100644 --- a/tests/routers/openml/tag_test_helper.py +++ b/tests/routers/openml/tag_test_helper.py @@ -29,8 +29,8 @@ async def assert_tag_response_is_identical( # noqa: PLR0913 and "already tagged" in php_response.json()["error"]["message"] ) if not already_tagged: - # undo the tag, because we don't want to persist this change to the taskbase - # Sometimes a change is already committed to the taskbase even if an error occurs. + # undo the tag, because we don't want to persist this change to the database + # Sometimes a change is already committed to the database even if an error occurs. await php_api.post( f"/{php_alias}/untag", data={"api_key": api_key, "tag": tag, f"{php_alias}_id": identifier}, @@ -47,8 +47,8 @@ async def assert_tag_response_is_identical( # noqa: PLR0913 json={f"{php_alias}_id": identifier, "tag": tag}, ) - # RFC 9457: Tag conflict now returns 409 instead of 500 - if php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR and already_tagged: + # RFC 9457: Tag conflict now returns 409 (CONFLICT) instead of 500 (INTERNAL SERVER ERROR) + if already_tagged: assert py_response.status_code == HTTPStatus.CONFLICT assert py_response.json()["code"] == php_response.json()["error"]["code"] assert php_response.json()["error"]["message"] == "Entity already tagged by this tag."