From 9d9776893874f190a1824f344fd42d8641b0dcee Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Thu, 2 Jul 2026 12:48:29 +0200 Subject: [PATCH 1/2] Accept boolean scalar fftconvolve axes --- src/pyrecest/_backend/pytorch/signal.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/pyrecest/_backend/pytorch/signal.py b/src/pyrecest/_backend/pytorch/signal.py index b5e07ff49..4d3aa5f77 100644 --- a/src/pyrecest/_backend/pytorch/signal.py +++ b/src/pyrecest/_backend/pytorch/signal.py @@ -9,9 +9,12 @@ def _coerce_axis(axis): axis_array = _np.asarray(axis) except (TypeError, ValueError) as exc: raise TypeError(_AXIS_TYPE_ERROR) from exc - if axis_array.shape != () or axis_array.dtype.kind not in "iu": + if axis_array.shape != (): raise TypeError(_AXIS_TYPE_ERROR) - return int(axis_array.item()) + try: + return int(axis_array.item().__index__()) + except AttributeError as exc: + raise TypeError(_AXIS_TYPE_ERROR) from exc def _is_scalar_axis_argument(value): From 3eb5b0f51b94da2741e5509c9f79e3f85a35f00a Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Thu, 2 Jul 2026 12:48:53 +0200 Subject: [PATCH 2/2] Test boolean fftconvolve axes contract --- .../test_pytorch_fftconvolve_contract.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/backend_support/test_pytorch_fftconvolve_contract.py b/tests/backend_support/test_pytorch_fftconvolve_contract.py index 4742ab8b1..23e3446e0 100644 --- a/tests/backend_support/test_pytorch_fftconvolve_contract.py +++ b/tests/backend_support/test_pytorch_fftconvolve_contract.py @@ -108,6 +108,21 @@ def test_pytorch_fftconvolve_scalar_empty_axes_matches_scipy(): assert np.allclose(actual, expected) +def test_pytorch_fftconvolve_boolean_axes_match_scipy(): + if backend.__backend_name__ != "pytorch": + pytest.skip("PyTorch-specific signal backend contract") + + first = backend.asarray([1.0, 2.0]) + second = backend.asarray([3.0, 4.0]) + + actual = _as_numpy(backend.signal.fftconvolve(first, second, axes=False)) + expected = scipy_fftconvolve(_as_numpy(first), _as_numpy(second), axes=False) + + assert np.allclose(actual, expected) + with pytest.raises(ValueError, match="out of bounds"): + backend.signal.fftconvolve(first, second, axes=True) + + def test_pytorch_fft_helpers_accept_array_like_inputs(): if backend.__backend_name__ != "pytorch": pytest.skip("PyTorch-specific FFT backend contract")