diff --git a/src/pyrecest/evaluation/summarize_filter_results.py b/src/pyrecest/evaluation/summarize_filter_results.py index 89a58e7c0..74d187d87 100644 --- a/src/pyrecest/evaluation/summarize_filter_results.py +++ b/src/pyrecest/evaluation/summarize_filter_results.py @@ -27,6 +27,8 @@ def summarize_filter_results( "Provided both last_filter_states and last_estimates. Using last_estimates." ) filter_results = last_estimates + elif last_estimates is not None: + filter_results = last_estimates elif last_filter_states is not None: filter_results = last_filter_states else: @@ -38,7 +40,7 @@ def summarize_filter_results( warnings.warn("Using less than 1000 runs. This may lead to unreliable results.") extract_mean = get_extract_mean( - scenario_config["manifold"], mtt_scenario=scenario_config["mtt"] + scenario_config["manifold"], mtt_scenario=scenario_config.get("mtt", False) ) distance_function = get_distance_function(scenario_config["manifold"]) errors_all = determine_all_deviations( diff --git a/tests/test_evaluation_summarize_filter_results.py b/tests/test_evaluation_summarize_filter_results.py index 9253c4100..65b315574 100644 --- a/tests/test_evaluation_summarize_filter_results.py +++ b/tests/test_evaluation_summarize_filter_results.py @@ -22,6 +22,35 @@ def test_rejects_jax_backend_explicitly(self): last_estimates=np.empty((0, 0), dtype=object), ) + @unittest.skipIf( + pyrecest.backend.__backend_name__ in ("pytorch", "jax"), + reason="Not supported on this backend", + ) + def test_accepts_last_estimates_without_explicit_mtt_flag(self): + groundtruths = np.empty((2, 2), dtype=object) + for index in np.ndindex(groundtruths.shape): + groundtruths[index] = np.zeros(2) + + last_estimates = np.zeros((1, 2, 2)) + runtimes = np.ones((1, 2)) + run_failed = np.zeros((1, 2), dtype=bool) + filter_configs = [{"name": "estimate-only", "parameter": None}] + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + results = summarize_filter_results( + scenario_config={"manifold": "Euclidean"}, + filter_configs=filter_configs, + runtimes=runtimes, + groundtruths=groundtruths, + run_failed=run_failed, + last_estimates=last_estimates, + ) + + self.assertIs(results, filter_configs) + self.assertAlmostEqual(float(results[0]["error_mean"]), 0.0) + self.assertAlmostEqual(float(results[0]["failure_rate"]), 0.0) + @unittest.skipIf( pyrecest.backend.__backend_name__ in ("pytorch", "jax"), reason="Not supported on this backend",