Skip to content

PyTorch backend concatenate fails for axis=None #3677

Description

@FlorianPfaff

pyrecest.backend.concatenate(..., axis=None) should follow NumPy semantics and flatten inputs before concatenating. The current raw PyTorch backend implementation calls torch.cat(seq, dim=axis), so axis=None is forwarded as dim=None and raises at runtime.

Current implementation:

def concatenate(seq, axis=0, out=None):
    seq = _tensor_sequence(seq)
    return _torch.cat(seq, dim=axis, out=out)

Minimal reproduction:

import os
os.environ["PYRECEST_BACKEND"] = "pytorch"
import pyrecest.backend as backend

first = backend.array([[1, 2], [3, 4]])
second = backend.array([[5, 6]])
backend.concatenate((first, second), axis=None)

Expected, matching NumPy:

[1, 2, 3, 4, 5, 6]

Suggested fix in the PyTorch backend compatibility patch layer:

def _patch_pytorch_concatenate_axis_none_contract(raw_pytorch, torch) -> None:
    try:
        import pyrecest.backend as backend
    except ModuleNotFoundError:
        backend = None

    original_concatenate = raw_pytorch.concatenate
    if getattr(original_concatenate, "_pyrecest_axis_none_contract", False):
        if backend is not None and getattr(backend, "__backend_name__", None) == "pytorch":
            backend.concatenate = original_concatenate
        return

    def concatenate(seq, axis=0, out=None):
        tensors = [raw_pytorch.array(item) for item in seq]
        if axis is None:
            tensors = [tensor.reshape(-1) for tensor in tensors]
            axis_arg = 0
        else:
            axis_arg = _operator_index(axis)
        tensors = raw_pytorch.convert_to_wider_dtype(tensors)
        return torch.cat(tensors, dim=axis_arg, out=out)

    concatenate.__name__ = getattr(original_concatenate, "__name__", "concatenate")
    concatenate.__doc__ = getattr(original_concatenate, "__doc__", None)
    concatenate._pyrecest_axis_none_contract = True
    raw_pytorch.concatenate = concatenate
    if backend is not None and getattr(backend, "__backend_name__", None) == "pytorch":
        backend.concatenate = concatenate

Also add a backend-portable regression test covering both backend.concatenate under PYRECEST_BACKEND=pytorch and raw pyrecest._backend.pytorch.concatenate under the NumPy public backend.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions