From f5bc6414f9c149a606fbb6c720becae28b484df7 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Thu, 25 Jun 2026 08:54:08 +0200
Subject: [PATCH 01/11] Add tag endpoint
---
src/database/tasks.py | 39 ++++++++++++++++++++++++++++++++++++-
src/routers/openml/tasks.py | 32 ++++++++++++++++++++++++++++--
2 files changed, 68 insertions(+), 3 deletions(-)
diff --git a/src/database/tasks.py b/src/database/tasks.py
index 788fc93..cdd9248 100644
--- a/src/database/tasks.py
+++ b/src/database/tasks.py
@@ -2,8 +2,15 @@
from typing import TYPE_CHECKING, cast
from sqlalchemy import Row, text
+from sqlalchemy.exc import IntegrityError
-from routers.types import Identifier
+from database.exceptions import (
+ _DUPLICATE_ENTRY,
+ _FOREIGN_KEY_CONSTRAINT_FAILED,
+ DuplicatePrimaryKeyError,
+ ForeignKeyConstraintError,
+)
+from routers.types import Identifier, TagString
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection
@@ -162,3 +169,33 @@ async def get_tags(id_: Identifier, expdb: AsyncConnection) -> list[str]:
)
tag_rows = rows.all()
return [row.tag for row in tag_rows]
+
+
+async def tag(
+ id_: Identifier,
+ tag_: TagString,
+ *,
+ user_id: Identifier,
+ connection: AsyncConnection,
+) -> None:
+ try:
+ await connection.execute(
+ text(
+ """
+ INSERT INTO task_tag(`id`, `tag`, `uploader`)
+ VALUES (:task_id, :tag, :user_id)
+ """,
+ ),
+ parameters={
+ "task_id": id_,
+ "user_id": user_id,
+ "tag": tag_,
+ },
+ )
+ except IntegrityError as e:
+ code, msg = e.orig.args
+ if code == _FOREIGN_KEY_CONSTRAINT_FAILED:
+ raise ForeignKeyConstraintError(msg) from e
+ if code == _DUPLICATE_ENTRY:
+ raise DuplicatePrimaryKeyError(msg) from e
+ raise
diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py
index 8faa898..f68c4b8 100644
--- a/src/routers/openml/tasks.py
+++ b/src/routers/openml/tasks.py
@@ -6,13 +6,16 @@
import xmltodict
from fastapi import APIRouter, Body, Depends
+from loguru import logger
from sqlalchemy import bindparam, text
import database.datasets
import database.tasks
from config import get_config
-from core.errors import InternalError, NoResultsError, TaskNotFoundError
-from routers.dependencies import Pagination, expdb_connection
+from core.errors import InternalError, NoResultsError, TagAlreadyExistsError, TaskNotFoundError
+from database.exceptions import DuplicatePrimaryKeyError, ForeignKeyConstraintError
+from database.users import User
+from routers.dependencies import Pagination, expdb_connection, fetch_user_or_raise
from routers.types import (
CasualString128,
Identifier,
@@ -31,6 +34,31 @@
type JSON = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None
+@router.post(path="/tag")
+async def tag_task(
+ task_id: Annotated[Identifier, Body()],
+ tag: Annotated[TagString, Body()],
+ user: Annotated[User, Depends(fetch_user_or_raise)],
+ expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
+) -> dict[str, dict[str, Any]]:
+ try:
+ await database.tasks.tag(task_id, tag, user_id=user.user_id, connection=expdb_db)
+ except ForeignKeyConstraintError:
+ msg = f"Task {task_id} not found."
+ raise TaskNotFoundError(msg, code=472) from None
+ except DuplicatePrimaryKeyError:
+ msg = f"Task {task_id} already tagged with {tag!r}."
+ raise TagAlreadyExistsError(msg) from None
+
+ logger.info("Task {task_id} tagged '{tag}'.", task_id=task_id, tag=tag)
+
+ tags = await database.tasks.get_tags(task_id, expdb_db)
+
+ return {
+ "task_tag": {"id": str(task_id), "tag": tags},
+ }
+
+
def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]:
json_template = xmltodict.parse(xml_template.replace("oml:", ""))
json_str = json.dumps(json_template)
From 2af428f6082602926dd113b59dd5a818cc230340 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Thu, 25 Jun 2026 09:24:48 +0200
Subject: [PATCH 02/11] Add tests
---
src/database/tasks.py | 4 +-
tests/routers/openml/task_tag_test.py | 183 ++++++++++++++++++++++++++
2 files changed, 185 insertions(+), 2 deletions(-)
create mode 100644 tests/routers/openml/task_tag_test.py
diff --git a/src/database/tasks.py b/src/database/tasks.py
index cdd9248..b0f6010 100644
--- a/src/database/tasks.py
+++ b/src/database/tasks.py
@@ -156,8 +156,8 @@ async def get_task_type_inout_with_template(
)
-async def get_tags(id_: Identifier, expdb: AsyncConnection) -> list[str]:
- rows = await expdb.execute(
+async def get_tags(id_: Identifier, connection: AsyncConnection) -> list[str]:
+ rows = await connection.execute(
text(
"""
SELECT `tag`
diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py
new file mode 100644
index 0000000..b42bb89
--- /dev/null
+++ b/tests/routers/openml/task_tag_test.py
@@ -0,0 +1,183 @@
+import re
+from http import HTTPStatus
+from typing import TYPE_CHECKING
+
+import pytest
+
+from core.conversions import nested_remove_single_element_list
+from core.errors import TagAlreadyExistsError, TaskNotFoundError
+from database.tasks import get_tags
+from database.users import User
+from routers.openml.tasks import tag_task
+from tests import constants
+from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey
+
+if TYPE_CHECKING:
+ import httpx
+ from sqlalchemy.ext.asyncio import AsyncConnection
+
+
+@pytest.mark.parametrize(
+ "key",
+ [None, ApiKey.INVALID],
+ ids=["no authentication", "invalid key"],
+)
+async def test_task_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.AsyncClient) -> None:
+ apikey = "" if key is None else f"?api_key={key}"
+ response = await py_api.post(
+ f"/tasks/tag{apikey}",
+ json={"task_id": next(iter(constants.PRIVATE_DATASET_ID)), "tag": "test"},
+ )
+ assert response.status_code == HTTPStatus.UNAUTHORIZED
+
+
+# ── Direct call tests: tag_task ──
+
+
+@pytest.mark.mut
+@pytest.mark.parametrize(
+ "user",
+ [ADMIN_USER, SOME_USER, OWNER_USER],
+ ids=["administrator", "non-owner", "owner"],
+)
+async def test_task_tag(user: User, expdb_test: AsyncConnection) -> None:
+ task_id, tag = 2, "test"
+ result = await tag_task(
+ task_id=task_id,
+ tag=tag,
+ user=user,
+ expdb_db=expdb_test,
+ )
+ assert result == {"task_tag": {"id": str(task_id), "tag": [tag]}}
+
+ tags = await get_tags(id_=task_id, connection=expdb_test)
+ assert tag in tags
+
+
+@pytest.mark.mut
+async def test_task_tag_returns_existing_tags(expdb_test: AsyncConnection) -> None:
+ task_id, tag = 1, "test" # Task 1 already is tagged with 'OpenML100'
+ result = await tag_task(
+ task_id=task_id,
+ tag=tag,
+ user=ADMIN_USER,
+ expdb_db=expdb_test,
+ )
+ assert result == {"task_tag": {"id": str(task_id), "tag": ["OpenML100", tag]}}
+
+
+@pytest.mark.mut
+async def test_task_tag_fails_if_tag_exists(expdb_test: AsyncConnection) -> None:
+ task_id, tag = 1, "OpenML100" # Task 1 already is tagged with 'OpenML100'
+ with pytest.raises(TagAlreadyExistsError) as e:
+ await tag_task(
+ task_id=task_id,
+ tag=tag,
+ user=ADMIN_USER,
+ expdb_db=expdb_test,
+ )
+ assert str(task_id) in e.value.detail
+ assert tag in e.value.detail
+
+
+async def test_task_tag_fails_if_task_does_not_exist(expdb_test: AsyncConnection) -> None:
+ task_id = 1_000_000
+ with pytest.raises(TaskNotFoundError) as e:
+ await tag_task(
+ task_id=task_id,
+ tag="foo",
+ user=ADMIN_USER,
+ expdb_db=expdb_test,
+ )
+ assert str(task_id) in e.value.detail
+ task_not_found_in_tag_endpoint = 472
+ assert e.value.code == task_not_found_in_tag_endpoint
+
+
+# -- migration tests --
+
+
+@pytest.mark.mut
+@pytest.mark.parametrize(
+ "task_id",
+ [
+ *range(1, 10),
+ 101,
+ constants.SOME_DEACTIVATED_DATASET_ID,
+ constants.DATASET_ID_THAT_DOES_NOT_EXIST,
+ ],
+)
+@pytest.mark.parametrize(
+ "api_key",
+ [ApiKey.ADMIN, ApiKey.SOME_USER, ApiKey.OWNER_USER],
+ ids=["Administrator", "regular user", "possible owner"],
+)
+@pytest.mark.parametrize(
+ "tag",
+ ["study_14", "totally_new_tag_for_migration_testing"],
+ ids=["typically existing tag", "new tag"],
+)
+async def test_task_tag_response_is_identical(
+ task_id: int,
+ tag: str,
+ api_key: str,
+ py_api: httpx.AsyncClient,
+ php_api: httpx.AsyncClient,
+) -> None:
+ # PHP request must happen first to check state, can't parallelize
+ php_response = await php_api.post(
+ "/task/tag",
+ data={"api_key": api_key, "tag": tag, "task_id": task_id},
+ )
+ already_tagged = (
+ php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
+ and "already tagged" in php_response.json()["error"]["message"]
+ )
+ if not already_tagged:
+ # undo the tag, because we don't want to persist this change to the taskbase
+ # Sometimes a change is already committed to the taskbase even if an error occurs.
+ await php_api.post(
+ "/task/untag",
+ data={"api_key": api_key, "tag": tag, "task_id": task_id},
+ )
+ if (
+ php_response.status_code != HTTPStatus.OK
+ and php_response.json()["error"]["message"] == "An Elastic Search Exception occured."
+ ):
+ pytest.skip("Encountered Elastic Search error.")
+
+ py_response = await py_api.post(
+ f"/tasks/tag?api_key={api_key}",
+ json={"task_id": task_id, "tag": tag},
+ )
+
+ # RFC 9457: Tag conflict now returns 409 instead of 500
+ if php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR and already_tagged:
+ assert py_response.status_code == HTTPStatus.CONFLICT
+ assert py_response.json()["code"] == php_response.json()["error"]["code"]
+ assert php_response.json()["error"]["message"] == "Entity already tagged by this tag."
+ assert re.match(
+ pattern=r"Task \d+ already tagged with " + f"'{tag}'.",
+ string=py_response.json()["detail"],
+ )
+ return
+
+ if py_response.status_code == HTTPStatus.NOT_FOUND:
+ assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED
+ py_error = py_response.json()
+ php_error = php_response.json()["error"]
+ assert py_error["code"] == php_error["code"]
+ assert php_error["message"] == "Entity not found."
+ assert re.match(r"Task \d+ not found.", py_error["detail"])
+ return
+
+ assert py_response.status_code == php_response.status_code, php_response.json()
+ if py_response.status_code != HTTPStatus.OK:
+ assert py_response.json()["code"] == php_response.json()["error"]["code"]
+ assert py_response.json()["detail"] == php_response.json()["error"]["message"]
+ return
+
+ php_json = php_response.json()
+ py_json = py_response.json()
+ py_json = nested_remove_single_element_list(py_json)
+ assert py_json == php_json
From 3db1974e48053efb4a7289bcf0ebf1121da34a32 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Thu, 25 Jun 2026 11:07:18 +0200
Subject: [PATCH 03/11] Toward unifying tag tests
---
tests/routers/openml/dataset_tag_test.py | 61 +-----------------
tests/routers/openml/tag_test_helper.py | 79 ++++++++++++++++++++++++
tests/routers/openml/task_tag_test.py | 63 +------------------
3 files changed, 84 insertions(+), 119 deletions(-)
create mode 100644 tests/routers/openml/tag_test_helper.py
diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py
index bb12a74..16f719a 100644
--- a/tests/routers/openml/dataset_tag_test.py
+++ b/tests/routers/openml/dataset_tag_test.py
@@ -1,15 +1,14 @@
-import re
from http import HTTPStatus
from typing import TYPE_CHECKING
import pytest
-from core.conversions import nested_remove_single_element_list
from core.errors import DatasetNotFoundError, TagAlreadyExistsError
from database.datasets import get_tags_for
from database.users import User
from routers.openml.datasets import tag_dataset
from tests import constants
+from tests.routers.openml.tag_test_helper import assert_tag_response_is_identical
from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey
if TYPE_CHECKING:
@@ -124,60 +123,4 @@ async def test_dataset_tag_response_is_identical(
py_api: httpx.AsyncClient,
php_api: httpx.AsyncClient,
) -> None:
- # PHP request must happen first to check state, can't parallelize
- php_response = await php_api.post(
- "/data/tag",
- data={"api_key": api_key, "tag": tag, "data_id": dataset_id},
- )
- already_tagged = (
- php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
- and "already tagged" in php_response.json()["error"]["message"]
- )
- if not already_tagged:
- # undo the tag, because we don't want to persist this change to the database
- # Sometimes a change is already committed to the database even if an error occurs.
- await php_api.post(
- "/data/untag",
- data={"api_key": api_key, "tag": tag, "data_id": dataset_id},
- )
- if (
- php_response.status_code != HTTPStatus.OK
- and php_response.json()["error"]["message"] == "An Elastic Search Exception occured."
- ):
- pytest.skip("Encountered Elastic Search error.")
-
- py_response = await py_api.post(
- f"/datasets/tag?api_key={api_key}",
- json={"data_id": dataset_id, "tag": tag},
- )
-
- # RFC 9457: Tag conflict now returns 409 instead of 500
- if php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR and already_tagged:
- assert py_response.status_code == HTTPStatus.CONFLICT
- assert py_response.json()["code"] == php_response.json()["error"]["code"]
- assert php_response.json()["error"]["message"] == "Entity already tagged by this tag."
- assert re.match(
- pattern=r"Dataset \d+ already tagged with " + f"'{tag}'.",
- string=py_response.json()["detail"],
- )
- return
-
- if py_response.status_code == HTTPStatus.NOT_FOUND:
- assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED
- py_error = py_response.json()
- php_error = php_response.json()["error"]
- assert py_error["code"] == php_error["code"]
- assert php_error["message"] == "Entity not found."
- assert re.match(r"Dataset \d+ not found.", py_error["detail"])
- return
-
- assert py_response.status_code == php_response.status_code, php_response.json()
- if py_response.status_code != HTTPStatus.OK:
- assert py_response.json()["code"] == php_response.json()["error"]["code"]
- assert py_response.json()["detail"] == php_response.json()["error"]["message"]
- return
-
- php_json = php_response.json()
- py_json = py_response.json()
- py_json = nested_remove_single_element_list(py_json)
- assert py_json == php_json
+ await assert_tag_response_is_identical(dataset_id, tag, api_key, "dataset", py_api, php_api)
diff --git a/tests/routers/openml/tag_test_helper.py b/tests/routers/openml/tag_test_helper.py
new file mode 100644
index 0000000..4c4a99d
--- /dev/null
+++ b/tests/routers/openml/tag_test_helper.py
@@ -0,0 +1,79 @@
+import re
+from http import HTTPStatus
+from typing import TYPE_CHECKING
+
+import pytest
+
+from core.conversions import nested_remove_single_element_list
+
+if TYPE_CHECKING:
+ import httpx
+
+
+async def assert_tag_response_is_identical( # noqa: PLR0913
+ identifier: int,
+ tag: str,
+ api_key: str,
+ entity: str,
+ py_api: httpx.AsyncClient,
+ php_api: httpx.AsyncClient,
+) -> None:
+ php_alias = "data" if entity == "dataset" else entity
+ # PHP request must happen first to check state, can't parallelize
+ php_response = await php_api.post(
+ f"/{php_alias}/tag",
+ data={"api_key": api_key, "tag": tag, f"{php_alias}_id": identifier},
+ )
+ already_tagged = (
+ php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
+ and "already tagged" in php_response.json()["error"]["message"]
+ )
+ if not already_tagged:
+ # undo the tag, because we don't want to persist this change to the taskbase
+ # Sometimes a change is already committed to the taskbase even if an error occurs.
+ await php_api.post(
+ f"/{php_alias}/untag",
+ data={"api_key": api_key, "tag": tag, f"{php_alias}_id": identifier},
+ )
+ if (
+ php_response.status_code != HTTPStatus.OK
+ and php_response.json()["error"]["message"] == "An Elastic Search Exception occured."
+ ):
+ pytest.skip("Encountered Elastic Search error.")
+
+ entity_plural = f"{entity}s"
+ py_response = await py_api.post(
+ f"/{entity_plural}/tag?api_key={api_key}",
+ json={f"{php_alias}_id": identifier, "tag": tag},
+ )
+
+ # RFC 9457: Tag conflict now returns 409 instead of 500
+ if php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR and already_tagged:
+ assert py_response.status_code == HTTPStatus.CONFLICT
+ assert py_response.json()["code"] == php_response.json()["error"]["code"]
+ assert php_response.json()["error"]["message"] == "Entity already tagged by this tag."
+ assert re.match(
+ pattern=rf"{entity.capitalize()} \d+ already tagged with " + f"'{tag}'.",
+ string=py_response.json()["detail"],
+ )
+ return
+
+ if py_response.status_code == HTTPStatus.NOT_FOUND:
+ assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED
+ py_error = py_response.json()
+ php_error = php_response.json()["error"]
+ assert py_error["code"] == php_error["code"]
+ assert php_error["message"] == "Entity not found."
+ assert re.match(rf"{entity.capitalize()} \d+ not found.", py_error["detail"])
+ return
+
+ assert py_response.status_code == php_response.status_code, php_response.json()
+ if py_response.status_code != HTTPStatus.OK:
+ assert py_response.json()["code"] == php_response.json()["error"]["code"]
+ assert py_response.json()["detail"] == php_response.json()["error"]["message"]
+ return
+
+ php_json = php_response.json()
+ py_json = py_response.json()
+ py_json = nested_remove_single_element_list(py_json)
+ assert py_json == php_json
diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py
index b42bb89..98bf457 100644
--- a/tests/routers/openml/task_tag_test.py
+++ b/tests/routers/openml/task_tag_test.py
@@ -1,15 +1,14 @@
-import re
from http import HTTPStatus
from typing import TYPE_CHECKING
import pytest
-from core.conversions import nested_remove_single_element_list
from core.errors import TagAlreadyExistsError, TaskNotFoundError
from database.tasks import get_tags
from database.users import User
from routers.openml.tasks import tag_task
from tests import constants
+from tests.routers.openml.tag_test_helper import assert_tag_response_is_identical
from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey
if TYPE_CHECKING:
@@ -114,7 +113,7 @@ async def test_task_tag_fails_if_task_does_not_exist(expdb_test: AsyncConnection
)
@pytest.mark.parametrize(
"tag",
- ["study_14", "totally_new_tag_for_migration_testing"],
+ ["OpenML100", "totally_new_tag_for_migration_testing"],
ids=["typically existing tag", "new tag"],
)
async def test_task_tag_response_is_identical(
@@ -124,60 +123,4 @@ async def test_task_tag_response_is_identical(
py_api: httpx.AsyncClient,
php_api: httpx.AsyncClient,
) -> None:
- # PHP request must happen first to check state, can't parallelize
- php_response = await php_api.post(
- "/task/tag",
- data={"api_key": api_key, "tag": tag, "task_id": task_id},
- )
- already_tagged = (
- php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
- and "already tagged" in php_response.json()["error"]["message"]
- )
- if not already_tagged:
- # undo the tag, because we don't want to persist this change to the taskbase
- # Sometimes a change is already committed to the taskbase even if an error occurs.
- await php_api.post(
- "/task/untag",
- data={"api_key": api_key, "tag": tag, "task_id": task_id},
- )
- if (
- php_response.status_code != HTTPStatus.OK
- and php_response.json()["error"]["message"] == "An Elastic Search Exception occured."
- ):
- pytest.skip("Encountered Elastic Search error.")
-
- py_response = await py_api.post(
- f"/tasks/tag?api_key={api_key}",
- json={"task_id": task_id, "tag": tag},
- )
-
- # RFC 9457: Tag conflict now returns 409 instead of 500
- if php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR and already_tagged:
- assert py_response.status_code == HTTPStatus.CONFLICT
- assert py_response.json()["code"] == php_response.json()["error"]["code"]
- assert php_response.json()["error"]["message"] == "Entity already tagged by this tag."
- assert re.match(
- pattern=r"Task \d+ already tagged with " + f"'{tag}'.",
- string=py_response.json()["detail"],
- )
- return
-
- if py_response.status_code == HTTPStatus.NOT_FOUND:
- assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED
- py_error = py_response.json()
- php_error = php_response.json()["error"]
- assert py_error["code"] == php_error["code"]
- assert php_error["message"] == "Entity not found."
- assert re.match(r"Task \d+ not found.", py_error["detail"])
- return
-
- assert py_response.status_code == php_response.status_code, php_response.json()
- if py_response.status_code != HTTPStatus.OK:
- assert py_response.json()["code"] == php_response.json()["error"]["code"]
- assert py_response.json()["detail"] == php_response.json()["error"]["message"]
- return
-
- php_json = php_response.json()
- py_json = py_response.json()
- py_json = nested_remove_single_element_list(py_json)
- assert py_json == php_json
+ await assert_tag_response_is_identical(task_id, tag, api_key, "task", py_api, php_api)
From e43b177c629c6189d6ba98fb815be969fa1ab0e4 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Thu, 25 Jun 2026 15:57:18 +0200
Subject: [PATCH 04/11] Refactor the tag task tests
---
tests/conftest.py | 37 ++++++++++++++++-
tests/routers/openml/task_tag_test.py | 58 +++++++++++----------------
2 files changed, 59 insertions(+), 36 deletions(-)
diff --git a/tests/conftest.py b/tests/conftest.py
index 23a5295..bd8cb5b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,8 +1,8 @@
import contextlib
import json
-from collections.abc import AsyncIterator, Callable, Iterable, Iterator
+from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator
from pathlib import Path
-from typing import TYPE_CHECKING, Any, NamedTuple
+from typing import TYPE_CHECKING, Any, NamedTuple, Protocol
import _pytest.mark
import httpx
@@ -22,6 +22,7 @@
from database.setup import expdb_database, user_database
from main import create_api
from routers.dependencies import expdb_connection, userdb_connection
+from routers.types import Identifier
from tests.users import OWNER_USER
if TYPE_CHECKING:
@@ -139,6 +140,38 @@ def dataset_130() -> Iterator[dict[str, Any]]:
yield json.load(dataset_file)
+class Task(NamedTuple):
+ """To be replaced by an actual ORM class."""
+
+ id: Identifier
+ task_type: Identifier
+ creator: Identifier
+
+
+class TaskFactory(Protocol):
+ def __call__(
+ self, *, task_id: Identifier = 42_000, task_type: Identifier = 1, creator: Identifier = 1
+ ) -> Awaitable[Task]: ...
+
+
+@pytest.fixture
+async def task_factory(
+ expdb_test: AsyncConnection,
+) -> TaskFactory:
+ async def create_task(
+ *, task_id: Identifier = 42_000, task_type: Identifier = 1, creator: Identifier = 1
+ ) -> Task:
+ await expdb_test.execute(
+ text("""
+ INSERT INTO task (task_id, ttid, creator) VALUES (:task_id, :ttid, :creator);
+ """),
+ parameters={"task_id": task_id, "ttid": task_type, "creator": creator},
+ )
+ return Task(task_id, task_type, creator)
+
+ return create_task
+
+
class Flow(NamedTuple):
"""To be replaced by an actual ORM class."""
diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py
index 98bf457..35ab120 100644
--- a/tests/routers/openml/task_tag_test.py
+++ b/tests/routers/openml/task_tag_test.py
@@ -8,6 +8,7 @@
from database.users import User
from routers.openml.tasks import tag_task
from tests import constants
+from tests.conftest import TaskFactory
from tests.routers.openml.tag_test_helper import assert_tag_response_is_identical
from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey
@@ -39,55 +40,44 @@ async def test_task_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.AsyncCli
[ADMIN_USER, SOME_USER, OWNER_USER],
ids=["administrator", "non-owner", "owner"],
)
-async def test_task_tag(user: User, expdb_test: AsyncConnection) -> None:
- task_id, tag = 2, "test"
- result = await tag_task(
- task_id=task_id,
- tag=tag,
- user=user,
- expdb_db=expdb_test,
- )
- assert result == {"task_tag": {"id": str(task_id), "tag": [tag]}}
+async def test_task_tag(user: User, expdb_test: AsyncConnection, task_factory: TaskFactory) -> None:
+ tag = "test_task_tag"
+ task = await task_factory()
+ result = await tag_task(task_id=task.id, tag=tag, user=user, expdb_db=expdb_test)
+ assert result == {"task_tag": {"id": str(task.id), "tag": [tag]}}
- tags = await get_tags(id_=task_id, connection=expdb_test)
+ tags = await get_tags(id_=task.id, connection=expdb_test)
assert tag in tags
@pytest.mark.mut
-async def test_task_tag_returns_existing_tags(expdb_test: AsyncConnection) -> None:
- task_id, tag = 1, "test" # Task 1 already is tagged with 'OpenML100'
- result = await tag_task(
- task_id=task_id,
- tag=tag,
- user=ADMIN_USER,
- expdb_db=expdb_test,
- )
- assert result == {"task_tag": {"id": str(task_id), "tag": ["OpenML100", tag]}}
+async def test_task_tag_returns_existing_tags(
+ task_factory: TaskFactory, expdb_test: AsyncConnection
+) -> None:
+ task = await task_factory()
+ await tag_task(task_id=task.id, tag="first", user=ADMIN_USER, expdb_db=expdb_test)
+ result = await tag_task(task_id=task.id, tag="second", user=ADMIN_USER, expdb_db=expdb_test)
+ assert result == {"task_tag": {"id": str(task.id), "tag": ["first", "second"]}}
@pytest.mark.mut
-async def test_task_tag_fails_if_tag_exists(expdb_test: AsyncConnection) -> None:
- task_id, tag = 1, "OpenML100" # Task 1 already is tagged with 'OpenML100'
+async def test_task_tag_fails_if_tag_exists(
+ expdb_test: AsyncConnection, task_factory: TaskFactory
+) -> None:
+ tag = "fails_if_exist"
+ task = await task_factory()
+ await tag_task(task_id=task.id, tag=tag, user=ADMIN_USER, expdb_db=expdb_test)
+
with pytest.raises(TagAlreadyExistsError) as e:
- await tag_task(
- task_id=task_id,
- tag=tag,
- user=ADMIN_USER,
- expdb_db=expdb_test,
- )
- assert str(task_id) in e.value.detail
+ await tag_task(task_id=task.id, tag=tag, user=ADMIN_USER, expdb_db=expdb_test)
+ assert str(task.id) in e.value.detail
assert tag in e.value.detail
async def test_task_tag_fails_if_task_does_not_exist(expdb_test: AsyncConnection) -> None:
task_id = 1_000_000
with pytest.raises(TaskNotFoundError) as e:
- await tag_task(
- task_id=task_id,
- tag="foo",
- user=ADMIN_USER,
- expdb_db=expdb_test,
- )
+ await tag_task(task_id=task_id, tag="foo", user=ADMIN_USER, expdb_db=expdb_test)
assert str(task_id) in e.value.detail
task_not_found_in_tag_endpoint = 472
assert e.value.code == task_not_found_in_tag_endpoint
From 25992b2320befa2e8aed61a6d9ce2f7138891db7 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Thu, 25 Jun 2026 16:19:28 +0200
Subject: [PATCH 05/11] Refactor dataset tag tests to not rely on database
state
---
tests/conftest.py | 44 +++++++++++++++++++++--
tests/routers/openml/dataset_tag_test.py | 45 ++++++++++++------------
2 files changed, 64 insertions(+), 25 deletions(-)
diff --git a/tests/conftest.py b/tests/conftest.py
index bd8cb5b..e604f0f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,4 +1,5 @@
import contextlib
+import datetime
import json
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator
from pathlib import Path
@@ -150,7 +151,11 @@ class Task(NamedTuple):
class TaskFactory(Protocol):
def __call__(
- self, *, task_id: Identifier = 42_000, task_type: Identifier = 1, creator: Identifier = 1
+ self,
+ *,
+ task_id: Identifier = 42_000,
+ task_type: Identifier = 1,
+ creator: Identifier = OWNER_USER.user_id,
) -> Awaitable[Task]: ...
@@ -159,7 +164,10 @@ async def task_factory(
expdb_test: AsyncConnection,
) -> TaskFactory:
async def create_task(
- *, task_id: Identifier = 42_000, task_type: Identifier = 1, creator: Identifier = 1
+ *,
+ task_id: Identifier = 42_000,
+ task_type: Identifier = 1,
+ creator: Identifier = OWNER_USER.user_id,
) -> Task:
await expdb_test.execute(
text("""
@@ -172,6 +180,38 @@ async def create_task(
return create_task
+class DatasetFactory(Protocol):
+ def __call__(
+ self, *, dataset_id: Identifier = 42_000, creator: Identifier = OWNER_USER.user_id
+ ) -> Awaitable[Identifier]: ...
+
+
+@pytest.fixture
+async def dataset_factory(
+ expdb_test: AsyncConnection,
+) -> DatasetFactory:
+ async def create_dataset(
+ *, dataset_id: Identifier = 42_000, creator: Identifier = OWNER_USER.user_id
+ ) -> Identifier:
+ await expdb_test.execute(
+ text("""
+ INSERT INTO dataset
+ (did, uploader, name, version, format, upload_date, licence, url, visibility)
+ VALUES
+ (:dataset_id, :creator, 'dataset-name', 'dataset-version', 'dataset-format',
+ :now, 'public', 'dataset-url', 'public');
+ """),
+ parameters={
+ "dataset_id": dataset_id,
+ "creator": creator,
+ "now": datetime.datetime.now(tz=datetime.UTC),
+ },
+ )
+ return dataset_id
+
+ return create_dataset
+
+
class Flow(NamedTuple):
"""To be replaced by an actual ORM class."""
diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py
index 16f719a..307927e 100644
--- a/tests/routers/openml/dataset_tag_test.py
+++ b/tests/routers/openml/dataset_tag_test.py
@@ -8,6 +8,7 @@
from database.users import User
from routers.openml.datasets import tag_dataset
from tests import constants
+from tests.conftest import DatasetFactory
from tests.routers.openml.tag_test_helper import assert_tag_response_is_identical
from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey
@@ -39,14 +40,12 @@ async def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.Async
[ADMIN_USER, SOME_USER, OWNER_USER],
ids=["administrator", "non-owner", "owner"],
)
-async def test_dataset_tag(user: User, expdb_test: AsyncConnection) -> None:
- dataset_id, tag = next(iter(constants.PRIVATE_DATASET_ID)), "test"
- result = await tag_dataset(
- data_id=dataset_id,
- tag=tag,
- user=user,
- expdb_db=expdb_test,
- )
+async def test_dataset_tag(
+ user: User, expdb_test: AsyncConnection, dataset_factory: DatasetFactory
+) -> None:
+ dataset_id = await dataset_factory()
+ tag = "test_tag"
+ result = await tag_dataset(data_id=dataset_id, tag=tag, user=user, expdb_db=expdb_test)
assert result == {"data_tag": {"id": str(dataset_id), "tag": [tag]}}
tags = await get_tags_for(id_=dataset_id, connection=expdb_test)
@@ -54,27 +53,27 @@ async def test_dataset_tag(user: User, expdb_test: AsyncConnection) -> None:
@pytest.mark.mut
-async def test_dataset_tag_returns_existing_tags(expdb_test: AsyncConnection) -> None:
- dataset_id, tag = 1, "test" # Dataset 1 already is tagged with 'study_14'
+async def test_dataset_tag_returns_existing_tags(
+ expdb_test: AsyncConnection, dataset_factory: DatasetFactory
+) -> None:
+ dataset_id = await dataset_factory()
+ await tag_dataset(data_id=dataset_id, tag="first", user=OWNER_USER, expdb_db=expdb_test)
result = await tag_dataset(
- data_id=dataset_id,
- tag=tag,
- user=ADMIN_USER,
- expdb_db=expdb_test,
+ data_id=dataset_id, tag="second", user=ADMIN_USER, expdb_db=expdb_test
)
- assert result == {"data_tag": {"id": str(dataset_id), "tag": ["study_14", tag]}}
+ assert result == {"data_tag": {"id": str(dataset_id), "tag": ["first", "second"]}}
@pytest.mark.mut
-async def test_dataset_tag_fails_if_tag_exists(expdb_test: AsyncConnection) -> None:
- dataset_id, tag = 1, "study_14" # Dataset 1 already is tagged with 'study_14'
+async def test_dataset_tag_fails_if_tag_exists(
+ expdb_test: AsyncConnection, dataset_factory: DatasetFactory
+) -> None:
+ tag = "repeated_tag"
+ dataset_id = await dataset_factory()
+ await tag_dataset(data_id=dataset_id, tag=tag, user=OWNER_USER, expdb_db=expdb_test)
+
with pytest.raises(TagAlreadyExistsError) as e:
- await tag_dataset(
- data_id=dataset_id,
- tag=tag,
- user=ADMIN_USER,
- expdb_db=expdb_test,
- )
+ await tag_dataset(data_id=dataset_id, tag=tag, user=ADMIN_USER, expdb_db=expdb_test)
assert str(dataset_id) in e.value.detail
assert tag in e.value.detail
From 3078682fda368bfa6b37061e226282b1c4ea0436 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Thu, 25 Jun 2026 16:28:38 +0200
Subject: [PATCH 06/11] indicate that the identifier does not matter to the
test
---
tests/routers/openml/dataset_tag_test.py | 3 ++-
tests/routers/openml/task_tag_test.py | 3 ++-
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py
index 307927e..42ef658 100644
--- a/tests/routers/openml/dataset_tag_test.py
+++ b/tests/routers/openml/dataset_tag_test.py
@@ -24,9 +24,10 @@
)
async def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.AsyncClient) -> None:
apikey = "" if key is None else f"?api_key={key}"
+ any_dataset_identifier = 1
response = await py_api.post(
f"/datasets/tag{apikey}",
- json={"data_id": next(iter(constants.PRIVATE_DATASET_ID)), "tag": "test"},
+ json={"data_id": any_dataset_identifier, "tag": "test"},
)
assert response.status_code == HTTPStatus.UNAUTHORIZED
diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py
index 35ab120..ba76b61 100644
--- a/tests/routers/openml/task_tag_test.py
+++ b/tests/routers/openml/task_tag_test.py
@@ -24,9 +24,10 @@
)
async def test_task_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.AsyncClient) -> None:
apikey = "" if key is None else f"?api_key={key}"
+ any_task_id = 1
response = await py_api.post(
f"/tasks/tag{apikey}",
- json={"task_id": next(iter(constants.PRIVATE_DATASET_ID)), "tag": "test"},
+ json={"task_id": any_task_id, "tag": "test"},
)
assert response.status_code == HTTPStatus.UNAUTHORIZED
From 0396bca80e9cff37e1e00321adfc7a34fb878e93 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Thu, 25 Jun 2026 16:30:32 +0200
Subject: [PATCH 07/11] generalize the name since it should be valid for all
entities
---
tests/constants.py | 2 +-
tests/routers/openml/dataset_tag_test.py | 2 +-
tests/routers/openml/task_tag_test.py | 3 +--
3 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/tests/constants.py b/tests/constants.py
index 5563dec..991c2e9 100644
--- a/tests/constants.py
+++ b/tests/constants.py
@@ -1,4 +1,4 @@
-DATASET_ID_THAT_DOES_NOT_EXIST = 9_9999_999
+ENTITY_ID_THAT_DOES_NOT_EXIST = 9_9999_999
SOME_PRIVATE_DATASET_ID = 130
PRIVATE_DATASET_ID = {130}
IN_PREPARATION_ID = {33, 161, 162, 163}
diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py
index 42ef658..dd80966 100644
--- a/tests/routers/openml/dataset_tag_test.py
+++ b/tests/routers/openml/dataset_tag_test.py
@@ -103,7 +103,7 @@ async def test_dataset_tag_fails_if_dataset_does_not_exist(expdb_test: AsyncConn
*range(1, 10),
101,
constants.SOME_DEACTIVATED_DATASET_ID,
- constants.DATASET_ID_THAT_DOES_NOT_EXIST,
+ constants.ENTITY_ID_THAT_DOES_NOT_EXIST,
],
)
@pytest.mark.parametrize(
diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py
index ba76b61..a2d9fc9 100644
--- a/tests/routers/openml/task_tag_test.py
+++ b/tests/routers/openml/task_tag_test.py
@@ -93,8 +93,7 @@ async def test_task_tag_fails_if_task_does_not_exist(expdb_test: AsyncConnection
[
*range(1, 10),
101,
- constants.SOME_DEACTIVATED_DATASET_ID,
- constants.DATASET_ID_THAT_DOES_NOT_EXIST,
+ constants.ENTITY_ID_THAT_DOES_NOT_EXIST,
],
)
@pytest.mark.parametrize(
From d7a406a3b39633dd8e2637a2bf5e8e1810aadf0e Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Fri, 26 Jun 2026 09:12:21 +0200
Subject: [PATCH 08/11] remove dead code
---
tests/routers/openml/tag_test_helper.py | 4 ----
1 file changed, 4 deletions(-)
diff --git a/tests/routers/openml/tag_test_helper.py b/tests/routers/openml/tag_test_helper.py
index 4c4a99d..ba3560d 100644
--- a/tests/routers/openml/tag_test_helper.py
+++ b/tests/routers/openml/tag_test_helper.py
@@ -68,10 +68,6 @@ async def assert_tag_response_is_identical( # noqa: PLR0913
return
assert py_response.status_code == php_response.status_code, php_response.json()
- if py_response.status_code != HTTPStatus.OK:
- assert py_response.json()["code"] == php_response.json()["error"]["code"]
- assert py_response.json()["detail"] == php_response.json()["error"]["message"]
- return
php_json = php_response.json()
py_json = py_response.json()
From c11d5aece94b799d9d06800fff477c5a69689b89 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Fri, 26 Jun 2026 09:14:39 +0200
Subject: [PATCH 09/11] Make task not found in tag error code a constant
---
src/core/errors.py | 2 ++
tests/routers/openml/task_tag_test.py | 4 ++--
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/src/core/errors.py b/src/core/errors.py
index f8de607..2ae6b6f 100644
--- a/src/core/errors.py
+++ b/src/core/errors.py
@@ -251,6 +251,8 @@ class AccountHasResourcesError(ProblemDetailError):
# Tag Errors
# =============================================================================
+TASK_NOT_FOUND_DURING_TAG = 472
+
class TagAlreadyExistsError(ProblemDetailError):
"""Raised when trying to add a tag that already exists."""
diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py
index a2d9fc9..087f76e 100644
--- a/tests/routers/openml/task_tag_test.py
+++ b/tests/routers/openml/task_tag_test.py
@@ -3,7 +3,7 @@
import pytest
-from core.errors import TagAlreadyExistsError, TaskNotFoundError
+from core.errors import TASK_NOT_FOUND_DURING_TAG, TagAlreadyExistsError, TaskNotFoundError
from database.tasks import get_tags
from database.users import User
from routers.openml.tasks import tag_task
@@ -80,7 +80,7 @@ async def test_task_tag_fails_if_task_does_not_exist(expdb_test: AsyncConnection
with pytest.raises(TaskNotFoundError) as e:
await tag_task(task_id=task_id, tag="foo", user=ADMIN_USER, expdb_db=expdb_test)
assert str(task_id) in e.value.detail
- task_not_found_in_tag_endpoint = 472
+ task_not_found_in_tag_endpoint = TASK_NOT_FOUND_DURING_TAG
assert e.value.code == task_not_found_in_tag_endpoint
From e19bd0d77874e34b18ee0f35ad988b3c3bfe1714 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Fri, 26 Jun 2026 09:27:12 +0200
Subject: [PATCH 10/11] Make the dataset and task factories callable multiple
times
---
tests/conftest.py | 24 +++++++++++++++++++++---
1 file changed, 21 insertions(+), 3 deletions(-)
diff --git a/tests/conftest.py b/tests/conftest.py
index e604f0f..4191c47 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -159,16 +159,32 @@ def __call__(
) -> Awaitable[Task]: ...
+def _create_identifier_factory() -> Callable[[], Identifier]:
+ _identifier_counter: Identifier = 10_000_000
+
+ def _get() -> Identifier:
+ nonlocal _identifier_counter
+ _identifier_counter += 1
+ return _identifier_counter
+
+ return _get
+
+
+_identifier_factory = _create_identifier_factory()
+
+
@pytest.fixture
async def task_factory(
expdb_test: AsyncConnection,
) -> TaskFactory:
async def create_task(
*,
- task_id: Identifier = 42_000,
+ task_id: Identifier | None = None,
task_type: Identifier = 1,
creator: Identifier = OWNER_USER.user_id,
) -> Task:
+ task_id = task_id or _identifier_factory()
+
await expdb_test.execute(
text("""
INSERT INTO task (task_id, ttid, creator) VALUES (:task_id, :ttid, :creator);
@@ -191,20 +207,22 @@ async def dataset_factory(
expdb_test: AsyncConnection,
) -> DatasetFactory:
async def create_dataset(
- *, dataset_id: Identifier = 42_000, creator: Identifier = OWNER_USER.user_id
+ *, dataset_id: Identifier | None = None, creator: Identifier = OWNER_USER.user_id
) -> Identifier:
+ dataset_id = dataset_id or _identifier_factory()
await expdb_test.execute(
text("""
INSERT INTO dataset
(did, uploader, name, version, format, upload_date, licence, url, visibility)
VALUES
- (:dataset_id, :creator, 'dataset-name', 'dataset-version', 'dataset-format',
+ (:dataset_id, :creator, :name, 'dataset-version', 'dataset-format',
:now, 'public', 'dataset-url', 'public');
"""),
parameters={
"dataset_id": dataset_id,
"creator": creator,
"now": datetime.datetime.now(tz=datetime.UTC),
+ "name": f"dataset-name-{dataset_id}",
},
)
return dataset_id
From 46ae998450480fe50e4d7d3910762a22f5fa153f Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Fri, 26 Jun 2026 09:35:10 +0200
Subject: [PATCH 11/11] Simplify control flow, fix comments
---
tests/routers/openml/tag_test_helper.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/tests/routers/openml/tag_test_helper.py b/tests/routers/openml/tag_test_helper.py
index ba3560d..eb8fb97 100644
--- a/tests/routers/openml/tag_test_helper.py
+++ b/tests/routers/openml/tag_test_helper.py
@@ -29,8 +29,8 @@ async def assert_tag_response_is_identical( # noqa: PLR0913
and "already tagged" in php_response.json()["error"]["message"]
)
if not already_tagged:
- # undo the tag, because we don't want to persist this change to the taskbase
- # Sometimes a change is already committed to the taskbase even if an error occurs.
+ # undo the tag, because we don't want to persist this change to the database
+ # Sometimes a change is already committed to the database even if an error occurs.
await php_api.post(
f"/{php_alias}/untag",
data={"api_key": api_key, "tag": tag, f"{php_alias}_id": identifier},
@@ -47,8 +47,8 @@ async def assert_tag_response_is_identical( # noqa: PLR0913
json={f"{php_alias}_id": identifier, "tag": tag},
)
- # RFC 9457: Tag conflict now returns 409 instead of 500
- if php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR and already_tagged:
+ # RFC 9457: Tag conflict now returns 409 (CONFLICT) instead of 500 (INTERNAL SERVER ERROR)
+ if already_tagged:
assert py_response.status_code == HTTPStatus.CONFLICT
assert py_response.json()["code"] == php_response.json()["error"]["code"]
assert php_response.json()["error"]["message"] == "Entity already tagged by this tag."