Skip to content

Commit

Permalink
Add improved version of Hungarian algorithm.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Nov 20, 2024
1 parent 63cdeb4 commit 64cdc1d
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 6 deletions.
2 changes: 1 addition & 1 deletion optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -2482,7 +2482,7 @@ def lbfgs(
... )
... params = optax.apply_updates(params, updates)
... print('Objective function: ', f(params))
Objective function: 7.5166864
Objective function: 7.516686...
Objective function: 7.460699e-14
Objective function: 2.6505726e-28
Objective function: 0.0
Expand Down
2 changes: 1 addition & 1 deletion optax/assignment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@

# pylint:disable=g-importing-member

from optax.assignment._hungarian_algorithm import hungarian_algorithm
from optax.assignment._hungarian_algorithm_v2 import hungarian_algorithm
8 changes: 4 additions & 4 deletions optax/assignment/_hungarian_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import jax
import jax.numpy as jnp
import jax.random as jrd
from optax.assignment import _hungarian_algorithm
from optax.assignment import hungarian_algorithm
import scipy


Expand All @@ -33,7 +33,7 @@ def test_hungarian_algorithm(self, n, m):
key = jrd.key(0)
costs = jrd.normal(key, (n, m))

i, j = _hungarian_algorithm.hungarian_algorithm(costs)
i, j = hungarian_algorithm(costs)

r = min(costs.shape)

Expand Down Expand Up @@ -95,7 +95,7 @@ def test_hungarian_algorithm_vmap(self, k, n, m):
costs = jrd.normal(key, (k, n, m))

with self.subTest('works under vmap'):
i, j = jax.vmap(_hungarian_algorithm.hungarian_algorithm)(costs)
i, j = jax.vmap(hungarian_algorithm)(costs)

r = min(costs.shape[1:])

Expand All @@ -110,7 +110,7 @@ def test_hungarian_algorithm_jit(self):
costs = jrd.normal(key, (20, 30))

with self.subTest('works under jit'):
i, j = jax.jit(_hungarian_algorithm.hungarian_algorithm)(costs)
i, j = jax.jit(hungarian_algorithm)(costs)

r = min(costs.shape)

