Skip to content

Commit db1082f

Browse files
feat: introduce ratio cutoff function and update instance ratio handling in evaluation metrics
1 parent 6af6745 commit db1082f

3 files changed

Lines changed: 87 additions & 75 deletions

File tree

src/cellmap_segmentation_challenge/evaluate.py

Lines changed: 78 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import zarr
1111
from scipy.spatial.distance import dice
1212
from scipy.ndimage import distance_transform_edt
13-
from fastremap import remap, unique
13+
from fastremap import remap, unique, renumber
1414
import cc3d
1515
from cc3d.types import StatisticsDict, StatisticsSlicesDict
1616

@@ -39,10 +39,22 @@
3939
MAX_SEMANTIC_THREADS = int(os.getenv("MAX_SEMANTIC_THREADS", 25))
4040
PER_INSTANCE_THREADS = int(os.getenv("PER_INSTANCE_THREADS", 25))
4141
MAX_DISTANCE_CAP_EPS = float(os.getenv("MAX_DISTANCE_CAP_EPS", "1e-4"))
42-
INSTANCE_RATIO_CUTOFF = float(os.getenv("INSTANCE_RATIO_CUTOFF", 10))
42+
FINAL_INSTANCE_RATIO_CUTOFF = float(os.getenv("FINAL_INSTANCE_RATIO_CUTOFF", 10))
43+
INITIAL_INSTANCE_RATIO_CUTOFF = float(os.getenv("INITIAL_INSTANCE_RATIO_CUTOFF", 50))
44+
INSTANCE_RATIO_FACTOR = float(os.getenv("INSTANCE_RATIO_FACTOR", 5.0))
4345
MAX_OVERLAP_EDGES = int(os.getenv("MAX_OVERLAP_EDGES", "5000000"))
4446

4547

