From d5012e0a5d7b2251f01f9458ecafc61d9943c768 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Fri, 3 Jul 2026 15:32:24 +0200 Subject: [PATCH 1/3] Fix PyTorch argsort NumPy axis contract --- src/pyrecest/_backend/capabilities.py | 80 ++++++++++++++++++++++++++- 1 file changed, 79 insertions(+), 1 deletion(-) diff --git a/src/pyrecest/_backend/capabilities.py b/src/pyrecest/_backend/capabilities.py index 43bade893..b7e56f383 100644 --- a/src/pyrecest/_backend/capabilities.py +++ b/src/pyrecest/_backend/capabilities.py @@ -8,6 +8,7 @@ from __future__ import annotations +from operator import index as _operator_index from typing import Any, Final, cast BACKEND_NAMES: Final = ("numpy", "pytorch", "jax") @@ -162,11 +163,88 @@ def _patch_jax_backend_contracts() -> None: ) except ModuleNotFoundError: # pragma: no cover - backend support may be unavailable return - patch_jax_randint_empty_size_contract() +def _resolve_pytorch_argsort_axis(axis, dim) -> int | None: + """Resolve NumPy ``axis`` and PyTorch ``dim`` aliases for argsort.""" + if dim is not None: + if axis is not None and axis != dim: + raise TypeError("argsort() got both 'axis' and 'dim'") + axis = dim + if axis is None: + return None + return _operator_index(axis) + + +def _patch_pytorch_argsort_contracts() -> None: + try: + import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + import torch # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable + return + + original_argsort = getattr(raw_pytorch, "argsort", None) + if original_argsort is None: + return + if getattr(original_argsort, "_pyrecest_numpy_contract", False): + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.argsort = original_argsort + return + + def argsort( + a, + axis=-1, + kind=None, + order=None, + *, + stable=None, + dim=None, + descending=False, + ): + axis_value = _resolve_pytorch_argsort_axis(axis, dim) + if order is not None: + raise ValueError("order is not supported by the PyTorch backend") + if kind is not None: + if kind in {"stable", "mergesort"}: + if stable is False: + raise TypeError( + "argsort() got conflicting 'kind' and 'stable' arguments" + ) + stable = True + elif kind in {"quicksort", "heapsort"}: + if stable is True: + raise TypeError( + "argsort() got conflicting 'kind' and 'stable' arguments" + ) + stable = False + else: + raise ValueError( + "sort kind must be one of 'quicksort', 'heapsort', 'stable', or 'mergesort'" + ) + + values = raw_pytorch.array(a) + if axis_value is None: + values = values.reshape(-1) + axis_value = 0 + return torch.argsort( + values, + dim=axis_value, + descending=descending, + stable=bool(stable) if stable is not None else False, + ) + + argsort.__name__ = getattr(original_argsort, "__name__", "argsort") + argsort.__doc__ = getattr(original_argsort, "__doc__", None) + argsort._pyrecest_numpy_contract = True + raw_pytorch.argsort = argsort + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.argsort = argsort + + _patch_jax_backend_contracts() +_patch_pytorch_argsort_contracts() def get_unsupported_functions( From 819747c6f4383f94f7dfff5f6a56cf37f2f6769e Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Fri, 3 Jul 2026 15:32:43 +0200 Subject: [PATCH 2/3] Add PyTorch argsort backend contract regression tests --- .../test_pytorch_argsort_contract.py | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 tests/backend_support/test_pytorch_argsort_contract.py diff --git a/tests/backend_support/test_pytorch_argsort_contract.py b/tests/backend_support/test_pytorch_argsort_contract.py new file mode 100644 index 000000000..6ac092b3c --- /dev/null +++ b/tests/backend_support/test_pytorch_argsort_contract.py @@ -0,0 +1,84 @@ +import importlib.util +import os +import subprocess +import sys + +import pytest + + +def _run_backend_code(backend_name, code): + if importlib.util.find_spec("torch") is None: + pytest.skip("torch is not installed") + + env = os.environ.copy() + env["PYRECEST_BACKEND"] = backend_name + src_path = os.path.abspath("src") + env["PYTHONPATH"] = ( + src_path + if not env.get("PYTHONPATH") + else os.pathsep.join([src_path, env["PYTHONPATH"]]) + ) + subprocess.run([sys.executable, "-c", code], check=True, env=env) + + +@pytest.mark.backend_portable +def test_pytorch_argsort_accepts_numpy_axis_contract(): + _run_backend_code( + "pytorch", + """ +import pyrecest.backend as backend +import pyrecest._backend.pytorch as pytorch_backend + +axis_result = backend.argsort([[3, 1, 2], [0, 5, 4]], axis=1) +assert backend.to_numpy(axis_result).tolist() == [[1, 2, 0], [0, 2, 1]] + +flat_result = backend.argsort([[3, 1], [0, 2]], axis=None) +assert backend.to_numpy(flat_result).tolist() == [2, 1, 3, 0] + +dim_result = backend.argsort([[3, 1], [0, 2]], dim=0) +assert backend.to_numpy(dim_result).tolist() == [[1, 0], [0, 1]] + +stable_result = backend.argsort([2, 1, 2], stable=True) +assert backend.to_numpy(stable_result).tolist() == [1, 0, 2] + +raw_result = pytorch_backend.argsort([[2, 0], [1, 3]], axis=0) +assert pytorch_backend.to_numpy(raw_result).tolist() == [[1, 0], [0, 1]] + +try: + backend.argsort([1, 2], axis=0, dim=1) +except TypeError: + pass +else: + raise AssertionError("argsort accepted conflicting axis and dim arguments") +""", + ) + + +@pytest.mark.backend_portable +def test_raw_pytorch_argsort_is_patched_under_numpy_backend(): + _run_backend_code( + "numpy", + """ +import pyrecest # noqa: F401 +import pyrecest.backend as backend +import pyrecest._backend.pytorch as pytorch_backend + +assert backend.__backend_name__ == "numpy" + +raw_axis_result = pytorch_backend.argsort([[3, 1, 2], [0, 5, 4]], axis=1) +assert pytorch_backend.to_numpy(raw_axis_result).tolist() == [[1, 2, 0], [0, 2, 1]] + +raw_flat_result = pytorch_backend.argsort([[3, 1], [0, 2]], axis=None) +assert pytorch_backend.to_numpy(raw_flat_result).tolist() == [2, 1, 3, 0] + +public_result = backend.argsort([[3, 1, 2], [0, 5, 4]], axis=1) +assert public_result.tolist() == [[1, 2, 0], [0, 2, 1]] + +try: + pytorch_backend.argsort([1, 2], axis=0, dim=1) +except TypeError: + pass +else: + raise AssertionError("raw argsort accepted conflicting axis and dim arguments") +""", + ) From e9f6bc6ddde337d1c34f8a0499481722baa908eb Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Fri, 3 Jul 2026 15:34:01 +0200 Subject: [PATCH 3/3] Allow PyTorch argsort dim alias without axis conflict --- src/pyrecest/_backend/capabilities.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/pyrecest/_backend/capabilities.py b/src/pyrecest/_backend/capabilities.py index b7e56f383..ee4e7deec 100644 --- a/src/pyrecest/_backend/capabilities.py +++ b/src/pyrecest/_backend/capabilities.py @@ -154,6 +154,7 @@ _ALLOWED_API_CAPABILITY_KEYS: Final = frozenset( (*REQUIRED_BACKENDS, *_OPTIONAL_API_CAPABILITY_KEYS) ) +_PYTORCH_ARGSORT_DEFAULT_AXIS: Final = object() def _patch_jax_backend_contracts() -> None: @@ -168,8 +169,11 @@ def _patch_jax_backend_contracts() -> None: def _resolve_pytorch_argsort_axis(axis, dim) -> int | None: """Resolve NumPy ``axis`` and PyTorch ``dim`` aliases for argsort.""" + axis_was_omitted = axis is _PYTORCH_ARGSORT_DEFAULT_AXIS + if axis_was_omitted: + axis = -1 if dim is not None: - if axis is not None and axis != dim: + if not axis_was_omitted and axis is not None and axis != dim: raise TypeError("argsort() got both 'axis' and 'dim'") axis = dim if axis is None: @@ -195,7 +199,7 @@ def _patch_pytorch_argsort_contracts() -> None: def argsort( a, - axis=-1, + axis=_PYTORCH_ARGSORT_DEFAULT_AXIS, kind=None, order=None, *,