-
Notifications
You must be signed in to change notification settings - Fork 203
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Hungarian algorithm for the linear assignment problem.
- Loading branch information
1 parent
474b4fd
commit bbd8386
Showing
3 changed files
with
363 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,291 @@ | ||
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""The Hungarian algorithm for the linear assignment problem.""" | ||
|
||
from functools import partial | ||
|
||
import jax | ||
from jax import lax, numpy as jnp | ||
|
||
|
||
def hungarian_algorithm(costs, /): | ||
"""The Hungarian algorithm for the linear assignment problem. | ||
The assignment problem is a fundamental combinatorial optimization problem. | ||
In this problem, there are :math:`n` workers and :math:`m` jobs. | ||
For each worker and job, there is a cost associated with assigning that worker | ||
to that job. | ||
The goal is to assign at most one worker to each job and at most one job to | ||
each worker, in a way that minimizes the total cost of the assignment. | ||
Equivalently, given a weighted complete bipartite graph, the problem is to | ||
find a matching that minimizes the sum of the weights of the edges. | ||
The Hungarian algorithm is an :math:`O(n^3)` algorithm for this problem. | ||
Args: | ||
costs: A matrix of costs. | ||
Returns: | ||
A pair ``(i, j)`` where ``i`` is an array of row indices and ``j`` is an | ||
array of column indices. | ||
The cost of the assignment is ``cost_matrix[i, j].sum()``. | ||
""" | ||
|
||
if costs.shape[0] == 0 or costs.shape[1] == 0: | ||
return jnp.zeros(0, int), jnp.zeros(0, int) | ||
|
||
transpose = costs.shape[1] < costs.shape[0] | ||
|
||
if transpose: | ||
costs = costs.T | ||
|
||
costs = costs.astype(float) | ||
u = jnp.zeros(costs.shape[0], costs.dtype) | ||
v = jnp.zeros(costs.shape[1], costs.dtype) | ||
|
||
path = jnp.full(costs.shape[1], -1) | ||
col4row = jnp.full(costs.shape[0], -1) | ||
row4col = jnp.full(costs.shape[1], -1) | ||
|
||
init = costs, u, v, path, row4col, col4row | ||
costs, u, v, path, row4col, col4row = lax.fori_loop( | ||
0, costs.shape[0], _lsa_body, init | ||
) | ||
|
||
if transpose: | ||
i = col4row.argsort() | ||
return col4row[i], i | ||
else: | ||
return jnp.arange(costs.shape[0]), col4row | ||
|
||
|
||
def _find_short_augpath_while_body_inner_for(it, val): | ||
( | ||
remaining, | ||
min_value, | ||
costs, | ||
i, | ||
u, | ||
v, | ||
shortest_path_costs, | ||
path, | ||
lowest, | ||
row4col, | ||
index, | ||
) = val | ||
|
||
j = remaining[it] | ||
|
||
r = min_value + costs[i, j] - u[i] - v[j] | ||
|
||
path = path.at[j].set(jnp.where(r < shortest_path_costs[j], i, path[j])) | ||
|
||
shortest_path_costs = shortest_path_costs.at[j].min(r) | ||
|
||
index = jnp.where( | ||
(shortest_path_costs[j] < lowest) | ||
| ((shortest_path_costs[j] == lowest) & (row4col[j] == -1)), | ||
it, | ||
index, | ||
) | ||
|
||
lowest = jnp.minimum(lowest, shortest_path_costs[j]) | ||
|
||
return ( | ||
remaining, | ||
min_value, | ||
costs, | ||
i, | ||
u, | ||
v, | ||
shortest_path_costs, | ||
path, | ||
lowest, | ||
row4col, | ||
index, | ||
) | ||
|
||
|
||
def _find_short_augpath_while_body_tail_alt(val): | ||
remaining, index, row4col, sink, i, sc, num_remaining = val | ||
|
||
j = remaining[index] | ||
pred = row4col[j] == -1 | ||
sink = jnp.where(pred, j, sink) | ||
i = jnp.where(pred, i, row4col[j]) | ||
|
||
sc = sc.at[j].set(True) | ||
num_remaining -= 1 | ||
remaining = remaining.at[index].set(remaining[num_remaining]) | ||
|
||
return remaining, sink, i, sc, num_remaining | ||
|
||
|
||
def _find_short_augpath_while_body(val): | ||
( | ||
costs, | ||
u, | ||
v, | ||
path, | ||
row4col, | ||
current_row, | ||
min_value, | ||
num_remaining, | ||
remaining, | ||
sr, | ||
sc, | ||
shortest_path_costs, | ||
sink, | ||
) = val | ||
|
||
index = -1 | ||
lowest = jnp.inf | ||
sr = sr.at[current_row].set(True) | ||
|
||
init = ( | ||
remaining, | ||
min_value, | ||
costs, | ||
current_row, | ||
u, | ||
v, | ||
shortest_path_costs, | ||
path, | ||
lowest, | ||
row4col, | ||
index, | ||
) | ||
output = lax.fori_loop( | ||
0, num_remaining, _find_short_augpath_while_body_inner_for, init | ||
) | ||
( | ||
remaining, | ||
min_value, | ||
costs, | ||
current_row, | ||
u, | ||
v, | ||
shortest_path_costs, | ||
path, | ||
lowest, | ||
row4col, | ||
index, | ||
) = output | ||
|
||
min_value = lowest | ||
# infeasible costs matrix | ||
sink = jnp.where(min_value == jnp.inf, -1, sink) | ||
|
||
state = remaining, index, row4col, sink, current_row, sc, num_remaining | ||
(remaining, sink, current_row, sc, num_remaining) = jax.tree.map( | ||
partial(jnp.where, sink == -1), | ||
_find_short_augpath_while_body_tail_alt(state), | ||
(remaining, sink, current_row, sc, num_remaining), | ||
) | ||
|
||
return ( | ||
costs, | ||
u, | ||
v, | ||
path, | ||
row4col, | ||
current_row, | ||
min_value, | ||
num_remaining, | ||
remaining, | ||
sr, | ||
sc, | ||
shortest_path_costs, | ||
sink, | ||
) | ||
|
||
|
||
def _find_augmenting_path(costs, u, v, path, row4col, current_row): | ||
min_value = 0 | ||
num_remaining = costs.shape[1] | ||
remaining = jnp.arange(costs.shape[1])[::-1] | ||
|
||
sr = jnp.zeros(costs.shape[0], bool) | ||
sc = jnp.zeros(costs.shape[1], bool) | ||
|
||
shortest_path_costs = jnp.full(costs.shape[1], jnp.inf) | ||
|
||
sink = -1 | ||
|
||
init = ( | ||
costs, | ||
u, | ||
v, | ||
path, | ||
row4col, | ||
current_row, | ||
min_value, | ||
num_remaining, | ||
remaining, | ||
sr, | ||
sc, | ||
shortest_path_costs, | ||
sink, | ||
) | ||
output = lax.while_loop( | ||
lambda val: val[-1] == -1, _find_short_augpath_while_body, init | ||
) | ||
( | ||
costs, | ||
u, | ||
v, | ||
path, | ||
row4col, | ||
current_row, | ||
min_value, | ||
num_remaining, | ||
remaining, | ||
sr, | ||
sc, | ||
shortest_path_costs, | ||
sink, | ||
) = output | ||
|
||
return sink, min_value, sr, sc, shortest_path_costs, path | ||
|
||
|
||
def _lsa_body(current_row, val): | ||
costs, u, v, path, row4col, col4row = val | ||
|
||
sink, min_value, sr, sc, shortest_path_costs, path = _find_augmenting_path( | ||
costs, u, v, path, row4col, current_row | ||
) | ||
|
||
u = u.at[current_row].add(min_value) | ||
|
||
mask = sr & (jnp.arange(costs.shape[0]) != current_row) | ||
u = jnp.where(mask, u + min_value - shortest_path_costs[col4row], u) | ||
|
||
v = jnp.where(sc, v + shortest_path_costs - min_value, v) | ||
|
||
def augment(carry): | ||
sink, row4col, col4row, _ = carry | ||
i = path[sink] | ||
row4col = row4col.at[sink].set(i) | ||
col4row, sink = col4row.at[i].set(sink), col4row[i] | ||
breakvar = i == current_row | ||
return sink, row4col, col4row, breakvar | ||
|
||
sink, row4col, col4row, _ = lax.while_loop( | ||
lambda val: ~val[-1], augment, (sink, row4col, col4row, False) | ||
) | ||
|
||
return costs, u, v, path, row4col, col4row |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for the Hungarian algorithm.""" | ||
|
||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
from jax import random, numpy as jnp | ||
import scipy | ||
|
||
from optax.assignment._hungarian_algorithm import hungarian_algorithm | ||
|
||
|
||
class HungarianAlgorithmTest(parameterized.TestCase): | ||
|
||
@parameterized.product( | ||
n=[0, 1, 2, 4, 8, 16], | ||
m=[0, 1, 2, 4, 8, 16], | ||
) | ||
def test(self, n, m): | ||
|
||
def test_hungarian_algorithm(costs): | ||
i, j = hungarian_algorithm(costs) | ||
|
||
r = min(costs.shape) | ||
assert i.shape == (r,) | ||
assert j.shape == (r,) | ||
|
||
assert jnp.issubdtype(i.dtype, jnp.integer) | ||
assert jnp.issubdtype(j.dtype, jnp.integer) | ||
|
||
assert jnp.all(0 <= i) | ||
assert jnp.all(0 <= j) | ||
|
||
assert (i < costs.shape[0]).all() | ||
assert (j < costs.shape[1]).all() | ||
|
||
x = jnp.zeros(costs.shape[0], int).at[i].add(1) | ||
assert (x <= 1).all() | ||
assert x.sum() == r | ||
|
||
y = jnp.zeros(costs.shape[1], int).at[j].add(1) | ||
assert (y <= 1).all() | ||
assert y.sum() == r | ||
|
||
cost_optax = costs[i, j].sum() | ||
|
||
i_scipy, j_scipy = scipy.optimize.linear_sum_assignment(costs) | ||
cost_scipy = costs[i_scipy, j_scipy].sum() | ||
|
||
assert jnp.isclose(cost_optax, cost_scipy) | ||
|
||
key = random.key(0) | ||
costs = random.normal(key, (n, m)) | ||
test_hungarian_algorithm(costs) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |