diff --git a/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py b/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py index f9a6d1da3..fd5440f1a 100644 --- a/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py +++ b/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py @@ -3,6 +3,7 @@ from __future__ import annotations +# Keep this hook aligned with the backend dot tests. def _preferred_pytorch_device(torch_module, *values): """Return a non-CPU tensor device when mixed-device operands are present.""" for value in values: @@ -52,8 +53,8 @@ def dot(a, b): a, b = _promoted_pair(raw_pytorch, torch, a, b) if a.ndim == 0 or b.ndim == 0: return torch.multiply(a, b) - if a.ndim == 1 and b.ndim == 1: - return torch.dot(a, b) + if a.ndim <= 2 and b.ndim <= 2: + return torch.matmul(a, b) if b.ndim == 1: return torch.einsum("...i,i->...", a, b) if a.ndim == 1: diff --git a/tests/backend_support/test_pytorch_dot_matrix_contract.py b/tests/backend_support/test_pytorch_dot_matrix_contract.py new file mode 100644 index 000000000..625303d4b --- /dev/null +++ b/tests/backend_support/test_pytorch_dot_matrix_contract.py @@ -0,0 +1,27 @@ +import importlib.util + +import pytest + +from tests.support.backend_runner import run_backend_code + +pytestmark = pytest.mark.backend_portable + + +def test_public_pytorch_dot_matches_matrix_product_shape(): + if importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + + code = """ +import pyrecest.backend as backend + +left = backend.array([[1.0, 2.0], [3.0, 4.0]]) +right = backend.array([[5.0, 6.0], [7.0, 8.0]]) +result = backend.dot(left, right) +assert tuple(result.shape) == (2, 2) +assert backend.to_numpy(result).tolist() == [[19.0, 22.0], [43.0, 50.0]] +print("ok") +""" + result = run_backend_code("pytorch", code) + + assert result.returncode == 0, result.stderr + assert "ok" in result.stdout