diff --git a/src/core/errors.py b/src/core/errors.py index f8de607..2ae6b6f 100644 --- a/src/core/errors.py +++ b/src/core/errors.py @@ -251,6 +251,8 @@ class AccountHasResourcesError(ProblemDetailError): # Tag Errors # ============================================================================= +TASK_NOT_FOUND_DURING_TAG = 472 + class TagAlreadyExistsError(ProblemDetailError): """Raised when trying to add a tag that already exists.""" diff --git a/src/database/tasks.py b/src/database/tasks.py index 788fc93..b0f6010 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -2,8 +2,15 @@ from typing import TYPE_CHECKING, cast from sqlalchemy import Row, text +from sqlalchemy.exc import IntegrityError -from routers.types import Identifier +from database.exceptions import ( + _DUPLICATE_ENTRY, + _FOREIGN_KEY_CONSTRAINT_FAILED, + DuplicatePrimaryKeyError, + ForeignKeyConstraintError, +) +from routers.types import Identifier, TagString if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncConnection @@ -149,8 +156,8 @@ async def get_task_type_inout_with_template( ) -async def get_tags(id_: Identifier, expdb: AsyncConnection) -> list[str]: - rows = await expdb.execute( +async def get_tags(id_: Identifier, connection: AsyncConnection) -> list[str]: + rows = await connection.execute( text( """ SELECT `tag` @@ -162,3 +169,33 @@ async def get_tags(id_: Identifier, expdb: AsyncConnection) -> list[str]: ) tag_rows = rows.all() return [row.tag for row in tag_rows] + + +async def tag( + id_: Identifier, + tag_: TagString, + *, + user_id: Identifier, + connection: AsyncConnection, +) -> None: + try: + await connection.execute( + text( + """ + INSERT INTO task_tag(`id`, `tag`, `uploader`) + VALUES (:task_id, :tag, :user_id) + """, + ), + parameters={ + "task_id": id_, + "user_id": user_id, + "tag": tag_, + }, + ) + except IntegrityError as e: + code, msg = e.orig.args + if code == _FOREIGN_KEY_CONSTRAINT_FAILED: + raise ForeignKeyConstraintError(msg) from e + if code == _DUPLICATE_ENTRY: + raise DuplicatePrimaryKeyError(msg) from e + raise diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 8faa898..f68c4b8 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -6,13 +6,16 @@ import xmltodict from fastapi import APIRouter, Body, Depends +from loguru import logger from sqlalchemy import bindparam, text import database.datasets import database.tasks from config import get_config -from core.errors import InternalError, NoResultsError, TaskNotFoundError -from routers.dependencies import Pagination, expdb_connection +from core.errors import InternalError, NoResultsError, TagAlreadyExistsError, TaskNotFoundError +from database.exceptions import DuplicatePrimaryKeyError, ForeignKeyConstraintError +from database.users import User +from routers.dependencies import Pagination, expdb_connection, fetch_user_or_raise from routers.types import ( CasualString128, Identifier, @@ -31,6 +34,31 @@ type JSON = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None +@router.post(path="/tag") +async def tag_task( + task_id: Annotated[Identifier, Body()], + tag: Annotated[TagString, Body()], + user: Annotated[User, Depends(fetch_user_or_raise)], + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> dict[str, dict[str, Any]]: + try: + await database.tasks.tag(task_id, tag, user_id=user.user_id, connection=expdb_db) + except ForeignKeyConstraintError: + msg = f"Task {task_id} not found." + raise TaskNotFoundError(msg, code=472) from None + except DuplicatePrimaryKeyError: + msg = f"Task {task_id} already tagged with {tag!r}." + raise TagAlreadyExistsError(msg) from None + + logger.info("Task {task_id} tagged '{tag}'.", task_id=task_id, tag=tag) + + tags = await database.tasks.get_tags(task_id, expdb_db) + + return { + "task_tag": {"id": str(task_id), "tag": tags}, + } + + def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]: json_template = xmltodict.parse(xml_template.replace("oml:", "")) json_str = json.dumps(json_template) diff --git a/tests/conftest.py b/tests/conftest.py index 23a5295..4191c47 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,9 @@ import contextlib +import datetime import json -from collections.abc import AsyncIterator, Callable, Iterable, Iterator +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator from pathlib import Path -from typing import TYPE_CHECKING, Any, NamedTuple +from typing import TYPE_CHECKING, Any, NamedTuple, Protocol import _pytest.mark import httpx @@ -22,6 +23,7 @@ from database.setup import expdb_database, user_database from main import create_api from routers.dependencies import expdb_connection, userdb_connection +from routers.types import Identifier from tests.users import OWNER_USER if TYPE_CHECKING: @@ -139,6 +141,95 @@ def dataset_130() -> Iterator[dict[str, Any]]: yield json.load(dataset_file) +class Task(NamedTuple): + """To be replaced by an actual ORM class.""" + + id: Identifier + task_type: Identifier + creator: Identifier + + +class TaskFactory(Protocol): + def __call__( + self, + *, + task_id: Identifier = 42_000, + task_type: Identifier = 1, + creator: Identifier = OWNER_USER.user_id, + ) -> Awaitable[Task]: ... + + +def _create_identifier_factory() -> Callable[[], Identifier]: + _identifier_counter: Identifier = 10_000_000 + + def _get() -> Identifier: + nonlocal _identifier_counter + _identifier_counter += 1 + return _identifier_counter + + return _get + + +_identifier_factory = _create_identifier_factory() + + +@pytest.fixture +async def task_factory( + expdb_test: AsyncConnection, +) -> TaskFactory: + async def create_task( + *, + task_id: Identifier | None = None, + task_type: Identifier = 1, + creator: Identifier = OWNER_USER.user_id, + ) -> Task: + task_id = task_id or _identifier_factory() + + await expdb_test.execute( + text(""" + INSERT INTO task (task_id, ttid, creator) VALUES (:task_id, :ttid, :creator); + """), + parameters={"task_id": task_id, "ttid": task_type, "creator": creator}, + ) + return Task(task_id, task_type, creator) + + return create_task + + +class DatasetFactory(Protocol): + def __call__( + self, *, dataset_id: Identifier = 42_000, creator: Identifier = OWNER_USER.user_id + ) -> Awaitable[Identifier]: ... + + +@pytest.fixture +async def dataset_factory( + expdb_test: AsyncConnection, +) -> DatasetFactory: + async def create_dataset( + *, dataset_id: Identifier | None = None, creator: Identifier = OWNER_USER.user_id + ) -> Identifier: + dataset_id = dataset_id or _identifier_factory() + await expdb_test.execute( + text(""" + INSERT INTO dataset + (did, uploader, name, version, format, upload_date, licence, url, visibility) + VALUES + (:dataset_id, :creator, :name, 'dataset-version', 'dataset-format', + :now, 'public', 'dataset-url', 'public'); + """), + parameters={ + "dataset_id": dataset_id, + "creator": creator, + "now": datetime.datetime.now(tz=datetime.UTC), + "name": f"dataset-name-{dataset_id}", + }, + ) + return dataset_id + + return create_dataset + + class Flow(NamedTuple): """To be replaced by an actual ORM class.""" diff --git a/tests/constants.py b/tests/constants.py index 5563dec..991c2e9 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -1,4 +1,4 @@ -DATASET_ID_THAT_DOES_NOT_EXIST = 9_9999_999 +ENTITY_ID_THAT_DOES_NOT_EXIST = 9_9999_999 SOME_PRIVATE_DATASET_ID = 130 PRIVATE_DATASET_ID = {130} IN_PREPARATION_ID = {33, 161, 162, 163} diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index bb12a74..dd80966 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -1,15 +1,15 @@ -import re from http import HTTPStatus from typing import TYPE_CHECKING import pytest -from core.conversions import nested_remove_single_element_list from core.errors import DatasetNotFoundError, TagAlreadyExistsError from database.datasets import get_tags_for from database.users import User from routers.openml.datasets import tag_dataset from tests import constants +from tests.conftest import DatasetFactory +from tests.routers.openml.tag_test_helper import assert_tag_response_is_identical from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey if TYPE_CHECKING: @@ -24,9 +24,10 @@ ) async def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.AsyncClient) -> None: apikey = "" if key is None else f"?api_key={key}" + any_dataset_identifier = 1 response = await py_api.post( f"/datasets/tag{apikey}", - json={"data_id": next(iter(constants.PRIVATE_DATASET_ID)), "tag": "test"}, + json={"data_id": any_dataset_identifier, "tag": "test"}, ) assert response.status_code == HTTPStatus.UNAUTHORIZED @@ -40,14 +41,12 @@ async def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.Async [ADMIN_USER, SOME_USER, OWNER_USER], ids=["administrator", "non-owner", "owner"], ) -async def test_dataset_tag(user: User, expdb_test: AsyncConnection) -> None: - dataset_id, tag = next(iter(constants.PRIVATE_DATASET_ID)), "test" - result = await tag_dataset( - data_id=dataset_id, - tag=tag, - user=user, - expdb_db=expdb_test, - ) +async def test_dataset_tag( + user: User, expdb_test: AsyncConnection, dataset_factory: DatasetFactory +) -> None: + dataset_id = await dataset_factory() + tag = "test_tag" + result = await tag_dataset(data_id=dataset_id, tag=tag, user=user, expdb_db=expdb_test) assert result == {"data_tag": {"id": str(dataset_id), "tag": [tag]}} tags = await get_tags_for(id_=dataset_id, connection=expdb_test) @@ -55,27 +54,27 @@ async def test_dataset_tag(user: User, expdb_test: AsyncConnection) -> None: @pytest.mark.mut -async def test_dataset_tag_returns_existing_tags(expdb_test: AsyncConnection) -> None: - dataset_id, tag = 1, "test" # Dataset 1 already is tagged with 'study_14' +async def test_dataset_tag_returns_existing_tags( + expdb_test: AsyncConnection, dataset_factory: DatasetFactory +) -> None: + dataset_id = await dataset_factory() + await tag_dataset(data_id=dataset_id, tag="first", user=OWNER_USER, expdb_db=expdb_test) result = await tag_dataset( - data_id=dataset_id, - tag=tag, - user=ADMIN_USER, - expdb_db=expdb_test, + data_id=dataset_id, tag="second", user=ADMIN_USER, expdb_db=expdb_test ) - assert result == {"data_tag": {"id": str(dataset_id), "tag": ["study_14", tag]}} + assert result == {"data_tag": {"id": str(dataset_id), "tag": ["first", "second"]}} @pytest.mark.mut -async def test_dataset_tag_fails_if_tag_exists(expdb_test: AsyncConnection) -> None: - dataset_id, tag = 1, "study_14" # Dataset 1 already is tagged with 'study_14' +async def test_dataset_tag_fails_if_tag_exists( + expdb_test: AsyncConnection, dataset_factory: DatasetFactory +) -> None: + tag = "repeated_tag" + dataset_id = await dataset_factory() + await tag_dataset(data_id=dataset_id, tag=tag, user=OWNER_USER, expdb_db=expdb_test) + with pytest.raises(TagAlreadyExistsError) as e: - await tag_dataset( - data_id=dataset_id, - tag=tag, - user=ADMIN_USER, - expdb_db=expdb_test, - ) + await tag_dataset(data_id=dataset_id, tag=tag, user=ADMIN_USER, expdb_db=expdb_test) assert str(dataset_id) in e.value.detail assert tag in e.value.detail @@ -104,7 +103,7 @@ async def test_dataset_tag_fails_if_dataset_does_not_exist(expdb_test: AsyncConn *range(1, 10), 101, constants.SOME_DEACTIVATED_DATASET_ID, - constants.DATASET_ID_THAT_DOES_NOT_EXIST, + constants.ENTITY_ID_THAT_DOES_NOT_EXIST, ], ) @pytest.mark.parametrize( @@ -124,60 +123,4 @@ async def test_dataset_tag_response_is_identical( py_api: httpx.AsyncClient, php_api: httpx.AsyncClient, ) -> None: - # PHP request must happen first to check state, can't parallelize - php_response = await php_api.post( - "/data/tag", - data={"api_key": api_key, "tag": tag, "data_id": dataset_id}, - ) - already_tagged = ( - php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR - and "already tagged" in php_response.json()["error"]["message"] - ) - if not already_tagged: - # undo the tag, because we don't want to persist this change to the database - # Sometimes a change is already committed to the database even if an error occurs. - await php_api.post( - "/data/untag", - data={"api_key": api_key, "tag": tag, "data_id": dataset_id}, - ) - if ( - php_response.status_code != HTTPStatus.OK - and php_response.json()["error"]["message"] == "An Elastic Search Exception occured." - ): - pytest.skip("Encountered Elastic Search error.") - - py_response = await py_api.post( - f"/datasets/tag?api_key={api_key}", - json={"data_id": dataset_id, "tag": tag}, - ) - - # RFC 9457: Tag conflict now returns 409 instead of 500 - if php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR and already_tagged: - assert py_response.status_code == HTTPStatus.CONFLICT - assert py_response.json()["code"] == php_response.json()["error"]["code"] - assert php_response.json()["error"]["message"] == "Entity already tagged by this tag." - assert re.match( - pattern=r"Dataset \d+ already tagged with " + f"'{tag}'.", - string=py_response.json()["detail"], - ) - return - - if py_response.status_code == HTTPStatus.NOT_FOUND: - assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED - py_error = py_response.json() - php_error = php_response.json()["error"] - assert py_error["code"] == php_error["code"] - assert php_error["message"] == "Entity not found." - assert re.match(r"Dataset \d+ not found.", py_error["detail"]) - return - - assert py_response.status_code == php_response.status_code, php_response.json() - if py_response.status_code != HTTPStatus.OK: - assert py_response.json()["code"] == php_response.json()["error"]["code"] - assert py_response.json()["detail"] == php_response.json()["error"]["message"] - return - - php_json = php_response.json() - py_json = py_response.json() - py_json = nested_remove_single_element_list(py_json) - assert py_json == php_json + await assert_tag_response_is_identical(dataset_id, tag, api_key, "dataset", py_api, php_api) diff --git a/tests/routers/openml/tag_test_helper.py b/tests/routers/openml/tag_test_helper.py new file mode 100644 index 0000000..eb8fb97 --- /dev/null +++ b/tests/routers/openml/tag_test_helper.py @@ -0,0 +1,75 @@ +import re +from http import HTTPStatus +from typing import TYPE_CHECKING + +import pytest + +from core.conversions import nested_remove_single_element_list + +if TYPE_CHECKING: + import httpx + + +async def assert_tag_response_is_identical( # noqa: PLR0913 + identifier: int, + tag: str, + api_key: str, + entity: str, + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + php_alias = "data" if entity == "dataset" else entity + # PHP request must happen first to check state, can't parallelize + php_response = await php_api.post( + f"/{php_alias}/tag", + data={"api_key": api_key, "tag": tag, f"{php_alias}_id": identifier}, + ) + already_tagged = ( + php_response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + and "already tagged" in php_response.json()["error"]["message"] + ) + if not already_tagged: + # undo the tag, because we don't want to persist this change to the database + # Sometimes a change is already committed to the database even if an error occurs. + await php_api.post( + f"/{php_alias}/untag", + data={"api_key": api_key, "tag": tag, f"{php_alias}_id": identifier}, + ) + if ( + php_response.status_code != HTTPStatus.OK + and php_response.json()["error"]["message"] == "An Elastic Search Exception occured." + ): + pytest.skip("Encountered Elastic Search error.") + + entity_plural = f"{entity}s" + py_response = await py_api.post( + f"/{entity_plural}/tag?api_key={api_key}", + json={f"{php_alias}_id": identifier, "tag": tag}, + ) + + # RFC 9457: Tag conflict now returns 409 (CONFLICT) instead of 500 (INTERNAL SERVER ERROR) + if already_tagged: + assert py_response.status_code == HTTPStatus.CONFLICT + assert py_response.json()["code"] == php_response.json()["error"]["code"] + assert php_response.json()["error"]["message"] == "Entity already tagged by this tag." + assert re.match( + pattern=rf"{entity.capitalize()} \d+ already tagged with " + f"'{tag}'.", + string=py_response.json()["detail"], + ) + return + + if py_response.status_code == HTTPStatus.NOT_FOUND: + assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED + py_error = py_response.json() + php_error = php_response.json()["error"] + assert py_error["code"] == php_error["code"] + assert php_error["message"] == "Entity not found." + assert re.match(rf"{entity.capitalize()} \d+ not found.", py_error["detail"]) + return + + assert py_response.status_code == php_response.status_code, php_response.json() + + php_json = php_response.json() + py_json = py_response.json() + py_json = nested_remove_single_element_list(py_json) + assert py_json == php_json diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py new file mode 100644 index 0000000..087f76e --- /dev/null +++ b/tests/routers/openml/task_tag_test.py @@ -0,0 +1,116 @@ +from http import HTTPStatus +from typing import TYPE_CHECKING + +import pytest + +from core.errors import TASK_NOT_FOUND_DURING_TAG, TagAlreadyExistsError, TaskNotFoundError +from database.tasks import get_tags +from database.users import User +from routers.openml.tasks import tag_task +from tests import constants +from tests.conftest import TaskFactory +from tests.routers.openml.tag_test_helper import assert_tag_response_is_identical +from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey + +if TYPE_CHECKING: + import httpx + from sqlalchemy.ext.asyncio import AsyncConnection + + +@pytest.mark.parametrize( + "key", + [None, ApiKey.INVALID], + ids=["no authentication", "invalid key"], +) +async def test_task_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.AsyncClient) -> None: + apikey = "" if key is None else f"?api_key={key}" + any_task_id = 1 + response = await py_api.post( + f"/tasks/tag{apikey}", + json={"task_id": any_task_id, "tag": "test"}, + ) + assert response.status_code == HTTPStatus.UNAUTHORIZED + + +# ── Direct call tests: tag_task ── + + +@pytest.mark.mut +@pytest.mark.parametrize( + "user", + [ADMIN_USER, SOME_USER, OWNER_USER], + ids=["administrator", "non-owner", "owner"], +) +async def test_task_tag(user: User, expdb_test: AsyncConnection, task_factory: TaskFactory) -> None: + tag = "test_task_tag" + task = await task_factory() + result = await tag_task(task_id=task.id, tag=tag, user=user, expdb_db=expdb_test) + assert result == {"task_tag": {"id": str(task.id), "tag": [tag]}} + + tags = await get_tags(id_=task.id, connection=expdb_test) + assert tag in tags + + +@pytest.mark.mut +async def test_task_tag_returns_existing_tags( + task_factory: TaskFactory, expdb_test: AsyncConnection +) -> None: + task = await task_factory() + await tag_task(task_id=task.id, tag="first", user=ADMIN_USER, expdb_db=expdb_test) + result = await tag_task(task_id=task.id, tag="second", user=ADMIN_USER, expdb_db=expdb_test) + assert result == {"task_tag": {"id": str(task.id), "tag": ["first", "second"]}} + + +@pytest.mark.mut +async def test_task_tag_fails_if_tag_exists( + expdb_test: AsyncConnection, task_factory: TaskFactory +) -> None: + tag = "fails_if_exist" + task = await task_factory() + await tag_task(task_id=task.id, tag=tag, user=ADMIN_USER, expdb_db=expdb_test) + + with pytest.raises(TagAlreadyExistsError) as e: + await tag_task(task_id=task.id, tag=tag, user=ADMIN_USER, expdb_db=expdb_test) + assert str(task.id) in e.value.detail + assert tag in e.value.detail + + +async def test_task_tag_fails_if_task_does_not_exist(expdb_test: AsyncConnection) -> None: + task_id = 1_000_000 + with pytest.raises(TaskNotFoundError) as e: + await tag_task(task_id=task_id, tag="foo", user=ADMIN_USER, expdb_db=expdb_test) + assert str(task_id) in e.value.detail + task_not_found_in_tag_endpoint = TASK_NOT_FOUND_DURING_TAG + assert e.value.code == task_not_found_in_tag_endpoint + + +# -- migration tests -- + + +@pytest.mark.mut +@pytest.mark.parametrize( + "task_id", + [ + *range(1, 10), + 101, + constants.ENTITY_ID_THAT_DOES_NOT_EXIST, + ], +) +@pytest.mark.parametrize( + "api_key", + [ApiKey.ADMIN, ApiKey.SOME_USER, ApiKey.OWNER_USER], + ids=["Administrator", "regular user", "possible owner"], +) +@pytest.mark.parametrize( + "tag", + ["OpenML100", "totally_new_tag_for_migration_testing"], + ids=["typically existing tag", "new tag"], +) +async def test_task_tag_response_is_identical( + task_id: int, + tag: str, + api_key: str, + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + await assert_tag_response_is_identical(task_id, tag, api_key, "task", py_api, php_api)