From 0b3c6fb3fe15156a1cefddcaa34c9e7efcef77b1 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Thu, 2 Jul 2026 18:51:41 +0200 Subject: [PATCH 1/5] Test small update --- .../_pytorch_dot_outer_device_contract.py | 81 +------------------ 1 file changed, 1 insertion(+), 80 deletions(-) diff --git a/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py b/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py index f9a6d1da3..83c831f0b 100644 --- a/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py +++ b/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py @@ -1,80 +1 @@ -"""PyTorch ``dot``/``outer`` device compatibility hook.""" - -from __future__ import annotations - - -def _preferred_pytorch_device(torch_module, *values): - """Return a non-CPU tensor device when mixed-device operands are present.""" - for value in values: - if torch_module.is_tensor(value) and value.device.type != "cpu": - return value.device - for value in values: - if torch_module.is_tensor(value): - return value.device - return None - - -def _promoted_pair(raw_pytorch, torch_module, left, right): - """Return PyTorch operands on a common dtype and preferred existing device.""" - device = _preferred_pytorch_device(torch_module, left, right) - left = raw_pytorch.array(left) - right = raw_pytorch.array(right) - dtype = torch_module.promote_types(left.dtype, right.dtype) - if device is None: - return left.to(dtype=dtype), right.to(dtype=dtype) - return left.to(device=device, dtype=dtype), right.to(device=device, dtype=dtype) - - -def patch_pytorch_dot_outer_device_contract() -> None: - """Patch raw/public PyTorch ``dot`` and ``outer`` to preserve non-CPU operands.""" - try: - import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel - import pyrecest.backend as backend # pylint: disable=import-outside-toplevel - import torch # pylint: disable=import-outside-toplevel - except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable - return - - original_dot = getattr(raw_pytorch, "dot", None) - original_outer = getattr(raw_pytorch, "outer", None) - if original_dot is None or original_outer is None: - return - if getattr(original_dot, "_pyrecest_dot_outer_device_contract", False) and getattr( - original_outer, - "_pyrecest_dot_outer_device_contract", - False, - ): - if getattr(backend, "__backend_name__", None) == "pytorch": - backend.dot = original_dot - backend.outer = original_outer - return - - def dot(a, b): - a, b = _promoted_pair(raw_pytorch, torch, a, b) - if a.ndim == 0 or b.ndim == 0: - return torch.multiply(a, b) - if a.ndim == 1 and b.ndim == 1: - return torch.dot(a, b) - if b.ndim == 1: - return torch.einsum("...i,i->...", a, b) - if a.ndim == 1: - return torch.einsum("i,...i->...", a, b) - return torch.einsum("...i,...i->...", a, b) - - def outer(a, b): - a, b = _promoted_pair(raw_pytorch, torch, a, b) - if a.ndim == 0 or b.ndim == 0: - return torch.multiply(a, b) - return a[..., :, None] * b[..., None, :] - - for helper_name, helper, original_helper in ( - ("dot", dot, original_dot), - ("outer", outer, original_outer), - ): - helper.__name__ = getattr(original_helper, "__name__", helper_name) - helper.__doc__ = getattr(original_helper, "__doc__", None) - helper._pyrecest_dot_outer_device_contract = True - helper._pyrecest_device_contract = True - helper._pyrecest_numpy_contract = True - setattr(raw_pytorch, helper_name, helper) - if getattr(backend, "__backend_name__", None) == "pytorch": - setattr(backend, helper_name, helper) +# test From 03bddd9ece4080384c615c8102405d9d0f8f7506 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Thu, 2 Jul 2026 18:52:00 +0200 Subject: [PATCH 2/5] Restore PyTorch dot outer device contract hook --- .../_pytorch_dot_outer_device_contract.py | 81 ++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py b/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py index 83c831f0b..f9a6d1da3 100644 --- a/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py +++ b/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py @@ -1 +1,80 @@ -# test +"""PyTorch ``dot``/``outer`` device compatibility hook.""" + +from __future__ import annotations + + +def _preferred_pytorch_device(torch_module, *values): + """Return a non-CPU tensor device when mixed-device operands are present.""" + for value in values: + if torch_module.is_tensor(value) and value.device.type != "cpu": + return value.device + for value in values: + if torch_module.is_tensor(value): + return value.device + return None + + +def _promoted_pair(raw_pytorch, torch_module, left, right): + """Return PyTorch operands on a common dtype and preferred existing device.""" + device = _preferred_pytorch_device(torch_module, left, right) + left = raw_pytorch.array(left) + right = raw_pytorch.array(right) + dtype = torch_module.promote_types(left.dtype, right.dtype) + if device is None: + return left.to(dtype=dtype), right.to(dtype=dtype) + return left.to(device=device, dtype=dtype), right.to(device=device, dtype=dtype) + + +def patch_pytorch_dot_outer_device_contract() -> None: + """Patch raw/public PyTorch ``dot`` and ``outer`` to preserve non-CPU operands.""" + try: + import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + import torch # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable + return + + original_dot = getattr(raw_pytorch, "dot", None) + original_outer = getattr(raw_pytorch, "outer", None) + if original_dot is None or original_outer is None: + return + if getattr(original_dot, "_pyrecest_dot_outer_device_contract", False) and getattr( + original_outer, + "_pyrecest_dot_outer_device_contract", + False, + ): + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.dot = original_dot + backend.outer = original_outer + return + + def dot(a, b): + a, b = _promoted_pair(raw_pytorch, torch, a, b) + if a.ndim == 0 or b.ndim == 0: + return torch.multiply(a, b) + if a.ndim == 1 and b.ndim == 1: + return torch.dot(a, b) + if b.ndim == 1: + return torch.einsum("...i,i->...", a, b) + if a.ndim == 1: + return torch.einsum("i,...i->...", a, b) + return torch.einsum("...i,...i->...", a, b) + + def outer(a, b): + a, b = _promoted_pair(raw_pytorch, torch, a, b) + if a.ndim == 0 or b.ndim == 0: + return torch.multiply(a, b) + return a[..., :, None] * b[..., None, :] + + for helper_name, helper, original_helper in ( + ("dot", dot, original_dot), + ("outer", outer, original_outer), + ): + helper.__name__ = getattr(original_helper, "__name__", helper_name) + helper.__doc__ = getattr(original_helper, "__doc__", None) + helper._pyrecest_dot_outer_device_contract = True + helper._pyrecest_device_contract = True + helper._pyrecest_numpy_contract = True + setattr(raw_pytorch, helper_name, helper) + if getattr(backend, "__backend_name__", None) == "pytorch": + setattr(backend, helper_name, helper) From 02bac168a86869d15ec1cebaa58a02e60fe88730 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Thu, 2 Jul 2026 18:53:22 +0200 Subject: [PATCH 3/5] Add dot contract note --- .../backend_support/_pytorch_dot_outer_device_contract.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py b/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py index f9a6d1da3..3a970d116 100644 --- a/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py +++ b/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py @@ -3,6 +3,7 @@ from __future__ import annotations +# Keep this hook aligned with the backend dot tests. def _preferred_pytorch_device(torch_module, *values): """Return a non-CPU tensor device when mixed-device operands are present.""" for value in values: From c82c5d8854891fa366536d327cd6f39fa2915cde Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Thu, 2 Jul 2026 18:53:45 +0200 Subject: [PATCH 4/5] Use matmul for low-rank PyTorch dot --- .../backend_support/_pytorch_dot_outer_device_contract.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py b/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py index 3a970d116..fd5440f1a 100644 --- a/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py +++ b/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py @@ -53,8 +53,8 @@ def dot(a, b): a, b = _promoted_pair(raw_pytorch, torch, a, b) if a.ndim == 0 or b.ndim == 0: return torch.multiply(a, b) - if a.ndim == 1 and b.ndim == 1: - return torch.dot(a, b) + if a.ndim <= 2 and b.ndim <= 2: + return torch.matmul(a, b) if b.ndim == 1: return torch.einsum("...i,i->...", a, b) if a.ndim == 1: From 6ab49e91f0f14b0ecbd6e20fdcbcf71a79b8996e Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Thu, 2 Jul 2026 18:54:39 +0200 Subject: [PATCH 5/5] Test PyTorch dot matrix contract --- .../test_pytorch_dot_matrix_contract.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 tests/backend_support/test_pytorch_dot_matrix_contract.py diff --git a/tests/backend_support/test_pytorch_dot_matrix_contract.py b/tests/backend_support/test_pytorch_dot_matrix_contract.py new file mode 100644 index 000000000..625303d4b --- /dev/null +++ b/tests/backend_support/test_pytorch_dot_matrix_contract.py @@ -0,0 +1,27 @@ +import importlib.util + +import pytest + +from tests.support.backend_runner import run_backend_code + +pytestmark = pytest.mark.backend_portable + + +def test_public_pytorch_dot_matches_matrix_product_shape(): + if importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + + code = """ +import pyrecest.backend as backend + +left = backend.array([[1.0, 2.0], [3.0, 4.0]]) +right = backend.array([[5.0, 6.0], [7.0, 8.0]]) +result = backend.dot(left, right) +assert tuple(result.shape) == (2, 2) +assert backend.to_numpy(result).tolist() == [[19.0, 22.0], [43.0, 50.0]] +print("ok") +""" + result = run_backend_code("pytorch", code) + + assert result.returncode == 0, result.stderr + assert "ok" in result.stdout