48+
def ratio_cutoff(
49+
nG: int,
50+
R_base: float = FINAL_INSTANCE_RATIO_CUTOFF,
51+
R_extra: float = INITIAL_INSTANCE_RATIO_CUTOFF,
52+
k: float = INSTANCE_RATIO_FACTOR,
53+
) -> float:
54+
# nG==0 handled upstream (ratio undefined); return max tolerance for completeness
55+
return float(R_base + R_extra * np.exp(-nG / k))
56+
57+
4658
def normalize_distance(distance: float, voxel_size) -> float:
4759
"""
4860
Normalize a distance value to [0, 1] using the maximum distance represented by a voxel
@@ -86,7 +98,7 @@ def match_instances(gt: np.ndarray, pred: np.ndarray) -> dict | None:
8698
logging.info("No GT or Pred instances; returning only background match.")
8799
return {0: 0}
88100

89-
if (nP / nG) > INSTANCE_RATIO_CUTOFF:
101+
if (nP / nG) > ratio_cutoff(nG):
90102
logging.warning(
91103
f"WARNING: Skipping {nP} instances in submission, {nG} in ground truth, "
92104
f"because there are too many instances in the submission."
@@ -412,7 +424,7 @@ def compute_hausdorff_distance_roi(
412424
if a_n == 0 and b_n == 0:
413425
return 0.0
414426
elif a_n == 0 or b_n == 0:
415-
return np.inf
427+
return max_distance
416428

417429
vs = np.asarray(voxel_size, dtype=np.float64)
418430
if vs.size != ndim:
@@ -446,7 +458,7 @@ def score_instance(
446458
truth_label,
447459
voxel_size,
448460
hausdorff_distance_max=None,
449-
) -> dict[str, float]:
461+
) -> dict[str, float | str]:
450462
"""
451463
Score a single instance label volume against the ground truth instance label volume.
452464
@@ -465,13 +477,15 @@ def score_instance(
465477
logging.info("Scoring instance segmentation...")
466478
if hausdorff_distance_max is None:
467479
hausdorff_distance_max = compute_default_max_distance(voxel_size)
468-
logging.info(
480+
logging.debug(
469481
f"Using default maximum Hausdorff distance of {hausdorff_distance_max:.2f} for voxel size {voxel_size}."
470482
)
471483

472484
# Relabel the predicted instance labels to be consistent with the ground truth instance labels
473485
logging.info("Relabeling predicted instance labels...")
474-
pred_label, n_pred = cc3d.connected_components(pred_label, return_N=True)
486+
# pred_label, n_pred = cc3d.connected_components(pred_label, return_N=True)
487+
pred_label, remapping = renumber(pred_label, in_place=True)
488+
n_pred = len(remapping) - 1 # exclude background
475489

476490
# Match instances between ground truth and prediction
477491
mapping = match_instances(truth_label, pred_label)
@@ -480,9 +494,13 @@ def score_instance(
480494
# Too many instances in submission, skip scoring
481495
return {
482496
"accuracy": 0,
483-
"hausdorff_distance": np.inf,
484-
"normalized_hausdorff_distance": 0,
497+
"binary_accuracy": ((truth_label > 0) == (pred_label > 0)).mean(),
498+
"hausdorff_distance": hausdorff_distance_max,
499+
"normalized_hausdorff_distance": normalize_distance(
500+
hausdorff_distance_max, voxel_size
501+
),
485502
"combined_score": 0,
503+
"status": "skipped_too_many_instances",
486504
}
487505
elif len(mapping) == 1 and 0 in mapping:
488506
# Only background present in both ground truth and prediction
@@ -504,23 +522,24 @@ def score_instance(
504522
hausdorff_distances = np.concatenate(
505523
[
506524
hausdorff_distances,
507-
np.full(len(unmatched_pred), np.inf, dtype=np.float32),
525+
np.full(
526+
len(unmatched_pred), hausdorff_distance_max, dtype=np.float32
527+
),
508528
]
509529
)
510530
else:
511531
# No predictions to match (no GT XOR no Pred instances)
512-
hausdorff_distances = []
532+
hausdorff_distances = [hausdorff_distance_max]
533+
534+
if len(hausdorff_distances) == 0:
535+
hausdorff_distances = [hausdorff_distance_max]
513536

514537
# Compute the scores
515538
logging.info("Computing accuracy score...")
516539
accuracy = float((truth_label == pred_label).mean())
517-
hausdorff_dist = (
518-
np.mean(hausdorff_distances) if len(hausdorff_distances) > 0 else np.inf
519-
)
520-
normalized_hausdorff_dist = (
521-
np.mean([normalize_distance(hd, voxel_size) for hd in hausdorff_distances])
522-
if len(hausdorff_distances) > 0
523-
else 0.0
540+
hausdorff_dist = np.mean(hausdorff_distances)
541+
normalized_hausdorff_dist = np.mean(
542+
[normalize_distance(hd, voxel_size) for hd in hausdorff_distances]
524543
)
525544
combined_score = (accuracy * normalized_hausdorff_dist) ** 0.5 # geometric mean
526545
logging.info(f"Accuracy: {accuracy:.4f}")
@@ -532,6 +551,7 @@ def score_instance(
532551
"hausdorff_distance": hausdorff_dist,
533552
"normalized_hausdorff_distance": normalized_hausdorff_dist,
534553
"combined_score": combined_score,
554+
"status": "scored",
535555
} # type: ignore
536556

537557

@@ -566,6 +586,7 @@ def score_semantic(pred_label, truth_label) -> dict[str, float]:
566586
scores = {
567587
"iou": iou_score,
568588
"dice_score": dice_score if not np.isnan(dice_score) else 1,
589+
"status": "scored",
569590
}
570591

571592
logging.info(f"IoU: {scores['iou']:.4f}")
@@ -655,36 +676,29 @@ def score_label(
655676
def empty_label_score(
656677
label, crop_name, instance_classes=INSTANCE_CLASSES, truth_path=TRUTH_PATH
657678
):
679+
truth_path = UPath(truth_path)
680+
ds = zarr.open((truth_path / crop_name / label).path, mode="r")
681+
voxel_size = ds.attrs["voxel_size"]
658682
if label in instance_classes:
659683
truth_path = UPath(truth_path)
660684
return {
661685
"accuracy": 0,
662-
"hausdorff_distance": 0,
686+
"hausdorff_distance": compute_default_max_distance(voxel_size),
663687
"normalized_hausdorff_distance": 0,
664688
"combined_score": 0,
665-
"num_voxels": int(
666-
np.prod(
667-
zarr.open((truth_path / crop_name / label).path, mode="r").shape
668-
)
669-
),
670-
"voxel_size": zarr.open(
671-
(truth_path / crop_name / label).path, mode="r"
672-
).attrs["voxel_size"],
689+
"num_voxels": int(np.prod(ds.shape)),
690+
"voxel_size": voxel_size,
673691
"is_missing": True,
692+
"status": "missing",
674693
}
675694
else:
676695
return {
677696
"iou": 0,
678697
"dice_score": 0,
679-
"num_voxels": int(
680-
np.prod(
681-
zarr.open((truth_path / crop_name / label).path, mode="r").shape
682-
)
683-
),
684-
"voxel_size": zarr.open(
685-
(truth_path / crop_name / label).path, mode="r"
686-
).attrs["voxel_size"],
698+
"num_voxels": int(np.prod(ds.shape)),
699+
"voxel_size": voxel_size,
687700
"is_missing": True,
701+
"status": "missing",
688702
}
689703

690704

@@ -755,55 +769,32 @@ def get_evaluation_args(
755769

756770

757771
def missing_volume_score(
758-
truth_volume_path, instance_classes=INSTANCE_CLASSES
772+
truth_path, volume, instance_classes=INSTANCE_CLASSES
759773
) -> list[tuple]:
760774
"""
761775
Score a missing volume as 0's, congruent with the score_volume function.
762776
763777
Args:
764-
truth_volume_path (str): The path to the ground truth volume.
778+
truth_path (str): The path to the ground truth volume.
779+
volume (str): The name of the volume.
780+
instance_classes (list): A list of instance classes.
765781
766782
Returns:
767783
dict: A dictionary of scores for the volume.
768784
769785
Example usage:
770786
scores = missing_volume_score('truth.zarr/test_volume')
771787
"""
772-
logging.info(f"Scoring missing volume {truth_volume_path}...")
773-
truth_volume_path = UPath(truth_volume_path)
788+
logging.info(f"Scoring missing volume {volume}...")
789+
truth_path = UPath(truth_path)
790+
truth_volume_path = truth_path / volume
774791

775792
# Find labels to score
776793
truth_labels = [a for a in ensure_zgroup(truth_volume_path).array_keys()]
777794

778795
# Score each label
779796
scores = {
780-
label: (
781-
{
782-
"accuracy": 0.0,
783-
"hausdorff_distance": 0.0,
784-
"normalized_hausdorff_distance": 0.0,
785-
"combined_score": 0.0,
786-
"num_voxels": int(
787-
np.prod(zarr.open((truth_volume_path / label).path, mode="r").shape)
788-
),
789-
"voxel_size": zarr.open(
790-
(truth_volume_path / label).path, mode="r"
791-
).attrs["voxel_size"],
792-
"is_missing": True,
793-
}
794-
if label in instance_classes
795-
else {
796-
"iou": 0.0,
797-
"dice_score": 0.0,
798-
"num_voxels": int(
799-
np.prod(zarr.open((truth_volume_path / label).path, mode="r").shape)
800-
),
801-
"voxel_size": zarr.open(
802-
(truth_volume_path / label).path, mode="r"
803-
).attrs["voxel_size"],
804-
"is_missing": True,
805-
}
806-
)
797+
label: empty_label_score(label, volume, instance_classes, truth_path)
807798
for label in truth_labels
808799
}
809800

@@ -883,16 +874,30 @@ def combine_scores(
883874
logging.info("Computing overall scores...")
884875
overall_instance_scores = []
885876
overall_semantic_scores = []
877+
instance_total_voxels = sum(
878+
total_voxels[label] for label in label_scores if label in instance_classes
879+
)
880+
semantic_total_voxels = sum(
881+
total_voxels[label] for label in label_scores if label not in instance_classes
882+
)
886883
for label in label_scores:
887884
if label in instance_classes:
888-
overall_instance_scores += [label_scores[label]["combined_score"]]
885+
overall_instance_scores += [
886+
label_scores[label]["combined_score"] * total_voxels[label]
887+
]
889888
else:
890-
overall_semantic_scores += [label_scores[label]["iou"]]
889+
overall_semantic_scores += [
890+
label_scores[label]["iou"] * total_voxels[label]
891+
]
891892
scores["overall_instance_score"] = (
892-
np.nanmean(overall_instance_scores) if overall_instance_scores else 0
893+
np.nansum(overall_instance_scores) / instance_total_voxels
894+
if overall_instance_scores
895+
else 0
893896
)
894897
scores["overall_semantic_score"] = (
895-
np.nanmean(overall_semantic_scores) if overall_semantic_scores else 0
898+
np.nansum(overall_semantic_scores) / semantic_total_voxels
899+
if overall_semantic_scores
900+
else 0
896901
)
897902
scores["overall_score"] = (
898903
scores["overall_instance_score"] * scores["overall_semantic_score"]
@@ -1058,7 +1063,7 @@ def score_submission(
10581063

10591064
scores = {
10601065
volume: missing_volume_score(
1061-
truth_path / volume, instance_classes=instance_classes
1066+
truth_path, volume, instance_classes=instance_classes
10621067
)
10631068
for volume in missing_volumes
10641069
}
@@ -1148,6 +1153,8 @@ def sanitize_scores(scores):
11481153
for key, value in label_scores.items():
11491154
if value is None:
11501155
continue
1156+
if isinstance(value, str):
1157+
continue
11511158
if not np.isscalar(value) and len(value) == 1:
11521159
value = value[0]
11531160
if np.isscalar(value):

tests/test_evaluate_metrics.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def test_roi_hausdorff_empty_sets_and_missing_instance():
255255
voxel_size=voxel_size,
256256
max_distance=max_distance,
257257
)
258-
assert np.isclose(d1, np.inf)
258+
assert np.isclose(d1, max_distance)
259259

260260

261261
def test_roi_hausdorff_clips_to_max_distance_matches_reference():
@@ -358,6 +358,7 @@ def _fake_roi(*args, **kwargs):
358358

359359
truth_stats = cc3d.statistics(truth)
360360
pred_stats = cc3d.statistics(pred)
361+
max_distance = 7.0
361362

362363
d = ev.compute_hausdorff_distance_roi(
363364
truth,
@@ -366,9 +367,9 @@ def _fake_roi(*args, **kwargs):
366367
pred_stats,
367368
tid,
368369
voxel_size=(1.0, 1.0),
369-
max_distance=7.0,
370+
max_distance=max_distance,
370371
)
371-
assert np.isclose(d, np.inf)
372+
assert np.isclose(d, max_distance)
372373

373374

374375
def test_optimized_hausdorff_distances_per_instance():
@@ -615,7 +616,8 @@ def test_missing_volume_score_mixed_labels(tmp_path):
615616
_create_simple_volume(truth_root, "crop1", "sem", arr_sem)
616617

617618
scores = ev.missing_volume_score(
618-
truth_volume_path=(truth_root / "crop1").as_posix(),
619+
truth_path=truth_root,
620+
volume="crop1",
619621
instance_classes=["instance"],
620622
)
621623

tests/test_match_instance.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ def test_ratio_cutoff_returns_none(monkeypatch):
9292
We reload the module after setting env to avoid import-time caching issues.
9393
"""
9494
monkeypatch.setenv("INSTANCE_RATIO_CUTOFF", "1.0")
95+
monkeypatch.setenv("INITIAL_INSTANCE_RATIO_CUTOFF", "1.0")
96+
monkeypatch.setenv("FINAL_INSTANCE_RATIO_CUTOFF", "1.0")
97+
monkeypatch.setenv("INSTANCE_RATIO_FACTOR", "1.0")
9598
monkeypatch.setenv("MAX_OVERLAP_EDGES", "5000000")
9699
ev = _reload_module()
97100

0 commit comments

Comments
 (0)