diff --git a/src/pyrecest/evaluation/get_axis_label.py b/src/pyrecest/evaluation/get_axis_label.py index ef2f51a2d..551ab02f9 100644 --- a/src/pyrecest/evaluation/get_axis_label.py +++ b/src/pyrecest/evaluation/get_axis_label.py @@ -1,5 +1,11 @@ +def _normalize_manifold_name(manifold_name): + if not isinstance(manifold_name, str) or not manifold_name.strip(): + raise ValueError("manifold_name must be a non-empty string") + return manifold_name.strip().lower() + + def get_axis_label(manifold_name): - normalized_name = manifold_name.lower() + normalized_name = _normalize_manifold_name(manifold_name) if "circlesymm" in normalized_name: error_label = "Error in radian" diff --git a/tests/test_axis_labels.py b/tests/test_axis_labels.py index 6139e0a0c..1dd64b015 100644 --- a/tests/test_axis_labels.py +++ b/tests/test_axis_labels.py @@ -26,3 +26,13 @@ def test_specific_manifolds_are_not_shadowed_by_generic_substrings( ) def test_generic_manifold_axis_labels_are_preserved(manifold_name, expected_label): assert get_axis_label(manifold_name) == expected_label + + +@pytest.mark.parametrize("manifold_name", [None, "", " "]) +def test_axis_label_rejects_invalid_manifold_names(manifold_name): + with pytest.raises(ValueError, match="manifold_name must be a non-empty string"): + get_axis_label(manifold_name) + + +def test_axis_label_strips_manifold_name_whitespace(): + assert get_axis_label(" hypersphere ") == "Error (orthodromic distance) in radian"