Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import importlib.util
from operator import index as _operator_index
from pathlib import Path


Expand Down Expand Up @@ -38,6 +39,7 @@ def patch_pytorch_dtype_promotion_contract() -> None:
_patch_pytorch_binary_device_contract(raw_pytorch, backend, torch)
_patch_pytorch_equality_device_contract(raw_pytorch, backend, torch)
_patch_pytorch_linspace_integer_dtype_contract(raw_pytorch, backend, torch)
_patch_pytorch_creation_bool_shape_contract(raw_pytorch, backend, torch, np)
_patch_pytorch_arraylike_helper_contract(raw_pytorch, backend, torch)


Expand Down Expand Up @@ -313,6 +315,74 @@ def linspace(start, stop, num=50, endpoint=True, dtype=None):
backend.linspace = linspace


def _pytorch_creation_dimension(dimension, np) -> int:
"""Return one creation-shape dimension while rejecting booleans."""
if isinstance(dimension, (bool, np.bool_)):
raise TypeError("shape dimensions must be integers")
try:
return _operator_index(dimension)
except TypeError as exc:
raise TypeError("shape dimensions must be integers") from exc


def _pytorch_creation_shape(shape, torch, np) -> tuple[int, ...]:
"""Return a NumPy-style creation shape without accepting boolean dimensions."""
if torch.is_tensor(shape):
shape = shape.detach().cpu().numpy()

if isinstance(shape, (bool, np.bool_)):
raise TypeError("shape dimensions must be integers")
if isinstance(shape, (list, tuple)):
return tuple(_pytorch_creation_dimension(dimension, np) for dimension in shape)

shape_array = np.asarray(shape)
if shape_array.shape == ():
return (_pytorch_creation_dimension(shape_array.item(), np),)
if shape_array.size and np.issubdtype(shape_array.dtype, np.bool_):
raise TypeError("shape dimensions must be integers")
return tuple(
_pytorch_creation_dimension(dimension, np)
for dimension in shape_array.tolist()
)


def _wrap_creation_shape_helper(original_helper, torch, np):
"""Normalize creation shapes before the base PyTorch compatibility wrapper."""
if getattr(original_helper, "_pyrecest_bool_shape_contract", False):
return original_helper

def creation_helper(shape, *args, **kwargs):
return original_helper(_pytorch_creation_shape(shape, torch, np), *args, **kwargs)

creation_helper.__name__ = getattr(original_helper, "__name__", "creation_helper")
creation_helper.__doc__ = getattr(original_helper, "__doc__", None)
creation_helper._pyrecest_bool_shape_contract = True
return creation_helper


def _patch_pytorch_creation_bool_shape_contract(raw_pytorch, backend, torch, np) -> None:
"""Reject boolean creation shapes before PyTorch interprets them as integers."""
helper_names = ("empty", "zeros", "ones", "full")
if all(
getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_bool_shape_contract", False)
for helper_name in helper_names
):
if getattr(backend, "__backend_name__", None) == "pytorch":
for helper_name in helper_names:
setattr(backend, helper_name, getattr(raw_pytorch, helper_name))
return

for helper_name in helper_names:
wrapped_helper = _wrap_creation_shape_helper(
getattr(raw_pytorch, helper_name),
torch,
np,
)
setattr(raw_pytorch, helper_name, wrapped_helper)
if getattr(backend, "__backend_name__", None) == "pytorch":
setattr(backend, helper_name, wrapped_helper)


def _arraylike_tensor(value, raw_pytorch, torch):
"""Return array-like helper input as a PyTorch tensor."""
if torch.is_tensor(value):
Expand Down
103 changes: 103 additions & 0 deletions tests/backend_support/test_pytorch_creation_bool_shape_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import importlib.util

import pytest
from tests.support.backend_runner import run_backend_code


@pytest.mark.backend_portable
def test_pytorch_creation_helpers_reject_boolean_shapes():
if importlib.util.find_spec("torch") is None:
pytest.skip("PyTorch is not installed")

result = run_backend_code(
"pytorch",
"""
import numpy as np
import torch

import pyrecest.backend as backend
import pyrecest._backend.pytorch as raw_backend

assert getattr(backend, "__backend_name__", None) == "pytorch"

bad_shapes = (
True,
np.bool_(False),
np.array(True, dtype=np.bool_),
[True, 2],
[np.bool_(True), 2],
[True, False],
np.array([True, False], dtype=np.bool_),
np.array([True, 2], dtype=object),
torch.tensor(True),
torch.tensor([True, False]),
)

for creation_backend in (backend, raw_backend):
for helper_name, extra_args in (
("empty", ()),
("zeros", ()),
("ones", ()),
("full", (7,)),
):
helper = getattr(creation_backend, helper_name)
for bad_shape in bad_shapes:
try:
helper(bad_shape, *extra_args)
except TypeError:
pass
else:
raise AssertionError(f"{helper_name} accepted boolean shape {bad_shape!r}")
""",
)

assert result.returncode == 0, result.stderr


@pytest.mark.backend_portable
def test_raw_pytorch_creation_helpers_reject_boolean_shapes_with_numpy_backend():
if importlib.util.find_spec("torch") is None:
pytest.skip("PyTorch is not installed")

result = run_backend_code(
"numpy",
"""
import numpy as np
import torch

import pyrecest.backend as backend
import pyrecest._backend.pytorch as raw_backend

assert getattr(backend, "__backend_name__", None) == "numpy"

bad_shapes = (
True,
np.bool_(False),
np.array(True, dtype=np.bool_),
[True, 2],
[np.bool_(True), 2],
[True, False],
np.array([True, False], dtype=np.bool_),
np.array([True, 2], dtype=object),
torch.tensor(True),
torch.tensor([True, False]),
)

for helper_name, extra_args in (
("empty", ()),
("zeros", ()),
("ones", ()),
("full", (7,)),
):
helper = getattr(raw_backend, helper_name)
for bad_shape in bad_shapes:
try:
helper(bad_shape, *extra_args)
except TypeError:
pass
else:
raise AssertionError(f"raw {helper_name} accepted boolean shape {bad_shape!r}")
""",
)

assert result.returncode == 0, result.stderr
Loading