Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/pyrecest/utils/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,17 @@ def _solve_subproblem( # pylint: disable=too-many-locals
for row_index, col_index in subproblem.forbidden_pairs:
modified_cost_matrix[row_index, col_index] = large_cost

forbidden_pairs = set(subproblem.forbidden_pairs)
forced_rows = set()
forced_cols = set()
for row_index, col_index in subproblem.forced_pairs:
if row_index in forced_rows or col_index in forced_cols:
if (
row_index in forced_rows
or col_index in forced_cols
or (row_index, col_index) in forbidden_pairs
):
return None
if bool(augmented_cost_matrix[row_index, col_index] >= large_cost / 2.0):
return None
forced_rows.add(row_index)
forced_cols.add(col_index)
Expand Down
32 changes: 32 additions & 0 deletions tests/utils/test_assignment_murty_forced_prefix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np

from pyrecest.utils.assignment import murty_k_best_assignments


def _assignment_tuple(solution):
return tuple(int(index) for index in solution["assignment"])


def test_murty_forced_prefix_subproblems_do_not_repeat_assignments():
cost_matrix = np.asarray(
[
[0.0, 100.0],
[1.0, 2.0],
]
)

solutions = murty_k_best_assignments(cost_matrix, k=5)
assignments = [_assignment_tuple(solution) for solution in solutions]

assert len(assignments) == 5
assert len(assignments) == len(set(assignments))
assert set(assignments) == {
(0, -1),
(-1, -1),
(-1, 0),
(0, 1),
(-1, 1),
}
assert [solution["cost"] for solution in solutions] == sorted(
solution["cost"] for solution in solutions
)
Loading