Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def patch_pytorch_dtype_promotion_contract() -> None:
_patch_pytorch_logical_device_contract(raw_pytorch, backend, torch)
_patch_pytorch_binary_device_contract(raw_pytorch, backend, torch)
_patch_pytorch_equality_device_contract(raw_pytorch, backend, torch)
_patch_pytorch_matmul_device_contract(raw_pytorch, backend, torch)
_patch_pytorch_linspace_integer_dtype_contract(raw_pytorch, backend, torch)


Expand Down Expand Up @@ -265,6 +266,32 @@ def _patch_pytorch_equality_device_contract(raw_pytorch, backend, torch) -> None
setattr(backend, helper_name, wrapped_helper)


def _patch_pytorch_matmul_device_contract(raw_pytorch, backend, torch) -> None:
"""Keep matmul operands on an existing non-CPU tensor device."""
original_matmul = raw_pytorch.matmul
if getattr(original_matmul, "_pyrecest_device_contract", False):
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.matmul = original_matmul
return

def matmul(x, y, out=None):
device = _preferred_pytorch_device(torch, x, y)
x = raw_pytorch.array(x)
y = raw_pytorch.array(y)
if device is not None:
x = x.to(device=device)
y = y.to(device=device)
x, y = raw_pytorch.convert_to_wider_dtype([x, y])
return torch.matmul(x, y, out=out)

matmul.__name__ = getattr(original_matmul, "__name__", "matmul")
matmul.__doc__ = getattr(original_matmul, "__doc__", None)
matmul._pyrecest_device_contract = True
raw_pytorch.matmul = matmul
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.matmul = matmul


def _integer_torch_dtype(dtype, raw_pytorch, torch):
"""Return an explicit integer torch dtype, or ``None`` for non-integers."""
if dtype is None:
Expand Down
Loading