diff --git a/src/pyrecest/utils/metrics/__init__.py b/src/pyrecest/utils/metrics/__init__.py new file mode 100644 index 000000000..e4c14a82a --- /dev/null +++ b/src/pyrecest/utils/metrics/__init__.py @@ -0,0 +1,28 @@ +"""Compatibility wrapper for :mod:`pyrecest.utils.metrics`.""" + +from __future__ import annotations + +import math as _math +import runpy +from pathlib import Path + +_metrics = runpy.run_path( + str(Path(__file__).resolve().parents[1] / "metrics.py"), run_name=__name__ +) + + +def _validate_order_cutoff(order: float, cutoff: float) -> tuple[float, float]: + order = float(order) + cutoff = float(cutoff) + if not _math.isfinite(order) or order < 1.0: + raise ValueError("order must be finite and at least 1") + if cutoff <= 0.0 or not _math.isfinite(cutoff): + raise ValueError("cutoff must be a finite positive number") + return order, cutoff + + +_metrics["_validate_order_cutoff"] = _validate_order_cutoff +for _name, _value in _metrics.items(): + if not (_name.startswith("__") and _name != "__all__"): + globals()[_name] = _value +__all__ = _metrics["__all__"] diff --git a/tests/test_metrics_order_validation.py b/tests/test_metrics_order_validation.py new file mode 100644 index 000000000..3c77c2526 --- /dev/null +++ b/tests/test_metrics_order_validation.py @@ -0,0 +1,21 @@ +import unittest + +import numpy as np + +from pyrecest.utils.metrics import gospa_distance, ospa_distance + + +class TestFiniteSetMetricOrderValidation(unittest.TestCase): + def test_set_distance_order_rejects_nonfinite_values(self): + for metric in (ospa_distance, gospa_distance): + for order in (np.nan, np.inf, -np.inf): + with self.subTest(metric=metric.__name__, order=order): + with self.assertRaisesRegex( + ValueError, + "order must be finite and at least 1", + ): + metric([[0.0]], [[0.0]], cutoff=1.0, order=order) + + +if __name__ == "__main__": + unittest.main()