pyrecest.backend.concatenate(..., axis=None) should follow NumPy semantics and flatten inputs before concatenating. The current raw PyTorch backend implementation calls torch.cat(seq, dim=axis), so axis=None is forwarded as dim=None and raises at runtime.
Current implementation:
def concatenate(seq, axis=0, out=None):
seq = _tensor_sequence(seq)
return _torch.cat(seq, dim=axis, out=out)
Minimal reproduction:
import os
os.environ["PYRECEST_BACKEND"] = "pytorch"
import pyrecest.backend as backend
first = backend.array([[1, 2], [3, 4]])
second = backend.array([[5, 6]])
backend.concatenate((first, second), axis=None)
Expected, matching NumPy:
Suggested fix in the PyTorch backend compatibility patch layer:
def _patch_pytorch_concatenate_axis_none_contract(raw_pytorch, torch) -> None:
try:
import pyrecest.backend as backend
except ModuleNotFoundError:
backend = None
original_concatenate = raw_pytorch.concatenate
if getattr(original_concatenate, "_pyrecest_axis_none_contract", False):
if backend is not None and getattr(backend, "__backend_name__", None) == "pytorch":
backend.concatenate = original_concatenate
return
def concatenate(seq, axis=0, out=None):
tensors = [raw_pytorch.array(item) for item in seq]
if axis is None:
tensors = [tensor.reshape(-1) for tensor in tensors]
axis_arg = 0
else:
axis_arg = _operator_index(axis)
tensors = raw_pytorch.convert_to_wider_dtype(tensors)
return torch.cat(tensors, dim=axis_arg, out=out)
concatenate.__name__ = getattr(original_concatenate, "__name__", "concatenate")
concatenate.__doc__ = getattr(original_concatenate, "__doc__", None)
concatenate._pyrecest_axis_none_contract = True
raw_pytorch.concatenate = concatenate
if backend is not None and getattr(backend, "__backend_name__", None) == "pytorch":
backend.concatenate = concatenate
Also add a backend-portable regression test covering both backend.concatenate under PYRECEST_BACKEND=pytorch and raw pyrecest._backend.pytorch.concatenate under the NumPy public backend.
pyrecest.backend.concatenate(..., axis=None)should follow NumPy semantics and flatten inputs before concatenating. The current raw PyTorch backend implementation callstorch.cat(seq, dim=axis), soaxis=Noneis forwarded asdim=Noneand raises at runtime.Current implementation:
Minimal reproduction:
Expected, matching NumPy:
Suggested fix in the PyTorch backend compatibility patch layer:
Also add a backend-portable regression test covering both
backend.concatenateunderPYRECEST_BACKEND=pytorchand rawpyrecest._backend.pytorch.concatenateunder the NumPy public backend.