diff --git a/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py b/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py index 73843254a..2df4726e7 100644 --- a/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py +++ b/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py @@ -35,6 +35,7 @@ def patch_pytorch_dtype_promotion_contract() -> None: _patch_pytorch_assignment_numpy_index_contract(raw_pytorch, backend, torch, np) _patch_pytorch_logical_device_contract(raw_pytorch, backend, torch) + _patch_pytorch_comparison_device_contract(raw_pytorch, backend, torch) _patch_pytorch_binary_device_contract(raw_pytorch, backend, torch) _patch_pytorch_equality_device_contract(raw_pytorch, backend, torch) _patch_pytorch_linspace_integer_dtype_contract(raw_pytorch, backend, torch) @@ -217,6 +218,28 @@ def binary_helper(x1, x2, *args, **kwargs): return binary_helper +def _patch_pytorch_comparison_device_contract(raw_pytorch, backend, torch) -> None: + """Make PyTorch comparison helpers accept NumPy-style array-like inputs.""" + helper_names = ("greater", "less", "logical_or") + if all( + getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False) + for helper_name in helper_names + ): + if getattr(backend, "__backend_name__", None) == "pytorch": + for helper_name in helper_names: + setattr(backend, helper_name, getattr(raw_pytorch, helper_name)) + return + + for helper_name in helper_names: + wrapped_helper = _wrap_tensor_binary_device_helper( + getattr(raw_pytorch, helper_name), + torch, + ) + setattr(raw_pytorch, helper_name, wrapped_helper) + if getattr(backend, "__backend_name__", None) == "pytorch": + setattr(backend, helper_name, wrapped_helper) + + def _patch_pytorch_binary_device_contract(raw_pytorch, backend, torch) -> None: """Keep boxed PyTorch binary helper operands on an existing non-CPU device.""" helpers = {