Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def patch_pytorch_dtype_promotion_contract() -> None:

_patch_pytorch_assignment_numpy_index_contract(raw_pytorch, backend, torch, np)
_patch_pytorch_logical_device_contract(raw_pytorch, backend, torch)
_patch_pytorch_comparison_device_contract(raw_pytorch, backend, torch)
_patch_pytorch_binary_device_contract(raw_pytorch, backend, torch)
_patch_pytorch_equality_device_contract(raw_pytorch, backend, torch)
_patch_pytorch_linspace_integer_dtype_contract(raw_pytorch, backend, torch)
Expand Down Expand Up @@ -217,6 +218,28 @@ def binary_helper(x1, x2, *args, **kwargs):
return binary_helper


def _patch_pytorch_comparison_device_contract(raw_pytorch, backend, torch) -> None:
"""Make PyTorch comparison helpers accept NumPy-style array-like inputs."""
helper_names = ("greater", "less", "logical_or")
if all(
getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False)
for helper_name in helper_names
):
if getattr(backend, "__backend_name__", None) == "pytorch":
for helper_name in helper_names:
setattr(backend, helper_name, getattr(raw_pytorch, helper_name))
return

for helper_name in helper_names:
wrapped_helper = _wrap_tensor_binary_device_helper(
getattr(raw_pytorch, helper_name),
torch,
)
setattr(raw_pytorch, helper_name, wrapped_helper)
if getattr(backend, "__backend_name__", None) == "pytorch":
setattr(backend, helper_name, wrapped_helper)


def _patch_pytorch_binary_device_contract(raw_pytorch, backend, torch) -> None:
"""Keep boxed PyTorch binary helper operands on an existing non-CPU device."""
helpers = {
Expand Down
Loading