Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 83 additions & 1 deletion src/pyrecest/_backend/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from __future__ import annotations

from operator import index as _operator_index
from typing import Any, Final, cast

BACKEND_NAMES: Final = ("numpy", "pytorch", "jax")
Expand Down Expand Up @@ -153,6 +154,7 @@
_ALLOWED_API_CAPABILITY_KEYS: Final = frozenset(
(*REQUIRED_BACKENDS, *_OPTIONAL_API_CAPABILITY_KEYS)
)
_PYTORCH_ARGSORT_DEFAULT_AXIS: Final = object()


def _patch_jax_backend_contracts() -> None:
Expand All @@ -162,11 +164,91 @@ def _patch_jax_backend_contracts() -> None:
)
except ModuleNotFoundError: # pragma: no cover - backend support may be unavailable
return

patch_jax_randint_empty_size_contract()


def _resolve_pytorch_argsort_axis(axis, dim) -> int | None:
"""Resolve NumPy ``axis`` and PyTorch ``dim`` aliases for argsort."""
axis_was_omitted = axis is _PYTORCH_ARGSORT_DEFAULT_AXIS
if axis_was_omitted:
axis = -1
if dim is not None:
if not axis_was_omitted and axis is not None and axis != dim:
raise TypeError("argsort() got both 'axis' and 'dim'")
axis = dim
if axis is None:
return None
return _operator_index(axis)


def _patch_pytorch_argsort_contracts() -> None:
try:
import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
import torch # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable
return

original_argsort = getattr(raw_pytorch, "argsort", None)
if original_argsort is None:
return
if getattr(original_argsort, "_pyrecest_numpy_contract", False):
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.argsort = original_argsort
return

def argsort(
a,
axis=_PYTORCH_ARGSORT_DEFAULT_AXIS,
kind=None,
order=None,
*,
stable=None,
dim=None,
descending=False,
):
axis_value = _resolve_pytorch_argsort_axis(axis, dim)
if order is not None:
raise ValueError("order is not supported by the PyTorch backend")
if kind is not None:
if kind in {"stable", "mergesort"}:
if stable is False:
raise TypeError(
"argsort() got conflicting 'kind' and 'stable' arguments"
)
stable = True
elif kind in {"quicksort", "heapsort"}:
if stable is True:
raise TypeError(
"argsort() got conflicting 'kind' and 'stable' arguments"
)
stable = False
else:
raise ValueError(
"sort kind must be one of 'quicksort', 'heapsort', 'stable', or 'mergesort'"
)

values = raw_pytorch.array(a)
if axis_value is None:
values = values.reshape(-1)
axis_value = 0
return torch.argsort(
values,
dim=axis_value,
descending=descending,
stable=bool(stable) if stable is not None else False,
)

argsort.__name__ = getattr(original_argsort, "__name__", "argsort")
argsort.__doc__ = getattr(original_argsort, "__doc__", None)
argsort._pyrecest_numpy_contract = True
raw_pytorch.argsort = argsort
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.argsort = argsort


_patch_jax_backend_contracts()
_patch_pytorch_argsort_contracts()


def get_unsupported_functions(
Expand Down
84 changes: 84 additions & 0 deletions tests/backend_support/test_pytorch_argsort_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import importlib.util
import os
import subprocess
import sys

import pytest


def _run_backend_code(backend_name, code):
if importlib.util.find_spec("torch") is None:
pytest.skip("torch is not installed")

env = os.environ.copy()
env["PYRECEST_BACKEND"] = backend_name
src_path = os.path.abspath("src")
env["PYTHONPATH"] = (
src_path
if not env.get("PYTHONPATH")
else os.pathsep.join([src_path, env["PYTHONPATH"]])
)
subprocess.run([sys.executable, "-c", code], check=True, env=env)


@pytest.mark.backend_portable
def test_pytorch_argsort_accepts_numpy_axis_contract():
_run_backend_code(
"pytorch",
"""
import pyrecest.backend as backend
import pyrecest._backend.pytorch as pytorch_backend

axis_result = backend.argsort([[3, 1, 2], [0, 5, 4]], axis=1)
assert backend.to_numpy(axis_result).tolist() == [[1, 2, 0], [0, 2, 1]]

flat_result = backend.argsort([[3, 1], [0, 2]], axis=None)
assert backend.to_numpy(flat_result).tolist() == [2, 1, 3, 0]

dim_result = backend.argsort([[3, 1], [0, 2]], dim=0)
assert backend.to_numpy(dim_result).tolist() == [[1, 0], [0, 1]]

stable_result = backend.argsort([2, 1, 2], stable=True)
assert backend.to_numpy(stable_result).tolist() == [1, 0, 2]

raw_result = pytorch_backend.argsort([[2, 0], [1, 3]], axis=0)
assert pytorch_backend.to_numpy(raw_result).tolist() == [[1, 0], [0, 1]]

try:
backend.argsort([1, 2], axis=0, dim=1)
except TypeError:
pass
else:
raise AssertionError("argsort accepted conflicting axis and dim arguments")
""",
)


@pytest.mark.backend_portable
def test_raw_pytorch_argsort_is_patched_under_numpy_backend():
_run_backend_code(
"numpy",
"""
import pyrecest # noqa: F401
import pyrecest.backend as backend
import pyrecest._backend.pytorch as pytorch_backend

assert backend.__backend_name__ == "numpy"

raw_axis_result = pytorch_backend.argsort([[3, 1, 2], [0, 5, 4]], axis=1)
assert pytorch_backend.to_numpy(raw_axis_result).tolist() == [[1, 2, 0], [0, 2, 1]]

raw_flat_result = pytorch_backend.argsort([[3, 1], [0, 2]], axis=None)
assert pytorch_backend.to_numpy(raw_flat_result).tolist() == [2, 1, 3, 0]

public_result = backend.argsort([[3, 1, 2], [0, 5, 4]], axis=1)
assert public_result.tolist() == [[1, 2, 0], [0, 2, 1]]

try:
pytorch_backend.argsort([1, 2], axis=0, dim=1)
except TypeError:
pass
else:
raise AssertionError("raw argsort accepted conflicting axis and dim arguments")
""",
)
Loading