Expand Down
File renamed without changes.
192 changes: 192 additions & 0 deletions optax/assignment/_hungarian_algorithm_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Hungarian algorithm for the linear assignment problem."""

import jax
from jax import lax, numpy as jnp


def _masked_argmin(array, mask):
array = jnp.where(mask, array, jnp.inf)
assert isinstance(array, jax.Array)
return jnp.argmin(array)


def hungarian_algorithm(cost_matrix):
r"""The Hungarian algorithm for the linear assignment problem.
In `this problem <https://en.wikipedia.org/wiki/Linear_assignment_problem>`_,
we are given an :math:`n \times m` cost matrix. The goal is to compute an
assignment, i.e. a set of pairs of rows and columns, in such a way that:
- At most one column is assigned to each row.
- At most one row is assigned to each column.
- The total number of assignments is :math:`\min(n, m)`.
- The assignment minimizes the sum of costs.
Equivalently, given a weighted complete bipartite graph, the problem is to
find a maximum-cardinality matching that minimizes the sum of the weights of
the edges included in the matching.
Formally, the problem is as follows. Given :math:`C \in \mathbb{R}^{n \times m
}`, solve the following `integer linear program <https://en.wikipedia.org/wiki
/Integer_linear_program>`_:
.. math::
\begin{align*}
\text{minimize} \quad & \sum_{i \in [n]} \sum_{j \in [m]} C_{ij} X_{ij}
\\ \text{subject to} \quad
& X_{ij} \in \{0, 1\} & \forall i \in [n], j \in [m] \\
& \sum_{i \in [n]} X_{ij} \leq 1 & \forall j \in [m] \\
& \sum_{j \in [m]} X_{ij} \leq 1 & \forall i \in [n] \\
& \sum_{i \in [n]} \sum_{j \in [m]} X_{ij} = \min(n, m)
\end{align*}
The `Hungarian algorithm <https://en.wikipedia.org/wiki/Hungarian_algorithm>`_
is a cubic-time algorithm that solves this problem.
This implementation is based on that of the Scenic library (see references).
Args:
cost_matrix: A matrix of costs.
Returns:
A pair ``(i, j)`` where ``i`` is an array of row indices and ``j`` is an
array of column indices.
The cost of the assignment is ``cost_matrix[i, j].sum()``.
Examples:
>>> import optax
>>> from jax import numpy as jnp
>>> cost = jnp.array(
... [
... [8, 4, 7],
... [5, 2, 3],
... [9, 6, 7],
... [9, 4, 8],
... ])
>>> i, j = optax.assignment.hungarian_algorithm(cost)
>>> print("cost:", cost[i, j].sum())
cost: 15
>>> cost = jnp.array(
... [
... [90, 80, 75, 70],
... [35, 85, 55, 65],
... [125, 95, 90, 95],
... [45, 110, 95, 115],
... [50, 100, 90, 100],
... ])
>>> i, j = optax.assignment.hungarian_algorithm(cost)
>>> print("cost:", cost[i, j].sum())
cost: 265
References:
Dehghani et al., `Scenic: A JAX Library for Computer Vision Research and
Beyond <https://arxiv.org/abs/2110.11403>`_, 2022
"""

def row_fn(state, row):

def dfs_body_fn(state):
u, v, used, minv, path, col = state

# mark column as used
used = used.at[col].set(True)
unused_slice = ~used[1:]

row = parent[col]

# update minv and path to it
cur = cost_matrix[row - 1, :] - u[row] - v[1:]
cur = jnp.where(unused_slice, cur, jnp.inf)
path = jnp.where(cur < minv, col, path)
minv = jnp.where(cur < minv, cur, minv) # type: ignore

# mask out the visited rows
col = _masked_argmin(minv, unused_slice) + 1
delta = minv.min(where=unused_slice, initial=jnp.inf)

# update potentials
indices = jnp.where(used, parent, rows + 1) # out-of-bounds
u = u.at[indices].add(delta)
v = jnp.where(used, v - delta, v)
minv = jnp.where(unused_slice, minv - delta, minv)

return u, v, used, minv, path, col

def dfs_cond_fn(state):
_, _, _, _, _, col = state
return parent[col] != 0

def back_body_fn(state):
parent, old_col = state
new_col = path[old_col - 1]
parent = parent.at[old_col].set(parent[new_col])
return parent, new_col

def back_cond_fn(state):
_, col = state
return col != 0

u, v, parent = state
parent = parent.at[0].set(row + 1)

# run the inner while loop (i.e. DFS)
path = jnp.zeros(cols, int)
used = jnp.zeros(cols + 1, bool)
minv = jnp.full(cols, jnp.inf) # support array
col = 0

# update parents based on the DFS path
state = u, v, used, minv, path, col
state = lax.while_loop(dfs_cond_fn, dfs_body_fn, state)
u, v, _, _, path, col = state

# backtrack the DFS path
parent, _ = lax.while_loop(back_cond_fn, back_body_fn, (parent, col))

return (u, v, parent), None

if cost_matrix.shape[0] == 0 or cost_matrix.shape[1] == 0:
return jnp.zeros(0, int), jnp.zeros(0, int)

transpose = cost_matrix.shape[0] > cost_matrix.shape[1]

if transpose:
cost_matrix = cost_matrix.T

rows, cols = cost_matrix.shape

u = jnp.zeros(rows + 2) # row potential
v = jnp.zeros(cols + 1) # column potential
parent = jnp.zeros(cols + 1, int) # maps columns to rows

# loop over the rows of the cost matrix
(u, v, parent), _ = lax.scan(row_fn, (u, v, parent), jnp.arange(rows))
# -v[0] is the matching cost

# top_k is costly, so skip it when possible (i.e. for square matrices)
if rows == cols:
parent, indices = parent[1:], jnp.arange(rows)
else:
parent, indices = lax.top_k(parent[1:], rows)

parent -= 1 # switch back to 0-based indexing

if transpose:
return indices, parent

return parent, indices

0 comments on commit 64cdc1d

Please sign in to comment.