diff --git a/gam/clustering.py b/gam/clustering.py index c909dd5..504deb8 100644 --- a/gam/clustering.py +++ b/gam/clustering.py @@ -807,6 +807,9 @@ def _swap_pairs( i = a_swap[1] d_ji = d[:, i] + E_batch = E[idx_ref] + D_batch = D[idx_ref] + if h_i == "h": if isinstance(X, da.Array): d_jh = dask_distance.cdist( @@ -818,14 +821,14 @@ def _swap_pairs( X[idx_ref, :], X[h, :].reshape(1, -1), metric=dist_func ).squeeze() K_jih = np.zeros(self.batchsize) - diff_ji = d_ji[idx_ref] - D[idx_ref] + diff_ji = d_ji[idx_ref] - D_batch idx = np.where(diff_ji > 0) - diff_jh = d_jh - D[idx_ref] + diff_jh = d_jh - D_batch K_jih[idx] = np.minimum(diff_jh[idx], 0) idx = np.where(diff_ji == 0) - K_jih[idx] = np.minimum(d_jh[idx], E[idx]) - D[idx] + K_jih[idx] = np.minimum(d_jh[idx], E_batch[idx]) - D_batch[idx] # base-line update of mu and sigma mu_x[h, i] = ((n_used_ref * mu_x[h, i]) + np.sum(K_jih)) / ( @@ -1011,3 +1014,4 @@ def _swap_bandit(self, X, centers, dist_func, max_iter, tol, verbose): # our best swap would degrade the clustering (min Tih > 0) current_iteration = current_iteration + 1 return centers + \ No newline at end of file