Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations


# Keep this hook aligned with the backend dot tests.
def _preferred_pytorch_device(torch_module, *values):
"""Return a non-CPU tensor device when mixed-device operands are present."""
for value in values:
Expand Down Expand Up @@ -52,8 +53,8 @@ def dot(a, b):
a, b = _promoted_pair(raw_pytorch, torch, a, b)
if a.ndim == 0 or b.ndim == 0:
return torch.multiply(a, b)
if a.ndim == 1 and b.ndim == 1:
return torch.dot(a, b)
if a.ndim <= 2 and b.ndim <= 2:
return torch.matmul(a, b)
if b.ndim == 1:
return torch.einsum("...i,i->...", a, b)
if a.ndim == 1:
Expand Down
27 changes: 27 additions & 0 deletions tests/backend_support/test_pytorch_dot_matrix_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import importlib.util

import pytest

from tests.support.backend_runner import run_backend_code

pytestmark = pytest.mark.backend_portable


def test_public_pytorch_dot_matches_matrix_product_shape():
if importlib.util.find_spec("torch") is None:
pytest.skip("PyTorch is not installed")

code = """
import pyrecest.backend as backend

left = backend.array([[1.0, 2.0], [3.0, 4.0]])
right = backend.array([[5.0, 6.0], [7.0, 8.0]])
result = backend.dot(left, right)
assert tuple(result.shape) == (2, 2)
assert backend.to_numpy(result).tolist() == [[19.0, 22.0], [43.0, 50.0]]
print("ok")
"""
result = run_backend_code("pytorch", code)

assert result.returncode == 0, result.stderr
assert "ok" in result.stdout
Loading