Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions src/pyrecest/backend_support/_pytorch_raw_reshape_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""PyTorch raw ``reshape`` compatibility hook."""

from __future__ import annotations

from operator import index as _operator_index


def _pytorch_reshape_shape(shape, torch_module) -> tuple[int, ...]:
"""Normalize NumPy-style reshape dimensions for ``torch.reshape``."""
if torch_module.is_tensor(shape):
if shape.ndim == 0:
return (_operator_index(shape.item()),)
shape = shape.detach().cpu().tolist()
elif getattr(shape, "ndim", None) == 0 and hasattr(shape, "item"):
return (_operator_index(shape.item()),)

try:
return (_operator_index(shape),)
except TypeError:
pass

if isinstance(shape, (str, bytes)):
raise TypeError("reshape shape must be an integer or a sequence of integers")

try:
return tuple(_operator_index(dimension) for dimension in shape)
except TypeError as exc:
raise TypeError(
"reshape shape must be an integer or a sequence of integers"
) from exc


def patch_pytorch_raw_reshape_contract() -> None:
"""Patch raw/public PyTorch ``reshape`` to accept NumPy-style inputs."""

try:
import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
import torch as torch_module # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - PyTorch may be unavailable
return

original_reshape = getattr(raw_pytorch, "reshape", None)
if original_reshape is None:
return
if getattr(original_reshape, "_pyrecest_raw_reshape_contract", False) or getattr(
original_reshape,
"_pyrecest_reshape_contract",
False,
):
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.reshape = original_reshape
return

def reshape(x, shape):
return original_reshape(
raw_pytorch.array(x),
_pytorch_reshape_shape(shape, torch_module),
)

reshape.__name__ = getattr(original_reshape, "__name__", "reshape")
reshape.__doc__ = getattr(original_reshape, "__doc__", None)
reshape._pyrecest_raw_reshape_contract = True
raw_pytorch.reshape = reshape
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.reshape = reshape


__all__ = ["patch_pytorch_raw_reshape_contract"]
4 changes: 4 additions & 0 deletions src/pyrecest/stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from pyrecest.backend_support._pytorch_minmax_device_contract import (
patch_pytorch_minmax_device_contract as _patch_pytorch_minmax_device_contract,
)
from pyrecest.backend_support._pytorch_raw_reshape_contract import (
patch_pytorch_raw_reshape_contract as _patch_pytorch_raw_reshape_contract,
)


def _patch_pytorch_raw_comparison_arraylike_contract() -> None:
Expand Down Expand Up @@ -117,6 +120,7 @@ def diag(v, k=0):
_patch_pytorch_dot_outer_device_contract()
_patch_pytorch_matmul_device_contract()
_patch_pytorch_minmax_device_contract()
_patch_pytorch_raw_reshape_contract()

P = ParamSpec("P")
R = TypeVar("R")
Expand Down
43 changes: 43 additions & 0 deletions tests/backend_support/test_raw_pytorch_reshape_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import importlib.util

import pytest

from tests.support.backend_runner import run_backend_code

pytestmark = pytest.mark.backend_portable


def test_raw_pytorch_reshape_accepts_array_like_inputs_when_public_backend_is_numpy():
if importlib.util.find_spec("torch") is None:
pytest.skip("PyTorch is not installed")

code = """
import numpy as np
import numpy.testing as npt
import pyrecest # noqa: F401
import pyrecest.backend as public_backend
import pyrecest._backend.pytorch as raw_pytorch

assert public_backend.__backend_name__ == "numpy"

result = raw_pytorch.reshape([1, 2, 3, 4], np.array([2, 2]))
assert tuple(result.shape) == (2, 2)
npt.assert_array_equal(raw_pytorch.to_numpy(result), np.array([[1, 2], [3, 4]]))

flat = raw_pytorch.reshape([[1, 2], [3, 4]], np.array(4, dtype=np.int64))
assert tuple(flat.shape) == (4,)
npt.assert_array_equal(raw_pytorch.to_numpy(flat), np.array([1, 2, 3, 4]))

for bad_shape in ("2", [2, 2.0]):
try:
raw_pytorch.reshape([1, 2, 3, 4], bad_shape)
except TypeError:
pass
else:
raise AssertionError(f"reshape accepted invalid shape {bad_shape!r}")

print("ok")
"""
result = run_backend_code("numpy", code)
assert result.returncode == 0, result.stderr
assert "ok" in result.stdout
Loading