Skip to content

Commit 8daa864

Browse files
committed
perf: optimize NaN infill step
1 parent fc3de1d commit 8daa864

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

soil_id/global_soil.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -862,16 +862,24 @@ def rank_soils_global(
862862

863863
# Infill NaN data
864864
for idx, dis_mat in enumerate(dis_mat_list):
865-
soil_slice = soil_matrix.iloc[idx]
866-
for j in range(len(dis_mat)):
867-
for k in range(len(dis_mat[j])):
868-
if np.isnan(dis_mat[j, k]):
869-
if (soil_slice[j] and not soil_slice[k]) or (
870-
not soil_slice[j] and soil_slice[k]
871-
):
872-
dis_mat[j, k] = dis_max
873-
elif not (soil_slice[j] or soil_slice[k]):
874-
dis_mat[j, k] = 0
865+
soil_slice = soil_matrix.iloc[idx].to_numpy(dtype=bool)
866+
867+
# Mask of NaNs in dis_mat
868+
nan_mask = np.isnan(dis_mat)
869+
870+
# Broadcast soil slice to row and column vectors
871+
soil_row = soil_slice[:, np.newaxis] # column vector
872+
soil_col = soil_slice[np.newaxis, :] # row vector
873+
874+
# Matrix where one is soil and the other isn't
875+
mismatch_mask = (soil_row & ~soil_col) | (~soil_row & soil_col)
876+
877+
# Matrix where neither is soil
878+
nonsoil_mask = ~soil_row & ~soil_col
879+
880+
# Set values for NaNs based on condition
881+
dis_mat[nan_mask & mismatch_mask] = dis_max
882+
dis_mat[nan_mask & nonsoil_mask] = 0
875883

876884
# Weighted average of depth-wise dissimilarity matrices
877885
dis_mat_list_masked = np.ma.MaskedArray(dis_mat_list, mask=np.isnan(dis_mat_list))

0 commit comments

Comments
 (0)