Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/pyrecest/_backend/pytorch/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,20 @@ def _normalize_fft_dim_sequence(dim):
return tuple(normalized_entries)


def _with_dim_alias(kwargs, alias, func_name):
def _with_dim_alias(kwargs, alias, func_name, *, none_alias_means_default=True):
if alias not in kwargs:
return kwargs

kwargs = dict(kwargs)
alias_value = kwargs.pop(alias)
dim_value = kwargs.get("dim")
if alias_value is None:
if none_alias_means_default:
return kwargs
if dim_value is not None:
raise TypeError("conflicting FFT axis aliases")
kwargs["dim"] = None
return kwargs
dim_value = kwargs.get("dim")
if dim_value is not None:
dim_value = _normalize_fft_dim_sequence(dim_value)
alias_value = _normalize_fft_dim_sequence(alias_value)
Expand All @@ -100,11 +105,17 @@ def _wrap_arraylike_fft(
empty_dim_is_noop=False,
normalize_scalar_dim=False,
normalize_dim_sequence=False,
none_alias_means_default=True,
):
@_wraps(torch_func)
def fft_func(value, *args, **kwargs):
if dim_alias is not None:
kwargs = _with_dim_alias(kwargs, dim_alias, func_name)
kwargs = _with_dim_alias(
kwargs,
dim_alias,
func_name,
none_alias_means_default=none_alias_means_default,
)
if normalize_scalar_dim and "dim" in kwargs:
kwargs = dict(kwargs)
kwargs["dim"] = _normalize_single_fft_dim(kwargs["dim"])
Expand All @@ -128,12 +139,14 @@ def fft_func(value, *args, **kwargs):
func_name="rfft",
dim_alias="axis",
normalize_scalar_dim=True,
none_alias_means_default=False,
)
irfft = _wrap_arraylike_fft(
_torch.fft.irfft,
func_name="irfft",
dim_alias="axis",
normalize_scalar_dim=True,
none_alias_means_default=False,
)
fftshift = _wrap_arraylike_fft(
_torch.fft.fftshift,
Expand Down
13 changes: 13 additions & 0 deletions tests/backend_support/test_pytorch_fft_axis_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ def test_raw_pytorch_fft_helpers_accept_numpy_axis_aliases():
)


@pytest.mark.backend_portable
@pytest.mark.parametrize("fft_func", [pytorch_fft.rfft, pytorch_fft.irfft])
def test_raw_pytorch_single_axis_fft_rejects_none_axis_alias(fft_func):
with pytest.raises(TypeError):
fft_func(np.arange(4.0), axis=None)


@pytest.mark.backend_portable
def test_raw_pytorch_single_axis_fft_rejects_conflicting_none_axis_alias():
with pytest.raises(TypeError):
pytorch_fft.rfft(np.arange(4.0), axis=None, dim=0)


@pytest.mark.backend_portable
def test_raw_pytorch_fft_none_axis_alias_preserves_explicit_dim():
matrix = np.arange(6.0).reshape(2, 3)
Expand Down
Loading