Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/pyrecest/stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from collections.abc import Callable, Iterable
from dataclasses import asdict, dataclass
from operator import index as _operator_index
from typing import Final, Literal, ParamSpec, TypeVar

from pyrecest.backend_support._pytorch_allclose_device_contract import (
Expand Down Expand Up @@ -111,8 +112,45 @@ def diag(v, k=0):
backend.diag = diag


def _patch_pytorch_round_numpy_contract() -> None:
"""Patch raw/public PyTorch ``round`` to accept NumPy-style inputs."""
try:
import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
import torch # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable
return

original_round = getattr(raw_pytorch, "round", None)
if original_round is None:
return
if getattr(original_round, "_pyrecest_numpy_contract", False):
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.round = original_round
return

def round(a, decimals=0, out=None): # pylint: disable=redefined-builtin
result = torch.round(raw_pytorch.array(a), decimals=_operator_index(decimals))
if out is None:
return result
copy_ = getattr(out, "copy_", None)
if copy_ is not None:
copy_(result)
return out
out[...] = result
return out

round.__name__ = getattr(original_round, "__name__", "round")
round.__doc__ = getattr(original_round, "__doc__", None)
round._pyrecest_numpy_contract = True
raw_pytorch.round = round
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.round = round


_patch_pytorch_allclose_device_contract()
_patch_pytorch_diag_numpy_contract()
_patch_pytorch_round_numpy_contract()
_patch_pytorch_raw_comparison_arraylike_contract()
_patch_pytorch_dot_outer_device_contract()
_patch_pytorch_matmul_device_contract()
Expand Down
45 changes: 45 additions & 0 deletions tests/backend_support/test_raw_pytorch_round_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import importlib.util
import os
import subprocess
import sys

import pytest


def _backend_test_env(backend_name):
env = os.environ.copy()
env["PYRECEST_BACKEND"] = backend_name
src_path = os.path.abspath("src")
env["PYTHONPATH"] = (
src_path
if not env.get("PYTHONPATH")
else os.pathsep.join([src_path, env["PYTHONPATH"]])
)
return env


@pytest.mark.backend_portable
def test_raw_pytorch_round_accepts_array_like_with_numpy_backend():
if importlib.util.find_spec("torch") is None:
pytest.skip("torch is not installed")

code = """
import pyrecest.backend as backend
import pyrecest._backend.pytorch as raw_backend

assert getattr(backend, "__backend_name__", None) == "numpy"

rounded = raw_backend.round([1.2, 2.7])
assert raw_backend.to_numpy(rounded).tolist() == [1.0, 3.0]

rounded_with_keyword = raw_backend.round([1.2, 2.7], decimals=0)
assert raw_backend.to_numpy(rounded_with_keyword).tolist() == [1.0, 3.0]

out = raw_backend.empty(2, dtype=raw_backend.float64)
returned = raw_backend.round([1.2, 2.7], out=out)
assert returned is out
assert raw_backend.to_numpy(out).tolist() == [1.0, 3.0]
"""
subprocess.run(
[sys.executable, "-c", code], check=True, env=_backend_test_env("numpy")
)
Loading