Skip to content

Commit b70d175

Browse files
feat: enhance Hausdorff distance computation by utilizing statistics for truth and prediction labels
1 parent be461a3 commit b70d175

2 files changed

Lines changed: 119 additions & 48 deletions

File tree

src/cellmap_segmentation_challenge/evaluate.py

Lines changed: 57 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import gc
23
import json
34
import os
45
import shutil
@@ -8,12 +9,13 @@
89
import numpy as np
910
import zarr
1011
from scipy.spatial.distance import dice
12+
from scipy.ndimage import distance_transform_edt
1113
from fastremap import remap, unique
1214
import cc3d
1315
from cc3d.types import StatisticsDict, StatisticsSlicesDict
1416

1517
from zarr.errors import PathNotFoundError
16-
from sklearn.metrics import accuracy_score, jaccard_score
18+
from sklearn.metrics import jaccard_score
1719
from tqdm import tqdm
1820
from upath import UPath
1921

@@ -257,12 +259,16 @@ def optimized_hausdorff_distances(
257259
return np.empty((0,), dtype=np.float32)
258260

259261
voxel_size = np.asarray(voxel_size, dtype=np.float64)
262+
truth_stats = cc3d.statistics(truth_label)
263+
pred_stats = cc3d.statistics(pred_label)
260264

261265
def get_distance(i: int):
262266
tid = int(truth_ids[i])
263267
h_dist = compute_hausdorff_distance_roi(
264268
truth_label,
269+
truth_stats,
265270
pred_label,
271+
pred_stats,
266272
tid,
267273
voxel_size,
268274
hausdorff_distance_max,
@@ -285,62 +291,59 @@ def get_distance(i: int):
285291
return dists
286292

287293

288-
def bbox_for_label_cc3d(
294+
def bbox_for_label(
289295
stats: StatisticsDict | StatisticsSlicesDict,
290-
label_vol: np.ndarray,
296+
ndim: int,
291297
label_id: int,
292298
):
293299
"""
294300
Try to get bbox without allocating a full boolean mask using cc3d statistics.
295301
Falls back to mask-based bbox if cc3d doesn't provide expected fields.
296302
Returns (mins, maxs) inclusive-exclusive in voxel indices, or None if missing.
297303
"""
298-
try:
299-
# stats = cc3d.statistics(label_vol)
300-
# cc3d.statistics usually returns dict-like with keys per label id.
301-
# There are multiple API variants; try common patterns.
302-
if "bounding_boxes" in stats:
303-
bb = stats["bounding_boxes"].get(label_id, None)
304-
if bb is None:
305-
return None
306-
# bb might be (z0,z1,y0,y1,x0,x1) with end exclusive
307-
ndim = label_vol.ndim
304+
# stats = cc3d.statistics(label_vol)
305+
# cc3d.statistics usually returns dict-like with keys per label id.
306+
# There are multiple API variants; try common patterns.
307+
if "bounding_boxes" in stats:
308+
# bounding_boxes is a list where index corresponds to label_id
309+
bounding_boxes = stats["bounding_boxes"]
310+
if label_id >= len(bounding_boxes):
311+
return None
312+
bb = bounding_boxes[label_id]
313+
if bb is None:
314+
return None
315+
# bb is a tuple of slices, convert to (mins, maxs)
316+
if isinstance(bb, tuple) and all(isinstance(s, slice) for s in bb):
317+
mins = [s.start for s in bb]
318+
maxs = [s.stop for s in bb]
319+
return mins, maxs
320+
# bb might be (z0,z1,y0,y1,x0,x1) with end exclusive
321+
mins = [bb[2 * k] for k in range(ndim)]
322+
maxs = [bb[2 * k + 1] for k in range(ndim)]
323+
return mins, maxs
324+
325+
if label_id in stats:
326+
s = stats[label_id]
327+
if "bounding_box" in s:
328+
bb = s["bounding_box"]
308329
mins = [bb[2 * k] for k in range(ndim)]
309330
maxs = [bb[2 * k + 1] for k in range(ndim)]
310331
return mins, maxs
311332

312-
if label_id in stats:
313-
s = stats[label_id]
314-
if "bounding_box" in s:
315-
bb = s["bounding_box"]
316-
ndim = label_vol.ndim
317-
mins = [bb[2 * k] for k in range(ndim)]
318-
maxs = [bb[2 * k + 1] for k in range(ndim)]
319-
return mins, maxs
320-
except Exception:
321-
pass
322-
323-
# Fallback: mask-based (allocates a temporary bool)
324-
coords = np.where(label_vol == label_id)
325-
if len(coords) == 0 or coords[0].size == 0:
326-
return None
327-
mins = [int(c.min()) for c in coords]
328-
maxs = [int(c.max()) + 1 for c in coords]
329-
return mins, maxs
330-
331333

332334
def roi_slices_for_pair(
333-
truth_label: np.ndarray,
334-
pred_label: np.ndarray,
335+
truth_stats: StatisticsDict | StatisticsSlicesDict,
336+
pred_stats: StatisticsDict | StatisticsSlicesDict,
335337
tid: int,
336338
voxel_size,
339+
ndim: int,
340+
shape: tuple[int, ...],
337341
max_distance: float,
338342
):
339343
"""
340344
ROI = union(bbox(truth==tid), bbox(pred==tid)) padded by P derived from max_distance.
341345
Returns tuple of slices suitable for numpy indexing.
342346
"""
343-
ndim = truth_label.ndim
344347
vs = np.asarray(voxel_size, dtype=float)
345348
if vs.size != ndim:
346349
# tolerate vs longer (e.g. includes channel), take last ndim
@@ -349,14 +352,12 @@ def roi_slices_for_pair(
349352
# padding per axis in voxels
350353
pad = np.ceil(max_distance / vs).astype(int) + 2
351354

352-
truth_stats = cc3d.statistics(truth_label)
353-
tb = bbox_for_label_cc3d(truth_stats, truth_label, tid)
355+
tb = bbox_for_label(truth_stats, ndim, tid)
354356
if tb is None:
355357
return None # should not happen for tid from truth_ids
356358

357359
tmins, tmaxs = tb
358-
pred_stats = cc3d.statistics(pred_label)
359-
pb = bbox_for_label_cc3d(pred_stats, pred_label, tid)
360+
pb = bbox_for_label(pred_stats, ndim, tid)
360361
if pb is None:
361362
pmins, pmaxs = tmins, tmaxs
362363
else:
@@ -366,7 +367,6 @@ def roi_slices_for_pair(
366367
maxs = [max(tmaxs[d], pmaxs[d]) for d in range(ndim)]
367368

368369
# expand and clamp
369-
shape = truth_label.shape
370370
out_slices = []
371371
for d in range(ndim):
372372
a = max(0, mins[d] - int(pad[d]))
@@ -377,7 +377,9 @@ def roi_slices_for_pair(
377377

378378
def compute_hausdorff_distance_roi(
379379
truth_label: np.ndarray,
380+
truth_stats: StatisticsDict | StatisticsSlicesDict,
380381
pred_label: np.ndarray,
382+
pred_stats: StatisticsDict | StatisticsSlicesDict,
381383
tid: int,
382384
voxel_size,
383385
max_distance: float,
@@ -388,9 +390,17 @@ def compute_hausdorff_distance_roi(
388390
Same metric as compute_hausdorff_distance(), but operates on an ROI slice
389391
and builds masks only inside ROI.
390392
"""
391-
from scipy.ndimage import distance_transform_edt
393+
ndim = truth_label.ndim
392394

393-
roi = roi_slices_for_pair(truth_label, pred_label, tid, voxel_size, max_distance)
395+
roi = roi_slices_for_pair(
396+
truth_stats,
397+
pred_stats,
398+
tid,
399+
voxel_size,
400+
ndim,
401+
truth_label.shape,
402+
max_distance,
403+
)
394404
if roi is None:
395405
return float(max_distance)
396406

@@ -407,7 +417,6 @@ def compute_hausdorff_distance_roi(
407417
if a_n == 0 or b_n == 0:
408418
return float(max_distance)
409419

410-
ndim = truth_label.ndim
411420
vs = np.asarray(voxel_size, dtype=np.float64)
412421
if vs.size != ndim:
413422
vs = vs[-ndim:]
@@ -497,7 +506,7 @@ def score_instance(
497506

498507
# Compute the scores
499508
logging.info("Computing accuracy score...")
500-
accuracy = accuracy_score(truth_label.ravel(), pred_label.ravel())
509+
accuracy = float((truth_label == pred_label).mean())
501510
# When there are no Hausdorff distances, use np.inf so that
502511
# normalize_distance(hausdorff_dist, voxel_size) returns 0.0. This encodes
503512
# the absence of matched instances and ensures the combined_score is 0.0.
@@ -612,6 +621,7 @@ def score_label(
612621
mask = zarr.open(mask_path.path, mode="r")[:]
613622
pred_label = pred_label * mask
614623
truth_label = truth_label * mask
624+
del mask
615625

616626
# Compute the scores
617627
if label_name in instance_classes:
@@ -628,6 +638,9 @@ def score_label(
628638
results["num_voxels"] = int(np.prod(truth_label.shape))
629639
results["voxel_size"] = crop.voxel_size
630640
results["is_missing"] = False
641+
# drop big arrays before returning
642+
del truth_label, pred_label, truth_label_ds
643+
gc.collect()
631644
return crop_name, label_name, results
632645

633646

tests/test_evaluate_metrics.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import dataclass
22
from pathlib import Path
33

4+
import cc3d
45
import numpy as np
56
import zarr
67
from fastremap import unique
@@ -114,8 +115,18 @@ def test_roi_hausdorff_identical_instance_is_zero():
114115
truth[2, 2] = tid
115116
pred[2, 2] = tid
116117

118+
truth_stats = cc3d.statistics(truth)
119+
pred_stats = cc3d.statistics(pred)
120+
117121
d = ev.compute_hausdorff_distance_roi(
118-
truth, pred, tid, voxel_size=(1.0, 1.0), max_distance=10.0, method="standard"
122+
truth,
123+
truth_stats,
124+
pred,
125+
pred_stats,
126+
tid,
127+
voxel_size=(1.0, 1.0),
128+
max_distance=10.0,
129+
method="standard",
119130
)
120131
assert np.isclose(d, 0.0)
121132

@@ -133,9 +144,14 @@ def test_roi_hausdorff_matches_full_reference_standard_and_modified():
133144
voxel_size = (1.0, 1.0)
134145
max_distance = 100.0
135146

147+
truth_stats = cc3d.statistics(truth)
148+
pred_stats = cc3d.statistics(pred)
149+
136150
d_roi = ev.compute_hausdorff_distance_roi(
137151
truth,
152+
truth_stats,
138153
pred,
154+
pred_stats,
139155
tid,
140156
voxel_size=voxel_size,
141157
max_distance=max_distance,
@@ -154,7 +170,9 @@ def test_roi_hausdorff_matches_full_reference_standard_and_modified():
154170

155171
d_roi_mod = ev.compute_hausdorff_distance_roi(
156172
truth,
173+
truth_stats,
157174
pred,
175+
pred_stats,
158176
tid,
159177
voxel_size=voxel_size,
160178
max_distance=max_distance,
@@ -188,9 +206,14 @@ def test_roi_hausdorff_percentile_matches_full_reference():
188206
voxel_size = (1.0, 1.0)
189207
max_distance = 100.0
190208

209+
truth_stats = cc3d.statistics(truth)
210+
pred_stats = cc3d.statistics(pred)
211+
191212
d_roi = ev.compute_hausdorff_distance_roi(
192213
truth,
214+
truth_stats,
193215
pred,
216+
pred_stats,
194217
tid,
195218
voxel_size=voxel_size,
196219
max_distance=max_distance,
@@ -221,16 +244,32 @@ def test_roi_hausdorff_empty_sets_and_missing_instance():
221244

222245
# present only in truth -> max_distance
223246
truth[0, 0] = tid
247+
truth_stats = cc3d.statistics(truth)
248+
pred_stats = cc3d.statistics(pred)
224249
d1 = ev.compute_hausdorff_distance_roi(
225-
truth, pred, tid, voxel_size=voxel_size, max_distance=max_distance
250+
truth,
251+
truth_stats,
252+
pred,
253+
pred_stats,
254+
tid,
255+
voxel_size=voxel_size,
256+
max_distance=max_distance,
226257
)
227258
assert np.isclose(d1, max_distance)
228259

229260
# present only in pred -> max_distance
230261
truth[0, 0] = 0
231262
pred[0, 0] = tid
263+
truth_stats = cc3d.statistics(truth)
264+
pred_stats = cc3d.statistics(pred)
232265
d2 = ev.compute_hausdorff_distance_roi(
233-
truth, pred, tid, voxel_size=voxel_size, max_distance=max_distance
266+
truth,
267+
truth_stats,
268+
pred,
269+
pred_stats,
270+
tid,
271+
voxel_size=voxel_size,
272+
max_distance=max_distance,
234273
)
235274
assert np.isclose(d2, max_distance)
236275

@@ -248,9 +287,14 @@ def test_roi_hausdorff_clips_to_max_distance_matches_reference():
248287
voxel_size = (1.0, 1.0)
249288
max_distance = 3.0
250289

290+
truth_stats = cc3d.statistics(truth)
291+
pred_stats = cc3d.statistics(pred)
292+
251293
d_roi = ev.compute_hausdorff_distance_roi(
252294
truth,
295+
truth_stats,
253296
pred,
297+
pred_stats,
254298
tid,
255299
voxel_size=voxel_size,
256300
max_distance=max_distance,
@@ -282,9 +326,14 @@ def test_roi_hausdorff_anisotropic_voxel_size_matches_reference():
282326
voxel_size = (2.0, 0.5) # physical distance = 2 * 0.5 = 1.0
283327
max_distance = 100.0
284328

329+
truth_stats = cc3d.statistics(truth)
330+
pred_stats = cc3d.statistics(pred)
331+
285332
d_roi = ev.compute_hausdorff_distance_roi(
286333
truth,
334+
truth_stats,
287335
pred,
336+
pred_stats,
288337
tid,
289338
voxel_size=voxel_size,
290339
max_distance=max_distance,
@@ -323,8 +372,17 @@ def _fake_roi(*args, **kwargs):
323372
# Make tid present in exactly one volume so we don't trigger the "both absent -> 0" shortcut
324373
truth[2, 2] = tid # present in truth only
325374

375+
truth_stats = cc3d.statistics(truth)
376+
pred_stats = cc3d.statistics(pred)
377+
326378
d = ev.compute_hausdorff_distance_roi(
327-
truth, pred, tid, voxel_size=(1.0, 1.0), max_distance=7.0
379+
truth,
380+
truth_stats,
381+
pred,
382+
pred_stats,
383+
tid,
384+
voxel_size=(1.0, 1.0),
385+
max_distance=7.0,
328386
)
329387
assert np.isclose(d, 7.0)
330388

0 commit comments

Comments
 (0)