From 5bb1f707530b7c9dd886b0277109a91affeb7523 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Fri, 3 Jul 2026 15:05:46 +0200 Subject: [PATCH 1/3] Add raw PyTorch reshape compatibility hook --- .../_pytorch_raw_reshape_contract.py | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 src/pyrecest/backend_support/_pytorch_raw_reshape_contract.py diff --git a/src/pyrecest/backend_support/_pytorch_raw_reshape_contract.py b/src/pyrecest/backend_support/_pytorch_raw_reshape_contract.py new file mode 100644 index 000000000..a07c2c603 --- /dev/null +++ b/src/pyrecest/backend_support/_pytorch_raw_reshape_contract.py @@ -0,0 +1,69 @@ +"""PyTorch raw ``reshape`` compatibility hook.""" + +from __future__ import annotations + +from operator import index as _operator_index + + +def _pytorch_reshape_shape(shape, torch_module) -> tuple[int, ...]: + """Normalize NumPy-style reshape dimensions for ``torch.reshape``.""" + if torch_module.is_tensor(shape): + if shape.ndim == 0: + return (_operator_index(shape.item()),) + shape = shape.detach().cpu().tolist() + elif getattr(shape, "ndim", None) == 0 and hasattr(shape, "item"): + return (_operator_index(shape.item()),) + + try: + return (_operator_index(shape),) + except TypeError: + pass + + if isinstance(shape, (str, bytes)): + raise TypeError("reshape shape must be an integer or a sequence of integers") + + try: + return tuple(_operator_index(dimension) for dimension in shape) + except TypeError as exc: + raise TypeError( + "reshape shape must be an integer or a sequence of integers" + ) from exc + + +def patch_pytorch_raw_reshape_contract() -> None: + """Patch raw/public PyTorch ``reshape`` to accept NumPy-style inputs.""" + + try: + import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + import torch as torch_module # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - PyTorch may be unavailable + return + + original_reshape = getattr(raw_pytorch, "reshape", None) + if original_reshape is None: + return + if getattr(original_reshape, "_pyrecest_raw_reshape_contract", False) or getattr( + original_reshape, + "_pyrecest_reshape_contract", + False, + ): + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.reshape = original_reshape + return + + def reshape(x, shape): + return original_reshape( + raw_pytorch.array(x), + _pytorch_reshape_shape(shape, torch_module), + ) + + reshape.__name__ = getattr(original_reshape, "__name__", "reshape") + reshape.__doc__ = getattr(original_reshape, "__doc__", None) + reshape._pyrecest_raw_reshape_contract = True + raw_pytorch.reshape = reshape + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.reshape = reshape + + +__all__ = ["patch_pytorch_raw_reshape_contract"] From d944965f7f0afeae02fe8e8a98983bb544a38864 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Fri, 3 Jul 2026 15:06:11 +0200 Subject: [PATCH 2/3] Patch raw PyTorch reshape during stability import --- src/pyrecest/stability.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pyrecest/stability.py b/src/pyrecest/stability.py index 2f3ed50c1..f35b9f3c9 100644 --- a/src/pyrecest/stability.py +++ b/src/pyrecest/stability.py @@ -18,6 +18,9 @@ from pyrecest.backend_support._pytorch_minmax_device_contract import ( patch_pytorch_minmax_device_contract as _patch_pytorch_minmax_device_contract, ) +from pyrecest.backend_support._pytorch_raw_reshape_contract import ( + patch_pytorch_raw_reshape_contract as _patch_pytorch_raw_reshape_contract, +) def _patch_pytorch_raw_comparison_arraylike_contract() -> None: @@ -117,6 +120,7 @@ def diag(v, k=0): _patch_pytorch_dot_outer_device_contract() _patch_pytorch_matmul_device_contract() _patch_pytorch_minmax_device_contract() +_patch_pytorch_raw_reshape_contract() P = ParamSpec("P") R = TypeVar("R") From 93ea0a85517e8e78d1bde9de184eafbca7048585 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Fri, 3 Jul 2026 15:06:24 +0200 Subject: [PATCH 3/3] Cover raw PyTorch reshape with NumPy public backend --- .../test_raw_pytorch_reshape_contract.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/backend_support/test_raw_pytorch_reshape_contract.py diff --git a/tests/backend_support/test_raw_pytorch_reshape_contract.py b/tests/backend_support/test_raw_pytorch_reshape_contract.py new file mode 100644 index 000000000..502918380 --- /dev/null +++ b/tests/backend_support/test_raw_pytorch_reshape_contract.py @@ -0,0 +1,43 @@ +import importlib.util + +import pytest + +from tests.support.backend_runner import run_backend_code + +pytestmark = pytest.mark.backend_portable + + +def test_raw_pytorch_reshape_accepts_array_like_inputs_when_public_backend_is_numpy(): + if importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + + code = """ +import numpy as np +import numpy.testing as npt +import pyrecest # noqa: F401 +import pyrecest.backend as public_backend +import pyrecest._backend.pytorch as raw_pytorch + +assert public_backend.__backend_name__ == "numpy" + +result = raw_pytorch.reshape([1, 2, 3, 4], np.array([2, 2])) +assert tuple(result.shape) == (2, 2) +npt.assert_array_equal(raw_pytorch.to_numpy(result), np.array([[1, 2], [3, 4]])) + +flat = raw_pytorch.reshape([[1, 2], [3, 4]], np.array(4, dtype=np.int64)) +assert tuple(flat.shape) == (4,) +npt.assert_array_equal(raw_pytorch.to_numpy(flat), np.array([1, 2, 3, 4])) + +for bad_shape in ("2", [2, 2.0]): + try: + raw_pytorch.reshape([1, 2, 3, 4], bad_shape) + except TypeError: + pass + else: + raise AssertionError(f"reshape accepted invalid shape {bad_shape!r}") + +print("ok") +""" + result = run_backend_code("numpy", code) + assert result.returncode == 0, result.stderr + assert "ok" in result.stdout