Skip to content
2 changes: 2 additions & 0 deletions src/core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ class AccountHasResourcesError(ProblemDetailError):
# Tag Errors
# =============================================================================

TASK_NOT_FOUND_DURING_TAG = 472


class TagAlreadyExistsError(ProblemDetailError):
"""Raised when trying to add a tag that already exists."""
Expand Down
43 changes: 40 additions & 3 deletions src/database/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@
from typing import TYPE_CHECKING, cast

from sqlalchemy import Row, text
from sqlalchemy.exc import IntegrityError

from routers.types import Identifier
from database.exceptions import (
_DUPLICATE_ENTRY,
_FOREIGN_KEY_CONSTRAINT_FAILED,
DuplicatePrimaryKeyError,
ForeignKeyConstraintError,
)
from routers.types import Identifier, TagString

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection
Expand Down Expand Up @@ -149,8 +156,8 @@ async def get_task_type_inout_with_template(
)


async def get_tags(id_: Identifier, expdb: AsyncConnection) -> list[str]:
rows = await expdb.execute(
async def get_tags(id_: Identifier, connection: AsyncConnection) -> list[str]:
rows = await connection.execute(
text(
"""
SELECT `tag`
Expand All @@ -162,3 +169,33 @@ async def get_tags(id_: Identifier, expdb: AsyncConnection) -> list[str]:
)
tag_rows = rows.all()
return [row.tag for row in tag_rows]


async def tag(
id_: Identifier,
tag_: TagString,
*,
user_id: Identifier,
connection: AsyncConnection,
) -> None:
try:
await connection.execute(
text(
"""
INSERT INTO task_tag(`id`, `tag`, `uploader`)
VALUES (:task_id, :tag, :user_id)
""",
),
parameters={
"task_id": id_,
"user_id": user_id,
"tag": tag_,
},
)
except IntegrityError as e:
code, msg = e.orig.args
if code == _FOREIGN_KEY_CONSTRAINT_FAILED:
raise ForeignKeyConstraintError(msg) from e
if code == _DUPLICATE_ENTRY:
raise DuplicatePrimaryKeyError(msg) from e
raise
32 changes: 30 additions & 2 deletions src/routers/openml/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@

import xmltodict
from fastapi import APIRouter, Body, Depends
from loguru import logger
from sqlalchemy import bindparam, text

import database.datasets
import database.tasks
from config import get_config
from core.errors import InternalError, NoResultsError, TaskNotFoundError
from routers.dependencies import Pagination, expdb_connection
from core.errors import InternalError, NoResultsError, TagAlreadyExistsError, TaskNotFoundError
from database.exceptions import DuplicatePrimaryKeyError, ForeignKeyConstraintError
from database.users import User
from routers.dependencies import Pagination, expdb_connection, fetch_user_or_raise
from routers.types import (
CasualString128,
Identifier,
Expand All @@ -31,6 +34,31 @@
type JSON = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None


@router.post(path="/tag")
async def tag_task(
task_id: Annotated[Identifier, Body()],
tag: Annotated[TagString, Body()],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[str, dict[str, Any]]:
try:
await database.tasks.tag(task_id, tag, user_id=user.user_id, connection=expdb_db)
except ForeignKeyConstraintError:
msg = f"Task {task_id} not found."
raise TaskNotFoundError(msg, code=472) from None
except DuplicatePrimaryKeyError:
msg = f"Task {task_id} already tagged with {tag!r}."
raise TagAlreadyExistsError(msg) from None

logger.info("Task {task_id} tagged '{tag}'.", task_id=task_id, tag=tag)

tags = await database.tasks.get_tags(task_id, expdb_db)

return {
"task_tag": {"id": str(task_id), "tag": tags},
}


def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]:
json_template = xmltodict.parse(xml_template.replace("oml:", ""))
json_str = json.dumps(json_template)
Expand Down
95 changes: 93 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import contextlib
import datetime
import json
from collections.abc import AsyncIterator, Callable, Iterable, Iterator
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator
from pathlib import Path
from typing import TYPE_CHECKING, Any, NamedTuple
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol

import _pytest.mark
import httpx
Expand All @@ -22,6 +23,7 @@
from database.setup import expdb_database, user_database
from main import create_api
from routers.dependencies import expdb_connection, userdb_connection
from routers.types import Identifier
from tests.users import OWNER_USER

if TYPE_CHECKING:
Expand Down Expand Up @@ -139,6 +141,95 @@ def dataset_130() -> Iterator[dict[str, Any]]:
yield json.load(dataset_file)


class Task(NamedTuple):
"""To be replaced by an actual ORM class."""

id: Identifier
task_type: Identifier
creator: Identifier


class TaskFactory(Protocol):
def __call__(
self,
*,
task_id: Identifier = 42_000,
task_type: Identifier = 1,
creator: Identifier = OWNER_USER.user_id,
) -> Awaitable[Task]: ...


def _create_identifier_factory() -> Callable[[], Identifier]:
_identifier_counter: Identifier = 10_000_000

def _get() -> Identifier:
nonlocal _identifier_counter
_identifier_counter += 1
return _identifier_counter

return _get


_identifier_factory = _create_identifier_factory()


@pytest.fixture
async def task_factory(
expdb_test: AsyncConnection,
) -> TaskFactory:
async def create_task(
*,
task_id: Identifier | None = None,
task_type: Identifier = 1,
creator: Identifier = OWNER_USER.user_id,
) -> Task:
task_id = task_id or _identifier_factory()

await expdb_test.execute(
text("""
INSERT INTO task (task_id, ttid, creator) VALUES (:task_id, :ttid, :creator);
"""),
parameters={"task_id": task_id, "ttid": task_type, "creator": creator},
)
return Task(task_id, task_type, creator)

return create_task


class DatasetFactory(Protocol):
def __call__(
self, *, dataset_id: Identifier = 42_000, creator: Identifier = OWNER_USER.user_id
) -> Awaitable[Identifier]: ...


@pytest.fixture
async def dataset_factory(
expdb_test: AsyncConnection,
) -> DatasetFactory:
async def create_dataset(
*, dataset_id: Identifier | None = None, creator: Identifier = OWNER_USER.user_id
) -> Identifier:
dataset_id = dataset_id or _identifier_factory()
await expdb_test.execute(
text("""
INSERT INTO dataset
(did, uploader, name, version, format, upload_date, licence, url, visibility)
VALUES
(:dataset_id, :creator, :name, 'dataset-version', 'dataset-format',
:now, 'public', 'dataset-url', 'public');
"""),
parameters={
"dataset_id": dataset_id,
"creator": creator,
"now": datetime.datetime.now(tz=datetime.UTC),
"name": f"dataset-name-{dataset_id}",
},
)
return dataset_id

return create_dataset


class Flow(NamedTuple):
"""To be replaced by an actual ORM class."""

Expand Down
2 changes: 1 addition & 1 deletion tests/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
DATASET_ID_THAT_DOES_NOT_EXIST = 9_9999_999
ENTITY_ID_THAT_DOES_NOT_EXIST = 9_9999_999
SOME_PRIVATE_DATASET_ID = 130
PRIVATE_DATASET_ID = {130}
IN_PREPARATION_ID = {33, 161, 162, 163}
Expand Down
Loading
Loading