From a20964df71e297408ea9cc01c5deefa838b7e70f Mon Sep 17 00:00:00 2001 From: Ashray Manepalli Date: Mon, 5 Aug 2024 13:29:15 -0700 Subject: [PATCH 1/2] pushing fix for batch idx issue --- gam/clustering.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/gam/clustering.py b/gam/clustering.py index c909dd5..f841944 100644 --- a/gam/clustering.py +++ b/gam/clustering.py @@ -807,6 +807,9 @@ def _swap_pairs( i = a_swap[1] d_ji = d[:, i] + E_batch = E[idx_ref] + D_batch = D[idx_ref] + if h_i == "h": if isinstance(X, da.Array): d_jh = dask_distance.cdist( @@ -818,14 +821,14 @@ def _swap_pairs( X[idx_ref, :], X[h, :].reshape(1, -1), metric=dist_func ).squeeze() K_jih = np.zeros(self.batchsize) - diff_ji = d_ji[idx_ref] - D[idx_ref] + diff_ji = d_ji[idx_ref] - D_batch idx = np.where(diff_ji > 0) - diff_jh = d_jh - D[idx_ref] + diff_jh = d_jh - D_batch K_jih[idx] = np.minimum(diff_jh[idx], 0) idx = np.where(diff_ji == 0) - K_jih[idx] = np.minimum(d_jh[idx], E[idx]) - D[idx] + K_jih[idx] = np.minimum(d_jh[idx], E_batch[idx]) - D_batch[idx] # base-line update of mu and sigma mu_x[h, i] = ((n_used_ref * mu_x[h, i]) + np.sum(K_jih)) / ( @@ -1010,4 +1013,4 @@ def _swap_bandit(self, X, centers, dist_func, max_iter, tol, verbose): print("\tNO Swap - ", i_swap, h_swap, Tih_min) # our best swap would degrade the clustering (min Tih > 0) current_iteration = current_iteration + 1 - return centers + return centers \ No newline at end of file From 00faca8019d3af5db80e88f7597f98661b4195cb Mon Sep 17 00:00:00 2001 From: Ashray Manepalli Date: Mon, 5 Aug 2024 13:30:09 -0700 Subject: [PATCH 2/2] adding final newline --- gam/clustering.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gam/clustering.py b/gam/clustering.py index f841944..504deb8 100644 --- a/gam/clustering.py +++ b/gam/clustering.py @@ -1013,4 +1013,5 @@ def _swap_bandit(self, X, centers, dist_func, max_iter, tol, verbose): print("\tNO Swap - ", i_swap, h_swap, Tih_min) # our best swap would degrade the clustering (min Tih > 0) current_iteration = current_iteration + 1 - return centers \ No newline at end of file + return centers + \ No newline at end of file