diff --git a/src/pyrecest/_backend/pytorch/signal.py b/src/pyrecest/_backend/pytorch/signal.py index b5e07ff49..4d3aa5f77 100644 --- a/src/pyrecest/_backend/pytorch/signal.py +++ b/src/pyrecest/_backend/pytorch/signal.py @@ -9,9 +9,12 @@ def _coerce_axis(axis): axis_array = _np.asarray(axis) except (TypeError, ValueError) as exc: raise TypeError(_AXIS_TYPE_ERROR) from exc - if axis_array.shape != () or axis_array.dtype.kind not in "iu": + if axis_array.shape != (): raise TypeError(_AXIS_TYPE_ERROR) - return int(axis_array.item()) + try: + return int(axis_array.item().__index__()) + except AttributeError as exc: + raise TypeError(_AXIS_TYPE_ERROR) from exc def _is_scalar_axis_argument(value): diff --git a/tests/backend_support/test_pytorch_fftconvolve_contract.py b/tests/backend_support/test_pytorch_fftconvolve_contract.py index 4742ab8b1..23e3446e0 100644 --- a/tests/backend_support/test_pytorch_fftconvolve_contract.py +++ b/tests/backend_support/test_pytorch_fftconvolve_contract.py @@ -108,6 +108,21 @@ def test_pytorch_fftconvolve_scalar_empty_axes_matches_scipy(): assert np.allclose(actual, expected) +def test_pytorch_fftconvolve_boolean_axes_match_scipy(): + if backend.__backend_name__ != "pytorch": + pytest.skip("PyTorch-specific signal backend contract") + + first = backend.asarray([1.0, 2.0]) + second = backend.asarray([3.0, 4.0]) + + actual = _as_numpy(backend.signal.fftconvolve(first, second, axes=False)) + expected = scipy_fftconvolve(_as_numpy(first), _as_numpy(second), axes=False) + + assert np.allclose(actual, expected) + with pytest.raises(ValueError, match="out of bounds"): + backend.signal.fftconvolve(first, second, axes=True) + + def test_pytorch_fft_helpers_accept_array_like_inputs(): if backend.__backend_name__ != "pytorch": pytest.skip("PyTorch-specific FFT backend contract")