diff --git a/src/pyrecest/evaluation/selection.py b/src/pyrecest/evaluation/selection.py index db174f778..f6d7124cd 100644 --- a/src/pyrecest/evaluation/selection.py +++ b/src/pyrecest/evaluation/selection.py @@ -22,9 +22,13 @@ def _is_text_scalar(value) -> bool: def _normalize_nonnegative_integer(value, name: str) -> int: - value_array = np.asarray(value) + message = f"{name} must be a non-negative integer." + try: + value_array = np.asarray(value) + except (TypeError, ValueError, RuntimeError) as exc: + raise ValueError(message) from exc if value_array.shape != () or value_array.dtype == np.bool_: - raise ValueError(f"{name} must be a non-negative integer.") + raise ValueError(message) scalar = value_array.item() if ( @@ -32,35 +36,42 @@ def _normalize_nonnegative_integer(value, name: str) -> int: or _is_text_scalar(scalar) or not isinstance(scalar, Real) ): - raise ValueError(f"{name} must be a non-negative integer.") + raise ValueError(message) if isinstance(scalar, (int, np.integer)): integer = int(scalar) else: try: scalar_float = float(scalar) except (TypeError, ValueError, OverflowError) as exc: - raise ValueError(f"{name} must be a non-negative integer.") from exc + raise ValueError(message) from exc if not np.isfinite(scalar_float) or not scalar_float.is_integer(): - raise ValueError(f"{name} must be a non-negative integer.") + raise ValueError(message) integer = int(scalar_float) if integer < 0: - raise ValueError(f"{name} must be a non-negative integer.") + raise ValueError(message) return integer def _normalize_bool_flag(value, name: str) -> bool: - value_array = np.asarray(value) + message = f"{name} must be a boolean." + try: + value_array = np.asarray(value) + except (TypeError, ValueError, RuntimeError) as exc: + raise ValueError(message) from exc if value_array.shape != (): - raise ValueError(f"{name} must be a boolean.") + raise ValueError(message) scalar = value_array.item() if not isinstance(scalar, (bool, np.bool_)): - raise ValueError(f"{name} must be a boolean.") + raise ValueError(message) return bool(scalar) def _normalize_finite_scalar(value, message: str) -> float: - value_array = np.asarray(value) + try: + value_array = np.asarray(value) + except (TypeError, ValueError, RuntimeError) as exc: + raise ValueError(message) from exc if value_array.shape != () or value_array.dtype == np.bool_: raise ValueError(message) diff --git a/tests/evaluation/test_selection.py b/tests/evaluation/test_selection.py index 15faf8c2b..9d1276fa5 100644 --- a/tests/evaluation/test_selection.py +++ b/tests/evaluation/test_selection.py @@ -15,6 +15,12 @@ ) +class UncoercibleScalar: + def __array__(self, dtype=None): + del dtype + raise TypeError("cannot convert") + + def test_top_count_mask_is_deterministic_with_ties() -> None: mask = top_count_mask([1.0, 2.0, 2.0, 0.5], 2) @@ -101,6 +107,23 @@ def test_selection_helpers_reject_text_scalar_fractions() -> None: tail_rescue_quota_count(3, rescue_fraction=fraction) +def test_selection_helpers_report_value_error_for_uncoercible_scalars() -> None: + uncoercible = UncoercibleScalar() + + with pytest.raises(ValueError, match="item_count"): + retained_count_from_fraction(uncoercible, 0.5) + with pytest.raises(ValueError, match="retention_fraction"): + retained_count_from_fraction(10, uncoercible) + with pytest.raises(ValueError, match="nonnegative"): + sanitized_score_vector([1.0], nonnegative=uncoercible) + with pytest.raises(ValueError, match="largest"): + top_count_mask([1.0], 1, largest=uncoercible) + with pytest.raises(ValueError, match="quantile"): + quantile_tail_threshold([0.0, 1.0], uncoercible) + with pytest.raises(ValueError, match="rescue_fraction"): + tail_rescue_quota_count(3, rescue_fraction=uncoercible) + + def test_quantile_tail_mask_selects_lower_tail() -> None: values = np.asarray([0.0, 1.0, 2.0, 3.0])