Skip to content

Commit

Permalink
Add Hungarian algorithm for the linear assignment problem.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Oct 2, 2024
1 parent 474b4fd commit bbd8386
Show file tree
Hide file tree
Showing 3 changed files with 363 additions and 0 deletions.
2 changes: 2 additions & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@
from optax._src.wrappers import ShouldSkipUpdateFunction
from optax._src.wrappers import skip_large_updates
from optax._src.wrappers import skip_not_finite
from optax.assignment._hungarian_algorithm import hungarian_algorithm


# TODO(mtthss): remove tree_utils aliases after updates.
Expand Down Expand Up @@ -340,6 +341,7 @@
"GradientTransformationExtraArgs",
"hinge_loss",
"huber_loss",
"hungarian_algorithm",
"identity",
"incremental_update",
"inject_hyperparams",
Expand Down
291 changes: 291 additions & 0 deletions optax/assignment/_hungarian_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Hungarian algorithm for the linear assignment problem."""

from functools import partial

import jax
from jax import lax, numpy as jnp


def hungarian_algorithm(costs, /):
"""The Hungarian algorithm for the linear assignment problem.
The assignment problem is a fundamental combinatorial optimization problem.
In this problem, there are :math:`n` workers and :math:`m` jobs.
For each worker and job, there is a cost associated with assigning that worker
to that job.
The goal is to assign at most one worker to each job and at most one job to
each worker, in a way that minimizes the total cost of the assignment.
Equivalently, given a weighted complete bipartite graph, the problem is to
find a matching that minimizes the sum of the weights of the edges.
The Hungarian algorithm is an :math:`O(n^3)` algorithm for this problem.
Args:
costs: A matrix of costs.
Returns:
A pair ``(i, j)`` where ``i`` is an array of row indices and ``j`` is an
array of column indices.
The cost of the assignment is ``cost_matrix[i, j].sum()``.
"""

if costs.shape[0] == 0 or costs.shape[1] == 0:
return jnp.zeros(0, int), jnp.zeros(0, int)

transpose = costs.shape[1] < costs.shape[0]

if transpose:
costs = costs.T

costs = costs.astype(float)
u = jnp.zeros(costs.shape[0], costs.dtype)
v = jnp.zeros(costs.shape[1], costs.dtype)

path = jnp.full(costs.shape[1], -1)
col4row = jnp.full(costs.shape[0], -1)
row4col = jnp.full(costs.shape[1], -1)

init = costs, u, v, path, row4col, col4row
costs, u, v, path, row4col, col4row = lax.fori_loop(
0, costs.shape[0], _lsa_body, init
)

if transpose:
i = col4row.argsort()
return col4row[i], i
else:
return jnp.arange(costs.shape[0]), col4row


def _find_short_augpath_while_body_inner_for(it, val):
(
remaining,
min_value,
costs,
i,
u,
v,
shortest_path_costs,
path,
lowest,
row4col,
index,
) = val

j = remaining[it]

r = min_value + costs[i, j] - u[i] - v[j]

path = path.at[j].set(jnp.where(r < shortest_path_costs[j], i, path[j]))

shortest_path_costs = shortest_path_costs.at[j].min(r)

index = jnp.where(
(shortest_path_costs[j] < lowest)
| ((shortest_path_costs[j] == lowest) & (row4col[j] == -1)),
it,
index,
)

lowest = jnp.minimum(lowest, shortest_path_costs[j])

return (
remaining,
min_value,
costs,
i,
u,
v,
shortest_path_costs,
path,
lowest,
row4col,
index,
)


def _find_short_augpath_while_body_tail_alt(val):
remaining, index, row4col, sink, i, sc, num_remaining = val

j = remaining[index]
pred = row4col[j] == -1
sink = jnp.where(pred, j, sink)
i = jnp.where(pred, i, row4col[j])

sc = sc.at[j].set(True)
num_remaining -= 1
remaining = remaining.at[index].set(remaining[num_remaining])

return remaining, sink, i, sc, num_remaining


def _find_short_augpath_while_body(val):
(
costs,
u,
v,
path,
row4col,
current_row,
min_value,
num_remaining,
remaining,
sr,
sc,
shortest_path_costs,
sink,
) = val

index = -1
lowest = jnp.inf
sr = sr.at[current_row].set(True)

init = (
remaining,
min_value,
costs,
current_row,
u,
v,
shortest_path_costs,
path,
lowest,
row4col,
index,
)
output = lax.fori_loop(
0, num_remaining, _find_short_augpath_while_body_inner_for, init
)
(
remaining,
min_value,
costs,
current_row,
u,
v,
shortest_path_costs,
path,
lowest,
row4col,
index,
) = output

min_value = lowest
# infeasible costs matrix
sink = jnp.where(min_value == jnp.inf, -1, sink)

state = remaining, index, row4col, sink, current_row, sc, num_remaining
(remaining, sink, current_row, sc, num_remaining) = jax.tree.map(
partial(jnp.where, sink == -1),
_find_short_augpath_while_body_tail_alt(state),
(remaining, sink, current_row, sc, num_remaining),
)

return (
costs,
u,
v,
path,
row4col,
current_row,
min_value,
num_remaining,
remaining,
sr,
sc,
shortest_path_costs,
sink,
)


def _find_augmenting_path(costs, u, v, path, row4col, current_row):
min_value = 0
num_remaining = costs.shape[1]
remaining = jnp.arange(costs.shape[1])[::-1]

sr = jnp.zeros(costs.shape[0], bool)
sc = jnp.zeros(costs.shape[1], bool)

shortest_path_costs = jnp.full(costs.shape[1], jnp.inf)

sink = -1

init = (
costs,
u,
v,
path,
row4col,
current_row,
min_value,
num_remaining,
remaining,
sr,
sc,
shortest_path_costs,
sink,
)
output = lax.while_loop(
lambda val: val[-1] == -1, _find_short_augpath_while_body, init
)
(
costs,
u,
v,
path,
row4col,
current_row,
min_value,
num_remaining,
remaining,
sr,
sc,
shortest_path_costs,
sink,
) = output

return sink, min_value, sr, sc, shortest_path_costs, path


def _lsa_body(current_row, val):
costs, u, v, path, row4col, col4row = val

sink, min_value, sr, sc, shortest_path_costs, path = _find_augmenting_path(
costs, u, v, path, row4col, current_row
)

u = u.at[current_row].add(min_value)

mask = sr & (jnp.arange(costs.shape[0]) != current_row)
u = jnp.where(mask, u + min_value - shortest_path_costs[col4row], u)

v = jnp.where(sc, v + shortest_path_costs - min_value, v)

def augment(carry):
sink, row4col, col4row, _ = carry
i = path[sink]
row4col = row4col.at[sink].set(i)
col4row, sink = col4row.at[i].set(sink), col4row[i]
breakvar = i == current_row
return sink, row4col, col4row, breakvar

sink, row4col, col4row, _ = lax.while_loop(
lambda val: ~val[-1], augment, (sink, row4col, col4row, False)
)

return costs, u, v, path, row4col, col4row
70 changes: 70 additions & 0 deletions optax/assignment/_hungarian_algorithm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the Hungarian algorithm."""

