11import argparse
2+ import gc
23import json
34import os
45import shutil
89import numpy as np
910import zarr
1011from scipy .spatial .distance import dice
12+ from scipy .ndimage import distance_transform_edt
1113from fastremap import remap , unique
1214import cc3d
1315from cc3d .types import StatisticsDict , StatisticsSlicesDict
1416
1517from zarr .errors import PathNotFoundError
16- from sklearn .metrics import accuracy_score , jaccard_score
18+ from sklearn .metrics import jaccard_score
1719from tqdm import tqdm
1820from upath import UPath
1921
@@ -257,12 +259,16 @@ def optimized_hausdorff_distances(
257259 return np .empty ((0 ,), dtype = np .float32 )
258260
259261 voxel_size = np .asarray (voxel_size , dtype = np .float64 )
262+ truth_stats = cc3d .statistics (truth_label )
263+ pred_stats = cc3d .statistics (pred_label )
260264
261265 def get_distance (i : int ):
262266 tid = int (truth_ids [i ])
263267 h_dist = compute_hausdorff_distance_roi (
264268 truth_label ,
269+ truth_stats ,
265270 pred_label ,
271+ pred_stats ,
266272 tid ,
267273 voxel_size ,
268274 hausdorff_distance_max ,
@@ -285,62 +291,59 @@ def get_distance(i: int):
285291 return dists
286292
287293
288- def bbox_for_label_cc3d (
294+ def bbox_for_label (
289295 stats : StatisticsDict | StatisticsSlicesDict ,
290- label_vol : np . ndarray ,
296+ ndim : int ,
291297 label_id : int ,
292298):
293299 """
294300 Try to get bbox without allocating a full boolean mask using cc3d statistics.
295301 Falls back to mask-based bbox if cc3d doesn't provide expected fields.
296302 Returns (mins, maxs) inclusive-exclusive in voxel indices, or None if missing.
297303 """
298- try :
299- # stats = cc3d.statistics(label_vol)
300- # cc3d.statistics usually returns dict-like with keys per label id.
301- # There are multiple API variants; try common patterns.
302- if "bounding_boxes" in stats :
303- bb = stats ["bounding_boxes" ].get (label_id , None )
304- if bb is None :
305- return None
306- # bb might be (z0,z1,y0,y1,x0,x1) with end exclusive
307- ndim = label_vol .ndim
304+ # stats = cc3d.statistics(label_vol)
305+ # cc3d.statistics usually returns dict-like with keys per label id.
306+ # There are multiple API variants; try common patterns.
307+ if "bounding_boxes" in stats :
308+ # bounding_boxes is a list where index corresponds to label_id
309+ bounding_boxes = stats ["bounding_boxes" ]
310+ if label_id >= len (bounding_boxes ):
311+ return None
312+ bb = bounding_boxes [label_id ]
313+ if bb is None :
314+ return None
315+ # bb is a tuple of slices, convert to (mins, maxs)
316+ if isinstance (bb , tuple ) and all (isinstance (s , slice ) for s in bb ):
317+ mins = [s .start for s in bb ]
318+ maxs = [s .stop for s in bb ]
319+ return mins , maxs
320+ # bb might be (z0,z1,y0,y1,x0,x1) with end exclusive
321+ mins = [bb [2 * k ] for k in range (ndim )]
322+ maxs = [bb [2 * k + 1 ] for k in range (ndim )]
323+ return mins , maxs
324+
325+ if label_id in stats :
326+ s = stats [label_id ]
327+ if "bounding_box" in s :
328+ bb = s ["bounding_box" ]
308329 mins = [bb [2 * k ] for k in range (ndim )]
309330 maxs = [bb [2 * k + 1 ] for k in range (ndim )]
310331 return mins , maxs
311332
312- if label_id in stats :
313- s = stats [label_id ]
314- if "bounding_box" in s :
315- bb = s ["bounding_box" ]
316- ndim = label_vol .ndim
317- mins = [bb [2 * k ] for k in range (ndim )]
318- maxs = [bb [2 * k + 1 ] for k in range (ndim )]
319- return mins , maxs
320- except Exception :
321- pass
322-
323- # Fallback: mask-based (allocates a temporary bool)
324- coords = np .where (label_vol == label_id )
325- if len (coords ) == 0 or coords [0 ].size == 0 :
326- return None
327- mins = [int (c .min ()) for c in coords ]
328- maxs = [int (c .max ()) + 1 for c in coords ]
329- return mins , maxs
330-
331333
332334def roi_slices_for_pair (
333- truth_label : np . ndarray ,
334- pred_label : np . ndarray ,
335+ truth_stats : StatisticsDict | StatisticsSlicesDict ,
336+ pred_stats : StatisticsDict | StatisticsSlicesDict ,
335337 tid : int ,
336338 voxel_size ,
339+ ndim : int ,
340+ shape : tuple [int , ...],
337341 max_distance : float ,
338342):
339343 """
340344 ROI = union(bbox(truth==tid), bbox(pred==tid)) padded by P derived from max_distance.
341345 Returns tuple of slices suitable for numpy indexing.
342346 """
343- ndim = truth_label .ndim
344347 vs = np .asarray (voxel_size , dtype = float )
345348 if vs .size != ndim :
346349 # tolerate vs longer (e.g. includes channel), take last ndim
@@ -349,14 +352,12 @@ def roi_slices_for_pair(
349352 # padding per axis in voxels
350353 pad = np .ceil (max_distance / vs ).astype (int ) + 2
351354
352- truth_stats = cc3d .statistics (truth_label )
353- tb = bbox_for_label_cc3d (truth_stats , truth_label , tid )
355+ tb = bbox_for_label (truth_stats , ndim , tid )
354356 if tb is None :
355357 return None # should not happen for tid from truth_ids
356358
357359 tmins , tmaxs = tb
358- pred_stats = cc3d .statistics (pred_label )
359- pb = bbox_for_label_cc3d (pred_stats , pred_label , tid )
360+ pb = bbox_for_label (pred_stats , ndim , tid )
360361 if pb is None :
361362 pmins , pmaxs = tmins , tmaxs
362363 else :
@@ -366,7 +367,6 @@ def roi_slices_for_pair(
366367 maxs = [max (tmaxs [d ], pmaxs [d ]) for d in range (ndim )]
367368
368369 # expand and clamp
369- shape = truth_label .shape
370370 out_slices = []
371371 for d in range (ndim ):
372372 a = max (0 , mins [d ] - int (pad [d ]))
@@ -377,7 +377,9 @@ def roi_slices_for_pair(
377377
378378def compute_hausdorff_distance_roi (
379379 truth_label : np .ndarray ,
380+ truth_stats : StatisticsDict | StatisticsSlicesDict ,
380381 pred_label : np .ndarray ,
382+ pred_stats : StatisticsDict | StatisticsSlicesDict ,
381383 tid : int ,
382384 voxel_size ,
383385 max_distance : float ,
@@ -388,9 +390,17 @@ def compute_hausdorff_distance_roi(
388390 Same metric as compute_hausdorff_distance(), but operates on an ROI slice
389391 and builds masks only inside ROI.
390392 """
391- from scipy . ndimage import distance_transform_edt
393+ ndim = truth_label . ndim
392394
393- roi = roi_slices_for_pair (truth_label , pred_label , tid , voxel_size , max_distance )
395+ roi = roi_slices_for_pair (
396+ truth_stats ,
397+ pred_stats ,
398+ tid ,
399+ voxel_size ,
400+ ndim ,
401+ truth_label .shape ,
402+ max_distance ,
403+ )
394404 if roi is None :
395405 return float (max_distance )
396406
@@ -407,7 +417,6 @@ def compute_hausdorff_distance_roi(
407417 if a_n == 0 or b_n == 0 :
408418 return float (max_distance )
409419
410- ndim = truth_label .ndim
411420 vs = np .asarray (voxel_size , dtype = np .float64 )
412421 if vs .size != ndim :
413422 vs = vs [- ndim :]
@@ -497,7 +506,7 @@ def score_instance(
497506
498507 # Compute the scores
499508 logging .info ("Computing accuracy score..." )
500- accuracy = accuracy_score ( truth_label . ravel (), pred_label . ravel ())
509+ accuracy = float (( truth_label == pred_label ). mean ())
501510 # When there are no Hausdorff distances, use np.inf so that
502511 # normalize_distance(hausdorff_dist, voxel_size) returns 0.0. This encodes
503512 # the absence of matched instances and ensures the combined_score is 0.0.
@@ -612,6 +621,7 @@ def score_label(
612621 mask = zarr .open (mask_path .path , mode = "r" )[:]
613622 pred_label = pred_label * mask
614623 truth_label = truth_label * mask
624+ del mask
615625
616626 # Compute the scores
617627 if label_name in instance_classes :
@@ -628,6 +638,9 @@ def score_label(
628638 results ["num_voxels" ] = int (np .prod (truth_label .shape ))
629639 results ["voxel_size" ] = crop .voxel_size
630640 results ["is_missing" ] = False
641+ # drop big arrays before returning
642+ del truth_label , pred_label , truth_label_ds
643+ gc .collect ()
631644 return crop_name , label_name , results
632645
633646
0 commit comments