diff --git a/src/pyrecest/stability.py b/src/pyrecest/stability.py index 2f3ed50c1..7c55d1b38 100644 --- a/src/pyrecest/stability.py +++ b/src/pyrecest/stability.py @@ -4,6 +4,7 @@ from collections.abc import Callable, Iterable from dataclasses import asdict, dataclass +from operator import index as _operator_index from typing import Final, Literal, ParamSpec, TypeVar from pyrecest.backend_support._pytorch_allclose_device_contract import ( @@ -111,8 +112,45 @@ def diag(v, k=0): backend.diag = diag +def _patch_pytorch_round_numpy_contract() -> None: + """Patch raw/public PyTorch ``round`` to accept NumPy-style inputs.""" + try: + import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + import torch # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable + return + + original_round = getattr(raw_pytorch, "round", None) + if original_round is None: + return + if getattr(original_round, "_pyrecest_numpy_contract", False): + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.round = original_round + return + + def round(a, decimals=0, out=None): # pylint: disable=redefined-builtin + result = torch.round(raw_pytorch.array(a), decimals=_operator_index(decimals)) + if out is None: + return result + copy_ = getattr(out, "copy_", None) + if copy_ is not None: + copy_(result) + return out + out[...] = result + return out + + round.__name__ = getattr(original_round, "__name__", "round") + round.__doc__ = getattr(original_round, "__doc__", None) + round._pyrecest_numpy_contract = True + raw_pytorch.round = round + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.round = round + + _patch_pytorch_allclose_device_contract() _patch_pytorch_diag_numpy_contract() +_patch_pytorch_round_numpy_contract() _patch_pytorch_raw_comparison_arraylike_contract() _patch_pytorch_dot_outer_device_contract() _patch_pytorch_matmul_device_contract() diff --git a/tests/backend_support/test_raw_pytorch_round_contract.py b/tests/backend_support/test_raw_pytorch_round_contract.py new file mode 100644 index 000000000..445d1690d --- /dev/null +++ b/tests/backend_support/test_raw_pytorch_round_contract.py @@ -0,0 +1,45 @@ +import importlib.util +import os +import subprocess +import sys + +import pytest + + +def _backend_test_env(backend_name): + env = os.environ.copy() + env["PYRECEST_BACKEND"] = backend_name + src_path = os.path.abspath("src") + env["PYTHONPATH"] = ( + src_path + if not env.get("PYTHONPATH") + else os.pathsep.join([src_path, env["PYTHONPATH"]]) + ) + return env + + +@pytest.mark.backend_portable +def test_raw_pytorch_round_accepts_array_like_with_numpy_backend(): + if importlib.util.find_spec("torch") is None: + pytest.skip("torch is not installed") + + code = """ +import pyrecest.backend as backend +import pyrecest._backend.pytorch as raw_backend + +assert getattr(backend, "__backend_name__", None) == "numpy" + +rounded = raw_backend.round([1.2, 2.7]) +assert raw_backend.to_numpy(rounded).tolist() == [1.0, 3.0] + +rounded_with_keyword = raw_backend.round([1.2, 2.7], decimals=0) +assert raw_backend.to_numpy(rounded_with_keyword).tolist() == [1.0, 3.0] + +out = raw_backend.empty(2, dtype=raw_backend.float64) +returned = raw_backend.round([1.2, 2.7], out=out) +assert returned is out +assert raw_backend.to_numpy(out).tolist() == [1.0, 3.0] +""" + subprocess.run( + [sys.executable, "-c", code], check=True, env=_backend_test_env("numpy") + )