from absl.testing import absltest
from absl.testing import parameterized
from jax import random, numpy as jnp
import scipy

from optax.assignment._hungarian_algorithm import hungarian_algorithm


class HungarianAlgorithmTest(parameterized.TestCase):

@parameterized.product(
n=[0, 1, 2, 4, 8, 16],
m=[0, 1, 2, 4, 8, 16],
)
def test(self, n, m):

def test_hungarian_algorithm(costs):
i, j = hungarian_algorithm(costs)

r = min(costs.shape)
assert i.shape == (r,)
assert j.shape == (r,)

assert jnp.issubdtype(i.dtype, jnp.integer)
assert jnp.issubdtype(j.dtype, jnp.integer)

assert jnp.all(0 <= i)
assert jnp.all(0 <= j)

assert (i < costs.shape[0]).all()
assert (j < costs.shape[1]).all()

x = jnp.zeros(costs.shape[0], int).at[i].add(1)
assert (x <= 1).all()
assert x.sum() == r

y = jnp.zeros(costs.shape[1], int).at[j].add(1)
assert (y <= 1).all()
assert y.sum() == r

cost_optax = costs[i, j].sum()

i_scipy, j_scipy = scipy.optimize.linear_sum_assignment(costs)
cost_scipy = costs[i_scipy, j_scipy].sum()

assert jnp.isclose(cost_optax, cost_scipy)

key = random.key(0)
costs = random.normal(key, (n, m))
test_hungarian_algorithm(costs)


if __name__ == '__main__':
absltest.main()

0 comments on commit bbd8386

Please sign in to comment.